mirror of
				https://github.com/django/django.git
				synced 2025-10-30 17:16:10 +00:00 
			
		
		
		
	Fixed #36605 -- Added support for QuerySet.in_bulk() after .values() or .values_list().
co-authored-by: Adam Johnson <me@adamj.eu> co-authored-by: Simon Charette <charette.s@gmail.com>
This commit is contained in:
		| @@ -1166,8 +1166,6 @@ class QuerySet(AltersData): | ||||
|         """ | ||||
|         if self.query.is_sliced: | ||||
|             raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().") | ||||
|         if not issubclass(self._iterable_class, ModelIterable): | ||||
|             raise TypeError("in_bulk() cannot be used with values() or values_list().") | ||||
|         opts = self.model._meta | ||||
|         unique_fields = [ | ||||
|             constraint.fields[0] | ||||
| @@ -1184,6 +1182,59 @@ class QuerySet(AltersData): | ||||
|                 "in_bulk()'s field_name must be a unique field but %r isn't." | ||||
|                 % field_name | ||||
|             ) | ||||
|  | ||||
|         qs = self | ||||
|  | ||||
|         def get_obj(obj): | ||||
|             return obj | ||||
|  | ||||
|         if issubclass(self._iterable_class, ModelIterable): | ||||
|             # Raise an AttributeError if field_name is deferred. | ||||
|             get_key = operator.attrgetter(field_name) | ||||
|  | ||||
|         elif issubclass(self._iterable_class, ValuesIterable): | ||||
|             if field_name not in self.query.values_select: | ||||
|                 qs = qs.values(field_name, *self.query.values_select) | ||||
|  | ||||
|                 def get_obj(obj):  # noqa: F811 | ||||
|                     # We can safely mutate the dictionaries returned by | ||||
|                     # ValuesIterable here, since they are limited to the scope | ||||
|                     # of this function, and get_key runs before get_obj. | ||||
|                     del obj[field_name] | ||||
|                     return obj | ||||
|  | ||||
|             get_key = operator.itemgetter(field_name) | ||||
|  | ||||
|         elif issubclass(self._iterable_class, ValuesListIterable): | ||||
|             try: | ||||
|                 field_index = self.query.values_select.index(field_name) | ||||
|             except ValueError: | ||||
|                 # field_name is missing from values_select, so add it. | ||||
|                 field_index = 0 | ||||
|                 if issubclass(self._iterable_class, NamedValuesListIterable): | ||||
|                     kwargs = {"named": True} | ||||
|                 else: | ||||
|                     kwargs = {} | ||||
|                     get_obj = operator.itemgetter(slice(1, None)) | ||||
|                 qs = qs.values_list(field_name, *self.query.values_select, **kwargs) | ||||
|  | ||||
|             get_key = operator.itemgetter(field_index) | ||||
|  | ||||
|         elif issubclass(self._iterable_class, FlatValuesListIterable): | ||||
|             if self.query.values_select == (field_name,): | ||||
|                 # Mapping field_name to itself. | ||||
|                 get_key = get_obj | ||||
|             else: | ||||
|                 # Transform it back into a non-flat values_list(). | ||||
|                 qs = qs.values_list(field_name, *self.query.values_select) | ||||
|                 get_key = operator.itemgetter(0) | ||||
|                 get_obj = operator.itemgetter(1) | ||||
|  | ||||
|         else: | ||||
|             raise TypeError( | ||||
|                 f"in_bulk() cannot be used with {self._iterable_class.__name__}." | ||||
|             ) | ||||
|  | ||||
|         if id_list is not None: | ||||
|             if not id_list: | ||||
|                 return {} | ||||
| @@ -1193,15 +1244,16 @@ class QuerySet(AltersData): | ||||
|             # If the database has a limit on the number of query parameters | ||||
|             # (e.g. SQLite), retrieve objects in batches if necessary. | ||||
|             if batch_size and batch_size < len(id_list): | ||||
|                 qs = () | ||||
|                 results = () | ||||
|                 for offset in range(0, len(id_list), batch_size): | ||||
|                     batch = id_list[offset : offset + batch_size] | ||||
|                     qs += tuple(self.filter(**{filter_key: batch})) | ||||
|                     results += tuple(qs.filter(**{filter_key: batch})) | ||||
|                 qs = results | ||||
|             else: | ||||
|                 qs = self.filter(**{filter_key: id_list}) | ||||
|                 qs = qs.filter(**{filter_key: id_list}) | ||||
|         else: | ||||
|             qs = self._chain() | ||||
|         return {getattr(obj, field_name): obj for obj in qs} | ||||
|             qs = qs._chain() | ||||
|         return {get_key(obj): get_obj(obj) for obj in qs} | ||||
|  | ||||
|     async def ain_bulk(self, id_list=None, *, field_name="pk"): | ||||
|         return await sync_to_async(self.in_bulk)( | ||||
|   | ||||
| @@ -2588,6 +2588,11 @@ Example: | ||||
|  | ||||
| If you pass ``in_bulk()`` an empty list, you'll get an empty dictionary. | ||||
|  | ||||
| .. versionchanged:: 6.1 | ||||
|  | ||||
|     Support for chaining ``in_bulk()`` after :meth:`values` or | ||||
|     :meth:`values_list` was added. | ||||
|  | ||||
| ``iterator()`` | ||||
| ~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -175,7 +175,8 @@ Migrations | ||||
| Models | ||||
| ~~~~~~ | ||||
|  | ||||
| * ... | ||||
| * :meth:`.QuerySet.in_bulk` now supports chaining after | ||||
|   :meth:`.QuerySet.values` and :meth:`.QuerySet.values_list`. | ||||
|  | ||||
| Pagination | ||||
| ~~~~~~~~~~ | ||||
|   | ||||
| @@ -167,6 +167,67 @@ class CompositePKTests(TestCase): | ||||
|                 comment_dict = Comment.objects.in_bulk(id_list=id_list) | ||||
|         self.assertQuerySetEqual(comment_dict, id_list) | ||||
|  | ||||
|     def test_in_bulk_values(self): | ||||
|         result = Comment.objects.values().in_bulk([self.comment.pk]) | ||||
|         self.assertEqual( | ||||
|             result, | ||||
|             { | ||||
|                 self.comment.pk: { | ||||
|                     "tenant_id": self.comment.tenant_id, | ||||
|                     "id": self.comment.id, | ||||
|                     "user_id": self.comment.user_id, | ||||
|                     "text": self.comment.text, | ||||
|                     "integer": self.comment.integer, | ||||
|                 } | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_field(self): | ||||
|         result = Comment.objects.values("text").in_bulk([self.comment.pk]) | ||||
|         self.assertEqual( | ||||
|             result, | ||||
|             {self.comment.pk: {"text": self.comment.text}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_fields(self): | ||||
|         result = Comment.objects.values("pk", "text").in_bulk([self.comment.pk]) | ||||
|         self.assertEqual( | ||||
|             result, | ||||
|             {self.comment.pk: {"pk": self.comment.pk, "text": self.comment.text}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list(self): | ||||
|         result = Comment.objects.values_list("text").in_bulk([self.comment.pk]) | ||||
|         self.assertEqual(result, {self.comment.pk: (self.comment.text,)}) | ||||
|  | ||||
|     def test_in_bulk_values_list_multiple_fields(self): | ||||
|         result = Comment.objects.values_list("pk", "text").in_bulk([self.comment.pk]) | ||||
|         self.assertEqual( | ||||
|             result, {self.comment.pk: (self.comment.pk, self.comment.text)} | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_fields_are_pk(self): | ||||
|         result = Comment.objects.values_list("tenant", "id").in_bulk([self.comment.pk]) | ||||
|         self.assertEqual( | ||||
|             result, {self.comment.pk: (self.comment.tenant_id, self.comment.id)} | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat(self): | ||||
|         result = Comment.objects.values_list("text", flat=True).in_bulk( | ||||
|             [self.comment.pk] | ||||
|         ) | ||||
|         self.assertEqual(result, {self.comment.pk: self.comment.text}) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_pk(self): | ||||
|         result = Comment.objects.values_list("pk", flat=True).in_bulk([self.comment.pk]) | ||||
|         self.assertEqual(result, {self.comment.pk: self.comment.pk}) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_tenant(self): | ||||
|         result = Comment.objects.values_list("tenant", flat=True).in_bulk( | ||||
|             [self.comment.pk] | ||||
|         ) | ||||
|         self.assertEqual(result, {self.comment.pk: self.tenant.id}) | ||||
|  | ||||
|     def test_iterator(self): | ||||
|         """ | ||||
|         Test the .iterator() method of composite_pk models. | ||||
|   | ||||
| @@ -317,12 +317,246 @@ class LookupTests(TestCase): | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             Article.objects.all()[0:5].in_bulk([self.a1.id, self.a2.id]) | ||||
|  | ||||
|     def test_in_bulk_not_model_iterable(self): | ||||
|         msg = "in_bulk() cannot be used with values() or values_list()." | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             Author.objects.values().in_bulk() | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             Author.objects.values_list().in_bulk() | ||||
|     def test_in_bulk_values_empty(self): | ||||
|         arts = Article.objects.values().in_bulk([]) | ||||
|         self.assertEqual(arts, {}) | ||||
|  | ||||
|     def test_in_bulk_values_all(self): | ||||
|         Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() | ||||
|         arts = Article.objects.values().in_bulk() | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: { | ||||
|                     "id": self.a1.pk, | ||||
|                     "author_id": self.au1.pk, | ||||
|                     "headline": "Article 1", | ||||
|                     "pub_date": self.a1.pub_date, | ||||
|                     "slug": "a1", | ||||
|                 }, | ||||
|                 self.a2.pk: { | ||||
|                     "id": self.a2.pk, | ||||
|                     "author_id": self.au1.pk, | ||||
|                     "headline": "Article 2", | ||||
|                     "pub_date": self.a2.pub_date, | ||||
|                     "slug": "a2", | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_pks(self): | ||||
|         arts = Article.objects.values().in_bulk([self.a1.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: { | ||||
|                     "id": self.a1.pk, | ||||
|                     "author_id": self.au1.pk, | ||||
|                     "headline": "Article 1", | ||||
|                     "pub_date": self.a1.pub_date, | ||||
|                     "slug": "a1", | ||||
|                 } | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_fields(self): | ||||
|         arts = Article.objects.values("headline").in_bulk([self.a1.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.pk: {"headline": "Article 1"}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_fields_including_pk(self): | ||||
|         arts = Article.objects.values("pk", "headline").in_bulk([self.a1.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.pk: {"pk": self.a1.pk, "headline": "Article 1"}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_fields_pk(self): | ||||
|         arts = Article.objects.values("pk").in_bulk([self.a1.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.pk: {"pk": self.a1.pk}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_fields_id(self): | ||||
|         arts = Article.objects.values("id").in_bulk([self.a1.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.pk: {"id": self.a1.pk}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_alternative_field_name(self): | ||||
|         arts = Article.objects.values("headline").in_bulk( | ||||
|             [self.a1.slug], field_name="slug" | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.slug: {"headline": "Article 1"}}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_empty(self): | ||||
|         arts = Article.objects.values_list().in_bulk([]) | ||||
|         self.assertEqual(arts, {}) | ||||
|  | ||||
|     def test_in_bulk_values_list_all(self): | ||||
|         Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() | ||||
|         arts = Article.objects.values_list().in_bulk() | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: ( | ||||
|                     self.a1.pk, | ||||
|                     "Article 1", | ||||
|                     self.a1.pub_date, | ||||
|                     self.au1.pk, | ||||
|                     "a1", | ||||
|                 ), | ||||
|                 self.a2.pk: ( | ||||
|                     self.a2.pk, | ||||
|                     "Article 2", | ||||
|                     self.a2.pub_date, | ||||
|                     self.au1.pk, | ||||
|                     "a2", | ||||
|                 ), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_fields(self): | ||||
|         arts = Article.objects.values_list("headline").in_bulk([self.a1.pk, self.a2.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: ("Article 1",), | ||||
|                 self.a2.pk: ("Article 2",), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_fields_including_pk(self): | ||||
|         arts = Article.objects.values_list("pk", "headline").in_bulk( | ||||
|             [self.a1.pk, self.a2.pk] | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: (self.a1.pk, "Article 1"), | ||||
|                 self.a2.pk: (self.a2.pk, "Article 2"), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_fields_pk(self): | ||||
|         arts = Article.objects.values_list("pk").in_bulk([self.a1.pk, self.a2.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: (self.a1.pk,), | ||||
|                 self.a2.pk: (self.a2.pk,), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_fields_id(self): | ||||
|         arts = Article.objects.values_list("id").in_bulk([self.a1.pk, self.a2.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: (self.a1.pk,), | ||||
|                 self.a2.pk: (self.a2.pk,), | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_named(self): | ||||
|         arts = Article.objects.values_list(named=True).in_bulk([self.a1.pk, self.a2.pk]) | ||||
|         self.assertIsInstance(arts, dict) | ||||
|         self.assertEqual(len(arts), 2) | ||||
|         arts1 = arts[self.a1.pk] | ||||
|         self.assertEqual( | ||||
|             arts1._fields, ("pk", "id", "headline", "pub_date", "author_id", "slug") | ||||
|         ) | ||||
|         self.assertEqual(arts1.pk, self.a1.pk) | ||||
|         self.assertEqual(arts1.headline, "Article 1") | ||||
|         self.assertEqual(arts1.pub_date, self.a1.pub_date) | ||||
|         self.assertEqual(arts1.author_id, self.au1.pk) | ||||
|         self.assertEqual(arts1.slug, "a1") | ||||
|  | ||||
|     def test_in_bulk_values_list_named_fields(self): | ||||
|         arts = Article.objects.values_list("pk", "headline", named=True).in_bulk( | ||||
|             [self.a1.pk, self.a2.pk] | ||||
|         ) | ||||
|         self.assertIsInstance(arts, dict) | ||||
|         self.assertEqual(len(arts), 2) | ||||
|         arts1 = arts[self.a1.pk] | ||||
|         self.assertEqual(arts1._fields, ("pk", "headline")) | ||||
|         self.assertEqual(arts1.pk, self.a1.pk) | ||||
|         self.assertEqual(arts1.headline, "Article 1") | ||||
|  | ||||
|     def test_in_bulk_values_list_named_fields_alternative_field(self): | ||||
|         arts = Article.objects.values_list("headline", named=True).in_bulk( | ||||
|             [self.a1.slug, self.a2.slug], field_name="slug" | ||||
|         ) | ||||
|         self.assertEqual(len(arts), 2) | ||||
|         arts1 = arts[self.a1.slug] | ||||
|         self.assertEqual(arts1._fields, ("slug", "headline")) | ||||
|         self.assertEqual(arts1.slug, "a1") | ||||
|         self.assertEqual(arts1.headline, "Article 1") | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_empty(self): | ||||
|         arts = Article.objects.values_list(flat=True).in_bulk([]) | ||||
|         self.assertEqual(arts, {}) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_all(self): | ||||
|         Article.objects.exclude(pk__in=[self.a1.pk, self.a2.pk]).delete() | ||||
|         arts = Article.objects.values_list(flat=True).in_bulk() | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: self.a1.pk, | ||||
|                 self.a2.pk: self.a2.pk, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_pks(self): | ||||
|         arts = Article.objects.values_list(flat=True).in_bulk([self.a1.pk, self.a2.pk]) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: self.a1.pk, | ||||
|                 self.a2.pk: self.a2.pk, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_field(self): | ||||
|         arts = Article.objects.values_list("headline", flat=True).in_bulk( | ||||
|             [self.a1.pk, self.a2.pk] | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             {self.a1.pk: "Article 1", self.a2.pk: "Article 2"}, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_field_pk(self): | ||||
|         arts = Article.objects.values_list("pk", flat=True).in_bulk( | ||||
|             [self.a1.pk, self.a2.pk] | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: self.a1.pk, | ||||
|                 self.a2.pk: self.a2.pk, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_in_bulk_values_list_flat_field_id(self): | ||||
|         arts = Article.objects.values_list("id", flat=True).in_bulk( | ||||
|             [self.a1.pk, self.a2.pk] | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             arts, | ||||
|             { | ||||
|                 self.a1.pk: self.a1.pk, | ||||
|                 self.a2.pk: self.a2.pk, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_values(self): | ||||
|         # values() returns a list of dictionaries instead of object instances, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user