From 1efea11808f7b8a3b31445e0c1c7d270f832d965 Mon Sep 17 00:00:00 2001
From: Luke Plant <L.Plant.98@cantab.net>
Date: Thu, 31 Mar 2022 08:10:22 +0200
Subject: [PATCH] Refs #33397 -- Added register_combinable_fields().

---
 django/db/models/expressions.py | 65 ++++++++++++++++++++++++++++-----
 1 file changed, 55 insertions(+), 10 deletions(-)

diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py
index edd644da54..32982777ef 100644
--- a/django/db/models/expressions.py
+++ b/django/db/models/expressions.py
@@ -2,6 +2,7 @@ import copy
 import datetime
 import functools
 import inspect
+from collections import defaultdict
 from decimal import Decimal
 from uuid import UUID
 
@@ -465,16 +466,60 @@ class Expression(BaseExpression, Combinable):
         return hash(self.identity)
 
 
-_connector_combinators = {
-    connector: [
-        (fields.IntegerField, fields.IntegerField, fields.IntegerField),
-        (fields.IntegerField, fields.DecimalField, fields.DecimalField),
-        (fields.DecimalField, fields.IntegerField, fields.DecimalField),
-        (fields.IntegerField, fields.FloatField, fields.FloatField),
-        (fields.FloatField, fields.IntegerField, fields.FloatField),
-    ]
-    for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
-}
+# Type inference for CombinedExpression.output_field.
+_connector_combinations = [
+    # Numeric operations - operands of same type.
+    {
+        connector: [
+            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
+            (fields.FloatField, fields.FloatField, fields.FloatField),
+            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
+        ]
+        for connector in (
+            Combinable.ADD,
+            Combinable.SUB,
+            Combinable.MUL,
+            # Behavior for DIV with integer arguments follows Postgres/SQLite,
+            # not MySQL/Oracle.
+            Combinable.DIV,
+        )
+    },
+    # Numeric operations - operands of different type.
+    {
+        connector: [
+            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
+            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
+            (fields.IntegerField, fields.FloatField, fields.FloatField),
+            (fields.FloatField, fields.IntegerField, fields.FloatField),
+        ]
+        for connector in (
+            Combinable.ADD,
+            Combinable.SUB,
+            Combinable.MUL,
+            Combinable.DIV,
+        )
+    },
+]
+
+_connector_combinators = defaultdict(list)
+
+
+def register_combinable_fields(lhs, connector, rhs, result):
+    """
+    Register combinable types:
+        lhs <connector> rhs -> result
+    e.g.
+        register_combinable_fields(
+            IntegerField, Combinable.ADD, FloatField, FloatField
+        )
+    """
+    _connector_combinators[connector].append((lhs, rhs, result))
+
+
+for d in _connector_combinations:
+    for connector, field_types in d.items():
+        for lhs, rhs, result in field_types:
+            register_combinable_fields(lhs, connector, rhs, result)
 
 
 @functools.lru_cache(maxsize=128)