From 052a011ee6122482a471795c1994bbcfdb069611 Mon Sep 17 00:00:00 2001
From: Luke Plant <L.Plant.98@cantab.net>
Date: Fri, 7 Oct 2011 16:05:53 +0000
Subject: [PATCH] Fixed #17003 - prefetch_related should support foreign
 keys/one-to-one

Support for `GenericForeignKey` is also included.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16939 bcc190cf-cafb-0310-a4f2-bffc1f526a37
---
 django/contrib/contenttypes/generic.py      |  52 +++++++-
 django/contrib/contenttypes/models.py       |   6 +
 django/db/models/fields/related.py          |  86 ++++++++++----
 django/db/models/query.py                   | 125 +++++++++++++++-----
 docs/ref/models/querysets.txt               |  87 +++++++++-----
 docs/releases/1.4.txt                       |  19 +--
 tests/modeltests/prefetch_related/models.py |   6 +-
 tests/modeltests/prefetch_related/tests.py  | 104 ++++++++++++----
 8 files changed, 366 insertions(+), 119 deletions(-)

diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py
index 9c6048a7ee..23abe01f51 100644
--- a/django/contrib/contenttypes/generic.py
+++ b/django/contrib/contenttypes/generic.py
@@ -2,7 +2,10 @@
 Classes allowing "generic" relations through ContentType and object-id fields.
 """
 
+from collections import defaultdict
 from functools import partial
+from operator import attrgetter
+
 from django.core.exceptions import ObjectDoesNotExist
 from django.db import connection
 from django.db.models import signals
@@ -59,6 +62,49 @@ class GenericForeignKey(object):
             # This should never happen. I love comments like this, don't you?
             raise Exception("Impossible arguments to GFK.get_content_type!")
 
+    def get_prefetch_query_set(self, instances):
+        # For efficiency, group the instances by content type and then do one
+        # query per model
+        fk_dict = defaultdict(list)
+        # We need one instance for each group in order to get the right db:
+        instance_dict = {}
+        ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
+        for instance in instances:
+            # We avoid looking for values if either ct_id or fkey value is None
+            ct_id = getattr(instance, ct_attname)
+            if ct_id is not None:
+                fk_val = getattr(instance, self.fk_field)
+                if fk_val is not None:
+                    fk_dict[ct_id].append(fk_val)
+                    instance_dict[ct_id] = instance
+
+        ret_val = []
+        for ct_id, fkeys in fk_dict.items():
+            instance = instance_dict[ct_id]
+            ct = self.get_content_type(id=ct_id, using=instance._state.db)
+            ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
+
+        # For doing the join in Python, we have to match both the FK val and the
+        # content type, so the 'attr' vals we return need to be callables that
+        # will return a (fk, class) pair.
+        def gfk_key(obj):
+            ct_id = getattr(obj, ct_attname)
+            if ct_id is None:
+                return None
+            else:
+                return (getattr(obj, self.fk_field),
+                        self.get_content_type(id=ct_id,
+                                              using=obj._state.db).model_class())
+
+        return (ret_val,
+                lambda obj: (obj._get_pk_val(), obj.__class__),
+                gfk_key,
+                True,
+                self.cache_attr)
+
+    def is_cached(self, instance):
+        return hasattr(instance, self.cache_attr)
+
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
@@ -282,7 +328,11 @@ def create_generic_related_manager(superclass):
                     [obj._get_pk_val() for obj in instances]
                 }
             qs = super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)
-            return (qs, self.object_id_field_name, 'pk')
+            return (qs,
+                    attrgetter(self.object_id_field_name),
+                    lambda obj: obj._get_pk_val(),
+                    False,
+                    self.prefetch_cache_name)
 
         def add(self, *objs):
             for obj in objs:
diff --git a/django/contrib/contenttypes/models.py b/django/contrib/contenttypes/models.py
index 9ab059eaec..c7e3dd79af 100644
--- a/django/contrib/contenttypes/models.py
+++ b/django/contrib/contenttypes/models.py
@@ -113,5 +113,11 @@ class ContentType(models.Model):
         """
         return self.model_class()._base_manager.using(self._state.db).get(**kwargs)
 
+    def get_all_objects_for_this_type(self, **kwargs):
+        """
+        Returns all objects of this type for the keyword arguments given.
+        """
+        return self.model_class()._base_manager.using(self._state.db).filter(**kwargs)
+
     def natural_key(self):
         return (self.app_label, self.model)
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 7bcae2ee54..c82c397288 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -1,3 +1,5 @@
+from operator import attrgetter
+
 from django.db import connection, router
 from django.db.backends import util
 from django.db.models import signals, get_model
@@ -227,6 +229,22 @@ class SingleRelatedObjectDescriptor(object):
         self.related = related
         self.cache_name = related.get_cache_name()
 
+    def is_cached(self, instance):
+        return hasattr(instance, self.cache_name)
+
+    def get_query_set(self, **db_hints):
+        db = router.db_for_read(self.related.model, **db_hints)
+        return self.related.model._base_manager.using(db)
+
+    def get_prefetch_query_set(self, instances):
+        vals = [instance._get_pk_val() for instance in instances]
+        params = {'%s__pk__in' % self.related.field.name: vals}
+        return (self.get_query_set(),
+                attrgetter(self.related.field.attname),
+                lambda obj: obj._get_pk_val(),
+                True,
+                self.cache_name)
+
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
@@ -234,8 +252,7 @@ class SingleRelatedObjectDescriptor(object):
             return getattr(instance, self.cache_name)
         except AttributeError:
             params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
-            db = router.db_for_read(self.related.model, instance=instance)
-            rel_obj = self.related.model._base_manager.using(db).get(**params)
+            rel_obj = self.get_query_set(instance=instance).get(**params)
             setattr(instance, self.cache_name, rel_obj)
             return rel_obj
 
@@ -283,14 +300,40 @@ class ReverseSingleRelatedObjectDescriptor(object):
     # ReverseSingleRelatedObjectDescriptor instance.
     def __init__(self, field_with_rel):
         self.field = field_with_rel
+        self.cache_name = self.field.get_cache_name()
+
+    def is_cached(self, instance):
+        return hasattr(instance, self.cache_name)
+
+    def get_query_set(self, **db_hints):
+        db = router.db_for_read(self.field.rel.to, **db_hints)
+        rel_mgr = self.field.rel.to._default_manager
+        # If the related manager indicates that it should be used for
+        # related fields, respect that.
+        if getattr(rel_mgr, 'use_for_related_fields', False):
+            return rel_mgr.using(db)
+        else:
+            return QuerySet(self.field.rel.to).using(db)
+
+    def get_prefetch_query_set(self, instances):
+        vals = [getattr(instance, self.field.attname) for instance in instances]
+        other_field = self.field.rel.get_related_field()
+        if other_field.rel:
+            params = {'%s__pk__in' % self.field.rel.field_name: vals}
+        else:
+            params = {'%s__in' % self.field.rel.field_name: vals}
+        return (self.get_query_set().filter(**params),
+                attrgetter(self.field.rel.field_name),
+                attrgetter(self.field.attname),
+                True,
+                self.cache_name)
 
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
 
-        cache_name = self.field.get_cache_name()
         try:
-            return getattr(instance, cache_name)
+            return getattr(instance, self.cache_name)
         except AttributeError:
             val = getattr(instance, self.field.attname)
             if val is None:
@@ -303,16 +346,9 @@ class ReverseSingleRelatedObjectDescriptor(object):
                 params = {'%s__pk' % self.field.rel.field_name: val}
             else:
                 params = {'%s__exact' % self.field.rel.field_name: val}
-
-            # If the related manager indicates that it should be used for
-            # related fields, respect that.
-            rel_mgr = self.field.rel.to._default_manager
-            db = router.db_for_read(self.field.rel.to, instance=instance)
-            if getattr(rel_mgr, 'use_for_related_fields', False):
-                rel_obj = rel_mgr.using(db).get(**params)
-            else:
-                rel_obj = QuerySet(self.field.rel.to).using(db).get(**params)
-            setattr(instance, cache_name, rel_obj)
+            qs = self.get_query_set(instance=instance)
+            rel_obj = qs.get(**params)
+            setattr(instance, self.cache_name, rel_obj)
             return rel_obj
 
     def __set__(self, instance, value):
@@ -425,15 +461,15 @@ class ForeignRelatedObjectsDescriptor(object):
                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
 
             def get_prefetch_query_set(self, instances):
-                """
-                Return a queryset that does the bulk lookup needed
-                by prefetch_related functionality.
-                """
                 db = self._db or router.db_for_read(self.model)
                 query = {'%s__%s__in' % (rel_field.name, attname):
                              [getattr(obj, attname) for obj in instances]}
                 qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
-                return (qs, rel_field.get_attname(), attname)
+                return (qs,
+                        attrgetter(rel_field.get_attname()),
+                        attrgetter(attname),
+                        False,
+                        rel_field.related_query_name())
 
             def add(self, *objs):
                 for obj in objs:
@@ -507,12 +543,6 @@ def create_many_related_manager(superclass, rel):
                 return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**self.core_filters)
 
         def get_prefetch_query_set(self, instances):
-            """
-            Returns a tuple:
-            (queryset of instances of self.model that are related to passed in instances
-             attr of returned instances needed for matching
-             attr of passed in instances needed for matching)
-            """
             from django.db import connections
             db = self._db or router.db_for_read(self.model)
             query = {'%s__pk__in' % self.query_field_name:
@@ -534,7 +564,11 @@ def create_many_related_manager(superclass, rel):
             qs = qs.extra(select={'_prefetch_related_val':
                                       '%s.%s' % (qn(join_table), qn(source_col))})
             select_attname = fk.rel.get_related_field().get_attname()
-            return (qs, '_prefetch_related_val', select_attname)
+            return (qs,
+                    attrgetter('_prefetch_related_val'),
+                    attrgetter(select_attname),
+                    False,
+                    self.prefetch_cache_name)
 
         # If the ManyToMany relation has an intermediary model,
         # the add and remove methods do not exist.
diff --git a/django/db/models/query.py b/django/db/models/query.py
index b21db2e521..1461125af4 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1612,36 +1612,42 @@ def prefetch_related_objects(result_cache, related_lookups):
                 break
 
             # Descend down tree
-            try:
-                rel_obj = getattr(obj_list[0], attr)
-            except AttributeError:
+
+            # We assume that objects retrieved are homogenous (which is the premise
+            # of prefetch_related), so what applies to first object applies to all.
+            first_obj = obj_list[0]
+            prefetcher, attr_found, is_fetched = get_prefetcher(first_obj, attr)
+
+            if not attr_found:
                 raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid "
                                      "parameter to prefetch_related()" %
-                                     (attr, obj_list[0].__class__.__name__, lookup))
+                                     (attr, first_obj.__class__.__name__, lookup))
 
-            can_prefetch = hasattr(rel_obj, 'get_prefetch_query_set')
-            if level == len(attrs) - 1 and not can_prefetch:
-                # Last one, this *must* resolve to a related manager.
-                raise ValueError("'%s' does not resolve to a supported 'many related"
-                                 " manager' for model %s - this is an invalid"
-                                 " parameter to prefetch_related()."
-                                 % (lookup, model.__name__))
+            if level == len(attrs) - 1 and prefetcher is None:
+                # Last one, this *must* resolve to something that supports
+                # prefetching, otherwise there is no point adding it and the
+                # developer asking for it has made a mistake.
+                raise ValueError("'%s' does not resolve to a item that supports "
+                                 "prefetching - this is an invalid parameter to "
+                                 "prefetch_related()." % lookup)
 
-            if can_prefetch:
+            if prefetcher is not None and not is_fetched:
                 # Check we didn't do this already
                 current_lookup = LOOKUP_SEP.join(attrs[0:level+1])
                 if current_lookup in done_queries:
                     obj_list = done_queries[current_lookup]
                 else:
-                    relmanager = rel_obj
-                    obj_list, additional_prl = prefetch_one_level(obj_list, relmanager, attr)
+                    obj_list, additional_prl = prefetch_one_level(obj_list, prefetcher, attr)
                     for f in additional_prl:
                         new_prl = LOOKUP_SEP.join([current_lookup, f])
                         related_lookups.append(new_prl)
                     done_queries[current_lookup] = obj_list
             else:
-                # Assume we've got some singly related object. We replace
-                # the current list of parent objects with that list.
+                # Either a singly related object that has already been fetched
+                # (e.g. via select_related), or hopefully some other property
+                # that doesn't support prefetching but needs to be traversed.
+
+                # We replace the current list of parent objects with that list.
                 obj_list = [getattr(obj, attr) for obj in obj_list]
 
                 # Filter out 'None' so that we can continue with nullable
@@ -1649,18 +1655,73 @@ def prefetch_related_objects(result_cache, related_lookups):
                 obj_list = [obj for obj in obj_list if obj is not None]
 
 
-def prefetch_one_level(instances, relmanager, attname):
+def get_prefetcher(instance, attr):
+    """
+    For the attribute 'attr' on the given instance, finds
+    an object that has a get_prefetch_query_set().
+    Return a 3 tuple containing:
+    (the object with get_prefetch_query_set (or None),
+     a boolean that is False if the attribute was not found at all,
+     a boolean that is True if the attribute has already been fetched)
+    """
+    prefetcher = None
+    attr_found = False
+    is_fetched = False
+
+    # For singly related objects, we have to avoid getting the attribute
+    # from the object, as this will trigger the query. So we first try
+    # on the class, in order to get the descriptor object.
+    rel_obj_descriptor = getattr(instance.__class__, attr, None)
+    if rel_obj_descriptor is None:
+        try:
+            rel_obj = getattr(instance, attr)
+            attr_found = True
+        except AttributeError:
+            pass
+    else:
+        attr_found = True
+        if rel_obj_descriptor:
+            # singly related object, descriptor object has the
+            # get_prefetch_query_set() method.
+            if hasattr(rel_obj_descriptor, 'get_prefetch_query_set'):
+                prefetcher = rel_obj_descriptor
+                if rel_obj_descriptor.is_cached(instance):
+                    is_fetched = True
+            else:
+                # descriptor doesn't support prefetching, so we go ahead and get
+                # the attribute on the instance rather than the class to
+                # support many related managers
+                rel_obj = getattr(instance, attr)
+                if hasattr(rel_obj, 'get_prefetch_query_set'):
+                    prefetcher = rel_obj
+    return prefetcher, attr_found, is_fetched
+
+
+def prefetch_one_level(instances, prefetcher, attname):
     """
     Helper function for prefetch_related_objects
 
-    Runs prefetches on all instances using the manager relmanager,
-    assigning results to queryset against instance.attname.
+    Runs prefetches on all instances using the prefetcher object,
+    assigning results to relevant caches in instance.
 
     The prefetched objects are returned, along with any additional
     prefetches that must be done due to prefetch_related lookups
     found from default managers.
     """
-    rel_qs, rel_obj_attr, instance_attr = relmanager.get_prefetch_query_set(instances)
+    # prefetcher must have a method get_prefetch_query_set() which takes a list
+    # of instances, and returns a tuple:
+
+    # (queryset of instances of self.model that are related to passed in instances,
+    #  callable that gets value to be matched for returned instances,
+    #  callable that gets value to be matched for passed in instances,
+    #  boolean that is True for singly related objects,
+    #  cache name to assign to).
+
+    # The 'values to be matched' must be hashable as they will be used
+    # in a dictionary.
+
+    rel_qs, rel_obj_attr, instance_attr, single, cache_name =\
+        prefetcher.get_prefetch_query_set(instances)
     # We have to handle the possibility that the default manager itself added
     # prefetch_related lookups to the QuerySet we just got back. We don't want to
     # trigger the prefetch_related functionality by evaluating the query.
@@ -1676,17 +1737,25 @@ def prefetch_one_level(instances, relmanager, attname):
 
     rel_obj_cache = {}
     for rel_obj in all_related_objects:
-        rel_attr_val = getattr(rel_obj, rel_obj_attr)
+        rel_attr_val = rel_obj_attr(rel_obj)
         if rel_attr_val not in rel_obj_cache:
             rel_obj_cache[rel_attr_val] = []
         rel_obj_cache[rel_attr_val].append(rel_obj)
 
     for obj in instances:
-        qs = getattr(obj, attname).all()
-        instance_attr_val = getattr(obj, instance_attr)
-        qs._result_cache = rel_obj_cache.get(instance_attr_val, [])
-        # We don't want the individual qs doing prefetch_related now, since we
-        # have merged this into the current work.
-        qs._prefetch_done = True
-        obj._prefetched_objects_cache[attname] = qs
+        instance_attr_val = instance_attr(obj)
+        vals = rel_obj_cache.get(instance_attr_val, [])
+        if single:
+            # Need to assign to single cache on instance
+            if vals:
+                setattr(obj, cache_name, vals[0])
+        else:
+            # Multi, attribute represents a manager with an .all() method that
+            # returns a QuerySet
+            qs = getattr(obj, attname).all()
+            qs._result_cache = vals
+            # We don't want the individual qs doing prefetch_related now, since we
+            # have merged this into the current work.
+            qs._prefetch_done = True
+            obj._prefetched_objects_cache[cache_name] = qs
     return all_related_objects, additional_prl
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index ea8e0ff6e3..238fe64915 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -696,14 +696,26 @@ prefetch_related
 .. versionadded:: 1.4
 
 Returns a ``QuerySet`` that will automatically retrieve, in a single batch,
-related many-to-many and many-to-one objects for each of the specified lookups.
+related objects for each of the specified lookups.
 
-This is similar to ``select_related`` for the 'many related objects' case, but
-note that ``prefetch_related`` causes a separate query to be issued for each set
-of related objects that you request, unlike ``select_related`` which modifies
-the original query with joins in order to get the related objects. With
-``prefetch_related``, the additional queries are done as soon as the QuerySet
-begins to be evaluated.
+This has a similar purpose to ``select_related``, in that both are designed to
+stop the deluge of database queries that is caused by accessing related objects,
+but the strategy is quite different.
+
+``select_related`` works by creating a SQL join and including the fields of the
+related object in the SELECT statement. For this reason, ``select_related`` gets
+the related objects in the same database query. However, to avoid the much
+larger result set that would result from joining across a 'many' relationship,
+``select_related`` is limited to single-valued relationships - foreign key and
+one-to-one.
+
+``prefetch_related``, on the other hand, does a separate lookup for each
+relationship, and does the 'joining' in Python. This allows it to prefetch
+many-to-many and many-to-one objects, which cannot be done using
+``select_related``, in addition to the foreign key and one-to-one relationships
+that are supported by ``select_related``. It also supports prefetching of
+:class:`~django.contrib.contenttypes.generic.GenericRelation` and
+:class:`~django.contrib.contenttypes.generic.GenericForeignKey`.
 
 For example, suppose you have these models::
 
@@ -733,14 +745,17 @@ All the relevant toppings will be fetched in a single query, and used to make
 ``QuerySets`` that have a pre-filled cache of the relevant results. These
 ``QuerySets`` are then used in the ``self.toppings.all()`` calls.
 
-Please note that use of ``prefetch_related`` will mean that the additional
-queries run will **always** be executed - even if you never use the related
-objects - and it always fully populates the result cache on the primary
-``QuerySet`` (which can sometimes be avoided in other cases).
+The additional queries are executed after the QuerySet has begun to be evaluated
+and the primary query has been executed. Note that the result cache of the
+primary QuerySet and all specified related objects will then be fully loaded
+into memory, which is often avoided in other cases - even after a query has been
+executed in the database, QuerySet normally tries to make uses of chunking
+between the database to avoid loading all objects into memory before you need
+them.
 
 Also remember that, as always with QuerySets, any subsequent chained methods
-will ignore previously cached results, and retrieve data using a fresh database
-query. So, if you write the following:
+which imply a different database query will ignore previously cached results,
+and retrieve data using a fresh database query. So, if you write the following:
 
     >>> pizzas = Pizza.objects.prefetch_related('toppings')
     >>> [list(pizza.toppings.filter(spicy=True)) for pizza in pizzas]
@@ -749,12 +764,6 @@ query. So, if you write the following:
 you - in fact it hurts performance, since you have done a database query that
 you haven't used. So use this feature with caution!
 
-The lookups that must be supplied to this method can be any attributes on the
-model instances which represent related queries that return multiple
-objects. This includes attributes representing the 'many' side of ``ForeignKey``
-relationships, forward and reverse ``ManyToManyField`` attributes, and also any
-``GenericRelations``.
-
 You can also use the normal join syntax to do related fields of related
 fields. Suppose we have an additional model to the example above::
 
@@ -770,24 +779,40 @@ This will prefetch all pizzas belonging to restaurants, and all toppings
 belonging to those pizzas. This will result in a total of 3 database queries -
 one for the restaurants, one for the pizzas, and one for the toppings.
 
-    >>> Restaurant.objects.select_related('best_pizza').prefetch_related('best_pizza__toppings')
+    >>> Restaurant.objects.prefetch_related('best_pizza__toppings')
 
 This will fetch the best pizza and all the toppings for the best pizza for each
-restaurant. This will be done in 2 database queries - one for the restaurants
-and 'best pizzas' combined (achieved through use of ``select_related``), and one
-for the toppings.
+restaurant. This will be done in 3 database queries - one for the restaurants,
+one for the 'best pizzas', and one for one for the toppings.
 
-Chaining ``prefetch_related`` calls will accumulate the fields that should have
-this behavior applied. To clear any ``prefetch_related`` behavior, pass `None`
-as a parameter::
+Of course, the ``best_pizza`` relationship could also be fetched using
+``select_related`` to reduce the query count to 2:
+
+    >>> Restaurant.objects.select_related('best_pizza').prefetch_related('best_pizza__toppings')
+
+Since the prefetch is executed after the main query (which includes the joins
+needed by ``select_related``), it is able to detect that the ``best_pizza``
+objects have already been fetched, and it will skip fetching them again.
+
+Chaining ``prefetch_related`` calls will accumulate the lookups that are
+prefetched. To clear any ``prefetch_related`` behavior, pass `None` as a
+parameter::
 
    >>> non_prefetched = qs.prefetch_related(None)
 
-One difference when using ``prefetch_related`` is that, in some circumstances,
-objects created by a query can be shared between the different objects that they
-are related to i.e. a single Python model instance can appear at more than one
-point in the tree of objects that are returned. Normally this behavior will not
-be a problem, and will in fact save both memory and CPU time.
+One difference to note when using ``prefetch_related`` is that objects created
+by a query can be shared between the different objects that they are related to
+i.e. a single Python model instance can appear at more than one point in the
+tree of objects that are returned. This will normally happen with foreign key
+relationships. Typically this behavior will not be a problem, and will in fact
+save both memory and CPU time.
+
+While ``prefetch_related`` supports prefetching ``GenericForeignKey``
+relationships, the number of queries will depend on the data. Since a
+``GenericForeignKey`` can reference data in multiple tables, one query per table
+referenced is needed, rather than one query for all the items. There could be
+additional queries on the ``ContentType`` table if the relevant rows have not
+already been fetched.
 
 extra
 ~~~~~
diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt
index 6a97060f40..c6b547dab6 100644
--- a/docs/releases/1.4.txt
+++ b/docs/releases/1.4.txt
@@ -66,15 +66,18 @@ information.
 ``QuerySet.prefetch_related``
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-Analagous to :meth:`~django.db.models.query.QuerySet.select_related` but for
-many-to-many relationships,
+Similar to :meth:`~django.db.models.query.QuerySet.select_related` but with a
+different strategy and broader scope,
 :meth:`~django.db.models.query.QuerySet.prefetch_related` has been added to
-:class:`~django.db.models.query.QuerySet`. This method returns a new ``QuerySet``
-that will prefetch in a single batch each of the specified related lookups as
-soon as it begins to be evaluated (e.g. by iterating over it). This enables you
-to fix many instances of a very common performance problem, in which your code
-ends up doing O(n) database queries (or worse) if objects on your primary
-``QuerySet`` each have many related objects that you also need.
+:class:`~django.db.models.query.QuerySet`. This method returns a new
+``QuerySet`` that will prefetch in a single batch each of the specified related
+lookups as soon as it begins to be evaluated. Unlike ``select_related``, it does
+the joins in Python, not in the database, and supports many-to-many
+relationships, :class:`~django.contrib.contenttypes.generic.GenericForeignKey`
+and more. This enables you to fix many instances of a very common performance
+problem, in which your code ends up doing O(n) database queries (or worse) if
+objects on your primary ``QuerySet`` each have many related objects that you
+also need.
 
 HTML5
 ~~~~~
diff --git a/tests/modeltests/prefetch_related/models.py b/tests/modeltests/prefetch_related/models.py
index ab28496f37..1c14c88818 100644
--- a/tests/modeltests/prefetch_related/models.py
+++ b/tests/modeltests/prefetch_related/models.py
@@ -104,13 +104,17 @@ class Department(models.Model):
         ordering = ['id']
 
 
-## Generic relation tests
+## GenericRelation/GenericForeignKey tests
 
 class TaggedItem(models.Model):
     tag = models.SlugField()
     content_type = models.ForeignKey(ContentType, related_name="taggeditem_set2")
     object_id = models.PositiveIntegerField()
     content_object = generic.GenericForeignKey('content_type', 'object_id')
+    created_by_ct = models.ForeignKey(ContentType, null=True,
+                                      related_name='taggeditem_set3')
+    created_by_fkey = models.PositiveIntegerField(null=True)
+    created_by = generic.GenericForeignKey('created_by_ct', 'created_by_fkey',)
 
     def __unicode__(self):
         return self.tag
diff --git a/tests/modeltests/prefetch_related/tests.py b/tests/modeltests/prefetch_related/tests.py
index 45202f2af8..bdbb0568c3 100644
--- a/tests/modeltests/prefetch_related/tests.py
+++ b/tests/modeltests/prefetch_related/tests.py
@@ -54,6 +54,13 @@ class PrefetchRelatedTests(TestCase):
         normal_lists = [list(a.books.all()) for a in Author.objects.all()]
         self.assertEqual(lists, normal_lists)
 
+    def test_foreignkey_forward(self):
+        with self.assertNumQueries(2):
+            books = [a.first_book for a in Author.objects.prefetch_related('first_book')]
+
+        normal_books = [a.first_book for a in Author.objects.all()]
+        self.assertEqual(books, normal_books)
+
     def test_foreignkey_reverse(self):
         with self.assertNumQueries(2):
             lists = [list(b.first_time_authors.all())
@@ -175,12 +182,12 @@ class PrefetchRelatedTests(TestCase):
         self.assertTrue('prefetch_related' in str(cm.exception))
 
     def test_invalid_final_lookup(self):
-        qs = Book.objects.prefetch_related('authors__first_book')
+        qs = Book.objects.prefetch_related('authors__name')
         with self.assertRaises(ValueError) as cm:
             list(qs)
 
         self.assertTrue('prefetch_related' in str(cm.exception))
-        self.assertTrue("first_book" in str(cm.exception))
+        self.assertTrue("name" in str(cm.exception))
 
 
 class DefaultManagerTests(TestCase):
@@ -222,39 +229,68 @@ class DefaultManagerTests(TestCase):
 
 class GenericRelationTests(TestCase):
 
-    def test_traverse_GFK(self):
-        """
-        Test that we can traverse a 'content_object' with prefetch_related()
-        """
-        # In fact, there is no special support for this in prefetch_related code
-        # - we can traverse any object that will lead us to objects that have
-        # related managers.
-
+    def setUp(self):
         book1 = Book.objects.create(title="Winnie the Pooh")
         book2 = Book.objects.create(title="Do you like green eggs and spam?")
+        book3 = Book.objects.create(title="Three Men In A Boat")
 
         reader1 = Reader.objects.create(name="me")
         reader2 = Reader.objects.create(name="you")
+        reader3 = Reader.objects.create(name="someone")
 
-        book1.read_by.add(reader1)
+        book1.read_by.add(reader1, reader2)
         book2.read_by.add(reader2)
+        book3.read_by.add(reader3)
 
-        TaggedItem.objects.create(tag="awesome", content_object=book1)
-        TaggedItem.objects.create(tag="awesome", content_object=book2)
+        self.book1, self.book2, self.book3 = book1, book2, book3
+        self.reader1, self.reader2, self.reader3 = reader1, reader2, reader3
+
+    def test_prefetch_GFK(self):
+        TaggedItem.objects.create(tag="awesome", content_object=self.book1)
+        TaggedItem.objects.create(tag="great", content_object=self.reader1)
+        TaggedItem.objects.create(tag="stupid", content_object=self.book2)
+        TaggedItem.objects.create(tag="amazing", content_object=self.reader3)
+
+        # 1 for TaggedItem table, 1 for Book table, 1 for Reader table
+        with self.assertNumQueries(3):
+            qs = TaggedItem.objects.prefetch_related('content_object')
+            list(qs)
+
+    def test_traverse_GFK(self):
+        """
+        Test that we can traverse a 'content_object' with prefetch_related() and
+        get to related objects on the other side (assuming it is suitably
+        filtered)
+        """
+        TaggedItem.objects.create(tag="awesome", content_object=self.book1)
+        TaggedItem.objects.create(tag="awesome", content_object=self.book2)
+        TaggedItem.objects.create(tag="awesome", content_object=self.book3)
+        TaggedItem.objects.create(tag="awesome", content_object=self.reader1)
+        TaggedItem.objects.create(tag="awesome", content_object=self.reader2)
 
         ct = ContentType.objects.get_for_model(Book)
 
-        # We get 4 queries - 1 for main query, 2 for each access to
-        # 'content_object' because these can't be handled by select_related, and
-        # 1 for the 'read_by' relation.
-        with self.assertNumQueries(4):
+        # We get 3 queries - 1 for main query, 1 for content_objects since they
+        # all use the same table, and 1 for the 'read_by' relation.
+        with self.assertNumQueries(3):
             # If we limit to books, we know that they will have 'read_by'
             # attributes, so the following makes sense:
-            qs = TaggedItem.objects.select_related('content_type').prefetch_related('content_object__read_by').filter(tag='awesome').filter(content_type=ct, tag='awesome')
-            readers_of_awesome_books = [r.name for tag in qs
-                                        for r in tag.content_object.read_by.all()]
-            self.assertEqual(readers_of_awesome_books, ["me", "you"])
+            qs = TaggedItem.objects.filter(content_type=ct, tag='awesome').prefetch_related('content_object__read_by')
+            readers_of_awesome_books = set([r.name for tag in qs
+                                            for r in tag.content_object.read_by.all()])
+            self.assertEqual(readers_of_awesome_books, set(["me", "you", "someone"]))
 
+    def test_nullable_GFK(self):
+        TaggedItem.objects.create(tag="awesome", content_object=self.book1,
+                                  created_by=self.reader1)
+        TaggedItem.objects.create(tag="great", content_object=self.book2)
+        TaggedItem.objects.create(tag="rubbish", content_object=self.book3)
+
+        with self.assertNumQueries(2):
+            result = [t.created_by for t in TaggedItem.objects.prefetch_related('created_by')]
+
+        self.assertEqual(result,
+                         [t.created_by for t in TaggedItem.objects.all()])
 
     def test_generic_relation(self):
         b = Bookmark.objects.create(url='http://www.djangoproject.com/')
@@ -311,9 +347,14 @@ class MultiTableInheritanceTest(TestCase):
         self.assertEquals(lst, lst2)
 
     def test_parent_link_prefetch(self):
-        with self.assertRaises(ValueError) as cm:
-            qs = list(AuthorWithAge.objects.prefetch_related('author'))
-        self.assertTrue('prefetch_related' in str(cm.exception))
+        with self.assertNumQueries(2):
+            [a.author for a in AuthorWithAge.objects.prefetch_related('author')]
+
+    def test_child_link_prefetch(self):
+        with self.assertNumQueries(2):
+            l = [a.authorwithage for a in Author.objects.prefetch_related('authorwithage')]
+
+        self.assertEqual(l, [a.authorwithage for a in Author.objects.all()])
 
 
 class ForeignKeyToFieldTest(TestCase):
@@ -406,6 +447,8 @@ class NullableTest(TestCase):
         worker2 = Employee.objects.create(name="Angela", boss=boss)
 
     def test_traverse_nullable(self):
+        # Because we use select_related() for 'boss', it doesn't need to be
+        # prefetched, but we can still traverse it although it contains some nulls
         with self.assertNumQueries(2):
             qs = Employee.objects.select_related('boss').prefetch_related('boss__serfs')
             co_serfs = [list(e.boss.serfs.all()) if e.boss is not None else []
@@ -416,3 +459,16 @@ class NullableTest(TestCase):
                         for e in qs2]
 
         self.assertEqual(co_serfs, co_serfs2)
+
+    def test_prefetch_nullable(self):
+        # One for main employee, one for boss, one for serfs
+        with self.assertNumQueries(3):
+            qs = Employee.objects.prefetch_related('boss__serfs')
+            co_serfs = [list(e.boss.serfs.all()) if e.boss is not None else []
+                        for e in qs]
+
+        qs2 =  Employee.objects.all()
+        co_serfs2 =  [list(e.boss.serfs.all()) if e.boss is not None else []
+                        for e in qs2]
+
+        self.assertEqual(co_serfs, co_serfs2)