1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #32489 -- Added iter_test_cases() to iterate over a TestSuite.

This also makes partition_suite_by_type(), partition_suite_by_case(),
filter_tests_by_tags(), and DiscoverRunner._get_databases() to use
iter_test_cases().
This commit is contained in:
Chris Jerdonek
2021-03-04 01:53:45 -08:00
committed by Mariusz Felisiak
parent b190419278
commit 22c9af0eae
3 changed files with 128 additions and 48 deletions

View File

@@ -16,9 +16,9 @@ from django.core.management import call_command
from django.db import connections
from django.test import SimpleTestCase, TestCase
from django.test.utils import (
NullTimeKeeper, TimeKeeper, setup_databases as _setup_databases,
setup_test_environment, teardown_databases as _teardown_databases,
teardown_test_environment,
NullTimeKeeper, TimeKeeper, iter_test_cases,
setup_databases as _setup_databases, setup_test_environment,
teardown_databases as _teardown_databases, teardown_test_environment,
)
from django.utils.datastructures import OrderedSet
@@ -683,19 +683,16 @@ class DiscoverRunner:
def _get_databases(self, suite):
databases = {}
for test in suite:
if isinstance(test, unittest.TestCase):
test_databases = getattr(test, 'databases', None)
if test_databases == '__all__':
test_databases = connections
if test_databases:
serialized_rollback = getattr(test, 'serialized_rollback', False)
databases.update(
(alias, serialized_rollback or databases.get(alias, False))
for alias in test_databases
)
else:
databases.update(self._get_databases(test))
for test in iter_test_cases(suite):
test_databases = getattr(test, 'databases', None)
if test_databases == '__all__':
test_databases = connections
if test_databases:
serialized_rollback = getattr(test, 'serialized_rollback', False)
databases.update(
(alias, serialized_rollback or databases.get(alias, False))
for alias in test_databases
)
return databases
def get_databases(self, suite):
@@ -800,49 +797,39 @@ def partition_suite_by_type(suite, classes, bins, reverse=False):
Tests of type classes[i] are added to bins[i],
tests with no match found in classes are place in bins[-1]
"""
suite_class = type(suite)
if reverse:
suite = reversed(tuple(suite))
for test in suite:
if isinstance(test, suite_class):
partition_suite_by_type(test, classes, bins, reverse=reverse)
for test in iter_test_cases(suite, reverse=reverse):
for i in range(len(classes)):
if isinstance(test, classes[i]):
bins[i].add(test)
break
else:
for i in range(len(classes)):
if isinstance(test, classes[i]):
bins[i].add(test)
break
else:
bins[-1].add(test)
bins[-1].add(test)
def partition_suite_by_case(suite):
"""Partition a test suite by test case, preserving the order of tests."""
groups = []
subsuites = []
suite_class = type(suite)
for test_type, test_group in itertools.groupby(suite, type):
if issubclass(test_type, unittest.TestCase):
groups.append(suite_class(test_group))
else:
for item in test_group:
groups.extend(partition_suite_by_case(item))
return groups
tests = iter_test_cases(suite)
for test_type, test_group in itertools.groupby(tests, type):
subsuite = suite_class(test_group)
subsuites.append(subsuite)
return subsuites
def filter_tests_by_tags(suite, tags, exclude_tags):
suite_class = type(suite)
filtered_suite = suite_class()
for test in suite:
if isinstance(test, suite_class):
filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags))
else:
test_tags = set(getattr(test, 'tags', set()))
test_fn_name = getattr(test, '_testMethodName', str(test))
test_fn = getattr(test, test_fn_name, test)
test_fn_tags = set(getattr(test_fn, 'tags', set()))
all_tags = test_tags.union(test_fn_tags)
matched_tags = all_tags.intersection(tags)
if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
filtered_suite.addTest(test)
for test in iter_test_cases(suite):
test_tags = set(getattr(test, 'tags', set()))
test_fn_name = getattr(test, '_testMethodName', str(test))
test_fn = getattr(test, test_fn_name, test)
test_fn_tags = set(getattr(test_fn, 'tags', set()))
all_tags = test_tags.union(test_fn_tags)
matched_tags = all_tags.intersection(tags)
if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
filtered_suite.addTest(test)
return filtered_suite

View File

@@ -235,6 +235,18 @@ def setup_databases(
return old_names
def iter_test_cases(suite, reverse=False):
"""Return an iterator over a test suite's unittest.TestCase objects."""
if reverse:
suite = reversed(tuple(suite))
for test in suite:
if isinstance(test, TestCase):
yield test
else:
# Otherwise, assume it is a test suite.
yield from iter_test_cases(test, reverse=reverse)
def dependency_ordered(test_databases, dependencies):
"""
Reorder test_databases into an order that honors the dependencies