From 928e1f1f337814f3a442827a3d6370ea1ca7628c Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 21 Dec 2009 19:30:49 +0000 Subject: [PATCH] [soc2009/multidb] A couple of cleanups of multi-db support for raw queries, including a using() call on raw query sets. Patch from Russell Keith-Magee. git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11930 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/manager.py | 9 +++--- django/db/models/query.py | 29 ++++++++++++++----- django/db/models/sql/query.py | 9 ++++-- docs/topics/db/sql.txt | 2 +- tests/modeltests/raw_query/tests.py | 4 +-- .../multiple_database/tests.py | 5 +++- 6 files changed, 39 insertions(+), 19 deletions(-) diff --git a/django/db/models/manager.py b/django/db/models/manager.py index cd03ee9146..7f96daaa4e 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -91,7 +91,7 @@ class Manager(object): obj = copy.copy(self) obj._db = alias return obj - + @property def db(self): return self._db or DEFAULT_DB_ALIAS @@ -189,7 +189,7 @@ class Manager(object): def using(self, *args, **kwargs): return self.get_query_set().using(*args, **kwargs) - + def exists(self, *args, **kwargs): return self.get_query_set().exists(*args, **kwargs) @@ -199,9 +199,8 @@ class Manager(object): def _update(self, values, **kwargs): return self.get_query_set()._update(values, **kwargs) - def raw(self, query, params=None, *args, **kwargs): - kwargs["using"] = self.db - return RawQuerySet(model=self.model, query=query, params=params, *args, **kwargs) + def raw(self, raw_query, params=None, *args, **kwargs): + return RawQuerySet(raw_query=raw_query, model=self.model, params=params, using=self.db, *args, **kwargs) class ManagerDescriptor(object): # This class ensures managers aren't accessible via model instances. diff --git a/django/db/models/query.py b/django/db/models/query.py index 73b514c13a..8799b4a93b 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -711,10 +711,10 @@ class QuerySet(object): return False ordered = property(ordered) + @property def db(self): "Return the database that will be used if this query is executed now" return self._db or DEFAULT_DB_ALIAS - db = property(db) ################### # PRIVATE METHODS # @@ -1154,11 +1154,12 @@ class RawQuerySet(object): Provides an iterator which converts the results of raw SQL queries into annotated model instances. """ - def __init__(self, query, model=None, query_obj=None, params=None, + def __init__(self, raw_query, model=None, query=None, params=None, translations=None, using=None): + self.raw_query = raw_query self.model = model - self.using = using - self.query = query_obj or sql.RawQuery(sql=query, connection=connections[using], params=params) + self._db = using + self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params) self.params = params or () self.translations = translations or {} @@ -1167,7 +1168,21 @@ class RawQuerySet(object): yield self.transform_results(row) def __repr__(self): - return "" % (self.query.sql % self.params) + return "" % (self.raw_query % self.params) + + @property + def db(self): + "Return the database that will be used if this query is executed now" + return self._db or DEFAULT_DB_ALIAS + + def using(self, alias): + """ + Selects which database this Raw QuerySet should excecute it's query against. + """ + return RawQuerySet(self.raw_query, model=self.model, + query=self.query.clone(using=alias), + params=self.params, translations=self.translations, + using=alias) @property def columns(self): @@ -1232,8 +1247,8 @@ class RawQuerySet(object): for field, value in annotations: setattr(instance, field, value) - - instance._state.db = self.using + + instance._state.db = self.query.using return instance diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dec39451a6..d821c0ee02 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -29,13 +29,16 @@ class RawQuery(object): A single raw SQL query """ - def __init__(self, sql, connection, params=None): + def __init__(self, sql, using, params=None): self.validate_sql(sql) self.params = params or () self.sql = sql - self.connection = connection + self.using = using self.cursor = None + def clone(self, using): + return RawQuery(self.sql, using, params=self.params) + def get_columns(self): if self.cursor is None: self._execute_query() @@ -56,7 +59,7 @@ class RawQuery(object): return "" % (self.sql % self.params) def _execute_query(self): - self.cursor = self.connection.cursor() + self.cursor = connections[self.using].cursor() self.cursor.execute(self.sql, self.params) diff --git a/docs/topics/db/sql.txt b/docs/topics/db/sql.txt index 45aa4f950e..987fcb091f 100644 --- a/docs/topics/db/sql.txt +++ b/docs/topics/db/sql.txt @@ -23,7 +23,7 @@ Performing raw queries The ``raw()`` manager method can be used to perform raw SQL queries that return model instances: -.. method:: Manager.raw(query, params=None, translations=None) +.. method:: Manager.raw(raw_query, params=None, translations=None) This method method takes a raw SQL query, executes it, and returns model instances. diff --git a/tests/modeltests/raw_query/tests.py b/tests/modeltests/raw_query/tests.py index b132605da5..688df21598 100644 --- a/tests/modeltests/raw_query/tests.py +++ b/tests/modeltests/raw_query/tests.py @@ -10,7 +10,7 @@ class RawQueryTests(TestCase): """ Execute the passed query against the passed model and check the output """ - results = list(model.objects.raw(query=query, params=params, translations=translations)) + results = list(model.objects.raw(query, params=params, translations=translations)) self.assertProcessed(results, expected_results, expected_annotations) self.assertAnnotations(results, expected_annotations) @@ -111,7 +111,7 @@ class RawQueryTests(TestCase): query = "SELECT * FROM raw_query_author WHERE first_name = %s" author = Author.objects.all()[2] params = [author.first_name] - results = list(Author.objects.raw(query=query, params=params)) + results = list(Author.objects.raw(query, params=params)) self.assertProcessed(results, [author]) self.assertNoAnnotations(results) self.assertEqual(len(results), 1) diff --git a/tests/regressiontests/multiple_database/tests.py b/tests/regressiontests/multiple_database/tests.py index d75f598c9f..300ed5e0a6 100644 --- a/tests/regressiontests/multiple_database/tests.py +++ b/tests/regressiontests/multiple_database/tests.py @@ -619,7 +619,7 @@ class QueryTestCase(TestCase): self.assertEquals(learn.get_next_by_published().title, "Dive into Python") self.assertEquals(dive.get_previous_by_published().title, "Learning Python") - + def test_raw(self): "test the raw() method across databases" dive = Book.objects.using('other').create(title="Dive into Python", @@ -627,6 +627,9 @@ class QueryTestCase(TestCase): val = Book.objects.db_manager("other").raw('SELECT id FROM "multiple_database_book"') self.assertEqual(map(lambda o: o.pk, val), [dive.pk]) + val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other') + self.assertEqual(map(lambda o: o.pk, val), [dive.pk]) + class UserProfileTestCase(TestCase): def setUp(self):