diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py index 3bd9546bd7..87410fc650 100644 --- a/django/db/migrations/operations/base.py +++ b/django/db/migrations/operations/base.py @@ -1,6 +1,7 @@ import enum from django.db import router +from django.utils.inspect import get_func_args class OperationCategory(str, enum.Enum): @@ -52,6 +53,16 @@ class Operation: self._constructor_args = (args, kwargs) return self + def __replace__(self, /, **changes): + args = [ + changes.pop(name, value) + for name, value in zip( + get_func_args(self.__class__), + self._constructor_args[0], + ) + ] + return self.__class__(*args, **(self._constructor_args[1] | changes)) + def deconstruct(self): """ Return a 3-tuple of class import path (or just name if it lives diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 34b441a247..732cccd12c 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -1,5 +1,6 @@ from django.db.migrations.utils import field_references from django.db.models import NOT_PROVIDED +from django.utils.copy import replace from django.utils.functional import cached_property from .base import Operation, OperationCategory @@ -134,8 +135,8 @@ class AddField(FieldOperation): ): if isinstance(operation, AlterField): return [ - AddField( - model_name=self.model_name, + replace( + self, name=operation.name, field=operation.field, ), @@ -143,13 +144,7 @@ class AddField(FieldOperation): elif isinstance(operation, RemoveField): return [] elif isinstance(operation, RenameField): - return [ - AddField( - model_name=self.model_name, - name=operation.new_name, - field=self.field, - ), - ] + return [replace(self, name=operation.new_name)] return super().reduce(operation, app_label) @@ -264,11 +259,7 @@ class AlterField(FieldOperation): ): return [ operation, - AlterField( - model_name=self.model_name, - name=operation.new_name, - field=self.field, - ), + replace(self, name=operation.new_name), ] return super().reduce(operation, app_label) @@ -350,13 +341,7 @@ class RenameField(FieldOperation): and self.is_same_model_operation(operation) and self.new_name_lower == operation.old_name_lower ): - return [ - RenameField( - self.model_name, - self.old_name, - operation.new_name, - ), - ] + return [replace(self, new_name=operation.new_name)] # Skip `FieldOperation.reduce` as we want to run `references_field` # against self.old_name and self.new_name. return super(FieldOperation, self).reduce(operation, app_label) or not ( diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 266a3efadc..366fe58047 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -1,8 +1,11 @@ +from copy import copy + from django.db import models from django.db.migrations.operations.base import Operation, OperationCategory from django.db.migrations.state import ModelState from django.db.migrations.utils import field_references, resolve_relation from django.db.models.options import normalize_together +from django.utils.copy import replace from django.utils.functional import cached_property from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField @@ -156,15 +159,7 @@ class CreateModel(ModelOperation): isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower ): - return [ - CreateModel( - operation.new_name, - fields=self.fields, - options=self.options, - bases=self.bases, - managers=self.managers, - ), - ] + return [replace(self, name=operation.new_name)] elif ( isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower @@ -173,42 +168,20 @@ class CreateModel(ModelOperation): for key in operation.ALTER_OPTION_KEYS: if key not in operation.options: options.pop(key, None) - return [ - CreateModel( - self.name, - fields=self.fields, - options=options, - bases=self.bases, - managers=self.managers, - ), - ] + return [replace(self, options=options)] elif ( isinstance(operation, AlterModelManagers) and self.name_lower == operation.name_lower ): - return [ - CreateModel( - self.name, - fields=self.fields, - options=self.options, - bases=self.bases, - managers=operation.managers, - ), - ] + return [replace(self, managers=operation.managers)] elif ( isinstance(operation, AlterModelTable) and self.name_lower == operation.name_lower ): return [ - CreateModel( - self.name, - fields=self.fields, - options={ - **self.options, - "db_table": operation.table, - }, - bases=self.bases, - managers=self.managers, + replace( + self, + options={**self.options, "db_table": operation.table}, ), ] elif ( @@ -216,15 +189,12 @@ class CreateModel(ModelOperation): and self.name_lower == operation.name_lower ): return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "db_table_comment": operation.table_comment, }, - bases=self.bases, - managers=self.managers, ), ] elif ( @@ -232,15 +202,12 @@ class CreateModel(ModelOperation): and self.name_lower == operation.name_lower ): return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, **{operation.option_name: operation.option_value}, }, - bases=self.bases, - managers=self.managers, ), ] elif ( @@ -248,15 +215,12 @@ class CreateModel(ModelOperation): and self.name_lower == operation.name_lower ): return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "order_with_respect_to": operation.order_with_respect_to, }, - bases=self.bases, - managers=self.managers, ), ] elif ( @@ -265,25 +229,19 @@ class CreateModel(ModelOperation): ): if isinstance(operation, AddField): return [ - CreateModel( - self.name, + replace( + self, fields=self.fields + [(operation.name, operation.field)], - options=self.options, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, AlterField): return [ - CreateModel( - self.name, + replace( + self, fields=[ (n, operation.field if n == operation.name else v) for n, v in self.fields ], - options=self.options, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, RemoveField): @@ -308,16 +266,14 @@ class CreateModel(ModelOperation): if order_with_respect_to == operation.name_lower: del options["order_with_respect_to"] return [ - CreateModel( - self.name, + replace( + self, fields=[ (n, v) for n, v in self.fields if n.lower() != operation.name_lower ], options=options, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, RenameField): @@ -336,15 +292,13 @@ class CreateModel(ModelOperation): if order_with_respect_to == operation.old_name: options["order_with_respect_to"] = operation.new_name return [ - CreateModel( - self.name, + replace( + self, fields=[ (operation.new_name if n == operation.old_name else n, v) for n, v in self.fields ], options=options, - bases=self.bases, - managers=self.managers, ), ] elif ( @@ -353,9 +307,8 @@ class CreateModel(ModelOperation): ): if isinstance(operation, AddIndex): return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "indexes": [ @@ -363,8 +316,6 @@ class CreateModel(ModelOperation): operation.index, ], }, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, RemoveIndex): @@ -374,22 +325,18 @@ class CreateModel(ModelOperation): if index.name != operation.name ] return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "indexes": options_indexes, }, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, AddConstraint): return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "constraints": [ @@ -397,8 +344,6 @@ class CreateModel(ModelOperation): operation.constraint, ], }, - bases=self.bases, - managers=self.managers, ), ] elif isinstance(operation, RemoveConstraint): @@ -408,15 +353,12 @@ class CreateModel(ModelOperation): if constraint.name != operation.name ] return [ - CreateModel( - self.name, - fields=self.fields, + replace( + self, options={ **self.options, "constraints": options_constraints, }, - bases=self.bases, - managers=self.managers, ), ] return super().reduce(operation, app_label) @@ -567,12 +509,7 @@ class RenameModel(ModelOperation): isinstance(operation, RenameModel) and self.new_name_lower == operation.old_name_lower ): - return [ - RenameModel( - self.old_name, - operation.new_name, - ), - ] + return [replace(self, new_name=operation.new_name)] # Skip `ModelOperation.reduce` as we want to run `references_model` # against self.new_name. return super(ModelOperation, self).reduce( @@ -990,8 +927,9 @@ class AddIndex(IndexOperation): if isinstance(operation, RemoveIndex) and self.index.name == operation.name: return [] if isinstance(operation, RenameIndex) and self.index.name == operation.old_name: - self.index.name = operation.new_name - return [self.__class__(model_name=self.model_name, index=self.index)] + index = copy(self.index) + index.name = operation.new_name + return [replace(self, index=index)] return super().reduce(operation, app_label) @@ -1183,14 +1121,7 @@ class RenameIndex(IndexOperation): and operation.old_name and self.new_name_lower == operation.old_name_lower ): - return [ - RenameIndex( - self.model_name, - new_name=operation.new_name, - old_name=self.old_name, - old_fields=self.old_fields, - ) - ] + return [replace(self, new_name=operation.new_name)] return super().reduce(operation, app_label) @@ -1247,7 +1178,7 @@ class AddConstraint(IndexOperation): and self.model_name_lower == operation.model_name_lower and self.constraint.name == operation.name ): - return [AddConstraint(self.model_name, operation.constraint)] + return [replace(self, constraint=operation.constraint)] return super().reduce(operation, app_label) diff --git a/django/utils/copy.py b/django/utils/copy.py new file mode 100644 index 0000000000..dd0cd729b0 --- /dev/null +++ b/django/utils/copy.py @@ -0,0 +1,17 @@ +from django.utils.version import PY313 + +if PY313: + from copy import replace +else: + # Backport of copy.replace() from Python 3.13. + def replace(obj, /, **changes): + """Return a new object replacing specified fields with new values. + + This is especially useful for immutable objects, like named tuples or + frozen dataclasses. + """ + cls = obj.__class__ + func = getattr(cls, "__replace__", None) + if func is None: + raise TypeError(f"replace() does not support {cls.__name__} objects") + return func(obj, **changes)