mirror of
https://github.com/django/django.git
synced 2025-01-18 14:24:39 +00:00
Fixed #36051 -- Declared arity on aggregate functions.
Follow-up to 4a66a69239c493c05b322815b18c605cd4c96e7c.
This commit is contained in:
parent
f07360e808
commit
d206d4c200
@ -158,6 +158,7 @@ class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
@ -166,6 +167,7 @@ class Count(Aggregate):
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
arity = 1
|
||||
allows_composite_expressions = True
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
@ -195,15 +197,18 @@ class Count(Aggregate):
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
arity = 1
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
arity = 1
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
@ -217,10 +222,12 @@ class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
|
@ -516,6 +516,7 @@ generated. Here's a brief example::
|
||||
function = "SUM"
|
||||
template = "%(function)s(%(all_values)s%(expressions)s)"
|
||||
allow_distinct = False
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, all_values=False, **extra):
|
||||
super().__init__(expression, all_values="ALL " if all_values else "", **extra)
|
||||
|
@ -511,6 +511,10 @@ Miscellaneous
|
||||
* The minimum supported version of ``oracledb`` is increased from 1.3.2 to
|
||||
2.3.0.
|
||||
|
||||
* Built-in aggregate functions accepting only one argument (``Avg``, ``Count``,
|
||||
``Max``, ``Min``, ``StdDev``, ``Sum``, and ``Variance``) now raise
|
||||
:exc:`TypeError` when called with an incorrect number of arguments.
|
||||
|
||||
.. _deprecated-features-5.2:
|
||||
|
||||
Features deprecated in 5.2
|
||||
|
@ -1276,6 +1276,8 @@ class AggregateTestCase(TestCase):
|
||||
Book.objects.annotate(Max("id")).annotate(Sum("id__max"))
|
||||
|
||||
class MyMax(Max):
|
||||
arity = None
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
self.set_source_expressions(self.get_source_expressions()[0:1])
|
||||
return super().as_sql(compiler, connection)
|
||||
@ -1288,6 +1290,7 @@ class AggregateTestCase(TestCase):
|
||||
def test_multi_arg_aggregate(self):
|
||||
class MyMax(Max):
|
||||
output_field = DecimalField()
|
||||
arity = None
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
copy = self.copy()
|
||||
@ -2178,6 +2181,27 @@ class AggregateTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(list(author_qs), [337])
|
||||
|
||||
def test_aggregate_arity(self):
|
||||
funcs_with_inherited_constructors = [Avg, Max, Min, Sum]
|
||||
msg = "takes exactly 1 argument (2 given)"
|
||||
for function in funcs_with_inherited_constructors:
|
||||
with (
|
||||
self.subTest(function=function),
|
||||
self.assertRaisesMessage(TypeError, msg),
|
||||
):
|
||||
function(Value(1), Value(2))
|
||||
|
||||
funcs_with_custom_constructors = [Count, StdDev, Variance]
|
||||
for function in funcs_with_custom_constructors:
|
||||
with self.subTest(function=function):
|
||||
# Extra arguments are rejected via the constructor.
|
||||
with self.assertRaises(TypeError):
|
||||
function(Value(1), True, Value(2))
|
||||
# If the constructor is skipped, the arity check runs.
|
||||
func_instance = function(Value(1), True)
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
super(function, func_instance).__init__(Value(1), Value(2))
|
||||
|
||||
|
||||
class AggregateAnnotationPruningTests(TestCase):
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user