mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.
This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details. git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -46,7 +46,7 @@ class Aggregate(object): | ||||
|         # Validate that the backend has a fully supported, correct | ||||
|         # implementation of this aggregate | ||||
|         query.connection.ops.check_aggregate_support(aggregate) | ||||
|         query.aggregate_select[alias] = aggregate | ||||
|         query.aggregates[alias] = aggregate | ||||
|  | ||||
| class Avg(Aggregate): | ||||
|     name = 'Avg' | ||||
|   | ||||
| @@ -596,7 +596,7 @@ class QuerySet(object): | ||||
|  | ||||
|         obj = self._clone() | ||||
|  | ||||
|         obj._setup_aggregate_query() | ||||
|         obj._setup_aggregate_query(kwargs.keys()) | ||||
|  | ||||
|         # Add the aggregates to the query | ||||
|         for (alias, aggregate_expr) in kwargs.items(): | ||||
| @@ -693,7 +693,7 @@ class QuerySet(object): | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def _setup_aggregate_query(self): | ||||
|     def _setup_aggregate_query(self, aggregates): | ||||
|         """ | ||||
|         Prepare the query for computing a result that contains aggregate annotations. | ||||
|         """ | ||||
| @@ -773,6 +773,8 @@ class ValuesQuerySet(QuerySet): | ||||
|  | ||||
|         self.query.select = [] | ||||
|         self.query.add_fields(self.field_names, False) | ||||
|         if self.aggregate_names is not None: | ||||
|             self.query.set_aggregate_mask(self.aggregate_names) | ||||
|  | ||||
|     def _clone(self, klass=None, setup=False, **kwargs): | ||||
|         """ | ||||
| @@ -798,13 +800,17 @@ class ValuesQuerySet(QuerySet): | ||||
|             raise TypeError("Merging '%s' classes must involve the same values in each case." | ||||
|                     % self.__class__.__name__) | ||||
|  | ||||
|     def _setup_aggregate_query(self): | ||||
|     def _setup_aggregate_query(self, aggregates): | ||||
|         """ | ||||
|         Prepare the query for computing a result that contains aggregate annotations. | ||||
|         """ | ||||
|         self.query.set_group_by() | ||||
|  | ||||
|         super(ValuesQuerySet, self)._setup_aggregate_query() | ||||
|         if self.aggregate_names is not None: | ||||
|             self.aggregate_names.extend(aggregates) | ||||
|             self.query.set_aggregate_mask(self.aggregate_names) | ||||
|  | ||||
|         super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) | ||||
|  | ||||
|     def as_sql(self): | ||||
|         """ | ||||
| @@ -824,6 +830,7 @@ class ValuesListQuerySet(ValuesQuerySet): | ||||
|     def iterator(self): | ||||
|         if self.extra_names is not None: | ||||
|             self.query.trim_extra_select(self.extra_names) | ||||
|  | ||||
|         if self.flat and len(self._fields) == 1: | ||||
|             for row in self.query.results_iter(): | ||||
|                 yield row[0] | ||||
| @@ -837,6 +844,7 @@ class ValuesListQuerySet(ValuesQuerySet): | ||||
|             extra_names = self.query.extra_select.keys() | ||||
|             field_names = self.field_names | ||||
|             aggregate_names = self.query.aggregate_select.keys() | ||||
|  | ||||
|             names = extra_names + field_names + aggregate_names | ||||
|  | ||||
|             # If a field list has been specified, use it. Otherwise, use the | ||||
|   | ||||
| @@ -77,7 +77,9 @@ class BaseQuery(object): | ||||
|         self.related_select_cols = [] | ||||
|  | ||||
|         # SQL aggregate-related attributes | ||||
|         self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function | ||||
|         self.aggregates = SortedDict() # Maps alias -> SQL aggregate function | ||||
|         self.aggregate_select_mask = None | ||||
|         self._aggregate_select_cache = None | ||||
|  | ||||
|         # Arbitrary maximum limit for select_related. Prevents infinite | ||||
|         # recursion. Can be changed by the depth parameter to select_related(). | ||||
| @@ -187,7 +189,15 @@ class BaseQuery(object): | ||||
|         obj.distinct = self.distinct | ||||
|         obj.select_related = self.select_related | ||||
|         obj.related_select_cols = [] | ||||
|         obj.aggregate_select = self.aggregate_select.copy() | ||||
|         obj.aggregates = self.aggregates.copy() | ||||
|         if self.aggregate_select_mask is None: | ||||
|             obj.aggregate_select_mask = None | ||||
|         else: | ||||
|             obj.aggregate_select_mask = self.aggregate_select_mask[:] | ||||
|         if self._aggregate_select_cache is None: | ||||
|             obj._aggregate_select_cache = None | ||||
|         else: | ||||
|             obj._aggregate_select_cache = self._aggregate_select_cache.copy() | ||||
|         obj.max_depth = self.max_depth | ||||
|         obj.extra_select = self.extra_select.copy() | ||||
|         obj.extra_tables = self.extra_tables | ||||
| @@ -940,12 +950,15 @@ class BaseQuery(object): | ||||
|         """ | ||||
|         assert set(change_map.keys()).intersection(set(change_map.values())) == set() | ||||
|  | ||||
|         # 1. Update references in "select" and "where". | ||||
|         # 1. Update references in "select" (normal columns plus aliases), | ||||
|         # "group by", "where" and "having". | ||||
|         self.where.relabel_aliases(change_map) | ||||
|         for pos, col in enumerate(self.select): | ||||
|         self.having.relabel_aliases(change_map) | ||||
|         for columns in (self.select, self.aggregates.values(), self.group_by or []): | ||||
|             for pos, col in enumerate(columns): | ||||
|                 if isinstance(col, (list, tuple)): | ||||
|                     old_alias = col[0] | ||||
|                 self.select[pos] = (change_map.get(old_alias, old_alias), col[1]) | ||||
|                     columns[pos] = (change_map.get(old_alias, old_alias), col[1]) | ||||
|                 else: | ||||
|                     col.relabel_aliases(change_map) | ||||
|  | ||||
| @@ -1205,11 +1218,11 @@ class BaseQuery(object): | ||||
|         opts = model._meta | ||||
|         field_list = aggregate.lookup.split(LOOKUP_SEP) | ||||
|         if (len(field_list) == 1 and | ||||
|             aggregate.lookup in self.aggregate_select.keys()): | ||||
|             aggregate.lookup in self.aggregates.keys()): | ||||
|             # Aggregate is over an annotation | ||||
|             field_name = field_list[0] | ||||
|             col = field_name | ||||
|             source = self.aggregate_select[field_name] | ||||
|             source = self.aggregates[field_name] | ||||
|         elif (len(field_list) > 1 or | ||||
|             field_list[0] not in [i.name for i in opts.fields]): | ||||
|             field, source, opts, join_list, last, _ = self.setup_joins( | ||||
| @@ -1299,7 +1312,7 @@ class BaseQuery(object): | ||||
|             value = SQLEvaluator(value, self) | ||||
|             having_clause = value.contains_aggregate | ||||
|  | ||||
|         for alias, aggregate in self.aggregate_select.items(): | ||||
|         for alias, aggregate in self.aggregates.items(): | ||||
|             if alias == parts[0]: | ||||
|                 entry = self.where_class() | ||||
|                 entry.add((aggregate, lookup_type, value), AND) | ||||
| @@ -1824,8 +1837,8 @@ class BaseQuery(object): | ||||
|         self.group_by = [] | ||||
|         if self.connection.features.allows_group_by_pk: | ||||
|             if len(self.select) == len(self.model._meta.fields): | ||||
|                 self.group_by.append('.'.join([self.model._meta.db_table, | ||||
|                                                self.model._meta.pk.column])) | ||||
|                 self.group_by.append((self.model._meta.db_table, | ||||
|                                       self.model._meta.pk.column)) | ||||
|                 return | ||||
|  | ||||
|         for sel in self.select: | ||||
| @@ -1858,7 +1871,11 @@ class BaseQuery(object): | ||||
|             # Distinct handling is done in Count(), so don't do it at this | ||||
|             # level. | ||||
|             self.distinct = False | ||||
|         self.aggregate_select = {None: count} | ||||
|  | ||||
|         # Set only aggregate to be the count column. | ||||
|         # Clear out the select cache to reflect the new unmasked aggregates. | ||||
|         self.aggregates = {None: count} | ||||
|         self.set_aggregate_mask(None) | ||||
|  | ||||
|     def add_select_related(self, fields): | ||||
|         """ | ||||
| @@ -1920,6 +1937,29 @@ class BaseQuery(object): | ||||
|         for key in set(self.extra_select).difference(set(names)): | ||||
|             del self.extra_select[key] | ||||
|  | ||||
|     def set_aggregate_mask(self, names): | ||||
|         "Set the mask of aggregates that will actually be returned by the SELECT" | ||||
|         self.aggregate_select_mask = names | ||||
|         self._aggregate_select_cache = None | ||||
|  | ||||
|     def _aggregate_select(self): | ||||
|         """The SortedDict of aggregate columns that are not masked, and should | ||||
|         be used in the SELECT clause. | ||||
|  | ||||
|         This result is cached for optimization purposes. | ||||
|         """ | ||||
|         if self._aggregate_select_cache is not None: | ||||
|             return self._aggregate_select_cache | ||||
|         elif self.aggregate_select_mask is not None: | ||||
|             self._aggregate_select_cache = SortedDict([ | ||||
|                 (k,v) for k,v in self.aggregates.items() | ||||
|                 if k in self.aggregate_select_mask | ||||
|             ]) | ||||
|             return self._aggregate_select_cache | ||||
|         else: | ||||
|             return self.aggregates | ||||
|     aggregate_select = property(_aggregate_select) | ||||
|  | ||||
|     def set_start(self, start): | ||||
|         """ | ||||
|         Sets the table from which to start joining. The start position is | ||||
|   | ||||
| @@ -213,10 +213,14 @@ class WhereNode(tree.Node): | ||||
|             elif isinstance(child, tree.Node): | ||||
|                 self.relabel_aliases(change_map, child) | ||||
|             else: | ||||
|                 if isinstance(child[0], (list, tuple)): | ||||
|                     elt = list(child[0]) | ||||
|                     if elt[0] in change_map: | ||||
|                         elt[0] = change_map[elt[0]] | ||||
|                         node.children[pos] = (tuple(elt),) + child[1:] | ||||
|                 else: | ||||
|                     child[0].relabel_aliases(change_map) | ||||
|  | ||||
|                 # Check if the query value also requires relabelling | ||||
|                 if hasattr(child[3], 'relabel_aliases'): | ||||
|                     child[3].relabel_aliases(change_map) | ||||
|   | ||||
| @@ -284,9 +284,6 @@ two authors with the same name, their results will be merged into a single | ||||
| result in the output of the query; the average will be computed as the | ||||
| average over the books written by both authors. | ||||
|  | ||||
| The annotation name will be added to the fields returned | ||||
| as part of the ``ValuesQuerySet``. | ||||
|  | ||||
| Order of ``annotate()`` and ``values()`` clauses | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| @@ -303,12 +300,21 @@ output. | ||||
| For example, if we reverse the order of the ``values()`` and ``annotate()`` | ||||
| clause from our previous example:: | ||||
|  | ||||
|     >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name') | ||||
|     >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name', 'average_rating') | ||||
|  | ||||
| This will now yield one unique result for each author; however, only | ||||
| the author's name and the ``average_rating`` annotation will be returned | ||||
| in the output data. | ||||
|  | ||||
| You should also note that ``average_rating`` has been explicitly included | ||||
| in the list of values to be returned. This is required because of the | ||||
| ordering of the ``values()`` and ``annotate()`` clause. | ||||
|  | ||||
| If the ``values()`` clause precedes the ``annotate()`` clause, any annotations | ||||
| will be automatically added to the result set. However, if the ``values()`` | ||||
| clause is applied after the ``annotate()`` clause, you need to explicitly | ||||
| include the aggregate column. | ||||
|  | ||||
| Aggregating annotations | ||||
| ----------------------- | ||||
|  | ||||
|   | ||||
| @@ -207,10 +207,9 @@ u'The Definitive Guide to Django: Web Development Done Right' | ||||
| >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('pk', 'isbn', 'mean_age') | ||||
| [{'pk': 1, 'isbn': u'159059725', 'mean_age': 34.5}] | ||||
|  | ||||
| # Calling it with paramters reduces the output but does not remove the | ||||
| # annotation. | ||||
| # Calling values() with parameters reduces the output | ||||
| >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('name') | ||||
| [{'name': u'The Definitive Guide to Django: Web Development Done Right', 'mean_age': 34.5}] | ||||
| [{'name': u'The Definitive Guide to Django: Web Development Done Right'}] | ||||
|  | ||||
| # An empty values() call before annotating has the same effect as an | ||||
| # empty values() call after annotating | ||||
|   | ||||
| @@ -95,10 +95,18 @@ __test__ = {'API_TESTS': """ | ||||
| >>> sorted(Book.objects.all().values().annotate(mean_auth_age=Avg('authors__age')).extra(select={'manufacture_cost' : 'price * .5'}).get(pk=2).items()) | ||||
| [('contact_id', 3), ('id', 2), ('isbn', u'067232959'), ('manufacture_cost', ...11.545...), ('mean_auth_age', 45.0), ('name', u'Sams Teach Yourself Django in 24 Hours'), ('pages', 528), ('price', Decimal("23.09")), ('pubdate', datetime.date(2008, 3, 3)), ('publisher_id', 2), ('rating', 3.0)] | ||||
|  | ||||
| # A values query that selects specific columns reduces the output | ||||
| # If the annotation precedes the values clause, it won't be included | ||||
| # unless it is explicitly named | ||||
| >>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name').get(pk=1).items()) | ||||
| [('name', u'The Definitive Guide to Django: Web Development Done Right')] | ||||
|  | ||||
| >>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name','mean_auth_age').get(pk=1).items()) | ||||
| [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] | ||||
|  | ||||
| # If an annotation isn't included in the values, it can still be used in a filter | ||||
| >>> Book.objects.annotate(n_authors=Count('authors')).values('name').filter(n_authors__gt=2) | ||||
| [{'name': u'Python Web Development with Django'}] | ||||
|  | ||||
| # The annotations are added to values output if values() precedes annotate() | ||||
| >>> sorted(Book.objects.all().values('name').annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).get(pk=1).items()) | ||||
| [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')] | ||||
| @@ -207,6 +215,11 @@ FieldError: Cannot resolve keyword 'foo' into field. Choices are: authors, conta | ||||
| >>> Book.objects.extra(select={'pub':'publisher_id','foo':'pages'}).values('pub').annotate(Count('id')).order_by('pub') | ||||
| [{'pub': 1, 'id__count': 2}, {'pub': 2, 'id__count': 1}, {'pub': 3, 'id__count': 2}, {'pub': 4, 'id__count': 1}] | ||||
|  | ||||
| # Regression for #10182 - Queries with aggregate calls are correctly realiased when used in a subquery | ||||
| >>> ids = Book.objects.filter(pages__gt=100).annotate(n_authors=Count('authors')).filter(n_authors__gt=2).order_by('n_authors') | ||||
| >>> Book.objects.filter(id__in=ids) | ||||
| [<Book: Python Web Development with Django>] | ||||
|  | ||||
| # Regression for #10199 - Aggregate calls clone the original query so the original query can still be used | ||||
| >>> books = Book.objects.all() | ||||
| >>> _ = books.aggregate(Avg('authors__age')) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user