diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py index d395313ff4..ace0a860c4 100644 --- a/django/db/migrations/serializer.py +++ b/django/db/migrations/serializer.py @@ -8,6 +8,7 @@ import math import re import types import uuid +from collections import OrderedDict from django.conf import SettingsReference from django.db import models @@ -271,6 +272,38 @@ class UUIDSerializer(BaseSerializer): return "uuid.%s" % repr(self.value), {"import uuid"} +class Serializer: + _registry = OrderedDict([ + (frozenset, FrozensetSerializer), + (list, SequenceSerializer), + (set, SetSerializer), + (tuple, TupleSerializer), + (dict, DictionarySerializer), + (enum.Enum, EnumSerializer), + (datetime.datetime, DatetimeDatetimeSerializer), + ((datetime.date, datetime.timedelta, datetime.time), DateTimeSerializer), + (SettingsReference, SettingsReferenceSerializer), + (float, FloatSerializer), + ((bool, int, type(None), bytes, str), BaseSimpleSerializer), + (decimal.Decimal, DecimalSerializer), + ((functools.partial, functools.partialmethod), FunctoolsPartialSerializer), + ((types.FunctionType, types.BuiltinFunctionType, types.MethodType), FunctionTypeSerializer), + (collections.abc.Iterable, IterableSerializer), + ((COMPILED_REGEX_TYPE, RegexObject), RegexSerializer), + (uuid.UUID, UUIDSerializer), + ]) + + @classmethod + def register(cls, type_, serializer): + if not issubclass(serializer, BaseSerializer): + raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__) + cls._registry[type_] = serializer + + @classmethod + def unregister(cls, type_): + cls._registry.pop(type_) + + def serializer_factory(value): if isinstance(value, Promise): value = str(value) @@ -290,42 +323,9 @@ def serializer_factory(value): # Anything that knows how to deconstruct itself. if hasattr(value, 'deconstruct'): return DeconstructableSerializer(value) - - # Unfortunately some of these are order-dependent. - if isinstance(value, frozenset): - return FrozensetSerializer(value) - if isinstance(value, list): - return SequenceSerializer(value) - if isinstance(value, set): - return SetSerializer(value) - if isinstance(value, tuple): - return TupleSerializer(value) - if isinstance(value, dict): - return DictionarySerializer(value) - if isinstance(value, enum.Enum): - return EnumSerializer(value) - if isinstance(value, datetime.datetime): - return DatetimeDatetimeSerializer(value) - if isinstance(value, (datetime.date, datetime.timedelta, datetime.time)): - return DateTimeSerializer(value) - if isinstance(value, SettingsReference): - return SettingsReferenceSerializer(value) - if isinstance(value, float): - return FloatSerializer(value) - if isinstance(value, (bool, int, type(None), bytes, str)): - return BaseSimpleSerializer(value) - if isinstance(value, decimal.Decimal): - return DecimalSerializer(value) - if isinstance(value, (functools.partial, functools.partialmethod)): - return FunctoolsPartialSerializer(value) - if isinstance(value, (types.FunctionType, types.BuiltinFunctionType, types.MethodType)): - return FunctionTypeSerializer(value) - if isinstance(value, collections.abc.Iterable): - return IterableSerializer(value) - if isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)): - return RegexSerializer(value) - if isinstance(value, uuid.UUID): - return UUIDSerializer(value) + for type_, serializer_cls in Serializer._registry.items(): + if isinstance(value, type_): + return serializer_cls(value) raise ValueError( "Cannot serialize: %r\nThere are some values Django cannot serialize into " "migration files.\nFor more, see https://docs.djangoproject.com/en/%s/" diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 1e001da4e6..047436ffab 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -8,7 +8,7 @@ from django.apps import apps from django.conf import SettingsReference # NOQA from django.db import migrations from django.db.migrations.loader import MigrationLoader -from django.db.migrations.serializer import serializer_factory +from django.db.migrations.serializer import Serializer, serializer_factory from django.utils.inspect import get_func_args from django.utils.module_loading import module_dir from django.utils.timezone import now @@ -270,6 +270,14 @@ class MigrationWriter: def serialize(cls, value): return serializer_factory(value).serialize() + @classmethod + def register_serializer(cls, type_, serializer): + Serializer.register(type_, serializer) + + @classmethod + def unregister_serializer(cls, type_): + Serializer.unregister(type_) + MIGRATION_HEADER_TEMPLATE = """\ # Generated by Django %(version)s on %(timestamp)s diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 87a7ef4931..c371d50281 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -211,6 +211,9 @@ Migrations * ``NoneType`` can now be serialized in migrations. +* You can now :ref:`register custom serializers ` + for migrations. + Models ~~~~~~ diff --git a/docs/topics/migrations.txt b/docs/topics/migrations.txt index b44f78cc69..2f33b97878 100644 --- a/docs/topics/migrations.txt +++ b/docs/topics/migrations.txt @@ -697,6 +697,35 @@ Django cannot serialize: - Arbitrary class instances (e.g. ``MyClass(4.3, 5.7)``) - Lambdas +.. _custom-migration-serializers: + +Custom serializers +------------------ + +.. versionadded:: 2.2 + +You can serialize other types by writing a custom serializer. For example, if +Django didn't serialize :class:`~decimal.Decimal` by default, you could do +this:: + + from decimal import Decimal + + from django.db.migrations.serializer import BaseSerializer + from django.db.migrations.writer import MigrationWriter + + class DecimalSerializer(BaseSerializer): + def serialize(self): + return repr(self.value), {'from decimal import Decimal'} + + MigrationWriter.register_serializer(Decimal, DecimalSerializer) + +The first argument of ``MigrationWriter.register_serializer()`` is a type or +iterable of types that should use the serializer. + +The ``serialize()`` method of your serializer must return a string of how the +value should appear in migrations and a set of any imports that are needed in +the migration. + .. _custom-deconstruct-method: Adding a ``deconstruct()`` method diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 8e30342763..abeeaf5182 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -15,6 +15,7 @@ from django import get_version from django.conf import SettingsReference, settings from django.core.validators import EmailValidator, RegexValidator from django.db import migrations, models +from django.db.migrations.serializer import BaseSerializer from django.db.migrations.writer import MigrationWriter, OperationWriter from django.test import SimpleTestCase from django.utils.deconstruct import deconstructible @@ -653,3 +654,18 @@ class WriterTests(SimpleTestCase): string = MigrationWriter.serialize(models.CharField(default=DeconstructibleInstances))[0] self.assertEqual(string, "models.CharField(default=migrations.test_writer.DeconstructibleInstances)") + + def test_register_serializer(self): + class ComplexSerializer(BaseSerializer): + def serialize(self): + return 'complex(%r)' % self.value, {} + + MigrationWriter.register_serializer(complex, ComplexSerializer) + self.assertSerializedEqual(complex(1, 2)) + MigrationWriter.unregister_serializer(complex) + with self.assertRaisesMessage(ValueError, 'Cannot serialize: (1+2j)'): + self.assertSerializedEqual(complex(1, 2)) + + def test_register_non_serializer(self): + with self.assertRaisesMessage(ValueError, "'TestModel1' must inherit from 'BaseSerializer'."): + MigrationWriter.register_serializer(complex, TestModel1)