From 8b1e324ca4aa1ae0721f6f5dcfba8325a751ef3c Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 7 Feb 2025 16:34:17 -0500 Subject: [PATCH] [5.2.x] Fixed #36173 -- Stabilized identity of Concat with an explicit output_field. When Expression.__init__() overrides make use of *args, **kwargs captures their argument values are respectively bound as a tuple and dict instances. These composite values might themselves contain values that require special identity treatments such as Concat(output_field) as it's a Field instance. Refs #30628 which introduced bound Field differentiation but lacked argument captures handling. Thanks erchenstein for the report. Backport of df2c4952df6d93c575fb8a3c853dc9d4c2449f36 from main --- django/db/models/expressions.py | 23 ++++++++++++++++------- tests/db_functions/text/test_concat.py | 14 ++++++++++++++ tests/expressions/tests.py | 23 +++++++++++++++++++++++ 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 57ceadcec4..444e2fab7b 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -523,6 +523,18 @@ class Expression(BaseExpression, Combinable): def _constructor_signature(cls): return inspect.signature(cls.__init__) + @classmethod + def _identity(cls, value): + if isinstance(value, tuple): + return tuple(map(cls._identity, value)) + if isinstance(value, dict): + return tuple((key, cls._identity(val)) for key, val in value.items()) + if isinstance(value, fields.Field): + if value.name and value.model: + return value.model._meta.label, value.name + return type(value) + return make_hashable(value) + @cached_property def identity(self): args, kwargs = self._constructor_args @@ -532,13 +544,10 @@ class Expression(BaseExpression, Combinable): next(arguments) identity = [self.__class__] for arg, value in arguments: - if isinstance(value, fields.Field): - if value.name and value.model: - value = (value.model._meta.label, value.name) - else: - value = type(value) - else: - value = make_hashable(value) + # If __init__() makes use of *args or **kwargs captures `value` + # will respectively be a tuple or a dict that must have its + # constituents unpacked (mainly if contain Field instances). + value = self._identity(value) identity.append((arg, value)) return tuple(identity) diff --git a/tests/db_functions/text/test_concat.py b/tests/db_functions/text/test_concat.py index 6e4cb91d3a..ffcd19fad6 100644 --- a/tests/db_functions/text/test_concat.py +++ b/tests/db_functions/text/test_concat.py @@ -107,3 +107,17 @@ class ConcatTests(TestCase): ctx.captured_queries[0]["sql"].count("::text"), 1 if connection.vendor == "postgresql" else 0, ) + + def test_equal(self): + self.assertEqual( + Concat("foo", "bar", output_field=TextField()), + Concat("foo", "bar", output_field=TextField()), + ) + self.assertNotEqual( + Concat("foo", "bar", output_field=TextField()), + Concat("foo", "bar", output_field=CharField()), + ) + self.assertNotEqual( + Concat("foo", "bar", output_field=TextField()), + Concat("bar", "foo", output_field=TextField()), + ) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index cfa33b6f45..89601de85b 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1433,6 +1433,29 @@ class SimpleExpressionTests(SimpleTestCase): Expression(TestModel._meta.get_field("other_field")), ) + class InitCaptureExpression(Expression): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # The identity of expressions that obscure their __init__() signature + # with *args and **kwargs cannot be determined when bound with + # different combinations or *args and **kwargs. + self.assertNotEqual( + InitCaptureExpression(IntegerField()), + InitCaptureExpression(output_field=IntegerField()), + ) + + # However, they should be considered equal when their bindings are + # equal. + self.assertEqual( + InitCaptureExpression(IntegerField()), + InitCaptureExpression(IntegerField()), + ) + self.assertEqual( + InitCaptureExpression(output_field=IntegerField()), + InitCaptureExpression(output_field=IntegerField()), + ) + def test_hash(self): self.assertEqual(hash(Expression()), hash(Expression())) self.assertEqual(