1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Refs #29641 -- Refactored database schema constraint creation.

Added a test for constraint names in the database.

Updated SQLite introspection to use sqlparse to allow reading the
constraint name for table check and unique constraints.

Co-authored-by: Ian Foote <python@ian.feete.org>
This commit is contained in:
Simon Charette
2018-08-05 21:06:52 -04:00
committed by Tim Graham
parent 2f120ac517
commit dba4a634ba
7 changed files with 147 additions and 82 deletions

View File

@@ -61,25 +61,24 @@ class BaseDatabaseSchemaEditor:
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)" sql_foreign_key_constraint = "FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s"
sql_create_check = "ALTER TABLE %(table)s ADD %(check)s" sql_unique_constraint = "UNIQUE (%(columns)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_check_constraint = "CHECK (%(check)s)"
sql_create_constraint = "ALTER TABLE %(table)s ADD %(constraint)s"
sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" sql_create_unique = None
sql_delete_unique = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_unique = sql_delete_constraint
sql_create_fk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s (%(to_column)s)%(deferrable)s"
)
sql_create_inline_fk = None sql_create_inline_fk = None
sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_fk = sql_delete_constraint
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s%(condition)s" sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s%(condition)s"
sql_delete_index = "DROP INDEX %(name)s" sql_delete_index = "DROP INDEX %(name)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_pk = sql_delete_constraint
sql_delete_procedure = 'DROP PROCEDURE %(procedure)s' sql_delete_procedure = 'DROP PROCEDURE %(procedure)s'
@@ -254,7 +253,7 @@ class BaseDatabaseSchemaEditor:
# Check constraints can go on the column SQL here # Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection) db_params = field.db_parameters(connection=self.connection)
if db_params['check']: if db_params['check']:
definition += " CHECK (%s)" % db_params['check'] definition += " " + self.sql_check_constraint % db_params
# Autoincrement SQL (for backends with inline variant) # Autoincrement SQL (for backends with inline variant)
col_type_suffix = field.db_type_suffix(connection=self.connection) col_type_suffix = field.db_type_suffix(connection=self.connection)
if col_type_suffix: if col_type_suffix:
@@ -287,7 +286,7 @@ class BaseDatabaseSchemaEditor:
for fields in model._meta.unique_together: for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields] columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns)) self.deferred_sql.append(self._create_unique_sql(model, columns))
constraints = [check.constraint_sql(model, self) for check in model._meta.constraints] constraints = [check.full_constraint_sql(model, self) for check in model._meta.constraints]
# Make the table # Make the table
sql = self.sql_create_table % { sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
@@ -596,7 +595,7 @@ class BaseDatabaseSchemaEditor:
old_field.column, old_field.column,
)) ))
for constraint_name in constraint_names: for constraint_name in constraint_names:
self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name)) self.execute(self._delete_constraint_sql(self.sql_delete_constraint, model, constraint_name))
# Have they renamed the column? # Have they renamed the column?
if old_field.column != new_field.column: if old_field.column != new_field.column:
self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
@@ -746,15 +745,16 @@ class BaseDatabaseSchemaEditor:
self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk")) self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk"))
# Does it have check constraints we need to add? # Does it have check constraints we need to add?
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
self.execute( constraint = self.sql_constraint % {
self.sql_create_check % {
"table": self.quote_name(model._meta.db_table),
"check": self.sql_check % {
'name': self.quote_name( 'name': self.quote_name(
self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'), self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'),
), ),
'check': new_db_params['check'], 'constraint': self.sql_check_constraint % new_db_params,
}, }
self.execute(
self.sql_create_constraint % {
'table': self.quote_name(model._meta.db_table),
'constraint': constraint,
} }
) )
# Drop the default if we need to # Drop the default if we need to
@@ -983,35 +983,57 @@ class BaseDatabaseSchemaEditor:
"type": new_type, "type": new_type,
} }
def _create_fk_sql(self, model, field, suffix): def _create_constraint_sql(self, table, name, constraint):
from_table = model._meta.db_table constraint = Statement(self.sql_constraint, name=name, constraint=constraint)
from_column = field.column return Statement(self.sql_create_constraint, table=table, constraint=constraint)
_, to_table = split_identifier(field.target_field.model._meta.db_table)
to_column = field.target_field.column
def _create_fk_sql(self, model, field, suffix):
def create_fk_name(*args, **kwargs): def create_fk_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs)) return self.quote_name(self._create_index_name(*args, **kwargs))
return Statement( table = Table(model._meta.db_table, self.quote_name)
self.sql_create_fk, name = ForeignKeyName(
table=Table(from_table, self.quote_name), model._meta.db_table,
name=ForeignKeyName(from_table, [from_column], to_table, [to_column], suffix, create_fk_name), [field.column],
column=Columns(from_table, [from_column], self.quote_name), split_identifier(field.target_field.model._meta.db_table)[1],
to_table=Table(field.target_field.model._meta.db_table, self.quote_name), [field.target_field.column],
to_column=Columns(field.target_field.model._meta.db_table, [to_column], self.quote_name), suffix,
deferrable=self.connection.ops.deferrable_sql(), create_fk_name,
) )
column = Columns(model._meta.db_table, [field.column], self.quote_name)
to_table = Table(field.target_field.model._meta.db_table, self.quote_name)
to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name)
deferrable = self.connection.ops.deferrable_sql()
constraint = Statement(
self.sql_foreign_key_constraint,
column=column,
to_table=to_table,
to_column=to_column,
deferrable=deferrable,
)
return self._create_constraint_sql(table, name, constraint)
def _create_unique_sql(self, model, columns): def _create_unique_sql(self, model, columns, name=None):
def create_unique_name(*args, **kwargs): def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs)) return self.quote_name(self._create_index_name(*args, **kwargs))
table = model._meta.db_table
table = Table(model._meta.db_table, self.quote_name)
if name is None:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else:
name = self.quote_name(name)
columns = Columns(table, columns, self.quote_name)
if self.sql_create_unique:
# Some databases use a different syntax for unique constraint
# creation.
return Statement( return Statement(
self.sql_create_unique, self.sql_create_unique,
table=Table(table, self.quote_name), table=table,
name=IndexName(table, columns, '_uniq', create_unique_name), name=name,
columns=Columns(table, columns, self.quote_name), columns=columns,
) )
constraint = Statement(self.sql_unique_constraint, columns=columns)
return self._create_constraint_sql(table, name, constraint)
def _delete_constraint_sql(self, template, model, name): def _delete_constraint_sql(self, template, model, name):
return template % { return template % {

View File

@@ -1,5 +1,7 @@
import re import re
import sqlparse
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo, BaseDatabaseIntrospection, FieldInfo, TableInfo,
) )
@@ -242,19 +244,37 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# table_name is a view. # table_name is a view.
pass pass
else: else:
fields_with_check_constraints = [ # Check constraint parsing is based of SQLite syntax diagram.
schema_row.strip().split(' ')[0][1:-1] # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
for schema_row in table_schema.split(',') def next_ttype(ttype):
if schema_row.find('CHECK') >= 0 for token in tokens:
] if token.ttype == ttype:
for field_name in fields_with_check_constraints: return token
# An arbitrary made up name.
constraints['__check__%s' % field_name] = { statement = sqlparse.parse(table_schema)[0]
'columns': [field_name], tokens = statement.flatten()
for token in tokens:
name = None
if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
# Table constraint
name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
name = name_token.value[1:-1]
token = next_ttype(sqlparse.tokens.Keyword)
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
# Column check constraint
if name is None:
column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
column = column_token.value[1:-1]
name = '__check__%s' % column
columns = [column]
else:
columns = []
constraints[name] = {
'check': True,
'columns': columns,
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'foreign_key': False, 'foreign_key': False,
'check': True,
'index': False, 'index': False,
} }
# Get the index info # Get the index info

View File

@@ -12,10 +12,10 @@ from django.db.utils import NotSupportedError
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s" sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s" sql_delete_unique = "DROP INDEX %(name)s"
sql_foreign_key_constraint = None
def __enter__(self): def __enter__(self):
# Some SQLite schema alterations need foreign key constraints to be # Some SQLite schema alterations need foreign key constraints to be

View File

@@ -10,16 +10,22 @@ class BaseConstraint:
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.') raise NotImplementedError('This method must be implemented by a subclass.')
def full_constraint_sql(self, model, schema_editor):
return schema_editor.sql_constraint % {
'name': schema_editor.quote_name(self.name),
'constraint': self.constraint_sql(model, schema_editor),
}
def create_sql(self, model, schema_editor): def create_sql(self, model, schema_editor):
sql = self.constraint_sql(model, schema_editor) sql = self.full_constraint_sql(model, schema_editor)
return schema_editor.sql_create_check % { return schema_editor.sql_create_constraint % {
'table': schema_editor.quote_name(model._meta.db_table), 'table': schema_editor.quote_name(model._meta.db_table),
'check': sql, 'constraint': sql,
} }
def remove_sql(self, model, schema_editor): def remove_sql(self, model, schema_editor):
quote_name = schema_editor.quote_name quote_name = schema_editor.quote_name
return schema_editor.sql_delete_check % { return schema_editor.sql_delete_constraint % {
'table': quote_name(model._meta.db_table), 'table': quote_name(model._meta.db_table),
'name': quote_name(self.name), 'name': quote_name(self.name),
} }
@@ -46,10 +52,7 @@ class CheckConstraint(BaseConstraint):
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection) sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params) params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % { return schema_editor.sql_check_constraint % {'check': sql % params}
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
def __repr__(self): def __repr__(self):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)

View File

@@ -293,6 +293,13 @@ Database backend API
* Third party database backends must implement support for partial indexes or * Third party database backends must implement support for partial indexes or
set ``DatabaseFeatures.supports_partial_indexes`` to ``False``. set ``DatabaseFeatures.supports_partial_indexes`` to ``False``.
* Several ``SchemaEditor`` attributes are changed:
* ``sql_create_check`` is replaced with ``sql_create_constraint``.
* ``sql_delete_check`` is replaced with ``sql_delete_constraint``.
* ``sql_create_fk`` is replaced with ``sql_foreign_key_constraint``,
``sql_constraint``, and ``sql_create_constraint``.
Admin actions are no longer collected from base ``ModelAdmin`` classes Admin actions are no longer collected from base ``ModelAdmin`` classes
---------------------------------------------------------------------- ----------------------------------------------------------------------

View File

@@ -1,10 +1,15 @@
from django.db import IntegrityError, models from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import Product from .models import Product
def get_constraints(table):
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
class BaseConstraintTests(SimpleTestCase): class BaseConstraintTests(SimpleTestCase):
def test_constraint_sql(self): def test_constraint_sql(self):
c = BaseConstraint('name') c = BaseConstraint('name')
@@ -37,3 +42,11 @@ class CheckConstraintTests(TestCase):
Product.objects.create(name='Valid', price=10, discounted_price=5) Product.objects.create(name='Valid', price=10, discounted_price=5)
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
Product.objects.create(name='Invalid', price=10, discounted_price=20) Product.objects.create(name='Invalid', price=10, discounted_price=20)
@skipUnlessDBFeature('supports_table_check_constraints')
def test_name(self):
constraints = get_constraints(Product._meta.db_table)
expected_name = 'price_gt_discounted_price'
if connection.features.uppercases_column_names:
expected_name = expected_name.upper()
self.assertIn(expected_name, constraints)

View File

@@ -2145,30 +2145,30 @@ class SchemaTests(TransactionTestCase):
self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table))
constraint_name = "CamelCaseUniqConstraint" constraint_name = "CamelCaseUniqConstraint"
editor.execute( editor.execute(editor._create_unique_sql(model, [field.column], constraint_name))
editor.sql_create_unique % {
"table": editor.quote_name(table),
"name": editor.quote_name(constraint_name),
"columns": editor.quote_name(field.column),
}
)
if connection.features.uppercases_column_names: if connection.features.uppercases_column_names:
constraint_name = constraint_name.upper() constraint_name = constraint_name.upper()
self.assertIn(constraint_name, self.get_constraints(model._meta.db_table)) self.assertIn(constraint_name, self.get_constraints(model._meta.db_table))
editor.alter_field(model, get_field(unique=True), field, strict=True) editor.alter_field(model, get_field(unique=True), field, strict=True)
self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(model._meta.db_table))
if editor.sql_create_fk: if editor.sql_foreign_key_constraint:
constraint_name = "CamelCaseFKConstraint" constraint_name = "CamelCaseFKConstraint"
editor.execute( fk_sql = editor.sql_foreign_key_constraint % {
editor.sql_create_fk % {
"table": editor.quote_name(table),
"name": editor.quote_name(constraint_name),
"column": editor.quote_name(column), "column": editor.quote_name(column),
"to_table": editor.quote_name(table), "to_table": editor.quote_name(table),
"to_column": editor.quote_name(model._meta.auto_field.column), "to_column": editor.quote_name(model._meta.auto_field.column),
"deferrable": connection.ops.deferrable_sql(), "deferrable": connection.ops.deferrable_sql(),
} }
constraint_sql = editor.sql_constraint % {
"name": editor.quote_name(constraint_name),
"constraint": fk_sql,
}
editor.execute(
editor.sql_create_constraint % {
"table": editor.quote_name(table),
"constraint": constraint_sql,
}
) )
if connection.features.uppercases_column_names: if connection.features.uppercases_column_names:
constraint_name = constraint_name.upper() constraint_name = constraint_name.upper()