mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #29048 -- Added **extra_context to database function as_vendor() methods.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							08f360355a
						
					
				
				
					commit
					83b04d4f88
				
			| @@ -26,10 +26,10 @@ class GeoAggregate(Aggregate): | |||||||
|             **extra_context |             **extra_context | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) |         tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) | ||||||
|         template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' |         template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' | ||||||
|         return self.as_sql(compiler, connection, template=template, tolerance=tolerance) |         return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context) | ||||||
|  |  | ||||||
|     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): |     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): | ||||||
|         c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) |         c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) | ||||||
|   | |||||||
| @@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin: | |||||||
|     By default, Decimal values are converted to str by the SQLite backend, which |     By default, Decimal values are converted to str by the SQLite backend, which | ||||||
|     is not acceptable by the GIS functions expecting numeric values. |     is not acceptable by the GIS functions expecting numeric values. | ||||||
|     """ |     """ | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         for expr in self.get_source_expressions(): |         for expr in self.get_source_expressions(): | ||||||
|             if hasattr(expr, 'value') and isinstance(expr.value, Decimal): |             if hasattr(expr, 'value') and isinstance(expr.value, Decimal): | ||||||
|                 expr.value = float(expr.value) |                 expr.value = float(expr.value) | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OracleToleranceMixin: | class OracleToleranceMixin: | ||||||
|     tolerance = 0.05 |     tolerance = 0.05 | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         tol = self.extra.get('tolerance', self.tolerance) |         tol = self.extra.get('tolerance', self.tolerance) | ||||||
|         return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) |         return self.as_sql( | ||||||
|  |             compiler, connection, | ||||||
|  |             template="%%(function)s(%%(expressions)s, %s)" % tol, | ||||||
|  |             **extra_context | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Area(OracleToleranceMixin, GeoFunc): | class Area(OracleToleranceMixin, GeoFunc): | ||||||
| @@ -181,11 +185,11 @@ class AsGML(GeoFunc): | |||||||
|  |  | ||||||
|  |  | ||||||
| class AsKML(AsGML): | class AsKML(AsGML): | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         # No version parameter |         # No version parameter | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         clone.set_source_expressions(self.get_source_expressions()[1:]) |         clone.set_source_expressions(self.get_source_expressions()[1:]) | ||||||
|         return clone.as_sql(compiler, connection) |         return clone.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AsSVG(GeoFunc): | class AsSVG(GeoFunc): | ||||||
| @@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc): | |||||||
|     def __init__(self, expression, num_seg=48, **extra): |     def __init__(self, expression, num_seg=48, **extra): | ||||||
|         super().__init__(expression, num_seg, **extra) |         super().__init__(expression, num_seg, **extra) | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         clone.set_source_expressions([self.get_source_expressions()[0]]) |         clone.set_source_expressions([self.get_source_expressions()[0]]) | ||||||
|         return super(BoundingCircle, clone).as_oracle(compiler, connection) |         return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): | class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): | ||||||
| @@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | |||||||
|             self.spheroid = self._handle_param(spheroid, 'spheroid', bool) |             self.spheroid = self._handle_param(spheroid, 'spheroid', bool) | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         function = None |         function = None | ||||||
|         expr2 = clone.source_expressions[1] |         expr2 = clone.source_expressions[1] | ||||||
| @@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | |||||||
|                 clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) |                 clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) | ||||||
|             else: |             else: | ||||||
|                 function = connection.ops.spatial_function_name('DistanceSphere') |                 function = connection.ops.spatial_function_name('DistanceSphere') | ||||||
|         return super(Distance, clone).as_sql(compiler, connection, function=function) |         return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection, **extra_context): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         if self.geo_field.geodetic(connection): |         if self.geo_field.geodetic(connection): | ||||||
| @@ -300,12 +304,12 @@ class GeoHash(GeoFunc): | |||||||
|             expressions.append(self._handle_param(precision, 'precision', int)) |             expressions.append(self._handle_param(precision, 'precision', int)) | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         # If no precision is provided, set it to the maximum. |         # If no precision is provided, set it to the maximum. | ||||||
|         if len(clone.source_expressions) < 2: |         if len(clone.source_expressions) < 2: | ||||||
|             clone.source_expressions.append(Value(100)) |             clone.source_expressions.append(Value(100)) | ||||||
|         return clone.as_sql(compiler, connection) |         return clone.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): | class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): | ||||||
| @@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | |||||||
|             raise NotSupportedError("This backend doesn't support Length on geodetic fields") |             raise NotSupportedError("This backend doesn't support Length on geodetic fields") | ||||||
|         return super().as_sql(compiler, connection, **extra_context) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         function = None |         function = None | ||||||
|         if self.source_is_geography(): |         if self.source_is_geography(): | ||||||
| @@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | |||||||
|             dim = min(f.dim for f in self.get_source_fields() if f) |             dim = min(f.dim for f in self.get_source_fields() if f) | ||||||
|             if dim > 2: |             if dim > 2: | ||||||
|                 function = connection.ops.length3d |                 function = connection.ops.length3d | ||||||
|         return super(Length, clone).as_sql(compiler, connection, function=function) |         return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         function = None |         function = None | ||||||
|         if self.geo_field.geodetic(connection): |         if self.geo_field.geodetic(connection): | ||||||
|             function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' |             function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' | ||||||
|         return super().as_sql(compiler, connection, function=function) |         return super().as_sql(compiler, connection, function=function, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class LineLocatePoint(GeoFunc): | class LineLocatePoint(GeoFunc): | ||||||
| @@ -383,19 +387,19 @@ class NumPoints(GeoFunc): | |||||||
| class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||||
|     arity = 1 |     arity = 1 | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         function = None |         function = None | ||||||
|         if self.geo_field.geodetic(connection) and not self.source_is_geography(): |         if self.geo_field.geodetic(connection) and not self.source_is_geography(): | ||||||
|             raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") |             raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") | ||||||
|         dim = min(f.dim for f in self.get_source_fields()) |         dim = min(f.dim for f in self.get_source_fields()) | ||||||
|         if dim > 2: |         if dim > 2: | ||||||
|             function = connection.ops.perimeter3d |             function = connection.ops.perimeter3d | ||||||
|         return super().as_sql(compiler, connection, function=function) |         return super().as_sql(compiler, connection, function=function, **extra_context) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         if self.geo_field.geodetic(connection): |         if self.geo_field.geodetic(connection): | ||||||
|             raise NotSupportedError("Perimeter cannot use a non-projected field.") |             raise NotSupportedError("Perimeter cannot use a non-projected field.") | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): | class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): | ||||||
| @@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Translate(Scale): | class Translate(Scale): | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         if len(self.source_expressions) < 4: |         if len(self.source_expressions) < 4: | ||||||
|             # Always provide the z parameter for ST_Translate |             # Always provide the z parameter for ST_Translate | ||||||
|             clone.source_expressions.append(Value(0)) |             clone.source_expressions.append(Value(0)) | ||||||
|         return super(Translate, clone).as_sqlite(compiler, connection) |         return super(Translate, clone).as_sqlite(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Union(OracleToleranceMixin, GeomOutputGeoFunc): | class Union(OracleToleranceMixin, GeomOutputGeoFunc): | ||||||
|   | |||||||
| @@ -64,7 +64,10 @@ class Aggregate(Func): | |||||||
|             if connection.features.supports_aggregate_filter_clause: |             if connection.features.supports_aggregate_filter_clause: | ||||||
|                 filter_sql, filter_params = self.filter.as_sql(compiler, connection) |                 filter_sql, filter_params = self.filter.as_sql(compiler, connection) | ||||||
|                 template = self.filter_template % extra_context.get('template', self.template) |                 template = self.filter_template % extra_context.get('template', self.template) | ||||||
|                 sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql) |                 sql, params = super().as_sql( | ||||||
|  |                     compiler, connection, template=template, filter=filter_sql, | ||||||
|  |                     **extra_context | ||||||
|  |                 ) | ||||||
|                 return sql, params + filter_params |                 return sql, params + filter_params | ||||||
|             else: |             else: | ||||||
|                 copy = self.copy() |                 copy = self.copy() | ||||||
| @@ -92,20 +95,20 @@ class Avg(Aggregate): | |||||||
|             return FloatField() |             return FloatField() | ||||||
|         return super()._resolve_output_field() |         return super()._resolve_output_field() | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         sql, params = super().as_sql(compiler, connection) |         sql, params = super().as_sql(compiler, connection, **extra_context) | ||||||
|         if self.output_field.get_internal_type() == 'DurationField': |         if self.output_field.get_internal_type() == 'DurationField': | ||||||
|             sql = 'CAST(%s as SIGNED)' % sql |             sql = 'CAST(%s as SIGNED)' % sql | ||||||
|         return sql, params |         return sql, params | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         if self.output_field.get_internal_type() == 'DurationField': |         if self.output_field.get_internal_type() == 'DurationField': | ||||||
|             expression = self.get_source_expressions()[0] |             expression = self.get_source_expressions()[0] | ||||||
|             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval |             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval | ||||||
|             return compiler.compile( |             return compiler.compile( | ||||||
|                 SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) |                 SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) | ||||||
|             ) |             ) | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Count(Aggregate): | class Count(Aggregate): | ||||||
| @@ -157,20 +160,20 @@ class Sum(Aggregate): | |||||||
|     function = 'SUM' |     function = 'SUM' | ||||||
|     name = 'Sum' |     name = 'Sum' | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         sql, params = super().as_sql(compiler, connection) |         sql, params = super().as_sql(compiler, connection, **extra_context) | ||||||
|         if self.output_field.get_internal_type() == 'DurationField': |         if self.output_field.get_internal_type() == 'DurationField': | ||||||
|             sql = 'CAST(%s as SIGNED)' % sql |             sql = 'CAST(%s as SIGNED)' % sql | ||||||
|         return sql, params |         return sql, params | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         if self.output_field.get_internal_type() == 'DurationField': |         if self.output_field.get_internal_type() == 'DurationField': | ||||||
|             expression = self.get_source_expressions()[0] |             expression = self.get_source_expressions()[0] | ||||||
|             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval |             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval | ||||||
|             return compiler.compile( |             return compiler.compile( | ||||||
|                 SecondsToInterval(Sum(IntervalToSeconds(expression))) |                 SecondsToInterval(Sum(IntervalToSeconds(expression))) | ||||||
|             ) |             ) | ||||||
|         return super().as_sql(compiler, connection) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Variance(Aggregate): | class Variance(Aggregate): | ||||||
|   | |||||||
| @@ -14,16 +14,16 @@ class Cast(Func): | |||||||
|         extra_context['db_type'] = self.output_field.cast_db_type(connection) |         extra_context['db_type'] = self.output_field.cast_db_type(connection) | ||||||
|         return super().as_sql(compiler, connection, **extra_context) |         return super().as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         # MySQL doesn't support explicit cast to float. |         # MySQL doesn't support explicit cast to float. | ||||||
|         template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None |         template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None | ||||||
|         return self.as_sql(compiler, connection, template=template) |         return self.as_sql(compiler, connection, template=template, **extra_context) | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         # CAST would be valid too, but the :: shortcut syntax is more readable. |         # CAST would be valid too, but the :: shortcut syntax is more readable. | ||||||
|         # 'expressions' is wrapped in parentheses in case it's a complex |         # 'expressions' is wrapped in parentheses in case it's a complex | ||||||
|         # expression. |         # expression. | ||||||
|         return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s') |         return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Coalesce(Func): | class Coalesce(Func): | ||||||
| @@ -35,7 +35,7 @@ class Coalesce(Func): | |||||||
|             raise ValueError('Coalesce must take at least two expressions') |             raise ValueError('Coalesce must take at least two expressions') | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), |         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), | ||||||
|         # so convert all fields to NCLOB when that type is expected. |         # so convert all fields to NCLOB when that type is expected. | ||||||
|         if self.output_field.get_internal_type() == 'TextField': |         if self.output_field.get_internal_type() == 'TextField': | ||||||
| @@ -47,8 +47,8 @@ class Coalesce(Func): | |||||||
|             ] |             ] | ||||||
|             clone = self.copy() |             clone = self.copy() | ||||||
|             clone.set_source_expressions(expressions) |             clone.set_source_expressions(expressions) | ||||||
|             return super(Coalesce, clone).as_sql(compiler, connection) |             return super(Coalesce, clone).as_sql(compiler, connection, **extra_context) | ||||||
|         return self.as_sql(compiler, connection) |         return self.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Greatest(Func): | class Greatest(Func): | ||||||
| @@ -66,9 +66,9 @@ class Greatest(Func): | |||||||
|             raise ValueError('Greatest must take at least two expressions') |             raise ValueError('Greatest must take at least two expressions') | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         """Use the MAX function on SQLite.""" |         """Use the MAX function on SQLite.""" | ||||||
|         return super().as_sqlite(compiler, connection, function='MAX') |         return super().as_sqlite(compiler, connection, function='MAX', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Least(Func): | class Least(Func): | ||||||
| @@ -86,6 +86,6 @@ class Least(Func): | |||||||
|             raise ValueError('Least must take at least two expressions') |             raise ValueError('Least must take at least two expressions') | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         """Use the MIN function on SQLite.""" |         """Use the MIN function on SQLite.""" | ||||||
|         return super().as_sqlite(compiler, connection, function='MIN') |         return super().as_sqlite(compiler, connection, function='MIN', **extra_context) | ||||||
|   | |||||||
| @@ -159,11 +159,11 @@ class Now(Func): | |||||||
|     template = 'CURRENT_TIMESTAMP' |     template = 'CURRENT_TIMESTAMP' | ||||||
|     output_field = fields.DateTimeField() |     output_field = fields.DateTimeField() | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the |         # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the | ||||||
|         # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with |         # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with | ||||||
|         # other databases. |         # other databases. | ||||||
|         return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()') |         return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TruncBase(TimezoneMixin, Transform): | class TruncBase(TimezoneMixin, Transform): | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from django.db.models.functions import Cast | |||||||
|  |  | ||||||
| class DecimalInputMixin: | class DecimalInputMixin: | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         # Cast FloatField to DecimalField as PostgreSQL doesn't support the |         # Cast FloatField to DecimalField as PostgreSQL doesn't support the | ||||||
|         # following function signatures: |         # following function signatures: | ||||||
|         # - LOG(double, double) |         # - LOG(double, double) | ||||||
| @@ -20,7 +20,7 @@ class DecimalInputMixin: | |||||||
|             Cast(expression, output_field) if isinstance(expression.output_field, FloatField) |             Cast(expression, output_field) if isinstance(expression.output_field, FloatField) | ||||||
|             else expression for expression in self.get_source_expressions() |             else expression for expression in self.get_source_expressions() | ||||||
|         ]) |         ]) | ||||||
|         return clone.as_sql(compiler, connection) |         return clone.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutputFieldMixin: | class OutputFieldMixin: | ||||||
| @@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func): | |||||||
|     function = 'ATAN2' |     function = 'ATAN2' | ||||||
|     arity = 2 |     arity = 2 | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0): |         if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0): | ||||||
|             return self.as_sql(compiler, connection) |             return self.as_sql(compiler, connection) | ||||||
|         # This function is usually ATan2(y, x), returning the inverse tangent |         # This function is usually ATan2(y, x), returning the inverse tangent | ||||||
| @@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func): | |||||||
|             Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) |             Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) | ||||||
|             else expression for expression in self.get_source_expressions()[::-1] |             else expression for expression in self.get_source_expressions()[::-1] | ||||||
|         ]) |         ]) | ||||||
|         return clone.as_sql(compiler, connection) |         return clone.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Ceil(Transform): | class Ceil(Transform): | ||||||
|     function = 'CEILING' |     function = 'CEILING' | ||||||
|     lookup_name = 'ceil' |     lookup_name = 'ceil' | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='CEIL') |         return super().as_sql(compiler, connection, function='CEIL', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Cos(OutputFieldMixin, Transform): | class Cos(OutputFieldMixin, Transform): | ||||||
| @@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform): | |||||||
|     function = 'COT' |     function = 'COT' | ||||||
|     lookup_name = 'cot' |     lookup_name = 'cot' | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))') |         return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Degrees(OutputFieldMixin, Transform): | class Degrees(OutputFieldMixin, Transform): | ||||||
|     function = 'DEGREES' |     function = 'DEGREES' | ||||||
|     lookup_name = 'degrees' |     lookup_name = 'degrees' | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi) |         return super().as_sql( | ||||||
|  |             compiler, connection, | ||||||
|  |             template='((%%(expressions)s) * 180 / %s)' % math.pi, | ||||||
|  |             **extra_context | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Exp(OutputFieldMixin, Transform): | class Exp(OutputFieldMixin, Transform): | ||||||
| @@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func): | |||||||
|     function = 'LOG' |     function = 'LOG' | ||||||
|     arity = 2 |     arity = 2 | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         if not getattr(connection.ops, 'spatialite', False): |         if not getattr(connection.ops, 'spatialite', False): | ||||||
|             return self.as_sql(compiler, connection) |             return self.as_sql(compiler, connection) | ||||||
|         # This function is usually Log(b, x) returning the logarithm of x to |         # This function is usually Log(b, x) returning the logarithm of x to | ||||||
|         # the base b, but on SpatiaLite it's Log(x, b). |         # the base b, but on SpatiaLite it's Log(x, b). | ||||||
|         clone = self.copy() |         clone = self.copy() | ||||||
|         clone.set_source_expressions(self.get_source_expressions()[::-1]) |         clone.set_source_expressions(self.get_source_expressions()[::-1]) | ||||||
|         return clone.as_sql(compiler, connection) |         return clone.as_sql(compiler, connection, **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Mod(DecimalInputMixin, OutputFieldMixin, Func): | class Mod(DecimalInputMixin, OutputFieldMixin, Func): | ||||||
| @@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func): | |||||||
|     function = 'PI' |     function = 'PI' | ||||||
|     arity = 0 |     arity = 0 | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, template=str(math.pi)) |         return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Power(OutputFieldMixin, Func): | class Power(OutputFieldMixin, Func): | ||||||
| @@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform): | |||||||
|     function = 'RADIANS' |     function = 'RADIANS' | ||||||
|     lookup_name = 'radians' |     lookup_name = 'radians' | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi) |         return super().as_sql( | ||||||
|  |             compiler, connection, | ||||||
|  |             template='((%%(expressions)s) * %s / 180)' % math.pi, | ||||||
|  |             **extra_context | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Round(Transform): | class Round(Transform): | ||||||
|   | |||||||
| @@ -22,13 +22,19 @@ class Chr(Transform): | |||||||
|     function = 'CHR' |     function = 'CHR' | ||||||
|     lookup_name = 'chr' |     lookup_name = 'chr' | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql( |         return super().as_sql( | ||||||
|             compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)' |             compiler, connection, function='CHAR', | ||||||
|  |             template='%(function)s(%(expressions)s USING utf16)', | ||||||
|  |             **extra_context | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)') |         return super().as_sql( | ||||||
|  |             compiler, connection, | ||||||
|  |             template='%(function)s(%(expressions)s USING NCHAR_CS)', | ||||||
|  |             **extra_context | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection, **extra_context): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='CHAR', **extra_context) |         return super().as_sql(compiler, connection, function='CHAR', **extra_context) | ||||||
| @@ -41,16 +47,19 @@ class ConcatPair(Func): | |||||||
|     """ |     """ | ||||||
|     function = 'CONCAT' |     function = 'CONCAT' | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         coalesced = self.coalesce() |         coalesced = self.coalesce() | ||||||
|         return super(ConcatPair, coalesced).as_sql( |         return super(ConcatPair, coalesced).as_sql( | ||||||
|             compiler, connection, template='%(expressions)s', arg_joiner=' || ' |             compiler, connection, template='%(expressions)s', arg_joiner=' || ', | ||||||
|  |             **extra_context | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         # Use CONCAT_WS with an empty separator so that NULLs are ignored. |         # Use CONCAT_WS with an empty separator so that NULLs are ignored. | ||||||
|         return super().as_sql( |         return super().as_sql( | ||||||
|             compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)" |             compiler, connection, function='CONCAT_WS', | ||||||
|  |             template="%(function)s('', %(expressions)s)", | ||||||
|  |             **extra_context | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def coalesce(self): |     def coalesce(self): | ||||||
| @@ -117,8 +126,8 @@ class Length(Transform): | |||||||
|     lookup_name = 'length' |     lookup_name = 'length' | ||||||
|     output_field = fields.IntegerField() |     output_field = fields.IntegerField() | ||||||
|  |  | ||||||
|     def as_mysql(self, compiler, connection): |     def as_mysql(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='CHAR_LENGTH') |         return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lower(Transform): | class Lower(Transform): | ||||||
| @@ -199,8 +208,8 @@ class StrIndex(Func): | |||||||
|     arity = 2 |     arity = 2 | ||||||
|     output_field = fields.IntegerField() |     output_field = fields.IntegerField() | ||||||
|  |  | ||||||
|     def as_postgresql(self, compiler, connection): |     def as_postgresql(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='STRPOS') |         return super().as_sql(compiler, connection, function='STRPOS', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Substr(Func): | class Substr(Func): | ||||||
| @@ -220,11 +229,11 @@ class Substr(Func): | |||||||
|             expressions.append(length) |             expressions.append(length) | ||||||
|         super().__init__(*expressions, **extra) |         super().__init__(*expressions, **extra) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='SUBSTR') |         return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         return super().as_sql(compiler, connection, function='SUBSTR') |         return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Trim(Transform): | class Trim(Transform): | ||||||
|   | |||||||
| @@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of | |||||||
| ``NotEqual`` with an ``as_mysql`` method:: | ``NotEqual`` with an ``as_mysql`` method:: | ||||||
|  |  | ||||||
|   class MySQLNotEqual(NotEqual): |   class MySQLNotEqual(NotEqual): | ||||||
|       def as_mysql(self, compiler, connection): |       def as_mysql(self, compiler, connection, **extra_context): | ||||||
|           lhs, lhs_params = self.process_lhs(compiler, connection) |           lhs, lhs_params = self.process_lhs(compiler, connection) | ||||||
|           rhs, rhs_params = self.process_rhs(compiler, connection) |           rhs, rhs_params = self.process_rhs(compiler, connection) | ||||||
|           params = lhs_params + rhs_params |           params = lhs_params + rhs_params | ||||||
|   | |||||||
| @@ -322,11 +322,12 @@ The ``Func`` API is as follows: | |||||||
|                 function = 'CONCAT' |                 function = 'CONCAT' | ||||||
|                 ... |                 ... | ||||||
|  |  | ||||||
|                 def as_mysql(self, compiler, connection): |                 def as_mysql(self, compiler, connection, **extra_context): | ||||||
|                     return super().as_sql( |                     return super().as_sql( | ||||||
|                         compiler, connection, |                         compiler, connection, | ||||||
|                         function='CONCAT_WS', |                         function='CONCAT_WS', | ||||||
|                         template="%(function)s('', %(expressions)s)", |                         template="%(function)s('', %(expressions)s)", | ||||||
|  |                         **extra_context | ||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|         To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must |         To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must | ||||||
|   | |||||||
| @@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase): | |||||||
|         class Greatest(Func): |         class Greatest(Func): | ||||||
|             function = 'GREATEST' |             function = 'GREATEST' | ||||||
|  |  | ||||||
|             def as_sqlite(self, compiler, connection): |             def as_sqlite(self, compiler, connection, **extra_context): | ||||||
|                 return super().as_sql(compiler, connection, function='MAX') |                 return super().as_sql(compiler, connection, function='MAX', **extra_context) | ||||||
|  |  | ||||||
|         qs = Publisher.objects.annotate( |         qs = Publisher.objects.annotate( | ||||||
|             price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) |             price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class Div3Transform(models.Transform): | |||||||
|         lhs, lhs_params = compiler.compile(self.lhs) |         lhs, lhs_params = compiler.compile(self.lhs) | ||||||
|         return '(%s) %%%% 3' % lhs, lhs_params |         return '(%s) %%%% 3' % lhs, lhs_params | ||||||
|  |  | ||||||
|     def as_oracle(self, compiler, connection): |     def as_oracle(self, compiler, connection, **extra_context): | ||||||
|         lhs, lhs_params = compiler.compile(self.lhs) |         lhs, lhs_params = compiler.compile(self.lhs) | ||||||
|         return 'mod(%s, 3)' % lhs, lhs_params |         return 'mod(%s, 3)' % lhs, lhs_params | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user