From 877c800f255ccaa7abde1fb944de45d1616f5cc9 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sun, 19 Jun 2022 23:46:22 -0400 Subject: [PATCH] Refs CVE-2022-34265 -- Properly escaped Extract() and Trunc() parameters. Co-authored-by: Mariusz Felisiak --- django/db/backends/base/operations.py | 21 ++- django/db/backends/mysql/operations.py | 117 +++++++++-------- django/db/backends/oracle/operations.py | 123 ++++++++++-------- django/db/backends/postgresql/operations.py | 82 ++++++------ django/db/backends/sqlite3/operations.py | 48 +++---- django/db/models/functions/datetime.py | 46 ++++--- docs/releases/4.1.txt | 14 ++ tests/backends/base/test_operations.py | 16 +-- tests/custom_lookups/tests.py | 2 +- .../datetime/test_extract_trunc.py | 14 +- 10 files changed, 263 insertions(+), 220 deletions(-) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 680ea1fc50..dd29068495 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -9,7 +9,6 @@ from django.db import NotSupportedError, transaction from django.db.backends import utils from django.utils import timezone from django.utils.encoding import force_str -from django.utils.regex_helper import _lazy_re_compile class BaseDatabaseOperations: @@ -55,8 +54,6 @@ class BaseDatabaseOperations: # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported. explain_prefix = None - extract_trunc_lookup_pattern = _lazy_re_compile(r"[\w\-_()]+") - def __init__(self, connection): self.connection = connection self._cache = None @@ -103,7 +100,7 @@ class BaseDatabaseOperations: """ return "%s" - def date_extract_sql(self, lookup_type, field_name): + def date_extract_sql(self, lookup_type, sql, params): """ Given a lookup_type of 'year', 'month', or 'day', return the SQL that extracts a value from the given date field field_name. @@ -113,7 +110,7 @@ class BaseDatabaseOperations: "method" ) - def date_trunc_sql(self, lookup_type, field_name, tzname=None): + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): """ Given a lookup_type of 'year', 'month', or 'day', return the SQL that truncates the given date or datetime field field_name to a date object @@ -127,7 +124,7 @@ class BaseDatabaseOperations: "method." ) - def datetime_cast_date_sql(self, field_name, tzname): + def datetime_cast_date_sql(self, sql, params, tzname): """ Return the SQL to cast a datetime value to date value. """ @@ -136,7 +133,7 @@ class BaseDatabaseOperations: "datetime_cast_date_sql() method." ) - def datetime_cast_time_sql(self, field_name, tzname): + def datetime_cast_time_sql(self, sql, params, tzname): """ Return the SQL to cast a datetime value to time value. """ @@ -145,7 +142,7 @@ class BaseDatabaseOperations: "datetime_cast_time_sql() method" ) - def datetime_extract_sql(self, lookup_type, field_name, tzname): + def datetime_extract_sql(self, lookup_type, sql, params, tzname): """ Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or 'second', return the SQL that extracts a value from the given @@ -156,7 +153,7 @@ class BaseDatabaseOperations: "method" ) - def datetime_trunc_sql(self, lookup_type, field_name, tzname): + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): """ Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or 'second', return the SQL that truncates the given datetime field @@ -167,7 +164,7 @@ class BaseDatabaseOperations: "method" ) - def time_trunc_sql(self, lookup_type, field_name, tzname=None): + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): """ Given a lookup_type of 'hour', 'minute' or 'second', return the SQL that truncates the given time or datetime field field_name to a time @@ -180,12 +177,12 @@ class BaseDatabaseOperations: "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method" ) - def time_extract_sql(self, lookup_type, field_name): + def time_extract_sql(self, lookup_type, sql, params): """ Given a lookup_type of 'hour', 'minute', or 'second', return the SQL that extracts a value from the given time field field_name. """ - return self.date_extract_sql(lookup_type, field_name) + return self.date_extract_sql(lookup_type, sql, params) def deferrable_sql(self): """ diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 7c4e21671b..34cdfc0292 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -7,6 +7,7 @@ from django.db.models import Exists, ExpressionWrapper, Lookup from django.db.models.constants import OnConflict from django.utils import timezone from django.utils.encoding import force_str +from django.utils.regex_helper import _lazy_re_compile class DatabaseOperations(BaseDatabaseOperations): @@ -37,117 +38,115 @@ class DatabaseOperations(BaseDatabaseOperations): cast_char_field_without_max_length = "char" explain_prefix = "EXPLAIN" - def date_extract_sql(self, lookup_type, field_name): + # EXTRACT format cannot be passed in parameters. + _extract_format_re = _lazy_re_compile(r"[A-Z_]+") + + def date_extract_sql(self, lookup_type, sql, params): # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html if lookup_type == "week_day": # DAYOFWEEK() returns an integer, 1-7, Sunday=1. - return "DAYOFWEEK(%s)" % field_name + return f"DAYOFWEEK({sql})", params elif lookup_type == "iso_week_day": # WEEKDAY() returns an integer, 0-6, Monday=0. - return "WEEKDAY(%s) + 1" % field_name + return f"WEEKDAY({sql}) + 1", params elif lookup_type == "week": # Override the value of default_week_format for consistency with # other database backends. # Mode 3: Monday, 1-53, with 4 or more days this year. - return "WEEK(%s, 3)" % field_name + return f"WEEK({sql}, 3)", params elif lookup_type == "iso_year": # Get the year part from the YEARWEEK function, which returns a # number as year * 100 + week. - return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name + return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params else: # EXTRACT returns 1-53 based on ISO-8601 for the week number. - return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) + lookup_type = lookup_type.upper() + if not self._extract_format_re.fullmatch(lookup_type): + raise ValueError(f"Invalid loookup type: {lookup_type!r}") + return f"EXTRACT({lookup_type} FROM {sql})", params - def date_trunc_sql(self, lookup_type, field_name, tzname=None): - field_name = self._convert_field_to_tz(field_name, tzname) + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_field_to_tz(sql, params, tzname) fields = { - "year": "%%Y-01-01", - "month": "%%Y-%%m-01", - } # Use double percents to escape. + "year": "%Y-01-01", + "month": "%Y-%m-01", + } if lookup_type in fields: format_str = fields[lookup_type] - return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str) + return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str) elif lookup_type == "quarter": return ( - "MAKEDATE(YEAR(%s), 1) + " - "INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" - % (field_name, field_name) + f"MAKEDATE(YEAR({sql}), 1) + " + f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER", + (*params, *params), ) elif lookup_type == "week": - return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name) + return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params) else: - return "DATE(%s)" % (field_name) + return f"DATE({sql})", params def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) return f"{sign}{offset}" if offset else tzname - def _convert_field_to_tz(self, field_name, tzname): + def _convert_field_to_tz(self, sql, params, tzname): if tzname and settings.USE_TZ and self.connection.timezone_name != tzname: - field_name = "CONVERT_TZ(%s, '%s', '%s')" % ( - field_name, + return f"CONVERT_TZ({sql}, %s, %s)", ( + *params, self.connection.timezone_name, self._prepare_tzname_delta(tzname), ) - return field_name + return sql, params - def datetime_cast_date_sql(self, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return "DATE(%s)" % field_name + def datetime_cast_date_sql(self, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) + return f"DATE({sql})", params - def datetime_cast_time_sql(self, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return "TIME(%s)" % field_name + def datetime_cast_time_sql(self, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) + return f"TIME({sql})", params - def datetime_extract_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return self.date_extract_sql(lookup_type, field_name) + def datetime_extract_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) + return self.date_extract_sql(lookup_type, sql, params) - def datetime_trunc_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) fields = ["year", "month", "day", "hour", "minute", "second"] - format = ( - "%%Y-", - "%%m", - "-%%d", - " %%H:", - "%%i", - ":%%s", - ) # Use double percents to escape. + format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s") format_def = ("0000-", "01", "-01", " 00:", "00", ":00") if lookup_type == "quarter": return ( - "CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + " - "INTERVAL QUARTER({field_name}) QUARTER - " - + "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)" - ).format(field_name=field_name) + f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + " + f"INTERVAL QUARTER({sql}) QUARTER - " + f"INTERVAL 1 QUARTER, %s) AS DATETIME)" + ), (*params, *params, "%Y-%m-01 00:00:00") if lookup_type == "week": return ( - "CAST(DATE_FORMAT(DATE_SUB({field_name}, " - "INTERVAL WEEKDAY({field_name}) DAY), " - "'%%Y-%%m-%%d 00:00:00') AS DATETIME)" - ).format(field_name=field_name) + f"CAST(DATE_FORMAT(" + f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)" + ), (*params, *params, "%Y-%m-%d 00:00:00") try: i = fields.index(lookup_type) + 1 except ValueError: - sql = field_name + pass else: format_str = "".join(format[:i] + format_def[i:]) - sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) - return sql + return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str) + return sql, params - def time_trunc_sql(self, lookup_type, field_name, tzname=None): - field_name = self._convert_field_to_tz(field_name, tzname) + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_field_to_tz(sql, params, tzname) fields = { - "hour": "%%H:00:00", - "minute": "%%H:%%i:00", - "second": "%%H:%%i:%%s", - } # Use double percents to escape. + "hour": "%H:00:00", + "minute": "%H:%i:00", + "second": "%H:%i:%s", + } if lookup_type in fields: format_str = fields[lookup_type] - return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str) + return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str) else: - return "TIME(%s)" % (field_name) + return f"TIME({sql})", params def fetch_returned_insert_rows(self, cursor): """ diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index b044adadda..70548e358f 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -77,34 +77,46 @@ END; f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY" ) - def date_extract_sql(self, lookup_type, field_name): + # EXTRACT format cannot be passed in parameters. + _extract_format_re = _lazy_re_compile(r"[A-Z_]+") + + def date_extract_sql(self, lookup_type, sql, params): + extract_sql = f"TO_CHAR({sql}, %s)" + extract_param = None if lookup_type == "week_day": # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday. - return "TO_CHAR(%s, 'D')" % field_name + extract_param = "D" elif lookup_type == "iso_week_day": - return "TO_CHAR(%s - 1, 'D')" % field_name + extract_sql = f"TO_CHAR({sql} - 1, %s)" + extract_param = "D" elif lookup_type == "week": # IW = ISO week number - return "TO_CHAR(%s, 'IW')" % field_name + extract_param = "IW" elif lookup_type == "quarter": - return "TO_CHAR(%s, 'Q')" % field_name + extract_param = "Q" elif lookup_type == "iso_year": - return "TO_CHAR(%s, 'IYYY')" % field_name + extract_param = "IYYY" else: + lookup_type = lookup_type.upper() + if not self._extract_format_re.fullmatch(lookup_type): + raise ValueError(f"Invalid loookup type: {lookup_type!r}") # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html - return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) + return f"EXTRACT({lookup_type} FROM {sql})", params + return extract_sql, (*params, extract_param) - def date_trunc_sql(self, lookup_type, field_name, tzname=None): - field_name = self._convert_field_to_tz(field_name, tzname) + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_field_to_tz(sql, params, tzname) # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html + trunc_param = None if lookup_type in ("year", "month"): - return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) + trunc_param = lookup_type.upper() elif lookup_type == "quarter": - return "TRUNC(%s, 'Q')" % field_name + trunc_param = "Q" elif lookup_type == "week": - return "TRUNC(%s, 'IW')" % field_name + trunc_param = "IW" else: - return "TRUNC(%s)" % field_name + return f"TRUNC({sql})", params + return f"TRUNC({sql}, %s)", (*params, trunc_param) # Oracle crashes with "ORA-03113: end-of-file on communication channel" # if the time zone name is passed in parameter. Use interpolation instead. @@ -116,77 +128,80 @@ END; tzname, sign, offset = split_tzname_delta(tzname) return f"{sign}{offset}" if offset else tzname - def _convert_field_to_tz(self, field_name, tzname): + def _convert_field_to_tz(self, sql, params, tzname): if not (settings.USE_TZ and tzname): - return field_name + return sql, params if not self._tzname_re.match(tzname): raise ValueError("Invalid time zone name: %s" % tzname) # Convert from connection timezone to the local time, returning # TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the # TIME ZONE details. if self.connection.timezone_name != tzname: - return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % ( - field_name, - self.connection.timezone_name, - self._prepare_tzname_delta(tzname), + from_timezone_name = self.connection.timezone_name + to_timezone_name = self._prepare_tzname_delta(tzname) + return ( + f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE " + f"'{to_timezone_name}') AS TIMESTAMP)", + params, ) - return field_name + return sql, params - def datetime_cast_date_sql(self, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return "TRUNC(%s)" % field_name + def datetime_cast_date_sql(self, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) + return f"TRUNC({sql})", params - def datetime_cast_time_sql(self, field_name, tzname): + def datetime_cast_time_sql(self, sql, params, tzname): # Since `TimeField` values are stored as TIMESTAMP change to the # default date and convert the field to the specified timezone. + sql, params = self._convert_field_to_tz(sql, params, tzname) convert_datetime_sql = ( - "TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), " - "'YYYY-MM-DD HH24:MI:SS.FF')" - ) % self._convert_field_to_tz(field_name, tzname) - return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % ( - field_name, - convert_datetime_sql, + f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), " + f"'YYYY-MM-DD HH24:MI:SS.FF')" + ) + return ( + f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END", + (*params, *params), ) - def datetime_extract_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return self.date_extract_sql(lookup_type, field_name) + def datetime_extract_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) + return self.date_extract_sql(lookup_type, sql, params) - def datetime_trunc_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_field_to_tz(sql, params, tzname) # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html + trunc_param = None if lookup_type in ("year", "month"): - sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) + trunc_param = lookup_type.upper() elif lookup_type == "quarter": - sql = "TRUNC(%s, 'Q')" % field_name + trunc_param = "Q" elif lookup_type == "week": - sql = "TRUNC(%s, 'IW')" % field_name - elif lookup_type == "day": - sql = "TRUNC(%s)" % field_name + trunc_param = "IW" elif lookup_type == "hour": - sql = "TRUNC(%s, 'HH24')" % field_name + trunc_param = "HH24" elif lookup_type == "minute": - sql = "TRUNC(%s, 'MI')" % field_name + trunc_param = "MI" + elif lookup_type == "day": + return f"TRUNC({sql})", params else: - sql = ( - "CAST(%s AS DATE)" % field_name - ) # Cast to DATE removes sub-second precision. - return sql + # Cast to DATE removes sub-second precision. + return f"CAST({sql} AS DATE)", params + return f"TRUNC({sql}, %s)", (*params, trunc_param) - def time_trunc_sql(self, lookup_type, field_name, tzname=None): + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): # The implementation is similar to `datetime_trunc_sql` as both # `DateTimeField` and `TimeField` are stored as TIMESTAMP where # the date part of the later is ignored. - field_name = self._convert_field_to_tz(field_name, tzname) + sql, params = self._convert_field_to_tz(sql, params, tzname) + trunc_param = None if lookup_type == "hour": - sql = "TRUNC(%s, 'HH24')" % field_name + trunc_param = "HH24" elif lookup_type == "minute": - sql = "TRUNC(%s, 'MI')" % field_name + trunc_param = "MI" elif lookup_type == "second": - sql = ( - "CAST(%s AS DATE)" % field_name - ) # Cast to DATE removes sub-second precision. - return sql + # Cast to DATE removes sub-second precision. + return f"CAST({sql} AS DATE)", params + return f"TRUNC({sql}, %s)", (*params, trunc_param) def get_db_converters(self, expression): converters = super().get_db_converters(expression) diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index ec162d53f4..e8eb06c9e2 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -47,22 +47,24 @@ class DatabaseOperations(BaseDatabaseOperations): ) return "%s" - def date_extract_sql(self, lookup_type, field_name): + def date_extract_sql(self, lookup_type, sql, params): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT + extract_sql = f"EXTRACT(%s FROM {sql})" + extract_param = lookup_type if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. - return "EXTRACT('dow' FROM %s) + 1" % field_name + extract_sql = f"EXTRACT(%s FROM {sql}) + 1" + extract_param = "dow" elif lookup_type == "iso_week_day": - return "EXTRACT('isodow' FROM %s)" % field_name + extract_param = "isodow" elif lookup_type == "iso_year": - return "EXTRACT('isoyear' FROM %s)" % field_name - else: - return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) + extract_param = "isoyear" + return extract_sql, (extract_param, *params) - def date_trunc_sql(self, lookup_type, field_name, tzname=None): - field_name = self._convert_field_to_tz(field_name, tzname) + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_sql_to_tz(sql, params, tzname) # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC - return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) + return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) @@ -71,43 +73,47 @@ class DatabaseOperations(BaseDatabaseOperations): return f"{tzname}{sign}{offset}" return tzname - def _convert_field_to_tz(self, field_name, tzname): + def _convert_sql_to_tz(self, sql, params, tzname): if tzname and settings.USE_TZ: - field_name = "%s AT TIME ZONE '%s'" % ( - field_name, - self._prepare_tzname_delta(tzname), + tzname_param = self._prepare_tzname_delta(tzname) + return f"{sql} AT TIME ZONE %s", (*params, tzname_param) + return sql, params + + def datetime_cast_date_sql(self, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + return f"({sql})::date", params + + def datetime_cast_time_sql(self, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + return f"({sql})::time", params + + def datetime_extract_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + if lookup_type == "second": + # Truncate fractional seconds. + return ( + f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", + ("second", "second", *params), ) - return field_name + return self.date_extract_sql(lookup_type, sql, params) - def datetime_cast_date_sql(self, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return "(%s)::date" % field_name - - def datetime_cast_time_sql(self, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - return "(%s)::time" % field_name - - def datetime_extract_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) - if lookup_type == "second": - # Truncate fractional seconds. - return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))" - return self.date_extract_sql(lookup_type, field_name) - - def datetime_trunc_sql(self, lookup_type, field_name, tzname): - field_name = self._convert_field_to_tz(field_name, tzname) + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): + sql, params = self._convert_sql_to_tz(sql, params, tzname) # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC - return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) + return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params) - def time_extract_sql(self, lookup_type, field_name): + def time_extract_sql(self, lookup_type, sql, params): if lookup_type == "second": # Truncate fractional seconds. - return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))" - return self.date_extract_sql(lookup_type, field_name) + return ( + f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))", + ("second", "second", *params), + ) + return self.date_extract_sql(lookup_type, sql, params) - def time_trunc_sql(self, lookup_type, field_name, tzname=None): - field_name = self._convert_field_to_tz(field_name, tzname) - return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): + sql, params = self._convert_sql_to_tz(sql, params, tzname) + return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params) def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 7c7cfce1ba..0d3a4060ac 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -69,13 +69,13 @@ class DatabaseOperations(BaseDatabaseOperations): "accepting multiple arguments." ) - def date_extract_sql(self, lookup_type, field_name): + def date_extract_sql(self, lookup_type, sql, params): """ Support EXTRACT with a user-defined function django_date_extract() that's registered in connect(). Use single quotes because this is a string and could otherwise cause a collision with a field name. """ - return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name) + return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params) def fetch_returned_insert_rows(self, cursor): """ @@ -88,53 +88,53 @@ class DatabaseOperations(BaseDatabaseOperations): """Do nothing since formatting is handled in the custom function.""" return sql - def date_trunc_sql(self, lookup_type, field_name, tzname=None): - return "django_date_trunc('%s', %s, %s, %s)" % ( + def date_trunc_sql(self, lookup_type, sql, params, tzname=None): + return f"django_date_trunc(%s, {sql}, %s, %s)", ( lookup_type.lower(), - field_name, + *params, *self._convert_tznames_to_sql(tzname), ) - def time_trunc_sql(self, lookup_type, field_name, tzname=None): - return "django_time_trunc('%s', %s, %s, %s)" % ( + def time_trunc_sql(self, lookup_type, sql, params, tzname=None): + return f"django_time_trunc(%s, {sql}, %s, %s)", ( lookup_type.lower(), - field_name, + *params, *self._convert_tznames_to_sql(tzname), ) def _convert_tznames_to_sql(self, tzname): if tzname and settings.USE_TZ: - return "'%s'" % tzname, "'%s'" % self.connection.timezone_name - return "NULL", "NULL" + return tzname, self.connection.timezone_name + return None, None - def datetime_cast_date_sql(self, field_name, tzname): - return "django_datetime_cast_date(%s, %s, %s)" % ( - field_name, + def datetime_cast_date_sql(self, sql, params, tzname): + return f"django_datetime_cast_date({sql}, %s, %s)", ( + *params, *self._convert_tznames_to_sql(tzname), ) - def datetime_cast_time_sql(self, field_name, tzname): - return "django_datetime_cast_time(%s, %s, %s)" % ( - field_name, + def datetime_cast_time_sql(self, sql, params, tzname): + return f"django_datetime_cast_time({sql}, %s, %s)", ( + *params, *self._convert_tznames_to_sql(tzname), ) - def datetime_extract_sql(self, lookup_type, field_name, tzname): - return "django_datetime_extract('%s', %s, %s, %s)" % ( + def datetime_extract_sql(self, lookup_type, sql, params, tzname): + return f"django_datetime_extract(%s, {sql}, %s, %s)", ( lookup_type.lower(), - field_name, + *params, *self._convert_tznames_to_sql(tzname), ) - def datetime_trunc_sql(self, lookup_type, field_name, tzname): - return "django_datetime_trunc('%s', %s, %s, %s)" % ( + def datetime_trunc_sql(self, lookup_type, sql, params, tzname): + return f"django_datetime_trunc(%s, {sql}, %s, %s)", ( lookup_type.lower(), - field_name, + *params, *self._convert_tznames_to_sql(tzname), ) - def time_extract_sql(self, lookup_type, field_name): - return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name) + def time_extract_sql(self, lookup_type, sql, params): + return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params) def pk_default_value(self): return "NULL" diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index 5f98e6bba1..f833c09973 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -51,25 +51,31 @@ class Extract(TimezoneMixin, Transform): super().__init__(expression, **extra) def as_sql(self, compiler, connection): - if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.lookup_name): - raise ValueError("Invalid lookup_name: %s" % self.lookup_name) sql, params = compiler.compile(self.lhs) lhs_output_field = self.lhs.output_field if isinstance(lhs_output_field, DateTimeField): tzname = self.get_tzname() - sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) + sql, params = connection.ops.datetime_extract_sql( + self.lookup_name, sql, tuple(params), tzname + ) elif self.tzinfo is not None: raise ValueError("tzinfo can only be used with DateTimeField.") elif isinstance(lhs_output_field, DateField): - sql = connection.ops.date_extract_sql(self.lookup_name, sql) + sql, params = connection.ops.date_extract_sql( + self.lookup_name, sql, tuple(params) + ) elif isinstance(lhs_output_field, TimeField): - sql = connection.ops.time_extract_sql(self.lookup_name, sql) + sql, params = connection.ops.time_extract_sql( + self.lookup_name, sql, tuple(params) + ) elif isinstance(lhs_output_field, DurationField): if not connection.features.has_native_duration_field: raise ValueError( "Extract requires native DurationField database support." ) - sql = connection.ops.time_extract_sql(self.lookup_name, sql) + sql, params = connection.ops.time_extract_sql( + self.lookup_name, sql, tuple(params) + ) else: # resolve_expression has already validated the output_field so this # assert should never be hit. @@ -237,25 +243,29 @@ class TruncBase(TimezoneMixin, Transform): super().__init__(expression, output_field=output_field, **extra) def as_sql(self, compiler, connection): - if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.kind): - raise ValueError("Invalid kind: %s" % self.kind) - inner_sql, inner_params = compiler.compile(self.lhs) + sql, params = compiler.compile(self.lhs) tzname = None if isinstance(self.lhs.output_field, DateTimeField): tzname = self.get_tzname() elif self.tzinfo is not None: raise ValueError("tzinfo can only be used with DateTimeField.") if isinstance(self.output_field, DateTimeField): - sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) + sql, params = connection.ops.datetime_trunc_sql( + self.kind, sql, tuple(params), tzname + ) elif isinstance(self.output_field, DateField): - sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname) + sql, params = connection.ops.date_trunc_sql( + self.kind, sql, tuple(params), tzname + ) elif isinstance(self.output_field, TimeField): - sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname) + sql, params = connection.ops.time_trunc_sql( + self.kind, sql, tuple(params), tzname + ) else: raise ValueError( "Trunc only valid on DateField, TimeField, or DateTimeField." ) - return sql, inner_params + return sql, params def resolve_expression( self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False @@ -384,10 +394,9 @@ class TruncDate(TruncBase): def as_sql(self, compiler, connection): # Cast to date rather than truncate to date. - lhs, lhs_params = compiler.compile(self.lhs) + sql, params = compiler.compile(self.lhs) tzname = self.get_tzname() - sql = connection.ops.datetime_cast_date_sql(lhs, tzname) - return sql, lhs_params + return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname) class TruncTime(TruncBase): @@ -397,10 +406,9 @@ class TruncTime(TruncBase): def as_sql(self, compiler, connection): # Cast to time rather than truncate to time. - lhs, lhs_params = compiler.compile(self.lhs) + sql, params = compiler.compile(self.lhs) tzname = self.get_tzname() - sql = connection.ops.datetime_cast_time_sql(lhs, tzname) - return sql, lhs_params + return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname) class TruncHour(TruncBase): diff --git a/docs/releases/4.1.txt b/docs/releases/4.1.txt index ad6400c665..49bbf2dec2 100644 --- a/docs/releases/4.1.txt +++ b/docs/releases/4.1.txt @@ -459,6 +459,20 @@ backends. ``DatabaseOperations.insert_statement()`` method is replaced by ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``. +* Several date and time methods on ``DatabaseOperations`` now take ``sql`` and + ``params`` arguments instead of ``field_name`` and return 2-tuple containing + some SQL and the parameters to be interpolated into that SQL. The changed + methods have these new signatures: + + * ``DatabaseOperations.date_extract_sql(lookup_type, sql, params)`` + * ``DatabaseOperations.datetime_extract_sql(lookup_type, sql, params, tzname)`` + * ``DatabaseOperations.time_extract_sql(lookup_type, sql, params)`` + * ``DatabaseOperations.date_trunc_sql(lookup_type, sql, params, tzname=None)`` + * ``DatabaseOperations.datetime_trunc_sql(self, lookup_type, sql, params, tzname)`` + * ``DatabaseOperations.time_trunc_sql(lookup_type, sql, params, tzname=None)`` + * ``DatabaseOperations.datetime_cast_date_sql(sql, params, tzname)`` + * ``DatabaseOperations.datetime_cast_time_sql(sql, params, tzname)`` + :mod:`django.contrib.gis` ------------------------- diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index b19b7ee558..5260344da7 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -115,49 +115,49 @@ class SimpleDatabaseOperationTests(SimpleTestCase): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "date_extract_sql" ): - self.ops.date_extract_sql(None, None) + self.ops.date_extract_sql(None, None, None) def test_time_extract_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "date_extract_sql" ): - self.ops.time_extract_sql(None, None) + self.ops.time_extract_sql(None, None, None) def test_date_trunc_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "date_trunc_sql" ): - self.ops.date_trunc_sql(None, None) + self.ops.date_trunc_sql(None, None, None) def test_time_trunc_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "time_trunc_sql" ): - self.ops.time_trunc_sql(None, None) + self.ops.time_trunc_sql(None, None, None) def test_datetime_trunc_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "datetime_trunc_sql" ): - self.ops.datetime_trunc_sql(None, None, None) + self.ops.datetime_trunc_sql(None, None, None, None) def test_datetime_cast_date_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "datetime_cast_date_sql" ): - self.ops.datetime_cast_date_sql(None, None) + self.ops.datetime_cast_date_sql(None, None, None) def test_datetime_cast_time_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "datetime_cast_time_sql" ): - self.ops.datetime_cast_time_sql(None, None) + self.ops.datetime_cast_time_sql(None, None, None) def test_datetime_extract_sql(self): with self.assertRaisesMessage( NotImplementedError, self.may_require_msg % "datetime_extract_sql" ): - self.ops.datetime_extract_sql(None, None, None) + self.ops.datetime_extract_sql(None, None, None, None) class DatabaseOperationTests(TestCase): diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index ece67a46a5..b730d7b5e1 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -75,7 +75,7 @@ class YearTransform(models.Transform): def as_sql(self, compiler, connection): lhs_sql, params = compiler.compile(self.lhs) - return connection.ops.date_extract_sql("year", lhs_sql), params + return connection.ops.date_extract_sql("year", lhs_sql, params) @property def output_field(self): diff --git a/tests/db_functions/datetime/test_extract_trunc.py b/tests/db_functions/datetime/test_extract_trunc.py index bb70ed6094..00e3897e68 100644 --- a/tests/db_functions/datetime/test_extract_trunc.py +++ b/tests/db_functions/datetime/test_extract_trunc.py @@ -13,6 +13,7 @@ except ImportError: pytz = None from django.conf import settings +from django.db import DataError, OperationalError from django.db.models import ( DateField, DateTimeField, @@ -244,8 +245,7 @@ class DateFunctionTests(TestCase): self.create_model(start_datetime, end_datetime) self.create_model(end_datetime, start_datetime) - msg = "Invalid lookup_name: " - with self.assertRaisesMessage(ValueError, msg): + with self.assertRaises((DataError, OperationalError, ValueError)): DTModel.objects.filter( start_datetime__year=Extract( "start_datetime", "day' FROM start_datetime)) OR 1=1;--" @@ -940,14 +940,18 @@ class DateFunctionTests(TestCase): end_datetime = timezone.make_aware(end_datetime) self.create_model(start_datetime, end_datetime) self.create_model(end_datetime, start_datetime) - msg = "Invalid kind: " - with self.assertRaisesMessage(ValueError, msg): - DTModel.objects.filter( + # Database backends raise an exception or don't return any results. + try: + exists = DTModel.objects.filter( start_datetime__date=Trunc( "start_datetime", "year', start_datetime)) OR 1=1;--", ) ).exists() + except (DataError, OperationalError): + pass + else: + self.assertIs(exists, False) def test_trunc_func(self): start_datetime = datetime(999, 6, 15, 14, 30, 50, 321)