diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 34f354341e..952b1cb20a 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -1,3 +1,7 @@ +# This is necessary in Python 2.5 to enable the with statement, in 2.6 +# and up it is no longer necessary. +from __future__ import with_statement + import sys import os import gzip @@ -166,12 +170,20 @@ class Command(BaseCommand): (format, fixture_name, humanize(fixture_dir))) try: objects = serializers.deserialize(format, fixture, using=using) - for obj in objects: - objects_in_fixture += 1 - if router.allow_syncdb(using, obj.object.__class__): - loaded_objects_in_fixture += 1 - models.add(obj.object.__class__) - obj.save(using=using) + + with connection.constraint_checks_disabled(): + for obj in objects: + objects_in_fixture += 1 + if router.allow_syncdb(using, obj.object.__class__): + loaded_objects_in_fixture += 1 + models.add(obj.object.__class__) + obj.save(using=using) + + # Since we disabled constraint checks, we must manually check for + # any invalid keys that might have been added + table_names = [model._meta.db_table for model in models] + connection.check_constraints(table_names=table_names) + loaded_object_count += loaded_objects_in_fixture fixture_object_count += objects_in_fixture label_found = True diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 1c3bc7e919..23ddedb4c6 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -3,6 +3,7 @@ try: except ImportError: import dummy_thread as thread from threading import local +from contextlib import contextmanager from django.conf import settings from django.db import DEFAULT_DB_ALIAS @@ -238,6 +239,35 @@ class BaseDatabaseWrapper(local): if self.savepoint_state: self._savepoint_commit(sid) + @contextmanager + def constraint_checks_disabled(self): + disabled = self.disable_constraint_checking() + try: + yield + finally: + if disabled: + self.enable_constraint_checking() + + def disable_constraint_checking(self): + """ + Backends can implement as needed to temporarily disable foreign key constraint + checking. + """ + pass + + def enable_constraint_checking(self): + """ + Backends can implement as needed to re-enable foreign key constraint checking. + """ + pass + + def check_constraints(self, table_names=None): + """ + Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS + ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered. + """ + pass + def close(self): if self.connection is not None: self.connection.close() @@ -869,6 +899,19 @@ class BaseDatabaseIntrospection(object): return sequence_list + def get_key_columns(self, cursor, table_name): + """ + Backends can override this to return a list of (column_name, referenced_table_name, + referenced_column_name) for all key columns in given table. + """ + raise NotImplementedError + + def get_primary_key_column(self, cursor, table_name): + """ + Backends can override this to return the column name of the primary key for the given table. + """ + raise NotImplementedError + class BaseDatabaseClient(object): """ This class encapsulates all backend-specific methods for opening a diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index 7de48c8b00..746f26bacc 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -34,6 +34,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): get_table_description = complain get_relations = complain get_indexes = complain + get_key_columns = complain class DatabaseWrapper(BaseDatabaseWrapper): operators = {} diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 6d02aa771c..f4523e4e82 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -349,3 +349,52 @@ class DatabaseWrapper(BaseDatabaseWrapper): raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info()) self.server_version = tuple([int(x) for x in m.groups()]) return self.server_version + + def disable_constraint_checking(self): + """ + Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True, + to indicate constraint checks need to be re-enabled. + """ + self.cursor().execute('SET foreign_key_checks=0') + return True + + def enable_constraint_checking(self): + """ + Re-enable foreign key checks after they have been disabled. + """ + self.cursor().execute('SET foreign_key_checks=1') + + def check_constraints(self, table_names=None): + """ + Checks each table name in table-names for rows with invalid foreign key references. This method is + intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to + determine if rows with invalid references were entered while constraint checks were off. + + Raises an IntegrityError on the first invalid foreign key reference encountered (if any) and provides + detailed information about the invalid reference in the error message. + + Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS + ALL IMMEDIATE") + """ + cursor = self.cursor() + if table_names is None: + table_names = self.introspection.get_table_list(cursor) + for table_name in table_names: + primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + if not primary_key_column_name: + continue + key_columns = self.introspection.get_key_columns(cursor, table_name) + for column_name, referenced_table_name, referenced_column_name in key_columns: + cursor.execute(""" + SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING + LEFT JOIN `%s` as REFERRED + ON (REFERRING.`%s` = REFERRED.`%s`) + WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL""" + % (primary_key_column_name, column_name, table_name, referenced_table_name, + column_name, referenced_column_name, column_name, referenced_column_name)) + for bad_row in cursor.fetchall(): + raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid " + "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s." + % (table_name, bad_row[0], + table_name, column_name, bad_row[1], + referenced_table_name, referenced_column_name)) diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 9e1518b06e..ab4eebea90 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -51,10 +51,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): representing all relationships to the given table. Indexes are 0-based. """ my_field_dict = self._name_to_index(cursor, table_name) - constraints = [] + constraints = self.get_key_columns(cursor, table_name) relations = {} + for my_fieldname, other_table, other_field in constraints: + other_field_index = self._name_to_index(cursor, other_table)[other_field] + my_field_index = my_field_dict[my_fieldname] + relations[my_field_index] = (other_field_index, other_table) + return relations + + def get_key_columns(self, cursor, table_name): + """ + Returns a list of (column_name, referenced_table_name, referenced_column_name) for all + key columns in given table. + """ + key_columns = [] try: - # This should work for MySQL 5.0. cursor.execute(""" SELECT column_name, referenced_table_name, referenced_column_name FROM information_schema.key_column_usage @@ -62,7 +73,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): AND table_schema = DATABASE() AND referenced_table_name IS NOT NULL AND referenced_column_name IS NOT NULL""", [table_name]) - constraints.extend(cursor.fetchall()) + key_columns.extend(cursor.fetchall()) except (ProgrammingError, OperationalError): # Fall back to "SHOW CREATE TABLE", for previous MySQL versions. # Go through all constraints and save the equal matches. @@ -74,14 +85,17 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): if match == None: break pos = match.end() - constraints.append(match.groups()) + key_columns.append(match.groups()) + return key_columns - for my_fieldname, other_table, other_field in constraints: - other_field_index = self._name_to_index(cursor, other_table)[other_field] - my_field_index = my_field_dict[my_fieldname] - relations[my_field_index] = (other_field_index, other_table) - - return relations + def get_primary_key_column(self, cursor, table_name): + """ + Returns the name of the primary key column for the given table + """ + for column in self.get_indexes(cursor, table_name).iteritems(): + if column[1]['primary_key']: + return column[0] + return None def get_indexes(self, cursor, table_name): """ diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 930b1bb0f1..3cadb6617d 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -428,6 +428,14 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation(self) + def check_constraints(self, table_names=None): + """ + To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they + are returned to deferred. + """ + self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') + self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + def _valid_connection(self): return self.connection is not None diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 6ed59a66f4..37aa072f6b 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -106,6 +106,14 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.validation = BaseDatabaseValidation(self) self._pg_version = None + def check_constraints(self, table_names=None): + """ + To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they + are returned to deferred. + """ + self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') + self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + def _get_pg_version(self): if self._pg_version is None: self._pg_version = get_version(self.connection) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 79c5eded34..f922d38638 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -206,6 +206,40 @@ class DatabaseWrapper(BaseDatabaseWrapper): connection_created.send(sender=self.__class__, connection=self) return self.connection.cursor(factory=SQLiteCursorWrapper) + def check_constraints(self, table_names=None): + """ + Checks each table name in table-names for rows with invalid foreign key references. This method is + intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to + determine if rows with invalid references were entered while constraint checks were off. + + Raises an IntegrityError on the first invalid foreign key reference encountered (if any) and provides + detailed information about the invalid reference in the error message. + + Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS + ALL IMMEDIATE") + """ + cursor = self.cursor() + if table_names is None: + table_names = self.introspection.get_table_list(cursor) + for table_name in table_names: + primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + if not primary_key_column_name: + continue + key_columns = self.introspection.get_key_columns(cursor, table_name) + for column_name, referenced_table_name, referenced_column_name in key_columns: + cursor.execute(""" + SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING + LEFT JOIN `%s` as REFERRED + ON (REFERRING.`%s` = REFERRED.`%s`) + WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL""" + % (primary_key_column_name, column_name, table_name, referenced_table_name, + column_name, referenced_column_name, column_name, referenced_column_name)) + for bad_row in cursor.fetchall(): + raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid " + "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s." + % (table_name, bad_row[0], table_name, column_name, bad_row[1], + referenced_table_name, referenced_column_name)) + def close(self): # If database is in memory, closing the connection destroys the # database. To prevent accidental data loss, ignore close requests on diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 5ee7b64bcd..9652a4da6a 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -103,6 +103,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return relations + def get_key_columns(self, cursor, table_name): + """ + Returns a list of (column_name, referenced_table_name, referenced_column_name) for all + key columns in given table. + """ + key_columns = [] + + # Schema for this table + cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"]) + results = cursor.fetchone()[0].strip() + results = results[results.index('(')+1:results.rindex(')')] + + # Walk through and look for references to other tables. SQLite doesn't + # really have enforced references, but since it echoes out the SQL used + # to create the table we can look for REFERENCES statements used there. + for field_index, field_desc in enumerate(results.split(',')): + field_desc = field_desc.strip() + if field_desc.startswith("UNIQUE"): + continue + + m = re.search('"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I) + if not m: + continue + + # This will append (column_name, referenced_table_name, referenced_column_name) to key_columns + key_columns.append(tuple([s.strip('"') for s in m.groups()])) + + return key_columns + def get_indexes(self, cursor, table_name): """ Returns a dictionary of fieldname -> infodict for the given table, @@ -128,6 +157,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): indexes[name]['unique'] = True return indexes + def get_primary_key_column(self, cursor, table_name): + """ + Get the column name of the primary key for the given table. + """ + # Don't use PRAGMA because that causes issues with some transactions + cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"]) + results = cursor.fetchone()[0].strip() + results = results[results.index('(')+1:results.rindex(')')] + for field_desc in results.split(','): + field_desc = field_desc.strip() + m = re.search('"(.*)".*PRIMARY KEY$', field_desc) + if m: + return m.groups()[0] + return None + def _table_info(self, cursor, name): cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name)) # cid, name, type, notnull, dflt_value, pk diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt index 2f55b9c8c6..5a2042a02a 100644 --- a/docs/ref/databases.txt +++ b/docs/ref/databases.txt @@ -142,6 +142,18 @@ currently the only engine that supports full-text indexing and searching. The InnoDB_ engine is fully transactional and supports foreign key references and is probably the best choice at this point in time. +.. versionchanged:: 1.4 + +In previous versions of Django, fixtures with forward references (i.e. +relations to rows that have not yet been inserted into the database) would fail +to load when using the InnoDB storage engine. This was due to the fact that InnoDB +deviates from the SQL standard by checking foreign key constraints immediately +instead of deferring the check until the transaction is committed. This +problem has been resolved in Django 1.4. Fixture data is now loaded with foreign key +checks turned off; foreign key checks are then re-enabled when the data has +finished loading, at which point the entire table is checked for invalid foreign +key references and an `IntegrityError` is raised if any are found. + .. _storage engines: http://dev.mysql.com/doc/refman/5.5/en/storage-engines.html .. _MyISAM: http://dev.mysql.com/doc/refman/5.5/en/myisam-storage-engine.html .. _InnoDB: http://dev.mysql.com/doc/refman/5.5/en/innodb.html diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt index 2723b42195..ec4eae09af 100644 --- a/docs/releases/1.4.txt +++ b/docs/releases/1.4.txt @@ -235,6 +235,9 @@ Django 1.4 also includes several smaller improvements worth noting: to delete all files at the destination before copying or linking the static files. +* It is now possible to load fixtures containing forward references when using + MySQL with the InnoDB database engine. + .. _backwards-incompatible-changes-1.4: Backwards incompatible changes in 1.4 diff --git a/tests/modeltests/serializers/tests.py b/tests/modeltests/serializers/tests.py index 4a7e0a2086..def0254a9f 100644 --- a/tests/modeltests/serializers/tests.py +++ b/tests/modeltests/serializers/tests.py @@ -1,3 +1,7 @@ +# This is necessary in Python 2.5 to enable the with statement, in 2.6 +# and up it is no longer necessary. +from __future__ import with_statement + # -*- coding: utf-8 -*- from datetime import datetime from StringIO import StringIO @@ -5,7 +9,7 @@ from xml.dom import minidom from django.conf import settings from django.core import serializers -from django.db import transaction +from django.db import transaction, connection from django.test import TestCase, TransactionTestCase, Approximate from django.utils import simplejson, unittest @@ -252,8 +256,9 @@ class SerializersTransactionTestBase(object): transaction.enter_transaction_management() transaction.managed(True) objs = serializers.deserialize(self.serializer_name, self.fwd_ref_str) - for obj in objs: - obj.save() + with connection.constraint_checks_disabled(): + for obj in objs: + obj.save() transaction.commit() transaction.leave_transaction_management() diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py index 29db6a705a..27d3dfdddc 100644 --- a/tests/regressiontests/backends/tests.py +++ b/tests/regressiontests/backends/tests.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- # Unit and doctests for specific database backends. +from __future__ import with_statement import datetime from django.conf import settings from django.core.management.color import no_style -from django.db import backend, connection, connections, DEFAULT_DB_ALIAS, IntegrityError +from django.db import backend, connection, connections, DEFAULT_DB_ALIAS, IntegrityError, transaction from django.db.backends.signals import connection_created from django.db.backends.postgresql_psycopg2 import version as pg_version from django.test import TestCase, skipUnlessDBFeature, TransactionTestCase @@ -328,7 +329,8 @@ class FkConstraintsTests(TransactionTestCase): try: a.save() except IntegrityError: - pass + return + self.skipTest("This backend does not support integrity checks.") def test_integrity_checks_on_update(self): """ @@ -343,4 +345,60 @@ class FkConstraintsTests(TransactionTestCase): try: a.save() except IntegrityError: - pass + return + self.skipTest("This backend does not support integrity checks.") + + def test_disable_constraint_checks_manually(self): + """ + When constraint checks are disabled, should be able to write bad data without IntegrityErrors. + """ + with transaction.commit_manually(): + # Create an Article. + models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r) + # Retrive it from the DB + a = models.Article.objects.get(headline="Test article") + a.reporter_id = 30 + try: + connection.disable_constraint_checking() + a.save() + connection.enable_constraint_checking() + except IntegrityError: + self.fail("IntegrityError should not have occurred.") + finally: + transaction.rollback() + + def test_disable_constraint_checks_context_manager(self): + """ + When constraint checks are disabled (using context manager), should be able to write bad data without IntegrityErrors. + """ + with transaction.commit_manually(): + # Create an Article. + models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r) + # Retrive it from the DB + a = models.Article.objects.get(headline="Test article") + a.reporter_id = 30 + try: + with connection.constraint_checks_disabled(): + a.save() + except IntegrityError: + self.fail("IntegrityError should not have occurred.") + finally: + transaction.rollback() + + def test_check_constraints(self): + """ + Constraint checks should raise an IntegrityError when bad data is in the DB. + """ + with transaction.commit_manually(): + # Create an Article. + models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r) + # Retrive it from the DB + a = models.Article.objects.get(headline="Test article") + a.reporter_id = 30 + try: + with connection.constraint_checks_disabled(): + a.save() + with self.assertRaises(IntegrityError): + connection.check_constraints() + finally: + transaction.rollback() diff --git a/tests/regressiontests/fixtures_regress/tests.py b/tests/regressiontests/fixtures_regress/tests.py index a565ec96e0..b67155d426 100644 --- a/tests/regressiontests/fixtures_regress/tests.py +++ b/tests/regressiontests/fixtures_regress/tests.py @@ -362,6 +362,35 @@ class TestFixtures(TestCase): % widget.pk ) + def test_loaddata_works_when_fixture_has_forward_refs(self): + """ + Regression for #3615 - Forward references cause fixtures not to load in MySQL (InnoDB) + """ + management.call_command( + 'loaddata', + 'forward_ref.json', + verbosity=0, + commit=False + ) + self.assertEqual(Book.objects.all()[0].id, 1) + self.assertEqual(Person.objects.all()[0].id, 4) + + def test_loaddata_raises_error_when_fixture_has_invalid_foreign_key(self): + """ + Regression for #3615 - Ensure data with nonexistent child key references raises error + """ + stderr = StringIO() + management.call_command( + 'loaddata', + 'forward_ref_bad_data.json', + verbosity=0, + commit=False, + stderr=stderr, + ) + self.assertTrue( + stderr.getvalue().startswith('Problem installing fixture') + ) + class NaturalKeyFixtureTests(TestCase): def assertRaisesMessage(self, exc, msg, func, *args, **kwargs): diff --git a/tests/regressiontests/introspection/tests.py b/tests/regressiontests/introspection/tests.py index 4f5fb09193..fdf30126cd 100644 --- a/tests/regressiontests/introspection/tests.py +++ b/tests/regressiontests/introspection/tests.py @@ -95,6 +95,16 @@ class IntrospectionTests(TestCase): # That's {field_index: (field_index_other_table, other_table)} self.assertEqual(relations, {3: (0, Reporter._meta.db_table)}) + def test_get_key_columns(self): + cursor = connection.cursor() + key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table) + self.assertEqual(key_columns, [(u'reporter_id', Reporter._meta.db_table, u'id')]) + + def test_get_primary_key_column(self): + cursor = connection.cursor() + primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table) + self.assertEqual(primary_key_column, u'id') + def test_get_indexes(self): cursor = connection.cursor() indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) diff --git a/tests/regressiontests/serializers_regress/tests.py b/tests/regressiontests/serializers_regress/tests.py index cd2ce3cc9a..bb6f598719 100644 --- a/tests/regressiontests/serializers_regress/tests.py +++ b/tests/regressiontests/serializers_regress/tests.py @@ -6,6 +6,8 @@ test case that is capable of testing the capabilities of the serializers. This includes all valid data values, plus forward, backwards and self references. """ +# This is necessary in Python 2.5 to enable the with statement, in 2.6 +# and up it is no longer necessary. from __future__ import with_statement import datetime @@ -382,7 +384,8 @@ def serializerTest(format, self): objects = [] instance_count = {} for (func, pk, klass, datum) in test_data: - objects.extend(func[0](pk, klass, datum)) + with connection.constraint_checks_disabled(): + objects.extend(func[0](pk, klass, datum)) # Get a count of the number of objects created for each class for klass in instance_count: