From 504ce3914fa86a58f29f5369a806f3fe56a0d59a Mon Sep 17 00:00:00 2001
From: Sergey Fedoseev <fedoseev.sergey@gmail.com>
Date: Sat, 15 Jul 2017 06:56:01 +0500
Subject: [PATCH] Fixed #28394 -- Allowed setting BaseExpression.output_field
 (renamed from _output_field).

---
 django/contrib/postgres/aggregates/general.py |  2 +-
 django/contrib/postgres/fields/jsonb.py       |  2 +-
 django/contrib/postgres/search.py             |  6 +--
 django/db/models/expressions.py               | 48 ++++++++++---------
 django/db/models/functions/base.py            |  4 +-
 django/db/models/functions/datetime.py        | 10 ++--
 docs/releases/2.0.txt                         |  3 ++
 tests/expressions/tests.py                    | 10 ++++
 8 files changed, 49 insertions(+), 36 deletions(-)

diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py
index 91835a9ca3..5bbf29e8ab 100644
--- a/django/contrib/postgres/aggregates/general.py
+++ b/django/contrib/postgres/aggregates/general.py
@@ -37,7 +37,7 @@ class BoolOr(Aggregate):
 
 class JSONBAgg(Aggregate):
     function = 'JSONB_AGG'
-    _output_field = JSONField()
+    output_field = JSONField()
 
     def convert_value(self, value, expression, connection, context):
         if not value:
diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py
index a3a3381745..a06187c4bc 100644
--- a/django/contrib/postgres/fields/jsonb.py
+++ b/django/contrib/postgres/fields/jsonb.py
@@ -115,7 +115,7 @@ class KeyTransform(Transform):
 class KeyTextTransform(KeyTransform):
     operator = '->>'
     nested_operator = '#>>'
-    _output_field = TextField()
+    output_field = TextField()
 
 
 class KeyTransformTextLookupMixin:
diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py
index 9d66976ae0..cc47dbfeb6 100644
--- a/django/contrib/postgres/search.py
+++ b/django/contrib/postgres/search.py
@@ -47,7 +47,7 @@ class SearchVectorCombinable:
 class SearchVector(SearchVectorCombinable, Func):
     function = 'to_tsvector'
     arg_joiner = " || ' ' || "
-    _output_field = SearchVectorField()
+    output_field = SearchVectorField()
     config = None
 
     def __init__(self, *expressions, **extra):
@@ -125,7 +125,7 @@ class SearchQueryCombinable:
 
 
 class SearchQuery(SearchQueryCombinable, Value):
-    _output_field = SearchQueryField()
+    output_field = SearchQueryField()
 
     def __init__(self, value, output_field=None, *, config=None, invert=False):
         self.config = config
@@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
 
 class SearchRank(Func):
     function = 'ts_rank'
-    _output_field = FloatField()
+    output_field = FloatField()
 
     def __init__(self, vector, query, **extra):
         if not hasattr(vector, 'resolve_expression'):
diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index 7e5c7313ce..dc2be87453 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -125,11 +125,11 @@ class BaseExpression:
 
     # aggregate specific fields
     is_summary = False
-    _output_field = None
+    _output_field_resolved_to_none = False
 
     def __init__(self, output_field=None):
         if output_field is not None:
-            self._output_field = output_field
+            self.output_field = output_field
 
     def get_db_converters(self, connection):
         return [self.convert_value] + self.output_field.get_db_converters(connection)
@@ -223,21 +223,23 @@ class BaseExpression:
     @cached_property
     def output_field(self):
         """Return the output type of this expressions."""
-        if self._output_field_or_none is None:
-            raise FieldError("Cannot resolve expression type, unknown output_field")
-        return self._output_field_or_none
+        output_field = self._resolve_output_field()
+        if output_field is None:
+            self._output_field_resolved_to_none = True
+            raise FieldError('Cannot resolve expression type, unknown output_field')
+        return output_field
 
     @cached_property
     def _output_field_or_none(self):
         """
-        Return the output field of this expression, or None if no output type
-        can be resolved. Note that the 'output_field' property will raise
-        FieldError if no type can be resolved, but this attribute allows for
-        None values.
+        Return the output field of this expression, or None if
+        _resolve_output_field() didn't return an output type.
         """
-        if self._output_field is None:
-            self._output_field = self._resolve_output_field()
-        return self._output_field
+        try:
+            return self.output_field
+        except FieldError:
+            if not self._output_field_resolved_to_none:
+                raise
 
     def _resolve_output_field(self):
         """
@@ -249,9 +251,9 @@ class BaseExpression:
         the type here is a convenience for the common case. The user should
         supply their own output_field with more complex computations.
 
-        If a source does not have an `_output_field` then we exclude it from
-        this check. If all sources are `None`, then an error will be thrown
-        higher up the stack in the `output_field` property.
+        If a source's output field resolves to None, exclude it from this check.
+        If all sources are None, then an error is raised higher up the stack in
+        the output_field property.
         """
         sources_iter = (source for source in self.get_source_fields() if source is not None)
         for output_field in sources_iter:
@@ -603,14 +605,14 @@ class Value(Expression):
     def as_sql(self, compiler, connection):
         connection.ops.check_expression_support(self)
         val = self.value
-        # check _output_field to avoid triggering an exception
-        if self._output_field is not None:
+        output_field = self._output_field_or_none
+        if output_field is not None:
             if self.for_save:
-                val = self.output_field.get_db_prep_save(val, connection=connection)
+                val = output_field.get_db_prep_save(val, connection=connection)
             else:
-                val = self.output_field.get_db_prep_value(val, connection=connection)
-            if hasattr(self._output_field, 'get_placeholder'):
-                return self._output_field.get_placeholder(val, compiler, connection), [val]
+                val = output_field.get_db_prep_value(val, connection=connection)
+            if hasattr(output_field, 'get_placeholder'):
+                return output_field.get_placeholder(val, compiler, connection), [val]
         if val is None:
             # cx_Oracle does not always convert None to the appropriate
             # NULL type (like in case expressions using numbers), so we
@@ -652,7 +654,7 @@ class RawSQL(Expression):
         return [self]
 
     def __hash__(self):
-        h = hash(self.sql) ^ hash(self._output_field)
+        h = hash(self.sql) ^ hash(self.output_field)
         for param in self.params:
             h ^= hash(param)
         return h
@@ -998,7 +1000,7 @@ class Exists(Subquery):
         super().__init__(*args, **kwargs)
 
     def __invert__(self):
-        return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra)
+        return type(self)(self.queryset, negated=(not self.negated), **self.extra)
 
     @property
     def output_field(self):
diff --git a/django/db/models/functions/base.py b/django/db/models/functions/base.py
index c487bb4ab5..82a6083c58 100644
--- a/django/db/models/functions/base.py
+++ b/django/db/models/functions/base.py
@@ -24,12 +24,12 @@ class Cast(Func):
 
     def as_sql(self, compiler, connection, **extra_context):
         if 'db_type' not in extra_context:
-            extra_context['db_type'] = self._output_field.db_type(connection)
+            extra_context['db_type'] = self.output_field.db_type(connection)
         return super().as_sql(compiler, connection, **extra_context)
 
     def as_mysql(self, compiler, connection):
         extra_context = {}
-        output_field_class = type(self._output_field)
+        output_field_class = type(self.output_field)
         if output_field_class in self.mysql_types:
             extra_context['db_type'] = self.mysql_types[output_field_class]
         return self.as_sql(compiler, connection, **extra_context)
diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py
index a56731e48f..52f1f73ae8 100644
--- a/django/db/models/functions/datetime.py
+++ b/django/db/models/functions/datetime.py
@@ -243,9 +243,8 @@ class TruncDate(TruncBase):
     kind = 'date'
     lookup_name = 'date'
 
-    @cached_property
-    def output_field(self):
-        return DateField()
+    def __init__(self, *args, output_field=None, **kwargs):
+        super().__init__(*args, output_field=DateField(), **kwargs)
 
     def as_sql(self, compiler, connection):
         # Cast to date rather than truncate to date.
@@ -259,9 +258,8 @@ class TruncTime(TruncBase):
     kind = 'time'
     lookup_name = 'time'
 
-    @cached_property
-    def output_field(self):
-        return TimeField()
+    def __init__(self, *args, output_field=None, **kwargs):
+        super().__init__(*args, output_field=TimeField(), **kwargs)
 
     def as_sql(self, compiler, connection):
         # Cast to date rather than truncate to date.
diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt
index 8b41076aa7..81162e549f 100644
--- a/docs/releases/2.0.txt
+++ b/docs/releases/2.0.txt
@@ -551,6 +551,9 @@ Miscellaneous
   in the cache backend as an intermediate class in ``CacheKeyWarning``'s
   inheritance of ``RuntimeWarning``.
 
+* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need
+  to update custom expressions.
+
 .. _deprecated-features-2.0:
 
 Features deprecated in 2.0
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index ed3e2ce628..532690ca8d 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase):
         outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
         self.assertFalse(outer.exists())
 
+    def test_explicit_output_field(self):
+        class FuncA(Func):
+            output_field = models.CharField()
+
+        class FuncB(Func):
+            pass
+
+        expr = FuncB(FuncA())
+        self.assertEqual(expr.output_field, FuncA.output_field)
+
 
 class IterableLookupInnerExpressionsTests(TestCase):
     @classmethod