mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #24629 -- Unified Transform and Expression APIs
This commit is contained in:
@@ -81,14 +81,14 @@ class KeyTransformFactory(object):
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class KeysTransform(lookups.FunctionTransform):
|
||||
class KeysTransform(Transform):
|
||||
lookup_name = 'keys'
|
||||
function = 'akeys'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class ValuesTransform(lookups.FunctionTransform):
|
||||
class ValuesTransform(Transform):
|
||||
lookup_name = 'values'
|
||||
function = 'avals'
|
||||
output_field = ArrayField(TextField())
|
||||
|
@@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup):
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeStartsWith(lookups.FunctionTransform):
|
||||
class RangeStartsWith(models.Transform):
|
||||
lookup_name = 'startswith'
|
||||
function = 'lower'
|
||||
|
||||
@@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform):
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeEndsWith(lookups.FunctionTransform):
|
||||
class RangeEndsWith(models.Transform):
|
||||
lookup_name = 'endswith'
|
||||
function = 'upper'
|
||||
|
||||
@@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform):
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class IsEmpty(lookups.FunctionTransform):
|
||||
class IsEmpty(models.Transform):
|
||||
lookup_name = 'isempty'
|
||||
function = 'isempty'
|
||||
output_field = models.BooleanField()
|
||||
|
@@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup):
|
||||
return '%s %s %s' % (lhs, self.operator, rhs), params
|
||||
|
||||
|
||||
class FunctionTransform(Transform):
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = qn.compile(self.lhs)
|
||||
return "%s(%s)" % (self.function, lhs), params
|
||||
|
||||
|
||||
class DataContains(PostgresSimpleLookup):
|
||||
lookup_name = 'contains'
|
||||
operator = '@>'
|
||||
@@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup):
|
||||
operator = '?|'
|
||||
|
||||
|
||||
class Unaccent(FunctionTransform):
|
||||
class Unaccent(Transform):
|
||||
bilateral = True
|
||||
lookup_name = 'unaccent'
|
||||
function = 'UNACCENT'
|
||||
|
@@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators
|
||||
# purposes.
|
||||
from django.core.exceptions import FieldDoesNotExist # NOQA
|
||||
from django.db import connection, connections, router
|
||||
from django.db.models.lookups import (
|
||||
Lookup, RegisterLookupMixin, Transform, default_lookups,
|
||||
)
|
||||
from django.db.models.query_utils import QueryWrapper
|
||||
from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin
|
||||
from django.utils import six, timezone
|
||||
from django.utils.datastructures import DictWrapper
|
||||
from django.utils.dateparse import (
|
||||
@@ -120,7 +117,6 @@ class Field(RegisterLookupMixin):
|
||||
'unique_for_date': _("%(field_label)s must be unique for "
|
||||
"%(date_field_label)s %(lookup_type)s."),
|
||||
}
|
||||
class_lookups = default_lookups.copy()
|
||||
system_check_deprecated_details = None
|
||||
system_check_removed_details = None
|
||||
|
||||
@@ -1492,22 +1488,6 @@ class DateTimeField(DateField):
|
||||
return super(DateTimeField, self).formfield(**defaults)
|
||||
|
||||
|
||||
@DateTimeField.register_lookup
|
||||
class DateTimeDateTransform(Transform):
|
||||
lookup_name = 'date'
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
||||
lhs_params.extend(tz_params)
|
||||
return sql, lhs_params
|
||||
|
||||
|
||||
class DecimalField(Field):
|
||||
empty_strings_allowed = False
|
||||
default_error_messages = {
|
||||
@@ -2450,146 +2430,3 @@ class UUIDField(Field):
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return super(UUIDField, self).formfield(**defaults)
|
||||
|
||||
|
||||
class DateTransform(Transform):
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
||||
params.extend(tz_params)
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
else:
|
||||
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
|
||||
return sql, params
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return IntegerField()
|
||||
|
||||
|
||||
class YearTransform(DateTransform):
|
||||
lookup_name = 'year'
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
|
||||
return bounds
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearExact(YearLookup):
|
||||
lookup_name = 'exact'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# We will need to skip the extract part and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
bounds = self.year_lookup_bounds(connection, rhs_params[0])
|
||||
params.extend(bounds)
|
||||
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
|
||||
|
||||
|
||||
class YearComparisonLookup(YearLookup):
|
||||
def as_sql(self, compiler, connection):
|
||||
# We will need to skip the extract part and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
|
||||
params.append(self.get_bound(start, finish))
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
def get_bound(self):
|
||||
raise NotImplementedError(
|
||||
'subclasses of YearComparisonLookup must provide a get_bound() method'
|
||||
)
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearGt(YearComparisonLookup):
|
||||
lookup_name = 'gt'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return finish
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearGte(YearComparisonLookup):
|
||||
lookup_name = 'gte'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return start
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearLt(YearComparisonLookup):
|
||||
lookup_name = 'lt'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return start
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearLte(YearComparisonLookup):
|
||||
lookup_name = 'lte'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return finish
|
||||
|
||||
|
||||
class MonthTransform(DateTransform):
|
||||
lookup_name = 'month'
|
||||
|
||||
|
||||
class DayTransform(DateTransform):
|
||||
lookup_name = 'day'
|
||||
|
||||
|
||||
class WeekDayTransform(DateTransform):
|
||||
lookup_name = 'week_day'
|
||||
|
||||
|
||||
class HourTransform(DateTransform):
|
||||
lookup_name = 'hour'
|
||||
|
||||
|
||||
class MinuteTransform(DateTransform):
|
||||
lookup_name = 'minute'
|
||||
|
||||
|
||||
class SecondTransform(DateTransform):
|
||||
lookup_name = 'second'
|
||||
|
||||
|
||||
DateField.register_lookup(YearTransform)
|
||||
DateField.register_lookup(MonthTransform)
|
||||
DateField.register_lookup(DayTransform)
|
||||
DateField.register_lookup(WeekDayTransform)
|
||||
|
||||
TimeField.register_lookup(HourTransform)
|
||||
TimeField.register_lookup(MinuteTransform)
|
||||
TimeField.register_lookup(SecondTransform)
|
||||
|
||||
DateTimeField.register_lookup(YearTransform)
|
||||
DateTimeField.register_lookup(MonthTransform)
|
||||
DateTimeField.register_lookup(DayTransform)
|
||||
DateTimeField.register_lookup(WeekDayTransform)
|
||||
DateTimeField.register_lookup(HourTransform)
|
||||
DateTimeField.register_lookup(MinuteTransform)
|
||||
DateTimeField.register_lookup(SecondTransform)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Classes that represent database functions.
|
||||
"""
|
||||
from django.db.models import DateTimeField, IntegerField
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models import (
|
||||
DateTimeField, Func, IntegerField, Transform, Value,
|
||||
)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
@@ -123,9 +124,10 @@ class Least(Func):
|
||||
return super(Least, self).as_sql(compiler, connection, function='MIN')
|
||||
|
||||
|
||||
class Length(Func):
|
||||
class Length(Transform):
|
||||
"""Returns the number of characters in the expression"""
|
||||
function = 'LENGTH'
|
||||
lookup_name = 'length'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
output_field = extra.pop('output_field', IntegerField())
|
||||
@@ -136,8 +138,9 @@ class Length(Func):
|
||||
return super(Length, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class Lower(Func):
|
||||
class Lower(Transform):
|
||||
function = 'LOWER'
|
||||
lookup_name = 'lower'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
super(Lower, self).__init__(expression, **extra)
|
||||
@@ -188,8 +191,9 @@ class Substr(Func):
|
||||
return super(Substr, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
class Upper(Func):
|
||||
class Upper(Transform):
|
||||
function = 'UPPER'
|
||||
lookup_name = 'upper'
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
super(Upper, self).__init__(expression, **extra)
|
||||
|
@@ -1,101 +1,17 @@
|
||||
import inspect
|
||||
from copy import copy
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import (
|
||||
DateField, DateTimeField, Field, IntegerField, TimeField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.six.moves import range
|
||||
|
||||
from .query_utils import QueryWrapper
|
||||
|
||||
|
||||
class RegisterLookupMixin(object):
|
||||
def _get_lookup(self, lookup_name):
|
||||
try:
|
||||
return self.class_lookups[lookup_name]
|
||||
except KeyError:
|
||||
# To allow for inheritance, check parent class' class_lookups.
|
||||
for parent in inspect.getmro(self.__class__):
|
||||
if 'class_lookups' not in parent.__dict__:
|
||||
continue
|
||||
if lookup_name in parent.class_lookups:
|
||||
return parent.class_lookups[lookup_name]
|
||||
except AttributeError:
|
||||
# This class didn't have any class_lookups
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_lookup(lookup_name)
|
||||
if found is not None and not issubclass(found, Lookup):
|
||||
return None
|
||||
return found
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_transform(lookup_name)
|
||||
if found is not None and not issubclass(found, Transform):
|
||||
return None
|
||||
return found
|
||||
|
||||
@classmethod
|
||||
def register_lookup(cls, lookup):
|
||||
if 'class_lookups' not in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup.lookup_name] = lookup
|
||||
return lookup
|
||||
|
||||
@classmethod
|
||||
def _unregister_lookup(cls, lookup):
|
||||
"""
|
||||
Removes given lookup from cls lookups. Meant to be used in
|
||||
tests only.
|
||||
"""
|
||||
del cls.class_lookups[lookup.lookup_name]
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin):
|
||||
|
||||
bilateral = False
|
||||
|
||||
def __init__(self, lhs, lookups):
|
||||
self.lhs = lhs
|
||||
self.init_lookups = lookups[:]
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return self.lhs.output_field
|
||||
|
||||
def copy(self):
|
||||
return copy(self)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
copy = self.copy()
|
||||
copy.lhs = self.lhs.relabeled_clone(relabels)
|
||||
return copy
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return self.lhs.get_group_by_cols()
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if self.bilateral:
|
||||
bilateral_transforms.append((self.__class__, self.init_lookups))
|
||||
return bilateral_transforms
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
return self.lhs.contains_aggregate
|
||||
|
||||
|
||||
class Lookup(RegisterLookupMixin):
|
||||
class Lookup(object):
|
||||
lookup_name = None
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
@@ -115,8 +31,8 @@ class Lookup(RegisterLookupMixin):
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
for transform, lookups in self.bilateral_transforms:
|
||||
value = transform(value, lookups)
|
||||
for transform in self.bilateral_transforms:
|
||||
value = transform(value)
|
||||
return value
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
@@ -125,9 +41,9 @@ class Lookup(RegisterLookupMixin):
|
||||
if self.bilateral_transforms:
|
||||
sqls, sqls_params = [], []
|
||||
for p in rhs:
|
||||
value = QueryWrapper('%s',
|
||||
[self.lhs.output_field.get_db_prep_value(p, connection)])
|
||||
value = Value(p, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
sql, sql_params = compiler.compile(value)
|
||||
sqls.append(sql)
|
||||
sqls_params.extend(sql_params)
|
||||
@@ -155,9 +71,9 @@ class Lookup(RegisterLookupMixin):
|
||||
if self.rhs_is_direct_value():
|
||||
# Do not call get_db_prep_lookup here as the value will be
|
||||
# transformed before being used for lookup
|
||||
value = QueryWrapper("%s",
|
||||
[self.lhs.output_field.get_db_prep_value(value, connection)])
|
||||
value = Value(value, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
# Due to historical reasons there are a couple of different
|
||||
# ways to produce sql here. get_compiler is likely a Query
|
||||
# instance, _as_sql QuerySet and as_sql just something with
|
||||
@@ -201,6 +117,31 @@ class Lookup(RegisterLookupMixin):
|
||||
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
||||
first examine self and then check output_field.
|
||||
"""
|
||||
bilateral = False
|
||||
|
||||
def __init__(self, expression, **extra):
|
||||
# Restrict Transform to allow only a single expression.
|
||||
super(Transform, self).__init__(expression, **extra)
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.get_source_expressions()[0]
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if self.bilateral:
|
||||
bilateral_transforms.append(self.__class__)
|
||||
return bilateral_transforms
|
||||
|
||||
|
||||
class BuiltinLookup(Lookup):
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs_sql, params = super(BuiltinLookup, self).process_lhs(
|
||||
@@ -223,12 +164,9 @@ class BuiltinLookup(Lookup):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
|
||||
default_lookups = {}
|
||||
|
||||
|
||||
class Exact(BuiltinLookup):
|
||||
lookup_name = 'exact'
|
||||
default_lookups['exact'] = Exact
|
||||
Field.register_lookup(Exact)
|
||||
|
||||
|
||||
class IExact(BuiltinLookup):
|
||||
@@ -241,27 +179,27 @@ class IExact(BuiltinLookup):
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['iexact'] = IExact
|
||||
Field.register_lookup(IExact)
|
||||
|
||||
|
||||
class GreaterThan(BuiltinLookup):
|
||||
lookup_name = 'gt'
|
||||
default_lookups['gt'] = GreaterThan
|
||||
Field.register_lookup(GreaterThan)
|
||||
|
||||
|
||||
class GreaterThanOrEqual(BuiltinLookup):
|
||||
lookup_name = 'gte'
|
||||
default_lookups['gte'] = GreaterThanOrEqual
|
||||
Field.register_lookup(GreaterThanOrEqual)
|
||||
|
||||
|
||||
class LessThan(BuiltinLookup):
|
||||
lookup_name = 'lt'
|
||||
default_lookups['lt'] = LessThan
|
||||
Field.register_lookup(LessThan)
|
||||
|
||||
|
||||
class LessThanOrEqual(BuiltinLookup):
|
||||
lookup_name = 'lte'
|
||||
default_lookups['lte'] = LessThanOrEqual
|
||||
Field.register_lookup(LessThanOrEqual)
|
||||
|
||||
|
||||
class In(BuiltinLookup):
|
||||
@@ -286,10 +224,14 @@ class In(BuiltinLookup):
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if self.rhs_is_direct_value() and (max_in_list_size and
|
||||
len(self.rhs) > max_in_list_size):
|
||||
# This is a special case for Oracle which limits the number of elements
|
||||
# which can appear in an 'IN' clause.
|
||||
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
|
||||
return self.split_parameter_list_as_sql(compiler, connection)
|
||||
return super(In, self).as_sql(compiler, connection)
|
||||
|
||||
def split_parameter_list_as_sql(self, compiler, connection):
|
||||
# This is a special case for databases which limit the number of
|
||||
# elements which can appear in an 'IN' clause.
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
||||
in_clause_elements = ['(']
|
||||
@@ -307,11 +249,7 @@ class In(BuiltinLookup):
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(')')
|
||||
return ''.join(in_clause_elements), params
|
||||
else:
|
||||
return super(In, self).as_sql(compiler, connection)
|
||||
|
||||
|
||||
default_lookups['in'] = In
|
||||
Field.register_lookup(In)
|
||||
|
||||
|
||||
class PatternLookup(BuiltinLookup):
|
||||
@@ -342,16 +280,12 @@ class Contains(PatternLookup):
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['contains'] = Contains
|
||||
Field.register_lookup(Contains)
|
||||
|
||||
|
||||
class IContains(Contains):
|
||||
lookup_name = 'icontains'
|
||||
|
||||
|
||||
default_lookups['icontains'] = IContains
|
||||
Field.register_lookup(IContains)
|
||||
|
||||
|
||||
class StartsWith(PatternLookup):
|
||||
@@ -362,9 +296,7 @@ class StartsWith(PatternLookup):
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['startswith'] = StartsWith
|
||||
Field.register_lookup(StartsWith)
|
||||
|
||||
|
||||
class IStartsWith(PatternLookup):
|
||||
@@ -375,9 +307,7 @@ class IStartsWith(PatternLookup):
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['istartswith'] = IStartsWith
|
||||
Field.register_lookup(IStartsWith)
|
||||
|
||||
|
||||
class EndsWith(PatternLookup):
|
||||
@@ -388,9 +318,7 @@ class EndsWith(PatternLookup):
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['endswith'] = EndsWith
|
||||
Field.register_lookup(EndsWith)
|
||||
|
||||
|
||||
class IEndsWith(PatternLookup):
|
||||
@@ -401,9 +329,7 @@ class IEndsWith(PatternLookup):
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['iendswith'] = IEndsWith
|
||||
Field.register_lookup(IEndsWith)
|
||||
|
||||
|
||||
class Between(BuiltinLookup):
|
||||
@@ -424,8 +350,7 @@ class Range(BuiltinLookup):
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super(Range, self).process_rhs(compiler, connection)
|
||||
|
||||
default_lookups['range'] = Range
|
||||
Field.register_lookup(Range)
|
||||
|
||||
|
||||
class IsNull(BuiltinLookup):
|
||||
@@ -437,7 +362,7 @@ class IsNull(BuiltinLookup):
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
return "%s IS NOT NULL" % sql, params
|
||||
default_lookups['isnull'] = IsNull
|
||||
Field.register_lookup(IsNull)
|
||||
|
||||
|
||||
class Search(BuiltinLookup):
|
||||
@@ -448,8 +373,7 @@ class Search(BuiltinLookup):
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
|
||||
return sql_template, lhs_params + rhs_params
|
||||
|
||||
default_lookups['search'] = Search
|
||||
Field.register_lookup(Search)
|
||||
|
||||
|
||||
class Regex(BuiltinLookup):
|
||||
@@ -463,9 +387,168 @@ class Regex(BuiltinLookup):
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
||||
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
||||
default_lookups['regex'] = Regex
|
||||
Field.register_lookup(Regex)
|
||||
|
||||
|
||||
class IRegex(Regex):
|
||||
lookup_name = 'iregex'
|
||||
default_lookups['iregex'] = IRegex
|
||||
Field.register_lookup(IRegex)
|
||||
|
||||
|
||||
class DateTimeDateTransform(Transform):
|
||||
lookup_name = 'date'
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
||||
lhs_params.extend(tz_params)
|
||||
return sql, lhs_params
|
||||
|
||||
|
||||
class DateTransform(Transform):
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
||||
params.extend(tz_params)
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
else:
|
||||
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
|
||||
return sql, params
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return IntegerField()
|
||||
|
||||
|
||||
class YearTransform(DateTransform):
|
||||
lookup_name = 'year'
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
|
||||
return bounds
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearExact(YearLookup):
|
||||
lookup_name = 'exact'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# We will need to skip the extract part and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
bounds = self.year_lookup_bounds(connection, rhs_params[0])
|
||||
params.extend(bounds)
|
||||
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
|
||||
|
||||
|
||||
class YearComparisonLookup(YearLookup):
|
||||
def as_sql(self, compiler, connection):
|
||||
# We will need to skip the extract part and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
|
||||
params.append(self.get_bound(start, finish))
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
def get_bound(self):
|
||||
raise NotImplementedError(
|
||||
'subclasses of YearComparisonLookup must provide a get_bound() method'
|
||||
)
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearGt(YearComparisonLookup):
|
||||
lookup_name = 'gt'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return finish
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearGte(YearComparisonLookup):
|
||||
lookup_name = 'gte'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return start
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearLt(YearComparisonLookup):
|
||||
lookup_name = 'lt'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return start
|
||||
|
||||
|
||||
@YearTransform.register_lookup
|
||||
class YearLte(YearComparisonLookup):
|
||||
lookup_name = 'lte'
|
||||
|
||||
def get_bound(self, start, finish):
|
||||
return finish
|
||||
|
||||
|
||||
class MonthTransform(DateTransform):
|
||||
lookup_name = 'month'
|
||||
|
||||
|
||||
class DayTransform(DateTransform):
|
||||
lookup_name = 'day'
|
||||
|
||||
|
||||
class WeekDayTransform(DateTransform):
|
||||
lookup_name = 'week_day'
|
||||
|
||||
|
||||
class HourTransform(DateTransform):
|
||||
lookup_name = 'hour'
|
||||
|
||||
|
||||
class MinuteTransform(DateTransform):
|
||||
lookup_name = 'minute'
|
||||
|
||||
|
||||
class SecondTransform(DateTransform):
|
||||
lookup_name = 'second'
|
||||
|
||||
|
||||
DateField.register_lookup(YearTransform)
|
||||
DateField.register_lookup(MonthTransform)
|
||||
DateField.register_lookup(DayTransform)
|
||||
DateField.register_lookup(WeekDayTransform)
|
||||
|
||||
TimeField.register_lookup(HourTransform)
|
||||
TimeField.register_lookup(MinuteTransform)
|
||||
TimeField.register_lookup(SecondTransform)
|
||||
|
||||
DateTimeField.register_lookup(DateTimeDateTransform)
|
||||
DateTimeField.register_lookup(YearTransform)
|
||||
DateTimeField.register_lookup(MonthTransform)
|
||||
DateTimeField.register_lookup(DayTransform)
|
||||
DateTimeField.register_lookup(WeekDayTransform)
|
||||
DateTimeField.register_lookup(HourTransform)
|
||||
DateTimeField.register_lookup(MinuteTransform)
|
||||
DateTimeField.register_lookup(SecondTransform)
|
||||
|
@@ -7,6 +7,7 @@ circular import difficulties.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
from django.apps import apps
|
||||
@@ -169,6 +170,60 @@ class DeferredAttribute(object):
|
||||
return None
|
||||
|
||||
|
||||
class RegisterLookupMixin(object):
|
||||
def _get_lookup(self, lookup_name):
|
||||
try:
|
||||
return self.class_lookups[lookup_name]
|
||||
except KeyError:
|
||||
# To allow for inheritance, check parent class' class_lookups.
|
||||
for parent in inspect.getmro(self.__class__):
|
||||
if 'class_lookups' not in parent.__dict__:
|
||||
continue
|
||||
if lookup_name in parent.class_lookups:
|
||||
return parent.class_lookups[lookup_name]
|
||||
except AttributeError:
|
||||
# This class didn't have any class_lookups
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
from django.db.models.lookups import Lookup
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_lookup(lookup_name)
|
||||
if found is not None and not issubclass(found, Lookup):
|
||||
return None
|
||||
return found
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
from django.db.models.lookups import Transform
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_transform(lookup_name)
|
||||
if found is not None and not issubclass(found, Transform):
|
||||
return None
|
||||
return found
|
||||
|
||||
@classmethod
|
||||
def register_lookup(cls, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if 'class_lookups' not in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup_name] = lookup
|
||||
return lookup
|
||||
|
||||
@classmethod
|
||||
def _unregister_lookup(cls, lookup, lookup_name=None):
|
||||
"""
|
||||
Remove given lookup from cls lookups. For use in tests only as it's
|
||||
not thread-safe.
|
||||
"""
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
del cls.class_lookups[lookup_name]
|
||||
|
||||
|
||||
def select_related_descend(field, restricted, requested, load_fields, reverse=False):
|
||||
"""
|
||||
Returns True if this field should be used to descend deeper for
|
||||
|
@@ -5,7 +5,7 @@ import copy
|
||||
import warnings
|
||||
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.lookups import RegisterLookupMixin
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.deprecation import RemovedInDjango110Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
@@ -1105,9 +1105,9 @@ class Query(object):
|
||||
Helper method for build_lookup. Tries to fetch and initialize
|
||||
a transform for name parameter from lhs.
|
||||
"""
|
||||
next = lhs.get_transform(name)
|
||||
if next:
|
||||
return next(lhs, rest_of_lookups)
|
||||
transform_class = lhs.get_transform(name)
|
||||
if transform_class:
|
||||
return transform_class(lhs)
|
||||
else:
|
||||
raise FieldError(
|
||||
"Unsupported lookup '%s' for %s or join on the field not "
|
||||
|
@@ -120,10 +120,7 @@ function ``ABS()`` to transform the value before comparison::
|
||||
|
||||
class AbsoluteValue(Transform):
|
||||
lookup_name = 'abs'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "ABS(%s)" % lhs, params
|
||||
function = 'ABS'
|
||||
|
||||
Next, let's register it for ``IntegerField``::
|
||||
|
||||
@@ -157,10 +154,7 @@ be done by adding an ``output_field`` attribute to the transform::
|
||||
|
||||
class AbsoluteValue(Transform):
|
||||
lookup_name = 'abs'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "ABS(%s)" % lhs, params
|
||||
function = 'ABS'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
@@ -243,12 +237,9 @@ this transformation should apply to both ``lhs`` and ``rhs``::
|
||||
|
||||
class UpperCase(Transform):
|
||||
lookup_name = 'upper'
|
||||
function = 'UPPER'
|
||||
bilateral = True
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "UPPER(%s)" % lhs, params
|
||||
|
||||
Next, let's register it::
|
||||
|
||||
from django.db.models import CharField, TextField
|
||||
|
@@ -180,6 +180,18 @@ Usage example::
|
||||
>>> print(author.name_length, author.goes_by_length)
|
||||
(14, None)
|
||||
|
||||
It can also be registered as a transform. For example::
|
||||
|
||||
>>> from django.db.models import CharField
|
||||
>>> from django.db.models.functions import Length
|
||||
>>> CharField.register_lookup(Length, 'length')
|
||||
>>> # Get authors whose name is longer than 7 characters
|
||||
>>> authors = Author.objects.filter(name__length__gt=7)
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
The ability to register the function as a transform was added.
|
||||
|
||||
Lower
|
||||
------
|
||||
|
||||
@@ -188,6 +200,8 @@ Lower
|
||||
Accepts a single text field or expression and returns the lowercase
|
||||
representation.
|
||||
|
||||
It can also be registered as a transform as described in :class:`Length`.
|
||||
|
||||
Usage example::
|
||||
|
||||
>>> from django.db.models.functions import Lower
|
||||
@@ -196,6 +210,10 @@ Usage example::
|
||||
>>> print(author.name_lower)
|
||||
margaret smith
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
The ability to register the function as a transform was added.
|
||||
|
||||
Now
|
||||
---
|
||||
|
||||
@@ -246,6 +264,8 @@ Upper
|
||||
Accepts a single text field or expression and returns the uppercase
|
||||
representation.
|
||||
|
||||
It can also be registered as a transform as described in :class:`Length`.
|
||||
|
||||
Usage example::
|
||||
|
||||
>>> from django.db.models.functions import Upper
|
||||
@@ -253,3 +273,7 @@ Usage example::
|
||||
>>> author = Author.objects.annotate(name_upper=Upper('name')).get()
|
||||
>>> print(author.name_upper)
|
||||
MARGARET SMITH
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
The ability to register the function as a transform was added.
|
||||
|
@@ -42,12 +42,17 @@ register lookups on itself. The two prominent examples are
|
||||
|
||||
A mixin that implements the lookup API on a class.
|
||||
|
||||
.. classmethod:: register_lookup(lookup)
|
||||
.. classmethod:: register_lookup(lookup, lookup_name=None)
|
||||
|
||||
Registers a new lookup in the class. For example
|
||||
``DateField.register_lookup(YearExact)`` will register ``YearExact``
|
||||
lookup on ``DateField``. It overrides a lookup that already exists with
|
||||
the same name.
|
||||
the same name. ``lookup_name`` will be used for this lookup if
|
||||
provided, otherwise ``lookup.lookup_name`` will be used.
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
The ``lookup_name`` parameter was added.
|
||||
|
||||
.. method:: get_lookup(lookup_name)
|
||||
|
||||
@@ -125,7 +130,14 @@ Transform reference
|
||||
``<expression>__<transformation>`` (e.g. ``date__year``).
|
||||
|
||||
This class follows the :ref:`Query Expression API <query-expression>`, which
|
||||
implies that you can use ``<expression>__<transform1>__<transform2>``.
|
||||
implies that you can use ``<expression>__<transform1>__<transform2>``. It's
|
||||
a specialized :ref:`Func() expression <func-expressions>` that only accepts
|
||||
one argument. It can also be used on the right hand side of a filter or
|
||||
directly as an annotation.
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
``Transform`` is now a subclass of ``Func``.
|
||||
|
||||
.. attribute:: bilateral
|
||||
|
||||
@@ -152,18 +164,6 @@ Transform reference
|
||||
:class:`~django.db.models.Field` instance. By default is the same as
|
||||
its ``lhs.output_field``.
|
||||
|
||||
.. method:: as_sql
|
||||
|
||||
To be overridden; raises :exc:`NotImplementedError`.
|
||||
|
||||
.. method:: get_lookup(lookup_name)
|
||||
|
||||
Same as :meth:`~lookups.RegisterLookupMixin.get_lookup()`.
|
||||
|
||||
.. method:: get_transform(transform_name)
|
||||
|
||||
Same as :meth:`~lookups.RegisterLookupMixin.get_transform()`.
|
||||
|
||||
Lookup reference
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@@ -520,6 +520,14 @@ Models
|
||||
* Added the :class:`~django.db.models.functions.Now` database function, which
|
||||
returns the current date and time.
|
||||
|
||||
* :class:`~django.db.models.Transform` is now a subclass of
|
||||
:ref:`Func() <func-expressions>` which allows ``Transform``\s to be used on
|
||||
the right hand side of an expression, just like regular ``Func``\s. This
|
||||
allows registering some database functions like
|
||||
:class:`~django.db.models.functions.Length`,
|
||||
:class:`~django.db.models.functions.Lower`, and
|
||||
:class:`~django.db.models.functions.Upper` as transforms.
|
||||
|
||||
* :class:`~django.db.models.SlugField` now accepts an
|
||||
:attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode
|
||||
characters in slugs.
|
||||
|
@@ -126,11 +126,17 @@ class YearLte(models.lookups.LessThanOrEqual):
|
||||
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
|
||||
|
||||
|
||||
class SQLFunc(models.Lookup):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
super(SQLFunc, self).__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
class Exactly(models.lookups.Exact):
|
||||
"""
|
||||
This lookup is used to test lookup registration.
|
||||
"""
|
||||
lookup_name = 'exactly'
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators['exact'] % rhs
|
||||
|
||||
|
||||
class SQLFuncMixin(object):
|
||||
def as_sql(self, compiler, connection):
|
||||
return '%s()', [self.name]
|
||||
|
||||
@@ -139,13 +145,28 @@ class SQLFunc(models.Lookup):
|
||||
return CustomField()
|
||||
|
||||
|
||||
class SQLFuncLookup(SQLFuncMixin, models.Lookup):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
super(SQLFuncLookup, self).__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
|
||||
|
||||
class SQLFuncTransform(SQLFuncMixin, models.Transform):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
super(SQLFuncTransform, self).__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
|
||||
|
||||
class SQLFuncFactory(object):
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, key, name):
|
||||
self.key = key
|
||||
self.name = name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return SQLFunc(self.name, *args, **kwargs)
|
||||
if self.key == 'lookupfunc':
|
||||
return SQLFuncLookup(self.name, *args, **kwargs)
|
||||
return SQLFuncTransform(self.name, *args, **kwargs)
|
||||
|
||||
|
||||
class CustomField(models.TextField):
|
||||
@@ -153,13 +174,13 @@ class CustomField(models.TextField):
|
||||
def get_lookup(self, lookup_name):
|
||||
if lookup_name.startswith('lookupfunc_'):
|
||||
key, name = lookup_name.split('_', 1)
|
||||
return SQLFuncFactory(name)
|
||||
return SQLFuncFactory(key, name)
|
||||
return super(CustomField, self).get_lookup(lookup_name)
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
if lookup_name.startswith('transformfunc_'):
|
||||
key, name = lookup_name.split('_', 1)
|
||||
return SQLFuncFactory(name)
|
||||
return SQLFuncFactory(key, name)
|
||||
return super(CustomField, self).get_transform(lookup_name)
|
||||
|
||||
|
||||
@@ -200,6 +221,27 @@ class DateTimeTransform(models.Transform):
|
||||
|
||||
|
||||
class LookupTests(TestCase):
|
||||
|
||||
def test_custom_name_lookup(self):
|
||||
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||
Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||
custom_lookup_name = 'isactually'
|
||||
custom_transform_name = 'justtheyear'
|
||||
try:
|
||||
models.DateField.register_lookup(YearTransform)
|
||||
models.DateField.register_lookup(YearTransform, custom_transform_name)
|
||||
YearTransform.register_lookup(Exactly)
|
||||
YearTransform.register_lookup(Exactly, custom_lookup_name)
|
||||
qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
|
||||
qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
|
||||
self.assertQuerysetEqual(qs1, [a1], lambda x: x)
|
||||
self.assertQuerysetEqual(qs2, [a1], lambda x: x)
|
||||
finally:
|
||||
YearTransform._unregister_lookup(Exactly)
|
||||
YearTransform._unregister_lookup(Exactly, custom_lookup_name)
|
||||
models.DateField._unregister_lookup(YearTransform)
|
||||
models.DateField._unregister_lookup(YearTransform, custom_transform_name)
|
||||
|
||||
def test_basic_lookup(self):
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
@@ -299,6 +341,19 @@ class BilateralTransformTests(TestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
|
||||
|
||||
def test_bilateral_multi_value(self):
|
||||
with register_lookup(models.CharField, UpperBilateralTransform):
|
||||
Author.objects.bulk_create([
|
||||
Author(name='Foo'),
|
||||
Author(name='Bar'),
|
||||
Author(name='Ray'),
|
||||
])
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'),
|
||||
['Bar', 'Foo'],
|
||||
lambda a: a.name
|
||||
)
|
||||
|
||||
def test_div3_bilateral_extract(self):
|
||||
with register_lookup(models.IntegerField, Div3BilateralTransform):
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
|
@@ -547,3 +547,97 @@ class FunctionTests(TestCase):
|
||||
['How to Time Travel'],
|
||||
lambda a: a.title
|
||||
)
|
||||
|
||||
def test_length_transform(self):
|
||||
try:
|
||||
CharField.register_lookup(Length, 'length')
|
||||
Author.objects.create(name='John Smith', alias='smithj')
|
||||
Author.objects.create(name='Rhonda')
|
||||
authors = Author.objects.filter(name__length__gt=7)
|
||||
self.assertQuerysetEqual(
|
||||
authors.order_by('name'), [
|
||||
'John Smith',
|
||||
],
|
||||
lambda a: a.name
|
||||
)
|
||||
finally:
|
||||
CharField._unregister_lookup(Length, 'length')
|
||||
|
||||
def test_lower_transform(self):
|
||||
try:
|
||||
CharField.register_lookup(Lower, 'lower')
|
||||
Author.objects.create(name='John Smith', alias='smithj')
|
||||
Author.objects.create(name='Rhonda')
|
||||
authors = Author.objects.filter(name__lower__exact='john smith')
|
||||
self.assertQuerysetEqual(
|
||||
authors.order_by('name'), [
|
||||
'John Smith',
|
||||
],
|
||||
lambda a: a.name
|
||||
)
|
||||
finally:
|
||||
CharField._unregister_lookup(Lower, 'lower')
|
||||
|
||||
def test_upper_transform(self):
|
||||
try:
|
||||
CharField.register_lookup(Upper, 'upper')
|
||||
Author.objects.create(name='John Smith', alias='smithj')
|
||||
Author.objects.create(name='Rhonda')
|
||||
authors = Author.objects.filter(name__upper__exact='JOHN SMITH')
|
||||
self.assertQuerysetEqual(
|
||||
authors.order_by('name'), [
|
||||
'John Smith',
|
||||
],
|
||||
lambda a: a.name
|
||||
)
|
||||
finally:
|
||||
CharField._unregister_lookup(Upper, 'upper')
|
||||
|
||||
def test_func_transform_bilateral(self):
|
||||
class UpperBilateral(Upper):
|
||||
bilateral = True
|
||||
|
||||
try:
|
||||
CharField.register_lookup(UpperBilateral, 'upper')
|
||||
Author.objects.create(name='John Smith', alias='smithj')
|
||||
Author.objects.create(name='Rhonda')
|
||||
authors = Author.objects.filter(name__upper__exact='john smith')
|
||||
self.assertQuerysetEqual(
|
||||
authors.order_by('name'), [
|
||||
'John Smith',
|
||||
],
|
||||
lambda a: a.name
|
||||
)
|
||||
finally:
|
||||
CharField._unregister_lookup(UpperBilateral, 'upper')
|
||||
|
||||
def test_func_transform_bilateral_multivalue(self):
|
||||
class UpperBilateral(Upper):
|
||||
bilateral = True
|
||||
|
||||
try:
|
||||
CharField.register_lookup(UpperBilateral, 'upper')
|
||||
Author.objects.create(name='John Smith', alias='smithj')
|
||||
Author.objects.create(name='Rhonda')
|
||||
authors = Author.objects.filter(name__upper__in=['john smith', 'rhonda'])
|
||||
self.assertQuerysetEqual(
|
||||
authors.order_by('name'), [
|
||||
'John Smith',
|
||||
'Rhonda',
|
||||
],
|
||||
lambda a: a.name
|
||||
)
|
||||
finally:
|
||||
CharField._unregister_lookup(UpperBilateral, 'upper')
|
||||
|
||||
def test_function_as_filter(self):
|
||||
Author.objects.create(name='John Smith', alias='SMITHJ')
|
||||
Author.objects.create(name='Rhonda')
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(alias=Upper(V('smithj'))),
|
||||
['John Smith'], lambda x: x.name
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.exclude(alias=Upper(V('smithj'))),
|
||||
['Rhonda'], lambda x: x.name
|
||||
)
|
||||
|
Reference in New Issue
Block a user