diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py index 9f29d72f49..97475de6f7 100644 --- a/django/contrib/postgres/apps.py +++ b/django/contrib/postgres/apps.py @@ -1,13 +1,21 @@ +from psycopg2.extras import ( + DateRange, DateTimeRange, DateTimeTZRange, NumericRange, +) + from django.apps import AppConfig from django.db import connections from django.db.backends.signals import connection_created +from django.db.migrations.writer import MigrationWriter from django.db.models import CharField, TextField from django.test.signals import setting_changed from django.utils.translation import gettext_lazy as _ from .lookups import SearchLookup, TrigramSimilar, Unaccent +from .serializers import RangeSerializer from .signals import register_type_handlers +RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange) + def uninstall_if_needed(setting, value, enter, **kwargs): """ @@ -26,6 +34,7 @@ def uninstall_if_needed(setting, value, enter, **kwargs): # and ready() connects it again to prevent unnecessary processing on # each setting change. setting_changed.disconnect(uninstall_if_needed) + MigrationWriter.unregister_serializer(RANGE_TYPES) class PostgresConfig(AppConfig): @@ -54,3 +63,4 @@ class PostgresConfig(AppConfig): TextField.register_lookup(SearchLookup) CharField.register_lookup(TrigramSimilar) TextField.register_lookup(TrigramSimilar) + MigrationWriter.register_serializer(RANGE_TYPES, RangeSerializer) diff --git a/django/contrib/postgres/serializers.py b/django/contrib/postgres/serializers.py new file mode 100644 index 0000000000..1b1c2f1112 --- /dev/null +++ b/django/contrib/postgres/serializers.py @@ -0,0 +1,10 @@ +from django.db.migrations.serializer import BaseSerializer + + +class RangeSerializer(BaseSerializer): + def serialize(self): + module = self.value.__class__.__module__ + # Ranges are implemented in psycopg2._range but the public import + # location is psycopg2.extras. + module = 'psycopg2.extras' if module == 'psycopg2._range' else module + return '%s.%r' % (module, self.value), {'import %s' % module} diff --git a/tests/postgres_tests/test_apps.py b/tests/postgres_tests/test_apps.py index a5740f9d15..7b56c8f716 100644 --- a/tests/postgres_tests/test_apps.py +++ b/tests/postgres_tests/test_apps.py @@ -1,8 +1,19 @@ from django.db.backends.signals import connection_created +from django.db.migrations.writer import MigrationWriter from django.test.utils import modify_settings from . import PostgreSQLTestCase +try: + from psycopg2.extras import ( + DateRange, DateTimeRange, DateTimeTZRange, NumericRange, + ) + from django.contrib.postgres.fields import ( + DateRangeField, DateTimeRangeField, IntegerRangeField, + ) +except ImportError: + pass + class PostgresConfigTests(PostgreSQLTestCase): def test_register_type_handlers_connection(self): @@ -11,3 +22,38 @@ class PostgresConfigTests(PostgreSQLTestCase): with modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}): self.assertIn(register_type_handlers, connection_created._live_receivers(None)) self.assertNotIn(register_type_handlers, connection_created._live_receivers(None)) + + def test_register_serializer_for_migrations(self): + tests = ( + (DateRange(empty=True), DateRangeField), + (DateTimeRange(empty=True), DateRangeField), + (DateTimeTZRange(None, None, '[]'), DateTimeRangeField), + (NumericRange(1, 10), IntegerRangeField), + ) + + def assertNotSerializable(): + for default, test_field in tests: + with self.subTest(default=default): + field = test_field(default=default) + with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__): + MigrationWriter.serialize(field) + + assertNotSerializable() + with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}): + for default, test_field in tests: + with self.subTest(default=default): + field = test_field(default=default) + serialized_field, imports = MigrationWriter.serialize(field) + self.assertEqual(imports, { + 'import django.contrib.postgres.fields.ranges', + 'import psycopg2.extras', + }) + self.assertIn( + '%s.%s(default=psycopg2.extras.%r)' % ( + field.__module__, + field.__class__.__name__, + default, + ), + serialized_field + ) + assertNotSerializable()