From e7afef13f594eb667f2709c0ef7bca98452ab32b Mon Sep 17 00:00:00 2001
From: Sergey Fedoseev <fedoseev.sergey@gmail.com>
Date: Mon, 10 Apr 2017 22:26:26 +0500
Subject: [PATCH] Fixed #26788 -- Fixed QuerySet.update() crash when updating a
 geometry to another one.

---
 .../gis/db/backends/base/operations.py        | 24 ++++++++++++++-
 .../gis/db/backends/mysql/operations.py       | 12 --------
 .../gis/db/backends/oracle/operations.py      | 26 +---------------
 .../gis/db/backends/postgis/operations.py     | 13 ++++----
 .../gis/db/backends/spatialite/operations.py  | 26 ----------------
 django/contrib/gis/db/models/lookups.py       |  2 ++
 django/db/backends/mysql/operations.py        |  2 +-
 django/db/models/sql/compiler.py              |  2 +-
 tests/gis_tests/geoapp/models.py              |  6 ++++
 tests/gis_tests/geoapp/test_expressions.py    | 30 +++++++++++++++++--
 10 files changed, 68 insertions(+), 75 deletions(-)

diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py
index 0732d517c6..1b314a9d2b 100644
--- a/django/contrib/gis/db/backends/base/operations.py
+++ b/django/contrib/gis/db/backends/base/operations.py
@@ -71,7 +71,29 @@ class BaseSpatialOperations:
         stored procedure call to the transformation function of the spatial
         backend.
         """
-        raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
+        def transform_value(value, field):
+            return (
+                not (value is None or value.srid == field.srid) and
+                self.connection.features.supports_transform
+            )
+
+        if hasattr(value, 'as_sql'):
+            return (
+                '%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid)
+                if transform_value(value.output_field, f)
+                else '%s'
+            )
+        if transform_value(value, f):
+            # Add Transform() to the SQL placeholder.
+            return '%s(%s(%%s,%s), %s)' % (
+                self.spatial_function_name('Transform'),
+                self.from_text, value.srid, f.srid,
+            )
+        elif self.connection.features.has_spatialrefsys_table:
+            return '%s(%%s,%s)' % (self.from_text, f.srid)
+        else:
+            # For backwards compatibility on MySQL (#27464).
+            return '%s(%%s)' % self.from_text
 
     def check_expression_support(self, expression):
         if isinstance(expression, self.disallowed_aggregates):
diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py
index 0ea9f8f274..1c850b39d1 100644
--- a/django/contrib/gis/db/backends/mysql/operations.py
+++ b/django/contrib/gis/db/backends/mysql/operations.py
@@ -86,18 +86,6 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
     def geo_db_type(self, f):
         return f.geom_type
 
-    def get_geom_placeholder(self, f, value, compiler):
-        """
-        The placeholder here has to include MySQL's WKT constructor.  Because
-        MySQL does not support spatial transformations, there is no need to
-        modify the placeholder based on the contents of the given value.
-        """
-        if hasattr(value, 'as_sql'):
-            placeholder, _ = compiler.compile(value)
-        else:
-            placeholder = '%s(%%s)' % self.from_text
-        return placeholder
-
     def get_db_converters(self, expression):
         converters = super().get_db_converters(expression)
         if isinstance(expression.output_field, GeometryField) and self.uses_invalid_empty_geometry_collection:
diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py
index 926209ac40..50f5bed8b9 100644
--- a/django/contrib/gis/db/backends/oracle/operations.py
+++ b/django/contrib/gis/db/backends/oracle/operations.py
@@ -187,33 +187,9 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
         return [dist_param]
 
     def get_geom_placeholder(self, f, value, compiler):
-        """
-        Provide a proper substitution value for Geometries that are not in the
-        SRID of the field.  Specifically, this routine will substitute in the
-        SDO_CS.TRANSFORM() function call.
-        """
-        tranform_func = self.spatial_function_name('Transform')
-
         if value is None:
             return 'NULL'
-
-        def transform_value(val, srid):
-            return val.srid != srid
-
-        if hasattr(value, 'as_sql'):
-            if transform_value(value, f.srid):
-                placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
-            else:
-                placeholder = '%s'
-            # No geometry value used for F expression, substitute in
-            # the column name instead.
-            sql, _ = compiler.compile(value)
-            return placeholder % sql
-        else:
-            if transform_value(value, f.srid):
-                return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (tranform_func, value.srid, f.srid)
-            else:
-                return 'SDO_GEOMETRY(%%s, %s)' % f.srid
+        return super().get_geom_placeholder(f, value, compiler)
 
     def spatial_aggregate_name(self, agg_name):
         """
diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py
index 49662236c0..71302d9eb7 100644
--- a/django/contrib/gis/db/backends/postgis/operations.py
+++ b/django/contrib/gis/db/backends/postgis/operations.py
@@ -292,6 +292,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
         substitute in the ST_Transform() function call.
         """
         tranform_func = self.spatial_function_name('Transform')
+        if hasattr(value, 'as_sql'):
+            if value.field.srid == f.srid:
+                placeholder = '%s'
+            else:
+                placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
+            return placeholder
 
         # Get the srid for this object
         if value is None:
@@ -310,13 +316,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
         else:
             placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
 
-        if hasattr(value, 'as_sql'):
-            # If this is an F expression, then we don't really want
-            # a placeholder and instead substitute in the column
-            # of the expression.
-            sql, _ = compiler.compile(value)
-            placeholder = placeholder % sql
-
         return placeholder
 
     def _get_postgis_func(self, func):
diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py
index 42b079bbfa..087a879d07 100644
--- a/django/contrib/gis/db/backends/spatialite/operations.py
+++ b/django/contrib/gis/db/backends/spatialite/operations.py
@@ -152,32 +152,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
             dist_param = value
         return [dist_param]
 
-    def get_geom_placeholder(self, f, value, compiler):
-        """
-        Provide a proper substitution value for Geometries that are not in the
-        SRID of the field.  Specifically, this routine will substitute in the
-        Transform() and GeomFromText() function call(s).
-        """
-        tranform_func = self.spatial_function_name('Transform')
-
-        def transform_value(value, srid):
-            return not (value is None or value.srid == srid)
-        if hasattr(value, 'as_sql'):
-            if transform_value(value, f.srid):
-                placeholder = '%s(%%s, %s)' % (tranform_func, f.srid)
-            else:
-                placeholder = '%s'
-            # No geometry value used for F expression, substitute in
-            # the column name instead.
-            sql, _ = compiler.compile(value)
-            return placeholder % sql
-        else:
-            if transform_value(value, f.srid):
-                # Adding Transform() to the SQL placeholder.
-                return '%s(%s(%%s,%s), %s)' % (tranform_func, self.from_text, value.srid, f.srid)
-            else:
-                return '%s(%%s,%s)' % (self.from_text, f.srid)
-
     def _get_spatialite_func(self, func):
         """
         Helper routine for calling SpatiaLite functions and returning
diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py
index 2024df45b0..c34e3391e1 100644
--- a/django/contrib/gis/db/models/lookups.py
+++ b/django/contrib/gis/db/models/lookups.py
@@ -68,6 +68,8 @@ class GISLookup(Lookup):
             if not hasattr(geo_fld, 'srid'):
                 raise ValueError('No geographic field found in expression.')
             self.rhs.srid = geo_fld.srid
+            sql, _ = compiler.compile(geom)
+            return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
         elif isinstance(self.rhs, Expression):
             raise ValueError('Complex expressions not supported for spatial fields.')
         elif isinstance(self.rhs, (list, tuple)):
diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py
index b47136df26..c1d0451a54 100644
--- a/django/db/backends/mysql/operations.py
+++ b/django/db/backends/mysql/operations.py
@@ -233,7 +233,7 @@ class DatabaseOperations(BaseDatabaseOperations):
         return value
 
     def binary_placeholder_sql(self, value):
-        return '_binary %s' if value is not None else '%s'
+        return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
 
     def subtract_temporals(self, internal_type, lhs, rhs):
         lhs_sql, lhs_params = lhs
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index f32d106def..14a727e998 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1167,7 +1167,7 @@ class SQLUpdateCompiler(SQLCompiler):
             name = field.column
             if hasattr(val, 'as_sql'):
                 sql, params = self.compile(val)
-                values.append('%s = %s' % (qn(name), sql))
+                values.append('%s = %s' % (qn(name), placeholder % sql))
                 update_params.extend(params)
             elif val is not None:
                 values.append('%s = %s' % (qn(name), placeholder))
diff --git a/tests/gis_tests/geoapp/models.py b/tests/gis_tests/geoapp/models.py
index 363f3deaf0..b555165e56 100644
--- a/tests/gis_tests/geoapp/models.py
+++ b/tests/gis_tests/geoapp/models.py
@@ -101,3 +101,9 @@ class NonConcreteField(models.IntegerField):
 class NonConcreteModel(NamedModel):
     non_concrete = NonConcreteField()
     point = models.PointField(geography=True)
+
+
+class ManyPointModel(NamedModel):
+    point1 = models.PointField()
+    point2 = models.PointField()
+    point3 = models.PointField(srid=3857)
diff --git a/tests/gis_tests/geoapp/test_expressions.py b/tests/gis_tests/geoapp/test_expressions.py
index c18d07f0e8..72f9a37dc4 100644
--- a/tests/gis_tests/geoapp/test_expressions.py
+++ b/tests/gis_tests/geoapp/test_expressions.py
@@ -1,11 +1,12 @@
 from unittest import skipUnless
 
-from django.contrib.gis.db.models import GeometryField, Value, functions
+from django.contrib.gis.db.models import F, GeometryField, Value, functions
 from django.contrib.gis.geos import Point, Polygon
+from django.db import connection
 from django.test import TestCase, skipUnlessDBFeature
 
 from ..utils import postgis
-from .models import City
+from .models import City, ManyPointModel
 
 
 @skipUnlessDBFeature('gis_enabled')
@@ -29,3 +30,28 @@ class GeoExpressionsTests(TestCase):
         p = Polygon(((1, 1), (1, 2), (2, 2), (2, 1), (1, 1)))
         area = City.objects.annotate(a=functions.Area(Value(p, GeometryField(srid=4326, geography=True)))).first().a
         self.assertAlmostEqual(area.sq_km, 12305.1, 0)
+
+    def test_update_from_other_field(self):
+        p1 = Point(1, 1, srid=4326)
+        p2 = Point(2, 2, srid=4326)
+        obj = ManyPointModel.objects.create(
+            point1=p1,
+            point2=p2,
+            point3=p2.transform(3857, clone=True),
+        )
+        # Updating a point to a point of the same SRID.
+        ManyPointModel.objects.filter(pk=obj.pk).update(point2=F('point1'))
+        obj.refresh_from_db()
+        self.assertEqual(obj.point2, p1)
+        # Updating a point to a point with a different SRID.
+        if connection.features.supports_transform:
+            ManyPointModel.objects.filter(pk=obj.pk).update(point3=F('point1'))
+            obj.refresh_from_db()
+            self.assertTrue(obj.point3.equals_exact(p1.transform(3857, clone=True), 0.1))
+
+    @skipUnlessDBFeature('has_Translate_function')
+    def test_update_with_expression(self):
+        city = City.objects.create(point=Point(1, 1, srid=4326))
+        City.objects.filter(pk=city.pk).update(point=functions.Translate('point', 1, 1))
+        city.refresh_from_db()
+        self.assertEqual(city.point, Point(2, 2, srid=4326))