1
0
mirror of https://github.com/django/django.git synced 2025-10-26 15:16:09 +00:00

Fixed #3566 -- Added support for aggregation to the ORM. See the documentation for details on usage.

Many thanks to:
 * Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
 * Alex Gaynor for his help debugging and fixing a number of issues.
 * Justin Bronn for his help integrating with contrib.gis.
 * Karen Tracey for her help with cross-platform testing.
 * Ian Kelly for his help testing and fixing Oracle support.
 * Malcolm Tredinnick for his invaluable review notes.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9742 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee
2009-01-15 11:06:34 +00:00
parent 50a293a0c3
commit cc4e4d9aee
30 changed files with 2357 additions and 325 deletions

View File

@@ -12,12 +12,13 @@ from copy import deepcopy
from django.utils.tree import Node
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode
from django.db.backends.util import truncate_name
from django.db import connection
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import select_related_descend
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
from django.db.models.sql.datastructures import Count
from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin
from constants import *
@@ -40,6 +41,7 @@ class BaseQuery(object):
alias_prefix = 'T'
query_terms = QUERY_TERMS
aggregates_module = base_aggregates_module
def __init__(self, model, connection, where=WhereNode):
self.model = model
@@ -73,6 +75,9 @@ class BaseQuery(object):
self.select_related = False
self.related_select_cols = []
# SQL aggregate-related attributes
self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
# Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related().
self.max_depth = 5
@@ -178,6 +183,7 @@ class BaseQuery(object):
obj.distinct = self.distinct
obj.select_related = self.select_related
obj.related_select_cols = []
obj.aggregate_select = self.aggregate_select.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra_tables = self.extra_tables
@@ -194,6 +200,35 @@ class BaseQuery(object):
obj._setup_query()
return obj
def convert_values(self, value, field):
"""Convert the database-returned value into a type that is consistent
across database backends.
By default, this defers to the underlying backend operations, but
it can be overridden by Query classes for specific backends.
"""
return self.connection.ops.convert_values(value, field)
def resolve_aggregate(self, value, aggregate):
"""Resolve the value of aggregates returned by the database to
consistent (and reasonable) types.
This is required because of the predisposition of certain backends
to return Decimal and long types when they are not needed.
"""
if value is None:
# Return None as-is
return value
elif aggregate.is_ordinal:
# Any ordinal aggregate (e.g., count) returns an int
return int(value)
elif aggregate.is_computed:
# Any computed aggregate (e.g., avg) returns a float
return float(value)
else:
# Return value depends on the type of the field being processed.
return self.convert_values(value, aggregate.field)
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
@@ -212,29 +247,78 @@ class BaseQuery(object):
else:
fields = self.model._meta.fields
row = self.resolve_columns(row, fields)
if self.aggregate_select:
aggregate_start = len(self.extra_select.keys()) + len(self.select)
row = tuple(row[:aggregate_start]) + tuple([
self.resolve_aggregate(value, aggregate)
for (alias, aggregate), value
in zip(self.aggregate_select.items(), row[aggregate_start:])
])
yield row
def get_aggregation(self):
"""
Returns the dictionary with the values of the existing aggregations.
"""
if not self.aggregate_select:
return {}
# If there is a group by clause, aggregating does not add useful
# information but retrieves only the first row. Aggregate
# over the subquery instead.
if self.group_by:
from subqueries import AggregateQuery
query = AggregateQuery(self.model, self.connection)
obj = self.clone()
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
for alias, aggregate in self.aggregate_select.items():
if aggregate.is_summary:
query.aggregate_select[alias] = aggregate
del obj.aggregate_select[alias]
query.add_subquery(obj)
else:
query = self
self.select = []
self.default_cols = False
self.extra_select = {}
query.clear_ordering(True)
query.clear_limits()
query.select_related = False
query.related_select_cols = []
query.related_select_fields = []
return dict([
(alias, self.resolve_aggregate(val, aggregate))
for (alias, aggregate), val
in zip(query.aggregate_select.items(), query.execute_sql(SINGLE))
])
def get_count(self):
"""
Performs a COUNT() query using the current filter constraints.
"""
from subqueries import CountQuery
obj = self.clone()
obj.clear_ordering(True)
obj.clear_limits()
obj.select_related = False
obj.related_select_cols = []
obj.related_select_fields = []
if len(obj.select) > 1:
obj = self.clone(CountQuery, _query=obj, where=self.where_class(),
distinct=False)
obj.select = []
obj.extra_select = SortedDict()
if len(self.select) > 1:
# If a select clause exists, then the query has already started to
# specify the columns that are to be returned.
# In this case, we need to use a subquery to evaluate the count.
from subqueries import AggregateQuery
subquery = obj
subquery.clear_ordering(True)
subquery.clear_limits()
obj = AggregateQuery(obj.model, obj.connection)
obj.add_subquery(subquery)
obj.add_count_column()
data = obj.execute_sql(SINGLE)
if not data:
return 0
number = data[0]
number = obj.get_aggregation()[None]
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
# in SQL (in variants that provide them) doesn't change the COUNT
@@ -450,25 +534,41 @@ class BaseQuery(object):
for col in self.select:
if isinstance(col, (list, tuple)):
r = '%s.%s' % (qn(col[0]), qn(col[1]))
if with_aliases and col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
if with_aliases:
if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append('%s AS %s' % (r, col[1]))
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(r)
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(quote_func=qn))
if hasattr(col, 'alias'):
aliases.add(col.alias)
col_aliases.add(col.alias)
elif self.default_cols:
cols, new_aliases = self.get_default_columns(with_aliases,
col_aliases)
result.extend(cols)
aliases.update(new_aliases)
result.extend([
'%s%s' % (
aggregate.as_sql(quote_func=qn),
alias is not None and ' AS %s' % qn(alias) or ''
)
for alias, aggregate in self.aggregate_select.items()
])
for table, col in self.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
if with_aliases and col in col_aliases:
@@ -538,7 +638,7 @@ class BaseQuery(object):
Returns a list of strings that are joined together to go after the
"FROM" part of the query, as well as a list any extra parameters that
need to be included. Sub-classes, can override this to create a
from-clause via a "select", for example (e.g. CountQuery).
from-clause via a "select".
This should only be called after any SQL construction methods that
might change the tables we need. This means the select columns and
@@ -635,10 +735,13 @@ class BaseQuery(object):
order = asc
result.append('%s %s' % (field, order))
continue
col, order = get_order_dir(field, asc)
if col in self.aggregate_select:
result.append('%s %s' % (col, order))
continue
if '.' in field:
# This came in through an extra(order_by=...) addition. Pass it
# on verbatim.
col, order = get_order_dir(field, asc)
table, col = col.split('.', 1)
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), col)
@@ -657,7 +760,6 @@ class BaseQuery(object):
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
else:
col, order = get_order_dir(field, asc)
elt = qn2(col)
if distinct and col not in select_aliases:
ordering_aliases.append(elt)
@@ -1068,6 +1170,48 @@ class BaseQuery(object):
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid)
def add_aggregate(self, aggregate, model, alias, is_summary):
"""
Adds a single aggregate expression to the Query
"""
opts = model._meta
field_list = aggregate.lookup.split(LOOKUP_SEP)
if (len(field_list) == 1 and
aggregate.lookup in self.aggregate_select.keys()):
# Aggregate is over an annotation
field_name = field_list[0]
col = field_name
source = self.aggregate_select[field_name]
elif (len(field_list) > 1 or
field_list[0] not in [i.name for i in opts.fields]):
field, source, opts, join_list, last, _ = self.setup_joins(
field_list, opts, self.get_initial_alias(), False)
# Process the join chain to see if it can be trimmed
_, _, col, _, join_list = self.trim_joins(source, join_list, last, False)
# If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned
# in order for zeros to be returned for those aggregates.
for column_alias in join_list:
self.promote_alias(column_alias, unconditional=True)
col = (join_list[-1], col)
else:
# Aggregate references a normal field
field_name = field_list[0]
source = opts.get_field(field_name)
if not (self.group_by and is_summary):
# Only use a column alias if this is a
# standalone aggregate, or an annotation
col = (opts.db_table, source.column)
else:
col = field_name
# Add the aggregate to the query
alias = truncate_name(alias, self.connection.ops.max_name_length())
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None, process_extras=True):
"""
@@ -1119,6 +1263,11 @@ class BaseQuery(object):
elif callable(value):
value = value()
for alias, aggregate in self.aggregate_select.items():
if alias == parts[0]:
self.having.add((aggregate, lookup_type, value), AND)
return
opts = self.get_meta()
alias = self.get_initial_alias()
allow_many = trim or not negate
@@ -1131,38 +1280,9 @@ class BaseQuery(object):
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse)
return
final = len(join_list)
penultimate = last.pop()
if penultimate == final:
penultimate = last.pop()
if trim and len(join_list) > 1:
extra = join_list[penultimate:]
join_list = join_list[:penultimate]
final = penultimate
penultimate = last.pop()
col = self.alias_map[extra[0]][LHS_JOIN_COL]
for alias in extra:
self.unref_alias(alias)
else:
col = target.column
alias = join_list[-1]
while final > 1:
# An optimization: if the final join is against the same column as
# we are comparing against, we can go back one step in the join
# chain and compare against the lhs of the join instead (and then
# repeat the optimization). The result, potentially, involves less
# table joins.
join = self.alias_map[alias]
if col != join[RHS_JOIN_COL]:
break
self.unref_alias(alias)
alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL]
join_list = join_list[:-1]
final -= 1
if final == penultimate:
penultimate = last.pop()
# Process the join chain to see if it can be trimmed
final, penultimate, col, alias, join_list = self.trim_joins(target, join_list, last, trim)
if (lookup_type == 'isnull' and value is True and not negate and
final > 1):
@@ -1313,7 +1433,7 @@ class BaseQuery(object):
field, model, direct, m2m = opts.get_field_by_name(f.name)
break
else:
names = opts.get_all_field_names()
names = opts.get_all_field_names() + self.aggregate_select.keys()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@@ -1462,6 +1582,43 @@ class BaseQuery(object):
return field, target, opts, joins, last, extra_filters
def trim_joins(self, target, join_list, last, trim):
"""An optimization: if the final join is against the same column as
we are comparing against, we can go back one step in a join
chain and compare against the LHS of the join instead (and then
repeat the optimization). The result, potentially, involves less
table joins.
Returns a tuple
"""
final = len(join_list)
penultimate = last.pop()
if penultimate == final:
penultimate = last.pop()
if trim and len(join_list) > 1:
extra = join_list[penultimate:]
join_list = join_list[:penultimate]
final = penultimate
penultimate = last.pop()
col = self.alias_map[extra[0]][LHS_JOIN_COL]
for alias in extra:
self.unref_alias(alias)
else:
col = target.column
alias = join_list[-1]
while final > 1:
join = self.alias_map[alias]
if col != join[RHS_JOIN_COL]:
break
self.unref_alias(alias)
alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL]
join_list = join_list[:-1]
final -= 1
if final == penultimate:
penultimate = last.pop()
return final, penultimate, col, alias, join_list
def update_dupe_avoidance(self, opts, col, alias):
"""
For a column that is one of multiple pointing to the same table, update
@@ -1554,6 +1711,7 @@ class BaseQuery(object):
"""
alias = self.get_initial_alias()
opts = self.get_meta()
try:
for name in field_names:
field, target, u2, joins, u3, u4 = self.setup_joins(
@@ -1574,7 +1732,7 @@ class BaseQuery(object):
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
names = opts.get_all_field_names() + self.extra_select.keys()
names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
@@ -1609,38 +1767,52 @@ class BaseQuery(object):
if force_empty:
self.default_ordering = False
def set_group_by(self):
"""
Expands the GROUP BY clause required by the query.
This will usually be the set of all non-aggregate fields in the
return data. If the database backend supports grouping by the
primary key, and the query would be equivalent, the optimization
will be made automatically.
"""
if self.connection.features.allows_group_by_pk:
if len(self.select) == len(self.model._meta.fields):
self.group_by.append('.'.join([self.model._meta.db_table,
self.model._meta.pk.column]))
return
for sel in self.select:
self.group_by.append(sel)
def add_count_column(self):
"""
Converts the query to do count(...) or count(distinct(pk)) in order to
get its size.
"""
# TODO: When group_by support is added, this needs to be adjusted so
# that it doesn't totally overwrite the select list.
if not self.distinct:
if not self.select:
select = Count()
count = self.aggregates_module.Count('*', is_summary=True)
else:
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select
select = Count(self.select[0])
count = self.aggregates_module.Count(self.select[0])
else:
opts = self.model._meta
if not self.select:
select = Count((self.join((None, opts.db_table, None, None)),
opts.pk.column), True)
count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
is_summary=True, distinct=True)
else:
# Because of SQL portability issues, multi-column, distinct
# counts need a sub-query -- see get_count() for details.
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'."
select = Count(self.select[0], True)
count = self.aggregates_module.Count(self.select[0], distinct=True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
self.select = [select]
self.select_fields = [None]
self.extra_select = {}
self.aggregate_select = {None: count}
def add_select_related(self, fields):
"""
@@ -1758,7 +1930,6 @@ class BaseQuery(object):
return empty_iter()
else:
return
cursor = self.connection.cursor()
cursor.execute(sql, params)