mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #30171 -- Fixed DatabaseError in servers tests.
Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.
The error appeared after 8c775391b7.
			
			
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| import copy | ||||
| import threading | ||||
| import time | ||||
| import warnings | ||||
| from collections import deque | ||||
| @@ -43,8 +44,7 @@ class BaseDatabaseWrapper: | ||||
|  | ||||
|     queries_limit = 9000 | ||||
|  | ||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, | ||||
|                  allow_thread_sharing=False): | ||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): | ||||
|         # Connection related attributes. | ||||
|         # The underlying database connection. | ||||
|         self.connection = None | ||||
| @@ -80,7 +80,8 @@ class BaseDatabaseWrapper: | ||||
|         self.errors_occurred = False | ||||
|  | ||||
|         # Thread-safety related attributes. | ||||
|         self.allow_thread_sharing = allow_thread_sharing | ||||
|         self._thread_sharing_lock = threading.Lock() | ||||
|         self._thread_sharing_count = 0 | ||||
|         self._thread_ident = _thread.get_ident() | ||||
|  | ||||
|         # A list of no-argument functions to run when the transaction commits. | ||||
| @@ -515,12 +516,27 @@ class BaseDatabaseWrapper: | ||||
|  | ||||
|     # ##### Thread safety handling ##### | ||||
|  | ||||
|     @property | ||||
|     def allow_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             return self._thread_sharing_count > 0 | ||||
|  | ||||
|     def inc_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             self._thread_sharing_count += 1 | ||||
|  | ||||
|     def dec_thread_sharing(self): | ||||
|         with self._thread_sharing_lock: | ||||
|             if self._thread_sharing_count <= 0: | ||||
|                 raise RuntimeError('Cannot decrement the thread sharing count below zero.') | ||||
|             self._thread_sharing_count -= 1 | ||||
|  | ||||
|     def validate_thread_sharing(self): | ||||
|         """ | ||||
|         Validate that the connection isn't accessed by another thread than the | ||||
|         one which originally created it, unless the connection was explicitly | ||||
|         authorized to be shared between threads (via the `allow_thread_sharing` | ||||
|         property). Raise an exception if the validation fails. | ||||
|         authorized to be shared between threads (via the `inc_thread_sharing()` | ||||
|         method). Raise an exception if the validation fails. | ||||
|         """ | ||||
|         if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): | ||||
|             raise DatabaseError( | ||||
| @@ -589,11 +605,7 @@ class BaseDatabaseWrapper: | ||||
|         potential child threads while (or after) the test database is destroyed. | ||||
|         Refs #10868, #17786, #16969. | ||||
|         """ | ||||
|         return self.__class__( | ||||
|             {**self.settings_dict, 'NAME': None}, | ||||
|             alias=NO_DB_ALIAS, | ||||
|             allow_thread_sharing=False, | ||||
|         ) | ||||
|         return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS) | ||||
|  | ||||
|     def schema_editor(self, *args, **kwargs): | ||||
|         """ | ||||
| @@ -635,7 +647,7 @@ class BaseDatabaseWrapper: | ||||
|         finally: | ||||
|             self.execute_wrappers.pop() | ||||
|  | ||||
|     def copy(self, alias=None, allow_thread_sharing=None): | ||||
|     def copy(self, alias=None): | ||||
|         """ | ||||
|         Return a copy of this connection. | ||||
|  | ||||
| @@ -644,6 +656,4 @@ class BaseDatabaseWrapper: | ||||
|         settings_dict = copy.deepcopy(self.settings_dict) | ||||
|         if alias is None: | ||||
|             alias = self.alias | ||||
|         if allow_thread_sharing is None: | ||||
|             allow_thread_sharing = self.allow_thread_sharing | ||||
|         return type(self)(settings_dict, alias, allow_thread_sharing) | ||||
|         return type(self)(settings_dict, alias) | ||||
|   | ||||
| @@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|                     return self.__class__( | ||||
|                         {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, | ||||
|                         alias=self.alias, | ||||
|                         allow_thread_sharing=False, | ||||
|                     ) | ||||
|         return nodb_connection | ||||
|  | ||||
|   | ||||
| @@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase): | ||||
|             # the server thread. | ||||
|             if conn.vendor == 'sqlite' and conn.is_in_memory_db(): | ||||
|                 # Explicitly enable thread-shareability for this connection | ||||
|                 conn.allow_thread_sharing = True | ||||
|                 conn.inc_thread_sharing() | ||||
|                 connections_override[conn.alias] = conn | ||||
|  | ||||
|         cls._live_server_modified_settings = modify_settings( | ||||
| @@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase): | ||||
|             # Terminate the live server's thread | ||||
|             cls.server_thread.terminate() | ||||
|  | ||||
|         # Restore sqlite in-memory database connections' non-shareability | ||||
|         for conn in connections.all(): | ||||
|             if conn.vendor == 'sqlite' and conn.is_in_memory_db(): | ||||
|                 conn.allow_thread_sharing = False | ||||
|             # Restore sqlite in-memory database connections' non-shareability. | ||||
|             for conn in cls.server_thread.connections_override.values(): | ||||
|                 conn.dec_thread_sharing() | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|   | ||||
| @@ -286,6 +286,9 @@ backends. | ||||
|   * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) | ||||
|   * ``_create_check_sql()`` and ``_delete_check_sql()`` | ||||
|  | ||||
| * The third argument of ``DatabaseWrapper.__init__()``, | ||||
|   ``allow_thread_sharing``, is removed. | ||||
|  | ||||
| Admin actions are no longer collected from base ``ModelAdmin`` classes | ||||
| ---------------------------------------------------------------------- | ||||
|  | ||||
|   | ||||
| @@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase): | ||||
|             connection = connections[DEFAULT_DB_ALIAS] | ||||
|             # Allow thread sharing so the connection can be closed by the | ||||
|             # main thread. | ||||
|             connection.allow_thread_sharing = True | ||||
|             connection.inc_thread_sharing() | ||||
|             connection.cursor() | ||||
|             connections_dict[id(connection)] = connection | ||||
|         try: | ||||
|             for x in range(2): | ||||
|                 t = threading.Thread(target=runner) | ||||
|                 t.start() | ||||
|                 t.join() | ||||
|             # Each created connection got different inner connection. | ||||
|             self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) | ||||
|         # Finish by closing the connections opened by the other threads (the | ||||
|         # connection opened in the main thread will automatically be closed on | ||||
|         # teardown). | ||||
|         finally: | ||||
|             # Finish by closing the connections opened by the other threads | ||||
|             # (the connection opened in the main thread will automatically be | ||||
|             # closed on teardown). | ||||
|             for conn in connections_dict.values(): | ||||
|                 if conn is not connection: | ||||
|                     if conn.allow_thread_sharing: | ||||
|                         conn.close() | ||||
|                         conn.dec_thread_sharing() | ||||
|  | ||||
|     def test_connections_thread_local(self): | ||||
|         """ | ||||
| @@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase): | ||||
|             for conn in connections.all(): | ||||
|                 # Allow thread sharing so the connection can be closed by the | ||||
|                 # main thread. | ||||
|                 conn.allow_thread_sharing = True | ||||
|                 conn.inc_thread_sharing() | ||||
|                 connections_dict[id(conn)] = conn | ||||
|         try: | ||||
|             for x in range(2): | ||||
|                 t = threading.Thread(target=runner) | ||||
|                 t.start() | ||||
|                 t.join() | ||||
|             self.assertEqual(len(connections_dict), 6) | ||||
|         # Finish by closing the connections opened by the other threads (the | ||||
|         # connection opened in the main thread will automatically be closed on | ||||
|         # teardown). | ||||
|         finally: | ||||
|             # Finish by closing the connections opened by the other threads | ||||
|             # (the connection opened in the main thread will automatically be | ||||
|             # closed on teardown). | ||||
|             for conn in connections_dict.values(): | ||||
|                 if conn is not connection: | ||||
|                     if conn.allow_thread_sharing: | ||||
|                         conn.close() | ||||
|                         conn.dec_thread_sharing() | ||||
|  | ||||
|     def test_pass_connection_between_threads(self): | ||||
|         """ | ||||
| @@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase): | ||||
|             t.start() | ||||
|             t.join() | ||||
|  | ||||
|         # Without touching allow_thread_sharing, which should be False by default. | ||||
|         # Without touching thread sharing, which should be False by default. | ||||
|         exceptions = [] | ||||
|         do_thread() | ||||
|         # Forbidden! | ||||
|         self.assertIsInstance(exceptions[0], DatabaseError) | ||||
|  | ||||
|         # If explicitly setting allow_thread_sharing to False | ||||
|         connections['default'].allow_thread_sharing = False | ||||
|         exceptions = [] | ||||
|         do_thread() | ||||
|         # Forbidden! | ||||
|         self.assertIsInstance(exceptions[0], DatabaseError) | ||||
|  | ||||
|         # If explicitly setting allow_thread_sharing to True | ||||
|         connections['default'].allow_thread_sharing = True | ||||
|         # After calling inc_thread_sharing() on the connection. | ||||
|         connections['default'].inc_thread_sharing() | ||||
|         try: | ||||
|             exceptions = [] | ||||
|             do_thread() | ||||
|             # All good | ||||
|             self.assertEqual(exceptions, []) | ||||
|         finally: | ||||
|             connections['default'].dec_thread_sharing() | ||||
|  | ||||
|     def test_closing_non_shared_connections(self): | ||||
|         """ | ||||
| @@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase): | ||||
|                 except DatabaseError as e: | ||||
|                     exceptions.add(e) | ||||
|             # Enable thread sharing | ||||
|             connections['default'].allow_thread_sharing = True | ||||
|             connections['default'].inc_thread_sharing() | ||||
|             try: | ||||
|                 t2 = threading.Thread(target=runner2, args=[connections['default']]) | ||||
|                 t2.start() | ||||
|                 t2.join() | ||||
|             finally: | ||||
|                 connections['default'].dec_thread_sharing() | ||||
|         t1 = threading.Thread(target=runner1) | ||||
|         t1.start() | ||||
|         t1.join() | ||||
|         # No exception was raised | ||||
|         self.assertEqual(len(exceptions), 0) | ||||
|  | ||||
|     def test_thread_sharing_count(self): | ||||
|         self.assertIs(connection.allow_thread_sharing, False) | ||||
|         connection.inc_thread_sharing() | ||||
|         self.assertIs(connection.allow_thread_sharing, True) | ||||
|         connection.inc_thread_sharing() | ||||
|         self.assertIs(connection.allow_thread_sharing, True) | ||||
|         connection.dec_thread_sharing() | ||||
|         self.assertIs(connection.allow_thread_sharing, True) | ||||
|         connection.dec_thread_sharing() | ||||
|         self.assertIs(connection.allow_thread_sharing, False) | ||||
|         msg = 'Cannot decrement the thread sharing count below zero.' | ||||
|         with self.assertRaisesMessage(RuntimeError, msg): | ||||
|             connection.dec_thread_sharing() | ||||
|  | ||||
|  | ||||
| class MySQLPKZeroTests(TestCase): | ||||
|     """ | ||||
|   | ||||
| @@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase): | ||||
|         # Pass a connection to the thread to check they are being closed. | ||||
|         connections_override = {DEFAULT_DB_ALIAS: conn} | ||||
|  | ||||
|         saved_sharing = conn.allow_thread_sharing | ||||
|         conn.inc_thread_sharing() | ||||
|         try: | ||||
|             conn.allow_thread_sharing = True | ||||
|             self.assertTrue(conn.is_usable()) | ||||
|             self.run_live_server_thread(connections_override) | ||||
|             self.assertFalse(conn.is_usable()) | ||||
|         finally: | ||||
|             conn.allow_thread_sharing = saved_sharing | ||||
|             conn.dec_thread_sharing() | ||||
|   | ||||
| @@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase): | ||||
|             # app without having set the required STATIC_URL setting.") | ||||
|             pass | ||||
|         finally: | ||||
|             # Use del to avoid decrementing the database thread sharing count a | ||||
|             # second time. | ||||
|             del cls.server_thread | ||||
|             super().tearDownClass() | ||||
|  | ||||
|     def test_test_test(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user