From 4511aeb6b8b843ee913fb43a37c9686980210948 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= <akaariai@gmail.com>
Date: Mon, 17 Dec 2012 17:09:07 +0200
Subject: [PATCH] Moved join path generation to Field

Refs #19385
---
 django/contrib/contenttypes/generic.py |  11 +++
 django/db/models/fields/related.py     |  52 +++++++++++-
 django/db/models/related.py            |  12 +++
 django/db/models/sql/constants.py      |   6 --
 django/db/models/sql/query.py          | 110 ++++---------------------
 5 files changed, 91 insertions(+), 100 deletions(-)

diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
index 6aff07e568..be7a5e5a22 100644
--- a/django/contrib/contenttypes/generic.py
+++ b/django/contrib/contenttypes/generic.py
@@ -11,6 +11,7 @@ from django.db import connection
 from django.db.models import signals
 from django.db import models, router, DEFAULT_DB_ALIAS
 from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
+from django.db.models.related import PathInfo
 from django.forms import ModelForm
 from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
 from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@@ -160,6 +161,16 @@ class GenericRelation(RelatedField, Field):
         kwargs['serialize'] = False
         Field.__init__(self, **kwargs)
 
+    def get_path_info(self):
+        from_field = self.model._meta.pk
+        opts = self.rel.to._meta
+        target = opts.get_field_by_name(self.object_id_field_name)[0]
+        # Note that we are using different field for the join_field
+        # than from_field or to_field. This is a hack, but we need the
+        # GenericRelation to generate the extra SQL.
+        return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
+                opts, target, self)
+
     def get_choices_default(self):
         return Field.get_choices(self, include_blank=False)
 
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 90fe69e23c..4b6a5b0aed 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -5,7 +5,7 @@ from django.db.backends import util
 from django.db.models import signals, get_model
 from django.db.models.fields import (AutoField, Field, IntegerField,
     PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
-from django.db.models.related import RelatedObject
+from django.db.models.related import RelatedObject, PathInfo
 from django.db.models.query import QuerySet
 from django.db.models.query_utils import QueryWrapper
 from django.db.models.deletion import CASCADE
@@ -16,7 +16,6 @@ from django.utils.functional import curry, cached_property
 from django.core import exceptions
 from django import forms
 
-
 RECURSIVE_RELATIONSHIP_CONSTANT = 'self'
 
 pending_lookups = {}
@@ -1004,6 +1003,31 @@ class ForeignKey(RelatedField, Field):
         )
         Field.__init__(self, **kwargs)
 
+    def get_path_info(self):
+        """
+        Get path from this field to the related model.
+        """
+        opts = self.rel.to._meta
+        target = self.rel.get_related_field()
+        from_opts = self.model._meta
+        return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
+
+    def get_reverse_path_info(self):
+        """
+        Get path from the related model to this field's model.
+        """
+        opts = self.model._meta
+        from_field = self.rel.get_related_field()
+        from_opts = from_field.model._meta
+        pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
+        if from_field.model is self.model:
+            # Recursive foreign key to self.
+            target = opts.get_field_by_name(
+                self.rel.field_name)[0]
+        else:
+            target = opts.pk
+        return pathinfos, opts, target, self
+
     def validate(self, value, model_instance):
         if self.rel.parent_link:
             return
@@ -1198,6 +1222,30 @@ class ManyToManyField(RelatedField, Field):
         msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
         self.help_text = string_concat(self.help_text, ' ', msg)
 
+    def _get_path_info(self, direct=False):
+        """
+        Called by both direct an indirect m2m traversal.
+        """
+        pathinfos = []
+        int_model = self.rel.through
+        linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
+        linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+        if direct:
+            join1infos, _, _, _ = linkfield1.get_reverse_path_info()
+            join2infos, opts, target, final_field = linkfield2.get_path_info()
+        else:
+            join1infos, _, _, _ = linkfield2.get_reverse_path_info()
+            join2infos, opts, target, final_field = linkfield1.get_path_info()
+        pathinfos.extend(join1infos)
+        pathinfos.extend(join2infos)
+        return pathinfos, opts, target, final_field
+
+    def get_path_info(self):
+        return self._get_path_info(direct=True)
+
+    def get_reverse_path_info(self):
+        return self._get_path_info(direct=False)
+
     def get_choices_default(self):
         return Field.get_choices(self, include_blank=False)
 
diff --git a/django/db/models/related.py b/django/db/models/related.py
index a0dcec7132..702853533d 100644
--- a/django/db/models/related.py
+++ b/django/db/models/related.py
@@ -1,6 +1,15 @@
+from collections import namedtuple
+
 from django.utils.encoding import smart_text
 from django.db.models.fields import BLANK_CHOICE_DASH
 
+# PathInfo is used when converting lookups (fk__somecol). The contents
+# describe the relation in Model terms (model Options and Fields for both
+# sides of the relation. The join_field is the field backing the relation.
+PathInfo = namedtuple('PathInfo',
+                      'from_field to_field from_opts to_opts join_field '
+                      'm2m direct')
+
 class BoundRelatedObject(object):
     def __init__(self, related_object, field_mapping, original):
         self.relation = related_object
@@ -67,3 +76,6 @@ class RelatedObject(object):
 
     def get_cache_name(self):
         return "_%s_cache" % self.get_accessor_name()
+
+    def get_path_info(self):
+        return self.field.get_reverse_path_info()
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index 9f82f426ed..1764db7fcc 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -26,12 +26,6 @@ JoinInfo = namedtuple('JoinInfo',
                       'table_name rhs_alias join_type lhs_alias '
                       'lhs_join_col rhs_join_col nullable join_field')
 
-# PathInfo is used when converting lookups (fk__somecol). The contents
-# describe the join in Model terms (model Options and Fields for both
-# sides of the join. The rel_field is the field we are joining along.
-PathInfo = namedtuple('PathInfo',
-                      'from_field to_field from_opts to_opts join_field')
-
 # Pairs of column clauses to select, and (possibly None) field for the clause.
 SelectInfo = namedtuple('SelectInfo', 'col field')
 
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 841452636b..e5833b2b51 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -18,9 +18,10 @@ from django.db.models.constants import LOOKUP_SEP
 from django.db.models.expressions import ExpressionNode
 from django.db.models.fields import FieldDoesNotExist
 from django.db.models.loading import get_model
+from django.db.models.related import PathInfo
 from django.db.models.sql import aggregates as base_aggregates_module
 from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
-        ORDER_PATTERN, JoinInfo, SelectInfo, PathInfo)
+        ORDER_PATTERN, JoinInfo, SelectInfo)
 from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
 from django.db.models.sql.expressions import SQLEvaluator
 from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -1294,7 +1295,6 @@ class Query(object):
         contain the same value as the final field).
         """
         path = []
-        multijoin_pos = None
         for pos, name in enumerate(names):
             if name == 'pk':
                 name = opts.pk.name
@@ -1328,92 +1328,19 @@ class Query(object):
                         target = final_field.rel.get_related_field()
                         opts = int_model._meta
                         path.append(PathInfo(final_field, target, final_field.model._meta,
-                                             opts, final_field))
-            # We have five different cases to solve: foreign keys, reverse
-            # foreign keys, m2m fields (also reverse) and non-relational
-            # fields. We are mostly just using the related field API to
-            # fetch the from and to fields. The m2m fields are handled as
-            # two foreign keys, first one reverse, the second one direct.
-            if direct and not field.rel and not m2m:
+                                             opts, final_field, False, True))
+            if hasattr(field, 'get_path_info'):
+                pathinfos, opts, target, final_field = field.get_path_info()
+                path.extend(pathinfos)
+            else:
                 # Local non-relational field.
                 final_field = target = field
                 break
-            elif direct and not m2m:
-                # Foreign Key
-                opts = field.rel.to._meta
-                target = field.rel.get_related_field()
-                final_field = field
-                from_opts = field.model._meta
-                path.append(PathInfo(field, target, from_opts, opts, field))
-            elif not direct and not m2m:
-                # Revere foreign key
-                final_field = to_field = field.field
-                opts = to_field.model._meta
-                from_field = to_field.rel.get_related_field()
-                from_opts = from_field.model._meta
-                path.append(
-                    PathInfo(from_field, to_field, from_opts, opts, to_field))
-                if from_field.model is to_field.model:
-                    # Recursive foreign key to self.
-                    target = opts.get_field_by_name(
-                        field.field.rel.field_name)[0]
-                else:
-                    target = opts.pk
-            elif direct and m2m:
-                if not field.rel.through:
-                    # Gotcha! This is just a fake m2m field - a generic relation
-                    # field).
-                    from_field = opts.pk
-                    opts = field.rel.to._meta
-                    target = opts.get_field_by_name(field.object_id_field_name)[0]
-                    final_field = field
-                    # Note that we are using different field for the join_field
-                    # than from_field or to_field. This is a hack, but we need the
-                    # GenericRelation to generate the extra SQL.
-                    path.append(PathInfo(from_field, target, field.model._meta, opts,
-                                         field))
-                else:
-                    # m2m field. We are travelling first to the m2m table along a
-                    # reverse relation, then from m2m table to the target table.
-                    from_field1 = opts.get_field_by_name(
-                        field.m2m_target_field_name())[0]
-                    opts = field.rel.through._meta
-                    to_field1 = opts.get_field_by_name(field.m2m_field_name())[0]
-                    path.append(
-                        PathInfo(from_field1, to_field1, from_field1.model._meta,
-                                 opts, to_field1))
-                    final_field = from_field2 = opts.get_field_by_name(
-                        field.m2m_reverse_field_name())[0]
-                    opts = field.rel.to._meta
-                    target = to_field2 = opts.get_field_by_name(
-                        field.m2m_reverse_target_field_name())[0]
-                    path.append(
-                        PathInfo(from_field2, to_field2, from_field2.model._meta,
-                                 opts, from_field2))
-            elif not direct and m2m:
-                # This one is just like above, except we are travelling the
-                # fields in opposite direction.
-                field = field.field
-                from_field1 = opts.get_field_by_name(
-                    field.m2m_reverse_target_field_name())[0]
-                int_opts = field.rel.through._meta
-                to_field1 = int_opts.get_field_by_name(
-                    field.m2m_reverse_field_name())[0]
-                path.append(
-                    PathInfo(from_field1, to_field1, from_field1.model._meta,
-                             int_opts, to_field1))
-                final_field = from_field2 = int_opts.get_field_by_name(
-                    field.m2m_field_name())[0]
-                opts = field.opts
-                target = to_field2 = opts.get_field_by_name(
-                    field.m2m_target_field_name())[0]
-                path.append(PathInfo(from_field2, to_field2, from_field2.model._meta,
-                                     opts, from_field2))
-
-            if m2m and multijoin_pos is None:
-                multijoin_pos = pos
-            if not direct and not path[-1].to_field.unique and multijoin_pos is None:
-                multijoin_pos = pos
+        multijoin_pos = None
+        for m2mpos, pathinfo in enumerate(path):
+            if pathinfo.m2m:
+                multijoin_pos = m2mpos
+                break
 
         if pos != len(names) - 1:
             if pos == len(names) - 2:
@@ -1463,16 +1390,15 @@ class Query(object):
         # joins at this stage - we will need the information about join type
         # of the trimmed joins.
         for pos, join in enumerate(path):
-            from_field, to_field, from_opts, opts, join_field = join
-            direct = join_field == from_field
-            if direct:
-                nullable = self.is_nullable(from_field)
+            opts = join.to_opts
+            if join.direct:
+                nullable = self.is_nullable(join.from_field)
             else:
                 nullable = True
-            connection = alias, opts.db_table, from_field.column, to_field.column
-            reuse = None if direct or to_field.unique else can_reuse
+            connection = alias, opts.db_table, join.from_field.column, join.to_field.column
+            reuse = can_reuse if join.m2m else None
             alias = self.join(connection, reuse=reuse,
-                              nullable=nullable, join_field=join_field)
+                              nullable=nullable, join_field=join.join_field)
             joins.append(alias)
         return final_field, target, opts, joins, path