mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Refs #28478 -- Prevented connection attempts against disallowed databases in tests.
Mocking connect as well as cursor methods makes sure an appropriate error message is surfaced when running a subset of test attempting to access a a disallowed database.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							a96b901932
						
					
				
				
					commit
					f5b635086a
				
			| @@ -135,7 +135,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): | ||||
|         return '%s was rendered.' % self.template_name | ||||
|  | ||||
|  | ||||
| class _CursorFailure: | ||||
| class _DatabaseFailure: | ||||
|     def __init__(self, wrapped, message): | ||||
|         self.wrapped = wrapped | ||||
|         self.message = message | ||||
| @@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase): | ||||
|  | ||||
|     databases = _SimpleTestCaseDatabasesDescriptor() | ||||
|     _disallowed_database_msg = ( | ||||
|         'Database queries are not allowed in SimpleTestCase subclasses. ' | ||||
|         'Either subclass TestCase or TransactionTestCase to ensure proper ' | ||||
|         'test isolation or add %(alias)r to %(test)s.databases to silence ' | ||||
|         'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase ' | ||||
|         'subclasses. Either subclass TestCase or TransactionTestCase to ensure ' | ||||
|         'proper test isolation or add %(alias)r to %(test)s.databases to silence ' | ||||
|         'this failure.' | ||||
|     ) | ||||
|     _disallowed_connection_methods = [ | ||||
|         ('connect', 'connections'), | ||||
|         ('temporary_connection', 'connections'), | ||||
|         ('cursor', 'queries'), | ||||
|         ('chunked_cursor', 'queries'), | ||||
|     ] | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
| @@ -188,7 +194,7 @@ class SimpleTestCase(unittest.TestCase): | ||||
|         if cls._modified_settings: | ||||
|             cls._cls_modified_context = modify_settings(cls._modified_settings) | ||||
|             cls._cls_modified_context.enable() | ||||
|         cls._add_cursor_failures() | ||||
|         cls._add_databases_failures() | ||||
|  | ||||
|     @classmethod | ||||
|     def _validate_databases(cls): | ||||
| @@ -208,31 +214,34 @@ class SimpleTestCase(unittest.TestCase): | ||||
|         return frozenset(cls.databases) | ||||
|  | ||||
|     @classmethod | ||||
|     def _add_cursor_failures(cls): | ||||
|     def _add_databases_failures(cls): | ||||
|         cls.databases = cls._validate_databases() | ||||
|         for alias in connections: | ||||
|             if alias in cls.databases: | ||||
|                 continue | ||||
|             connection = connections[alias] | ||||
|             message = cls._disallowed_database_msg % { | ||||
|                 'test': '%s.%s' % (cls.__module__, cls.__qualname__), | ||||
|                 'alias': alias, | ||||
|             } | ||||
|             connection.cursor = _CursorFailure(connection.cursor, message) | ||||
|             connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message) | ||||
|             for name, operation in cls._disallowed_connection_methods: | ||||
|                 message = cls._disallowed_database_msg % { | ||||
|                     'test': '%s.%s' % (cls.__module__, cls.__qualname__), | ||||
|                     'alias': alias, | ||||
|                     'operation': operation, | ||||
|                 } | ||||
|                 method = getattr(connection, name) | ||||
|                 setattr(connection, name, _DatabaseFailure(method, message)) | ||||
|  | ||||
|     @classmethod | ||||
|     def _remove_cursor_failures(cls): | ||||
|     def _remove_databases_failures(cls): | ||||
|         for alias in connections: | ||||
|             if alias in cls.databases: | ||||
|                 continue | ||||
|             connection = connections[alias] | ||||
|             connection.cursor = connection.cursor.wrapped | ||||
|             connection.chunked_cursor = connection.chunked_cursor.wrapped | ||||
|             for name, _ in cls._disallowed_connection_methods: | ||||
|                 method = getattr(connection, name) | ||||
|                 setattr(connection, name, method.wrapped) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         cls._remove_cursor_failures() | ||||
|         cls._remove_databases_failures() | ||||
|         if hasattr(cls, '_cls_modified_context'): | ||||
|             cls._cls_modified_context.disable() | ||||
|             delattr(cls, '_cls_modified_context') | ||||
| @@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase): | ||||
|  | ||||
|     databases = _TransactionTestCaseDatabasesDescriptor() | ||||
|     _disallowed_database_msg = ( | ||||
|         'Database queries to %(alias)r are not allowed in this test. Add ' | ||||
|         '%(alias)r to %(test)s.databases to ensure proper test isolation ' | ||||
|         'Database %(operation)s to %(alias)r are not allowed in this test. ' | ||||
|         'Add %(alias)r to %(test)s.databases to ensure proper test isolation ' | ||||
|         'and silence this failure.' | ||||
|     ) | ||||
|  | ||||
| @@ -1121,13 +1130,13 @@ class TestCase(TransactionTestCase): | ||||
|                     call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) | ||||
|                 except Exception: | ||||
|                     cls._rollback_atomics(cls.cls_atomics) | ||||
|                     cls._remove_cursor_failures() | ||||
|                     cls._remove_databases_failures() | ||||
|                     raise | ||||
|         try: | ||||
|             cls.setUpTestData() | ||||
|         except Exception: | ||||
|             cls._rollback_atomics(cls.cls_atomics) | ||||
|             cls._remove_cursor_failures() | ||||
|             cls._remove_databases_failures() | ||||
|             raise | ||||
|  | ||||
|     @classmethod | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from django.db import IntegrityError, transaction | ||||
| from django.db import IntegrityError, connections, transaction | ||||
| from django.test import TestCase, skipUnlessDBFeature | ||||
|  | ||||
| from .models import Car, PossessedCar | ||||
| @@ -19,6 +19,17 @@ class TestTestCase(TestCase): | ||||
|         finally: | ||||
|             self._rollback_atomics = rollback_atomics | ||||
|  | ||||
|     def test_disallowed_database_connection(self): | ||||
|         message = ( | ||||
|             "Database connections to 'other' are not allowed in this test. " | ||||
|             "Add 'other' to test_utils.test_testcase.TestTestCase.databases to " | ||||
|             "ensure proper test isolation and silence this failure." | ||||
|         ) | ||||
|         with self.assertRaisesMessage(AssertionError, message): | ||||
|             connections['other'].connect() | ||||
|         with self.assertRaisesMessage(AssertionError, message): | ||||
|             connections['other'].temporary_connection() | ||||
|  | ||||
|     def test_disallowed_database_queries(self): | ||||
|         message = ( | ||||
|             "Database queries to 'other' are not allowed in this test. " | ||||
|   | ||||
| @@ -1159,11 +1159,24 @@ class TestBadSetUpTestData(TestCase): | ||||
|  | ||||
|  | ||||
| class DisallowedDatabaseQueriesTests(SimpleTestCase): | ||||
|     def test_disallowed_database_connections(self): | ||||
|         expected_message = ( | ||||
|             "Database connections to 'default' are not allowed in SimpleTestCase " | ||||
|             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||
|             "ensure proper test isolation or add 'default' to " | ||||
|             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||
|             "silence this failure." | ||||
|         ) | ||||
|         with self.assertRaisesMessage(AssertionError, expected_message): | ||||
|             connection.connect() | ||||
|         with self.assertRaisesMessage(AssertionError, expected_message): | ||||
|             connection.temporary_connection() | ||||
|  | ||||
|     def test_disallowed_database_queries(self): | ||||
|         expected_message = ( | ||||
|             "Database queries are not allowed in SimpleTestCase subclasses. " | ||||
|             "Either subclass TestCase or TransactionTestCase to ensure proper " | ||||
|             "test isolation or add 'default' to " | ||||
|             "Database queries to 'default' are not allowed in SimpleTestCase " | ||||
|             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||
|             "ensure proper test isolation or add 'default' to " | ||||
|             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||
|             "silence this failure." | ||||
|         ) | ||||
| @@ -1172,9 +1185,9 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase): | ||||
|  | ||||
|     def test_disallowed_database_chunked_cursor_queries(self): | ||||
|         expected_message = ( | ||||
|             "Database queries are not allowed in SimpleTestCase subclasses. " | ||||
|             "Either subclass TestCase or TransactionTestCase to ensure proper " | ||||
|             "test isolation or add 'default' to " | ||||
|             "Database queries to 'default' are not allowed in SimpleTestCase " | ||||
|             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||
|             "ensure proper test isolation or add 'default' to " | ||||
|             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||
|             "silence this failure." | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user