1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Fixed #21612 -- Made QuerySet.update() respect to_field

This commit is contained in:
Karen Tracey
2014-11-15 14:36:41 -05:00
committed by Tim Graham
parent de912495ab
commit dec93d8991
3 changed files with 24 additions and 3 deletions

View File

@@ -826,10 +826,10 @@ class Model(six.with_metaclass(ModelBase)):
setattr(self, cachename, obj) setattr(self, cachename, obj)
return getattr(self, cachename) return getattr(self, cachename)
def prepare_database_save(self, unused): def prepare_database_save(self, field):
if self.pk is None: if self.pk is None:
raise ValueError("Unsaved model instance %r cannot be used in an ORM query." % self) raise ValueError("Unsaved model instance %r cannot be used in an ORM query." % self)
return self.pk return getattr(self, field.rel.field_name)
def clean(self): def clean(self):
""" """

View File

@@ -42,3 +42,11 @@ class C(models.Model):
class D(C): class D(C):
a = models.ForeignKey(A) a = models.ForeignKey(A)
class Foo(models.Model):
target = models.CharField(max_length=10, unique=True)
class Bar(models.Model):
foo = models.ForeignKey(Foo, to_field='target')

View File

@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from .models import A, B, D, DataPoint, RelatedPoint from .models import A, B, D, DataPoint, RelatedPoint, Foo, Bar
class SimpleTest(TestCase): class SimpleTest(TestCase):
@@ -125,3 +125,16 @@ class AdvancedTests(TestCase):
method = DataPoint.objects.all()[:2].update method = DataPoint.objects.all()[:2].update
self.assertRaises(AssertionError, method, self.assertRaises(AssertionError, method,
another_value='another thing') another_value='another thing')
def test_update_respects_to_field(self):
"""
Update of an FK field which specifies a to_field works.
"""
a_foo = Foo.objects.create(target='aaa')
b_foo = Foo.objects.create(target='bbb')
bar = Bar.objects.create(foo=a_foo)
self.assertEqual(bar.foo_id, a_foo.target)
bar_qs = Bar.objects.filter(pk=bar.pk)
self.assertEqual(bar_qs[0].foo_id, a_foo.target)
bar_qs.update(foo=b_foo)
self.assertEqual(bar_qs[0].foo_id, b_foo.target)