From 4608d34b346c28d5d227363c881d3279378f40b3 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Mon, 9 Dec 2024 18:38:18 -0500 Subject: [PATCH] Fixed #36088 -- Avoided unnecessary DEFAULT usage on bulk_create(). When all values of a field with a db_default are DatabaseDefault, which is the case most of the time, there is no point in specifying explicit DEFAULT for all INSERT VALUES as that's what the database will do anyway if not specified. In the case of PostgreSQL doing so can even be harmful as it prevents the usage of the UNNEST strategy and in the case of Oracle, which doesn't support the usage of the DEFAULT keyword, it unnecessarily requires providing literal db defaults. Thanks Lily Foote for the review. --- django/db/models/query.py | 11 ---- django/db/models/sql/compiler.py | 62 ++++++++++++++++--- tests/backends/models.py | 2 +- tests/backends/postgresql/test_compilation.py | 6 ++ tests/bulk_create/models.py | 6 ++ tests/bulk_create/tests.py | 26 ++++++++ 6 files changed, 91 insertions(+), 22 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 4aa7f03a5f..84806a5f72 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -670,21 +670,10 @@ class QuerySet(AltersData): acreate.alters_data = True def _prepare_for_bulk_create(self, objs): - from django.db.models.expressions import DatabaseDefault - - connection = connections[self.db] for obj in objs: if not obj._is_pk_set(): # Populate new PK values. obj.pk = obj._meta.pk.get_pk_value_on_save(obj) - if not connection.features.supports_default_keyword_in_bulk_insert: - for field in obj._meta.fields: - if field.generated: - continue - value = getattr(obj, field.attname) - if isinstance(value, DatabaseDefault): - setattr(obj, field.attname, field.db_default) - obj._prepare_related_fields_for_save(operation_name="bulk_create") def _check_bulk_create_options( diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 04372c509e..3bfb3bd631 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1810,23 +1810,65 @@ class SQLInsertCompiler(SQLCompiler): on_conflict=self.query.on_conflict, ) result = ["%s %s" % (insert_statement, qn(opts.db_table))] - fields = self.query.fields or [opts.pk] - result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) - if self.query.fields: - value_rows = [ - [ - self.prepare_value(field, self.pre_save_val(field, obj)) - for field in fields + if fields := list(self.query.fields): + from django.db.models.expressions import DatabaseDefault + + supports_default_keyword_in_bulk_insert = ( + self.connection.features.supports_default_keyword_in_bulk_insert + ) + value_cols = [] + for field in list(fields): + field_prepare = partial(self.prepare_value, field) + field_pre_save = partial(self.pre_save_val, field) + field_values = [ + field_prepare(field_pre_save(obj)) for obj in self.query.objs ] - for obj in self.query.objs - ] + + if not field.has_db_default(): + value_cols.append(field_values) + continue + + # If all values are DEFAULT don't include the field and its + # values in the query as they are redundant and could prevent + # optimizations. This cannot be done if we're dealing with the + # last field as INSERT statements require at least one. + if len(fields) > 1 and all( + isinstance(value, DatabaseDefault) for value in field_values + ): + fields.remove(field) + continue + + if supports_default_keyword_in_bulk_insert: + value_cols.append(field_values) + continue + + # If the field cannot be excluded from the INSERT for the + # reasons listed above and the backend doesn't support the + # DEFAULT keyword each values must be expanded into their + # underlying expressions. + prepared_db_default = field_prepare(field.db_default) + field_values = [ + ( + prepared_db_default + if isinstance(value, DatabaseDefault) + else value + ) + for value in field_values + ] + value_cols.append(field_values) + value_rows = list(zip(*value_cols)) + result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) else: - # An empty object. + # No fields were specified but an INSERT statement must include at + # least one column. This can only happen when the model's primary + # key is composed of a single auto-field so default to including it + # as a placeholder to generate a valid INSERT statement. value_rows = [ [self.connection.ops.pk_default_value()] for _ in self.query.objs ] fields = [None] + result.append("(%s)" % qn(opts.pk.column)) # Currently the backends just accept values when generating bulk # queries and generate their own placeholders. Doing that isn't diff --git a/tests/backends/models.py b/tests/backends/models.py index 1ed108c2b8..afb6ebe303 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -5,7 +5,7 @@ from django.db import models class Square(models.Model): root = models.IntegerField() - square = models.PositiveIntegerField() + square = models.PositiveIntegerField(db_default=9) def __str__(self): return "%s ** 2 == %s" % (self.root, self.square) diff --git a/tests/backends/postgresql/test_compilation.py b/tests/backends/postgresql/test_compilation.py index 67fe893e35..5a86a427ff 100644 --- a/tests/backends/postgresql/test_compilation.py +++ b/tests/backends/postgresql/test_compilation.py @@ -27,3 +27,9 @@ class BulkCreateUnnestTests(TestCase): [Square(root=2, square=4), Square(root=3, square=9)] ) self.assertIn("UNNEST", ctx[0]["sql"]) + + def test_unnest_eligible_db_default(self): + with self.assertNumQueries(1) as ctx: + squares = Square.objects.bulk_create([Square(root=3), Square(root=3)]) + self.assertIn("UNNEST", ctx[0]["sql"]) + self.assertEqual([square.square for square in squares], [9, 9]) diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py index 8a21c7dfa1..f0df9da66e 100644 --- a/tests/bulk_create/models.py +++ b/tests/bulk_create/models.py @@ -3,6 +3,7 @@ import uuid from decimal import Decimal from django.db import models +from django.db.models.functions import Now from django.utils import timezone try: @@ -141,3 +142,8 @@ class RelatedModel(models.Model): name = models.CharField(max_length=15, null=True) country = models.OneToOneField(Country, models.CASCADE, primary_key=True) big_auto_fields = models.ManyToManyField(BigAutoFieldModel) + + +class DbDefaultModel(models.Model): + name = models.CharField(max_length=10) + created_at = models.DateTimeField(db_default=Now()) diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 7b86a2def5..83ff8e4514 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -17,10 +17,12 @@ from django.test import ( skipIfDBFeature, skipUnlessDBFeature, ) +from django.utils import timezone from .models import ( BigAutoFieldModel, Country, + DbDefaultModel, FieldsWithDbColumns, NoFields, NullableFields, @@ -840,3 +842,27 @@ class BulkCreateTests(TestCase): {"rank": 2, "name": "d"}, ], ) + + def test_db_default_field_excluded(self): + # created_at is excluded when no db_default override is provided. + with self.assertNumQueries(1) as ctx: + DbDefaultModel.objects.bulk_create( + [DbDefaultModel(name="foo"), DbDefaultModel(name="bar")] + ) + created_at_quoted_name = connection.ops.quote_name("created_at") + self.assertEqual( + ctx[0]["sql"].count(created_at_quoted_name), + 1 if connection.features.can_return_rows_from_bulk_insert else 0, + ) + # created_at is included when a db_default override is provided. + with self.assertNumQueries(1) as ctx: + DbDefaultModel.objects.bulk_create( + [ + DbDefaultModel(name="foo", created_at=timezone.now()), + DbDefaultModel(name="bar"), + ] + ) + self.assertEqual( + ctx[0]["sql"].count(created_at_quoted_name), + 2 if connection.features.can_return_rows_from_bulk_insert else 1, + )