mirror of
https://github.com/django/django.git
synced 2025-10-30 17:16:10 +00:00
Fix migration planner to fully understand squashed migrations. And test.
This commit is contained in:
@@ -11,7 +11,6 @@ class MigrationExecutor(object):
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.loader.load_disk()
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
@@ -20,7 +19,7 @@ class MigrationExecutor(object):
|
||||
Given a set of targets, returns a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
applied = self.recorder.applied_migrations()
|
||||
applied = set(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (appname, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
@@ -87,7 +86,13 @@ class MigrationExecutor(object):
|
||||
with self.connection.schema_editor() as schema_editor:
|
||||
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||
migration.apply(project_state, schema_editor)
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
# Report prgress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration)
|
||||
|
||||
@@ -101,6 +106,12 @@ class MigrationExecutor(object):
|
||||
with self.connection.schema_editor() as schema_editor:
|
||||
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||
migration.unapply(project_state, schema_editor)
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from django.utils.functional import cached_property
|
||||
from django.db.models.loading import cache
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.utils import six
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
@@ -32,10 +33,12 @@ class MigrationLoader(object):
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection, load=True):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
@@ -55,6 +58,7 @@ class MigrationLoader(object):
|
||||
# Get the migrations module directory
|
||||
app_label = app.__name__.split(".")[-2]
|
||||
module_name = self.migrations_module(app_label)
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ImportError as e:
|
||||
@@ -71,6 +75,9 @@ class MigrationLoader(object):
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, '__path__'):
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
six.moves.reload_module(module)
|
||||
self.migrated_apps.add(app_label)
|
||||
directory = os.path.dirname(module.__file__)
|
||||
# Scan for .py[c|o] files
|
||||
@@ -107,9 +114,6 @@ class MigrationLoader(object):
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"Returns the migration(s) which match the given app label and name _prefix_"
|
||||
# Make sure we have the disk data
|
||||
if self.disk_migrations is None:
|
||||
self.load_disk()
|
||||
# Do the search
|
||||
results = []
|
||||
for l, n in self.disk_migrations:
|
||||
@@ -122,18 +126,17 @@ class MigrationLoader(object):
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
@cached_property
|
||||
def graph(self):
|
||||
def build_graph(self):
|
||||
"""
|
||||
Builds a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Make sure we have the disk data
|
||||
if self.disk_migrations is None:
|
||||
self.load_disk()
|
||||
# And the database data
|
||||
if self.applied_migrations is None:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# Do a first pass to separate out replacing and non-replacing migrations
|
||||
normal = {}
|
||||
replacing = {}
|
||||
@@ -152,12 +155,12 @@ class MigrationLoader(object):
|
||||
# Carry out replacements if we can - that is, if all replaced migrations
|
||||
# are either unapplied or missing.
|
||||
for key, migration in replacing.items():
|
||||
# Do the check
|
||||
can_replace = True
|
||||
for target in migration.replaces:
|
||||
if target in self.applied_migrations:
|
||||
can_replace = False
|
||||
break
|
||||
# Ensure this replacement migration is not in applied_migrations
|
||||
self.applied_migrations.discard(key)
|
||||
# Do the check. We can replace if all our replace targets are
|
||||
# applied, or if all of them are unapplied.
|
||||
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
|
||||
can_replace = all(applied_statuses) or (not any(applied_statuses))
|
||||
if not can_replace:
|
||||
continue
|
||||
# Alright, time to replace. Step through the replaced migrations
|
||||
@@ -171,14 +174,16 @@ class MigrationLoader(object):
|
||||
normal[child_key].dependencies.remove(replaced)
|
||||
normal[child_key].dependencies.append(key)
|
||||
normal[key] = migration
|
||||
# Mark the replacement as applied if all its replaced ones are
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations.add(key)
|
||||
# Finally, make a graph and load everything into it
|
||||
graph = MigrationGraph()
|
||||
self.graph = MigrationGraph()
|
||||
for key, migration in normal.items():
|
||||
graph.add_node(key, migration)
|
||||
self.graph.add_node(key, migration)
|
||||
for key, migration in normal.items():
|
||||
for parent in migration.dependencies:
|
||||
graph.add_dependency(key, parent)
|
||||
return graph
|
||||
self.graph.add_dependency(key, parent)
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
|
||||
@@ -39,6 +39,11 @@ class Migration(object):
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Migration):
|
||||
|
||||
Reference in New Issue
Block a user