mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Added a context manager to capture queries while testing.
Also made some import cleanups while I was there. Refs #10399.
This commit is contained in:
		| @@ -24,31 +24,30 @@ from django.core.exceptions import ValidationError, ImproperlyConfigured | |||||||
| from django.core.handlers.wsgi import WSGIHandler | from django.core.handlers.wsgi import WSGIHandler | ||||||
| from django.core.management import call_command | from django.core.management import call_command | ||||||
| from django.core.management.color import no_style | from django.core.management.color import no_style | ||||||
| from django.core.signals import request_started |  | ||||||
| from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer, | from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer, | ||||||
|     WSGIServerException) |     WSGIServerException) | ||||||
| from django.core.urlresolvers import clear_url_caches | from django.core.urlresolvers import clear_url_caches | ||||||
| from django.core.validators import EMPTY_VALUES | from django.core.validators import EMPTY_VALUES | ||||||
| from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS, | from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction | ||||||
|     reset_queries) |  | ||||||
| from django.forms.fields import CharField | from django.forms.fields import CharField | ||||||
| from django.http import QueryDict | from django.http import QueryDict | ||||||
| from django.test import _doctest as doctest | from django.test import _doctest as doctest | ||||||
| from django.test.client import Client | from django.test.client import Client | ||||||
| from django.test.html import HTMLParseError, parse_html | from django.test.html import HTMLParseError, parse_html | ||||||
| from django.test.signals import template_rendered | from django.test.signals import template_rendered | ||||||
| from django.test.utils import (override_settings, compare_xml, strip_quotes) | from django.test.utils import (CaptureQueriesContext, ContextList, | ||||||
| from django.test.utils import ContextList |     override_settings, compare_xml, strip_quotes) | ||||||
| from django.utils import unittest as ut2 | from django.utils import six, unittest as ut2 | ||||||
| from django.utils.encoding import force_text | from django.utils.encoding import force_text | ||||||
| from django.utils import six | from django.utils.unittest import skipIf # Imported here for backward compatibility | ||||||
| from django.utils.unittest.util import safe_repr | from django.utils.unittest.util import safe_repr | ||||||
| from django.utils.unittest import skipIf |  | ||||||
| from django.views.static import serve | from django.views.static import serve | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase', | __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase', | ||||||
|            'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') |            'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') | ||||||
|  |  | ||||||
|  |  | ||||||
| normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) | normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) | ||||||
| normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", | normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", | ||||||
|                                 lambda m: "Decimal(\"%s\")" % m.groups()[0], s) |                                 lambda m: "Decimal(\"%s\")" % m.groups()[0], s) | ||||||
| @@ -168,28 +167,17 @@ class DocTestRunner(doctest.DocTestRunner): | |||||||
|             transaction.rollback_unless_managed(using=conn) |             transaction.rollback_unless_managed(using=conn) | ||||||
|  |  | ||||||
|  |  | ||||||
| class _AssertNumQueriesContext(object): | class _AssertNumQueriesContext(CaptureQueriesContext): | ||||||
|     def __init__(self, test_case, num, connection): |     def __init__(self, test_case, num, connection): | ||||||
|         self.test_case = test_case |         self.test_case = test_case | ||||||
|         self.num = num |         self.num = num | ||||||
|         self.connection = connection |         super(_AssertNumQueriesContext, self).__init__(connection) | ||||||
|  |  | ||||||
|     def __enter__(self): |  | ||||||
|         self.old_debug_cursor = self.connection.use_debug_cursor |  | ||||||
|         self.connection.use_debug_cursor = True |  | ||||||
|         self.starting_queries = len(self.connection.queries) |  | ||||||
|         request_started.disconnect(reset_queries) |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|     def __exit__(self, exc_type, exc_value, traceback): |     def __exit__(self, exc_type, exc_value, traceback): | ||||||
|         self.connection.use_debug_cursor = self.old_debug_cursor |  | ||||||
|         request_started.connect(reset_queries) |  | ||||||
|         if exc_type is not None: |         if exc_type is not None: | ||||||
|             return |             return | ||||||
|  |         super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback) | ||||||
|         final_queries = len(self.connection.queries) |         executed = len(self) | ||||||
|         executed = final_queries - self.starting_queries |  | ||||||
|  |  | ||||||
|         self.test_case.assertEqual( |         self.test_case.assertEqual( | ||||||
|             executed, self.num, "%d queries executed, %d expected" % ( |             executed, self.num, "%d queries executed, %d expected" % ( | ||||||
|                 executed, self.num |                 executed, self.num | ||||||
| @@ -1051,7 +1039,6 @@ class LiveServerThread(threading.Thread): | |||||||
|         http requests. |         http requests. | ||||||
|         """ |         """ | ||||||
|         if self.connections_override: |         if self.connections_override: | ||||||
|             from django.db import connections |  | ||||||
|             # Override this thread's database connections with the ones |             # Override this thread's database connections with the ones | ||||||
|             # provided by the main thread. |             # provided by the main thread. | ||||||
|             for alias, conn in self.connections_override.items(): |             for alias, conn in self.connections_override.items(): | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ from xml.dom.minidom import parseString, Node | |||||||
|  |  | ||||||
| from django.conf import settings, UserSettingsHolder | from django.conf import settings, UserSettingsHolder | ||||||
| from django.core import mail | from django.core import mail | ||||||
|  | from django.core.signals import request_started | ||||||
|  | from django.db import reset_queries | ||||||
| from django.template import Template, loader, TemplateDoesNotExist | from django.template import Template, loader, TemplateDoesNotExist | ||||||
| from django.template.loaders import cached | from django.template.loaders import cached | ||||||
| from django.test.signals import template_rendered, setting_changed | from django.test.signals import template_rendered, setting_changed | ||||||
| @@ -339,5 +341,42 @@ def strip_quotes(want, got): | |||||||
|         got = got.strip()[2:-1] |         got = got.strip()[2:-1] | ||||||
|     return want, got |     return want, got | ||||||
|  |  | ||||||
|  |  | ||||||
| def str_prefix(s): | def str_prefix(s): | ||||||
|     return s % {'_': '' if six.PY3 else 'u'} |     return s % {'_': '' if six.PY3 else 'u'} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CaptureQueriesContext(object): | ||||||
|  |     """ | ||||||
|  |     Context manager that captures queries executed by the specified connection. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, connection): | ||||||
|  |         self.connection = connection | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|  |         return iter(self.captured_queries) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         return self.captured_queries[index] | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.captured_queries) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def captured_queries(self): | ||||||
|  |         return self.connection.queries[self.initial_queries:self.final_queries] | ||||||
|  |  | ||||||
|  |     def __enter__(self): | ||||||
|  |         self.use_debug_cursor = self.connection.use_debug_cursor | ||||||
|  |         self.connection.use_debug_cursor = True | ||||||
|  |         self.initial_queries = len(self.connection.queries) | ||||||
|  |         self.final_queries = None | ||||||
|  |         request_started.disconnect(reset_queries) | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def __exit__(self, exc_type, exc_value, traceback): | ||||||
|  |         self.connection.use_debug_cursor = self.use_debug_cursor | ||||||
|  |         request_started.connect(reset_queries) | ||||||
|  |         if exc_type is not None: | ||||||
|  |             return | ||||||
|  |         self.final_queries = len(self.connection.queries) | ||||||
|   | |||||||
| @@ -1,10 +1,14 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| from __future__ import absolute_import, unicode_literals | from __future__ import absolute_import, unicode_literals | ||||||
|  | import warnings | ||||||
|  |  | ||||||
|  | from django.db import connection | ||||||
| from django.forms import EmailField, IntegerField | from django.forms import EmailField, IntegerField | ||||||
| from django.http import HttpResponse | from django.http import HttpResponse | ||||||
| from django.template.loader import render_to_string | from django.template.loader import render_to_string | ||||||
| from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature | from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature | ||||||
|  | from django.test.html import HTMLParseError, parse_html | ||||||
|  | from django.test.utils import CaptureQueriesContext | ||||||
| from django.utils import six | from django.utils import six | ||||||
| from django.utils.unittest import skip | from django.utils.unittest import skip | ||||||
|  |  | ||||||
| @@ -94,6 +98,60 @@ class AssertQuerysetEqualTests(TestCase): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CaptureQueriesContextManagerTests(TestCase): | ||||||
|  |     urls = 'test_utils.urls' | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         self.person_pk = six.text_type(Person.objects.create(name='test').pk) | ||||||
|  |  | ||||||
|  |     def test_simple(self): | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             Person.objects.get(pk=self.person_pk) | ||||||
|  |         self.assertEqual(len(captured_queries), 1) | ||||||
|  |         self.assertIn(self.person_pk, captured_queries[0]['sql']) | ||||||
|  |  | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             pass | ||||||
|  |         self.assertEqual(0, len(captured_queries)) | ||||||
|  |  | ||||||
|  |     def test_within(self): | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             Person.objects.get(pk=self.person_pk) | ||||||
|  |             self.assertEqual(len(captured_queries), 1) | ||||||
|  |             self.assertIn(self.person_pk, captured_queries[0]['sql']) | ||||||
|  |  | ||||||
|  |     def test_nested(self): | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             Person.objects.count() | ||||||
|  |             with CaptureQueriesContext(connection) as nested_captured_queries: | ||||||
|  |                 Person.objects.count() | ||||||
|  |         self.assertEqual(1, len(nested_captured_queries)) | ||||||
|  |         self.assertEqual(2, len(captured_queries)) | ||||||
|  |  | ||||||
|  |     def test_failure(self): | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             with CaptureQueriesContext(connection): | ||||||
|  |                 raise TypeError | ||||||
|  |  | ||||||
|  |     def test_with_client(self): | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             self.client.get("/test_utils/get_person/%s/" % self.person_pk) | ||||||
|  |         self.assertEqual(len(captured_queries), 1) | ||||||
|  |         self.assertIn(self.person_pk, captured_queries[0]['sql']) | ||||||
|  |  | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             self.client.get("/test_utils/get_person/%s/" % self.person_pk) | ||||||
|  |         self.assertEqual(len(captured_queries), 1) | ||||||
|  |         self.assertIn(self.person_pk, captured_queries[0]['sql']) | ||||||
|  |  | ||||||
|  |         with CaptureQueriesContext(connection) as captured_queries: | ||||||
|  |             self.client.get("/test_utils/get_person/%s/" % self.person_pk) | ||||||
|  |             self.client.get("/test_utils/get_person/%s/" % self.person_pk) | ||||||
|  |         self.assertEqual(len(captured_queries), 2) | ||||||
|  |         self.assertIn(self.person_pk, captured_queries[0]['sql']) | ||||||
|  |         self.assertIn(self.person_pk, captured_queries[1]['sql']) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AssertNumQueriesContextManagerTests(TestCase): | class AssertNumQueriesContextManagerTests(TestCase): | ||||||
|     urls = 'test_utils.urls' |     urls = 'test_utils.urls' | ||||||
|  |  | ||||||
| @@ -219,7 +277,6 @@ class SaveRestoreWarningState(TestCase): | |||||||
|         # In reality this test could be satisfied by many broken implementations |         # In reality this test could be satisfied by many broken implementations | ||||||
|         # of save_warnings_state/restore_warnings_state (e.g. just |         # of save_warnings_state/restore_warnings_state (e.g. just | ||||||
|         # warnings.resetwarnings()) , but it is difficult to test more. |         # warnings.resetwarnings()) , but it is difficult to test more. | ||||||
|         import warnings |  | ||||||
|         with warnings.catch_warnings(): |         with warnings.catch_warnings(): | ||||||
|             warnings.simplefilter("ignore", DeprecationWarning) |             warnings.simplefilter("ignore", DeprecationWarning) | ||||||
|  |  | ||||||
| @@ -245,7 +302,6 @@ class SaveRestoreWarningState(TestCase): | |||||||
|  |  | ||||||
| class HTMLEqualTests(TestCase): | class HTMLEqualTests(TestCase): | ||||||
|     def test_html_parser(self): |     def test_html_parser(self): | ||||||
|         from django.test.html import parse_html |  | ||||||
|         element = parse_html('<div><p>Hello</p></div>') |         element = parse_html('<div><p>Hello</p></div>') | ||||||
|         self.assertEqual(len(element.children), 1) |         self.assertEqual(len(element.children), 1) | ||||||
|         self.assertEqual(element.children[0].name, 'p') |         self.assertEqual(element.children[0].name, 'p') | ||||||
| @@ -259,7 +315,6 @@ class HTMLEqualTests(TestCase): | |||||||
|         self.assertEqual(dom[0], 'foo') |         self.assertEqual(dom[0], 'foo') | ||||||
|  |  | ||||||
|     def test_parse_html_in_script(self): |     def test_parse_html_in_script(self): | ||||||
|         from django.test.html import parse_html |  | ||||||
|         parse_html('<script>var a = "<p" + ">";</script>'); |         parse_html('<script>var a = "<p" + ">";</script>'); | ||||||
|         parse_html(''' |         parse_html(''' | ||||||
|             <script> |             <script> | ||||||
| @@ -275,8 +330,6 @@ class HTMLEqualTests(TestCase): | |||||||
|         self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>") |         self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>") | ||||||
|  |  | ||||||
|     def test_self_closing_tags(self): |     def test_self_closing_tags(self): | ||||||
|         from django.test.html import parse_html |  | ||||||
|  |  | ||||||
|         self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer', |         self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer', | ||||||
|             'link', 'frame', 'base', 'col') |             'link', 'frame', 'base', 'col') | ||||||
|         for tag in self_closing_tags: |         for tag in self_closing_tags: | ||||||
| @@ -400,7 +453,6 @@ class HTMLEqualTests(TestCase): | |||||||
|         </html>""") |         </html>""") | ||||||
|  |  | ||||||
|     def test_html_contain(self): |     def test_html_contain(self): | ||||||
|         from django.test.html import parse_html |  | ||||||
|         # equal html contains each other |         # equal html contains each other | ||||||
|         dom1 = parse_html('<p>foo') |         dom1 = parse_html('<p>foo') | ||||||
|         dom2 = parse_html('<p>foo</p>') |         dom2 = parse_html('<p>foo</p>') | ||||||
| @@ -424,7 +476,6 @@ class HTMLEqualTests(TestCase): | |||||||
|         self.assertTrue(dom1 in dom2) |         self.assertTrue(dom1 in dom2) | ||||||
|  |  | ||||||
|     def test_count(self): |     def test_count(self): | ||||||
|         from django.test.html import parse_html |  | ||||||
|         # equal html contains each other one time |         # equal html contains each other one time | ||||||
|         dom1 = parse_html('<p>foo') |         dom1 = parse_html('<p>foo') | ||||||
|         dom2 = parse_html('<p>foo</p>') |         dom2 = parse_html('<p>foo</p>') | ||||||
| @@ -459,7 +510,6 @@ class HTMLEqualTests(TestCase): | |||||||
|         self.assertEqual(dom2.count(dom1), 0) |         self.assertEqual(dom2.count(dom1), 0) | ||||||
|  |  | ||||||
|     def test_parsing_errors(self): |     def test_parsing_errors(self): | ||||||
|         from django.test.html import HTMLParseError, parse_html |  | ||||||
|         with self.assertRaises(AssertionError): |         with self.assertRaises(AssertionError): | ||||||
|             self.assertHTMLEqual('<p>', '') |             self.assertHTMLEqual('<p>', '') | ||||||
|         with self.assertRaises(AssertionError): |         with self.assertRaises(AssertionError): | ||||||
| @@ -488,7 +538,6 @@ class HTMLEqualTests(TestCase): | |||||||
|             self.assertContains(response, '<p "whats" that>') |             self.assertContains(response, '<p "whats" that>') | ||||||
|  |  | ||||||
|     def test_unicode_handling(self): |     def test_unicode_handling(self): | ||||||
|         from django.http import HttpResponse |  | ||||||
|         response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>') |         response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>') | ||||||
|         self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True) |         self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user