1
0
mirror of https://github.com/django/django.git synced 2025-01-19 14:52:54 +00:00

Fixed #32442 -- Used converters on returning fields from INSERT statements.

This commit is contained in:
Adam Johnson 2021-02-13 08:58:24 +00:00 committed by Mariusz Felisiak
parent 619f26d289
commit d9de74141e
4 changed files with 47 additions and 11 deletions

View File

@ -1405,6 +1405,7 @@ class SQLInsertCompiler(SQLCompiler):
returning_fields and len(self.query.objs) != 1 and
not self.connection.features.can_return_rows_from_bulk_insert
)
opts = self.query.get_meta()
self.returning_fields = returning_fields
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
@ -1412,13 +1413,21 @@ class SQLInsertCompiler(SQLCompiler):
if not self.returning_fields:
return []
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_rows(cursor)
if self.connection.features.can_return_columns_from_insert:
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
return [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)]
return [(self.connection.ops.last_insert_id(
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
),)]
rows = [self.connection.ops.fetch_returned_insert_columns(
cursor, self.returning_params,
)]
else:
rows = [(self.connection.ops.last_insert_id(
cursor, opts.db_table, opts.pk.column,
),)]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
rows = list(self.apply_converters(rows, converters))
return rows
class SQLDeleteCompiler(SQLCompiler):

View File

@ -20,7 +20,7 @@ class MyWrapper:
return self.value == other
class MyAutoField(models.CharField):
class MyWrapperField(models.CharField):
def __init__(self, *args, **kwargs):
kwargs['max_length'] = 10
@ -58,3 +58,15 @@ class MyAutoField(models.CharField):
if isinstance(value, MyWrapper):
return str(value)
return value
class MyAutoField(models.BigAutoField):
def from_db_value(self, value, expression, connection):
if value is None:
return None
return MyWrapper(value)
def get_prep_value(self, value):
if value is None:
return None
return int(value)

View File

@ -7,7 +7,7 @@ this behavior by explicitly adding ``primary_key=True`` to a field.
from django.db import models
from .fields import MyAutoField
from .fields import MyAutoField, MyWrapperField
class Employee(models.Model):
@ -31,8 +31,12 @@ class Business(models.Model):
class Bar(models.Model):
id = MyAutoField(primary_key=True, db_index=True)
id = MyWrapperField(primary_key=True, db_index=True)
class Foo(models.Model):
bar = models.ForeignKey(Bar, models.CASCADE)
class CustomAutoFieldModel(models.Model):
id = MyAutoField(primary_key=True)

View File

@ -1,7 +1,8 @@
from django.db import IntegrityError, transaction
from django.test import TestCase, skipIfDBFeature
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from .models import Bar, Business, Employee, Foo
from .fields import MyWrapper
from .models import Bar, Business, CustomAutoFieldModel, Employee, Foo
class BasicCustomPKTests(TestCase):
@ -230,3 +231,13 @@ class CustomPKTests(TestCase):
with self.assertRaises(IntegrityError):
with transaction.atomic():
Employee.objects.create(first_name="Tom", last_name="Smith")
def test_auto_field_subclass_create(self):
obj = CustomAutoFieldModel.objects.create()
self.assertIsInstance(obj.id, MyWrapper)
@skipUnlessDBFeature('can_return_rows_from_bulk_insert')
def test_auto_field_subclass_bulk_create(self):
obj = CustomAutoFieldModel()
CustomAutoFieldModel.objects.bulk_create([obj])
self.assertIsInstance(obj.id, MyWrapper)