mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #31347 -- Checked allow_migrate() in CreateExtension operation.
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							d88365708c
						
					
				
				
					commit
					ec292f261d
				
			| @@ -1,7 +1,7 @@ | |||||||
| from django.contrib.postgres.signals import ( | from django.contrib.postgres.signals import ( | ||||||
|     get_citext_oids, get_hstore_oids, register_type_handlers, |     get_citext_oids, get_hstore_oids, register_type_handlers, | ||||||
| ) | ) | ||||||
| from django.db import NotSupportedError | from django.db import NotSupportedError, router | ||||||
| from django.db.migrations import AddIndex, RemoveIndex | from django.db.migrations import AddIndex, RemoveIndex | ||||||
| from django.db.migrations.operations.base import Operation | from django.db.migrations.operations.base import Operation | ||||||
|  |  | ||||||
| @@ -16,7 +16,10 @@ class CreateExtension(Operation): | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def database_forwards(self, app_label, schema_editor, from_state, to_state): |     def database_forwards(self, app_label, schema_editor, from_state, to_state): | ||||||
|         if schema_editor.connection.vendor != 'postgresql': |         if ( | ||||||
|  |             schema_editor.connection.vendor != 'postgresql' or | ||||||
|  |             not router.allow_migrate(schema_editor.connection.alias, app_label) | ||||||
|  |         ): | ||||||
|             return |             return | ||||||
|         schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) |         schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) | ||||||
|         # Clear cached, stale oids. |         # Clear cached, stale oids. | ||||||
| @@ -28,6 +31,8 @@ class CreateExtension(Operation): | |||||||
|         register_type_handlers(schema_editor.connection) |         register_type_handlers(schema_editor.connection) | ||||||
|  |  | ||||||
|     def database_backwards(self, app_label, schema_editor, from_state, to_state): |     def database_backwards(self, app_label, schema_editor, from_state, to_state): | ||||||
|  |         if not router.allow_migrate(schema_editor.connection.alias, app_label): | ||||||
|  |             return | ||||||
|         schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) |         schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) | ||||||
|         # Clear cached, stale oids. |         # Clear cached, stale oids. | ||||||
|         get_hstore_oids.cache_clear() |         get_hstore_oids.cache_clear() | ||||||
|   | |||||||
| @@ -3,12 +3,16 @@ import unittest | |||||||
| from migrations.test_base import OperationTestBase | from migrations.test_base import OperationTestBase | ||||||
|  |  | ||||||
| from django.db import NotSupportedError, connection | from django.db import NotSupportedError, connection | ||||||
|  | from django.db.migrations.state import ProjectState | ||||||
| from django.db.models import Index | from django.db.models import Index | ||||||
| from django.test import modify_settings | from django.test import modify_settings, override_settings | ||||||
|  | from django.test.utils import CaptureQueriesContext | ||||||
|  |  | ||||||
|  | from . import PostgreSQLTestCase | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     from django.contrib.postgres.operations import ( |     from django.contrib.postgres.operations import ( | ||||||
|         AddIndexConcurrently, RemoveIndexConcurrently, |         AddIndexConcurrently, CreateExtension, RemoveIndexConcurrently, | ||||||
|     ) |     ) | ||||||
|     from django.contrib.postgres.indexes import BrinIndex, BTreeIndex |     from django.contrib.postgres.indexes import BrinIndex, BTreeIndex | ||||||
| except ImportError: | except ImportError: | ||||||
| @@ -141,3 +145,44 @@ class RemoveIndexConcurrentlyTests(OperationTestBase): | |||||||
|         self.assertEqual(name, 'RemoveIndexConcurrently') |         self.assertEqual(name, 'RemoveIndexConcurrently') | ||||||
|         self.assertEqual(args, []) |         self.assertEqual(args, []) | ||||||
|         self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'}) |         self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class NoExtensionRouter(): | ||||||
|  |     def allow_migrate(self, db, app_label, **hints): | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.') | ||||||
|  | class CreateExtensionTests(PostgreSQLTestCase): | ||||||
|  |     app_label = 'test_allow_create_extention' | ||||||
|  |  | ||||||
|  |     @override_settings(DATABASE_ROUTERS=[NoExtensionRouter()]) | ||||||
|  |     def test_no_allow_migrate(self): | ||||||
|  |         operation = CreateExtension('uuid-ossp') | ||||||
|  |         project_state = ProjectState() | ||||||
|  |         new_state = project_state.clone() | ||||||
|  |         # Don't create an extension. | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             with connection.schema_editor(atomic=False) as editor: | ||||||
|  |                 operation.database_forwards(self.app_label, editor, project_state, new_state) | ||||||
|  |         self.assertEqual(len(captured_queries), 0) | ||||||
|  |         # Reversal. | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             with connection.schema_editor(atomic=False) as editor: | ||||||
|  |                 operation.database_backwards(self.app_label, editor, new_state, project_state) | ||||||
|  |         self.assertEqual(len(captured_queries), 0) | ||||||
|  |  | ||||||
|  |     def test_allow_migrate(self): | ||||||
|  |         operation = CreateExtension('uuid-ossp') | ||||||
|  |         project_state = ProjectState() | ||||||
|  |         new_state = project_state.clone() | ||||||
|  |         # Create an extension. | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             with connection.schema_editor(atomic=False) as editor: | ||||||
|  |                 operation.database_forwards(self.app_label, editor, project_state, new_state) | ||||||
|  |         self.assertIn('CREATE EXTENSION', captured_queries[0]['sql']) | ||||||
|  |         # Reversal. | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             with connection.schema_editor(atomic=False) as editor: | ||||||
|  |                 operation.database_backwards(self.app_label, editor, new_state, project_state) | ||||||
|  |         self.assertIn('DROP EXTENSION', captured_queries[0]['sql']) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user