mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			784 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			784 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import itertools
 | |
| import math
 | |
| import warnings
 | |
| 
 | |
| from django.core.exceptions import EmptyResultSet, FullResultSet
 | |
| from django.db.backends.base.operations import BaseDatabaseOperations
 | |
| from django.db.models.expressions import Case, Expression, Func, Value, When
 | |
| from django.db.models.fields import (
 | |
|     BooleanField,
 | |
|     CharField,
 | |
|     DateTimeField,
 | |
|     Field,
 | |
|     IntegerField,
 | |
|     UUIDField,
 | |
| )
 | |
| from django.db.models.query_utils import RegisterLookupMixin
 | |
| from django.utils.datastructures import OrderedSet
 | |
| from django.utils.deprecation import RemovedInDjango60Warning
 | |
| from django.utils.functional import cached_property
 | |
| from django.utils.hashable import make_hashable
 | |
| 
 | |
| 
 | |
| class Lookup(Expression):
 | |
|     lookup_name = None
 | |
|     prepare_rhs = True
 | |
|     can_use_none_as_rhs = False
 | |
| 
 | |
|     def __init__(self, lhs, rhs):
 | |
|         self.lhs, self.rhs = lhs, rhs
 | |
|         self.rhs = self.get_prep_lookup()
 | |
|         self.lhs = self.get_prep_lhs()
 | |
|         if hasattr(self.lhs, "get_bilateral_transforms"):
 | |
|             bilateral_transforms = self.lhs.get_bilateral_transforms()
 | |
|         else:
 | |
|             bilateral_transforms = []
 | |
|         if bilateral_transforms:
 | |
|             # Warn the user as soon as possible if they are trying to apply
 | |
|             # a bilateral transformation on a nested QuerySet: that won't work.
 | |
|             from django.db.models.sql.query import Query  # avoid circular import
 | |
| 
 | |
|             if isinstance(rhs, Query):
 | |
|                 raise NotImplementedError(
 | |
|                     "Bilateral transformations on nested querysets are not implemented."
 | |
|                 )
 | |
|         self.bilateral_transforms = bilateral_transforms
 | |
| 
 | |
|     def apply_bilateral_transforms(self, value):
 | |
|         for transform in self.bilateral_transforms:
 | |
|             value = transform(value)
 | |
|         return value
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
 | |
| 
 | |
|     def batch_process_rhs(self, compiler, connection, rhs=None):
 | |
|         if rhs is None:
 | |
|             rhs = self.rhs
 | |
|         if self.bilateral_transforms:
 | |
|             sqls, sqls_params = [], []
 | |
|             for p in rhs:
 | |
|                 value = Value(p, output_field=self.lhs.output_field)
 | |
|                 value = self.apply_bilateral_transforms(value)
 | |
|                 value = value.resolve_expression(compiler.query)
 | |
|                 sql, sql_params = compiler.compile(value)
 | |
|                 sqls.append(sql)
 | |
|                 sqls_params.extend(sql_params)
 | |
|         else:
 | |
|             _, params = self.get_db_prep_lookup(rhs, connection)
 | |
|             sqls, sqls_params = ["%s"] * len(params), params
 | |
|         return sqls, sqls_params
 | |
| 
 | |
|     def get_source_expressions(self):
 | |
|         if self.rhs_is_direct_value():
 | |
|             return [self.lhs]
 | |
|         return [self.lhs, self.rhs]
 | |
| 
 | |
|     def set_source_expressions(self, new_exprs):
 | |
|         if len(new_exprs) == 1:
 | |
|             self.lhs = new_exprs[0]
 | |
|         else:
 | |
|             self.lhs, self.rhs = new_exprs
 | |
| 
 | |
|     def get_prep_lookup(self):
 | |
|         if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
 | |
|             return self.rhs
 | |
|         if hasattr(self.lhs, "output_field"):
 | |
|             if hasattr(self.lhs.output_field, "get_prep_value"):
 | |
|                 return self.lhs.output_field.get_prep_value(self.rhs)
 | |
|         elif self.rhs_is_direct_value():
 | |
|             return Value(self.rhs)
 | |
|         return self.rhs
 | |
| 
 | |
|     def get_prep_lhs(self):
 | |
|         if hasattr(self.lhs, "resolve_expression"):
 | |
|             return self.lhs
 | |
|         return Value(self.lhs)
 | |
| 
 | |
|     def get_db_prep_lookup(self, value, connection):
 | |
|         return ("%s", [value])
 | |
| 
 | |
|     def process_lhs(self, compiler, connection, lhs=None):
 | |
|         lhs = lhs or self.lhs
 | |
|         if hasattr(lhs, "resolve_expression"):
 | |
|             lhs = lhs.resolve_expression(compiler.query)
 | |
|         sql, params = compiler.compile(lhs)
 | |
|         if isinstance(lhs, Lookup):
 | |
|             # Wrapped in parentheses to respect operator precedence.
 | |
|             sql = f"({sql})"
 | |
|         return sql, params
 | |
| 
 | |
|     def process_rhs(self, compiler, connection):
 | |
|         value = self.rhs
 | |
|         if self.bilateral_transforms:
 | |
|             if self.rhs_is_direct_value():
 | |
|                 # Do not call get_db_prep_lookup here as the value will be
 | |
|                 # transformed before being used for lookup
 | |
|                 value = Value(value, output_field=self.lhs.output_field)
 | |
|             value = self.apply_bilateral_transforms(value)
 | |
|             value = value.resolve_expression(compiler.query)
 | |
|         if hasattr(value, "as_sql"):
 | |
|             sql, params = compiler.compile(value)
 | |
|             # Ensure expression is wrapped in parentheses to respect operator
 | |
|             # precedence but avoid double wrapping as it can be misinterpreted
 | |
|             # on some backends (e.g. subqueries on SQLite).
 | |
|             if sql and sql[0] != "(":
 | |
|                 sql = "(%s)" % sql
 | |
|             return sql, params
 | |
|         else:
 | |
|             return self.get_db_prep_lookup(value, connection)
 | |
| 
 | |
|     def rhs_is_direct_value(self):
 | |
|         return not hasattr(self.rhs, "as_sql")
 | |
| 
 | |
|     def get_group_by_cols(self):
 | |
|         cols = []
 | |
|         for source in self.get_source_expressions():
 | |
|             cols.extend(source.get_group_by_cols())
 | |
|         return cols
 | |
| 
 | |
|     def as_oracle(self, compiler, connection):
 | |
|         # Oracle doesn't allow EXISTS() and filters to be compared to another
 | |
|         # expression unless they're wrapped in a CASE WHEN.
 | |
|         wrapped = False
 | |
|         exprs = []
 | |
|         for expr in (self.lhs, self.rhs):
 | |
|             if connection.ops.conditional_expression_supported_in_where_clause(expr):
 | |
|                 expr = Case(When(expr, then=True), default=False)
 | |
|                 wrapped = True
 | |
|             exprs.append(expr)
 | |
|         lookup = type(self)(*exprs) if wrapped else self
 | |
|         return lookup.as_sql(compiler, connection)
 | |
| 
 | |
|     @cached_property
 | |
|     def output_field(self):
 | |
|         return BooleanField()
 | |
| 
 | |
|     @property
 | |
|     def identity(self):
 | |
|         return self.__class__, self.lhs, self.rhs
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if not isinstance(other, Lookup):
 | |
|             return NotImplemented
 | |
|         return self.identity == other.identity
 | |
| 
 | |
|     def __hash__(self):
 | |
|         return hash(make_hashable(self.identity))
 | |
| 
 | |
|     def resolve_expression(
 | |
|         self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 | |
|     ):
 | |
|         c = self.copy()
 | |
|         c.is_summary = summarize
 | |
|         c.lhs = self.lhs.resolve_expression(
 | |
|             query, allow_joins, reuse, summarize, for_save
 | |
|         )
 | |
|         if hasattr(self.rhs, "resolve_expression"):
 | |
|             c.rhs = self.rhs.resolve_expression(
 | |
|                 query, allow_joins, reuse, summarize, for_save
 | |
|             )
 | |
|         return c
 | |
| 
 | |
|     def select_format(self, compiler, sql, params):
 | |
|         # Wrap filters with a CASE WHEN expression if a database backend
 | |
|         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
 | |
|         # BY list.
 | |
|         if not compiler.connection.features.supports_boolean_expr_in_select_clause:
 | |
|             sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
 | |
|         return sql, params
 | |
| 
 | |
|     @cached_property
 | |
|     def allowed_default(self):
 | |
|         return self.lhs.allowed_default and self.rhs.allowed_default
 | |
| 
 | |
| 
 | |
| class Transform(RegisterLookupMixin, Func):
 | |
|     """
 | |
|     RegisterLookupMixin() is first so that get_lookup() and get_transform()
 | |
|     first examine self and then check output_field.
 | |
|     """
 | |
| 
 | |
|     bilateral = False
 | |
|     arity = 1
 | |
| 
 | |
|     @property
 | |
|     def lhs(self):
 | |
|         return self.get_source_expressions()[0]
 | |
| 
 | |
|     def get_bilateral_transforms(self):
 | |
|         if hasattr(self.lhs, "get_bilateral_transforms"):
 | |
|             bilateral_transforms = self.lhs.get_bilateral_transforms()
 | |
|         else:
 | |
|             bilateral_transforms = []
 | |
|         if self.bilateral:
 | |
|             bilateral_transforms.append(self.__class__)
 | |
|         return bilateral_transforms
 | |
| 
 | |
| 
 | |
| class BuiltinLookup(Lookup):
 | |
|     def process_lhs(self, compiler, connection, lhs=None):
 | |
|         lhs_sql, params = super().process_lhs(compiler, connection, lhs)
 | |
|         field_internal_type = self.lhs.output_field.get_internal_type()
 | |
|         if (
 | |
|             hasattr(connection.ops.__class__, "field_cast_sql")
 | |
|             and connection.ops.__class__.field_cast_sql
 | |
|             is not BaseDatabaseOperations.field_cast_sql
 | |
|         ):
 | |
|             warnings.warn(
 | |
|                 (
 | |
|                     "The usage of DatabaseOperations.field_cast_sql() is deprecated. "
 | |
|                     "Implement DatabaseOperations.lookup_cast() instead."
 | |
|                 ),
 | |
|                 RemovedInDjango60Warning,
 | |
|             )
 | |
|             db_type = self.lhs.output_field.db_type(connection=connection)
 | |
|             lhs_sql = (
 | |
|                 connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
 | |
|             )
 | |
|         lhs_sql = (
 | |
|             connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
 | |
|         )
 | |
|         return lhs_sql, list(params)
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs_sql, params = self.process_lhs(compiler, connection)
 | |
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
 | |
|         params.extend(rhs_params)
 | |
|         rhs_sql = self.get_rhs_op(connection, rhs_sql)
 | |
|         return "%s %s" % (lhs_sql, rhs_sql), params
 | |
| 
 | |
|     def get_rhs_op(self, connection, rhs):
 | |
|         return connection.operators[self.lookup_name] % rhs
 | |
| 
 | |
| 
 | |
| class FieldGetDbPrepValueMixin:
 | |
|     """
 | |
|     Some lookups require Field.get_db_prep_value() to be called on their
 | |
|     inputs.
 | |
|     """
 | |
| 
 | |
|     get_db_prep_lookup_value_is_iterable = False
 | |
| 
 | |
|     def get_db_prep_lookup(self, value, connection):
 | |
|         # For relational fields, use the 'target_field' attribute of the
 | |
|         # output_field.
 | |
|         field = getattr(self.lhs.output_field, "target_field", None)
 | |
|         get_db_prep_value = (
 | |
|             getattr(field, "get_db_prep_value", None)
 | |
|             or self.lhs.output_field.get_db_prep_value
 | |
|         )
 | |
|         return (
 | |
|             "%s",
 | |
|             [get_db_prep_value(v, connection, prepared=True) for v in value]
 | |
|             if self.get_db_prep_lookup_value_is_iterable
 | |
|             else [get_db_prep_value(value, connection, prepared=True)],
 | |
|         )
 | |
| 
 | |
| 
 | |
| class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
 | |
|     """
 | |
|     Some lookups require Field.get_db_prep_value() to be called on each value
 | |
|     in an iterable.
 | |
|     """
 | |
| 
 | |
|     get_db_prep_lookup_value_is_iterable = True
 | |
| 
 | |
|     def get_prep_lookup(self):
 | |
|         if hasattr(self.rhs, "resolve_expression"):
 | |
|             return self.rhs
 | |
|         prepared_values = []
 | |
|         for rhs_value in self.rhs:
 | |
|             if hasattr(rhs_value, "resolve_expression"):
 | |
|                 # An expression will be handled by the database but can coexist
 | |
|                 # alongside real values.
 | |
|                 pass
 | |
|             elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
 | |
|                 rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
 | |
|             prepared_values.append(rhs_value)
 | |
|         return prepared_values
 | |
| 
 | |
|     def process_rhs(self, compiler, connection):
 | |
|         if self.rhs_is_direct_value():
 | |
|             # rhs should be an iterable of values. Use batch_process_rhs()
 | |
|             # to prepare/transform those values.
 | |
|             return self.batch_process_rhs(compiler, connection)
 | |
|         else:
 | |
|             return super().process_rhs(compiler, connection)
 | |
| 
 | |
|     def resolve_expression_parameter(self, compiler, connection, sql, param):
 | |
|         params = [param]
 | |
|         if hasattr(param, "resolve_expression"):
 | |
|             param = param.resolve_expression(compiler.query)
 | |
|         if hasattr(param, "as_sql"):
 | |
|             sql, params = compiler.compile(param)
 | |
|         return sql, params
 | |
| 
 | |
|     def batch_process_rhs(self, compiler, connection, rhs=None):
 | |
|         pre_processed = super().batch_process_rhs(compiler, connection, rhs)
 | |
|         # The params list may contain expressions which compile to a
 | |
|         # sql/param pair. Zip them to get sql and param pairs that refer to the
 | |
|         # same argument and attempt to replace them with the result of
 | |
|         # compiling the param step.
 | |
|         sql, params = zip(
 | |
|             *(
 | |
|                 self.resolve_expression_parameter(compiler, connection, sql, param)
 | |
|                 for sql, param in zip(*pre_processed)
 | |
|             )
 | |
|         )
 | |
|         params = itertools.chain.from_iterable(params)
 | |
|         return sql, tuple(params)
 | |
| 
 | |
| 
 | |
| class PostgresOperatorLookup(Lookup):
 | |
|     """Lookup defined by operators on PostgreSQL."""
 | |
| 
 | |
|     postgres_operator = None
 | |
| 
 | |
|     def as_postgresql(self, compiler, connection):
 | |
|         lhs, lhs_params = self.process_lhs(compiler, connection)
 | |
|         rhs, rhs_params = self.process_rhs(compiler, connection)
 | |
|         params = tuple(lhs_params) + tuple(rhs_params)
 | |
|         return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
 | |
|     lookup_name = "exact"
 | |
| 
 | |
|     def get_prep_lookup(self):
 | |
|         from django.db.models.sql.query import Query  # avoid circular import
 | |
| 
 | |
|         if isinstance(self.rhs, Query):
 | |
|             if self.rhs.has_limit_one():
 | |
|                 if not self.rhs.has_select_fields:
 | |
|                     self.rhs.clear_select_clause()
 | |
|                     self.rhs.add_fields(["pk"])
 | |
|             else:
 | |
|                 raise ValueError(
 | |
|                     "The QuerySet value for an exact lookup must be limited to "
 | |
|                     "one result using slicing."
 | |
|                 )
 | |
|         return super().get_prep_lookup()
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         # Avoid comparison against direct rhs if lhs is a boolean value. That
 | |
|         # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
 | |
|         # "WHERE boolean_field = True" when allowed.
 | |
|         if (
 | |
|             isinstance(self.rhs, bool)
 | |
|             and getattr(self.lhs, "conditional", False)
 | |
|             and connection.ops.conditional_expression_supported_in_where_clause(
 | |
|                 self.lhs
 | |
|             )
 | |
|         ):
 | |
|             lhs_sql, params = self.process_lhs(compiler, connection)
 | |
|             template = "%s" if self.rhs else "NOT %s"
 | |
|             return template % lhs_sql, params
 | |
|         return super().as_sql(compiler, connection)
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IExact(BuiltinLookup):
 | |
|     lookup_name = "iexact"
 | |
|     prepare_rhs = False
 | |
| 
 | |
|     def process_rhs(self, qn, connection):
 | |
|         rhs, params = super().process_rhs(qn, connection)
 | |
|         if params:
 | |
|             params[0] = connection.ops.prep_for_iexact_query(params[0])
 | |
|         return rhs, params
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
 | |
|     lookup_name = "gt"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
 | |
|     lookup_name = "gte"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
 | |
|     lookup_name = "lt"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
 | |
|     lookup_name = "lte"
 | |
| 
 | |
| 
 | |
| class IntegerFieldOverflow:
 | |
|     underflow_exception = EmptyResultSet
 | |
|     overflow_exception = EmptyResultSet
 | |
| 
 | |
|     def process_rhs(self, compiler, connection):
 | |
|         rhs = self.rhs
 | |
|         if isinstance(rhs, int):
 | |
|             field_internal_type = self.lhs.output_field.get_internal_type()
 | |
|             min_value, max_value = connection.ops.integer_field_range(
 | |
|                 field_internal_type
 | |
|             )
 | |
|             if min_value is not None and rhs < min_value:
 | |
|                 raise self.underflow_exception
 | |
|             if max_value is not None and rhs > max_value:
 | |
|                 raise self.overflow_exception
 | |
|         return super().process_rhs(compiler, connection)
 | |
| 
 | |
| 
 | |
| class IntegerFieldFloatRounding:
 | |
|     """
 | |
|     Allow floats to work as query values for IntegerField. Without this, the
 | |
|     decimal portion of the float would always be discarded.
 | |
|     """
 | |
| 
 | |
|     def get_prep_lookup(self):
 | |
|         if isinstance(self.rhs, float):
 | |
|             self.rhs = math.ceil(self.rhs)
 | |
|         return super().get_prep_lookup()
 | |
| 
 | |
| 
 | |
| @IntegerField.register_lookup
 | |
| class IntegerFieldExact(IntegerFieldOverflow, Exact):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @IntegerField.register_lookup
 | |
| class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
 | |
|     underflow_exception = FullResultSet
 | |
| 
 | |
| 
 | |
| @IntegerField.register_lookup
 | |
| class IntegerGreaterThanOrEqual(
 | |
|     IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
 | |
| ):
 | |
|     underflow_exception = FullResultSet
 | |
| 
 | |
| 
 | |
| @IntegerField.register_lookup
 | |
| class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
 | |
|     overflow_exception = FullResultSet
 | |
| 
 | |
| 
 | |
| @IntegerField.register_lookup
 | |
| class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
 | |
|     overflow_exception = FullResultSet
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
 | |
|     lookup_name = "in"
 | |
| 
 | |
|     def get_prep_lookup(self):
 | |
|         from django.db.models.sql.query import Query  # avoid circular import
 | |
| 
 | |
|         if isinstance(self.rhs, Query):
 | |
|             self.rhs.clear_ordering(clear_default=True)
 | |
|             if not self.rhs.has_select_fields:
 | |
|                 self.rhs.clear_select_clause()
 | |
|                 self.rhs.add_fields(["pk"])
 | |
|         return super().get_prep_lookup()
 | |
| 
 | |
|     def process_rhs(self, compiler, connection):
 | |
|         db_rhs = getattr(self.rhs, "_db", None)
 | |
|         if db_rhs is not None and db_rhs != connection.alias:
 | |
|             raise ValueError(
 | |
|                 "Subqueries aren't allowed across different databases. Force "
 | |
|                 "the inner query to be evaluated using `list(inner_query)`."
 | |
|             )
 | |
| 
 | |
|         if self.rhs_is_direct_value():
 | |
|             # Remove None from the list as NULL is never equal to anything.
 | |
|             try:
 | |
|                 rhs = OrderedSet(self.rhs)
 | |
|                 rhs.discard(None)
 | |
|             except TypeError:  # Unhashable items in self.rhs
 | |
|                 rhs = [r for r in self.rhs if r is not None]
 | |
| 
 | |
|             if not rhs:
 | |
|                 raise EmptyResultSet
 | |
| 
 | |
|             # rhs should be an iterable; use batch_process_rhs() to
 | |
|             # prepare/transform those values.
 | |
|             sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
 | |
|             placeholder = "(" + ", ".join(sqls) + ")"
 | |
|             return (placeholder, sqls_params)
 | |
|         return super().process_rhs(compiler, connection)
 | |
| 
 | |
|     def get_rhs_op(self, connection, rhs):
 | |
|         return "IN %s" % rhs
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         max_in_list_size = connection.ops.max_in_list_size()
 | |
|         if (
 | |
|             self.rhs_is_direct_value()
 | |
|             and max_in_list_size
 | |
|             and len(self.rhs) > max_in_list_size
 | |
|         ):
 | |
|             return self.split_parameter_list_as_sql(compiler, connection)
 | |
|         return super().as_sql(compiler, connection)
 | |
| 
 | |
|     def split_parameter_list_as_sql(self, compiler, connection):
 | |
|         # This is a special case for databases which limit the number of
 | |
|         # elements which can appear in an 'IN' clause.
 | |
|         max_in_list_size = connection.ops.max_in_list_size()
 | |
|         lhs, lhs_params = self.process_lhs(compiler, connection)
 | |
|         rhs, rhs_params = self.batch_process_rhs(compiler, connection)
 | |
|         in_clause_elements = ["("]
 | |
|         params = []
 | |
|         for offset in range(0, len(rhs_params), max_in_list_size):
 | |
|             if offset > 0:
 | |
|                 in_clause_elements.append(" OR ")
 | |
|             in_clause_elements.append("%s IN (" % lhs)
 | |
|             params.extend(lhs_params)
 | |
|             sqls = rhs[offset : offset + max_in_list_size]
 | |
|             sqls_params = rhs_params[offset : offset + max_in_list_size]
 | |
|             param_group = ", ".join(sqls)
 | |
|             in_clause_elements.append(param_group)
 | |
|             in_clause_elements.append(")")
 | |
|             params.extend(sqls_params)
 | |
|         in_clause_elements.append(")")
 | |
|         return "".join(in_clause_elements), params
 | |
| 
 | |
| 
 | |
| class PatternLookup(BuiltinLookup):
 | |
|     param_pattern = "%%%s%%"
 | |
|     prepare_rhs = False
 | |
| 
 | |
|     def get_rhs_op(self, connection, rhs):
 | |
|         # Assume we are in startswith. We need to produce SQL like:
 | |
|         #     col LIKE %s, ['thevalue%']
 | |
|         # For python values we can (and should) do that directly in Python,
 | |
|         # but if the value is for example reference to other column, then
 | |
|         # we need to add the % pattern match to the lookup by something like
 | |
|         #     col LIKE othercol || '%%'
 | |
|         # So, for Python values we don't need any special pattern, but for
 | |
|         # SQL reference values or SQL transformations we need the correct
 | |
|         # pattern added.
 | |
|         if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
 | |
|             pattern = connection.pattern_ops[self.lookup_name].format(
 | |
|                 connection.pattern_esc
 | |
|             )
 | |
|             return pattern.format(rhs)
 | |
|         else:
 | |
|             return super().get_rhs_op(connection, rhs)
 | |
| 
 | |
|     def process_rhs(self, qn, connection):
 | |
|         rhs, params = super().process_rhs(qn, connection)
 | |
|         if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
 | |
|             params[0] = self.param_pattern % connection.ops.prep_for_like_query(
 | |
|                 params[0]
 | |
|             )
 | |
|         return rhs, params
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class Contains(PatternLookup):
 | |
|     lookup_name = "contains"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IContains(Contains):
 | |
|     lookup_name = "icontains"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class StartsWith(PatternLookup):
 | |
|     lookup_name = "startswith"
 | |
|     param_pattern = "%s%%"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IStartsWith(StartsWith):
 | |
|     lookup_name = "istartswith"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class EndsWith(PatternLookup):
 | |
|     lookup_name = "endswith"
 | |
|     param_pattern = "%%%s"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IEndsWith(EndsWith):
 | |
|     lookup_name = "iendswith"
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
 | |
|     lookup_name = "range"
 | |
| 
 | |
|     def get_rhs_op(self, connection, rhs):
 | |
|         return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IsNull(BuiltinLookup):
 | |
|     lookup_name = "isnull"
 | |
|     prepare_rhs = False
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         if not isinstance(self.rhs, bool):
 | |
|             raise ValueError(
 | |
|                 "The QuerySet value for an isnull lookup must be True or False."
 | |
|             )
 | |
|         if isinstance(self.lhs, Value):
 | |
|             if self.lhs.value is None or (
 | |
|                 self.lhs.value == ""
 | |
|                 and connection.features.interprets_empty_strings_as_nulls
 | |
|             ):
 | |
|                 result_exception = FullResultSet if self.rhs else EmptyResultSet
 | |
|             else:
 | |
|                 result_exception = EmptyResultSet if self.rhs else FullResultSet
 | |
|             raise result_exception
 | |
|         sql, params = self.process_lhs(compiler, connection)
 | |
|         if self.rhs:
 | |
|             return "%s IS NULL" % sql, params
 | |
|         else:
 | |
|             return "%s IS NOT NULL" % sql, params
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class Regex(BuiltinLookup):
 | |
|     lookup_name = "regex"
 | |
|     prepare_rhs = False
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         if self.lookup_name in connection.operators:
 | |
|             return super().as_sql(compiler, connection)
 | |
|         else:
 | |
|             lhs, lhs_params = self.process_lhs(compiler, connection)
 | |
|             rhs, rhs_params = self.process_rhs(compiler, connection)
 | |
|             sql_template = connection.ops.regex_lookup(self.lookup_name)
 | |
|             return sql_template % (lhs, rhs), lhs_params + rhs_params
 | |
| 
 | |
| 
 | |
| @Field.register_lookup
 | |
| class IRegex(Regex):
 | |
|     lookup_name = "iregex"
 | |
| 
 | |
| 
 | |
| class YearLookup(Lookup):
 | |
|     def year_lookup_bounds(self, connection, year):
 | |
|         from django.db.models.functions import ExtractIsoYear
 | |
| 
 | |
|         iso_year = isinstance(self.lhs, ExtractIsoYear)
 | |
|         output_field = self.lhs.lhs.output_field
 | |
|         if isinstance(output_field, DateTimeField):
 | |
|             bounds = connection.ops.year_lookup_bounds_for_datetime_field(
 | |
|                 year,
 | |
|                 iso_year=iso_year,
 | |
|             )
 | |
|         else:
 | |
|             bounds = connection.ops.year_lookup_bounds_for_date_field(
 | |
|                 year,
 | |
|                 iso_year=iso_year,
 | |
|             )
 | |
|         return bounds
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         # Avoid the extract operation if the rhs is a direct value to allow
 | |
|         # indexes to be used.
 | |
|         if self.rhs_is_direct_value():
 | |
|             # Skip the extract part by directly using the originating field,
 | |
|             # that is self.lhs.lhs.
 | |
|             lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
 | |
|             rhs_sql, _ = self.process_rhs(compiler, connection)
 | |
|             rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
 | |
|             start, finish = self.year_lookup_bounds(connection, self.rhs)
 | |
|             params.extend(self.get_bound_params(start, finish))
 | |
|             return "%s %s" % (lhs_sql, rhs_sql), params
 | |
|         return super().as_sql(compiler, connection)
 | |
| 
 | |
|     def get_direct_rhs_sql(self, connection, rhs):
 | |
|         return connection.operators[self.lookup_name] % rhs
 | |
| 
 | |
|     def get_bound_params(self, start, finish):
 | |
|         raise NotImplementedError(
 | |
|             "subclasses of YearLookup must provide a get_bound_params() method"
 | |
|         )
 | |
| 
 | |
| 
 | |
| class YearExact(YearLookup, Exact):
 | |
|     def get_direct_rhs_sql(self, connection, rhs):
 | |
|         return "BETWEEN %s AND %s"
 | |
| 
 | |
|     def get_bound_params(self, start, finish):
 | |
|         return (start, finish)
 | |
| 
 | |
| 
 | |
| class YearGt(YearLookup, GreaterThan):
 | |
|     def get_bound_params(self, start, finish):
 | |
|         return (finish,)
 | |
| 
 | |
| 
 | |
| class YearGte(YearLookup, GreaterThanOrEqual):
 | |
|     def get_bound_params(self, start, finish):
 | |
|         return (start,)
 | |
| 
 | |
| 
 | |
| class YearLt(YearLookup, LessThan):
 | |
|     def get_bound_params(self, start, finish):
 | |
|         return (start,)
 | |
| 
 | |
| 
 | |
| class YearLte(YearLookup, LessThanOrEqual):
 | |
|     def get_bound_params(self, start, finish):
 | |
|         return (finish,)
 | |
| 
 | |
| 
 | |
| class UUIDTextMixin:
 | |
|     """
 | |
|     Strip hyphens from a value when filtering a UUIDField on backends without
 | |
|     a native datatype for UUID.
 | |
|     """
 | |
| 
 | |
|     def process_rhs(self, qn, connection):
 | |
|         if not connection.features.has_native_uuid_field:
 | |
|             from django.db.models.functions import Replace
 | |
| 
 | |
|             if self.rhs_is_direct_value():
 | |
|                 self.rhs = Value(self.rhs)
 | |
|             self.rhs = Replace(
 | |
|                 self.rhs, Value("-"), Value(""), output_field=CharField()
 | |
|             )
 | |
|         rhs, params = super().process_rhs(qn, connection)
 | |
|         return rhs, params
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDIExact(UUIDTextMixin, IExact):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDContains(UUIDTextMixin, Contains):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDIContains(UUIDTextMixin, IContains):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDStartsWith(UUIDTextMixin, StartsWith):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDEndsWith(UUIDTextMixin, EndsWith):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @UUIDField.register_lookup
 | |
| class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
 | |
|     pass
 |