mirror of
				https://github.com/django/django.git
				synced 2025-10-31 01:25:32 +00:00 
			
		
		
		
	Fixed #32489 -- Added iter_test_cases() to iterate over a TestSuite.
This also makes partition_suite_by_type(), partition_suite_by_case(), filter_tests_by_tags(), and DiscoverRunner._get_databases() to use iter_test_cases().
This commit is contained in:
		
				
					committed by
					
						 Mariusz Felisiak
						Mariusz Felisiak
					
				
			
			
				
	
			
			
			
						parent
						
							b190419278
						
					
				
				
					commit
					22c9af0eae
				
			| @@ -16,9 +16,9 @@ from django.core.management import call_command | ||||
| from django.db import connections | ||||
| from django.test import SimpleTestCase, TestCase | ||||
| from django.test.utils import ( | ||||
|     NullTimeKeeper, TimeKeeper, setup_databases as _setup_databases, | ||||
|     setup_test_environment, teardown_databases as _teardown_databases, | ||||
|     teardown_test_environment, | ||||
|     NullTimeKeeper, TimeKeeper, iter_test_cases, | ||||
|     setup_databases as _setup_databases, setup_test_environment, | ||||
|     teardown_databases as _teardown_databases, teardown_test_environment, | ||||
| ) | ||||
| from django.utils.datastructures import OrderedSet | ||||
|  | ||||
| @@ -683,19 +683,16 @@ class DiscoverRunner: | ||||
|  | ||||
|     def _get_databases(self, suite): | ||||
|         databases = {} | ||||
|         for test in suite: | ||||
|             if isinstance(test, unittest.TestCase): | ||||
|                 test_databases = getattr(test, 'databases', None) | ||||
|                 if test_databases == '__all__': | ||||
|                     test_databases = connections | ||||
|                 if test_databases: | ||||
|                     serialized_rollback = getattr(test, 'serialized_rollback', False) | ||||
|                     databases.update( | ||||
|                         (alias, serialized_rollback or databases.get(alias, False)) | ||||
|                         for alias in test_databases | ||||
|                     ) | ||||
|             else: | ||||
|                 databases.update(self._get_databases(test)) | ||||
|         for test in iter_test_cases(suite): | ||||
|             test_databases = getattr(test, 'databases', None) | ||||
|             if test_databases == '__all__': | ||||
|                 test_databases = connections | ||||
|             if test_databases: | ||||
|                 serialized_rollback = getattr(test, 'serialized_rollback', False) | ||||
|                 databases.update( | ||||
|                     (alias, serialized_rollback or databases.get(alias, False)) | ||||
|                     for alias in test_databases | ||||
|                 ) | ||||
|         return databases | ||||
|  | ||||
|     def get_databases(self, suite): | ||||
| @@ -800,49 +797,39 @@ def partition_suite_by_type(suite, classes, bins, reverse=False): | ||||
|     Tests of type classes[i] are added to bins[i], | ||||
|     tests with no match found in classes are place in bins[-1] | ||||
|     """ | ||||
|     suite_class = type(suite) | ||||
|     if reverse: | ||||
|         suite = reversed(tuple(suite)) | ||||
|     for test in suite: | ||||
|         if isinstance(test, suite_class): | ||||
|             partition_suite_by_type(test, classes, bins, reverse=reverse) | ||||
|     for test in iter_test_cases(suite, reverse=reverse): | ||||
|         for i in range(len(classes)): | ||||
|             if isinstance(test, classes[i]): | ||||
|                 bins[i].add(test) | ||||
|                 break | ||||
|         else: | ||||
|             for i in range(len(classes)): | ||||
|                 if isinstance(test, classes[i]): | ||||
|                     bins[i].add(test) | ||||
|                     break | ||||
|             else: | ||||
|                 bins[-1].add(test) | ||||
|             bins[-1].add(test) | ||||
|  | ||||
|  | ||||
| def partition_suite_by_case(suite): | ||||
|     """Partition a test suite by test case, preserving the order of tests.""" | ||||
|     groups = [] | ||||
|     subsuites = [] | ||||
|     suite_class = type(suite) | ||||
|     for test_type, test_group in itertools.groupby(suite, type): | ||||
|         if issubclass(test_type, unittest.TestCase): | ||||
|             groups.append(suite_class(test_group)) | ||||
|         else: | ||||
|             for item in test_group: | ||||
|                 groups.extend(partition_suite_by_case(item)) | ||||
|     return groups | ||||
|     tests = iter_test_cases(suite) | ||||
|     for test_type, test_group in itertools.groupby(tests, type): | ||||
|         subsuite = suite_class(test_group) | ||||
|         subsuites.append(subsuite) | ||||
|  | ||||
|     return subsuites | ||||
|  | ||||
|  | ||||
| def filter_tests_by_tags(suite, tags, exclude_tags): | ||||
|     suite_class = type(suite) | ||||
|     filtered_suite = suite_class() | ||||
|  | ||||
|     for test in suite: | ||||
|         if isinstance(test, suite_class): | ||||
|             filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags)) | ||||
|         else: | ||||
|             test_tags = set(getattr(test, 'tags', set())) | ||||
|             test_fn_name = getattr(test, '_testMethodName', str(test)) | ||||
|             test_fn = getattr(test, test_fn_name, test) | ||||
|             test_fn_tags = set(getattr(test_fn, 'tags', set())) | ||||
|             all_tags = test_tags.union(test_fn_tags) | ||||
|             matched_tags = all_tags.intersection(tags) | ||||
|             if (matched_tags or not tags) and not all_tags.intersection(exclude_tags): | ||||
|                 filtered_suite.addTest(test) | ||||
|     for test in iter_test_cases(suite): | ||||
|         test_tags = set(getattr(test, 'tags', set())) | ||||
|         test_fn_name = getattr(test, '_testMethodName', str(test)) | ||||
|         test_fn = getattr(test, test_fn_name, test) | ||||
|         test_fn_tags = set(getattr(test_fn, 'tags', set())) | ||||
|         all_tags = test_tags.union(test_fn_tags) | ||||
|         matched_tags = all_tags.intersection(tags) | ||||
|         if (matched_tags or not tags) and not all_tags.intersection(exclude_tags): | ||||
|             filtered_suite.addTest(test) | ||||
|  | ||||
|     return filtered_suite | ||||
|   | ||||
| @@ -235,6 +235,18 @@ def setup_databases( | ||||
|     return old_names | ||||
|  | ||||
|  | ||||
| def iter_test_cases(suite, reverse=False): | ||||
|     """Return an iterator over a test suite's unittest.TestCase objects.""" | ||||
|     if reverse: | ||||
|         suite = reversed(tuple(suite)) | ||||
|     for test in suite: | ||||
|         if isinstance(test, TestCase): | ||||
|             yield test | ||||
|         else: | ||||
|             # Otherwise, assume it is a test suite. | ||||
|             yield from iter_test_cases(test, reverse=reverse) | ||||
|  | ||||
|  | ||||
| def dependency_ordered(test_databases, dependencies): | ||||
|     """ | ||||
|     Reorder test_databases into an order that honors the dependencies | ||||
|   | ||||
| @@ -18,12 +18,93 @@ from django.test.runner import DiscoverRunner | ||||
| from django.test.testcases import connections_support_transactions | ||||
| from django.test.utils import ( | ||||
|     captured_stderr, dependency_ordered, get_unique_databases_and_mirrors, | ||||
|     iter_test_cases, | ||||
| ) | ||||
| from django.utils.deprecation import RemovedInDjango50Warning | ||||
|  | ||||
| from .models import B, Person, Through | ||||
|  | ||||
|  | ||||
| class MySuite: | ||||
|     def __init__(self): | ||||
|         self.tests = [] | ||||
|  | ||||
|     def addTest(self, test): | ||||
|         self.tests.append(test) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         yield from self.tests | ||||
|  | ||||
|  | ||||
| class IterTestCasesTests(unittest.TestCase): | ||||
|     def make_test_suite(self, suite=None, suite_class=None): | ||||
|         if suite_class is None: | ||||
|             suite_class = unittest.TestSuite | ||||
|         if suite is None: | ||||
|             suite = suite_class() | ||||
|  | ||||
|         class Tests1(unittest.TestCase): | ||||
|             def test1(self): | ||||
|                 pass | ||||
|  | ||||
|             def test2(self): | ||||
|                 pass | ||||
|  | ||||
|         class Tests2(unittest.TestCase): | ||||
|             def test1(self): | ||||
|                 pass | ||||
|  | ||||
|             def test2(self): | ||||
|                 pass | ||||
|  | ||||
|         loader = unittest.defaultTestLoader | ||||
|         for test_cls in (Tests1, Tests2): | ||||
|             tests = loader.loadTestsFromTestCase(test_cls) | ||||
|             subsuite = suite_class() | ||||
|             # Only use addTest() to simplify testing a custom TestSuite. | ||||
|             for test in tests: | ||||
|                 subsuite.addTest(test) | ||||
|             suite.addTest(subsuite) | ||||
|  | ||||
|         return suite | ||||
|  | ||||
|     def assertTestNames(self, tests, expected): | ||||
|         # Each test.id() has a form like the following: | ||||
|         # "test_runner.tests.IterTestCasesTests.test_iter_test_cases.<locals>.Tests1.test1". | ||||
|         # It suffices to check only the last two parts. | ||||
|         names = ['.'.join(test.id().split('.')[-2:]) for test in tests] | ||||
|         self.assertEqual(names, expected) | ||||
|  | ||||
|     def test_basic(self): | ||||
|         suite = self.make_test_suite() | ||||
|         tests = iter_test_cases(suite) | ||||
|         self.assertTestNames(tests, expected=[ | ||||
|             'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2', | ||||
|         ]) | ||||
|  | ||||
|     def test_reverse(self): | ||||
|         suite = self.make_test_suite() | ||||
|         tests = iter_test_cases(suite, reverse=True) | ||||
|         self.assertTestNames(tests, expected=[ | ||||
|             'Tests2.test2', 'Tests2.test1', 'Tests1.test2', 'Tests1.test1', | ||||
|         ]) | ||||
|  | ||||
|     def test_custom_test_suite_class(self): | ||||
|         suite = self.make_test_suite(suite_class=MySuite) | ||||
|         tests = iter_test_cases(suite) | ||||
|         self.assertTestNames(tests, expected=[ | ||||
|             'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2', | ||||
|         ]) | ||||
|  | ||||
|     def test_mixed_test_suite_classes(self): | ||||
|         suite = self.make_test_suite(suite=MySuite()) | ||||
|         child_suite = list(suite)[0] | ||||
|         self.assertNotIsInstance(child_suite, MySuite) | ||||
|         tests = list(iter_test_cases(suite)) | ||||
|         self.assertEqual(len(tests), 4) | ||||
|         self.assertNotIsInstance(tests[0], unittest.TestSuite) | ||||
|  | ||||
|  | ||||
| class DependencyOrderingTests(unittest.TestCase): | ||||
|  | ||||
|     def test_simple_dependencies(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user