1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Change FKs when what they point to changes

This commit is contained in:
Andrew Godwin
2013-12-11 13:16:29 +00:00
parent f3582a0594
commit 248fdb1110
3 changed files with 107 additions and 78 deletions

View File

@@ -615,6 +615,11 @@ class BaseDatabaseSchemaEditor(object):
"extra": "", "extra": "",
} }
) )
# Type alteration on primary key? Then we need to alter the column
# referring to us.
rels_to_update = []
if old_field.primary_key and new_field.primary_key and old_type != new_type:
rels_to_update.extend(model._meta.get_all_related_objects())
# Changed to become primary key? # Changed to become primary key?
# Note that we don't detect unsetting of a PK, as we assume another field # Note that we don't detect unsetting of a PK, as we assume another field
# will always come along and replace it. # will always come along and replace it.
@@ -641,6 +646,21 @@ class BaseDatabaseSchemaEditor(object):
"columns": self.quote_name(new_field.column), "columns": self.quote_name(new_field.column),
} }
) )
# Update all referencing columns
rels_to_update.extend(model._meta.get_all_related_objects())
# Handle out type alters on the other end of rels from the PK stuff above
for rel in rels_to_update:
rel_db_params = rel.field.db_parameters(connection=self.connection)
rel_type = rel_db_params['type']
self.execute(
self.sql_alter_column % {
"table": self.quote_name(rel.model._meta.db_table),
"changes": self.sql_alter_column_type % {
"column": self.quote_name(rel.field.column),
"type": rel_type,
}
}
)
# Does it have a foreign key? # Does it have a foreign key?
if new_field.rel: if new_field.rel:
self.execute( self.execute(

View File

@@ -153,80 +153,16 @@ class MigrationAutodetector(object):
) )
# Changes within models # Changes within models
kept_models = set(old_model_keys).intersection(new_model_keys) kept_models = set(old_model_keys).intersection(new_model_keys)
old_fields = set()
new_fields = set()
for app_label, model_name in kept_models: for app_label, model_name in kept_models:
old_model_state = self.from_state.models[app_label, model_name] old_model_state = self.from_state.models[app_label, model_name]
new_model_state = self.to_state.models[app_label, model_name] new_model_state = self.to_state.models[app_label, model_name]
# New fields # Collect field changes for later global dealing with (so AddFields
old_field_names = set(x for x, y in old_model_state.fields) # always come before AlterFields even on separate models)
new_field_names = set(x for x, y in new_model_state.fields) old_fields.update((app_label, model_name, x) for x, y in old_model_state.fields)
for field_name in new_field_names - old_field_names: new_fields.update((app_label, model_name, x) for x, y in new_model_state.fields)
field = new_model_state.get_field_by_name(field_name) # Unique_together changes
# Scan to see if this is actually a rename!
field_dec = field.deconstruct()[1:]
found_rename = False
for removed_field_name in (old_field_names - new_field_names):
if old_model_state.get_field_by_name(removed_field_name).deconstruct()[1:] == field_dec:
if self.questioner.ask_rename(model_name, removed_field_name, field_name, field):
self.add_to_migration(
app_label,
operations.RenameField(
model_name=model_name,
old_name=removed_field_name,
new_name=field_name,
)
)
old_field_names.remove(removed_field_name)
new_field_names.remove(field_name)
found_rename = True
break
if found_rename:
continue
# You can't just add NOT NULL fields with no default
if not field.null and not field.has_default():
field = field.clone()
field.default = self.questioner.ask_not_null_addition(field_name, model_name)
self.add_to_migration(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
preserve_default=False,
)
)
else:
self.add_to_migration(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
)
)
# Old fields
for field_name in old_field_names - new_field_names:
self.add_to_migration(
app_label,
operations.RemoveField(
model_name=model_name,
name=field_name,
)
)
# The same fields
for field_name in old_field_names.intersection(new_field_names):
# Did the field change?
old_field_dec = old_model_state.get_field_by_name(field_name).deconstruct()
new_field_dec = new_model_state.get_field_by_name(field_name).deconstruct()
if old_field_dec != new_field_dec:
self.add_to_migration(
app_label,
operations.AlterField(
model_name=model_name,
name=field_name,
field=new_model_state.get_field_by_name(field_name),
)
)
# unique_together changes
if old_model_state.options.get("unique_together", set()) != new_model_state.options.get("unique_together", set()): if old_model_state.options.get("unique_together", set()) != new_model_state.options.get("unique_together", set()):
self.add_to_migration( self.add_to_migration(
app_label, app_label,
@@ -235,6 +171,81 @@ class MigrationAutodetector(object):
unique_together=new_model_state.options.get("unique_together", set()), unique_together=new_model_state.options.get("unique_together", set()),
) )
) )
# New fields
for app_label, model_name, field_name in new_fields - old_fields:
old_model_state = self.from_state.models[app_label, model_name]
new_model_state = self.to_state.models[app_label, model_name]
field = new_model_state.get_field_by_name(field_name)
# Scan to see if this is actually a rename!
field_dec = field.deconstruct()[1:]
found_rename = False
for rem_app_label, rem_model_name, rem_field_name in (old_fields - new_fields):
if rem_app_label == app_label and rem_model_name == model_name:
if old_model_state.get_field_by_name(rem_field_name).deconstruct()[1:] == field_dec:
if self.questioner.ask_rename(model_name, rem_field_name, field_name, field):
self.add_to_migration(
app_label,
operations.RenameField(
model_name=model_name,
old_name=rem_field_name,
new_name=field_name,
)
)
old_fields.remove((rem_app_label, rem_model_name, rem_field_name))
new_fields.remove((app_label, model_name, field_name))
found_rename = True
break
if found_rename:
continue
# You can't just add NOT NULL fields with no default
if not field.null and not field.has_default():
field = field.clone()
field.default = self.questioner.ask_not_null_addition(field_name, model_name)
self.add_to_migration(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
preserve_default=False,
)
)
else:
self.add_to_migration(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
)
)
# Old fields
for app_label, model_name, field_name in old_fields - new_fields:
old_model_state = self.from_state.models[app_label, model_name]
new_model_state = self.to_state.models[app_label, model_name]
self.add_to_migration(
app_label,
operations.RemoveField(
model_name=model_name,
name=field_name,
)
)
# The same fields
for app_label, model_name, field_name in old_fields.intersection(new_fields):
# Did the field change?
old_model_state = self.from_state.models[app_label, model_name]
new_model_state = self.to_state.models[app_label, model_name]
old_field_dec = old_model_state.get_field_by_name(field_name).deconstruct()
new_field_dec = new_model_state.get_field_by_name(field_name).deconstruct()
if old_field_dec != new_field_dec:
self.add_to_migration(
app_label,
operations.AlterField(
model_name=model_name,
name=field_name,
field=new_model_state.get_field_by_name(field_name),
)
)
# Alright, now add internal dependencies # Alright, now add internal dependencies
for app_label, migrations in self.migrations.items(): for app_label, migrations in self.migrations.items():
for m1, m2 in zip(migrations, migrations[1:]): for m1, m2 in zip(migrations, migrations[1:]):

View File

@@ -24,10 +24,9 @@ class AddField(Operation):
state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name)
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name) from_model = from_state.render().get_model(app_label, self.model_name)
@@ -74,10 +73,9 @@ class RemoveField(Operation):
schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name)
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
def describe(self): def describe(self):
return "Remove field %s from %s" % (self.name, self.model_name) return "Remove field %s from %s" % (self.name, self.model_name)
@@ -109,7 +107,7 @@ class AlterField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
from_model, to_model,
from_model._meta.get_field_by_name(self.name)[0], from_model._meta.get_field_by_name(self.name)[0],
to_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0],
) )
@@ -155,7 +153,7 @@ class RenameField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
from_model, to_model,
from_model._meta.get_field_by_name(self.old_name)[0], from_model._meta.get_field_by_name(self.old_name)[0],
to_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.new_name)[0],
) )
@@ -165,7 +163,7 @@ class RenameField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
from_model, to_model,
from_model._meta.get_field_by_name(self.new_name)[0], from_model._meta.get_field_by_name(self.new_name)[0],
to_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.old_name)[0],
) )