diff --git a/django/contrib/postgres/__init__.py b/django/contrib/postgres/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/django/contrib/postgres/fields/__init__.py b/django/contrib/postgres/fields/__init__.py
new file mode 100644
index 0000000000..e3ceebd62c
--- /dev/null
+++ b/django/contrib/postgres/fields/__init__.py
@@ -0,0 +1 @@
+from .array import * # NOQA
diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py
new file mode 100644
index 0000000000..7a37267400
--- /dev/null
+++ b/django/contrib/postgres/fields/array.py
@@ -0,0 +1,254 @@
+import json
+
+from django.contrib.postgres.forms import SimpleArrayField
+from django.contrib.postgres.validators import ArrayMaxLengthValidator
+from django.core import checks, exceptions
+from django.db.models import Field, Lookup, Transform, IntegerField
+from django.utils import six
+from django.utils.translation import string_concat, ugettext_lazy as _
+
+
+__all__ = ['ArrayField']
+
+
+class AttributeSetter(object):
+ def __init__(self, name, value):
+ setattr(self, name, value)
+
+
+class ArrayField(Field):
+ empty_strings_allowed = False
+ default_error_messages = {
+ 'item_invalid': _('Item %(nth)s in the array did not validate: '),
+ 'nested_array_mismatch': _('Nested arrays must have the same length.'),
+ }
+
+ def __init__(self, base_field, size=None, **kwargs):
+ self.base_field = base_field
+ self.size = size
+ if self.size:
+ self.default_validators = self.default_validators[:]
+ self.default_validators.append(ArrayMaxLengthValidator(self.size))
+ super(ArrayField, self).__init__(**kwargs)
+
+ def check(self, **kwargs):
+ errors = super(ArrayField, self).check(**kwargs)
+ if self.base_field.rel:
+ errors.append(
+ checks.Error(
+ 'Base field for array cannot be a related field.',
+ hint=None,
+ obj=self,
+ id='postgres.E002'
+ )
+ )
+ else:
+ # Remove the field name checks as they are not needed here.
+ base_errors = self.base_field.check()
+ if base_errors:
+ messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
+ errors.append(
+ checks.Error(
+ 'Base field for array has errors:\n %s' % messages,
+ hint=None,
+ obj=self,
+ id='postgres.E001'
+ )
+ )
+ return errors
+
+ def set_attributes_from_name(self, name):
+ super(ArrayField, self).set_attributes_from_name(name)
+ self.base_field.set_attributes_from_name(name)
+
+ @property
+ def description(self):
+ return 'Array of %s' % self.base_field.description
+
+ def db_type(self, connection):
+ size = self.size or ''
+ return '%s[%s]' % (self.base_field.db_type(connection), size)
+
+ def get_prep_value(self, value):
+ if isinstance(value, list) or isinstance(value, tuple):
+ return [self.base_field.get_prep_value(i) for i in value]
+ return value
+
+ def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
+ if lookup_type == 'contains':
+ return [self.get_prep_value(value)]
+ return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
+ connection, prepared=False)
+
+ def deconstruct(self):
+ name, path, args, kwargs = super(ArrayField, self).deconstruct()
+ path = 'django.contrib.postgres.fields.ArrayField'
+ args.insert(0, self.base_field)
+ kwargs['size'] = self.size
+ return name, path, args, kwargs
+
+ def to_python(self, value):
+ if isinstance(value, six.string_types):
+ # Assume we're deserializing
+ vals = json.loads(value)
+ value = [self.base_field.to_python(val) for val in vals]
+ return value
+
+ def value_to_string(self, obj):
+ values = []
+ vals = self._get_val_from_obj(obj)
+ base_field = self.base_field
+
+ for val in vals:
+ obj = AttributeSetter(base_field.attname, val)
+ values.append(base_field.value_to_string(obj))
+ return json.dumps(values)
+
+ def get_transform(self, name):
+ transform = super(ArrayField, self).get_transform(name)
+ if transform:
+ return transform
+ try:
+ index = int(name)
+ except ValueError:
+ pass
+ else:
+ index += 1 # postgres uses 1-indexing
+ return IndexTransformFactory(index, self.base_field)
+ try:
+ start, end = name.split('_')
+ start = int(start) + 1
+ end = int(end) # don't add one here because postgres slices are weird
+ except ValueError:
+ pass
+ else:
+ return SliceTransformFactory(start, end)
+
+ def validate(self, value, model_instance):
+ super(ArrayField, self).validate(value, model_instance)
+ for i, part in enumerate(value):
+ try:
+ self.base_field.validate(part, model_instance)
+ except exceptions.ValidationError as e:
+ raise exceptions.ValidationError(
+ string_concat(self.error_messages['item_invalid'], e.message),
+ code='item_invalid',
+ params={'nth': i},
+ )
+ if isinstance(self.base_field, ArrayField):
+ if len({len(i) for i in value}) > 1:
+ raise exceptions.ValidationError(
+ self.error_messages['nested_array_mismatch'],
+ code='nested_array_mismatch',
+ )
+
+ def formfield(self, **kwargs):
+ defaults = {
+ 'form_class': SimpleArrayField,
+ 'base_field': self.base_field.formfield(),
+ 'max_length': self.size,
+ }
+ defaults.update(kwargs)
+ return super(ArrayField, self).formfield(**defaults)
+
+
+class ArrayContainsLookup(Lookup):
+ lookup_name = 'contains'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params + rhs_params
+ return '%s @> %s' % (lhs, rhs), params
+
+
+ArrayField.register_lookup(ArrayContainsLookup)
+
+
+class ArrayContainedByLookup(Lookup):
+ lookup_name = 'contained_by'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params + rhs_params
+ return '%s <@ %s' % (lhs, rhs), params
+
+
+ArrayField.register_lookup(ArrayContainedByLookup)
+
+
+class ArrayOverlapLookup(Lookup):
+ lookup_name = 'overlap'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params + rhs_params
+ return '%s && %s' % (lhs, rhs), params
+
+
+ArrayField.register_lookup(ArrayOverlapLookup)
+
+
+class ArrayLenTransform(Transform):
+ lookup_name = 'len'
+
+ @property
+ def output_type(self):
+ return IntegerField()
+
+ def as_sql(self, qn, connection):
+ lhs, params = qn.compile(self.lhs)
+ return 'array_length(%s, 1)' % lhs, params
+
+
+ArrayField.register_lookup(ArrayLenTransform)
+
+
+class IndexTransform(Transform):
+
+ def __init__(self, index, base_field, *args, **kwargs):
+ super(IndexTransform, self).__init__(*args, **kwargs)
+ self.index = index
+ self.base_field = base_field
+
+ def as_sql(self, qn, connection):
+ lhs, params = qn.compile(self.lhs)
+ return '%s[%s]' % (lhs, self.index), params
+
+ @property
+ def output_type(self):
+ return self.base_field
+
+
+class IndexTransformFactory(object):
+
+ def __init__(self, index, base_field):
+ self.index = index
+ self.base_field = base_field
+
+ def __call__(self, *args, **kwargs):
+ return IndexTransform(self.index, self.base_field, *args, **kwargs)
+
+
+class SliceTransform(Transform):
+
+ def __init__(self, start, end, *args, **kwargs):
+ super(SliceTransform, self).__init__(*args, **kwargs)
+ self.start = start
+ self.end = end
+
+ def as_sql(self, qn, connection):
+ lhs, params = qn.compile(self.lhs)
+ return '%s[%s:%s]' % (lhs, self.start, self.end), params
+
+
+class SliceTransformFactory(object):
+
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ def __call__(self, *args, **kwargs):
+ return SliceTransform(self.start, self.end, *args, **kwargs)
diff --git a/django/contrib/postgres/forms/__init__.py b/django/contrib/postgres/forms/__init__.py
new file mode 100644
index 0000000000..e3ceebd62c
--- /dev/null
+++ b/django/contrib/postgres/forms/__init__.py
@@ -0,0 +1 @@
+from .array import * # NOQA
diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py
new file mode 100644
index 0000000000..620c7c7b6e
--- /dev/null
+++ b/django/contrib/postgres/forms/array.py
@@ -0,0 +1,185 @@
+import copy
+
+from django.contrib.postgres.validators import ArrayMinLengthValidator, ArrayMaxLengthValidator
+from django.core.exceptions import ValidationError
+from django import forms
+from django.utils.safestring import mark_safe
+from django.utils import six
+from django.utils.translation import string_concat, ugettext_lazy as _
+
+
+class SimpleArrayField(forms.CharField):
+ default_error_messages = {
+ 'item_invalid': _('Item %(nth)s in the array did not validate: '),
+ }
+
+ def __init__(self, base_field, delimiter=',', max_length=None, min_length=None, *args, **kwargs):
+ self.base_field = base_field
+ self.delimiter = delimiter
+ super(SimpleArrayField, self).__init__(*args, **kwargs)
+ if min_length is not None:
+ self.min_length = min_length
+ self.validators.append(ArrayMinLengthValidator(int(min_length)))
+ if max_length is not None:
+ self.max_length = max_length
+ self.validators.append(ArrayMaxLengthValidator(int(max_length)))
+
+ def prepare_value(self, value):
+ if isinstance(value, list):
+ return self.delimiter.join([six.text_type(self.base_field.prepare_value(v)) for v in value])
+ return value
+
+ def to_python(self, value):
+ if value:
+ items = value.split(self.delimiter)
+ else:
+ items = []
+ errors = []
+ values = []
+ for i, item in enumerate(items):
+ try:
+ values.append(self.base_field.to_python(item))
+ except ValidationError as e:
+ for error in e.error_list:
+ errors.append(ValidationError(
+ string_concat(self.error_messages['item_invalid'], error.message),
+ code='item_invalid',
+ params={'nth': i},
+ ))
+ if errors:
+ raise ValidationError(errors)
+ return values
+
+ def validate(self, value):
+ super(SimpleArrayField, self).validate(value)
+ errors = []
+ for i, item in enumerate(value):
+ try:
+ self.base_field.validate(item)
+ except ValidationError as e:
+ for error in e.error_list:
+ errors.append(ValidationError(
+ string_concat(self.error_messages['item_invalid'], error.message),
+ code='item_invalid',
+ params={'nth': i},
+ ))
+ if errors:
+ raise ValidationError(errors)
+
+ def run_validators(self, value):
+ super(SimpleArrayField, self).run_validators(value)
+ errors = []
+ for i, item in enumerate(value):
+ try:
+ self.base_field.run_validators(item)
+ except ValidationError as e:
+ for error in e.error_list:
+ errors.append(ValidationError(
+ string_concat(self.error_messages['item_invalid'], error.message),
+ code='item_invalid',
+ params={'nth': i},
+ ))
+ if errors:
+ raise ValidationError(errors)
+
+
+class SplitArrayWidget(forms.Widget):
+
+ def __init__(self, widget, size, **kwargs):
+ self.widget = widget() if isinstance(widget, type) else widget
+ self.size = size
+ super(SplitArrayWidget, self).__init__(**kwargs)
+
+ @property
+ def is_hidden(self):
+ return self.widget.is_hidden
+
+ def value_from_datadict(self, data, files, name):
+ return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index))
+ for index in range(self.size)]
+
+ def id_for_label(self, id_):
+ # See the comment for RadioSelect.id_for_label()
+ if id_:
+ id_ += '_0'
+ return id_
+
+ def render(self, name, value, attrs=None):
+ if self.is_localized:
+ self.widget.is_localized = self.is_localized
+ value = value or []
+ output = []
+ final_attrs = self.build_attrs(attrs)
+ id_ = final_attrs.get('id', None)
+ for i in range(max(len(value), self.size)):
+ try:
+ widget_value = value[i]
+ except IndexError:
+ widget_value = None
+ if id_:
+ final_attrs = dict(final_attrs, id='%s_%s' % (id_, i))
+ output.append(self.widget.render(name + '_%s' % i, widget_value, final_attrs))
+ return mark_safe(self.format_output(output))
+
+ def format_output(self, rendered_widgets):
+ return ''.join(rendered_widgets)
+
+ @property
+ def media(self):
+ return self.widget.media
+
+ def __deepcopy__(self, memo):
+ obj = super(SplitArrayWidget, self).__deepcopy__(memo)
+ obj.widget = copy.deepcopy(self.widget)
+ return obj
+
+ @property
+ def needs_multipart_form(self):
+ return self.widget.needs_multipart_form
+
+
+class SplitArrayField(forms.Field):
+ default_error_messages = {
+ 'item_invalid': _('Item %(nth)s in the array did not validate: '),
+ }
+
+ def __init__(self, base_field, size, remove_trailing_nulls=False, **kwargs):
+ self.base_field = base_field
+ self.size = size
+ self.remove_trailing_nulls = remove_trailing_nulls
+ widget = SplitArrayWidget(widget=base_field.widget, size=size)
+ kwargs.setdefault('widget', widget)
+ super(SplitArrayField, self).__init__(**kwargs)
+
+ def clean(self, value):
+ cleaned_data = []
+ errors = []
+ if not any(value) and self.required:
+ raise ValidationError(self.error_messages['required'])
+ max_size = max(self.size, len(value))
+ for i in range(max_size):
+ item = value[i]
+ try:
+ cleaned_data.append(self.base_field.clean(item))
+ errors.append(None)
+ except ValidationError as error:
+ errors.append(ValidationError(
+ string_concat(self.error_messages['item_invalid'], error.message),
+ code='item_invalid',
+ params={'nth': i},
+ ))
+ cleaned_data.append(None)
+ if self.remove_trailing_nulls:
+ null_index = None
+ for i, value in reversed(list(enumerate(cleaned_data))):
+ if value in self.base_field.empty_values:
+ null_index = i
+ else:
+ break
+ if null_index:
+ cleaned_data = cleaned_data[:null_index]
+ errors = errors[:null_index]
+ errors = list(filter(None, errors))
+ if errors:
+ raise ValidationError(errors)
+ return cleaned_data
diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py
new file mode 100644
index 0000000000..353305949e
--- /dev/null
+++ b/django/contrib/postgres/validators.py
@@ -0,0 +1,16 @@
+from django.core.validators import MaxLengthValidator, MinLengthValidator
+from django.utils.translation import ungettext_lazy
+
+
+class ArrayMaxLengthValidator(MaxLengthValidator):
+ message = ungettext_lazy(
+ 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.',
+ 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.',
+ 'limit_value')
+
+
+class ArrayMinLengthValidator(MinLengthValidator):
+ message = ungettext_lazy(
+ 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.',
+ 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.',
+ 'limit_value')
diff --git a/docs/index.txt b/docs/index.txt
index e0529c4503..b41f5c0ecb 100644
--- a/docs/index.txt
+++ b/docs/index.txt
@@ -91,7 +91,8 @@ manipulating the data of your Web application. Learn more about it below:
:doc:`Supported databases [` |
:doc:`Legacy databases ` |
:doc:`Providing initial data ` |
- :doc:`Optimize database access `
+ :doc:`Optimize database access ` |
+ :doc:`PostgreSQL specific features ][`
The view layer
==============
diff --git a/docs/ref/contrib/index.txt b/docs/ref/contrib/index.txt
index 533680659e..ebfc2874b4 100644
--- a/docs/ref/contrib/index.txt
+++ b/docs/ref/contrib/index.txt
@@ -31,6 +31,7 @@ those packages have.
gis/index
humanize
messages
+ postgres/index
redirects
sitemaps
sites
@@ -122,6 +123,13 @@ messages
See the :doc:`messages documentation ]`.
+postgres
+========
+
+A collection of PostgreSQL specific features.
+
+See the :doc:`contrib.postgres documentation `.
+
redirects
=========
diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt
new file mode 100644
index 0000000000..dcde84d2ec
--- /dev/null
+++ b/docs/ref/contrib/postgres/fields.txt
@@ -0,0 +1,228 @@
+PostgreSQL specific model fields
+================================
+
+All of these fields are available from the ``django.contrib.postgres.fields``
+module.
+
+.. currentmodule:: django.contrib.postgres.fields
+
+ArrayField
+----------
+
+.. class:: ArrayField(base_field, size=None, **options)
+
+ A field for storing lists of data. Most field types can be used, you simply
+ pass another field instance as the :attr:`base_field
+ `. You may also specify a :attr:`size
+ `. ``ArrayField`` can be nested to store multi-dimensional
+ arrays.
+
+ .. attribute:: base_field
+
+ This is a required argument.
+
+ Specifies the underlying data type and behaviour for the array. It
+ should be an instance of a subclass of
+ :class:`~django.db.models.Field`. For example, it could be an
+ :class:`~django.db.models.IntegerField` or a
+ :class:`~django.db.models.CharField`. Most field types are permitted,
+ with the exception of those handling relational data
+ (:class:`~django.db.models.ForeignKey`,
+ :class:`~django.db.models.OneToOneField` and
+ :class:`~django.db.models.ManyToManyField`).
+
+ It is possible to nest array fields - you can specify an instance of
+ ``ArrayField`` as the ``base_field``. For example::
+
+ from django.db import models
+ from django.contrib.postgres.fields import ArrayField
+
+ class ChessBoard(models.Model):
+ board = ArrayField(
+ ArrayField(
+ CharField(max_length=10, blank=True, null=True),
+ size=8),
+ size=8)
+
+ Transformation of values between the database and the model, validation
+ of data and configuration, and serialization are all delegated to the
+ underlying base field.
+
+ .. attribute:: size
+
+ This is an optional argument.
+
+ If passed, the array will have a maximum size as specified. This will
+ be passed to the database, although PostgreSQL at present does not
+ enforce the restriction.
+
+.. note::
+
+ When nesting ``ArrayField``, whether you use the `size` parameter or not,
+ PostgreSQL requires that the arrays are rectangular::
+
+ from django.db import models
+ from django.contrib.postgres.fields import ArrayField
+
+ class Board(models.Model):
+ pieces = ArrayField(ArrayField(models.IntegerField()))
+
+ # Valid
+ Board(pieces=[
+ [2, 3],
+ [2, 1],
+ ])
+
+ # Not valid
+ Board(pieces=[
+ [2, 3],
+ [2],
+ ])
+
+ If irregular shapes are required, then the underlying field should be made
+ nullable and the values padded with ``None``.
+
+Querying ArrayField
+^^^^^^^^^^^^^^^^^^^
+
+There are a number of custom lookups and transforms for :class:`ArrayField`.
+We will use the following example model::
+
+ from django.db import models
+ from django.contrib.postgres.fields import ArrayField
+
+ class Post(models.Model):
+ name = models.CharField(max_length=200)
+ tags = ArrayField(models.CharField(max_length=200), blank=True)
+
+ def __str__(self): # __unicode__ on python 2
+ return self.name
+
+.. fieldlookup:: arrayfield.contains
+
+contains
+~~~~~~~~
+
+The :lookup:`contains` lookup is overridden on :class:`ArrayField`. The
+returned objects will be those where the values passed are a subset of the
+data. It uses the SQL operator ``@>``. For example::
+
+ >>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
+ >>> Post.objects.create(name='Second post', tags=['thoughts'])
+ >>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
+
+ >>> Post.objects.filter(tags__contains=['thoughts'])
+ [, ]
+
+ >>> Post.objects.filter(tags__contains=['django'])
+ [, ]
+
+ >>> Post.objects.filter(tags__contains=['django', 'thoughts'])
+ []
+
+.. fieldlookup:: arrayfield.contained_by
+
+contained_by
+~~~~~~~~~~~~
+
+This is the inverse of the :lookup:`contains ` lookup -
+the objects returned will be those where the data is a subset of the values
+passed. It uses the SQL operator ``<@``. For example::
+
+ >>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
+ >>> Post.objects.create(name='Second post', tags=['thoughts'])
+ >>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
+
+ >>> Post.objects.filter(tags__contained_by=['thoughts', 'django'])
+ []
+
+ >>> Post.objects.filter(tags__contained_by=['thoughts', 'django', 'tutorial'])
+ [, , ]
+
+.. fieldlookup:: arrayfield.overlap
+
+overlap
+~~~~~~~
+
+Returns objects where the data shares any results with the values passed. Uses
+the SQL operator ``&&``. For example::
+
+ >>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
+ >>> Post.objects.create(name='Second post', tags=['thoughts'])
+ >>> Post.objects.create(name='Third post', tags=['tutorial', 'django'])
+
+ >>> Post.objects.filter(tags__overlap=['thoughts'])
+ [, ]
+
+ >>> Post.objects.filter(tags__overlap=['thoughts', 'tutorial'])
+ [, , ]
+
+.. fieldlookup:: arrayfield.index
+
+Index transforms
+~~~~~~~~~~~~~~~~
+
+This class of transforms allows you to index into the array in queries. Any
+non-negative integer can be used. There are no errors if it exceeds the
+:attr:`size ` of the array. The lookups available after the
+transform are those from the :attr:`base_field `. For
+example::
+
+ >>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
+ >>> Post.objects.create(name='Second post', tags=['thoughts'])
+
+ >>> Post.objects.filter(tags__0='thoughts')
+ [, ]
+
+ >>> Post.objects.filter(tags__1__iexact='Django')
+ []
+
+ >>> Post.objects.filter(tags__276='javascript')
+ []
+
+.. note::
+
+ PostgreSQL uses 1-based indexing for array fields when writing raw SQL.
+ However these indexes and those used in :lookup:`slices `
+ use 0-based indexing to be consistent with Python.
+
+.. fieldlookup:: arrayfield.slice
+
+Slice transforms
+~~~~~~~~~~~~~~~~
+
+This class of transforms allow you to take a slice of the array. Any two
+non-negative integers can be used, separated by a single underscore. The
+lookups available after the transform do not change. For example::
+
+ >>> Post.objects.create(name='First post', tags=['thoughts', 'django'])
+ >>> Post.objects.create(name='Second post', tags=['thoughts'])
+ >>> Post.objects.create(name='Third post', tags=['django', 'python', 'thoughts'])
+
+ >>> Post.objects.filter(tags__0_1=['thoughts'])
+ []
+
+ >>> Post.objects.filter(tags__0_2__contains='thoughts')
+ [, ]
+
+.. note::
+
+ PostgreSQL uses 1-based indexing for array fields when writing raw SQL.
+ However these slices and those used in :lookup:`indexes `
+ use 0-based indexing to be consistent with Python.
+
+.. admonition:: Multidimensional arrays with indexes and slices
+
+ PostgreSQL has some rather esoteric behaviour when using indexes and slices
+ on multidimensional arrays. It will always work to use indexes to reach
+ down to the final underlying data, but most other slices behave strangely
+ at the database level and cannot be supported in a logical, consistent
+ fashion by Django.
+
+Indexing ArrayField
+^^^^^^^^^^^^^^^^^^^
+
+At present using :attr:`~django.db.models.Field.db_index` will create a
+``btree`` index. This does not offer particularly significant help to querying.
+A more useful index is a ``GIN`` index, which you should create using a
+:class:`~django.db.migrations.operations.RunSQL` operation.
diff --git a/docs/ref/contrib/postgres/forms.txt b/docs/ref/contrib/postgres/forms.txt
new file mode 100644
index 0000000000..6cad537f3b
--- /dev/null
+++ b/docs/ref/contrib/postgres/forms.txt
@@ -0,0 +1,135 @@
+PostgreSQL specific form fields and widgets
+===========================================
+
+All of these fields and widgets are available from the
+``django.contrib.postgres.forms`` module.
+
+.. currentmodule:: django.contrib.postgres.forms
+
+SimpleArrayField
+----------------
+
+.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None)
+
+ A simple field which maps to an array. It is represented by an HTML
+ ````.
+
+ .. attribute:: base_field
+
+ This is a required argument.
+
+ It specifies the underlying form field for the array. This is not used
+ to render any HTML, but it is used to process the submitted data and
+ validate it. For example::
+
+ >>> from django.contrib.postgres.forms import SimpleArrayField
+ >>> from django import forms
+
+ >>> class NumberListForm(forms.Form):
+ ... numbers = SimpleArrayField(forms.IntegerField())
+
+ >>> form = NumberListForm({'numbers': '1,2,3'})
+ >>> form.is_valid()
+ True
+ >>> form.cleaned_data
+ {'numbers': [1, 2, 3]}
+
+ >>> form = NumberListForm({'numbers': '1,2,a'})
+ >>> form.is_valid()
+ False
+
+ .. attribute:: delimiter
+
+ This is an optional argument which defaults to a comma: ``,``. This
+ value is used to split the submitted data. It allows you to chain
+ ``SimpleArrayField`` for multidimensional data::
+
+ >>> from django.contrib.postgres.forms import SimpleArrayField
+ >>> from django import forms
+
+ >>> class GridForm(forms.Form):
+ ... places = SimpleArrayField(SimpleArrayField(IntegerField()), delimiter='|')
+
+ >>> form = GridForm({'places': '1,2|2,1|4,3'})
+ >>> form.is_valid()
+ True
+ >>> form.cleaned_data
+ {'places': [[1, 2], [2, 1], [4, 3]]}
+
+ .. note::
+
+ The field does not support escaping of the delimiter, so be careful
+ in cases where the delimiter is a valid character in the underlying
+ field. The delimiter does not need to be only one character.
+
+ .. attribute:: max_length
+
+ This is an optional argument which validates that the array does not
+ exceed the stated length.
+
+ .. attribute:: min_length
+
+ This is an optional argument which validates that the array reaches at
+ least the stated length.
+
+ .. admonition:: User friendly forms
+
+ ``SimpleArrayField`` is not particularly user friendly in most cases,
+ however it is a useful way to format data from a client-side widget for
+ submission to the server.
+
+SplitArrayField
+---------------
+
+.. class:: SplitArrayField(base_field, size, remove_trailing_nulls=False)
+
+ This field handles arrays by reproducing the underlying field a fixed
+ number of times.
+
+ .. attribute:: base_field
+
+ This is a required argument. It specifies the form field to be
+ repeated.
+
+ .. attribute:: size
+
+ This is the fixed number of times the underlying field will be used.
+
+ .. attribute:: remove_trailing_nulls
+
+ By default, this is set to ``False``. When ``False``, each value from
+ the repeated fields is stored. When set to ``True``, any trailing
+ values which are blank will be stripped from the result. If the
+ underlying field has ``required=True``, but ``remove_trailing_nulls``
+ is ``True``, then null values are only allowed at the end, and will be
+ stripped.
+
+ Some examples::
+
+ SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=False)
+
+ ['1', '2', '3'] # -> [1, 2, 3]
+ ['1', '2', ''] # -> ValidationError - third entry required.
+ ['1', '', '3'] # -> ValidationError - second entry required.
+ ['', '2', ''] # -> ValidationError - first and third entries required.
+
+ SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=False)
+
+ ['1', '2', '3'] # -> [1, 2, 3]
+ ['1', '2', ''] # -> [1, 2, None]
+ ['1', '', '3'] # -> [1, None, 3]
+ ['', '2', ''] # -> [None, 2, None]
+
+ SplitArrayField(IntegerField(required=True), size=3, remove_trailing_nulls=True)
+
+ ['1', '2', '3'] # -> [1, 2, 3]
+ ['1', '2', ''] # -> [1, 2]
+ ['1', '', '3'] # -> ValidationError - second entry required.
+ ['', '2', ''] # -> ValidationError - first entry required.
+
+ SplitArrayField(IntegerField(required=False), size=3, remove_trailing_nulls=True)
+
+ ['1', '2', '3'] # -> [1, 2, 3]
+ ['1', '2', ''] # -> [1, 2]
+ ['1', '', '3'] # -> [1, None, 3]
+ ['', '2', ''] # -> [None, 2]
diff --git a/docs/ref/contrib/postgres/index.txt b/docs/ref/contrib/postgres/index.txt
new file mode 100644
index 0000000000..5db4ab80ed
--- /dev/null
+++ b/docs/ref/contrib/postgres/index.txt
@@ -0,0 +1,28 @@
+``django.contrib.postgres``
+===========================
+
+PostgreSQL has a number of features which are not shared by the other databases
+Django supports. This optional module contains model fields and form fields for
+a number of PostgreSQL specific data types.
+
+.. note::
+ Django is, and will continue to be, a database-agnostic web framework. We
+ would encourage those writing reusable applications for the Django
+ community to write database-agnostic code where practical. However, we
+ recognise that real world projects written using Django need not be
+ database-agnostic. In fact, once a project reaches a given size changing
+ the underlying data store is already a significant challenge and is likely
+ to require changing the code base in some ways to handle differences
+ between the data stores.
+
+ Django provides support for a number of data types which will
+ only work with PostgreSQL. There is no fundamental reason why (for example)
+ a ``contrib.mysql`` module does not exist, except that PostgreSQL has the
+ richest feature set of the supported databases so its users have the most
+ to gain.
+
+.. toctree::
+ :maxdepth: 2
+
+ fields
+ forms
diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
new file mode 100644
index 0000000000..6420ebe1cd
--- /dev/null
+++ b/tests/postgres_tests/models.py
@@ -0,0 +1,22 @@
+from django.contrib.postgres.fields import ArrayField
+from django.db import models
+
+
+class IntegerArrayModel(models.Model):
+ field = ArrayField(models.IntegerField())
+
+
+class NullableIntegerArrayModel(models.Model):
+ field = ArrayField(models.IntegerField(), blank=True, null=True)
+
+
+class CharArrayModel(models.Model):
+ field = ArrayField(models.CharField(max_length=10))
+
+
+class DateTimeArrayModel(models.Model):
+ field = ArrayField(models.DateTimeField())
+
+
+class NestedIntegerArrayModel(models.Model):
+ field = ArrayField(ArrayField(models.IntegerField()))
diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py
new file mode 100644
index 0000000000..35ea65480a
--- /dev/null
+++ b/tests/postgres_tests/test_array.py
@@ -0,0 +1,389 @@
+import unittest
+
+from django.contrib.postgres.fields import ArrayField
+from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
+from django.core import exceptions, serializers
+from django.db import models, IntegrityError, connection
+from django.db.migrations.writer import MigrationWriter
+from django import forms
+from django.test import TestCase
+from django.utils import timezone
+
+from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class TestSaveLoad(TestCase):
+
+ def test_integer(self):
+ instance = IntegerArrayModel(field=[1, 2, 3])
+ instance.save()
+ loaded = IntegerArrayModel.objects.get()
+ self.assertEqual(instance.field, loaded.field)
+
+ def test_char(self):
+ instance = CharArrayModel(field=['hello', 'goodbye'])
+ instance.save()
+ loaded = CharArrayModel.objects.get()
+ self.assertEqual(instance.field, loaded.field)
+
+ def test_dates(self):
+ instance = DateTimeArrayModel(field=[timezone.now()])
+ instance.save()
+ loaded = DateTimeArrayModel.objects.get()
+ self.assertEqual(instance.field, loaded.field)
+
+ def test_tuples(self):
+ instance = IntegerArrayModel(field=(1,))
+ instance.save()
+ loaded = IntegerArrayModel.objects.get()
+ self.assertSequenceEqual(instance.field, loaded.field)
+
+ def test_integers_passed_as_strings(self):
+ # This checks that get_prep_value is deferred properly
+ instance = IntegerArrayModel(field=['1'])
+ instance.save()
+ loaded = IntegerArrayModel.objects.get()
+ self.assertEqual(loaded.field, [1])
+
+ def test_null_handling(self):
+ instance = NullableIntegerArrayModel(field=None)
+ instance.save()
+ loaded = NullableIntegerArrayModel.objects.get()
+ self.assertEqual(instance.field, loaded.field)
+
+ instance = IntegerArrayModel(field=None)
+ with self.assertRaises(IntegrityError):
+ instance.save()
+
+ def test_nested(self):
+ instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
+ instance.save()
+ loaded = NestedIntegerArrayModel.objects.get()
+ self.assertEqual(instance.field, loaded.field)
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class TestQuerying(TestCase):
+
+ def setUp(self):
+ self.objs = [
+ NullableIntegerArrayModel.objects.create(field=[1]),
+ NullableIntegerArrayModel.objects.create(field=[2]),
+ NullableIntegerArrayModel.objects.create(field=[2, 3]),
+ NullableIntegerArrayModel.objects.create(field=[20, 30, 40]),
+ NullableIntegerArrayModel.objects.create(field=None),
+ ]
+
+ def test_exact(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__exact=[1]),
+ self.objs[:1]
+ )
+
+ def test_isnull(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__isnull=True),
+ self.objs[-1:]
+ )
+
+ def test_gt(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__gt=[0]),
+ self.objs[:4]
+ )
+
+ def test_lt(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__lt=[2]),
+ self.objs[:1]
+ )
+
+ def test_in(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
+ self.objs[:2]
+ )
+
+ def test_contained_by(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
+ self.objs[:2]
+ )
+
+ def test_contains(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__contains=[2]),
+ self.objs[1:3]
+ )
+
+ def test_index(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__0=2),
+ self.objs[1:3]
+ )
+
+ def test_index_chained(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__0__lt=3),
+ self.objs[0:3]
+ )
+
+ def test_index_nested(self):
+ instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
+ self.assertSequenceEqual(
+ NestedIntegerArrayModel.objects.filter(field__0__0=1),
+ [instance]
+ )
+
+ @unittest.expectedFailure
+ def test_index_used_on_nested_data(self):
+ instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
+ self.assertSequenceEqual(
+ NestedIntegerArrayModel.objects.filter(field__0=[1, 2]),
+ [instance]
+ )
+
+ def test_overlap(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
+ self.objs[0:3]
+ )
+
+ def test_len(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__len__lte=2),
+ self.objs[0:3]
+ )
+
+ def test_slice(self):
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__0_1=[2]),
+ self.objs[1:3]
+ )
+
+ self.assertSequenceEqual(
+ NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]),
+ self.objs[2:3]
+ )
+
+ @unittest.expectedFailure
+ def test_slice_nested(self):
+ instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
+ self.assertSequenceEqual(
+ NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]),
+ [instance]
+ )
+
+
+class TestChecks(TestCase):
+
+ def test_field_checks(self):
+ field = ArrayField(models.CharField())
+ field.set_attributes_from_name('field')
+ errors = field.check()
+ self.assertEqual(len(errors), 1)
+ self.assertEqual(errors[0].id, 'postgres.E001')
+
+ def test_invalid_base_fields(self):
+ field = ArrayField(models.ManyToManyField('postgres_tests.IntegerArrayModel'))
+ field.set_attributes_from_name('field')
+ errors = field.check()
+ self.assertEqual(len(errors), 1)
+ self.assertEqual(errors[0].id, 'postgres.E002')
+
+
+class TestMigrations(TestCase):
+
+ def test_deconstruct(self):
+ field = ArrayField(models.IntegerField())
+ name, path, args, kwargs = field.deconstruct()
+ new = ArrayField(*args, **kwargs)
+ self.assertEqual(type(new.base_field), type(field.base_field))
+
+ def test_deconstruct_with_size(self):
+ field = ArrayField(models.IntegerField(), size=3)
+ name, path, args, kwargs = field.deconstruct()
+ new = ArrayField(*args, **kwargs)
+ self.assertEqual(new.size, field.size)
+
+ def test_deconstruct_args(self):
+ field = ArrayField(models.CharField(max_length=20))
+ name, path, args, kwargs = field.deconstruct()
+ new = ArrayField(*args, **kwargs)
+ self.assertEqual(new.base_field.max_length, field.base_field.max_length)
+
+ def test_makemigrations(self):
+ field = ArrayField(models.CharField(max_length=20))
+ statement, imports = MigrationWriter.serialize(field)
+ self.assertEqual(statement, 'django.contrib.postgres.fields.ArrayField(models.CharField(max_length=20), size=None)')
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class TestSerialization(TestCase):
+ test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
+
+ def test_dumping(self):
+ instance = IntegerArrayModel(field=[1, 2])
+ data = serializers.serialize('json', [instance])
+ self.assertEqual(data, self.test_data)
+
+ def test_loading(self):
+ instance = list(serializers.deserialize('json', self.test_data))[0].object
+ self.assertEqual(instance.field, [1, 2])
+
+
+class TestValidation(TestCase):
+
+ def test_unbounded(self):
+ field = ArrayField(models.IntegerField())
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean([1, None], None)
+ self.assertEqual(cm.exception.code, 'item_invalid')
+ self.assertEqual(cm.exception.message % cm.exception.params, 'Item 1 in the array did not validate: This field cannot be null.')
+
+ def test_blank_true(self):
+ field = ArrayField(models.IntegerField(blank=True, null=True))
+ # This should not raise a validation error
+ field.clean([1, None], None)
+
+ def test_with_size(self):
+ field = ArrayField(models.IntegerField(), size=3)
+ field.clean([1, 2, 3], None)
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean([1, 2, 3, 4], None)
+ self.assertEqual(cm.exception.messages[0], 'List contains 4 items, it should contain no more than 3.')
+
+ def test_nested_array_mismatch(self):
+ field = ArrayField(ArrayField(models.IntegerField()))
+ field.clean([[1, 2], [3, 4]], None)
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean([[1, 2], [3, 4, 5]], None)
+ self.assertEqual(cm.exception.code, 'nested_array_mismatch')
+ self.assertEqual(cm.exception.messages[0], 'Nested arrays must have the same length.')
+
+
+class TestSimpleFormField(TestCase):
+
+ def test_valid(self):
+ field = SimpleArrayField(forms.CharField())
+ value = field.clean('a,b,c')
+ self.assertEqual(value, ['a', 'b', 'c'])
+
+ def test_to_python_fail(self):
+ field = SimpleArrayField(forms.IntegerField())
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('a,b,9')
+ self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a whole number.')
+
+ def test_validate_fail(self):
+ field = SimpleArrayField(forms.CharField(required=True))
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('a,b,')
+ self.assertEqual(cm.exception.messages[0], 'Item 2 in the array did not validate: This field is required.')
+
+ def test_validators_fail(self):
+ field = SimpleArrayField(forms.RegexField('[a-e]{2}'))
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('a,bc,de')
+ self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Enter a valid value.')
+
+ def test_delimiter(self):
+ field = SimpleArrayField(forms.CharField(), delimiter='|')
+ value = field.clean('a|b|c')
+ self.assertEqual(value, ['a', 'b', 'c'])
+
+ def test_delimiter_with_nesting(self):
+ field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter='|')
+ value = field.clean('a,b|c,d')
+ self.assertEqual(value, [['a', 'b'], ['c', 'd']])
+
+ def test_prepare_value(self):
+ field = SimpleArrayField(forms.CharField())
+ value = field.prepare_value(['a', 'b', 'c'])
+ self.assertEqual(value, 'a,b,c')
+
+ def test_max_length(self):
+ field = SimpleArrayField(forms.CharField(), max_length=2)
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('a,b,c')
+ self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no more than 2.')
+
+ def test_min_length(self):
+ field = SimpleArrayField(forms.CharField(), min_length=4)
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('a,b,c')
+ self.assertEqual(cm.exception.messages[0], 'List contains 3 items, it should contain no fewer than 4.')
+
+ def test_required(self):
+ field = SimpleArrayField(forms.CharField(), required=True)
+ with self.assertRaises(exceptions.ValidationError) as cm:
+ field.clean('')
+ self.assertEqual(cm.exception.messages[0], 'This field is required.')
+
+ def test_model_field_formfield(self):
+ model_field = ArrayField(models.CharField(max_length=27))
+ form_field = model_field.formfield()
+ self.assertIsInstance(form_field, SimpleArrayField)
+ self.assertIsInstance(form_field.base_field, forms.CharField)
+ self.assertEqual(form_field.base_field.max_length, 27)
+
+ def test_model_field_formfield_size(self):
+ model_field = ArrayField(models.CharField(max_length=27), size=4)
+ form_field = model_field.formfield()
+ self.assertIsInstance(form_field, SimpleArrayField)
+ self.assertEqual(form_field.max_length, 4)
+
+
+class TestSplitFormField(TestCase):
+
+ def test_valid(self):
+ class SplitForm(forms.Form):
+ array = SplitArrayField(forms.CharField(), size=3)
+
+ data = {'array_0': 'a', 'array_1': 'b', 'array_2': 'c'}
+ form = SplitForm(data)
+ self.assertTrue(form.is_valid())
+ self.assertEqual(form.cleaned_data, {'array': ['a', 'b', 'c']})
+
+ def test_required(self):
+ class SplitForm(forms.Form):
+ array = SplitArrayField(forms.CharField(), required=True, size=3)
+
+ data = {'array_0': '', 'array_1': '', 'array_2': ''}
+ form = SplitForm(data)
+ self.assertFalse(form.is_valid())
+ self.assertEqual(form.errors, {'array': ['This field is required.']})
+
+ def test_remove_trailing_nulls(self):
+ class SplitForm(forms.Form):
+ array = SplitArrayField(forms.CharField(required=False), size=5, remove_trailing_nulls=True)
+
+ data = {'array_0': 'a', 'array_1': '', 'array_2': 'b', 'array_3': '', 'array_4': ''}
+ form = SplitForm(data)
+ self.assertTrue(form.is_valid(), form.errors)
+ self.assertEqual(form.cleaned_data, {'array': ['a', '', 'b']})
+
+ def test_required_field(self):
+ class SplitForm(forms.Form):
+ array = SplitArrayField(forms.CharField(), size=3)
+
+ data = {'array_0': 'a', 'array_1': 'b', 'array_2': ''}
+ form = SplitForm(data)
+ self.assertFalse(form.is_valid())
+ self.assertEqual(form.errors, {'array': ['Item 2 in the array did not validate: This field is required.']})
+
+ def test_rendering(self):
+ class SplitForm(forms.Form):
+ array = SplitArrayField(forms.CharField(), size=3)
+
+ self.assertHTMLEqual(str(SplitForm()), '''
+
+ |
+
+
+
+
+ |
+
+ ''')
diff --git a/tests/runtests.py b/tests/runtests.py
index 787b83e7a5..14014f4b01 100755
--- a/tests/runtests.py
+++ b/tests/runtests.py
@@ -57,6 +57,7 @@ ALWAYS_INSTALLED_APPS = [
def get_test_modules():
from django.contrib.gis.tests.utils import HAS_SPATIAL_DB
+ from django.db import connection
modules = []
discovery_paths = [
(None, RUNTESTS_DIR),
@@ -75,6 +76,8 @@ def get_test_modules():
os.path.isfile(f) or
not os.path.exists(os.path.join(dirpath, f, '__init__.py'))):
continue
+ if not connection.vendor == 'postgresql' and f == 'postgres_tests':
+ continue
modules.append((modpath, f))
return modules