mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #12937 -- Corrected the operation of select_related() when following an reverse relation on an inherited model. Thanks to subsume for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12814 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -1113,7 +1113,7 @@ class EmptyQuerySet(QuerySet): | ||||
|  | ||||
|  | ||||
| def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|                    requested=None, offset=0, only_load=None): | ||||
|                    requested=None, offset=0, only_load=None, local_only=False): | ||||
|     """ | ||||
|     Helper function that recursively returns an object with the specified | ||||
|     related attributes already populated. | ||||
| @@ -1141,6 +1141,8 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|      * only_load - if the query has had only() or defer() applied, | ||||
|        this is the list of field names that will be returned. If None, | ||||
|        the full field list for `klass` can be assumed. | ||||
|      * local_only - Only populate local fields. This is used when building | ||||
|        following reverse select-related relations | ||||
|     """ | ||||
|     if max_depth and requested is None and cur_depth > max_depth: | ||||
|         # We've recursed deeply enough; stop now. | ||||
| @@ -1153,9 +1155,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|         skip = set() | ||||
|         init_list = [] | ||||
|         # Build the list of fields that *haven't* been requested | ||||
|         for field in klass._meta.fields: | ||||
|         for field, model in klass._meta.get_fields_with_model(): | ||||
|             if field.name not in load_fields: | ||||
|                 skip.add(field.name) | ||||
|             elif local_only and model is not None: | ||||
|                 continue | ||||
|             else: | ||||
|                 init_list.append(field.attname) | ||||
|         # Retrieve all the requested fields | ||||
| @@ -1174,7 +1178,11 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|  | ||||
|     else: | ||||
|         # Load all fields on klass | ||||
|         field_count = len(klass._meta.fields) | ||||
|         if local_only: | ||||
|             field_names = [f.attname for f in klass._meta.local_fields] | ||||
|         else: | ||||
|             field_names = [f.attname for f in klass._meta.fields] | ||||
|         field_count = len(field_names) | ||||
|         fields = row[index_start : index_start + field_count] | ||||
|         # If all the select_related columns are None, then the related | ||||
|         # object must be non-existent - set the relation to None. | ||||
| @@ -1182,7 +1190,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|         if fields == (None,) * field_count: | ||||
|             obj = None | ||||
|         else: | ||||
|             obj = klass(*fields) | ||||
|             obj = klass(**dict(zip(field_names, fields))) | ||||
|  | ||||
|     # If an object was retrieved, set the database state. | ||||
|     if obj: | ||||
| @@ -1229,7 +1237,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|             next = requested[f.related_query_name()] | ||||
|             # Recursively retrieve the data for the related object | ||||
|             cached_row = get_cached_row(model, row, index_end, using, | ||||
|                 max_depth, cur_depth+1, next) | ||||
|                 max_depth, cur_depth+1, next, local_only=True) | ||||
|             # If the recursive descent found an object, populate the | ||||
|             # descriptor caches relevant to the object | ||||
|             if cached_row: | ||||
| @@ -1242,7 +1250,20 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, | ||||
|                     # If the related object exists, populate | ||||
|                     # the descriptor cache. | ||||
|                     setattr(rel_obj, f.get_cache_name(), obj) | ||||
|  | ||||
|                     # Now populate all the non-local field values | ||||
|                     # on the related object | ||||
|                     for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): | ||||
|                         if rel_model is not None: | ||||
|                             setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) | ||||
|                             # populate the field cache for any related object | ||||
|                             # that has already been retrieved | ||||
|                             if rel_field.rel: | ||||
|                                 try: | ||||
|                                     cached_obj = getattr(obj, rel_field.get_cache_name()) | ||||
|                                     setattr(rel_obj, rel_field.get_cache_name(), cached_obj) | ||||
|                                 except AttributeError: | ||||
|                                     # Related object hasn't been cached yet | ||||
|                                     pass | ||||
|     return obj, index_end | ||||
|  | ||||
| def delete_objects(seen_objs, using): | ||||
|   | ||||
| @@ -215,7 +215,7 @@ class SQLCompiler(object): | ||||
|         return result | ||||
|  | ||||
|     def get_default_columns(self, with_aliases=False, col_aliases=None, | ||||
|             start_alias=None, opts=None, as_pairs=False): | ||||
|             start_alias=None, opts=None, as_pairs=False, local_only=False): | ||||
|         """ | ||||
|         Computes the default columns for selecting every field in the base | ||||
|         model. Will sometimes be called to pull in related models (e.g. via | ||||
| @@ -240,6 +240,8 @@ class SQLCompiler(object): | ||||
|         if start_alias: | ||||
|             seen = {None: start_alias} | ||||
|         for field, model in opts.get_fields_with_model(): | ||||
|             if local_only and model is not None: | ||||
|                 continue | ||||
|             if start_alias: | ||||
|                 try: | ||||
|                     alias = seen[model] | ||||
| @@ -643,7 +645,7 @@ class SQLCompiler(object): | ||||
|                 ) | ||||
|                 used.add(alias) | ||||
|                 columns, aliases = self.get_default_columns(start_alias=alias, | ||||
|                     opts=model._meta, as_pairs=True) | ||||
|                     opts=model._meta, as_pairs=True, local_only=True) | ||||
|                 self.query.related_select_cols.extend(columns) | ||||
|                 self.query.related_select_fields.extend(model._meta.fields) | ||||
|  | ||||
|   | ||||
| @@ -43,8 +43,7 @@ class StatDetails(models.Model): | ||||
|  | ||||
|  | ||||
| class AdvancedUserStat(UserStat): | ||||
|     pass | ||||
|  | ||||
|     karma = models.IntegerField() | ||||
|  | ||||
| class Image(models.Model): | ||||
|     name = models.CharField(max_length=100) | ||||
|   | ||||
| @@ -2,7 +2,7 @@ from django import db | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
|  | ||||
| from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,  | ||||
| from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, | ||||
|     AdvancedUserStat, Image, Product) | ||||
|  | ||||
| class ReverseSelectRelatedTestCase(TestCase): | ||||
| @@ -22,7 +22,7 @@ class ReverseSelectRelatedTestCase(TestCase): | ||||
|  | ||||
|         user2 = User.objects.create(username="bob") | ||||
|         results2 = UserStatResult.objects.create(results='moar results') | ||||
|         advstat = AdvancedUserStat.objects.create(user=user2, posts=200, | ||||
|         advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, | ||||
|                                                   results=results2) | ||||
|         StatDetails.objects.create(base_stats=advstat, comments=250) | ||||
|  | ||||
| @@ -74,18 +74,21 @@ class ReverseSelectRelatedTestCase(TestCase): | ||||
|         self.assertQueries(2) | ||||
|  | ||||
|     def test_follow_from_child_class(self): | ||||
|         stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) | ||||
|         stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) | ||||
|         self.assertEqual(stat.statdetails.comments, 250) | ||||
|         self.assertEqual(stat.user.username, 'bob') | ||||
|         self.assertQueries(1) | ||||
|  | ||||
|     def test_follow_inheritance(self): | ||||
|         stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) | ||||
|         stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) | ||||
|         self.assertEqual(stat.advanceduserstat.posts, 200) | ||||
|         self.assertEqual(stat.user.username, 'bob') | ||||
|         self.assertEqual(stat.advanceduserstat.user.username, 'bob') | ||||
|         self.assertQueries(1) | ||||
|      | ||||
|  | ||||
|     def test_nullable_relation(self): | ||||
|         im = Image.objects.create(name="imag1") | ||||
|         p1 = Product.objects.create(name="Django Plushie", image=im) | ||||
|         p2 = Product.objects.create(name="Talking Django Plushie") | ||||
|          | ||||
|  | ||||
|         self.assertEqual(len(Product.objects.select_related("image")), 2) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user