From 7aeb7390fc4231119494a9ebdee3c6ee0d5af053 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 11 Aug 2016 14:16:48 -0400 Subject: [PATCH] Fixed #26891 -- Fixed lookup registration for ForeignObject. --- django/db/models/fields/related.py | 37 ++++++++++------------ django/db/models/query_utils.py | 51 ++++++++++++++++++++++-------- tests/custom_lookups/models.py | 4 +++ tests/custom_lookups/tests.py | 24 +++++++++++++- 4 files changed, 81 insertions(+), 35 deletions(-) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index a41e93b73f..67ecefb336 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import inspect import warnings from functools import partial @@ -17,6 +18,7 @@ from django.utils import six from django.utils.deprecation import RemovedInDjango20Warning from django.utils.encoding import force_text from django.utils.functional import cached_property, curry +from django.utils.lru_cache import lru_cache from django.utils.translation import ugettext_lazy as _ from django.utils.version import get_docs_version @@ -731,26 +733,13 @@ class ForeignObject(RelatedField): pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] return pathinfos - def get_lookup(self, lookup_name): - if lookup_name == 'in': - return RelatedIn - elif lookup_name == 'exact': - return RelatedExact - elif lookup_name == 'gt': - return RelatedGreaterThan - elif lookup_name == 'gte': - return RelatedGreaterThanOrEqual - elif lookup_name == 'lt': - return RelatedLessThan - elif lookup_name == 'lte': - return RelatedLessThanOrEqual - elif lookup_name == 'isnull': - return RelatedIsNull - else: - raise TypeError('Related Field got invalid lookup: %s' % lookup_name) - - def get_transform(self, *args, **kwargs): - raise NotImplementedError('Relational fields do not support transforms.') + @classmethod + @lru_cache(maxsize=None) + def get_lookups(cls): + bases = inspect.getmro(cls) + bases = bases[:bases.index(ForeignObject) + 1] + class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in bases] + return cls.merge_dicts(class_lookups) def contribute_to_class(self, cls, name, private_only=False, **kwargs): super(ForeignObject, self).contribute_to_class(cls, name, private_only=private_only, **kwargs) @@ -767,6 +756,14 @@ class ForeignObject(RelatedField): if self.remote_field.limit_choices_to: cls._meta.related_fkey_lookups.append(self.remote_field.limit_choices_to) +ForeignObject.register_lookup(RelatedIn) +ForeignObject.register_lookup(RelatedExact) +ForeignObject.register_lookup(RelatedLessThan) +ForeignObject.register_lookup(RelatedGreaterThan) +ForeignObject.register_lookup(RelatedGreaterThanOrEqual) +ForeignObject.register_lookup(RelatedLessThanOrEqual) +ForeignObject.register_lookup(RelatedIsNull) + class ForeignKey(ForeignObject): """ diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index b71cd5649c..111ba65f3c 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -13,6 +13,7 @@ from collections import namedtuple from django.core.exceptions import FieldDoesNotExist from django.db.models.constants import LOOKUP_SEP from django.utils import tree +from django.utils.lru_cache import lru_cache # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both @@ -27,6 +28,15 @@ class InvalidQuery(Exception): pass +def subclasses(cls): + yield cls + # Python 2 lacks 'yield from', which could replace the inner loop + for subclass in cls.__subclasses__(): + # yield from subclasses(subclass) + for item in subclasses(subclass): + yield item + + class QueryWrapper(object): """ A type that indicates the contents are an SQL fragment and the associate @@ -132,20 +142,16 @@ class DeferredAttribute(object): class RegisterLookupMixin(object): - def _get_lookup(self, lookup_name): - try: - return self.class_lookups[lookup_name] - except KeyError: - # To allow for inheritance, check parent class' class_lookups. - for parent in inspect.getmro(self.__class__): - if 'class_lookups' not in parent.__dict__: - continue - if lookup_name in parent.class_lookups: - return parent.class_lookups[lookup_name] - except AttributeError: - # This class didn't have any class_lookups - pass - return None + + @classmethod + def _get_lookup(cls, lookup_name): + return cls.get_lookups().get(lookup_name, None) + + @classmethod + @lru_cache(maxsize=None) + def get_lookups(cls): + class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(cls)] + return cls.merge_dicts(class_lookups) def get_lookup(self, lookup_name): from django.db.models.lookups import Lookup @@ -165,6 +171,22 @@ class RegisterLookupMixin(object): return None return found + @staticmethod + def merge_dicts(dicts): + """ + Merge dicts in reverse to preference the order of the original list. e.g., + merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'. + """ + merged = {} + for d in reversed(dicts): + merged.update(d) + return merged + + @classmethod + def _clear_cached_lookups(cls): + for subclass in subclasses(cls): + subclass.get_lookups.cache_clear() + @classmethod def register_lookup(cls, lookup, lookup_name=None): if lookup_name is None: @@ -172,6 +194,7 @@ class RegisterLookupMixin(object): if 'class_lookups' not in cls.__dict__: cls.class_lookups = {} cls.class_lookups[lookup_name] = lookup + cls._clear_cached_lookups() return lookup @classmethod diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py index 82a835e160..97979dd953 100644 --- a/tests/custom_lookups/models.py +++ b/tests/custom_lookups/models.py @@ -13,6 +13,10 @@ class Author(models.Model): return self.name +class Article(models.Model): + author = models.ForeignKey(Author, on_delete=models.CASCADE) + + @python_2_unicode_compatible class MySQLUnixTimestamp(models.Model): timestamp = models.PositiveIntegerField() diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index c538d23b76..da9274904c 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -10,7 +10,7 @@ from django.db import connection, models from django.test import TestCase, override_settings from django.utils import timezone -from .models import Author, MySQLUnixTimestamp +from .models import Article, Author, MySQLUnixTimestamp @contextlib.contextmanager @@ -319,6 +319,28 @@ class LookupTests(TestCase): baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4], lambda x: x) + def test_foreignobject_lookup_registration(self): + field = Article._meta.get_field('author') + + with register_lookup(models.ForeignObject, Exactly): + self.assertIs(field.get_lookup('exactly'), Exactly) + + # ForeignObject should ignore regular Field lookups + with register_lookup(models.Field, Exactly): + self.assertIsNone(field.get_lookup('exactly')) + + def test_lookups_caching(self): + field = Article._meta.get_field('author') + + # clear and re-cache + field.get_lookups.cache_clear() + self.assertNotIn('exactly', field.get_lookups()) + + # registration should bust the cache + with register_lookup(models.ForeignObject, Exactly): + # getting the lookups again should re-cache + self.assertIn('exactly', field.get_lookups()) + class BilateralTransformTests(TestCase):