1
0
mirror of https://github.com/django/django.git synced 2025-10-23 21:59:11 +00:00

Fixed #19385 again, now with real code changes

The commit of 266de5f9ae included only
tests, this time also code changes included...
This commit is contained in:
Anssi Kääriäinen
2013-03-24 18:40:40 +02:00
parent 266de5f9ae
commit 97774429ae
17 changed files with 653 additions and 426 deletions

View File

@@ -8,10 +8,11 @@ from functools import partial
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
from django.db.models import signals
from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
from django.db.models import signals
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo
from django.db.models.sql.where import Constraint
from django.forms import ModelForm
from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@@ -149,17 +150,14 @@ class GenericForeignKey(six.with_metaclass(RenameGenericForeignKeyMethods)):
setattr(instance, self.fk_field, fk)
setattr(instance, self.cache_attr, value)
class GenericRelation(RelatedField, Field):
class GenericRelation(ForeignObject):
"""Provides an accessor to generic related objects (e.g. comments)"""
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
kwargs['rel'] = GenericRel(to,
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
symmetrical=kwargs.pop('symmetrical', True))
kwargs['rel'] = GenericRel(
self, to, related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),)
# Override content-type/object-id field names on the related class
self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
@@ -167,47 +165,44 @@ class GenericRelation(RelatedField, Field):
kwargs['blank'] = True
kwargs['editable'] = False
kwargs['serialize'] = False
Field.__init__(self, **kwargs)
# This construct is somewhat of an abuse of ForeignObject. This field
# represents a relation from pk to object_id field. But, this relation
# isn't direct, the join is generated reverse along foreign key. So,
# the from_field is object_id field, to_field is pk because of the
# reverse join.
super(GenericRelation, self).__init__(
to, to_fields=[],
from_fields=[self.object_id_field_name], **kwargs)
def get_path_info(self):
from_field = self.model._meta.pk
def resolve_related_fields(self):
self.to_fields = [self.model._meta.pk.name]
return [(self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0],
self.model._meta.pk)]
def get_reverse_path_info(self):
opts = self.rel.to._meta
target = opts.get_field_by_name(self.object_id_field_name)[0]
# Note that we are using different field for the join_field
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
opts, target, self)
return [PathInfo(self.model._meta, opts, (target,), self.rel, True, False)]
def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
return super(GenericRelation, self).get_choices(include_blank=False)
def value_to_string(self, obj):
qs = getattr(obj, self.name).all()
return smart_text([instance._get_pk_val() for instance in qs])
def m2m_db_table(self):
return self.rel.to._meta.db_table
def m2m_column_name(self):
return self.object_id_field_name
def m2m_reverse_name(self):
return self.rel.to._meta.pk.column
def m2m_target_field_name(self):
return self.model._meta.pk.name
def m2m_reverse_target_field_name(self):
return self.rel.to._meta.pk.name
def get_joining_columns(self, reverse_join=False):
if not reverse_join:
# This error message is meant for the user, and from user
# perspective this is a reverse join along the GenericRelation.
raise ValueError('Joining in reverse direction not allowed.')
return super(GenericRelation, self).get_joining_columns(reverse_join)
def contribute_to_class(self, cls, name):
super(GenericRelation, self).contribute_to_class(cls, name)
super(GenericRelation, self).contribute_to_class(cls, name, virtual_only=True)
# Save a reference to which model this class is on for future use
self.model = cls
# Add the descriptor for the m2m relation
# Add the descriptor for the relation
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))
def contribute_to_related_class(self, cls, related):
@@ -219,21 +214,18 @@ class GenericRelation(RelatedField, Field):
def get_internal_type(self):
return "ManyToManyField"
def db_type(self, connection):
# Since we're simulating a ManyToManyField, in effect, best return the
# same db_type as well.
return None
def get_content_type(self):
"""
Returns the content type associated with this field's model.
"""
return ContentType.objects.get_for_model(self.model)
def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
contenttype = self.get_content_type().pk
return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
def get_extra_restriction(self, where_class, alias, remote_alias):
field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0]
contenttype_pk = self.get_content_type().pk
cond = where_class()
cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND')
return cond
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
@@ -273,12 +265,12 @@ class ReverseGenericRelatedObjectsDescriptor(object):
qn = connection.ops.quote_name
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance)
join_cols = self.field.get_joining_columns(reverse_join=True)[0]
manager = RelatedManager(
model = rel_model,
instance = instance,
symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model),
source_col_name = qn(self.field.m2m_column_name()),
target_col_name = qn(self.field.m2m_reverse_name()),
source_col_name = qn(join_cols[0]),
target_col_name = qn(join_cols[1]),
content_type = content_type,
content_type_field_name = self.field.content_type_field_name,
object_id_field_name = self.field.object_id_field_name,
@@ -378,14 +370,10 @@ def create_generic_related_manager(superclass):
return GenericRelatedObjectManager
class GenericRel(ManyToManyRel):
def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
self.to = to
self.related_name = related_name
self.limit_choices_to = limit_choices_to or {}
self.symmetrical = symmetrical
self.multiple = True
self.through = None
class GenericRel(ForeignObjectRel):
def __init__(self, field, to, related_name=None, limit_choices_to=None):
super(GenericRel, self).__init__(field, to, related_name, limit_choices_to)
class BaseGenericInlineFormSet(BaseModelFormSet):
"""

View File

@@ -153,8 +153,16 @@ def get_validation_errors(outfile, app=None):
continue
# Make sure the related field specified by a ForeignKey is unique
if not f.rel.to._meta.get_field(f.rel.field_name).unique:
e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__))
if f.requires_unique_target:
if len(f.foreign_related_fields) > 1:
has_unique_field = False
for rel_field in f.foreign_related_fields:
has_unique_field = has_unique_field or rel_field.unique
if not has_unique_field:
e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__))
else:
if not f.foreign_related_fields[0].unique:
e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__))
rel_opts = f.rel.to._meta
rel_name = f.related.get_accessor_name()

View File

@@ -17,6 +17,12 @@ class SQLCompiler(compiler.SQLCompiler):
values.append(value)
return row[:index_extra_select] + tuple(values)
def as_subquery_condition(self, alias, columns):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass

View File

@@ -8,7 +8,7 @@ from django.db.models.aggregates import *
from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
from django.db.models import signals
from django.utils.decorators import wraps

View File

@@ -10,7 +10,7 @@ from django.conf import settings
from django.core.exceptions import (ObjectDoesNotExist,
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (ManyToOneRel,
from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
OneToOneField, add_lazy_relation)
from django.db import (router, transaction, DatabaseError,
DEFAULT_DB_ALIAS)
@@ -333,12 +333,12 @@ class Model(six.with_metaclass(ModelBase)):
# The reason for the kwargs check is that standard iterator passes in by
# args, and instantiation for iteration is 33% faster.
args_len = len(args)
if args_len > len(self._meta.fields):
if args_len > len(self._meta.concrete_fields):
# Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields")
fields_iter = iter(self._meta.fields)
if not kwargs:
fields_iter = iter(self._meta.concrete_fields)
# The ordering of the zip calls matter - zip throws StopIteration
# when an iter throws it. So if the first iter throws it, the second
# is *not* consumed. We rely on this, so don't change the order
@@ -347,6 +347,7 @@ class Model(six.with_metaclass(ModelBase)):
setattr(self, field.attname, val)
else:
# Slower, kwargs-ready version.
fields_iter = iter(self._meta.fields)
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
kwargs.pop(field.name, None)
@@ -363,11 +364,12 @@ class Model(six.with_metaclass(ModelBase)):
# data-descriptor object (DeferredAttribute) without triggering its
# __get__ method.
if (field.attname not in kwargs and
isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)):
(isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)
or field.column is None)):
# This field will be populated on request.
continue
if kwargs:
if isinstance(field.rel, ManyToOneRel):
if isinstance(field.rel, ForeignObjectRel):
try:
# Assume object instance was passed in.
rel_obj = kwargs.pop(field.name)
@@ -394,6 +396,7 @@ class Model(six.with_metaclass(ModelBase)):
val = field.get_default()
else:
val = field.get_default()
if is_related_object:
# If we are passed a related instance, set it using the
# field.name instead of field.attname (e.g. "user" instead of
@@ -528,7 +531,7 @@ class Model(six.with_metaclass(ModelBase)):
# automatically do a "update_fields" save on the loaded fields.
elif not force_insert and self._deferred and using == self._state.db:
field_names = set()
for field in self._meta.fields:
for field in self._meta.concrete_fields:
if not field.primary_key and not hasattr(field, 'through'):
field_names.add(field.attname)
deferred_fields = [
@@ -614,7 +617,7 @@ class Model(six.with_metaclass(ModelBase)):
for a single table.
"""
meta = cls._meta
non_pks = [f for f in meta.local_fields if not f.primary_key]
non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
if update_fields:
non_pks = [f for f in non_pks
@@ -652,7 +655,7 @@ class Model(six.with_metaclass(ModelBase)):
**{field.name: getattr(self, field.attname)}).count()
self._order = order_value
fields = meta.local_fields
fields = meta.local_concrete_fields
if not pk_set:
fields = [f for f in fields if not isinstance(f, AutoField)]

View File

@@ -1,4 +1,3 @@
from functools import wraps
from operator import attrgetter
from django.db import connections, transaction, IntegrityError
@@ -196,17 +195,13 @@ class Collector(object):
self.fast_deletes.append(sub_objs)
elif sub_objs:
field.rel.on_delete(self, field, sub_objs, self.using)
# TODO This entire block is only needed as a special case to
# support cascade-deletes for GenericRelation. It should be
# removed/fixed when the ORM gains a proper abstraction for virtual
# or composite fields, and GFKs are reworked to fit into that.
for relation in model._meta.many_to_many:
if not relation.rel.through:
sub_objs = relation.bulk_related_objects(new_objs, self.using)
for field in model._meta.virtual_fields:
if hasattr(field, 'bulk_related_objects'):
# Its something like generic foreign key.
sub_objs = field.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs,
source=model,
source_attr=relation.rel.related_name,
source_attr=field.rel.related_name,
nullable=True)
def related_objects(self, related, objs):

View File

@@ -292,10 +292,13 @@ class Field(object):
if self.verbose_name is None and self.name:
self.verbose_name = self.name.replace('_', ' ')
def contribute_to_class(self, cls, name):
def contribute_to_class(self, cls, name, virtual_only=False):
self.set_attributes_from_name(name)
self.model = cls
cls._meta.add_field(self)
if virtual_only:
cls._meta.add_virtual_field(self)
else:
cls._meta.add_field(self)
if self.choices:
setattr(cls, 'get_%s_display' % self.name,
curry(cls._get_FIELD_display, field=self))

View File

@@ -7,7 +7,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE
from django.utils.encoding import smart_text
from django.utils import six
@@ -93,22 +92,27 @@ signals.class_prepared.connect(do_pending_lookups)
#HACK
class RelatedField(object):
def contribute_to_class(self, cls, name):
class RelatedField(Field):
def db_type(self, connection):
'''By default related field will not have a column
as it relates columns to another table'''
return None
def contribute_to_class(self, cls, name, virtual_only=False):
sup = super(RelatedField, self)
# Store the opts for related_query_name()
self.opts = cls._meta
if hasattr(sup, 'contribute_to_class'):
sup.contribute_to_class(cls, name)
sup.contribute_to_class(cls, name, virtual_only=virtual_only)
if not cls._meta.abstract and self.rel.related_name:
self.rel.related_name = self.rel.related_name % {
'class': cls.__name__.lower(),
'app_label': cls._meta.app_label.lower(),
}
related_name = self.rel.related_name % {
'class': cls.__name__.lower(),
'app_label': cls._meta.app_label.lower()
}
self.rel.related_name = related_name
other = self.rel.to
if isinstance(other, six.string_types) or other._meta.pk is None:
def resolve_related_class(field, model, cls):
@@ -122,7 +126,6 @@ class RelatedField(object):
self.name = self.name or (self.rel.to._meta.model_name + '_' + self.rel.to._meta.pk.name)
if self.verbose_name is None:
self.verbose_name = self.rel.to._meta.verbose_name
self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name
def do_related_class(self, other, cls):
self.set_attributes_from_rel()
@@ -130,94 +133,6 @@ class RelatedField(object):
if not cls._meta.abstract:
self.contribute_to_related_class(other, self.related)
def get_prep_lookup(self, lookup_type, value):
if hasattr(value, 'prepare'):
return value.prepare()
if hasattr(value, '_prepare'):
return value._prepare()
# FIXME: lt and gt are explicitly allowed to make
# get_(next/prev)_by_date work; other lookups are not allowed since that
# gets messy pretty quick. This is a good candidate for some refactoring
# in the future.
if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
return self._pk_trace(value, 'get_prep_lookup', lookup_type)
if lookup_type in ('range', 'in'):
return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value]
elif lookup_type == 'isnull':
return []
raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if not prepared:
value = self.get_prep_lookup(lookup_type, value)
if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabeled_clone method it means the
# value will be handled later on.
if hasattr(value, 'relabeled_clone'):
return value
if hasattr(value, 'as_sql'):
sql, params = value.as_sql()
else:
sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
# FIXME: lt and gt are explicitly allowed to make
# get_(next/prev)_by_date work; other lookups are not allowed since that
# gets messy pretty quick. This is a good candidate for some refactoring
# in the future.
if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type,
connection=connection, prepared=prepared)]
if lookup_type in ('range', 'in'):
return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type,
connection=connection, prepared=prepared)
for v in value]
elif lookup_type == 'isnull':
return []
raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
def _pk_trace(self, value, prep_func, lookup_type, **kwargs):
# Value may be a primary key, or an object held in a relation.
# If it is an object, then we need to get the primary key value for
# that object. In certain conditions (especially one-to-one relations),
# the primary key may itself be an object - so we need to keep drilling
# down until we hit a value that can be used for a comparison.
v = value
# In the case of an FK to 'self', this check allows to_field to be used
# for both forwards and reverse lookups across the FK. (For normal FKs,
# it's only relevant for forward lookups).
if isinstance(v, self.rel.to):
field_name = getattr(self.rel, "field_name", None)
else:
field_name = None
try:
while True:
if field_name is None:
field_name = v._meta.pk.name
v = getattr(v, field_name)
field_name = None
except AttributeError:
pass
except exceptions.ObjectDoesNotExist:
v = None
field = self
while field.rel:
if hasattr(field.rel, 'field_name'):
field = field.rel.to._meta.get_field(field.rel.field_name)
else:
field = field.rel.to._meta.pk
if lookup_type in ('range', 'in'):
v = [v]
v = getattr(field, prep_func)(lookup_type, v, **kwargs)
if isinstance(v, list):
v = v[0]
return v
def related_query_name(self):
# This method defines the name that can be used to identify this
# related object in a table-spanning query. It uses the lower-cased
@@ -254,8 +169,8 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
rel_obj_attr = attrgetter(self.related.field.attname)
instance_attr = lambda obj: obj._get_pk_val()
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
params = {'%s__pk__in' % self.related.field.name: list(instances_dict)}
qs = self.get_queryset(instance=instances[0]).filter(**params)
query = {'%s__in' % self.related.field.name: instances}
qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
rel_obj_cache_name = self.related.field.get_cache_name()
@@ -274,7 +189,9 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
if related_pk is None:
rel_obj = None
else:
params = {'%s__pk' % self.related.field.name: related_pk}
params = {}
for lh_field, rh_field in self.related.field.related_fields:
params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
try:
rel_obj = self.get_queryset(instance=instance).get(**params)
except self.related.model.DoesNotExist:
@@ -314,13 +231,14 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' %
(value, instance._state.db, value._state.db))
related_pk = getattr(instance, self.related.field.rel.get_related_field().attname)
if related_pk is None:
related_pk = tuple([getattr(instance, field.attname) for field in self.related.field.foreign_related_fields])
if None in related_pk:
raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' %
(value, instance._meta.object_name))
# Set the value of the related field to the value of the related object's related field
setattr(value, self.related.field.attname, related_pk)
for index, field in enumerate(self.related.field.local_related_fields):
setattr(value, field.attname, related_pk[index])
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -352,16 +270,12 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
else:
return QuerySet(self.field.rel.to).using(db)
def get_prefetch_queryset(self, instances):
other_field = self.field.rel.get_related_field()
rel_obj_attr = attrgetter(other_field.attname)
instance_attr = attrgetter(self.field.attname)
def get_prefetch_query_set(self, instances):
rel_obj_attr = self.field.get_foreign_related_value
instance_attr = self.field.get_local_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
if other_field.rel:
params = {'%s__pk__in' % self.field.rel.field_name: list(instances_dict)}
else:
params = {'%s__in' % self.field.rel.field_name: list(instances_dict)}
qs = self.get_queryset(instance=instances[0]).filter(**params)
query = {'%s__in' % self.field.related_query_name(): instances}
qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
if not self.field.rel.multiple:
@@ -377,16 +291,14 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
try:
rel_obj = getattr(instance, self.cache_name)
except AttributeError:
val = getattr(instance, self.field.attname)
if val is None:
val = self.field.get_local_related_value(instance)
if None in val:
rel_obj = None
else:
other_field = self.field.rel.get_related_field()
if other_field.rel:
params = {'%s__%s' % (self.field.rel.field_name, other_field.rel.field_name): val}
else:
params = {'%s__exact' % self.field.rel.field_name: val}
qs = self.get_queryset(instance=instance)
params = {rh_field.attname: getattr(instance, lh_field.attname)
for lh_field, rh_field in self.field.related_fields}
params.update(self.field.get_extra_descriptor_filter(instance))
qs = self.get_query_set(instance=instance)
# Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get(**params)
if not self.field.rel.multiple:
@@ -440,11 +352,11 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
setattr(related, self.field.related.get_cache_name(), None)
# Set the value of the related field
try:
val = getattr(value, self.field.rel.get_related_field().attname)
except AttributeError:
val = None
setattr(instance, self.field.attname, val)
for lh_field, rh_field in self.field.related_fields:
try:
setattr(instance, lh_field.attname, getattr(value, rh_field.attname))
except AttributeError:
setattr(instance, lh_field.attname, None)
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -487,15 +399,12 @@ class ForeignRelatedObjectsDescriptor(object):
superclass = self.related.model._default_manager.__class__
rel_field = self.related.field
rel_model = self.related.model
attname = rel_field.rel.get_related_field().attname
class RelatedManager(superclass):
def __init__(self, instance):
super(RelatedManager, self).__init__()
self.instance = instance
self.core_filters = {
'%s__%s' % (rel_field.name, attname): getattr(instance, attname)
}
self.core_filters= {'%s__exact' % rel_field.name: instance}
self.model = rel_model
def get_queryset(self):
@@ -504,20 +413,22 @@ class ForeignRelatedObjectsDescriptor(object):
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters)
val = getattr(self.instance, attname)
if val is None or val == '' and connections[db].features.interprets_empty_strings_as_nulls:
return qs.none()
empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
for field in rel_field.foreign_related_fields:
val = getattr(self.instance, field.attname)
if val is None or (val == '' and empty_strings_as_null):
return qs.none()
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_queryset(self, instances):
rel_obj_attr = attrgetter(rel_field.attname)
instance_attr = attrgetter(attname)
rel_obj_attr = rel_field.get_local_related_value
instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
db = self._db or router.db_for_read(self.model, instance=instances[0])
query = {'%s__%s__in' % (rel_field.name, attname): list(instances_dict)}
qs = super(RelatedManager, self).get_queryset().using(db).filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
query = {'%s__in' % rel_field.name: instances}
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
# Since we just bypassed this class' get_query_set(), we must manage
# the reverse relation manually.
for rel_obj in qs:
instance = instances_dict[rel_obj_attr(rel_obj)]
@@ -550,10 +461,10 @@ class ForeignRelatedObjectsDescriptor(object):
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null:
def remove(self, *objs):
val = getattr(self.instance, attname)
val = rel_field.get_foreign_related_value(self.instance)
for obj in objs:
# Is obj actually part of this descriptor set?
if getattr(obj, rel_field.attname) == val:
if rel_field.get_local_related_value(obj) == val:
setattr(obj, rel_field.name, None)
obj.save()
else:
@@ -577,16 +488,26 @@ def create_many_related_manager(superclass, rel):
super(ManyRelatedManager, self).__init__()
self.model = model
self.query_field_name = query_field_name
self.core_filters = {'%s__pk' % query_field_name: instance._get_pk_val()}
source_field = through._meta.get_field(source_field_name)
source_related_fields = source_field.related_fields
self.core_filters = {}
for lh_field, rh_field in source_related_fields:
self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
self.instance = instance
self.symmetrical = symmetrical
self.source_field = source_field
self.source_field_name = source_field_name
self.target_field_name = target_field_name
self.reverse = reverse
self.through = through
self.prefetch_cache_name = prefetch_cache_name
self._fk_val = self._get_fk_val(instance, source_field_name)
if self._fk_val is None:
self.related_val = source_field.get_foreign_related_value(instance)
# Used for single column related auto created models
self._fk_val = self.related_val[0]
if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, source_field_name))
@@ -620,11 +541,9 @@ def create_many_related_manager(superclass, rel):
def get_prefetch_queryset(self, instances):
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
query = {'%s__pk__in' % self.query_field_name:
set(obj._get_pk_val() for obj in instances)}
qs = super(ManyRelatedManager, self).get_queryset().using(db)._next_is_sticky().filter(**query)
query = {'%s__in' % self.query_field_name: instances}
qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
# that the secondary model was actually related to. We know that
@@ -634,16 +553,14 @@ def create_many_related_manager(superclass, rel):
# For non-autocreated 'through' models, can't assume we are
# dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name)
source_col = fk.column
join_table = self.through._meta.db_table
connection = connections[db]
qn = connection.ops.quote_name
qs = qs.extra(select={'_prefetch_related_val':
'%s.%s' % (qn(join_table), qn(source_col))})
select_attname = fk.rel.get_related_field().get_attname()
qs = qs.extra(select={'_prefetch_related_val_%s' % f.attname:
'%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(select_attname),
lambda result: tuple([getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields]),
lambda inst: tuple([getattr(inst, f.attname) for f in fk.foreign_related_fields]),
False,
self.prefetch_cache_name)
@@ -795,7 +712,7 @@ def create_many_related_manager(superclass, rel):
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{
source_field_name: self._fk_val
source_field_name: self.related_val
}).delete()
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the
@@ -918,19 +835,18 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager.clear()
manager.add(*value)
class ManyToOneRel(object):
def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
class ForeignObjectRel(object):
def __init__(self, field, to, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
try:
to._meta
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
self.to, self.field_name = to, field_name
self.field = field
self.to = to
self.related_name = related_name
if limit_choices_to is None:
limit_choices_to = {}
self.limit_choices_to = limit_choices_to
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
self.multiple = True
self.parent_link = parent_link
self.on_delete = on_delete
@@ -939,6 +855,20 @@ class ManyToOneRel(object):
"Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+'
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, where_class, alias, related_alias):
return self.field.get_extra_restriction(where_class, related_alias, alias)
class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
super(ManyToOneRel, self).__init__(
field, to, related_name=related_name, limit_choices_to=limit_choices_to,
parent_link=parent_link, on_delete=on_delete)
self.field_name = field_name
def get_related_field(self):
"""
Returns the Field in the 'to' object to which this relationship is
@@ -952,9 +882,9 @@ class ManyToOneRel(object):
class OneToOneRel(ManyToOneRel):
def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
super(OneToOneRel, self).__init__(to, field_name,
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
super(OneToOneRel, self).__init__(field, to, field_name,
related_name=related_name, limit_choices_to=limit_choices_to,
parent_link=parent_link, on_delete=on_delete
)
@@ -963,7 +893,7 @@ class OneToOneRel(ManyToOneRel):
class ManyToManyRel(object):
def __init__(self, to, related_name=None, limit_choices_to=None,
symmetrical=True, through=None, db_constraint=True):
symmetrical=True, through=None, db_constraint=True):
if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False")
self.to = to
@@ -989,7 +919,199 @@ class ManyToManyRel(object):
return self.to._meta.pk
class ForeignKey(RelatedField, Field):
class ForeignObject(RelatedField):
requires_unique_target = True
generate_reverse_relation = True
def __init__(self, to, from_fields, to_fields, **kwargs):
self.from_fields = from_fields
self.to_fields = to_fields
if 'rel' not in kwargs:
kwargs['rel'] = ForeignObjectRel(
self, to,
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE),
)
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
super(ForeignObject, self).__init__(**kwargs)
def resolve_related_fields(self):
if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields):
raise ValueError('Foreign Object from and to fields must be the same non-zero length')
related_fields = []
for index in range(len(self.from_fields)):
from_field_name = self.from_fields[index]
to_field_name = self.to_fields[index]
from_field = (self if from_field_name == 'self'
else self.opts.get_field_by_name(from_field_name)[0])
to_field = (self.rel.to._meta.pk if to_field_name is None
else self.rel.to._meta.get_field_by_name(to_field_name)[0])
related_fields.append((from_field, to_field))
return related_fields
@property
def related_fields(self):
if not hasattr(self, '_related_fields'):
self._related_fields = self.resolve_related_fields()
return self._related_fields
@property
def reverse_related_fields(self):
return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields]
@property
def local_related_fields(self):
return tuple([lhs_field for lhs_field, rhs_field in self.related_fields])
@property
def foreign_related_fields(self):
return tuple([rhs_field for lhs_field, rhs_field in self.related_fields])
def get_local_related_value(self, instance):
return self.get_instance_value_for_fields(instance, self.local_related_fields)
def get_foreign_related_value(self, instance):
return self.get_instance_value_for_fields(instance, self.foreign_related_fields)
@staticmethod
def get_instance_value_for_fields(instance, fields):
return tuple([getattr(instance, field.attname) for field in fields])
def get_attname_column(self):
attname, column = super(ForeignObject, self).get_attname_column()
return attname, None
def get_joining_columns(self, reverse_join=False):
source = self.reverse_related_fields if reverse_join else self.related_fields
return tuple([(lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source])
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
def get_extra_descriptor_filter(self, instance):
"""
Returns an extra filter condition for related object fetching when
user does 'instance.fieldname', that is the extra filter is used in
the descriptor of the field.
The filter should be something usable in .filter(**kwargs) call, and
will be ANDed together with the joining columns condition.
A parallel method is get_extra_relation_restriction() which is used in
JOIN and subquery conditions.
"""
return {}
def get_extra_restriction(self, where_class, alias, related_alias):
"""
Returns a pair condition used for joining and subquery pushdown. The
condition is something that responds to as_sql(qn, connection) method.
Note that currently referring both the 'alias' and 'related_alias'
will not work in some conditions, like subquery pushdown.
A parallel method is get_extra_descriptor_filter() which is used in
instance.fieldname related object fetching.
"""
return None
def get_path_info(self):
"""
Get path from this field to the related model.
"""
opts = self.rel.to._meta
from_opts = self.model._meta
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
def get_reverse_path_info(self):
"""
Get path from the related model to this field's model.
"""
opts = self.model._meta
from_opts = self.rel.to._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
raw_value):
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
root_constraint = constraint_class()
assert len(targets) == len(sources)
def get_normalized_value(value):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
for source in sources:
# Account for one-to-one relations when sent a different model
while not isinstance(value, source.model):
source = source.rel.to._meta.get_field(source.rel.field_name)
value_list.append(getattr(value, source.attname))
return tuple(value_list)
elif not isinstance(value, tuple):
return (value,)
return value
is_multicolumn = len(self.related_fields) > 1
if (hasattr(raw_value, '_as_sql') or
hasattr(raw_value, 'get_compiler')):
root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
[source.name for source in sources], raw_value),
AND)
elif lookup_type == 'isnull':
root_constraint.add(
(Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND)
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
and not is_multicolumn)):
value = get_normalized_value(raw_value)
for index, source in enumerate(sources):
root_constraint.add(
(Constraint(alias, targets[index].column, sources[index]), lookup_type,
value[index]), AND)
elif lookup_type in ['range', 'in'] and not is_multicolumn:
values = [get_normalized_value(value) for value in raw_value]
value = [val[0] for val in values]
root_constraint.add(
(Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND)
elif lookup_type == 'in':
values = [get_normalized_value(value) for value in raw_value]
for value in values:
value_constraint = constraint_class()
for index, target in enumerate(targets):
value_constraint.add(
(Constraint(alias, target.column, sources[index]), 'exact', value[index]),
AND)
root_constraint.add(value_constraint, OR)
else:
raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
return root_constraint
@property
def attnames(self):
return tuple([field.attname for field in self.local_related_fields])
def get_defaults(self):
return tuple([field.get_default() for field in self.local_related_fields])
def contribute_to_class(self, cls, name, virtual_only=False):
super(ForeignObject, self).contribute_to_class(cls, name, virtual_only=virtual_only)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
def contribute_to_related_class(self, cls, related):
# Internal FK's - i.e., those with a related name ending with '+' -
# and swapped models don't get a related descriptor.
if not self.rel.is_hidden() and not related.model._meta.swapped:
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
if self.rel.limit_choices_to:
cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
class ForeignKey(ForeignObject):
empty_strings_allowed = False
default_error_messages = {
'invalid': _('Model %(model)s with pk %(pk)r does not exist.')
@@ -999,7 +1121,7 @@ class ForeignKey(RelatedField, Field):
def __init__(self, to, to_field=None, rel_class=ManyToOneRel,
db_constraint=True, **kwargs):
try:
to._meta.model_name
to_name = to._meta.object_name.lower()
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else:
@@ -1008,44 +1130,33 @@ class ForeignKey(RelatedField, Field):
# the to_field during FK construction. It won't be guaranteed to
# be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name)
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
if 'db_index' not in kwargs:
kwargs['db_index'] = True
self.db_constraint = db_constraint
kwargs['rel'] = rel_class(to, to_field,
kwargs['rel'] = rel_class(
self, to, to_field,
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE),
)
super(ForeignKey, self).__init__(**kwargs)
super(ForeignKey, self).__init__(to, ['self'], [to_field], **kwargs)
def get_path_info(self):
"""
Get path from this field to the related model.
"""
opts = self.rel.to._meta
target = self.rel.get_related_field()
from_opts = self.model._meta
return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
@property
def related_field(self):
return self.foreign_related_fields[0]
def get_reverse_path_info(self):
"""
Get path from the related model to this field's model.
"""
opts = self.model._meta
from_field = self.rel.get_related_field()
from_opts = from_field.model._meta
pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
if from_field.model is self.model:
# Recursive foreign key to self.
target = opts.get_field_by_name(
self.rel.field_name)[0]
else:
target = opts.pk
return pathinfos, opts, target, self
from_opts = self.rel.to._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
def validate(self, value, model_instance):
if self.rel.parent_link:
@@ -1066,21 +1177,26 @@ class ForeignKey(RelatedField, Field):
def get_attname(self):
return '%s_id' % self.name
def get_attname_column(self):
attname = self.get_attname()
column = self.db_column or attname
return attname, column
def get_validator_unique_lookup_type(self):
return '%s__%s__exact' % (self.name, self.rel.get_related_field().name)
return '%s__%s__exact' % (self.name, self.related_field.name)
def get_default(self):
"Here we check if the default value is an object and return the to_field if so."
field_default = super(ForeignKey, self).get_default()
if isinstance(field_default, self.rel.to):
return getattr(field_default, self.rel.get_related_field().attname)
return getattr(field_default, self.related_field.attname)
return field_default
def get_db_prep_save(self, value, connection):
if value == '' or value == None:
return None
else:
return self.rel.get_related_field().get_db_prep_save(value,
return self.related_field.get_db_prep_save(value,
connection=connection)
def value_to_string(self, obj):
@@ -1093,19 +1209,10 @@ class ForeignKey(RelatedField, Field):
choice_list = self.get_choices_default()
if len(choice_list) == 2:
return smart_text(choice_list[1][0])
return Field.value_to_string(self, obj)
def contribute_to_class(self, cls, name):
super(ForeignKey, self).contribute_to_class(cls, name)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
return super(ForeignKey, self).value_to_string(obj)
def contribute_to_related_class(self, cls, related):
# Internal FK's - i.e., those with a related name ending with '+' -
# and swapped models don't get a related descriptor.
if not self.rel.is_hidden() and not related.model._meta.swapped:
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
if self.rel.limit_choices_to:
cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
super(ForeignKey, self).contribute_to_related_class(cls, related)
if self.rel.field_name is None:
self.rel.field_name = cls._meta.pk.name
@@ -1130,7 +1237,7 @@ class ForeignKey(RelatedField, Field):
# in which case the column type is simply that of an IntegerField.
# If the database needs similar types for key fields however, the only
# thing we can do is making AutoField an IntegerField.
rel_field = self.rel.get_related_field()
rel_field = self.related_field
if (isinstance(rel_field, AutoField) or
(not connection.features.related_fields_match_type and
isinstance(rel_field, (PositiveIntegerField,
@@ -1212,7 +1319,7 @@ def create_many_to_many_intermediary_model(field, klass):
})
class ManyToManyField(RelatedField, Field):
class ManyToManyField(RelatedField):
description = _("Many-to-many relationship")
def __init__(self, to, db_constraint=True, **kwargs):
@@ -1252,14 +1359,14 @@ class ManyToManyField(RelatedField, Field):
linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
if direct:
join1infos, _, _, _ = linkfield1.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield2.get_path_info()
join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info()
else:
join1infos, _, _, _ = linkfield2.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield1.get_path_info()
join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
return pathinfos, opts, target, final_field
return pathinfos
def get_path_info(self):
return self._get_path_info(direct=True)
@@ -1402,8 +1509,3 @@ class ManyToManyField(RelatedField, Field):
initial = initial()
defaults['initial'] = [i._get_pk_val() for i in initial]
return super(ManyToManyField, self).formfield(**defaults)
def db_type(self, connection):
# A ManyToManyField is not represented by a single column,
# so return None.
return None

View File

@@ -10,6 +10,7 @@ from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.proxy import OrderWrt
from django.db.models.loading import get_models, app_cache_ready
from django.utils import six
from django.utils.functional import cached_property
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_text, smart_text, python_2_unicode_compatible
from django.utils.translation import activate, deactivate_all, get_language, string_concat
@@ -173,6 +174,22 @@ class Options(object):
if hasattr(self, '_field_cache'):
del self._field_cache
del self._field_name_cache
# The fields, concrete_fields and local_concrete_fields are
# implemented as cached properties for performance reasons.
# The attrs will not exists if the cached property isn't
# accessed yet, hence the try-excepts.
try:
del self.fields
except AttributeError:
pass
try:
del self.concrete_fields
except AttributeError:
pass
try:
del self.local_concrete_fields
except AttributeError:
pass
if hasattr(self, '_name_map'):
del self._name_map
@@ -245,7 +262,8 @@ class Options(object):
return None
swapped = property(_swapped)
def _fields(self):
@cached_property
def fields(self):
"""
The getter for self.fields. This returns the list of field objects
available to this model (including through parent models).
@@ -258,7 +276,14 @@ class Options(object):
except AttributeError:
self._fill_fields_cache()
return self._field_name_cache
fields = property(_fields)
@cached_property
def concrete_fields(self):
return [f for f in self.fields if f.column is not None]
@cached_property
def local_concrete_fields(self):
return [f for f in self.local_fields if f.column is not None]
def get_fields_with_model(self):
"""
@@ -272,6 +297,10 @@ class Options(object):
self._fill_fields_cache()
return self._field_cache
def get_concrete_fields_with_model(self):
return [(field, model) for field, model in self.get_fields_with_model() if
field.column is not None]
def _fill_fields_cache(self):
cache = []
for parent in self.parents:
@@ -377,6 +406,9 @@ class Options(object):
cache[f.name] = (f, model, True, True)
for f, model in self.get_fields_with_model():
cache[f.name] = (f, model, True, False)
for f in self.virtual_fields:
if hasattr(f, 'related'):
cache[f.name] = (f.related, None if f.model == self.model else f.model, True, False)
if app_cache_ready():
self._name_map = cache
return cache
@@ -432,7 +464,7 @@ class Options(object):
for klass in get_models(include_auto_created=True, only_installed=False):
if not klass._meta.swapped:
for f in klass._meta.local_fields:
if f.rel and not isinstance(f.rel.to, six.string_types):
if f.rel and not isinstance(f.rel.to, six.string_types) and f.generate_reverse_relation:
if self == f.rel.to._meta:
cache[f.related] = None
proxy_cache[f.related] = None

View File

@@ -261,13 +261,13 @@ class QuerySet(object):
only_load = self.query.get_loaded_field_names()
if not fill_cache:
fields = self.model._meta.fields
fields = self.model._meta.concrete_fields
load_fields = []
# If only/defer clauses have been specified,
# build the list of fields that are to be loaded.
if only_load:
for field, model in self.model._meta.get_fields_with_model():
for field, model in self.model._meta.get_concrete_fields_with_model():
if model is None:
model = self.model
try:
@@ -280,7 +280,7 @@ class QuerySet(object):
load_fields.append(field.name)
index_start = len(extra_select)
aggregate_start = index_start + len(load_fields or self.model._meta.fields)
aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields)
skip = None
if load_fields and not fill_cache:
@@ -312,7 +312,11 @@ class QuerySet(object):
if skip:
obj = model_cls(**dict(zip(init_list, row_data)))
else:
obj = model(*row_data)
try:
obj = model(*row_data)
except IndexError:
import ipdb; ipdb.set_trace()
pass
# Store the source database of the object
obj._state.db = db
@@ -962,7 +966,7 @@ class QuerySet(object):
"""
opts = self.model._meta
if self.query.group_by is None:
field_names = [f.attname for f in opts.fields]
field_names = [f.attname for f in opts.concrete_fields]
self.query.add_fields(field_names, False)
self.query.set_group_by()
@@ -1055,7 +1059,7 @@ class ValuesQuerySet(QuerySet):
else:
# Default to all fields.
self.extra_names = None
self.field_names = [f.attname for f in self.model._meta.fields]
self.field_names = [f.attname for f in self.model._meta.concrete_fields]
self.aggregate_names = None
self.query.select = []
@@ -1266,7 +1270,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
for field, model in klass._meta.get_fields_with_model():
for field, model in klass._meta.get_concrete_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
elif from_parent and issubclass(from_parent, model.__class__):
@@ -1285,22 +1289,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
field_count = len(klass._meta.fields)
field_count = len(klass._meta.concrete_fields)
# Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.fields
field_names = [f.attname for f in klass._meta.concrete_fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
if field_count == len(klass._meta.fields):
if field_count == len(klass._meta.concrete_fields):
field_names = ()
restricted = requested is not None

View File

@@ -7,7 +7,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo',
'from_field to_field from_opts to_opts join_field '
'from_opts to_opts target_fields join_field '
'm2m direct')
class RelatedObject(object):

View File

@@ -2,10 +2,9 @@ import datetime
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import transaction
from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend
from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
@@ -33,7 +32,7 @@ class SQLCompiler(object):
# cleaned. We are not using a clone() of the query here.
"""
if not self.query.tables:
self.query.join((None, self.query.model._meta.db_table, None, None))
self.query.join((None, self.query.model._meta.db_table, None))
if (not self.query.select and self.query.default_cols and not
self.query.included_inherited_models):
self.query.setup_inherited_models()
@@ -273,7 +272,7 @@ class SQLCompiler(object):
# be used by local fields.
seen_models = {None: start_alias}
for field, model in opts.get_fields_with_model():
for field, model in opts.get_concrete_fields_with_model():
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
@@ -314,9 +313,10 @@ class SQLCompiler(object):
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
field, col, alias, _, _ = self._setup_joins(parts, opts, None)
col, alias = self._final_join_removal(col, alias)
result.append("%s.%s" % (qn(alias), qn2(col)))
field, cols, alias, _, _ = self._setup_joins(parts, opts, None)
cols, alias = self._final_join_removal(cols, alias)
for col in cols:
result.append("%s.%s" % (qn(alias), qn2(col)))
return result
@@ -387,15 +387,16 @@ class SQLCompiler(object):
elif get_order_dir(field)[0] not in self.query.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
for table, col, order in self.find_ordering_name(field,
for table, cols, order in self.find_ordering_name(field,
self.query.model._meta, default_order=asc):
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), qn2(col))
processed_pairs.add((table, col))
if distinct and elt not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
for col in cols:
if (table, col) not in processed_pairs:
elt = '%s.%s' % (qn(table), qn2(col))
processed_pairs.add((table, col))
if distinct and elt not in select_aliases:
ordering_aliases.append(elt)
result.append('%s %s' % (elt, order))
group_by.append((elt, []))
else:
elt = qn2(col)
if distinct and col not in select_aliases:
@@ -414,7 +415,7 @@ class SQLCompiler(object):
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias)
field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
@@ -432,8 +433,8 @@ class SQLCompiler(object):
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
col, alias = self._final_join_removal(col, alias)
return [(alias, col, order)]
cols, alias = self._final_join_removal(cols, alias)
return [(alias, cols, order)]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -446,13 +447,13 @@ class SQLCompiler(object):
"""
if not alias:
alias = self.query.get_initial_alias()
field, target, opts, joins, _ = self.query.setup_joins(
field, targets, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias)
# We will later on need to promote those joins that were added to the
# query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
alias = joins[-1]
col = target.column
cols = [target.column for target in targets]
if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on
@@ -463,9 +464,9 @@ class SQLCompiler(object):
# Ordering or distinct must not affect the returned set, and INNER
# JOINS for nullable fields could do this.
self.query.promote_joins(joins_to_promote)
return field, col, alias, joins, opts
return field, cols, alias, joins, opts
def _final_join_removal(self, col, alias):
def _final_join_removal(self, cols, alias):
"""
A helper method for get_distinct and get_ordering. This method will
trim extra not-needed joins from the tail of the join chain.
@@ -477,12 +478,14 @@ class SQLCompiler(object):
if alias:
while 1:
join = self.query.alias_map[alias]
if col != join.rhs_join_col:
lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols])
if set(cols) != set(rhs_cols):
break
cols = [lhs_cols[rhs_cols.index(col)] for col in cols]
self.query.unref_alias(alias)
alias = join.lhs_alias
col = join.lhs_join_col
return col, alias
return cols, alias
def get_from_clause(self):
"""
@@ -504,22 +507,30 @@ class SQLCompiler(object):
if not self.query.alias_refcount[alias]:
continue
try:
name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
if join_field and hasattr(join_field, 'get_extra_join_sql'):
extra_cond, extra_params = join_field.get_extra_join_sql(
self.connection, qn, lhs, alias)
extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs)
if extra_cond:
extra_sql, extra_params = extra_cond.as_sql(
qn, self.connection)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
extra_cond = ""
result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
(join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col), extra_cond))
extra_sql = ""
result.append('%s %s%s ON ('
% (join_type, qn(name), alias_str))
for index, (lhs_col, rhs_col) in enumerate(join_cols):
if index != 0:
result.append(' AND ')
result.append('%s.%s = %s.%s' %
(qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col)))
result.append('%s)' % extra_sql)
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@@ -545,7 +556,7 @@ class SQLCompiler(object):
select_cols = self.query.select + self.query.related_select_cols
# Just the column, not the fields.
select_cols = [s[0] for s in select_cols]
if (len(self.query.model._meta.fields) == len(self.query.select)
if (len(self.query.model._meta.concrete_fields) == len(self.query.select)
and self.connection.features.allows_group_by_pk):
self.query.group_by = [
(self.query.model._meta.db_table, self.query.model._meta.pk.column)
@@ -623,14 +634,13 @@ class SQLCompiler(object):
table = f.rel.to._meta.db_table
promote = nullable or f.null
alias = self.query.join_parent_model(opts, model, root_alias, {})
alias = self.query.join((alias, table, f.column,
f.rel.get_related_field().column),
join_cols = f.get_joining_columns()
alias = self.query.join((alias, table, join_cols),
outer_if_first=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.concrete_fields))
if restricted:
next = requested.get(f.name, {})
else:
@@ -653,7 +663,7 @@ class SQLCompiler(object):
alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {})
table = model._meta.db_table
alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column),
(alias, table, f.get_joining_columns(reverse_join=True)),
outer_if_first=True, join_field=f
)
from_parent = (opts.model if issubclass(model, opts.model)
@@ -662,7 +672,7 @@ class SQLCompiler(object):
opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field
in zip(columns, model._meta.fields))
in zip(columns, model._meta.concrete_fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable.
@@ -706,7 +716,7 @@ class SQLCompiler(object):
if self.query.select:
fields = [f.field for f in self.query.select]
else:
fields = self.query.model._meta.fields
fields = self.query.model._meta.concrete_fields
fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed
@@ -776,6 +786,22 @@ class SQLCompiler(object):
return list(result)
return result
def as_subquery_condition(self, alias, columns):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params
for index, select_col in enumerate(self.query.select):
lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1]))
rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
self.query.where.add(
QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND')
sql, params = self.as_sql()
return 'EXISTS (%s)' % sql, params
class SQLInsertCompiler(SQLCompiler):
def placeholder(self, field, val):

View File

@@ -25,7 +25,7 @@ GET_ITERATOR_CHUNK_SIZE = 100
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable join_field')
'join_cols nullable join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')

View File

@@ -55,13 +55,14 @@ class SQLEvaluator(object):
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
field, source, opts, join_list, path = query.setup_joins(
field, sources, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(),
query.get_initial_alias(), self.reuse)
target, _, join_list = query.trim_joins(source, join_list, path)
targets, _, join_list = query.trim_joins(sources, join_list, path)
if self.reuse is not None:
self.reuse.update(join_list)
self.cols.append((node, (join_list[-1], target.column)))
for t in targets:
self.cols.append((node, (join_list[-1], t.column)))
except FieldDoesNotExist:
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (self.name,

View File

@@ -452,13 +452,13 @@ class Query(object):
# Now, add the joins from rhs query into the new query (skipping base
# table).
for alias in rhs.tables[1:]:
table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias]
promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the
# updated alias.
lhs = change_map.get(lhs, lhs)
new_alias = self.join(
(lhs, table, lhs_col, col), reuse=reuse,
(lhs, table, join_cols), reuse=reuse,
outer_if_first=not conjunction, nullable=nullable,
join_field=join_field)
if promote:
@@ -682,7 +682,7 @@ class Query(object):
aliases = list(aliases)
while aliases:
alias = aliases.pop(0)
if self.alias_map[alias].rhs_join_col is None:
if self.alias_map[alias].join_cols[0][1] is None:
# This is the base table (first FROM entry) - this table
# isn't really joined at all in the query, so we should not
# alter its join type.
@@ -818,7 +818,7 @@ class Query(object):
alias = self.tables[0]
self.ref_alias(alias)
else:
alias = self.join((None, self.model._meta.db_table, None, None))
alias = self.join((None, self.model._meta.db_table, None))
return alias
def count_active_tables(self):
@@ -834,11 +834,12 @@ class Query(object):
"""
Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a
tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing
table alias or a table name. The join correspods to the SQL equivalent
of::
tuple (lhs, table, join_cols) where 'lhs' is either an existing
table alias or a table name. 'join_cols' is a tuple of tuples containing
columns to join on ((l_id1, r_id1), (l_id2, r_id2)). The join corresponds
to the SQL equivalent of::
lhs.lhs_col = table.col
lhs.l_id1 = table.r_id1 AND lhs.l_id2 = table.r_id2
The 'reuse' parameter can be either None which means all joins
(matching the connection) are reusable, or it can be a set containing
@@ -855,7 +856,7 @@ class Query(object):
The 'join_field' is the field we are joining along (if any).
"""
lhs, table, lhs_col, col = connection
lhs, table, join_cols = connection
assert lhs is None or join_field is not None
existing = self.join_map.get(connection, ())
if reuse is None:
@@ -884,7 +885,7 @@ class Query(object):
join_type = self.LOUTER
else:
join_type = self.INNER
join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
join_field)
self.alias_map[alias] = join
if connection in self.join_map:
@@ -941,7 +942,7 @@ class Query(object):
continue
link_field = int_opts.get_ancestor_link(int_model)
int_opts = int_model._meta
connection = (alias, int_opts.db_table, link_field.column, int_opts.pk.column)
connection = (alias, int_opts.db_table, link_field.get_joining_columns())
alias = seen[int_model] = self.join(connection, nullable=False,
join_field=link_field)
return alias or seen[None]
@@ -982,18 +983,20 @@ class Query(object):
# - this is an annotation over a model field
# then we need to explore the joins that are required.
field, source, opts, join_list, path = self.setup_joins(
field, sources, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias())
# Process the join chain to see if it can be trimmed
target, _, join_list = self.trim_joins(source, join_list, path)
targets, _, join_list = self.trim_joins(sources, join_list, path)
# If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned
# in order for zeros to be returned for those aggregates.
self.promote_joins(join_list, True)
col = (join_list[-1], target.column)
col = targets[0].column
source = sources[0]
col = (join_list[-1], col)
else:
# The simplest cases. No joins required -
# just reference the provided column alias.
@@ -1086,7 +1089,7 @@ class Query(object):
allow_many = not branch_negated
try:
field, target, opts, join_list, path = self.setup_joins(
field, sources, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True)
if can_reuse is not None:
@@ -1106,13 +1109,19 @@ class Query(object):
# the far end (fewer tables in a query is better). Note that join
# promotion must happen before join trimming to have the join type
# information available when reusing joins.
target, alias, join_list = self.trim_joins(target, join_list, path)
clause.add((Constraint(alias, target.column, field), lookup_type, value),
AND)
targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'):
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
lookup_type, value)
else:
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
clause.add(constraint, AND)
if current_negated and (lookup_type != 'isnull' or value is False):
self.promote_joins(join_list)
if (lookup_type != 'isnull' and (
self.is_nullable(target) or self.alias_map[join_list[-1]].join_type == self.LOUTER)):
self.is_nullable(targets[0]) or
self.alias_map[join_list[-1]].join_type == self.LOUTER)):
# The condition added here will be SQL like this:
# NOT (col IS NOT NULL), where the first NOT is added in
# upper layers of code. The reason for addition is that if col
@@ -1122,7 +1131,7 @@ class Query(object):
# (col IS NULL OR col != someval)
# <=>
# NOT (col IS NOT NULL AND col = someval).
clause.add((Constraint(alias, target.column, None), 'isnull', False), AND)
clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND)
return clause
def add_filter(self, filter_clause):
@@ -1272,22 +1281,26 @@ class Query(object):
opts = int_model._meta
else:
final_field = opts.parents[int_model]
target = final_field.rel.get_related_field()
targets = (final_field.rel.get_related_field(),)
opts = int_model._meta
path.append(PathInfo(final_field, target, final_field.model._meta,
opts, final_field, False, True))
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
if hasattr(field, 'get_path_info'):
pathinfos, opts, target, final_field = field.get_path_info()
pathinfos = field.get_path_info()
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
names_with_path.append((name, pathinfos[0:inner_pos + 1]))
raise MultiJoin(pos + 1, names_with_path)
last = pathinfos[-1]
path.extend(pathinfos)
final_field = last.join_field
opts = last.to_opts
targets = last.target_fields
names_with_path.append((name, pathinfos))
else:
# Local non-relational field.
final_field = target = field
final_field = field
targets = (field,)
break
if pos != len(names) - 1:
@@ -1297,7 +1310,7 @@ class Query(object):
"the lookup type?" % (name, names[pos + 1]))
else:
raise FieldError("Join on field %r not permitted." % name)
return path, final_field, target
return path, final_field, targets
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
allow_explicit_fk=False):
@@ -1330,7 +1343,7 @@ class Query(object):
"""
joins = [alias]
# First, generate the path for the names
path, final_field, target = self.names_to_path(
path, final_field, targets = self.names_to_path(
names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
@@ -1338,17 +1351,19 @@ class Query(object):
for pos, join in enumerate(path):
opts = join.to_opts
if join.direct:
nullable = self.is_nullable(join.from_field)
nullable = self.is_nullable(join.join_field)
else:
nullable = True
connection = alias, opts.db_table, join.from_field.column, join.to_field.column
connection = alias, opts.db_table, join.join_field.get_joining_columns()
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse,
nullable=nullable, join_field=join.join_field)
joins.append(alias)
return final_field, target, opts, joins, path
if hasattr(final_field, 'field'):
final_field = final_field.field
return final_field, targets, opts, joins, path
def trim_joins(self, target, joins, path):
def trim_joins(self, targets, joins, path):
"""
The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos
@@ -1362,13 +1377,16 @@ class Query(object):
trimmed as we don't know if there is anything on the other side of
the join.
"""
for info in reversed(path):
if info.to_field == target and info.direct:
target = info.from_field
self.unref_alias(joins.pop())
else:
for pos, info in enumerate(reversed(path)):
if len(joins) == 1 or not info.direct:
break
return target, joins[-1], joins
join_targets = set(t.column for t in info.join_field.foreign_related_fields)
cur_targets = set(t.column for t in targets)
if not cur_targets.issubset(join_targets):
break
targets = tuple(r[0] for r in info.join_field.related_fields if r[1].column in cur_targets)
self.unref_alias(joins.pop())
return targets, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
"""
@@ -1413,17 +1431,31 @@ class Query(object):
trimmed_prefix = []
paths_in_prefix = trimmed_joins
for name, path in names_with_path:
if paths_in_prefix - len(path) > 0:
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
else:
trimmed_prefix.append(
path[paths_in_prefix - len(path)].from_field.name)
if paths_in_prefix - len(path) < 0:
break
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
join_field = path[paths_in_prefix].join_field
# TODO: This should be made properly multicolumn
# join aware. It is likely better to not use build_filter
# at all, instead construct joins up to the correct point,
# then construct the needed equality constraint manually,
# or maybe using SubqueryConstraint would work, too.
# The foreign_related_fields attribute is right here, we
# don't ever split joins for direct case.
trimmed_prefix.append(
join_field.field.foreign_related_fields[0].name)
trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
return self.build_filter(
condition = self.build_filter(
('%s__in' % trimmed_prefix, query),
current_negated=True, branch_negated=True, can_reuse=can_reuse)
# Intentionally leave the other alias as blank, if the condition
# refers it, things will break here.
extra_restriction = join_field.get_extra_restriction(
self.where_class, None, [t for t in query.tables if query.alias_refcount[t]][0])
if extra_restriction:
query.where.add(extra_restriction, 'AND')
return condition
def set_empty(self):
self.where = EmptyWhere()
@@ -1502,20 +1534,17 @@ class Query(object):
try:
for name in field_names:
field, target, u2, joins, u3 = self.setup_joins(
field, targets, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
True)
final_alias = joins[-1]
col = target.column
if len(joins) > 1:
join = self.alias_map[final_alias]
if col == join.rhs_join_col:
self.unref_alias(final_alias)
final_alias = join.lhs_alias
col = join.lhs_join_col
joins = joins[:-1]
# Trim last join if possible
targets, final_alias, remaining_joins = self.trim_joins(targets, joins[-2:], path)
joins = joins[:-2] + remaining_joins
self.promote_joins(joins[1:])
self.select.append(SelectInfo((final_alias, col), field))
for target in targets:
self.select.append(SelectInfo((final_alias, target.column), target))
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
@@ -1590,7 +1619,7 @@ class Query(object):
opts = self.model._meta
if not self.select:
count = self.aggregates_module.Count(
(self.join((None, opts.db_table, None, None)), opts.pk.column),
(self.join((None, opts.db_table, None)), opts.pk.column),
is_summary=True, distinct=True)
else:
# Because of SQL portability issues, multi-column, distinct
@@ -1792,22 +1821,27 @@ class Query(object):
in "WHERE somecol IN (subquery)". This construct is needed by
split_exclude().
_"""
join_pos = 0
all_paths = []
for _, paths in names_with_path:
for path in paths:
peek = self.tables[join_pos + 1]
if self.alias_map[peek].join_type == self.LOUTER:
# Back up one level and break
select_alias = self.tables[join_pos]
select_field = path.from_field
break
select_alias = self.tables[join_pos + 1]
select_field = path.to_field
self.unref_alias(self.tables[join_pos])
join_pos += 1
self.select = [SelectInfo((select_alias, select_field.column), select_field)]
all_paths.extend(paths)
direct_join = True
for pos, path in enumerate(all_paths):
if self.alias_map[self.tables[pos + 1]].join_type == self.LOUTER:
direct_join = False
pos -= 1
break
self.unref_alias(self.tables[pos])
if path.direct:
direct_join = not direct_join
join_side = 0 if direct_join else 1
select_alias = self.tables[pos + 1]
join_field = path.join_field
if hasattr(join_field, 'field'):
join_field = join_field.field
select_fields = [r[join_side] for r in join_field.related_fields]
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
self.remove_inherited_models()
return join_pos
return pos
def is_nullable(self, field):
"""

View File

@@ -382,3 +382,28 @@ class Constraint(object):
new.__class__ = self.__class__
new.alias, new.col, new.field = change_map[self.alias], self.col, self.field
return new
class SubqueryConstraint(object):
def __init__(self, alias, columns, targets, query_object):
self.alias = alias
self.columns = columns
self.targets = targets
self.query_object = query_object
def as_sql(self, qn, connection):
query = self.query_object
# QuerySet was sent
if hasattr(query, 'values'):
# as_sql should throw if we are using a
# connection on another database
query._as_sql(connection=connection)
query = query.values(*self.targets).query
query_compiler = query.get_compiler(connection=connection)
return query_compiler.as_subquery_condition(self.alias, self.columns)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias),
self.columns, self.query_object)

View File

@@ -110,7 +110,7 @@ def model_to_dict(instance, fields=None, exclude=None):
from django.db.models.fields.related import ManyToManyField
opts = instance._meta
data = {}
for f in opts.fields + opts.many_to_many:
for f in opts.concrete_fields + opts.many_to_many:
if not f.editable:
continue
if fields and not f.name in fields:
@@ -149,7 +149,7 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, formfield_c
field_list = []
ignored = []
opts = model._meta
for f in sorted(opts.fields + opts.many_to_many):
for f in sorted(opts.concrete_fields + opts.many_to_many):
if not f.editable:
continue
if fields is not None and not f.name in fields: