diff --git a/tests/expressions_window/models.py b/tests/expressions_window/models.py index d9c7568d10..e1cc02323d 100644 --- a/tests/expressions_window/models.py +++ b/tests/expressions_window/models.py @@ -13,3 +13,10 @@ class Employee(models.Model): age = models.IntegerField(blank=False, null=False) classification = models.ForeignKey('Classification', on_delete=models.CASCADE, null=True) bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True) + + +class Detail(models.Model): + value = models.JSONField() + + class Meta: + required_db_features = {'supports_json_field'} diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index 6d00d6543c..b837f73eb9 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -5,16 +5,17 @@ from unittest import mock, skipIf from django.core.exceptions import FieldError from django.db import NotSupportedError, connection from django.db.models import ( - Avg, BooleanField, Case, F, Func, Max, Min, OuterRef, Q, RowRange, - Subquery, Sum, Value, ValueRange, When, Window, WindowFrame, + Avg, BooleanField, Case, F, Func, IntegerField, Max, Min, OuterRef, Q, + RowRange, Subquery, Sum, Value, ValueRange, When, Window, WindowFrame, ) +from django.db.models.fields.json import KeyTextTransform, KeyTransform from django.db.models.functions import ( - CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead, + Cast, CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead, NthValue, Ntile, PercentRank, Rank, RowNumber, Upper, ) from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature -from .models import Employee +from .models import Detail, Employee @skipUnlessDBFeature('supports_over_clause') @@ -743,6 +744,36 @@ class WindowFunctionTests(TestCase): {'department': 'Management', 'salary': 100000} ]) + @skipUnlessDBFeature('supports_json_field') + def test_key_transform(self): + Detail.objects.bulk_create([ + Detail(value={'department': 'IT', 'name': 'Smith', 'salary': 37000}), + Detail(value={'department': 'IT', 'name': 'Nowak', 'salary': 32000}), + Detail(value={'department': 'HR', 'name': 'Brown', 'salary': 50000}), + Detail(value={'department': 'HR', 'name': 'Smith', 'salary': 55000}), + Detail(value={'department': 'PR', 'name': 'Moore', 'salary': 90000}), + ]) + qs = Detail.objects.annotate(department_sum=Window( + expression=Sum(Cast( + KeyTextTransform('salary', 'value'), + output_field=IntegerField(), + )), + partition_by=[KeyTransform('department', 'value')], + order_by=[KeyTransform('name', 'value')], + )).order_by('value__department', 'department_sum') + self.assertQuerysetEqual(qs, [ + ('Brown', 'HR', 50000, 50000), + ('Smith', 'HR', 55000, 105000), + ('Nowak', 'IT', 32000, 32000), + ('Smith', 'IT', 37000, 69000), + ('Moore', 'PR', 90000, 90000), + ], lambda entry: ( + entry.value['name'], + entry.value['department'], + entry.value['salary'], + entry.department_sum, + )) + def test_invalid_start_value_range(self): msg = "start argument must be a negative integer, zero, or None, but got '3'." with self.assertRaisesMessage(ValueError, msg):