From ca74e563500e291480f1976b58fcd34aac768dca Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Mon, 10 Jul 2017 19:45:09 +0200 Subject: [PATCH] Fixed #28378 -- Fixed union() and difference() when combining with a queryset raising EmptyResultSet. Thanks Jon Dufresne for the report. Thanks Tim Graham and Simon Charette for the reviews. --- django/db/models/sql/compiler.py | 28 ++++++++++++++++++---------- docs/releases/1.11.4.txt | 3 +++ tests/queries/test_qs_combinators.py | 8 ++++++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index c705d33af8..84e240d1f4 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -399,7 +399,18 @@ class SQLCompiler: raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.') if compiler.get_order_by(): raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.') - parts = (compiler.as_sql() for compiler in compilers) + parts = () + for compiler in compilers: + try: + parts += (compiler.as_sql(),) + except EmptyResultSet: + # Omit the empty queryset with UNION and with DIFFERENCE if the + # first queryset is nonempty. + if combinator == 'union' or (combinator == 'difference' and parts): + continue + raise + if not parts: + return [], [] combinator_sql = self.connection.ops.set_operators[combinator] if all and combinator == 'union': combinator_sql += ' ALL' @@ -422,16 +433,7 @@ class SQLCompiler: refcounts_before = self.query.alias_refcount.copy() try: extra_select, order_by, group_by = self.pre_sql_setup() - distinct_fields = self.get_distinct() - - # This must come after 'select', 'ordering', and 'distinct' -- see - # docstring of get_from_clause() for details. - from_, f_params = self.get_from_clause() - for_update_part = None - where, w_params = self.compile(self.where) if self.where is not None else ("", []) - having, h_params = self.compile(self.having) if self.having is not None else ("", []) - combinator = self.query.combinator features = self.connection.features if combinator: @@ -439,6 +441,12 @@ class SQLCompiler: raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) result, params = self.get_combinator_sql(combinator, self.query.combinator_all) else: + distinct_fields = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() + where, w_params = self.compile(self.where) if self.where is not None else ("", []) + having, h_params = self.compile(self.having) if self.having is not None else ("", []) result = ['SELECT'] params = [] diff --git a/docs/releases/1.11.4.txt b/docs/releases/1.11.4.txt index 28f06deb06..2383ad6a17 100644 --- a/docs/releases/1.11.4.txt +++ b/docs/releases/1.11.4.txt @@ -12,3 +12,6 @@ Bugfixes * Fixed a regression in 1.11.3 on Python 2 where non-ASCII ``format`` values for date/time widgets results in an empty ``value`` in the widget's HTML (:ticket:`28355`). + +* Fixed ``QuerySet.union()`` and ``difference()`` when combining with + a queryset raising ``EmptyResultSet`` (:ticket:`28378`). diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py index efa3a2c987..e5bdedba45 100644 --- a/tests/queries/test_qs_combinators.py +++ b/tests/queries/test_qs_combinators.py @@ -58,18 +58,26 @@ class QuerySetSetOperationTests(TestCase): def test_difference_with_empty_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.none() + qs3 = Number.objects.filter(pk__in=[]) self.assertEqual(len(qs1.difference(qs2)), 10) + self.assertEqual(len(qs1.difference(qs3)), 10) self.assertEqual(len(qs2.difference(qs1)), 0) + self.assertEqual(len(qs3.difference(qs1)), 0) self.assertEqual(len(qs2.difference(qs2)), 0) + self.assertEqual(len(qs3.difference(qs3)), 0) def test_union_with_empty_qs(self): qs1 = Number.objects.all() qs2 = Number.objects.none() + qs3 = Number.objects.filter(pk__in=[]) self.assertEqual(len(qs1.union(qs2)), 10) self.assertEqual(len(qs2.union(qs1)), 10) + self.assertEqual(len(qs1.union(qs3)), 10) + self.assertEqual(len(qs3.union(qs1)), 10) self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10) self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20) self.assertEqual(len(qs2.union(qs2)), 0) + self.assertEqual(len(qs3.union(qs3)), 0) def test_limits(self): qs1 = Number.objects.all()