mirror of
https://github.com/django/django.git
synced 2025-10-24 14:16:09 +00:00
Fixed #35444 -- Added generic support for Aggregate.order_by.
This moves the behaviors of `order_by` used in Postgres aggregates into the `Aggregate` class. This allows for creating aggregate functions that support this behavior across all database engines. This is shown by moving the `StringAgg` class into the shared `aggregates` module and adding support for all databases. The Postgres `StringAgg` class is now a thin wrapper on the new shared `StringAgg` class. Thank you Simon Charette for the review.
This commit is contained in:
committed by
Sarah Boyce
parent
6d1cf5375f
commit
4b977a5d72
@@ -33,15 +33,14 @@ class GeoAggregate(Aggregate):
|
|||||||
if not self.is_extent:
|
if not self.is_extent:
|
||||||
tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
|
tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
source_expressions = self.get_source_expressions()
|
*source_exprs, filter_expr, order_by_expr = self.get_source_expressions()
|
||||||
source_expressions.pop() # Don't wrap filters with SDOAGGRTYPE().
|
|
||||||
spatial_type_expr = Func(
|
spatial_type_expr = Func(
|
||||||
*source_expressions,
|
*source_exprs,
|
||||||
Value(tolerance),
|
Value(tolerance),
|
||||||
function="SDOAGGRTYPE",
|
function="SDOAGGRTYPE",
|
||||||
output_field=self.output_field,
|
output_field=self.output_field,
|
||||||
)
|
)
|
||||||
source_expressions = [spatial_type_expr, self.filter]
|
source_expressions = [spatial_type_expr, filter_expr, order_by_expr]
|
||||||
clone.set_source_expressions(source_expressions)
|
clone.set_source_expressions(source_expressions)
|
||||||
return clone.as_sql(compiler, connection, **extra_context)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
return self.as_sql(compiler, connection, **extra_context)
|
return self.as_sql(compiler, connection, **extra_context)
|
||||||
|
@@ -1,7 +1,12 @@
|
|||||||
from django.contrib.postgres.fields import ArrayField
|
import warnings
|
||||||
from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value
|
|
||||||
|
|
||||||
from .mixins import OrderableAggMixin
|
from django.contrib.postgres.fields import ArrayField
|
||||||
|
from django.db.models import Aggregate, BooleanField, JSONField
|
||||||
|
from django.db.models import StringAgg as _StringAgg
|
||||||
|
from django.db.models import Value
|
||||||
|
from django.utils.deprecation import RemovedInDjango70Warning
|
||||||
|
|
||||||
|
from .mixins import _DeprecatedOrdering
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ArrayAgg",
|
"ArrayAgg",
|
||||||
@@ -11,14 +16,16 @@ __all__ = [
|
|||||||
"BoolAnd",
|
"BoolAnd",
|
||||||
"BoolOr",
|
"BoolOr",
|
||||||
"JSONBAgg",
|
"JSONBAgg",
|
||||||
"StringAgg",
|
"StringAgg", # RemovedInDjango70Warning.
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ArrayAgg(OrderableAggMixin, Aggregate):
|
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||||
|
# class ArrayAgg(Aggregate):
|
||||||
|
class ArrayAgg(_DeprecatedOrdering, Aggregate):
|
||||||
function = "ARRAY_AGG"
|
function = "ARRAY_AGG"
|
||||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
|
||||||
allow_distinct = True
|
allow_distinct = True
|
||||||
|
allow_order_by = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
@@ -47,19 +54,37 @@ class BoolOr(Aggregate):
|
|||||||
output_field = BooleanField()
|
output_field = BooleanField()
|
||||||
|
|
||||||
|
|
||||||
class JSONBAgg(OrderableAggMixin, Aggregate):
|
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||||
|
# class JSONBAgg(Aggregate):
|
||||||
|
class JSONBAgg(_DeprecatedOrdering, Aggregate):
|
||||||
function = "JSONB_AGG"
|
function = "JSONB_AGG"
|
||||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
|
||||||
allow_distinct = True
|
allow_distinct = True
|
||||||
|
allow_order_by = True
|
||||||
output_field = JSONField()
|
output_field = JSONField()
|
||||||
|
|
||||||
|
|
||||||
class StringAgg(OrderableAggMixin, Aggregate):
|
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||||
function = "STRING_AGG"
|
# class StringAgg(_StringAgg):
|
||||||
template = "%(function)s(%(distinct)s%(expressions)s %(order_by)s)"
|
# RemovedInDjango70Warning: When the deprecation ends, remove completely.
|
||||||
allow_distinct = True
|
class StringAgg(_DeprecatedOrdering, _StringAgg):
|
||||||
output_field = TextField()
|
|
||||||
|
|
||||||
def __init__(self, expression, delimiter, **extra):
|
def __init__(self, expression, delimiter, **extra):
|
||||||
delimiter_expr = Value(str(delimiter))
|
if isinstance(delimiter, str):
|
||||||
super().__init__(expression, delimiter_expr, **extra)
|
warnings.warn(
|
||||||
|
"delimiter: str will be resolved as a field reference instead "
|
||||||
|
"of a string literal on Django 7.0. Pass "
|
||||||
|
f"`delimiter=Value({delimiter!r})` to preserve the previous behaviour.",
|
||||||
|
category=RemovedInDjango70Warning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
delimiter = Value(delimiter)
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||||
|
"django.db.models.aggregate.StringAgg instead.",
|
||||||
|
category=RemovedInDjango70Warning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(expression, delimiter, **extra)
|
||||||
|
@@ -1,15 +1,11 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from django.core.exceptions import FullResultSet
|
|
||||||
from django.db.models.expressions import OrderByList
|
|
||||||
from django.utils.deprecation import RemovedInDjango61Warning
|
from django.utils.deprecation import RemovedInDjango61Warning
|
||||||
|
|
||||||
|
|
||||||
class OrderableAggMixin:
|
# RemovedInDjango61Warning.
|
||||||
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
class _DeprecatedOrdering:
|
||||||
# def __init__(self, *expressions, order_by=(), **extra):
|
|
||||||
def __init__(self, *expressions, ordering=(), order_by=(), **extra):
|
def __init__(self, *expressions, ordering=(), order_by=(), **extra):
|
||||||
# RemovedInDjango61Warning.
|
|
||||||
if ordering:
|
if ordering:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The ordering argument is deprecated. Use order_by instead.",
|
"The ordering argument is deprecated. Use order_by instead.",
|
||||||
@@ -19,44 +15,14 @@ class OrderableAggMixin:
|
|||||||
if order_by:
|
if order_by:
|
||||||
raise TypeError("Cannot specify both order_by and ordering.")
|
raise TypeError("Cannot specify both order_by and ordering.")
|
||||||
order_by = ordering
|
order_by = ordering
|
||||||
if not order_by:
|
|
||||||
self.order_by = None
|
|
||||||
elif isinstance(order_by, (list, tuple)):
|
|
||||||
self.order_by = OrderByList(*order_by)
|
|
||||||
else:
|
|
||||||
self.order_by = OrderByList(order_by)
|
|
||||||
super().__init__(*expressions, **extra)
|
|
||||||
|
|
||||||
def resolve_expression(self, *args, **kwargs):
|
super().__init__(*expressions, order_by=order_by, **extra)
|
||||||
if self.order_by is not None:
|
|
||||||
self.order_by = self.order_by.resolve_expression(*args, **kwargs)
|
|
||||||
return super().resolve_expression(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_source_expressions(self):
|
|
||||||
return super().get_source_expressions() + [self.order_by]
|
|
||||||
|
|
||||||
def set_source_expressions(self, exprs):
|
# RemovedInDjango61Warning: When the deprecation ends, replace with:
|
||||||
*exprs, self.order_by = exprs
|
# class OrderableAggMixin:
|
||||||
return super().set_source_expressions(exprs)
|
class OrderableAggMixin(_DeprecatedOrdering):
|
||||||
|
allow_order_by = True
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def __init_subclass__(cls, /, *args, **kwargs):
|
||||||
*source_exprs, filtering_expr, order_by_expr = self.get_source_expressions()
|
super().__init_subclass__(*args, **kwargs)
|
||||||
|
|
||||||
order_by_sql = ""
|
|
||||||
order_by_params = []
|
|
||||||
if order_by_expr is not None:
|
|
||||||
order_by_sql, order_by_params = compiler.compile(order_by_expr)
|
|
||||||
|
|
||||||
filter_params = []
|
|
||||||
if filtering_expr is not None:
|
|
||||||
try:
|
|
||||||
_, filter_params = compiler.compile(filtering_expr)
|
|
||||||
except FullResultSet:
|
|
||||||
pass
|
|
||||||
|
|
||||||
source_params = []
|
|
||||||
for source_expr in source_exprs:
|
|
||||||
source_params += compiler.compile(source_expr)[1]
|
|
||||||
|
|
||||||
sql, _ = super().as_sql(compiler, connection, order_by=order_by_sql)
|
|
||||||
return sql, (*source_params, *order_by_params, *filter_params)
|
|
||||||
|
@@ -257,6 +257,15 @@ class BaseDatabaseFeatures:
|
|||||||
# expressions?
|
# expressions?
|
||||||
supports_aggregate_filter_clause = False
|
supports_aggregate_filter_clause = False
|
||||||
|
|
||||||
|
# Does the database support ORDER BY in aggregate expressions?
|
||||||
|
supports_aggregate_order_by_clause = False
|
||||||
|
|
||||||
|
# Does the database backend support DISTINCT when using multiple arguments in an
|
||||||
|
# aggregate expression? For example, Sqlite treats the "delimiter" argument of
|
||||||
|
# STRING_AGG/GROUP_CONCAT as an extra argument and does not allow using a custom
|
||||||
|
# delimiter along with DISTINCT.
|
||||||
|
supports_aggregate_distinct_multiple_argument = True
|
||||||
|
|
||||||
# Does the backend support indexing a TextField?
|
# Does the backend support indexing a TextField?
|
||||||
supports_index_on_text_field = True
|
supports_index_on_text_field = True
|
||||||
|
|
||||||
|
@@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
requires_explicit_null_ordering_when_grouping = True
|
requires_explicit_null_ordering_when_grouping = True
|
||||||
atomic_transactions = False
|
atomic_transactions = False
|
||||||
can_clone_databases = True
|
can_clone_databases = True
|
||||||
|
supports_aggregate_order_by_clause = True
|
||||||
supports_comments = True
|
supports_comments = True
|
||||||
supports_comments_inline = True
|
supports_comments_inline = True
|
||||||
supports_temporal_subtraction = True
|
supports_temporal_subtraction = True
|
||||||
|
@@ -45,6 +45,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
# does by uppercasing all identifiers.
|
# does by uppercasing all identifiers.
|
||||||
ignores_table_name_case = True
|
ignores_table_name_case = True
|
||||||
supports_index_on_text_field = False
|
supports_index_on_text_field = False
|
||||||
|
supports_aggregate_order_by_clause = True
|
||||||
create_test_procedure_without_params_sql = """
|
create_test_procedure_without_params_sql = """
|
||||||
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
||||||
V_I INTEGER;
|
V_I INTEGER;
|
||||||
|
@@ -64,6 +64,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
supports_frame_exclusion = True
|
supports_frame_exclusion = True
|
||||||
only_supports_unbounded_with_preceding_and_following = True
|
only_supports_unbounded_with_preceding_and_following = True
|
||||||
supports_aggregate_filter_clause = True
|
supports_aggregate_filter_clause = True
|
||||||
|
supports_aggregate_order_by_clause = True
|
||||||
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
|
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
|
||||||
supports_deferrable_unique_constraints = True
|
supports_deferrable_unique_constraints = True
|
||||||
has_json_operators = True
|
has_json_operators = True
|
||||||
|
@@ -34,6 +34,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
supports_frame_range_fixed_distance = True
|
supports_frame_range_fixed_distance = True
|
||||||
supports_frame_exclusion = True
|
supports_frame_exclusion = True
|
||||||
supports_aggregate_filter_clause = True
|
supports_aggregate_filter_clause = True
|
||||||
|
supports_aggregate_order_by_clause = Database.sqlite_version_info >= (3, 44, 0)
|
||||||
|
supports_aggregate_distinct_multiple_argument = False
|
||||||
order_by_nulls_first = True
|
order_by_nulls_first = True
|
||||||
supports_json_field_contains = False
|
supports_json_field_contains = False
|
||||||
supports_update_conflicts = True
|
supports_update_conflicts = True
|
||||||
|
@@ -3,8 +3,17 @@ Classes to represent the definitions of aggregate functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from django.core.exceptions import FieldError, FullResultSet
|
from django.core.exceptions import FieldError, FullResultSet
|
||||||
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
|
from django.db import NotSupportedError
|
||||||
from django.db.models.fields import IntegerField
|
from django.db.models.expressions import (
|
||||||
|
Case,
|
||||||
|
ColPairs,
|
||||||
|
Func,
|
||||||
|
OrderByList,
|
||||||
|
Star,
|
||||||
|
Value,
|
||||||
|
When,
|
||||||
|
)
|
||||||
|
from django.db.models.fields import IntegerField, TextField
|
||||||
from django.db.models.functions import Coalesce
|
from django.db.models.functions import Coalesce
|
||||||
from django.db.models.functions.mixins import (
|
from django.db.models.functions.mixins import (
|
||||||
FixDurationInputMixin,
|
FixDurationInputMixin,
|
||||||
@@ -18,42 +27,91 @@ __all__ = [
|
|||||||
"Max",
|
"Max",
|
||||||
"Min",
|
"Min",
|
||||||
"StdDev",
|
"StdDev",
|
||||||
|
"StringAgg",
|
||||||
"Sum",
|
"Sum",
|
||||||
"Variance",
|
"Variance",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateFilter(Func):
|
||||||
|
arity = 1
|
||||||
|
template = " FILTER (WHERE %(expressions)s)"
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
if not connection.features.supports_aggregate_filter_clause:
|
||||||
|
raise NotSupportedError(
|
||||||
|
"Aggregate filter clauses are not supported on this database backend."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
except FullResultSet:
|
||||||
|
return "", ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition(self):
|
||||||
|
return self.source_expressions[0]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateOrderBy(OrderByList):
|
||||||
|
template = " ORDER BY %(expressions)s"
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
if not connection.features.supports_aggregate_order_by_clause:
|
||||||
|
raise NotSupportedError(
|
||||||
|
"This database backend does not support specifying an order on "
|
||||||
|
"aggregates."
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Aggregate(Func):
|
class Aggregate(Func):
|
||||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||||
contains_aggregate = True
|
contains_aggregate = True
|
||||||
name = None
|
name = None
|
||||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
|
||||||
window_compatible = True
|
window_compatible = True
|
||||||
allow_distinct = False
|
allow_distinct = False
|
||||||
|
allow_order_by = False
|
||||||
empty_result_set_value = None
|
empty_result_set_value = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
self,
|
||||||
|
*expressions,
|
||||||
|
distinct=False,
|
||||||
|
filter=None,
|
||||||
|
default=None,
|
||||||
|
order_by=None,
|
||||||
|
**extra,
|
||||||
):
|
):
|
||||||
if distinct and not self.allow_distinct:
|
if distinct and not self.allow_distinct:
|
||||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||||
|
if order_by and not self.allow_order_by:
|
||||||
|
raise TypeError("%s does not allow order_by." % self.__class__.__name__)
|
||||||
if default is not None and self.empty_result_set_value is not None:
|
if default is not None and self.empty_result_set_value is not None:
|
||||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||||
|
|
||||||
self.distinct = distinct
|
self.distinct = distinct
|
||||||
self.filter = filter
|
self.filter = filter and AggregateFilter(filter)
|
||||||
self.default = default
|
self.default = default
|
||||||
|
self.order_by = AggregateOrderBy.from_param(
|
||||||
|
f"{self.__class__.__name__}.order_by", order_by
|
||||||
|
)
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def get_source_fields(self):
|
def get_source_fields(self):
|
||||||
# Don't return the filter expression since it's not a source field.
|
# Don't consider filter and order by expression as they have nothing
|
||||||
|
# to do with the output field resolution.
|
||||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
source_expressions = super().get_source_expressions()
|
source_expressions = super().get_source_expressions()
|
||||||
return source_expressions + [self.filter]
|
return source_expressions + [self.filter, self.order_by]
|
||||||
|
|
||||||
def set_source_expressions(self, exprs):
|
def set_source_expressions(self, exprs):
|
||||||
*exprs, self.filter = exprs
|
*exprs, self.filter, self.order_by = exprs
|
||||||
return super().set_source_expressions(exprs)
|
return super().set_source_expressions(exprs)
|
||||||
|
|
||||||
def resolve_expression(
|
def resolve_expression(
|
||||||
@@ -66,6 +124,11 @@ class Aggregate(Func):
|
|||||||
if c.filter
|
if c.filter
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
c.order_by = (
|
||||||
|
c.order_by.resolve_expression(query, allow_joins, reuse, summarize)
|
||||||
|
if c.order_by
|
||||||
|
else None
|
||||||
|
)
|
||||||
if summarize:
|
if summarize:
|
||||||
# Summarized aggregates cannot refer to summarized aggregates.
|
# Summarized aggregates cannot refer to summarized aggregates.
|
||||||
for ref in c.get_refs():
|
for ref in c.get_refs():
|
||||||
@@ -115,35 +178,45 @@ class Aggregate(Func):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
if (
|
||||||
if self.filter:
|
self.distinct
|
||||||
if connection.features.supports_aggregate_filter_clause:
|
and not connection.features.supports_aggregate_distinct_multiple_argument
|
||||||
try:
|
and len(super().get_source_expressions()) > 1
|
||||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
):
|
||||||
except FullResultSet:
|
raise NotSupportedError(
|
||||||
pass
|
f"{self.name} does not support distinct with multiple expressions on "
|
||||||
else:
|
f"this database backend."
|
||||||
template = self.filter_template % extra_context.get(
|
)
|
||||||
"template", self.template
|
|
||||||
)
|
distinct_sql = "DISTINCT " if self.distinct else ""
|
||||||
sql, params = super().as_sql(
|
order_by_sql = ""
|
||||||
compiler,
|
order_by_params = []
|
||||||
connection,
|
filter_sql = ""
|
||||||
template=template,
|
filter_params = []
|
||||||
filter=filter_sql,
|
|
||||||
**extra_context,
|
if (order_by := self.order_by) is not None:
|
||||||
)
|
order_by_sql, order_by_params = compiler.compile(order_by)
|
||||||
return sql, (*params, *filter_params)
|
|
||||||
else:
|
if self.filter is not None:
|
||||||
|
try:
|
||||||
|
filter_sql, filter_params = compiler.compile(self.filter)
|
||||||
|
except NotSupportedError:
|
||||||
|
# Fallback to a CASE statement on backends that don't support
|
||||||
|
# the FILTER clause.
|
||||||
copy = self.copy()
|
copy = self.copy()
|
||||||
copy.filter = None
|
copy.filter = None
|
||||||
source_expressions = copy.get_source_expressions()
|
source_expressions = copy.get_source_expressions()
|
||||||
condition = When(self.filter, then=source_expressions[0])
|
condition = When(self.filter.condition, then=source_expressions[0])
|
||||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||||
return super(Aggregate, copy).as_sql(
|
return copy.as_sql(compiler, connection, **extra_context)
|
||||||
compiler, connection, **extra_context
|
|
||||||
)
|
extra_context.update(
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
distinct=distinct_sql,
|
||||||
|
filter=filter_sql,
|
||||||
|
order_by=order_by_sql,
|
||||||
|
)
|
||||||
|
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||||
|
return sql, (*params, *order_by_params, *filter_params)
|
||||||
|
|
||||||
def _get_repr_options(self):
|
def _get_repr_options(self):
|
||||||
options = super()._get_repr_options()
|
options = super()._get_repr_options()
|
||||||
@@ -151,6 +224,8 @@ class Aggregate(Func):
|
|||||||
options["distinct"] = self.distinct
|
options["distinct"] = self.distinct
|
||||||
if self.filter:
|
if self.filter:
|
||||||
options["filter"] = self.filter
|
options["filter"] = self.filter
|
||||||
|
if self.order_by:
|
||||||
|
options["order_by"] = self.order_by
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
@@ -179,17 +254,17 @@ class Count(Aggregate):
|
|||||||
|
|
||||||
def resolve_expression(self, *args, **kwargs):
|
def resolve_expression(self, *args, **kwargs):
|
||||||
result = super().resolve_expression(*args, **kwargs)
|
result = super().resolve_expression(*args, **kwargs)
|
||||||
expr = result.source_expressions[0]
|
source_expressions = result.get_source_expressions()
|
||||||
|
|
||||||
# In case of composite primary keys, count the first column.
|
# In case of composite primary keys, count the first column.
|
||||||
if isinstance(expr, ColPairs):
|
if isinstance(expr := source_expressions[0], ColPairs):
|
||||||
if self.distinct:
|
if self.distinct:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"COUNT(DISTINCT) doesn't support composite primary keys"
|
"COUNT(DISTINCT) doesn't support composite primary keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
cols = expr.get_cols()
|
source_expressions[0] = expr.get_cols()[0]
|
||||||
return Count(cols[0], filter=result.filter)
|
result.set_source_expressions(source_expressions)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -218,6 +293,88 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
|
|||||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||||
|
|
||||||
|
|
||||||
|
class StringAggDelimiter(Func):
|
||||||
|
arity = 1
|
||||||
|
template = "%(expressions)s"
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
super().__init__(value)
|
||||||
|
|
||||||
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
|
template = " SEPARATOR %(expressions)s"
|
||||||
|
|
||||||
|
return self.as_sql(
|
||||||
|
compiler,
|
||||||
|
connection,
|
||||||
|
template=template,
|
||||||
|
**extra_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StringAgg(Aggregate):
|
||||||
|
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s)%(filter)s"
|
||||||
|
function = "STRING_AGG"
|
||||||
|
name = "StringAgg"
|
||||||
|
allow_distinct = True
|
||||||
|
allow_order_by = True
|
||||||
|
output_field = TextField()
|
||||||
|
|
||||||
|
def __init__(self, expression, delimiter, **extra):
|
||||||
|
self.delimiter = StringAggDelimiter(delimiter)
|
||||||
|
super().__init__(expression, self.delimiter, **extra)
|
||||||
|
|
||||||
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
|
if self.order_by:
|
||||||
|
template = (
|
||||||
|
"%(function)s(%(distinct)s%(expressions)s) WITHIN GROUP (%(order_by)s)"
|
||||||
|
"%(filter)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
template = "%(function)s(%(distinct)s%(expressions)s)%(filter)s"
|
||||||
|
|
||||||
|
return self.as_sql(
|
||||||
|
compiler,
|
||||||
|
connection,
|
||||||
|
function="LISTAGG",
|
||||||
|
template=template,
|
||||||
|
**extra_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
|
extra_context["function"] = "GROUP_CONCAT"
|
||||||
|
|
||||||
|
template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(delimiter)s)"
|
||||||
|
extra_context["template"] = template
|
||||||
|
|
||||||
|
c = self.copy()
|
||||||
|
# The creation of the delimiter SQL and the ordering of the parameters must be
|
||||||
|
# handled explicitly, as MySQL puts the delimiter at the end of the aggregate
|
||||||
|
# using the `SEPARATOR` declaration (rather than treating as an expression like
|
||||||
|
# other database backends).
|
||||||
|
delimiter_params = []
|
||||||
|
if c.delimiter:
|
||||||
|
delimiter_sql, delimiter_params = compiler.compile(c.delimiter)
|
||||||
|
# Drop the delimiter from the source expressions.
|
||||||
|
c.source_expressions = c.source_expressions[:-1]
|
||||||
|
extra_context["delimiter"] = delimiter_sql
|
||||||
|
|
||||||
|
sql, params = c.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
return sql, (*params, *delimiter_params)
|
||||||
|
|
||||||
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
|
if connection.get_database_version() < (3, 44):
|
||||||
|
return self.as_sql(
|
||||||
|
compiler,
|
||||||
|
connection,
|
||||||
|
function="GROUP_CONCAT",
|
||||||
|
**extra_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Sum(FixDurationInputMixin, Aggregate):
|
class Sum(FixDurationInputMixin, Aggregate):
|
||||||
function = "SUM"
|
function = "SUM"
|
||||||
name = "Sum"
|
name = "Sum"
|
||||||
|
@@ -1481,6 +1481,21 @@ class OrderByList(ExpressionList):
|
|||||||
)
|
)
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_param(cls, context, param):
|
||||||
|
if param is None:
|
||||||
|
return None
|
||||||
|
if isinstance(param, (list, tuple)):
|
||||||
|
if not param:
|
||||||
|
return None
|
||||||
|
return cls(*param)
|
||||||
|
elif isinstance(param, str) or hasattr(param, "resolve_expression"):
|
||||||
|
return cls(param)
|
||||||
|
raise ValueError(
|
||||||
|
f"{context} must be either a string reference to a "
|
||||||
|
f"field, an expression, or a list or tuple of them not {param!r}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@deconstructible(path="django.db.models.ExpressionWrapper")
|
@deconstructible(path="django.db.models.ExpressionWrapper")
|
||||||
class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||||
@@ -1943,16 +1958,7 @@ class Window(SQLiteNumericMixin, Expression):
|
|||||||
self.partition_by = (self.partition_by,)
|
self.partition_by = (self.partition_by,)
|
||||||
self.partition_by = ExpressionList(*self.partition_by)
|
self.partition_by = ExpressionList(*self.partition_by)
|
||||||
|
|
||||||
if self.order_by is not None:
|
self.order_by = OrderByList.from_param("Window.order_by", self.order_by)
|
||||||
if isinstance(self.order_by, (list, tuple)):
|
|
||||||
self.order_by = OrderByList(*self.order_by)
|
|
||||||
elif isinstance(self.order_by, (BaseExpression, str)):
|
|
||||||
self.order_by = OrderByList(self.order_by)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Window.order_by must be either a string reference to a "
|
|
||||||
"field, an expression, or a list or tuple of them."
|
|
||||||
)
|
|
||||||
super().__init__(output_field=output_field)
|
super().__init__(output_field=output_field)
|
||||||
self.source_expression = self._parse_expressions(expression)[0]
|
self.source_expression = self._parse_expressions(expression)[0]
|
||||||
|
|
||||||
|
@@ -18,6 +18,8 @@ details on these changes.
|
|||||||
* The ``serialize`` keyword argument of
|
* The ``serialize`` keyword argument of
|
||||||
``BaseDatabaseCreation.create_test_db()`` will be removed.
|
``BaseDatabaseCreation.create_test_db()`` will be removed.
|
||||||
|
|
||||||
|
* The ``django.contrib.postgres.aggregates.StringAgg`` class will be removed.
|
||||||
|
|
||||||
.. _deprecation-removed-in-6.1:
|
.. _deprecation-removed-in-6.1:
|
||||||
|
|
||||||
6.1
|
6.1
|
||||||
|
@@ -194,6 +194,8 @@ General-purpose aggregation functions
|
|||||||
|
|
||||||
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=())
|
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, order_by=())
|
||||||
|
|
||||||
|
.. deprecated:: 6.0
|
||||||
|
|
||||||
Returns the input values concatenated into a string, separated by
|
Returns the input values concatenated into a string, separated by
|
||||||
the ``delimiter`` string, or ``default`` if there are no values.
|
the ``delimiter`` string, or ``default`` if there are no values.
|
||||||
|
|
||||||
|
@@ -448,7 +448,7 @@ some complex computations::
|
|||||||
|
|
||||||
The ``Aggregate`` API is as follows:
|
The ``Aggregate`` API is as follows:
|
||||||
|
|
||||||
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra)
|
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, order_by=None, **extra)
|
||||||
|
|
||||||
.. attribute:: template
|
.. attribute:: template
|
||||||
|
|
||||||
@@ -473,6 +473,15 @@ The ``Aggregate`` API is as follows:
|
|||||||
allows passing a ``distinct`` keyword argument. If set to ``False``
|
allows passing a ``distinct`` keyword argument. If set to ``False``
|
||||||
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
|
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
|
||||||
|
|
||||||
|
.. attribute:: allow_order_by
|
||||||
|
|
||||||
|
.. versionadded:: 6.0
|
||||||
|
|
||||||
|
A class attribute determining whether or not this aggregate function
|
||||||
|
allows passing a ``order_by`` keyword argument. If set to ``False``
|
||||||
|
(default), ``TypeError`` is raised if ``order_by`` is passed as a value
|
||||||
|
other than ``None``.
|
||||||
|
|
||||||
.. attribute:: empty_result_set_value
|
.. attribute:: empty_result_set_value
|
||||||
|
|
||||||
Defaults to ``None`` since most aggregate functions result in ``NULL``
|
Defaults to ``None`` since most aggregate functions result in ``NULL``
|
||||||
@@ -491,6 +500,12 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
|
|||||||
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
||||||
and :ref:`filtering-on-annotations` for example usage.
|
and :ref:`filtering-on-annotations` for example usage.
|
||||||
|
|
||||||
|
The ``order_by`` argument behaves similarly to the ``field_names`` input of the
|
||||||
|
:meth:`~.QuerySet.order_by` function, accepting a field name (with an optional
|
||||||
|
``"-"`` prefix which indicates descending order) or an expression (or a tuple
|
||||||
|
or list of strings and/or expressions) that specifies the ordering of the
|
||||||
|
elements in the result.
|
||||||
|
|
||||||
The ``default`` argument takes a value that will be passed along with the
|
The ``default`` argument takes a value that will be passed along with the
|
||||||
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
|
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
|
||||||
specifying a value to be returned other than ``None`` when the queryset (or
|
specifying a value to be returned other than ``None`` when the queryset (or
|
||||||
@@ -499,6 +514,10 @@ grouping) contains no entries.
|
|||||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||||
into the ``template`` attribute.
|
into the ``template`` attribute.
|
||||||
|
|
||||||
|
.. versionchanged:: 6.0
|
||||||
|
|
||||||
|
The ``order_by`` argument was added.
|
||||||
|
|
||||||
Creating your own Aggregate Functions
|
Creating your own Aggregate Functions
|
||||||
-------------------------------------
|
-------------------------------------
|
||||||
|
|
||||||
|
@@ -4046,6 +4046,25 @@ by the aggregate.
|
|||||||
However, if ``sample=True``, the return value will be the sample
|
However, if ``sample=True``, the return value will be the sample
|
||||||
variance.
|
variance.
|
||||||
|
|
||||||
|
``StringAgg``
|
||||||
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. versionadded:: 6.0
|
||||||
|
|
||||||
|
.. class:: StringAgg(expression, delimiter, output_field=None, distinct=False, filter=None, order_by=None, default=None, **extra)
|
||||||
|
|
||||||
|
Returns the input values concatenated into a string, separated by the
|
||||||
|
``delimiter`` string, or ``default`` if there are no values.
|
||||||
|
|
||||||
|
* Default alias: ``<field>__stringagg``
|
||||||
|
* Return type: ``string`` or ``output_field`` if supplied. If the
|
||||||
|
queryset or grouping is empty, ``default`` is returned.
|
||||||
|
|
||||||
|
.. attribute:: delimiter
|
||||||
|
|
||||||
|
A ``Value`` or expression representing the string that should separate
|
||||||
|
each of the values. For example, ``Value(",")``.
|
||||||
|
|
||||||
Query-related tools
|
Query-related tools
|
||||||
===================
|
===================
|
||||||
|
|
||||||
|
@@ -184,6 +184,16 @@ Models
|
|||||||
* :doc:`Constraints </ref/models/constraints>` now implement a ``check()``
|
* :doc:`Constraints </ref/models/constraints>` now implement a ``check()``
|
||||||
method that is already registered with the check framework.
|
method that is already registered with the check framework.
|
||||||
|
|
||||||
|
* The new ``order_by`` argument for :class:`~django.db.models.Aggregate` allows
|
||||||
|
specifying the ordering of the elements in the result.
|
||||||
|
|
||||||
|
* The new :attr:`.Aggregate.allow_order_by` class attribute determines whether
|
||||||
|
the aggregate function allows passing an ``order_by`` keyword argument.
|
||||||
|
|
||||||
|
* The new :class:`~django.db.models.StringAgg` aggregate returns the input
|
||||||
|
values concatenated into a string, separated by the ``delimiter`` string.
|
||||||
|
This aggregate was previously supported only for PostgreSQL.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -288,6 +298,9 @@ Miscellaneous
|
|||||||
* ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use
|
* ``BaseDatabaseCreation.create_test_db(serialize)`` is deprecated. Use
|
||||||
``serialize_db_to_string()`` instead.
|
``serialize_db_to_string()`` instead.
|
||||||
|
|
||||||
|
* The PostgreSQL ``StringAgg`` class is deprecated in favor of the generally
|
||||||
|
available :class:`~django.db.models.StringAgg` class.
|
||||||
|
|
||||||
Features removed in 6.0
|
Features removed in 6.0
|
||||||
=======================
|
=======================
|
||||||
|
|
||||||
|
@@ -43,3 +43,7 @@ class Store(models.Model):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class Employee(models.Model):
|
||||||
|
work_day_preferences = models.JSONField()
|
||||||
|
@@ -4,10 +4,11 @@ import re
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import connection
|
from django.db import NotSupportedError, connection
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
Avg,
|
Avg,
|
||||||
Case,
|
Case,
|
||||||
|
CharField,
|
||||||
Count,
|
Count,
|
||||||
DateField,
|
DateField,
|
||||||
DateTimeField,
|
DateTimeField,
|
||||||
@@ -22,6 +23,7 @@ from django.db.models import (
|
|||||||
OuterRef,
|
OuterRef,
|
||||||
Q,
|
Q,
|
||||||
StdDev,
|
StdDev,
|
||||||
|
StringAgg,
|
||||||
Subquery,
|
Subquery,
|
||||||
Sum,
|
Sum,
|
||||||
TimeField,
|
TimeField,
|
||||||
@@ -32,9 +34,11 @@ from django.db.models import (
|
|||||||
Window,
|
Window,
|
||||||
)
|
)
|
||||||
from django.db.models.expressions import Func, RawSQL
|
from django.db.models.expressions import Func, RawSQL
|
||||||
|
from django.db.models.fields.json import KeyTextTransform
|
||||||
from django.db.models.functions import (
|
from django.db.models.functions import (
|
||||||
Cast,
|
Cast,
|
||||||
Coalesce,
|
Coalesce,
|
||||||
|
Concat,
|
||||||
Greatest,
|
Greatest,
|
||||||
Least,
|
Least,
|
||||||
Lower,
|
Lower,
|
||||||
@@ -45,11 +49,11 @@ from django.db.models.functions import (
|
|||||||
TruncHour,
|
TruncHour,
|
||||||
)
|
)
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.testcases import skipUnlessDBFeature
|
from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
|
||||||
from django.test.utils import Approximate, CaptureQueriesContext
|
from django.test.utils import Approximate, CaptureQueriesContext
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from .models import Author, Book, Publisher, Store
|
from .models import Author, Book, Employee, Publisher, Store
|
||||||
|
|
||||||
|
|
||||||
class NowUTC(Now):
|
class NowUTC(Now):
|
||||||
@@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(books["ratings"], expected_result)
|
self.assertEqual(books["ratings"], expected_result)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
|
||||||
|
def test_distinct_on_stringagg(self):
|
||||||
|
books = Book.objects.aggregate(
|
||||||
|
ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
|
||||||
|
)
|
||||||
|
self.assertEqual(books["ratings"], "3,4,4.5,5")
|
||||||
|
|
||||||
|
@skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
|
||||||
|
def test_raises_error_on_multiple_argument_distinct(self):
|
||||||
|
message = (
|
||||||
|
"StringAgg does not support distinct with multiple expressions on this "
|
||||||
|
"database backend."
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(NotSupportedError, message):
|
||||||
|
Book.objects.aggregate(
|
||||||
|
ratings=StringAgg(
|
||||||
|
Cast(F("rating"), CharField()),
|
||||||
|
Value(","),
|
||||||
|
distinct=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_non_grouped_annotation_not_in_group_by(self):
|
def test_non_grouped_annotation_not_in_group_by(self):
|
||||||
"""
|
"""
|
||||||
An annotation not included in values() before an aggregate should be
|
An annotation not included in values() before an aggregate should be
|
||||||
@@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
|
|||||||
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
|
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
|
||||||
|
|
||||||
def test_multi_arg_aggregate(self):
|
def test_multi_arg_aggregate(self):
|
||||||
class MyMax(Max):
|
class MultiArgAgg(Max):
|
||||||
output_field = DecimalField()
|
output_field = DecimalField()
|
||||||
arity = None
|
arity = None
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
copy = self.copy()
|
copy = self.copy()
|
||||||
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
|
# Most database backends do not support compiling multiple arguments on
|
||||||
return super(MyMax, copy).as_sql(compiler, connection)
|
# the Max aggregate, and that isn't what is being tested here anyway. To
|
||||||
|
# avoid errors, the extra argument is just dropped.
|
||||||
|
copy.set_source_expressions(
|
||||||
|
copy.get_source_expressions()[0:1] + [None, None]
|
||||||
|
)
|
||||||
|
|
||||||
|
return super(MultiArgAgg, copy).as_sql(compiler, connection)
|
||||||
|
|
||||||
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
|
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
|
||||||
Book.objects.aggregate(MyMax("pages", "price"))
|
Book.objects.aggregate(MultiArgAgg("pages", "price"))
|
||||||
|
|
||||||
with self.assertRaisesMessage(
|
with self.assertRaisesMessage(
|
||||||
TypeError, "Complex annotations require an alias"
|
TypeError, "Complex annotations require an alias"
|
||||||
):
|
):
|
||||||
Book.objects.annotate(MyMax("pages", "price"))
|
Book.objects.annotate(MultiArgAgg("pages", "price"))
|
||||||
|
|
||||||
Book.objects.aggregate(max_field=MyMax("pages", "price"))
|
Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
|
||||||
|
|
||||||
def test_add_implementation(self):
|
def test_add_implementation(self):
|
||||||
class MySum(Sum):
|
class MySum(Sum):
|
||||||
@@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
|
|||||||
"function": self.function.lower(),
|
"function": self.function.lower(),
|
||||||
"expressions": sql,
|
"expressions": sql,
|
||||||
"distinct": "",
|
"distinct": "",
|
||||||
|
"filter": "",
|
||||||
|
"order_by": "",
|
||||||
}
|
}
|
||||||
substitutions.update(self.extra)
|
substitutions.update(self.extra)
|
||||||
return self.template % substitutions, params
|
return self.template % substitutions, params
|
||||||
@@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
|
|||||||
|
|
||||||
# test overriding all parts of the template
|
# test overriding all parts of the template
|
||||||
def be_evil(self, compiler, connection):
|
def be_evil(self, compiler, connection):
|
||||||
substitutions = {"function": "MAX", "expressions": "2", "distinct": ""}
|
substitutions = {
|
||||||
|
"function": "MAX",
|
||||||
|
"expressions": "2",
|
||||||
|
"distinct": "",
|
||||||
|
"filter": "",
|
||||||
|
"order_by": "",
|
||||||
|
}
|
||||||
substitutions.update(self.extra)
|
substitutions.update(self.extra)
|
||||||
return self.template % substitutions, ()
|
return self.template % substitutions, ()
|
||||||
|
|
||||||
@@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
|
|||||||
Publisher.objects.none().aggregate(
|
Publisher.objects.none().aggregate(
|
||||||
sum_awards=Sum("num_awards"),
|
sum_awards=Sum("num_awards"),
|
||||||
books_count=Count("book"),
|
books_count=Count("book"),
|
||||||
|
all_names=StringAgg("name", Value(",")),
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
"sum_awards": None,
|
"sum_awards": None,
|
||||||
"books_count": 0,
|
"books_count": 0,
|
||||||
|
"all_names": None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Expression without empty_result_set_value forces queries to be
|
# Expression without empty_result_set_value forces queries to be
|
||||||
@@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(result["value"], 35)
|
self.assertEqual(result["value"], 35)
|
||||||
|
|
||||||
|
def test_stringagg_default_value(self):
|
||||||
|
result = Author.objects.filter(age__gt=100).aggregate(
|
||||||
|
value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
|
||||||
|
)
|
||||||
|
self.assertEqual(result["value"], "<empty>")
|
||||||
|
|
||||||
def test_aggregation_default_group_by(self):
|
def test_aggregation_default_group_by(self):
|
||||||
qs = (
|
qs = (
|
||||||
Publisher.objects.values("name")
|
Publisher.objects.values("name")
|
||||||
@@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
|
|||||||
with self.assertRaisesMessage(TypeError, msg):
|
with self.assertRaisesMessage(TypeError, msg):
|
||||||
super(function, func_instance).__init__(Value(1), Value(2))
|
super(function, func_instance).__init__(Value(1), Value(2))
|
||||||
|
|
||||||
|
def test_string_agg_requires_delimiter(self):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
Book.objects.aggregate(stringagg=StringAgg("name"))
|
||||||
|
|
||||||
|
def test_string_agg_escapes_delimiter(self):
|
||||||
|
values = Publisher.objects.aggregate(
|
||||||
|
stringagg=StringAgg("name", delimiter=Value("'"))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
|
||||||
|
"of Books",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
|
||||||
|
def test_string_agg_order_by(self):
|
||||||
|
order_by_test_cases = (
|
||||||
|
(
|
||||||
|
F("original_opening").desc(),
|
||||||
|
"Books.com;Amazon.com;Mamma and Pappa's Books",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
F("original_opening").asc(),
|
||||||
|
"Mamma and Pappa's Books;Amazon.com;Books.com",
|
||||||
|
),
|
||||||
|
(F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
|
||||||
|
("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
|
||||||
|
("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
|
||||||
|
(
|
||||||
|
Concat("original_opening", Value("@")),
|
||||||
|
"Mamma and Pappa's Books;Amazon.com;Books.com",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Concat("original_opening", Value("@")).desc(),
|
||||||
|
"Books.com;Amazon.com;Mamma and Pappa's Books",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for order_by, expected_output in order_by_test_cases:
|
||||||
|
with self.subTest(order_by=order_by, expected_output=expected_output):
|
||||||
|
values = Store.objects.aggregate(
|
||||||
|
stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
|
||||||
|
)
|
||||||
|
self.assertEqual(values, {"stringagg": expected_output})
|
||||||
|
|
||||||
|
@skipIfDBFeature("supports_aggregate_order_by_clause")
|
||||||
|
def test_string_agg_order_by_is_not_supported(self):
|
||||||
|
message = (
|
||||||
|
"This database backend does not support specifying an order on aggregates."
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(NotSupportedError, message):
|
||||||
|
Store.objects.aggregate(
|
||||||
|
stringagg=StringAgg(
|
||||||
|
"name",
|
||||||
|
delimiter=Value(";"),
|
||||||
|
order_by="original_opening",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_string_agg_filter(self):
|
||||||
|
values = Book.objects.aggregate(
|
||||||
|
stringagg=StringAgg(
|
||||||
|
"name",
|
||||||
|
delimiter=Value(";"),
|
||||||
|
filter=Q(name__startswith="P"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_values = {
|
||||||
|
"stringagg": "Practical Django Projects;"
|
||||||
|
"Python Web Development with Django;Paradigms of Artificial "
|
||||||
|
"Intelligence Programming: Case Studies in Common Lisp",
|
||||||
|
}
|
||||||
|
self.assertEqual(values, expected_values)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
|
||||||
|
def test_string_agg_jsonfield_order_by(self):
|
||||||
|
Employee.objects.bulk_create(
|
||||||
|
[
|
||||||
|
Employee(work_day_preferences={"Monday": "morning"}),
|
||||||
|
Employee(work_day_preferences={"Monday": "afternoon"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
values = Employee.objects.aggregate(
|
||||||
|
stringagg=StringAgg(
|
||||||
|
KeyTextTransform("Monday", "work_day_preferences"),
|
||||||
|
delimiter=Value(","),
|
||||||
|
order_by=KeyTextTransform(
|
||||||
|
"Monday",
|
||||||
|
"work_day_preferences",
|
||||||
|
),
|
||||||
|
output_field=CharField(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(values, {"stringagg": "afternoon,morning"})
|
||||||
|
|
||||||
|
def test_string_agg_filter_in_subquery(self):
|
||||||
|
aggregate = StringAgg(
|
||||||
|
"authors__name",
|
||||||
|
delimiter=Value(";"),
|
||||||
|
filter=~Q(authors__name__startswith="J"),
|
||||||
|
)
|
||||||
|
subquery = (
|
||||||
|
Book.objects.filter(
|
||||||
|
pk=OuterRef("pk"),
|
||||||
|
)
|
||||||
|
.annotate(agg=aggregate)
|
||||||
|
.values("agg")
|
||||||
|
)
|
||||||
|
values = list(
|
||||||
|
Book.objects.annotate(
|
||||||
|
agg=Subquery(subquery),
|
||||||
|
).values_list("agg", flat=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_values = [
|
||||||
|
"Adrian Holovaty",
|
||||||
|
"Brad Dayley",
|
||||||
|
"Paul Bissex;Wesley J. Chun",
|
||||||
|
"Peter Norvig;Stuart Russell",
|
||||||
|
"Peter Norvig",
|
||||||
|
"" if connection.features.interprets_empty_strings_as_nulls else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertQuerySetEqual(expected_values, values, ordered=False)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
|
||||||
|
def test_order_by_in_subquery(self):
|
||||||
|
aggregate = StringAgg(
|
||||||
|
"authors__name",
|
||||||
|
delimiter=Value(";"),
|
||||||
|
order_by="authors__name",
|
||||||
|
)
|
||||||
|
subquery = (
|
||||||
|
Book.objects.filter(
|
||||||
|
pk=OuterRef("pk"),
|
||||||
|
)
|
||||||
|
.annotate(agg=aggregate)
|
||||||
|
.values("agg")
|
||||||
|
)
|
||||||
|
values = list(
|
||||||
|
Book.objects.annotate(
|
||||||
|
agg=Subquery(subquery),
|
||||||
|
)
|
||||||
|
.order_by("agg")
|
||||||
|
.values_list("agg", flat=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_values = [
|
||||||
|
"Adrian Holovaty;Jacob Kaplan-Moss",
|
||||||
|
"Brad Dayley",
|
||||||
|
"James Bennett",
|
||||||
|
"Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
|
||||||
|
"Peter Norvig",
|
||||||
|
"Peter Norvig;Stuart Russell",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(expected_values, values)
|
||||||
|
|
||||||
|
|
||||||
class AggregateAnnotationPruningTests(TestCase):
|
class AggregateAnnotationPruningTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -1720,14 +1720,14 @@ class WindowFunctionTests(TestCase):
|
|||||||
"""Window expressions can't be used in an INSERT statement."""
|
"""Window expressions can't be used in an INSERT statement."""
|
||||||
msg = (
|
msg = (
|
||||||
"Window expressions are not allowed in this query (salary=<Window: "
|
"Window expressions are not allowed in this query (salary=<Window: "
|
||||||
"Sum(Value(10000), order_by=OrderBy(F(pk), descending=False)) OVER ()"
|
"Sum(Value(10000)) OVER ()"
|
||||||
)
|
)
|
||||||
with self.assertRaisesMessage(FieldError, msg):
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
Employee.objects.create(
|
Employee.objects.create(
|
||||||
name="Jameson",
|
name="Jameson",
|
||||||
department="Management",
|
department="Management",
|
||||||
hire_date=datetime.date(2007, 7, 1),
|
hire_date=datetime.date(2007, 7, 1),
|
||||||
salary=Window(expression=Sum(Value(10000), order_by=F("pk").asc())),
|
salary=Window(expression=Sum(Value(10000))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_window_expression_within_subquery(self):
|
def test_window_expression_within_subquery(self):
|
||||||
@@ -2025,7 +2025,7 @@ class NonQueryWindowTests(SimpleTestCase):
|
|||||||
def test_invalid_order_by(self):
|
def test_invalid_order_by(self):
|
||||||
msg = (
|
msg = (
|
||||||
"Window.order_by must be either a string reference to a field, an "
|
"Window.order_by must be either a string reference to a field, an "
|
||||||
"expression, or a list or tuple of them."
|
"expression, or a list or tuple of them not {'-horse'}."
|
||||||
)
|
)
|
||||||
with self.assertRaisesMessage(ValueError, msg):
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
Window(expression=Sum("power"), order_by={"-horse"})
|
Window(expression=Sum("power"), order_by={"-horse"})
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
CharField,
|
CharField,
|
||||||
@@ -11,16 +13,19 @@ from django.db.models import (
|
|||||||
Value,
|
Value,
|
||||||
Window,
|
Window,
|
||||||
)
|
)
|
||||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
from django.db.models.fields.json import KeyTransform
|
||||||
from django.db.models.functions import Cast, Concat, LPad, Substr
|
from django.db.models.functions import Cast, Concat, LPad, Substr
|
||||||
from django.test.utils import Approximate
|
from django.test.utils import Approximate
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.deprecation import RemovedInDjango61Warning
|
from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
|
||||||
|
|
||||||
from . import PostgreSQLTestCase
|
from . import PostgreSQLTestCase
|
||||||
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
|
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from django.contrib.postgres.aggregates import (
|
||||||
|
StringAgg, # RemovedInDjango70Warning.
|
||||||
|
)
|
||||||
from django.contrib.postgres.aggregates import (
|
from django.contrib.postgres.aggregates import (
|
||||||
ArrayAgg,
|
ArrayAgg,
|
||||||
BitAnd,
|
BitAnd,
|
||||||
@@ -41,7 +46,6 @@ try:
|
|||||||
RegrSXY,
|
RegrSXY,
|
||||||
RegrSYY,
|
RegrSYY,
|
||||||
StatAggregate,
|
StatAggregate,
|
||||||
StringAgg,
|
|
||||||
)
|
)
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
BoolAnd("boolean_field"),
|
BoolAnd("boolean_field"),
|
||||||
BoolOr("boolean_field"),
|
BoolOr("boolean_field"),
|
||||||
JSONBAgg("integer_field"),
|
JSONBAgg("integer_field"),
|
||||||
StringAgg("char_field", delimiter=";"),
|
|
||||||
BitXor("integer_field"),
|
BitXor("integer_field"),
|
||||||
]
|
]
|
||||||
for aggregation in tests:
|
for aggregation in tests:
|
||||||
@@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
|
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
|
||||||
["<empty>"],
|
["<empty>"],
|
||||||
),
|
),
|
||||||
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
|
|
||||||
(
|
|
||||||
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
|
|
||||||
"<empty>",
|
|
||||||
),
|
|
||||||
(BitXor("integer_field", default=0), 0),
|
(BitXor("integer_field", default=0), 0),
|
||||||
]
|
]
|
||||||
for aggregation, expected_result in tests:
|
for aggregation, expected_result in tests:
|
||||||
@@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
|
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
|
||||||
self.assertEqual(ctx.filename, __file__)
|
self.assertEqual(ctx.filename, __file__)
|
||||||
|
|
||||||
|
# RemovedInDjango61Warning: Remove this test
|
||||||
def test_ordering_and_order_by_causes_error(self):
|
def test_ordering_and_order_by_causes_error(self):
|
||||||
with self.assertWarns(RemovedInDjango61Warning):
|
with warnings.catch_warnings(record=True, action="always") as wm:
|
||||||
with self.assertRaisesMessage(
|
with self.assertRaisesMessage(
|
||||||
TypeError,
|
TypeError,
|
||||||
"Cannot specify both order_by and ordering.",
|
"Cannot specify both order_by and ordering.",
|
||||||
@@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
first_warning = wm[0]
|
||||||
|
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
|
||||||
|
self.assertEqual(
|
||||||
|
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||||
|
"django.db.models.aggregate.StringAgg instead.",
|
||||||
|
str(first_warning.message),
|
||||||
|
)
|
||||||
|
|
||||||
|
second_warning = wm[1]
|
||||||
|
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
|
||||||
|
self.assertEqual(
|
||||||
|
"The ordering argument is deprecated. Use order_by instead.",
|
||||||
|
str(second_warning.message),
|
||||||
|
)
|
||||||
|
|
||||||
def test_array_agg_charfield(self):
|
def test_array_agg_charfield(self):
|
||||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
|
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
|
||||||
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
|
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
|
||||||
@@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(values, {"boolor": False})
|
self.assertEqual(values, {"boolor": False})
|
||||||
|
|
||||||
def test_string_agg_requires_delimiter(self):
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
|
|
||||||
|
|
||||||
def test_string_agg_delimiter_escaping(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("char_field", delimiter="'")
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
|
||||||
|
|
||||||
def test_string_agg_charfield(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("char_field", delimiter=";")
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
|
|
||||||
|
|
||||||
def test_string_agg_default_output_field(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("text_field", delimiter=";"),
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
|
|
||||||
|
|
||||||
def test_string_agg_charfield_order_by(self):
|
|
||||||
order_by_test_cases = (
|
|
||||||
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
|
|
||||||
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
|
|
||||||
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
|
|
||||||
("char_field", "Foo1;Foo2;Foo3;Foo4"),
|
|
||||||
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
|
|
||||||
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
|
|
||||||
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
|
|
||||||
)
|
|
||||||
for order_by, expected_output in order_by_test_cases:
|
|
||||||
with self.subTest(order_by=order_by, expected_output=expected_output):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": expected_output})
|
|
||||||
|
|
||||||
def test_string_agg_jsonfield_order_by(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg(
|
|
||||||
KeyTextTransform("lang", "json_field"),
|
|
||||||
delimiter=";",
|
|
||||||
order_by=KeyTextTransform("lang", "json_field"),
|
|
||||||
output_field=CharField(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": "en;pl"})
|
|
||||||
|
|
||||||
def test_string_agg_filter(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg(
|
|
||||||
"char_field",
|
|
||||||
delimiter=";",
|
|
||||||
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
|
|
||||||
|
|
||||||
def test_orderable_agg_alternative_fields(self):
|
def test_orderable_agg_alternative_fields(self):
|
||||||
values = AggregateTestModel.objects.aggregate(
|
values = AggregateTestModel.objects.aggregate(
|
||||||
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
|
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
|
||||||
@@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_string_agg_array_agg_order_by_in_subquery(self):
|
def test_array_agg_order_by_in_subquery(self):
|
||||||
stats = []
|
stats = []
|
||||||
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
|
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
|
||||||
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
|
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
|
||||||
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
|
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
|
||||||
StatTestModel.objects.bulk_create(stats)
|
StatTestModel.objects.bulk_create(stats)
|
||||||
|
|
||||||
for aggregate, expected_result in (
|
aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
|
||||||
(
|
expected_result = [
|
||||||
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"),
|
("Foo1", [0, 1]),
|
||||||
[
|
("Foo2", [1, 2]),
|
||||||
("Foo1", [0, 1]),
|
("Foo3", [2, 3]),
|
||||||
("Foo2", [1, 2]),
|
("Foo4", [3, 4]),
|
||||||
("Foo3", [2, 3]),
|
]
|
||||||
("Foo4", [3, 4]),
|
|
||||||
],
|
subquery = (
|
||||||
),
|
AggregateTestModel.objects.filter(
|
||||||
(
|
pk=OuterRef("pk"),
|
||||||
StringAgg(
|
)
|
||||||
Cast("stattestmodel__int1", CharField()),
|
.annotate(agg=aggregate)
|
||||||
delimiter=";",
|
.values("agg")
|
||||||
order_by="-stattestmodel__int2",
|
)
|
||||||
),
|
values = (
|
||||||
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
|
AggregateTestModel.objects.annotate(
|
||||||
),
|
agg=Subquery(subquery),
|
||||||
):
|
)
|
||||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
.order_by("char_field")
|
||||||
subquery = (
|
.values_list("char_field", "agg")
|
||||||
AggregateTestModel.objects.filter(
|
)
|
||||||
pk=OuterRef("pk"),
|
self.assertEqual(list(values), expected_result)
|
||||||
)
|
|
||||||
.annotate(agg=aggregate)
|
|
||||||
.values("agg")
|
|
||||||
)
|
|
||||||
values = (
|
|
||||||
AggregateTestModel.objects.annotate(
|
|
||||||
agg=Subquery(subquery),
|
|
||||||
)
|
|
||||||
.order_by("char_field")
|
|
||||||
.values_list("char_field", "agg")
|
|
||||||
)
|
|
||||||
self.assertEqual(list(values), expected_result)
|
|
||||||
|
|
||||||
def test_string_agg_array_agg_filter_in_subquery(self):
|
def test_string_agg_array_agg_filter_in_subquery(self):
|
||||||
StatTestModel.objects.bulk_create(
|
StatTestModel.objects.bulk_create(
|
||||||
@@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
|
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for aggregate, expected_result in (
|
|
||||||
(
|
|
||||||
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
|
|
||||||
[("Foo1", [0, 1]), ("Foo2", None)],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
StringAgg(
|
|
||||||
Cast("stattestmodel__int2", CharField()),
|
|
||||||
delimiter=";",
|
|
||||||
filter=Q(stattestmodel__int1__lt=2),
|
|
||||||
),
|
|
||||||
[("Foo1", "5;4"), ("Foo2", None)],
|
|
||||||
),
|
|
||||||
):
|
|
||||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
|
||||||
subquery = (
|
|
||||||
AggregateTestModel.objects.filter(
|
|
||||||
pk=OuterRef("pk"),
|
|
||||||
)
|
|
||||||
.annotate(agg=aggregate)
|
|
||||||
.values("agg")
|
|
||||||
)
|
|
||||||
values = (
|
|
||||||
AggregateTestModel.objects.annotate(
|
|
||||||
agg=Subquery(subquery),
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
char_field__in=["Foo1", "Foo2"],
|
|
||||||
)
|
|
||||||
.order_by("char_field")
|
|
||||||
.values_list("char_field", "agg")
|
|
||||||
)
|
|
||||||
self.assertEqual(list(values), expected_result)
|
|
||||||
|
|
||||||
def test_string_agg_filter_in_subquery_with_exclude(self):
|
aggregate = ArrayAgg(
|
||||||
|
"stattestmodel__int1",
|
||||||
|
filter=Q(stattestmodel__int2__gt=3),
|
||||||
|
)
|
||||||
|
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
|
||||||
|
|
||||||
subquery = (
|
subquery = (
|
||||||
AggregateTestModel.objects.annotate(
|
AggregateTestModel.objects.filter(
|
||||||
stringagg=StringAgg(
|
pk=OuterRef("pk"),
|
||||||
"char_field",
|
|
||||||
delimiter=";",
|
|
||||||
filter=Q(char_field__endswith="1"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.exclude(stringagg="")
|
.annotate(agg=aggregate)
|
||||||
.values("id")
|
.values("agg")
|
||||||
)
|
)
|
||||||
self.assertSequenceEqual(
|
values = (
|
||||||
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
|
AggregateTestModel.objects.annotate(
|
||||||
[self.aggs[0]],
|
agg=Subquery(subquery),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
char_field__in=["Foo1", "Foo2"],
|
||||||
|
)
|
||||||
|
.order_by("char_field")
|
||||||
|
.values_list("char_field", "agg")
|
||||||
)
|
)
|
||||||
|
self.assertEqual(list(values), expected_result)
|
||||||
|
|
||||||
def test_ordering_isnt_cleared_for_array_subquery(self):
|
def test_ordering_isnt_cleared_for_array_subquery(self):
|
||||||
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
|
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
|
||||||
@@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
|
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
|
||||||
for aggregation in tests:
|
for aggregation in tests:
|
||||||
with self.subTest(aggregation=aggregation):
|
with self.subTest(aggregation=aggregation):
|
||||||
|
results = AggregateTestModel.objects.annotate(
|
||||||
|
agg=aggregation
|
||||||
|
).values_list("agg")
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
AggregateTestModel.objects.values_list(aggregation),
|
results,
|
||||||
[([0],), ([1],), ([2],), ([0],)],
|
[([0],), ([1],), ([2],), ([0],)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_string_agg_delimiter_deprecation(self):
|
||||||
|
msg = (
|
||||||
|
"delimiter: str will be resolved as a field reference instead "
|
||||||
|
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
|
||||||
|
"preserve the previous behaviour."
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||||
|
values = AggregateTestModel.objects.aggregate(
|
||||||
|
stringagg=StringAgg("char_field", delimiter="'")
|
||||||
|
)
|
||||||
|
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||||
|
self.assertEqual(ctx.filename, __file__)
|
||||||
|
|
||||||
|
def test_string_agg_deprecation(self):
|
||||||
|
msg = (
|
||||||
|
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||||
|
"django.db.models.aggregate.StringAgg instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||||
|
values = AggregateTestModel.objects.aggregate(
|
||||||
|
stringagg=StringAgg("char_field", delimiter=Value("'"))
|
||||||
|
)
|
||||||
|
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||||
|
self.assertEqual(ctx.filename, __file__)
|
||||||
|
|
||||||
|
|
||||||
class TestAggregateDistinct(PostgreSQLTestCase):
|
class TestAggregateDistinct(PostgreSQLTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
|
|||||||
AggregateTestModel.objects.create(char_field="Foo")
|
AggregateTestModel.objects.create(char_field="Foo")
|
||||||
AggregateTestModel.objects.create(char_field="Bar")
|
AggregateTestModel.objects.create(char_field="Bar")
|
||||||
|
|
||||||
def test_string_agg_distinct_false(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
|
|
||||||
)
|
|
||||||
self.assertEqual(values["stringagg"].count("Foo"), 2)
|
|
||||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
|
||||||
|
|
||||||
def test_string_agg_distinct_true(self):
|
|
||||||
values = AggregateTestModel.objects.aggregate(
|
|
||||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
|
|
||||||
)
|
|
||||||
self.assertEqual(values["stringagg"].count("Foo"), 1)
|
|
||||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
|
||||||
|
|
||||||
def test_array_agg_distinct_false(self):
|
def test_array_agg_distinct_false(self):
|
||||||
values = AggregateTestModel.objects.aggregate(
|
values = AggregateTestModel.objects.aggregate(
|
||||||
arrayagg=ArrayAgg("char_field", distinct=False)
|
arrayagg=ArrayAgg("char_field", distinct=False)
|
||||||
|
Reference in New Issue
Block a user