From c58a8acd413ccc992dd30afd98ed900897e1f719 Mon Sep 17 00:00:00 2001
From: Simon Charette <charette.s@gmail.com>
Date: Mon, 4 Jul 2022 21:51:07 +0100
Subject: [PATCH] Fixed #33768 -- Fixed ordering compound queries by
 nulls_first/nulls_last on MySQL.

Columns of the left outer most select statement in a combined query
can be referenced by alias just like by index.

This removes combined query ordering by column index and avoids an
unnecessary usage of RawSQL which causes issues for backends that
specialize the treatment of null ordering.
---
 django/db/backends/base/features.py   |  1 +
 django/db/backends/oracle/features.py |  1 +
 django/db/models/sql/compiler.py      | 21 +++++++++++----------
 tests/queries/test_qs_combinators.py  | 18 ++++++++++++++++++
 4 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 4fd21beee3..5c99736f22 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -235,6 +235,7 @@ class BaseDatabaseFeatures:
     supports_select_difference = True
     supports_slicing_ordering_in_compound = False
     supports_parentheses_in_compound = True
+    requires_compound_order_by_subquery = False
 
     # Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
     # expressions?
diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
index 49e58ff59d..9a98616dc2 100644
--- a/django/db/backends/oracle/features.py
+++ b/django/db/backends/oracle/features.py
@@ -69,6 +69,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_partial_indexes = False
     can_rename_index = True
     supports_slicing_ordering_in_compound = True
+    requires_compound_order_by_subquery = True
     allows_multiple_constraints_on_same_fields = False
     supports_boolean_expr_in_select_clause = False
     supports_comparing_boolean_expr = False
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 4668a820fb..f566546307 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -435,21 +435,18 @@ class SQLCompiler:
 
         for expr, is_ref in self._order_by_pairs():
             resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
-            if self.query.combinator and self.select:
+            if not is_ref and self.query.combinator and self.select:
                 src = resolved.expression
                 expr_src = expr.expression
-                # Relabel order by columns to raw numbers if this is a combined
-                # query; necessary since the columns can't be referenced by the
-                # fully qualified name and the simple column names may collide.
-                for idx, (sel_expr, _, col_alias) in enumerate(self.select):
-                    if is_ref and col_alias == src.refs:
-                        src = src.source
-                    elif col_alias and not (
+                for sel_expr, _, col_alias in self.select:
+                    if col_alias and not (
                         isinstance(expr_src, F) and col_alias == expr_src.name
                     ):
                         continue
                     if src == sel_expr:
-                        resolved.set_source_expressions([RawSQL("%d" % (idx + 1), ())])
+                        resolved.set_source_expressions(
+                            [Ref(col_alias if col_alias else src.target.column, src)]
+                        )
                         break
                 else:
                     if col_alias:
@@ -853,7 +850,11 @@ class SQLCompiler:
                 for _, (o_sql, o_params, _) in order_by:
                     ordering.append(o_sql)
                     params.extend(o_params)
-                result.append("ORDER BY %s" % ", ".join(ordering))
+                order_by_sql = "ORDER BY %s" % ", ".join(ordering)
+                if combinator and features.requires_compound_order_by_subquery:
+                    result = ["SELECT * FROM (", *result, ")", order_by_sql]
+                else:
+                    result.append(order_by_sql)
 
             if with_limit_offset:
                 result.append(
diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py
index 3cd19d5f31..5fc09ca922 100644
--- a/tests/queries/test_qs_combinators.py
+++ b/tests/queries/test_qs_combinators.py
@@ -61,6 +61,24 @@ class QuerySetSetOperationTests(TestCase):
         self.assertSequenceEqual(qs3.none(), [])
         self.assertNumbersEqual(qs3, [0, 1, 8, 9], ordered=False)
 
+    def test_union_order_with_null_first_last(self):
+        Number.objects.filter(other_num=5).update(other_num=None)
+        qs1 = Number.objects.filter(num__lte=1)
+        qs2 = Number.objects.filter(num__gte=2)
+        qs3 = qs1.union(qs2)
+        self.assertSequenceEqual(
+            qs3.order_by(
+                F("other_num").asc(nulls_first=True),
+            ).values_list("other_num", flat=True),
+            [None, 1, 2, 3, 4, 6, 7, 8, 9, 10],
+        )
+        self.assertSequenceEqual(
+            qs3.order_by(
+                F("other_num").asc(nulls_last=True),
+            ).values_list("other_num", flat=True),
+            [1, 2, 3, 4, 6, 7, 8, 9, 10, None],
+        )
+
     @skipUnlessDBFeature("supports_select_intersection")
     def test_intersection_with_empty_qs(self):
         qs1 = Number.objects.all()