mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #22288 -- Fixed F() expressions with the __range lookup.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							f6cd669ff2
						
					
				
				
					commit
					4f138fe5a4
				
			
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -498,6 +498,7 @@ answer newbie questions, and generally made Django that much better: | |||||||
|     Matthew Schinckel <matt@schinckel.net> |     Matthew Schinckel <matt@schinckel.net> | ||||||
|     Matthew Somerville <matthew-django@dracos.co.uk> |     Matthew Somerville <matthew-django@dracos.co.uk> | ||||||
|     Matthew Tretter <m@tthewwithanm.com> |     Matthew Tretter <m@tthewwithanm.com> | ||||||
|  |     Matthew Wilkes <matt@matthewwilkes.name> | ||||||
|     Matthias Kestenholz <mk@406.ch> |     Matthias Kestenholz <mk@406.ch> | ||||||
|     Matthias Pronk <django@masida.nl> |     Matthias Pronk <django@masida.nl> | ||||||
|     Matt Hoskins <skaffenuk@googlemail.com> |     Matt Hoskins <skaffenuk@googlemail.com> | ||||||
|   | |||||||
| @@ -239,7 +239,13 @@ class ArrayInLookup(In): | |||||||
|         values = super(ArrayInLookup, self).get_prep_lookup() |         values = super(ArrayInLookup, self).get_prep_lookup() | ||||||
|         # In.process_rhs() expects values to be hashable, so convert lists |         # In.process_rhs() expects values to be hashable, so convert lists | ||||||
|         # to tuples. |         # to tuples. | ||||||
|         return [tuple(value) for value in values] |         prepared_values = [] | ||||||
|  |         for value in values: | ||||||
|  |             if hasattr(value, 'resolve_expression'): | ||||||
|  |                 prepared_values.append(value) | ||||||
|  |             else: | ||||||
|  |                 prepared_values.append(tuple(value)) | ||||||
|  |         return prepared_values | ||||||
|  |  | ||||||
|  |  | ||||||
| class IndexTransform(Transform): | class IndexTransform(Transform): | ||||||
|   | |||||||
| @@ -155,6 +155,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         # MySQL doesn't support tz-aware datetimes |         # MySQL doesn't support tz-aware datetimes | ||||||
|         if timezone.is_aware(value): |         if timezone.is_aware(value): | ||||||
|             if settings.USE_TZ: |             if settings.USE_TZ: | ||||||
| @@ -171,6 +175,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         # MySQL doesn't support tz-aware times |         # MySQL doesn't support tz-aware times | ||||||
|         if timezone.is_aware(value): |         if timezone.is_aware(value): | ||||||
|             raise ValueError("MySQL backend does not support timezone-aware times.") |             raise ValueError("MySQL backend does not support timezone-aware times.") | ||||||
|   | |||||||
| @@ -408,6 +408,10 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         # cx_Oracle doesn't support tz-aware datetimes |         # cx_Oracle doesn't support tz-aware datetimes | ||||||
|         if timezone.is_aware(value): |         if timezone.is_aware(value): | ||||||
|             if settings.USE_TZ: |             if settings.USE_TZ: | ||||||
| @@ -421,6 +425,10 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         if isinstance(value, six.string_types): |         if isinstance(value, six.string_types): | ||||||
|             return datetime.datetime.strptime(value, '%H:%M:%S') |             return datetime.datetime.strptime(value, '%H:%M:%S') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -182,6 +182,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         # SQLite doesn't support tz-aware datetimes |         # SQLite doesn't support tz-aware datetimes | ||||||
|         if timezone.is_aware(value): |         if timezone.is_aware(value): | ||||||
|             if settings.USE_TZ: |             if settings.USE_TZ: | ||||||
| @@ -195,6 +199,10 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |         # Expression values are adapted by the database. | ||||||
|  |         if hasattr(value, 'resolve_expression'): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|         # SQLite doesn't support tz-aware datetimes |         # SQLite doesn't support tz-aware datetimes | ||||||
|         if timezone.is_aware(value): |         if timezone.is_aware(value): | ||||||
|             raise ValueError("SQLite backend does not support timezone-aware times.") |             raise ValueError("SQLite backend does not support timezone-aware times.") | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | import itertools | ||||||
| import math | import math | ||||||
| import warnings | import warnings | ||||||
| from copy import copy | from copy import copy | ||||||
| @@ -170,6 +171,12 @@ class FieldGetDbPrepValueMixin(object): | |||||||
|     """ |     """ | ||||||
|     get_db_prep_lookup_value_is_iterable = False |     get_db_prep_lookup_value_is_iterable = False | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def get_prep_lookup_value(cls, value, output_field): | ||||||
|  |         if hasattr(value, '_prepare'): | ||||||
|  |             return value._prepare(output_field) | ||||||
|  |         return output_field.get_prep_value(value) | ||||||
|  |  | ||||||
|     def get_db_prep_lookup(self, value, connection): |     def get_db_prep_lookup(self, value, connection): | ||||||
|         # For relational fields, use the output_field of the 'field' attribute. |         # For relational fields, use the output_field of the 'field' attribute. | ||||||
|         field = getattr(self.lhs.output_field, 'field', None) |         field = getattr(self.lhs.output_field, 'field', None) | ||||||
| @@ -191,6 +198,51 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): | |||||||
|     """ |     """ | ||||||
|     get_db_prep_lookup_value_is_iterable = True |     get_db_prep_lookup_value_is_iterable = True | ||||||
|  |  | ||||||
|  |     def get_prep_lookup(self): | ||||||
|  |         prepared_values = [] | ||||||
|  |         if hasattr(self.rhs, '_prepare'): | ||||||
|  |             # A subquery is like an iterable but its items shouldn't be | ||||||
|  |             # prepared independently. | ||||||
|  |             return self.rhs._prepare(self.lhs.output_field) | ||||||
|  |         for rhs_value in self.rhs: | ||||||
|  |             if hasattr(rhs_value, 'resolve_expression'): | ||||||
|  |                 # An expression will be handled by the database but can coexist | ||||||
|  |                 # alongside real values. | ||||||
|  |                 pass | ||||||
|  |             elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): | ||||||
|  |                 rhs_value = self.lhs.output_field.get_prep_value(rhs_value) | ||||||
|  |             prepared_values.append(rhs_value) | ||||||
|  |         return prepared_values | ||||||
|  |  | ||||||
|  |     def process_rhs(self, compiler, connection): | ||||||
|  |         if self.rhs_is_direct_value(): | ||||||
|  |             # rhs should be an iterable of values. Use batch_process_rhs() | ||||||
|  |             # to prepare/transform those values. | ||||||
|  |             return self.batch_process_rhs(compiler, connection) | ||||||
|  |         else: | ||||||
|  |             return super(FieldGetDbPrepValueIterableMixin, self).process_rhs(compiler, connection) | ||||||
|  |  | ||||||
|  |     def resolve_expression_parameter(self, compiler, connection, sql, param): | ||||||
|  |         params = [param] | ||||||
|  |         if hasattr(param, 'resolve_expression'): | ||||||
|  |             param = param.resolve_expression(compiler.query) | ||||||
|  |         if hasattr(param, 'as_sql'): | ||||||
|  |             sql, params = param.as_sql(compiler, connection) | ||||||
|  |         return sql, params | ||||||
|  |  | ||||||
|  |     def batch_process_rhs(self, compiler, connection, rhs=None): | ||||||
|  |         pre_processed = super(FieldGetDbPrepValueIterableMixin, self).batch_process_rhs(compiler, connection, rhs) | ||||||
|  |         # The params list may contain expressions which compile to a | ||||||
|  |         # sql/param pair. Zip them to get sql and param pairs that refer to the | ||||||
|  |         # same argument and attempt to replace them with the result of | ||||||
|  |         # compiling the param step. | ||||||
|  |         sql, params = zip(*( | ||||||
|  |             self.resolve_expression_parameter(compiler, connection, sql, param) | ||||||
|  |             for sql, param in zip(*pre_processed) | ||||||
|  |         )) | ||||||
|  |         params = itertools.chain.from_iterable(params) | ||||||
|  |         return sql, tuple(params) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): | class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): | ||||||
|     lookup_name = 'exact' |     lookup_name = 'exact' | ||||||
| @@ -255,13 +307,6 @@ IntegerField.register_lookup(IntegerLessThan) | |||||||
| class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | ||||||
|     lookup_name = 'in' |     lookup_name = 'in' | ||||||
|  |  | ||||||
|     def get_prep_lookup(self): |  | ||||||
|         if hasattr(self.rhs, '_prepare'): |  | ||||||
|             return self.rhs._prepare(self.lhs.output_field) |  | ||||||
|         if hasattr(self.lhs.output_field, 'get_prep_value'): |  | ||||||
|             return [self.lhs.output_field.get_prep_value(v) for v in self.rhs] |  | ||||||
|         return self.rhs |  | ||||||
|  |  | ||||||
|     def process_rhs(self, compiler, connection): |     def process_rhs(self, compiler, connection): | ||||||
|         db_rhs = getattr(self.rhs, '_db', None) |         db_rhs = getattr(self.rhs, '_db', None) | ||||||
|         if db_rhs is not None and db_rhs != connection.alias: |         if db_rhs is not None and db_rhs != connection.alias: | ||||||
| @@ -409,21 +454,9 @@ Field.register_lookup(IEndsWith) | |||||||
| class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): | ||||||
|     lookup_name = 'range' |     lookup_name = 'range' | ||||||
|  |  | ||||||
|     def get_prep_lookup(self): |  | ||||||
|         if hasattr(self.rhs, '_prepare'): |  | ||||||
|             return self.rhs._prepare(self.lhs.output_field) |  | ||||||
|         return [self.lhs.output_field.get_prep_value(v) for v in self.rhs] |  | ||||||
|  |  | ||||||
|     def get_rhs_op(self, connection, rhs): |     def get_rhs_op(self, connection, rhs): | ||||||
|         return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) |         return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) | ||||||
|  |  | ||||||
|     def process_rhs(self, compiler, connection): |  | ||||||
|         if self.rhs_is_direct_value(): |  | ||||||
|             # rhs should be an iterable of 2 values, we use batch_process_rhs |  | ||||||
|             # to prepare/transform those values |  | ||||||
|             return self.batch_process_rhs(compiler, connection) |  | ||||||
|         else: |  | ||||||
|             return super(Range, self).process_rhs(compiler, connection) |  | ||||||
| Field.register_lookup(Range) | Field.register_lookup(Range) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -990,6 +990,20 @@ class Query(object): | |||||||
|             pre_joins = self.alias_refcount.copy() |             pre_joins = self.alias_refcount.copy() | ||||||
|             value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) |             value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) | ||||||
|             used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] |             used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] | ||||||
|  |         elif isinstance(value, (list, tuple)): | ||||||
|  |             # The items of the iterable may be expressions and therefore need | ||||||
|  |             # to be resolved independently. | ||||||
|  |             processed_values = [] | ||||||
|  |             used_joins = set() | ||||||
|  |             for sub_value in value: | ||||||
|  |                 if hasattr(sub_value, 'resolve_expression'): | ||||||
|  |                     pre_joins = self.alias_refcount.copy() | ||||||
|  |                     processed_values.append( | ||||||
|  |                         sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) | ||||||
|  |                     ) | ||||||
|  |                     # The used_joins for a tuple of expressions is the union of | ||||||
|  |                     # the used_joins for the individual expressions. | ||||||
|  |                     used_joins |= set(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)) | ||||||
|         # Subqueries need to use a different set of aliases than the |         # Subqueries need to use a different set of aliases than the | ||||||
|         # outer query. Call bump_prefix to change aliases of the inner |         # outer query. Call bump_prefix to change aliases of the inner | ||||||
|         # query (the value). |         # query (the value). | ||||||
|   | |||||||
| @@ -234,6 +234,9 @@ Models | |||||||
| * Added support for expressions in :meth:`.QuerySet.values` and | * Added support for expressions in :meth:`.QuerySet.values` and | ||||||
|   :meth:`~.QuerySet.values_list`. |   :meth:`~.QuerySet.values_list`. | ||||||
|  |  | ||||||
|  | * Added support for query expressions on lookups that take multiple arguments, | ||||||
|  |   such as ``range``. | ||||||
|  |  | ||||||
| Requests and Responses | Requests and Responses | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -61,6 +61,15 @@ class Experiment(models.Model): | |||||||
|         return self.end - self.start |         return self.end - self.start | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @python_2_unicode_compatible | ||||||
|  | class Result(models.Model): | ||||||
|  |     experiment = models.ForeignKey(Experiment, models.CASCADE) | ||||||
|  |     result_time = models.DateTimeField() | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return "Result at %s" % self.result_time | ||||||
|  |  | ||||||
|  |  | ||||||
| @python_2_unicode_compatible | @python_2_unicode_compatible | ||||||
| class Time(models.Model): | class Time(models.Model): | ||||||
|     time = models.TimeField(null=True) |     time = models.TimeField(null=True) | ||||||
| @@ -69,6 +78,16 @@ class Time(models.Model): | |||||||
|         return "%s" % self.time |         return "%s" % self.time | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @python_2_unicode_compatible | ||||||
|  | class SimulationRun(models.Model): | ||||||
|  |     start = models.ForeignKey(Time, models.CASCADE, null=True) | ||||||
|  |     end = models.ForeignKey(Time, models.CASCADE, null=True) | ||||||
|  |     midpoint = models.TimeField() | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return "%s (%s to %s)" % (self.midpoint, self.start, self.end) | ||||||
|  |  | ||||||
|  |  | ||||||
| @python_2_unicode_compatible | @python_2_unicode_compatible | ||||||
| class UUID(models.Model): | class UUID(models.Model): | ||||||
|     uuid = models.UUIDField(null=True) |     uuid = models.UUIDField(null=True) | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
|  |  | ||||||
| import datetime | import datetime | ||||||
|  | import unittest | ||||||
| import uuid | import uuid | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
| @@ -17,11 +18,15 @@ from django.db.models.expressions import ( | |||||||
| from django.db.models.functions import ( | from django.db.models.functions import ( | ||||||
|     Coalesce, Concat, Length, Lower, Substr, Upper, |     Coalesce, Concat, Length, Lower, Substr, Upper, | ||||||
| ) | ) | ||||||
|  | from django.db.models.sql import constants | ||||||
|  | from django.db.models.sql.datastructures import Join | ||||||
| from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature | from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature | ||||||
| from django.test.utils import Approximate | from django.test.utils import Approximate | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  |  | ||||||
| from .models import UUID, Company, Employee, Experiment, Number, Time | from .models import ( | ||||||
|  |     UUID, Company, Employee, Experiment, Number, Result, SimulationRun, Time, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BasicExpressionsTests(TestCase): | class BasicExpressionsTests(TestCase): | ||||||
| @@ -391,6 +396,144 @@ class BasicExpressionsTests(TestCase): | |||||||
|         self.assertEqual(str(qs.query).count('JOIN'), 2) |         self.assertEqual(str(qs.query).count('JOIN'), 2) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class IterableLookupInnerExpressionsTests(TestCase): | ||||||
|  |     @classmethod | ||||||
|  |     def setUpTestData(cls): | ||||||
|  |         ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30) | ||||||
|  |         # MySQL requires that the values calculated for expressions don't pass | ||||||
|  |         # outside of the field's range, so it's inconvenient to use the values | ||||||
|  |         # in the more general tests. | ||||||
|  |         Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo) | ||||||
|  |         Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo) | ||||||
|  |         Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo) | ||||||
|  |         Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo) | ||||||
|  |         Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo) | ||||||
|  |  | ||||||
|  |     def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self): | ||||||
|  |         # __in lookups can use F() expressions for integers. | ||||||
|  |         queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10])) | ||||||
|  |         self.assertQuerysetEqual(queryset, ['<Company: 5060 Ltd>'], ordered=False) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])), | ||||||
|  |             ['<Company: 5040 Ltd>', '<Company: 5060 Ltd>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter( | ||||||
|  |                 num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10]) | ||||||
|  |             ), | ||||||
|  |             ['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_expressions_in_lookups_join_choice(self): | ||||||
|  |         midpoint = datetime.time(13, 0) | ||||||
|  |         t1 = Time.objects.create(time=datetime.time(12, 0)) | ||||||
|  |         t2 = Time.objects.create(time=datetime.time(14, 0)) | ||||||
|  |         SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint) | ||||||
|  |         SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint) | ||||||
|  |         SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint) | ||||||
|  |         SimulationRun.objects.create(start=None, end=None, midpoint=midpoint) | ||||||
|  |  | ||||||
|  |         queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')]) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             queryset, | ||||||
|  |             ['<SimulationRun: 13:00:00 (12:00:00 to 14:00:00)>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |         for alias in queryset.query.alias_map.values(): | ||||||
|  |             if isinstance(alias, Join): | ||||||
|  |                 self.assertEqual(alias.join_type, constants.INNER) | ||||||
|  |  | ||||||
|  |         queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')]) | ||||||
|  |         self.assertQuerysetEqual(queryset, [], ordered=False) | ||||||
|  |         for alias in queryset.query.alias_map.values(): | ||||||
|  |             if isinstance(alias, Join): | ||||||
|  |                 self.assertEqual(alias.join_type, constants.LOUTER) | ||||||
|  |  | ||||||
|  |     def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self): | ||||||
|  |         # Range lookups can use F() expressions for integers. | ||||||
|  |         Company.objects.filter(num_employees__exact=F("num_chairs")) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter(num_employees__range=(F('num_chairs'), 100)), | ||||||
|  |             ['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)), | ||||||
|  |             ['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)), | ||||||
|  |             ['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |         self.assertQuerysetEqual( | ||||||
|  |             Company.objects.filter(num_employees__range=(1, 100)), | ||||||
|  |             [ | ||||||
|  |                 '<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>', | ||||||
|  |                 '<Company: 5060 Ltd>', '<Company: 99300 Ltd>', | ||||||
|  |             ], | ||||||
|  |             ordered=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @unittest.skipUnless(connection.vendor == 'sqlite', | ||||||
|  |                          "This defensive test only works on databases that don't validate parameter types") | ||||||
|  |     def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self): | ||||||
|  |         """ | ||||||
|  |         This tests that SQL injection isn't possible using compilation of | ||||||
|  |         expressions in iterable filters, as their compilation happens before | ||||||
|  |         the main query compilation. It's limited to SQLite, as PostgreSQL, | ||||||
|  |         Oracle and other vendors have defense in depth against this by type | ||||||
|  |         checking. Testing against SQLite (the most permissive of the built-in | ||||||
|  |         databases) demonstrates that the problem doesn't exist while keeping | ||||||
|  |         the test simple. | ||||||
|  |         """ | ||||||
|  |         queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1']) | ||||||
|  |         self.assertQuerysetEqual(queryset, [], ordered=False) | ||||||
|  |  | ||||||
|  |     def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self): | ||||||
|  |         start = datetime.datetime(2016, 2, 3, 15, 0, 0) | ||||||
|  |         end = datetime.datetime(2016, 2, 5, 15, 0, 0) | ||||||
|  |         experiment_1 = Experiment.objects.create( | ||||||
|  |             name='Integrity testing', | ||||||
|  |             assigned=start.date(), | ||||||
|  |             start=start, | ||||||
|  |             end=end, | ||||||
|  |             completed=end.date(), | ||||||
|  |             estimated_time=end - start, | ||||||
|  |         ) | ||||||
|  |         experiment_2 = Experiment.objects.create( | ||||||
|  |             name='Taste testing', | ||||||
|  |             assigned=start.date(), | ||||||
|  |             start=start, | ||||||
|  |             end=end, | ||||||
|  |             completed=end.date(), | ||||||
|  |             estimated_time=end - start, | ||||||
|  |         ) | ||||||
|  |         Result.objects.create( | ||||||
|  |             experiment=experiment_1, | ||||||
|  |             result_time=datetime.datetime(2016, 2, 4, 15, 0, 0), | ||||||
|  |         ) | ||||||
|  |         Result.objects.create( | ||||||
|  |             experiment=experiment_1, | ||||||
|  |             result_time=datetime.datetime(2016, 3, 10, 2, 0, 0), | ||||||
|  |         ) | ||||||
|  |         Result.objects.create( | ||||||
|  |             experiment=experiment_2, | ||||||
|  |             result_time=datetime.datetime(2016, 1, 8, 5, 0, 0), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         within_experiment_time = [F('experiment__start'), F('experiment__end')] | ||||||
|  |         queryset = Result.objects.filter(result_time__range=within_experiment_time) | ||||||
|  |         self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"]) | ||||||
|  |  | ||||||
|  |         within_experiment_time = [F('experiment__start'), F('experiment__end')] | ||||||
|  |         queryset = Result.objects.filter(result_time__range=within_experiment_time) | ||||||
|  |         self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ExpressionsTests(TestCase): | class ExpressionsTests(TestCase): | ||||||
|  |  | ||||||
|     def test_F_object_deepcopy(self): |     def test_F_object_deepcopy(self): | ||||||
|   | |||||||
| @@ -173,12 +173,40 @@ class TestQuerying(PostgreSQLTestCase): | |||||||
|             self.objs[:2] |             self.objs[:2] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @unittest.expectedFailure | ||||||
|  |     def test_in_including_F_object(self): | ||||||
|  |         # This test asserts that Array objects passed to filters can be | ||||||
|  |         # constructed to contain F objects. This currently doesn't work as the | ||||||
|  |         # psycopg2 mogrify method that generates the ARRAY() syntax is | ||||||
|  |         # expecting literals, not column references (#27095). | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]), | ||||||
|  |             self.objs[:2] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_in_as_F_object(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]), | ||||||
|  |             self.objs[:4] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_contained_by(self): |     def test_contained_by(self): | ||||||
|         self.assertSequenceEqual( |         self.assertSequenceEqual( | ||||||
|             NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), |             NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), | ||||||
|             self.objs[:2] |             self.objs[:2] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @unittest.expectedFailure | ||||||
|  |     def test_contained_by_including_F_object(self): | ||||||
|  |         # This test asserts that Array objects passed to filters can be | ||||||
|  |         # constructed to contain F objects. This currently doesn't work as the | ||||||
|  |         # psycopg2 mogrify method that generates the ARRAY() syntax is | ||||||
|  |         # expecting literals, not column references (#27095). | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]), | ||||||
|  |             self.objs[:2] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_contains(self): |     def test_contains(self): | ||||||
|         self.assertSequenceEqual( |         self.assertSequenceEqual( | ||||||
|             NullableIntegerArrayModel.objects.filter(field__contains=[2]), |             NullableIntegerArrayModel.objects.filter(field__contains=[2]), | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user