mirror of
				https://github.com/django/django.git
				synced 2025-10-26 15:16:09 +00:00 
			
		
		
		
	Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local): | |||||||
|         self.settings_dict = settings_dict |         self.settings_dict = settings_dict | ||||||
|         self.alias = alias |         self.alias = alias | ||||||
|         self.vendor = 'unknown' |         self.vendor = 'unknown' | ||||||
|  |         self.use_debug_cursor = None | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         return self.settings_dict == other.settings_dict |         return self.settings_dict == other.settings_dict | ||||||
| @@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local): | |||||||
|     def cursor(self): |     def cursor(self): | ||||||
|         from django.conf import settings |         from django.conf import settings | ||||||
|         cursor = self._cursor() |         cursor = self._cursor() | ||||||
|         if settings.DEBUG: |         if (self.use_debug_cursor or | ||||||
|  |             (self.use_debug_cursor is None and settings.DEBUG)): | ||||||
|             return self.make_debug_cursor(cursor) |             return self.make_debug_cursor(cursor) | ||||||
|         return cursor |         return cursor | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| import re | import re | ||||||
|  | import sys | ||||||
| from urlparse import urlsplit, urlunsplit | from urlparse import urlsplit, urlunsplit | ||||||
| from xml.dom.minidom import parseString, Node | from xml.dom.minidom import parseString, Node | ||||||
|  |  | ||||||
| @@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner): | |||||||
|         for conn in connections: |         for conn in connections: | ||||||
|             transaction.rollback_unless_managed(using=conn) |             transaction.rollback_unless_managed(using=conn) | ||||||
|  |  | ||||||
|  | class _AssertNumQueriesContext(object): | ||||||
|  |     def __init__(self, test_case, num, connection): | ||||||
|  |         self.test_case = test_case | ||||||
|  |         self.num = num | ||||||
|  |         self.connection = connection | ||||||
|  |  | ||||||
|  |     def __enter__(self): | ||||||
|  |         self.old_debug_cursor = self.connection.use_debug_cursor | ||||||
|  |         self.connection.use_debug_cursor = True | ||||||
|  |         self.starting_queries = len(self.connection.queries) | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def __exit__(self, exc_type, exc_value, traceback): | ||||||
|  |         if exc_type is not None: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         self.connection.use_debug_cursor = self.old_debug_cursor | ||||||
|  |         final_queries = len(self.connection.queries) | ||||||
|  |         executed = final_queries - self.starting_queries | ||||||
|  |  | ||||||
|  |         self.test_case.assertEqual( | ||||||
|  |             executed, self.num, "%d queries executed, %d expected" % ( | ||||||
|  |                 executed, self.num | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TransactionTestCase(unittest.TestCase): | class TransactionTestCase(unittest.TestCase): | ||||||
|     # The class we'll use for the test client self.client. |     # The class we'll use for the test client self.client. | ||||||
|     # Can be overridden in derived classes. |     # Can be overridden in derived classes. | ||||||
| @@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase): | |||||||
|     def assertQuerysetEqual(self, qs, values, transform=repr): |     def assertQuerysetEqual(self, qs, values, transform=repr): | ||||||
|         return self.assertEqual(map(transform, qs), values) |         return self.assertEqual(map(transform, qs), values) | ||||||
|  |  | ||||||
|  |     def assertNumQueries(self, num, func=None, *args, **kwargs): | ||||||
|  |         using = kwargs.pop("using", DEFAULT_DB_ALIAS) | ||||||
|  |         connection = connections[using] | ||||||
|  |  | ||||||
|  |         context = _AssertNumQueriesContext(self, num, connection) | ||||||
|  |         if func is None: | ||||||
|  |             return context | ||||||
|  |  | ||||||
|  |         # Basically emulate the `with` statement here. | ||||||
|  |  | ||||||
|  |         context.__enter__() | ||||||
|  |         try: | ||||||
|  |             func(*args, **kwargs) | ||||||
|  |         finally: | ||||||
|  |             context.__exit__(*sys.exc_info()) | ||||||
|  |  | ||||||
| def connections_support_transactions(): | def connections_support_transactions(): | ||||||
|     """ |     """ | ||||||
|     Returns True if all connections support transactions.  This is messy |     Returns True if all connections support transactions.  This is messy | ||||||
|   | |||||||
| @@ -1372,6 +1372,32 @@ cause of an failure in your test suite. | |||||||
|     implicit ordering, you will need to apply a ``order_by()`` clause to your |     implicit ordering, you will need to apply a ``order_by()`` clause to your | ||||||
|     queryset to ensure that the test will pass reliably. |     queryset to ensure that the test will pass reliably. | ||||||
|  |  | ||||||
|  | .. method:: TestCase.assertNumQueries(num, func, *args, **kwargs): | ||||||
|  |  | ||||||
|  |     .. versionadded:: 1.3 | ||||||
|  |  | ||||||
|  |     Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that | ||||||
|  |     ``num`` database queries are executed. | ||||||
|  |  | ||||||
|  |     If a ``"using"`` key is present in ``kwargs`` it is used as the database | ||||||
|  |     alias for which to check the number of queries.  If you wish to call a | ||||||
|  |     function with a ``using`` parameter you can do it by wrapping the call with | ||||||
|  |     a ``lambda`` to add an extra parameter:: | ||||||
|  |  | ||||||
|  |         self.assertNumQueries(7, lambda: my_function(using=7)) | ||||||
|  |  | ||||||
|  |     If you're using Python 2.5 or greater you can also use this as a context | ||||||
|  |     manager:: | ||||||
|  |  | ||||||
|  |         # This is necessary in Python 2.5 to enable the with statement, in 2.6 | ||||||
|  |         # and up it is no longer necessary. | ||||||
|  |         from __future__ import with_statement | ||||||
|  |  | ||||||
|  |         with self.assertNumQueries(2): | ||||||
|  |             Person.objects.create(name="Aaron") | ||||||
|  |             Person.objects.create(name="Daniel") | ||||||
|  |  | ||||||
|  |  | ||||||
| .. _topics-testing-email: | .. _topics-testing-email: | ||||||
|  |  | ||||||
| E-mail services | E-mail services | ||||||
|   | |||||||
| @@ -1,6 +1,4 @@ | |||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.conf import settings |  | ||||||
| from django import db |  | ||||||
|  |  | ||||||
| from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species | from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species | ||||||
|  |  | ||||||
| @@ -36,36 +34,34 @@ class SelectRelatedTests(TestCase): | |||||||
|         # queries so we'll set it to True here and reset it at the end of the |         # queries so we'll set it to True here and reset it at the end of the | ||||||
|         # test case. |         # test case. | ||||||
|         self.create_base_data() |         self.create_base_data() | ||||||
|         settings.DEBUG = True |  | ||||||
|         db.reset_queries() |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         settings.DEBUG = False |  | ||||||
|  |  | ||||||
|     def test_access_fks_without_select_related(self): |     def test_access_fks_without_select_related(self): | ||||||
|         """ |         """ | ||||||
|         Normally, accessing FKs doesn't fill in related objects |         Normally, accessing FKs doesn't fill in related objects | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             fly = Species.objects.get(name="melanogaster") |             fly = Species.objects.get(name="melanogaster") | ||||||
|             domain = fly.genus.family.order.klass.phylum.kingdom.domain |             domain = fly.genus.family.order.klass.phylum.kingdom.domain | ||||||
|             self.assertEqual(domain.name, 'Eukaryota') |             self.assertEqual(domain.name, 'Eukaryota') | ||||||
|         self.assertEqual(len(db.connection.queries), 8) |         self.assertNumQueries(8, test) | ||||||
|  |  | ||||||
|     def test_access_fks_with_select_related(self): |     def test_access_fks_with_select_related(self): | ||||||
|         """ |         """ | ||||||
|         A select_related() call will fill in those related objects without any |         A select_related() call will fill in those related objects without any | ||||||
|         extra queries |         extra queries | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             person = Species.objects.select_related(depth=10).get(name="sapiens") |             person = Species.objects.select_related(depth=10).get(name="sapiens") | ||||||
|             domain = person.genus.family.order.klass.phylum.kingdom.domain |             domain = person.genus.family.order.klass.phylum.kingdom.domain | ||||||
|             self.assertEqual(domain.name, 'Eukaryota') |             self.assertEqual(domain.name, 'Eukaryota') | ||||||
|         self.assertEqual(len(db.connection.queries), 1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_list_without_select_related(self): |     def test_list_without_select_related(self): | ||||||
|         """ |         """ | ||||||
|         select_related() also of course applies to entire lists, not just |         select_related() also of course applies to entire lists, not just | ||||||
|         items. This test verifies the expected behavior without select_related. |         items. This test verifies the expected behavior without select_related. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             world = Species.objects.all() |             world = Species.objects.all() | ||||||
|             families = [o.genus.family.name for o in world] |             families = [o.genus.family.name for o in world] | ||||||
|             self.assertEqual(families, [ |             self.assertEqual(families, [ | ||||||
| @@ -74,13 +70,14 @@ class SelectRelatedTests(TestCase): | |||||||
|                 'Fabaceae', |                 'Fabaceae', | ||||||
|                 'Amanitacae', |                 'Amanitacae', | ||||||
|             ]) |             ]) | ||||||
|         self.assertEqual(len(db.connection.queries), 9) |         self.assertNumQueries(9, test) | ||||||
|  |  | ||||||
|     def test_list_with_select_related(self): |     def test_list_with_select_related(self): | ||||||
|         """ |         """ | ||||||
|         select_related() also of course applies to entire lists, not just |         select_related() also of course applies to entire lists, not just | ||||||
|         items. This test verifies the expected behavior with select_related. |         items. This test verifies the expected behavior with select_related. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             world = Species.objects.all().select_related() |             world = Species.objects.all().select_related() | ||||||
|             families = [o.genus.family.name for o in world] |             families = [o.genus.family.name for o in world] | ||||||
|             self.assertEqual(families, [ |             self.assertEqual(families, [ | ||||||
| @@ -89,20 +86,21 @@ class SelectRelatedTests(TestCase): | |||||||
|                 'Fabaceae', |                 'Fabaceae', | ||||||
|                 'Amanitacae', |                 'Amanitacae', | ||||||
|             ]) |             ]) | ||||||
|         self.assertEqual(len(db.connection.queries), 1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_depth(self, depth=1, expected=7): |     def test_depth(self, depth=1, expected=7): | ||||||
|         """ |         """ | ||||||
|         The "depth" argument to select_related() will stop the descent at a |         The "depth" argument to select_related() will stop the descent at a | ||||||
|         particular level. |         particular level. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             pea = Species.objects.select_related(depth=depth).get(name="sativum") |             pea = Species.objects.select_related(depth=depth).get(name="sativum") | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 pea.genus.family.order.klass.phylum.kingdom.domain.name, |                 pea.genus.family.order.klass.phylum.kingdom.domain.name, | ||||||
|                 'Eukaryota' |                 'Eukaryota' | ||||||
|             ) |             ) | ||||||
|         # Notice: one fewer queries than above because of depth=1 |         # Notice: one fewer queries than above because of depth=1 | ||||||
|         self.assertEqual(len(db.connection.queries), expected) |         self.assertNumQueries(expected, test) | ||||||
|  |  | ||||||
|     def test_larger_depth(self): |     def test_larger_depth(self): | ||||||
|         """ |         """ | ||||||
| @@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase): | |||||||
|         The "depth" argument to select_related() will stop the descent at a |         The "depth" argument to select_related() will stop the descent at a | ||||||
|         particular level. This can be used on lists as well. |         particular level. This can be used on lists as well. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             world = Species.objects.all().select_related(depth=2) |             world = Species.objects.all().select_related(depth=2) | ||||||
|             orders = [o.genus.family.order.name for o in world] |             orders = [o.genus.family.order.name for o in world] | ||||||
|             self.assertEqual(orders, |             self.assertEqual(orders, | ||||||
|                 ['Diptera', 'Primates', 'Fabales', 'Agaricales']) |                 ['Diptera', 'Primates', 'Fabales', 'Agaricales']) | ||||||
|         self.assertEqual(len(db.connection.queries), 5) |         self.assertNumQueries(5, test) | ||||||
|  |  | ||||||
|     def test_select_related_with_extra(self): |     def test_select_related_with_extra(self): | ||||||
|         s = Species.objects.all().select_related(depth=1)\ |         s = Species.objects.all().select_related(depth=1)\ | ||||||
| @@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase): | |||||||
|         In this case, we explicitly say to select the 'genus' and |         In this case, we explicitly say to select the 'genus' and | ||||||
|         'genus.family' models, leading to the same number of queries as before. |         'genus.family' models, leading to the same number of queries as before. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             world = Species.objects.select_related('genus__family') |             world = Species.objects.select_related('genus__family') | ||||||
|             families = [o.genus.family.name for o in world] |             families = [o.genus.family.name for o in world] | ||||||
|             self.assertEqual(families, |             self.assertEqual(families, | ||||||
|                 ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) |                 ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) | ||||||
|         self.assertEqual(len(db.connection.queries), 1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_more_certain_fields(self): |     def test_more_certain_fields(self): | ||||||
|         """ |         """ | ||||||
|         In this case, we explicitly say to select the 'genus' and |         In this case, we explicitly say to select the 'genus' and | ||||||
|         'genus.family' models, leading to the same number of queries as before. |         'genus.family' models, leading to the same number of queries as before. | ||||||
|         """ |         """ | ||||||
|  |         def test(): | ||||||
|             world = Species.objects.filter(genus__name='Amanita')\ |             world = Species.objects.filter(genus__name='Amanita')\ | ||||||
|                 .select_related('genus__family') |                 .select_related('genus__family') | ||||||
|             orders = [o.genus.family.order.name for o in world] |             orders = [o.genus.family.order.name for o in world] | ||||||
|             self.assertEqual(orders, [u'Agaricales']) |             self.assertEqual(orders, [u'Agaricales']) | ||||||
|         self.assertEqual(len(db.connection.queries), 2) |         self.assertNumQueries(2, test) | ||||||
|  |  | ||||||
|     def test_field_traversal(self): |     def test_field_traversal(self): | ||||||
|  |         def test(): | ||||||
|             s = Species.objects.all().select_related('genus__family__order' |             s = Species.objects.all().select_related('genus__family__order' | ||||||
|                 ).order_by('id')[0:1].get().genus.family.order.name |                 ).order_by('id')[0:1].get().genus.family.order.name | ||||||
|             self.assertEqual(s, u'Diptera') |             self.assertEqual(s, u'Diptera') | ||||||
|         self.assertEqual(len(db.connection.queries), 1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_depth_fields_fails(self): |     def test_depth_fields_fails(self): | ||||||
|         self.assertRaises(TypeError, |         self.assertRaises(TypeError, | ||||||
|   | |||||||
| @@ -2,9 +2,11 @@ import datetime | |||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import connection | from django.db import connection | ||||||
|  | from django.test import TestCase | ||||||
| from django.utils import unittest | from django.utils import unittest | ||||||
|  |  | ||||||
| from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate | from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, | ||||||
|  |     UniqueForDateModel, ModelToValidate) | ||||||
|  |  | ||||||
|  |  | ||||||
| class GetUniqueCheckTests(unittest.TestCase): | class GetUniqueCheckTests(unittest.TestCase): | ||||||
| @@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase): | |||||||
|             ), m._get_unique_checks(exclude='start_date') |             ), m._get_unique_checks(exclude='start_date') | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| class PerformUniqueChecksTest(unittest.TestCase): | class PerformUniqueChecksTest(TestCase): | ||||||
|     def setUp(self): |  | ||||||
|         # Set debug to True to gain access to connection.queries. |  | ||||||
|         self._old_debug, settings.DEBUG = settings.DEBUG, True |  | ||||||
|         super(PerformUniqueChecksTest, self).setUp() |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         # Restore old debug value. |  | ||||||
|         settings.DEBUG = self._old_debug |  | ||||||
|         super(PerformUniqueChecksTest, self).tearDown() |  | ||||||
|  |  | ||||||
|     def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): |     def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): | ||||||
|         # Regression test for #12560 |         # Regression test for #12560 | ||||||
|         query_count = len(connection.queries) |         def test(): | ||||||
|             mtv = ModelToValidate(number=10, name='Some Name') |             mtv = ModelToValidate(number=10, name='Some Name') | ||||||
|             setattr(mtv, '_adding', True) |             setattr(mtv, '_adding', True) | ||||||
|             mtv.full_clean() |             mtv.full_clean() | ||||||
|         self.assertEqual(query_count, len(connection.queries)) |         self.assertNumQueries(0, test) | ||||||
|  |  | ||||||
|     def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): |     def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): | ||||||
|         # Regression test for #12560 |         # Regression test for #12560 | ||||||
|         query_count = len(connection.queries) |         def test(): | ||||||
|             mtv = ModelToValidate(number=10, name='Some Name', id=123) |             mtv = ModelToValidate(number=10, name='Some Name', id=123) | ||||||
|             setattr(mtv, '_adding', True) |             setattr(mtv, '_adding', True) | ||||||
|             mtv.full_clean() |             mtv.full_clean() | ||||||
|         self.assertEqual(query_count + 1, len(connection.queries)) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_primary_key_unique_check_not_performed_when_not_adding(self): |     def test_primary_key_unique_check_not_performed_when_not_adding(self): | ||||||
|         # Regression test for #12132 |         # Regression test for #12132 | ||||||
|         query_count= len(connection.queries) |         def test(): | ||||||
|             mtv = ModelToValidate(number=10, name='Some Name') |             mtv = ModelToValidate(number=10, name='Some Name') | ||||||
|             mtv.full_clean() |             mtv.full_clean() | ||||||
|         self.assertEqual(query_count, len(connection.queries)) |         self.assertNumQueries(0, test) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate | |||||||
|  |  | ||||||
| # Import other tests for this package. | # Import other tests for this package. | ||||||
| from modeltests.validation.validators import TestModelsWithValidators | from modeltests.validation.validators import TestModelsWithValidators | ||||||
| from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest | from modeltests.validation.test_unique import (GetUniqueCheckTests, | ||||||
|  |     PerformUniqueChecksTest) | ||||||
| from modeltests.validation.test_custom_messages import CustomMessagesTest | from modeltests.validation.test_custom_messages import CustomMessagesTest | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -111,4 +112,3 @@ class ModelFormsTests(TestCase): | |||||||
|         article = Article(author_id=self.author.id) |         article = Article(author_id=self.author.id) | ||||||
|         form = ArticleForm(data, instance=article) |         form = ArticleForm(data, instance=article) | ||||||
|         self.assertEqual(form.errors.keys(), ['pub_date']) |         self.assertEqual(form.errors.keys(), ['pub_date']) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf | |||||||
|  |  | ||||||
|  |  | ||||||
| class DeferRegressionTest(TestCase): | class DeferRegressionTest(TestCase): | ||||||
|     def assert_num_queries(self, n, func, *args, **kwargs): |  | ||||||
|         old_DEBUG = settings.DEBUG |  | ||||||
|         settings.DEBUG = True |  | ||||||
|         starting_queries = len(connection.queries) |  | ||||||
|         try: |  | ||||||
|             func(*args, **kwargs) |  | ||||||
|         finally: |  | ||||||
|             settings.DEBUG = old_DEBUG |  | ||||||
|         self.assertEqual(starting_queries + n, len(connection.queries)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_basic(self): |     def test_basic(self): | ||||||
|         # Deferred fields should really be deferred and not accidentally use |         # Deferred fields should really be deferred and not accidentally use | ||||||
|         # the field's default value just because they aren't passed to __init__ |         # the field's default value just because they aren't passed to __init__ | ||||||
| @@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase): | |||||||
|         def test(): |         def test(): | ||||||
|             self.assertEqual(obj.name, "first") |             self.assertEqual(obj.name, "first") | ||||||
|             self.assertEqual(obj.other_value, 0) |             self.assertEqual(obj.other_value, 0) | ||||||
|         self.assert_num_queries(0, test) |         self.assertNumQueries(0, test) | ||||||
|  |  | ||||||
|         def test(): |         def test(): | ||||||
|             self.assertEqual(obj.value, 42) |             self.assertEqual(obj.value, 42) | ||||||
|         self.assert_num_queries(1, test) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|         def test(): |         def test(): | ||||||
|             self.assertEqual(obj.text, "xyzzy") |             self.assertEqual(obj.text, "xyzzy") | ||||||
|         self.assert_num_queries(1, test) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|         def test(): |         def test(): | ||||||
|             self.assertEqual(obj.text, "xyzzy") |             self.assertEqual(obj.text, "xyzzy") | ||||||
|         self.assert_num_queries(0, test) |         self.assertNumQueries(0, test) | ||||||
|  |  | ||||||
|         # Regression test for #10695. Make sure different instances don't |         # Regression test for #10695. Make sure different instances don't | ||||||
|         # inadvertently share data in the deferred descriptor objects. |         # inadvertently share data in the deferred descriptor objects. | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| import datetime | import datetime | ||||||
| import tempfile |  | ||||||
| import shutil | import shutil | ||||||
|  | import tempfile | ||||||
|  |  | ||||||
| from django.db import models, connection | from django.db import models | ||||||
| from django.conf import settings |  | ||||||
| # Can't import as "forms" due to implementation details in the test suite (the | # Can't import as "forms" due to implementation details in the test suite (the | ||||||
| # current file is called "forms" and is already imported). | # current file is called "forms" and is already imported). | ||||||
| from django import forms as django_forms | from django import forms as django_forms | ||||||
| @@ -77,19 +76,13 @@ class TestTicket12510(TestCase): | |||||||
|     ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). ''' |     ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). ''' | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         self.groups = [Group.objects.create(name=name) for name in 'abc'] |         self.groups = [Group.objects.create(name=name) for name in 'abc'] | ||||||
|         self.old_debug = settings.DEBUG |  | ||||||
|         # turn debug on to get access to connection.queries |  | ||||||
|         settings.DEBUG = True |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         settings.DEBUG = self.old_debug |  | ||||||
|  |  | ||||||
|     def test_choices_not_fetched_when_not_rendering(self): |     def test_choices_not_fetched_when_not_rendering(self): | ||||||
|         initial_queries = len(connection.queries) |         def test(): | ||||||
|             field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) |             field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) | ||||||
|             self.assertEqual('a', field.clean(self.groups[0].pk).name) |             self.assertEqual('a', field.clean(self.groups[0].pk).name) | ||||||
|         # only one query is required to pull the model from DB |         # only one query is required to pull the model from DB | ||||||
|         self.assertEqual(initial_queries+1, len(connection.queries)) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
| class ModelFormCallableModelDefault(TestCase): | class ModelFormCallableModelDefault(TestCase): | ||||||
|     def test_no_empty_option(self): |     def test_no_empty_option(self): | ||||||
|   | |||||||
| @@ -1,10 +1,8 @@ | |||||||
| import unittest | import unittest | ||||||
| from datetime import date | from datetime import date | ||||||
|  |  | ||||||
| from django import db |  | ||||||
| from django import forms | from django import forms | ||||||
| from django.forms.models import modelform_factory, ModelChoiceField | from django.forms.models import modelform_factory, ModelChoiceField | ||||||
| from django.conf import settings |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.core.exceptions import FieldError, ValidationError | from django.core.exceptions import FieldError, ValidationError | ||||||
| from django.core.files.uploadedfile import SimpleUploadedFile | from django.core.files.uploadedfile import SimpleUploadedFile | ||||||
| @@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \ | |||||||
|  |  | ||||||
|  |  | ||||||
| class ModelMultipleChoiceFieldTests(TestCase): | class ModelMultipleChoiceFieldTests(TestCase): | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         self.old_debug = settings.DEBUG |  | ||||||
|         settings.DEBUG = True |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         settings.DEBUG = self.old_debug |  | ||||||
|  |  | ||||||
|     def test_model_multiple_choice_number_of_queries(self): |     def test_model_multiple_choice_number_of_queries(self): | ||||||
|         """ |         """ | ||||||
|         Test that ModelMultipleChoiceField does O(1) queries instead of |         Test that ModelMultipleChoiceField does O(1) queries instead of | ||||||
| @@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase): | |||||||
|         for i in range(30): |         for i in range(30): | ||||||
|             Person.objects.create(name="Person %s" % i) |             Person.objects.create(name="Person %s" % i) | ||||||
|  |  | ||||||
|         db.reset_queries() |  | ||||||
|         f = forms.ModelMultipleChoiceField(queryset=Person.objects.all()) |         f = forms.ModelMultipleChoiceField(queryset=Person.objects.all()) | ||||||
|         selected = f.clean([1, 3, 5, 7, 9]) |         self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9]) | ||||||
|         self.assertEquals(len(db.connection.queries), 1) |  | ||||||
|  |  | ||||||
| class TripleForm(forms.ModelForm): | class TripleForm(forms.ModelForm): | ||||||
|     class Meta: |     class Meta: | ||||||
|   | |||||||
| @@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, | |||||||
|  |  | ||||||
| class ReverseSelectRelatedTestCase(TestCase): | class ReverseSelectRelatedTestCase(TestCase): | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         # Explicitly enable debug for these tests - we need to count |  | ||||||
|         # the queries that have been issued. |  | ||||||
|         self.old_debug = settings.DEBUG |  | ||||||
|         settings.DEBUG = True |  | ||||||
|  |  | ||||||
|         user = User.objects.create(username="test") |         user = User.objects.create(username="test") | ||||||
|         userprofile = UserProfile.objects.create(user=user, state="KS", |         userprofile = UserProfile.objects.create(user=user, state="KS", | ||||||
|                                                  city="Lawrence") |                                                  city="Lawrence") | ||||||
| @@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase): | |||||||
|                                                   results=results2) |                                                   results=results2) | ||||||
|         StatDetails.objects.create(base_stats=advstat, comments=250) |         StatDetails.objects.create(base_stats=advstat, comments=250) | ||||||
|  |  | ||||||
|         db.reset_queries() |  | ||||||
|  |  | ||||||
|     def assertQueries(self, queries): |  | ||||||
|         self.assertEqual(len(db.connection.queries), queries) |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         settings.DEBUG = self.old_debug |  | ||||||
|  |  | ||||||
|     def test_basic(self): |     def test_basic(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related("userprofile").get(username="test") |             u = User.objects.select_related("userprofile").get(username="test") | ||||||
|             self.assertEqual(u.userprofile.state, "KS") |             self.assertEqual(u.userprofile.state, "KS") | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_follow_next_level(self): |     def test_follow_next_level(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related("userstat__results").get(username="test") |             u = User.objects.select_related("userstat__results").get(username="test") | ||||||
|             self.assertEqual(u.userstat.posts, 150) |             self.assertEqual(u.userstat.posts, 150) | ||||||
|             self.assertEqual(u.userstat.results.results, 'first results') |             self.assertEqual(u.userstat.results.results, 'first results') | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_follow_two(self): |     def test_follow_two(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related("userprofile", "userstat").get(username="test") |             u = User.objects.select_related("userprofile", "userstat").get(username="test") | ||||||
|             self.assertEqual(u.userprofile.state, "KS") |             self.assertEqual(u.userprofile.state, "KS") | ||||||
|             self.assertEqual(u.userstat.posts, 150) |             self.assertEqual(u.userstat.posts, 150) | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_follow_two_next_level(self): |     def test_follow_two_next_level(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") |             u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") | ||||||
|             self.assertEqual(u.userstat.results.results, 'first results') |             self.assertEqual(u.userstat.results.results, 'first results') | ||||||
|             self.assertEqual(u.userstat.statdetails.comments, 259) |             self.assertEqual(u.userstat.statdetails.comments, 259) | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_forward_and_back(self): |     def test_forward_and_back(self): | ||||||
|  |         def test(): | ||||||
|             stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") |             stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") | ||||||
|             self.assertEqual(stat.user.userprofile.state, 'KS') |             self.assertEqual(stat.user.userprofile.state, 'KS') | ||||||
|             self.assertEqual(stat.user.userstat.posts, 150) |             self.assertEqual(stat.user.userstat.posts, 150) | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_back_and_forward(self): |     def test_back_and_forward(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related("userstat").get(username="test") |             u = User.objects.select_related("userstat").get(username="test") | ||||||
|             self.assertEqual(u.userstat.user.username, 'test') |             self.assertEqual(u.userstat.user.username, 'test') | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_not_followed_by_default(self): |     def test_not_followed_by_default(self): | ||||||
|  |         def test(): | ||||||
|             u = User.objects.select_related().get(username="test") |             u = User.objects.select_related().get(username="test") | ||||||
|             self.assertEqual(u.userstat.posts, 150) |             self.assertEqual(u.userstat.posts, 150) | ||||||
|         self.assertQueries(2) |         self.assertNumQueries(2, test) | ||||||
|  |  | ||||||
|     def test_follow_from_child_class(self): |     def test_follow_from_child_class(self): | ||||||
|  |         def test(): | ||||||
|             stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) |             stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) | ||||||
|             self.assertEqual(stat.statdetails.comments, 250) |             self.assertEqual(stat.statdetails.comments, 250) | ||||||
|             self.assertEqual(stat.user.username, 'bob') |             self.assertEqual(stat.user.username, 'bob') | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_follow_inheritance(self): |     def test_follow_inheritance(self): | ||||||
|  |         def test(): | ||||||
|             stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) |             stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) | ||||||
|             self.assertEqual(stat.advanceduserstat.posts, 200) |             self.assertEqual(stat.advanceduserstat.posts, 200) | ||||||
|             self.assertEqual(stat.user.username, 'bob') |             self.assertEqual(stat.user.username, 'bob') | ||||||
|             self.assertEqual(stat.advanceduserstat.user.username, 'bob') |             self.assertEqual(stat.advanceduserstat.user.username, 'bob') | ||||||
|         self.assertQueries(1) |         self.assertNumQueries(1, test) | ||||||
|  |  | ||||||
|     def test_nullable_relation(self): |     def test_nullable_relation(self): | ||||||
|         im = Image.objects.create(name="imag1") |         im = Image.objects.create(name="imag1") | ||||||
|   | |||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | from django.db import models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Person(models.Model): | ||||||
|  |     name = models.CharField(max_length=100) | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								tests/regressiontests/test_utils/python_25.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tests/regressiontests/test_utils/python_25.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | from __future__ import with_statement | ||||||
|  |  | ||||||
|  | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from models import Person | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AssertNumQueriesTests(TestCase): | ||||||
|  |     def test_simple(self): | ||||||
|  |         with self.assertNumQueries(0): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         with self.assertNumQueries(1): | ||||||
|  |             # Guy who wrote Linux | ||||||
|  |             Person.objects.create(name="Linus Torvalds") | ||||||
|  |  | ||||||
|  |         with self.assertNumQueries(2): | ||||||
|  |             # Guy who owns the bagel place I like | ||||||
|  |             Person.objects.create(name="Uncle Ricky") | ||||||
|  |             self.assertEqual(Person.objects.count(), 2) | ||||||
|  |  | ||||||
|  |     def test_failure(self): | ||||||
|  |         with self.assertRaises(AssertionError) as exc_info: | ||||||
|  |             with self.assertNumQueries(2): | ||||||
|  |                 Person.objects.count() | ||||||
|  |         self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected") | ||||||
|  |  | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             with self.assertNumQueries(4000): | ||||||
|  |                 raise TypeError | ||||||
| @@ -1,4 +1,10 @@ | |||||||
| r""" | import sys | ||||||
|  |  | ||||||
|  | if sys.version_info >= (2, 5): | ||||||
|  |     from python_25 import AssertNumQueriesTests | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __test__ = {"API_TEST": r""" | ||||||
| # Some checks of the doctest output normalizer. | # Some checks of the doctest output normalizer. | ||||||
| # Standard doctests do fairly | # Standard doctests do fairly | ||||||
| >>> from django.utils import simplejson | >>> from django.utils import simplejson | ||||||
| @@ -69,4 +75,4 @@ r""" | |||||||
| >>> produce_xml_fragment() | >>> produce_xml_fragment() | ||||||
| '<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>' | '<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>' | ||||||
|  |  | ||||||
| """ | """} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user