diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 66c978d42e..6584e864db 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -76,6 +76,8 @@ from django.db import ( transaction, ) from django.db.models import Manager, Q, Window, signals +from django.db.models.expressions import ColPairs +from django.db.models.fields.tuple_lookups import TupleIn from django.db.models.functions import RowNumber from django.db.models.lookups import GreaterThan, LessThanOrEqual from django.db.models.query import QuerySet @@ -178,23 +180,19 @@ class ForwardManyToOneDescriptor: rel_obj_attr = self.field.get_foreign_related_value instance_attr = self.field.get_local_related_value instances_dict = {instance_attr(inst): inst for inst in instances} - related_field = self.field.foreign_related_fields[0] + related_fields = self.field.foreign_related_fields remote_field = self.field.remote_field - - # FIXME: This will need to be revisited when we introduce support for - # composite fields. In the meantime we take this practical approach to - # solve a regression on 1.6 when the reverse manager is hidden - # (related_name ends with a '+'). Refs #21410. - # The check for len(...) == 1 is a special case that allows the query - # to be join-less and smaller. Refs #21760. - if remote_field.hidden or len(self.field.foreign_related_fields) == 1: - query = { - "%s__in" - % related_field.name: {instance_attr(inst)[0] for inst in instances} - } - else: - query = {"%s__in" % self.field.related_query_name(): instances} - queryset = queryset.filter(**query) + queryset = queryset.filter( + TupleIn( + ColPairs( + queryset.model._meta.db_table, + related_fields, + related_fields, + self.field, + ), + list(instances_dict), + ) + ) # There can be only one object prefetched for each instance so clear # ordering if the query allows it without side effects. queryset.query.clear_ordering() diff --git a/tests/foreign_object/models/person.py b/tests/foreign_object/models/person.py index f0848e6c3e..d536ab63d7 100644 --- a/tests/foreign_object/models/person.py +++ b/tests/foreign_object/models/person.py @@ -107,6 +107,6 @@ class Friendship(models.Model): Person, from_fields=["to_friend_country_id", "to_friend_id"], to_fields=["person_country_id", "id"], - related_name="to_friend", + related_name="to_friend+", on_delete=models.CASCADE, ) diff --git a/tests/foreign_object/tests.py b/tests/foreign_object/tests.py index e288ecd7d4..8b36df29d7 100644 --- a/tests/foreign_object/tests.py +++ b/tests/foreign_object/tests.py @@ -4,7 +4,7 @@ import pickle from operator import attrgetter from django.core.exceptions import FieldError -from django.db import models +from django.db import connection, models from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import isolate_apps from django.utils import translation @@ -247,7 +247,7 @@ class MultiColumnFKTests(TestCase): normal_people = [m.person for m in Membership.objects.order_by("pk")] self.assertEqual(people, normal_people) - def test_prefetch_foreignkey_forward_works(self): + def test_prefetch_foreignobject_forward(self): Membership.objects.create( membership_country=self.usa, person=self.bob, group=self.cia ) @@ -264,7 +264,40 @@ class MultiColumnFKTests(TestCase): normal_people = [m.person for m in Membership.objects.order_by("pk")] self.assertEqual(people, normal_people) - def test_prefetch_foreignkey_reverse_works(self): + def test_prefetch_foreignobject_hidden_forward(self): + Friendship.objects.create( + from_friend_country=self.usa, + from_friend_id=self.bob.id, + to_friend_country_id=self.usa.id, + to_friend_id=self.george.id, + ) + Friendship.objects.create( + from_friend_country=self.usa, + from_friend_id=self.bob.id, + to_friend_country_id=self.soviet_union.id, + to_friend_id=self.sam.id, + ) + with self.assertNumQueries(2) as ctx: + friendships = list( + Friendship.objects.prefetch_related("to_friend").order_by("pk") + ) + prefetch_sql = ctx[-1]["sql"] + # Prefetch queryset should be filtered by all foreign related fields + # to prevent extra rows from being eagerly fetched. + prefetch_where_sql = prefetch_sql.split("WHERE")[-1] + for to_field_name in Friendship.to_friend.field.to_fields: + to_field = Person._meta.get_field(to_field_name) + with self.subTest(to_field=to_field): + self.assertIn( + connection.ops.quote_name(to_field.column), + prefetch_where_sql, + ) + self.assertNotIn(" JOIN ", prefetch_sql) + with self.assertNumQueries(0): + self.assertEqual(friendships[0].to_friend, self.george) + self.assertEqual(friendships[1].to_friend, self.sam) + + def test_prefetch_foreignobject_reverse(self): Membership.objects.create( membership_country=self.usa, person=self.bob, group=self.cia )