1
0
mirror of https://github.com/django/django.git synced 2025-10-24 14:16:09 +00:00

Fixed #35444 -- Added generic support for Aggregate.order_by.

This moves the behaviors of `order_by` used in Postgres aggregates into
the `Aggregate` class. This allows for creating aggregate functions that
support this behavior across all database engines. This is shown by
moving the `StringAgg` class into the shared `aggregates` module and
adding support for all databases. The Postgres `StringAgg` class is now
a thin wrapper on the new shared `StringAgg` class.

Thank you Simon Charette for the review.
This commit is contained in:
Chris Muthig
2024-12-22 16:30:55 +01:00
committed by Sarah Boyce
parent 6d1cf5375f
commit 4b977a5d72
19 changed files with 659 additions and 291 deletions

View File

@@ -33,15 +33,14 @@ class GeoAggregate(Aggregate):
if not self.is_extent:
tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
clone = self.copy()
source_expressions = self.get_source_expressions()
source_expressions.pop() # Don't wrap filters with SDOAGGRTYPE().
*source_exprs, filter_expr, order_by_expr = self.get_source_expressions()
spatial_type_expr = Func(
*source_expressions,
*source_exprs,
Value(tolerance),
function="SDOAGGRTYPE",
output_field=self.output_field,
)
source_expressions = [spatial_type_expr, self.filter]
source_expressions = [spatial_type_expr, filter_expr, order_by_expr]
clone.set_source_expressions(source_expressions)
return clone.as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)

View File

@@ -1,7 +1,12 @@
from django.contrib.postgres.fields import ArrayField
from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value
import warnings
from .mixins import OrderableAggMixin
from django.contrib.postgres.fields import ArrayField
from django.db.models import Aggregate, BooleanField, JSONField
from django.db.models import StringAgg as _StringAgg
from django.db.models import Value
from django.utils.deprecation import RemovedInDjango70Warning
from .mixins import _DeprecatedOrdering
__all__ = [
"ArrayAgg",
@@ -11,14 +16,16 @@ __all__ = [
"BoolAnd",
"BoolOr",
"JSONBAgg",
"StringAgg",
"StringAgg", # RemovedInDjango70Warning.
]
class ArrayAgg(OrderableAggMixin, Aggregate):
# RemovedInDjango61Warning: When the deprecation ends, replace with:
# class ArrayAgg(Aggregate):
class ArrayAgg(_DeprecatedOrdering, Aggregate):
function = "ARRAY_AGG"
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
allow_distinct = True
allow_order_by = True
@property
def output_field(self):
@@ -47,19 +54,37 @@ class BoolOr(Aggregate):
output_field = BooleanField()
class JSONBAgg(OrderableAggMixin, Aggregate):
# RemovedInDjango61Warning: When the deprecation ends, replace with:
# class JSONBAgg(Aggregate):
class JSONBAgg(_DeprecatedOrdering, Aggregate):
function = "JSONB_AGG"
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
allow_distinct = True
allow_order_by = True
output_field = JSONField()
class StringAgg(OrderableAggMixin, Aggregate):
function = "STRING_AGG"
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
allow_distinct = True
output_field = TextField()
# RemovedInDjango61Warning: When the deprecation ends, replace with:
# class StringAgg(_StringAgg):
# RemovedInDjango70Warning: When the deprecation ends, remove completely.
class StringAgg(_DeprecatedOrdering, _StringAgg):
def __init__(self, expression, delimiter, **extra):
delimiter_expr = Value(str(delimiter))
super().__init__(expression, delimiter_expr, **extra)
if isinstance(delimiter, str):
warnings.warn(
"delimiter: str will be resolved as a field reference instead "
"of a string literal on Django 7.0. Pass "
f"`delimiter=Value({delimiter!r})` to preserve the previous behaviour.",
category=RemovedInDjango70Warning,
stacklevel=2,
)
delimiter = Value(delimiter)
warnings.warn(
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead.",
category=RemovedInDjango70Warning,
stacklevel=2,
)
super().__init__(expression, delimiter, **extra)

View File

@@ -1,15 +1,11 @@
import warnings
from django.core.exceptions import FullResultSet
from django.db.models.expressions import OrderByList
from django.utils.deprecation import RemovedInDjango61Warning
class OrderableAggMixin:
# RemovedInDjango61Warning: When the deprecation ends, replace with:
# def __init__(self, *expressions, order_by=(), **extra):
# RemovedInDjango61Warning.
class _DeprecatedOrdering:
def __init__(self, *expressions, ordering=(), order_by=(), **extra):
# RemovedInDjango61Warning.
if ordering:
warnings.warn(
"The ordering argument is deprecated. Use order_by instead.",
@@ -19,44 +15,14 @@ class OrderableAggMixin:
if order_by:
raise TypeError("Cannot specify both order_by and ordering.")
order_by = ordering
if not order_by:
self.order_by = None
elif isinstance(order_by, (list, tuple)):
self.order_by = OrderByList(*order_by)
else:
self.order_by = OrderByList(order_by)
super().__init__(*expressions, **extra)
def resolve_expression(self, *args, **kwargs):
if self.order_by is not None:
self.order_by = self.order_by.resolve_expression(*args, **kwargs)
return super().resolve_expression(*args, **kwargs)
super().__init__(*expressions, order_by=order_by, **extra)
def get_source_expressions(self):
return super().get_source_expressions() + [self.order_by]
def set_source_expressions(self, exprs):
*exprs, self.order_by = exprs
return super().set_source_expressions(exprs)
# RemovedInDjango61Warning: When the deprecation ends, replace with:
# class OrderableAggMixin:
class OrderableAggMixin(_DeprecatedOrdering):
allow_order_by = True
def as_sql(self, compiler, connection):
*source_exprs, filtering_expr, order_by_expr = self.get_source_expressions()
order_by_sql = ""
order_by_params = []
if order_by_expr is not None:
order_by_sql, order_by_params = compiler.compile(order_by_expr)
filter_params = []
if filtering_expr is not None:
try:
_, filter_params = compiler.compile(filtering_expr)
except FullResultSet:
pass
source_params = []
for source_expr in source_exprs:
source_params += compiler.compile(source_expr)[1]
sql, _ = super().as_sql(compiler, connection, order_by=order_by_sql)
return sql, (*source_params, *order_by_params, *filter_params)
def __init_subclass__(cls, /, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)

View File

@@ -257,6 +257,15 @@ class BaseDatabaseFeatures:
# expressions?
supports_aggregate_filter_clause = False
# Does the database support ORDER BY in aggregate expressions?
supports_aggregate_order_by_clause = False
# Does the database backend support DISTINCT when using multiple arguments in an
# aggregate expression? For example, Sqlite treats the "delimiter" argument of
# STRING_AGG/GROUP_CONCAT as an extra argument and does not allow using a custom
# delimiter along with DISTINCT.
supports_aggregate_distinct_multiple_argument = True
# Does the backend support indexing a TextField?
supports_index_on_text_field = True

View File

@@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_explicit_null_ordering_when_grouping = True
atomic_transactions = False
can_clone_databases = True
supports_aggregate_order_by_clause = True
supports_comments = True
supports_comments_inline = True
supports_temporal_subtraction = True

View File

@@ -45,6 +45,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# does by uppercasing all identifiers.
ignores_table_name_case = True
supports_index_on_text_field = False
supports_aggregate_order_by_clause = True
create_test_procedure_without_params_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" AS
V_I INTEGER;

View File

@@ -64,6 +64,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supports_aggregate_order_by_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
supports_deferrable_unique_constraints = True
has_json_operators = True

View File

@@ -34,6 +34,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_frame_range_fixed_distance = True
supports_frame_exclusion = True
supports_aggregate_filter_clause = True
supports_aggregate_order_by_clause = Database.sqlite_version_info >= (3, 44, 0)
supports_aggregate_distinct_multiple_argument = False
order_by_nulls_first = True
supports_json_field_contains = False
supports_update_conflicts = True

View File

@@ -3,8 +3,17 @@ Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
from django.db.models.fields import IntegerField
from django.db import NotSupportedError
from django.db.models.expressions import (
Case,
ColPairs,
Func,
OrderByList,
Star,
Value,
When,
)
from django.db.models.fields import IntegerField, TextField
from django.db.models.functions import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin,
@@ -18,42 +27,91 @@ __all__ = [
"Max",
"Min",
"StdDev",
"StringAgg",
"Sum",
"Variance",
]
class AggregateFilter(Func):
arity = 1
template = " FILTER (WHERE %(expressions)s)"
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_aggregate_filter_clause:
raise NotSupportedError(
"Aggregate filter clauses are not supported on this database backend."
)
try:
return super().as_sql(compiler, connection, **extra_context)
except FullResultSet:
return "", ()
@property
def condition(self):
return self.source_expressions[0]
def __str__(self):
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
class AggregateOrderBy(OrderByList):
template = " ORDER BY %(expressions)s"
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_aggregate_order_by_clause:
raise NotSupportedError(
"This database backend does not support specifying an order on "
"aggregates."
)
return super().as_sql(compiler, connection, **extra_context)
class Aggregate(Func):
template = "%(function)s(%(distinct)s%(expressions)s)"
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
contains_aggregate = True
name = None
filter_template = "%s FILTER (WHERE %%(filter)s)"
window_compatible = True
allow_distinct = False
allow_order_by = False
empty_result_set_value = None
def __init__(
self, *expressions, distinct=False, filter=None, default=None, **extra
self,
*expressions,
distinct=False,
filter=None,
default=None,
order_by=None,
**extra,
):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if order_by and not self.allow_order_by:
raise TypeError("%s does not allow order_by." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None:
raise TypeError(f"{self.__class__.__name__} does not allow default.")
self.distinct = distinct
self.filter = filter
self.filter = filter and AggregateFilter(filter)
self.default = default
self.order_by = AggregateOrderBy.from_param(
f"{self.__class__.__name__}.order_by", order_by
)
super().__init__(*expressions, **extra)
def get_source_fields(self):
# Don't return the filter expression since it's not a source field.
# Don't consider filter and order by expression as they have nothing
# to do with the output field resolution.
return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
return source_expressions + [self.filter]
return source_expressions + [self.filter, self.order_by]
def set_source_expressions(self, exprs):
*exprs, self.filter = exprs
*exprs, self.filter, self.order_by = exprs
return super().set_source_expressions(exprs)
def resolve_expression(
@@ -66,6 +124,11 @@ class Aggregate(Func):
if c.filter
else None
)
c.order_by = (
c.order_by.resolve_expression(query, allow_joins, reuse, summarize)
if c.order_by
else None
)
if summarize:
# Summarized aggregates cannot refer to summarized aggregates.
for ref in c.get_refs():
@@ -115,35 +178,45 @@ class Aggregate(Func):
return []
def as_sql(self, compiler, connection, **extra_context):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
try:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
except FullResultSet:
pass
else:
template = self.filter_template % extra_context.get(
"template", self.template
)
sql, params = super().as_sql(
compiler,
connection,
template=template,
filter=filter_sql,
**extra_context,
)
return sql, (*params, *filter_params)
else:
if (
self.distinct
and not connection.features.supports_aggregate_distinct_multiple_argument
and len(super().get_source_expressions()) > 1
):
raise NotSupportedError(
f"{self.name} does not support distinct with multiple expressions on "
f"this database backend."
)
distinct_sql = "DISTINCT " if self.distinct else ""
order_by_sql = ""
order_by_params = []
filter_sql = ""
filter_params = []
if (order_by := self.order_by) is not None:
order_by_sql, order_by_params = compiler.compile(order_by)
if self.filter is not None:
try:
filter_sql, filter_params = compiler.compile(self.filter)
except NotSupportedError:
# Fallback to a CASE statement on backends that don't support
# the FILTER clause.
copy = self.copy()
copy.filter = None
source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
condition = When(self.filter.condition, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(
compiler, connection, **extra_context
)
return super().as_sql(compiler, connection, **extra_context)
return copy.as_sql(compiler, connection, **extra_context)
extra_context.update(
distinct=distinct_sql,
filter=filter_sql,
order_by=order_by_sql,
)
sql, params = super().as_sql(compiler, connection, **extra_context)
return sql, (*params, *order_by_params, *filter_params)
def _get_repr_options(self):
options = super()._get_repr_options()
@@ -151,6 +224,8 @@ class Aggregate(Func):
options["distinct"] = self.distinct
if self.filter:
options["filter"] = self.filter
if self.order_by:
options["order_by"] = self.order_by
return options
@@ -179,17 +254,17 @@ class Count(Aggregate):
def resolve_expression(self, *args, **kwargs):
result = super().resolve_expression(*args, **kwargs)
expr = result.source_expressions[0]
source_expressions = result.get_source_expressions()
# In case of composite primary keys, count the first column.
if isinstance(expr, ColPairs):
if isinstance(expr := source_expressions[0], ColPairs):
if self.distinct:
raise ValueError(
"COUNT(DISTINCT) doesn't support composite primary keys"
)
cols = expr.get_cols()
return Count(cols[0], filter=result.filter)
source_expressions[0] = expr.get_cols()[0]
result.set_source_expressions(source_expressions)
return result
@@ -218,6 +293,88 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
class StringAggDelimiter(Func):
arity = 1
template = "%(expressions)s"
def __init__(self, value):
self.value = value
super().__init__(value)
def as_mysql(self, compiler, connection, **extra_context):
template = " SEPARATOR %(expressions)s"
return self.as_sql(
compiler,
connection,
template=template,
**extra_context,
)
class StringAgg(Aggregate):
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
function = "STRING_AGG"
name = "StringAgg"
allow_distinct = True
allow_order_by = True
output_field = TextField()
def __init__(self, expression, delimiter, **extra):
self.delimiter = StringAggDelimiter(delimiter)
super().__init__(expression, self.delimiter, **extra)
def as_oracle(self, compiler, connection, **extra_context):
if self.order_by:
template = (
"%(function)s(%(distinct)s%(expressions)s) WITHIN GROUP (%(order_by)s)"
"%(filter)s"
)
else:
template = "%(function)s(%(distinct)s%(expressions)s)%(filter)s"
return self.as_sql(
compiler,
connection,
function="LISTAGG",
template=template,
**extra_context,
)
def as_mysql(self, compiler, connection, **extra_context):
extra_context["function"] = "GROUP_CONCAT"
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(delimiter)s)"
extra_context["template"] = template
c = self.copy()
# The creation of the delimiter SQL and the ordering of the parameters must be
# handled explicitly, as MySQL puts the delimiter at the end of the aggregate
# using the `SEPARATOR` declaration (rather than treating as an expression like
# other database backends).
delimiter_params = []
if c.delimiter:
delimiter_sql, delimiter_params = compiler.compile(c.delimiter)
# Drop the delimiter from the source expressions.
c.source_expressions = c.source_expressions[:-1]
extra_context["delimiter"] = delimiter_sql
sql, params = c.as_sql(compiler, connection, **extra_context)
return sql, (*params, *delimiter_params)
def as_sqlite(self, compiler, connection, **extra_context):
if connection.get_database_version() < (3, 44):
return self.as_sql(
compiler,
connection,
function="GROUP_CONCAT",
**extra_context,
)
return self.as_sql(compiler, connection, **extra_context)
class Sum(FixDurationInputMixin, Aggregate):
function = "SUM"
name = "Sum"

View File

@@ -1481,6 +1481,21 @@ class OrderByList(ExpressionList):
)
super().__init__(*expressions, **extra)
@classmethod
def from_param(cls, context, param):
if param is None:
return None
if isinstance(param, (list, tuple)):
if not param:
return None
return cls(*param)
elif isinstance(param, str) or hasattr(param, "resolve_expression"):
return cls(param)
raise ValueError(
f"{context} must be either a string reference to a "
f"field, an expression, or a list or tuple of them not {param!r}."
)
@deconstructible(path="django.db.models.ExpressionWrapper")
class ExpressionWrapper(SQLiteNumericMixin, Expression):
@@ -1943,16 +1958,7 @@ class Window(SQLiteNumericMixin, Expression):
self.partition_by = (self.partition_by,)
self.partition_by = ExpressionList(*self.partition_by)
if self.order_by is not None:
if isinstance(self.order_by, (list, tuple)):
self.order_by = OrderByList(*self.order_by)
elif isinstance(self.order_by, (BaseExpression, str)):
self.order_by = OrderByList(self.order_by)
else:
raise ValueError(
"Window.order_by must be either a string reference to a "
"field, an expression, or a list or tuple of them."
)
self.order_by = OrderByList.from_param("Window.order_by", self.order_by)
super().__init__(output_field=output_field)
self.source_expression = self._parse_expressions(expression)[0]

View File

@@ -18,6 +18,8 @@ details on these changes.
* The ``serialize`` keyword argument of
``BaseDatabaseCreation.create_test_db()`` will be removed.
* The ``django.contrib.postgres.aggregates.StringAgg`` class will be removed.
.. _deprecation-removed-in-6.1:
6.1

View File

@@ -194,6 +194,8 @@ General-purpose aggregation functions
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=())
.. deprecated:: 6.0
Returns the input values concatenated into a string, separated by
the ``delimiter`` string, or ``default`` if there are no values.

View File

@@ -448,7 +448,7 @@ some complex computations::
The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra)
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, order_by=None, **extra)
.. attribute:: template
@@ -473,6 +473,15 @@ The ``Aggregate`` API is as follows:
allows passing a ``distinct`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
.. attribute:: allow_order_by
.. versionadded:: 6.0
A class attribute determining whether or not this aggregate function
allows passing a ``order_by`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``order_by`` is passed as a value
other than ``None``.
.. attribute:: empty_result_set_value
Defaults to ``None`` since most aggregate functions result in ``NULL``
@@ -491,6 +500,12 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage.
The ``order_by`` argument behaves similarly to the ``field_names`` input of the
:meth:`~.QuerySet.order_by` function, accepting a field name (with an optional
``"-"`` prefix which indicates descending order) or an expression (or a tuple
or list of strings and/or expressions) that specifies the ordering of the
elements in the result.
The ``default`` argument takes a value that will be passed along with the
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
specifying a value to be returned other than ``None`` when the queryset (or
@@ -499,6 +514,10 @@ grouping) contains no entries.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute.
.. versionchanged:: 6.0
The ``order_by`` argument was added.
Creating your own Aggregate Functions
-------------------------------------

View File

@@ -4046,6 +4046,25 @@ by the aggregate.
However, if ``sample=True``, the return value will be the sample
variance.
``StringAgg``
~~~~~~~~~~~~~
.. versionadded:: 6.0
.. class:: StringAgg(expression, delimiter, output_field=None, distinct=False, filter=None, order_by=None, default=None, **extra)
Returns the input values concatenated into a string, separated by the
``delimiter`` string, or ``default`` if there are no values.
* Default alias: ``<field>__stringagg``
* Return type: ``string`` or ``output_field`` if supplied. If the
queryset or grouping is empty, ``default`` is returned.
.. attribute:: delimiter
A ``Value`` or expression representing the string that should separate
each of the values. For example, ``Value(",")``.
Query-related tools
===================

View File

@@ -184,6 +184,16 @@ Models
* :doc:`Constraints </ref/models/constraints>` now implement a ``check()``
method that is already registered with the check framework.
* The new ``order_by`` argument for :class:`~django.db.models.Aggregate` allows
specifying the ordering of the elements in the result.
* The new :attr:`.Aggregate.allow_order_by` class attribute determines whether
the aggregate function allows passing an ``order_by`` keyword argument.
* The new :class:`~django.db.models.StringAgg` aggregate returns the input
values concatenated into a string, separated by the ``delimiter`` string.
This aggregate was previously supported only for PostgreSQL.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@@ -288,6 +298,9 @@ Miscellaneous
* ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use
``serialize_db_to_string()`` instead.
* The PostgreSQL ``StringAgg`` class is deprecated in favor of the generally
available :class:`~django.db.models.StringAgg` class.
Features removed in 6.0
=======================

View File

@@ -43,3 +43,7 @@ class Store(models.Model):
def __str__(self):
return self.name
class Employee(models.Model):
work_day_preferences = models.JSONField()

View File

@@ -4,10 +4,11 @@ import re
from decimal import Decimal
from django.core.exceptions import FieldError
from django.db import connection
from django.db import NotSupportedError, connection
from django.db.models import (
Avg,
Case,
CharField,
Count,
DateField,
DateTimeField,
@@ -22,6 +23,7 @@ from django.db.models import (
OuterRef,
Q,
StdDev,
StringAgg,
Subquery,
Sum,
TimeField,
@@ -32,9 +34,11 @@ from django.db.models import (
Window,
)
from django.db.models.expressions import Func, RawSQL
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import (
Cast,
Coalesce,
Concat,
Greatest,
Least,
Lower,
@@ -45,11 +49,11 @@ from django.db.models.functions import (
TruncHour,
)
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext
from django.utils import timezone
from .models import Author, Book, Publisher, Store
from .models import Author, Book, Employee, Publisher, Store
class NowUTC(Now):
@@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(books["ratings"], expected_result)
@skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
def test_distinct_on_stringagg(self):
books = Book.objects.aggregate(
ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
)
self.assertEqual(books["ratings"], "3,4,4.5,5")
@skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
def test_raises_error_on_multiple_argument_distinct(self):
message = (
"StringAgg does not support distinct with multiple expressions on this "
"database backend."
)
with self.assertRaisesMessage(NotSupportedError, message):
Book.objects.aggregate(
ratings=StringAgg(
Cast(F("rating"), CharField()),
Value(","),
distinct=True,
)
)
def test_non_grouped_annotation_not_in_group_by(self):
"""
An annotation not included in values() before an aggregate should be
@@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
def test_multi_arg_aggregate(self):
class MyMax(Max):
class MultiArgAgg(Max):
output_field = DecimalField()
arity = None
def as_sql(self, compiler, connection):
def as_sql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
return super(MyMax, copy).as_sql(compiler, connection)
# Most database backends do not support compiling multiple arguments on
# the Max aggregate, and that isn't what is being tested here anyway. To
# avoid errors, the extra argument is just dropped.
copy.set_source_expressions(
copy.get_source_expressions()[0:1] + [None, None]
)
return super(MultiArgAgg, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
Book.objects.aggregate(MyMax("pages", "price"))
Book.objects.aggregate(MultiArgAgg("pages", "price"))
with self.assertRaisesMessage(
TypeError, "Complex annotations require an alias"
):
Book.objects.annotate(MyMax("pages", "price"))
Book.objects.annotate(MultiArgAgg("pages", "price"))
Book.objects.aggregate(max_field=MyMax("pages", "price"))
Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
def test_add_implementation(self):
class MySum(Sum):
@@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
"function": self.function.lower(),
"expressions": sql,
"distinct": "",
"filter": "",
"order_by": "",
}
substitutions.update(self.extra)
return self.template % substitutions, params
@@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template
def be_evil(self, compiler, connection):
substitutions = {"function": "MAX", "expressions": "2", "distinct": ""}
substitutions = {
"function": "MAX",
"expressions": "2",
"distinct": "",
"filter": "",
"order_by": "",
}
substitutions.update(self.extra)
return self.template % substitutions, ()
@@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
Publisher.objects.none().aggregate(
sum_awards=Sum("num_awards"),
books_count=Count("book"),
all_names=StringAgg("name", Value(",")),
),
{
"sum_awards": None,
"books_count": 0,
"all_names": None,
},
)
# Expression without empty_result_set_value forces queries to be
@@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(result["value"], 35)
def test_stringagg_default_value(self):
result = Author.objects.filter(age__gt=100).aggregate(
value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
)
self.assertEqual(result["value"], "<empty>")
def test_aggregation_default_group_by(self):
qs = (
Publisher.objects.values("name")
@@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
with self.assertRaisesMessage(TypeError, msg):
super(function, func_instance).__init__(Value(1), Value(2))
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
Book.objects.aggregate(stringagg=StringAgg("name"))
def test_string_agg_escapes_delimiter(self):
values = Publisher.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value("'"))
)
self.assertEqual(
values,
{
"stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
"of Books",
},
)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by(self):
order_by_test_cases = (
(
F("original_opening").desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
(
F("original_opening").asc(),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
(
Concat("original_opening", Value("@")),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(
Concat("original_opening", Value("@")).desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = Store.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
@skipIfDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by_is_not_supported(self):
message = (
"This database backend does not support specifying an order on aggregates."
)
with self.assertRaisesMessage(NotSupportedError, message):
Store.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
order_by="original_opening",
)
)
def test_string_agg_filter(self):
values = Book.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
filter=Q(name__startswith="P"),
)
)
expected_values = {
"stringagg": "Practical Django Projects;"
"Python Web Development with Django;Paradigms of Artificial "
"Intelligence Programming: Case Studies in Common Lisp",
}
self.assertEqual(values, expected_values)
@skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
def test_string_agg_jsonfield_order_by(self):
Employee.objects.bulk_create(
[
Employee(work_day_preferences={"Monday": "morning"}),
Employee(work_day_preferences={"Monday": "afternoon"}),
]
)
values = Employee.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("Monday", "work_day_preferences"),
delimiter=Value(","),
order_by=KeyTextTransform(
"Monday",
"work_day_preferences",
),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "afternoon,morning"})
def test_string_agg_filter_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
filter=~Q(authors__name__startswith="J"),
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
).values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty",
"Brad Dayley",
"Paul Bissex;Wesley J. Chun",
"Peter Norvig;Stuart Russell",
"Peter Norvig",
"" if connection.features.interprets_empty_strings_as_nulls else None,
]
self.assertQuerySetEqual(expected_values, values, ordered=False)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_order_by_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
order_by="authors__name",
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
)
.order_by("agg")
.values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty;Jacob Kaplan-Moss",
"Brad Dayley",
"James Bennett",
"Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
"Peter Norvig",
"Peter Norvig;Stuart Russell",
]
self.assertEqual(expected_values, values)
class AggregateAnnotationPruningTests(TestCase):
@classmethod

View File

@@ -1720,14 +1720,14 @@ class WindowFunctionTests(TestCase):
"""Window expressions can't be used in an INSERT statement."""
msg = (
"Window expressions are not allowed in this query (salary=<Window: "
"Sum(Value(10000), order_by=OrderBy(F(pk), descending=False)) OVER ()"
"Sum(Value(10000)) OVER ()"
)
with self.assertRaisesMessage(FieldError, msg):
Employee.objects.create(
name="Jameson",
department="Management",
hire_date=datetime.date(2007, 7, 1),
salary=Window(expression=Sum(Value(10000), order_by=F("pk").asc())),
salary=Window(expression=Sum(Value(10000))),
)
def test_window_expression_within_subquery(self):
@@ -2025,7 +2025,7 @@ class NonQueryWindowTests(SimpleTestCase):
def test_invalid_order_by(self):
msg = (
"Window.order_by must be either a string reference to a field, an "
"expression, or a list or tuple of them."
"expression, or a list or tuple of them not {'-horse'}."
)
with self.assertRaisesMessage(ValueError, msg):
Window(expression=Sum("power"), order_by={"-horse"})

View File

@@ -1,3 +1,5 @@
import warnings
from django.db import transaction
from django.db.models import (
CharField,
@@ -11,16 +13,19 @@ from django.db.models import (
Value,
Window,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.fields.json import KeyTransform
from django.db.models.functions import Cast, Concat, LPad, Substr
from django.test.utils import Approximate
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango61Warning
from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
from . import PostgreSQLTestCase
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
try:
from django.contrib.postgres.aggregates import (
StringAgg, # RemovedInDjango70Warning.
)
from django.contrib.postgres.aggregates import (
ArrayAgg,
BitAnd,
@@ -41,7 +46,6 @@ try:
RegrSXY,
RegrSYY,
StatAggregate,
StringAgg,
)
from django.contrib.postgres.fields import ArrayField
except ImportError:
@@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
BoolAnd("boolean_field"),
BoolOr("boolean_field"),
JSONBAgg("integer_field"),
StringAgg("char_field", delimiter=";"),
BitXor("integer_field"),
]
for aggregation in tests:
@@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
["<empty>"],
),
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
(
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
"<empty>",
),
(BitXor("integer_field", default=0), 0),
]
for aggregation, expected_result in tests:
@@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
self.assertEqual(ctx.filename, __file__)
# RemovedInDjango61Warning: Remove this test
def test_ordering_and_order_by_causes_error(self):
with self.assertWarns(RemovedInDjango61Warning):
with warnings.catch_warnings(record=True, action="always") as wm:
with self.assertRaisesMessage(
TypeError,
"Cannot specify both order_by and ordering.",
@@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
)
first_warning = wm[0]
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
self.assertEqual(
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead.",
str(first_warning.message),
)
second_warning = wm[1]
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
self.assertEqual(
"The ordering argument is deprecated. Use order_by instead.",
str(second_warning.message),
)
def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
@@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
self.assertEqual(values, {"boolor": False})
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
def test_string_agg_delimiter_escaping(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";")
)
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
def test_string_agg_default_output_field(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("text_field", delimiter=";"),
)
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
def test_string_agg_charfield_order_by(self):
order_by_test_cases = (
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
("char_field", "Foo1;Foo2;Foo3;Foo4"),
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
def test_string_agg_jsonfield_order_by(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("lang", "json_field"),
delimiter=";",
order_by=KeyTextTransform("lang", "json_field"),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "en;pl"})
def test_string_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
"char_field",
delimiter=";",
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
)
)
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
def test_orderable_agg_alternative_fields(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
@@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
],
)
def test_string_agg_array_agg_order_by_in_subquery(self):
def test_array_agg_order_by_in_subquery(self):
stats = []
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
StatTestModel.objects.bulk_create(stats)
for aggregate, expected_result in (
(
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"),
[
("Foo1", [0, 1]),
("Foo2", [1, 2]),
("Foo3", [2, 3]),
("Foo4", [3, 4]),
],
),
(
StringAgg(
Cast("stattestmodel__int1", CharField()),
delimiter=";",
order_by="-stattestmodel__int2",
),
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
expected_result = [
("Foo1", [0, 1]),
("Foo2", [1, 2]),
("Foo3", [2, 3]),
("Foo4", [3, 4]),
]
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_array_agg_filter_in_subquery(self):
StatTestModel.objects.bulk_create(
@@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
]
)
for aggregate, expected_result in (
(
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
[("Foo1", [0, 1]), ("Foo2", None)],
),
(
StringAgg(
Cast("stattestmodel__int2", CharField()),
delimiter=";",
filter=Q(stattestmodel__int1__lt=2),
),
[("Foo1", "5;4"), ("Foo2", None)],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_filter_in_subquery_with_exclude(self):
aggregate = ArrayAgg(
"stattestmodel__int1",
filter=Q(stattestmodel__int2__gt=3),
)
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
subquery = (
AggregateTestModel.objects.annotate(
stringagg=StringAgg(
"char_field",
delimiter=";",
filter=Q(char_field__endswith="1"),
)
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.exclude(stringagg="")
.values("id")
.annotate(agg=aggregate)
.values("agg")
)
self.assertSequenceEqual(
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
[self.aggs[0]],
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_ordering_isnt_cleared_for_array_subquery(self):
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
@@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
for aggregation in tests:
with self.subTest(aggregation=aggregation):
results = AggregateTestModel.objects.annotate(
agg=aggregation
).values_list("agg")
self.assertCountEqual(
AggregateTestModel.objects.values_list(aggregation),
results,
[([0],), ([1],), ([2],), ([0],)],
)
def test_string_agg_delimiter_deprecation(self):
msg = (
"delimiter: str will be resolved as a field reference instead "
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
"preserve the previous behaviour."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
def test_string_agg_deprecation(self):
msg = (
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=Value("'"))
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
class TestAggregateDistinct(PostgreSQLTestCase):
@classmethod
@@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
AggregateTestModel.objects.create(char_field="Foo")
AggregateTestModel.objects.create(char_field="Bar")
def test_string_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
)
self.assertEqual(values["stringagg"].count("Foo"), 2)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_string_agg_distinct_true(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
)
self.assertEqual(values["stringagg"].count("Foo"), 1)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_array_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("char_field", distinct=False)