1
0
mirror of https://github.com/django/django.git synced 2025-10-25 22:56:12 +00:00

Fixed #10182 -- Corrected realiasing and the process of evaluating values() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.

This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee
2009-02-23 14:47:59 +00:00
parent 4bd24474c0
commit 542709d0d1
7 changed files with 102 additions and 32 deletions

View File

@@ -77,7 +77,9 @@ class BaseQuery(object):
self.related_select_cols = []
# SQL aggregate-related attributes
self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
self.aggregate_select_mask = None
self._aggregate_select_cache = None
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
@@ -187,7 +189,15 @@ class BaseQuery(object):
obj.distinct = self.distinct
obj.select_related = self.select_related
obj.related_select_cols = []
obj.aggregate_select = self.aggregate_select.copy()
obj.aggregates = self.aggregates.copy()
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
else:
obj.aggregate_select_mask = self.aggregate_select_mask[:]
if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None
else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra_tables = self.extra_tables
@@ -940,14 +950,17 @@ class BaseQuery(object):
"""
assert set(change_map.keys()).intersection(set(change_map.values())) == set()
# 1. Update references in "select" and "where".
# 1. Update references in "select" (normal columns plus aliases),
# "group by", "where" and "having".
self.where.relabel_aliases(change_map)
for pos, col in enumerate(self.select):
if isinstance(col, (list, tuple)):
old_alias = col[0]
self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
else:
col.relabel_aliases(change_map)
self.having.relabel_aliases(change_map)
for columns in (self.select, self.aggregates.values(), self.group_by or []):
for pos, col in enumerate(columns):
if isinstance(col, (list, tuple)):
old_alias = col[0]
columns[pos] = (change_map.get(old_alias, old_alias), col[1])
else:
col.relabel_aliases(change_map)
# 2. Rename the alias in the internal table/alias datastructures.
for old_alias, new_alias in change_map.iteritems():
@@ -1205,11 +1218,11 @@ class BaseQuery(object):
opts = model._meta
field_list = aggregate.lookup.split(LOOKUP_SEP)
if (len(field_list) == 1 and
aggregate.lookup in self.aggregate_select.keys()):
aggregate.lookup in self.aggregates.keys()):
# Aggregate is over an annotation
field_name = field_list[0]
col = field_name
source = self.aggregate_select[field_name]
source = self.aggregates[field_name]
elif (len(field_list) > 1 or
field_list[0] not in [i.name for i in opts.fields]):
field, source, opts, join_list, last, _ = self.setup_joins(
@@ -1299,7 +1312,7 @@ class BaseQuery(object):
value = SQLEvaluator(value, self)
having_clause = value.contains_aggregate
for alias, aggregate in self.aggregate_select.items():
for alias, aggregate in self.aggregates.items():
if alias == parts[0]:
entry = self.where_class()
entry.add((aggregate, lookup_type, value), AND)
@@ -1824,8 +1837,8 @@ class BaseQuery(object):
self.group_by = []
if self.connection.features.allows_group_by_pk:
if len(self.select) == len(self.model._meta.fields):
self.group_by.append('.'.join([self.model._meta.db_table,
self.model._meta.pk.column]))
self.group_by.append((self.model._meta.db_table,
self.model._meta.pk.column))
return
for sel in self.select:
@@ -1858,7 +1871,11 @@ class BaseQuery(object):
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
self.aggregate_select = {None: count}
# Set only aggregate to be the count column.
# Clear out the select cache to reflect the new unmasked aggregates.
self.aggregates = {None: count}
self.set_aggregate_mask(None)
def add_select_related(self, fields):
"""
@@ -1920,6 +1937,29 @@ class BaseQuery(object):
for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key]
def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT"
self.aggregate_select_mask = names
self._aggregate_select_cache = None
def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
This result is cached for optimization purposes.
"""
if self._aggregate_select_cache is not None:
return self._aggregate_select_cache
elif self.aggregate_select_mask is not None:
self._aggregate_select_cache = SortedDict([
(k,v) for k,v in self.aggregates.items()
if k in self.aggregate_select_mask
])
return self._aggregate_select_cache
else:
return self.aggregates
aggregate_select = property(_aggregate_select)
def set_start(self, start):
"""
Sets the table from which to start joining. The start position is