diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 911b60a86d..ea16cc440c 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -158,6 +158,7 @@ class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): function = "AVG" name = "Avg" allow_distinct = True + arity = 1 class Count(Aggregate): @@ -166,6 +167,7 @@ class Count(Aggregate): output_field = IntegerField() allow_distinct = True empty_result_set_value = 0 + arity = 1 allows_composite_expressions = True def __init__(self, expression, filter=None, **extra): @@ -195,15 +197,18 @@ class Count(Aggregate): class Max(Aggregate): function = "MAX" name = "Max" + arity = 1 class Min(Aggregate): function = "MIN" name = "Min" + arity = 1 class StdDev(NumericOutputFieldMixin, Aggregate): name = "StdDev" + arity = 1 def __init__(self, expression, sample=False, **extra): self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" @@ -217,10 +222,12 @@ class Sum(FixDurationInputMixin, Aggregate): function = "SUM" name = "Sum" allow_distinct = True + arity = 1 class Variance(NumericOutputFieldMixin, Aggregate): name = "Variance" + arity = 1 def __init__(self, expression, sample=False, **extra): self.function = "VAR_SAMP" if sample else "VAR_POP" diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 6faec969c3..5d5504dbe7 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -516,6 +516,7 @@ generated. Here's a brief example:: function = "SUM" template = "%(function)s(%(all_values)s%(expressions)s)" allow_distinct = False + arity = 1 def __init__(self, expression, all_values=False, **extra): super().__init__(expression, all_values="ALL " if all_values else "", **extra) diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index 3d3a958b6d..716f217aee 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -511,6 +511,10 @@ Miscellaneous * The minimum supported version of ``oracledb`` is increased from 1.3.2 to 2.3.0. +* Built-in aggregate functions accepting only one argument (``Avg``, ``Count``, + ``Max``, ``Min``, ``StdDev``, ``Sum``, and ``Variance``) now raise + :exc:`TypeError` when called with an incorrect number of arguments. + .. _deprecated-features-5.2: Features deprecated in 5.2 diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index b6ba728e77..861b2c5dfc 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1276,6 +1276,8 @@ class AggregateTestCase(TestCase): Book.objects.annotate(Max("id")).annotate(Sum("id__max")) class MyMax(Max): + arity = None + def as_sql(self, compiler, connection): self.set_source_expressions(self.get_source_expressions()[0:1]) return super().as_sql(compiler, connection) @@ -1288,6 +1290,7 @@ class AggregateTestCase(TestCase): def test_multi_arg_aggregate(self): class MyMax(Max): output_field = DecimalField() + arity = None def as_sql(self, compiler, connection): copy = self.copy() @@ -2178,6 +2181,27 @@ class AggregateTestCase(TestCase): ) self.assertEqual(list(author_qs), [337]) + def test_aggregate_arity(self): + funcs_with_inherited_constructors = [Avg, Max, Min, Sum] + msg = "takes exactly 1 argument (2 given)" + for function in funcs_with_inherited_constructors: + with ( + self.subTest(function=function), + self.assertRaisesMessage(TypeError, msg), + ): + function(Value(1), Value(2)) + + funcs_with_custom_constructors = [Count, StdDev, Variance] + for function in funcs_with_custom_constructors: + with self.subTest(function=function): + # Extra arguments are rejected via the constructor. + with self.assertRaises(TypeError): + function(Value(1), True, Value(2)) + # If the constructor is skipped, the arity check runs. + func_instance = function(Value(1), True) + with self.assertRaisesMessage(TypeError, msg): + super(function, func_instance).__init__(Value(1), Value(2)) + class AggregateAnnotationPruningTests(TestCase): @classmethod