mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #10929 -- Added default argument to aggregates.
Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							59942a66ce
						
					
				
				
					commit
					501a8db465
				
			| @@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate): | ||||
|         return ArrayField(self.source_expressions[0].output_field) | ||||
|  | ||||
|     def convert_value(self, value, expression, connection): | ||||
|         if not value: | ||||
|         if value is None and self.default is None: | ||||
|             return [] | ||||
|         return value | ||||
|  | ||||
| @@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate): | ||||
|     output_field = JSONField() | ||||
|  | ||||
|     def convert_value(self, value, expression, connection): | ||||
|         if not value: | ||||
|         if value is None and self.default is None: | ||||
|             return '[]' | ||||
|         return value | ||||
|  | ||||
| @@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate): | ||||
|         super().__init__(expression, delimiter_expr, **extra) | ||||
|  | ||||
|     def convert_value(self, value, expression, connection): | ||||
|         if not value: | ||||
|         if value is None and self.default is None: | ||||
|             return '' | ||||
|         return value | ||||
|   | ||||
| @@ -9,10 +9,10 @@ __all__ = [ | ||||
| class StatAggregate(Aggregate): | ||||
|     output_field = FloatField() | ||||
|  | ||||
|     def __init__(self, y, x, output_field=None, filter=None): | ||||
|     def __init__(self, y, x, output_field=None, filter=None, default=None): | ||||
|         if not x or not y: | ||||
|             raise ValueError('Both y and x must be provided.') | ||||
|         super().__init__(y, x, output_field=output_field, filter=filter) | ||||
|         super().__init__(y, x, output_field=output_field, filter=filter, default=default) | ||||
|  | ||||
|  | ||||
| class Corr(StatAggregate): | ||||
| @@ -20,9 +20,9 @@ class Corr(StatAggregate): | ||||
|  | ||||
|  | ||||
| class CovarPop(StatAggregate): | ||||
|     def __init__(self, y, x, sample=False, filter=None): | ||||
|     def __init__(self, y, x, sample=False, filter=None, default=None): | ||||
|         self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' | ||||
|         super().__init__(y, x, filter=filter) | ||||
|         super().__init__(y, x, filter=filter, default=default) | ||||
|  | ||||
|  | ||||
| class RegrAvgX(StatAggregate): | ||||
|   | ||||
| @@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures): | ||||
|                     'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o', | ||||
|                 }, | ||||
|             }) | ||||
|         if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,): | ||||
|             skips.update({ | ||||
|                 'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': { | ||||
|                     'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python', | ||||
|                     'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python', | ||||
|                 }, | ||||
|                 'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': { | ||||
|                     'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database', | ||||
|                     'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database', | ||||
|                 }, | ||||
|             }) | ||||
|         if ( | ||||
|             self.connection.mysql_is_mariadb and | ||||
|             (10, 4, 3) < self.connection.mysql_version < (10, 5, 2) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions. | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db.models.expressions import Case, Func, Star, When | ||||
| from django.db.models.fields import IntegerField | ||||
| from django.db.models.functions.comparison import Coalesce | ||||
| from django.db.models.functions.mixins import ( | ||||
|     FixDurationInputMixin, NumericOutputFieldMixin, | ||||
| ) | ||||
| @@ -22,11 +23,14 @@ class Aggregate(Func): | ||||
|     allow_distinct = False | ||||
|     empty_aggregate_value = None | ||||
|  | ||||
|     def __init__(self, *expressions, distinct=False, filter=None, **extra): | ||||
|     def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra): | ||||
|         if distinct and not self.allow_distinct: | ||||
|             raise TypeError("%s does not allow distinct." % self.__class__.__name__) | ||||
|         if default is not None and self.empty_aggregate_value is not None: | ||||
|             raise TypeError(f'{self.__class__.__name__} does not allow default.') | ||||
|         self.distinct = distinct | ||||
|         self.filter = filter | ||||
|         self.default = default | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def get_source_fields(self): | ||||
| @@ -56,7 +60,12 @@ class Aggregate(Func): | ||||
|                     before_resolved = self.get_source_expressions()[index] | ||||
|                     name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved) | ||||
|                     raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name)) | ||||
|         return c | ||||
|         if (default := c.default) is None: | ||||
|             return c | ||||
|         if hasattr(default, 'resolve_expression'): | ||||
|             default = default.resolve_expression(query, allow_joins, reuse, summarize) | ||||
|         c.default = None  # Reset the default argument before wrapping. | ||||
|         return Coalesce(c, default, output_field=c._output_field_or_none) | ||||
|  | ||||
|     @property | ||||
|     def default_alias(self): | ||||
|   | ||||
| @@ -19,8 +19,8 @@ module. They are described in more detail in the `PostgreSQL docs | ||||
|  | ||||
| .. admonition:: Common aggregate options | ||||
|  | ||||
|     All aggregates have the :ref:`filter <aggregate-filter>` keyword | ||||
|     argument. | ||||
|     All aggregates have the :ref:`filter <aggregate-filter>` keyword argument | ||||
|     and most also have the :ref:`default <aggregate-default>` keyword argument. | ||||
|  | ||||
| General-purpose aggregation functions | ||||
| ===================================== | ||||
| @@ -28,9 +28,10 @@ General-purpose aggregation functions | ||||
| ``ArrayAgg`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra) | ||||
| .. class:: ArrayAgg(expression, distinct=False, filter=None, default=None, ordering=(), **extra) | ||||
|  | ||||
|     Returns a list of values, including nulls, concatenated into an array. | ||||
|     Returns a list of values, including nulls, concatenated into an array, or | ||||
|     ``default`` if there are no values. | ||||
|  | ||||
|     .. attribute:: distinct | ||||
|  | ||||
| @@ -54,26 +55,26 @@ General-purpose aggregation functions | ||||
| ``BitAnd`` | ||||
| ---------- | ||||
|  | ||||
| .. class:: BitAnd(expression, filter=None, **extra) | ||||
| .. class:: BitAnd(expression, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or | ||||
|     ``None`` if all values are null. | ||||
|     ``default`` if all values are null. | ||||
|  | ||||
| ``BitOr`` | ||||
| --------- | ||||
|  | ||||
| .. class:: BitOr(expression, filter=None, **extra) | ||||
| .. class:: BitOr(expression, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or | ||||
|     ``None`` if all values are null. | ||||
|     ``default`` if all values are null. | ||||
|  | ||||
| ``BoolAnd`` | ||||
| ----------- | ||||
|  | ||||
| .. class:: BoolAnd(expression, filter=None, **extra) | ||||
| .. class:: BoolAnd(expression, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns ``True``, if all input values are true, ``None`` if all values are | ||||
|     null or if there are no values, otherwise ``False`` . | ||||
|     Returns ``True``, if all input values are true, ``default`` if all values | ||||
|     are null or if there are no values, otherwise ``False``. | ||||
|  | ||||
|     Usage example:: | ||||
|  | ||||
| @@ -92,9 +93,9 @@ General-purpose aggregation functions | ||||
| ``BoolOr`` | ||||
| ---------- | ||||
|  | ||||
| .. class:: BoolOr(expression, filter=None, **extra) | ||||
| .. class:: BoolOr(expression, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns ``True`` if at least one input value is true, ``None`` if all | ||||
|     Returns ``True`` if at least one input value is true, ``default`` if all | ||||
|     values are null or if there are no values, otherwise ``False``. | ||||
|  | ||||
|     Usage example:: | ||||
| @@ -114,9 +115,10 @@ General-purpose aggregation functions | ||||
| ``JSONBAgg`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: JSONBAgg(expressions, distinct=False, filter=None, ordering=(), **extra) | ||||
| .. class:: JSONBAgg(expressions, distinct=False, filter=None, default=None, ordering=(), **extra) | ||||
|  | ||||
|     Returns the input values as a ``JSON`` array. | ||||
|     Returns the input values as a ``JSON`` array, or ``default`` if there are | ||||
|     no values. | ||||
|  | ||||
|     .. attribute:: distinct | ||||
|  | ||||
| @@ -139,10 +141,10 @@ General-purpose aggregation functions | ||||
| ``StringAgg`` | ||||
| ------------- | ||||
|  | ||||
| .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=()) | ||||
| .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, ordering=()) | ||||
|  | ||||
|     Returns the input values concatenated into a string, separated by | ||||
|     the ``delimiter`` string. | ||||
|     the ``delimiter`` string, or ``default`` if there are no values. | ||||
|  | ||||
|     .. attribute:: delimiter | ||||
|  | ||||
| @@ -174,17 +176,17 @@ field or an expression returning a numeric data. Both are required. | ||||
| ``Corr`` | ||||
| -------- | ||||
|  | ||||
| .. class:: Corr(y, x, filter=None) | ||||
| .. class:: Corr(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the correlation coefficient as a ``float``, or ``None`` if there | ||||
|     Returns the correlation coefficient as a ``float``, or ``default`` if there | ||||
|     aren't any matching rows. | ||||
|  | ||||
| ``CovarPop`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: CovarPop(y, x, sample=False, filter=None) | ||||
| .. class:: CovarPop(y, x, sample=False, filter=None, default=None) | ||||
|  | ||||
|     Returns the population covariance as a ``float``, or ``None`` if there | ||||
|     Returns the population covariance as a ``float``, or ``default`` if there | ||||
|     aren't any matching rows. | ||||
|  | ||||
|     Has one optional argument: | ||||
| @@ -198,18 +200,18 @@ field or an expression returning a numeric data. Both are required. | ||||
| ``RegrAvgX`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: RegrAvgX(y, x, filter=None) | ||||
| .. class:: RegrAvgX(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the average of the independent variable (``sum(x)/N``) as a | ||||
|     ``float``, or ``None`` if there aren't any matching rows. | ||||
|     ``float``, or ``default`` if there aren't any matching rows. | ||||
|  | ||||
| ``RegrAvgY`` | ||||
| ------------ | ||||
|  | ||||
| .. class:: RegrAvgY(y, x, filter=None) | ||||
| .. class:: RegrAvgY(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the average of the dependent variable (``sum(y)/N``) as a | ||||
|     ``float``, or ``None`` if there aren't any matching rows. | ||||
|     ``float``, or ``default`` if there aren't any matching rows. | ||||
|  | ||||
| ``RegrCount`` | ||||
| ------------- | ||||
| @@ -219,56 +221,60 @@ field or an expression returning a numeric data. Both are required. | ||||
|     Returns an ``int`` of the number of input rows in which both expressions | ||||
|     are not null. | ||||
|  | ||||
|     .. note:: | ||||
|  | ||||
|         The ``default`` argument is not supported. | ||||
|  | ||||
| ``RegrIntercept`` | ||||
| ----------------- | ||||
|  | ||||
| .. class:: RegrIntercept(y, x, filter=None) | ||||
| .. class:: RegrIntercept(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the y-intercept of the least-squares-fit linear equation determined | ||||
|     by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any | ||||
|     by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any | ||||
|     matching rows. | ||||
|  | ||||
| ``RegrR2`` | ||||
| ---------- | ||||
|  | ||||
| .. class:: RegrR2(y, x, filter=None) | ||||
| .. class:: RegrR2(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the square of the correlation coefficient as a ``float``, or | ||||
|     ``None`` if there aren't any matching rows. | ||||
|     ``default`` if there aren't any matching rows. | ||||
|  | ||||
| ``RegrSlope`` | ||||
| ------------- | ||||
|  | ||||
| .. class:: RegrSlope(y, x, filter=None) | ||||
| .. class:: RegrSlope(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns the slope of the least-squares-fit linear equation determined | ||||
|     by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any | ||||
|     by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any | ||||
|     matching rows. | ||||
|  | ||||
| ``RegrSXX`` | ||||
| ----------- | ||||
|  | ||||
| .. class:: RegrSXX(y, x, filter=None) | ||||
| .. class:: RegrSXX(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent | ||||
|     variable) as a ``float``, or ``None`` if there aren't any matching rows. | ||||
|     variable) as a ``float``, or ``default`` if there aren't any matching rows. | ||||
|  | ||||
| ``RegrSXY`` | ||||
| ----------- | ||||
|  | ||||
| .. class:: RegrSXY(y, x, filter=None) | ||||
| .. class:: RegrSXY(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent | ||||
|     times dependent variable) as a ``float``, or ``None`` if there aren't any | ||||
|     matching rows. | ||||
|     times dependent variable) as a ``float``, or ``default`` if there aren't | ||||
|     any matching rows. | ||||
|  | ||||
| ``RegrSYY`` | ||||
| ----------- | ||||
|  | ||||
| .. class:: RegrSYY(y, x, filter=None) | ||||
| .. class:: RegrSYY(y, x, filter=None, default=None) | ||||
|  | ||||
|     Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent | ||||
|     variable)  as a ``float``, or ``None`` if there aren't any matching rows. | ||||
|     variable) as a ``float``, or ``default`` if there aren't any matching rows. | ||||
|  | ||||
| Usage examples | ||||
| ============== | ||||
|   | ||||
| @@ -59,7 +59,7 @@ will result in a database error. | ||||
| Usage examples:: | ||||
|  | ||||
|     >>> # Get a screen name from least to most public | ||||
|     >>> from django.db.models import Sum, Value as V | ||||
|     >>> from django.db.models import Sum | ||||
|     >>> from django.db.models.functions import Coalesce | ||||
|     >>> Author.objects.create(name='Margaret Smith', goes_by='Maggie') | ||||
|     >>> author = Author.objects.annotate( | ||||
| @@ -68,13 +68,18 @@ Usage examples:: | ||||
|     Maggie | ||||
|  | ||||
|     >>> # Prevent an aggregate Sum() from returning None | ||||
|     >>> # The aggregate default argument uses Coalesce() under the hood. | ||||
|     >>> aggregated = Author.objects.aggregate( | ||||
|     ...    combined_age=Coalesce(Sum('age'), V(0)), | ||||
|     ...    combined_age_default=Sum('age')) | ||||
|     ...    combined_age=Sum('age'), | ||||
|     ...    combined_age_default=Sum('age', default=0), | ||||
|     ...    combined_age_coalesce=Coalesce(Sum('age'), 0), | ||||
|     ... ) | ||||
|     >>> print(aggregated['combined_age']) | ||||
|     0 | ||||
|     >>> print(aggregated['combined_age_default']) | ||||
|     None | ||||
|     >>> print(aggregated['combined_age_default']) | ||||
|     0 | ||||
|     >>> print(aggregated['combined_age_coalesce']) | ||||
|     0 | ||||
|  | ||||
| .. warning:: | ||||
|  | ||||
|   | ||||
| @@ -393,7 +393,7 @@ some complex computations:: | ||||
|  | ||||
| The ``Aggregate`` API is as follows: | ||||
|  | ||||
| .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra) | ||||
| .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra) | ||||
|  | ||||
|     .. attribute:: template | ||||
|  | ||||
| @@ -452,6 +452,11 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's | ||||
| used to filter the rows that are aggregated. See :ref:`conditional-aggregation` | ||||
| and :ref:`filtering-on-annotations` for example usage. | ||||
|  | ||||
| The ``default`` argument takes a value that will be passed along with the | ||||
| aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for | ||||
| specifying a value to be returned other than ``None`` when the queryset (or | ||||
| grouping) contains no entries. | ||||
|  | ||||
| The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated | ||||
| into the ``template`` attribute. | ||||
|  | ||||
| @@ -459,6 +464,10 @@ into the ``template`` attribute. | ||||
|  | ||||
|     Support for transforms of the field was added. | ||||
|  | ||||
| .. versionchanged:: 4.0 | ||||
|  | ||||
|     The ``default`` argument was added. | ||||
|  | ||||
| Creating your own Aggregate Functions | ||||
| ------------------------------------- | ||||
|  | ||||
|   | ||||
| @@ -3540,8 +3540,10 @@ documentation to learn how to create your aggregates. | ||||
|  | ||||
|     Aggregation functions return ``None`` when used with an empty | ||||
|     ``QuerySet``. For example, the ``Sum`` aggregation function returns ``None`` | ||||
|     instead of ``0`` if the ``QuerySet`` contains no entries. An exception is | ||||
|     ``Count``, which does return ``0`` if the ``QuerySet`` is empty. | ||||
|     instead of ``0`` if the ``QuerySet`` contains no entries. To return another | ||||
|     value instead, pass a value to the ``default`` argument. An exception is | ||||
|     ``Count``, which does return ``0`` if the ``QuerySet`` is empty. ``Count`` | ||||
|     does not support the ``default`` argument. | ||||
|  | ||||
| All aggregates have the following parameters in common: | ||||
|  | ||||
| @@ -3578,6 +3580,16 @@ rows that are aggregated. | ||||
| See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for | ||||
| example usage. | ||||
|  | ||||
| .. _aggregate-default: | ||||
|  | ||||
| ``default`` | ||||
| ~~~~~~~~~~~ | ||||
|  | ||||
| .. versionadded:: 4.0 | ||||
|  | ||||
| An optional argument that allows specifying a value to use as a default value | ||||
| when the queryset (or grouping) contains no entries. | ||||
|  | ||||
| ``**extra`` | ||||
| ~~~~~~~~~~~ | ||||
|  | ||||
| @@ -3587,7 +3599,7 @@ by the aggregate. | ||||
| ``Avg`` | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. class:: Avg(expression, output_field=None, distinct=False, filter=None, **extra) | ||||
| .. class:: Avg(expression, output_field=None, distinct=False, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns the mean value of the given expression, which must be numeric | ||||
|     unless you specify a different ``output_field``. | ||||
| @@ -3623,10 +3635,14 @@ by the aggregate. | ||||
|         This is the SQL equivalent of ``COUNT(DISTINCT <field>)``. The default | ||||
|         value is ``False``. | ||||
|  | ||||
|     .. note:: | ||||
|  | ||||
|         The ``default`` argument is not supported. | ||||
|  | ||||
| ``Max`` | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. class:: Max(expression, output_field=None, filter=None, **extra) | ||||
| .. class:: Max(expression, output_field=None, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns the maximum value of the given expression. | ||||
|  | ||||
| @@ -3636,7 +3652,7 @@ by the aggregate. | ||||
| ``Min`` | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. class:: Min(expression, output_field=None, filter=None, **extra) | ||||
| .. class:: Min(expression, output_field=None, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns the minimum value of the given expression. | ||||
|  | ||||
| @@ -3646,7 +3662,7 @@ by the aggregate. | ||||
| ``StdDev`` | ||||
| ~~~~~~~~~~ | ||||
|  | ||||
| .. class:: StdDev(expression, output_field=None, sample=False, filter=None, **extra) | ||||
| .. class:: StdDev(expression, output_field=None, sample=False, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns the standard deviation of the data in the provided expression. | ||||
|  | ||||
| @@ -3664,7 +3680,7 @@ by the aggregate. | ||||
| ``Sum`` | ||||
| ~~~~~~~ | ||||
|  | ||||
| .. class:: Sum(expression, output_field=None, distinct=False, filter=None, **extra) | ||||
| .. class:: Sum(expression, output_field=None, distinct=False, filter=None, default=None, **extra) | ||||
|  | ||||
|     Computes the sum of all values of the given expression. | ||||
|  | ||||
| @@ -3682,7 +3698,7 @@ by the aggregate. | ||||
| ``Variance`` | ||||
| ~~~~~~~~~~~~ | ||||
|  | ||||
| .. class:: Variance(expression, output_field=None, sample=False, filter=None, **extra) | ||||
| .. class:: Variance(expression, output_field=None, sample=False, filter=None, default=None, **extra) | ||||
|  | ||||
|     Returns the variance of the data in the provided expression. | ||||
|  | ||||
|   | ||||
| @@ -288,6 +288,10 @@ Models | ||||
| * :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet`` | ||||
|   annotations, aggregations, and directly in filters. | ||||
|  | ||||
| * The new :ref:`default <aggregate-default>` argument for built-in aggregates | ||||
|   allows specifying a value to be returned when the queryset (or grouping) | ||||
|   contains no entries, rather than ``None``. | ||||
|  | ||||
| Requests and Responses | ||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -1,15 +1,19 @@ | ||||
| import datetime | ||||
| import math | ||||
| import re | ||||
| from decimal import Decimal | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.db import connection | ||||
| from django.db.models import ( | ||||
|     Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, | ||||
|     IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When, | ||||
|     Avg, Case, Count, DateField, DateTimeField, DecimalField, DurationField, | ||||
|     Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Q, StdDev, | ||||
|     Subquery, Sum, TimeField, Value, Variance, When, | ||||
| ) | ||||
| from django.db.models.expressions import Func, RawSQL | ||||
| from django.db.models.functions import Coalesce, Greatest | ||||
| from django.db.models.functions import ( | ||||
|     Cast, Coalesce, Greatest, Now, Pi, TruncDate, TruncHour, | ||||
| ) | ||||
| from django.test import TestCase | ||||
| from django.test.testcases import skipUnlessDBFeature | ||||
| from django.test.utils import Approximate, CaptureQueriesContext | ||||
| @@ -18,6 +22,20 @@ from django.utils import timezone | ||||
| from .models import Author, Book, Publisher, Store | ||||
|  | ||||
|  | ||||
| class NowUTC(Now): | ||||
|     template = 'CURRENT_TIMESTAMP' | ||||
|     output_field = DateTimeField() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return self.as_sql(compiler, connection, template='UTC_TIMESTAMP', **extra_context) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return self.as_sql(compiler, connection, template="CURRENT_TIMESTAMP AT TIME ZONE 'UTC'", **extra_context) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         return self.as_sql(compiler, connection, template="STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'", **extra_context) | ||||
|  | ||||
|  | ||||
| class AggregateTestCase(TestCase): | ||||
|  | ||||
|     @classmethod | ||||
| @@ -1402,3 +1420,190 @@ class AggregateTestCase(TestCase): | ||||
|                 )['latest_opening'], | ||||
|                 datetime.datetime, | ||||
|             ) | ||||
|  | ||||
|     def test_aggregation_default_unsupported_by_count(self): | ||||
|         msg = 'Count does not allow default.' | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             Count('age', default=0) | ||||
|  | ||||
|     def test_aggregation_default_unset(self): | ||||
|         for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: | ||||
|             with self.subTest(Aggregate): | ||||
|                 result = Author.objects.filter(age__gt=100).aggregate( | ||||
|                     value=Aggregate('age'), | ||||
|                 ) | ||||
|                 self.assertIsNone(result['value']) | ||||
|  | ||||
|     def test_aggregation_default_zero(self): | ||||
|         for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: | ||||
|             with self.subTest(Aggregate): | ||||
|                 result = Author.objects.filter(age__gt=100).aggregate( | ||||
|                     value=Aggregate('age', default=0), | ||||
|                 ) | ||||
|                 self.assertEqual(result['value'], 0) | ||||
|  | ||||
|     def test_aggregation_default_integer(self): | ||||
|         for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: | ||||
|             with self.subTest(Aggregate): | ||||
|                 result = Author.objects.filter(age__gt=100).aggregate( | ||||
|                     value=Aggregate('age', default=21), | ||||
|                 ) | ||||
|                 self.assertEqual(result['value'], 21) | ||||
|  | ||||
|     def test_aggregation_default_expression(self): | ||||
|         for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: | ||||
|             with self.subTest(Aggregate): | ||||
|                 result = Author.objects.filter(age__gt=100).aggregate( | ||||
|                     value=Aggregate('age', default=Value(5) * Value(7)), | ||||
|                 ) | ||||
|                 self.assertEqual(result['value'], 35) | ||||
|  | ||||
|     def test_aggregation_default_group_by(self): | ||||
|         qs = Publisher.objects.values('name').annotate( | ||||
|             books=Count('book'), | ||||
|             pages=Sum('book__pages', default=0), | ||||
|         ).filter(books=0) | ||||
|         self.assertSequenceEqual( | ||||
|             qs, | ||||
|             [{'name': "Jonno's House of Books", 'books': 0, 'pages': 0}], | ||||
|         ) | ||||
|  | ||||
|     def test_aggregation_default_compound_expression(self): | ||||
|         # Scale rating to a percentage; default to 50% if no books published. | ||||
|         formula = Avg('book__rating', default=2.5) * 20.0 | ||||
|         queryset = Publisher.objects.annotate(rating=formula).order_by('name') | ||||
|         self.assertSequenceEqual(queryset.values('name', 'rating'), [ | ||||
|             {'name': 'Apress', 'rating': 85.0}, | ||||
|             {'name': "Jonno's House of Books", 'rating': 50.0}, | ||||
|             {'name': 'Morgan Kaufmann', 'rating': 100.0}, | ||||
|             {'name': 'Prentice Hall', 'rating': 80.0}, | ||||
|             {'name': 'Sams', 'rating': 60.0}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_time_from_python(self): | ||||
|         expr = Min( | ||||
|             'store__friday_night_closing', | ||||
|             filter=~Q(store__name='Amazon.com'), | ||||
|             default=datetime.time(17), | ||||
|         ) | ||||
|         if connection.vendor == 'mysql': | ||||
|             # Workaround for #30224 for MySQL 8.0+ & MariaDB. | ||||
|             expr.default = Cast(expr.default, TimeField()) | ||||
|         queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') | ||||
|         self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ | ||||
|             {'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|             {'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)}, | ||||
|             {'isbn': '067232959', 'oldest_store_opening': datetime.time(17)}, | ||||
|             {'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|             {'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)}, | ||||
|             {'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_time_from_database(self): | ||||
|         now = timezone.now().astimezone(timezone.utc) | ||||
|         expr = Min( | ||||
|             'store__friday_night_closing', | ||||
|             filter=~Q(store__name='Amazon.com'), | ||||
|             default=TruncHour(NowUTC(), output_field=TimeField()), | ||||
|         ) | ||||
|         queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') | ||||
|         self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ | ||||
|             {'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|             {'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)}, | ||||
|             {'isbn': '067232959', 'oldest_store_opening': datetime.time(now.hour)}, | ||||
|             {'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|             {'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)}, | ||||
|             {'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_date_from_python(self): | ||||
|         expr = Min('book__pubdate', default=datetime.date(1970, 1, 1)) | ||||
|         if connection.vendor == 'mysql': | ||||
|             # Workaround for #30224 for MySQL 5.7+ & MariaDB. | ||||
|             expr.default = Cast(expr.default, DateField()) | ||||
|         queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name') | ||||
|         self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [ | ||||
|             {'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)}, | ||||
|             {'name': "Jonno's House of Books", 'earliest_pubdate': datetime.date(1970, 1, 1)}, | ||||
|             {'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)}, | ||||
|             {'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)}, | ||||
|             {'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_date_from_database(self): | ||||
|         now = timezone.now().astimezone(timezone.utc) | ||||
|         expr = Min('book__pubdate', default=TruncDate(NowUTC())) | ||||
|         queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name') | ||||
|         self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [ | ||||
|             {'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)}, | ||||
|             {'name': "Jonno's House of Books", 'earliest_pubdate': now.date()}, | ||||
|             {'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)}, | ||||
|             {'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)}, | ||||
|             {'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_datetime_from_python(self): | ||||
|         expr = Min( | ||||
|             'store__original_opening', | ||||
|             filter=~Q(store__name='Amazon.com'), | ||||
|             default=datetime.datetime(1970, 1, 1), | ||||
|         ) | ||||
|         if connection.vendor == 'mysql': | ||||
|             # Workaround for #30224 for MySQL 8.0+ & MariaDB. | ||||
|             expr.default = Cast(expr.default, DateTimeField()) | ||||
|         queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') | ||||
|         self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ | ||||
|             {'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|             {'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, | ||||
|             {'isbn': '067232959', 'oldest_store_opening': datetime.datetime(1970, 1, 1)}, | ||||
|             {'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|             {'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, | ||||
|             {'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_datetime_from_database(self): | ||||
|         now = timezone.now().astimezone(timezone.utc) | ||||
|         expr = Min( | ||||
|             'store__original_opening', | ||||
|             filter=~Q(store__name='Amazon.com'), | ||||
|             default=TruncHour(NowUTC(), output_field=DateTimeField()), | ||||
|         ) | ||||
|         queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') | ||||
|         self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ | ||||
|             {'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|             {'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, | ||||
|             {'isbn': '067232959', 'oldest_store_opening': now.replace(minute=0, second=0, microsecond=0, tzinfo=None)}, | ||||
|             {'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|             {'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, | ||||
|             {'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, | ||||
|         ]) | ||||
|  | ||||
|     def test_aggregation_default_using_duration_from_python(self): | ||||
|         result = Publisher.objects.filter(num_awards__gt=3).aggregate( | ||||
|             value=Sum('duration', default=datetime.timedelta(0)), | ||||
|         ) | ||||
|         self.assertEqual(result['value'], datetime.timedelta(0)) | ||||
|  | ||||
|     def test_aggregation_default_using_duration_from_database(self): | ||||
|         result = Publisher.objects.filter(num_awards__gt=3).aggregate( | ||||
|             value=Sum('duration', default=Now() - Now()), | ||||
|         ) | ||||
|         self.assertEqual(result['value'], datetime.timedelta(0)) | ||||
|  | ||||
|     def test_aggregation_default_using_decimal_from_python(self): | ||||
|         result = Book.objects.filter(rating__lt=3.0).aggregate( | ||||
|             value=Sum('price', default=Decimal('0.00')), | ||||
|         ) | ||||
|         self.assertEqual(result['value'], Decimal('0.00')) | ||||
|  | ||||
|     def test_aggregation_default_using_decimal_from_database(self): | ||||
|         result = Book.objects.filter(rating__lt=3.0).aggregate( | ||||
|             value=Sum('price', default=Pi()), | ||||
|         ) | ||||
|         self.assertAlmostEqual(result['value'], Decimal.from_float(math.pi), places=6) | ||||
|  | ||||
|     def test_aggregation_default_passed_another_aggregate(self): | ||||
|         result = Book.objects.aggregate( | ||||
|             value=Sum('price', filter=Q(rating__lt=3.0), default=Avg('pages') / 10.0), | ||||
|         ) | ||||
|         self.assertAlmostEqual(result['value'], Decimal('61.72'), places=2) | ||||
|   | ||||
| @@ -72,6 +72,34 @@ class TestGeneralAggregate(PostgreSQLTestCase): | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|  | ||||
|     def test_default_argument(self): | ||||
|         AggregateTestModel.objects.all().delete() | ||||
|         tests = [ | ||||
|             (ArrayAgg('char_field', default=['<empty>']), ['<empty>']), | ||||
|             (ArrayAgg('integer_field', default=[0]), [0]), | ||||
|             (ArrayAgg('boolean_field', default=[False]), [False]), | ||||
|             (BitAnd('integer_field', default=0), 0), | ||||
|             (BitOr('integer_field', default=0), 0), | ||||
|             (BoolAnd('boolean_field', default=False), False), | ||||
|             (BoolOr('boolean_field', default=False), False), | ||||
|             (JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']), | ||||
|             (StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'), | ||||
|         ] | ||||
|         for aggregation, expected_result in tests: | ||||
|             with self.subTest(aggregation=aggregation): | ||||
|                 # Empty result with non-execution optimization. | ||||
|                 with self.assertNumQueries(0): | ||||
|                     values = AggregateTestModel.objects.none().aggregate( | ||||
|                         aggregation=aggregation, | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|                 # Empty result when query must be executed. | ||||
|                 with self.assertNumQueries(1): | ||||
|                     values = AggregateTestModel.objects.aggregate( | ||||
|                         aggregation=aggregation, | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|  | ||||
|     def test_array_agg_charfield(self): | ||||
|         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) | ||||
|         self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) | ||||
| @@ -515,6 +543,37 @@ class TestStatisticsAggregate(PostgreSQLTestCase): | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|  | ||||
|     def test_default_argument(self): | ||||
|         StatTestModel.objects.all().delete() | ||||
|         tests = [ | ||||
|             (Corr(y='int2', x='int1', default=0), 0), | ||||
|             (CovarPop(y='int2', x='int1', default=0), 0), | ||||
|             (CovarPop(y='int2', x='int1', sample=True, default=0), 0), | ||||
|             (RegrAvgX(y='int2', x='int1', default=0), 0), | ||||
|             (RegrAvgY(y='int2', x='int1', default=0), 0), | ||||
|             # RegrCount() doesn't support the default argument. | ||||
|             (RegrIntercept(y='int2', x='int1', default=0), 0), | ||||
|             (RegrR2(y='int2', x='int1', default=0), 0), | ||||
|             (RegrSlope(y='int2', x='int1', default=0), 0), | ||||
|             (RegrSXX(y='int2', x='int1', default=0), 0), | ||||
|             (RegrSXY(y='int2', x='int1', default=0), 0), | ||||
|             (RegrSYY(y='int2', x='int1', default=0), 0), | ||||
|         ] | ||||
|         for aggregation, expected_result in tests: | ||||
|             with self.subTest(aggregation=aggregation): | ||||
|                 # Empty result with non-execution optimization. | ||||
|                 with self.assertNumQueries(0): | ||||
|                     values = StatTestModel.objects.none().aggregate( | ||||
|                         aggregation=aggregation, | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|                 # Empty result when query must be executed. | ||||
|                 with self.assertNumQueries(1): | ||||
|                     values = StatTestModel.objects.aggregate( | ||||
|                         aggregation=aggregation, | ||||
|                     ) | ||||
|                     self.assertEqual(values, {'aggregation': expected_result}) | ||||
|  | ||||
|     def test_corr_general(self): | ||||
|         values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1')) | ||||
|         self.assertEqual(values, {'corr': -1.0}) | ||||
| @@ -539,6 +598,11 @@ class TestStatisticsAggregate(PostgreSQLTestCase): | ||||
|         values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1')) | ||||
|         self.assertEqual(values, {'regrcount': 3}) | ||||
|  | ||||
|     def test_regr_count_default(self): | ||||
|         msg = 'RegrCount does not allow default.' | ||||
|         with self.assertRaisesMessage(TypeError, msg): | ||||
|             RegrCount(y='int2', x='int1', default=0) | ||||
|  | ||||
|     def test_regr_intercept_general(self): | ||||
|         values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1')) | ||||
|         self.assertEqual(values, {'regrintercept': 4}) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user