diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 612eb8f2d9..4bc9e6ed34 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -12,8 +12,10 @@ class MultiJoin(Exception): multi-valued join was attempted (if the caller wants to treat that exceptionally). """ - def __init__(self, level): - self.level = level + def __init__(self, names_pos, path_with_names): + self.level = names_pos + # The path travelled, this includes the path to the multijoin. + self.names_with_path = path_with_names class Empty(object): pass diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fb0f09efde..422029c5e0 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1200,7 +1200,7 @@ class Query(object): can_reuse.update(join_list) except MultiJoin as e: self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), - can_reuse) + can_reuse, e.names_with_path) return if (lookup_type == 'isnull' and value is True and not negate and @@ -1324,7 +1324,7 @@ class Query(object): (the last used join field), and target (which is a field guaranteed to contain the same value as the final field). """ - path = [] + path, names_with_path = [], [] for pos, name in enumerate(names): if name == 'pk': name = opts.pk.name @@ -1361,16 +1361,17 @@ class Query(object): opts, final_field, False, True)) if hasattr(field, 'get_path_info'): pathinfos, opts, target, final_field = field.get_path_info() + if not allow_many: + for inner_pos, p in enumerate(pathinfos): + if p.m2m: + names_with_path.append((name, pathinfos[0:inner_pos + 1])) + raise MultiJoin(pos + 1, names_with_path) path.extend(pathinfos) + names_with_path.append((name, pathinfos)) else: # Local non-relational field. final_field = target = field break - multijoin_pos = None - for m2mpos, pathinfo in enumerate(path): - if pathinfo.m2m: - multijoin_pos = m2mpos - break if pos != len(names) - 1: if pos == len(names) - 2: @@ -1379,8 +1380,6 @@ class Query(object): "the lookup type?" % (name, names[pos + 1])) else: raise FieldError("Join on field %r not permitted." % name) - if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many: - raise MultiJoin(multijoin_pos + 1) return path, final_field, target def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, @@ -1454,7 +1453,7 @@ class Query(object): break return target.column, joins[-1], joins - def split_exclude(self, filter_expr, prefix, can_reuse): + def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): """ When doing an exclude against any kind of N-to-many relation, we need to use a subquery. This method constructs the nested query, given the @@ -1462,11 +1461,10 @@ class Query(object): N-to-many relation field. As an example we could have original filter ~Q(child__name='foo'). - We would get here with filter_expr = child_name, prefix = child and - can_reuse is a set of joins we can reuse for filtering in the original - query. + We would get here with filter_expr = child__name, prefix = child and + can_reuse is a set of joins usable for filters in the original query. - We will turn this into + We will turn this into equivalent of: WHERE pk NOT IN (SELECT parent_id FROM thetable WHERE name = 'foo' AND parent_id IS NOT NULL) @@ -1474,44 +1472,48 @@ class Query(object): saner null handling, and is easier for the backend's optimizer to handle. """ + # Generate the inner query. query = Query(self.model) query.add_filter(filter_expr) query.bump_prefix() query.clear_ordering(True) - query.set_start(prefix) - # Adding extra check to make sure the selected field will not be null + # Try to have as simple as possible subquery -> trim leading joins from + # the subquery. + trimmed_joins = query.trim_start(names_with_path) + # Add extra check to make sure the selected field will not be null # since we are adding a IN clause. This prevents the # database from tripping over IN (...,NULL,...) selects and returning # nothing alias, col = query.select[0].col query.where.add((Constraint(alias, col, None), 'isnull', False), AND) - # We need to trim the last part from the prefix. - trimmed_prefix = LOOKUP_SEP.join(prefix.split(LOOKUP_SEP)[0:-1]) - if not trimmed_prefix: - rel, _, direct, m2m = self.model._meta.get_field_by_name(prefix) - if not m2m: - trimmed_prefix = rel.field.rel.field_name + + # Still make sure that the trimmed parts in the inner query and + # trimmed prefix are in sync. So, use the trimmed_joins to make sure + # as many path elements are in the prefix as there were trimmed joins. + # In addition, convert the path elements back to names so that + # add_filter() can handle them. + trimmed_prefix = [] + paths_in_prefix = trimmed_joins + for name, path in names_with_path: + if paths_in_prefix - len(path) > 0: + trimmed_prefix.append(name) + paths_in_prefix -= len(path) else: - if direct: - trimmed_prefix = rel.m2m_target_field_name() - else: - trimmed_prefix = rel.field.m2m_reverse_target_field_name() - + trimmed_prefix.append( + path[paths_in_prefix - len(path)].from_field.name) + break + trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix) self.add_filter(('%s__in' % trimmed_prefix, query), negate=True, - can_reuse=can_reuse) + can_reuse=can_reuse) - # If there's more than one join in the inner query (before any initial - # bits were trimmed -- which means the last active table is more than - # two places into the alias list), we need to also handle the - # possibility that the earlier joins don't match anything by adding a - # comparison to NULL (e.g. in - # Tag.objects.exclude(parent__parent__name='t1'), a tag with no parent - # would otherwise be overlooked). - active_positions = len([count for count - in query.alias_refcount.items() if count]) - if active_positions > 1: + # If there's more than one join in the inner query, we need to also + # handle the possibility that the earlier joins don't match anything + # by adding a comparison to NULL (e.g. in + # Tag.objects.exclude(parent__parent__name='t1') + # a tag with no parent would otherwise be overlooked). + if trimmed_joins > 1: self.add_filter(('%s__isnull' % trimmed_prefix, False), negate=True, - can_reuse=can_reuse) + can_reuse=can_reuse) def set_empty(self): self.where = EmptyWhere() @@ -1869,42 +1871,33 @@ class Query(object): return self.extra extra_select = property(_extra_select) - def set_start(self, start): + def trim_start(self, names_with_path): """ - Sets the table from which to start joining. The start position is - specified by the related attribute from the base model. This will - automatically set to the select column to be the column linked from the - previous table. + Trims joins from the start of the join path. The candidates for trim + are the PathInfos in names_with_path structure. Outer joins are not + eligible for removal. Also sets the select column so the start + matches the join. - This method is primarily for internal use and the error checking isn't - as friendly as add_filter(). Mostly useful for querying directly - against the join table of many-to-many relation in a subquery. - """ - opts = self.model._meta - alias = self.get_initial_alias() - field, col, opts, joins, extra = self.setup_joins( - start.split(LOOKUP_SEP), opts, alias) - select_col = self.alias_map[joins[1]].lhs_join_col - select_alias = alias - - # The call to setup_joins added an extra reference to everything in - # joins. Reverse that. - for alias in joins: - self.unref_alias(alias) - - # We might be able to trim some joins from the front of this query, - # providing that we only traverse "always equal" connections (i.e. rhs - # is *always* the same value as lhs). - for alias in joins[1:]: - join_info = self.alias_map[alias] - if (join_info.lhs_join_col != select_col - or join_info.join_type != self.INNER): - break - self.unref_alias(select_alias) - select_alias = join_info.rhs_alias - select_col = join_info.rhs_join_col + This method is mostly useful for generating the subquery joins & col + in "WHERE somecol IN (subquery)". This construct is needed by + split_exclude(). + _""" + join_pos = 0 + for _, paths in names_with_path: + for path in paths: + peek = self.tables[join_pos + 1] + if self.alias_map[peek].join_type == self.LOUTER: + # Back up one level and break + select_alias = self.tables[join_pos] + select_col = path.from_field.column + break + select_alias = self.tables[join_pos + 1] + select_col = path.to_field.column + self.unref_alias(self.tables[join_pos]) + join_pos += 1 self.select = [SelectInfo((select_alias, select_col), None)] self.remove_inherited_models() + return join_pos def is_nullable(self, field): """ diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 16583e891c..91edf71aeb 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -439,3 +439,17 @@ class BaseA(models.Model): a = models.ForeignKey(FK1, null=True) b = models.ForeignKey(FK2, null=True) c = models.ForeignKey(FK3, null=True) + +@python_2_unicode_compatible +class Identifier(models.Model): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + +class Program(models.Model): + identifier = models.OneToOneField(Identifier) + +class Channel(models.Model): + programs = models.ManyToManyField(Program) + identifier = models.OneToOneField(Identifier) diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index ea54d18451..34bfea0b94 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -24,7 +24,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, - JobResponsibilities, BaseA) + JobResponsibilities, BaseA, Identifier, Program, Channel) class BaseQuerysetTest(TestCase): @@ -2612,3 +2612,22 @@ class DisjunctionPromotionTests(TestCase): qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2))) self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) self.assertEqual(str(qs.query).count('INNER JOIN'), 0) + + +class ManyToManyExcludeTest(TestCase): + def test_exclude_many_to_many(self): + Identifier.objects.create(name='extra') + program = Program.objects.create(identifier=Identifier.objects.create(name='program')) + channel = Channel.objects.create(identifier=Identifier.objects.create(name='channel')) + channel.programs.add(program) + + # channel contains 'program1', so all Identifiers except that one + # should be returned + self.assertQuerysetEqual( + Identifier.objects.exclude(program__channel=channel).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Identifier.objects.exclude(program__channel=None).order_by('name'), + [''] + ) diff --git a/tests/tmp.txt b/tests/tmp.txt new file mode 100644 index 0000000000..4e812b2c23 --- /dev/null +++ b/tests/tmp.txt @@ -0,0 +1 @@ +SELECT "queries_tag"."id", "queries_tag"."name", "queries_tag"."parent_id", "queries_tag"."category_id" FROM "queries_tag" WHERE NOT (("queries_tag"."id" IN (SELECT U0."id" FROM "queries_tag" U0 LEFT OUTER JOIN "queries_tag" U1 ON (U0."id" = U1."parent_id") WHERE (U1."id" IS NULL AND U0."id" IS NOT NULL)) AND "queries_tag"."id" IS NOT NULL)) ORDER BY "queries_tag"."name" ASC