mirror of
				https://github.com/django/django.git
				synced 2025-10-26 15:16:09 +00:00 
			
		
		
		
	Fixed #29048 -- Added **extra_context to database function as_vendor() methods.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							08f360355a
						
					
				
				
					commit
					83b04d4f88
				
			| @@ -26,10 +26,10 @@ class GeoAggregate(Aggregate): | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) | ||||
|         template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' | ||||
|         return self.as_sql(compiler, connection, template=template, tolerance=tolerance) | ||||
|         return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context) | ||||
|  | ||||
|     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): | ||||
|         c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) | ||||
|   | ||||
| @@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin: | ||||
|     By default, Decimal values are converted to str by the SQLite backend, which | ||||
|     is not acceptable by the GIS functions expecting numeric values. | ||||
|     """ | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         for expr in self.get_source_expressions(): | ||||
|             if hasattr(expr, 'value') and isinstance(expr.value, Decimal): | ||||
|                 expr.value = float(expr.value) | ||||
|         return super().as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class OracleToleranceMixin: | ||||
|     tolerance = 0.05 | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         tol = self.extra.get('tolerance', self.tolerance) | ||||
|         return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) | ||||
|         return self.as_sql( | ||||
|             compiler, connection, | ||||
|             template="%%(function)s(%%(expressions)s, %s)" % tol, | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Area(OracleToleranceMixin, GeoFunc): | ||||
| @@ -181,11 +185,11 @@ class AsGML(GeoFunc): | ||||
|  | ||||
|  | ||||
| class AsKML(AsGML): | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         # No version parameter | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions(self.get_source_expressions()[1:]) | ||||
|         return clone.as_sql(compiler, connection) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class AsSVG(GeoFunc): | ||||
| @@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc): | ||||
|     def __init__(self, expression, num_seg=48, **extra): | ||||
|         super().__init__(expression, num_seg, **extra) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions([self.get_source_expressions()[0]]) | ||||
|         return super(BoundingCircle, clone).as_oracle(compiler, connection) | ||||
|         return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
| @@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|             self.spheroid = self._handle_param(spheroid, 'spheroid', bool) | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         clone = self.copy() | ||||
|         function = None | ||||
|         expr2 = clone.source_expressions[1] | ||||
| @@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|                 clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) | ||||
|             else: | ||||
|                 function = connection.ops.spatial_function_name('DistanceSphere') | ||||
|         return super(Distance, clone).as_sql(compiler, connection, function=function) | ||||
|         return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if self.geo_field.geodetic(connection): | ||||
| @@ -300,12 +304,12 @@ class GeoHash(GeoFunc): | ||||
|             expressions.append(self._handle_param(precision, 'precision', int)) | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         clone = self.copy() | ||||
|         # If no precision is provided, set it to the maximum. | ||||
|         if len(clone.source_expressions) < 2: | ||||
|             clone.source_expressions.append(Value(100)) | ||||
|         return clone.as_sql(compiler, connection) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
| @@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|             raise NotSupportedError("This backend doesn't support Length on geodetic fields") | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         clone = self.copy() | ||||
|         function = None | ||||
|         if self.source_is_geography(): | ||||
| @@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|             dim = min(f.dim for f in self.get_source_fields() if f) | ||||
|             if dim > 2: | ||||
|                 function = connection.ops.length3d | ||||
|         return super(Length, clone).as_sql(compiler, connection, function=function) | ||||
|         return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         function = None | ||||
|         if self.geo_field.geodetic(connection): | ||||
|             function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' | ||||
|         return super().as_sql(compiler, connection, function=function) | ||||
|         return super().as_sql(compiler, connection, function=function, **extra_context) | ||||
|  | ||||
|  | ||||
| class LineLocatePoint(GeoFunc): | ||||
| @@ -383,19 +387,19 @@ class NumPoints(GeoFunc): | ||||
| class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): | ||||
|     arity = 1 | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         function = None | ||||
|         if self.geo_field.geodetic(connection) and not self.source_is_geography(): | ||||
|             raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") | ||||
|         dim = min(f.dim for f in self.get_source_fields()) | ||||
|         if dim > 2: | ||||
|             function = connection.ops.perimeter3d | ||||
|         return super().as_sql(compiler, connection, function=function) | ||||
|         return super().as_sql(compiler, connection, function=function, **extra_context) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if self.geo_field.geodetic(connection): | ||||
|             raise NotSupportedError("Perimeter cannot use a non-projected field.") | ||||
|         return super().as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
| @@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc): | ||||
|  | ||||
|  | ||||
| class Translate(Scale): | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         clone = self.copy() | ||||
|         if len(self.source_expressions) < 4: | ||||
|             # Always provide the z parameter for ST_Translate | ||||
|             clone.source_expressions.append(Value(0)) | ||||
|         return super(Translate, clone).as_sqlite(compiler, connection) | ||||
|         return super(Translate, clone).as_sqlite(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Union(OracleToleranceMixin, GeomOutputGeoFunc): | ||||
|   | ||||
| @@ -64,7 +64,10 @@ class Aggregate(Func): | ||||
|             if connection.features.supports_aggregate_filter_clause: | ||||
|                 filter_sql, filter_params = self.filter.as_sql(compiler, connection) | ||||
|                 template = self.filter_template % extra_context.get('template', self.template) | ||||
|                 sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql) | ||||
|                 sql, params = super().as_sql( | ||||
|                     compiler, connection, template=template, filter=filter_sql, | ||||
|                     **extra_context | ||||
|                 ) | ||||
|                 return sql, params + filter_params | ||||
|             else: | ||||
|                 copy = self.copy() | ||||
| @@ -92,20 +95,20 @@ class Avg(Aggregate): | ||||
|             return FloatField() | ||||
|         return super()._resolve_output_field() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         sql, params = super().as_sql(compiler, connection) | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         sql, params = super().as_sql(compiler, connection, **extra_context) | ||||
|         if self.output_field.get_internal_type() == 'DurationField': | ||||
|             sql = 'CAST(%s as SIGNED)' % sql | ||||
|         return sql, params | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         if self.output_field.get_internal_type() == 'DurationField': | ||||
|             expression = self.get_source_expressions()[0] | ||||
|             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval | ||||
|             return compiler.compile( | ||||
|                 SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) | ||||
|             ) | ||||
|         return super().as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Count(Aggregate): | ||||
| @@ -157,20 +160,20 @@ class Sum(Aggregate): | ||||
|     function = 'SUM' | ||||
|     name = 'Sum' | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         sql, params = super().as_sql(compiler, connection) | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         sql, params = super().as_sql(compiler, connection, **extra_context) | ||||
|         if self.output_field.get_internal_type() == 'DurationField': | ||||
|             sql = 'CAST(%s as SIGNED)' % sql | ||||
|         return sql, params | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         if self.output_field.get_internal_type() == 'DurationField': | ||||
|             expression = self.get_source_expressions()[0] | ||||
|             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval | ||||
|             return compiler.compile( | ||||
|                 SecondsToInterval(Sum(IntervalToSeconds(expression))) | ||||
|             ) | ||||
|         return super().as_sql(compiler, connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Variance(Aggregate): | ||||
|   | ||||
| @@ -14,16 +14,16 @@ class Cast(Func): | ||||
|         extra_context['db_type'] = self.output_field.cast_db_type(connection) | ||||
|         return super().as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         # MySQL doesn't support explicit cast to float. | ||||
|         template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None | ||||
|         return self.as_sql(compiler, connection, template=template) | ||||
|         return self.as_sql(compiler, connection, template=template, **extra_context) | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # CAST would be valid too, but the :: shortcut syntax is more readable. | ||||
|         # 'expressions' is wrapped in parentheses in case it's a complex | ||||
|         # expression. | ||||
|         return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s') | ||||
|         return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) | ||||
|  | ||||
|  | ||||
| class Coalesce(Func): | ||||
| @@ -35,7 +35,7 @@ class Coalesce(Func): | ||||
|             raise ValueError('Coalesce must take at least two expressions') | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), | ||||
|         # so convert all fields to NCLOB when that type is expected. | ||||
|         if self.output_field.get_internal_type() == 'TextField': | ||||
| @@ -47,8 +47,8 @@ class Coalesce(Func): | ||||
|             ] | ||||
|             clone = self.copy() | ||||
|             clone.set_source_expressions(expressions) | ||||
|             return super(Coalesce, clone).as_sql(compiler, connection) | ||||
|         return self.as_sql(compiler, connection) | ||||
|             return super(Coalesce, clone).as_sql(compiler, connection, **extra_context) | ||||
|         return self.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Greatest(Func): | ||||
| @@ -66,9 +66,9 @@ class Greatest(Func): | ||||
|             raise ValueError('Greatest must take at least two expressions') | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         """Use the MAX function on SQLite.""" | ||||
|         return super().as_sqlite(compiler, connection, function='MAX') | ||||
|         return super().as_sqlite(compiler, connection, function='MAX', **extra_context) | ||||
|  | ||||
|  | ||||
| class Least(Func): | ||||
| @@ -86,6 +86,6 @@ class Least(Func): | ||||
|             raise ValueError('Least must take at least two expressions') | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         """Use the MIN function on SQLite.""" | ||||
|         return super().as_sqlite(compiler, connection, function='MIN') | ||||
|         return super().as_sqlite(compiler, connection, function='MIN', **extra_context) | ||||
|   | ||||
| @@ -159,11 +159,11 @@ class Now(Func): | ||||
|     template = 'CURRENT_TIMESTAMP' | ||||
|     output_field = fields.DateTimeField() | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the | ||||
|         # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with | ||||
|         # other databases. | ||||
|         return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()') | ||||
|         return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context) | ||||
|  | ||||
|  | ||||
| class TruncBase(TimezoneMixin, Transform): | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from django.db.models.functions import Cast | ||||
|  | ||||
| class DecimalInputMixin: | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         # Cast FloatField to DecimalField as PostgreSQL doesn't support the | ||||
|         # following function signatures: | ||||
|         # - LOG(double, double) | ||||
| @@ -20,7 +20,7 @@ class DecimalInputMixin: | ||||
|             Cast(expression, output_field) if isinstance(expression.output_field, FloatField) | ||||
|             else expression for expression in self.get_source_expressions() | ||||
|         ]) | ||||
|         return clone.as_sql(compiler, connection) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class OutputFieldMixin: | ||||
| @@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func): | ||||
|     function = 'ATAN2' | ||||
|     arity = 2 | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0): | ||||
|             return self.as_sql(compiler, connection) | ||||
|         # This function is usually ATan2(y, x), returning the inverse tangent | ||||
| @@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func): | ||||
|             Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) | ||||
|             else expression for expression in self.get_source_expressions()[::-1] | ||||
|         ]) | ||||
|         return clone.as_sql(compiler, connection) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Ceil(Transform): | ||||
|     function = 'CEILING' | ||||
|     lookup_name = 'ceil' | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, function='CEIL') | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='CEIL', **extra_context) | ||||
|  | ||||
|  | ||||
| class Cos(OutputFieldMixin, Transform): | ||||
| @@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform): | ||||
|     function = 'COT' | ||||
|     lookup_name = 'cot' | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))') | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) | ||||
|  | ||||
|  | ||||
| class Degrees(OutputFieldMixin, Transform): | ||||
|     function = 'DEGREES' | ||||
|     lookup_name = 'degrees' | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi) | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, | ||||
|             template='((%%(expressions)s) * 180 / %s)' % math.pi, | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Exp(OutputFieldMixin, Transform): | ||||
| @@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func): | ||||
|     function = 'LOG' | ||||
|     arity = 2 | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         if not getattr(connection.ops, 'spatialite', False): | ||||
|             return self.as_sql(compiler, connection) | ||||
|         # This function is usually Log(b, x) returning the logarithm of x to | ||||
|         # the base b, but on SpatiaLite it's Log(x, b). | ||||
|         clone = self.copy() | ||||
|         clone.set_source_expressions(self.get_source_expressions()[::-1]) | ||||
|         return clone.as_sql(compiler, connection) | ||||
|         return clone.as_sql(compiler, connection, **extra_context) | ||||
|  | ||||
|  | ||||
| class Mod(DecimalInputMixin, OutputFieldMixin, Func): | ||||
| @@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func): | ||||
|     function = 'PI' | ||||
|     arity = 0 | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, template=str(math.pi)) | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) | ||||
|  | ||||
|  | ||||
| class Power(OutputFieldMixin, Func): | ||||
| @@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform): | ||||
|     function = 'RADIANS' | ||||
|     lookup_name = 'radians' | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi) | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, | ||||
|             template='((%%(expressions)s) * %s / 180)' % math.pi, | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Round(Transform): | ||||
|   | ||||
| @@ -22,13 +22,19 @@ class Chr(Transform): | ||||
|     function = 'CHR' | ||||
|     lookup_name = 'chr' | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)' | ||||
|             compiler, connection, function='CHAR', | ||||
|             template='%(function)s(%(expressions)s USING utf16)', | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)') | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql( | ||||
|             compiler, connection, | ||||
|             template='%(function)s(%(expressions)s USING NCHAR_CS)', | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='CHAR', **extra_context) | ||||
| @@ -41,16 +47,19 @@ class ConcatPair(Func): | ||||
|     """ | ||||
|     function = 'CONCAT' | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         coalesced = self.coalesce() | ||||
|         return super(ConcatPair, coalesced).as_sql( | ||||
|             compiler, connection, template='%(expressions)s', arg_joiner=' || ' | ||||
|             compiler, connection, template='%(expressions)s', arg_joiner=' || ', | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         # Use CONCAT_WS with an empty separator so that NULLs are ignored. | ||||
|         return super().as_sql( | ||||
|             compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)" | ||||
|             compiler, connection, function='CONCAT_WS', | ||||
|             template="%(function)s('', %(expressions)s)", | ||||
|             **extra_context | ||||
|         ) | ||||
|  | ||||
|     def coalesce(self): | ||||
| @@ -117,8 +126,8 @@ class Length(Transform): | ||||
|     lookup_name = 'length' | ||||
|     output_field = fields.IntegerField() | ||||
|  | ||||
|     def as_mysql(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, function='CHAR_LENGTH') | ||||
|     def as_mysql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context) | ||||
|  | ||||
|  | ||||
| class Lower(Transform): | ||||
| @@ -199,8 +208,8 @@ class StrIndex(Func): | ||||
|     arity = 2 | ||||
|     output_field = fields.IntegerField() | ||||
|  | ||||
|     def as_postgresql(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, function='STRPOS') | ||||
|     def as_postgresql(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='STRPOS', **extra_context) | ||||
|  | ||||
|  | ||||
| class Substr(Func): | ||||
| @@ -220,11 +229,11 @@ class Substr(Func): | ||||
|             expressions.append(length) | ||||
|         super().__init__(*expressions, **extra) | ||||
|  | ||||
|     def as_sqlite(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, function='SUBSTR') | ||||
|     def as_sqlite(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|         return super().as_sql(compiler, connection, function='SUBSTR') | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) | ||||
|  | ||||
|  | ||||
| class Trim(Transform): | ||||
|   | ||||
| @@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of | ||||
| ``NotEqual`` with an ``as_mysql`` method:: | ||||
|  | ||||
|   class MySQLNotEqual(NotEqual): | ||||
|       def as_mysql(self, compiler, connection): | ||||
|       def as_mysql(self, compiler, connection, **extra_context): | ||||
|           lhs, lhs_params = self.process_lhs(compiler, connection) | ||||
|           rhs, rhs_params = self.process_rhs(compiler, connection) | ||||
|           params = lhs_params + rhs_params | ||||
|   | ||||
| @@ -322,11 +322,12 @@ The ``Func`` API is as follows: | ||||
|                 function = 'CONCAT' | ||||
|                 ... | ||||
|  | ||||
|                 def as_mysql(self, compiler, connection): | ||||
|                 def as_mysql(self, compiler, connection, **extra_context): | ||||
|                     return super().as_sql( | ||||
|                         compiler, connection, | ||||
|                         function='CONCAT_WS', | ||||
|                         template="%(function)s('', %(expressions)s)", | ||||
|                         **extra_context | ||||
|                     ) | ||||
|  | ||||
|         To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must | ||||
|   | ||||
| @@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase): | ||||
|         class Greatest(Func): | ||||
|             function = 'GREATEST' | ||||
|  | ||||
|             def as_sqlite(self, compiler, connection): | ||||
|                 return super().as_sql(compiler, connection, function='MAX') | ||||
|             def as_sqlite(self, compiler, connection, **extra_context): | ||||
|                 return super().as_sql(compiler, connection, function='MAX', **extra_context) | ||||
|  | ||||
|         qs = Publisher.objects.annotate( | ||||
|             price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) | ||||
|   | ||||
| @@ -34,7 +34,7 @@ class Div3Transform(models.Transform): | ||||
|         lhs, lhs_params = compiler.compile(self.lhs) | ||||
|         return '(%s) %%%% 3' % lhs, lhs_params | ||||
|  | ||||
|     def as_oracle(self, compiler, connection): | ||||
|     def as_oracle(self, compiler, connection, **extra_context): | ||||
|         lhs, lhs_params = compiler.compile(self.lhs) | ||||
|         return 'mod(%s, 3)' % lhs, lhs_params | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user