mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Fixed #13630 -- Made __init__ methods of all DB backends' DatabaseOperations classes take a connection argument. Thanks calexium for the report.
				
					
				
			git-svn-id: http://code.djangoproject.com/svn/django/trunk@16016 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -133,8 +133,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): | |||||||
|     truncate_params = {'relate' : None} |     truncate_params = {'relate' : None} | ||||||
|  |  | ||||||
|     def __init__(self, connection): |     def __init__(self, connection): | ||||||
|         super(OracleOperations, self).__init__() |         super(OracleOperations, self).__init__(connection) | ||||||
|         self.connection = connection |  | ||||||
|  |  | ||||||
|     def convert_extent(self, clob): |     def convert_extent(self, clob): | ||||||
|         if clob: |         if clob: | ||||||
|   | |||||||
| @@ -110,8 +110,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): | |||||||
|     geometry_functions.update(distance_functions) |     geometry_functions.update(distance_functions) | ||||||
|  |  | ||||||
|     def __init__(self, connection): |     def __init__(self, connection): | ||||||
|         super(DatabaseOperations, self).__init__() |         super(DatabaseOperations, self).__init__(connection) | ||||||
|         self.connection = connection |  | ||||||
|  |  | ||||||
|         # Determine the version of the SpatiaLite library. |         # Determine the version of the SpatiaLite library. | ||||||
|         try: |         try: | ||||||
|   | |||||||
| @@ -388,7 +388,8 @@ class BaseDatabaseOperations(object): | |||||||
|     """ |     """ | ||||||
|     compiler_module = "django.db.models.sql.compiler" |     compiler_module = "django.db.models.sql.compiler" | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self, connection): | ||||||
|  |         self.connection = connection | ||||||
|         self._cache = None |         self._cache = None | ||||||
|  |  | ||||||
|     def autoinc_sql(self, table, column): |     def autoinc_sql(self, table, column): | ||||||
|   | |||||||
| @@ -59,7 +59,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|         super(DatabaseWrapper, self).__init__(*args, **kwargs) |         super(DatabaseWrapper, self).__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|         self.features = BaseDatabaseFeatures(self) |         self.features = BaseDatabaseFeatures(self) | ||||||
|         self.ops = DatabaseOperations() |         self.ops = DatabaseOperations(self) | ||||||
|         self.client = DatabaseClient(self) |         self.client = DatabaseClient(self) | ||||||
|         self.creation = BaseDatabaseCreation(self) |         self.creation = BaseDatabaseCreation(self) | ||||||
|         self.introspection = DatabaseIntrospection(self) |         self.introspection = DatabaseIntrospection(self) | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ if (version < (1,2,1) or (version[:3] == (1, 2, 1) and | |||||||
|     raise ImproperlyConfigured("MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__) |     raise ImproperlyConfigured("MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__) | ||||||
|  |  | ||||||
| from MySQLdb.converters import conversions | from MySQLdb.converters import conversions | ||||||
| from MySQLdb.constants import FIELD_TYPE, FLAG, CLIENT | from MySQLdb.constants import FIELD_TYPE, CLIENT | ||||||
|  |  | ||||||
| from django.db import utils | from django.db import utils | ||||||
| from django.db.backends import * | from django.db.backends import * | ||||||
| @@ -279,7 +279,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|  |  | ||||||
|         self.server_version = None |         self.server_version = None | ||||||
|         self.features = DatabaseFeatures(self) |         self.features = DatabaseFeatures(self) | ||||||
|         self.ops = DatabaseOperations() |         self.ops = DatabaseOperations(self) | ||||||
|         self.client = DatabaseClient(self) |         self.client = DatabaseClient(self) | ||||||
|         self.creation = DatabaseCreation(self) |         self.creation = DatabaseCreation(self) | ||||||
|         self.introspection = DatabaseIntrospection(self) |         self.introspection = DatabaseIntrospection(self) | ||||||
|   | |||||||
| @@ -84,8 +84,8 @@ class DatabaseOperations(BaseDatabaseOperations): | |||||||
|     def autoinc_sql(self, table, column): |     def autoinc_sql(self, table, column): | ||||||
|         # To simulate auto-incrementing primary keys in Oracle, we have to |         # To simulate auto-incrementing primary keys in Oracle, we have to | ||||||
|         # create a sequence and a trigger. |         # create a sequence and a trigger. | ||||||
|         sq_name = get_sequence_name(table) |         sq_name = self._get_sequence_name(table) | ||||||
|         tr_name = get_trigger_name(table) |         tr_name = self._get_trigger_name(table) | ||||||
|         tbl_name = self.quote_name(table) |         tbl_name = self.quote_name(table) | ||||||
|         col_name = self.quote_name(column) |         col_name = self.quote_name(column) | ||||||
|         sequence_sql = """ |         sequence_sql = """ | ||||||
| @@ -197,7 +197,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|         return " DEFERRABLE INITIALLY DEFERRED" |         return " DEFERRABLE INITIALLY DEFERRED" | ||||||
|  |  | ||||||
|     def drop_sequence_sql(self, table): |     def drop_sequence_sql(self, table): | ||||||
|         return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table)) |         return "DROP SEQUENCE %s;" % self.quote_name(self._get_sequence_name(table)) | ||||||
|  |  | ||||||
|     def fetch_returned_insert_id(self, cursor): |     def fetch_returned_insert_id(self, cursor): | ||||||
|         return long(cursor._insert_id_var.getvalue()) |         return long(cursor._insert_id_var.getvalue()) | ||||||
| @@ -209,7 +209,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|             return "%s" |             return "%s" | ||||||
|  |  | ||||||
|     def last_insert_id(self, cursor, table_name, pk_name): |     def last_insert_id(self, cursor, table_name, pk_name): | ||||||
|         sq_name = get_sequence_name(table_name) |         sq_name = self._get_sequence_name(table_name) | ||||||
|         cursor.execute('SELECT "%s".currval FROM dual' % sq_name) |         cursor.execute('SELECT "%s".currval FROM dual' % sq_name) | ||||||
|         return cursor.fetchone()[0] |         return cursor.fetchone()[0] | ||||||
|  |  | ||||||
| @@ -285,7 +285,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|             # Since we've just deleted all the rows, running our sequence |             # Since we've just deleted all the rows, running our sequence | ||||||
|             # ALTER code will reset the sequence to 0. |             # ALTER code will reset the sequence to 0. | ||||||
|             for sequence_info in sequences: |             for sequence_info in sequences: | ||||||
|                 sequence_name = get_sequence_name(sequence_info['table']) |                 sequence_name = self._get_sequence_name(sequence_info['table']) | ||||||
|                 table_name = self.quote_name(sequence_info['table']) |                 table_name = self.quote_name(sequence_info['table']) | ||||||
|                 column_name = self.quote_name(sequence_info['column'] or 'id') |                 column_name = self.quote_name(sequence_info['column'] or 'id') | ||||||
|                 query = _get_sequence_reset_sql() % {'sequence': sequence_name, |                 query = _get_sequence_reset_sql() % {'sequence': sequence_name, | ||||||
| @@ -304,7 +304,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|             for f in model._meta.local_fields: |             for f in model._meta.local_fields: | ||||||
|                 if isinstance(f, models.AutoField): |                 if isinstance(f, models.AutoField): | ||||||
|                     table_name = self.quote_name(model._meta.db_table) |                     table_name = self.quote_name(model._meta.db_table) | ||||||
|                     sequence_name = get_sequence_name(model._meta.db_table) |                     sequence_name = self._get_sequence_name(model._meta.db_table) | ||||||
|                     column_name = self.quote_name(f.column) |                     column_name = self.quote_name(f.column) | ||||||
|                     output.append(query % {'sequence': sequence_name, |                     output.append(query % {'sequence': sequence_name, | ||||||
|                                            'table': table_name, |                                            'table': table_name, | ||||||
| @@ -315,7 +315,7 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|             for f in model._meta.many_to_many: |             for f in model._meta.many_to_many: | ||||||
|                 if not f.rel.through: |                 if not f.rel.through: | ||||||
|                     table_name = self.quote_name(f.m2m_db_table()) |                     table_name = self.quote_name(f.m2m_db_table()) | ||||||
|                     sequence_name = get_sequence_name(f.m2m_db_table()) |                     sequence_name = self._get_sequence_name(f.m2m_db_table()) | ||||||
|                     column_name = self.quote_name('id') |                     column_name = self.quote_name('id') | ||||||
|                     output.append(query % {'sequence': sequence_name, |                     output.append(query % {'sequence': sequence_name, | ||||||
|                                            'table': table_name, |                                            'table': table_name, | ||||||
| @@ -365,6 +365,14 @@ WHEN (new.%(col_name)s IS NULL) | |||||||
|             raise NotImplementedError("Bit-wise or is not supported in Oracle.") |             raise NotImplementedError("Bit-wise or is not supported in Oracle.") | ||||||
|         return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) |         return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) | ||||||
|  |  | ||||||
|  |     def _get_sequence_name(self, table): | ||||||
|  |         name_length = self.max_name_length() - 3 | ||||||
|  |         return '%s_SQ' % util.truncate_name(table, name_length).upper() | ||||||
|  |  | ||||||
|  |     def _get_trigger_name(self, table): | ||||||
|  |         name_length = self.max_name_length() - 3 | ||||||
|  |         return '%s_TR' % util.truncate_name(table, name_length).upper() | ||||||
|  |  | ||||||
|  |  | ||||||
| class _UninitializedOperatorsDescriptor(object): | class _UninitializedOperatorsDescriptor(object): | ||||||
|  |  | ||||||
| @@ -415,7 +423,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|         self.features = DatabaseFeatures(self) |         self.features = DatabaseFeatures(self) | ||||||
|         use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) |         use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) | ||||||
|         self.features.can_return_id_from_insert = use_returning_into |         self.features.can_return_id_from_insert = use_returning_into | ||||||
|         self.ops = DatabaseOperations() |         self.ops = DatabaseOperations(self) | ||||||
|         self.client = DatabaseClient(self) |         self.client = DatabaseClient(self) | ||||||
|         self.creation = DatabaseCreation(self) |         self.creation = DatabaseCreation(self) | ||||||
|         self.introspection = DatabaseIntrospection(self) |         self.introspection = DatabaseIntrospection(self) | ||||||
| @@ -776,13 +784,3 @@ BEGIN | |||||||
|     END LOOP; |     END LOOP; | ||||||
| END; | END; | ||||||
| /""" | /""" | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_sequence_name(table): |  | ||||||
|     name_length = DatabaseOperations().max_name_length() - 3 |  | ||||||
|     return '%s_SQ' % util.truncate_name(table, name_length).upper() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_trigger_name(table): |  | ||||||
|     name_length = DatabaseOperations().max_name_length() - 3 |  | ||||||
|     return '%s_TR' % util.truncate_name(table, name_length).upper() |  | ||||||
|   | |||||||
| @@ -5,9 +5,8 @@ from django.db.backends import BaseDatabaseOperations | |||||||
|  |  | ||||||
| class DatabaseOperations(BaseDatabaseOperations): | class DatabaseOperations(BaseDatabaseOperations): | ||||||
|     def __init__(self, connection): |     def __init__(self, connection): | ||||||
|         super(DatabaseOperations, self).__init__() |         super(DatabaseOperations, self).__init__(connection) | ||||||
|         self._postgres_version = None |         self._postgres_version = None | ||||||
|         self.connection = connection |  | ||||||
|  |  | ||||||
|     def _get_postgres_version(self): |     def _get_postgres_version(self): | ||||||
|         if self._postgres_version is None: |         if self._postgres_version is None: | ||||||
|   | |||||||
| @@ -179,7 +179,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): | |||||||
|         super(DatabaseWrapper, self).__init__(*args, **kwargs) |         super(DatabaseWrapper, self).__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|         self.features = DatabaseFeatures(self) |         self.features = DatabaseFeatures(self) | ||||||
|         self.ops = DatabaseOperations() |         self.ops = DatabaseOperations(self) | ||||||
|         self.client = DatabaseClient(self) |         self.client = DatabaseClient(self) | ||||||
|         self.creation = DatabaseCreation(self) |         self.creation = DatabaseCreation(self) | ||||||
|         self.introspection = DatabaseIntrospection(self) |         self.introspection = DatabaseIntrospection(self) | ||||||
|   | |||||||
| @@ -232,6 +232,12 @@ class BackendTestCase(TestCase): | |||||||
|         self.assertEqual(list(cursor.fetchmany(2)), [(u'Jane', u'Doe'), (u'John', u'Doe')]) |         self.assertEqual(list(cursor.fetchmany(2)), [(u'Jane', u'Doe'), (u'John', u'Doe')]) | ||||||
|         self.assertEqual(list(cursor.fetchall()), [(u'Mary', u'Agnelline'), (u'Peter', u'Parker')]) |         self.assertEqual(list(cursor.fetchall()), [(u'Mary', u'Agnelline'), (u'Peter', u'Parker')]) | ||||||
|  |  | ||||||
|  |     def test_database_operations_helper_class(self): | ||||||
|  |         # Ticket #13630 | ||||||
|  |         self.assertTrue(hasattr(connection, 'ops')) | ||||||
|  |         self.assertTrue(hasattr(connection.ops, 'connection')) | ||||||
|  |         self.assertEqual(connection, connection.ops.connection) | ||||||
|  |  | ||||||
|  |  | ||||||
| # We don't make these tests conditional because that means we would need to | # We don't make these tests conditional because that means we would need to | ||||||
| # check and differentiate between: | # check and differentiate between: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user