1
0
mirror of https://github.com/django/django.git synced 2025-10-25 14:46:09 +00:00

Fixed #10790 -- Refactored sql.Query.setup_joins()

This is a rather large refactoring. The "lookup traversal" code was
splitted out from the setup_joins. There is now names_to_path() method
which does the lookup traveling, the actual work of setup_joins() is
calling names_to_path() and then adding the joins found into the query.

As a side effect it was possible to remove the "process_extra"
functionality used by genric relations. This never worked for left
joins. Now the extra restriction is appended directly to the join
condition instead of the where clause.

To generate the extra condition we need to have the join field
available in the compiler. This has the side-effect that we need more
ugly code in Query.__getstate__ and __setstate__ as Field objects
aren't pickleable.

The join trimming code got a big change - now we trim all direct joins
and never trim reverse joins. This also fixes the problem in #10790
which was join trimming in null filter cases.
This commit is contained in:
Anssi Kääriäinen
2012-08-25 16:33:07 +03:00
parent f811649710
commit 69597e5bcc
8 changed files with 562 additions and 266 deletions

View File

@@ -205,17 +205,16 @@ class GenericRelation(RelatedField, Field):
# same db_type as well. # same db_type as well.
return None return None
def extra_filters(self, pieces, pos, negate): def get_content_type(self):
""" """
Return an extra filter to the queryset so that the results are filtered Returns the content type associated with this field's model.
on the appropriate content type.
""" """
if negate: return ContentType.objects.get_for_model(self.model)
return []
content_type = ContentType.objects.get_for_model(self.model) def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
prefix = "__".join(pieces[:pos + 1]) extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
return [("%s__%s" % (prefix, self.content_type_field_name), contenttype = self.get_content_type().pk
content_type)] return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
""" """
@@ -246,9 +245,6 @@ class ReverseGenericRelatedObjectsDescriptor(object):
if instance is None: if instance is None:
return self return self
# This import is done here to avoid circular import importing this module
from django.contrib.contenttypes.models import ContentType
# Dynamically create a class that subclasses the related model's # Dynamically create a class that subclasses the related model's
# default manager. # default manager.
rel_model = self.field.rel.to rel_model = self.field.rel.to
@@ -379,8 +375,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
def __init__(self, data=None, files=None, instance=None, save_as_new=None, def __init__(self, data=None, files=None, instance=None, save_as_new=None,
prefix=None, queryset=None): prefix=None, queryset=None):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
opts = self.model._meta opts = self.model._meta
self.instance = instance self.instance = instance
self.rel_name = '-'.join(( self.rel_name = '-'.join((
@@ -409,8 +403,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
)) ))
def save_new(self, form, commit=True): def save_new(self, form, commit=True):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
kwargs = { kwargs = {
self.ct_field.get_attname(): ContentType.objects.get_for_model(self.instance).pk, self.ct_field.get_attname(): ContentType.objects.get_for_model(self.instance).pk,
self.ct_fk_field.get_attname(): self.instance.pk, self.ct_fk_field.get_attname(): self.instance.pk,
@@ -432,8 +424,6 @@ def generic_inlineformset_factory(model, form=ModelForm,
defaults ``content_type`` and ``object_id`` respectively. defaults ``content_type`` and ``object_id`` respectively.
""" """
opts = model._meta opts = model._meta
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
# if there is no field called `ct_field` let the exception propagate # if there is no field called `ct_field` let the exception propagate
ct_field = opts.get_field(ct_field) ct_field = opts.get_field(ct_field)
if not isinstance(ct_field, models.ForeignKey) or ct_field.rel.to != ContentType: if not isinstance(ct_field, models.ForeignKey) or ct_field.rel.to != ContentType:

View File

@@ -274,7 +274,8 @@ class SQLCompiler(object):
except KeyError: except KeyError:
link_field = opts.get_ancestor_link(model) link_field = opts.get_ancestor_link(model)
alias = self.query.join((start_alias, model._meta.db_table, alias = self.query.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column)) link_field.column, model._meta.pk.column),
join_field=link_field)
seen[model] = alias seen[model] = alias
else: else:
# If we're starting from the base model of the queryset, the # If we're starting from the base model of the queryset, the
@@ -448,8 +449,8 @@ class SQLCompiler(object):
""" """
if not alias: if not alias:
alias = self.query.get_initial_alias() alias = self.query.get_initial_alias()
field, target, opts, joins, _, _ = self.query.setup_joins(pieces, field, target, opts, joins, _ = self.query.setup_joins(
opts, alias, REUSE_ALL) pieces, opts, alias, REUSE_ALL)
# We will later on need to promote those joins that were added to the # We will later on need to promote those joins that were added to the
# query afresh above. # query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
@@ -501,20 +502,27 @@ class SQLCompiler(object):
qn = self.quote_name_unless_alias qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
first = True first = True
from_params = []
for alias in self.query.tables: for alias in self.query.tables:
if not self.query.alias_refcount[alias]: if not self.query.alias_refcount[alias]:
continue continue
try: try:
name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias] name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
except KeyError: except KeyError:
# Extra tables can end up in self.tables, but not in the # Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them. # alias_map if they aren't in a join. That's OK. We skip them.
continue continue
alias_str = (alias != name and ' %s' % alias or '') alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first: if join_type and not first:
result.append('%s %s%s ON (%s.%s = %s.%s)' if join_field and hasattr(join_field, 'get_extra_join_sql'):
% (join_type, qn(name), alias_str, qn(lhs), extra_cond, extra_params = join_field.get_extra_join_sql(
qn2(lhs_col), qn(alias), qn2(col))) self.connection, qn, lhs, alias)
from_params.extend(extra_params)
else:
extra_cond = ""
result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
(join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col), extra_cond))
else: else:
connector = not first and ', ' or '' connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str)) result.append('%s%s%s' % (connector, qn(name), alias_str))
@@ -528,7 +536,7 @@ class SQLCompiler(object):
connector = not first and ', ' or '' connector = not first and ', ' or ''
result.append('%s%s' % (connector, qn(alias))) result.append('%s%s' % (connector, qn(alias)))
first = False first = False
return result, [] return result, from_params
def get_grouping(self, ordering_group_by): def get_grouping(self, ordering_group_by):
""" """
@@ -638,7 +646,7 @@ class SQLCompiler(object):
alias = self.query.join((alias, table, f.column, alias = self.query.join((alias, table, f.column,
f.rel.get_related_field().column), f.rel.get_related_field().column),
promote=promote) promote=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias, columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True) opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend( self.query.related_select_cols.extend(
@@ -685,7 +693,7 @@ class SQLCompiler(object):
alias_chain.append(alias) alias_chain.append(alias)
alias = self.query.join( alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column), (alias, table, f.rel.get_related_field().column, f.column),
promote=True promote=True, join_field=f
) )
from_parent = (opts.model if issubclass(model, opts.model) from_parent = (opts.model if issubclass(model, opts.model)
else None) else None)

View File

@@ -18,12 +18,19 @@ QUERY_TERMS = set([
# Larger values are slightly faster at the expense of more storage space. # Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100 GET_ITERATOR_CHUNK_SIZE = 100
# Constants to make looking up tuple values clearer. # Namedtuples for sql.* internal use.
# Join lists (indexes into the tuples that are values in the alias_map # Join lists (indexes into the tuples that are values in the alias_map
# dictionary in the Query class). # dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo', JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias ' 'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable') 'lhs_join_col rhs_join_col nullable join_field')
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the join in Model terms (model Options and Fields for both
# sides of the join. The rel_field is the field we are joining along.
PathInfo = namedtuple('PathInfo',
'from_field to_field from_opts to_opts join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause. # Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field') SelectInfo = namedtuple('SelectInfo', 'col field')

View File

@@ -50,10 +50,10 @@ class SQLEvaluator(object):
self.cols.append((node, query.aggregate_select[node.name])) self.cols.append((node, query.aggregate_select[node.name]))
else: else:
try: try:
field, source, opts, join_list, last, _ = query.setup_joins( field, source, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(), field_list, query.get_meta(),
query.get_initial_alias(), self.reuse) query.get_initial_alias(), self.reuse)
col, _, join_list = query.trim_joins(source, join_list, last, False) col, _, join_list = query.trim_joins(source, join_list, path)
if self.reuse is not None and self.reuse != REUSE_ALL: if self.reuse is not None and self.reuse != REUSE_ALL:
self.reuse.update(join_list) self.reuse.update(join_list)
self.cols.append((node, (join_list[-1], col))) self.cols.append((node, (join_list[-1], col)))

View File

@@ -14,13 +14,13 @@ from django.utils.encoding import force_text
from django.utils.tree import Node from django.utils.tree import Node
from django.utils import six from django.utils import six
from django.db import connections, DEFAULT_DB_ALIAS from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import ExpressionNode from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.db.models.loading import get_model
from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo) ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -119,7 +119,7 @@ class Query(object):
self.filter_is_sticky = False self.filter_is_sticky = False
self.included_inherited_models = {} self.included_inherited_models = {}
# SQL-related attributes # SQL-related attributes
# Select and related select clauses as SelectInfo instances. # Select and related select clauses as SelectInfo instances.
# The select is used for cases where we want to set up the select # The select is used for cases where we want to set up the select
# clause to contain other than default fields (values(), annotate(), # clause to contain other than default fields (values(), annotate(),
@@ -201,6 +201,16 @@ class Query(object):
(s.col, s.field is not None and s.field.name or None) (s.col, s.field is not None and s.field.name or None)
for s in obj_dict['select'] for s in obj_dict['select']
] ]
# alias_map can also contain references to fields.
new_alias_map = {}
for alias, join_info in obj_dict['alias_map'].items():
if join_info.join_field is None:
new_alias_map[alias] = join_info
else:
model = join_info.join_field.model._meta
field_id = (model.app_label, model.object_name, join_info.join_field.name)
new_alias_map[alias] = join_info._replace(join_field=field_id)
obj_dict['alias_map'] = new_alias_map
return obj_dict return obj_dict
def __setstate__(self, obj_dict): def __setstate__(self, obj_dict):
@@ -213,6 +223,15 @@ class Query(object):
SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None) SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None)
for tpl in obj_dict['select'] for tpl in obj_dict['select']
] ]
new_alias_map = {}
for alias, join_info in obj_dict['alias_map'].items():
if join_info.join_field is None:
new_alias_map[alias] = join_info
else:
field_id = join_info.join_field
new_alias_map[alias] = join_info._replace(
join_field=get_model(field_id[0], field_id[1])._meta.get_field(field_id[2]))
obj_dict['alias_map'] = new_alias_map
self.__dict__.update(obj_dict) self.__dict__.update(obj_dict)
@@ -479,21 +498,26 @@ class Query(object):
# Now, add the joins from rhs query into the new query (skipping base # Now, add the joins from rhs query into the new query (skipping base
# table). # table).
for alias in rhs.tables[1:]: for alias in rhs.tables[1:]:
if not rhs.alias_refcount[alias]: table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
continue
table, _, join_type, lhs, lhs_col, col, nullable = rhs.alias_map[alias]
promote = (join_type == self.LOUTER) promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the # If the left side of the join was already relabeled, use the
# updated alias. # updated alias.
lhs = change_map.get(lhs, lhs) lhs = change_map.get(lhs, lhs)
new_alias = self.join( new_alias = self.join(
(lhs, table, lhs_col, col), reuse=reuse, promote=promote, (lhs, table, lhs_col, col), reuse=reuse, promote=promote,
outer_if_first=not conjunction, nullable=nullable) outer_if_first=not conjunction, nullable=nullable,
join_field=join_field)
# We can't reuse the same join again in the query. If we have two # We can't reuse the same join again in the query. If we have two
# distinct joins for the same connection in rhs query, then the # distinct joins for the same connection in rhs query, then the
# combined query must have two joins, too. # combined query must have two joins, too.
reuse.discard(new_alias) reuse.discard(new_alias)
change_map[alias] = new_alias change_map[alias] = new_alias
if not rhs.alias_refcount[alias]:
# The alias was unused in the rhs query. Unref it so that it
# will be unused in the new query, too. We have to add and
# unref the alias so that join promotion has information of
# the join type for the unused alias.
self.unref_alias(new_alias)
# So that we don't exclude valid results in an "or" query combination, # So that we don't exclude valid results in an "or" query combination,
# all joins exclusive to either the lhs or the rhs must be converted # all joins exclusive to either the lhs or the rhs must be converted
@@ -868,7 +892,7 @@ class Query(object):
return len([1 for count in self.alias_refcount.values() if count]) return len([1 for count in self.alias_refcount.values() if count])
def join(self, connection, reuse=REUSE_ALL, promote=False, def join(self, connection, reuse=REUSE_ALL, promote=False,
outer_if_first=False, nullable=False): outer_if_first=False, nullable=False, join_field=None):
""" """
Returns an alias for the join in 'connection', either reusing an Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a existing alias for that join or creating a new one. 'connection' is a
@@ -897,6 +921,8 @@ class Query(object):
If 'nullable' is True, the join can potentially involve NULL values and If 'nullable' is True, the join can potentially involve NULL values and
is a candidate for promotion (to "left outer") when combining querysets. is a candidate for promotion (to "left outer") when combining querysets.
The 'join_field' is the field we are joining along (if any).
""" """
lhs, table, lhs_col, col = connection lhs, table, lhs_col, col = connection
existing = self.join_map.get(connection, ()) existing = self.join_map.get(connection, ())
@@ -906,8 +932,13 @@ class Query(object):
reuse = set() reuse = set()
else: else:
reuse = [a for a in existing if a in reuse] reuse = [a for a in existing if a in reuse]
if reuse: for alias in reuse:
alias = reuse[0] if join_field and self.alias_map[alias].join_field != join_field:
# The join_map doesn't contain join_field (mainly because
# fields in Query structs are problematic in pickling), so
# check that the existing join is created using the same
# join_field used for the under work join.
continue
self.ref_alias(alias) self.ref_alias(alias)
if promote or (lhs and self.alias_map[lhs].join_type == self.LOUTER): if promote or (lhs and self.alias_map[lhs].join_type == self.LOUTER):
self.promote_joins([alias]) self.promote_joins([alias])
@@ -926,7 +957,8 @@ class Query(object):
join_type = self.LOUTER join_type = self.LOUTER
else: else:
join_type = self.INNER join_type = self.INNER
join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable) join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
join_field)
self.alias_map[alias] = join self.alias_map[alias] = join
if connection in self.join_map: if connection in self.join_map:
self.join_map[connection] += (alias,) self.join_map[connection] += (alias,)
@@ -1007,11 +1039,11 @@ class Query(object):
# - this is an annotation over a model field # - this is an annotation over a model field
# then we need to explore the joins that are required. # then we need to explore the joins that are required.
field, source, opts, join_list, last, _ = self.setup_joins( field, source, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias(), REUSE_ALL) field_list, opts, self.get_initial_alias(), REUSE_ALL)
# Process the join chain to see if it can be trimmed # Process the join chain to see if it can be trimmed
col, _, join_list = self.trim_joins(source, join_list, last, False) col, _, join_list = self.trim_joins(source, join_list, path)
# If the aggregate references a model or field that requires a join, # If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned # those joins must be LEFT OUTER - empty join rows must be returned
@@ -1030,7 +1062,7 @@ class Query(object):
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def add_filter(self, filter_expr, connector=AND, negate=False, def add_filter(self, filter_expr, connector=AND, negate=False,
can_reuse=None, process_extras=True, force_having=False): can_reuse=None, force_having=False):
""" """
Add a single filter to the query. The 'filter_expr' is a pair: Add a single filter to the query. The 'filter_expr' is a pair:
(filter_string, value). E.g. ('name__contains', 'fred') (filter_string, value). E.g. ('name__contains', 'fred')
@@ -1047,10 +1079,6 @@ class Query(object):
will be a set of table aliases that can be reused in this filter, even will be a set of table aliases that can be reused in this filter, even
if we would otherwise force the creation of new aliases for a join if we would otherwise force the creation of new aliases for a join
(needed for nested Q-filters). The set is updated by this method. (needed for nested Q-filters). The set is updated by this method.
If 'process_extras' is set, any extra filters returned from the table
joining process will be processed. This parameter is set to False
during the processing of extra filters to avoid infinite recursion.
""" """
arg, value = filter_expr arg, value = filter_expr
parts = arg.split(LOOKUP_SEP) parts = arg.split(LOOKUP_SEP)
@@ -1115,10 +1143,11 @@ class Query(object):
allow_many = not negate allow_many = not negate
try: try:
field, target, opts, join_list, last, extra_filters = self.setup_joins( field, target, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many, parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True, negate=negate, allow_explicit_fk=True)
process_extras=process_extras) if can_reuse is not None:
can_reuse.update(join_list)
except MultiJoin as e: except MultiJoin as e:
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse) can_reuse)
@@ -1136,10 +1165,10 @@ class Query(object):
join_promote = True join_promote = True
# Process the join list to see if we can remove any inner joins from # Process the join list to see if we can remove any inner joins from
# the far end (fewer tables in a query is better). # the far end (fewer tables in a query is better). Note that join
nonnull_comparison = (lookup_type == 'isnull' and value is False) # promotion must happen before join trimming to have the join type
col, alias, join_list = self.trim_joins(target, join_list, last, # information available when reusing joins.
nonnull_comparison) col, alias, join_list = self.trim_joins(target, join_list, path)
if connector == OR: if connector == OR:
# Some joins may need to be promoted when adding a new filter to a # Some joins may need to be promoted when adding a new filter to a
@@ -1212,12 +1241,6 @@ class Query(object):
# is added in upper layers of the code. # is added in upper layers of the code.
self.where.add((Constraint(alias, col, None), 'isnull', False), AND) self.where.add((Constraint(alias, col, None), 'isnull', False), AND)
if can_reuse is not None:
can_reuse.update(join_list)
if process_extras:
for filter in extra_filters:
self.add_filter(filter, negate=negate, can_reuse=can_reuse,
process_extras=False)
def add_q(self, q_object, used_aliases=None, force_having=False): def add_q(self, q_object, used_aliases=None, force_having=False):
""" """
@@ -1270,37 +1293,24 @@ class Query(object):
if self.filter_is_sticky: if self.filter_is_sticky:
self.used_aliases = used_aliases self.used_aliases = used_aliases
def setup_joins(self, names, opts, alias, can_reuse, allow_many=True, def names_to_path(self, names, opts, allow_many=False,
allow_explicit_fk=False, negate=False, process_extras=True): allow_explicit_fk=True):
""" """
Compute the necessary table joins for the passage through the fields Walks the names path and turns them PathInfo tuples. Note that a
given in 'names'. 'opts' is the Options class for the current model single name in 'names' can generate multiple PathInfos (m2m for
(which gives the table we are joining to), 'alias' is the alias for the example).
table we are joining to.
The 'can_reuse' defines the reverse foreign key joins we can reuse. It 'names' is the path of names to travle, 'opts' is the model Options we
can be either sql.constants.REUSE_ALL in which case all joins are start the name resolving from, 'allow_many' and 'allow_explicit_fk'
reusable or a set of aliases that can be reused. Non-reverse foreign are as for setup_joins().
keys are always reusable.
The 'allow_explicit_fk' controls if field.attname is allowed in the Returns a list of PathInfo tuples. In addition returns the final field
lookups. (the last used join field), and target (which is a field guaranteed to
contain the same value as the final field).
Finally, 'negate' is used in the same sense as for add_filter()
-- it indicates an exclude() filter, or something similar. It is only
passed in here so that it can be passed to a field's extra_filter() for
customized behavior.
Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value and the
list of tables joined.
""" """
joins = [alias] path = []
last = [0] multijoin_pos = None
extra_filters = []
int_alias = None
for pos, name in enumerate(names): for pos, name in enumerate(names):
last.append(len(joins))
if name == 'pk': if name == 'pk':
name = opts.pk.name name = opts.pk.name
try: try:
@@ -1314,14 +1324,12 @@ class Query(object):
field, model, direct, m2m = opts.get_field_by_name(f.name) field, model, direct, m2m = opts.get_field_by_name(f.name)
break break
else: else:
names = opts.get_all_field_names() + list(self.aggregate_select) available = opts.get_all_field_names() + list(self.aggregate_select)
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(available)))
# Check if we need any joins for concrete inheritance cases (the
if not allow_many and (m2m or not direct): # field lives in parent, but we are currently in one of its
for alias in joins: # children)
self.unref_alias(alias)
raise MultiJoin(pos + 1)
if model: if model:
# The field lives on a base class of the current model. # The field lives on a base class of the current model.
# Skip the chain of proxy to the concrete proxied model # Skip the chain of proxy to the concrete proxied model
@@ -1331,172 +1339,179 @@ class Query(object):
if int_model is proxied_model: if int_model is proxied_model:
opts = int_model._meta opts = int_model._meta
else: else:
lhs_col = opts.parents[int_model].column final_field = opts.parents[int_model]
target = final_field.rel.get_related_field()
opts = int_model._meta opts = int_model._meta
alias = self.join((alias, opts.db_table, lhs_col, path.append(PathInfo(final_field, target, final_field.model._meta,
opts.pk.column)) opts, final_field))
joins.append(alias) # We have five different cases to solve: foreign keys, reverse
cached_data = opts._join_cache.get(name) # foreign keys, m2m fields (also reverse) and non-relational
orig_opts = opts # fields. We are mostly just using the related field API to
# fetch the from and to fields. The m2m fields are handled as
if process_extras and hasattr(field, 'extra_filters'): # two foreign keys, first one reverse, the second one direct.
extra_filters.extend(field.extra_filters(names, pos, negate)) if direct and not field.rel and not m2m:
if direct: # Local non-relational field.
if m2m: final_field = target = field
# Many-to-many field defined on the current model. break
if cached_data: elif direct and not m2m:
(table1, from_col1, to_col1, table2, from_col2, # Foreign Key
to_col2, opts, target) = cached_data opts = field.rel.to._meta
else: target = field.rel.get_related_field()
table1 = field.m2m_db_table() final_field = field
from_col1 = opts.get_field_by_name( from_opts = field.model._meta
field.m2m_target_field_name())[0].column path.append(PathInfo(field, target, from_opts, opts, field))
to_col1 = field.m2m_column_name() elif not direct and not m2m:
opts = field.rel.to._meta # Revere foreign key
table2 = opts.db_table final_field = to_field = field.field
from_col2 = field.m2m_reverse_name() opts = to_field.model._meta
to_col2 = opts.get_field_by_name( from_field = to_field.rel.get_related_field()
field.m2m_reverse_target_field_name())[0].column from_opts = from_field.model._meta
target = opts.pk path.append(
orig_opts._join_cache[name] = (table1, from_col1, PathInfo(from_field, to_field, from_opts, opts, to_field))
to_col1, table2, from_col2, to_col2, opts, if from_field.model is to_field.model:
target) # Recursive foreign key to self.
target = opts.get_field_by_name(
int_alias = self.join((alias, table1, from_col1, to_col1), field.field.rel.field_name)[0]
reuse=can_reuse, nullable=True)
if int_alias == table2 and from_col2 == to_col2:
joins.append(int_alias)
alias = int_alias
else:
alias = self.join(
(int_alias, table2, from_col2, to_col2),
reuse=can_reuse, nullable=True)
joins.extend([int_alias, alias])
elif field.rel:
# One-to-one or many-to-one field
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
opts = field.rel.to._meta
target = field.rel.get_related_field()
table = opts.db_table
from_col = field.column
to_col = target.column
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
alias = self.join((alias, table, from_col, to_col),
nullable=self.is_nullable(field))
joins.append(alias)
else: else:
# Non-relation fields. target = opts.pk
target = field elif direct and m2m:
break if not field.rel.through:
else: # Gotcha! This is just a fake m2m field - a generic relation
orig_field = field # field).
from_field = opts.pk
opts = field.rel.to._meta
target = opts.get_field_by_name(field.object_id_field_name)[0]
final_field = field
# Note that we are using different field for the join_field
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
path.append(PathInfo(from_field, target, field.model._meta, opts,
field))
else:
# m2m field. We are travelling first to the m2m table along a
# reverse relation, then from m2m table to the target table.
from_field1 = opts.get_field_by_name(
field.m2m_target_field_name())[0]
opts = field.rel.through._meta
to_field1 = opts.get_field_by_name(field.m2m_field_name())[0]
path.append(
PathInfo(from_field1, to_field1, from_field1.model._meta,
opts, to_field1))
final_field = from_field2 = opts.get_field_by_name(
field.m2m_reverse_field_name())[0]
opts = field.rel.to._meta
target = to_field2 = opts.get_field_by_name(
field.m2m_reverse_target_field_name())[0]
path.append(
PathInfo(from_field2, to_field2, from_field2.model._meta,
opts, from_field2))
elif not direct and m2m:
# This one is just like above, except we are travelling the
# fields in opposite direction.
field = field.field field = field.field
if m2m: from_field1 = opts.get_field_by_name(
# Many-to-many field defined on the target model. field.m2m_reverse_target_field_name())[0]
if cached_data: int_opts = field.rel.through._meta
(table1, from_col1, to_col1, table2, from_col2, to_field1 = int_opts.get_field_by_name(
to_col2, opts, target) = cached_data field.m2m_reverse_field_name())[0]
else: path.append(
table1 = field.m2m_db_table() PathInfo(from_field1, to_field1, from_field1.model._meta,
from_col1 = opts.get_field_by_name( int_opts, to_field1))
field.m2m_reverse_target_field_name())[0].column final_field = from_field2 = int_opts.get_field_by_name(
to_col1 = field.m2m_reverse_name() field.m2m_field_name())[0]
opts = orig_field.opts opts = field.opts
table2 = opts.db_table target = to_field2 = opts.get_field_by_name(
from_col2 = field.m2m_column_name() field.m2m_target_field_name())[0]
to_col2 = opts.get_field_by_name( path.append(PathInfo(from_field2, to_field2, from_field2.model._meta,
field.m2m_target_field_name())[0].column opts, from_field2))
target = opts.pk
orig_opts._join_cache[name] = (table1, from_col1,
to_col1, table2, from_col2, to_col2, opts,
target)
int_alias = self.join((alias, table1, from_col1, to_col1), if m2m and multijoin_pos is None:
reuse=can_reuse, nullable=True) multijoin_pos = pos
alias = self.join((int_alias, table2, from_col2, to_col2), if not direct and not path[-1].to_field.unique and multijoin_pos is None:
reuse=can_reuse, nullable=True) multijoin_pos = pos
joins.extend([int_alias, alias])
else:
# One-to-many field (ForeignKey defined on the target model)
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
local_field = opts.get_field_by_name(
field.rel.field_name)[0]
opts = orig_field.opts
table = opts.db_table
from_col = local_field.column
to_col = field.column
# In case of a recursive FK, use the to_field for
# reverse lookups as well
if orig_field.model is local_field.model:
target = opts.get_field_by_name(
field.rel.field_name)[0]
else:
target = opts.pk
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
alias = self.join((alias, table, from_col, to_col),
reuse=can_reuse, nullable=True)
joins.append(alias)
if pos != len(names) - 1: if pos != len(names) - 1:
if pos == len(names) - 2: if pos == len(names) - 2:
raise FieldError("Join on field %r not permitted. Did you misspell %r for the lookup type?" % (name, names[pos + 1])) raise FieldError(
"Join on field %r not permitted. Did you misspell %r for "
"the lookup type?" % (name, names[pos + 1]))
else: else:
raise FieldError("Join on field %r not permitted." % name) raise FieldError("Join on field %r not permitted." % name)
if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many:
raise MultiJoin(multijoin_pos + 1)
return path, final_field, target
return field, target, opts, joins, last, extra_filters def setup_joins(self, names, opts, alias, can_reuse, allow_many=True,
allow_explicit_fk=False):
def trim_joins(self, target, join_list, last, nonnull_check=False):
""" """
Sometimes joins at the end of a multi-table sequence can be trimmed. If Compute the necessary table joins for the passage through the fields
the final join is against the same column as we are comparing against, given in 'names'. 'opts' is the Options class for the current model
and is an inner join, we can go back one step in a join chain and (which gives the table we are starting from), 'alias' is the alias for
compare against the LHS of the join instead (and then repeat the the table to start the joining from.
optimization). The result, potentially, involves fewer table joins.
The 'target' parameter is the final field being joined to, 'join_list' The 'can_reuse' defines the reverse foreign key joins we can reuse. It
is the full list of join aliases. can be sql.constants.REUSE_ALL in which case all joins are reusable
or a set of aliases that can be reused. Note that Non-reverse foreign
keys are always reusable.
The 'last' list contains offsets into 'join_list', corresponding to If 'allow_many' is False, then any reverse foreign key seen will
each component of the filter. Many-to-many relations, for example, add generate a MultiJoin exception.
two tables to the join list and we want to deal with both tables the
same way, so 'last' has an entry for the first of the two tables and
then the table immediately after the second table, in that case.
The 'nonnull_check' parameter is True when we are using inner joins The 'allow_explicit_fk' controls if field.attname is allowed in the
between tables explicitly to exclude NULL entries. In that case, the lookups.
tables shouldn't be trimmed, because the very action of joining to them
alters the result set. Returns the final field involved in the joins, the target field (used
for any 'where' constraint), the final 'opts' value, the joins and the
field path travelled to generate the joins.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
that value. Final field is needed for example in some value
conversions (convert 'obj' in fk__id=obj to pk val using the foreign
key field for example).
"""
joins = [alias]
# First, generate the path for the names
path, final_field, target = self.names_to_path(
names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
for pos, join in enumerate(path):
from_field, to_field, from_opts, opts, join_field = join
direct = join_field == from_field
if direct:
nullable = self.is_nullable(from_field)
else:
nullable = True
connection = alias, opts.db_table, from_field.column, to_field.column
alias = self.join(connection, reuse=can_reuse, nullable=nullable,
join_field=join_field)
joins.append(alias)
return final_field, target, opts, joins, path
def trim_joins(self, target, joins, path):
"""
The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos
used to create the joins.
Returns the final active column and table alias and the new active Returns the final active column and table alias and the new active
join_list. joins.
We will always trim any direct join if we have the target column
available already in the previous table. Reverse joins can't be
trimmed as we don't know if there is anything on the other side of
the join.
""" """
final = len(join_list) for info in reversed(path):
penultimate = last.pop() direct = info.join_field == info.from_field
if penultimate == final: if info.to_field == target and direct:
penultimate = last.pop() target = info.from_field
col = target.column self.unref_alias(joins.pop())
alias = join_list[-1] else:
while final > 1:
join = self.alias_map[alias]
if (col != join.rhs_join_col or join.join_type != self.INNER or
nonnull_check):
break break
self.unref_alias(alias) return target.column, joins[-1], joins
alias = join.lhs_alias
col = join.lhs_join_col
join_list.pop()
final -= 1
if final == penultimate:
penultimate = last.pop()
return col, alias, join_list
def split_exclude(self, filter_expr, prefix, can_reuse): def split_exclude(self, filter_expr, prefix, can_reuse):
""" """
@@ -1627,9 +1642,9 @@ class Query(object):
try: try:
for name in field_names: for name in field_names:
field, target, u2, joins, u3, u4 = self.setup_joins( field, target, u2, joins, u3 = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m,
allow_m2m, True) True)
final_alias = joins[-1] final_alias = joins[-1]
col = target.column col = target.column
if len(joins) > 1: if len(joins) > 1:
@@ -1918,7 +1933,7 @@ class Query(object):
""" """
opts = self.model._meta opts = self.model._meta
alias = self.get_initial_alias() alias = self.get_initial_alias()
field, col, opts, joins, last, extra = self.setup_joins( field, col, opts, joins, extra = self.setup_joins(
start.split(LOOKUP_SEP), opts, alias, REUSE_ALL) start.split(LOOKUP_SEP), opts, alias, REUSE_ALL)
select_col = self.alias_map[joins[1]].lhs_join_col select_col = self.alias_map[joins[1]].lhs_join_col
select_alias = alias select_alias = alias
@@ -1975,18 +1990,6 @@ def get_order_dir(field, default='ASC'):
return field, dirn[0] return field, dirn[0]
def setup_join_cache(sender, **kwargs):
"""
The information needed to join between model fields is something that is
invariant over the life of the model, so we cache it in the model's Options
class, rather than recomputing it all the time.
This method initialises the (empty) cache when the model is created.
"""
sender._meta._join_cache = {}
signals.class_prepared.connect(setup_join_cache)
def add_to_dict(data, key, value): def add_to_dict(data, key, value):
""" """
A helper function to add "value" to the set of values for "key", whether or A helper function to add "value" to the set of values for "key", whether or

View File

@@ -978,3 +978,7 @@ class AggregationTests(TestCase):
('The Definitive Guide to Django: Web Development Done Right', 2) ('The Definitive Guide to Django: Web Development Done Right', 2)
] ]
) )
def test_reverse_join_trimming(self):
qs = Author.objects.annotate(Count('book_contact_set__contact'))
self.assertIn(' JOIN ', str(qs.query))

View File

@@ -283,6 +283,7 @@ class SingleObject(models.Model):
class RelatedObject(models.Model): class RelatedObject(models.Model):
single = models.ForeignKey(SingleObject, null=True) single = models.ForeignKey(SingleObject, null=True)
f = models.IntegerField(null=True)
class Meta: class Meta:
ordering = ['single'] ordering = ['single']
@@ -311,7 +312,7 @@ class Food(models.Model):
@python_2_unicode_compatible @python_2_unicode_compatible
class Eaten(models.Model): class Eaten(models.Model):
food = models.ForeignKey(Food, to_field="name") food = models.ForeignKey(Food, to_field="name", null=True)
meal = models.CharField(max_length=20) meal = models.CharField(max_length=20)
def __str__(self): def __str__(self):
@@ -400,3 +401,23 @@ class ModelA(models.Model):
name = models.TextField() name = models.TextField()
b = models.ForeignKey(ModelB, null=True) b = models.ForeignKey(ModelB, null=True)
d = models.ForeignKey(ModelD) d = models.ForeignKey(ModelD)
@python_2_unicode_compatible
class Job(models.Model):
name = models.CharField(max_length=20, unique=True)
def __str__(self):
return self.name
class JobResponsibilities(models.Model):
job = models.ForeignKey(Job, to_field='name')
responsibility = models.ForeignKey('Responsibility', to_field='description')
@python_2_unicode_compatible
class Responsibility(models.Model):
description = models.CharField(max_length=20, unique=True)
jobs = models.ManyToManyField(Job, through=JobResponsibilities,
related_name='responsibilities')
def __str__(self):
return self.description

View File

@@ -23,7 +23,8 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten,
Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory,
SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
SingleObject, RelatedObject, ModelA, ModelD) SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job,
JobResponsibilities)
class BaseQuerysetTest(TestCase): class BaseQuerysetTest(TestCase):
@@ -243,7 +244,10 @@ class Queries1Tests(BaseQuerysetTest):
q1 = Item.objects.order_by('name') q1 = Item.objects.order_by('name')
q2 = Item.objects.filter(id=self.i1.id) q2 = Item.objects.filter(id=self.i1.id)
list(q2) list(q2)
self.assertEqual(len((q1 & q2).order_by('name').query.tables), 1) combined_query = (q1 & q2).order_by('name').query
self.assertEqual(len([
t for t in combined_query.tables if combined_query.alias_refcount[t]
]), 1)
def test_order_by_join_unref(self): def test_order_by_join_unref(self):
""" """
@@ -883,6 +887,225 @@ class Queries1Tests(BaseQuerysetTest):
Item.objects.filter(Q(tags__name__in=['t4', 't3'])), Item.objects.filter(Q(tags__name__in=['t4', 't3'])),
[repr(i) for i in Item.objects.filter(~~Q(tags__name__in=['t4', 't3']))]) [repr(i) for i in Item.objects.filter(~~Q(tags__name__in=['t4', 't3']))])
def test_ticket_10790_1(self):
# Querying direct fields with isnull should trim the left outer join.
# It also should not create INNER JOIN.
q = Tag.objects.filter(parent__isnull=True)
self.assertQuerysetEqual(q, ['<Tag: t1>'])
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.filter(parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__isnull=False)
self.assertQuerysetEqual(q, ['<Tag: t1>'])
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue('INNER JOIN' not in str(q.query))
def test_ticket_10790_2(self):
# Querying across several tables should strip only the last outer join,
# while preserving the preceeding inner joins.
q = Tag.objects.filter(parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
# Querying without isnull should not convert anything to left outer join.
q = Tag.objects.filter(parent__parent=self.t1)
self.assertQuerysetEqual(
q,
['<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
def test_ticket_10790_3(self):
# Querying via indirect fields should populate the left outer join
q = NamedCategory.objects.filter(tag__isnull=True)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
# join to dumbcategory ptr_id
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
self.assertQuerysetEqual(q, [])
# Querying across several tables should strip only the last join, while
# preserving the preceding left outer joins.
q = NamedCategory.objects.filter(tag__parent__isnull=True)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertQuerysetEqual( q, ['<NamedCategory: NamedCategory object>'])
def test_ticket_10790_4(self):
# Querying across m2m field should not strip the m2m table from join.
q = Author.objects.filter(item__tags__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a2>', '<Author: a3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 2)
self.assertTrue('INNER JOIN' not in str(q.query))
q = Author.objects.filter(item__tags__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3)
self.assertTrue('INNER JOIN' not in str(q.query))
def test_ticket_10790_5(self):
# Querying with isnull=False across m2m field should not create outer joins
q = Author.objects.filter(item__tags__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 2)
q = Author.objects.filter(item__tags__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 3)
q = Author.objects.filter(item__tags__parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 4)
def test_ticket_10790_6(self):
# Querying with isnull=True across m2m field should not create inner joins
# and strip last outer join
q = Author.objects.filter(item__tags__parent__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a1>', '<Author: a2>', '<Author: a2>',
'<Author: a2>', '<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 4)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
q = Author.objects.filter(item__tags__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
def test_ticket_10790_7(self):
# Reverse querying with isnull should not strip the join
q = Author.objects.filter(item__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
q = Author.objects.filter(item__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
def test_ticket_10790_8(self):
# Querying with combined q-objects should also strip the left outer join
q = Tag.objects.filter(Q(parent__isnull=True) | Q(parent=self.t1))
self.assertQuerysetEqual(
q,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
def test_ticket_10790_combine(self):
# Combining queries should not re-populate the left outer join
q1 = Tag.objects.filter(parent__isnull=True)
q2 = Tag.objects.filter(parent__isnull=False)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q1 & q2
self.assertQuerysetEqual(q3, [])
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q2 = Tag.objects.filter(parent=self.t1)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q2 | q1
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q1 = Tag.objects.filter(parent__isnull=True)
q2 = Tag.objects.filter(parent__parent__isnull=True)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q2 | q1
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
class Queries2Tests(TestCase): class Queries2Tests(TestCase):
def setUp(self): def setUp(self):
Number.objects.create(num=4) Number.objects.create(num=4)
@@ -1037,6 +1260,10 @@ class Queries4Tests(BaseQuerysetTest):
Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3) Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3)
def test_ticket14876(self): def test_ticket14876(self):
# Note: when combining the query we need to have information available
# about the join type of the trimmed "creator__isnull" join. If we
# don't have that information, then the join is created as INNER JOIN
# and results will be incorrect.
q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1')) q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1'))
q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1')) q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1'))
self.assertQuerysetEqual(q1, ["<Report: r1>", "<Report: r3>"], ordered=False) self.assertQuerysetEqual(q1, ["<Report: r1>", "<Report: r3>"], ordered=False)
@@ -1405,17 +1632,19 @@ class NullableRelOrderingTests(TestCase):
# the join type of already existing joins. # the join type of already existing joins.
Plaything.objects.create(name="p1") Plaything.objects.create(name="p1")
s = SingleObject.objects.create(name='s') s = SingleObject.objects.create(name='s')
r = RelatedObject.objects.create(single=s) r = RelatedObject.objects.create(single=s, f=1)
Plaything.objects.create(name="p2", others=r) Plaything.objects.create(name="p2", others=r)
qs = Plaything.objects.all().filter(others__isnull=False).order_by('pk') qs = Plaything.objects.all().filter(others__isnull=False).order_by('pk')
self.assertTrue('JOIN' not in str(qs.query))
qs = Plaything.objects.all().filter(others__f__isnull=False).order_by('pk')
self.assertTrue('INNER' in str(qs.query)) self.assertTrue('INNER' in str(qs.query))
qs = qs.order_by('others__single__name') qs = qs.order_by('others__single__name')
# The ordering by others__single__pk will add one new join (to single) # The ordering by others__single__pk will add one new join (to single)
# and that join must be LEFT join. The already existing join to related # and that join must be LEFT join. The already existing join to related
# objects must be kept INNER. So, we have both a INNER and a LEFT join # objects must be kept INNER. So, we have both a INNER and a LEFT join
# in the query. # in the query.
self.assertTrue('LEFT' in str(qs.query)) self.assertEquals(str(qs.query).count('LEFT'), 1)
self.assertTrue('INNER' in str(qs.query)) self.assertEquals(str(qs.query).count('INNER'), 1)
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, qs,
['<Plaything: p2>'] ['<Plaything: p2>']
@@ -1466,6 +1695,7 @@ class Queries6Tests(TestCase):
# This next test used to cause really weird PostgreSQL behavior, but it was # This next test used to cause really weird PostgreSQL behavior, but it was
# only apparent much later when the full test suite ran. # only apparent much later when the full test suite ran.
# - Yeah, it leaves global ITER_CHUNK_SIZE to 2 instead of 100...
#@unittest.expectedFailure #@unittest.expectedFailure
def test_slicing_and_cache_interaction(self): def test_slicing_and_cache_interaction(self):
# We can do slicing beyond what is currently in the result cache, # We can do slicing beyond what is currently in the result cache,
@@ -1993,6 +2223,29 @@ class DefaultValuesInsertTest(TestCase):
except TypeError: except TypeError:
self.fail("Creation of an instance of a model with only the PK field shouldn't error out after bulk insert refactoring (#17056)") self.fail("Creation of an instance of a model with only the PK field shouldn't error out after bulk insert refactoring (#17056)")
class ExcludeTest(TestCase):
def setUp(self):
f1 = Food.objects.create(name='apples')
Food.objects.create(name='oranges')
Eaten.objects.create(food=f1, meal='dinner')
j1 = Job.objects.create(name='Manager')
r1 = Responsibility.objects.create(description='Playing golf')
j2 = Job.objects.create(name='Programmer')
r2 = Responsibility.objects.create(description='Programming')
JobResponsibilities.objects.create(job=j1, responsibility=r1)
JobResponsibilities.objects.create(job=j2, responsibility=r2)
def test_to_field(self):
self.assertQuerysetEqual(
Food.objects.exclude(eaten__meal='dinner'),
['<Food: oranges>'])
self.assertQuerysetEqual(
Job.objects.exclude(responsibilities__description='Playing golf'),
['<Job: Programmer>'])
self.assertQuerysetEqual(
Responsibility.objects.exclude(jobs__name='Manager'),
['<Responsibility: Programming>'])
class NullInExcludeTest(TestCase): class NullInExcludeTest(TestCase):
def setUp(self): def setUp(self):
NullableName.objects.create(name='i1') NullableName.objects.create(name='i1')
@@ -2155,3 +2408,13 @@ class NullJoinPromotionOrTest(TestCase):
# so we can use INNER JOIN for it. However, we can NOT use INNER JOIN # so we can use INNER JOIN for it. However, we can NOT use INNER JOIN
# for the b->c join, as a->b is nullable. # for the b->c join, as a->b is nullable.
self.assertEqual(str(qset.query).count('INNER JOIN'), 1) self.assertEqual(str(qset.query).count('INNER JOIN'), 1)
class ReverseJoinTrimmingTest(TestCase):
def test_reverse_trimming(self):
# Check that we don't accidentally trim reverse joins - we can't know
# if there is anything on the other side of the join, so trimming
# reverse joins can't be done, ever.
t = Tag.objects.create()
qs = Tag.objects.filter(annotation__tag=t.pk)
self.assertIn('INNER JOIN', str(qs.query))
self.assertEquals(list(qs), [])