From e6f7f4533c183800c2a9ac526d8ee8887e96ac5d Mon Sep 17 00:00:00 2001
From: Andrew Godwin <andrew@aeracode.org>
Date: Thu, 30 May 2013 18:08:58 +0100
Subject: [PATCH] Add an Executor for end-to-end running

---
 django/db/migrations/executor.py           | 68 ++++++++++++++++++++++
 django/db/migrations/migration.py          | 48 +++++++++++++++
 django/db/migrations/operations/fields.py  | 16 ++---
 tests/migrations/migrations/0002_second.py |  2 +-
 tests/migrations/test_executor.py          | 35 +++++++++++
 tests/migrations/test_loader.py            |  2 +-
 tests/migrations/test_operations.py        | 28 ++++++++-
 7 files changed, 188 insertions(+), 11 deletions(-)
 create mode 100644 django/db/migrations/executor.py
 create mode 100644 tests/migrations/test_executor.py

diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py
new file mode 100644
index 0000000000..e9e98d41fd
--- /dev/null
+++ b/django/db/migrations/executor.py
@@ -0,0 +1,68 @@
+from .loader import MigrationLoader
+from .recorder import MigrationRecorder
+
+
+class MigrationExecutor(object):
+    """
+    End-to-end migration execution - loads migrations, and runs them
+    up or down to a specified set of targets.
+    """
+
+    def __init__(self, connection):
+        self.connection = connection
+        self.loader = MigrationLoader(self.connection)
+        self.recorder = MigrationRecorder(self.connection)
+
+    def migration_plan(self, targets):
+        """
+        Given a set of targets, returns a list of (Migration instance, backwards?).
+        """
+        plan = []
+        applied = self.recorder.applied_migrations()
+        for target in targets:
+            # If the migration is already applied, do backwards mode,
+            # otherwise do forwards mode.
+            if target in applied:
+                for migration in self.loader.graph.backwards_plan(target)[:-1]:
+                    if migration in applied:
+                        plan.append((self.loader.graph.nodes[migration], True))
+                        applied.remove(migration)
+            else:
+                for migration in self.loader.graph.forwards_plan(target):
+                    if migration not in applied:
+                        plan.append((self.loader.graph.nodes[migration], False))
+                        applied.add(migration)
+        return plan
+
+    def migrate(self, targets):
+        """
+        Migrates the database up to the given targets.
+        """
+        plan = self.migration_plan(targets)
+        for migration, backwards in plan:
+            if not backwards:
+                self.apply_migration(migration)
+            else:
+                self.unapply_migration(migration)
+
+    def apply_migration(self, migration):
+        """
+        Runs a migration forwards.
+        """
+        print "Applying %s" % migration
+        with self.connection.schema_editor() as schema_editor:
+            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
+            migration.apply(project_state, schema_editor)
+        self.recorder.record_applied(migration.app_label, migration.name)
+        print "Finished %s" % migration
+
+    def unapply_migration(self, migration):
+        """
+        Runs a migration backwards.
+        """
+        print "Unapplying %s" % migration
+        with self.connection.schema_editor() as schema_editor:
+            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
+            migration.unapply(project_state, schema_editor)
+        self.recorder.record_unapplied(migration.app_label, migration.name)
+        print "Finished %s" % migration
diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py
index a8b744a9b4..672e7440ad 100644
--- a/django/db/migrations/migration.py
+++ b/django/db/migrations/migration.py
@@ -36,6 +36,17 @@ class Migration(object):
         self.name = name
         self.app_label = app_label
 
+    def __eq__(self, other):
+        if not isinstance(other, Migration):
+            return False
+        return (self.name == other.name) and (self.app_label == other.app_label)
+
+    def __ne__(self, other):
+        return not (self == other)
+
+    def __repr__(self):
+        return "<Migration %s.%s>" % (self.app_label, self.name)
+
     def mutate_state(self, project_state):
         """
         Takes a ProjectState and returns a new one with the migration's
@@ -45,3 +56,40 @@ class Migration(object):
         for operation in self.operations:
             operation.state_forwards(self.app_label, new_state)
         return new_state
+
+    def apply(self, project_state, schema_editor):
+        """
+        Takes a project_state representing all migrations prior to this one
+        and a schema_editor for a live database and applies the migration
+        in a forwards order.
+
+        Returns the resulting project state for efficient re-use by following
+        Migrations.
+        """
+        for operation in self.operations:
+            # Get the state after the operation has run
+            new_state = project_state.clone()
+            operation.state_forwards(self.app_label, new_state)
+            # Run the operation
+            operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
+            # Switch states
+            project_state = new_state
+        return project_state
+
+    def unapply(self, project_state, schema_editor):
+        """
+        Takes a project_state representing all migrations prior to this one
+        and a schema_editor for a live database and applies the migration
+        in a reverse order.
+        """
+        # We need to pre-calculate the stack of project states
+        to_run = []
+        for operation in self.operations:
+            new_state = project_state.clone()
+            operation.state_forwards(self.app_label, new_state)
+            to_run.append((operation, project_state, new_state))
+            project_state = new_state
+        # Now run them in reverse
+        to_run.reverse()
+        for operation, to_state, from_state in to_run:
+            operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py
index 2ecf77f7ef..efb12b22c3 100644
--- a/django/db/migrations/operations/fields.py
+++ b/django/db/migrations/operations/fields.py
@@ -16,13 +16,13 @@ class AddField(Operation):
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
         app_cache = to_state.render()
-        model = app_cache.get_model(app_label, self.name)
-        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
+        model = app_cache.get_model(app_label, self.model_name)
+        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         app_cache = from_state.render()
-        model = app_cache.get_model(app_label, self.name)
-        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
+        model = app_cache.get_model(app_label, self.model_name)
+        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
 
 
 class RemoveField(Operation):
@@ -43,10 +43,10 @@ class RemoveField(Operation):
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
         app_cache = from_state.render()
-        model = app_cache.get_model(app_label, self.name)
-        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
+        model = app_cache.get_model(app_label, self.model_name)
+        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         app_cache = to_state.render()
-        model = app_cache.get_model(app_label, self.name)
-        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
+        model = app_cache.get_model(app_label, self.model_name)
+        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
diff --git a/tests/migrations/migrations/0002_second.py b/tests/migrations/migrations/0002_second.py
index fbaef11f71..ace9a83347 100644
--- a/tests/migrations/migrations/0002_second.py
+++ b/tests/migrations/migrations/0002_second.py
@@ -11,7 +11,7 @@ class Migration(migrations.Migration):
 
         migrations.RemoveField("Author", "silly_field"),
 
-        migrations.AddField("Author", "important", models.BooleanField()),
+        migrations.AddField("Author", "rating", models.IntegerField(default=0)),
 
         migrations.CreateModel(
             "Book",
diff --git a/tests/migrations/test_executor.py b/tests/migrations/test_executor.py
new file mode 100644
index 0000000000..629c47de56
--- /dev/null
+++ b/tests/migrations/test_executor.py
@@ -0,0 +1,35 @@
+from django.test import TransactionTestCase
+from django.db import connection
+from django.db.migrations.executor import MigrationExecutor
+
+
+class ExecutorTests(TransactionTestCase):
+    """
+    Tests the migration executor (full end-to-end running).
+
+    Bear in mind that if these are failing you should fix the other
+    test failures first, as they may be propagating into here.
+    """
+
+    def test_run(self):
+        """
+        Tests running a simple set of migrations.
+        """
+        executor = MigrationExecutor(connection)
+        # Let's look at the plan first and make sure it's up to scratch
+        plan = executor.migration_plan([("migrations", "0002_second")])
+        self.assertEqual(
+            plan,
+            [
+                (executor.loader.graph.nodes["migrations", "0001_initial"], False),
+                (executor.loader.graph.nodes["migrations", "0002_second"], False),
+            ],
+        )
+        # Were the tables there before?
+        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
+        self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
+        # Alright, let's try running it
+        executor.migrate([("migrations", "0002_second")])
+        # Are the tables there now?
+        self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
+        self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
diff --git a/tests/migrations/test_loader.py b/tests/migrations/test_loader.py
index badace57cc..9318f77004 100644
--- a/tests/migrations/test_loader.py
+++ b/tests/migrations/test_loader.py
@@ -54,7 +54,7 @@ class LoaderTests(TransactionTestCase):
         author_state = project_state.models["migrations", "author"]
         self.assertEqual(
             [x for x, y in author_state.fields],
-            ["id", "name", "slug", "age", "important"]
+            ["id", "name", "slug", "age", "rating"]
         )
 
         book_state = project_state.models["migrations", "book"]
diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py
index ea6dea0302..bf8549e092 100644
--- a/tests/migrations/test_operations.py
+++ b/tests/migrations/test_operations.py
@@ -1,6 +1,6 @@
 from django.test import TransactionTestCase
 from django.db import connection, models, migrations
-from django.db.migrations.state import ProjectState, ModelState
+from django.db.migrations.state import ProjectState
 
 
 class OperationTests(TransactionTestCase):
@@ -16,6 +16,12 @@ class OperationTests(TransactionTestCase):
     def assertTableNotExists(self, table):
         self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
 
+    def assertColumnExists(self, table, column):
+        self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+
+    def assertColumnNotExists(self, table, column):
+        self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+
     def set_up_test_model(self, app_label):
         """
         Creates a test model state and database table.
@@ -82,3 +88,23 @@ class OperationTests(TransactionTestCase):
         with connection.schema_editor() as editor:
             operation.database_backwards("test_dlmo", editor, new_state, project_state)
         self.assertTableExists("test_dlmo_pony")
+
+    def test_add_field(self):
+        """
+        Tests the AddField operation.
+        """
+        project_state = self.set_up_test_model("test_adfl")
+        # Test the state alteration
+        operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
+        new_state = project_state.clone()
+        operation.state_forwards("test_adfl", new_state)
+        self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3)
+        # Test the database alteration
+        self.assertColumnNotExists("test_adfl_pony", "height")
+        with connection.schema_editor() as editor:
+            operation.database_forwards("test_adfl", editor, project_state, new_state)
+        self.assertColumnExists("test_adfl_pony", "height")
+        # And test reversal
+        with connection.schema_editor() as editor:
+            operation.database_backwards("test_adfl", editor, new_state, project_state)
+        self.assertColumnNotExists("test_adfl_pony", "height")