1
0
mirror of https://github.com/django/django.git synced 2025-05-07 15:36:29 +00:00

Refs #36075 -- Used field in pk_fields over field.primary_key.

This commit is contained in:
Sarah Boyce 2025-01-09 13:37:08 +01:00
parent d66137b39b
commit bf7b17d16d
8 changed files with 53 additions and 7 deletions

View File

@ -891,8 +891,9 @@ class Model(AltersData, metaclass=ModelBase):
and using == self._state.db and using == self._state.db
): ):
field_names = set() field_names = set()
pk_fields = self._meta.pk_fields
for field in self._meta.concrete_fields: for field in self._meta.concrete_fields:
if not field.primary_key and not hasattr(field, "through"): if field not in pk_fields and not hasattr(field, "through"):
field_names.add(field.attname) field_names.add(field.attname)
loaded_fields = field_names.difference(deferred_non_generated_fields) loaded_fields = field_names.difference(deferred_non_generated_fields)
if loaded_fields: if loaded_fields:
@ -1492,7 +1493,7 @@ class Model(AltersData, metaclass=ModelBase):
): ):
# no value, skip the lookup # no value, skip the lookup
continue continue
if f.primary_key and not self._state.adding: if f in self._meta.pk_fields and not self._state.adding:
# no need to check for unique primary key when editing # no need to check for unique primary key when editing
continue continue
lookup_kwargs[str(field_name)] = lookup_value lookup_kwargs[str(field_name)] = lookup_value

View File

@ -1009,7 +1009,7 @@ class Options:
""" """
names = [] names = []
for field in self.concrete_fields: for field in self.concrete_fields:
if not field.primary_key: if field not in self.pk_fields:
names.append(field.name) names.append(field.name)
if field.name != field.attname: if field.name != field.attname:
names.append(field.attname) names.append(field.attname)

View File

@ -878,7 +878,7 @@ class QuerySet(AltersData):
fields = [self.model._meta.get_field(name) for name in fields] fields = [self.model._meta.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields): if any(not f.concrete or f.many_to_many for f in fields):
raise ValueError("bulk_update() can only be used with concrete fields.") raise ValueError("bulk_update() can only be used with concrete fields.")
if any(f.primary_key for f in fields): if any(f in self.model._meta.pk_fields for f in fields):
raise ValueError("bulk_update() cannot be used with primary key fields.") raise ValueError("bulk_update() cannot be used with primary key fields.")
if not objs: if not objs:
return 0 return 0
@ -995,9 +995,10 @@ class QuerySet(AltersData):
# This is to maintain backward compatibility as these fields # This is to maintain backward compatibility as these fields
# are not updated unless explicitly specified in the # are not updated unless explicitly specified in the
# update_fields list. # update_fields list.
pk_fields = self.model._meta.pk_fields
for field in self.model._meta.local_concrete_fields: for field in self.model._meta.local_concrete_fields:
if not ( if not (
field.primary_key or field.__class__.pre_save is Field.pre_save field in pk_fields or field.__class__.pre_save is Field.pre_save
): ):
update_fields.add(field.name) update_fields.add(field.name)
if field.name != field.attname: if field.name != field.attname:

View File

@ -77,7 +77,8 @@
"model": "composite_pk.timestamped", "model": "composite_pk.timestamped",
"fields": { "fields": {
"id": 1, "id": 1,
"created": "2022-01-12T05:55:14.956" "created": "2022-01-12T05:55:14.956",
"text": ""
} }
} }
] ]

View File

@ -56,3 +56,4 @@ class TimeStamped(models.Model):
pk = models.CompositePrimaryKey("id", "created") pk = models.CompositePrimaryKey("id", "created")
id = models.SmallIntegerField(unique=True) id = models.SmallIntegerField(unique=True)
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
text = models.TextField(default="", blank=True)

View File

@ -118,6 +118,10 @@ class CompositePKModelsTests(TestCase):
self.assertSequenceEqual(ctx.exception.messages, messages) self.assertSequenceEqual(ctx.exception.messages, messages)
def test_full_clean_update(self):
with self.assertNumQueries(1):
self.comment_1.full_clean()
def test_field_conflicts(self): def test_field_conflicts(self):
test_cases = ( test_cases = (
({"pk": (1, 1), "id": 2}, (1, 1)), ({"pk": (1, 1), "id": 2}, (1, 1)),

View File

@ -1,7 +1,7 @@
from django.db import connection from django.db import connection
from django.test import TestCase from django.test import TestCase
from .models import Comment, Tenant, Token, User from .models import Comment, Tenant, TimeStamped, Token, User
class CompositePKUpdateTests(TestCase): class CompositePKUpdateTests(TestCase):
@ -57,6 +57,28 @@ class CompositePKUpdateTests(TestCase):
self.assertEqual(user.email, email) self.assertEqual(user.email, email)
self.assertEqual(count, User.objects.count()) self.assertEqual(count, User.objects.count())
def test_update_fields_deferred(self):
c = Comment.objects.defer("text", "user_id").get(pk=self.comment_1.pk)
c.text = "Hello"
with self.assertNumQueries(1) as ctx:
c.save()
sql = ctx[0]["sql"]
self.assertEqual(sql.count(connection.ops.quote_name("tenant_id")), 1)
self.assertEqual(sql.count(connection.ops.quote_name("comment_id")), 1)
c = Comment.objects.get(pk=self.comment_1.pk)
self.assertEqual(c.text, "Hello")
def test_update_fields_pk_field(self):
msg = (
"The following fields do not exist in this model, are m2m fields, "
"or are non-concrete fields: id"
)
with self.assertRaisesMessage(ValueError, msg):
self.user_1.save(update_fields=["id"])
def test_bulk_update_comments(self): def test_bulk_update_comments(self):
comment_1 = Comment.objects.get(pk=self.comment_1.pk) comment_1 = Comment.objects.get(pk=self.comment_1.pk)
comment_2 = Comment.objects.get(pk=self.comment_2.pk) comment_2 = Comment.objects.get(pk=self.comment_2.pk)
@ -77,6 +99,11 @@ class CompositePKUpdateTests(TestCase):
self.assertEqual(comment_2.text, "bar") self.assertEqual(comment_2.text, "bar")
self.assertEqual(comment_3.text, "baz") self.assertEqual(comment_3.text, "baz")
def test_bulk_update_primary_key_fields(self):
message = "bulk_update() cannot be used with primary key fields."
with self.assertRaisesMessage(ValueError, message):
Comment.objects.bulk_update([self.comment_1, self.comment_2], ["id"])
def test_update_or_create_user(self): def test_update_or_create_user(self):
test_cases = ( test_cases = (
{ {
@ -110,6 +137,16 @@ class CompositePKUpdateTests(TestCase):
self.assertEqual(user.email, fields["defaults"]["email"]) self.assertEqual(user.email, fields["defaults"]["email"])
self.assertEqual(count, User.objects.count()) self.assertEqual(count, User.objects.count())
def test_update_or_create_with_pre_save_pk_field(self):
t = TimeStamped.objects.create(id=1)
self.assertEqual(TimeStamped.objects.count(), 1)
t, created = TimeStamped.objects.update_or_create(
pk=t.pk, defaults={"text": "new text"}
)
self.assertIs(created, False)
self.assertEqual(TimeStamped.objects.count(), 1)
self.assertEqual(t.text, "new text")
def test_update_comment_by_user_email(self): def test_update_comment_by_user_email(self):
result = Comment.objects.filter(user__email=self.user_1.email).update( result = Comment.objects.filter(user__email=self.user_1.email).update(
text="foo" text="foo"

View File

@ -340,6 +340,7 @@ class CompositePKFixturesTests(TestCase):
"fields": { "fields": {
"id": 1, "id": 1,
"created": "2022-01-12T05:55:14.956", "created": "2022-01-12T05:55:14.956",
"text": "",
}, },
}, },
], ],