1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #35444 -- Added generic support for Aggregate.order_by.

This moves the behaviors of `order_by` used in Postgres aggregates into
the `Aggregate` class. This allows for creating aggregate functions that
support this behavior across all database engines. This is shown by
moving the `StringAgg` class into the shared `aggregates` module and
adding support for all databases. The Postgres `StringAgg` class is now
a thin wrapper on the new shared `StringAgg` class.

Thank you Simon Charette for the review.
This commit is contained in:
Chris Muthig
2024-12-22 16:30:55 +01:00
committed by Sarah Boyce
parent 6d1cf5375f
commit 4b977a5d72
19 changed files with 659 additions and 291 deletions

View File

@@ -4,10 +4,11 @@ import re
from decimal import Decimal
from django.core.exceptions import FieldError
from django.db import connection
from django.db import NotSupportedError, connection
from django.db.models import (
Avg,
Case,
CharField,
Count,
DateField,
DateTimeField,
@@ -22,6 +23,7 @@ from django.db.models import (
OuterRef,
Q,
StdDev,
StringAgg,
Subquery,
Sum,
TimeField,
@@ -32,9 +34,11 @@ from django.db.models import (
Window,
)
from django.db.models.expressions import Func, RawSQL
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import (
Cast,
Coalesce,
Concat,
Greatest,
Least,
Lower,
@@ -45,11 +49,11 @@ from django.db.models.functions import (
TruncHour,
)
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from django.test.testcases import skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext
from django.utils import timezone
from .models import Author, Book, Publisher, Store
from .models import Author, Book, Employee, Publisher, Store
class NowUTC(Now):
@@ -566,6 +570,28 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(books["ratings"], expected_result)
@skipUnlessDBFeature("supports_aggregate_distinct_multiple_argument")
def test_distinct_on_stringagg(self):
books = Book.objects.aggregate(
ratings=StringAgg(Cast(F("rating"), CharField()), Value(","), distinct=True)
)
self.assertEqual(books["ratings"], "3,4,4.5,5")
@skipIfDBFeature("supports_aggregate_distinct_multiple_argument")
def test_raises_error_on_multiple_argument_distinct(self):
message = (
"StringAgg does not support distinct with multiple expressions on this "
"database backend."
)
with self.assertRaisesMessage(NotSupportedError, message):
Book.objects.aggregate(
ratings=StringAgg(
Cast(F("rating"), CharField()),
Value(","),
distinct=True,
)
)
def test_non_grouped_annotation_not_in_group_by(self):
"""
An annotation not included in values() before an aggregate should be
@@ -1288,24 +1314,30 @@ class AggregateTestCase(TestCase):
Book.objects.annotate(Max("id")).annotate(my_max=MyMax("id__max", "price"))
def test_multi_arg_aggregate(self):
class MyMax(Max):
class MultiArgAgg(Max):
output_field = DecimalField()
arity = None
def as_sql(self, compiler, connection):
def as_sql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
return super(MyMax, copy).as_sql(compiler, connection)
# Most database backends do not support compiling multiple arguments on
# the Max aggregate, and that isn't what is being tested here anyway. To
# avoid errors, the extra argument is just dropped.
copy.set_source_expressions(
copy.get_source_expressions()[0:1] + [None, None]
)
return super(MultiArgAgg, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
Book.objects.aggregate(MyMax("pages", "price"))
Book.objects.aggregate(MultiArgAgg("pages", "price"))
with self.assertRaisesMessage(
TypeError, "Complex annotations require an alias"
):
Book.objects.annotate(MyMax("pages", "price"))
Book.objects.annotate(MultiArgAgg("pages", "price"))
Book.objects.aggregate(max_field=MyMax("pages", "price"))
Book.objects.aggregate(max_field=MultiArgAgg("pages", "price"))
def test_add_implementation(self):
class MySum(Sum):
@@ -1318,6 +1350,8 @@ class AggregateTestCase(TestCase):
"function": self.function.lower(),
"expressions": sql,
"distinct": "",
"filter": "",
"order_by": "",
}
substitutions.update(self.extra)
return self.template % substitutions, params
@@ -1351,7 +1385,13 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template
def be_evil(self, compiler, connection):
substitutions = {"function": "MAX", "expressions": "2", "distinct": ""}
substitutions = {
"function": "MAX",
"expressions": "2",
"distinct": "",
"filter": "",
"order_by": "",
}
substitutions.update(self.extra)
return self.template % substitutions, ()
@@ -1779,10 +1819,12 @@ class AggregateTestCase(TestCase):
Publisher.objects.none().aggregate(
sum_awards=Sum("num_awards"),
books_count=Count("book"),
all_names=StringAgg("name", Value(",")),
),
{
"sum_awards": None,
"books_count": 0,
"all_names": None,
},
)
# Expression without empty_result_set_value forces queries to be
@@ -1874,6 +1916,12 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(result["value"], 35)
def test_stringagg_default_value(self):
result = Author.objects.filter(age__gt=100).aggregate(
value=StringAgg("name", delimiter=Value(";"), default=Value("<empty>")),
)
self.assertEqual(result["value"], "<empty>")
def test_aggregation_default_group_by(self):
qs = (
Publisher.objects.values("name")
@@ -2202,6 +2250,167 @@ class AggregateTestCase(TestCase):
with self.assertRaisesMessage(TypeError, msg):
super(function, func_instance).__init__(Value(1), Value(2))
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
Book.objects.aggregate(stringagg=StringAgg("name"))
def test_string_agg_escapes_delimiter(self):
values = Publisher.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value("'"))
)
self.assertEqual(
values,
{
"stringagg": "Apress'Sams'Prentice Hall'Morgan Kaufmann'Jonno's House "
"of Books",
},
)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by(self):
order_by_test_cases = (
(
F("original_opening").desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
(
F("original_opening").asc(),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(F("original_opening"), "Mamma and Pappa's Books;Amazon.com;Books.com"),
("original_opening", "Mamma and Pappa's Books;Amazon.com;Books.com"),
("-original_opening", "Books.com;Amazon.com;Mamma and Pappa's Books"),
(
Concat("original_opening", Value("@")),
"Mamma and Pappa's Books;Amazon.com;Books.com",
),
(
Concat("original_opening", Value("@")).desc(),
"Books.com;Amazon.com;Mamma and Pappa's Books",
),
)
for order_by, expected_output in order_by_test_cases:
with self.subTest(order_by=order_by, expected_output=expected_output):
values = Store.objects.aggregate(
stringagg=StringAgg("name", delimiter=Value(";"), order_by=order_by)
)
self.assertEqual(values, {"stringagg": expected_output})
@skipIfDBFeature("supports_aggregate_order_by_clause")
def test_string_agg_order_by_is_not_supported(self):
message = (
"This database backend does not support specifying an order on aggregates."
)
with self.assertRaisesMessage(NotSupportedError, message):
Store.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
order_by="original_opening",
)
)
def test_string_agg_filter(self):
values = Book.objects.aggregate(
stringagg=StringAgg(
"name",
delimiter=Value(";"),
filter=Q(name__startswith="P"),
)
)
expected_values = {
"stringagg": "Practical Django Projects;"
"Python Web Development with Django;Paradigms of Artificial "
"Intelligence Programming: Case Studies in Common Lisp",
}
self.assertEqual(values, expected_values)
@skipUnlessDBFeature("supports_json_field", "supports_aggregate_order_by_clause")
def test_string_agg_jsonfield_order_by(self):
Employee.objects.bulk_create(
[
Employee(work_day_preferences={"Monday": "morning"}),
Employee(work_day_preferences={"Monday": "afternoon"}),
]
)
values = Employee.objects.aggregate(
stringagg=StringAgg(
KeyTextTransform("Monday", "work_day_preferences"),
delimiter=Value(","),
order_by=KeyTextTransform(
"Monday",
"work_day_preferences",
),
output_field=CharField(),
),
)
self.assertEqual(values, {"stringagg": "afternoon,morning"})
def test_string_agg_filter_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
filter=~Q(authors__name__startswith="J"),
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
).values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty",
"Brad Dayley",
"Paul Bissex;Wesley J. Chun",
"Peter Norvig;Stuart Russell",
"Peter Norvig",
"" if connection.features.interprets_empty_strings_as_nulls else None,
]
self.assertQuerySetEqual(expected_values, values, ordered=False)
@skipUnlessDBFeature("supports_aggregate_order_by_clause")
def test_order_by_in_subquery(self):
aggregate = StringAgg(
"authors__name",
delimiter=Value(";"),
order_by="authors__name",
)
subquery = (
Book.objects.filter(
pk=OuterRef("pk"),
)
.annotate(agg=aggregate)
.values("agg")
)
values = list(
Book.objects.annotate(
agg=Subquery(subquery),
)
.order_by("agg")
.values_list("agg", flat=True)
)
expected_values = [
"Adrian Holovaty;Jacob Kaplan-Moss",
"Brad Dayley",
"James Bennett",
"Jeffrey Forcier;Paul Bissex;Wesley J. Chun",
"Peter Norvig",
"Peter Norvig;Stuart Russell",
]
self.assertEqual(expected_values, values)
class AggregateAnnotationPruningTests(TestCase):
@classmethod