mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #33277 -- Disallowed database connections in threads in SimpleTestCase.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							45f778eded
						
					
				
				
					commit
					8fb0be3500
				
			| @@ -10,6 +10,7 @@ from contextlib import contextmanager | |||||||
| from copy import copy, deepcopy | from copy import copy, deepcopy | ||||||
| from difflib import get_close_matches | from difflib import get_close_matches | ||||||
| from functools import wraps | from functools import wraps | ||||||
|  | from unittest import mock | ||||||
| from unittest.suite import _DebugResult | from unittest.suite import _DebugResult | ||||||
| from unittest.util import safe_repr | from unittest.util import safe_repr | ||||||
| from urllib.parse import ( | from urllib.parse import ( | ||||||
| @@ -37,6 +38,7 @@ from django.core.management.sql import emit_post_migrate_signal | |||||||
| from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler | from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler | ||||||
| from django.core.signals import setting_changed | from django.core.signals import setting_changed | ||||||
| from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction | from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction | ||||||
|  | from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper | ||||||
| from django.forms.fields import CharField | from django.forms.fields import CharField | ||||||
| from django.http import QueryDict | from django.http import QueryDict | ||||||
| from django.http.request import split_domain_port, validate_host | from django.http.request import split_domain_port, validate_host | ||||||
| @@ -255,6 +257,13 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|                 } |                 } | ||||||
|                 method = getattr(connection, name) |                 method = getattr(connection, name) | ||||||
|                 setattr(connection, name, _DatabaseFailure(method, message)) |                 setattr(connection, name, _DatabaseFailure(method, message)) | ||||||
|  |         cls.enterClassContext( | ||||||
|  |             mock.patch.object( | ||||||
|  |                 BaseDatabaseWrapper, | ||||||
|  |                 "ensure_connection", | ||||||
|  |                 new=cls.ensure_connection_patch_method(), | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _remove_databases_failures(cls): |     def _remove_databases_failures(cls): | ||||||
| @@ -266,6 +275,28 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|                 method = getattr(connection, name) |                 method = getattr(connection, name) | ||||||
|                 setattr(connection, name, method.wrapped) |                 setattr(connection, name, method.wrapped) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def ensure_connection_patch_method(cls): | ||||||
|  |         real_ensure_connection = BaseDatabaseWrapper.ensure_connection | ||||||
|  |  | ||||||
|  |         def patched_ensure_connection(self, *args, **kwargs): | ||||||
|  |             if ( | ||||||
|  |                 self.connection is None | ||||||
|  |                 and self.alias not in cls.databases | ||||||
|  |                 and self.alias != NO_DB_ALIAS | ||||||
|  |             ): | ||||||
|  |                 # Connection has not yet been established, but the alias is not allowed. | ||||||
|  |                 message = cls._disallowed_database_msg % { | ||||||
|  |                     "test": f"{cls.__module__}.{cls.__qualname__}", | ||||||
|  |                     "alias": self.alias, | ||||||
|  |                     "operation": "threaded connections", | ||||||
|  |                 } | ||||||
|  |                 return _DatabaseFailure(self.ensure_connection, message)() | ||||||
|  |  | ||||||
|  |             real_ensure_connection(self, *args, **kwargs) | ||||||
|  |  | ||||||
|  |         return patched_ensure_connection | ||||||
|  |  | ||||||
|     def __call__(self, result=None): |     def __call__(self, result=None): | ||||||
|         """ |         """ | ||||||
|         Wrapper around default __call__ method to perform common Django test |         Wrapper around default __call__ method to perform common Django test | ||||||
|   | |||||||
| @@ -250,6 +250,9 @@ Tests | |||||||
| * The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that | * The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that | ||||||
|   an HTML fragment is not contained in the given HTML haystack. |   an HTML fragment is not contained in the given HTML haystack. | ||||||
|  |  | ||||||
|  | * In order to enforce test isolation, database connections inside threads are | ||||||
|  |   no longer allowed in :class:`~django.test.SimpleTestCase`. | ||||||
|  |  | ||||||
| URLs | URLs | ||||||
| ~~~~ | ~~~~ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
|  | import threading | ||||||
| import unittest | import unittest | ||||||
| import warnings | import warnings | ||||||
| from io import StringIO | from io import StringIO | ||||||
| @@ -2093,6 +2094,29 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase): | |||||||
|         with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message): |         with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message): | ||||||
|             next(Car.objects.iterator()) |             next(Car.objects.iterator()) | ||||||
|  |  | ||||||
|  |     def test_disallowed_thread_database_connection(self): | ||||||
|  |         expected_message = ( | ||||||
|  |             "Database threaded connections to 'default' are not allowed in " | ||||||
|  |             "SimpleTestCase subclasses. Either subclass TestCase or TransactionTestCase" | ||||||
|  |             " to ensure proper test isolation or add 'default' to " | ||||||
|  |             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||||
|  |             "silence this failure." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         exceptions = [] | ||||||
|  |  | ||||||
|  |         def thread_func(): | ||||||
|  |             try: | ||||||
|  |                 Car.objects.first() | ||||||
|  |             except DatabaseOperationForbidden as e: | ||||||
|  |                 exceptions.append(e) | ||||||
|  |  | ||||||
|  |         t = threading.Thread(target=thread_func) | ||||||
|  |         t.start() | ||||||
|  |         t.join() | ||||||
|  |         self.assertEqual(len(exceptions), 1) | ||||||
|  |         self.assertEqual(exceptions[0].args[0], expected_message) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AllowedDatabaseQueriesTests(SimpleTestCase): | class AllowedDatabaseQueriesTests(SimpleTestCase): | ||||||
|     databases = {"default"} |     databases = {"default"} | ||||||
| @@ -2103,6 +2127,14 @@ class AllowedDatabaseQueriesTests(SimpleTestCase): | |||||||
|     def test_allowed_database_chunked_cursor_queries(self): |     def test_allowed_database_chunked_cursor_queries(self): | ||||||
|         next(Car.objects.iterator(), None) |         next(Car.objects.iterator(), None) | ||||||
|  |  | ||||||
|  |     def test_allowed_threaded_database_queries(self): | ||||||
|  |         def thread_func(): | ||||||
|  |             next(Car.objects.iterator(), None) | ||||||
|  |  | ||||||
|  |         t = threading.Thread(target=thread_func) | ||||||
|  |         t.start() | ||||||
|  |         t.join() | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseAliasTests(SimpleTestCase): | class DatabaseAliasTests(SimpleTestCase): | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user