mirror of
				https://github.com/django/django.git
				synced 2025-10-31 01:25:32 +00:00 
			
		
		
		
	Fixed #31169 -- Adapted the parallel test runner to use spawn.
Co-authored-by: Valz <ahmadahussein0@gmail.com> Co-authored-by: Nick Pope <nick@nickpope.me.uk>
This commit is contained in:
		
				
					committed by
					
						 Carlton Gibson
						Carlton Gibson
					
				
			
			
				
	
			
			
			
						parent
						
							3eaba13a47
						
					
				
				
					commit
					3b3f38b3b0
				
			| @@ -1,8 +1,11 @@ | ||||
| import multiprocessing | ||||
| import os | ||||
| import shutil | ||||
| import sqlite3 | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from django.db import NotSupportedError | ||||
| from django.db.backends.base.creation import BaseDatabaseCreation | ||||
|  | ||||
|  | ||||
| @@ -51,16 +54,26 @@ class DatabaseCreation(BaseDatabaseCreation): | ||||
|     def get_test_db_clone_settings(self, suffix): | ||||
|         orig_settings_dict = self.connection.settings_dict | ||||
|         source_database_name = orig_settings_dict["NAME"] | ||||
|         if self.is_in_memory_db(source_database_name): | ||||
|  | ||||
|         if not self.is_in_memory_db(source_database_name): | ||||
|             root, ext = os.path.splitext(source_database_name) | ||||
|             return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"} | ||||
|  | ||||
|         start_method = multiprocessing.get_start_method() | ||||
|         if start_method == "fork": | ||||
|             return orig_settings_dict | ||||
|         else: | ||||
|             root, ext = os.path.splitext(orig_settings_dict["NAME"]) | ||||
|             return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)} | ||||
|         if start_method == "spawn": | ||||
|             return { | ||||
|                 **orig_settings_dict, | ||||
|                 "NAME": f"{self.connection.alias}_{suffix}.sqlite3", | ||||
|             } | ||||
|         raise NotSupportedError( | ||||
|             f"Cloning with start method {start_method!r} is not supported." | ||||
|         ) | ||||
|  | ||||
|     def _clone_test_db(self, suffix, verbosity, keepdb=False): | ||||
|         source_database_name = self.connection.settings_dict["NAME"] | ||||
|         target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] | ||||
|         # Forking automatically makes a copy of an in-memory database. | ||||
|         if not self.is_in_memory_db(source_database_name): | ||||
|             # Erase the old test database | ||||
|             if os.access(target_database_name, os.F_OK): | ||||
| @@ -85,6 +98,12 @@ class DatabaseCreation(BaseDatabaseCreation): | ||||
|             except Exception as e: | ||||
|                 self.log("Got an error cloning the test database: %s" % e) | ||||
|                 sys.exit(2) | ||||
|         # Forking automatically makes a copy of an in-memory database. | ||||
|         # Spawn requires migrating to disk which will be re-opened in | ||||
|         # setup_worker_connection. | ||||
|         elif multiprocessing.get_start_method() == "spawn": | ||||
|             ondisk_db = sqlite3.connect(target_database_name, uri=True) | ||||
|             self.connection.connection.backup(ondisk_db) | ||||
|  | ||||
|     def _destroy_test_db(self, test_database_name, verbosity): | ||||
|         if test_database_name and not self.is_in_memory_db(test_database_name): | ||||
| @@ -106,3 +125,34 @@ class DatabaseCreation(BaseDatabaseCreation): | ||||
|         else: | ||||
|             sig.append(test_database_name) | ||||
|         return tuple(sig) | ||||
|  | ||||
|     def setup_worker_connection(self, _worker_id): | ||||
|         settings_dict = self.get_test_db_clone_settings(_worker_id) | ||||
|         # connection.settings_dict must be updated in place for changes to be | ||||
|         # reflected in django.db.connections. Otherwise new threads would | ||||
|         # connect to the default database instead of the appropriate clone. | ||||
|         start_method = multiprocessing.get_start_method() | ||||
|         if start_method == "fork": | ||||
|             # Update settings_dict in place. | ||||
|             self.connection.settings_dict.update(settings_dict) | ||||
|             self.connection.close() | ||||
|         elif start_method == "spawn": | ||||
|             alias = self.connection.alias | ||||
|             connection_str = ( | ||||
|                 f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared" | ||||
|             ) | ||||
|             source_db = self.connection.Database.connect( | ||||
|                 f"file:{alias}_{_worker_id}.sqlite3", uri=True | ||||
|             ) | ||||
|             target_db = sqlite3.connect(connection_str, uri=True) | ||||
|             source_db.backup(target_db) | ||||
|             source_db.close() | ||||
|             # Update settings_dict in place. | ||||
|             self.connection.settings_dict.update(settings_dict) | ||||
|             self.connection.settings_dict["NAME"] = connection_str | ||||
|             # Re-open connection to in-memory database before closing copy | ||||
|             # connection. | ||||
|             self.connection.connect() | ||||
|             target_db.close() | ||||
|             if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": | ||||
|                 self.mark_expected_failures_and_skips() | ||||
|   | ||||
| @@ -20,7 +20,12 @@ from io import StringIO | ||||
| from django.core.management import call_command | ||||
| from django.db import connections | ||||
| from django.test import SimpleTestCase, TestCase | ||||
| from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases | ||||
| from django.test.utils import ( | ||||
|     NullTimeKeeper, | ||||
|     TimeKeeper, | ||||
|     captured_stdout, | ||||
|     iter_test_cases, | ||||
| ) | ||||
| from django.test.utils import setup_databases as _setup_databases | ||||
| from django.test.utils import setup_test_environment | ||||
| from django.test.utils import teardown_databases as _teardown_databases | ||||
| @@ -367,8 +372,8 @@ def get_max_test_processes(): | ||||
|     The maximum number of test processes when using the --parallel option. | ||||
|     """ | ||||
|     # The current implementation of the parallel test runner requires | ||||
|     # multiprocessing to start subprocesses with fork(). | ||||
|     if multiprocessing.get_start_method() != "fork": | ||||
|     # multiprocessing to start subprocesses with fork() or spawn(). | ||||
|     if multiprocessing.get_start_method() not in {"fork", "spawn"}: | ||||
|         return 1 | ||||
|     try: | ||||
|         return int(os.environ["DJANGO_TEST_PROCESSES"]) | ||||
| @@ -391,7 +396,13 @@ def parallel_type(value): | ||||
| _worker_id = 0 | ||||
|  | ||||
|  | ||||
| def _init_worker(counter): | ||||
| def _init_worker( | ||||
|     counter, | ||||
|     initial_settings=None, | ||||
|     serialized_contents=None, | ||||
|     process_setup=None, | ||||
|     process_setup_args=None, | ||||
| ): | ||||
|     """ | ||||
|     Switch to databases dedicated to this worker. | ||||
|  | ||||
| @@ -405,9 +416,22 @@ def _init_worker(counter): | ||||
|         counter.value += 1 | ||||
|         _worker_id = counter.value | ||||
|  | ||||
|     start_method = multiprocessing.get_start_method() | ||||
|  | ||||
|     if start_method == "spawn": | ||||
|         process_setup(*process_setup_args) | ||||
|         setup_test_environment() | ||||
|  | ||||
|     for alias in connections: | ||||
|         connection = connections[alias] | ||||
|         if start_method == "spawn": | ||||
|             # Restore initial settings in spawned processes. | ||||
|             connection.settings_dict.update(initial_settings[alias]) | ||||
|             if value := serialized_contents.get(alias): | ||||
|                 connection._test_serialized_contents = value | ||||
|         connection.creation.setup_worker_connection(_worker_id) | ||||
|         with captured_stdout(): | ||||
|             call_command("check", databases=connections) | ||||
|  | ||||
|  | ||||
| def _run_subsuite(args): | ||||
| @@ -449,6 +473,8 @@ class ParallelTestSuite(unittest.TestSuite): | ||||
|         self.processes = processes | ||||
|         self.failfast = failfast | ||||
|         self.buffer = buffer | ||||
|         self.initial_settings = None | ||||
|         self.serialized_contents = None | ||||
|         super().__init__() | ||||
|  | ||||
|     def run(self, result): | ||||
| @@ -469,8 +495,12 @@ class ParallelTestSuite(unittest.TestSuite): | ||||
|         counter = multiprocessing.Value(ctypes.c_int, 0) | ||||
|         pool = multiprocessing.Pool( | ||||
|             processes=self.processes, | ||||
|             initializer=self.init_worker.__func__, | ||||
|             initargs=[counter], | ||||
|             initializer=self.init_worker, | ||||
|             initargs=[ | ||||
|                 counter, | ||||
|                 self.initial_settings, | ||||
|                 self.serialized_contents, | ||||
|             ], | ||||
|         ) | ||||
|         args = [ | ||||
|             (self.runner_class, index, subsuite, self.failfast, self.buffer) | ||||
| @@ -508,6 +538,17 @@ class ParallelTestSuite(unittest.TestSuite): | ||||
|     def __iter__(self): | ||||
|         return iter(self.subsuites) | ||||
|  | ||||
|     def initialize_suite(self): | ||||
|         if multiprocessing.get_start_method() == "spawn": | ||||
|             self.initial_settings = { | ||||
|                 alias: connections[alias].settings_dict for alias in connections | ||||
|             } | ||||
|             self.serialized_contents = { | ||||
|                 alias: connections[alias]._test_serialized_contents | ||||
|                 for alias in connections | ||||
|                 if alias in self.serialized_aliases | ||||
|             } | ||||
|  | ||||
|  | ||||
| class Shuffler: | ||||
|     """ | ||||
| @@ -921,6 +962,8 @@ class DiscoverRunner: | ||||
|     def run_suite(self, suite, **kwargs): | ||||
|         kwargs = self.get_test_runner_kwargs() | ||||
|         runner = self.test_runner(**kwargs) | ||||
|         if hasattr(suite, "initialize_suite"): | ||||
|             suite.initialize_suite() | ||||
|         try: | ||||
|             return runner.run(suite) | ||||
|         finally: | ||||
| @@ -989,13 +1032,13 @@ class DiscoverRunner: | ||||
|         self.setup_test_environment() | ||||
|         suite = self.build_suite(test_labels, extra_tests) | ||||
|         databases = self.get_databases(suite) | ||||
|         serialized_aliases = set( | ||||
|         suite.serialized_aliases = set( | ||||
|             alias for alias, serialize in databases.items() if serialize | ||||
|         ) | ||||
|         with self.time_keeper.timed("Total database setup"): | ||||
|             old_config = self.setup_databases( | ||||
|                 aliases=databases, | ||||
|                 serialized_aliases=serialized_aliases, | ||||
|                 serialized_aliases=suite.serialized_aliases, | ||||
|             ) | ||||
|         run_failed = False | ||||
|         try: | ||||
|   | ||||
| @@ -130,7 +130,7 @@ def iter_modules_and_files(modules, extra_files): | ||||
|         # cause issues here. | ||||
|         if not isinstance(module, ModuleType): | ||||
|             continue | ||||
|         if module.__name__ == "__main__": | ||||
|         if module.__name__ in ("__main__", "__mp_main__"): | ||||
|             # __main__ (usually manage.py) doesn't always have a __spec__ set. | ||||
|             # Handle this by falling back to using __file__, resolved below. | ||||
|             # See https://docs.python.org/reference/import.html#main-spec | ||||
|   | ||||
| @@ -70,6 +70,8 @@ class SessionMiddlewareSubclass(SessionMiddleware): | ||||
|     ], | ||||
| ) | ||||
| class SystemChecksTestCase(SimpleTestCase): | ||||
|     databases = "__all__" | ||||
|  | ||||
|     def test_checks_are_performed(self): | ||||
|         admin.site.register(Song, MyAdmin) | ||||
|         try: | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| import copy | ||||
| import multiprocessing | ||||
| import unittest | ||||
| from unittest import mock | ||||
|  | ||||
| from django.db import DEFAULT_DB_ALIAS, connection, connections | ||||
| from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connection, connections | ||||
| from django.test import SimpleTestCase | ||||
|  | ||||
|  | ||||
| @@ -33,3 +35,9 @@ class TestDbSignatureTests(SimpleTestCase): | ||||
|                 creation_class = test_connection.creation_class(test_connection) | ||||
|                 clone_settings_dict = creation_class.get_test_db_clone_settings("1") | ||||
|                 self.assertEqual(clone_settings_dict["NAME"], expected_clone_name) | ||||
|  | ||||
|     @mock.patch.object(multiprocessing, "get_start_method", return_value="forkserver") | ||||
|     def test_get_test_db_clone_settings_not_supported(self, *mocked_objects): | ||||
|         msg = "Cloning with start method 'forkserver' is not supported." | ||||
|         with self.assertRaisesMessage(NotSupportedError, msg): | ||||
|             connection.creation.get_test_db_clone_settings(1) | ||||
|   | ||||
| @@ -362,5 +362,7 @@ class CheckFrameworkReservedNamesTests(SimpleTestCase): | ||||
|  | ||||
|  | ||||
| class ChecksRunDuringTests(SimpleTestCase): | ||||
|     databases = "__all__" | ||||
|  | ||||
|     def test_registered_check_did_run(self): | ||||
|         self.assertTrue(my_check.did_run) | ||||
|   | ||||
| @@ -11,6 +11,8 @@ from django.test.utils import isolate_apps | ||||
|  | ||||
| @isolate_apps("contenttypes_tests", attr_name="apps") | ||||
| class GenericForeignKeyTests(SimpleTestCase): | ||||
|     databases = "__all__" | ||||
|  | ||||
|     def test_missing_content_type_field(self): | ||||
|         class TaggedItem(models.Model): | ||||
|             # no content_type field | ||||
|   | ||||
| @@ -22,6 +22,13 @@ class RemoveStaleContentTypesTests(TestCase): | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpTestData(cls): | ||||
|         with captured_stdout(): | ||||
|             call_command( | ||||
|                 "remove_stale_contenttypes", | ||||
|                 interactive=False, | ||||
|                 include_stale_apps=True, | ||||
|                 verbosity=2, | ||||
|             ) | ||||
|         cls.before_count = ContentType.objects.count() | ||||
|         cls.content_type = ContentType.objects.create( | ||||
|             app_label="contenttypes_tests", model="Fake" | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| from datetime import date | ||||
|  | ||||
| from django.test import modify_settings | ||||
|  | ||||
| from . import PostgreSQLTestCase | ||||
| from .models import ( | ||||
|     HStoreModel, | ||||
| @@ -16,6 +18,7 @@ except ImportError: | ||||
|     pass  # psycopg2 isn't installed. | ||||
|  | ||||
|  | ||||
| @modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}) | ||||
| class BulkSaveTests(PostgreSQLTestCase): | ||||
|     def test_bulk_update(self): | ||||
|         test_data = [ | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import argparse | ||||
| import atexit | ||||
| import copy | ||||
| import gc | ||||
| import multiprocessing | ||||
| import os | ||||
| import shutil | ||||
| import socket | ||||
| @@ -10,6 +11,7 @@ import subprocess | ||||
| import sys | ||||
| import tempfile | ||||
| import warnings | ||||
| from functools import partial | ||||
| from pathlib import Path | ||||
|  | ||||
| try: | ||||
| @@ -24,7 +26,7 @@ else: | ||||
|     from django.core.exceptions import ImproperlyConfigured | ||||
|     from django.db import connection, connections | ||||
|     from django.test import TestCase, TransactionTestCase | ||||
|     from django.test.runner import get_max_test_processes, parallel_type | ||||
|     from django.test.runner import _init_worker, get_max_test_processes, parallel_type | ||||
|     from django.test.selenium import SeleniumTestCaseBase | ||||
|     from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner | ||||
|     from django.utils.deprecation import RemovedInDjango50Warning | ||||
| @@ -382,7 +384,8 @@ def django_tests( | ||||
|             msg += " with up to %d processes" % max_parallel | ||||
|         print(msg) | ||||
|  | ||||
|     test_labels, state = setup_run_tests(verbosity, start_at, start_after, test_labels) | ||||
|     process_setup_args = (verbosity, start_at, start_after, test_labels) | ||||
|     test_labels, state = setup_run_tests(*process_setup_args) | ||||
|     # Run the test suite, including the extra validation tests. | ||||
|     if not hasattr(settings, "TEST_RUNNER"): | ||||
|         settings.TEST_RUNNER = "django.test.runner.DiscoverRunner" | ||||
| @@ -395,6 +398,11 @@ def django_tests( | ||||
|             parallel = 1 | ||||
|  | ||||
|     TestRunner = get_runner(settings) | ||||
|     TestRunner.parallel_test_suite.init_worker = partial( | ||||
|         _init_worker, | ||||
|         process_setup=setup_run_tests, | ||||
|         process_setup_args=process_setup_args, | ||||
|     ) | ||||
|     test_runner = TestRunner( | ||||
|         verbosity=verbosity, | ||||
|         interactive=interactive, | ||||
| @@ -718,6 +726,11 @@ if __name__ == "__main__": | ||||
|         options.settings = os.environ["DJANGO_SETTINGS_MODULE"] | ||||
|  | ||||
|     if options.selenium: | ||||
|         if multiprocessing.get_start_method() == "spawn" and options.parallel != 1: | ||||
|             parser.error( | ||||
|                 "You cannot use --selenium with parallel tests on this system. " | ||||
|                 "Pass --parallel=1 to use --selenium." | ||||
|             ) | ||||
|         if not options.tags: | ||||
|             options.tags = ["selenium"] | ||||
|         elif "selenium" not in options.tags: | ||||
|   | ||||
| @@ -86,6 +86,16 @@ class DiscoverRunnerParallelArgumentTests(SimpleTestCase): | ||||
|         mocked_cpu_count, | ||||
|     ): | ||||
|         mocked_get_start_method.return_value = "spawn" | ||||
|         self.assertEqual(get_max_test_processes(), 12) | ||||
|         with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): | ||||
|             self.assertEqual(get_max_test_processes(), 7) | ||||
|  | ||||
|     def test_get_max_test_processes_forkserver( | ||||
|         self, | ||||
|         mocked_get_start_method, | ||||
|         mocked_cpu_count, | ||||
|     ): | ||||
|         mocked_get_start_method.return_value = "forkserver" | ||||
|         self.assertEqual(get_max_test_processes(), 1) | ||||
|         with mock.patch.dict(os.environ, {"DJANGO_TEST_PROCESSES": "7"}): | ||||
|             self.assertEqual(get_max_test_processes(), 1) | ||||
|   | ||||
| @@ -480,8 +480,6 @@ class ManageCommandTests(unittest.TestCase): | ||||
| # Isolate from the real environment. | ||||
| @mock.patch.dict(os.environ, {}, clear=True) | ||||
| @mock.patch.object(multiprocessing, "cpu_count", return_value=12) | ||||
| # Python 3.8 on macOS defaults to 'spawn' mode. | ||||
| @mock.patch.object(multiprocessing, "get_start_method", return_value="fork") | ||||
| class ManageCommandParallelTests(SimpleTestCase): | ||||
|     def test_parallel_default(self, *mocked_objects): | ||||
|         with captured_stderr() as stderr: | ||||
| @@ -507,8 +505,8 @@ class ManageCommandParallelTests(SimpleTestCase): | ||||
|         # Parallel is disabled by default. | ||||
|         self.assertEqual(stderr.getvalue(), "") | ||||
|  | ||||
|     def test_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): | ||||
|         mocked_get_start_method.return_value = "spawn" | ||||
|     @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn") | ||||
|     def test_parallel_spawn(self, *mocked_objects): | ||||
|         with captured_stderr() as stderr: | ||||
|             call_command( | ||||
|                 "test", | ||||
| @@ -517,8 +515,8 @@ class ManageCommandParallelTests(SimpleTestCase): | ||||
|             ) | ||||
|         self.assertIn("parallel=1", stderr.getvalue()) | ||||
|  | ||||
|     def test_no_parallel_spawn(self, mocked_get_start_method, mocked_cpu_count): | ||||
|         mocked_get_start_method.return_value = "spawn" | ||||
|     @mock.patch.object(multiprocessing, "get_start_method", return_value="spawn") | ||||
|     def test_no_parallel_spawn(self, *mocked_objects): | ||||
|         with captured_stderr() as stderr: | ||||
|             call_command( | ||||
|                 "test", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user