mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #24509 -- Added Expression support to SQLInsertCompiler
This commit is contained in:
		| @@ -576,7 +576,7 @@ class BaseDatabaseOperations(object): | ||||
|     def combine_duration_expression(self, connector, sub_expressions): | ||||
|         return self.combine_expression(connector, sub_expressions) | ||||
|  | ||||
|     def modify_insert_params(self, placeholders, params): | ||||
|     def modify_insert_params(self, placeholder, params): | ||||
|         """Allow modification of insert parameters. Needed for Oracle Spatial | ||||
|         backend due to #10888. | ||||
|         """ | ||||
|   | ||||
| @@ -166,9 +166,10 @@ class DatabaseOperations(BaseDatabaseOperations): | ||||
|     def max_name_length(self): | ||||
|         return 64 | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, num_values): | ||||
|         items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) | ||||
|         return "VALUES " + ", ".join([items_sql] * num_values) | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) | ||||
|         values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) | ||||
|         return "VALUES " + values_sql | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         """ | ||||
|   | ||||
| @@ -439,6 +439,8 @@ WHEN (new.%(col_name)s IS NULL) | ||||
|         name_length = self.max_name_length() - 3 | ||||
|         return '%s_TR' % truncate_name(table, name_length).upper() | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, num_values): | ||||
|         items_sql = "SELECT %s FROM DUAL" % ", ".join(["%s"] * len(fields)) | ||||
|         return " UNION ALL ".join([items_sql] * num_values) | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         return " UNION ALL ".join( | ||||
|             "SELECT %s FROM DUAL" % ", ".join(row) | ||||
|             for row in placeholder_rows | ||||
|         ) | ||||
|   | ||||
| @@ -221,9 +221,10 @@ class DatabaseOperations(BaseDatabaseOperations): | ||||
|     def return_insert_id(self): | ||||
|         return "RETURNING %s", () | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, num_values): | ||||
|         items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) | ||||
|         return "VALUES " + ", ".join([items_sql] * num_values) | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) | ||||
|         values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) | ||||
|         return "VALUES " + values_sql | ||||
|  | ||||
|     def adapt_datefield_value(self, value): | ||||
|         return value | ||||
|   | ||||
| @@ -226,13 +226,11 @@ class DatabaseOperations(BaseDatabaseOperations): | ||||
|             value = uuid.UUID(value) | ||||
|         return value | ||||
|  | ||||
|     def bulk_insert_sql(self, fields, num_values): | ||||
|         res = [] | ||||
|         res.append("SELECT %s" % ", ".join( | ||||
|             "%%s AS %s" % self.quote_name(f.column) for f in fields | ||||
|         )) | ||||
|         res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1)) | ||||
|         return " ".join(res) | ||||
|     def bulk_insert_sql(self, fields, placeholder_rows): | ||||
|         return " UNION ALL ".join( | ||||
|             "SELECT %s" % ", ".join(row) | ||||
|             for row in placeholder_rows | ||||
|         ) | ||||
|  | ||||
|     def combine_expression(self, connector, sub_expressions): | ||||
|         # SQLite doesn't have a power function, so we fake it with a | ||||
|   | ||||
| @@ -180,6 +180,13 @@ class BaseExpression(object): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     @cached_property | ||||
|     def contains_column_references(self): | ||||
|         for expr in self.get_source_expressions(): | ||||
|             if expr and expr.contains_column_references: | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): | ||||
|         """ | ||||
|         Provides the chance to do any preprocessing or validation before being | ||||
| @@ -339,6 +346,17 @@ class BaseExpression(object): | ||||
|     def reverse_ordering(self): | ||||
|         return self | ||||
|  | ||||
|     def flatten(self): | ||||
|         """ | ||||
|         Recursively yield this expression and all subexpressions, in | ||||
|         depth-first order. | ||||
|         """ | ||||
|         yield self | ||||
|         for expr in self.get_source_expressions(): | ||||
|             if expr: | ||||
|                 for inner_expr in expr.flatten(): | ||||
|                     yield inner_expr | ||||
|  | ||||
|  | ||||
| class Expression(BaseExpression, Combinable): | ||||
|     """ | ||||
| @@ -613,6 +631,9 @@ class Random(Expression): | ||||
|  | ||||
|  | ||||
| class Col(Expression): | ||||
|  | ||||
|     contains_column_references = True | ||||
|  | ||||
|     def __init__(self, alias, target, output_field=None): | ||||
|         if output_field is None: | ||||
|             output_field = target | ||||
|   | ||||
| @@ -458,6 +458,8 @@ class QuerySet(object): | ||||
|         specifying whether an object was created. | ||||
|         """ | ||||
|         lookup, params = self._extract_model_params(defaults, **kwargs) | ||||
|         # The get() needs to be targeted at the write database in order | ||||
|         # to avoid potential transaction consistency problems. | ||||
|         self._for_write = True | ||||
|         try: | ||||
|             return self.get(**lookup), False | ||||
|   | ||||
| @@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler): | ||||
|         self.return_id = False | ||||
|         super(SQLInsertCompiler, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def placeholder(self, field, val): | ||||
|     def field_as_sql(self, field, val): | ||||
|         """ | ||||
|         Take a field and a value intended to be saved on that field, and | ||||
|         return placeholder SQL and accompanying params. Checks for raw values, | ||||
|         expressions and fields with get_placeholder() defined in that order. | ||||
|  | ||||
|         When field is None, the value is considered raw and is used as the | ||||
|         placeholder, with no corresponding parameters returned. | ||||
|         """ | ||||
|         if field is None: | ||||
|             # A field value of None means the value is raw. | ||||
|             return val | ||||
|             sql, params = val, [] | ||||
|         elif hasattr(val, 'as_sql'): | ||||
|             # This is an expression, let's compile it. | ||||
|             sql, params = self.compile(val) | ||||
|         elif hasattr(field, 'get_placeholder'): | ||||
|             # Some fields (e.g. geo fields) need special munging before | ||||
|             # they can be inserted. | ||||
|             return field.get_placeholder(val, self, self.connection) | ||||
|             sql, params = field.get_placeholder(val, self, self.connection), [val] | ||||
|         else: | ||||
|             # Return the common case for the placeholder | ||||
|             return '%s' | ||||
|             sql, params = '%s', [val] | ||||
|  | ||||
|         # The following hook is only used by Oracle Spatial, which sometimes | ||||
|         # needs to yield 'NULL' and [] as its placeholder and params instead | ||||
|         # of '%s' and [None]. The 'NULL' placeholder is produced earlier by | ||||
|         # OracleOperations.get_geom_placeholder(). The following line removes | ||||
|         # the corresponding None parameter. See ticket #10888. | ||||
|         params = self.connection.ops.modify_insert_params(sql, params) | ||||
|  | ||||
|         return sql, params | ||||
|  | ||||
|     def prepare_value(self, field, value): | ||||
|         """ | ||||
|         Prepare a value to be used in a query by resolving it if it is an | ||||
|         expression and otherwise calling the field's get_db_prep_save(). | ||||
|         """ | ||||
|         if hasattr(value, 'resolve_expression'): | ||||
|             value = value.resolve_expression(self.query, allow_joins=False, for_save=True) | ||||
|             # Don't allow values containing Col expressions. They refer to | ||||
|             # existing columns on a row, but in the case of insert the row | ||||
|             # doesn't exist yet. | ||||
|             if value.contains_column_references: | ||||
|                 raise ValueError( | ||||
|                     'Failed to insert expression "%s" on %s. F() expressions ' | ||||
|                     'can only be used to update, not to insert.' % (value, field) | ||||
|                 ) | ||||
|             if value.contains_aggregate: | ||||
|                 raise FieldError("Aggregate functions are not allowed in this query") | ||||
|         else: | ||||
|             value = field.get_db_prep_save(value, connection=self.connection) | ||||
|         return value | ||||
|  | ||||
|     def pre_save_val(self, field, obj): | ||||
|         """ | ||||
|         Get the given field's value off the given obj. pre_save() is used for | ||||
|         things like auto_now on DateTimeField. Skip it if this is a raw query. | ||||
|         """ | ||||
|         if self.query.raw: | ||||
|             return getattr(obj, field.attname) | ||||
|         return field.pre_save(obj, add=True) | ||||
|  | ||||
|     def assemble_as_sql(self, fields, value_rows): | ||||
|         """ | ||||
|         Take a sequence of N fields and a sequence of M rows of values, | ||||
|         generate placeholder SQL and parameters for each field and value, and | ||||
|         return a pair containing: | ||||
|          * a sequence of M rows of N SQL placeholder strings, and | ||||
|          * a sequence of M rows of corresponding parameter values. | ||||
|  | ||||
|         Each placeholder string may contain any number of '%s' interpolation | ||||
|         strings, and each parameter row will contain exactly as many params | ||||
|         as the total number of '%s's in the corresponding placeholder row. | ||||
|         """ | ||||
|         if not value_rows: | ||||
|             return [], [] | ||||
|  | ||||
|         # list of (sql, [params]) tuples for each object to be saved | ||||
|         # Shape: [n_objs][n_fields][2] | ||||
|         rows_of_fields_as_sql = ( | ||||
|             (self.field_as_sql(field, v) for field, v in zip(fields, row)) | ||||
|             for row in value_rows | ||||
|         ) | ||||
|  | ||||
|         # tuple like ([sqls], [[params]s]) for each object to be saved | ||||
|         # Shape: [n_objs][2][n_fields] | ||||
|         sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql) | ||||
|  | ||||
|         # Extract separate lists for placeholders and params. | ||||
|         # Each of these has shape [n_objs][n_fields] | ||||
|         placeholder_rows, param_rows = zip(*sql_and_param_pair_rows) | ||||
|  | ||||
|         # Params for each field are still lists, and need to be flattened. | ||||
|         param_rows = [[p for ps in row for p in ps] for row in param_rows] | ||||
|  | ||||
|         return placeholder_rows, param_rows | ||||
|  | ||||
|     def as_sql(self): | ||||
|         # We don't need quote_name_unless_alias() here, since these are all | ||||
| @@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler): | ||||
|         result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) | ||||
|  | ||||
|         if has_fields: | ||||
|             params = values = [ | ||||
|                 [ | ||||
|                     f.get_db_prep_save( | ||||
|                         getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), | ||||
|                         connection=self.connection | ||||
|                     ) for f in fields | ||||
|                 ] | ||||
|             value_rows = [ | ||||
|                 [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields] | ||||
|                 for obj in self.query.objs | ||||
|             ] | ||||
|         else: | ||||
|             values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] | ||||
|             params = [[]] | ||||
|             # An empty object. | ||||
|             value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs] | ||||
|             fields = [None] | ||||
|         can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and | ||||
|             not self.return_id and self.connection.features.has_bulk_insert) | ||||
|  | ||||
|         if can_bulk: | ||||
|             placeholders = [["%s"] * len(fields)] | ||||
|         else: | ||||
|             placeholders = [ | ||||
|                 [self.placeholder(field, v) for field, v in zip(fields, val)] | ||||
|                 for val in values | ||||
|             ] | ||||
|             # Oracle Spatial needs to remove some values due to #10888 | ||||
|             params = self.connection.ops.modify_insert_params(placeholders, params) | ||||
|         # Currently the backends just accept values when generating bulk | ||||
|         # queries and generate their own placeholders. Doing that isn't | ||||
|         # necessary and it should be possible to use placeholders and | ||||
|         # expressions in bulk inserts too. | ||||
|         can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) | ||||
|  | ||||
|         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) | ||||
|  | ||||
|         if self.return_id and self.connection.features.can_return_id_from_insert: | ||||
|             params = params[0] | ||||
|             params = param_rows[0] | ||||
|             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) | ||||
|             result.append("VALUES (%s)" % ", ".join(placeholders[0])) | ||||
|             result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) | ||||
|             r_fmt, r_params = self.connection.ops.return_insert_id() | ||||
|             # Skip empty r_fmt to allow subclasses to customize behavior for | ||||
|             # 3rd party backends. Refs #19096. | ||||
| @@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler): | ||||
|                 result.append(r_fmt % col) | ||||
|                 params += r_params | ||||
|             return [(" ".join(result), tuple(params))] | ||||
|  | ||||
|         if can_bulk: | ||||
|             result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) | ||||
|             return [(" ".join(result), tuple(v for val in values for v in val))] | ||||
|             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) | ||||
|             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] | ||||
|         else: | ||||
|             return [ | ||||
|                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) | ||||
|                 for p, vals in zip(placeholders, params) | ||||
|                 for p, vals in zip(placeholder_rows, param_rows) | ||||
|             ] | ||||
|  | ||||
|     def execute_sql(self, return_id=False): | ||||
| @@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler): | ||||
|                         connection=self.connection, | ||||
|                     ) | ||||
|                 else: | ||||
|                     raise TypeError("Database is trying to update a relational field " | ||||
|                                     "of type %s with a value of type %s. Make sure " | ||||
|                                     "you are setting the correct relations" % | ||||
|                                     (field.__class__.__name__, val.__class__.__name__)) | ||||
|                     raise TypeError( | ||||
|                         "Tried to update field %s with a model instance, %r. " | ||||
|                         "Use a value compatible with %s." | ||||
|                         % (field, val, field.__class__.__name__) | ||||
|                     ) | ||||
|             else: | ||||
|                 val = field.get_db_prep_save(val, connection=self.connection) | ||||
|  | ||||
|   | ||||
| @@ -139,9 +139,9 @@ class UpdateQuery(Query): | ||||
|  | ||||
|     def add_update_fields(self, values_seq): | ||||
|         """ | ||||
|         Turn a sequence of (field, model, value) triples into an update query. | ||||
|         Used by add_update_values() as well as the "fast" update path when | ||||
|         saving models. | ||||
|         Append a sequence of (field, model, value) triples to the internal list | ||||
|         that will be used to generate the UPDATE query. Might be more usefully | ||||
|         called add_update_targets() to hint at the extra information here. | ||||
|         """ | ||||
|         self.values.extend(values_seq) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user