1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Refs #27222 -- Refreshed GeneratedFields values on save() initiated update.

This required implementing UPDATE RETURNING machinery that heavily
borrows from the INSERT one.
This commit is contained in:
Simon Charette
2025-03-19 01:11:34 -04:00
committed by Mariusz Felisiak
parent c48904a225
commit 55a0073b3b
12 changed files with 213 additions and 59 deletions

View File

@@ -173,11 +173,6 @@ class BaseGeneratedFieldTests(SimpleTestCase):
class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m):
if not connection.features.can_return_columns_from_insert:
m.refresh_from_db()
return m
def test_unsaved_error(self):
m = self.base_model(a=1, b=2)
msg = "Cannot retrieve deferred field 'field' from an unsaved model."
@@ -189,8 +184,11 @@ class GeneratedFieldTestMixin:
# full_clean() ignores GeneratedFields.
m.full_clean()
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 3)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
@skipUnlessDBFeature("supports_table_check_constraints")
def test_full_clean_with_check_constraint(self):
@@ -199,8 +197,11 @@ class GeneratedFieldTestMixin:
m = self.check_constraint_model(a=2)
m.full_clean()
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.a_squared, 4)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.a_squared, 4)
m = self.check_constraint_model(a=-1)
with self.assertRaises(ValidationError) as cm:
@@ -217,8 +218,11 @@ class GeneratedFieldTestMixin:
m = self.unique_constraint_model(a=2)
m.full_clean()
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.a_squared, 4)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.a_squared, 4)
m = self.unique_constraint_model(a=2)
with self.assertRaises(ValidationError) as cm:
@@ -230,8 +234,11 @@ class GeneratedFieldTestMixin:
def test_create(self):
m = self.base_model.objects.create(a=1, b=2)
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 3)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_non_nullable_create(self):
with self.assertRaises(IntegrityError):
@@ -241,26 +248,52 @@ class GeneratedFieldTestMixin:
# Insert.
m = self.base_model(a=2, b=4)
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 6)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 6)
# Update.
m.a = 4
m.save()
m.refresh_from_db()
self.assertEqual(m.field, 8)
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 8)
# Update non-dependent field.
self.base_model.objects.filter(pk=m.pk).update(a=6)
m.save(update_fields=["fk"])
with self.assertNumQueries(0):
self.assertEqual(m.field, 8)
# Update dependent field without persisting local changes.
m.save(update_fields=["b"])
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 10)
# Update dependent field while persisting local changes.
m.a = 8
m.save(update_fields=["a"])
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 12)
def test_save_model_with_pk(self):
m = self.base_model(pk=1, a=1, b=2)
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 3)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_save_model_with_foreign_key(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
m = self.base_model(a=1, b=2, fk=fk_object)
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 3)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, 3)
def test_generated_fields_can_be_deferred(self):
fk_object = Foo.objects.create(a="abc", d=Decimal("12.34"))
@@ -330,17 +363,23 @@ class GeneratedFieldTestMixin:
def test_model_with_params(self):
m = self.params_model.objects.create()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, "Constant")
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m.field, "Constant")
def test_nullable(self):
m1 = self.nullable_model.objects.create()
m1 = self._refresh_if_needed(m1)
none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
self.assertEqual(m1.lower_name, none_val)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m1.lower_name, none_val)
m2 = self.nullable_model.objects.create(name="NaMe")
m2 = self._refresh_if_needed(m2)
self.assertEqual(m2.lower_name, "name")
with self.assertNumQueries(expected_num_queries):
self.assertEqual(m2.lower_name, "name")
@skipUnlessDBFeature("supports_stored_generated_columns")
@@ -354,8 +393,21 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
def test_create_field_with_db_converters(self):
obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
obj = self._refresh_if_needed(obj)
self.assertEqual(obj.field, obj.field_copy)
expected_num_queries = (
0 if connection.features.can_return_columns_from_insert else 1
)
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.field, obj.field_copy)
def test_save_field_with_db_converters(self):
obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4())
obj.field = uuid.uuid4()
expected_num_queries = (
0 if connection.features.can_return_rows_from_update else 1
)
obj.save(update_fields={"field"})
with self.assertNumQueries(expected_num_queries):
self.assertEqual(obj.field, obj.field_copy)
def test_create_with_non_auto_pk(self):
obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2)