1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Fixed #24509 -- Added Expression support to SQLInsertCompiler

This commit is contained in:
Alex Hill
2015-08-04 00:34:19 +10:00
committed by Josh Smeaton
parent 6e51d5d0e5
commit 134ca4d438
15 changed files with 247 additions and 67 deletions

View File

@@ -576,7 +576,7 @@ class BaseDatabaseOperations(object):
def combine_duration_expression(self, connector, sub_expressions):
return self.combine_expression(connector, sub_expressions)
def modify_insert_params(self, placeholders, params):
def modify_insert_params(self, placeholder, params):
"""Allow modification of insert parameters. Needed for Oracle Spatial
backend due to #10888.
"""

View File

@@ -166,9 +166,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def max_name_length(self):
return 64
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
"""

View File

@@ -439,6 +439,8 @@ WHEN (new.%(col_name)s IS NULL)
name_length = self.max_name_length() - 3
return '%s_TR' % truncate_name(table, name_length).upper()
def bulk_insert_sql(self, fields, num_values):
items_sql = "SELECT %s FROM DUAL" % ", ".join(["%s"] * len(fields))
return " UNION ALL ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s FROM DUAL" % ", ".join(row)
for row in placeholder_rows
)

View File

@@ -221,9 +221,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_id(self):
return "RETURNING %s", ()
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def adapt_datefield_value(self, value):
return value

View File

@@ -226,13 +226,11 @@ class DatabaseOperations(BaseDatabaseOperations):
value = uuid.UUID(value)
return value
def bulk_insert_sql(self, fields, num_values):
res = []
res.append("SELECT %s" % ", ".join(
"%%s AS %s" % self.quote_name(f.column) for f in fields
))
res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
return " ".join(res)
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a power function, so we fake it with a

View File

@@ -180,6 +180,13 @@ class BaseExpression(object):
return True
return False
@cached_property
def contains_column_references(self):
for expr in self.get_source_expressions():
if expr and expr.contains_column_references:
return True
return False
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
"""
Provides the chance to do any preprocessing or validation before being
@@ -339,6 +346,17 @@ class BaseExpression(object):
def reverse_ordering(self):
return self
def flatten(self):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
"""
yield self
for expr in self.get_source_expressions():
if expr:
for inner_expr in expr.flatten():
yield inner_expr
class Expression(BaseExpression, Combinable):
"""
@@ -613,6 +631,9 @@ class Random(Expression):
class Col(Expression):
contains_column_references = True
def __init__(self, alias, target, output_field=None):
if output_field is None:
output_field = target

View File

@@ -458,6 +458,8 @@ class QuerySet(object):
specifying whether an object was created.
"""
lookup, params = self._extract_model_params(defaults, **kwargs)
# The get() needs to be targeted at the write database in order
# to avoid potential transaction consistency problems.
self._for_write = True
try:
return self.get(**lookup), False

View File

@@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler):
self.return_id = False
super(SQLInsertCompiler, self).__init__(*args, **kwargs)
def placeholder(self, field, val):
def field_as_sql(self, field, val):
"""
Take a field and a value intended to be saved on that field, and
return placeholder SQL and accompanying params. Checks for raw values,
expressions and fields with get_placeholder() defined in that order.
When field is None, the value is considered raw and is used as the
placeholder, with no corresponding parameters returned.
"""
if field is None:
# A field value of None means the value is raw.
return val
sql, params = val, []
elif hasattr(val, 'as_sql'):
# This is an expression, let's compile it.
sql, params = self.compile(val)
elif hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
return field.get_placeholder(val, self, self.connection)
sql, params = field.get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
return '%s'
sql, params = '%s', [val]
# The following hook is only used by Oracle Spatial, which sometimes
# needs to yield 'NULL' and [] as its placeholder and params instead
# of '%s' and [None]. The 'NULL' placeholder is produced earlier by
# OracleOperations.get_geom_placeholder(). The following line removes
# the corresponding None parameter. See ticket #10888.
params = self.connection.ops.modify_insert_params(sql, params)
return sql, params
def prepare_value(self, field, value):
"""
Prepare a value to be used in a query by resolving it if it is an
expression and otherwise calling the field's get_db_prep_save().
"""
if hasattr(value, 'resolve_expression'):
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
# Don't allow values containing Col expressions. They refer to
# existing columns on a row, but in the case of insert the row
# doesn't exist yet.
if value.contains_column_references:
raise ValueError(
'Failed to insert expression "%s" on %s. F() expressions '
'can only be used to update, not to insert.' % (value, field)
)
if value.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
else:
value = field.get_db_prep_save(value, connection=self.connection)
return value
def pre_save_val(self, field, obj):
"""
Get the given field's value off the given obj. pre_save() is used for
things like auto_now on DateTimeField. Skip it if this is a raw query.
"""
if self.query.raw:
return getattr(obj, field.attname)
return field.pre_save(obj, add=True)
def assemble_as_sql(self, fields, value_rows):
"""
Take a sequence of N fields and a sequence of M rows of values,
generate placeholder SQL and parameters for each field and value, and
return a pair containing:
* a sequence of M rows of N SQL placeholder strings, and
* a sequence of M rows of corresponding parameter values.
Each placeholder string may contain any number of '%s' interpolation
strings, and each parameter row will contain exactly as many params
as the total number of '%s's in the corresponding placeholder row.
"""
if not value_rows:
return [], []
# list of (sql, [params]) tuples for each object to be saved
# Shape: [n_objs][n_fields][2]
rows_of_fields_as_sql = (
(self.field_as_sql(field, v) for field, v in zip(fields, row))
for row in value_rows
)
# tuple like ([sqls], [[params]s]) for each object to be saved
# Shape: [n_objs][2][n_fields]
sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
# Extract separate lists for placeholders and params.
# Each of these has shape [n_objs][n_fields]
placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
# Params for each field are still lists, and need to be flattened.
param_rows = [[p for ps in row for p in ps] for row in param_rows]
return placeholder_rows, param_rows
def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
@@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler):
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
if has_fields:
params = values = [
[
f.get_db_prep_save(
getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True),
connection=self.connection
) for f in fields
]
value_rows = [
[self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
for obj in self.query.objs
]
else:
values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
params = [[]]
# An empty object.
value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
fields = [None]
can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
not self.return_id and self.connection.features.has_bulk_insert)
if can_bulk:
placeholders = [["%s"] * len(fields)]
else:
placeholders = [
[self.placeholder(field, v) for field, v in zip(fields, val)]
for val in values
]
# Oracle Spatial needs to remove some values due to #10888
params = self.connection.ops.modify_insert_params(placeholders, params)
# Currently the backends just accept values when generating bulk
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
if self.return_id and self.connection.features.can_return_id_from_insert:
params = params[0]
params = param_rows[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholders[0]))
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
r_fmt, r_params = self.connection.ops.return_insert_id()
# Skip empty r_fmt to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
@@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler):
result.append(r_fmt % col)
params += r_params
return [(" ".join(result), tuple(params))]
if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
return [(" ".join(result), tuple(v for val in values for v in val))]
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
else:
return [
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
for p, vals in zip(placeholders, params)
for p, vals in zip(placeholder_rows, param_rows)
]
def execute_sql(self, return_id=False):
@@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler):
connection=self.connection,
)
else:
raise TypeError("Database is trying to update a relational field "
"of type %s with a value of type %s. Make sure "
"you are setting the correct relations" %
(field.__class__.__name__, val.__class__.__name__))
raise TypeError(
"Tried to update field %s with a model instance, %r. "
"Use a value compatible with %s."
% (field, val, field.__class__.__name__)
)
else:
val = field.get_db_prep_save(val, connection=self.connection)

View File

@@ -139,9 +139,9 @@ class UpdateQuery(Query):
def add_update_fields(self, values_seq):
"""
Turn a sequence of (field, model, value) triples into an update query.
Used by add_update_values() as well as the "fast" update path when
saving models.
Append a sequence of (field, model, value) triples to the internal list
that will be used to generate the UPDATE query. Might be more usefully
called add_update_targets() to hint at the extra information here.
"""
self.values.extend(values_seq)