From 013bcf57d54afea02611413c0169351a1521ee7c Mon Sep 17 00:00:00 2001
From: Simon Charette <charette.s@gmail.com>
Date: Fri, 3 Feb 2017 00:42:15 -0500
Subject: [PATCH] Introduced ModelTuple to remove migrations boilerplate.

---
 django/db/migrations/operations/base.py   | 10 ------
 django/db/migrations/operations/fields.py | 41 ++++-----------------
 django/db/migrations/operations/models.py | 27 +++++++-------
 django/db/migrations/operations/utils.py  | 44 +++++++++++++++++++++++
 4 files changed, 62 insertions(+), 60 deletions(-)

diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py
index 2448284a2b..3fb1002c44 100644
--- a/django/db/migrations/operations/base.py
+++ b/django/db/migrations/operations/base.py
@@ -80,16 +80,6 @@ class Operation:
         """
         return "%s: %s" % (self.__class__.__name__, self._constructor_args)
 
-    def model_to_key(self, model):
-        """
-        Take either a model class or an 'app_label.ModelName' string and return
-        (app_label, model_name).
-        """
-        if isinstance(model, str):
-            return tuple(model.lower().split('.', 1))
-        else:
-            return model._meta.app_label, model._meta.model_name
-
     def references_model(self, name, app_label=None):
         """
         Return True if there is a chance this operation references the given
diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py
index 66231292e8..34f1d2b64d 100644
--- a/django/db/migrations/operations/fields.py
+++ b/django/db/migrations/operations/fields.py
@@ -3,7 +3,9 @@ from django.db.models.fields import NOT_PROVIDED
 from django.utils.functional import cached_property
 
 from .base import Operation
-from .utils import is_referenced_by_foreign_key
+from .utils import (
+    ModelTuple, field_references_model, is_referenced_by_foreign_key,
+)
 
 
 class FieldOperation(Operation):
@@ -31,40 +33,9 @@ class FieldOperation(Operation):
         if name_lower == self.model_name_lower:
             return True
         if self.field:
-            if self.field.remote_field:
-                remote_app_label, remote_model_name = self.model_to_key(self.field.remote_field.model)
-                if (remote_model_name == name_lower and app_label is None or
-                        not remote_app_label or remote_app_label == app_label):
-                    return True
-                through = getattr(self.field.remote_field, 'through', None)
-                if through and self.model_to_key(through) == (app_label, name_lower):
-                    through_app_label, through_model_name = self.model_to_key(through)
-                    if (through_model_name == name_lower and app_label is None or
-                            not through_app_label or through_app_label == app_label):
-                        return True
-            return False
-        return True
-
-    def references_field(self, model_name, name, app_label=None):
-        if self.field:
-            model_name_lower = model_name.lower()
-            remote_field = self.field.remote_field
-            if remote_field:
-                remote_app_label, remote_model_name = self.model_to_key(remote_field.model)
-                if (remote_model_name == model_name_lower and
-                        (app_label is None or not remote_app_label or remote_app_label == app_label)):
-                    # TODO: Consider to_fields/from_fields.
-                    return True
-                through = getattr(remote_field, 'through', None)
-                if through and self.model_to_key(through) == (app_label, model_name_lower):
-                    through_app_label, through_model_name = self.model_to_key(through)
-                    if (through_model_name == model_name_lower and
-                        (app_label is None or not through_app_label or through_app_label == app_label) and
-                            (remote_field.through_fields is None or name in remote_field.through_fields)):
-                            return True
-            elif model_name_lower == self.model_name_lower and name == self.name:
-                return True
-            return False
+            return field_references_model(self.field, ModelTuple(app_label, name_lower))
+        # Refuse the temptation to guess. This operation could be performed on
+        # a field referencing the specified model.
         return True
 
     def reduce(self, operation, in_between, app_label=None):
diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py
index 88f3507c22..b2d3f70fea 100644
--- a/django/db/migrations/operations/models.py
+++ b/django/db/migrations/operations/models.py
@@ -7,6 +7,7 @@ from django.utils.functional import cached_property
 from .fields import (
     AddField, AlterField, FieldOperation, RemoveField, RenameField,
 )
+from .utils import ModelTuple, field_references_model
 
 
 def _check_for_duplicates(arg_name, objs):
@@ -104,19 +105,15 @@ class CreateModel(ModelOperation):
             return True
 
         # Check we didn't inherit from the model
-        models_to_check = [
-            base for base in self.bases
-            if base is not models.Model and isinstance(base, (models.base.ModelBase, str))
-        ]
+        model_tuple = ModelTuple(app_label, name_lower)
+        for base in self.bases:
+            if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
+                    ModelTuple.from_model(base) == model_tuple):
+                return True
+
         # Check we have no FKs/M2Ms with it
-        for fname, field in self.fields:
-            if field.remote_field:
-                models_to_check.append(field.remote_field.model)
-        # Now go over all the models and check against them
-        for model in models_to_check:
-            model_app_label, model_name = self.model_to_key(model)
-            if (model_name == name_lower and app_label is None or
-                    not model_app_label or model_app_label == app_label):
+        for _name, field in self.fields:
+            if field_references_model(field, model_tuple):
                 return True
         return False
 
@@ -267,7 +264,7 @@ class RenameModel(ModelOperation):
         renamed_model.name = self.new_name
         state.models[app_label, self.new_name_lower] = renamed_model
         # Repoint all fields pointing to the old model to the new one.
-        old_model_tuple = app_label, self.old_name_lower
+        old_model_tuple = ModelTuple(app_label, self.old_name_lower)
         new_remote_model = '%s.%s' % (app_label, self.new_name)
         to_reload = []
         for (model_app_label, model_name), model_state in state.models.items():
@@ -276,7 +273,7 @@ class RenameModel(ModelOperation):
                 changed_field = None
                 remote_field = field.remote_field
                 if remote_field:
-                    remote_model_tuple = self._get_model_tuple(
+                    remote_model_tuple = ModelTuple.from_model(
                         remote_field.model, model_app_label, model_name
                     )
                     if remote_model_tuple == old_model_tuple:
@@ -284,7 +281,7 @@ class RenameModel(ModelOperation):
                         changed_field.remote_field.model = new_remote_model
                     through_model = getattr(remote_field, 'through', None)
                     if through_model:
-                        through_model_tuple = self._get_model_tuple(
+                        through_model_tuple = ModelTuple.from_model(
                             through_model, model_app_label, model_name
                         )
                         if through_model_tuple == old_model_tuple:
diff --git a/django/db/migrations/operations/utils.py b/django/db/migrations/operations/utils.py
index af23ea9563..34fdaba821 100644
--- a/django/db/migrations/operations/utils.py
+++ b/django/db/migrations/operations/utils.py
@@ -1,3 +1,8 @@
+from collections import namedtuple
+
+from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
+
+
 def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
     for state_app_label, state_model in state.models:
         for _, f in state.models[state_app_label, state_model].fields:
@@ -7,3 +12,42 @@ def is_referenced_by_foreign_key(state, model_name_lower, field, field_name):
                 if (f.to_fields[0] is None and field.primary_key) or field_name in f.to_fields:
                     return True
     return False
+
+
+class ModelTuple(namedtuple('ModelTupleBase', ('app_label', 'model_name'))):
+    @classmethod
+    def from_model(cls, model, app_label=None, model_name=None):
+        """
+        Take a model class or a 'app_label.ModelName' string and return a
+        ModelTuple('app_label', 'modelname'). The optional app_label and
+        model_name arguments are the defaults if "self" or "ModelName" are
+        passed.
+        """
+        if isinstance(model, str):
+            if model == RECURSIVE_RELATIONSHIP_CONSTANT:
+                return cls(app_label, model_name)
+            if '.' in model:
+                return cls(*model.lower().split('.', 1))
+            return cls(app_label, model.lower())
+        return cls(model._meta.app_label, model._meta.model_name)
+
+    def __eq__(self, other):
+        if isinstance(other, ModelTuple):
+            # Consider ModelTuple equal if their model_name is equal and either
+            # one of them is missing an app_label.
+            return self.model_name == other.model_name and (
+                self.app_label is None or other.app_label is None or self.app_label == other.app_label
+            )
+        return super().__eq__(other)
+
+
+def field_references_model(field, model_tuple):
+    """Return whether or not field references model_tuple."""
+    remote_field = field.remote_field
+    if remote_field:
+        if ModelTuple.from_model(remote_field.model) == model_tuple:
+            return True
+        through = getattr(remote_field, 'through', None)
+        if through and ModelTuple.from_model(through) == model_tuple:
+            return True
+    return False