mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Refs #28478 -- Prevented connection attempts against disallowed databases in tests.
Mocking connect as well as cursor methods makes sure an appropriate error message is surfaced when running a subset of test attempting to access a a disallowed database.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							a96b901932
						
					
				
				
					commit
					f5b635086a
				
			| @@ -135,7 +135,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): | |||||||
|         return '%s was rendered.' % self.template_name |         return '%s was rendered.' % self.template_name | ||||||
|  |  | ||||||
|  |  | ||||||
| class _CursorFailure: | class _DatabaseFailure: | ||||||
|     def __init__(self, wrapped, message): |     def __init__(self, wrapped, message): | ||||||
|         self.wrapped = wrapped |         self.wrapped = wrapped | ||||||
|         self.message = message |         self.message = message | ||||||
| @@ -173,11 +173,17 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|  |  | ||||||
|     databases = _SimpleTestCaseDatabasesDescriptor() |     databases = _SimpleTestCaseDatabasesDescriptor() | ||||||
|     _disallowed_database_msg = ( |     _disallowed_database_msg = ( | ||||||
|         'Database queries are not allowed in SimpleTestCase subclasses. ' |         'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase ' | ||||||
|         'Either subclass TestCase or TransactionTestCase to ensure proper ' |         'subclasses. Either subclass TestCase or TransactionTestCase to ensure ' | ||||||
|         'test isolation or add %(alias)r to %(test)s.databases to silence ' |         'proper test isolation or add %(alias)r to %(test)s.databases to silence ' | ||||||
|         'this failure.' |         'this failure.' | ||||||
|     ) |     ) | ||||||
|  |     _disallowed_connection_methods = [ | ||||||
|  |         ('connect', 'connections'), | ||||||
|  |         ('temporary_connection', 'connections'), | ||||||
|  |         ('cursor', 'queries'), | ||||||
|  |         ('chunked_cursor', 'queries'), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def setUpClass(cls): |     def setUpClass(cls): | ||||||
| @@ -188,7 +194,7 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|         if cls._modified_settings: |         if cls._modified_settings: | ||||||
|             cls._cls_modified_context = modify_settings(cls._modified_settings) |             cls._cls_modified_context = modify_settings(cls._modified_settings) | ||||||
|             cls._cls_modified_context.enable() |             cls._cls_modified_context.enable() | ||||||
|         cls._add_cursor_failures() |         cls._add_databases_failures() | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _validate_databases(cls): |     def _validate_databases(cls): | ||||||
| @@ -208,31 +214,34 @@ class SimpleTestCase(unittest.TestCase): | |||||||
|         return frozenset(cls.databases) |         return frozenset(cls.databases) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _add_cursor_failures(cls): |     def _add_databases_failures(cls): | ||||||
|         cls.databases = cls._validate_databases() |         cls.databases = cls._validate_databases() | ||||||
|         for alias in connections: |         for alias in connections: | ||||||
|             if alias in cls.databases: |             if alias in cls.databases: | ||||||
|                 continue |                 continue | ||||||
|             connection = connections[alias] |             connection = connections[alias] | ||||||
|             message = cls._disallowed_database_msg % { |             for name, operation in cls._disallowed_connection_methods: | ||||||
|                 'test': '%s.%s' % (cls.__module__, cls.__qualname__), |                 message = cls._disallowed_database_msg % { | ||||||
|                 'alias': alias, |                     'test': '%s.%s' % (cls.__module__, cls.__qualname__), | ||||||
|             } |                     'alias': alias, | ||||||
|             connection.cursor = _CursorFailure(connection.cursor, message) |                     'operation': operation, | ||||||
|             connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message) |                 } | ||||||
|  |                 method = getattr(connection, name) | ||||||
|  |                 setattr(connection, name, _DatabaseFailure(method, message)) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _remove_cursor_failures(cls): |     def _remove_databases_failures(cls): | ||||||
|         for alias in connections: |         for alias in connections: | ||||||
|             if alias in cls.databases: |             if alias in cls.databases: | ||||||
|                 continue |                 continue | ||||||
|             connection = connections[alias] |             connection = connections[alias] | ||||||
|             connection.cursor = connection.cursor.wrapped |             for name, _ in cls._disallowed_connection_methods: | ||||||
|             connection.chunked_cursor = connection.chunked_cursor.wrapped |                 method = getattr(connection, name) | ||||||
|  |                 setattr(connection, name, method.wrapped) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def tearDownClass(cls): |     def tearDownClass(cls): | ||||||
|         cls._remove_cursor_failures() |         cls._remove_databases_failures() | ||||||
|         if hasattr(cls, '_cls_modified_context'): |         if hasattr(cls, '_cls_modified_context'): | ||||||
|             cls._cls_modified_context.disable() |             cls._cls_modified_context.disable() | ||||||
|             delattr(cls, '_cls_modified_context') |             delattr(cls, '_cls_modified_context') | ||||||
| @@ -894,8 +903,8 @@ class TransactionTestCase(SimpleTestCase): | |||||||
|  |  | ||||||
|     databases = _TransactionTestCaseDatabasesDescriptor() |     databases = _TransactionTestCaseDatabasesDescriptor() | ||||||
|     _disallowed_database_msg = ( |     _disallowed_database_msg = ( | ||||||
|         'Database queries to %(alias)r are not allowed in this test. Add ' |         'Database %(operation)s to %(alias)r are not allowed in this test. ' | ||||||
|         '%(alias)r to %(test)s.databases to ensure proper test isolation ' |         'Add %(alias)r to %(test)s.databases to ensure proper test isolation ' | ||||||
|         'and silence this failure.' |         'and silence this failure.' | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
| @@ -1121,13 +1130,13 @@ class TestCase(TransactionTestCase): | |||||||
|                     call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) |                     call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) | ||||||
|                 except Exception: |                 except Exception: | ||||||
|                     cls._rollback_atomics(cls.cls_atomics) |                     cls._rollback_atomics(cls.cls_atomics) | ||||||
|                     cls._remove_cursor_failures() |                     cls._remove_databases_failures() | ||||||
|                     raise |                     raise | ||||||
|         try: |         try: | ||||||
|             cls.setUpTestData() |             cls.setUpTestData() | ||||||
|         except Exception: |         except Exception: | ||||||
|             cls._rollback_atomics(cls.cls_atomics) |             cls._rollback_atomics(cls.cls_atomics) | ||||||
|             cls._remove_cursor_failures() |             cls._remove_databases_failures() | ||||||
|             raise |             raise | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from django.db import IntegrityError, transaction | from django.db import IntegrityError, connections, transaction | ||||||
| from django.test import TestCase, skipUnlessDBFeature | from django.test import TestCase, skipUnlessDBFeature | ||||||
|  |  | ||||||
| from .models import Car, PossessedCar | from .models import Car, PossessedCar | ||||||
| @@ -19,6 +19,17 @@ class TestTestCase(TestCase): | |||||||
|         finally: |         finally: | ||||||
|             self._rollback_atomics = rollback_atomics |             self._rollback_atomics = rollback_atomics | ||||||
|  |  | ||||||
|  |     def test_disallowed_database_connection(self): | ||||||
|  |         message = ( | ||||||
|  |             "Database connections to 'other' are not allowed in this test. " | ||||||
|  |             "Add 'other' to test_utils.test_testcase.TestTestCase.databases to " | ||||||
|  |             "ensure proper test isolation and silence this failure." | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(AssertionError, message): | ||||||
|  |             connections['other'].connect() | ||||||
|  |         with self.assertRaisesMessage(AssertionError, message): | ||||||
|  |             connections['other'].temporary_connection() | ||||||
|  |  | ||||||
|     def test_disallowed_database_queries(self): |     def test_disallowed_database_queries(self): | ||||||
|         message = ( |         message = ( | ||||||
|             "Database queries to 'other' are not allowed in this test. " |             "Database queries to 'other' are not allowed in this test. " | ||||||
|   | |||||||
| @@ -1159,11 +1159,24 @@ class TestBadSetUpTestData(TestCase): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DisallowedDatabaseQueriesTests(SimpleTestCase): | class DisallowedDatabaseQueriesTests(SimpleTestCase): | ||||||
|  |     def test_disallowed_database_connections(self): | ||||||
|  |         expected_message = ( | ||||||
|  |             "Database connections to 'default' are not allowed in SimpleTestCase " | ||||||
|  |             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||||
|  |             "ensure proper test isolation or add 'default' to " | ||||||
|  |             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||||
|  |             "silence this failure." | ||||||
|  |         ) | ||||||
|  |         with self.assertRaisesMessage(AssertionError, expected_message): | ||||||
|  |             connection.connect() | ||||||
|  |         with self.assertRaisesMessage(AssertionError, expected_message): | ||||||
|  |             connection.temporary_connection() | ||||||
|  |  | ||||||
|     def test_disallowed_database_queries(self): |     def test_disallowed_database_queries(self): | ||||||
|         expected_message = ( |         expected_message = ( | ||||||
|             "Database queries are not allowed in SimpleTestCase subclasses. " |             "Database queries to 'default' are not allowed in SimpleTestCase " | ||||||
|             "Either subclass TestCase or TransactionTestCase to ensure proper " |             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||||
|             "test isolation or add 'default' to " |             "ensure proper test isolation or add 'default' to " | ||||||
|             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " |             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||||
|             "silence this failure." |             "silence this failure." | ||||||
|         ) |         ) | ||||||
| @@ -1172,9 +1185,9 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase): | |||||||
|  |  | ||||||
|     def test_disallowed_database_chunked_cursor_queries(self): |     def test_disallowed_database_chunked_cursor_queries(self): | ||||||
|         expected_message = ( |         expected_message = ( | ||||||
|             "Database queries are not allowed in SimpleTestCase subclasses. " |             "Database queries to 'default' are not allowed in SimpleTestCase " | ||||||
|             "Either subclass TestCase or TransactionTestCase to ensure proper " |             "subclasses. Either subclass TestCase or TransactionTestCase to " | ||||||
|             "test isolation or add 'default' to " |             "ensure proper test isolation or add 'default' to " | ||||||
|             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " |             "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " | ||||||
|             "silence this failure." |             "silence this failure." | ||||||
|         ) |         ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user