mirror of
https://github.com/django/django.git
synced 2025-10-25 22:56:12 +00:00
Fixed #10929 -- Added default argument to aggregates.
Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
committed by
Mariusz Felisiak
parent
59942a66ce
commit
501a8db465
@@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate):
|
||||
return ArrayField(self.source_expressions[0].output_field)
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if not value:
|
||||
if value is None and self.default is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate):
|
||||
output_field = JSONField()
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if not value:
|
||||
if value is None and self.default is None:
|
||||
return '[]'
|
||||
return value
|
||||
|
||||
@@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate):
|
||||
super().__init__(expression, delimiter_expr, **extra)
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if not value:
|
||||
if value is None and self.default is None:
|
||||
return ''
|
||||
return value
|
||||
|
||||
@@ -9,10 +9,10 @@ __all__ = [
|
||||
class StatAggregate(Aggregate):
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self, y, x, output_field=None, filter=None):
|
||||
def __init__(self, y, x, output_field=None, filter=None, default=None):
|
||||
if not x or not y:
|
||||
raise ValueError('Both y and x must be provided.')
|
||||
super().__init__(y, x, output_field=output_field, filter=filter)
|
||||
super().__init__(y, x, output_field=output_field, filter=filter, default=default)
|
||||
|
||||
|
||||
class Corr(StatAggregate):
|
||||
@@ -20,9 +20,9 @@ class Corr(StatAggregate):
|
||||
|
||||
|
||||
class CovarPop(StatAggregate):
|
||||
def __init__(self, y, x, sample=False, filter=None):
|
||||
def __init__(self, y, x, sample=False, filter=None, default=None):
|
||||
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
|
||||
super().__init__(y, x, filter=filter)
|
||||
super().__init__(y, x, filter=filter, default=default)
|
||||
|
||||
|
||||
class RegrAvgX(StatAggregate):
|
||||
|
||||
@@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
|
||||
},
|
||||
})
|
||||
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
|
||||
skips.update({
|
||||
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
|
||||
},
|
||||
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
|
||||
},
|
||||
})
|
||||
if (
|
||||
self.connection.mysql_is_mariadb and
|
||||
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)
|
||||
|
||||
@@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions.
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Case, Func, Star, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin, NumericOutputFieldMixin,
|
||||
)
|
||||
@@ -22,11 +23,14 @@ class Aggregate(Func):
|
||||
allow_distinct = False
|
||||
empty_aggregate_value = None
|
||||
|
||||
def __init__(self, *expressions, distinct=False, filter=None, **extra):
|
||||
def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_aggregate_value is not None:
|
||||
raise TypeError(f'{self.__class__.__name__} does not allow default.')
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
@@ -56,7 +60,12 @@ class Aggregate(Func):
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
|
||||
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
|
||||
return c
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, 'resolve_expression'):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
return Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
|
||||
Reference in New Issue
Block a user