import time
import unittest
from datetime import date, datetime

from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models.fields.related_lookups import RelatedGreaterThan
from django.db.models.lookups import EndsWith, StartsWith
from django.test import SimpleTestCase, TestCase, override_settings
from django.test.utils import register_lookup
from django.utils import timezone

from .models import Article, Author, MySQLUnixTimestamp


class Div3Lookup(models.Lookup):
    lookup_name = "div3"

    def as_sql(self, compiler, connection):
        lhs, params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params.extend(rhs_params)
        return "(%s) %%%% 3 = %s" % (lhs, rhs), params

    def as_oracle(self, compiler, connection):
        lhs, params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params.extend(rhs_params)
        return "mod(%s, 3) = %s" % (lhs, rhs), params


class Div3Transform(models.Transform):
    lookup_name = "div3"

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        return "(%s) %%%% 3" % lhs, lhs_params

    def as_oracle(self, compiler, connection, **extra_context):
        lhs, lhs_params = compiler.compile(self.lhs)
        return "mod(%s, 3)" % lhs, lhs_params


class Div3BilateralTransform(Div3Transform):
    bilateral = True


class Mult3BilateralTransform(models.Transform):
    bilateral = True
    lookup_name = "mult3"

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        return "3 * (%s)" % lhs, lhs_params


class LastDigitTransform(models.Transform):
    lookup_name = "lastdigit"

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params


class UpperBilateralTransform(models.Transform):
    bilateral = True
    lookup_name = "upper"

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        return "UPPER(%s)" % lhs, lhs_params


class YearTransform(models.Transform):
    # Use a name that avoids collision with the built-in year lookup.
    lookup_name = "testyear"

    def as_sql(self, compiler, connection):
        lhs_sql, params = compiler.compile(self.lhs)
        return connection.ops.date_extract_sql("year", lhs_sql, params)

    @property
    def output_field(self):
        return models.IntegerField()


@YearTransform.register_lookup
class YearExact(models.lookups.Lookup):
    lookup_name = "exact"

    def as_sql(self, compiler, connection):
        # We will need to skip the extract part, and instead go
        # directly with the originating field, that is self.lhs.lhs
        lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        # Note that we must be careful so that we have params in the
        # same order as we have the parts in the SQL.
        params = lhs_params + rhs_params + lhs_params + rhs_params
        # We use PostgreSQL specific SQL here. Note that we must do the
        # conversions in SQL instead of in Python to support F() references.
        return (
            "%(lhs)s >= (%(rhs)s || '-01-01')::date "
            "AND %(lhs)s <= (%(rhs)s || '-12-31')::date"
            % {"lhs": lhs_sql, "rhs": rhs_sql},
            params,
        )


@YearTransform.register_lookup
class YearLte(models.lookups.LessThanOrEqual):
    """
    The purpose of this lookup is to efficiently compare the year of the field.
    """

    def as_sql(self, compiler, connection):
        # Skip the YearTransform above us (no possibility for efficient
        # lookup otherwise).
        real_lhs = self.lhs.lhs
        lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        params.extend(rhs_params)
        # Build SQL where the integer year is concatenated with last month
        # and day, then convert that to date. (We try to have SQL like:
        #     WHERE somecol <= '2013-12-31')
        # but also make it work if the rhs_sql is field reference.
        return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params


class Exactly(models.lookups.Exact):
    """
    This lookup is used to test lookup registration.
    """

    lookup_name = "exactly"

    def get_rhs_op(self, connection, rhs):
        return connection.operators["exact"] % rhs


class SQLFuncMixin:
    def as_sql(self, compiler, connection):
        return "%s()" % self.name, []

    @property
    def output_field(self):
        return CustomField()


class SQLFuncLookup(SQLFuncMixin, models.Lookup):
    def __init__(self, name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = name


class SQLFuncTransform(SQLFuncMixin, models.Transform):
    def __init__(self, name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = name


class SQLFuncFactory:
    def __init__(self, key, name):
        self.key = key
        self.name = name

    def __call__(self, *args, **kwargs):
        if self.key == "lookupfunc":
            return SQLFuncLookup(self.name, *args, **kwargs)
        return SQLFuncTransform(self.name, *args, **kwargs)


class CustomField(models.TextField):
    def get_lookup(self, lookup_name):
        if lookup_name.startswith("lookupfunc_"):
            key, name = lookup_name.split("_", 1)
            return SQLFuncFactory(key, name)
        return super().get_lookup(lookup_name)

    def get_transform(self, lookup_name):
        if lookup_name.startswith("transformfunc_"):
            key, name = lookup_name.split("_", 1)
            return SQLFuncFactory(key, name)
        return super().get_transform(lookup_name)


class CustomModel(models.Model):
    field = CustomField()


# We will register this class temporarily in the test method.


class InMonth(models.lookups.Lookup):
    """
    InMonth matches if the column's month is the same as value's month.
    """

    lookup_name = "inmonth"

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        # We need to be careful so that we get the params in right
        # places.
        params = lhs_params + rhs_params + lhs_params + rhs_params
        return (
            "%s >= date_trunc('month', %s) and "
            "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs),
            params,
        )


class DateTimeTransform(models.Transform):
    lookup_name = "as_datetime"

    @property
    def output_field(self):
        return models.DateTimeField()

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return "from_unixtime({})".format(lhs), params


class CustomStartsWith(StartsWith):
    lookup_name = "sw"


class CustomEndsWith(EndsWith):
    lookup_name = "ew"


class RelatedMoreThan(RelatedGreaterThan):
    lookup_name = "rmt"


class LookupTests(TestCase):
    def test_custom_name_lookup(self):
        a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
        Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
        with (
            register_lookup(models.DateField, YearTransform),
            register_lookup(models.DateField, YearTransform, lookup_name="justtheyear"),
            register_lookup(YearTransform, Exactly),
            register_lookup(YearTransform, Exactly, lookup_name="isactually"),
        ):
            qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
            qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
            self.assertSequenceEqual(qs1, [a1])
            self.assertSequenceEqual(qs2, [a1])

    def test_custom_exact_lookup_none_rhs(self):
        """
        __exact=None is transformed to __isnull=True if a custom lookup class
        with lookup_name != 'exact' is registered as the `exact` lookup.
        """
        field = Author._meta.get_field("birthdate")
        OldExactLookup = field.get_lookup("exact")
        author = Author.objects.create(name="author", birthdate=None)
        try:
            field.register_lookup(Exactly, "exact")
            self.assertEqual(Author.objects.get(birthdate__exact=None), author)
        finally:
            field.register_lookup(OldExactLookup, "exact")

    def test_basic_lookup(self):
        a1 = Author.objects.create(name="a1", age=1)
        a2 = Author.objects.create(name="a2", age=2)
        a3 = Author.objects.create(name="a3", age=3)
        a4 = Author.objects.create(name="a4", age=4)
        with register_lookup(models.IntegerField, Div3Lookup):
            self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
            self.assertSequenceEqual(
                Author.objects.filter(age__div3=1).order_by("age"), [a1, a4]
            )
            self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
            self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])

    @unittest.skipUnless(
        connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    )
    def test_birthdate_month(self):
        a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
        a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
        a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
        a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
        with register_lookup(models.DateField, InMonth):
            self.assertSequenceEqual(
                Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3]
            )
            self.assertSequenceEqual(
                Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2]
            )
            self.assertSequenceEqual(
                Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1]
            )
            self.assertSequenceEqual(
                Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4]
            )
            self.assertSequenceEqual(
                Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), []
            )

    def test_div3_extract(self):
        with register_lookup(models.IntegerField, Div3Transform):
            a1 = Author.objects.create(name="a1", age=1)
            a2 = Author.objects.create(name="a2", age=2)
            a3 = Author.objects.create(name="a3", age=3)
            a4 = Author.objects.create(name="a4", age=4)
            baseqs = Author.objects.order_by("name")
            self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
            self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
            self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
            self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
            self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
            self.assertSequenceEqual(
                baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
            )

    def test_foreignobject_lookup_registration(self):
        field = Article._meta.get_field("author")

        with register_lookup(models.ForeignObject, Exactly):
            self.assertIs(field.get_lookup("exactly"), Exactly)

        # ForeignObject should ignore regular Field lookups
        with register_lookup(models.Field, Exactly):
            self.assertIsNone(field.get_lookup("exactly"))

    def test_lookups_caching(self):
        field = Article._meta.get_field("author")

        # clear and re-cache
        field.get_class_lookups.cache_clear()
        self.assertNotIn("exactly", field.get_lookups())

        # registration should bust the cache
        with register_lookup(models.ForeignObject, Exactly):
            # getting the lookups again should re-cache
            self.assertIn("exactly", field.get_lookups())
        # Unregistration should bust the cache.
        self.assertNotIn("exactly", field.get_lookups())


class BilateralTransformTests(TestCase):
    def test_bilateral_upper(self):
        with register_lookup(models.CharField, UpperBilateralTransform):
            author1 = Author.objects.create(name="Doe")
            author2 = Author.objects.create(name="doe")
            author3 = Author.objects.create(name="Foo")
            self.assertCountEqual(
                Author.objects.filter(name__upper="doe"),
                [author1, author2],
            )
            self.assertSequenceEqual(
                Author.objects.filter(name__upper__contains="f"),
                [author3],
            )

    def test_bilateral_inner_qs(self):
        with register_lookup(models.CharField, UpperBilateralTransform):
            msg = "Bilateral transformations on nested querysets are not implemented."
            with self.assertRaisesMessage(NotImplementedError, msg):
                Author.objects.filter(
                    name__upper__in=Author.objects.values_list("name")
                )

    def test_bilateral_multi_value(self):
        with register_lookup(models.CharField, UpperBilateralTransform):
            Author.objects.bulk_create(
                [
                    Author(name="Foo"),
                    Author(name="Bar"),
                    Author(name="Ray"),
                ]
            )
            self.assertQuerySetEqual(
                Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by(
                    "name"
                ),
                ["Bar", "Foo"],
                lambda a: a.name,
            )

    def test_div3_bilateral_extract(self):
        with register_lookup(models.IntegerField, Div3BilateralTransform):
            a1 = Author.objects.create(name="a1", age=1)
            a2 = Author.objects.create(name="a2", age=2)
            a3 = Author.objects.create(name="a3", age=3)
            a4 = Author.objects.create(name="a4", age=4)
            baseqs = Author.objects.order_by("name")
            self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
            self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
            self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
            self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
            self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
            self.assertSequenceEqual(
                baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
            )

    def test_bilateral_order(self):
        with register_lookup(
            models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform
        ):
            a1 = Author.objects.create(name="a1", age=1)
            a2 = Author.objects.create(name="a2", age=2)
            a3 = Author.objects.create(name="a3", age=3)
            a4 = Author.objects.create(name="a4", age=4)
            baseqs = Author.objects.order_by("name")

            # mult3__div3 always leads to 0
            self.assertSequenceEqual(
                baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]
            )
            self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])

    def test_transform_order_by(self):
        with register_lookup(models.IntegerField, LastDigitTransform):
            a1 = Author.objects.create(name="a1", age=11)
            a2 = Author.objects.create(name="a2", age=23)
            a3 = Author.objects.create(name="a3", age=32)
            a4 = Author.objects.create(name="a4", age=40)
            qs = Author.objects.order_by("age__lastdigit")
            self.assertSequenceEqual(qs, [a4, a1, a3, a2])

    def test_bilateral_fexpr(self):
        with register_lookup(models.IntegerField, Mult3BilateralTransform):
            a1 = Author.objects.create(name="a1", age=1, average_rating=3.2)
            a2 = Author.objects.create(name="a2", age=2, average_rating=0.5)
            a3 = Author.objects.create(name="a3", age=3, average_rating=1.5)
            a4 = Author.objects.create(name="a4", age=4)
            baseqs = Author.objects.order_by("name")
            self.assertSequenceEqual(
                baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4]
            )
            # Same as age >= average_rating
            self.assertSequenceEqual(
                baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3]
            )


@override_settings(USE_TZ=True)
class DateTimeLookupTests(TestCase):
    @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used")
    def test_datetime_output_field(self):
        with register_lookup(models.PositiveIntegerField, DateTimeTransform):
            ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
            y2k = timezone.make_aware(datetime(2000, 1, 1))
            self.assertSequenceEqual(
                MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut]
            )


class YearLteTests(TestCase):
    @classmethod
    def setUpTestData(cls):
        cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
        cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
        cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
        cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))

    def setUp(self):
        models.DateField.register_lookup(YearTransform)
        self.addCleanup(models.DateField._unregister_lookup, YearTransform)

    @unittest.skipUnless(
        connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    )
    def test_year_lte(self):
        baseqs = Author.objects.order_by("name")
        self.assertSequenceEqual(
            baseqs.filter(birthdate__testyear__lte=2012),
            [self.a1, self.a2, self.a3, self.a4],
        )
        self.assertSequenceEqual(
            baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4]
        )

        self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query))
        self.assertSequenceEqual(
            baseqs.filter(birthdate__testyear__lte=2011), [self.a1]
        )
        # The non-optimized version works, too.
        self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])

    @unittest.skipUnless(
        connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    )
    def test_year_lte_fexpr(self):
        self.a2.age = 2011
        self.a2.save()
        self.a3.age = 2012
        self.a3.save()
        self.a4.age = 2013
        self.a4.save()
        baseqs = Author.objects.order_by("name")
        self.assertSequenceEqual(
            baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4]
        )
        self.assertSequenceEqual(
            baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4]
        )

    def test_year_lte_sql(self):
        # This test will just check the generated SQL for __lte. This
        # doesn't require running on PostgreSQL and spots the most likely
        # error - not running YearLte SQL at all.
        baseqs = Author.objects.order_by("name")
        self.assertIn(
            "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query)
        )
        self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query))

    def test_postgres_year_exact(self):
        baseqs = Author.objects.order_by("name")
        self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query))
        self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query))

    def test_custom_implementation_year_exact(self):
        try:
            # Two ways to add a customized implementation for different backends:
            # First is MonkeyPatch of the class.
            def as_custom_sql(self, compiler, connection):
                lhs_sql, lhs_params = self.process_lhs(
                    compiler, connection, self.lhs.lhs
                )
                rhs_sql, rhs_params = self.process_rhs(compiler, connection)
                params = lhs_params + rhs_params + lhs_params + rhs_params
                return (
                    "%(lhs)s >= "
                    "str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
                    "AND %(lhs)s <= "
                    "str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
                    % {"lhs": lhs_sql, "rhs": rhs_sql},
                    params,
                )

            setattr(YearExact, "as_" + connection.vendor, as_custom_sql)
            self.assertIn(
                "concat(", str(Author.objects.filter(birthdate__testyear=2012).query)
            )
        finally:
            delattr(YearExact, "as_" + connection.vendor)
        try:
            # The other way is to subclass the original lookup and register the
            # subclassed lookup instead of the original.
            class CustomYearExact(YearExact):
                # This method should be named "as_mysql" for MySQL,
                # "as_postgresql" for postgres and so on, but as we don't know
                # which DB we are running on, we need to use setattr.
                def as_custom_sql(self, compiler, connection):
                    lhs_sql, lhs_params = self.process_lhs(
                        compiler, connection, self.lhs.lhs
                    )
                    rhs_sql, rhs_params = self.process_rhs(compiler, connection)
                    params = lhs_params + rhs_params + lhs_params + rhs_params
                    return (
                        "%(lhs)s >= "
                        "str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
                        "AND %(lhs)s <= "
                        "str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
                        % {"lhs": lhs_sql, "rhs": rhs_sql},
                        params,
                    )

            setattr(
                CustomYearExact,
                "as_" + connection.vendor,
                CustomYearExact.as_custom_sql,
            )
            YearTransform.register_lookup(CustomYearExact)
            self.assertIn(
                "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query)
            )
        finally:
            YearTransform._unregister_lookup(CustomYearExact)
            YearTransform.register_lookup(YearExact)


class TrackCallsYearTransform(YearTransform):
    # Use a name that avoids collision with the built-in year lookup.
    lookup_name = "testyear"
    call_order = []

    def as_sql(self, compiler, connection):
        lhs_sql, params = compiler.compile(self.lhs)
        return connection.ops.date_extract_sql("year", lhs_sql), params

    @property
    def output_field(self):
        return models.IntegerField()

    def get_lookup(self, lookup_name):
        self.call_order.append("lookup")
        return super().get_lookup(lookup_name)

    def get_transform(self, lookup_name):
        self.call_order.append("transform")
        return super().get_transform(lookup_name)


class LookupTransformCallOrderTests(SimpleTestCase):
    def test_call_order(self):
        with register_lookup(models.DateField, TrackCallsYearTransform):
            # junk lookup - tries lookup, then transform, then fails
            msg = (
                "Unsupported lookup 'junk' for IntegerField or join on the field not "
                "permitted."
            )
            with self.assertRaisesMessage(FieldError, msg):
                Author.objects.filter(birthdate__testyear__junk=2012)
            self.assertEqual(
                TrackCallsYearTransform.call_order, ["lookup", "transform"]
            )
            TrackCallsYearTransform.call_order = []
            # junk transform - tries transform only, then fails
            msg = (
                "Unsupported lookup 'junk__more_junk' for IntegerField or join"
                " on the field not permitted."
            )
            with self.assertRaisesMessage(FieldError, msg):
                Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
            self.assertEqual(TrackCallsYearTransform.call_order, ["transform"])
            TrackCallsYearTransform.call_order = []
            # Just getting the year (implied __exact) - lookup only
            Author.objects.filter(birthdate__testyear=2012)
            self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
            TrackCallsYearTransform.call_order = []
            # Just getting the year (explicit __exact) - lookup only
            Author.objects.filter(birthdate__testyear__exact=2012)
            self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])


class CustomisedMethodsTests(SimpleTestCase):
    def test_overridden_get_lookup(self):
        q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
        self.assertIn("monkeys()", str(q.query))

    def test_overridden_get_transform(self):
        q = CustomModel.objects.filter(field__transformfunc_banana=3)
        self.assertIn("banana()", str(q.query))

    def test_overridden_get_lookup_chain(self):
        q = CustomModel.objects.filter(
            field__transformfunc_banana__lookupfunc_elephants=3
        )
        self.assertIn("elephants()", str(q.query))

    def test_overridden_get_transform_chain(self):
        q = CustomModel.objects.filter(
            field__transformfunc_banana__transformfunc_pear=3
        )
        self.assertIn("pear()", str(q.query))


class SubqueryTransformTests(TestCase):
    def test_subquery_usage(self):
        with register_lookup(models.IntegerField, Div3Transform):
            Author.objects.create(name="a1", age=1)
            a2 = Author.objects.create(name="a2", age=2)
            Author.objects.create(name="a3", age=3)
            Author.objects.create(name="a4", age=4)
            qs = Author.objects.order_by("name").filter(
                id__in=Author.objects.filter(age__div3=2)
            )
            self.assertSequenceEqual(qs, [a2])


class RegisterLookupTests(SimpleTestCase):
    def test_class_lookup(self):
        author_name = Author._meta.get_field("name")
        with register_lookup(models.CharField, CustomStartsWith):
            self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
        self.assertIsNone(author_name.get_lookup("sw"))

    def test_instance_lookup(self):
        author_name = Author._meta.get_field("name")
        author_alias = Author._meta.get_field("alias")
        with register_lookup(author_name, CustomStartsWith):
            self.assertEqual(author_name.instance_lookups, {"sw": CustomStartsWith})
            self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
            self.assertIsNone(author_alias.get_lookup("sw"))
        self.assertIsNone(author_name.get_lookup("sw"))
        self.assertEqual(author_name.instance_lookups, {})
        self.assertIsNone(author_alias.get_lookup("sw"))

    def test_instance_lookup_override_class_lookups(self):
        author_name = Author._meta.get_field("name")
        author_alias = Author._meta.get_field("alias")
        with register_lookup(models.CharField, CustomStartsWith, lookup_name="st_end"):
            with register_lookup(author_alias, CustomEndsWith, lookup_name="st_end"):
                self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
                self.assertEqual(author_alias.get_lookup("st_end"), CustomEndsWith)
            self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
            self.assertEqual(author_alias.get_lookup("st_end"), CustomStartsWith)
        self.assertIsNone(author_name.get_lookup("st_end"))
        self.assertIsNone(author_alias.get_lookup("st_end"))

    def test_instance_lookup_override(self):
        author_name = Author._meta.get_field("name")
        with register_lookup(author_name, CustomStartsWith, lookup_name="st_end"):
            self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
            author_name.register_lookup(CustomEndsWith, lookup_name="st_end")
            self.assertEqual(author_name.get_lookup("st_end"), CustomEndsWith)
        self.assertIsNone(author_name.get_lookup("st_end"))

    def test_lookup_on_transform(self):
        transform = Div3Transform
        with register_lookup(Div3Transform, CustomStartsWith):
            with register_lookup(Div3Transform, CustomEndsWith):
                self.assertEqual(
                    transform.get_lookups(),
                    {"sw": CustomStartsWith, "ew": CustomEndsWith},
                )
            self.assertEqual(transform.get_lookups(), {"sw": CustomStartsWith})
        self.assertEqual(transform.get_lookups(), {})

    def test_transform_on_field(self):
        author_name = Author._meta.get_field("name")
        author_alias = Author._meta.get_field("alias")
        with register_lookup(models.CharField, Div3Transform):
            self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
            self.assertEqual(author_name.get_transform("div3"), Div3Transform)
        with register_lookup(author_alias, Div3Transform):
            self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
            self.assertIsNone(author_name.get_transform("div3"))
        self.assertIsNone(author_alias.get_transform("div3"))
        self.assertIsNone(author_name.get_transform("div3"))

    def test_related_lookup(self):
        article_author = Article._meta.get_field("author")
        with register_lookup(models.Field, CustomStartsWith):
            self.assertIsNone(article_author.get_lookup("sw"))
        with register_lookup(models.ForeignKey, RelatedMoreThan):
            self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)

    def test_instance_related_lookup(self):
        article_author = Article._meta.get_field("author")
        with register_lookup(article_author, RelatedMoreThan):
            self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)
        self.assertIsNone(article_author.get_lookup("rmt"))