mirror of
https://github.com/django/django.git
synced 2025-10-23 21:59:11 +00:00
Fixed #35444 -- Added generic support for Aggregate.order_by.
This moves the behaviors of `order_by` used in Postgres aggregates into the `Aggregate` class. This allows for creating aggregate functions that support this behavior across all database engines. This is shown by moving the `StringAgg` class into the shared `aggregates` module and adding support for all databases. The Postgres `StringAgg` class is now a thin wrapper on the new shared `StringAgg` class. Thank you Simon Charette for the review.
This commit is contained in:
committed by
Sarah Boyce
parent
6d1cf5375f
commit
4b977a5d72
@@ -1,3 +1,5 @@
|
||||
import warnings
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import (
|
||||
CharField,
|
||||
@@ -11,16 +13,19 @@ from django.db.models import (
|
||||
Value,
|
||||
Window,
|
||||
)
|
||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
||||
from django.db.models.fields.json import KeyTransform
|
||||
from django.db.models.functions import Cast, Concat, LPad, Substr
|
||||
from django.test.utils import Approximate
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango61Warning
|
||||
from django.utils.deprecation import RemovedInDjango61Warning, RemovedInDjango70Warning
|
||||
|
||||
from . import PostgreSQLTestCase
|
||||
from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
|
||||
|
||||
try:
|
||||
from django.contrib.postgres.aggregates import (
|
||||
StringAgg, # RemovedInDjango70Warning.
|
||||
)
|
||||
from django.contrib.postgres.aggregates import (
|
||||
ArrayAgg,
|
||||
BitAnd,
|
||||
@@ -41,7 +46,6 @@ try:
|
||||
RegrSXY,
|
||||
RegrSYY,
|
||||
StatAggregate,
|
||||
StringAgg,
|
||||
)
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
except ImportError:
|
||||
@@ -94,7 +98,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
BoolAnd("boolean_field"),
|
||||
BoolOr("boolean_field"),
|
||||
JSONBAgg("integer_field"),
|
||||
StringAgg("char_field", delimiter=";"),
|
||||
BitXor("integer_field"),
|
||||
]
|
||||
for aggregation in tests:
|
||||
@@ -127,11 +130,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
|
||||
["<empty>"],
|
||||
),
|
||||
(StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
|
||||
(
|
||||
StringAgg("char_field", delimiter=";", default=Value("<empty>")),
|
||||
"<empty>",
|
||||
),
|
||||
(BitXor("integer_field", default=0), 0),
|
||||
]
|
||||
for aggregation, expected_result in tests:
|
||||
@@ -158,8 +156,9 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
# RemovedInDjango61Warning: Remove this test
|
||||
def test_ordering_and_order_by_causes_error(self):
|
||||
with self.assertWarns(RemovedInDjango61Warning):
|
||||
with warnings.catch_warnings(record=True, action="always") as wm:
|
||||
with self.assertRaisesMessage(
|
||||
TypeError,
|
||||
"Cannot specify both order_by and ordering.",
|
||||
@@ -173,6 +172,21 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
first_warning = wm[0]
|
||||
self.assertEqual(first_warning.category, RemovedInDjango70Warning)
|
||||
self.assertEqual(
|
||||
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||
"django.db.models.aggregate.StringAgg instead.",
|
||||
str(first_warning.message),
|
||||
)
|
||||
|
||||
second_warning = wm[1]
|
||||
self.assertEqual(second_warning.category, RemovedInDjango61Warning)
|
||||
self.assertEqual(
|
||||
"The ordering argument is deprecated. Use order_by instead.",
|
||||
str(second_warning.message),
|
||||
)
|
||||
|
||||
def test_array_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
|
||||
self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
|
||||
@@ -425,66 +439,6 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
)
|
||||
self.assertEqual(values, {"boolor": False})
|
||||
|
||||
def test_string_agg_requires_delimiter(self):
|
||||
with self.assertRaises(TypeError):
|
||||
AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
|
||||
|
||||
def test_string_agg_delimiter_escaping(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter="'")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
|
||||
def test_string_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=";")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
|
||||
|
||||
def test_string_agg_default_output_field(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("text_field", delimiter=";"),
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
|
||||
|
||||
def test_string_agg_charfield_order_by(self):
|
||||
order_by_test_cases = (
|
||||
(F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
|
||||
(F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
|
||||
(F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
|
||||
("char_field", "Foo1;Foo2;Foo3;Foo4"),
|
||||
("-char_field", "Foo4;Foo3;Foo2;Foo1"),
|
||||
(Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
|
||||
(Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
|
||||
)
|
||||
for order_by, expected_output in order_by_test_cases:
|
||||
with self.subTest(order_by=order_by, expected_output=expected_output):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=";", order_by=order_by)
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": expected_output})
|
||||
|
||||
def test_string_agg_jsonfield_order_by(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
KeyTextTransform("lang", "json_field"),
|
||||
delimiter=";",
|
||||
order_by=KeyTextTransform("lang", "json_field"),
|
||||
output_field=CharField(),
|
||||
),
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "en;pl"})
|
||||
|
||||
def test_string_agg_filter(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg(
|
||||
"char_field",
|
||||
delimiter=";",
|
||||
filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
|
||||
)
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
|
||||
|
||||
def test_orderable_agg_alternative_fields(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg("integer_field", order_by=F("char_field").asc())
|
||||
@@ -593,48 +547,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_string_agg_array_agg_order_by_in_subquery(self):
|
||||
def test_array_agg_order_by_in_subquery(self):
|
||||
stats = []
|
||||
for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
|
||||
stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
|
||||
stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
|
||||
StatTestModel.objects.bulk_create(stats)
|
||||
|
||||
for aggregate, expected_result in (
|
||||
(
|
||||
ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2"),
|
||||
[
|
||||
("Foo1", [0, 1]),
|
||||
("Foo2", [1, 2]),
|
||||
("Foo3", [2, 3]),
|
||||
("Foo4", [3, 4]),
|
||||
],
|
||||
),
|
||||
(
|
||||
StringAgg(
|
||||
Cast("stattestmodel__int1", CharField()),
|
||||
delimiter=";",
|
||||
order_by="-stattestmodel__int2",
|
||||
),
|
||||
[("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
|
||||
),
|
||||
):
|
||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
aggregate = ArrayAgg("stattestmodel__int1", order_by="-stattestmodel__int2")
|
||||
expected_result = [
|
||||
("Foo1", [0, 1]),
|
||||
("Foo2", [1, 2]),
|
||||
("Foo3", [2, 3]),
|
||||
("Foo4", [3, 4]),
|
||||
]
|
||||
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_string_agg_array_agg_filter_in_subquery(self):
|
||||
StatTestModel.objects.bulk_create(
|
||||
@@ -644,56 +586,31 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
|
||||
]
|
||||
)
|
||||
for aggregate, expected_result in (
|
||||
(
|
||||
ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
|
||||
[("Foo1", [0, 1]), ("Foo2", None)],
|
||||
),
|
||||
(
|
||||
StringAgg(
|
||||
Cast("stattestmodel__int2", CharField()),
|
||||
delimiter=";",
|
||||
filter=Q(stattestmodel__int1__lt=2),
|
||||
),
|
||||
[("Foo1", "5;4"), ("Foo2", None)],
|
||||
),
|
||||
):
|
||||
with self.subTest(aggregate=aggregate.__class__.__name__):
|
||||
subquery = (
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.filter(
|
||||
char_field__in=["Foo1", "Foo2"],
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_string_agg_filter_in_subquery_with_exclude(self):
|
||||
aggregate = ArrayAgg(
|
||||
"stattestmodel__int1",
|
||||
filter=Q(stattestmodel__int2__gt=3),
|
||||
)
|
||||
expected_result = [("Foo1", [0, 1]), ("Foo2", None)]
|
||||
|
||||
subquery = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
stringagg=StringAgg(
|
||||
"char_field",
|
||||
delimiter=";",
|
||||
filter=Q(char_field__endswith="1"),
|
||||
)
|
||||
AggregateTestModel.objects.filter(
|
||||
pk=OuterRef("pk"),
|
||||
)
|
||||
.exclude(stringagg="")
|
||||
.values("id")
|
||||
.annotate(agg=aggregate)
|
||||
.values("agg")
|
||||
)
|
||||
self.assertSequenceEqual(
|
||||
AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
|
||||
[self.aggs[0]],
|
||||
values = (
|
||||
AggregateTestModel.objects.annotate(
|
||||
agg=Subquery(subquery),
|
||||
)
|
||||
.filter(
|
||||
char_field__in=["Foo1", "Foo2"],
|
||||
)
|
||||
.order_by("char_field")
|
||||
.values_list("char_field", "agg")
|
||||
)
|
||||
self.assertEqual(list(values), expected_result)
|
||||
|
||||
def test_ordering_isnt_cleared_for_array_subquery(self):
|
||||
inner_qs = AggregateTestModel.objects.order_by("-integer_field")
|
||||
@@ -729,11 +646,41 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
||||
tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
|
||||
for aggregation in tests:
|
||||
with self.subTest(aggregation=aggregation):
|
||||
results = AggregateTestModel.objects.annotate(
|
||||
agg=aggregation
|
||||
).values_list("agg")
|
||||
self.assertCountEqual(
|
||||
AggregateTestModel.objects.values_list(aggregation),
|
||||
results,
|
||||
[([0],), ([1],), ([2],), ([0],)],
|
||||
)
|
||||
|
||||
def test_string_agg_delimiter_deprecation(self):
|
||||
msg = (
|
||||
"delimiter: str will be resolved as a field reference instead "
|
||||
'of a string literal on Django 7.0. Pass `delimiter=Value("\'")` to '
|
||||
"preserve the previous behaviour."
|
||||
)
|
||||
|
||||
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter="'")
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
def test_string_agg_deprecation(self):
|
||||
msg = (
|
||||
"The PostgreSQL specific StringAgg function is deprecated. Use "
|
||||
"django.db.models.aggregate.StringAgg instead."
|
||||
)
|
||||
|
||||
with self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx:
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=Value("'"))
|
||||
)
|
||||
self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
|
||||
self.assertEqual(ctx.filename, __file__)
|
||||
|
||||
|
||||
class TestAggregateDistinct(PostgreSQLTestCase):
|
||||
@classmethod
|
||||
@@ -742,20 +689,6 @@ class TestAggregateDistinct(PostgreSQLTestCase):
|
||||
AggregateTestModel.objects.create(char_field="Foo")
|
||||
AggregateTestModel.objects.create(char_field="Bar")
|
||||
|
||||
def test_string_agg_distinct_false(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
|
||||
)
|
||||
self.assertEqual(values["stringagg"].count("Foo"), 2)
|
||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
||||
|
||||
def test_string_agg_distinct_true(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
|
||||
)
|
||||
self.assertEqual(values["stringagg"].count("Foo"), 1)
|
||||
self.assertEqual(values["stringagg"].count("Bar"), 1)
|
||||
|
||||
def test_array_agg_distinct_false(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg("char_field", distinct=False)
|
||||
|
||||
Reference in New Issue
Block a user