1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Refs #19527 -- Allowed QuerySet.bulk_create() to set the primary key of its objects.

PostgreSQL support only.

Thanks Vladislav Manchev and alesasnouski for working on the patch.
This commit is contained in:
acrefoot
2015-08-20 22:38:58 -07:00
committed by Tim Graham
parent 60633ef3de
commit 04240b2365
8 changed files with 90 additions and 20 deletions

View File

@@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object):
can_use_chunked_reads = True
can_return_id_from_insert = False
can_return_ids_from_bulk_insert = False
has_bulk_insert = False
uses_savepoints = False
can_release_savepoints = False

View File

@@ -5,6 +5,7 @@ from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
can_return_id_from_insert = True
can_return_ids_from_bulk_insert = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True

View File

@@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_ids(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the
list of newly created IDs.
"""
return [item[0] for item in cursor.fetchall()]
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'

View File

@@ -411,17 +411,21 @@ class QuerySet(object):
Inserts each of the instances into the database. This does *not* call
save() on each of the instances, does not send any pre/post save
signals, and does not set the primary key attribute if it is an
autoincrement field. Multi-table models are not supported.
autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
Multi-table models are not supported.
"""
# So this case is fun. When you bulk insert you don't get the primary
# keys back (if it's an autoincrement), so you can't insert into the
# child tables which references this. There are two workarounds, 1)
# this could be implemented if you didn't have an autoincrement pk,
# and 2) you could do it by doing O(n) normal inserts into the parent
# tables to get the primary keys back, and then doing a single bulk
# insert into the childmost table. Some databases might allow doing
# this by using RETURNING clause for the insert query. We're punting
# on these for now because they are relatively rare cases.
# When you bulk insert you don't get the primary keys back (if it's an
# autoincrement, except if can_return_ids_from_bulk_insert=True), so
# you can't insert into the child tables which references this. There
# are two workarounds:
# 1) This could be implemented if you didn't have an autoincrement pk
# 2) You could do it by doing O(n) normal inserts into the parent
# tables to get the primary keys back and then doing a single bulk
# insert into the childmost table.
# We currently set the primary keys on the objects when using
# PostgreSQL via the RETURNING ID clause. It should be possible for
# Oracle as well, but the semantics for extracting the primary keys is
# trickier so it's not done yet.
assert batch_size is None or batch_size > 0
# Check that the parents share the same concrete model with the our
# model to detect the inheritance pattern ConcreteGrandParent ->
@@ -447,7 +451,11 @@ class QuerySet(object):
self._batched_insert(objs_with_pk, fields, batch_size)
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
self._batched_insert(objs_without_pk, fields, batch_size)
ids = self._batched_insert(objs_without_pk, fields, batch_size)
if connection.features.can_return_ids_from_bulk_insert:
assert len(ids) == len(objs_without_pk)
for i in range(len(ids)):
objs_without_pk[i].pk = ids[i]
return objs
@@ -1051,10 +1059,19 @@ class QuerySet(object):
return
ops = connections[self.db].ops
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
for batch in [objs[i:i + batch_size]
for i in range(0, len(objs), batch_size)]:
self.model._base_manager._insert(batch, fields=fields,
using=self.db)
inserted_ids = []
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if connections[self.db].features.can_return_ids_from_bulk_insert:
inserted_id = self.model._base_manager._insert(
item, fields=fields, using=self.db, return_id=True
)
if len(objs) > 1:
inserted_ids.extend(inserted_id)
if len(objs) == 1:
inserted_ids.append(inserted_id)
else:
self.model._base_manager._insert(item, fields=fields, using=self.db)
return inserted_ids
def _clone(self, **kwargs):
query = self.query.clone()

View File

@@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler):
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
if self.return_id and self.connection.features.can_return_id_from_insert:
params = param_rows[0]
if self.connection.features.can_return_ids_from_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
params = param_rows
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
params = param_rows[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
r_fmt, r_params = self.connection.ops.return_insert_id()
# Skip empty r_fmt to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
if r_fmt:
result.append(r_fmt % col)
params += r_params
return [(" ".join(result), tuple(params))]
return [(" ".join(result), tuple(chain.from_iterable(params)))]
if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
@@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler):
]
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
assert not (
return_id and len(self.query.objs) != 1 and
not self.connection.features.can_return_ids_from_bulk_insert
)
self.return_id = return_id
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):
return
if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_ids(cursor)
if self.connection.features.can_return_id_from_insert:
assert len(self.query.objs) == 1
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.query.get_meta().db_table, self.query.get_meta().pk.column)