1
0
mirror of https://github.com/django/django.git synced 2025-10-26 07:06:08 +00:00

Fixed #33379 -- Added minimum database version checks.

Thanks Tim Graham for the review.
This commit is contained in:
Hasan Ramezani
2021-12-27 19:04:59 +01:00
committed by Mariusz Felisiak
parent 737542390a
commit 9ac3ef59f9
16 changed files with 166 additions and 41 deletions

View File

@@ -13,7 +13,7 @@ except ImportError:
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, DatabaseError
from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError
from django.db.backends import utils
from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.signals import connection_created
@@ -24,6 +24,7 @@ from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
NO_DB_ALIAS = "__no_db__"
RAN_DB_VERSION_CHECK = set()
# RemovedInDjango50Warning
@@ -185,6 +186,29 @@ class BaseDatabaseWrapper:
)
return list(self.queries_log)
def get_database_version(self):
"""Return a tuple of the database's version."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_database_version() "
"method."
)
def check_database_version_supported(self):
"""
Raise an error if the database version isn't supported by this
version of Django.
"""
if (
self.features.minimum_database_version is not None
and self.get_database_version() < self.features.minimum_database_version
):
db_version = ".".join(map(str, self.get_database_version()))
min_db_version = ".".join(map(str, self.features.minimum_database_version))
raise NotSupportedError(
f"{self.display_name} {min_db_version} or later is required "
f"(found {db_version})."
)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
@@ -203,10 +227,10 @@ class BaseDatabaseWrapper:
def init_connection_state(self):
"""Initialize the database connection settings."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an init_connection_state() "
"method"
)
global RAN_DB_VERSION_CHECK
if self.alias not in RAN_DB_VERSION_CHECK:
self.check_database_version_supported()
RAN_DB_VERSION_CHECK.add(self.alias)
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""

View File

@@ -3,6 +3,8 @@ from django.utils.functional import cached_property
class BaseDatabaseFeatures:
# An optional tuple indicating the minimum supported database version.
minimum_database_version = None
gis_enabled = False
# Oracle can't group by LOB (large object) data types.
allows_group_by_lob = True

View File

@@ -200,6 +200,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def get_database_version(self):
return self.mysql_version
def get_connection_params(self):
kwargs = {
"conv": django_conversions,
@@ -251,6 +254,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return connection
def init_connection_state(self):
super().init_connection_state()
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on

View File

@@ -48,6 +48,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
@cached_property
def minimum_database_version(self):
if self.connection.mysql_is_mariadb:
return (10, 2)
else:
return (5, 7)
@cached_property
def test_collations(self):
charset = "utf8"

View File

@@ -239,6 +239,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
self.features.can_return_columns_from_insert = use_returning_into
def get_database_version(self):
return self.oracle_version
def get_connection_params(self):
conn_params = self.settings_dict["OPTIONS"].copy()
if "use_returning_into" in conn_params:
@@ -255,6 +258,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
def init_connection_state(self):
super().init_connection_state()
cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of

View File

@@ -4,6 +4,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (19,)
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
# BLOB" when grouping by LOBs (#24096).
allows_group_by_lob = False

View File

@@ -153,6 +153,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
def get_database_version(self):
"""
Return a tuple of the database's version.
E.g. for pg_version 120004, return (12, 4).
"""
return divmod(self.pg_version, 10000)
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
@@ -236,6 +243,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
def init_connection_state(self):
super().init_connection_state()
self.connection.set_client_encoding("UTF8")
timezone_changed = self.ensure_timezone()

View File

@@ -6,6 +6,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (10,)
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True

View File

@@ -29,15 +29,6 @@ def decoder(conv_func):
return lambda s: conv_func(s.decode())
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
)
check_sqlite_version()
Database.register_converter("bool", b"1".__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
@@ -168,6 +159,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
kwargs.update({"check_same_thread": False, "uri": True})
return kwargs
def get_database_version(self):
return self.Database.sqlite_version_info
@async_unsafe
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
@@ -179,9 +173,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.execute("PRAGMA legacy_alter_table = OFF")
return conn
def init_connection_state(self):
pass
def create_cursor(self, name=None):
return self.connection.cursor(factory=SQLiteCursorWrapper)

View File

@@ -9,6 +9,7 @@ from .base import Database
class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (3, 9)
test_db_allows_multiple_connections = False
supports_unspecified_pk = True
supports_timezones = False

View File

@@ -157,6 +157,18 @@ CSRF
* ...
Database backends
~~~~~~~~~~~~~~~~~
* Third-party database backends can now specify the minimum required version of
the database using the ``DatabaseFeatures.minimum_database_version``
attribute which is a tuple (e.g. ``(10, 0)`` means "10.0"). If a minimum
version is specified, backends must also implement
``DatabaseWrapper.get_database_version()``, which returns a tuple of the
current database version. The backend's
``DatabaseWrapper.init_connection_state()`` method must call ``super()`` in
order for the check to run.
Decorators
~~~~~~~~~~

View File

@@ -41,6 +41,19 @@ class DatabaseWrapperTests(SimpleTestCase):
self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
self.assertNotEqual(connection.display_name, "unknown")
def test_get_database_version(self):
with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
msg = (
"subclasses of BaseDatabaseWrapper may require a "
"get_database_version() method."
)
with self.assertRaisesMessage(NotImplementedError, msg):
BaseDatabaseWrapper().get_database_version()
def test_check_database_version_supported_with_none_as_database_version(self):
with patch.object(connection.features, "minimum_database_version", None):
connection.check_database_version_supported()
class ExecuteWrapperTests(TestCase):
@staticmethod
@@ -297,3 +310,25 @@ class ConnectionHealthChecksTests(SimpleTestCase):
connection.commit()
connection.set_autocommit(True)
self.assertIs(new_connection, connection.connection)
class MultiDatabaseTests(TestCase):
databases = {"default", "other"}
def test_multi_database_init_connection_state_called_once(self):
for db in self.databases:
with self.subTest(database=db):
with patch.object(connections[db], "commit", return_value=None):
with patch.object(
connections[db],
"check_database_version_supported",
) as mocked_check_database_version_supported:
connections[db].init_connection_state()
after_first_calls = len(
mocked_check_database_version_supported.mock_calls
)
connections[db].init_connection_state()
self.assertEqual(
len(mocked_check_database_version_supported.mock_calls),
after_first_calls,
)

View File

@@ -1,8 +1,9 @@
import unittest
from contextlib import contextmanager
from unittest import mock
from django.core.exceptions import ImproperlyConfigured
from django.db import connection
from django.db import NotSupportedError, connection
from django.test import TestCase, override_settings
@@ -99,3 +100,19 @@ class IsolationLevelTests(TestCase):
)
with self.assertRaisesMessage(ImproperlyConfigured, msg):
new_connection.cursor()
@unittest.skipUnless(connection.vendor == "mysql", "MySQL tests")
class Tests(TestCase):
@mock.patch.object(connection, "get_database_version")
def test_check_database_version_supported(self, mocked_get_database_version):
if connection.mysql_is_mariadb:
mocked_get_database_version.return_value = (10, 1)
msg = "MariaDB 10.2 or later is required (found 10.1)."
else:
mocked_get_database_version.return_value = (5, 6)
msg = "MySQL 5.7 or later is required (found 5.6)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)

View File

@@ -1,14 +1,15 @@
import unittest
from unittest import mock
from django.db import DatabaseError, connection
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import BooleanField
from django.test import TransactionTestCase
from django.test import TestCase, TransactionTestCase
from ..models import Square, VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
class Tests(unittest.TestCase):
class Tests(TestCase):
def test_quote_name(self):
"""'%' chars are escaped for query execution."""
name = '"SOME%NAME"'
@@ -56,6 +57,17 @@ class Tests(unittest.TestCase):
field.set_attributes_from_name("is_nice")
self.assertIn('"IS_NICE" IN (0,1)', field.db_check(connection))
@mock.patch.object(
connection,
"get_database_version",
return_value=(18, 1),
)
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "Oracle 19 or later is required (found 18.1)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)
@unittest.skipUnless(connection.vendor == "oracle", "Oracle tests")
class TransactionalTests(TransactionTestCase):

View File

@@ -4,7 +4,13 @@ from io import StringIO
from unittest import mock
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connection, connections
from django.db import (
DEFAULT_DB_ALIAS,
DatabaseError,
NotSupportedError,
connection,
connections,
)
from django.db.backends.base.base import BaseDatabaseWrapper
from django.test import TestCase, override_settings
@@ -303,3 +309,15 @@ class Tests(TestCase):
[q["sql"] for q in connection.queries],
[copy_expert_sql, "COPY django_session TO STDOUT"],
)
def test_get_database_version(self):
new_connection = connection.copy()
new_connection.pg_version = 110009
self.assertEqual(new_connection.get_database_version(), (11, 9))
@mock.patch.object(connection, "get_database_version", return_value=(9, 6))
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "PostgreSQL 10 or later is required (found 9.6)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)

View File

@@ -4,10 +4,8 @@ import tempfile
import threading
import unittest
from pathlib import Path
from sqlite3 import dbapi2
from unittest import mock
from django.core.exceptions import ImproperlyConfigured
from django.db import NotSupportedError, connection, transaction
from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance
from django.db.utils import ConnectionHandler
@@ -21,28 +19,11 @@ from django.test.utils import isolate_apps
from ..models import Author, Item, Object, Square
try:
from django.db.backends.sqlite3.base import check_sqlite_version
except ImproperlyConfigured:
# Ignore "SQLite is too old" when running tests on another database.
pass
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
class Tests(TestCase):
longMessage = True
def test_check_sqlite_version(self):
msg = "SQLite 3.9.0 or later is required (found 3.8.11.1)."
with mock.patch.object(
dbapi2, "sqlite_version_info", (3, 8, 11, 1)
), mock.patch.object(
dbapi2, "sqlite_version", "3.8.11.1"
), self.assertRaisesMessage(
ImproperlyConfigured, msg
):
check_sqlite_version()
def test_aggregation(self):
"""Raise NotSupportedError when aggregating on date/time fields."""
for aggregate in (Sum, Avg, Variance, StdDev):
@@ -125,6 +106,13 @@ class Tests(TestCase):
connections["default"].close()
self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db")))
@mock.patch.object(connection, "get_database_version", return_value=(3, 8))
def test_check_database_version_supported(self, mocked_get_database_version):
msg = "SQLite 3.9 or later is required (found 3.8)."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.check_database_version_supported()
self.assertTrue(mocked_get_database_version.called)
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
@isolate_apps("backends")