mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #21391 -- Allow model signals to lazily reference their senders.
This commit is contained in:
		| @@ -1,5 +1,6 @@ | |||||||
| import collections | import collections | ||||||
| import sys | import sys | ||||||
|  | import types | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.management.color import color_style | from django.core.management.color import color_style | ||||||
| @@ -25,7 +26,7 @@ def get_validation_errors(outfile, app=None): | |||||||
|     validates all models of all installed apps. Writes errors, if any, to outfile. |     validates all models of all installed apps. Writes errors, if any, to outfile. | ||||||
|     Returns number of errors. |     Returns number of errors. | ||||||
|     """ |     """ | ||||||
|     from django.db import models, connection |     from django.db import connection, models | ||||||
|     from django.db.models.loading import get_app_errors |     from django.db.models.loading import get_app_errors | ||||||
|     from django.db.models.deletion import SET_NULL, SET_DEFAULT |     from django.db.models.deletion import SET_NULL, SET_DEFAULT | ||||||
|  |  | ||||||
| @@ -363,6 +364,8 @@ def get_validation_errors(outfile, app=None): | |||||||
|             for it in opts.index_together: |             for it in opts.index_together: | ||||||
|                 validate_local_fields(e, opts, "index_together", it) |                 validate_local_fields(e, opts, "index_together", it) | ||||||
|  |  | ||||||
|  |     validate_model_signals(e) | ||||||
|  |  | ||||||
|     return len(e.errors) |     return len(e.errors) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -382,3 +385,28 @@ def validate_local_fields(e, opts, field_name, fields): | |||||||
|                     e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name)) |                     e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name)) | ||||||
|                 if f not in opts.local_fields: |                 if f not in opts.local_fields: | ||||||
|                     e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name)) |                     e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def validate_model_signals(e): | ||||||
|  |     """Ensure lazily referenced model signals senders are installed.""" | ||||||
|  |     from django.db import models | ||||||
|  |  | ||||||
|  |     for name in dir(models.signals): | ||||||
|  |         obj = getattr(models.signals, name) | ||||||
|  |         if isinstance(obj, models.signals.ModelSignal): | ||||||
|  |             for reference, receivers in obj.unresolved_references.items(): | ||||||
|  |                 for receiver, _, _ in receivers: | ||||||
|  |                     # The receiver is either a function or an instance of class | ||||||
|  |                     # defining a `__call__` method. | ||||||
|  |                     if isinstance(receiver, types.FunctionType): | ||||||
|  |                         description = "The `%s` function" % receiver.__name__ | ||||||
|  |                     else: | ||||||
|  |                         description = "An instance of the `%s` class" % receiver.__class__.__name__ | ||||||
|  |                     e.add( | ||||||
|  |                         receiver.__module__, | ||||||
|  |                         "%s was connected to the `%s` signal " | ||||||
|  |                         "with a lazy reference to the '%s' sender, " | ||||||
|  |                         "which has not been installed." % ( | ||||||
|  |                             description, name, '.'.join(reference) | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|   | |||||||
| @@ -1,20 +1,70 @@ | |||||||
|  | from collections import defaultdict | ||||||
|  |  | ||||||
|  | from django.db.models.loading import get_model | ||||||
| from django.dispatch import Signal | from django.dispatch import Signal | ||||||
|  | from django.utils import six | ||||||
|  |  | ||||||
|  |  | ||||||
| class_prepared = Signal(providing_args=["class"]) | class_prepared = Signal(providing_args=["class"]) | ||||||
|  |  | ||||||
| pre_init = Signal(providing_args=["instance", "args", "kwargs"], use_caching=True) |  | ||||||
| post_init = Signal(providing_args=["instance"], use_caching=True) |  | ||||||
|  |  | ||||||
| pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"], | class ModelSignal(Signal): | ||||||
|                  use_caching=True) |     """ | ||||||
| post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True) |     Signal subclass that allows the sender to be lazily specified as a string | ||||||
|  |     of the `app_label.ModelName` form. | ||||||
|  |     """ | ||||||
|  |  | ||||||
| pre_delete = Signal(providing_args=["instance", "using"], use_caching=True) |     def __init__(self, *args, **kwargs): | ||||||
| post_delete = Signal(providing_args=["instance", "using"], use_caching=True) |         super(ModelSignal, self).__init__(*args, **kwargs) | ||||||
|  |         self.unresolved_references = defaultdict(list) | ||||||
|  |         class_prepared.connect(self._resolve_references) | ||||||
|  |  | ||||||
|  |     def _resolve_references(self, sender, **kwargs): | ||||||
|  |         opts = sender._meta | ||||||
|  |         reference = (opts.app_label, opts.object_name) | ||||||
|  |         try: | ||||||
|  |             receivers = self.unresolved_references.pop(reference) | ||||||
|  |         except KeyError: | ||||||
|  |             pass | ||||||
|  |         else: | ||||||
|  |             for receiver, weak, dispatch_uid in receivers: | ||||||
|  |                 super(ModelSignal, self).connect( | ||||||
|  |                     receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def connect(self, receiver, sender=None, weak=True, dispatch_uid=None): | ||||||
|  |         if isinstance(sender, six.string_types): | ||||||
|  |             try: | ||||||
|  |                 app_label, object_name = sender.split('.') | ||||||
|  |             except ValueError: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Specified sender must either be a model or a " | ||||||
|  |                     "model name of the 'app_label.ModelName' form." | ||||||
|  |                 ) | ||||||
|  |             sender = get_model(app_label, object_name, only_installed=False) | ||||||
|  |             if sender is None: | ||||||
|  |                 reference = (app_label, object_name) | ||||||
|  |                 self.unresolved_references[reference].append( | ||||||
|  |                     (receiver, weak, dispatch_uid) | ||||||
|  |                 ) | ||||||
|  |                 return | ||||||
|  |         super(ModelSignal, self).connect( | ||||||
|  |             receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  | pre_init = ModelSignal(providing_args=["instance", "args", "kwargs"], use_caching=True) | ||||||
|  | post_init = ModelSignal(providing_args=["instance"], use_caching=True) | ||||||
|  |  | ||||||
|  | pre_save = ModelSignal(providing_args=["instance", "raw", "using", "update_fields"], | ||||||
|  |                        use_caching=True) | ||||||
|  | post_save = ModelSignal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True) | ||||||
|  |  | ||||||
|  | pre_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True) | ||||||
|  | post_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True) | ||||||
|  |  | ||||||
|  | m2m_changed = ModelSignal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True) | ||||||
|  |  | ||||||
| pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"]) | pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"]) | ||||||
| pre_syncdb = pre_migrate | pre_syncdb = pre_migrate | ||||||
| post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"]) | post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"]) | ||||||
| post_syncdb = post_migrate | post_syncdb = post_migrate | ||||||
|  |  | ||||||
| m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True) |  | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ Model signals | |||||||
|    :synopsis: Signals sent by the model system. |    :synopsis: Signals sent by the model system. | ||||||
|  |  | ||||||
| The :mod:`django.db.models.signals` module defines a set of signals sent by the | The :mod:`django.db.models.signals` module defines a set of signals sent by the | ||||||
| module system. | model system. | ||||||
|  |  | ||||||
| .. warning:: | .. warning:: | ||||||
|  |  | ||||||
| @@ -37,6 +37,14 @@ module system. | |||||||
|     so if your handler is a local function, it may be garbage collected.  To |     so if your handler is a local function, it may be garbage collected.  To | ||||||
|     prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`. |     prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`. | ||||||
|  |  | ||||||
|  | .. versionadded:: 1.7 | ||||||
|  |  | ||||||
|  |     Model signals ``sender`` model can be lazily referenced when connecting a | ||||||
|  |     receiver by specifying its full application label. For example, an | ||||||
|  |     ``Answer`` model defined in the ``polls`` application could be referenced | ||||||
|  |     as ``'polls.Answer'``. This sort of reference can be quite handy when | ||||||
|  |     dealing with circular import dependencies and swappable models. | ||||||
|  |  | ||||||
| pre_init | pre_init | ||||||
| -------- | -------- | ||||||
|  |  | ||||||
|   | |||||||
| @@ -425,7 +425,7 @@ Models | |||||||
| * Is it now possible to avoid creating a backward relation for | * Is it now possible to avoid creating a backward relation for | ||||||
|   :class:`~django.db.models.OneToOneField` by setting its |   :class:`~django.db.models.OneToOneField` by setting its | ||||||
|   :attr:`~django.db.models.ForeignKey.related_name` to |   :attr:`~django.db.models.ForeignKey.related_name` to | ||||||
|   `'+'` or ending it with `'+'`. |   ``'+'`` or ending it with ``'+'``. | ||||||
|  |  | ||||||
| * :class:`F expressions <django.db.models.F>` support the power operator | * :class:`F expressions <django.db.models.F>` support the power operator | ||||||
|   (``**``). |   (``**``). | ||||||
| @@ -436,6 +436,10 @@ Signals | |||||||
| * The ``enter`` argument was added to the | * The ``enter`` argument was added to the | ||||||
|   :data:`~django.test.signals.setting_changed` signal. |   :data:`~django.test.signals.setting_changed` signal. | ||||||
|  |  | ||||||
|  | * The model signals can be now be connected to using a ``str`` of the | ||||||
|  |   ``'app_label.ModelName'`` form – just like related fields – to lazily | ||||||
|  |   reference their senders. | ||||||
|  |  | ||||||
| Templates | Templates | ||||||
| ^^^^^^^^^ | ^^^^^^^^^ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -413,6 +413,19 @@ different User model. | |||||||
|         class Article(models.Model): |         class Article(models.Model): | ||||||
|             author = models.ForeignKey(settings.AUTH_USER_MODEL) |             author = models.ForeignKey(settings.AUTH_USER_MODEL) | ||||||
|  |  | ||||||
|  |     .. versionadded:: 1.7 | ||||||
|  |  | ||||||
|  |         When connecting to signals sent by the User model, you should specify the | ||||||
|  |         custom model using the :setting:`AUTH_USER_MODEL` setting. For example:: | ||||||
|  |  | ||||||
|  |             from django.conf import settings | ||||||
|  |             from django.db.models.signals import post_save | ||||||
|  |  | ||||||
|  |             def post_save_receiver(signal, sender, instance, **kwargs): | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             post_save.connect(post_save_receiver, sender=settings.AUTH_USER_MODEL) | ||||||
|  |  | ||||||
| Specifying a custom User model | Specifying a custom User model | ||||||
| ------------------------------ | ------------------------------ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,10 +1,22 @@ | |||||||
| from django.core import management | from django.core import management | ||||||
|  | from django.core.management.validation import ( | ||||||
|  |     ModelErrorCollection, validate_model_signals | ||||||
|  | ) | ||||||
|  | from django.db.models.signals import post_init | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  |  | ||||||
|  |  | ||||||
| class ModelValidationTest(TestCase): | class OnPostInit(object): | ||||||
|  |     def __call__(self, **kwargs): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def on_post_init(**kwargs): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelValidationTest(TestCase): | ||||||
|     def test_models_validate(self): |     def test_models_validate(self): | ||||||
|         # All our models should validate properly |         # All our models should validate properly | ||||||
|         # Validation Tests: |         # Validation Tests: | ||||||
| @@ -13,3 +25,23 @@ class ModelValidationTest(TestCase): | |||||||
|         #   * related_name='+' doesn't clash with another '+' |         #   * related_name='+' doesn't clash with another '+' | ||||||
|         #       See: https://code.djangoproject.com/ticket/21375 |         #       See: https://code.djangoproject.com/ticket/21375 | ||||||
|         management.call_command("validate", stdout=six.StringIO()) |         management.call_command("validate", stdout=six.StringIO()) | ||||||
|  |  | ||||||
|  |     def test_model_signal(self): | ||||||
|  |         unresolved_references = post_init.unresolved_references.copy() | ||||||
|  |         post_init.connect(on_post_init, sender='missing-app.Model') | ||||||
|  |         post_init.connect(OnPostInit(), sender='missing-app.Model') | ||||||
|  |         e = ModelErrorCollection(six.StringIO()) | ||||||
|  |         validate_model_signals(e) | ||||||
|  |         self.assertSetEqual(set(e.errors), { | ||||||
|  |             ('model_validation.tests', | ||||||
|  |                 "The `on_post_init` function was connected to the `post_init` " | ||||||
|  |                 "signal with a lazy reference to the 'missing-app.Model' " | ||||||
|  |                 "sender, which has not been installed." | ||||||
|  |             ), | ||||||
|  |             ('model_validation.tests', | ||||||
|  |                 "An instance of the `OnPostInit` class was connected to " | ||||||
|  |                 "the `post_init` signal with a lazy reference to the " | ||||||
|  |                 "'missing-app.Model' sender, which has not been installed." | ||||||
|  |             ) | ||||||
|  |         }) | ||||||
|  |         post_init.unresolved_references = unresolved_references | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
|  |  | ||||||
|  | from django.db import models | ||||||
| from django.db.models import signals | from django.db.models import signals | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| @@ -8,8 +9,7 @@ from django.utils import six | |||||||
| from .models import Author, Book, Car, Person | from .models import Author, Book, Car, Person | ||||||
|  |  | ||||||
|  |  | ||||||
| class SignalTests(TestCase): | class BaseSignalTest(TestCase): | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         # Save up the number of connected signals so that we can check at the |         # Save up the number of connected signals so that we can check at the | ||||||
|         # end that all the signals we register get properly unregistered (#9989) |         # end that all the signals we register get properly unregistered (#9989) | ||||||
| @@ -30,6 +30,8 @@ class SignalTests(TestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(self.pre_signals, post_signals) |         self.assertEqual(self.pre_signals, post_signals) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SignalTests(BaseSignalTest): | ||||||
|     def test_save_signals(self): |     def test_save_signals(self): | ||||||
|         data = [] |         data = [] | ||||||
|  |  | ||||||
| @@ -239,3 +241,48 @@ class SignalTests(TestCase): | |||||||
|         self.assertTrue(a._run) |         self.assertTrue(a._run) | ||||||
|         self.assertTrue(b._run) |         self.assertTrue(b._run) | ||||||
|         self.assertEqual(signals.post_save.receivers, []) |         self.assertEqual(signals.post_save.receivers, []) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LazyModelRefTest(BaseSignalTest): | ||||||
|  |     def setUp(self): | ||||||
|  |         super(LazyModelRefTest, self).setUp() | ||||||
|  |         self.received = [] | ||||||
|  |  | ||||||
|  |     def receiver(self, **kwargs): | ||||||
|  |         self.received.append(kwargs) | ||||||
|  |  | ||||||
|  |     def test_invalid_sender_model_name(self): | ||||||
|  |         with self.assertRaisesMessage(ValueError, | ||||||
|  |                     "Specified sender must either be a model or a " | ||||||
|  |                     "model name of the 'app_label.ModelName' form."): | ||||||
|  |             signals.post_init.connect(self.receiver, sender='invalid') | ||||||
|  |  | ||||||
|  |     def test_already_loaded_model(self): | ||||||
|  |         signals.post_init.connect( | ||||||
|  |             self.receiver, sender='signals.Book', weak=False | ||||||
|  |         ) | ||||||
|  |         try: | ||||||
|  |             instance = Book() | ||||||
|  |             self.assertEqual(self.received, [{ | ||||||
|  |                 'signal': signals.post_init, | ||||||
|  |                 'sender': Book, | ||||||
|  |                 'instance': instance | ||||||
|  |             }]) | ||||||
|  |         finally: | ||||||
|  |             signals.post_init.disconnect(self.receiver, sender=Book) | ||||||
|  |  | ||||||
|  |     def test_not_loaded_model(self): | ||||||
|  |         signals.post_init.connect( | ||||||
|  |             self.receiver, sender='signals.Created', weak=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             class Created(models.Model): | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             instance = Created() | ||||||
|  |             self.assertEqual(self.received, [{ | ||||||
|  |                 'signal': signals.post_init, 'sender': Created, 'instance': instance | ||||||
|  |             }]) | ||||||
|  |         finally: | ||||||
|  |             signals.post_init.disconnect(self.receiver, sender=Created) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user