From 97c05a64ca87253e9789ebaab4b6d20a1b2370cf Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Fri, 27 Sep 2024 00:04:33 +0800 Subject: [PATCH] Refs #373 -- Added additional validations to tuple lookups. --- django/db/models/fields/tuple_lookups.py | 36 ++++++- tests/foreign_object/test_tuple_lookups.py | 110 ++++++++++++++++++++- 2 files changed, 139 insertions(+), 7 deletions(-) diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index eb2d80b20f..a94582db95 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -2,7 +2,7 @@ import itertools from django.core.exceptions import EmptyResultSet from django.db.models import Field -from django.db.models.expressions import Func, Value +from django.db.models.expressions import ColPairs, Func, Value from django.db.models.lookups import ( Exact, GreaterThan, @@ -28,17 +28,32 @@ class Tuple(Func): class TupleLookupMixin: def get_prep_lookup(self): + self.check_rhs_is_tuple_or_list() self.check_rhs_length_equals_lhs_length() return self.rhs + def check_rhs_is_tuple_or_list(self): + if not isinstance(self.rhs, (tuple, list)): + lhs_str = self.get_lhs_str() + raise ValueError( + f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list" + ) + def check_rhs_length_equals_lhs_length(self): len_lhs = len(self.lhs) if len_lhs != len(self.rhs): + lhs_str = self.get_lhs_str() raise ValueError( - f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " - f"must have {len_lhs} elements" + f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements" ) + def get_lhs_str(self): + if isinstance(self.lhs, ColPairs): + return repr(self.lhs.field.name) + else: + names = ", ".join(repr(f.name) for f in self.lhs) + return f"({names})" + def get_prep_lhs(self): if isinstance(self.lhs, (tuple, list)): return Tuple(*self.lhs) @@ -196,14 +211,25 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): class TupleIn(TupleLookupMixin, In): def get_prep_lookup(self): + self.check_rhs_is_tuple_or_list() + self.check_rhs_is_collection_of_tuples_or_lists() self.check_rhs_elements_length_equals_lhs_length() - return super(TupleLookupMixin, self).get_prep_lookup() + return self.rhs # skip checks from mixin + + def check_rhs_is_collection_of_tuples_or_lists(self): + if not all(isinstance(vals, (tuple, list)) for vals in self.rhs): + lhs_str = self.get_lhs_str() + raise ValueError( + f"{self.lookup_name!r} lookup of {lhs_str} " + "must be a collection of tuples or lists" + ) def check_rhs_elements_length_equals_lhs_length(self): len_lhs = len(self.lhs) if not all(len_lhs == len(vals) for vals in self.rhs): + lhs_str = self.get_lhs_str() raise ValueError( - f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " + f"{self.lookup_name!r} lookup of {lhs_str} " f"must have {len_lhs} elements each" ) diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index e2561676f3..06182d3bb5 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -1,3 +1,4 @@ +import itertools import unittest from django.db import NotSupportedError, connection @@ -129,6 +130,37 @@ class TupleLookupsTests(TestCase): (self.contact_1, self.contact_2, self.contact_5), ) + def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self): + test_cases = ( + (1, 2, 3), + ((1, 2), (3, 4), None), + ) + + for rhs in test_cases: + with self.subTest(rhs=rhs): + with self.assertRaisesMessage( + ValueError, + "'in' lookup of ('customer_code', 'company_code') " + "must be a collection of tuples or lists", + ): + TupleIn((F("customer_code"), F("company_code")), rhs) + + def test_tuple_in_rhs_must_have_2_elements_each(self): + test_cases = ( + ((),), + ((1,),), + ((1, 2, 3),), + ) + + for rhs in test_cases: + with self.subTest(rhs=rhs): + with self.assertRaisesMessage( + ValueError, + "'in' lookup of ('customer_code', 'company_code') " + "must have 2 elements each", + ): + TupleIn((F("customer_code"), F("company_code")), rhs) + def test_lt(self): c1, c2, c3, c4, c5, c6 = ( self.contact_1, @@ -358,8 +390,8 @@ class TupleLookupsTests(TestCase): ) def test_lookup_errors(self): - m_2_elements = "'%s' lookup of 'customer' field must have 2 elements" - m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each" + m_2_elements = "'%s' lookup of 'customer' must have 2 elements" + m_2_elements_each = "'in' lookup of 'customer' must have 2 elements each" test_cases = ( ({"customer": 1}, m_2_elements % "exact"), ({"customer": (1, 2, 3)}, m_2_elements % "exact"), @@ -381,3 +413,77 @@ class TupleLookupsTests(TestCase): self.assertRaisesMessage(ValueError, message), ): Contact.objects.get(**kwargs) + + def test_tuple_lookup_names(self): + test_cases = ( + (TupleExact, "exact"), + (TupleGreaterThan, "gt"), + (TupleGreaterThanOrEqual, "gte"), + (TupleLessThan, "lt"), + (TupleLessThanOrEqual, "lte"), + (TupleIn, "in"), + (TupleIsNull, "isnull"), + ) + + for lookup_class, lookup_name in test_cases: + with self.subTest(lookup_name): + self.assertEqual(lookup_class.lookup_name, lookup_name) + + def test_tuple_lookup_rhs_must_be_tuple_or_list(self): + test_cases = itertools.product( + ( + TupleExact, + TupleGreaterThan, + TupleGreaterThanOrEqual, + TupleLessThan, + TupleLessThanOrEqual, + TupleIn, + ), + ( + 0, + 1, + None, + True, + False, + {"foo": "bar"}, + ), + ) + + for lookup_cls, rhs in test_cases: + lookup_name = lookup_cls.lookup_name + with self.subTest(lookup_name=lookup_name, rhs=rhs): + with self.assertRaisesMessage( + ValueError, + f"'{lookup_name}' lookup of ('customer_code', 'company_code') " + "must be a tuple or a list", + ): + lookup_cls((F("customer_code"), F("company_code")), rhs) + + def test_tuple_lookup_rhs_must_have_2_elements(self): + test_cases = itertools.product( + ( + TupleExact, + TupleGreaterThan, + TupleGreaterThanOrEqual, + TupleLessThan, + TupleLessThanOrEqual, + ), + ( + [], + [1], + [1, 2, 3], + (), + (1,), + (1, 2, 3), + ), + ) + + for lookup_cls, rhs in test_cases: + lookup_name = lookup_cls.lookup_name + with self.subTest(lookup_name=lookup_name, rhs=rhs): + with self.assertRaisesMessage( + ValueError, + f"'{lookup_name}' lookup of ('customer_code', 'company_code') " + "must have 2 elements", + ): + lookup_cls((F("customer_code"), F("company_code")), rhs)