From 35cecb1ebd0ccda0be7a518d1b7273333d26fbae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= <akaariai@gmail.com>
Date: Wed, 8 Jan 2014 19:35:47 +0200
Subject: [PATCH] Fixed #21748 -- join promotion for negated AND conditions

Made sure Django treats case .filter(NOT (a AND b)) the same way as
.filter((NOT a OR NOT b)) for join promotion.
---
 django/db/models/sql/query.py | 24 +++++++----
 tests/queries/tests.py        | 79 +++++++++++++++++++++++++++++++++++
 2 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 838c2914d9..79a4459e2f 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -483,7 +483,7 @@ class Query(object):
         # Base table must be present in the query - this is the same
         # table on both sides.
         self.get_initial_alias()
-        joinpromoter = JoinPromoter(connector, 2)
+        joinpromoter = JoinPromoter(connector, 2, False)
         joinpromoter.add_votes(
             j for j in self.alias_map if self.alias_map[j].join_type == self.INNER)
         rhs_votes = set()
@@ -1299,11 +1299,9 @@ class Query(object):
         connector = q_object.connector
         current_negated = current_negated ^ q_object.negated
         branch_negated = branch_negated or q_object.negated
-        # Note that if the connector happens to match what we have already in
-        # the tree, the add will be a no-op.
         target_clause = self.where_class(connector=connector,
                                          negated=q_object.negated)
-        joinpromoter = JoinPromoter(q_object.connector, len(q_object.children))
+        joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
         for child in q_object.children:
             if isinstance(child, Node):
                 child_clause, needed_inner = self._add_q(
@@ -2013,8 +2011,16 @@ class JoinPromoter(object):
     conditions.
     """
 
-    def __init__(self, connector, num_children):
+    def __init__(self, connector, num_children, negated):
         self.connector = connector
+        self.negated = negated
+        if self.negated:
+            if connector == AND:
+                self.effective_connector = OR
+            else:
+                self.effective_connector = AND
+        else:
+            self.effective_connector = self.connector
         self.num_children = num_children
         # Maps of table alias to how many times it is seen as required for
         # inner and/or outer joins.
@@ -2038,6 +2044,8 @@ class JoinPromoter(object):
         """
         to_promote = set()
         to_demote = set()
+        # The effective_connector is used so that NOT (a AND b) is treated
+        # similarly to (a OR b) for join promotion.
         for table, votes in self.inner_votes.items():
             # We must use outer joins in OR case when the join isn't contained
             # in all of the joins. Otherwise the INNER JOIN itself could remove
@@ -2049,7 +2057,7 @@ class JoinPromoter(object):
             # to rel_a would remove a valid match from the query. So, we need
             # to promote any existing INNER to LOUTER (it is possible this
             # promotion in turn will be demoted later on).
-            if self.connector == 'OR' and votes < self.num_children:
+            if self.effective_connector == 'OR' and votes < self.num_children:
                 to_promote.add(table)
             # If connector is AND and there is a filter that can match only
             # when there is a joinable row, then use INNER. For example, in
@@ -2061,8 +2069,8 @@ class JoinPromoter(object):
             #     (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
             # then if rel_a doesn't produce any rows, the whole condition
             # can't match. Hence we can safely use INNER join.
-            if self.connector == 'AND' or (self.connector == 'OR' and
-                                           votes == self.num_children):
+            if self.effective_connector == 'AND' or (
+                    self.effective_connector == 'OR' and votes == self.num_children):
                 to_demote.add(table)
             # Finally, what happens in cases where we have:
             #    (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
diff --git a/tests/queries/tests.py b/tests/queries/tests.py
index 6bf90cdbf9..3f861db5d1 100644
--- a/tests/queries/tests.py
+++ b/tests/queries/tests.py
@@ -2799,6 +2799,85 @@ class NullJoinPromotionOrTest(TestCase):
         self.assertQuerysetEqual(
             qs.order_by('name'), [r2, r1], lambda x: x)
 
+    def test_ticket_21748(self):
+        i1 = Identifier.objects.create(name='i1')
+        i2 = Identifier.objects.create(name='i2')
+        i3 = Identifier.objects.create(name='i3')
+        Program.objects.create(identifier=i1)
+        Channel.objects.create(identifier=i1)
+        Program.objects.create(identifier=i2)
+        self.assertQuerysetEqual(
+            Identifier.objects.filter(program=None, channel=None),
+            [i3], lambda x: x)
+        self.assertQuerysetEqual(
+            Identifier.objects.exclude(program=None, channel=None).order_by('name'),
+            [i1, i2], lambda x: x)
+
+    def test_ticket_21748_double_negated_and(self):
+        i1 = Identifier.objects.create(name='i1')
+        i2 = Identifier.objects.create(name='i2')
+        Identifier.objects.create(name='i3')
+        p1 = Program.objects.create(identifier=i1)
+        c1 = Channel.objects.create(identifier=i1)
+        Program.objects.create(identifier=i2)
+        # Check the ~~Q() (or equivalently .exclude(~Q)) works like Q() for
+        # join promotion.
+        qs1_doubleneg = Identifier.objects.exclude(~Q(program__id=p1.id, channel__id=c1.id)).order_by('pk')
+        qs1_filter = Identifier.objects.filter(program__id=p1.id, channel__id=c1.id).order_by('pk')
+        self.assertQuerysetEqual(qs1_doubleneg, qs1_filter, lambda x: x)
+        self.assertEqual(str(qs1_filter.query).count('JOIN'),
+                         str(qs1_doubleneg.query).count('JOIN'))
+        self.assertEqual(2, str(qs1_doubleneg.query).count('INNER JOIN'))
+        self.assertEqual(str(qs1_filter.query).count('INNER JOIN'),
+                         str(qs1_doubleneg.query).count('INNER JOIN'))
+
+    def test_ticket_21748_double_negated_or(self):
+        i1 = Identifier.objects.create(name='i1')
+        i2 = Identifier.objects.create(name='i2')
+        Identifier.objects.create(name='i3')
+        p1 = Program.objects.create(identifier=i1)
+        c1 = Channel.objects.create(identifier=i1)
+        p2 = Program.objects.create(identifier=i2)
+        # Test OR + doubleneq. The expected result is that channel is LOUTER
+        # joined, program INNER joined
+        qs1_filter = Identifier.objects.filter(
+            Q(program__id=p2.id, channel__id=c1.id)
+            | Q(program__id=p1.id)
+        ).order_by('pk')
+        qs1_doubleneg = Identifier.objects.exclude(
+            ~Q(Q(program__id=p2.id, channel__id=c1.id)
+            | Q(program__id=p1.id))
+        ).order_by('pk')
+        self.assertQuerysetEqual(qs1_doubleneg, qs1_filter, lambda x: x)
+        self.assertEqual(str(qs1_filter.query).count('JOIN'),
+                         str(qs1_doubleneg.query).count('JOIN'))
+        self.assertEqual(1, str(qs1_doubleneg.query).count('INNER JOIN'))
+        self.assertEqual(str(qs1_filter.query).count('INNER JOIN'),
+                         str(qs1_doubleneg.query).count('INNER JOIN'))
+
+    def test_ticket_21748_complex_filter(self):
+        i1 = Identifier.objects.create(name='i1')
+        i2 = Identifier.objects.create(name='i2')
+        Identifier.objects.create(name='i3')
+        p1 = Program.objects.create(identifier=i1)
+        c1 = Channel.objects.create(identifier=i1)
+        p2 = Program.objects.create(identifier=i2)
+        # Finally, a more complex case, one time in a way where each
+        # NOT is pushed to lowest level in the boolean tree, and
+        # another query where this isn't done.
+        qs1 = Identifier.objects.filter(
+            ~Q(~Q(program__id=p2.id, channel__id=c1.id)
+            & Q(program__id=p1.id))).order_by('pk')
+        qs2 = Identifier.objects.filter(
+            Q(Q(program__id=p2.id, channel__id=c1.id)
+            | ~Q(program__id=p1.id))).order_by('pk')
+        self.assertQuerysetEqual(qs1, qs2, lambda x: x)
+        self.assertEqual(str(qs1.query).count('JOIN'),
+                         str(qs2.query).count('JOIN'))
+        self.assertEqual(0, str(qs1.query).count('INNER JOIN'))
+        self.assertEqual(str(qs1.query).count('INNER JOIN'),
+                         str(qs2.query).count('INNER JOIN'))
+
 
 class ReverseJoinTrimmingTest(TestCase):
     def test_reverse_trimming(self):