diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 50f8b44158..fd97757b14 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -48,18 +48,37 @@ def get_normalized_value(value, lhs): class RelatedIn(In): def get_prep_lookup(self): - if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): - # If we get here, we are dealing with single-column relations. - self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] - # We need to run the related field's get_prep_value(). Consider case - # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself - # doesn't have validation for non-integers, so we must run validation - # using the target field. - if hasattr(self.lhs.output_field, 'path_infos'): - # Run the target field's get_prep_value. We can safely assume there is - # only one as we don't get to the direct value branch otherwise. - target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] - self.rhs = [target_field.get_prep_value(v) for v in self.rhs] + if not isinstance(self.lhs, MultiColSource): + if self.rhs_is_direct_value(): + # If we get here, we are dealing with single-column relations. + self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] + # We need to run the related field's get_prep_value(). Consider + # case ForeignKey to IntegerField given value 'abc'. The + # ForeignKey itself doesn't have validation for non-integers, + # so we must run validation using the target field. + if hasattr(self.lhs.output_field, 'path_infos'): + # Run the target field's get_prep_value. We can safely + # assume there is only one as we don't get to the direct + # value branch otherwise. + target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] + self.rhs = [target_field.get_prep_value(v) for v in self.rhs] + elif ( + not getattr(self.rhs, 'has_select_fields', True) and + not getattr(self.lhs.field.target_field, 'primary_key', False) + ): + self.rhs.clear_select_clause() + if ( + getattr(self.lhs.output_field, 'primary_key', False) and + self.lhs.output_field.model == self.rhs.model + ): + # A case like + # Restaurant.objects.filter(place__in=restaurant_qs), where + # place is a OneToOneField and the primary key of + # Restaurant. + target_field = self.lhs.field.name + else: + target_field = self.lhs.field.target_field.name + self.rhs.add_fields([target_field], True) return super().get_prep_lookup() def as_sql(self, compiler, connection): @@ -88,20 +107,7 @@ class RelatedIn(In): [source.name for source in self.lhs.sources], self.rhs), AND) return root_constraint.as_sql(compiler, connection) - else: - if (not getattr(self.rhs, 'has_select_fields', True) and - not getattr(self.lhs.field.target_field, 'primary_key', False)): - self.rhs.clear_select_clause() - if (getattr(self.lhs.output_field, 'primary_key', False) and - self.lhs.output_field.model == self.rhs.model): - # A case like Restaurant.objects.filter(place__in=restaurant_qs), - # where place is a OneToOneField and the primary key of - # Restaurant. - target_field = self.lhs.field.name - else: - target_field = self.lhs.field.target_field.name - self.rhs.add_fields([target_field], True) - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection) class RelatedLookupMixin: diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 0f104416de..9315ae8039 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -302,8 +302,8 @@ class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup): class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'exact' - def process_rhs(self, compiler, connection): - from django.db.models.sql.query import Query + def get_prep_lookup(self): + from django.db.models.sql.query import Query # avoid circular import if isinstance(self.rhs, Query): if self.rhs.has_limit_one(): if not self.rhs.has_select_fields: @@ -314,7 +314,7 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): 'The QuerySet value for an exact lookup must be limited to ' 'one result using slicing.' ) - return super().process_rhs(compiler, connection) + return super().get_prep_lookup() def as_sql(self, compiler, connection): # Avoid comparison against direct rhs if lhs is a boolean value. That @@ -388,6 +388,15 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan): class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'in' + def get_prep_lookup(self): + from django.db.models.sql.query import Query # avoid circular import + if isinstance(self.rhs, Query): + self.rhs.clear_ordering(clear_default=True) + if not self.rhs.has_select_fields: + self.rhs.clear_select_clause() + self.rhs.add_fields(['pk']) + return super().get_prep_lookup() + def process_rhs(self, compiler, connection): db_rhs = getattr(self.rhs, '_db', None) if db_rhs is not None and db_rhs != connection.alias: @@ -412,27 +421,7 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) placeholder = '(' + ', '.join(sqls) + ')' return (placeholder, sqls_params) - else: - from django.db.models.sql.query import ( # avoid circular import - Query, - ) - if isinstance(self.rhs, Query): - query = self.rhs - query.clear_ordering(clear_default=True) - if not query.has_select_fields: - query.clear_select_clause() - query.add_fields(['pk']) - - return super().process_rhs(compiler, connection) - - def get_group_by_cols(self, alias=None): - cols = self.lhs.get_group_by_cols() - if hasattr(self.rhs, 'get_group_by_cols'): - if not getattr(self.rhs, 'has_select_fields', True): - self.rhs.clear_select_clause() - self.rhs.add_fields(['pk']) - cols.extend(self.rhs.get_group_by_cols()) - return cols + return super().process_rhs(compiler, connection) def get_rhs_op(self, connection, rhs): return 'IN %s' % rhs