1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #25588 -- Added spatial lookups to RasterField.

Thanks Tim Graham for the review.
This commit is contained in:
Daniel Wiesmann
2016-04-21 17:03:14 +01:00
committed by Tim Graham
parent 03efa304bc
commit bbfad84dd9
11 changed files with 743 additions and 158 deletions

View File

@@ -6,16 +6,27 @@ from __future__ import unicode_literals
from psycopg2 import Binary
from psycopg2.extensions import ISQLQuote
from django.contrib.gis.db.backends.postgis.pgraster import to_pgraster
from django.contrib.gis.geometry.backend import Geometry
class PostGISAdapter(object):
def __init__(self, geom, geography=False):
"Initializes on the geometry."
def __init__(self, obj, geography=False):
"""
Initialize on the spatial object.
"""
self.is_geometry = isinstance(obj, Geometry)
# Getting the WKB (in string form, to allow easy pickling of
# the adaptor) and the SRID from the geometry.
self.ewkb = bytes(geom.ewkb)
self.srid = geom.srid
# the adaptor) and the SRID from the geometry or raster.
if self.is_geometry:
self.ewkb = bytes(obj.ewkb)
self._adapter = Binary(self.ewkb)
else:
self.ewkb = to_pgraster(obj)
self.srid = obj.srid
self.geography = geography
self._adapter = Binary(self.ewkb)
def __conform__(self, proto):
# Does the given protocol conform to what Psycopg2 expects?
@@ -40,12 +51,19 @@ class PostGISAdapter(object):
This method allows escaping the binary in the style required by the
server's `standard_conforming_string` setting.
"""
self._adapter.prepare(conn)
if self.is_geometry:
self._adapter.prepare(conn)
def getquoted(self):
"Returns a properly quoted string for use in PostgreSQL/PostGIS."
# psycopg will figure out whether to use E'\\000' or '\000'
return str('%s(%s)' % (
'ST_GeogFromWKB' if self.geography else 'ST_GeomFromEWKB',
self._adapter.getquoted().decode())
)
"""
Return a properly quoted string for use in PostgreSQL/PostGIS.
"""
if self.is_geometry:
# Psycopg will figure out whether to use E'\\000' or '\000'.
return str('%s(%s)' % (
'ST_GeogFromWKB' if self.geography else 'ST_GeomFromEWKB',
self._adapter.getquoted().decode())
)
else:
# For rasters, add explicit type cast to WKB string.
return "'%s'::raster" % self.ewkb

View File

@@ -4,30 +4,83 @@ from django.conf import settings
from django.contrib.gis.db.backends.base.operations import \
BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.gdal import GDALRaster
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.utils import ProgrammingError
from django.utils import six
from django.utils.functional import cached_property
from .adapter import PostGISAdapter
from .models import PostGISGeometryColumns, PostGISSpatialRefSys
from .pgraster import from_pgraster, get_pgraster_srid, to_pgraster
# Identifier to mark raster lookups as bilateral.
BILATERAL = 'bilateral'
class PostGISOperator(SpatialOperator):
def __init__(self, geography=False, **kwargs):
# Only a subset of the operators and functions are available
# for the geography type.
def __init__(self, geography=False, raster=False, **kwargs):
# Only a subset of the operators and functions are available for the
# geography type.
self.geography = geography
# Only a subset of the operators and functions are available for the
# raster type. Lookups that don't suport raster will be converted to
# polygons. If the raster argument is set to BILATERAL, then the
# operator cannot handle mixed geom-raster lookups.
self.raster = raster
super(PostGISOperator, self).__init__(**kwargs)
def as_sql(self, connection, lookup, *args):
def as_sql(self, connection, lookup, template_params, *args):
if lookup.lhs.output_field.geography and not self.geography:
raise ValueError('PostGIS geography does not support the "%s" '
'function/operator.' % (self.func or self.op,))
return super(PostGISOperator, self).as_sql(connection, lookup, *args)
template_params = self.check_raster(lookup, template_params)
return super(PostGISOperator, self).as_sql(connection, lookup, template_params, *args)
def check_raster(self, lookup, template_params):
# Get rhs value.
if isinstance(lookup.rhs, (tuple, list)):
rhs_val = lookup.rhs[0]
spheroid = lookup.rhs[-1] == 'spheroid'
else:
rhs_val = lookup.rhs
spheroid = False
# Check which input is a raster.
lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER'
rhs_is_raster = isinstance(rhs_val, GDALRaster)
# Look for band indices and inject them if provided.
if lookup.band_lhs is not None and lhs_is_raster:
if not self.func:
raise ValueError('Band indices are not allowed for this operator, it works on bbox only.')
template_params['lhs'] = '%s, %s' % (template_params['lhs'], lookup.band_lhs)
if lookup.band_rhs is not None and rhs_is_raster:
if not self.func:
raise ValueError('Band indices are not allowed for this operator, it works on bbox only.')
template_params['rhs'] = '%s, %s' % (template_params['rhs'], lookup.band_rhs)
# Convert rasters to polygons if necessary.
if not self.raster or spheroid:
# Operators without raster support.
if lhs_is_raster:
template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs']
if rhs_is_raster:
template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs']
elif self.raster == BILATERAL:
# Operators with raster support but don't support mixed (rast-geom)
# lookups.
if lhs_is_raster and not rhs_is_raster:
template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs']
elif rhs_is_raster and not lhs_is_raster:
template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs']
return template_params
class PostGISDistanceOperator(PostGISOperator):
@@ -35,6 +88,7 @@ class PostGISDistanceOperator(PostGISOperator):
def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
template_params = self.check_raster(lookup, template_params)
sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
@@ -58,33 +112,33 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
Adapter = PostGISAdapter
gis_operators = {
'bbcontains': PostGISOperator(op='~'),
'bboverlaps': PostGISOperator(op='&&', geography=True),
'contained': PostGISOperator(op='@'),
'contains': PostGISOperator(func='ST_Contains'),
'overlaps_left': PostGISOperator(op='&<'),
'overlaps_right': PostGISOperator(op='&>'),
'bbcontains': PostGISOperator(op='~', raster=True),
'bboverlaps': PostGISOperator(op='&&', geography=True, raster=True),
'contained': PostGISOperator(op='@', raster=True),
'overlaps_left': PostGISOperator(op='&<', raster=BILATERAL),
'overlaps_right': PostGISOperator(op='&>', raster=BILATERAL),
'overlaps_below': PostGISOperator(op='&<|'),
'overlaps_above': PostGISOperator(op='|&>'),
'left': PostGISOperator(op='<<'),
'right': PostGISOperator(op='>>'),
'strictly_below': PostGISOperator(op='<<|'),
'strictly_above': PostGISOperator(op='|>>'),
'same_as': PostGISOperator(op='~='),
'exact': PostGISOperator(op='~='), # alias of same_as
'contains_properly': PostGISOperator(func='ST_ContainsProperly'),
'coveredby': PostGISOperator(func='ST_CoveredBy', geography=True),
'covers': PostGISOperator(func='ST_Covers', geography=True),
'same_as': PostGISOperator(op='~=', raster=BILATERAL),
'exact': PostGISOperator(op='~=', raster=BILATERAL), # alias of same_as
'contains': PostGISOperator(func='ST_Contains', raster=BILATERAL),
'contains_properly': PostGISOperator(func='ST_ContainsProperly', raster=BILATERAL),
'coveredby': PostGISOperator(func='ST_CoveredBy', geography=True, raster=BILATERAL),
'covers': PostGISOperator(func='ST_Covers', geography=True, raster=BILATERAL),
'crosses': PostGISOperator(func='ST_Crosses'),
'disjoint': PostGISOperator(func='ST_Disjoint'),
'disjoint': PostGISOperator(func='ST_Disjoint', raster=BILATERAL),
'equals': PostGISOperator(func='ST_Equals'),
'intersects': PostGISOperator(func='ST_Intersects', geography=True),
'intersects': PostGISOperator(func='ST_Intersects', geography=True, raster=BILATERAL),
'isvalid': PostGISOperator(func='ST_IsValid'),
'overlaps': PostGISOperator(func='ST_Overlaps'),
'overlaps': PostGISOperator(func='ST_Overlaps', raster=BILATERAL),
'relate': PostGISOperator(func='ST_Relate'),
'touches': PostGISOperator(func='ST_Touches'),
'within': PostGISOperator(func='ST_Within'),
'dwithin': PostGISOperator(func='ST_DWithin', geography=True),
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
@@ -272,14 +326,14 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
def get_geom_placeholder(self, f, value, compiler):
"""
Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the
ST_Transform() function call.
Provide a proper substitution value for Geometries or rasters that are
not in the SRID of the field. Specifically, this routine will
substitute in the ST_Transform() function call.
"""
# Get the srid for this object
if value is None:
value_srid = None
elif f.geom_type == 'RASTER':
elif f.geom_type == 'RASTER' and isinstance(value, six.string_types):
value_srid = get_pgraster_srid(value)
else:
value_srid = value.srid
@@ -288,7 +342,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
# is not equal to the field srid.
if value_srid is None or value_srid == f.srid:
placeholder = '%s'
elif f.geom_type == 'RASTER':
elif f.geom_type == 'RASTER' and isinstance(value, six.string_types):
placeholder = '%s((%%s)::raster, %s)' % (self.transform, f.srid)
else:
placeholder = '%s(%%s, %s)' % (self.transform, f.srid)

View File

@@ -1,7 +1,10 @@
from django.contrib.gis import forms
from django.contrib.gis.db.models.lookups import gis_lookups
from django.contrib.gis.db.models.lookups import (
RasterBandTransform, gis_lookups,
)
from django.contrib.gis.db.models.proxy import SpatialProxy
from django.contrib.gis.gdal import HAS_GDAL
from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.geometry.backend import Geometry, GeometryException
from django.core.exceptions import ImproperlyConfigured
from django.db.models.expressions import Expression
@@ -157,6 +160,82 @@ class BaseSpatialField(Field):
"""
return connection.ops.get_geom_placeholder(self, value, compiler)
def get_srid(self, obj):
"""
Return the default SRID for the given geometry or raster, taking into
account the SRID set for the field. For example, if the input geometry
or raster doesn't have an SRID, then the SRID of the field will be
returned.
"""
srid = obj.srid # SRID of given geometry.
if srid is None or self.srid == -1 or (srid == -1 and self.srid != -1):
return self.srid
else:
return srid
def get_db_prep_save(self, value, connection):
"""
Prepare the value for saving in the database.
"""
if not value:
return None
else:
return connection.ops.Adapter(self.get_prep_value(value))
def get_prep_value(self, value):
"""
Spatial lookup values are either a parameter that is (or may be
converted to) a geometry or raster, or a sequence of lookup values
that begins with a geometry or raster. This routine sets up the
geometry or raster value properly and preserves any other lookup
parameters.
"""
from django.contrib.gis.gdal import GDALRaster
value = super(BaseSpatialField, self).get_prep_value(value)
# For IsValid lookups, boolean values are allowed.
if isinstance(value, (Expression, bool)):
return value
elif isinstance(value, (tuple, list)):
obj = value[0]
seq_value = True
else:
obj = value
seq_value = False
# When the input is not a geometry or raster, attempt to construct one
# from the given string input.
if isinstance(obj, (Geometry, GDALRaster)):
pass
elif isinstance(obj, (bytes, six.string_types)) or hasattr(obj, '__geo_interface__'):
try:
obj = Geometry(obj)
except (GeometryException, GDALException):
try:
obj = GDALRaster(obj)
except GDALException:
raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj)
elif isinstance(obj, dict):
try:
obj = GDALRaster(obj)
except GDALException:
raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj)
else:
raise ValueError('Cannot use object with type %s for a spatial lookup parameter.' % type(obj).__name__)
# Assigning the SRID value.
obj.srid = self.get_srid(obj)
if seq_value:
lookup_val = [obj]
lookup_val.extend(value[1:])
return tuple(lookup_val)
else:
return obj
for klass in gis_lookups.values():
BaseSpatialField.register_lookup(klass)
class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
"""
@@ -224,6 +303,8 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
value properly, and preserve any other lookup parameters before
returning to the caller.
"""
from django.contrib.gis.gdal import GDALRaster
value = super(GeometryField, self).get_prep_value(value)
if isinstance(value, (Expression, bool)):
return value
@@ -236,7 +317,7 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
# When the input is not a GEOS geometry, attempt to construct one
# from the given string input.
if isinstance(geom, Geometry):
if isinstance(geom, (Geometry, GDALRaster)):
pass
elif isinstance(geom, (bytes, six.string_types)) or hasattr(geom, '__geo_interface__'):
try:
@@ -265,18 +346,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
value.srid = self.srid
return value
def get_srid(self, geom):
"""
Returns the default SRID for the given geometry, taking into account
the SRID set for the field. For example, if the input geometry
has no SRID, then that of the field will be returned.
"""
gsrid = geom.srid # SRID of given geometry.
if gsrid is None or self.srid == -1 or (gsrid == -1 and self.srid != -1):
return self.srid
else:
return gsrid
# ### Routines overloaded from Field ###
def contribute_to_class(self, cls, name, **kwargs):
super(GeometryField, self).contribute_to_class(cls, name, **kwargs)
@@ -316,17 +385,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField):
params = [connection.ops.Adapter(value)]
return params
def get_db_prep_save(self, value, connection):
"Prepares the value for saving in the database."
if not value:
return None
else:
return connection.ops.Adapter(self.get_prep_value(value))
for klass in gis_lookups.values():
GeometryField.register_lookup(klass)
# The OpenGIS Geometry Type Fields
class PointField(GeometryField):
@@ -387,6 +445,7 @@ class RasterField(BaseSpatialField):
description = _("Raster Field")
geom_type = 'RASTER'
geography = False
def __init__(self, *args, **kwargs):
if not HAS_GDAL:
@@ -421,3 +480,15 @@ class RasterField(BaseSpatialField):
# delays the instantiation of the objects to the moment of evaluation
# of the raster attribute.
setattr(cls, self.attname, SpatialProxy(GDALRaster, self))
def get_transform(self, name):
try:
band_index = int(name)
return type(
'SpecificRasterBandTransform',
(RasterBandTransform, ),
{'band_index': band_index}
)
except ValueError:
pass
return super(RasterField, self).get_transform(name)

View File

@@ -5,16 +5,23 @@ import re
from django.core.exceptions import FieldDoesNotExist
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Expression
from django.db.models.lookups import BuiltinLookup, Lookup
from django.db.models.lookups import BuiltinLookup, Lookup, Transform
from django.utils import six
gis_lookups = {}
class RasterBandTransform(Transform):
def as_sql(self, compiler, connection):
return compiler.compile(self.lhs)
class GISLookup(Lookup):
sql_template = None
transform_func = None
distance = False
band_rhs = None
band_lhs = None
def __init__(self, *args, **kwargs):
super(GISLookup, self).__init__(*args, **kwargs)
@@ -28,10 +35,10 @@ class GISLookup(Lookup):
'point, 'the_geom', or a related lookup on a geographic field like
'address__point'.
If a GeometryField exists according to the given lookup on the model
options, it will be returned. Otherwise returns None.
If a BaseSpatialField exists according to the given lookup on the model
options, it will be returned. Otherwise return None.
"""
from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.fields import BaseSpatialField
# This takes into account the situation where the lookup is a
# lookup to a related geographic field, e.g., 'address__point'.
field_list = lookup.split(LOOKUP_SEP)
@@ -55,11 +62,34 @@ class GISLookup(Lookup):
return False
# Finally, make sure we got a Geographic field and return.
if isinstance(geo_fld, GeometryField):
if isinstance(geo_fld, BaseSpatialField):
return geo_fld
else:
return False
def process_band_indices(self, only_lhs=False):
"""
Extract the lhs band index from the band transform class and the rhs
band index from the input tuple.
"""
# PostGIS band indices are 1-based, so the band index needs to be
# increased to be consistent with the GDALRaster band indices.
if only_lhs:
self.band_rhs = 1
self.band_lhs = self.lhs.band_index + 1
return
if isinstance(self.lhs, RasterBandTransform):
self.band_lhs = self.lhs.band_index + 1
else:
self.band_lhs = 1
self.band_rhs = self.rhs[1]
if len(self.rhs) == 1:
self.rhs = self.rhs[0]
else:
self.rhs = (self.rhs[0], ) + self.rhs[2:]
def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class
if isinstance(value, (tuple, list)):
@@ -70,10 +100,9 @@ class GISLookup(Lookup):
return ('%s', params)
def process_rhs(self, compiler, connection):
rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection)
if hasattr(self.rhs, '_as_sql'):
# If rhs is some QuerySet, don't touch it
return rhs, rhs_params
return super(GISLookup, self).process_rhs(compiler, connection)
geom = self.rhs
if isinstance(self.rhs, Col):
@@ -85,9 +114,19 @@ class GISLookup(Lookup):
raise ValueError('No geographic field found in expression.')
self.rhs.srid = geo_fld.srid
elif isinstance(self.rhs, Expression):
raise ValueError('Complex expressions not supported for GeometryField')
raise ValueError('Complex expressions not supported for spatial fields.')
elif isinstance(self.rhs, (list, tuple)):
geom = self.rhs[0]
# Check if a band index was passed in the query argument.
if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or
(len(self.rhs) == 3 and self.lookup_name == 'relate')):
self.process_band_indices()
elif len(self.rhs) > 2:
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
elif isinstance(self.lhs, RasterBandTransform):
self.process_band_indices(only_lhs=True)
rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
return rhs, rhs_params
@@ -274,6 +313,8 @@ class IsValidLookup(BuiltinLookup):
lookup_name = 'isvalid'
def as_sql(self, compiler, connection):
if self.lhs.field.geom_type == 'RASTER':
raise ValueError('The isvalid lookup is only available on geometry fields.')
gis_op = connection.ops.gis_operators[self.lookup_name]
sql, params = self.process_lhs(compiler, connection)
sql = '%(func)s(%(lhs)s)' % {'func': gis_op.func, 'lhs': sql}
@@ -323,9 +364,17 @@ class DistanceLookupBase(GISLookup):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def process_rhs(self, compiler, connection):
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4:
raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid':
raise ValueError("For 4-element tuples the last argument must be the 'speroid' directive.")
# Check if the second parameter is a band index.
if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid':
self.process_band_indices()
params = [connection.ops.Adapter(self.rhs[0])]
# Getting the distance parameter in the units of the field.
dist_param = self.rhs[1]
if hasattr(dist_param, 'resolve_expression'):