diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index a89acaf5a9..57ceadcec4 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1367,6 +1367,9 @@ class ColPairs(Expression): def resolve_expression(self, *args, **kwargs): return self + def select_format(self, compiler, sql, params): + return sql, params + class Ref(Expression): """ diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index a6e28b11fb..38d6308f53 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -40,7 +40,16 @@ def get_normalized_value(value, lhs): class RelatedIn(In): def get_prep_lookup(self): - if not isinstance(self.lhs, ColPairs): + from django.db.models.sql.query import Query # avoid circular import + + if isinstance(self.lhs, ColPairs): + if ( + isinstance(self.rhs, Query) + and not self.rhs.has_select_fields + and self.lhs.output_field.related_model is self.rhs.model + ): + self.rhs.set_values([f.name for f in self.lhs.sources]) + else: if self.rhs_is_direct_value(): # If we get here, we are dealing with single-column relations. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index 1e77f095c8..161e2973a0 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -36,7 +36,8 @@ class TupleLookupMixin: self.check_rhs_is_tuple_or_list() self.check_rhs_length_equals_lhs_length() else: - self.check_rhs_is_outer_ref() + self.check_rhs_is_supported_expression() + super().get_prep_lookup() return self.rhs def check_rhs_is_tuple_or_list(self): @@ -54,13 +55,13 @@ class TupleLookupMixin: f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements" ) - def check_rhs_is_outer_ref(self): - if not isinstance(self.rhs, ResolvedOuterRef): + def check_rhs_is_supported_expression(self): + if not isinstance(self.rhs, (ResolvedOuterRef, Query)): lhs_str = self.get_lhs_str() rhs_cls = self.rhs.__class__.__name__ raise ValueError( f"{self.lookup_name!r} subquery lookup of {lhs_str} " - f"only supports OuterRef objects (received {rhs_cls!r})" + f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})" ) def get_lhs_str(self): @@ -90,11 +91,14 @@ class TupleLookupMixin: return compiler.compile(Tuple(*args)) else: sql, params = compiler.compile(self.rhs) - if not isinstance(self.rhs, ColPairs): + if isinstance(self.rhs, ColPairs): + return "(%s)" % sql, params + elif isinstance(self.rhs, Query): + return super().process_rhs(compiler, connection) + else: raise ValueError( "Composite field lookups only work with composite expressions." ) - return "(%s)" % sql, params def get_fallback_sql(self, compiler, connection): raise NotImplementedError( @@ -110,6 +114,8 @@ class TupleLookupMixin: class TupleExact(TupleLookupMixin, Exact): def get_fallback_sql(self, compiler, connection): + if isinstance(self.rhs, Query): + return super(TupleLookupMixin, self).as_sql(compiler, connection) # Process right-hand-side to trigger sanitization. self.process_rhs(compiler, connection) # e.g.: (a, b, c) == (x, y, z) as SQL: @@ -262,7 +268,7 @@ class TupleIn(TupleLookupMixin, In): self.check_rhs_elements_length_equals_lhs_length() else: self.check_rhs_is_query() - self.check_rhs_select_length_equals_lhs_length() + super(TupleLookupMixin, self).get_prep_lookup() return self.rhs # skip checks from mixin @@ -292,19 +298,10 @@ class TupleIn(TupleLookupMixin, In): f"must be a Query object (received {rhs_cls!r})" ) - def check_rhs_select_length_equals_lhs_length(self): - len_rhs = len(self.rhs.select) - if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs): - len_rhs = len(self.rhs.select[0]) - len_lhs = len(self.lhs) - if len_rhs != len_lhs: - lhs_str = self.get_lhs_str() - raise ValueError( - f"{self.lookup_name!r} subquery lookup of {lhs_str} " - f"must have {len_lhs} fields (received {len_rhs})" - ) - def process_rhs(self, compiler, connection): + if not self.rhs_is_direct_value(): + return super(TupleLookupMixin, self).process_rhs(compiler, connection) + rhs = self.rhs if not rhs: raise EmptyResultSet @@ -326,19 +323,12 @@ class TupleIn(TupleLookupMixin, In): return compiler.compile(Tuple(*result)) - def as_subquery_sql(self, compiler, connection): - lhs = self.lhs - rhs = self.rhs - if isinstance(lhs, ColPairs): - rhs = rhs.clone() - rhs.set_values([source.name for source in lhs.sources]) - lhs = Tuple(lhs) - return compiler.compile(In(lhs, rhs)) - def get_fallback_sql(self, compiler, connection): rhs = self.rhs if not rhs: raise EmptyResultSet + if not self.rhs_is_direct_value(): + return super(TupleLookupMixin, self).as_sql(compiler, connection) # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) @@ -351,11 +341,6 @@ class TupleIn(TupleLookupMixin, In): return root.as_sql(compiler, connection) - def as_sql(self, compiler, connection): - if not self.rhs_is_direct_value(): - return self.as_subquery_sql(compiler, connection) - return super().as_sql(compiler, connection) - tuple_lookups = { "exact": TupleExact, diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 2d80385fde..c5b2a7a0bb 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -392,16 +392,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): def get_prep_lookup(self): from django.db.models.sql.query import Query # avoid circular import - if isinstance(self.rhs, Query): - if self.rhs.has_limit_one(): - if not self.rhs.has_select_fields: - self.rhs.clear_select_clause() - self.rhs.add_fields(["pk"]) - else: + if isinstance(query := self.rhs, Query): + if not query.has_limit_one(): raise ValueError( "The QuerySet value for an exact lookup must be limited to " "one result using slicing." ) + lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 + if (rhs_len := query._subquery_fields_len) != lhs_len: + raise ValueError( + f"The QuerySet value for the exact lookup must have {lhs_len} " + f"selected fields (received {rhs_len})" + ) + if not query.has_select_fields: + query.clear_select_clause() + query.add_fields(["pk"]) return super().get_prep_lookup() def as_sql(self, compiler, connection): @@ -518,6 +523,12 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): from django.db.models.sql.query import Query # avoid circular import if isinstance(self.rhs, Query): + lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1 + if (rhs_len := self.rhs._subquery_fields_len) != lhs_len: + raise ValueError( + f"The QuerySet value for the 'in' lookup must have {lhs_len} " + f"selected fields (received {rhs_len})" + ) self.rhs.clear_ordering(clear_default=True) if not self.rhs.has_select_fields: self.rhs.clear_select_clause() diff --git a/django/db/models/query.py b/django/db/models/query.py index aaeb8d30cc..b898f3ec1a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1970,10 +1970,6 @@ class QuerySet(AltersData): self._known_related_objects.setdefault(field, {}).update(objects) def resolve_expression(self, *args, **kwargs): - if self._fields and len(self._fields) > 1: - # values() queryset can only be used as nested queries - # if they are set up to select only a single field. - raise TypeError("Cannot use multi-field values as a filter value.") query = self.query.resolve_expression(*args, **kwargs) query._db = self._db return query diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ec47d9aa24..0d1fe5fb43 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1224,6 +1224,12 @@ class Query(BaseExpression): if self.selected: self.selected[alias] = alias + @property + def _subquery_fields_len(self): + if self.has_select_fields: + return len(self.selected) + return len(self.model._meta.pk_fields) + def resolve_expression(self, query, *args, **kwargs): clone = self.clone() # Subqueries need to use a different set of aliases than the outer query. diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py index 6286ed2354..c85869afa7 100644 --- a/tests/composite_pk/models/tenant.py +++ b/tests/composite_pk/models/tenant.py @@ -44,6 +44,7 @@ class Comment(models.Model): related_name="comments", ) text = models.TextField(default="", blank=True) + integer = models.IntegerField(default=0) class Post(models.Model): diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py index fe942b9e5b..d4c6ef13e0 100644 --- a/tests/composite_pk/test_filter.py +++ b/tests/composite_pk/test_filter.py @@ -10,7 +10,7 @@ from django.db.models import ( ) from django.db.models.functions import Cast from django.db.models.lookups import Exact -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature from .models import Comment, Tenant, User @@ -182,6 +182,30 @@ class CompositePKFilterTests(TestCase): Comment.objects.filter(pk__in=pks).order_by("pk"), objs ) + def test_filter_comments_by_pk_in_subquery(self): + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk), + ), + [self.comment_1], + ) + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( + "tenant_id", "id" + ), + ), + [self.comment_1], + ) + self.comment_2.integer = self.comment_1.id + self.comment_2.save() + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.values("tenant_id", "integer"), + ), + [self.comment_1], + ) + def test_filter_comments_by_user_and_order_by_pk_asc(self): self.assertSequenceEqual( Comment.objects.filter(user=self.user_1).order_by("pk"), @@ -440,16 +464,40 @@ class CompositePKFilterTests(TestCase): queryset = Comment.objects.filter(**{f"id{lookup}": subquery}) self.assertEqual(queryset.count(), expected_count) - def test_non_outer_ref_subquery(self): - # If rhs is any non-OuterRef object with an as_sql() function. + def test_unsupported_rhs(self): pk = Exact(F("tenant_id"), 1) msg = ( - "'exact' subquery lookup of 'pk' only supports OuterRef objects " - "(received 'Exact')" + "'exact' subquery lookup of 'pk' only supports OuterRef " + "and QuerySet objects (received 'Exact')" ) with self.assertRaisesMessage(ValueError, msg): Comment.objects.filter(pk=pk) + @skipUnlessDBFeature("allow_sliced_subqueries_with_in") + def test_filter_comments_by_pk_exact_subquery(self): + self.assertSequenceEqual( + Comment.objects.filter( + pk=Comment.objects.filter(pk=self.comment_1.pk)[:1], + ), + [self.comment_1], + ) + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.filter(pk=self.comment_1.pk).values( + "tenant_id", "id" + )[:1], + ), + [self.comment_1], + ) + self.comment_2.integer = self.comment_1.id + self.comment_2.save() + self.assertSequenceEqual( + Comment.objects.filter( + pk__in=Comment.objects.values("tenant_id", "integer"), + )[:1], + [self.comment_1], + ) + def test_outer_ref_not_composite_pk(self): subquery = Comment.objects.filter(pk=OuterRef("id")).values("id") queryset = Comment.objects.filter(id=Subquery(subquery)) diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py index 6b09480fb0..18fa53d9c0 100644 --- a/tests/composite_pk/tests.py +++ b/tests/composite_pk/tests.py @@ -109,13 +109,10 @@ class CompositePKTests(TestCase): def test_composite_pk_in_fields(self): user_fields = {f.name for f in User._meta.get_fields()} - self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"}) + self.assertTrue({"pk", "tenant", "id"}.issubset(user_fields)) comment_fields = {f.name for f in Comment._meta.get_fields()} - self.assertEqual( - comment_fields, - {"pk", "tenant", "id", "user_id", "user", "text"}, - ) + self.assertTrue({"pk", "tenant", "id"}.issubset(comment_fields)) def test_pk_field(self): pk = User._meta.get_field("pk") @@ -174,7 +171,7 @@ class CompositePKTests(TestCase): self.assertEqual(user.email, self.user.email) def test_model_forms(self): - fields = ["tenant", "id", "user_id", "text"] + fields = ["tenant", "id", "user_id", "text", "integer"] self.assertEqual(list(CommentForm.base_fields), fields) form = modelform_factory(Comment, fields="__all__") diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index 42717c4f11..008f118994 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -63,9 +63,11 @@ class TupleLookupsTests(TestCase): ) def test_exact_subquery(self): - with self.assertRaisesMessage( - ValueError, "'exact' doesn't support multi-column subqueries." - ): + msg = ( + "The QuerySet value for the exact lookup must have 2 selected " + "fields (received 1)" + ) + with self.assertRaisesMessage(ValueError, msg): subquery = Customer.objects.filter(id=self.customer_1.id)[:1] self.assertSequenceEqual( Contact.objects.filter(customer=subquery).order_by("id"), () @@ -140,11 +142,11 @@ class TupleLookupsTests(TestCase): def test_tuple_in_subquery_must_have_2_fields(self): lhs = (F("customer_code"), F("company_code")) rhs = Customer.objects.values_list("customer_id").query - with self.assertRaisesMessage( - ValueError, - "'in' subquery lookup of ('customer_code', 'company_code') " - "must have 2 fields (received 1)", - ): + msg = ( + "The QuerySet value for the 'in' lookup must have 2 selected " + "fields (received 1)" + ) + with self.assertRaisesMessage(ValueError, msg): TupleIn(lhs, rhs) def test_tuple_in_subquery(self): diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index df96546d04..e19fbca521 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -789,6 +789,14 @@ class LookupTests(TestCase): sql = ctx.captured_queries[0]["sql"] self.assertIn("IN (%s)" % self.a1.pk, sql) + def test_in_select_mismatch(self): + msg = ( + "The QuerySet value for the 'in' lookup must have 1 " + "selected fields (received 2)" + ) + with self.assertRaisesMessage(ValueError, msg): + Article.objects.filter(id__in=Article.objects.values("id", "headline")) + def test_error_messages(self): # Programming errors are pointed out with nice error messages with self.assertRaisesMessage( @@ -1364,6 +1372,14 @@ class LookupTests(TestCase): authors = Author.objects.filter(id=authors_max_ids[:1]) self.assertEqual(authors.get(), newest_author) + def test_exact_query_rhs_with_selected_columns_mismatch(self): + msg = ( + "The QuerySet value for the exact lookup must have 1 " + "selected fields (received 2)" + ) + with self.assertRaisesMessage(ValueError, msg): + Author.objects.filter(id=Author.objects.values("id", "name")[:1]) + def test_isnull_non_boolean_value(self): msg = "The QuerySet value for an isnull lookup must be True or False." tests = [ diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 45866fd50f..c429a93af3 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -922,20 +922,6 @@ class Queries1Tests(TestCase): [self.t2, self.t3], ) - # Multi-valued values() and values_list() querysets should raise errors. - with self.assertRaisesMessage( - TypeError, "Cannot use multi-field values as a filter value." - ): - Tag.objects.filter( - name__in=Tag.objects.filter(parent=self.t1).values("name", "id") - ) - with self.assertRaisesMessage( - TypeError, "Cannot use multi-field values as a filter value." - ): - Tag.objects.filter( - name__in=Tag.objects.filter(parent=self.t1).values_list("name", "id") - ) - def test_ticket9985(self): # qs.values_list(...).values(...) combinations should work. self.assertSequenceEqual(