mirror of
				https://github.com/django/django.git
				synced 2025-10-26 07:06:08 +00:00 
			
		
		
		
	Refs #24267 -- Implemented lookups for related fields
Previously related fields didn't implement get_lookup, instead related fields were treated specially. This commit removed some of the special handling. In particular, related fields return Lookup instances now, too. Other notable changes in this commit is removal of support for annotations in names_to_path().
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							8654c6a732
						
					
				
				
					commit
					b68212f539
				
			| @@ -15,7 +15,10 @@ from django.db.models.fields import ( | ||||
|     BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField, | ||||
|     PositiveSmallIntegerField, | ||||
| ) | ||||
| from django.db.models.lookups import IsNull | ||||
| from django.db.models.fields.related_lookups import ( | ||||
|     RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn, | ||||
|     RelatedLessThan, RelatedLessThanOrEqual, | ||||
| ) | ||||
| from django.db.models.query import QuerySet | ||||
| from django.db.models.query_utils import PathInfo | ||||
| from django.utils import six | ||||
| @@ -1336,6 +1339,16 @@ class ForeignObjectRel(object): | ||||
|     def one_to_one(self): | ||||
|         return self.field.one_to_one | ||||
|  | ||||
|     def get_prep_lookup(self, lookup_name, value): | ||||
|         return self.field.get_prep_lookup(lookup_name, value) | ||||
|  | ||||
|     def get_internal_type(self): | ||||
|         return self.field.get_internal_type() | ||||
|  | ||||
|     @property | ||||
|     def db_type(self): | ||||
|         return self.field.db_type | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return '<%s: %s.%s>' % ( | ||||
|             type(self).__name__, | ||||
| @@ -1760,67 +1773,25 @@ class ForeignObject(RelatedField): | ||||
|         pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] | ||||
|         return pathinfos | ||||
|  | ||||
|     def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, | ||||
|                               raw_value): | ||||
|         from django.db.models.sql.where import SubqueryConstraint, AND, OR | ||||
|         root_constraint = constraint_class() | ||||
|         assert len(targets) == len(sources) | ||||
|         if len(lookups) > 1: | ||||
|             raise exceptions.FieldError( | ||||
|                 "Cannot resolve keyword %r into field. Choices are: %s" % ( | ||||
|                     lookups[0], | ||||
|                     ", ".join(f.name for f in self.model._meta.get_fields()), | ||||
|                 ) | ||||
|             ) | ||||
|         lookup_type = lookups[0] | ||||
|     def get_lookup(self, lookup_name): | ||||
|         if lookup_name == 'in': | ||||
|             return RelatedIn | ||||
|         elif lookup_name == 'exact': | ||||
|             return RelatedExact | ||||
|         elif lookup_name == 'gt': | ||||
|             return RelatedGreaterThan | ||||
|         elif lookup_name == 'gte': | ||||
|             return RelatedGreaterThanOrEqual | ||||
|         elif lookup_name == 'lt': | ||||
|             return RelatedLessThan | ||||
|         elif lookup_name == 'lte': | ||||
|             return RelatedLessThanOrEqual | ||||
|         elif lookup_name != 'isnull': | ||||
|             raise TypeError('Related Field got invalid lookup: %s' % lookup_name) | ||||
|         return super(ForeignObject, self).get_lookup(lookup_name) | ||||
|  | ||||
|         def get_normalized_value(value): | ||||
|             from django.db.models import Model | ||||
|             if isinstance(value, Model): | ||||
|                 value_list = [] | ||||
|                 for source in sources: | ||||
|                     # Account for one-to-one relations when sent a different model | ||||
|                     while not isinstance(value, source.model) and source.rel: | ||||
|                         source = source.rel.to._meta.get_field(source.rel.field_name) | ||||
|                     value_list.append(getattr(value, source.attname)) | ||||
|                 return tuple(value_list) | ||||
|             elif not isinstance(value, tuple): | ||||
|                 return (value,) | ||||
|             return value | ||||
|  | ||||
|         is_multicolumn = len(self.related_fields) > 1 | ||||
|         if (hasattr(raw_value, '_as_sql') or | ||||
|                 hasattr(raw_value, 'get_compiler')): | ||||
|             root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets], | ||||
|                                                    [source.name for source in sources], raw_value), | ||||
|                                 AND) | ||||
|         elif lookup_type == 'isnull': | ||||
|             root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND) | ||||
|         elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] | ||||
|                                          and not is_multicolumn)): | ||||
|             value = get_normalized_value(raw_value) | ||||
|             for target, source, val in zip(targets, sources, value): | ||||
|                 lookup_class = target.get_lookup(lookup_type) | ||||
|                 root_constraint.add( | ||||
|                     lookup_class(target.get_col(alias, source), val), AND) | ||||
|         elif lookup_type in ['range', 'in'] and not is_multicolumn: | ||||
|             values = [get_normalized_value(value) for value in raw_value] | ||||
|             value = [val[0] for val in values] | ||||
|             lookup_class = targets[0].get_lookup(lookup_type) | ||||
|             root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND) | ||||
|         elif lookup_type == 'in': | ||||
|             values = [get_normalized_value(value) for value in raw_value] | ||||
|             root_constraint.connector = OR | ||||
|             for value in values: | ||||
|                 value_constraint = constraint_class() | ||||
|                 for source, target, val in zip(sources, targets, value): | ||||
|                     lookup_class = target.get_lookup('exact') | ||||
|                     lookup = lookup_class(target.get_col(alias, source), val) | ||||
|                     value_constraint.add(lookup, AND) | ||||
|                 root_constraint.add(value_constraint, OR) | ||||
|         else: | ||||
|             raise TypeError('Related Field got invalid lookup: %s' % lookup_type) | ||||
|         return root_constraint | ||||
|     def get_transform(self, *args, **kwargs): | ||||
|         raise NotImplementedError('Relational fields do not support transforms.') | ||||
|  | ||||
|     @property | ||||
|     def attnames(self): | ||||
| @@ -2017,6 +1988,9 @@ class ForeignKey(ForeignObject): | ||||
|         else: | ||||
|             return self.related_field.get_db_prep_save(value, connection=connection) | ||||
|  | ||||
|     def get_db_prep_value(self, value, connection, prepared=False): | ||||
|         return self.related_field.get_db_prep_value(value, connection, prepared) | ||||
|  | ||||
|     def value_to_string(self, obj): | ||||
|         if not obj: | ||||
|             # In required many-to-one fields with only one available choice, | ||||
|   | ||||
							
								
								
									
										130
									
								
								django/db/models/fields/related_lookups.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								django/db/models/fields/related_lookups.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,130 @@ | ||||
| from django.db.models.lookups import ( | ||||
|     Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class MultiColSource(object): | ||||
|     contains_aggregate = False | ||||
|  | ||||
|     def __init__(self, alias, targets, sources, field): | ||||
|         self.targets, self.sources, self.field, self.alias = targets, sources, field, alias | ||||
|         self.output_field = self.field | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{}({}, {})".format( | ||||
|             self.__class__.__name__, self.alias, self.field) | ||||
|  | ||||
|     def relabeled_clone(self, relabels): | ||||
|         return self.__class__(relabels.get(self.alias, self.alias), | ||||
|                               self.targets, self.sources, self.field) | ||||
|  | ||||
|  | ||||
| def get_normalized_value(value, lhs): | ||||
|     from django.db.models import Model | ||||
|     if isinstance(value, Model): | ||||
|         value_list = [] | ||||
|         # Account for one-to-one relations when sent a different model | ||||
|         sources = lhs.output_field.get_path_info()[-1].target_fields | ||||
|         for source in sources: | ||||
|             while not isinstance(value, source.model) and source.rel: | ||||
|                 source = source.rel.to._meta.get_field(source.rel.field_name) | ||||
|             value_list.append(getattr(value, source.attname)) | ||||
|         return tuple(value_list) | ||||
|     if not isinstance(value, tuple): | ||||
|         return (value,) | ||||
|     return value | ||||
|  | ||||
|  | ||||
| class RelatedIn(In): | ||||
|     def get_prep_lookup(self): | ||||
|         if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): | ||||
|             # If we get here, we are dealing with single-column relations. | ||||
|             self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] | ||||
|             # We need to run the related field's get_prep_lookup(). Consider case | ||||
|             # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself | ||||
|             # doesn't have validation for non-integers, so we must run validation | ||||
|             # using the target field. | ||||
|             if hasattr(self.lhs.output_field, 'get_path_info'): | ||||
|                 # Run the target field's get_prep_lookup. We can safely assume there is | ||||
|                 # only one as we don't get to the direct value branch otherwise. | ||||
|                 self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup( | ||||
|                     self.lookup_name, self.rhs) | ||||
|         return super(RelatedIn, self).get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if isinstance(self.lhs, MultiColSource): | ||||
|             # For multicolumn lookups we need to build a multicolumn where clause. | ||||
|             # This clause is either a SubqueryConstraint (for values that need to be compiled to | ||||
|             # SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses. | ||||
|             from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR | ||||
|  | ||||
|             root_constraint = WhereNode(connector=OR) | ||||
|             if self.rhs_is_direct_value(): | ||||
|                 values = [get_normalized_value(value, self.lhs) for value in self.rhs] | ||||
|                 for value in values: | ||||
|                     value_constraint = WhereNode() | ||||
|                     for source, target, val in zip(self.lhs.sources, self.lhs.targets, value): | ||||
|                         lookup_class = target.get_lookup('exact') | ||||
|                         lookup = lookup_class(target.get_col(self.lhs.alias, source), val) | ||||
|                         value_constraint.add(lookup, AND) | ||||
|                     root_constraint.add(value_constraint, OR) | ||||
|             else: | ||||
|                 root_constraint.add( | ||||
|                     SubqueryConstraint( | ||||
|                         self.lhs.alias, [target.column for target in self.lhs.targets], | ||||
|                         [source.name for source in self.lhs.sources], self.rhs), | ||||
|                     AND) | ||||
|             return root_constraint.as_sql(compiler, connection) | ||||
|         else: | ||||
|             return super(RelatedIn, self).as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class RelatedLookupMixin(object): | ||||
|     def get_prep_lookup(self): | ||||
|         if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): | ||||
|             # If we get here, we are dealing with single-column relations. | ||||
|             self.rhs = get_normalized_value(self.rhs, self.lhs)[0] | ||||
|             # We need to run the related field's get_prep_lookup(). Consider case | ||||
|             # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself | ||||
|             # doesn't have validation for non-integers, so we must run validation | ||||
|             # using the target field. | ||||
|             if hasattr(self.lhs.output_field, 'get_path_info'): | ||||
|                 # Get the target field. We can safely assume there is only one | ||||
|                 # as we don't get to the direct value branch otherwise. | ||||
|                 self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup( | ||||
|                     self.lookup_name, self.rhs) | ||||
|  | ||||
|         return super(RelatedLookupMixin, self).get_prep_lookup() | ||||
|  | ||||
|     def as_sql(self, compiler, connection): | ||||
|         if isinstance(self.lhs, MultiColSource): | ||||
|             assert self.rhs_is_direct_value() | ||||
|             self.rhs = get_normalized_value(self.rhs, self.lhs) | ||||
|             from django.db.models.sql.where import WhereNode, AND | ||||
|             root_constraint = WhereNode() | ||||
|             for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs): | ||||
|                 lookup_class = target.get_lookup(self.lookup_name) | ||||
|                 root_constraint.add( | ||||
|                     lookup_class(target.get_col(self.lhs.alias, source), val), AND) | ||||
|             return root_constraint.as_sql(compiler, connection) | ||||
|         return super(RelatedLookupMixin, self).as_sql(compiler, connection) | ||||
|  | ||||
|  | ||||
| class RelatedExact(RelatedLookupMixin, Exact): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedLessThan(RelatedLookupMixin, LessThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedGreaterThan(RelatedLookupMixin, GreaterThan): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual): | ||||
|     pass | ||||
| @@ -250,7 +250,7 @@ deferred_class_factory.__safe_for_unpickling__ = True | ||||
|  | ||||
| def refs_aggregate(lookup_parts, aggregates): | ||||
|     """ | ||||
|     A little helper method to check if the lookup_parts contains references | ||||
|     A helper method to check if the lookup_parts contains references | ||||
|     to the given aggregates set. Because the LOOKUP_SEP is contained in the | ||||
|     default annotation names we must check each prefix of the lookup_parts | ||||
|     for a match. | ||||
| @@ -260,3 +260,17 @@ def refs_aggregate(lookup_parts, aggregates): | ||||
|         if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate: | ||||
|             return aggregates[level_n_lookup], lookup_parts[n:] | ||||
|     return False, () | ||||
|  | ||||
|  | ||||
| def refs_expression(lookup_parts, annotations): | ||||
|     """ | ||||
|     A helper method to check if the lookup_parts contains references | ||||
|     to the given annotations set. Because the LOOKUP_SEP is contained in the | ||||
|     default annotation names we must check each prefix of the lookup_parts | ||||
|     for a match. | ||||
|     """ | ||||
|     for n in range(len(lookup_parts) + 1): | ||||
|         level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) | ||||
|         if level_n_lookup in annotations and annotations[level_n_lookup]: | ||||
|             return annotations[level_n_lookup], lookup_parts[n:] | ||||
|     return False, () | ||||
|   | ||||
| @@ -17,7 +17,8 @@ from django.db import DEFAULT_DB_ALIAS, connections | ||||
| from django.db.models.aggregates import Count | ||||
| from django.db.models.constants import LOOKUP_SEP | ||||
| from django.db.models.expressions import Col, Ref | ||||
| from django.db.models.query_utils import Q, PathInfo, refs_aggregate | ||||
| from django.db.models.fields.related_lookups import MultiColSource | ||||
| from django.db.models.query_utils import Q, PathInfo, refs_expression | ||||
| from django.db.models.sql.constants import ( | ||||
|     INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE, | ||||
| ) | ||||
| @@ -1006,7 +1007,7 @@ class Query(object): | ||||
|         """ | ||||
|         lookup_splitted = lookup.split(LOOKUP_SEP) | ||||
|         if self._annotations: | ||||
|             aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations) | ||||
|             aggregate, aggregate_lookups = refs_expression(lookup_splitted, self.annotations) | ||||
|             if aggregate: | ||||
|                 return aggregate_lookups, (), aggregate | ||||
|         _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) | ||||
| @@ -1157,24 +1158,26 @@ class Query(object): | ||||
|         if can_reuse is not None: | ||||
|             can_reuse.update(join_list) | ||||
|         used_joins = set(used_joins).union(set(join_list)) | ||||
|  | ||||
|         # Process the join list to see if we can remove any non-needed joins from | ||||
|         # the far end (fewer tables in a query is better). | ||||
|         targets, alias, join_list = self.trim_joins(sources, join_list, path) | ||||
|  | ||||
|         if hasattr(field, 'get_lookup_constraint'): | ||||
|             # For now foreign keys get special treatment. This should be | ||||
|             # refactored when composite fields lands. | ||||
|             condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, | ||||
|                                                     lookups, value) | ||||
|             lookup_type = lookups[-1] | ||||
|         else: | ||||
|             assert(len(targets) == 1) | ||||
|             if hasattr(targets[0], 'as_sql'): | ||||
|                 # handle Expressions as annotations | ||||
|                 col = targets[0] | ||||
|         if field.is_relation: | ||||
|             # No support for transforms for relational fields | ||||
|             assert len(lookups) == 1 | ||||
|             lookup_class = field.get_lookup(lookups[0]) | ||||
|             # Undo the changes done in setup_joins() if hasattr(final_field, 'field') branch | ||||
|             # This hack is needed as long as the field.rel isn't like a real field. | ||||
|             if field.get_path_info()[-1].target_fields != sources: | ||||
|                 target_field = field.rel | ||||
|             else: | ||||
|                 col = targets[0].get_col(alias, field) | ||||
|                 target_field = field | ||||
|             if len(targets) == 1: | ||||
|                 lhs = targets[0].get_col(alias, target_field) | ||||
|             else: | ||||
|                 lhs = MultiColSource(alias, targets, sources, target_field) | ||||
|             condition = lookup_class(lhs, value) | ||||
|             lookup_type = lookup_class.lookup_name | ||||
|         else: | ||||
|             col = targets[0].get_col(alias, field) | ||||
|             condition = self.build_lookup(lookups, col, value) | ||||
|             lookup_type = condition.lookup_name | ||||
|  | ||||
| @@ -1284,14 +1287,6 @@ class Query(object): | ||||
|                     ) | ||||
|                 model = field.model._meta.concrete_model | ||||
|             except FieldDoesNotExist: | ||||
|                 # is it an annotation? | ||||
|                 if self._annotations and name in self._annotations: | ||||
|                     field, model = self._annotations[name], None | ||||
|                     if not field.contains_aggregate: | ||||
|                         # Local non-relational field. | ||||
|                         final_field = field | ||||
|                         targets = (field,) | ||||
|                         break | ||||
|                 # We didn't find the current field, so move position back | ||||
|                 # one step. | ||||
|                 pos -= 1 | ||||
| @@ -1985,7 +1980,7 @@ def is_reverse_o2o(field): | ||||
|     A little helper to check if the given field is reverse-o2o. The field is | ||||
|     expected to be some sort of relation field or related object. | ||||
|     """ | ||||
|     return not hasattr(field, 'rel') and field.field.unique | ||||
|     return field.is_relation and field.one_to_one and not field.concrete | ||||
|  | ||||
|  | ||||
| class JoinPromoter(object): | ||||
|   | ||||
| @@ -144,22 +144,26 @@ class GenericRelationTests(TestCase): | ||||
|         tag.save() | ||||
|  | ||||
|     def test_ticket_20378(self): | ||||
|         # Create a couple of extra HasLinkThing so that the autopk value | ||||
|         # isn't the same for Link and HasLinkThing. | ||||
|         hs1 = HasLinkThing.objects.create() | ||||
|         hs2 = HasLinkThing.objects.create() | ||||
|         l1 = Link.objects.create(content_object=hs1) | ||||
|         l2 = Link.objects.create(content_object=hs2) | ||||
|         hs3 = HasLinkThing.objects.create() | ||||
|         hs4 = HasLinkThing.objects.create() | ||||
|         l1 = Link.objects.create(content_object=hs3) | ||||
|         l2 = Link.objects.create(content_object=hs4) | ||||
|         self.assertQuerysetEqual( | ||||
|             HasLinkThing.objects.filter(links=l1), | ||||
|             [hs1], lambda x: x) | ||||
|             [hs3], lambda x: x) | ||||
|         self.assertQuerysetEqual( | ||||
|             HasLinkThing.objects.filter(links=l2), | ||||
|             [hs2], lambda x: x) | ||||
|             [hs4], lambda x: x) | ||||
|         self.assertQuerysetEqual( | ||||
|             HasLinkThing.objects.exclude(links=l2), | ||||
|             [hs1], lambda x: x) | ||||
|             [hs1, hs2, hs3], lambda x: x, ordered=False) | ||||
|         self.assertQuerysetEqual( | ||||
|             HasLinkThing.objects.exclude(links=l1), | ||||
|             [hs2], lambda x: x) | ||||
|             [hs1, hs2, hs4], lambda x: x, ordered=False) | ||||
|  | ||||
|     def test_ticket_20564(self): | ||||
|         b1 = B.objects.create() | ||||
|   | ||||
| @@ -3678,3 +3678,11 @@ class TestTicket24279(TestCase): | ||||
|         School.objects.create() | ||||
|         qs = School.objects.filter(Q(pk__in=()) | Q()) | ||||
|         self.assertQuerysetEqual(qs, []) | ||||
|  | ||||
|  | ||||
| class TestInvalidValuesRelation(TestCase): | ||||
|     def test_invalid_values(self): | ||||
|         with self.assertRaises(ValueError): | ||||
|             Annotation.objects.filter(tag='abc') | ||||
|         with self.assertRaises(ValueError): | ||||
|             Annotation.objects.filter(tag__in=[123, 'abc']) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user