1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Fixed #28908 -- Allowed ArrayField lookups on ArrayAgg annotations.

This commit is contained in:
Sergey Fedoseev
2017-12-31 00:46:52 +05:00
committed by Tim Graham
parent 58ec55b157
commit 1490611038
2 changed files with 17 additions and 1 deletions

View File

@@ -1,4 +1,4 @@
from django.contrib.postgres.fields import JSONField from django.contrib.postgres.fields import ArrayField, JSONField
from django.db.models.aggregates import Aggregate from django.db.models.aggregates import Aggregate
__all__ = [ __all__ = [
@@ -10,6 +10,10 @@ class ArrayAgg(Aggregate):
function = 'ARRAY_AGG' function = 'ARRAY_AGG'
template = '%(function)s(%(distinct)s%(expressions)s)' template = '%(function)s(%(distinct)s%(expressions)s)'
@property
def output_field(self):
return ArrayField(self.source_expressions[0].output_field)
def __init__(self, expression, distinct=False, **extra): def __init__(self, expression, distinct=False, **extra):
super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra) super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra)

View File

@@ -46,6 +46,18 @@ class TestGeneralAggregate(PostgreSQLTestCase):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
self.assertEqual(values, {'arrayagg': []}) self.assertEqual(values, {'arrayagg': []})
def test_array_agg_lookups(self):
aggr1 = AggregateTestModel.objects.create()
aggr2 = AggregateTestModel.objects.create()
StatTestModel.objects.create(related_field=aggr1, int1=1, int2=0)
StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0)
StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0)
StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0)
qs = StatTestModel.objects.values('related_field').annotate(
array=ArrayAgg('int1')
).filter(array__overlap=[2]).values_list('array', flat=True)
self.assertCountEqual(qs.get(), [1, 2])
def test_bit_and_general(self): def test_bit_and_general(self):
values = AggregateTestModel.objects.filter( values = AggregateTestModel.objects.filter(
integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field')) integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field'))