From 44f3ee77166bd5c0e8a4604f2d96015268dce100 Mon Sep 17 00:00:00 2001
From: Jon Dufresne <jon.dufresne@gmail.com>
Date: Sat, 7 Mar 2015 13:20:29 -0800
Subject: [PATCH] Fixed #9596 -- Added date transform for DateTimeField.

---
 django/db/backends/base/operations.py         |  6 +++
 django/db/backends/mysql/operations.py        | 25 +++++------
 django/db/backends/oracle/operations.py       | 20 ++++-----
 .../postgresql_psycopg2/operations.py         | 24 +++++-----
 django/db/backends/sqlite3/base.py            | 24 +++++++---
 django/db/backends/sqlite3/operations.py      | 18 ++++----
 django/db/models/fields/__init__.py           | 16 +++++++
 docs/ref/models/querysets.txt                 | 21 +++++++++
 docs/releases/1.9.txt                         |  6 +++
 tests/model_fields/tests.py                   | 45 ++++++++++++++++++-
 10 files changed, 155 insertions(+), 50 deletions(-)

diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index 198d62eb14..2125fa2efe 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -99,6 +99,12 @@ class BaseDatabaseOperations(object):
         """
         return "%s"
 
+    def datetime_cast_date_sql(self, field_name, tzname):
+        """
+        Returns the SQL necessary to cast a datetime value to date value.
+        """
+        raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_date() method')
+
     def datetime_extract_sql(self, lookup_type, field_name, tzname):
         """
         Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute' or
diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py
index fa0e1edcab..a496079932 100644
--- a/django/db/backends/mysql/operations.py
+++ b/django/db/backends/mysql/operations.py
@@ -39,27 +39,26 @@ class DatabaseOperations(BaseDatabaseOperations):
             sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
         return sql
 
-    def datetime_extract_sql(self, lookup_type, field_name, tzname):
+    def _convert_field_to_tz(self, field_name, tzname):
         if settings.USE_TZ:
             field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name
             params = [tzname]
         else:
             params = []
-        # http://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
-        if lookup_type == 'week_day':
-            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
-            # Note: WEEKDAY() returns 0-6, Monday=0.
-            sql = "DAYOFWEEK(%s)" % field_name
-        else:
-            sql = "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
+        return field_name, params
+
+    def datetime_cast_date_sql(self, field_name, tzname):
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
+        sql = "DATE(%s)" % field_name
+        return sql, params
+
+    def datetime_extract_sql(self, lookup_type, field_name, tzname):
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
+        sql = self.date_extract_sql(lookup_type, field_name)
         return sql, params
 
     def datetime_trunc_sql(self, lookup_type, field_name, tzname):
-        if settings.USE_TZ:
-            field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name
-            params = [tzname]
-        else:
-            params = []
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
         fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
         format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s')  # Use double percents to escape.
         format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py
index 1d3945227c..ed42fcb3cf 100644
--- a/django/db/backends/oracle/operations.py
+++ b/django/db/backends/oracle/operations.py
@@ -114,6 +114,8 @@ WHEN (new.%(col_name)s IS NULL)
     _tzname_re = re.compile(r'^[\w/:+-]+$')
 
     def _convert_field_to_tz(self, field_name, tzname):
+        if not settings.USE_TZ:
+            return field_name
         if not self._tzname_re.match(tzname):
             raise ValueError("Invalid time zone name: %s" % tzname)
         # Convert from UTC to local time, returning TIMESTAMP WITH TIME ZONE.
@@ -127,20 +129,18 @@ WHEN (new.%(col_name)s IS NULL)
         # on DATE values, even though they actually store the time part.
         return "CAST(%s AS TIMESTAMP)" % result
 
+    def datetime_cast_date_sql(self, field_name, tzname):
+        field_name = self._convert_field_to_tz(field_name, tzname)
+        sql = 'TRUNC(%s)' % field_name
+        return sql, []
+
     def datetime_extract_sql(self, lookup_type, field_name, tzname):
-        if settings.USE_TZ:
-            field_name = self._convert_field_to_tz(field_name, tzname)
-        if lookup_type == 'week_day':
-            # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
-            sql = "TO_CHAR(%s, 'D')" % field_name
-        else:
-            # http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions050.htm
-            sql = "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
+        field_name = self._convert_field_to_tz(field_name, tzname)
+        sql = self.date_extract_sql(lookup_type, field_name)
         return sql, []
 
     def datetime_trunc_sql(self, lookup_type, field_name, tzname):
-        if settings.USE_TZ:
-            field_name = self._convert_field_to_tz(field_name, tzname)
+        field_name = self._convert_field_to_tz(field_name, tzname)
         # http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions230.htm#i1002084
         if lookup_type in ('year', 'month'):
             sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py
index 65d6c9154b..866e2ca38b 100644
--- a/django/db/backends/postgresql_psycopg2/operations.py
+++ b/django/db/backends/postgresql_psycopg2/operations.py
@@ -32,26 +32,26 @@ class DatabaseOperations(BaseDatabaseOperations):
         # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
         return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
 
-    def datetime_extract_sql(self, lookup_type, field_name, tzname):
+    def _convert_field_to_tz(self, field_name, tzname):
         if settings.USE_TZ:
             field_name = "%s AT TIME ZONE %%s" % field_name
             params = [tzname]
         else:
             params = []
-        # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
-        if lookup_type == 'week_day':
-            # For consistency across backends, we return Sunday=1, Saturday=7.
-            sql = "EXTRACT('dow' FROM %s) + 1" % field_name
-        else:
-            sql = "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
+        return field_name, params
+
+    def datetime_cast_date_sql(self, field_name, tzname):
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
+        sql = '(%s)::date' % field_name
+        return sql, params
+
+    def datetime_extract_sql(self, lookup_type, field_name, tzname):
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
+        sql = self.date_extract_sql(lookup_type, field_name)
         return sql, params
 
     def datetime_trunc_sql(self, lookup_type, field_name, tzname):
-        if settings.USE_TZ:
-            field_name = "%s AT TIME ZONE %%s" % field_name
-            params = [tzname]
-        else:
-            params = []
+        field_name, params = self._convert_field_to_tz(field_name, tzname)
         # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
         sql = "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
         return sql, params
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
index 5da6bbd1fb..4bd9609e9d 100644
--- a/django/db/backends/sqlite3/base.py
+++ b/django/db/backends/sqlite3/base.py
@@ -207,6 +207,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         conn = Database.connect(**conn_params)
         conn.create_function("django_date_extract", 2, _sqlite_date_extract)
         conn.create_function("django_date_trunc", 2, _sqlite_date_trunc)
+        conn.create_function("django_datetime_cast_date", 2, _sqlite_datetime_cast_date)
         conn.create_function("django_datetime_extract", 3, _sqlite_datetime_extract)
         conn.create_function("django_datetime_trunc", 3, _sqlite_datetime_trunc)
         conn.create_function("regexp", 2, _sqlite_regexp)
@@ -354,7 +355,7 @@ def _sqlite_date_trunc(lookup_type, dt):
         return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
 
 
-def _sqlite_datetime_extract(lookup_type, dt, tzname):
+def _sqlite_datetime_parse(dt, tzname):
     if dt is None:
         return None
     try:
@@ -363,6 +364,20 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname):
         return None
     if tzname is not None:
         dt = timezone.localtime(dt, pytz.timezone(tzname))
+    return dt
+
+
+def _sqlite_datetime_cast_date(dt, tzname):
+    dt = _sqlite_datetime_parse(dt, tzname)
+    if dt is None:
+        return None
+    return dt.date().isoformat()
+
+
+def _sqlite_datetime_extract(lookup_type, dt, tzname):
+    dt = _sqlite_datetime_parse(dt, tzname)
+    if dt is None:
+        return None
     if lookup_type == 'week_day':
         return (dt.isoweekday() % 7) + 1
     else:
@@ -370,12 +385,9 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname):
 
 
 def _sqlite_datetime_trunc(lookup_type, dt, tzname):
-    try:
-        dt = backend_utils.typecast_timestamp(dt)
-    except (ValueError, TypeError):
+    dt = _sqlite_datetime_parse(dt, tzname)
+    if dt is None:
         return None
-    if tzname is not None:
-        dt = timezone.localtime(dt, pytz.timezone(tzname))
     if lookup_type == 'year':
         return "%i-01-01 00:00:00" % dt.year
     elif lookup_type == 'month':
diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
index 02d688b63d..fe4a75fb37 100644
--- a/django/db/backends/sqlite3/operations.py
+++ b/django/db/backends/sqlite3/operations.py
@@ -68,21 +68,23 @@ class DatabaseOperations(BaseDatabaseOperations):
         # cause a collision with a field name).
         return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
 
+    def _require_pytz(self):
+        if settings.USE_TZ and pytz is None:
+            raise ImproperlyConfigured("This query requires pytz, but it isn't installed.")
+
+    def datetime_cast_date_sql(self, field_name, tzname):
+        self._require_pytz()
+        return "django_datetime_cast_date(%s, %%s)" % field_name, [tzname]
+
     def datetime_extract_sql(self, lookup_type, field_name, tzname):
         # Same comment as in date_extract_sql.
-        if settings.USE_TZ:
-            if pytz is None:
-                raise ImproperlyConfigured("This query requires pytz, "
-                                           "but it isn't installed.")
+        self._require_pytz()
         return "django_datetime_extract('%s', %s, %%s)" % (
             lookup_type.lower(), field_name), [tzname]
 
     def datetime_trunc_sql(self, lookup_type, field_name, tzname):
         # Same comment as in date_trunc_sql.
-        if settings.USE_TZ:
-            if pytz is None:
-                raise ImproperlyConfigured("This query requires pytz, "
-                                           "but it isn't installed.")
+        self._require_pytz()
         return "django_datetime_trunc('%s', %s, %%s)" % (
             lookup_type.lower(), field_name), [tzname]
 
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index 3351b4ad35..e6fa29a565 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -1463,6 +1463,22 @@ class DateTimeField(DateField):
         return super(DateTimeField, self).formfield(**defaults)
 
 
+@DateTimeField.register_lookup
+class DateTimeDateTransform(Transform):
+    lookup_name = 'date'
+
+    @cached_property
+    def output_field(self):
+        return DateField()
+
+    def as_sql(self, compiler, connection):
+        lhs, lhs_params = compiler.compile(self.lhs)
+        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
+        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
+        lhs_params.extend(tz_params)
+        return sql, lhs_params
+
+
 class DecimalField(Field):
     empty_strings_allowed = False
     default_error_messages = {
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index 0ad944fe3a..0a0749d0c9 100644
--- a/docs/ref/models/querysets.txt
+++ b/docs/ref/models/querysets.txt
@@ -2463,6 +2463,27 @@ numbers and even characters.
 
     Generally speaking, you can't mix dates and datetimes.
 
+.. fieldlookup:: date
+
+date
+~~~~
+
+.. versionadded:: 1.9
+
+For datetime fields, casts the value as date. Allows chaining additional field
+lookups. Takes a date value.
+
+Example::
+
+    Entry.objects.filter(pub_date__date=datetime.date(2005, 1, 1))
+    Entry.objects.filter(pub_date__date__gt=datetime.date(2005, 1, 1))
+
+(No equivalent SQL code fragment is included for this lookup because
+implementation of the relevant query varies among different database engines.)
+
+When :setting:`USE_TZ` is ``True``, fields are converted to the current time
+zone before filtering.
+
 .. fieldlookup:: year
 
 year
diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt
index 71960c5c3b..72e7fce2dc 100644
--- a/docs/releases/1.9.txt
+++ b/docs/releases/1.9.txt
@@ -233,6 +233,9 @@ Models
   :class:`~django.db.models.Avg` aggregate in order to aggregate over
   non-numeric columns, such as ``DurationField``.
 
+* Added the :lookup:`date` lookup to :class:`~django.db.models.DateTimeField`
+  to allow querying the field by only the date portion.
+
 CSRF
 ^^^^
 
@@ -346,6 +349,9 @@ Database backend API
   ``adapt_<type>field_value()`` to mirror the ``convert_<type>field_value()``
   methods.
 
+* To use the new ``date`` lookup, third-party database backends may need to
+  implement the ``DatabaseOperations.datetime_cast_date_sql()`` method.
+
 Default settings that were tuples are now lists
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py
index 8238aeaa93..3e59f6465b 100644
--- a/tests/model_fields/tests.py
+++ b/tests/model_fields/tests.py
@@ -18,7 +18,8 @@ from django.db.models.fields import (
     TimeField, URLField,
 )
 from django.db.models.fields.files import FileField, ImageField
-from django.utils import six
+from django.test.utils import requires_tz_support
+from django.utils import six, timezone
 from django.utils.functional import lazy
 
 from .models import (
@@ -274,6 +275,48 @@ class DateTimeFieldTests(test.TestCase):
         self.assertEqual(obj.dt, datetim)
         self.assertEqual(obj.t, tim)
 
+    @test.override_settings(USE_TZ=False)
+    def test_lookup_date_without_use_tz(self):
+        d = datetime.date(2014, 3, 12)
+        dt1 = datetime.datetime(2014, 3, 12, 21, 22, 23, 240000)
+        dt2 = datetime.datetime(2014, 3, 11, 21, 22, 23, 240000)
+        t = datetime.time(21, 22, 23, 240000)
+        m = DateTimeModel.objects.create(d=d, dt=dt1, t=t)
+        # Other model with different datetime.
+        DateTimeModel.objects.create(d=d, dt=dt2, t=t)
+        self.assertEqual(m, DateTimeModel.objects.get(dt__date=d))
+
+    @requires_tz_support
+    @test.skipUnlessDBFeature('has_zoneinfo_database')
+    @test.override_settings(USE_TZ=True, TIME_ZONE='America/Vancouver')
+    def test_lookup_date_with_use_tz(self):
+        d = datetime.date(2014, 3, 12)
+        # The following is equivalent to UTC 2014-03-12 18:34:23.24000.
+        dt1 = datetime.datetime(
+            2014, 3, 12, 10, 22, 23, 240000,
+            tzinfo=timezone.get_current_timezone()
+        )
+        # The following is equivalent to UTC 2014-03-13 05:34:23.24000.
+        dt2 = datetime.datetime(
+            2014, 3, 12, 21, 22, 23, 240000,
+            tzinfo=timezone.get_current_timezone()
+        )
+        t = datetime.time(21, 22, 23, 240000)
+        m1 = DateTimeModel.objects.create(d=d, dt=dt1, t=t)
+        m2 = DateTimeModel.objects.create(d=d, dt=dt2, t=t)
+        # In Vancouver, we expect both results.
+        self.assertQuerysetEqual(
+            DateTimeModel.objects.filter(dt__date=d),
+            [repr(m1), repr(m2)],
+            ordered=False
+        )
+        with self.settings(TIME_ZONE='UTC'):
+            # But in UTC, the __date only matches one of them.
+            self.assertQuerysetEqual(
+                DateTimeModel.objects.filter(dt__date=d),
+                [repr(m1)]
+            )
+
 
 class BooleanFieldTests(test.TestCase):
     def _test_get_db_prep_lookup(self, f):