1
0
mirror of https://github.com/django/django.git synced 2025-04-29 19:54:37 +00:00
django/django/db/models/constraints.py

725 lines
29 KiB
Python

import warnings
from enum import Enum
from types import NoneType
from django.core import checks
from django.core.exceptions import FieldDoesNotExist, FieldError, ValidationError
from django.db import connections
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Exists, ExpressionList, F, RawSQL
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import Exact, IsNull
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
from django.db.utils import DEFAULT_DB_ALIAS
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.translation import gettext_lazy as _
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
class BaseConstraint:
default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_code = None
violation_error_message = None
# RemovedInDjango60Warning: When the deprecation ends, replace with:
# def __init__(
# self, *, name, violation_error_code=None, violation_error_message=None
# ):
def __init__(
self, *args, name=None, violation_error_code=None, violation_error_message=None
):
# RemovedInDjango60Warning.
if name is None and not args:
raise TypeError(
f"{self.__class__.__name__}.__init__() missing 1 required keyword-only "
f"argument: 'name'"
)
self.name = name
if violation_error_code is not None:
self.violation_error_code = violation_error_code
if violation_error_message is not None:
self.violation_error_message = violation_error_message
else:
self.violation_error_message = self.default_violation_error_message
# RemovedInDjango60Warning.
if args:
warnings.warn(
f"Passing positional arguments to {self.__class__.__name__} is "
f"deprecated.",
RemovedInDjango60Warning,
stacklevel=2,
)
for arg, attr in zip(args, ["name", "violation_error_message"]):
if arg:
setattr(self, attr, arg)
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def create_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
@classmethod
def _expression_refs_exclude(cls, model, expression, exclude):
get_field = model._meta.get_field
for field_name, *__ in model._get_expr_references(expression):
if field_name in exclude:
return True
field = get_field(field_name)
if field.generated and cls._expression_refs_exclude(
model, field.expression, exclude
):
return True
return False
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")
def get_violation_error_message(self):
return self.violation_error_message % {"name": self.name}
def _check(self, model, connection):
return []
def _check_references(self, model, references):
errors = []
fields = set()
for field_name, *lookups in references:
# pk is an alias that won't be found by opts.get_field.
if field_name != "pk":
fields.add(field_name)
if not lookups:
# If it has no lookups it cannot result in a JOIN.
continue
try:
if field_name == "pk":
field = model._meta.pk
else:
field = model._meta.get_field(field_name)
if not field.is_relation or field.many_to_many or field.one_to_many:
continue
except FieldDoesNotExist:
continue
# JOIN must happen at the first lookup.
first_lookup = lookups[0]
if (
hasattr(field, "get_transform")
and hasattr(field, "get_lookup")
and field.get_transform(first_lookup) is None
and field.get_lookup(first_lookup) is None
):
errors.append(
checks.Error(
"'constraints' refers to the joined field '%s'."
% LOOKUP_SEP.join([field_name] + lookups),
obj=model,
id="models.E041",
)
)
errors.extend(model._check_local_fields(fields, "constraints"))
return errors
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
kwargs = {"name": self.name}
if (
self.violation_error_message is not None
and self.violation_error_message != self.default_violation_error_message
):
kwargs["violation_error_message"] = self.violation_error_message
if self.violation_error_code is not None:
kwargs["violation_error_code"] = self.violation_error_code
return (path, (), kwargs)
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
class CheckConstraint(BaseConstraint):
# RemovedInDjango60Warning: when the deprecation ends, replace with
# def __init__(
# self, *, condition, name, violation_error_code=None, violation_error_message=None
# )
def __init__(
self,
*,
name,
condition=None,
check=None,
violation_error_code=None,
violation_error_message=None,
):
if check is not None:
warnings.warn(
"CheckConstraint.check is deprecated in favor of `.condition`.",
RemovedInDjango60Warning,
stacklevel=2,
)
condition = check
self.condition = condition
if not getattr(condition, "conditional", False):
raise TypeError(
"CheckConstraint.condition must be a Q instance or boolean expression."
)
super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
def _get_check(self):
warnings.warn(
"CheckConstraint.check is deprecated in favor of `.condition`.",
RemovedInDjango60Warning,
stacklevel=2,
)
return self.condition
def _set_check(self, value):
warnings.warn(
"CheckConstraint.check is deprecated in favor of `.condition`.",
RemovedInDjango60Warning,
stacklevel=2,
)
self.condition = value
check = property(_get_check, _set_check)
def _check(self, model, connection):
errors = []
if not (
connection.features.supports_table_check_constraints
or "supports_table_check_constraints" in model._meta.required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support check constraints.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W027",
)
)
elif (
connection.features.supports_table_check_constraints
or "supports_table_check_constraints"
not in model._meta.required_db_features
):
references = set()
condition = self.condition
if isinstance(condition, Q):
references.update(model._get_expr_references(condition))
if any(isinstance(expr, RawSQL) for expr in condition.flatten()):
errors.append(
checks.Warning(
f"Check constraint {self.name!r} contains RawSQL() expression "
"and won't be validated during the model full_clean().",
hint="Silence this warning if you don't care about it.",
obj=model,
id="models.W045",
),
)
errors.extend(self._check_references(model, references))
return errors
def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._check_sql(self.name, check)
def create_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._create_check_sql(model, self.name, check)
def remove_sql(self, model, schema_editor):
return schema_editor._delete_check_sql(model, self.name)
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
against = instance._get_field_expression_map(meta=model._meta, exclude=exclude)
try:
if not Q(self.condition).check(against, using=using):
raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)
except FieldError:
pass
def __repr__(self):
return "<%s: condition=%s name=%s%s%s>" % (
self.__class__.__qualname__,
self.condition,
repr(self.name),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
(
""
if self.violation_error_message is None
or self.violation_error_message == self.default_violation_error_message
else " violation_error_message=%r" % self.violation_error_message
),
)
def __eq__(self, other):
if isinstance(other, CheckConstraint):
return (
self.name == other.name
and self.condition == other.condition
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["condition"] = self.condition
return path, args, kwargs
class Deferrable(Enum):
DEFERRED = "deferred"
IMMEDIATE = "immediate"
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class UniqueConstraint(BaseConstraint):
def __init__(
self,
*expressions,
fields=(),
name=None,
condition=None,
deferrable=None,
include=None,
opclasses=(),
nulls_distinct=None,
violation_error_code=None,
violation_error_message=None,
):
if not name:
raise ValueError("A unique constraint must be named.")
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define a "
"unique constraint."
)
if expressions and fields:
raise ValueError(
"UniqueConstraint.fields and expressions are mutually exclusive."
)
if not isinstance(condition, (NoneType, Q)):
raise ValueError("UniqueConstraint.condition must be a Q instance.")
if condition and deferrable:
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
if include and deferrable:
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
if opclasses and deferrable:
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
if expressions and deferrable:
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
if expressions and opclasses:
raise ValueError(
"UniqueConstraint.opclasses cannot be used with expressions. "
"Use django.contrib.postgres.indexes.OpClass() instead."
)
if not isinstance(deferrable, (NoneType, Deferrable)):
raise TypeError(
"UniqueConstraint.deferrable must be a Deferrable instance."
)
if not isinstance(include, (NoneType, list, tuple)):
raise TypeError("UniqueConstraint.include must be a list or tuple.")
if not isinstance(opclasses, (list, tuple)):
raise TypeError("UniqueConstraint.opclasses must be a list or tuple.")
if not isinstance(nulls_distinct, (NoneType, bool)):
raise TypeError("UniqueConstraint.nulls_distinct must be a bool.")
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
"have the same number of elements."
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
self.nulls_distinct = nulls_distinct
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
@property
def contains_expressions(self):
return bool(self.expressions)
def _check(self, model, connection):
errors = model._check_local_fields({*self.fields, *self.include}, "constraints")
required_db_features = model._meta.required_db_features
if self.condition is not None and not (
connection.features.supports_partial_indexes
or "supports_partial_indexes" in required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support unique constraints "
"with conditions.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W036",
)
)
if self.deferrable is not None and not (
connection.features.supports_deferrable_unique_constraints
or "supports_deferrable_unique_constraints" in required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support deferrable unique "
"constraints.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W038",
)
)
if self.include and not (
connection.features.supports_covering_indexes
or "supports_covering_indexes" in required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support unique constraints "
"with non-key columns.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W039",
)
)
if self.contains_expressions and not (
connection.features.supports_expression_indexes
or "supports_expression_indexes" in required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support unique constraints on "
"expressions.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W044",
)
)
if self.nulls_distinct is not None and not (
connection.features.supports_nulls_distinct_unique_constraints
or "supports_nulls_distinct_unique_constraints" in required_db_features
):
errors.append(
checks.Warning(
f"{connection.display_name} does not support unique constraints "
"with nulls distinct.",
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=model,
id="models.W047",
)
)
references = set()
if (
connection.features.supports_partial_indexes
or "supports_partial_indexes" not in required_db_features
) and isinstance(self.condition, Q):
references.update(model._get_expr_references(self.condition))
if self.contains_expressions and (
connection.features.supports_expression_indexes
or "supports_expression_indexes" not in required_db_features
):
for expression in self.expressions:
references.update(model._get_expr_references(expression))
errors.extend(self._check_references(model, references))
return errors
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
nulls_distinct=self.nulls_distinct,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
nulls_distinct=self.nulls_distinct,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql(
model,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
nulls_distinct=self.nulls_distinct,
)
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
" name=%s" % repr(self.name),
"" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
(
""
if self.nulls_distinct is None
else " nulls_distinct=%r" % self.nulls_distinct
),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
(
""
if self.violation_error_message is None
or self.violation_error_message == self.default_violation_error_message
else " violation_error_message=%r" % self.violation_error_message
),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
self.name == other.name
and self.fields == other.fields
and self.condition == other.condition
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
and self.nulls_distinct is other.nulls_distinct
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
kwargs["fields"] = self.fields
if self.condition:
kwargs["condition"] = self.condition
if self.deferrable:
kwargs["deferrable"] = self.deferrable
if self.include:
kwargs["include"] = self.include
if self.opclasses:
kwargs["opclasses"] = self.opclasses
if self.nulls_distinct is not None:
kwargs["nulls_distinct"] = self.nulls_distinct
return path, self.expressions, kwargs
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
generated_field_names = []
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
if field.generated:
if exclude and self._expression_refs_exclude(
model, field.expression, exclude
):
return
generated_field_names.append(field.name)
else:
lookup_value = getattr(instance, field.attname)
if (
self.nulls_distinct is not False
and lookup_value is None
or (
lookup_value == ""
and connections[
using
].features.interprets_empty_strings_as_nulls
)
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
lookup_args = []
if generated_field_names:
field_expression_map = instance._get_field_expression_map(
meta=model._meta, exclude=exclude
)
for field_name in generated_field_names:
expression = field_expression_map[field_name]
if self.nulls_distinct is False:
lhs = F(field_name)
condition = Q(Exact(lhs, expression)) | Q(
IsNull(lhs, True), IsNull(expression, True)
)
lookup_args.append(condition)
else:
lookup_kwargs[field_name] = expression
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
else:
# Ignore constraints with excluded fields.
if exclude and any(
self._expression_refs_exclude(model, expression, exclude)
for expression in self.expressions
):
return
replacements = {
F(field): value
for field, value in instance._get_field_expression_map(
meta=model._meta, exclude=exclude
).items()
}
filters = []
for expr in self.expressions:
if hasattr(expr, "get_expression_for_validation"):
expr = expr.get_expression_for_validation()
rhs = expr.replace_expressions(replacements)
condition = Exact(expr, rhs)
if self.nulls_distinct is False:
condition = Q(condition) | Q(IsNull(expr, True), IsNull(rhs, True))
filters.append(condition)
queryset = queryset.filter(*filters)
model_class_pk = instance._get_pk_val(model._meta)
if not instance._state.adding and instance._is_pk_set(model._meta):
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
if (
self.fields
and self.violation_error_message
== self.default_violation_error_message
):
# When fields are defined, use the unique_error_message() as
# a default for backward compatibility.
validation_error_message = instance.unique_error_message(
model, self.fields
)
raise ValidationError(
validation_error_message,
code=validation_error_message.code,
)
raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
else:
against = instance._get_field_expression_map(
meta=model._meta, exclude=exclude
)
try:
if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using
):
raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
except FieldError:
pass