mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	Fixed #17258 -- Moved threading.local from DatabaseWrapper to the django.db.connections dictionary. This allows connections to be explicitly shared between multiple threads and is particularly useful for enabling the sharing of in-memory SQLite connections. Many thanks to Anssi Kääriäinen for the excellent suggestions and feedback, and to Alex Gaynor for the reviews. Refs #2879.
				
					
				
			git-svn-id: http://code.djangoproject.com/svn/django/trunk@17205 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -22,9 +22,21 @@ router = ConnectionRouter(settings.DATABASE_ROUTERS) | ||||
| # we manually create the dictionary from the settings, passing only the | ||||
| # settings that the database backends care about. Note that TIME_ZONE is used | ||||
| # by the PostgreSQL backends. | ||||
| # we load all these up for backwards compatibility, you should use | ||||
| # We load all these up for backwards compatibility, you should use | ||||
| # connections['default'] instead. | ||||
| connection = connections[DEFAULT_DB_ALIAS] | ||||
| class DefaultConnectionProxy(object): | ||||
|     """ | ||||
|     Proxy for accessing the default DatabaseWrapper object's attributes. If you | ||||
|     need to access the DatabaseWrapper object itself, use | ||||
|     connections[DEFAULT_DB_ALIAS] instead. | ||||
|     """ | ||||
|     def __getattr__(self, item): | ||||
|         return getattr(connections[DEFAULT_DB_ALIAS], item) | ||||
|  | ||||
|     def __setattr__(self, name, value): | ||||
|         return setattr(connections[DEFAULT_DB_ALIAS], name, value) | ||||
|  | ||||
| connection = DefaultConnectionProxy() | ||||
| backend = load_backend(connection.settings_dict['ENGINE']) | ||||
|  | ||||
| # Register an event that closes the database connection | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| from django.db.utils import DatabaseError | ||||
|  | ||||
| try: | ||||
|     import thread | ||||
| except ImportError: | ||||
|     import dummy_thread as thread | ||||
| from threading import local | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from django.conf import settings | ||||
| @@ -13,14 +14,15 @@ from django.utils.importlib import import_module | ||||
| from django.utils.timezone import is_aware | ||||
|  | ||||
|  | ||||
| class BaseDatabaseWrapper(local): | ||||
| class BaseDatabaseWrapper(object): | ||||
|     """ | ||||
|     Represents a database connection. | ||||
|     """ | ||||
|     ops = None | ||||
|     vendor = 'unknown' | ||||
|  | ||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): | ||||
|     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, | ||||
|                  allow_thread_sharing=False): | ||||
|         # `settings_dict` should be a dictionary containing keys such as | ||||
|         # NAME, USER, etc. It's called `settings_dict` instead of `settings` | ||||
|         # to disambiguate it from Django settings modules. | ||||
| @@ -34,6 +36,8 @@ class BaseDatabaseWrapper(local): | ||||
|         self.transaction_state = [] | ||||
|         self.savepoint_state = 0 | ||||
|         self._dirty = None | ||||
|         self._thread_ident = thread.get_ident() | ||||
|         self.allow_thread_sharing = allow_thread_sharing | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return self.alias == other.alias | ||||
| @@ -116,6 +120,21 @@ class BaseDatabaseWrapper(local): | ||||
|                 "pending COMMIT/ROLLBACK") | ||||
|         self._dirty = False | ||||
|  | ||||
|     def validate_thread_sharing(self): | ||||
|         """ | ||||
|         Validates that the connection isn't accessed by another thread than the | ||||
|         one which originally created it, unless the connection was explicitly | ||||
|         authorized to be shared between threads (via the `allow_thread_sharing` | ||||
|         property). Raises an exception if the validation fails. | ||||
|         """ | ||||
|         if (not self.allow_thread_sharing | ||||
|             and self._thread_ident != thread.get_ident()): | ||||
|                 raise DatabaseError("DatabaseWrapper objects created in a " | ||||
|                     "thread can only be used in that same thread. The object" | ||||
|                     "with alias '%s' was created in thread id %s and this is " | ||||
|                     "thread id %s." | ||||
|                     % (self.alias, self._thread_ident, thread.get_ident())) | ||||
|  | ||||
|     def is_dirty(self): | ||||
|         """ | ||||
|         Returns True if the current transaction requires a commit for changes to | ||||
| @@ -179,6 +198,7 @@ class BaseDatabaseWrapper(local): | ||||
|         """ | ||||
|         Commits changes if the system is not in managed transaction mode. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         if not self.is_managed(): | ||||
|             self._commit() | ||||
|             self.clean_savepoints() | ||||
| @@ -189,6 +209,7 @@ class BaseDatabaseWrapper(local): | ||||
|         """ | ||||
|         Rolls back changes if the system is not in managed transaction mode. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         if not self.is_managed(): | ||||
|             self._rollback() | ||||
|         else: | ||||
| @@ -198,6 +219,7 @@ class BaseDatabaseWrapper(local): | ||||
|         """ | ||||
|         Does the commit itself and resets the dirty flag. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         self._commit() | ||||
|         self.set_clean() | ||||
|  | ||||
| @@ -205,6 +227,7 @@ class BaseDatabaseWrapper(local): | ||||
|         """ | ||||
|         This function does the rollback itself and resets the dirty flag. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         self._rollback() | ||||
|         self.set_clean() | ||||
|  | ||||
| @@ -228,6 +251,7 @@ class BaseDatabaseWrapper(local): | ||||
|         Rolls back the most recent savepoint (if one exists). Does nothing if | ||||
|         savepoints are not supported. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         if self.savepoint_state: | ||||
|             self._savepoint_rollback(sid) | ||||
|  | ||||
| @@ -236,6 +260,7 @@ class BaseDatabaseWrapper(local): | ||||
|         Commits the most recent savepoint (if one exists). Does nothing if | ||||
|         savepoints are not supported. | ||||
|         """ | ||||
|         self.validate_thread_sharing() | ||||
|         if self.savepoint_state: | ||||
|             self._savepoint_commit(sid) | ||||
|  | ||||
| @@ -269,11 +294,13 @@ class BaseDatabaseWrapper(local): | ||||
|         pass | ||||
|  | ||||
|     def close(self): | ||||
|         self.validate_thread_sharing() | ||||
|         if self.connection is not None: | ||||
|             self.connection.close() | ||||
|             self.connection = None | ||||
|  | ||||
|     def cursor(self): | ||||
|         self.validate_thread_sharing() | ||||
|         if (self.use_debug_cursor or | ||||
|             (self.use_debug_cursor is None and settings.DEBUG)): | ||||
|             cursor = self.make_debug_cursor(self._cursor()) | ||||
|   | ||||
| @@ -7,10 +7,10 @@ standard library. | ||||
|  | ||||
| import datetime | ||||
| import decimal | ||||
| import warnings | ||||
| import re | ||||
| import sys | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.db import utils | ||||
| from django.db.backends import * | ||||
| from django.db.backends.signals import connection_created | ||||
| @@ -241,6 +241,21 @@ class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|                 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, | ||||
|             } | ||||
|             kwargs.update(settings_dict['OPTIONS']) | ||||
|             # Always allow the underlying SQLite connection to be shareable | ||||
|             # between multiple threads. The safe-guarding will be handled at a | ||||
|             # higher level by the `BaseDatabaseWrapper.allow_thread_sharing` | ||||
|             # property. This is necessary as the shareability is disabled by | ||||
|             # default in pysqlite and it cannot be changed once a connection is | ||||
|             # opened. | ||||
|             if 'check_same_thread' in kwargs and kwargs['check_same_thread']: | ||||
|                 warnings.warn( | ||||
|                     'The `check_same_thread` option was provided and set to ' | ||||
|                     'True. It will be overriden with False. Use the ' | ||||
|                     '`DatabaseWrapper.allow_thread_sharing` property instead ' | ||||
|                     'for controlling thread shareability.', | ||||
|                     RuntimeWarning | ||||
|                 ) | ||||
|             kwargs.update({'check_same_thread': False}) | ||||
|             self.connection = Database.connect(**kwargs) | ||||
|             # Register extract, date_trunc, and regexp functions. | ||||
|             self.connection.create_function("django_extract", 2, _sqlite_extract) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import os | ||||
| from threading import local | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| @@ -50,7 +51,7 @@ class ConnectionDoesNotExist(Exception): | ||||
| class ConnectionHandler(object): | ||||
|     def __init__(self, databases): | ||||
|         self.databases = databases | ||||
|         self._connections = {} | ||||
|         self._connections = local() | ||||
|  | ||||
|     def ensure_defaults(self, alias): | ||||
|         """ | ||||
| @@ -73,16 +74,19 @@ class ConnectionHandler(object): | ||||
|             conn.setdefault(setting, None) | ||||
|  | ||||
|     def __getitem__(self, alias): | ||||
|         if alias in self._connections: | ||||
|             return self._connections[alias] | ||||
|         if hasattr(self._connections, alias): | ||||
|             return getattr(self._connections, alias) | ||||
|  | ||||
|         self.ensure_defaults(alias) | ||||
|         db = self.databases[alias] | ||||
|         backend = load_backend(db['ENGINE']) | ||||
|         conn = backend.DatabaseWrapper(db, alias) | ||||
|         self._connections[alias] = conn | ||||
|         setattr(self._connections, alias, conn) | ||||
|         return conn | ||||
|  | ||||
|     def __setitem__(self, key, value): | ||||
|         setattr(self._connections, key, value) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self.databases) | ||||
|  | ||||
|   | ||||
| @@ -673,6 +673,32 @@ datetimes are now stored without time zone information in SQLite. When | ||||
| :setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime | ||||
| object, Django raises an exception. | ||||
|  | ||||
| Database connection's thread-locality | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| ``DatabaseWrapper`` objects (i.e. the connection objects referenced by | ||||
| ``django.db.connection`` and ``django.db.connections["some_alias"]``) used to | ||||
| be thread-local. They are now global objects in order to be potentially shared | ||||
| between multiple threads. While the individual connection objects are now | ||||
| global, the ``django.db.connections`` dictionary referencing those objects is | ||||
| still thread-local. Therefore if you just use the ORM or | ||||
| ``DatabaseWrapper.cursor()`` then the behavior is still the same as before. | ||||
| Note, however, that ``django.db.connection`` does not directly reference the | ||||
| default ``DatabaseWrapper`` object any more and is now a proxy to access that | ||||
| object's attributes. If you need to access the actual ``DatabaseWrapper`` | ||||
| object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead. | ||||
|  | ||||
| As part of this change, all underlying SQLite connections are now enabled for | ||||
| potential thread-sharing (by passing the ``check_same_thread=False`` attribute | ||||
| to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by | ||||
| disabling thread-sharing by default, so this does not affect any existing | ||||
| code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``. | ||||
|  | ||||
| Finally, while it is now possible to pass connections between threads, Django | ||||
| does not make any effort to synchronize access to the underlying backend. | ||||
| Concurrency behavior is defined by the underlying backend implementation. | ||||
| Check their documentation for details. | ||||
|  | ||||
| `COMMENTS_BANNED_USERS_GROUP` setting | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| from __future__ import with_statement, absolute_import | ||||
|  | ||||
| import datetime | ||||
| import threading | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.management.color import no_style | ||||
| @@ -283,7 +284,7 @@ class ConnectionCreatedSignalTest(TestCase): | ||||
|         connection_created.connect(receiver) | ||||
|         connection.close() | ||||
|         cursor = connection.cursor() | ||||
|         self.assertTrue(data["connection"] is connection) | ||||
|         self.assertTrue(data["connection"].connection is connection.connection) | ||||
|  | ||||
|         connection_created.disconnect(receiver) | ||||
|         data.clear() | ||||
| @@ -446,3 +447,94 @@ class FkConstraintsTests(TransactionTestCase): | ||||
|                         connection.check_constraints() | ||||
|             finally: | ||||
|                 transaction.rollback() | ||||
|  | ||||
|  | ||||
| class ThreadTests(TestCase): | ||||
|  | ||||
|     def test_default_connection_thread_local(self): | ||||
|         """ | ||||
|         Ensure that the default connection (i.e. django.db.connection) is | ||||
|         different for each thread. | ||||
|         Refs #17258. | ||||
|         """ | ||||
|         connections_set = set() | ||||
|         connection.cursor() | ||||
|         connections_set.add(connection.connection) | ||||
|         def runner(): | ||||
|             from django.db import connection | ||||
|             connection.cursor() | ||||
|             connections_set.add(connection.connection) | ||||
|         for x in xrange(2): | ||||
|             t = threading.Thread(target=runner) | ||||
|             t.start() | ||||
|             t.join() | ||||
|         self.assertEquals(len(connections_set), 3) | ||||
|         # Finish by closing the connections opened by the other threads (the | ||||
|         # connection opened in the main thread will automatically be closed on | ||||
|         # teardown). | ||||
|         for conn in connections_set: | ||||
|             if conn != connection.connection: | ||||
|                 conn.close() | ||||
|  | ||||
|     def test_connections_thread_local(self): | ||||
|         """ | ||||
|         Ensure that the connections are different for each thread. | ||||
|         Refs #17258. | ||||
|         """ | ||||
|         connections_set = set() | ||||
|         for conn in connections.all(): | ||||
|             connections_set.add(conn) | ||||
|         def runner(): | ||||
|             from django.db import connections | ||||
|             for conn in connections.all(): | ||||
|                 connections_set.add(conn) | ||||
|         for x in xrange(2): | ||||
|             t = threading.Thread(target=runner) | ||||
|             t.start() | ||||
|             t.join() | ||||
|         self.assertEquals(len(connections_set), 6) | ||||
|         # Finish by closing the connections opened by the other threads (the | ||||
|         # connection opened in the main thread will automatically be closed on | ||||
|         # teardown). | ||||
|         for conn in connections_set: | ||||
|             if conn != connection: | ||||
|                 conn.close() | ||||
|  | ||||
|     def test_pass_connection_between_threads(self): | ||||
|         """ | ||||
|         Ensure that a connection can be passed from one thread to the other. | ||||
|         Refs #17258. | ||||
|         """ | ||||
|         models.Person.objects.create(first_name="John", last_name="Doe") | ||||
|  | ||||
|         def do_thread(): | ||||
|             def runner(main_thread_connection): | ||||
|                 from django.db import connections | ||||
|                 connections['default'] = main_thread_connection | ||||
|                 try: | ||||
|                     models.Person.objects.get(first_name="John", last_name="Doe") | ||||
|                 except DatabaseError, e: | ||||
|                     exceptions.append(e) | ||||
|             t = threading.Thread(target=runner, args=[connections['default']]) | ||||
|             t.start() | ||||
|             t.join() | ||||
|  | ||||
|         # Without touching allow_thread_sharing, which should be False by default. | ||||
|         exceptions = [] | ||||
|         do_thread() | ||||
|         # Forbidden! | ||||
|         self.assertTrue(isinstance(exceptions[0], DatabaseError)) | ||||
|  | ||||
|         # If explicitly setting allow_thread_sharing to False | ||||
|         connections['default'].allow_thread_sharing = False | ||||
|         exceptions = [] | ||||
|         do_thread() | ||||
|         # Forbidden! | ||||
|         self.assertTrue(isinstance(exceptions[0], DatabaseError)) | ||||
|  | ||||
|         # If explicitly setting allow_thread_sharing to True | ||||
|         connections['default'].allow_thread_sharing = True | ||||
|         exceptions = [] | ||||
|         do_thread() | ||||
|         # All good | ||||
|         self.assertEqual(len(exceptions), 0) | ||||
		Reference in New Issue
	
	Block a user