mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Added support for parameters in SELECT clauses.
This commit is contained in:
		| @@ -56,12 +56,13 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): | ||||
|  | ||||
|         lookup_info = self.geometry_functions.get(lookup_type, False) | ||||
|         if lookup_info: | ||||
|             return "%s(%s, %s)" % (lookup_info, geo_col, | ||||
|             sql = "%s(%s, %s)" % (lookup_info, geo_col, | ||||
|                                   self.get_geom_placeholder(value, field.srid)) | ||||
|             return sql, [] | ||||
|  | ||||
|         # TODO: Is this really necessary? MySQL can't handle NULL geometries | ||||
|         #  in its spatial indexes anyways. | ||||
|         if lookup_type == 'isnull': | ||||
|             return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) | ||||
|             return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] | ||||
|  | ||||
|         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) | ||||
|   | ||||
| @@ -262,7 +262,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): | ||||
|                 return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value)) | ||||
|         elif lookup_type == 'isnull': | ||||
|             # Handling 'isnull' lookup type | ||||
|             return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) | ||||
|             return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] | ||||
|  | ||||
|         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) | ||||
|  | ||||
|   | ||||
| @@ -560,7 +560,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): | ||||
|  | ||||
|         elif lookup_type == 'isnull': | ||||
|             # Handling 'isnull' lookup type | ||||
|             return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) | ||||
|             return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] | ||||
|  | ||||
|         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) | ||||
|  | ||||
|   | ||||
| @@ -358,7 +358,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): | ||||
|             return op.as_sql(geo_col, self.get_geom_placeholder(field, geom)) | ||||
|         elif lookup_type == 'isnull': | ||||
|             # Handling 'isnull' lookup type | ||||
|             return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) | ||||
|             return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] | ||||
|  | ||||
|         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) | ||||
|  | ||||
|   | ||||
| @@ -16,7 +16,7 @@ class SpatialOperation(object): | ||||
|         self.extra = kwargs | ||||
|  | ||||
|     def as_sql(self, geo_col, geometry='%s'): | ||||
|         return self.sql_template % self.params(geo_col, geometry) | ||||
|         return self.sql_template % self.params(geo_col, geometry), [] | ||||
|  | ||||
|     def params(self, geo_col, geometry): | ||||
|         params = {'function' : self.function, | ||||
|   | ||||
| @@ -22,13 +22,15 @@ class GeoAggregate(Aggregate): | ||||
|             raise ValueError('Geospatial aggregates only allowed on geometry fields.') | ||||
|  | ||||
|     def as_sql(self, qn, connection): | ||||
|         "Return the aggregate, rendered as SQL." | ||||
|         "Return the aggregate, rendered as SQL with parameters." | ||||
|  | ||||
|         if connection.ops.oracle: | ||||
|             self.extra['tolerance'] = self.tolerance | ||||
|  | ||||
|         params = [] | ||||
|  | ||||
|         if hasattr(self.col, 'as_sql'): | ||||
|             field_name = self.col.as_sql(qn, connection) | ||||
|             field_name, params = self.col.as_sql(qn, connection) | ||||
|         elif isinstance(self.col, (list, tuple)): | ||||
|             field_name = '.'.join([qn(c) for c in self.col]) | ||||
|         else: | ||||
| @@ -36,13 +38,13 @@ class GeoAggregate(Aggregate): | ||||
|  | ||||
|         sql_template, sql_function = connection.ops.spatial_aggregate_sql(self) | ||||
|  | ||||
|         params = { | ||||
|         substitutions = { | ||||
|             'function': sql_function, | ||||
|             'field': field_name | ||||
|         } | ||||
|         params.update(self.extra) | ||||
|         substitutions.update(self.extra) | ||||
|  | ||||
|         return sql_template % params | ||||
|         return sql_template % substitutions, params | ||||
|  | ||||
| class Collect(GeoAggregate): | ||||
|     pass | ||||
|   | ||||
| @@ -33,6 +33,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): | ||||
|         qn2 = self.connection.ops.quote_name | ||||
|         result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias)) | ||||
|                   for alias, col in six.iteritems(self.query.extra_select)] | ||||
|         params = [] | ||||
|         aliases = set(self.query.extra_select.keys()) | ||||
|         if with_aliases: | ||||
|             col_aliases = aliases.copy() | ||||
| @@ -63,7 +64,9 @@ class GeoSQLCompiler(compiler.SQLCompiler): | ||||
|                         aliases.add(r) | ||||
|                         col_aliases.add(col[1]) | ||||
|                 else: | ||||
|                     result.append(col.as_sql(qn, self.connection)) | ||||
|                     col_sql, col_params = col.as_sql(qn, self.connection) | ||||
|                     result.append(col_sql) | ||||
|                     params.extend(col_params) | ||||
|  | ||||
|                     if hasattr(col, 'alias'): | ||||
|                         aliases.add(col.alias) | ||||
| @@ -76,15 +79,13 @@ class GeoSQLCompiler(compiler.SQLCompiler): | ||||
|             aliases.update(new_aliases) | ||||
|  | ||||
|         max_name_length = self.connection.ops.max_name_length() | ||||
|         result.extend([ | ||||
|                 '%s%s' % ( | ||||
|                     self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection), | ||||
|                     alias is not None | ||||
|                         and ' AS %s' % qn(truncate_name(alias, max_name_length)) | ||||
|                         or '' | ||||
|                     ) | ||||
|                 for alias, aggregate in self.query.aggregate_select.items() | ||||
|         ]) | ||||
|         for alias, aggregate in self.query.aggregate_select.items(): | ||||
|             agg_sql, agg_params = aggregate.as_sql(qn, self.connection) | ||||
|             if alias is None: | ||||
|                 result.append(agg_sql) | ||||
|             else: | ||||
|                 result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) | ||||
|             params.extend(agg_params) | ||||
|  | ||||
|         # This loop customized for GeoQuery. | ||||
|         for (table, col), field in self.query.related_select_cols: | ||||
| @@ -100,7 +101,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): | ||||
|                 col_aliases.add(col) | ||||
|  | ||||
|         self._select_aliases = aliases | ||||
|         return result | ||||
|         return result, params | ||||
|  | ||||
|     def get_default_columns(self, with_aliases=False, col_aliases=None, | ||||
|             start_alias=None, opts=None, as_pairs=False, from_parent=None): | ||||
|   | ||||
| @@ -44,8 +44,9 @@ class GeoWhereNode(WhereNode): | ||||
|         lvalue, lookup_type, value_annot, params_or_value = child | ||||
|         if isinstance(lvalue, GeoConstraint): | ||||
|             data, params = lvalue.process(lookup_type, params_or_value, connection) | ||||
|             spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn) | ||||
|             return spatial_sql, params | ||||
|             spatial_sql, spatial_params = connection.ops.spatial_lookup_sql( | ||||
|                     data, lookup_type, params_or_value, lvalue.field, qn) | ||||
|             return spatial_sql, spatial_params + params | ||||
|         else: | ||||
|             return super(GeoWhereNode, self).make_atom(child, qn, connection) | ||||
|  | ||||
|   | ||||
| @@ -25,7 +25,7 @@ class QueryWrapper(object): | ||||
|     parameters. Can be used to pass opaque data to a where-clause, for example. | ||||
|     """ | ||||
|     def __init__(self, sql, params): | ||||
|         self.data = sql, params | ||||
|         self.data = sql, list(params) | ||||
|  | ||||
|     def as_sql(self, qn=None, connection=None): | ||||
|         return self.data | ||||
|   | ||||
| @@ -73,22 +73,23 @@ class Aggregate(object): | ||||
|             self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) | ||||
|  | ||||
|     def as_sql(self, qn, connection): | ||||
|         "Return the aggregate, rendered as SQL." | ||||
|         "Return the aggregate, rendered as SQL with parameters." | ||||
|         params = [] | ||||
|  | ||||
|         if hasattr(self.col, 'as_sql'): | ||||
|             field_name = self.col.as_sql(qn, connection) | ||||
|             field_name, params = self.col.as_sql(qn, connection) | ||||
|         elif isinstance(self.col, (list, tuple)): | ||||
|             field_name = '.'.join([qn(c) for c in self.col]) | ||||
|         else: | ||||
|             field_name = self.col | ||||
|  | ||||
|         params = { | ||||
|         substitutions = { | ||||
|             'function': self.sql_function, | ||||
|             'field': field_name | ||||
|         } | ||||
|         params.update(self.extra) | ||||
|         substitutions.update(self.extra) | ||||
|  | ||||
|         return self.sql_template % params | ||||
|         return self.sql_template % substitutions, params | ||||
|  | ||||
|  | ||||
| class Avg(Aggregate): | ||||
|   | ||||
| @@ -74,7 +74,7 @@ class SQLCompiler(object): | ||||
|         # as the pre_sql_setup will modify query state in a way that forbids | ||||
|         # another run of it. | ||||
|         self.refcounts_before = self.query.alias_refcount.copy() | ||||
|         out_cols = self.get_columns(with_col_aliases) | ||||
|         out_cols, s_params = self.get_columns(with_col_aliases) | ||||
|         ordering, ordering_group_by = self.get_ordering() | ||||
|  | ||||
|         distinct_fields = self.get_distinct() | ||||
| @@ -97,6 +97,7 @@ class SQLCompiler(object): | ||||
|             result.append(self.connection.ops.distinct_sql(distinct_fields)) | ||||
|  | ||||
|         result.append(', '.join(out_cols + self.query.ordering_aliases)) | ||||
|         params.extend(s_params) | ||||
|  | ||||
|         result.append('FROM') | ||||
|         result.extend(from_) | ||||
| @@ -164,9 +165,10 @@ class SQLCompiler(object): | ||||
|  | ||||
|     def get_columns(self, with_aliases=False): | ||||
|         """ | ||||
|         Returns the list of columns to use in the select statement. If no | ||||
|         columns have been specified, returns all columns relating to fields in | ||||
|         the model. | ||||
|         Returns the list of columns to use in the select statement, as well as | ||||
|         a list any extra parameters that need to be included. If no columns | ||||
|         have been specified, returns all columns relating to fields in the | ||||
|         model. | ||||
|  | ||||
|         If 'with_aliases' is true, any column names that are duplicated | ||||
|         (without the table names) are given unique aliases. This is needed in | ||||
| @@ -175,6 +177,7 @@ class SQLCompiler(object): | ||||
|         qn = self.quote_name_unless_alias | ||||
|         qn2 = self.connection.ops.quote_name | ||||
|         result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] | ||||
|         params = [] | ||||
|         aliases = set(self.query.extra_select.keys()) | ||||
|         if with_aliases: | ||||
|             col_aliases = aliases.copy() | ||||
| @@ -204,7 +207,9 @@ class SQLCompiler(object): | ||||
|                         aliases.add(r) | ||||
|                         col_aliases.add(col[1]) | ||||
|                 else: | ||||
|                     result.append(col.as_sql(qn, self.connection)) | ||||
|                     col_sql, col_params = col.as_sql(qn, self.connection) | ||||
|                     result.append(col_sql) | ||||
|                     params.extend(col_params) | ||||
|  | ||||
|                     if hasattr(col, 'alias'): | ||||
|                         aliases.add(col.alias) | ||||
| @@ -217,15 +222,13 @@ class SQLCompiler(object): | ||||
|             aliases.update(new_aliases) | ||||
|  | ||||
|         max_name_length = self.connection.ops.max_name_length() | ||||
|         result.extend([ | ||||
|             '%s%s' % ( | ||||
|                 aggregate.as_sql(qn, self.connection), | ||||
|                 alias is not None | ||||
|                     and ' AS %s' % qn(truncate_name(alias, max_name_length)) | ||||
|                     or '' | ||||
|             ) | ||||
|             for alias, aggregate in self.query.aggregate_select.items() | ||||
|         ]) | ||||
|         for alias, aggregate in self.query.aggregate_select.items(): | ||||
|             agg_sql, agg_params = aggregate.as_sql(qn, self.connection) | ||||
|             if alias is None: | ||||
|                 result.append(agg_sql) | ||||
|             else: | ||||
|                 result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) | ||||
|             params.extend(agg_params) | ||||
|  | ||||
|         for (table, col), _ in self.query.related_select_cols: | ||||
|             r = '%s.%s' % (qn(table), qn(col)) | ||||
| @@ -240,7 +243,7 @@ class SQLCompiler(object): | ||||
|                 col_aliases.add(col) | ||||
|  | ||||
|         self._select_aliases = aliases | ||||
|         return result | ||||
|         return result, params | ||||
|  | ||||
|     def get_default_columns(self, with_aliases=False, col_aliases=None, | ||||
|             start_alias=None, opts=None, as_pairs=False, from_parent=None): | ||||
| @@ -545,14 +548,16 @@ class SQLCompiler(object): | ||||
|             seen = set() | ||||
|             cols = self.query.group_by + select_cols | ||||
|             for col in cols: | ||||
|                 col_params = () | ||||
|                 if isinstance(col, (list, tuple)): | ||||
|                     sql = '%s.%s' % (qn(col[0]), qn(col[1])) | ||||
|                 elif hasattr(col, 'as_sql'): | ||||
|                     sql = col.as_sql(qn, self.connection) | ||||
|                     sql, col_params = col.as_sql(qn, self.connection) | ||||
|                 else: | ||||
|                     sql = '(%s)' % str(col) | ||||
|                 if sql not in seen: | ||||
|                     result.append(sql) | ||||
|                     params.extend(col_params) | ||||
|                     seen.add(sql) | ||||
|  | ||||
|             # Still, we need to add all stuff in ordering (except if the backend can | ||||
| @@ -991,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler): | ||||
|         if qn is None: | ||||
|             qn = self.quote_name_unless_alias | ||||
|  | ||||
|         sql = ('SELECT %s FROM (%s) subquery' % ( | ||||
|             ', '.join([ | ||||
|                 aggregate.as_sql(qn, self.connection) | ||||
|                 for aggregate in self.query.aggregate_select.values() | ||||
|             ]), | ||||
|             self.query.subquery) | ||||
|         ) | ||||
|         params = self.query.sub_params | ||||
|         return (sql, params) | ||||
|         sql, params = [], [] | ||||
|         for aggregate in self.query.aggregate_select.values(): | ||||
|             agg_sql, agg_params = aggregate.as_sql(qn, self.connection) | ||||
|             sql.append(agg_sql) | ||||
|             params.extend(agg_params) | ||||
|         sql = ', '.join(sql) | ||||
|         params = tuple(params) | ||||
|  | ||||
|         sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) | ||||
|         params = params + self.query.sub_params | ||||
|         return sql, params | ||||
|  | ||||
| class SQLDateCompiler(SQLCompiler): | ||||
|     def results_iter(self): | ||||
|   | ||||
| @@ -42,7 +42,7 @@ class Date(object): | ||||
|             col = '%s.%s' % tuple([qn(c) for c in self.col]) | ||||
|         else: | ||||
|             col = self.col | ||||
|         return getattr(connection.ops, self.trunc_func)(self.lookup_type, col) | ||||
|         return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), [] | ||||
|  | ||||
| class DateTime(Date): | ||||
|     """ | ||||
|   | ||||
| @@ -94,9 +94,9 @@ class SQLEvaluator(object): | ||||
|         if col is None: | ||||
|             raise ValueError("Given node not found") | ||||
|         if hasattr(col, 'as_sql'): | ||||
|             return col.as_sql(qn, connection), () | ||||
|             return col.as_sql(qn, connection) | ||||
|         else: | ||||
|             return '%s.%s' % (qn(col[0]), qn(col[1])), () | ||||
|             return '%s.%s' % (qn(col[0]), qn(col[1])), [] | ||||
|  | ||||
|     def evaluate_date_modifier_node(self, node, qn, connection): | ||||
|         timedelta = node.children.pop() | ||||
|   | ||||
| @@ -172,10 +172,10 @@ class WhereNode(tree.Node): | ||||
|  | ||||
|         if isinstance(lvalue, tuple): | ||||
|             # A direct database column lookup. | ||||
|             field_sql = self.sql_for_columns(lvalue, qn, connection) | ||||
|             field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), [] | ||||
|         else: | ||||
|             # A smart object with an as_sql() method. | ||||
|             field_sql = lvalue.as_sql(qn, connection) | ||||
|             field_sql, field_params = lvalue.as_sql(qn, connection) | ||||
|  | ||||
|         is_datetime_field = value_annotation is datetime.datetime | ||||
|         cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' | ||||
| @@ -186,6 +186,8 @@ class WhereNode(tree.Node): | ||||
|         else: | ||||
|             extra = '' | ||||
|  | ||||
|         params = field_params + params | ||||
|  | ||||
|         if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' | ||||
|             and connection.features.interprets_empty_strings_as_nulls): | ||||
|             lookup_type = 'isnull' | ||||
| @@ -245,7 +247,7 @@ class WhereNode(tree.Node): | ||||
|         """ | ||||
|         Returns the SQL fragment used for the left-hand side of a column | ||||
|         constraint (for example, the "T1.foo" portion in the clause | ||||
|         "WHERE ... T1.foo = 6"). | ||||
|         "WHERE ... T1.foo = 6") and a list of parameters. | ||||
|         """ | ||||
|         table_alias, name, db_type = data | ||||
|         if table_alias: | ||||
| @@ -338,7 +340,7 @@ class ExtraWhere(object): | ||||
|  | ||||
|     def as_sql(self, qn=None, connection=None): | ||||
|         sqls = ["(%s)" % sql for sql in self.sqls] | ||||
|         return " AND ".join(sqls), tuple(self.params or ()) | ||||
|         return " AND ".join(sqls), list(self.params or ()) | ||||
|  | ||||
|     def clone(self): | ||||
|         return self | ||||
|   | ||||
		Reference in New Issue
	
	Block a user