mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #24171 -- Fixed failure with complex aggregate query and expressions
The query used a construct of qs.annotate().values().aggregate() where the first annotate used an F-object reference and the values() and aggregate() calls referenced that F-object. Also made sure the inner query's select clause is as simple as possible, and made sure .values().distinct().aggreate() works correctly.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							63f2dd4ad7
						
					
				
				
					commit
					fb146193c4
				
			| @@ -351,7 +351,7 @@ class Query(object): | |||||||
|                 # is selected. |                 # is selected. | ||||||
|                 col_cnt += 1 |                 col_cnt += 1 | ||||||
|                 col_alias = '__col%d' % col_cnt |                 col_alias = '__col%d' % col_cnt | ||||||
|                 self.annotation_select[col_alias] = expr |                 self.annotations[col_alias] = expr | ||||||
|                 self.append_annotation_mask([col_alias]) |                 self.append_annotation_mask([col_alias]) | ||||||
|                 new_exprs.append(Ref(col_alias, expr)) |                 new_exprs.append(Ref(col_alias, expr)) | ||||||
|             else: |             else: | ||||||
| @@ -390,10 +390,22 @@ class Query(object): | |||||||
|             from django.db.models.sql.subqueries import AggregateQuery |             from django.db.models.sql.subqueries import AggregateQuery | ||||||
|             outer_query = AggregateQuery(self.model) |             outer_query = AggregateQuery(self.model) | ||||||
|             inner_query = self.clone() |             inner_query = self.clone() | ||||||
|             if not has_limit and not self.distinct_fields: |  | ||||||
|                 inner_query.clear_ordering(True) |  | ||||||
|             inner_query.select_for_update = False |             inner_query.select_for_update = False | ||||||
|             inner_query.select_related = False |             inner_query.select_related = False | ||||||
|  |             if not has_limit and not self.distinct_fields: | ||||||
|  |                 # Queries with distinct_fields need ordering and when a limit | ||||||
|  |                 # is applied we must take the slice from the ordered query. | ||||||
|  |                 # Otherwise no need for ordering. | ||||||
|  |                 inner_query.clear_ordering(True) | ||||||
|  |             if not inner_query.distinct: | ||||||
|  |                 # If the inner query uses default select and it has some | ||||||
|  |                 # aggregate annotations, then we must make sure the inner | ||||||
|  |                 # query is grouped by the main model's primary key. However, | ||||||
|  |                 # clearing the select clause can alter results if distinct is | ||||||
|  |                 # used. | ||||||
|  |                 if inner_query.default_cols and has_existing_annotations: | ||||||
|  |                     inner_query.group_by = [self.model._meta.pk.get_col(inner_query.get_initial_alias())] | ||||||
|  |                 inner_query.default_cols = False | ||||||
|  |  | ||||||
|             relabels = {t: 'subquery' for t in inner_query.tables} |             relabels = {t: 'subquery' for t in inner_query.tables} | ||||||
|             relabels[None] = 'subquery' |             relabels[None] = 'subquery' | ||||||
| @@ -404,7 +416,14 @@ class Query(object): | |||||||
|                 if expression.is_summary: |                 if expression.is_summary: | ||||||
|                     expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) |                     expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) | ||||||
|                     outer_query.annotations[alias] = expression.relabeled_clone(relabels) |                     outer_query.annotations[alias] = expression.relabeled_clone(relabels) | ||||||
|                     del inner_query.annotation_select[alias] |                     del inner_query.annotations[alias] | ||||||
|  |                 # Make sure the annotation_select wont use cached results. | ||||||
|  |                 inner_query.set_annotation_mask(inner_query.annotation_select_mask) | ||||||
|  |             if inner_query.select == [] and not inner_query.default_cols and not inner_query.annotation_select_mask: | ||||||
|  |                 # In case of Model.objects[0:3].count(), there would be no | ||||||
|  |                 # field selected in the inner query, yet we must use a subquery. | ||||||
|  |                 # So, make sure at least one field is selected. | ||||||
|  |                 inner_query.select = [self.model._meta.pk.get_col(inner_query.get_initial_alias())] | ||||||
|             try: |             try: | ||||||
|                 outer_query.add_subquery(inner_query, using) |                 outer_query.add_subquery(inner_query, using) | ||||||
|             except EmptyResultSet: |             except EmptyResultSet: | ||||||
|   | |||||||
| @@ -7,7 +7,9 @@ from operator import attrgetter | |||||||
|  |  | ||||||
| from django.contrib.contenttypes.models import ContentType | from django.contrib.contenttypes.models import ContentType | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db.models import F, Q, Avg, Count, Max, StdDev, Sum, Variance | from django.db.models import ( | ||||||
|  |     F, Q, Avg, Count, Max, StdDev, Sum, Value, Variance, | ||||||
|  | ) | ||||||
| from django.test import TestCase, skipUnlessDBFeature | from django.test import TestCase, skipUnlessDBFeature | ||||||
| from django.test.utils import Approximate | from django.test.utils import Approximate | ||||||
| from django.utils import six | from django.utils import six | ||||||
| @@ -1232,6 +1234,14 @@ class AggregationTests(TestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(qs['publisher_awards'], 30) |         self.assertEqual(qs['publisher_awards'], 30) | ||||||
|  |  | ||||||
|  |     def test_annotate_distinct_aggregate(self): | ||||||
|  |         # There are three books with rating of 4.0 and two of the books have | ||||||
|  |         # the same price. Hence, the distinct removes one rating of 4.0 | ||||||
|  |         # from the results. | ||||||
|  |         vals1 = Book.objects.values('rating', 'price').distinct().aggregate(result=Sum('rating')) | ||||||
|  |         vals2 = Book.objects.aggregate(result=Sum('rating') - Value(4.0)) | ||||||
|  |         self.assertEqual(vals1, vals2) | ||||||
|  |  | ||||||
|  |  | ||||||
| class JoinPromotionTests(TestCase): | class JoinPromotionTests(TestCase): | ||||||
|     def test_ticket_21150(self): |     def test_ticket_21150(self): | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ from django.utils.encoding import python_2_unicode_compatible | |||||||
| class Employee(models.Model): | class Employee(models.Model): | ||||||
|     firstname = models.CharField(max_length=50) |     firstname = models.CharField(max_length=50) | ||||||
|     lastname = models.CharField(max_length=50) |     lastname = models.CharField(max_length=50) | ||||||
|  |     salary = models.IntegerField(blank=True, null=True) | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return '%s %s' % (self.firstname, self.lastname) |         return '%s %s' % (self.firstname, self.lastname) | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import uuid | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
| from django.core.exceptions import FieldError | from django.core.exceptions import FieldError | ||||||
| from django.db import DatabaseError, connection, transaction | from django.db import DatabaseError, connection, models, transaction | ||||||
| from django.db.models import TimeField, UUIDField | from django.db.models import TimeField, UUIDField | ||||||
| from django.db.models.aggregates import ( | from django.db.models.aggregates import ( | ||||||
|     Avg, Count, Max, Min, StdDev, Sum, Variance, |     Avg, Count, Max, Min, StdDev, Sum, Variance, | ||||||
| @@ -30,15 +30,15 @@ class BasicExpressionsTests(TestCase): | |||||||
|     def setUpTestData(cls): |     def setUpTestData(cls): | ||||||
|         Company.objects.create( |         Company.objects.create( | ||||||
|             name="Example Inc.", num_employees=2300, num_chairs=5, |             name="Example Inc.", num_employees=2300, num_chairs=5, | ||||||
|             ceo=Employee.objects.create(firstname="Joe", lastname="Smith") |             ceo=Employee.objects.create(firstname="Joe", lastname="Smith", salary=10) | ||||||
|         ) |         ) | ||||||
|         Company.objects.create( |         Company.objects.create( | ||||||
|             name="Foobar Ltd.", num_employees=3, num_chairs=4, |             name="Foobar Ltd.", num_employees=3, num_chairs=4, | ||||||
|             ceo=Employee.objects.create(firstname="Frank", lastname="Meyer") |             ceo=Employee.objects.create(firstname="Frank", lastname="Meyer", salary=20) | ||||||
|         ) |         ) | ||||||
|         Company.objects.create( |         Company.objects.create( | ||||||
|             name="Test GmbH", num_employees=32, num_chairs=1, |             name="Test GmbH", num_employees=32, num_chairs=1, | ||||||
|             ceo=Employee.objects.create(firstname="Max", lastname="Mustermann") |             ceo=Employee.objects.create(firstname="Max", lastname="Mustermann", salary=30) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
| @@ -48,6 +48,15 @@ class BasicExpressionsTests(TestCase): | |||||||
|             "name", "num_employees", "num_chairs" |             "name", "num_employees", "num_chairs" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_annotate_values_aggregate(self): | ||||||
|  |         companies = Company.objects.annotate( | ||||||
|  |             salaries=F('ceo__salary'), | ||||||
|  |         ).values('num_employees', 'salaries').aggregate( | ||||||
|  |             result=Sum(F('salaries') + F('num_employees'), | ||||||
|  |             output_field=models.IntegerField()), | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(companies['result'], 2395) | ||||||
|  |  | ||||||
|     def test_filter_inter_attribute(self): |     def test_filter_inter_attribute(self): | ||||||
|         # We can filter on attribute relationships on same model obj, e.g. |         # We can filter on attribute relationships on same model obj, e.g. | ||||||
|         # find companies where the number of employees is greater |         # find companies where the number of employees is greater | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user