mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #25143 -- Added ArrayField.from_db_value().
Thanks Karan Lyons for contributing to the patch.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							f8d20da047
						
					
				
				
					commit
					2495023a4c
				
			| @@ -28,6 +28,10 @@ class ArrayField(Field): | |||||||
|         if self.size: |         if self.size: | ||||||
|             self.default_validators = self.default_validators[:] |             self.default_validators = self.default_validators[:] | ||||||
|             self.default_validators.append(ArrayMaxLengthValidator(self.size)) |             self.default_validators.append(ArrayMaxLengthValidator(self.size)) | ||||||
|  |         # For performance, only add a from_db_value() method if the base field | ||||||
|  |         # implements it. | ||||||
|  |         if hasattr(self.base_field, 'from_db_value'): | ||||||
|  |             self.from_db_value = self._from_db_value | ||||||
|         super(ArrayField, self).__init__(**kwargs) |         super(ArrayField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -100,6 +104,14 @@ class ArrayField(Field): | |||||||
|             value = [self.base_field.to_python(val) for val in vals] |             value = [self.base_field.to_python(val) for val in vals] | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|  |     def _from_db_value(self, value, expression, connection, context): | ||||||
|  |         if value is None: | ||||||
|  |             return value | ||||||
|  |         return [ | ||||||
|  |             self.base_field.from_db_value(item, expression, connection, context) | ||||||
|  |             for item in value | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     def value_to_string(self, obj): |     def value_to_string(self, obj): | ||||||
|         values = [] |         values = [] | ||||||
|         vals = self.value_from_object(obj) |         vals = self.value_from_object(obj) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ from __future__ import unicode_literals | |||||||
| from django.db import migrations, models | from django.db import migrations, models | ||||||
|  |  | ||||||
| from ..fields import *  # NOQA | from ..fields import *  # NOQA | ||||||
|  | from ..models import TagField | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): | class Migration(migrations.Migration): | ||||||
| @@ -55,6 +56,7 @@ class Migration(migrations.Migration): | |||||||
|                 ('ips', ArrayField(models.GenericIPAddressField(), size=None)), |                 ('ips', ArrayField(models.GenericIPAddressField(), size=None)), | ||||||
|                 ('uuids', ArrayField(models.UUIDField(), size=None)), |                 ('uuids', ArrayField(models.UUIDField(), size=None)), | ||||||
|                 ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), |                 ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), | ||||||
|  |                 ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), | ||||||
|             ], |             ], | ||||||
|             options={ |             options={ | ||||||
|                 'required_db_vendor': 'postgresql', |                 'required_db_vendor': 'postgresql', | ||||||
|   | |||||||
| @@ -6,6 +6,35 @@ from .fields import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Tag(object): | ||||||
|  |     def __init__(self, tag_id): | ||||||
|  |         self.tag_id = tag_id | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return isinstance(other, Tag) and self.tag_id == other.tag_id | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TagField(models.SmallIntegerField): | ||||||
|  |  | ||||||
|  |     def from_db_value(self, value, expression, connection, context): | ||||||
|  |         if value is None: | ||||||
|  |             return value | ||||||
|  |         return Tag(int(value)) | ||||||
|  |  | ||||||
|  |     def to_python(self, value): | ||||||
|  |         if isinstance(value, Tag): | ||||||
|  |             return value | ||||||
|  |         if value is None: | ||||||
|  |             return value | ||||||
|  |         return Tag(int(value)) | ||||||
|  |  | ||||||
|  |     def get_prep_value(self, value): | ||||||
|  |         return value.tag_id | ||||||
|  |  | ||||||
|  |     def get_db_prep_value(self, value, connection, prepared=False): | ||||||
|  |         return self.get_prep_value(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PostgreSQLModel(models.Model): | class PostgreSQLModel(models.Model): | ||||||
|     class Meta: |     class Meta: | ||||||
|         abstract = True |         abstract = True | ||||||
| @@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel): | |||||||
|     ips = ArrayField(models.GenericIPAddressField()) |     ips = ArrayField(models.GenericIPAddressField()) | ||||||
|     uuids = ArrayField(models.UUIDField()) |     uuids = ArrayField(models.UUIDField()) | ||||||
|     decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) |     decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) | ||||||
|  |     tags = ArrayField(TagField(), blank=True, null=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class HStoreModel(PostgreSQLModel): | class HStoreModel(PostgreSQLModel): | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ from . import PostgreSQLTestCase | |||||||
| from .models import ( | from .models import ( | ||||||
|     ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, |     ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, | ||||||
|     NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, |     NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, | ||||||
|     PostgreSQLModel, |     PostgreSQLModel, Tag, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| try: | try: | ||||||
| @@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase): | |||||||
|             ips=['192.168.0.1', '::1'], |             ips=['192.168.0.1', '::1'], | ||||||
|             uuids=[uuid.uuid4()], |             uuids=[uuid.uuid4()], | ||||||
|             decimals=[decimal.Decimal(1.25), 1.75], |             decimals=[decimal.Decimal(1.25), 1.75], | ||||||
|  |             tags=[Tag(1), Tag(2), Tag(3)], | ||||||
|         ) |         ) | ||||||
|         instance.save() |         instance.save() | ||||||
|         loaded = OtherTypesArrayModel.objects.get() |         loaded = OtherTypesArrayModel.objects.get() | ||||||
|         self.assertEqual(instance.ips, loaded.ips) |         self.assertEqual(instance.ips, loaded.ips) | ||||||
|         self.assertEqual(instance.uuids, loaded.uuids) |         self.assertEqual(instance.uuids, loaded.uuids) | ||||||
|         self.assertEqual(instance.decimals, loaded.decimals) |         self.assertEqual(instance.decimals, loaded.decimals) | ||||||
|  |         self.assertEqual(instance.tags, loaded.tags) | ||||||
|  |  | ||||||
|  |     def test_null_from_db_value_handling(self): | ||||||
|  |         instance = OtherTypesArrayModel.objects.create( | ||||||
|  |             ips=['192.168.0.1', '::1'], | ||||||
|  |             uuids=[uuid.uuid4()], | ||||||
|  |             decimals=[decimal.Decimal(1.25), 1.75], | ||||||
|  |             tags=None, | ||||||
|  |         ) | ||||||
|  |         instance.refresh_from_db() | ||||||
|  |         self.assertIsNone(instance.tags) | ||||||
|  |  | ||||||
|     def test_model_set_on_base_field(self): |     def test_model_set_on_base_field(self): | ||||||
|         instance = IntegerArrayModel() |         instance = IntegerArrayModel() | ||||||
| @@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): | |||||||
|         self.ips = ['192.168.0.1', '::1'] |         self.ips = ['192.168.0.1', '::1'] | ||||||
|         self.uuids = [uuid.uuid4()] |         self.uuids = [uuid.uuid4()] | ||||||
|         self.decimals = [decimal.Decimal(1.25), 1.75] |         self.decimals = [decimal.Decimal(1.25), 1.75] | ||||||
|  |         self.tags = [Tag(1), Tag(2), Tag(3)] | ||||||
|         self.objs = [ |         self.objs = [ | ||||||
|             OtherTypesArrayModel.objects.create( |             OtherTypesArrayModel.objects.create( | ||||||
|                 ips=self.ips, |                 ips=self.ips, | ||||||
|                 uuids=self.uuids, |                 uuids=self.uuids, | ||||||
|                 decimals=self.decimals, |                 decimals=self.decimals, | ||||||
|  |                 tags=self.tags, | ||||||
|             ) |             ) | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
| @@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): | |||||||
|             self.objs |             self.objs | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_exact_tags(self): | ||||||
|  |         self.assertSequenceEqual( | ||||||
|  |             OtherTypesArrayModel.objects.filter(tags=self.tags), | ||||||
|  |             self.objs | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @isolate_apps('postgres_tests') | @isolate_apps('postgres_tests') | ||||||
| class TestChecks(PostgreSQLTestCase): | class TestChecks(PostgreSQLTestCase): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user