mirror of
https://github.com/django/django.git
synced 2025-06-05 19:49:13 +00:00
Fixed #34285 -- Fixed index/slice lookups on filtered aggregates with ArrayField.
Thanks Simon Charette for the review.
This commit is contained in:
parent
4403432b75
commit
ae1fe72e9b
@ -325,7 +325,9 @@ class IndexTransform(Transform):
|
|||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
lhs, params = compiler.compile(self.lhs)
|
lhs, params = compiler.compile(self.lhs)
|
||||||
return "%s[%%s]" % lhs, params + [self.index]
|
if not lhs.endswith("]"):
|
||||||
|
lhs = "(%s)" % lhs
|
||||||
|
return "%s[%%s]" % lhs, (*params, self.index)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
@ -349,7 +351,9 @@ class SliceTransform(Transform):
|
|||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
lhs, params = compiler.compile(self.lhs)
|
lhs, params = compiler.compile(self.lhs)
|
||||||
return "%s[%%s:%%s]" % lhs, params + [self.start, self.end]
|
if not lhs.endswith("]"):
|
||||||
|
lhs = "(%s)" % lhs
|
||||||
|
return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
|
||||||
|
|
||||||
|
|
||||||
class SliceTransformFactory:
|
class SliceTransformFactory:
|
||||||
|
@ -313,6 +313,49 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||||||
)
|
)
|
||||||
self.assertCountEqual(qs.get(), [1, 2])
|
self.assertCountEqual(qs.get(), [1, 2])
|
||||||
|
|
||||||
|
def test_array_agg_filter_index(self):
|
||||||
|
aggr1 = AggregateTestModel.objects.create(integer_field=1)
|
||||||
|
aggr2 = AggregateTestModel.objects.create(integer_field=2)
|
||||||
|
StatTestModel.objects.bulk_create(
|
||||||
|
[
|
||||||
|
StatTestModel(related_field=aggr1, int1=1, int2=0),
|
||||||
|
StatTestModel(related_field=aggr1, int1=2, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=3, int2=0),
|
||||||
|
StatTestModel(related_field=aggr2, int1=4, int2=1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qs = (
|
||||||
|
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
|
||||||
|
.annotate(
|
||||||
|
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
|
||||||
|
)
|
||||||
|
.annotate(array_value=F("array__0"))
|
||||||
|
.values_list("array_value", flat=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(qs, [1, 3])
|
||||||
|
|
||||||
|
def test_array_agg_filter_slice(self):
|
||||||
|
aggr1 = AggregateTestModel.objects.create(integer_field=1)
|
||||||
|
aggr2 = AggregateTestModel.objects.create(integer_field=2)
|
||||||
|
StatTestModel.objects.bulk_create(
|
||||||
|
[
|
||||||
|
StatTestModel(related_field=aggr1, int1=1, int2=0),
|
||||||
|
StatTestModel(related_field=aggr1, int1=2, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=3, int2=0),
|
||||||
|
StatTestModel(related_field=aggr2, int1=4, int2=1),
|
||||||
|
StatTestModel(related_field=aggr2, int1=5, int2=0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
qs = (
|
||||||
|
AggregateTestModel.objects.filter(pk__in=[aggr1.pk, aggr2.pk])
|
||||||
|
.annotate(
|
||||||
|
array=ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2=0))
|
||||||
|
)
|
||||||
|
.annotate(array_value=F("array__1_2"))
|
||||||
|
.values_list("array_value", flat=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(qs, [[], [5]])
|
||||||
|
|
||||||
def test_bit_and_general(self):
|
def test_bit_and_general(self):
|
||||||
values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
|
values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
|
||||||
bitand=BitAnd("integer_field")
|
bitand=BitAnd("integer_field")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user