From 534d8d875eebac6aee278f0ffa6cc59760dac546 Mon Sep 17 00:00:00 2001 From: Adnan Umer Date: Wed, 18 Apr 2018 22:30:25 +0500 Subject: [PATCH] Fixed #28600 -- Added prefetch_related() support to RawQuerySet. --- django/db/models/query.py | 30 ++++++++++++++++++++++++++- docs/releases/2.1.txt | 2 ++ tests/prefetch_related/tests.py | 36 ++++++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index f16af1e91d..8d4c2d083c 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -727,7 +727,9 @@ class QuerySet: def raw(self, raw_query, params=None, translations=None, using=None): if using is None: using = self.db - return RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using) + qs = RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using) + qs._prefetch_related_lookups = self._prefetch_related_lookups[:] + return qs def _values(self, *fields, **expressions): clone = self._chain() @@ -1278,6 +1280,8 @@ class RawQuerySet: self.params = params or () self.translations = translations or {} self._result_cache = None + self._prefetch_related_lookups = () + self._prefetch_done = False def resolve_model_init_order(self): """Resolve the init field names and value positions.""" @@ -1289,9 +1293,33 @@ class RawQuerySet: model_init_names = [f.attname for f in model_init_fields] return model_init_names, model_init_order, annotation_fields + def prefetch_related(self, *lookups): + """Same as QuerySet.prefetch_related()""" + clone = self._clone() + if lookups == (None,): + clone._prefetch_related_lookups = () + else: + clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups + return clone + + def _prefetch_related_objects(self): + prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) + self._prefetch_done = True + + def _clone(self): + """Same as QuerySet._clone()""" + c = self.__class__( + self.raw_query, model=self.model, query=self.query, params=self.params, + translations=self.translations, using=self._db, hints=self._hints + ) + c._prefetch_related_lookups = self._prefetch_related_lookups[:] + return c + def _fetch_all(self): if self._result_cache is None: self._result_cache = list(self.iterator()) + if self._prefetch_related_lookups and not self._prefetch_done: + self._prefetch_related_objects() def __len__(self): self._fetch_all() diff --git a/docs/releases/2.1.txt b/docs/releases/2.1.txt index f3a99e5fc0..7595be46c0 100644 --- a/docs/releases/2.1.txt +++ b/docs/releases/2.1.txt @@ -239,6 +239,8 @@ Models * The new :meth:`.QuerySet.explain` method displays the database's execution plan of a queryset's query. +* :meth:`.QuerySet.raw` now supports :meth:`~.QuerySet.prefetch_related`. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index e92d7f349f..5a701bffec 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -14,7 +14,7 @@ from .models import ( ) -class PrefetchRelatedTests(TestCase): +class TestDataMixin: @classmethod def setUpTestData(cls): cls.book1 = Book.objects.create(title='Poems') @@ -38,6 +38,8 @@ class PrefetchRelatedTests(TestCase): cls.reader1.books_read.add(cls.book1, cls.book4) cls.reader2.books_read.add(cls.book2, cls.book4) + +class PrefetchRelatedTests(TestDataMixin, TestCase): def assertWhereContains(self, sql, needle): where_idx = sql.index('WHERE') self.assertEqual( @@ -281,6 +283,38 @@ class PrefetchRelatedTests(TestCase): self.assertWhereContains(sql, self.author1.id) +class RawQuerySetTests(TestDataMixin, TestCase): + def test_basic(self): + with self.assertNumQueries(2): + books = Book.objects.raw( + "SELECT * FROM prefetch_related_book WHERE id = %s", + (self.book1.id,) + ).prefetch_related('authors') + book1 = list(books)[0] + + with self.assertNumQueries(0): + self.assertCountEqual(book1.authors.all(), [self.author1, self.author2, self.author3]) + + def test_prefetch_before_raw(self): + with self.assertNumQueries(2): + books = Book.objects.prefetch_related('authors').raw( + "SELECT * FROM prefetch_related_book WHERE id = %s", + (self.book1.id,) + ) + book1 = list(books)[0] + + with self.assertNumQueries(0): + self.assertCountEqual(book1.authors.all(), [self.author1, self.author2, self.author3]) + + def test_clear(self): + with self.assertNumQueries(5): + with_prefetch = Author.objects.raw( + "SELECT * FROM prefetch_related_author" + ).prefetch_related('books') + without_prefetch = with_prefetch.prefetch_related(None) + [list(a.books.all()) for a in without_prefetch] + + class CustomPrefetchTests(TestCase): @classmethod def traverse_qs(cls, obj_iter, path):