diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 9d245b93c4..d24f16947f 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -16,12 +16,9 @@ NUMERIC_TYPES = (int, float, Decimal) class GeoFuncMixin: function = None - output_field_class = None geom_param_pos = (0,) def __init__(self, *expressions, **extra): - if 'output_field' not in extra and self.output_field_class: - extra['output_field'] = self.output_field_class() super().__init__(*expressions, **extra) # Ensure that value expressions are geometric. @@ -137,13 +134,13 @@ class Area(OracleToleranceMixin, GeoFunc): class Azimuth(GeoFunc): - output_field_class = FloatField + output_field = FloatField() arity = 2 geom_param_pos = (0, 1) class AsGeoJSON(GeoFunc): - output_field_class = TextField + output_field = TextField() def __init__(self, expression, bbox=False, crs=False, precision=8, **extra): expressions = [expression] @@ -163,7 +160,7 @@ class AsGeoJSON(GeoFunc): class AsGML(GeoFunc): geom_param_pos = (1,) - output_field_class = TextField + output_field = TextField() def __init__(self, expression, version=2, precision=8, **extra): expressions = [version, expression] @@ -189,7 +186,7 @@ class AsKML(AsGML): class AsSVG(GeoFunc): - output_field_class = TextField + output_field = TextField() def __init__(self, expression, relative=False, precision=8, **extra): relative = relative if hasattr(relative, 'resolve_expression') else int(relative) @@ -281,7 +278,7 @@ class ForceRHR(GeomOutputGeoFunc): class GeoHash(GeoFunc): - output_field_class = TextField + output_field = TextField() def __init__(self, expression, precision=None, **extra): expressions = [expression] @@ -345,7 +342,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): class LineLocatePoint(GeoFunc): - output_field_class = FloatField + output_field = FloatField() arity = 2 geom_param_pos = (0, 1) @@ -355,17 +352,17 @@ class MakeValid(GeoFunc): class MemSize(GeoFunc): - output_field_class = IntegerField + output_field = IntegerField() arity = 1 class NumGeometries(GeoFunc): - output_field_class = IntegerField + output_field = IntegerField() arity = 1 class NumPoints(GeoFunc): - output_field_class = IntegerField + output_field = IntegerField() arity = 1 diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py index 19f26ec53c..5f5ddc4757 100644 --- a/django/contrib/postgres/aggregates/statistics.py +++ b/django/contrib/postgres/aggregates/statistics.py @@ -8,7 +8,9 @@ __all__ = [ class StatAggregate(Aggregate): - def __init__(self, y, x, output_field=FloatField(), filter=None): + output_field = FloatField() + + def __init__(self, y, x, output_field=None, filter=None): if not x or not y: raise ValueError('Both y and x must be provided.') super().__init__(y, x, output_field=output_field, filter=filter) @@ -37,9 +39,7 @@ class RegrAvgY(StatAggregate): class RegrCount(StatAggregate): function = 'REGR_COUNT' - - def __init__(self, y, x, filter=None): - super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter) + output_field = IntegerField() def convert_value(self, value, expression, connection): if value is None: diff --git a/django/contrib/postgres/functions.py b/django/contrib/postgres/functions.py index 36b32e0751..819ce058e5 100644 --- a/django/contrib/postgres/functions.py +++ b/django/contrib/postgres/functions.py @@ -3,17 +3,9 @@ from django.db.models import DateTimeField, Func, UUIDField class RandomUUID(Func): template = 'GEN_RANDOM_UUID()' - - def __init__(self, output_field=None, **extra): - if output_field is None: - output_field = UUIDField() - super().__init__(output_field=output_field, **extra) + output_field = UUIDField() class TransactionNow(Func): template = 'CURRENT_TIMESTAMP' - - def __init__(self, output_field=None, **extra): - if output_field is None: - output_field = DateTimeField() - super().__init__(output_field=output_field, **extra) + output_field = DateTimeField() diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 889f5253c3..a14d510208 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -202,10 +202,12 @@ SearchVectorField.register_lookup(SearchVectorExact) class TrigramBase(Func): + output_field = FloatField() + def __init__(self, expression, string, **extra): if not hasattr(string, 'resolve_expression'): string = Value(string) - super().__init__(expression, string, output_field=FloatField(), **extra) + super().__init__(expression, string, **extra) class TrigramSimilarity(TrigramBase): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index db62541559..1937ca16c7 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -467,8 +467,10 @@ class DurationExpression(CombinedExpression): class TemporalSubtraction(CombinedExpression): + output_field = fields.DurationField() + def __init__(self, lhs, rhs): - super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField()) + super().__init__(lhs, self.SUB, rhs) def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) @@ -692,8 +694,7 @@ class Star(Expression): class Random(Expression): - def __init__(self): - super().__init__(output_field=fields.FloatField()) + output_field = fields.FloatField() def __repr__(self): return "Random()" @@ -1017,6 +1018,7 @@ class Subquery(Expression): class Exists(Subquery): template = 'EXISTS(%(subquery)s)' + output_field = fields.BooleanField() def __init__(self, *args, negated=False, **kwargs): self.negated = negated @@ -1025,10 +1027,6 @@ class Exists(Subquery): def __invert__(self): return type(self)(self.queryset, negated=(not self.negated), **self.extra) - @property - def output_field(self): - return fields.BooleanField() - def resolve_expression(self, query=None, **kwargs): # As a performance optimization, remove ordering since EXISTS doesn't # care about it, just whether or not a row matches. diff --git a/django/db/models/functions/base.py b/django/db/models/functions/base.py index e1d7938eac..4b4d4f4ea5 100644 --- a/django/db/models/functions/base.py +++ b/django/db/models/functions/base.py @@ -142,9 +142,7 @@ class Length(Transform): """Return the number of characters in the expression.""" function = 'LENGTH' lookup_name = 'length' - - def __init__(self, expression, *, output_field=None, **extra): - super().__init__(expression, output_field=output_field or fields.IntegerField(), **extra) + output_field = fields.IntegerField() def as_mysql(self, compiler, connection): return super().as_sql(compiler, connection, function='CHAR_LENGTH') @@ -157,11 +155,7 @@ class Lower(Transform): class Now(Func): template = 'CURRENT_TIMESTAMP' - - def __init__(self, output_field=None, **extra): - if output_field is None: - output_field = fields.DateTimeField() - super().__init__(output_field=output_field, **extra) + output_field = fields.DateTimeField() def as_postgresql(self, compiler, connection): # Postgres' CURRENT_TIMESTAMP means "the time at the start of the @@ -178,13 +172,7 @@ class StrIndex(Func): """ function = 'INSTR' arity = 2 - - def __init__(self, string, substring, **extra): - """ - string: the name of a field, or an expression returning a string - substring: the name of a field, or an expression returning a string - """ - super().__init__(string, substring, output_field=fields.IntegerField(), **extra) + output_field = fields.IntegerField() def as_postgresql(self, compiler, connection): return super().as_sql(compiler, connection, function='STRPOS') diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index a8e1ca45c1..c6614a14c2 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -2,14 +2,13 @@ from datetime import datetime from django.conf import settings from django.db.models import ( - DateField, DateTimeField, DurationField, IntegerField, TimeField, + DateField, DateTimeField, DurationField, Field, IntegerField, TimeField, Transform, ) from django.db.models.lookups import ( YearExact, YearGt, YearGte, YearLt, YearLte, ) from django.utils import timezone -from django.utils.functional import cached_property class TimezoneMixin: @@ -31,6 +30,7 @@ class TimezoneMixin: class Extract(TimezoneMixin, Transform): lookup_name = None + output_field = IntegerField() def __init__(self, expression, lookup_name=None, tzinfo=None, **extra): if self.lookup_name is None: @@ -75,10 +75,6 @@ class Extract(TimezoneMixin, Transform): ) return copy - @cached_property - def output_field(self): - return IntegerField() - class ExtractYear(Extract): lookup_name = 'year' @@ -183,17 +179,18 @@ class TruncBase(TimezoneMixin, Transform): raise ValueError('output_field must be either DateField, TimeField, or DateTimeField') # Passing dates or times to functions expecting datetimes is most # likely a mistake. - output_field = copy.output_field - explicit_output_field = field.__class__ != copy.output_field.__class__ + class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None + output_field = class_output_field or copy.output_field + has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__ if type(field) == DateField and ( isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')): raise ValueError("Cannot truncate DateField '%s' to %s. " % ( - field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField' + field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField' )) elif isinstance(field, TimeField) and ( isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')): raise ValueError("Cannot truncate TimeField '%s' to %s. " % ( - field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField' + field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField' )) return copy @@ -241,9 +238,7 @@ class TruncDay(TruncBase): class TruncDate(TruncBase): kind = 'date' lookup_name = 'date' - - def __init__(self, *args, output_field=None, **kwargs): - super().__init__(*args, output_field=DateField(), **kwargs) + output_field = DateField() def as_sql(self, compiler, connection): # Cast to date rather than truncate to date. @@ -256,9 +251,7 @@ class TruncDate(TruncBase): class TruncTime(TruncBase): kind = 'time' lookup_name = 'time' - - def __init__(self, *args, output_field=None, **kwargs): - super().__init__(*args, output_field=TimeField(), **kwargs) + output_field = TimeField() def as_sql(self, compiler, connection): # Cast to date rather than truncate to date.