1
0
mirror of https://github.com/django/django.git synced 2025-10-23 21:59:11 +00:00

Fixed #35444 -- Added generic support for Aggregate.order_by.

This moves the behaviors of `order_by` used in Postgres aggregates into
the `Aggregate` class. This allows for creating aggregate functions that
support this behavior across all database engines. This is shown by
moving the `StringAgg` class into the shared `aggregates` module and
adding support for all databases. The Postgres `StringAgg` class is now
a thin wrapper on the new shared `StringAgg` class.

Thank you Simon Charette for the review.
This commit is contained in:
Chris Muthig
2024-12-22 16:30:55 +01:00
committed by Sarah Boyce
parent 6d1cf5375f
commit 4b977a5d72
19 changed files with 659 additions and 291 deletions

View File

@@ -1,3 +1,5 @@
import warnings
from django.db import transaction
from django.db.models import (
CharField,
@@ -11,16 +13,19 @@ from django.db.models import (
Value,
Window,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.fields.json import KeyTransform
from django.db.models.functions import Cast, Concat, LPad, Substr
from django.test.utils import Approximate
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango61Warning
from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
from . import PostgreSQLTestCase
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
try:
from django.contrib.postgres.aggregates import (
StringAgg, # RemovedInDjango70Warning.
)
from django.contrib.postgres.aggregates import (
ArrayAgg,
BitAnd,
@@ -41,7 +46,6 @@ try:
RegrSXY,
RegrSYY,
StatAggregate,
StringAgg,
)
from django.contrib.postgres.fields import ArrayField
except ImportError:
@@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
BoolAnd("boolean_field"),
BoolOr("boolean_field"),
JSONBAgg("integer_field"),
StringAgg("char_field", delimiter=";"),
BitXor("integer_field"),
]
for aggregation in tests:
@@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
["<empty>"],
),
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
(
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
"<empty>",
),
(BitXor("integer_field", default=0), 0),
]
for aggregation, expected_result in tests:
@@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
self.assertEqual(ctx.filename, __file__)
# RemovedInDjango61Warning: Remove this test
def test_ordering_and_order_by_causes_error(self):
with self.assertWarns(RemovedInDjango61Warning):
with warnings.catch_warnings(record=True, action="always") as wm:
with self.assertRaisesMessage(
TypeError,
"Cannot specify both order_by and ordering.",
@@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
)
first_warning = wm[0]
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
self.assertEqual(
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead.",
str(first_warning.message),
)
second_warning = wm[1]
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
self.assertEqual(
"The ordering argument is deprecated. Use order_by instead.",
str(second_warning.message),
)
def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
@@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
self.assertEqual(values, {"boolor": False})
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
def test_string_agg_delimiter_escaping(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";")
)
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
def test_string_agg_default_output_field(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("text_field", delimiter=";"),
)
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
def test_string_agg_charfield_order_by(self):
order_by_test_cases = (
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
("char_field", "Foo1;Foo2;Foo3;Foo4"),
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
def test_string_agg_jsonfield_order_by(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("lang", "json_field"),
delimiter=";",
order_by=KeyTextTransform("lang", "json_field"),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "en;pl"})
def test_string_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg(
"char_field",
delimiter=";",
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
)
)
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
def test_orderable_agg_alternative_fields(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
@@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
],
)
def test_string_agg_array_agg_order_by_in_subquery(self):
def test_array_agg_order_by_in_subquery(self):
stats = []
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
StatTestModel.objects.bulk_create(stats)
for aggregate, expected_result in (
(
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"),
[
("Foo1", [0, 1]),
("Foo2", [1, 2]),
("Foo3", [2, 3]),
("Foo4", [3, 4]),
],
),
(
StringAgg(
Cast("stattestmodel__int1", CharField()),
delimiter=";",
order_by="-stattestmodel__int2",
),
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
expected_result = [
("Foo1", [0, 1]),
("Foo2", [1, 2]),
("Foo3", [2, 3]),
("Foo4", [3, 4]),
]
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_array_agg_filter_in_subquery(self):
StatTestModel.objects.bulk_create(
@@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
]
)
for aggregate, expected_result in (
(
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
[("Foo1", [0, 1]), ("Foo2", None)],
),
(
StringAgg(
Cast("stattestmodel__int2", CharField()),
delimiter=";",
filter=Q(stattestmodel__int1__lt=2),
),
[("Foo1", "5;4"), ("Foo2", None)],
),
):
with self.subTest(aggregate=aggregate.__class__.__name__):
subquery = (
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_string_agg_filter_in_subquery_with_exclude(self):
aggregate = ArrayAgg(
"stattestmodel__int1",
filter=Q(stattestmodel__int2__gt=3),
)
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
subquery = (
AggregateTestModel.objects.annotate(
stringagg=StringAgg(
"char_field",
delimiter=";",
filter=Q(char_field__endswith="1"),
)
AggregateTestModel.objects.filter(
pk=OuterRef("pk"),
)
.exclude(stringagg="")
.values("id")
.annotate(agg=aggregate)
.values("agg")
)
self.assertSequenceEqual(
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
[self.aggs[0]],
values = (
AggregateTestModel.objects.annotate(
agg=Subquery(subquery),
)
.filter(
char_field__in=["Foo1", "Foo2"],
)
.order_by("char_field")
.values_list("char_field", "agg")
)
self.assertEqual(list(values), expected_result)
def test_ordering_isnt_cleared_for_array_subquery(self):
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
@@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
for aggregation in tests:
with self.subTest(aggregation=aggregation):
results = AggregateTestModel.objects.annotate(
agg=aggregation
).values_list("agg")
self.assertCountEqual(
AggregateTestModel.objects.values_list(aggregation),
results,
[([0],), ([1],), ([2],), ([0],)],
)
def test_string_agg_delimiter_deprecation(self):
msg = (
"delimiter: str will be resolved as a field reference instead "
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
"preserve the previous behaviour."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter="'")
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
def test_string_agg_deprecation(self):
msg = (
"The PostgreSQL specific StringAgg function is deprecated. Use "
"django.db.models.aggregate.StringAgg instead."
)
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=Value("'"))
)
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
self.assertEqual(ctx.filename, __file__)
class TestAggregateDistinct(PostgreSQLTestCase):
@classmethod
@@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
AggregateTestModel.objects.create(char_field="Foo")
AggregateTestModel.objects.create(char_field="Bar")
def test_string_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
)
self.assertEqual(values["stringagg"].count("Foo"), 2)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_string_agg_distinct_true(self):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
)
self.assertEqual(values["stringagg"].count("Foo"), 1)
self.assertEqual(values["stringagg"].count("Bar"), 1)
def test_array_agg_distinct_false(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("char_field", distinct=False)