1
0
mirror of https://github.com/django/django.git synced 2025-10-24 14:16:09 +00:00

Fixed #35444 -- Added generic support for Aggregate.order_by.

This moves the behaviors of `order_by` used in Postgres aggregates into
the `Aggregate` class. This allows for creating aggregate functions that
support this behavior across all database engines. This is shown by
moving the `StringAgg` class into the shared `aggregates` module and
adding support for all databases. The Postgres `StringAgg` class is now
a thin wrapper on the new shared `StringAgg` class.

Thank you Simon Charette for the review.
This commit is contained in:
Chris Muthig
2024-12-22 16:30:55 +01:00
committed by Sarah Boyce
parent 6d1cf5375f
commit 4b977a5d72
19 changed files with 659 additions and 291 deletions

View File

@@ -33,15 +33,14 @@ class GeoAggregate(Aggregate):
if not self.is_extent: if not self.is_extent:
tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05) tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
clone = self.copy() clone = self.copy()
source_expressions = self.get_source_expressions() *source_exprs, filter_expr, order_by_expr = self.get_source_expressions()
source_expressions.pop() # Don't wrap filters with SDOAGGRTYPE().
spatial_type_expr = Func( spatial_type_expr = Func(
*source_expressions, *source_exprs,
Value(tolerance), Value(tolerance),
function="SDOAGGRTYPE", function="SDOAGGRTYPE",
output_field=self.output_field, output_field=self.output_field,
) )
source_expressions = [spatial_type_expr, self.filter] source_expressions = [spatial_type_expr, filter_expr, order_by_expr]
clone.set_source_expressions(source_expressions) clone.set_source_expressions(source_expressions)
return clone.as_sql(compiler, connection, **extra_context) return clone.as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context) return self.as_sql(compiler, connection, **extra_context)

View File

@@ -1,7 +1,12 @@
from django.contrib.postgres.fields import ArrayField import warnings
from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value
from .mixins import OrderableAggMixin from django.contrib.postgres.fields import ArrayField
from django.db.models import Aggregate, BooleanField, JSONField
from django.db.models import StringAgg as _StringAgg
from django.db.models import Value
from django.utils.deprecation import RemovedInDjango70Warning
from .mixins import _DeprecatedOrdering
__all__ = [ __all__ = [
"ArrayAgg", "ArrayAgg",
@@ -11,14 +16,16 @@ __all__ = [
"BoolAnd", "BoolAnd",
"BoolOr", "BoolOr",
"JSONBAgg", "JSONBAgg",
"StringAgg", "StringAgg", # RemovedInDjango70Warning.
] ]
class ArrayAgg(OrderableAggMixin, Aggregate): # RemovedInDjango61Warning: When the deprecation ends, replace with:
# class ArrayAgg(Aggregate):
class ArrayAgg(_DeprecatedOrdering, Aggregate):
function = "ARRAY_AGG" function = "ARRAY_AGG"
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
allow_distinct = True allow_distinct = True
allow_order_by = True
@property @property
def output_field(self): def output_field(self):
@@ -47,19 +54,37 @@ class BoolOr(Aggregate):
output_field = BooleanField() output_field = BooleanField()
class JSONBAgg(OrderableAggMixin, Aggregate): # RemovedInDjango61Warning: When the deprecation ends, replace with:
# class JSONBAgg(Aggregate):
class JSONBAgg(_DeprecatedOrdering, Aggregate):
function = "JSONB_AGG" function = "JSONB_AGG"
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
allow_distinct = True allow_distinct = True
allow_order_by = True
output_field = JSONField() output_field = JSONField()
class StringAgg(OrderableAggMixin, Aggregate): # RemovedInDjango61Warning: When the deprecation ends, replace with:
function = "STRING_AGG" # class StringAgg(_StringAgg):
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)" # RemovedInDjango70Warning: When the deprecation ends, remove completely.
allow_distinct = True class StringAgg(_DeprecatedOrdering, _StringAgg):
output_field = TextField()
def __init__(self, expression, delimiter, **extra): def __init__(self, expression, delimiter, **extra):
delimiter_expr = Value(str(delimiter)) if isinstance(delimiter, str):
super().__init__(expression, delimiter_expr, **extra) warnings.warn(
"delimiter: str will be resolved as a field reference instead "
"of a string literal on Django 7.0. Pass "
f"`delimiter=Value({delimiter!r})` to preserve the previous behaviour.",
category=RemovedInDjango70Warning,
stacklevel=2,
)
delimiter = Value(delimiter)
warnings.warn(
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead.",
category=RemovedInDjango70Warning,
stacklevel=2,
)
super().__init__(expression, delimiter, **extra)

View File

@@ -1,15 +1,11 @@
import warnings import warnings
from django.core.exceptions import FullResultSet
from django.db.models.expressions import OrderByList
from django.utils.deprecation import RemovedInDjango61Warning from django.utils.deprecation import RemovedInDjango61Warning
class OrderableAggMixin: # RemovedInDjango61Warning.
# RemovedInDjango61Warning: When the deprecation ends, replace with: class _DeprecatedOrdering:
# def __init__(self, *expressions, order_by=(), **extra):
def __init__(self, *expressions, ordering=(), order_by=(), **extra): def __init__(self, *expressions, ordering=(), order_by=(), **extra):
# RemovedInDjango61Warning.
if ordering: if ordering:
warnings.warn( warnings.warn(
"The ordering argument is deprecated. Use order_by instead.", "The ordering argument is deprecated. Use order_by instead.",
@@ -19,44 +15,14 @@ class OrderableAggMixin:
if order_by: if order_by:
raise TypeError("Cannot specify both order_by and ordering.") raise TypeError("Cannot specify both order_by and ordering.")
order_by = ordering order_by = ordering
if not order_by:
self.order_by = None
elif isinstance(order_by, (list, tuple)):
self.order_by = OrderByList(*order_by)
else:
self.order_by = OrderByList(order_by)
super().__init__(*expressions, **extra)
def resolve_expression(self, *args, **kwargs): super().__init__(*expressions, order_by=order_by, **extra)
if self.order_by is not None:
self.order_by = self.order_by.resolve_expression(*args, **kwargs)
return super().resolve_expression(*args, **kwargs)
def get_source_expressions(self):
return super().get_source_expressions() + [self.order_by]
def set_source_expressions(self, exprs): # RemovedInDjango61Warning: When the deprecation ends, replace with:
*exprs, self.order_by = exprs # class OrderableAggMixin:
return super().set_source_expressions(exprs) class OrderableAggMixin(_DeprecatedOrdering):
allow_order_by = True
def as_sql(self, compiler, connection): def __init_subclass__(cls, /, *args, **kwargs):
*source_exprs, filtering_expr, order_by_expr = self.get_source_expressions() super().__init_subclass__(*args, **kwargs)
order_by_sql = ""
order_by_params = []
if order_by_expr is not None:
order_by_sql, order_by_params = compiler.compile(order_by_expr)
filter_params = []
if filtering_expr is not None:
try:
_, filter_params = compiler.compile(filtering_expr)
except FullResultSet:
pass
source_params = []
for source_expr in source_exprs:
source_params += compiler.compile(source_expr)[1]
sql, _ = super().as_sql(compiler, connection, order_by=order_by_sql)
return sql, (*source_params, *order_by_params, *filter_params)

View File

@@ -257,6 +257,15 @@ class BaseDatabaseFeatures:
# expressions? # expressions?
supports_aggregate_filter_clause = False supports_aggregate_filter_clause = False
# Does the database support ORDER BY in aggregate expressions?
supports_aggregate_order_by_clause = False
# Does the database backend support DISTINCT when using multiple arguments in an
# aggregate expression? For example, Sqlite treats the "delimiter" argument of
# STRING_AGG/GROUP_CONCAT as an extra argument and does not allow using a custom
# delimiter along with DISTINCT.
supports_aggregate_distinct_multiple_argument = True
# Does the backend support indexing a TextField? # Does the backend support indexing a TextField?
supports_index_on_text_field = True supports_index_on_text_field = True

View File

@@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_explicit_null_ordering_when_grouping = True requires_explicit_null_ordering_when_grouping = True
atomic_transactions = False atomic_transactions = False
can_clone_databases = True can_clone_databases = True
supports_aggregate_order_by_clause = True
supports_comments = True supports_comments = True
supports_comments_inline = True supports_comments_inline = True
supports_temporal_subtraction = True supports_temporal_subtraction = True

View File

@@ -45,6 +45,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# does by uppercasing all identifiers. # does by uppercasing all identifiers.
ignores_table_name_case = True ignores_table_name_case = True
supports_index_on_text_field = False supports_index_on_text_field = False
supports_aggregate_order_by_clause = True
create_test_procedure_without_params_sql = """ create_test_procedure_without_params_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" AS CREATE PROCEDURE "TEST_PROCEDURE" AS
V_I INTEGER; V_I INTEGER;

View File

@@ -64,6 +64,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_exclusion = True supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True supports_aggregate_filter_clause = True
supports_aggregate_order_by_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"} supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
supports_deferrable_unique_constraints = True supports_deferrable_unique_constraints = True
has_json_operators = True has_json_operators = True

View File

@@ -34,6 +34,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_range_fixed_distance = True supports_frame_range_fixed_distance = True
supports_frame_exclusion = True supports_frame_exclusion = True
supports_aggregate_filter_clause = True supports_aggregate_filter_clause = True
supports_aggregate_order_by_clause = Database.sqlite_version_info >= (3, 44, 0)
supports_aggregate_distinct_multiple_argument = False
order_by_nulls_first = True order_by_nulls_first = True
supports_json_field_contains = False supports_json_field_contains = False
supports_update_conflicts = True supports_update_conflicts = True

View File

@@ -3,8 +3,17 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError, FullResultSet from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When from django.db import NotSupportedError
from django.db.models.fields import IntegerField from django.db.models.expressions import (
Case,
ColPairs,
Func,
OrderByList,
Star,
Value,
When,
)
from django.db.models.fields import IntegerField, TextField
from django.db.models.functions import Coalesce from django.db.models.functions import Coalesce
from django.db.models.functions.mixins import ( from django.db.models.functions.mixins import (
FixDurationInputMixin, FixDurationInputMixin,
@@ -18,42 +27,91 @@ __all__ = [
"Max", "Max",
"Min", "Min",
"StdDev", "StdDev",
"StringAgg",
"Sum", "Sum",
"Variance", "Variance",
] ]
class AggregateFilter(Func):
arity = 1
template = " FILTER (WHERE %(expressions)s)"
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_aggregate_filter_clause:
raise NotSupportedError(
"Aggregate filter clauses are not supported on this database backend."
)
try:
return super().as_sql(compiler, connection, **extra_context)
except FullResultSet:
return "", ()
@property
def condition(self):
return self.source_expressions[0]
def __str__(self):
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
class AggregateOrderBy(OrderByList):
template = " ORDER BY %(expressions)s"
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_aggregate_order_by_clause:
raise NotSupportedError(
"This database backend does not support specifying an order on "
"aggregates."
)
return super().as_sql(compiler, connection, **extra_context)
class Aggregate(Func): class Aggregate(Func):
template = "%(function)s(%(distinct)s%(expressions)s)" template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
contains_aggregate = True contains_aggregate = True
name = None name = None
filter_template = "%s FILTER (WHERE %%(filter)s)"
window_compatible = True window_compatible = True
allow_distinct = False allow_distinct = False
allow_order_by = False
empty_result_set_value = None empty_result_set_value = None
def __init__( def __init__(
self, *expressions, distinct=False, filter=None, default=None, **extra self,
*expressions,
distinct=False,
filter=None,
default=None,
order_by=None,
**extra,
): ):
if distinct and not self.allow_distinct: if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__) raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if order_by and not self.allow_order_by:
raise TypeError("%s does not allow order_by." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None: if default is not None and self.empty_result_set_value is not None:
raise TypeError(f"{self.__class__.__name__} does not allow default.") raise TypeError(f"{self.__class__.__name__} does not allow default.")
self.distinct = distinct self.distinct = distinct
self.filter = filter self.filter = filter and AggregateFilter(filter)
self.default = default self.default = default
self.order_by = AggregateOrderBy.from_param(
f"{self.__class__.__name__}.order_by", order_by
)
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def get_source_fields(self): def get_source_fields(self):
# Don't return the filter expression since it's not a source field. # Don't consider filter and order by expression as they have nothing
# to do with the output field resolution.
return [e._output_field_or_none for e in super().get_source_expressions()] return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self): def get_source_expressions(self):
source_expressions = super().get_source_expressions() source_expressions = super().get_source_expressions()
return source_expressions + [self.filter] return source_expressions + [self.filter, self.order_by]
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
*exprs, self.filter = exprs *exprs, self.filter, self.order_by = exprs
return super().set_source_expressions(exprs) return super().set_source_expressions(exprs)
def resolve_expression( def resolve_expression(
@@ -66,6 +124,11 @@ class Aggregate(Func):
if c.filter if c.filter
else None else None
) )
c.order_by = (
c.order_by.resolve_expression(query, allow_joins, reuse, summarize)
if c.order_by
else None
)
if summarize: if summarize:
# Summarized aggregates cannot refer to summarized aggregates. # Summarized aggregates cannot refer to summarized aggregates.
for ref in c.get_refs(): for ref in c.get_refs():
@@ -115,35 +178,45 @@ class Aggregate(Func):
return [] return []
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):
extra_context["distinct"] = "DISTINCT " if self.distinct else "" if (
if self.filter: self.distinct
if connection.features.supports_aggregate_filter_clause: and not connection.features.supports_aggregate_distinct_multiple_argument
try: and len(super().get_source_expressions()) > 1
filter_sql, filter_params = self.filter.as_sql(compiler, connection) ):
except FullResultSet: raise NotSupportedError(
pass f"{self.name} does not support distinct with multiple expressions on "
else: f"this database backend."
template = self.filter_template % extra_context.get( )
"template", self.template
) distinct_sql = "DISTINCT " if self.distinct else ""
sql, params = super().as_sql( order_by_sql = ""
compiler, order_by_params = []
connection, filter_sql = ""
template=template, filter_params = []
filter=filter_sql,
**extra_context, if (order_by := self.order_by) is not None:
) order_by_sql, order_by_params = compiler.compile(order_by)
return sql, (*params, *filter_params)
else: if self.filter is not None:
try:
filter_sql, filter_params = compiler.compile(self.filter)
except NotSupportedError:
# Fallback to a CASE statement on backends that don't support
# the FILTER clause.
copy = self.copy() copy = self.copy()
copy.filter = None copy.filter = None
source_expressions = copy.get_source_expressions() source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0]) condition = When(self.filter.condition, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:]) copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql( return copy.as_sql(compiler, connection, **extra_context)
compiler, connection, **extra_context
) extra_context.update(
return super().as_sql(compiler, connection, **extra_context) distinct=distinct_sql,
filter=filter_sql,
order_by=order_by_sql,
)
sql, params = super().as_sql(compiler, connection, **extra_context)
return sql, (*params, *order_by_params, *filter_params)
def _get_repr_options(self): def _get_repr_options(self):
options = super()._get_repr_options() options = super()._get_repr_options()
@@ -151,6 +224,8 @@ class Aggregate(Func):
options["distinct"] = self.distinct options["distinct"] = self.distinct
if self.filter: if self.filter:
options["filter"] = self.filter options["filter"] = self.filter
if self.order_by:
options["order_by"] = self.order_by
return options return options
@@ -179,17 +254,17 @@ class Count(Aggregate):
def resolve_expression(self, *args, **kwargs): def resolve_expression(self, *args, **kwargs):
result = super().resolve_expression(*args, **kwargs) result = super().resolve_expression(*args, **kwargs)
expr = result.source_expressions[0] source_expressions = result.get_source_expressions()
# In case of composite primary keys, count the first column. # In case of composite primary keys, count the first column.
if isinstance(expr, ColPairs): if isinstance(expr := source_expressions[0], ColPairs):
if self.distinct: if self.distinct:
raise ValueError( raise ValueError(
"COUNT(DISTINCT) doesn't support composite primary keys" "COUNT(DISTINCT) doesn't support composite primary keys"
) )
cols = expr.get_cols() source_expressions[0] = expr.get_cols()[0]
return Count(cols[0], filter=result.filter) result.set_source_expressions(source_expressions)
return result return result
@@ -218,6 +293,88 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
class StringAggDelimiter(Func):
arity = 1
template = "%(expressions)s"
def __init__(self, value):
self.value = value
super().__init__(value)
def as_mysql(self, compiler, connection, **extra_context):
template = " SEPARATOR %(expressions)s"
return self.as_sql(
compiler,
connection,
template=template,
**extra_context,
)
class StringAgg(Aggregate):
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
function = "STRING_AGG"
name = "StringAgg"
allow_distinct = True
allow_order_by = True
output_field = TextField()
def __init__(self, expression, delimiter, **extra):
self.delimiter = StringAggDelimiter(delimiter)
super().__init__(expression, self.delimiter, **extra)
def as_oracle(self, compiler, connection, **extra_context):
if self.order_by:
template = (
"%(function)s(%(distinct)s%(expressions)s) WITHIN GROUP (%(order_by)s)"
"%(filter)s"
)
else:
template = "%(function)s(%(distinct)s%(expressions)s)%(filter)s"
return self.as_sql(
compiler,
connection,
function="LISTAGG",
template=template,
**extra_context,
)
def as_mysql(self, compiler, connection, **extra_context):
extra_context["function"] = "GROUP_CONCAT"
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(delimiter)s)"
extra_context["template"] = template
c = self.copy()
# The creation of the delimiter SQL and the ordering of the parameters must be
# handled explicitly, as MySQL puts the delimiter at the end of the aggregate
# using the `SEPARATOR` declaration (rather than treating as an expression like
# other database backends).
delimiter_params = []
if c.delimiter:
delimiter_sql, delimiter_params = compiler.compile(c.delimiter)
# Drop the delimiter from the source expressions.
c.source_expressions = c.source_expressions[:-1]
extra_context["delimiter"] = delimiter_sql
sql, params = c.as_sql(compiler, connection, **extra_context)
return sql, (*params, *delimiter_params)
def as_sqlite(self, compiler, connection, **extra_context):
if connection.get_database_version() < (3, 44):
return self.as_sql(
compiler,
connection,
function="GROUP_CONCAT",
**extra_context,
)
return self.as_sql(compiler, connection, **extra_context)
class Sum(FixDurationInputMixin, Aggregate): class Sum(FixDurationInputMixin, Aggregate):
function = "SUM" function = "SUM"
name = "Sum" name = "Sum"

View File

@@ -1481,6 +1481,21 @@ class OrderByList(ExpressionList):
) )
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
@classmethod
def from_param(cls, context, param):
if param is None:
return None
if isinstance(param, (list, tuple)):
if not param:
return None
return cls(*param)
elif isinstance(param, str) or hasattr(param, "resolve_expression"):
return cls(param)
raise ValueError(
f"{context} must be either a string reference to a "
f"field, an expression, or a list or tuple of them not {param!r}."
)
@deconstructible(path="django.db.models.ExpressionWrapper") @deconstructible(path="django.db.models.ExpressionWrapper")
class ExpressionWrapper(SQLiteNumericMixin, Expression): class ExpressionWrapper(SQLiteNumericMixin, Expression):
@@ -1943,16 +1958,7 @@ class Window(SQLiteNumericMixin, Expression):
self.partition_by = (self.partition_by,) self.partition_by = (self.partition_by,)
self.partition_by = ExpressionList(*self.partition_by) self.partition_by = ExpressionList(*self.partition_by)
if self.order_by is not None: self.order_by = OrderByList.from_param("Window.order_by", self.order_by)
if isinstance(self.order_by, (list, tuple)):
self.order_by = OrderByList(*self.order_by)
elif isinstance(self.order_by, (BaseExpression, str)):
self.order_by = OrderByList(self.order_by)
else:
raise ValueError(
"Window.order_by must be either a string reference to a "
"field, an expression, or a list or tuple of them."
)
super().__init__(output_field=output_field) super().__init__(output_field=output_field)
self.source_expression = self._parse_expressions(expression)[0] self.source_expression = self._parse_expressions(expression)[0]

View File

@@ -18,6 +18,8 @@ details on these changes.
* The ``serialize`` keyword argument of * The ``serialize`` keyword argument of
``BaseDatabaseCreation.create_test_db()`` will be removed. ``BaseDatabaseCreation.create_test_db()`` will be removed.
* The ``django.contrib.postgres.aggregates.StringAgg`` class will be removed.
.. _deprecation-removed-in-6.1: .. _deprecation-removed-in-6.1:
6.1 6.1

View File

@@ -194,6 +194,8 @@ General-purpose aggregation functions
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=()) .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=())
.. deprecated:: 6.0
Returns the input values concatenated into a string, separated by Returns the input values concatenated into a string, separated by
the ``delimiter`` string, or ``default`` if there are no values. the ``delimiter`` string, or ``default`` if there are no values.

View File

@@ -448,7 +448,7 @@ some complex computations::
The ``Aggregate`` API is as follows: The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra) .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, order_by=None, **extra)
.. attribute:: template .. attribute:: template
@@ -473,6 +473,15 @@ The ``Aggregate`` API is as follows:
allows passing a ``distinct`` keyword argument. If set to ``False`` allows passing a ``distinct`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``distinct=True`` is passed. (default), ``TypeError`` is raised if ``distinct=True`` is passed.
.. attribute:: allow_order_by
.. versionadded:: 6.0
A class attribute determining whether or not this aggregate function
allows passing a ``order_by`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``order_by`` is passed as a value
other than ``None``.
.. attribute:: empty_result_set_value .. attribute:: empty_result_set_value
Defaults to ``None`` since most aggregate functions result in ``NULL`` Defaults to ``None`` since most aggregate functions result in ``NULL``
@@ -491,6 +500,12 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation` used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage. and :ref:`filtering-on-annotations` for example usage.
The ``order_by`` argument behaves similarly to the ``field_names`` input of the
:meth:`~.QuerySet.order_by` function, accepting a field name (with an optional
``"-"`` prefix which indicates descending order) or an expression (or a tuple
or list of strings and/or expressions) that specifies the ordering of the
elements in the result.
The ``default`` argument takes a value that will be passed along with the The ``default`` argument takes a value that will be passed along with the
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
specifying a value to be returned other than ``None`` when the queryset (or specifying a value to be returned other than ``None`` when the queryset (or
@@ -499,6 +514,10 @@ grouping) contains no entries.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
.. versionchanged:: 6.0
The ``order_by`` argument was added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------

View File

@@ -4046,6 +4046,25 @@ by the aggregate.
However, if ``sample=True``, the return value will be the sample However, if ``sample=True``, the return value will be the sample
variance. variance.
``StringAgg``
~~~~~~~~~~~~~
.. versionadded:: 6.0
.. class:: StringAgg(expression, delimiter, output_field=None, distinct=False, filter=None, order_by=None, default=None, **extra)
Returns the input values concatenated into a string, separated by the
``delimiter`` string, or ``default`` if there are no values.
* Default alias: ``<field>__stringagg``
* Return type: ``string`` or ``output_field`` if supplied. If the
queryset or grouping is empty, ``default`` is returned.
.. attribute:: delimiter
A ``Value`` or expression representing the string that should separate
each of the values. For example, ``Value(",")``.
Query-related tools Query-related tools
=================== ===================

View File

@@ -184,6 +184,16 @@ Models
* :doc:`Constraints </ref/models/constraints>` now implement a ``check()`` * :doc:`Constraints </ref/models/constraints>` now implement a ``check()``
method that is already registered with the check framework. method that is already registered with the check framework.
* The new ``order_by`` argument for :class:`~django.db.models.Aggregate` allows
specifying the ordering of the elements in the result.
* The new :attr:`.Aggregate.allow_order_by` class attribute determines whether
the aggregate function allows passing an ``order_by`` keyword argument.
* The new :class:`~django.db.models.StringAgg` aggregate returns the input
values concatenated into a string, separated by the ``delimiter`` string.
This aggregate was previously supported only for PostgreSQL.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
@@ -288,6 +298,9 @@ Miscellaneous
* ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use * ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use
``serialize_db_to_string()`` instead. ``serialize_db_to_string()`` instead.
* The PostgreSQL ``StringAgg`` class is deprecated in favor of the generally
available :class:`~django.db.models.StringAgg` class.
Features removed in 6.0 Features removed in 6.0
======================= =======================

View File

@@ -43,3 +43,7 @@ class Store(models.Model):
def __str__(self): def __str__(self):
return self.name return self.name
class Employee(models.Model):
work_day_preferences = models.JSONField()

View File

@@ -4,10 +4,11 @@ import re
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection from django.db import NotSupportedError, connection
from django.db.models import ( from django.db.models import (
Avg, Avg,
Case, Case,
CharField,
Count, Count,
DateField, DateField,
DateTimeField, DateTimeField,
@@ -22,6 +23,7 @@ from django.db.models import (
OuterRef, OuterRef,
Q, Q,
StdDev, StdDev,
StringAgg,
Subquery, Subquery,
Sum, Sum,
TimeField, TimeField,
@@ -32,9 +34,11 @@ from django.db.models import (
Window, Window,
) )
from django.db.models.expressions import Func, RawSQL from django.db.models.expressions import Func, RawSQL
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import ( from django.db.models.functions import (
Cast, Cast,
Coalesce, Coalesce,
Concat,
Greatest, Greatest,
Least, Least,
Lower, Lower,
@@ -45,11 +49,11 @@ from django.db.models.functions import (
TruncHour, TruncHour,
) )
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext from django.test.utils import Approximate, CaptureQueriesContext
from django.utils import timezone from django.utils import timezone
from .models import Author, Book, Publisher, Store from .models import Author, Book, Employee, Publisher, Store
class NowUTC(Now): class NowUTC(Now):
@@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
) )
self.assertEqual(books["ratings"], expected_result) self.assertEqual(books["ratings"], expected_result)
@skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
def test_distinct_on_stringagg(self):
books = Book.objects.aggregate(
ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
)
self.assertEqual(books["ratings"], "3,4,4.5,5")
@skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
def test_raises_error_on_multiple_argument_distinct(self):
message = (
"StringAgg does not support distinct with multiple expressions on this "
"database backend."
)
with self.assertRaisesMessage(NotSupportedError, message):
Book.objects.aggregate(
ratings=StringAgg(
Cast(F("rating"), CharField()),
Value(","),
distinct=True,
)
)
def test_non_grouped_annotation_not_in_group_by(self): def test_non_grouped_annotation_not_in_group_by(self):
""" """
An annotation not included in values() before an aggregate should be An annotation not included in values() before an aggregate should be
@@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price")) Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
def test_multi_arg_aggregate(self): def test_multi_arg_aggregate(self):
class MyMax(Max): class MultiArgAgg(Max):
output_field = DecimalField() output_field = DecimalField()
arity = None arity = None
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection, **extra_context):
copy = self.copy() copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None]) # Most database backends do not support compiling multiple arguments on
return super(MyMax, copy).as_sql(compiler, connection) # the Max aggregate, and that isn't what is being tested here anyway. To
# avoid errors, the extra argument is just dropped.
copy.set_source_expressions(
copy.get_source_expressions()[0:1] + [None, None]
)
return super(MultiArgAgg, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"): with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
Book.objects.aggregate(MyMax("pages", "price")) Book.objects.aggregate(MultiArgAgg("pages", "price"))
with self.assertRaisesMessage( with self.assertRaisesMessage(
TypeError, "Complex annotations require an alias" TypeError, "Complex annotations require an alias"
): ):
Book.objects.annotate(MyMax("pages", "price")) Book.objects.annotate(MultiArgAgg("pages", "price"))
Book.objects.aggregate(max_field=MyMax("pages", "price")) Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
def test_add_implementation(self): def test_add_implementation(self):
class MySum(Sum): class MySum(Sum):
@@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
"function": self.function.lower(), "function": self.function.lower(),
"expressions": sql, "expressions": sql,
"distinct": "", "distinct": "",
"filter": "",
"order_by": "",
} }
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, params return self.template % substitutions, params
@@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template # test overriding all parts of the template
def be_evil(self, compiler, connection): def be_evil(self, compiler, connection):
substitutions = {"function": "MAX", "expressions": "2", "distinct": ""} substitutions = {
"function": "MAX",
"expressions": "2",
"distinct": "",
"filter": "",
"order_by": "",
}
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, () return self.template % substitutions, ()
@@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
Publisher.objects.none().aggregate( Publisher.objects.none().aggregate(
sum_awards=Sum("num_awards"), sum_awards=Sum("num_awards"),
books_count=Count("book"), books_count=Count("book"),
all_names=StringAgg("name", Value(",")),
), ),
{ {
"sum_awards": None, "sum_awards": None,
"books_count": 0, "books_count": 0,
"all_names": None,
}, },
) )
# Expression without empty_result_set_value forces queries to be # Expression without empty_result_set_value forces queries to be
@@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
) )
self.assertEqual(result["value"], 35) self.assertEqual(result["value"], 35)
def test_stringagg_default_value(self):
result = Author.objects.filter(age__gt=100).aggregate(
value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
)
self.assertEqual(result["value"], "<empty>")
def test_aggregation_default_group_by(self): def test_aggregation_default_group_by(self):
qs = ( qs = (
Publisher.objects.values("name") Publisher.objects.values("name")
@@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
super(function, func_instance).__init__(Value(1), Value(2)) super(function, func_instance).__init__(Value(1), Value(2))
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
Book.objects.aggregate(stringagg=StringAgg("name"))
def test_string_agg_escapes_delimiter(self):
values = Publisher.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value("'"))
)
self.assertEqual(
values,
{
"stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
"of Books",
},
)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by(self):
order_by_test_cases = (
(
F("original_opening").desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
(
F("original_opening").asc(),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
(
Concat("original_opening", Value("@")),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(
Concat("original_opening", Value("@")).desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = Store.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
@skipIfDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by_is_not_supported(self):
message = (
"This database backend does not support specifying an order on aggregates."
)
with self.assertRaisesMessage(NotSupportedError, message):
Store.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
order_by="original_opening",
)
)
def test_string_agg_filter(self):
values = Book.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
filter=Q(name__startswith="P"),
)
)
expected_values = {
"stringagg": "Practical Django Projects;"
"Python Web Development with Django;Paradigms of Artificial "
"Intelligence Programming: Case Studies in Common Lisp",
}
self.assertEqual(values, expected_values)
@skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
def test_string_agg_jsonfield_order_by(self):
Employee.objects.bulk_create(
[
Employee(work_day_preferences={"Monday": "morning"}),
Employee(work_day_preferences={"Monday": "afternoon"}),
]
)
values = Employee.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("Monday", "work_day_preferences"),
delimiter=Value(","),
order_by=KeyTextTransform(
"Monday",
"work_day_preferences",
),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "afternoon,morning"})
def test_string_agg_filter_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
filter=~Q(authors__name__startswith="J"),
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
).values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty",
"Brad Dayley",
"Paul Bissex;Wesley J. Chun",
"Peter Norvig;Stuart Russell",
"Peter Norvig",
"" if connection.features.interprets_empty_strings_as_nulls else None,
]
self.assertQuerySetEqual(expected_values, values, ordered=False)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_order_by_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
order_by="authors__name",
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
)
.order_by("agg")
.values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty;Jacob Kaplan-Moss",
"Brad Dayley",
"James Bennett",
"Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
"Peter Norvig",
"Peter Norvig;Stuart Russell",
]
self.assertEqual(expected_values, values)
class AggregateAnnotationPruningTests(TestCase): class AggregateAnnotationPruningTests(TestCase):
@classmethod @classmethod

View File

@@ -1720,14 +1720,14 @@ class WindowFunctionTests(TestCase):
"""Window expressions can't be used in an INSERT statement.""" """Window expressions can't be used in an INSERT statement."""
msg = ( msg = (
"Window expressions are not allowed in this query (salary=<Window: " "Window expressions are not allowed in this query (salary=<Window: "
"Sum(Value(10000), order_by=OrderBy(F(pk), descending=False)) OVER ()" "Sum(Value(10000)) OVER ()"
) )
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):
Employee.objects.create( Employee.objects.create(
name="Jameson", name="Jameson",
department="Management", department="Management",
hire_date=datetime.date(2007, 7, 1), hire_date=datetime.date(2007, 7, 1),
salary=Window(expression=Sum(Value(10000), order_by=F("pk").asc())), salary=Window(expression=Sum(Value(10000))),
) )
def test_window_expression_within_subquery(self): def test_window_expression_within_subquery(self):
@@ -2025,7 +2025,7 @@ class NonQueryWindowTests(SimpleTestCase):
def test_invalid_order_by(self): def test_invalid_order_by(self):
msg = ( msg = (
"Window.order_by must be either a string reference to a field, an " "Window.order_by must be either a string reference to a field, an "
"expression, or a list or tuple of them." "expression, or a list or tuple of them not {'-horse'}."
) )
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
Window(expression=Sum("power"), order_by={"-horse"}) Window(expression=Sum("power"), order_by={"-horse"})

View File

@@ -1,3 +1,5 @@
import warnings
from django.db import transaction from django.db import transaction
from django.db.models import ( from django.db.models import (
CharField, CharField,
@@ -11,16 +13,19 @@ from django.db.models import (
Value, Value,
Window, Window,
) )
from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.fields.json import KeyTransform
from django.db.models.functions import Cast, Concat, LPad, Substr from django.db.models.functions import Cast, Concat, LPad, Substr
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import timezone from django.utils import timezone
from django.utils.deprecation import RemovedInDjango61Warning from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
try: try:
from django.contrib.postgres.aggregates import (
StringAgg, # RemovedInDjango70Warning.
)
from django.contrib.postgres.aggregates import ( from django.contrib.postgres.aggregates import (
ArrayAgg, ArrayAgg,
BitAnd, BitAnd,
@@ -41,7 +46,6 @@ try:
RegrSXY, RegrSXY,
RegrSYY, RegrSYY,
StatAggregate, StatAggregate,
StringAgg,
) )
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
except ImportError: except ImportError:
@@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
BoolAnd("boolean_field"), BoolAnd("boolean_field"),
BoolOr("boolean_field"), BoolOr("boolean_field"),
JSONBAgg("integer_field"), JSONBAgg("integer_field"),
StringAgg("char_field", delimiter=";"),
BitXor("integer_field"), BitXor("integer_field"),
] ]
for aggregation in tests: for aggregation in tests:
@@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())), JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
["<empty>"], ["<empty>"],
), ),
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
(
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
"<empty>",
),
(BitXor("integer_field", default=0), 0), (BitXor("integer_field", default=0), 0),
] ]
for aggregation, expected_result in tests: for aggregation, expected_result in tests:
@@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]}) self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
self.assertEqual(ctx.filename, __file__) self.assertEqual(ctx.filename, __file__)
# RemovedInDjango61Warning: Remove this test
def test_ordering_and_order_by_causes_error(self): def test_ordering_and_order_by_causes_error(self):
with self.assertWarns(RemovedInDjango61Warning): with warnings.catch_warnings(record=True, action="always") as wm:
with self.assertRaisesMessage( with self.assertRaisesMessage(
TypeError, TypeError,
"Cannot specify both order_by and ordering.", "Cannot specify both order_by and ordering.",
@@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
) )
) )
first_warning = wm[0]
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
self.assertEqual(
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead.",
str(first_warning.message),
)
second_warning = wm[1]
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
self.assertEqual(
"The ordering argument is deprecated. Use order_by instead.",
str(second_warning.message),
)
def test_array_agg_charfield(self): def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field")) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]}) self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
@@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
) )
self.assertEqual(values, {"boolor": False}) self.assertEqual(values, {"boolor": False})
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
def test_string_agg_delimiter_escaping(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";")
)
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
def test_string_agg_default_output_field(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("text_field", delimiter=";"),
)
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
def test_string_agg_charfield_order_by(self):
order_by_test_cases = (
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
("char_field", "Foo1;Foo2;Foo3;Foo4"),
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
def test_string_agg_jsonfield_order_by(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("lang", "json_field"),
delimiter=";",
order_by=KeyTextTransform("lang", "json_field"),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "en;pl"})
def test_string_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
"char_field",
delimiter=";",
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
)
)
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
def test_orderable_agg_alternative_fields(self): def test_orderable_agg_alternative_fields(self):
values = AggregateTestModel.objects.aggregate( values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc()) arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
@@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
], ],
) )
def test_string_agg_array_agg_order_by_in_subquery(self): def test_array_agg_order_by_in_subquery(self):
stats = [] stats = []
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")): for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1)) stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i)) stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
StatTestModel.objects.bulk_create(stats) StatTestModel.objects.bulk_create(stats)
for aggregate, expected_result in ( aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
( expected_result = [
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"), ("Foo1", [0, 1]),
[ ("Foo2", [1, 2]),
("Foo1", [0, 1]), ("Foo3", [2, 3]),
("Foo2", [1, 2]), ("Foo4", [3, 4]),
("Foo3", [2, 3]), ]
("Foo4", [3, 4]),
], subquery = (
), AggregateTestModel.objects.filter(
( pk=OuterRef("pk"),
StringAgg( )
Cast("stattestmodel__int1", CharField()), .annotate(agg=aggregate)
delimiter=";", .values("agg")
order_by="-stattestmodel__int2", )
), values = (
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")], AggregateTestModel.objects.annotate(
), agg=Subquery(subquery),
): )
with self.subTest(aggregate=aggregate.__class__.__name__): .order_by("char_field")
subquery = ( .values_list("char_field", "agg")
AggregateTestModel.objects.filter( )
pk=OuterRef("pk"), self.assertEqual(list(values), expected_result)
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_array_agg_filter_in_subquery(self): def test_string_agg_array_agg_filter_in_subquery(self):
StatTestModel.objects.bulk_create( StatTestModel.objects.bulk_create(
@@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
StatTestModel(related_field=self.aggs[0], int1=2, int2=3), StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
] ]
) )
for aggregate, expected_result in (
(
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
[("Foo1", [0, 1]), ("Foo2", None)],
),
(
StringAgg(
Cast("stattestmodel__int2", CharField()),
delimiter=";",
filter=Q(stattestmodel__int1__lt=2),
),
[("Foo1", "5;4"), ("Foo2", None)],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_filter_in_subquery_with_exclude(self): aggregate = ArrayAgg(
"stattestmodel__int1",
filter=Q(stattestmodel__int2__gt=3),
)
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
subquery = ( subquery = (
AggregateTestModel.objects.annotate( AggregateTestModel.objects.filter(
stringagg=StringAgg( pk=OuterRef("pk"),
"char_field",
delimiter=";",
filter=Q(char_field__endswith="1"),
)
) )
.exclude(stringagg="") .annotate(agg=aggregate)
.values("id") .values("agg")
) )
self.assertSequenceEqual( values = (
AggregateTestModel.objects.filter(id__in=Subquery(subquery)), AggregateTestModel.objects.annotate(
[self.aggs[0]], agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
) )
self.assertEqual(list(values), expected_result)
def test_ordering_isnt_cleared_for_array_subquery(self): def test_ordering_isnt_cleared_for_array_subquery(self):
inner_qs = AggregateTestModel.objects.order_by("-integer_field") inner_qs = AggregateTestModel.objects.order_by("-integer_field")
@@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")] tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
for aggregation in tests: for aggregation in tests:
with self.subTest(aggregation=aggregation): with self.subTest(aggregation=aggregation):
results = AggregateTestModel.objects.annotate(
agg=aggregation
).values_list("agg")
self.assertCountEqual( self.assertCountEqual(
AggregateTestModel.objects.values_list(aggregation), results,
[([0],), ([1],), ([2],), ([0],)], [([0],), ([1],), ([2],), ([0],)],
) )
def test_string_agg_delimiter_deprecation(self):
msg = (
"delimiter: str will be resolved as a field reference instead "
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
"preserve the previous behaviour."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
def test_string_agg_deprecation(self):
msg = (
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=Value("'"))
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
class TestAggregateDistinct(PostgreSQLTestCase): class TestAggregateDistinct(PostgreSQLTestCase):
@classmethod @classmethod
@@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
AggregateTestModel.objects.create(char_field="Foo") AggregateTestModel.objects.create(char_field="Foo")
AggregateTestModel.objects.create(char_field="Bar") AggregateTestModel.objects.create(char_field="Bar")
def test_string_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
)
self.assertEqual(values["stringagg"].count("Foo"), 2)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_string_agg_distinct_true(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
)
self.assertEqual(values["stringagg"].count("Foo"), 1)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_array_agg_distinct_false(self): def test_array_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate( values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("char_field", distinct=False) arrayagg=ArrayAgg("char_field", distinct=False)