From 5a36c81f58b8ff45d8dac052343722c54e3fa521 Mon Sep 17 00:00:00 2001
From: Vinay Karanam <vinayinvicible@gmail.com>
Date: Mon, 4 Feb 2019 04:57:19 +0530
Subject: [PATCH] Fixed #29391 -- Made PostgresSimpleLookup respect
 Field.get_db_prep_value().

---
 django/contrib/postgres/lookups.py              |  4 ++--
 tests/postgres_tests/fields.py                  |  7 +++++++
 .../migrations/0002_create_test_models.py       | 15 +++++++++++++--
 tests/postgres_tests/models.py                  |  8 ++++++--
 tests/postgres_tests/test_array.py              | 17 ++++++++++++++---
 5 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py
index c2b3d2b569..f0a523d849 100644
--- a/django/contrib/postgres/lookups.py
+++ b/django/contrib/postgres/lookups.py
@@ -1,10 +1,10 @@
 from django.db.models import Lookup, Transform
-from django.db.models.lookups import Exact
+from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin
 
 from .search import SearchVector, SearchVectorExact, SearchVectorField
 
 
-class PostgresSimpleLookup(Lookup):
+class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup):
     def as_sql(self, qn, connection):
         lhs, lhs_params = self.process_lhs(qn, connection)
         rhs, rhs_params = self.process_rhs(qn, connection)
diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py
index 2275eb2ab2..4ebc0ce7dc 100644
--- a/tests/postgres_tests/fields.py
+++ b/tests/postgres_tests/fields.py
@@ -2,6 +2,8 @@
 Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
 run with a backend other than PostgreSQL.
 """
+import enum
+
 from django.db import models
 
 try:
@@ -40,3 +42,8 @@ except ImportError:
     IntegerRangeField = models.Field
     JSONField = DummyJSONField
     SearchVectorField = models.Field
+
+
+class EnumField(models.CharField):
+    def get_prep_value(self, value):
+        return value.value if isinstance(value, enum.Enum) else value
diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py
index 5db8a71385..dc941de139 100644
--- a/tests/postgres_tests/migrations/0002_create_test_models.py
+++ b/tests/postgres_tests/migrations/0002_create_test_models.py
@@ -3,8 +3,8 @@ from django.db import migrations, models
 
 from ..fields import (
     ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
-    DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
-    IntegerRangeField, JSONField, SearchVectorField,
+    DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
+    HStoreField, IntegerRangeField, JSONField, SearchVectorField,
 )
 from ..models import TagField
 
@@ -249,4 +249,15 @@ class Migration(migrations.Migration):
             },
             bases=(models.Model,),
         ),
+        migrations.CreateModel(
+            name='ArrayEnumModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
+            ],
+            options={
+                'required_db_vendor': 'postgresql',
+            },
+            bases=(models.Model,),
+        ),
     ]
diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py
index cbe477e402..2bb6e6fcdf 100644
--- a/tests/postgres_tests/models.py
+++ b/tests/postgres_tests/models.py
@@ -3,8 +3,8 @@ from django.db import models
 
 from .fields import (
     ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
-    DateRangeField, DateTimeRangeField, DecimalRangeField, HStoreField,
-    IntegerRangeField, JSONField, SearchVectorField,
+    DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
+    HStoreField, IntegerRangeField, JSONField, SearchVectorField,
 )
 
 
@@ -77,6 +77,10 @@ class HStoreModel(PostgreSQLModel):
     array_field = ArrayField(HStoreField(), null=True)
 
 
+class ArrayEnumModel(PostgreSQLModel):
+    array_of_enums = ArrayField(EnumField(max_length=20))
+
+
 class CharFieldModel(models.Model):
     field = models.CharField(max_length=16)
 
diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py
index 447d511c9f..465eac1785 100644
--- a/tests/postgres_tests/test_array.py
+++ b/tests/postgres_tests/test_array.py
@@ -1,4 +1,5 @@
 import decimal
+import enum
 import json
 import unittest
 import uuid
@@ -16,9 +17,9 @@ from . import (
     PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase,
 )
 from .models import (
-    ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
-    NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
-    PostgreSQLModel, Tag,
+    ArrayEnumModel, ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel,
+    IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel,
+    OtherTypesArrayModel, PostgreSQLModel, Tag,
 )
 
 try:
@@ -357,6 +358,16 @@ class TestQuerying(PostgreSQLTestCase):
             [self.objs[3]]
         )
 
+    def test_enum_lookup(self):
+        class TestEnum(enum.Enum):
+            VALUE_1 = 'value_1'
+
+        instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
+        self.assertSequenceEqual(
+            ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
+            [instance]
+        )
+
     def test_unsupported_lookup(self):
         msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
         with self.assertRaisesMessage(FieldError, msg):