From e39dd618085bd437bcd800b5a0a7e29751ab6274 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 Jun 2015 13:49:32 +0200 Subject: [PATCH] Adjusted tests that were messing with database connections too heavily. The previous implementation would result in tests hitting the wrong database when running tests in parallel on multiple databases. --- tests/timezones/tests.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/timezones/tests.py b/tests/timezones/tests.py index 18539687a8..65a052f2cb 100644 --- a/tests/timezones/tests.py +++ b/tests/timezones/tests.py @@ -4,6 +4,7 @@ import datetime import re import sys import warnings +from contextlib import contextmanager from unittest import SkipTest, skipIf from xml.dom.minidom import parseString @@ -611,27 +612,41 @@ class ForcedTimeZoneDatabaseTests(TransactionTestCase): raise SkipTest("Database doesn't support feature(s): test_db_allows_multiple_connections") super(ForcedTimeZoneDatabaseTests, cls).setUpClass() - connections.databases['tz'] = connections.databases['default'].copy() - connections.databases['tz']['TIME_ZONE'] = 'Asia/Bangkok' - @classmethod - def tearDownClass(cls): - connections['tz'].close() - del connections['tz'] - del connections.databases['tz'] - super(ForcedTimeZoneDatabaseTests, cls).tearDownClass() + @contextmanager + def override_database_connection_timezone(self, timezone): + try: + orig_timezone = connection.settings_dict['TIME_ZONE'] + connection.settings_dict['TIME_ZONE'] = timezone + # Clear cached properties, after first accessing them to ensure they exist. + connection.timezone + del connection.timezone + connection.timezone_name + del connection.timezone_name + + yield + + finally: + connection.settings_dict['TIME_ZONE'] = orig_timezone + # Clear cached properties, after first accessing them to ensure they exist. + connection.timezone + del connection.timezone + connection.timezone_name + del connection.timezone_name def test_read_datetime(self): fake_dt = datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=UTC) Event.objects.create(dt=fake_dt) - event = Event.objects.using('tz').get() - dt = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC) + with self.override_database_connection_timezone('Asia/Bangkok'): + event = Event.objects.get() + dt = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC) self.assertEqual(event.dt, dt) def test_write_datetime(self): dt = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC) - Event.objects.using('tz').create(dt=dt) + with self.override_database_connection_timezone('Asia/Bangkok'): + Event.objects.create(dt=dt) event = Event.objects.get() fake_dt = datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=UTC)