From a34fba5e596a3ec95bf284fd77b1609e71a65019 Mon Sep 17 00:00:00 2001 From: Claude Paroz Date: Thu, 15 Jan 2015 11:35:35 +0100 Subject: [PATCH] Simplified a bit GeoAggregate classes Thanks Josh Smeaton for the review. Refs #24152. --- .../contrib/gis/db/backends/base/features.py | 8 ++- .../gis/db/backends/base/operations.py | 17 ++--- .../gis/db/backends/mysql/operations.py | 3 + .../gis/db/backends/oracle/operations.py | 19 ++---- .../gis/db/backends/postgis/operations.py | 20 ++---- .../gis/db/backends/spatialite/operations.py | 28 ++++---- django/contrib/gis/db/models/aggregates.py | 21 +++--- .../contrib/gis/db/models/sql/aggregates.py | 64 ------------------- 8 files changed, 46 insertions(+), 134 deletions(-) diff --git a/django/contrib/gis/db/backends/base/features.py b/django/contrib/gis/db/backends/base/features.py index faf471b6aa..f43724dcf7 100644 --- a/django/contrib/gis/db/backends/base/features.py +++ b/django/contrib/gis/db/backends/base/features.py @@ -1,5 +1,7 @@ from functools import partial +from django.contrib.gis.db.models import aggregates + class BaseSpatialFeatures(object): gis_enabled = True @@ -61,15 +63,15 @@ class BaseSpatialFeatures(object): # Specifies whether the Collect and Extent aggregates are supported by the database @property def supports_collect_aggr(self): - return 'Collect' in self.connection.ops.valid_aggregates + return aggregates.Collect not in self.connection.ops.disallowed_aggregates @property def supports_extent_aggr(self): - return 'Extent' in self.connection.ops.valid_aggregates + return aggregates.Extent not in self.connection.ops.disallowed_aggregates @property def supports_make_line_aggr(self): - return 'MakeLine' in self.connection.ops.valid_aggregates + return aggregates.MakeLine not in self.connection.ops.disallowed_aggregates def __init__(self, *args): super(BaseSpatialFeatures, self).__init__(*args) diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index 4365b64542..dc2fad025b 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -46,11 +46,7 @@ class BaseSpatialOperations(object): union = False # Aggregates - collect = False - extent = False - extent3d = False - make_line = False - unionagg = False + disallowed_aggregates = () # Serialization geohash = False @@ -103,12 +99,13 @@ class BaseSpatialOperations(object): raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') def check_aggregate_support(self, aggregate): - if aggregate.contains_aggregate == 'gis': - return aggregate.name in self.valid_aggregates - return super(BaseSpatialOperations, self).check_aggregate_support(aggregate) + if isinstance(aggregate, self.disallowed_aggregates): + raise NotImplementedError( + "%s spatial aggregation is not supported by this database backend." % aggregate.name + ) + super(BaseSpatialOperations, self).check_aggregate_support(aggregate) - # Spatial SQL Construction - def spatial_aggregate_sql(self, agg): + def spatial_aggregate_name(self, agg_name): raise NotImplementedError('Aggregate support not implemented for this spatial backend.') # Routines for getting the OGC-compliant models. diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 6e76674364..9d0d2a9928 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -1,6 +1,7 @@ from django.contrib.gis.db.backends.base.adapter import WKTAdapter from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.utils import SpatialOperator +from django.contrib.gis.db.models import aggregates from django.db.backends.mysql.operations import DatabaseOperations @@ -30,6 +31,8 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): 'within': SpatialOperator(func='MBRWithin'), } + disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union) + def geo_db_type(self, f): return f.geom_type diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index f6ef5415e2..277126a77e 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -12,6 +12,7 @@ import re from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.utils import SpatialOperator +from django.contrib.gis.db.models import aggregates from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Distance from django.db.backends.oracle.base import Database @@ -56,7 +57,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): name = 'oracle' oracle = True - valid_aggregates = {'Union', 'Extent'} + disallowed_aggregates = (aggregates.Collect, aggregates.Extent3D, aggregates.MakeLine) Adapter = OracleSpatialAdapter Adaptor = Adapter # Backwards-compatibility alias. @@ -223,20 +224,12 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): else: return 'SDO_GEOMETRY(%%s, %s)' % f.srid - def spatial_aggregate_sql(self, agg): + def spatial_aggregate_name(self, agg_name): """ - Returns the spatial aggregate SQL template and function for the - given Aggregate instance. + Returns the spatial aggregate SQL name. """ - agg_name = agg.__class__.__name__.lower() - if agg_name == 'union': - agg_name += 'agg' - if agg.is_extent: - sql_template = '%(function)s(%(expressions)s)' - else: - sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' - sql_function = getattr(self, agg_name) - return sql_template, sql_function + agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower() + return getattr(self, agg_name) # Routines for getting the OGC-compliant models. def geometry_columns(self): diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 4958aee315..44cab5d12a 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -49,7 +49,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): geography = True geom_func_prefix = 'ST_' version_regex = re.compile(r'^(?P\d)\.(?P\d)\.(?P\d+)') - valid_aggregates = {'Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'} Adapter = PostGISAdapter Adaptor = Adapter # Backwards-compatibility alias. @@ -360,20 +359,11 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): else: raise Exception('Could not determine PROJ.4 version from PostGIS.') - def spatial_aggregate_sql(self, agg): - """ - Returns the spatial aggregate SQL template and function for the - given Aggregate instance. - """ - agg_name = agg.__class__.__name__ - if not self.check_aggregate_support(agg): - raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name) - agg_name = agg_name.lower() - if agg_name == 'union': - agg_name += 'agg' - sql_template = '%(function)s(%(expressions)s)' - sql_function = getattr(self, agg_name) - return sql_template, sql_function + def spatial_aggregate_name(self, agg_name): + if agg_name == 'Extent3D': + return self.extent3d + else: + return self.geom_func_prefix + agg_name # Routines for getting the OGC-compliant models. def geometry_columns(self): diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 0d774c9a6d..c4700e7aef 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -4,6 +4,7 @@ import sys from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter +from django.contrib.gis.db.models import aggregates from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Distance from django.core.exceptions import ImproperlyConfigured @@ -18,13 +19,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): spatialite = True version_regex = re.compile(r'^(?P\d)\.(?P\d)\.(?P\d+)') - @property - def valid_aggregates(self): - if self.spatial_version >= (3, 0, 0): - return {'Collect', 'Extent', 'Union'} - else: - return {'Union'} - Adapter = SpatiaLiteAdapter Adaptor = Adapter # Backwards-compatibility alias. @@ -109,6 +103,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): return False return True + @cached_property + def disallowed_aggregates(self): + disallowed = (aggregates.Extent3D, aggregates.MakeLine) + if self.spatial_version < (3, 0, 0): + disallowed += (aggregates.Collect, aggregates.Extent) + return disallowed + @cached_property def gml(self): return 'AsGML' if self._version_greater_2_4_0_rc4 else None @@ -237,20 +238,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): return (version, major, minor1, minor2) - def spatial_aggregate_sql(self, agg): + def spatial_aggregate_name(self, agg_name): """ Returns the spatial aggregate SQL template and function for the given Aggregate instance. """ - agg_name = agg.__class__.__name__ - if not self.check_aggregate_support(agg): - raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name) - agg_name = agg_name.lower() - if agg_name == 'union': - agg_name += 'agg' - sql_template = '%(function)s(%(expressions)s)' - sql_function = getattr(self, agg_name) - return sql_template, sql_function + agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower() + return getattr(self, agg_name) # Routines for getting the OGC-compliant models. def geometry_columns(self): diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index b775774c19..d7db0a8dd8 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -5,24 +5,21 @@ __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] class GeoAggregate(Aggregate): - template = None function = None - contains_aggregate = 'gis' is_extent = False def as_sql(self, compiler, connection): - if connection.ops.oracle: - if not hasattr(self, 'tolerance'): - self.tolerance = 0.05 - self.extra['tolerance'] = self.tolerance - - template, function = connection.ops.spatial_aggregate_sql(self) - if template is None: - template = '%(function)s(%(expressions)s)' - self.extra['template'] = self.extra.get('template', template) - self.extra['function'] = self.extra.get('function', function) + self.function = connection.ops.spatial_aggregate_name(self.name) return super(GeoAggregate, self).as_sql(compiler, connection) + def as_oracle(self, compiler, connection): + if not hasattr(self, 'tolerance'): + self.tolerance = 0.05 + self.extra['tolerance'] = self.tolerance + if not self.is_extent: + self.template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' + return self.as_sql(compiler, connection) + def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False): c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize) if not isinstance(self.expressions[0].output_field, GeometryField): diff --git a/django/contrib/gis/db/models/sql/aggregates.py b/django/contrib/gis/db/models/sql/aggregates.py index 65ccc960df..fe0e396f2f 100644 --- a/django/contrib/gis/db/models/sql/aggregates.py +++ b/django/contrib/gis/db/models/sql/aggregates.py @@ -1,6 +1,5 @@ from django.db.models.sql import aggregates from django.db.models.sql.aggregates import * # NOQA -from django.contrib.gis.db.models.fields import GeometryField __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__ @@ -10,66 +9,3 @@ warnings.warn( "django.contrib.gis.db.models.sql.aggregates is deprecated. Use " "django.contrib.gis.db.models.aggregates instead.", RemovedInDjango20Warning, stacklevel=2) - - -class GeoAggregate(Aggregate): - # Default SQL template for spatial aggregates. - sql_template = '%(function)s(%(expressions)s)' - - # Flags for indicating the type of the aggregate. - is_extent = False - - def __init__(self, col, source=None, is_summary=False, tolerance=0.05, **extra): - super(GeoAggregate, self).__init__(col, source, is_summary, **extra) - - # Required by some Oracle aggregates. - self.tolerance = tolerance - - # Can't use geographic aggregates on non-geometry fields. - if not isinstance(self.source, GeometryField): - raise ValueError('Geospatial aggregates only allowed on geometry fields.') - - def as_sql(self, compiler, connection): - "Return the aggregate, rendered as SQL with parameters." - - if connection.ops.oracle: - self.extra['tolerance'] = self.tolerance - - params = [] - - if hasattr(self.col, 'as_sql'): - field_name, params = self.col.as_sql(compiler, connection) - elif isinstance(self.col, (list, tuple)): - field_name = '.'.join(compiler.quote_name_unless_alias(c) for c in self.col) - else: - field_name = self.col - - sql_template, sql_function = connection.ops.spatial_aggregate_sql(self) - - substitutions = { - 'function': sql_function, - 'expressions': field_name - } - substitutions.update(self.extra) - - return sql_template % substitutions, params - - -class Collect(GeoAggregate): - pass - - -class Extent(GeoAggregate): - is_extent = '2D' - - -class Extent3D(GeoAggregate): - is_extent = '3D' - - -class MakeLine(GeoAggregate): - pass - - -class Union(GeoAggregate): - pass