mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
The new way is trying to call first method 'as_' + connection.vendor. If that doesn't exist, then call as_sql(). Also altered how lookup registration is done. There is now RegisterLookupMixin class that is used by Field, Extract and sql.Aggregate. This allows one to register lookups for extracts and aggregates in the same way lookup registration is done for fields.
318 lines
9.0 KiB
Python
318 lines
9.0 KiB
Python
from copy import copy
|
|
import inspect
|
|
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
from django.utils.functional import cached_property
|
|
|
|
|
|
class RegisterLookupMixin(object):
|
|
def get_lookup(self, lookup_name):
|
|
try:
|
|
return self.class_lookups[lookup_name]
|
|
except KeyError:
|
|
# To allow for inheritance, check parent class class lookups.
|
|
for parent in inspect.getmro(self.__class__):
|
|
if not 'class_lookups' in parent.__dict__:
|
|
continue
|
|
if lookup_name in parent.class_lookups:
|
|
return parent.class_lookups[lookup_name]
|
|
except AttributeError:
|
|
# This class didn't have any class_lookups
|
|
pass
|
|
if hasattr(self, 'output_type'):
|
|
return self.output_type.get_lookup(lookup_name)
|
|
return None
|
|
|
|
@classmethod
|
|
def register_lookup(cls, lookup):
|
|
if not 'class_lookups' in cls.__dict__:
|
|
cls.class_lookups = {}
|
|
cls.class_lookups[lookup.lookup_name] = lookup
|
|
|
|
@classmethod
|
|
def _unregister_lookup(cls, lookup):
|
|
"""
|
|
Removes given lookup from cls lookups. Meant to be used in
|
|
tests only.
|
|
"""
|
|
del cls.class_lookups[lookup.lookup_name]
|
|
|
|
|
|
class Extract(RegisterLookupMixin):
|
|
def __init__(self, lhs, lookups):
|
|
self.lhs = lhs
|
|
self.init_lookups = lookups[:]
|
|
|
|
def as_sql(self, qn, connection):
|
|
raise NotImplementedError
|
|
|
|
@cached_property
|
|
def output_type(self):
|
|
return self.lhs.output_type
|
|
|
|
def relabeled_clone(self, relabels):
|
|
return self.__class__(self.lhs.relabeled_clone(relabels))
|
|
|
|
def get_cols(self):
|
|
return self.lhs.get_cols()
|
|
|
|
|
|
class Lookup(RegisterLookupMixin):
|
|
lookup_name = None
|
|
|
|
def __init__(self, lhs, rhs):
|
|
self.lhs, self.rhs = lhs, rhs
|
|
self.rhs = self.get_prep_lookup()
|
|
|
|
def get_prep_lookup(self):
|
|
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
|
|
|
def get_db_prep_lookup(self, value, connection):
|
|
return (
|
|
'%s', self.lhs.output_type.get_db_prep_lookup(
|
|
self.lookup_name, value, connection, prepared=True))
|
|
|
|
def process_lhs(self, qn, connection, lhs=None):
|
|
lhs = lhs or self.lhs
|
|
return qn.compile(lhs)
|
|
|
|
def process_rhs(self, qn, connection, rhs=None):
|
|
value = rhs or self.rhs
|
|
# Due to historical reasons there are a couple of different
|
|
# ways to produce sql here. get_compiler is likely a Query
|
|
# instance, _as_sql QuerySet and as_sql just something with
|
|
# as_sql. Finally the value can of course be just plain
|
|
# Python value.
|
|
if hasattr(value, 'get_compiler'):
|
|
value = value.get_compiler(connection=connection)
|
|
if hasattr(value, 'as_sql'):
|
|
sql, params = qn.compile(value)
|
|
return '(' + sql + ')', params
|
|
if hasattr(value, '_as_sql'):
|
|
sql, params = value._as_sql(connection=connection)
|
|
return '(' + sql + ')', params
|
|
else:
|
|
return self.get_db_prep_lookup(value, connection)
|
|
|
|
def relabeled_clone(self, relabels):
|
|
new = copy(self)
|
|
new.lhs = new.lhs.relabeled_clone(relabels)
|
|
if hasattr(new.rhs, 'relabeled_clone'):
|
|
new.rhs = new.rhs.relabeled_clone(relabels)
|
|
return new
|
|
|
|
def get_cols(self):
|
|
cols = self.lhs.get_cols()
|
|
if hasattr(self.rhs, 'get_cols'):
|
|
cols.extend(self.rhs.get_cols())
|
|
return cols
|
|
|
|
def as_sql(self, qn, connection):
|
|
raise NotImplementedError
|
|
|
|
|
|
class BuiltinLookup(Lookup):
|
|
def as_sql(self, qn, connection):
|
|
lhs_sql, params = self.process_lhs(qn, connection)
|
|
field_internal_type = self.lhs.output_type.get_internal_type()
|
|
db_type = self.lhs.output_type
|
|
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
|
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
|
|
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
|
params.extend(rhs_params)
|
|
operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
|
|
return '%s %s' % (lhs_sql, operator_plus_rhs), params
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return connection.operators[self.lookup_name] % rhs
|
|
|
|
|
|
default_lookups = {}
|
|
|
|
|
|
class Exact(BuiltinLookup):
|
|
lookup_name = 'exact'
|
|
default_lookups['exact'] = Exact
|
|
|
|
|
|
class IExact(BuiltinLookup):
|
|
lookup_name = 'iexact'
|
|
default_lookups['iexact'] = IExact
|
|
|
|
|
|
class Contains(BuiltinLookup):
|
|
lookup_name = 'contains'
|
|
default_lookups['contains'] = Contains
|
|
|
|
|
|
class IContains(BuiltinLookup):
|
|
lookup_name = 'icontains'
|
|
default_lookups['icontains'] = IContains
|
|
|
|
|
|
class GreaterThan(BuiltinLookup):
|
|
lookup_name = 'gt'
|
|
default_lookups['gt'] = GreaterThan
|
|
|
|
|
|
class GreaterThanOrEqual(BuiltinLookup):
|
|
lookup_name = 'gte'
|
|
default_lookups['gte'] = GreaterThanOrEqual
|
|
|
|
|
|
class LessThan(BuiltinLookup):
|
|
lookup_name = 'lt'
|
|
default_lookups['lt'] = LessThan
|
|
|
|
|
|
class LessThanOrEqual(BuiltinLookup):
|
|
lookup_name = 'lte'
|
|
default_lookups['lte'] = LessThanOrEqual
|
|
|
|
|
|
class In(BuiltinLookup):
|
|
lookup_name = 'in'
|
|
|
|
def get_db_prep_lookup(self, value, connection):
|
|
params = self.lhs.output_type.get_db_prep_lookup(
|
|
self.lookup_name, value, connection, prepared=True)
|
|
if not params:
|
|
# TODO: check why this leads to circular import
|
|
from django.db.models.sql.datastructures import EmptyResultSet
|
|
raise EmptyResultSet
|
|
placeholder = '(' + ', '.join('%s' for p in params) + ')'
|
|
return (placeholder, params)
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return 'IN %s' % rhs
|
|
default_lookups['in'] = In
|
|
|
|
|
|
class PatternLookup(BuiltinLookup):
|
|
def get_rhs_op(self, connection, rhs):
|
|
# Assume we are in startswith. We need to produce SQL like:
|
|
# col LIKE %s, ['thevalue%']
|
|
# For python values we can (and should) do that directly in Python,
|
|
# but if the value is for example reference to other column, then
|
|
# we need to add the % pattern match to the lookup by something like
|
|
# col LIKE othercol || '%%'
|
|
# So, for Python values we don't need any special pattern, but for
|
|
# SQL reference values we need the correct pattern added.
|
|
value = self.rhs
|
|
if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
|
|
or hasattr(value, '_as_sql')):
|
|
return connection.pattern_ops[self.lookup_name] % rhs
|
|
else:
|
|
return super(PatternLookup, self).get_rhs_op(connection, rhs)
|
|
|
|
|
|
class StartsWith(PatternLookup):
|
|
lookup_name = 'startswith'
|
|
default_lookups['startswith'] = StartsWith
|
|
|
|
|
|
class IStartsWith(PatternLookup):
|
|
lookup_name = 'istartswith'
|
|
default_lookups['istartswith'] = IStartsWith
|
|
|
|
|
|
class EndsWith(BuiltinLookup):
|
|
lookup_name = 'endswith'
|
|
default_lookups['endswith'] = EndsWith
|
|
|
|
|
|
class IEndsWith(BuiltinLookup):
|
|
lookup_name = 'iendswith'
|
|
default_lookups['iendswith'] = IEndsWith
|
|
|
|
|
|
class Between(BuiltinLookup):
|
|
def get_rhs_op(self, connection, rhs):
|
|
return "BETWEEN %s AND %s" % (rhs, rhs)
|
|
|
|
|
|
class Year(Between):
|
|
lookup_name = 'year'
|
|
default_lookups['year'] = Year
|
|
|
|
|
|
class Range(Between):
|
|
lookup_name = 'range'
|
|
default_lookups['range'] = Range
|
|
|
|
|
|
class DateLookup(BuiltinLookup):
|
|
|
|
def process_lhs(self, qn, connection):
|
|
lhs, params = super(DateLookup, self).process_lhs(qn, connection)
|
|
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
|
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
|
|
return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return '= %s' % rhs
|
|
|
|
|
|
class Month(DateLookup):
|
|
lookup_name = 'month'
|
|
extract_type = 'month'
|
|
default_lookups['month'] = Month
|
|
|
|
|
|
class Day(DateLookup):
|
|
lookup_name = 'day'
|
|
extract_type = 'day'
|
|
default_lookups['day'] = Day
|
|
|
|
|
|
class WeekDay(DateLookup):
|
|
lookup_name = 'week_day'
|
|
extract_type = 'week_day'
|
|
default_lookups['week_day'] = WeekDay
|
|
|
|
|
|
class Hour(DateLookup):
|
|
lookup_name = 'hour'
|
|
extract_type = 'hour'
|
|
default_lookups['hour'] = Hour
|
|
|
|
|
|
class Minute(DateLookup):
|
|
lookup_name = 'minute'
|
|
extract_type = 'minute'
|
|
default_lookups['minute'] = Minute
|
|
|
|
|
|
class Second(DateLookup):
|
|
lookup_name = 'second'
|
|
extract_type = 'second'
|
|
default_lookups['second'] = Second
|
|
|
|
|
|
class IsNull(BuiltinLookup):
|
|
lookup_name = 'isnull'
|
|
|
|
def as_sql(self, qn, connection):
|
|
sql, params = qn.compile(self.lhs)
|
|
if self.rhs:
|
|
return "%s IS NULL" % sql, params
|
|
else:
|
|
return "%s IS NOT NULL" % sql, params
|
|
default_lookups['isnull'] = IsNull
|
|
|
|
|
|
class Search(BuiltinLookup):
|
|
lookup_name = 'search'
|
|
default_lookups['search'] = Search
|
|
|
|
|
|
class Regex(BuiltinLookup):
|
|
lookup_name = 'regex'
|
|
default_lookups['regex'] = Regex
|
|
|
|
|
|
class IRegex(BuiltinLookup):
|
|
lookup_name = 'iregex'
|
|
default_lookups['iregex'] = IRegex
|