From 04240b23658f8935bbfebacccc23b5e47a1d6c22 Mon Sep 17 00:00:00 2001
From: acrefoot <acrefoot@dropbox.com>
Date: Thu, 20 Aug 2015 22:38:58 -0700
Subject: [PATCH] Refs #19527 -- Allowed QuerySet.bulk_create() to set the
 primary key of its objects.

PostgreSQL support only.

Thanks Vladislav Manchev and alesasnouski for working on the patch.
---
 django/db/backends/base/features.py         |  1 +
 django/db/backends/postgresql/features.py   |  1 +
 django/db/backends/postgresql/operations.py |  8 ++++
 django/db/models/query.py                   | 47 ++++++++++++++-------
 django/db/models/sql/compiler.py            | 18 ++++++--
 docs/ref/models/querysets.txt               |  8 +++-
 docs/releases/1.10.txt                      |  8 ++++
 tests/bulk_create/tests.py                  | 19 +++++++++
 8 files changed, 90 insertions(+), 20 deletions(-)

diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
index 8005b6e6b4..91065ecc99 100644
--- a/django/db/backends/base/features.py
+++ b/django/db/backends/base/features.py
@@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object):
 
     can_use_chunked_reads = True
     can_return_id_from_insert = False
+    can_return_ids_from_bulk_insert = False
     has_bulk_insert = False
     uses_savepoints = False
     can_release_savepoints = False
diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
index 4465ab1cd1..218a7645de 100644
--- a/django/db/backends/postgresql/features.py
+++ b/django/db/backends/postgresql/features.py
@@ -5,6 +5,7 @@ from django.db.utils import InterfaceError
 class DatabaseFeatures(BaseDatabaseFeatures):
     allows_group_by_selected_pks = True
     can_return_id_from_insert = True
+    can_return_ids_from_bulk_insert = True
     has_real_datatype = True
     has_native_uuid_field = True
     has_native_duration_field = True
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 5bd433c639..41d3a7a230 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations):
     def deferrable_sql(self):
         return " DEFERRABLE INITIALLY DEFERRED"
 
+    def fetch_returned_insert_ids(self, cursor):
+        """
+        Given a cursor object that has just performed an INSERT...RETURNING
+        statement into a table that has an auto-incrementing ID, return the
+        list of newly created IDs.
+        """
+        return [item[0] for item in cursor.fetchall()]
+
     def lookup_cast(self, lookup_type, internal_type=None):
         lookup = '%s'
 
diff --git a/django/db/models/query.py b/django/db/models/query.py
index ace55a5da2..71e9ba9bda 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -411,17 +411,21 @@ class QuerySet(object):
         Inserts each of the instances into the database. This does *not* call
         save() on each of the instances, does not send any pre/post save
         signals, and does not set the primary key attribute if it is an
-        autoincrement field. Multi-table models are not supported.
+        autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
+        Multi-table models are not supported.
         """
-        # So this case is fun. When you bulk insert you don't get the primary
-        # keys back (if it's an autoincrement), so you can't insert into the
-        # child tables which references this. There are two workarounds, 1)
-        # this could be implemented if you didn't have an autoincrement pk,
-        # and 2) you could do it by doing O(n) normal inserts into the parent
-        # tables to get the primary keys back, and then doing a single bulk
-        # insert into the childmost table. Some databases might allow doing
-        # this by using RETURNING clause for the insert query. We're punting
-        # on these for now because they are relatively rare cases.
+        # When you bulk insert you don't get the primary keys back (if it's an
+        # autoincrement, except if can_return_ids_from_bulk_insert=True), so
+        # you can't insert into the child tables which references this. There
+        # are two workarounds:
+        # 1) This could be implemented if you didn't have an autoincrement pk
+        # 2) You could do it by doing O(n) normal inserts into the parent
+        #    tables to get the primary keys back and then doing a single bulk
+        #    insert into the childmost table.
+        # We currently set the primary keys on the objects when using
+        # PostgreSQL via the RETURNING ID clause. It should be possible for
+        # Oracle as well, but the semantics for  extracting the primary keys is
+        # trickier so it's not done yet.
         assert batch_size is None or batch_size > 0
         # Check that the parents share the same concrete model with the our
         # model to detect the inheritance pattern ConcreteGrandParent ->
@@ -447,7 +451,11 @@ class QuerySet(object):
                     self._batched_insert(objs_with_pk, fields, batch_size)
                 if objs_without_pk:
                     fields = [f for f in fields if not isinstance(f, AutoField)]
-                    self._batched_insert(objs_without_pk, fields, batch_size)
+                    ids = self._batched_insert(objs_without_pk, fields, batch_size)
+                    if connection.features.can_return_ids_from_bulk_insert:
+                        assert len(ids) == len(objs_without_pk)
+                    for i in range(len(ids)):
+                        objs_without_pk[i].pk = ids[i]
 
         return objs
 
@@ -1051,10 +1059,19 @@ class QuerySet(object):
             return
         ops = connections[self.db].ops
         batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
-        for batch in [objs[i:i + batch_size]
-                      for i in range(0, len(objs), batch_size)]:
-            self.model._base_manager._insert(batch, fields=fields,
-                                             using=self.db)
+        inserted_ids = []
+        for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
+            if connections[self.db].features.can_return_ids_from_bulk_insert:
+                inserted_id = self.model._base_manager._insert(
+                    item, fields=fields, using=self.db, return_id=True
+                )
+                if len(objs) > 1:
+                    inserted_ids.extend(inserted_id)
+                if len(objs) == 1:
+                    inserted_ids.append(inserted_id)
+            else:
+                self.model._base_manager._insert(item, fields=fields, using=self.db)
+        return inserted_ids
 
     def _clone(self, **kwargs):
         query = self.query.clone()
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 3b804dc0b0..5eddcf5723 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler):
         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
 
         if self.return_id and self.connection.features.can_return_id_from_insert:
-            params = param_rows[0]
+            if self.connection.features.can_return_ids_from_bulk_insert:
+                result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
+                params = param_rows
+            else:
+                result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
+                params = param_rows[0]
             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
-            result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
             r_fmt, r_params = self.connection.ops.return_insert_id()
             # Skip empty r_fmt to allow subclasses to customize behavior for
             # 3rd party backends. Refs #19096.
             if r_fmt:
                 result.append(r_fmt % col)
                 params += r_params
-            return [(" ".join(result), tuple(params))]
+            return [(" ".join(result), tuple(chain.from_iterable(params)))]
 
         if can_bulk:
             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
@@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler):
             ]
 
     def execute_sql(self, return_id=False):
-        assert not (return_id and len(self.query.objs) != 1)
+        assert not (
+            return_id and len(self.query.objs) != 1 and
+            not self.connection.features.can_return_ids_from_bulk_insert
+        )
         self.return_id = return_id
         with self.connection.cursor() as cursor:
             for sql, params in self.as_sql():
                 cursor.execute(sql, params)
             if not (return_id and cursor):
                 return
+            if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
+                return self.connection.ops.fetch_returned_insert_ids(cursor)
             if self.connection.features.can_return_id_from_insert:
+                assert len(self.query.objs) == 1
                 return self.connection.ops.fetch_returned_insert_id(cursor)
             return self.connection.ops.last_insert_id(cursor,
                     self.query.get_meta().db_table, self.query.get_meta().pk.column)
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index f626391dec..8781c4cdd4 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -1794,13 +1794,19 @@ This has a number of caveats though:
   ``post_save`` signals will not be sent.
 * It does not work with child models in a multi-table inheritance scenario.
 * If the model's primary key is an :class:`~django.db.models.AutoField` it
-  does not retrieve and set the primary key attribute, as ``save()`` does.
+  does not retrieve and set the primary key attribute, as ``save()`` does,
+  unless the database backend supports it (currently PostgreSQL).
 * It does not work with many-to-many relationships.
 
 .. versionchanged:: 1.9
 
     Support for using ``bulk_create()`` with proxy models was added.
 
+.. versionchanged:: 1.0
+
+    Support for setting primary keys on objects created using ``bulk_create()``
+    when using PostgreSQL was added.
+
 The ``batch_size`` parameter controls how many objects are created in single
 query. The default is to create all objects in one batch, except for SQLite
 where the default is such that at most 999 variables per query are used.
diff --git a/docs/releases/1.10.txt b/docs/releases/1.10.txt
index 29700f4393..6df47d1f82 100644
--- a/docs/releases/1.10.txt
+++ b/docs/releases/1.10.txt
@@ -203,6 +203,11 @@ Database backends
 
 * Temporal data subtraction was unified on all backends.
 
+* If the database supports it, backends can set
+  ``DatabaseFeatures.can_return_ids_from_bulk_insert=True`` and implement
+  ``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
+  on objects created using ``QuerySet.bulk_create()``.
+
 Email
 ~~~~~
 
@@ -315,6 +320,9 @@ Models
 * The :func:`~django.db.models.prefetch_related_objects` function is now a
   public API.
 
+* :meth:`QuerySet.bulk_create() <django.db.models.query.QuerySet.bulk_create>`
+  sets the primary key on objects when using PostgreSQL.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
index f59f335ce0..a7eb725f55 100644
--- a/tests/bulk_create/tests.py
+++ b/tests/bulk_create/tests.py
@@ -198,3 +198,22 @@ class BulkCreateTests(TestCase):
         ])
         bbb = Restaurant.objects.filter(name="betty's beetroot bar")
         self.assertEqual(bbb.count(), 1)
+
+    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
+    def test_set_pk_and_insert_single_item(self):
+        countries = []
+        with self.assertNumQueries(1):
+            countries = Country.objects.bulk_create([self.data[0]])
+        self.assertEqual(len(countries), 1)
+        self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
+
+    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
+    def test_set_pk_and_query_efficiency(self):
+        countries = []
+        with self.assertNumQueries(1):
+            countries = Country.objects.bulk_create(self.data)
+        self.assertEqual(len(countries), 4)
+        self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
+        self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1])
+        self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2])
+        self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3])