mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #10414 -- Made select_related() fail on invalid field names.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							b27db97b23
						
					
				
				
					commit
					3daa9d60be
				
			| @@ -1,3 +1,4 @@ | ||||
| from itertools import chain | ||||
| import warnings | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| @@ -599,6 +600,14 @@ class SQLCompiler(object): | ||||
|         (for example, cur_depth=1 means we are looking at models with direct | ||||
|         connections to the root model). | ||||
|         """ | ||||
|         def _get_field_choices(): | ||||
|             direct_choices = (f.name for (f, _) in opts.get_fields_with_model() if f.rel) | ||||
|             reverse_choices = ( | ||||
|                 f.field.related_query_name() | ||||
|                 for f in opts.get_all_related_objects() if f.field.unique | ||||
|             ) | ||||
|             return chain(direct_choices, reverse_choices) | ||||
|  | ||||
|         if not restricted and self.query.max_depth and cur_depth > self.query.max_depth: | ||||
|             # We've recursed far enough; bail out. | ||||
|             return | ||||
| @@ -611,6 +620,7 @@ class SQLCompiler(object): | ||||
|  | ||||
|         # Setup for the case when only particular related fields should be | ||||
|         # included in the related selection. | ||||
|         fields_found = set() | ||||
|         if requested is None: | ||||
|             if isinstance(self.query.select_related, dict): | ||||
|                 requested = self.query.select_related | ||||
| @@ -619,6 +629,24 @@ class SQLCompiler(object): | ||||
|                 restricted = False | ||||
|  | ||||
|         for f, model in opts.get_fields_with_model(): | ||||
|             fields_found.add(f.name) | ||||
|  | ||||
|             if restricted: | ||||
|                 next = requested.get(f.name, {}) | ||||
|                 if not f.rel: | ||||
|                     # If a non-related field is used like a relation, | ||||
|                     # or if a single non-relational field is given. | ||||
|                     if next or (cur_depth == 1 and f.name in requested): | ||||
|                         raise FieldError( | ||||
|                             "Non-relational field given in select_related: '%s'. " | ||||
|                             "Choices are: %s" % ( | ||||
|                                 f.name, | ||||
|                                 ", ".join(_get_field_choices()) or '(none)', | ||||
|                             ) | ||||
|                         ) | ||||
|             else: | ||||
|                 next = False | ||||
|  | ||||
|             # The get_fields_with_model() returns None for fields that live | ||||
|             # in the field's local model. So, for those fields we want to use | ||||
|             # the f.model - that is the field's local model. | ||||
| @@ -632,13 +660,9 @@ class SQLCompiler(object): | ||||
|             columns, _ = self.get_default_columns(start_alias=alias, | ||||
|                     opts=f.rel.to._meta, as_pairs=True) | ||||
|             self.query.related_select_cols.extend( | ||||
|                 SelectInfo((col[0], col[1].column), col[1]) for col in columns) | ||||
|             if restricted: | ||||
|                 next = requested.get(f.name, {}) | ||||
|             else: | ||||
|                 next = False | ||||
|             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, | ||||
|                     next, restricted) | ||||
|                 SelectInfo((col[0], col[1].column), col[1]) for col in columns | ||||
|             ) | ||||
|             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted) | ||||
|  | ||||
|         if restricted: | ||||
|             related_fields = [ | ||||
| @@ -651,8 +675,10 @@ class SQLCompiler(object): | ||||
|                                               only_load.get(model), reverse=True): | ||||
|                     continue | ||||
|  | ||||
|                 _, _, _, joins, _ = self.query.setup_joins( | ||||
|                     [f.related_query_name()], opts, root_alias) | ||||
|                 related_field_name = f.related_query_name() | ||||
|                 fields_found.add(related_field_name) | ||||
|  | ||||
|                 _, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias) | ||||
|                 alias = joins[-1] | ||||
|                 from_parent = (opts.model if issubclass(model, opts.model) | ||||
|                                else None) | ||||
| @@ -664,6 +690,17 @@ class SQLCompiler(object): | ||||
|                 self.fill_related_selections(model._meta, alias, cur_depth + 1, | ||||
|                                              next, restricted) | ||||
|  | ||||
|             fields_not_found = set(requested.keys()).difference(fields_found) | ||||
|             if fields_not_found: | ||||
|                 invalid_fields = ("'%s'" % s for s in fields_not_found) | ||||
|                 raise FieldError( | ||||
|                     'Invalid field name(s) given in select_related: %s. ' | ||||
|                     'Choices are: %s' % ( | ||||
|                         ', '.join(invalid_fields), | ||||
|                         ', '.join(_get_field_choices()) or '(none)', | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|     def deferred_to_columns(self): | ||||
|         """ | ||||
|         Converts the self.deferred_loading data structure to mapping of table | ||||
|   | ||||
| @@ -681,6 +681,24 @@ lookups:: | ||||
|     ... | ||||
|     ValueError: Cannot query "<Book: Django>": Must be "Author" instance. | ||||
|  | ||||
| ``select_related()`` now checks given fields | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| ``select_related()`` now validates that the given fields actually exist. | ||||
| Previously, nonexistent fields were silently ignored. Now, an error is raised:: | ||||
|  | ||||
|     >>> book = Book.objects.select_related('nonexistent_field') | ||||
|     Traceback (most recent call last): | ||||
|     ... | ||||
|     FieldError: Invalid field name(s) given in select_related: 'nonexistent_field' | ||||
|  | ||||
| The validation also makes sure that the given field is relational:: | ||||
|  | ||||
|     >>> book = Book.objects.select_related('name') | ||||
|     Traceback (most recent call last): | ||||
|     ... | ||||
|     FieldError: Non-relational field given in select_related: 'name' | ||||
|  | ||||
| Default ``EmailField.max_length`` increased to 254 | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -181,7 +181,7 @@ class NonAggregateAnnotationTestCase(TestCase): | ||||
|             other_chain=F('chain'), | ||||
|             is_open=Value(True, BooleanField()), | ||||
|             book_isbn=F('books__isbn') | ||||
|         ).select_related('store').order_by('book_isbn').filter(chain='Westfield') | ||||
|         ).order_by('book_isbn').filter(chain='Westfield') | ||||
|  | ||||
|         self.assertQuerysetEqual( | ||||
|             qs, [ | ||||
|   | ||||
| @@ -10,6 +10,9 @@ the select-related behavior will traverse. | ||||
| from django.db import models | ||||
| from django.utils.encoding import python_2_unicode_compatible | ||||
|  | ||||
| from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation | ||||
| from django.contrib.contenttypes.models import ContentType | ||||
|  | ||||
| # Who remembers high school biology? | ||||
|  | ||||
|  | ||||
| @@ -94,3 +97,41 @@ class HybridSpecies(models.Model): | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.name | ||||
|  | ||||
|  | ||||
| @python_2_unicode_compatible | ||||
| class Topping(models.Model): | ||||
|     name = models.CharField(max_length=30) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.name | ||||
|  | ||||
|  | ||||
| @python_2_unicode_compatible | ||||
| class Pizza(models.Model): | ||||
|     name = models.CharField(max_length=100) | ||||
|     toppings = models.ManyToManyField(Topping) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.name | ||||
|  | ||||
|  | ||||
| @python_2_unicode_compatible | ||||
| class TaggedItem(models.Model): | ||||
|     tag = models.CharField(max_length=30) | ||||
|  | ||||
|     content_type = models.ForeignKey(ContentType, related_name='select_related_tagged_items') | ||||
|     object_id = models.PositiveIntegerField() | ||||
|     content_object = GenericForeignKey('content_type', 'object_id') | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.tag | ||||
|  | ||||
|  | ||||
| @python_2_unicode_compatible | ||||
| class Bookmark(models.Model): | ||||
|     url = models.URLField() | ||||
|     tags = GenericRelation(TaggedItem) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.url | ||||
|   | ||||
| @@ -1,8 +1,12 @@ | ||||
| from __future__ import unicode_literals | ||||
|  | ||||
| from django.test import TestCase | ||||
| from django.core.exceptions import FieldError | ||||
|  | ||||
| from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies | ||||
| from .models import ( | ||||
|     Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies, | ||||
|     Pizza, TaggedItem, Bookmark, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class SelectRelatedTests(TestCase): | ||||
| @@ -126,6 +130,12 @@ class SelectRelatedTests(TestCase): | ||||
|             orders = [o.genus.family.order.name for o in world] | ||||
|             self.assertEqual(orders, ['Agaricales']) | ||||
|  | ||||
|     def test_single_related_field(self): | ||||
|         with self.assertNumQueries(1): | ||||
|             species = Species.objects.select_related('genus__name') | ||||
|             names = [s.genus.name for s in species] | ||||
|             self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum']) | ||||
|  | ||||
|     def test_field_traversal(self): | ||||
|         with self.assertNumQueries(1): | ||||
|             s = (Species.objects.all() | ||||
| @@ -152,3 +162,47 @@ class SelectRelatedTests(TestCase): | ||||
|             obj = queryset[0] | ||||
|             self.assertEqual(obj.parent_1, parent_1) | ||||
|             self.assertEqual(obj.parent_2, parent_2) | ||||
|  | ||||
|  | ||||
| class SelectRelatedValidationTests(TestCase): | ||||
|     """ | ||||
|     select_related() should thrown an error on fields that do not exist and | ||||
|     non-relational fields. | ||||
|     """ | ||||
|     non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s" | ||||
|     invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s" | ||||
|  | ||||
|     def test_non_relational_field(self): | ||||
|         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')): | ||||
|             list(Species.objects.select_related('name__some_field')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')): | ||||
|             list(Species.objects.select_related('name')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', '(none)')): | ||||
|             list(Domain.objects.select_related('name')) | ||||
|  | ||||
|     def test_many_to_many_field(self): | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('toppings', '(none)')): | ||||
|             list(Pizza.objects.select_related('toppings')) | ||||
|  | ||||
|     def test_reverse_relational_field(self): | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('child_1', 'genus')): | ||||
|             list(Species.objects.select_related('child_1')) | ||||
|  | ||||
|     def test_invalid_field(self): | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', 'genus')): | ||||
|             list(Species.objects.select_related('invalid_field')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('related_invalid_field', 'family')): | ||||
|             list(Species.objects.select_related('genus__related_invalid_field')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', '(none)')): | ||||
|             list(Domain.objects.select_related('invalid_field')) | ||||
|  | ||||
|     def test_generic_relations(self): | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('tags', '')): | ||||
|             list(Bookmark.objects.select_related('tags')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('content_object', 'content_type')): | ||||
|             list(TaggedItem.objects.select_related('content_object')) | ||||
|   | ||||
| @@ -2,6 +2,7 @@ from __future__ import unicode_literals | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| from django.core.exceptions import FieldError | ||||
| from django.test import TestCase | ||||
|  | ||||
| from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, | ||||
| @@ -208,3 +209,21 @@ class ReverseSelectRelatedTestCase(TestCase): | ||||
|             self.assertEqual(p.child1.name1, 'n1') | ||||
|         with self.assertNumQueries(1): | ||||
|             self.assertEqual(p.child1.child4.name1, 'n1') | ||||
|  | ||||
|  | ||||
| class ReverseSelectRelatedValidationTests(TestCase): | ||||
|     """ | ||||
|     Rverse related fields should be listed in the validation message when an | ||||
|     invalid field is given in select_related(). | ||||
|     """ | ||||
|     non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s" | ||||
|     invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s" | ||||
|  | ||||
|     def test_reverse_related_validation(self): | ||||
|         fields = 'userprofile, userstat' | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)): | ||||
|             list(User.objects.select_related('foobar')) | ||||
|  | ||||
|         with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): | ||||
|             list(User.objects.select_related('username')) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user