mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Refs #25367 -- Moved select_format hook to BaseExpression.
This will expose an intermediary hook for expressions that need special formatting when used in a SELECT clause.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							4f7328ce8a
						
					
				
				
					commit
					fff5186d32
				
			| @@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField): | |||||||
|         of the spatial backend. For example, Oracle and MySQL require custom |         of the spatial backend. For example, Oracle and MySQL require custom | ||||||
|         selection formats in order to retrieve geometries in OGC WKB. |         selection formats in order to retrieve geometries in OGC WKB. | ||||||
|         """ |         """ | ||||||
|  |         if not compiler.query.subquery: | ||||||
|             return compiler.connection.ops.select % sql, params |             return compiler.connection.ops.select % sql, params | ||||||
|  |         return sql, params | ||||||
|  |  | ||||||
|  |  | ||||||
| # The OpenGIS Geometry Type Fields | # The OpenGIS Geometry Type Fields | ||||||
|   | |||||||
| @@ -366,6 +366,13 @@ class BaseExpression: | |||||||
|             if expr: |             if expr: | ||||||
|                 yield from expr.flatten() |                 yield from expr.flatten() | ||||||
|  |  | ||||||
|  |     def select_format(self, compiler, sql, params): | ||||||
|  |         """ | ||||||
|  |         Custom format for select clauses. For example, EXISTS expressions need | ||||||
|  |         to be wrapped in CASE WHEN on Oracle. | ||||||
|  |         """ | ||||||
|  |         return self.output_field.select_format(compiler, sql, params) | ||||||
|  |  | ||||||
|     @cached_property |     @cached_property | ||||||
|     def identity(self): |     def identity(self): | ||||||
|         constructor_signature = inspect.signature(self.__init__) |         constructor_signature = inspect.signature(self.__init__) | ||||||
|   | |||||||
| @@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError | |||||||
| from django.utils.deprecation import RemovedInDjango31Warning | from django.utils.deprecation import RemovedInDjango31Warning | ||||||
| from django.utils.hashable import make_hashable | from django.utils.hashable import make_hashable | ||||||
|  |  | ||||||
| FORCE = object() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SQLCompiler: | class SQLCompiler: | ||||||
|     def __init__(self, query, connection, using): |     def __init__(self, query, connection, using): | ||||||
| @@ -244,10 +242,12 @@ class SQLCompiler: | |||||||
|         ret = [] |         ret = [] | ||||||
|         for col, alias in select: |         for col, alias in select: | ||||||
|             try: |             try: | ||||||
|                 sql, params = self.compile(col, select_format=True) |                 sql, params = self.compile(col) | ||||||
|             except EmptyResultSet: |             except EmptyResultSet: | ||||||
|                 # Select a predicate that's always False. |                 # Select a predicate that's always False. | ||||||
|                 sql, params = '0', () |                 sql, params = '0', () | ||||||
|  |             else: | ||||||
|  |                 sql, params = col.select_format(self, sql, params) | ||||||
|             ret.append((col, (sql, params), alias)) |             ret.append((col, (sql, params), alias)) | ||||||
|         return ret, klass_info, annotations |         return ret, klass_info, annotations | ||||||
|  |  | ||||||
| @@ -402,14 +402,12 @@ class SQLCompiler: | |||||||
|         self.quote_cache[name] = r |         self.quote_cache[name] = r | ||||||
|         return r |         return r | ||||||
|  |  | ||||||
|     def compile(self, node, select_format=False): |     def compile(self, node): | ||||||
|         vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) |         vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) | ||||||
|         if vendor_impl: |         if vendor_impl: | ||||||
|             sql, params = vendor_impl(self, self.connection) |             sql, params = vendor_impl(self, self.connection) | ||||||
|         else: |         else: | ||||||
|             sql, params = node.as_sql(self, self.connection) |             sql, params = node.as_sql(self, self.connection) | ||||||
|         if select_format is FORCE or (select_format and not self.query.subquery): |  | ||||||
|             return node.output_field.select_format(self, sql, params) |  | ||||||
|         return sql, params |         return sql, params | ||||||
|  |  | ||||||
|     def get_combinator_sql(self, combinator, all): |     def get_combinator_sql(self, combinator, all): | ||||||
| @@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler): | |||||||
|         """ |         """ | ||||||
|         sql, params = [], [] |         sql, params = [], [] | ||||||
|         for annotation in self.query.annotation_select.values(): |         for annotation in self.query.annotation_select.values(): | ||||||
|             ann_sql, ann_params = self.compile(annotation, select_format=FORCE) |             ann_sql, ann_params = self.compile(annotation) | ||||||
|  |             ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params) | ||||||
|             sql.append(ann_sql) |             sql.append(ann_sql) | ||||||
|             params.extend(ann_params) |             params.extend(ann_params) | ||||||
|         self.col_count = len(self.query.annotation_select) |         self.col_count = len(self.query.annotation_select) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user