mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #24343 -- Ensure db converters are used for foreign keys.
Joint effort between myself, Josh, Anssi and Shai.
This commit is contained in:
@@ -585,10 +585,10 @@ class Random(ExpressionNode):
|
||||
|
||||
|
||||
class Col(ExpressionNode):
|
||||
def __init__(self, alias, target, source=None):
|
||||
if source is None:
|
||||
source = target
|
||||
super(Col, self).__init__(output_field=source)
|
||||
def __init__(self, alias, target, output_field=None):
|
||||
if output_field is None:
|
||||
output_field = target
|
||||
super(Col, self).__init__(output_field=output_field)
|
||||
self.alias, self.target = alias, target
|
||||
|
||||
def __repr__(self):
|
||||
@@ -606,7 +606,10 @@ class Col(ExpressionNode):
|
||||
return [self]
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return self.output_field.get_db_converters(connection)
|
||||
if self.target == self.output_field:
|
||||
return self.output_field.get_db_converters(connection)
|
||||
return (self.output_field.get_db_converters(connection) +
|
||||
self.target.get_db_converters(connection))
|
||||
|
||||
|
||||
class Ref(ExpressionNode):
|
||||
|
@@ -330,12 +330,12 @@ class Field(RegisterLookupMixin):
|
||||
]
|
||||
return []
|
||||
|
||||
def get_col(self, alias, source=None):
|
||||
if source is None:
|
||||
source = self
|
||||
if alias != self.model._meta.db_table or source != self:
|
||||
def get_col(self, alias, output_field=None):
|
||||
if output_field is None:
|
||||
output_field = self
|
||||
if alias != self.model._meta.db_table or output_field != self:
|
||||
from django.db.models.expressions import Col
|
||||
return Col(alias, self, source)
|
||||
return Col(alias, self, output_field)
|
||||
else:
|
||||
return self.cached_col
|
||||
|
||||
|
@@ -2064,6 +2064,20 @@ class ForeignKey(ForeignObject):
|
||||
def db_parameters(self, connection):
|
||||
return {"type": self.db_type(connection), "check": []}
|
||||
|
||||
def convert_empty_strings(self, value, connection, context):
|
||||
if (not value) and isinstance(value, six.string_types):
|
||||
return None
|
||||
return value
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
converters = super(ForeignKey, self).get_db_converters(connection)
|
||||
if connection.features.interprets_empty_strings_as_nulls:
|
||||
converters += [self.convert_empty_strings]
|
||||
return converters
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
return super(ForeignKey, self).get_col(alias, output_field or self.related_field)
|
||||
|
||||
|
||||
class OneToOneField(ForeignKey):
|
||||
"""
|
||||
|
@@ -57,7 +57,7 @@ class ModelIterator(BaseIterator):
|
||||
model_cls = klass_info['model']
|
||||
select_fields = klass_info['select_fields']
|
||||
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
|
||||
init_list = [f[0].output_field.attname
|
||||
init_list = [f[0].target.attname
|
||||
for f in select[model_fields_start:model_fields_end]]
|
||||
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||
init_set = set(init_list)
|
||||
@@ -1618,7 +1618,7 @@ class RelatedPopulator(object):
|
||||
self.cols_start = select_fields[0]
|
||||
self.cols_end = select_fields[-1] + 1
|
||||
self.init_list = [
|
||||
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
|
||||
f[0].target.attname for f in select[self.cols_start:self.cols_end]
|
||||
]
|
||||
self.reorder_for_init = None
|
||||
else:
|
||||
@@ -1627,7 +1627,7 @@ class RelatedPopulator(object):
|
||||
]
|
||||
reorder_map = []
|
||||
for idx in select_fields:
|
||||
field = select[idx][0].output_field
|
||||
field = select[idx][0].target
|
||||
init_pos = model_init_attnames.index(field.attname)
|
||||
reorder_map.append((init_pos, field.attname, idx))
|
||||
reorder_map.sort()
|
||||
|
@@ -1458,7 +1458,7 @@ class Query(object):
|
||||
# database from tripping over IN (...,NULL,...) selects and returning
|
||||
# nothing
|
||||
col = query.select[0]
|
||||
select_field = col.field
|
||||
select_field = col.target
|
||||
alias = col.alias
|
||||
if self.is_nullable(select_field):
|
||||
lookup_class = select_field.get_lookup('isnull')
|
||||
|
@@ -369,6 +369,10 @@ class PrimaryKeyUUIDModel(models.Model):
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
|
||||
|
||||
|
||||
class RelatedToUUIDModel(models.Model):
|
||||
uuid_fk = models.ForeignKey('PrimaryKeyUUIDModel')
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
||||
# See ticket #24215.
|
||||
|
@@ -5,7 +5,9 @@ from django.core import exceptions, serializers
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import NullableUUIDModel, PrimaryKeyUUIDModel, UUIDModel
|
||||
from .models import (
|
||||
NullableUUIDModel, PrimaryKeyUUIDModel, RelatedToUUIDModel, UUIDModel,
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoad(TestCase):
|
||||
@@ -121,3 +123,9 @@ class TestAsPrimaryKey(TestCase):
|
||||
self.assertTrue(u1_found)
|
||||
self.assertTrue(u2_found)
|
||||
self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2)
|
||||
|
||||
def test_underlying_field(self):
|
||||
pk_model = PrimaryKeyUUIDModel.objects.create()
|
||||
RelatedToUUIDModel.objects.create(uuid_fk=pk_model)
|
||||
related = RelatedToUUIDModel.objects.get()
|
||||
self.assertEqual(related.uuid_fk.pk, related.uuid_fk_id)
|
||||
|
Reference in New Issue
Block a user