mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Start adding operations that work and tests for them
This commit is contained in:
		| @@ -272,7 +272,7 @@ class BaseDatabaseSchemaEditor(object): | ||||
|             "new_tablespace": self.quote_name(new_db_tablespace), | ||||
|         }) | ||||
|  | ||||
|     def create_field(self, model, field, keep_default=False): | ||||
|     def add_field(self, model, field, keep_default=False): | ||||
|         """ | ||||
|         Creates a field on a model. | ||||
|         Usually involves adding a column, but may involve adding a | ||||
| @@ -325,7 +325,7 @@ class BaseDatabaseSchemaEditor(object): | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|     def delete_field(self, model, field): | ||||
|     def remove_field(self, model, field): | ||||
|         """ | ||||
|         Removes a field from a model. Usually involves deleting a column, | ||||
|         but for M2Ms may involve deleting a table. | ||||
|   | ||||
| @@ -73,7 +73,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|         if restore_pk_field: | ||||
|             restore_pk_field.primary_key = True | ||||
|  | ||||
|     def create_field(self, model, field): | ||||
|     def add_field(self, model, field): | ||||
|         """ | ||||
|         Creates a field on a model. | ||||
|         Usually involves adding a column, but may involve adding a | ||||
| @@ -89,7 +89,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): | ||||
|             raise ValueError("You cannot add a null=False column without a default value on SQLite.") | ||||
|         self._remake_table(model, create_fields=[field]) | ||||
|  | ||||
|     def delete_field(self, model, field): | ||||
|     def remove_field(self, model, field): | ||||
|         """ | ||||
|         Removes a field from a model. Usually involves deleting a column, | ||||
|         but for M2Ms may involve deleting a table. | ||||
|   | ||||
| @@ -1 +1,2 @@ | ||||
| from .migration import Migration | ||||
| from .operations import * | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from django.utils.datastructures import SortedSet | ||||
| from django.db.migrations.state import ProjectState | ||||
|  | ||||
|  | ||||
| class MigrationGraph(object): | ||||
| @@ -33,8 +34,10 @@ class MigrationGraph(object): | ||||
|         self.nodes[node] = implementation | ||||
|  | ||||
|     def add_dependency(self, child, parent): | ||||
|         self.nodes[child] = None | ||||
|         self.nodes[parent] = None | ||||
|         if child not in self.nodes: | ||||
|             raise KeyError("Dependency references nonexistent child node %r" % (child,)) | ||||
|         if parent not in self.nodes: | ||||
|             raise KeyError("Dependency references nonexistent parent node %r" % (parent,)) | ||||
|         self.dependencies.setdefault(child, set()).add(parent) | ||||
|         self.dependents.setdefault(parent, set()).add(child) | ||||
|  | ||||
| @@ -117,6 +120,16 @@ class MigrationGraph(object): | ||||
|     def __str__(self): | ||||
|         return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values())) | ||||
|  | ||||
|     def project_state(self, node): | ||||
|         """ | ||||
|         Given a migration node, returns a complete ProjectState for it. | ||||
|         """ | ||||
|         plan = self.forwards_plan(node) | ||||
|         project_state = ProjectState() | ||||
|         for node in plan: | ||||
|             project_state = self.nodes[node].mutate_state(project_state) | ||||
|         return project_state | ||||
|  | ||||
|  | ||||
| class CircularDependencyError(Exception): | ||||
|     """ | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import os | ||||
| from django.utils.importlib import import_module | ||||
| from django.utils.functional import cached_property | ||||
| from django.db.models.loading import cache | ||||
| from django.db.migrations.recorder import MigrationRecorder | ||||
| from django.db.migrations.graph import MigrationGraph | ||||
| @@ -64,9 +65,10 @@ class MigrationLoader(object): | ||||
|                 migration_module = import_module("%s.%s" % (module_name, migration_name)) | ||||
|                 if not hasattr(migration_module, "Migration"): | ||||
|                     raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label)) | ||||
|                 self.disk_migrations[app_label, migration_name] = migration_module.Migration | ||||
|                 self.disk_migrations[app_label, migration_name] = migration_module.Migration(migration_name, app_label) | ||||
|  | ||||
|     def build_graph(self): | ||||
|     @cached_property | ||||
|     def graph(self): | ||||
|         """ | ||||
|         Builds a migration dependency graph using both the disk and database. | ||||
|         """ | ||||
| @@ -116,6 +118,7 @@ class MigrationLoader(object): | ||||
|         graph = MigrationGraph() | ||||
|         for key, migration in normal.items(): | ||||
|             graph.add_node(key, migration) | ||||
|         for key, migration in normal.items(): | ||||
|             for parent in migration.dependencies: | ||||
|                 graph.add_dependency(key, parent) | ||||
|         return graph | ||||
|   | ||||
| @@ -10,6 +10,9 @@ class Migration(object): | ||||
|      - dependencies: A list of tuples of (app_path, migration_name) | ||||
|      - run_before: A list of tuples of (app_path, migration_name) | ||||
|      - replaces: A list of migration_names | ||||
|  | ||||
|     Note that all migrations come out of migrations and into the Loader or | ||||
|     Graph as instances, having been initialised with their app label and name. | ||||
|     """ | ||||
|  | ||||
|     # Operations to apply during this migration, in order. | ||||
| @@ -28,3 +31,17 @@ class Migration(object): | ||||
|     # non-empty, this migration will only be applied if all these migrations | ||||
|     # are not applied. | ||||
|     replaces = [] | ||||
|  | ||||
|     def __init__(self, name, app_label): | ||||
|         self.name = name | ||||
|         self.app_label = app_label | ||||
|  | ||||
|     def mutate_state(self, project_state): | ||||
|         """ | ||||
|         Takes a ProjectState and returns a new one with the migration's | ||||
|         operations applied to it. | ||||
|         """ | ||||
|         new_state = project_state.clone() | ||||
|         for operation in self.operations: | ||||
|             operation.state_forwards(self.app_label, new_state) | ||||
|         return new_state | ||||
|   | ||||
| @@ -1 +1,2 @@ | ||||
| from .models import CreateModel, DeleteModel | ||||
| from .fields import AddField, RemoveField | ||||
|   | ||||
| @@ -15,21 +15,21 @@ class Operation(object): | ||||
|     # Some operations are impossible to reverse, like deleting data. | ||||
|     reversible = True | ||||
|  | ||||
|     def state_forwards(self, app, state): | ||||
|     def state_forwards(self, app_label, state): | ||||
|         """ | ||||
|         Takes the state from the previous migration, and mutates it | ||||
|         so that it matches what this migration would perform. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def database_forwards(self, app, schema_editor, from_state, to_state): | ||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         """ | ||||
|         Performs the mutation on the database schema in the normal | ||||
|         (forwards) direction. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def database_backwards(self, app, schema_editor, from_state, to_state): | ||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         """ | ||||
|         Performs the mutation on the database schema in the reverse | ||||
|         direction - e.g. if this were CreateModel, it would in fact | ||||
|   | ||||
							
								
								
									
										52
									
								
								django/db/migrations/operations/fields.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								django/db/migrations/operations/fields.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| from .base import Operation | ||||
|  | ||||
|  | ||||
| class AddField(Operation): | ||||
|     """ | ||||
|     Adds a field to a model. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, model_name, name, instance): | ||||
|         self.model_name = model_name | ||||
|         self.name = name | ||||
|         self.instance = instance | ||||
|  | ||||
|     def state_forwards(self, app_label, state): | ||||
|         state.models[app_label, self.model_name.lower()].fields.append((self.name, self.instance)) | ||||
|  | ||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = to_state.render() | ||||
|         model = app_cache.get_model(app_label, self.name) | ||||
|         schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) | ||||
|  | ||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = from_state.render() | ||||
|         model = app_cache.get_model(app_label, self.name) | ||||
|         schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) | ||||
|  | ||||
|  | ||||
| class RemoveField(Operation): | ||||
|     """ | ||||
|     Removes a field from a model. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, model_name, name): | ||||
|         self.model_name = model_name | ||||
|         self.name = name | ||||
|  | ||||
|     def state_forwards(self, app_label, state): | ||||
|         new_fields = [] | ||||
|         for name, instance in state.models[app_label, self.model_name.lower()].fields: | ||||
|             if name != self.name: | ||||
|                 new_fields.append((name, instance)) | ||||
|         state.models[app_label, self.model_name.lower()].fields = new_fields | ||||
|  | ||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = from_state.render() | ||||
|         model = app_cache.get_model(app_label, self.name) | ||||
|         schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) | ||||
|  | ||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = to_state.render() | ||||
|         model = app_cache.get_model(app_label, self.name) | ||||
|         schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) | ||||
| @@ -1,4 +1,5 @@ | ||||
| from .base import Operation | ||||
| from django.db import models | ||||
| from django.db.migrations.state import ModelState | ||||
|  | ||||
|  | ||||
| @@ -7,20 +8,39 @@ class CreateModel(Operation): | ||||
|     Create a model's table. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, name): | ||||
|     def __init__(self, name, fields, options=None, bases=None): | ||||
|         self.name = name | ||||
|         self.fields = fields | ||||
|         self.options = options or {} | ||||
|         self.bases = bases or (models.Model,) | ||||
|  | ||||
|     def state_forwards(self, app, state): | ||||
|         state.models[app, self.name.lower()] = ModelState(state, app, self.name) | ||||
|     def state_forwards(self, app_label, state): | ||||
|         state.models[app_label, self.name.lower()] = ModelState(app_label, self.name, self.fields, self.options, self.bases) | ||||
|  | ||||
|     def database_forwards(self, app, schema_editor, from_state, to_state): | ||||
|         app_cache = to_state.render() | ||||
|         schema_editor.create_model(app_cache.get_model(app, self.name)) | ||||
|  | ||||
|     def database_backwards(self, app, schema_editor, from_state, to_state): | ||||
|         """ | ||||
|         Performs the mutation on the database schema in the reverse | ||||
|         direction - e.g. if this were CreateModel, it would in fact | ||||
|         drop the model's table. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|         app_cache = from_state.render() | ||||
|         schema_editor.delete_model(app_cache.get_model(app, self.name)) | ||||
|  | ||||
|  | ||||
| class DeleteModel(Operation): | ||||
|     """ | ||||
|     Drops a model's table. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|  | ||||
|     def state_forwards(self, app_label, state): | ||||
|         del state.models[app_label, self.name.lower()] | ||||
|  | ||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = from_state.render() | ||||
|         schema_editor.delete_model(app_cache.get_model(app_label, self.name)) | ||||
|  | ||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||
|         app_cache = to_state.render() | ||||
|         schema_editor.create_model(app_cache.get_model(app_label, self.name)) | ||||
|   | ||||
| @@ -21,7 +21,7 @@ class ProjectState(object): | ||||
|     def clone(self): | ||||
|         "Returns an exact copy of this ProjectState" | ||||
|         return ProjectState( | ||||
|             models = dict((k, v.copy()) for k, v in self.models.items()) | ||||
|             models = dict((k, v.clone()) for k, v in self.models.items()) | ||||
|         ) | ||||
|  | ||||
|     def render(self): | ||||
| @@ -49,12 +49,15 @@ class ModelState(object): | ||||
|     mutate this one and then render it into a Model as required. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, app_label, name, fields=None, options=None, bases=None): | ||||
|     def __init__(self, app_label, name, fields, options=None, bases=None): | ||||
|         self.app_label = app_label | ||||
|         self.name = name | ||||
|         self.fields = fields or [] | ||||
|         self.fields = fields | ||||
|         self.options = options or {} | ||||
|         self.bases = bases or (models.Model, ) | ||||
|         # Sanity-check that fields is NOT a dict. It must be ordered. | ||||
|         if isinstance(self.fields, dict): | ||||
|             raise ValueError("ModelState.fields cannot be a dict - it must be a list of 2-tuples.") | ||||
|  | ||||
|     @classmethod | ||||
|     def from_model(cls, model): | ||||
|   | ||||
| @@ -1,5 +1,27 @@ | ||||
| from django.db import migrations | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|     pass | ||||
|      | ||||
|     operations = [ | ||||
|  | ||||
|         migrations.CreateModel( | ||||
|             "Author", | ||||
|             [ | ||||
|                 ("id", models.AutoField(primary_key=True)), | ||||
|                 ("name", models.CharField(max_length=255)), | ||||
|                 ("slug", models.SlugField(null=True)), | ||||
|                 ("age", models.IntegerField(default=0)), | ||||
|                 ("silly_field", models.BooleanField()), | ||||
|             ], | ||||
|         ), | ||||
|  | ||||
|         migrations.CreateModel( | ||||
|             "Tribble", | ||||
|             [ | ||||
|                 ("id", models.AutoField(primary_key=True)), | ||||
|                 ("fluffy", models.BooleanField(default=True)), | ||||
|             ], | ||||
|         ) | ||||
|          | ||||
|     ] | ||||
|   | ||||
| @@ -1,6 +1,24 @@ | ||||
| from django.db import migrations | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [("migrations", "0001_initial")] | ||||
|  | ||||
|     operations = [ | ||||
|  | ||||
|         migrations.DeleteModel("Tribble"), | ||||
|  | ||||
|         migrations.RemoveField("Author", "silly_field"), | ||||
|  | ||||
|         migrations.AddField("Author", "important", models.BooleanField()), | ||||
|  | ||||
|         migrations.CreateModel( | ||||
|             "Book", | ||||
|             [ | ||||
|                 ("id", models.AutoField(primary_key=True)), | ||||
|                 ("author", models.ForeignKey("migrations.Author", null=True)), | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     ] | ||||
|   | ||||
| @@ -1,11 +1,8 @@ | ||||
| from django.test import TransactionTestCase, TestCase | ||||
| from django.db import connection | ||||
| from django.test import TestCase | ||||
| from django.db.migrations.graph import MigrationGraph, CircularDependencyError | ||||
| from django.db.migrations.loader import MigrationLoader | ||||
| from django.db.migrations.recorder import MigrationRecorder | ||||
|  | ||||
|  | ||||
| class GraphTests(TransactionTestCase): | ||||
| class GraphTests(TestCase): | ||||
|     """ | ||||
|     Tests the digraph structure. | ||||
|     """ | ||||
| @@ -117,20 +114,3 @@ class GraphTests(TransactionTestCase): | ||||
|             CircularDependencyError, | ||||
|             graph.forwards_plan, ("app_a", "0003"), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LoaderTests(TransactionTestCase): | ||||
|     """ | ||||
|     Tests the disk and database loader. | ||||
|     """ | ||||
|  | ||||
|     def test_load(self): | ||||
|         """ | ||||
|         Makes sure the loader can load the migrations for the test apps. | ||||
|         """ | ||||
|         migration_loader = MigrationLoader(connection) | ||||
|         graph = migration_loader.build_graph() | ||||
|         self.assertEqual( | ||||
|             graph.forwards_plan(("migrations", "0002_second")), | ||||
|             [("migrations", "0001_initial"), ("migrations", "0002_second")], | ||||
|         ) | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from django.test import TestCase | ||||
| from django.test import TestCase, TransactionTestCase | ||||
| from django.db import connection | ||||
| from django.db.migrations.loader import MigrationLoader | ||||
| from django.db.migrations.recorder import MigrationRecorder | ||||
|  | ||||
|  | ||||
| class RecorderTests(TestCase): | ||||
|     """ | ||||
|     Tests the disk and database loader. | ||||
|     Tests recording migrations as applied or not. | ||||
|     """ | ||||
|  | ||||
|     def test_apply(self): | ||||
| @@ -27,3 +28,37 @@ class RecorderTests(TestCase): | ||||
|             recorder.applied_migrations(), | ||||
|             set(), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LoaderTests(TransactionTestCase): | ||||
|     """ | ||||
|     Tests the disk and database loader, and running through migrations | ||||
|     in memory. | ||||
|     """ | ||||
|  | ||||
|     def test_load(self): | ||||
|         """ | ||||
|         Makes sure the loader can load the migrations for the test apps, | ||||
|         and then render them out to a new AppCache. | ||||
|         """ | ||||
|         # Load and test the plan | ||||
|         migration_loader = MigrationLoader(connection) | ||||
|         self.assertEqual( | ||||
|             migration_loader.graph.forwards_plan(("migrations", "0002_second")), | ||||
|             [("migrations", "0001_initial"), ("migrations", "0002_second")], | ||||
|         ) | ||||
|         # Now render it out! | ||||
|         project_state = migration_loader.graph.project_state(("migrations", "0002_second")) | ||||
|         self.assertEqual(len(project_state.models), 2) | ||||
|  | ||||
|         author_state = project_state.models["migrations", "author"] | ||||
|         self.assertEqual( | ||||
|             [x for x, y in author_state.fields], | ||||
|             ["id", "name", "slug", "age", "important"] | ||||
|         ) | ||||
|  | ||||
|         book_state = project_state.models["migrations", "book"] | ||||
|         self.assertEqual( | ||||
|             [x for x, y in book_state.fields], | ||||
|             ["id", "author"] | ||||
|         ) | ||||
|   | ||||
| @@ -132,7 +132,7 @@ class SchemaTests(TransactionTestCase): | ||||
|         else: | ||||
|             self.fail("No FK constraint for author_id found") | ||||
|  | ||||
|     def test_create_field(self): | ||||
|     def test_add_field(self): | ||||
|         """ | ||||
|         Tests adding fields to models | ||||
|         """ | ||||
| @@ -146,7 +146,7 @@ class SchemaTests(TransactionTestCase): | ||||
|         new_field = IntegerField(null=True) | ||||
|         new_field.set_attributes_from_name("age") | ||||
|         with connection.schema_editor() as editor: | ||||
|             editor.create_field( | ||||
|             editor.add_field( | ||||
|                 Author, | ||||
|                 new_field, | ||||
|             ) | ||||
| @@ -251,7 +251,7 @@ class SchemaTests(TransactionTestCase): | ||||
|             connection.rollback() | ||||
|             # Add the field | ||||
|             with connection.schema_editor() as editor: | ||||
|                 editor.create_field( | ||||
|                 editor.add_field( | ||||
|                     Author, | ||||
|                     new_field, | ||||
|                 ) | ||||
| @@ -260,7 +260,7 @@ class SchemaTests(TransactionTestCase): | ||||
|             self.assertEqual(columns['tag_id'][0], "IntegerField") | ||||
|             # Remove the M2M table again | ||||
|             with connection.schema_editor() as editor: | ||||
|                 editor.delete_field( | ||||
|                 editor.remove_field( | ||||
|                     Author, | ||||
|                     new_field, | ||||
|                 ) | ||||
| @@ -530,7 +530,7 @@ class SchemaTests(TransactionTestCase): | ||||
|         ) | ||||
|         # Add a unique column, verify that creates an implicit index | ||||
|         with connection.schema_editor() as editor: | ||||
|             editor.create_field( | ||||
|             editor.add_field( | ||||
|                 Book, | ||||
|                 BookWithSlug._meta.get_field_by_name("slug")[0], | ||||
|             ) | ||||
| @@ -568,7 +568,7 @@ class SchemaTests(TransactionTestCase): | ||||
|         new_field = SlugField(primary_key=True) | ||||
|         new_field.set_attributes_from_name("slug") | ||||
|         with connection.schema_editor() as editor: | ||||
|             editor.delete_field(Tag, Tag._meta.get_field_by_name("id")[0]) | ||||
|             editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) | ||||
|             editor.alter_field( | ||||
|                 Tag, | ||||
|                 Tag._meta.get_field_by_name("slug")[0], | ||||
|   | ||||
		Reference in New Issue
	
	Block a user