From 36f514f06553ef299001b4e9a5f63ec806a50581 Mon Sep 17 00:00:00 2001 From: Marc Tamlyn <marc.tamlyn@gmail.com> Date: Fri, 14 Mar 2014 17:34:49 +0000 Subject: [PATCH] Added HStoreField. Thanks to `django-hstore` for inspiration in some areas, and many people for reviews. --- django/contrib/postgres/__init__.py | 1 + django/contrib/postgres/apps.py | 13 ++ django/contrib/postgres/fields/__init__.py | 1 + django/contrib/postgres/fields/array.py | 2 +- django/contrib/postgres/fields/hstore.py | 145 ++++++++++++ django/contrib/postgres/forms/__init__.py | 1 + django/contrib/postgres/forms/hstore.py | 37 +++ django/contrib/postgres/operations.py | 34 +++ django/contrib/postgres/signals.py | 25 ++ django/contrib/postgres/validators.py | 51 +++- docs/ref/contrib/postgres/fields.txt | 166 ++++++++++++- docs/ref/contrib/postgres/forms.txt | 20 ++ docs/ref/contrib/postgres/index.txt | 2 + docs/ref/contrib/postgres/operations.txt | 27 +++ docs/ref/contrib/postgres/validators.txt | 20 ++ docs/releases/1.8.txt | 8 + .../migrations/0001_setup_extensions.py | 15 ++ .../migrations/0002_create_test_models.py | 76 ++++++ tests/postgres_tests/migrations/__init__.py | 0 tests/postgres_tests/models.py | 6 +- tests/postgres_tests/test_hstore.py | 218 ++++++++++++++++++ tests/runtests.py | 2 +- 22 files changed, 864 insertions(+), 6 deletions(-) create mode 100644 django/contrib/postgres/apps.py create mode 100644 django/contrib/postgres/fields/hstore.py create mode 100644 django/contrib/postgres/forms/hstore.py create mode 100644 django/contrib/postgres/operations.py create mode 100644 django/contrib/postgres/signals.py create mode 100644 docs/ref/contrib/postgres/operations.txt create mode 100644 docs/ref/contrib/postgres/validators.txt create mode 100644 tests/postgres_tests/migrations/0001_setup_extensions.py create mode 100644 tests/postgres_tests/migrations/0002_create_test_models.py create mode 100644 tests/postgres_tests/migrations/__init__.py create mode 100644 tests/postgres_tests/test_hstore.py diff --git a/django/contrib/postgres/__init__.py b/django/contrib/postgres/__init__.py index e69de29bb2..8723eea638 100644 --- a/django/contrib/postgres/__init__.py +++ b/django/contrib/postgres/__init__.py @@ -0,0 +1 @@ +default_app_config = 'django.contrib.postgres.apps.PostgresConfig' diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py new file mode 100644 index 0000000000..f6fa263a84 --- /dev/null +++ b/django/contrib/postgres/apps.py @@ -0,0 +1,13 @@ +from django.apps import AppConfig +from django.db.backends.signals import connection_created +from django.utils.translation import ugettext_lazy as _ + +from .signals import register_hstore_handler + + +class PostgresConfig(AppConfig): + name = 'django.contrib.postgres' + verbose_name = _('PostgreSQL extensions') + + def ready(self): + connection_created.connect(register_hstore_handler) diff --git a/django/contrib/postgres/fields/__init__.py b/django/contrib/postgres/fields/__init__.py index e3ceebd62c..63a0c2c952 100644 --- a/django/contrib/postgres/fields/__init__.py +++ b/django/contrib/postgres/fields/__init__.py @@ -1 +1,2 @@ from .array import * # NOQA +from .hstore import * # NOQA diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 1a969d7c0e..ab57ccb110 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -168,7 +168,7 @@ class ArrayContainsLookup(Lookup): lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) params = lhs_params + rhs_params - type_cast = self.lhs.source.db_type(connection) + type_cast = self.lhs.output_field.db_type(connection) return '%s @> %s::%s' % (lhs, rhs, type_cast), params diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py new file mode 100644 index 0000000000..ae51ce02ce --- /dev/null +++ b/django/contrib/postgres/fields/hstore.py @@ -0,0 +1,145 @@ +import json + +from django.contrib.postgres import forms +from django.contrib.postgres.fields.array import ArrayField +from django.core import exceptions +from django.db.models import Field, Lookup, Transform, TextField +from django.utils import six +from django.utils.translation import ugettext_lazy as _ + + +__all__ = ['HStoreField'] + + +class HStoreField(Field): + empty_strings_allowed = False + description = _('Map of strings to strings') + default_error_messages = { + 'not_a_string': _('The value of "%(key)s" is not a string.'), + } + + def db_type(self, connection): + return 'hstore' + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): + if lookup_type == 'contains': + return [self.get_prep_value(value)] + return super(HStoreField, self).get_db_prep_lookup(lookup_type, value, + connection, prepared=False) + + def get_transform(self, name): + transform = super(HStoreField, self).get_transform(name) + if transform: + return transform + return KeyTransformFactory(name) + + def validate(self, value, model_instance): + super(HStoreField, self).validate(value, model_instance) + for key, val in value.items(): + if not isinstance(val, six.string_types): + raise exceptions.ValidationError( + self.error_messages['not_a_string'], + code='not_a_string', + params={'key': key}, + ) + + def to_python(self, value): + if isinstance(value, six.string_types): + value = json.loads(value) + return value + + def value_to_string(self, obj): + value = self._get_val_from_obj(obj) + return json.dumps(value) + + def formfield(self, **kwargs): + defaults = { + 'form_class': forms.HStoreField, + } + defaults.update(kwargs) + return super(HStoreField, self).formfield(**defaults) + + +@HStoreField.register_lookup +class HStoreContainsLookup(Lookup): + lookup_name = 'contains' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s @> %s' % (lhs, rhs), params + + +@HStoreField.register_lookup +class HStoreContainedByLookup(Lookup): + lookup_name = 'contained_by' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s <@ %s' % (lhs, rhs), params + + +@HStoreField.register_lookup +class HasKeyLookup(Lookup): + lookup_name = 'has_key' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s ? %s' % (lhs, rhs), params + + +@HStoreField.register_lookup +class HasKeysLookup(Lookup): + lookup_name = 'has_keys' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s ?& %s' % (lhs, rhs), params + + +class KeyTransform(Transform): + output_field = TextField() + + def __init__(self, key_name, *args, **kwargs): + super(KeyTransform, self).__init__(*args, **kwargs) + self.key_name = key_name + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return "%s -> '%s'" % (lhs, self.key_name), params + + +class KeyTransformFactory(object): + + def __init__(self, key_name): + self.key_name = key_name + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, *args, **kwargs) + + +@HStoreField.register_lookup +class KeysTransform(Transform): + lookup_name = 'keys' + output_field = ArrayField(TextField()) + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return 'akeys(%s)' % lhs, params + + +@HStoreField.register_lookup +class ValuesTransform(Transform): + lookup_name = 'values' + output_field = ArrayField(TextField()) + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return 'avals(%s)' % lhs, params diff --git a/django/contrib/postgres/forms/__init__.py b/django/contrib/postgres/forms/__init__.py index e3ceebd62c..63a0c2c952 100644 --- a/django/contrib/postgres/forms/__init__.py +++ b/django/contrib/postgres/forms/__init__.py @@ -1 +1,2 @@ from .array import * # NOQA +from .hstore import * # NOQA diff --git a/django/contrib/postgres/forms/hstore.py b/django/contrib/postgres/forms/hstore.py new file mode 100644 index 0000000000..548be0f570 --- /dev/null +++ b/django/contrib/postgres/forms/hstore.py @@ -0,0 +1,37 @@ +import json + +from django import forms +from django.core.exceptions import ValidationError +from django.utils import six +from django.utils.translation import ugettext_lazy as _ + + +__all__ = ['HStoreField'] + + +class HStoreField(forms.CharField): + """A field for HStore data which accepts JSON input.""" + widget = forms.Textarea + default_error_messages = { + 'invalid_json': _('Could not load JSON data.'), + } + + def prepare_value(self, value): + if isinstance(value, dict): + return json.dumps(value) + return value + + def to_python(self, value): + if not value: + return {} + try: + value = json.loads(value) + except ValueError: + raise ValidationError( + self.error_messages['invalid_json'], + code='invalid_json', + ) + # Cast everything to strings for ease. + for key, val in value.items(): + value[key] = six.text_type(val) + return value diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py new file mode 100644 index 0000000000..e39d63ffa0 --- /dev/null +++ b/django/contrib/postgres/operations.py @@ -0,0 +1,34 @@ +from django.contrib.postgres.signals import register_hstore_handler +from django.db.migrations.operations.base import Operation + + +class CreateExtension(Operation): + reversible = True + + def __init__(self, name): + self.name = name + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % self.name) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + schema_editor.execute("DROP EXTENSION %s" % self.name) + + def describe(self): + return "Creates extension %s" % self.name + + +class HStoreExtension(CreateExtension): + + def __init__(self): + self.name = 'hstore' + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + super(HStoreExtension, self).database_forwards(app_label, schema_editor, from_state, to_state) + # Register hstore straight away as it cannot be done before the + # extension is installed, a subsequent data migration would use the + # same connection + register_hstore_handler(schema_editor.connection) diff --git a/django/contrib/postgres/signals.py b/django/contrib/postgres/signals.py new file mode 100644 index 0000000000..602dd08700 --- /dev/null +++ b/django/contrib/postgres/signals.py @@ -0,0 +1,25 @@ +from django.utils import six + +from psycopg2 import ProgrammingError +from psycopg2.extras import register_hstore + + +def register_hstore_handler(connection, **kwargs): + if connection.vendor != 'postgresql': + return + + try: + if six.PY2: + register_hstore(connection.connection, globally=True, unicode=True) + else: + register_hstore(connection.connection, globally=True) + except ProgrammingError: + # Hstore is not available on the database. + # + # If someone tries to create an hstore field it will error there. + # This is necessary as someone may be using PSQL without extensions + # installed but be using other features of contrib.postgres. + # + # This is also needed in order to create the connection in order to + # install the hstore extension. + pass diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py index 353305949e..19d0a69765 100644 --- a/django/contrib/postgres/validators.py +++ b/django/contrib/postgres/validators.py @@ -1,5 +1,9 @@ +import copy + +from django.core.exceptions import ValidationError from django.core.validators import MaxLengthValidator, MinLengthValidator -from django.utils.translation import ungettext_lazy +from django.utils.deconstruct import deconstructible +from django.utils.translation import ungettext_lazy, ugettext_lazy as _ class ArrayMaxLengthValidator(MaxLengthValidator): @@ -14,3 +18,48 @@ class ArrayMinLengthValidator(MinLengthValidator): 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.', 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.', 'limit_value') + + +@deconstructible +class KeysValidator(object): + """A validator designed for HStore to require/restrict keys.""" + + messages = { + 'missing_keys': _('Some keys were missing: %(keys)s'), + 'extra_keys': _('Some unknown keys were provided: %(keys)s'), + } + strict = False + + def __init__(self, keys, strict=False, messages=None): + self.keys = set(keys) + self.strict = strict + if messages is not None: + self.messages = copy.copy(self.messages) + self.messages.update(messages) + + def __call__(self, value): + keys = set(value.keys()) + missing_keys = self.keys - keys + if missing_keys: + raise ValidationError(self.messages['missing_keys'], + code='missing_keys', + params={'keys': ', '.join(missing_keys)}, + ) + if self.strict: + extra_keys = keys - self.keys + if extra_keys: + raise ValidationError(self.messages['extra_keys'], + code='extra_keys', + params={'keys': ', '.join(extra_keys)}, + ) + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and (self.keys == other.keys) + and (self.messages == other.messages) + and (self.strict == other.strict) + ) + + def __ne__(self, other): + return not (self == other) diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt index 49ce5368c8..3fba239e74 100644 --- a/docs/ref/contrib/postgres/fields.txt +++ b/docs/ref/contrib/postgres/fields.txt @@ -61,8 +61,8 @@ ArrayField When nesting ``ArrayField``, whether you use the `size` parameter or not, PostgreSQL requires that the arrays are rectangular:: - from django.db import models from django.contrib.postgres.fields import ArrayField + from django.db import models class Board(models.Model): pieces = ArrayField(ArrayField(models.IntegerField())) @@ -95,7 +95,7 @@ We will use the following example model:: name = models.CharField(max_length=200) tags = ArrayField(models.CharField(max_length=200), blank=True) - def __str__(self): # __unicode__ on python 2 + def __str__(self): # __unicode__ on Python 2 return self.name .. fieldlookup:: arrayfield.contains @@ -240,3 +240,165 @@ At present using :attr:`~django.db.models.Field.db_index` will create a ``btree`` index. This does not offer particularly significant help to querying. A more useful index is a ``GIN`` index, which you should create using a :class:`~django.db.migrations.operations.RunSQL` operation. + +HStoreField +----------- + +.. class:: HStoreField(**options) + + A field for storing mappings of strings to strings. The Python data type + used is a ``dict``. + +.. note:: + + On occasions it may be useful to require or restrict the keys which are + valid for a given field. This can be done using the + :class:`~django.contrib.postgres.validators.KeysValidator`. + +Querying HStoreField +^^^^^^^^^^^^^^^^^^^^ + +In addition to the ability to query by key, there are a number of custom +lookups available for ``HStoreField``. + +We will use the following example model:: + + from django.contrib.postgres.fields import HStoreField + from django.db import models + + class Dog(models.Model): + name = models.CharField(max_length=200) + data = HStoreField() + + def __str__(self): # __unicode__ on Python 2 + return self.name + +.. fieldlookup:: hstorefield.key + +Key lookups +~~~~~~~~~~~ + +To query based on a given key, you simply use that key as the lookup name:: + + >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie'}) + + >>> Dog.objects.filter(data__breed='collie') + [<Dog: Meg>] + +You can chain other lookups after key lookups:: + + >>> Dog.objects.filter(data__breed__contains='l') + [<Dog: Rufus>, Dog: Meg>] + +If the key you wish to query by clashes with the name of another lookup, you +need to use the :lookup:`hstorefield.contains` lookup instead. + +.. warning:: + + Since any string could be a key in a hstore value, any lookup other than + those listed below will be interpreted as a key lookup. No errors are + raised. Be extra careful for typing mistakes, and always check your queries + work as you intend. + +.. fieldlookup:: hstorefield.contains + +contains +~~~~~~~~ + +The :lookup:`contains` lookup is overridden on +:class:`~django.contrib.postgres.fields.HStoreField`. The returned objects are +those where the given ``dict`` of key-value pairs are all contained in the +field. It uses the SQL operator ``@>``. For example:: + + >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador', 'owner': 'Bob'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + >>> Dog.objects.create(name='Fred', data={}) + + >>> Dog.objects.filter(data__contains={'owner': 'Bob'}) + [<Dog: Rufus>, <Dog: Meg>] + + >>> Dog.objects.filter(data__contains={'breed': 'collie'}) + [<Dog: Meg>] + +.. fieldlookup:: hstorefield.contained_by + +contained_by +~~~~~~~~~~~~ + +This is the inverse of the :lookup:`contains <hstorefield.contains>` lookup - +the objects returned will be those where the key-value pairs on the object are +a subset of those in the value passed. It uses the SQL operator ``<@``. For +example:: + + >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador', 'owner': 'Bob'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + >>> Dog.objects.create(name='Fred', data={}) + + >>> Dog.objects.filter(data__contained_by={'breed': 'collie', 'owner': 'Bob'}) + [<Dog: Meg>, <Dog: Fred>] + + >>> Dog.objects.filter(data__contained_by={'breed': 'collie'}) + [<Dog: Fred>] + +.. fieldlookup:: hstorefield.has_key + +has_key +~~~~~~~ + +Returns objects where the given key is in the data. Uses the SQL operator +``?``. For example:: + + >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + + >>> Dog.objects.filter(data__has_key='owner') + [<Dog: Meg>] + +.. fieldlookup:: hstorefield.has_keys + +has_keys +~~~~~~~~ + +Returns objects where all of the given keys are in the data. Uses the SQL operator +``?&``. For example:: + + >>> Dog.objects.create(name='Rufus', data={}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + + >>> Dog.objects.filter(data__has_keys=['breed', 'owner']) + [<Dog: Meg>] + +.. fieldlookup:: hstorefield.keys + +keys +~~~~ + +Returns objects where the array of keys is the given value. Note that the order +is not guaranteed to be reliable, so this transform is mainly useful for using +in conjunction with lookups on +:class:`~django.contrib.postgres.fields.ArrayField`. Uses the SQL function +``akeys()``. For example:: + + >>> Dog.objects.create(name='Rufus', data={'toy': 'bone'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + + >>> Dog.objects.filter(data__keys__overlap=['breed', 'toy']) + [<Dog: Rufus>, <Dog: Meg>] + +.. fieldlookup:: hstorefield.values + +values +~~~~~~ + +Returns objects where the array of values is the given value. Note that the +order is not guaranteed to be reliable, so this transform is mainly useful for +using in conjunction with lookups on +:class:`~django.contrib.postgres.fields.ArrayField`. Uses the SQL function +``avalues()``. For example:: + + >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'}) + >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'}) + + >>> Dog.objects.filter(data__values__contains=['collie']) + [<Dog: Meg>] diff --git a/docs/ref/contrib/postgres/forms.txt b/docs/ref/contrib/postgres/forms.txt index 6cad537f3b..261cc7f4f4 100644 --- a/docs/ref/contrib/postgres/forms.txt +++ b/docs/ref/contrib/postgres/forms.txt @@ -133,3 +133,23 @@ SplitArrayField ['1', '2', ''] # -> [1, 2] ['1', '', '3'] # -> [1, None, 3] ['', '2', ''] # -> [None, 2] + +HStoreField +----------- + +.. class:: HStoreField + + A field which accepts JSON encoded data for an + :class:`~django.contrib.postgres.fields.HStoreField`. It will cast all the + values to strings. It is represented by an HTML ``<textarea>``. + + .. admonition:: User friendly forms + + ``HStoreField`` is not particularly user friendly in most cases, + however it is a useful way to format data from a client-side widget for + submission to the server. + + .. note:: + On occasions it may be useful to require or restrict the keys which are + valid for a given field. This can be done using the + :class:`~django.contrib.postgres.validators.KeysValidator`. diff --git a/docs/ref/contrib/postgres/index.txt b/docs/ref/contrib/postgres/index.txt index 31969222ca..b23db125f2 100644 --- a/docs/ref/contrib/postgres/index.txt +++ b/docs/ref/contrib/postgres/index.txt @@ -26,3 +26,5 @@ a number of PostgreSQL specific data types. fields forms + operations + validators diff --git a/docs/ref/contrib/postgres/operations.txt b/docs/ref/contrib/postgres/operations.txt new file mode 100644 index 0000000000..4b9b7f5c44 --- /dev/null +++ b/docs/ref/contrib/postgres/operations.txt @@ -0,0 +1,27 @@ +Database migration operations +============================= + +All of these :doc:`operations </ref/migration-operations>` are available from +the ``django.contrib.postgres.operations`` module. + +.. currentmodule:: django.contrib.postgres.operations + +CreateExtension +--------------- + +.. class:: CreateExtension(name) + + An ``Operation`` subclass which installs PostgreSQL extensions. + + .. attribute:: name + + This is a required argument. The name of the extension to be installed. + +HStoreExtension +--------------- + +.. class:: HStoreExtension() + + A subclass of :class:`~django.contrib.postgres.operations.CreateExtension` + which will install the ``hstore`` extension and also immediately set up the + connection to interpret hstore data. diff --git a/docs/ref/contrib/postgres/validators.txt b/docs/ref/contrib/postgres/validators.txt new file mode 100644 index 0000000000..76cd52510a --- /dev/null +++ b/docs/ref/contrib/postgres/validators.txt @@ -0,0 +1,20 @@ +========== +Validators +========== + +.. module:: django.contrib.postgres.validators + +``KeysValidator`` +----------------- + +.. class:: KeysValidator(keys, strict=False, messages=None) + + Validates that the given keys are contained in the value. If ``strict`` is + ``True``, then it also checks that there are no other keys present. + + The ``messages`` passed should be a dict containing the keys + ``missing_keys`` and/or ``extra_keys``. + + .. note:: + Note that this checks only for the existence of a given key, not that + the value of a key is non-empty. diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 3cf7e77038..da2e5ae466 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -35,6 +35,14 @@ site. .. _django-secure: https://pypi.python.org/pypi/django-secure +New PostgreSQL specific functionality +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Django now has a module with extensions for PostgreSQL specific features, such +as :class:`~django.contrib.postgres.fields.ArrayField` and +:class:`~django.contrib.postgres.fields.HStoreField`. A full breakdown of the +features is available :doc:`in the documentation</ref/contrib/postgres/index>`. + New data types ~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/migrations/0001_setup_extensions.py b/tests/postgres_tests/migrations/0001_setup_extensions.py new file mode 100644 index 0000000000..e69a9a00df --- /dev/null +++ b/tests/postgres_tests/migrations/0001_setup_extensions.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.contrib.postgres.operations import HStoreExtension +from django.db import models, migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ] + + operations = [ + HStoreExtension(), + ] diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py new file mode 100644 index 0000000000..073e62b1d3 --- /dev/null +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations +import django.contrib.postgres.fields +import django.contrib.postgres.fields.hstore + + +class Migration(migrations.Migration): + + dependencies = [ + ('postgres_tests', '0001_setup_extensions'), + ] + + operations = [ + migrations.CreateModel( + name='CharArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.ArrayField(models.CharField(max_length=10), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='DateTimeArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='HStoreModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.hstore.HStoreField(blank=True, null=True)), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='IntegerArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='NestedIntegerArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.ArrayField(django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='NullableIntegerArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None, null=True, blank=True)), + ], + options={ + }, + bases=(models.Model,), + ), + ] diff --git a/tests/postgres_tests/migrations/__init__.py b/tests/postgres_tests/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 6420ebe1cd..cadab474af 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -1,4 +1,4 @@ -from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.fields import ArrayField, HStoreField from django.db import models @@ -20,3 +20,7 @@ class DateTimeArrayModel(models.Model): class NestedIntegerArrayModel(models.Model): field = ArrayField(ArrayField(models.IntegerField())) + + +class HStoreModel(models.Model): + field = HStoreField(blank=True, null=True) diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py new file mode 100644 index 0000000000..ac22f322b5 --- /dev/null +++ b/tests/postgres_tests/test_hstore.py @@ -0,0 +1,218 @@ +import json +import unittest + +from django.contrib.postgres import forms +from django.contrib.postgres.fields import HStoreField +from django.contrib.postgres.validators import KeysValidator +from django.core import exceptions, serializers +from django.db import connection +from django.test import TestCase + +from .models import HStoreModel + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class SimpleTests(TestCase): + apps = ['django.contrib.postgres'] + + def test_save_load_success(self): + value = {'a': 'b'} + instance = HStoreModel(field=value) + instance.save() + reloaded = HStoreModel.objects.get() + self.assertEqual(reloaded.field, value) + + def test_null(self): + instance = HStoreModel(field=None) + instance.save() + reloaded = HStoreModel.objects.get() + self.assertEqual(reloaded.field, None) + + def test_value_null(self): + value = {'a': None} + instance = HStoreModel(field=value) + instance.save() + reloaded = HStoreModel.objects.get() + self.assertEqual(reloaded.field, value) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestQuerying(TestCase): + + def setUp(self): + self.objs = [ + HStoreModel.objects.create(field={'a': 'b'}), + HStoreModel.objects.create(field={'a': 'b', 'c': 'd'}), + HStoreModel.objects.create(field={'c': 'd'}), + HStoreModel.objects.create(field={}), + HStoreModel.objects.create(field=None), + ] + + def test_exact(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__exact={'a': 'b'}), + self.objs[:1] + ) + + def test_contained_by(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__contained_by={'a': 'b', 'c': 'd'}), + self.objs[:4] + ) + + def test_contains(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__contains={'a': 'b'}), + self.objs[:2] + ) + + def test_has_key(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__has_key='c'), + self.objs[1:3] + ) + + def test_has_keys(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__has_keys=['a', 'c']), + self.objs[1:2] + ) + + def test_key_transform(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__a='b'), + self.objs[:2] + ) + + def test_keys(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__keys=['a']), + self.objs[:1] + ) + + def test_values(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__values=['b']), + self.objs[:1] + ) + + def test_field_chaining(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__a__contains='b'), + self.objs[:2] + ) + + def test_keys_contains(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__keys__contains=['a']), + self.objs[:2] + ) + + def test_values_overlap(self): + self.assertSequenceEqual( + HStoreModel.objects.filter(field__values__overlap=['b', 'd']), + self.objs[:3] + ) + + +@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') +class TestSerialization(TestCase): + test_data = '[{"fields": {"field": "{\\"a\\": \\"b\\"}"}, "model": "postgres_tests.hstoremodel", "pk": null}]' + + def test_dumping(self): + instance = HStoreModel(field={'a': 'b'}) + data = serializers.serialize('json', [instance]) + self.assertEqual(json.loads(data), json.loads(self.test_data)) + + def test_loading(self): + instance = list(serializers.deserialize('json', self.test_data))[0].object + self.assertEqual(instance.field, {'a': 'b'}) + + +class TestValidation(TestCase): + + def test_not_a_string(self): + field = HStoreField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean({'a': 1}, None) + self.assertEqual(cm.exception.code, 'not_a_string') + self.assertEqual(cm.exception.message % cm.exception.params, 'The value of "a" is not a string.') + + +class TestFormField(TestCase): + + def test_valid(self): + field = forms.HStoreField() + value = field.clean('{"a": "b"}') + self.assertEqual(value, {'a': 'b'}) + + def test_invalid_json(self): + field = forms.HStoreField() + with self.assertRaises(exceptions.ValidationError) as cm: + field.clean('{"a": "b"') + self.assertEqual(cm.exception.messages[0], 'Could not load JSON data.') + self.assertEqual(cm.exception.code, 'invalid_json') + + def test_not_string_values(self): + field = forms.HStoreField() + value = field.clean('{"a": 1}') + self.assertEqual(value, {'a': '1'}) + + def test_empty(self): + field = forms.HStoreField(required=False) + value = field.clean('') + self.assertEqual(value, {}) + + def test_model_field_formfield(self): + model_field = HStoreField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, forms.HStoreField) + + +class TestValidator(TestCase): + + def test_simple_valid(self): + validator = KeysValidator(keys=['a', 'b']) + validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) + + def test_missing_keys(self): + validator = KeysValidator(keys=['a', 'b']) + with self.assertRaises(exceptions.ValidationError) as cm: + validator({'a': 'foo', 'c': 'baz'}) + self.assertEqual(cm.exception.messages[0], 'Some keys were missing: b') + self.assertEqual(cm.exception.code, 'missing_keys') + + def test_strict_valid(self): + validator = KeysValidator(keys=['a', 'b'], strict=True) + validator({'a': 'foo', 'b': 'bar'}) + + def test_extra_keys(self): + validator = KeysValidator(keys=['a', 'b'], strict=True) + with self.assertRaises(exceptions.ValidationError) as cm: + validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) + self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c') + self.assertEqual(cm.exception.code, 'extra_keys') + + def test_custom_messages(self): + messages = { + 'missing_keys': 'Foobar', + } + validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages) + with self.assertRaises(exceptions.ValidationError) as cm: + validator({'a': 'foo', 'c': 'baz'}) + self.assertEqual(cm.exception.messages[0], 'Foobar') + self.assertEqual(cm.exception.code, 'missing_keys') + with self.assertRaises(exceptions.ValidationError) as cm: + validator({'a': 'foo', 'b': 'bar', 'c': 'baz'}) + self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c') + self.assertEqual(cm.exception.code, 'extra_keys') + + def test_deconstruct(self): + messages = { + 'missing_keys': 'Foobar', + } + validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages) + path, args, kwargs = validator.deconstruct() + self.assertEqual(path, 'django.contrib.postgres.validators.KeysValidator') + self.assertEqual(args, ()) + self.assertEqual(kwargs, {'keys': ['a', 'b'], 'strict': True, 'messages': messages}) diff --git a/tests/runtests.py b/tests/runtests.py index ed14bdb666..8f612c80ec 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -78,7 +78,7 @@ def get_test_modules(): os.path.isfile(f) or not os.path.exists(os.path.join(dirpath, f, '__init__.py'))): continue - if not connection.vendor == 'postgresql' and f == 'postgres_tests': + if not connection.vendor == 'postgresql' and f == 'postgres_tests' or f == 'postgres': continue modules.append((modpath, f)) return modules