1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Fixed #36118 -- Accounted for multiple primary keys in bulk_update max_batch_size.

Co-authored-by: Simon Charette <charette.s@gmail.com>
This commit is contained in:
Sarah Boyce
2025-01-27 10:28:21 +01:00
parent 0671a461c4
commit 5a2c1bc07d
6 changed files with 69 additions and 11 deletions

View File

@@ -1,12 +1,19 @@
import datetime import datetime
import uuid import uuid
from functools import lru_cache from functools import lru_cache
from itertools import chain
from django.conf import settings from django.conf import settings
from django.db import NotSupportedError from django.db import NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup from django.db.models import (
AutoField,
CompositePrimaryKey,
Exists,
ExpressionWrapper,
Lookup,
)
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.utils import timezone from django.utils import timezone
@@ -699,6 +706,12 @@ END;
def bulk_batch_size(self, fields, objs): def bulk_batch_size(self, fields, objs):
"""Oracle restricts the number of parameters in a query.""" """Oracle restricts the number of parameters in a query."""
fields = list(
chain.from_iterable(
field.fields if isinstance(field, CompositePrimaryKey) else [field]
for field in fields
)
)
if fields: if fields:
return self.connection.features.max_query_params // len(fields) return self.connection.features.max_query_params // len(fields)
return len(objs) return len(objs)

View File

@@ -36,6 +36,16 @@ class DatabaseOperations(BaseDatabaseOperations):
If there's only a single field to insert, the limit is 500 If there's only a single field to insert, the limit is 500
(SQLITE_MAX_COMPOUND_SELECT). (SQLITE_MAX_COMPOUND_SELECT).
""" """
fields = list(
chain.from_iterable(
(
field.fields
if isinstance(field, models.CompositePrimaryKey)
else [field]
)
for field in fields
)
)
if len(fields) == 1: if len(fields) == 1:
return 500 return 500
elif len(fields) > 1: elif len(fields) > 1:

View File

@@ -230,9 +230,8 @@ class Collector:
""" """
Return the objs in suitably sized batches for the used connection. Return the objs in suitably sized batches for the used connection.
""" """
field_names = [field.name for field in fields]
conn_batch_size = max( conn_batch_size = max(
connections[self.using].ops.bulk_batch_size(field_names, objs), 1 connections[self.using].ops.bulk_batch_size(fields, objs), 1
) )
if len(objs) > conn_batch_size: if len(objs) > conn_batch_size:
return [ return [

View File

@@ -874,11 +874,12 @@ class QuerySet(AltersData):
objs = tuple(objs) objs = tuple(objs)
if not all(obj._is_pk_set() for obj in objs): if not all(obj._is_pk_set() for obj in objs):
raise ValueError("All bulk_update() objects must have a primary key set.") raise ValueError("All bulk_update() objects must have a primary key set.")
fields = [self.model._meta.get_field(name) for name in fields] opts = self.model._meta
fields = [opts.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields): if any(not f.concrete or f.many_to_many for f in fields):
raise ValueError("bulk_update() can only be used with concrete fields.") raise ValueError("bulk_update() can only be used with concrete fields.")
all_pk_fields = set(self.model._meta.pk_fields) all_pk_fields = set(opts.pk_fields)
for parent in self.model._meta.all_parents: for parent in opts.all_parents:
all_pk_fields.update(parent._meta.pk_fields) all_pk_fields.update(parent._meta.pk_fields)
if any(f in all_pk_fields for f in fields): if any(f in all_pk_fields for f in fields):
raise ValueError("bulk_update() cannot be used with primary key fields.") raise ValueError("bulk_update() cannot be used with primary key fields.")
@@ -892,7 +893,9 @@ class QuerySet(AltersData):
# and once in the WHEN. Each field will also have one CAST. # and once in the WHEN. Each field will also have one CAST.
self._for_write = True self._for_write = True
connection = connections[self.db] connection = connections[self.db]
max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) max_batch_size = connection.ops.bulk_batch_size(
[opts.pk, opts.pk] + fields, objs
)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
requires_casting = connection.features.requires_casted_case_in_updates requires_casting = connection.features.requires_casted_case_in_updates
batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from django.core.management.color import no_style from django.core.management.color import no_style
from django.db import connection from django.db import connection, models
from django.test import TransactionTestCase from django.test import TransactionTestCase
from ..models import Person, Tag from ..models import Person, Tag
@@ -22,14 +22,25 @@ class OperationsTests(TransactionTestCase):
objects = range(2**16) objects = range(2**16)
self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects)) self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects))
# Each field is a parameter for each object. # Each field is a parameter for each object.
first_name_field = Person._meta.get_field("first_name")
last_name_field = Person._meta.get_field("last_name")
self.assertEqual( self.assertEqual(
connection.ops.bulk_batch_size(["id"], objects), connection.ops.bulk_batch_size([first_name_field], objects),
connection.features.max_query_params, connection.features.max_query_params,
) )
self.assertEqual( self.assertEqual(
connection.ops.bulk_batch_size(["id", "other"], objects), connection.ops.bulk_batch_size(
[first_name_field, last_name_field],
objects,
),
connection.features.max_query_params // 2, connection.features.max_query_params // 2,
) )
composite_pk = models.CompositePrimaryKey("first_name", "last_name")
composite_pk.fields = [first_name_field, last_name_field]
self.assertEqual(
connection.ops.bulk_batch_size([composite_pk, first_name_field], objects),
connection.features.max_query_params // 3,
)
def test_sql_flush(self): def test_sql_flush(self):
statements = connection.ops.sql_flush( statements = connection.ops.sql_flush(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from django.core.management.color import no_style from django.core.management.color import no_style
from django.db import connection from django.db import connection, models
from django.test import TestCase from django.test import TestCase
from ..models import Person, Tag from ..models import Person, Tag
@@ -86,3 +86,25 @@ class SQLiteOperationsTests(TestCase):
"zzz'", "zzz'",
statements[-1], statements[-1],
) )
def test_bulk_batch_size(self):
self.assertEqual(connection.ops.bulk_batch_size([], [Person()]), 1)
first_name_field = Person._meta.get_field("first_name")
last_name_field = Person._meta.get_field("last_name")
self.assertEqual(
connection.ops.bulk_batch_size([first_name_field], [Person()]), 500
)
self.assertEqual(
connection.ops.bulk_batch_size(
[first_name_field, last_name_field], [Person()]
),
connection.features.max_query_params // 2,
)
composite_pk = models.CompositePrimaryKey("first_name", "last_name")
composite_pk.fields = [first_name_field, last_name_field]
self.assertEqual(
connection.ops.bulk_batch_size(
[composite_pk, first_name_field], [Person()]
),
connection.features.max_query_params // 3,
)