diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 96eb8c8776..706e37a6bd 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -19,6 +19,9 @@ class MultiColSource: return self.__class__(relabels.get(self.alias, self.alias), self.targets, self.sources, self.field) + def get_lookup(self, lookup): + return self.output_field.get_lookup(lookup) + def get_normalized_value(value, lhs): from django.db.models import Model diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index bdec204351..ccbbe0cd5f 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -964,19 +964,9 @@ class Query: def as_sql(self, compiler, connection): return self.get_compiler(connection=connection).as_sql() - def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True): - # Default lookup if none given is exact. + def resolve_lookup_value(self, value, can_reuse, allow_joins): used_joins = set() - if len(lookups) == 0: - lookups = ['exact'] - # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all - # uses of None as a query value. - if value is None: - if lookups[-1] not in ('exact', 'iexact'): - raise ValueError("Cannot use None as a query value") - lookups[-1] = 'isnull' - return True, lookups, used_joins - elif hasattr(value, 'resolve_expression'): + if hasattr(value, 'resolve_expression'): pre_joins = self.alias_refcount.copy() value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} @@ -993,15 +983,7 @@ class Query: # The used_joins for a tuple of expressions is the union of # the used_joins for the individual expressions. used_joins.update(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)) - # For Oracle '' is equivalent to null. The check needs to be done - # at this stage because join promotion can't be done at compiler - # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we - # can do here. Similar thing is done in is_nullable(), too. - if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and - lookups[-1] == 'exact' and value == ''): - value = True - lookups[-1] = 'isnull' - return value, lookups, used_joins + return value, used_joins def solve_lookup_type(self, lookup): """ @@ -1014,13 +996,11 @@ class Query: return expression_lookups, (), expression _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] - if len(lookup_parts) == 0: - lookup_parts = ['exact'] - elif len(lookup_parts) > 1: - if not field_parts: - raise FieldError( - 'Invalid lookup "%s" for model %s".' % - (lookup, self.get_meta().model.__name__)) + if len(lookup_parts) > 1 and not field_parts: + raise FieldError( + 'Invalid lookup "%s" for model %s".' % + (lookup, self.get_meta().model.__name__) + ) return lookup_parts, field_parts, False def check_query_object_type(self, value, opts, field): @@ -1063,23 +1043,43 @@ class Query: The lookups is a list of names to extract using get_lookup() and get_transform(). """ - lookups = lookups[:] - while lookups: - name = lookups[0] - # If there is just one part left, try first get_lookup() so - # that if the lhs supports both transform and lookup for the - # name, then lookup will be picked. - if len(lookups) == 1: - final_lookup = lhs.get_lookup(name) - if not final_lookup: - # We didn't find a lookup. We are going to interpret - # the name as transform, and do an Exact lookup against - # it. - lhs = self.try_transform(lhs, name) - final_lookup = lhs.get_lookup('exact') - return final_lookup(lhs, rhs) + # __exact is the default lookup if one isn't given. + if len(lookups) == 0: + lookups = ['exact'] + + for name in lookups[:-1]: lhs = self.try_transform(lhs, name) - lookups = lookups[1:] + # First try get_lookup() so that the lookup takes precedence if the lhs + # supports both transform and lookup for the name. + lookup_class = lhs.get_lookup(lookups[-1]) + if not lookup_class: + if lhs.field.is_relation: + raise FieldError('Related Field got invalid lookup: {}'.format(lookups[-1])) + # A lookup wasn't found. Try to interpret the name as a transform + # and do an Exact lookup against it. + lhs = self.try_transform(lhs, lookups[-1]) + lookup_class = lhs.get_lookup('exact') + + if not lookup_class: + return + + lookup = lookup_class(lhs, rhs) + # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all + # uses of None as a query value. + if lookup.rhs is None: + if lookup.lookup_name not in ('exact', 'iexact'): + raise ValueError("Cannot use None as a query value") + return lhs.get_lookup('isnull')(lhs, True) + + # For Oracle '' is equivalent to null. The check must be done at this + # stage because join promotion can't be done in the compiler. Using + # DEFAULT_DB_ALIAS isn't nice but it's the best that can be done here. + # A similar thing is done in is_nullable(), too. + if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and + lookup.lookup_name == 'exact' and lookup.rhs == ''): + return lhs.get_lookup('isnull')(lhs, True) + + return lookup def try_transform(self, lhs, name): """ @@ -1133,7 +1133,7 @@ class Query: # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins) + value, used_joins = self.resolve_lookup_value(value, can_reuse, allow_joins) clause = self.where_class() if reffed_expression: @@ -1173,25 +1173,19 @@ class Query: num_lookups = len(lookups) if num_lookups > 1: raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) - assert num_lookups > 0 # Likely a bug in Django if this fails. - lookup_class = field.get_lookup(lookups[0]) - if lookup_class is None: - raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) if len(targets) == 1: - lhs = targets[0].get_col(alias, field) + col = targets[0].get_col(alias, field) else: - lhs = MultiColSource(alias, targets, sources, field) - condition = lookup_class(lhs, value) - lookup_type = lookup_class.lookup_name + col = MultiColSource(alias, targets, sources, field) else: col = targets[0].get_col(alias, field) - condition = self.build_lookup(lookups, col, value) - lookup_type = condition.lookup_name + condition = self.build_lookup(lookups, col, value) + lookup_type = condition.lookup_name clause.add(condition, AND) - require_outer = lookup_type == 'isnull' and value is True and not current_negated - if current_negated and (lookup_type != 'isnull' or value is False): + require_outer = lookup_type == 'isnull' and condition.rhs is True and not current_negated + if current_negated and (lookup_type != 'isnull' or condition.rhs is False): require_outer = True if (lookup_type != 'isnull' and ( self.is_nullable(targets[0]) or diff --git a/tests/lookup/models.py b/tests/lookup/models.py index 1c7ea799a6..d58d863885 100644 --- a/tests/lookup/models.py +++ b/tests/lookup/models.py @@ -44,7 +44,8 @@ class Tag(models.Model): class NulledTextField(models.TextField): - pass + def get_prep_value(self, value): + return None if value == '' else value @NulledTextField.register_lookup diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index 37e98f43e2..7b08c778df 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -846,3 +846,13 @@ class LookupTests(TestCase): self.assertFalse(Season.objects.filter(nulled_text_field__isnull=True)) self.assertTrue(Season.objects.filter(nulled_text_field__nulled__isnull=True)) self.assertTrue(Season.objects.filter(nulled_text_field__nulled__exact=None)) + self.assertTrue(Season.objects.filter(nulled_text_field__nulled=None)) + + def test_custom_field_none_rhs(self): + """ + __exact=value is transformed to __isnull=True if Field.get_prep_value() + converts value to None. + """ + season = Season.objects.create(year=2012, nulled_text_field=None) + self.assertTrue(Season.objects.filter(pk=season.pk, nulled_text_field__isnull=True)) + self.assertTrue(Season.objects.filter(pk=season.pk, nulled_text_field=''))