#!/usr/bin/env python
import argparse
import atexit
import copy
import gc
import os
import shutil
import socket
import subprocess
import sys
import tempfile
import warnings
from pathlib import Path

try:
    import django
except ImportError as e:
    raise RuntimeError(
        "Django module not found, reference tests/README.rst for instructions."
    ) from e
else:
    from django.apps import apps
    from django.conf import settings
    from django.core.exceptions import ImproperlyConfigured
    from django.db import connection, connections
    from django.test import TestCase, TransactionTestCase
    from django.test.runner import get_max_test_processes, parallel_type
    from django.test.selenium import SeleniumTestCaseBase
    from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner
    from django.utils.deprecation import RemovedInDjango50Warning
    from django.utils.log import DEFAULT_LOGGING

try:
    import MySQLdb
except ImportError:
    pass
else:
    # Ignore informational warnings from QuerySet.explain().
    warnings.filterwarnings("ignore", r"\(1003, *", category=MySQLdb.Warning)

# Make deprecation warnings errors to ensure no usage of deprecated features.
warnings.simplefilter("error", RemovedInDjango50Warning)
# Make resource and runtime warning errors to ensure no usage of error prone
# patterns.
warnings.simplefilter("error", ResourceWarning)
warnings.simplefilter("error", RuntimeWarning)
# Ignore known warnings in test dependencies.
warnings.filterwarnings(
    "ignore", "'U' mode is deprecated", DeprecationWarning, module="docutils.io"
)

# Reduce garbage collection frequency to improve performance. Since CPython
# uses refcounting, garbage collection only collects objects with cyclic
# references, which are a minority, so the garbage collection threshold can be
# larger than the default threshold of 700 allocations + deallocations without
# much increase in memory usage.
gc.set_threshold(100_000)

RUNTESTS_DIR = os.path.abspath(os.path.dirname(__file__))

TEMPLATE_DIR = os.path.join(RUNTESTS_DIR, "templates")

# Create a specific subdirectory for the duration of the test suite.
TMPDIR = tempfile.mkdtemp(prefix="django_")
# Set the TMPDIR environment variable in addition to tempfile.tempdir
# so that children processes inherit it.
tempfile.tempdir = os.environ["TMPDIR"] = TMPDIR

# Removing the temporary TMPDIR.
atexit.register(shutil.rmtree, TMPDIR)


# This is a dict mapping RUNTESTS_DIR subdirectory to subdirectories of that
# directory to skip when searching for test modules.
SUBDIRS_TO_SKIP = {
    "": {"import_error_package", "test_runner_apps"},
    "gis_tests": {"data"},
}

ALWAYS_INSTALLED_APPS = [
    "django.contrib.contenttypes",
    "django.contrib.auth",
    "django.contrib.sites",
    "django.contrib.sessions",
    "django.contrib.messages",
    "django.contrib.admin.apps.SimpleAdminConfig",
    "django.contrib.staticfiles",
]

ALWAYS_MIDDLEWARE = [
    "django.contrib.sessions.middleware.SessionMiddleware",
    "django.middleware.common.CommonMiddleware",
    "django.middleware.csrf.CsrfViewMiddleware",
    "django.contrib.auth.middleware.AuthenticationMiddleware",
    "django.contrib.messages.middleware.MessageMiddleware",
]

# Need to add the associated contrib app to INSTALLED_APPS in some cases to
# avoid "RuntimeError: Model class X doesn't declare an explicit app_label
# and isn't in an application in INSTALLED_APPS."
CONTRIB_TESTS_TO_APPS = {
    "deprecation": ["django.contrib.flatpages", "django.contrib.redirects"],
    "flatpages_tests": ["django.contrib.flatpages"],
    "redirects_tests": ["django.contrib.redirects"],
}


def get_test_modules(gis_enabled):
    """
    Scan the tests directory and yield the names of all test modules.

    The yielded names have either one dotted part like "test_runner" or, in
    the case of GIS tests, two dotted parts like "gis_tests.gdal_tests".
    """
    discovery_dirs = [""]
    if gis_enabled:
        # GIS tests are in nested apps
        discovery_dirs.append("gis_tests")
    else:
        SUBDIRS_TO_SKIP[""].add("gis_tests")

    for dirname in discovery_dirs:
        dirpath = os.path.join(RUNTESTS_DIR, dirname)
        subdirs_to_skip = SUBDIRS_TO_SKIP[dirname]
        with os.scandir(dirpath) as entries:
            for f in entries:
                if (
                    "." in f.name
                    or os.path.basename(f.name) in subdirs_to_skip
                    or f.is_file()
                    or not os.path.exists(os.path.join(f.path, "__init__.py"))
                ):
                    continue
                test_module = f.name
                if dirname:
                    test_module = dirname + "." + test_module
                yield test_module


def get_label_module(label):
    """Return the top-level module part for a test label."""
    path = Path(label)
    if len(path.parts) == 1:
        # Interpret the label as a dotted module name.
        return label.split(".")[0]

    # Otherwise, interpret the label as a path. Check existence first to
    # provide a better error message than relative_to() if it doesn't exist.
    if not path.exists():
        raise RuntimeError(f"Test label path {label} does not exist")
    path = path.resolve()
    rel_path = path.relative_to(RUNTESTS_DIR)
    return rel_path.parts[0]


def get_filtered_test_modules(start_at, start_after, gis_enabled, test_labels=None):
    if test_labels is None:
        test_labels = []
    # Reduce each test label to just the top-level module part.
    label_modules = set()
    for label in test_labels:
        test_module = get_label_module(label)
        label_modules.add(test_module)

    # It would be nice to put this validation earlier but it must come after
    # django.setup() so that connection.features.gis_enabled can be accessed.
    if "gis_tests" in label_modules and not gis_enabled:
        print("Aborting: A GIS database backend is required to run gis_tests.")
        sys.exit(1)

    def _module_match_label(module_name, label):
        # Exact or ancestor match.
        return module_name == label or module_name.startswith(label + ".")

    start_label = start_at or start_after
    for test_module in get_test_modules(gis_enabled):
        if start_label:
            if not _module_match_label(test_module, start_label):
                continue
            start_label = ""
            if not start_at:
                assert start_after
                # Skip the current one before starting.
                continue
        # If the module (or an ancestor) was named on the command line, or
        # no modules were named (i.e., run all), include the test module.
        if not test_labels or any(
            _module_match_label(test_module, label_module)
            for label_module in label_modules
        ):
            yield test_module


def setup_collect_tests(start_at, start_after, test_labels=None):
    state = {
        "INSTALLED_APPS": settings.INSTALLED_APPS,
        "ROOT_URLCONF": getattr(settings, "ROOT_URLCONF", ""),
        "TEMPLATES": settings.TEMPLATES,
        "LANGUAGE_CODE": settings.LANGUAGE_CODE,
        "STATIC_URL": settings.STATIC_URL,
        "STATIC_ROOT": settings.STATIC_ROOT,
        "MIDDLEWARE": settings.MIDDLEWARE,
    }

    # Redirect some settings for the duration of these tests.
    settings.INSTALLED_APPS = ALWAYS_INSTALLED_APPS
    settings.ROOT_URLCONF = "urls"
    settings.STATIC_URL = "static/"
    settings.STATIC_ROOT = os.path.join(TMPDIR, "static")
    settings.TEMPLATES = [
        {
            "BACKEND": "django.template.backends.django.DjangoTemplates",
            "DIRS": [TEMPLATE_DIR],
            "APP_DIRS": True,
            "OPTIONS": {
                "context_processors": [
                    "django.template.context_processors.debug",
                    "django.template.context_processors.request",
                    "django.contrib.auth.context_processors.auth",
                    "django.contrib.messages.context_processors.messages",
                ],
            },
        }
    ]
    settings.LANGUAGE_CODE = "en"
    settings.SITE_ID = 1
    settings.MIDDLEWARE = ALWAYS_MIDDLEWARE
    settings.MIGRATION_MODULES = {
        # This lets us skip creating migrations for the test models as many of
        # them depend on one of the following contrib applications.
        "auth": None,
        "contenttypes": None,
        "sessions": None,
    }
    log_config = copy.deepcopy(DEFAULT_LOGGING)
    # Filter out non-error logging so we don't have to capture it in lots of
    # tests.
    log_config["loggers"]["django"]["level"] = "ERROR"
    settings.LOGGING = log_config
    settings.SILENCED_SYSTEM_CHECKS = [
        "fields.W342",  # ForeignKey(unique=True) -> OneToOneField
    ]

    # Load all the ALWAYS_INSTALLED_APPS.
    django.setup()

    # This flag must be evaluated after django.setup() because otherwise it can
    # raise AppRegistryNotReady when running gis_tests in isolation on some
    # backends (e.g. PostGIS).
    gis_enabled = connection.features.gis_enabled

    test_modules = list(
        get_filtered_test_modules(
            start_at,
            start_after,
            gis_enabled,
            test_labels=test_labels,
        )
    )
    return test_modules, state


def teardown_collect_tests(state):
    # Restore the old settings.
    for key, value in state.items():
        setattr(settings, key, value)


def get_installed():
    return [app_config.name for app_config in apps.get_app_configs()]


# This function should be called only after calling django.setup(),
# since it calls connection.features.gis_enabled.
def get_apps_to_install(test_modules):
    for test_module in test_modules:
        if test_module in CONTRIB_TESTS_TO_APPS:
            yield from CONTRIB_TESTS_TO_APPS[test_module]
        yield test_module

    # Add contrib.gis to INSTALLED_APPS if needed (rather than requiring
    # @override_settings(INSTALLED_APPS=...) on all test cases.
    if connection.features.gis_enabled:
        yield "django.contrib.gis"


def setup_run_tests(verbosity, start_at, start_after, test_labels=None):
    test_modules, state = setup_collect_tests(
        start_at, start_after, test_labels=test_labels
    )

    installed_apps = set(get_installed())
    for app in get_apps_to_install(test_modules):
        if app in installed_apps:
            continue
        if verbosity >= 2:
            print(f"Importing application {app}")
        settings.INSTALLED_APPS.append(app)
        installed_apps.add(app)

    apps.set_installed_apps(settings.INSTALLED_APPS)

    # Force declaring available_apps in TransactionTestCase for faster tests.
    def no_available_apps(self):
        raise Exception(
            "Please define available_apps in TransactionTestCase and its subclasses."
        )

    TransactionTestCase.available_apps = property(no_available_apps)
    TestCase.available_apps = None

    # Set an environment variable that other code may consult to see if
    # Django's own test suite is running.
    os.environ["RUNNING_DJANGOS_TEST_SUITE"] = "true"

    test_labels = test_labels or test_modules
    return test_labels, state


def teardown_run_tests(state):
    teardown_collect_tests(state)
    # Discard the multiprocessing.util finalizer that tries to remove a
    # temporary directory that's already removed by this script's
    # atexit.register(shutil.rmtree, TMPDIR) handler. Prevents
    # FileNotFoundError at the end of a test run (#27890).
    from multiprocessing.util import _finalizer_registry

    _finalizer_registry.pop((-100, 0), None)
    del os.environ["RUNNING_DJANGOS_TEST_SUITE"]


class ActionSelenium(argparse.Action):
    """
    Validate the comma-separated list of requested browsers.
    """

    def __call__(self, parser, namespace, values, option_string=None):
        try:
            import selenium  # NOQA
        except ImportError as e:
            raise ImproperlyConfigured(f"Error loading selenium module: {e}")
        browsers = values.split(",")
        for browser in browsers:
            try:
                SeleniumTestCaseBase.import_webdriver(browser)
            except ImportError:
                raise argparse.ArgumentError(
                    self, "Selenium browser specification '%s' is not valid." % browser
                )
        setattr(namespace, self.dest, browsers)


def django_tests(
    verbosity,
    interactive,
    failfast,
    keepdb,
    reverse,
    test_labels,
    debug_sql,
    parallel,
    tags,
    exclude_tags,
    test_name_patterns,
    start_at,
    start_after,
    pdb,
    buffer,
    timing,
    shuffle,
):
    if parallel in {0, "auto"}:
        max_parallel = get_max_test_processes()
    else:
        max_parallel = parallel

    if verbosity >= 1:
        msg = "Testing against Django installed in '%s'" % os.path.dirname(
            django.__file__
        )
        if max_parallel > 1:
            msg += " with up to %d processes" % max_parallel
        print(msg)

    test_labels, state = setup_run_tests(verbosity, start_at, start_after, test_labels)
    # Run the test suite, including the extra validation tests.
    if not hasattr(settings, "TEST_RUNNER"):
        settings.TEST_RUNNER = "django.test.runner.DiscoverRunner"

    if parallel in {0, "auto"}:
        # This doesn't work before django.setup() on some databases.
        if all(conn.features.can_clone_databases for conn in connections.all()):
            parallel = max_parallel
        else:
            parallel = 1

    TestRunner = get_runner(settings)
    test_runner = TestRunner(
        verbosity=verbosity,
        interactive=interactive,
        failfast=failfast,
        keepdb=keepdb,
        reverse=reverse,
        debug_sql=debug_sql,
        parallel=parallel,
        tags=tags,
        exclude_tags=exclude_tags,
        test_name_patterns=test_name_patterns,
        pdb=pdb,
        buffer=buffer,
        timing=timing,
        shuffle=shuffle,
    )
    failures = test_runner.run_tests(test_labels)
    teardown_run_tests(state)
    return failures


def collect_test_modules(start_at, start_after):
    test_modules, state = setup_collect_tests(start_at, start_after)
    teardown_collect_tests(state)
    return test_modules


def get_subprocess_args(options):
    subprocess_args = [sys.executable, __file__, "--settings=%s" % options.settings]
    if options.failfast:
        subprocess_args.append("--failfast")
    if options.verbosity:
        subprocess_args.append("--verbosity=%s" % options.verbosity)
    if not options.interactive:
        subprocess_args.append("--noinput")
    if options.tags:
        subprocess_args.append("--tag=%s" % options.tags)
    if options.exclude_tags:
        subprocess_args.append("--exclude_tag=%s" % options.exclude_tags)
    if options.shuffle is not False:
        if options.shuffle is None:
            subprocess_args.append("--shuffle")
        else:
            subprocess_args.append("--shuffle=%s" % options.shuffle)
    return subprocess_args


def bisect_tests(bisection_label, options, test_labels, start_at, start_after):
    if not test_labels:
        test_labels = collect_test_modules(start_at, start_after)

    print("***** Bisecting test suite: %s" % " ".join(test_labels))

    # Make sure the bisection point isn't in the test list
    # Also remove tests that need to be run in specific combinations
    for label in [bisection_label, "model_inheritance_same_model_name"]:
        try:
            test_labels.remove(label)
        except ValueError:
            pass

    subprocess_args = get_subprocess_args(options)

    iteration = 1
    while len(test_labels) > 1:
        midpoint = len(test_labels) // 2
        test_labels_a = test_labels[:midpoint] + [bisection_label]
        test_labels_b = test_labels[midpoint:] + [bisection_label]
        print("***** Pass %da: Running the first half of the test suite" % iteration)
        print("***** Test labels: %s" % " ".join(test_labels_a))
        failures_a = subprocess.run(subprocess_args + test_labels_a)

        print("***** Pass %db: Running the second half of the test suite" % iteration)
        print("***** Test labels: %s" % " ".join(test_labels_b))
        print("")
        failures_b = subprocess.run(subprocess_args + test_labels_b)

        if failures_a.returncode and not failures_b.returncode:
            print("***** Problem found in first half. Bisecting again...")
            iteration += 1
            test_labels = test_labels_a[:-1]
        elif failures_b.returncode and not failures_a.returncode:
            print("***** Problem found in second half. Bisecting again...")
            iteration += 1
            test_labels = test_labels_b[:-1]
        elif failures_a.returncode and failures_b.returncode:
            print("***** Multiple sources of failure found")
            break
        else:
            print("***** No source of failure found... try pair execution (--pair)")
            break

    if len(test_labels) == 1:
        print("***** Source of error: %s" % test_labels[0])


def paired_tests(paired_test, options, test_labels, start_at, start_after):
    if not test_labels:
        test_labels = collect_test_modules(start_at, start_after)

    print("***** Trying paired execution")

    # Make sure the constant member of the pair isn't in the test list
    # Also remove tests that need to be run in specific combinations
    for label in [paired_test, "model_inheritance_same_model_name"]:
        try:
            test_labels.remove(label)
        except ValueError:
            pass

    subprocess_args = get_subprocess_args(options)

    for i, label in enumerate(test_labels):
        print(
            "***** %d of %d: Check test pairing with %s"
            % (i + 1, len(test_labels), label)
        )
        failures = subprocess.call(subprocess_args + [label, paired_test])
        if failures:
            print("***** Found problem pair with %s" % label)
            return

    print("***** No problem pair found")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the Django test suite.")
    parser.add_argument(
        "modules",
        nargs="*",
        metavar="module",
        help='Optional path(s) to test modules; e.g. "i18n" or '
        '"i18n.tests.TranslationTests.test_lazy_objects".',
    )
    parser.add_argument(
        "-v",
        "--verbosity",
        default=1,
        type=int,
        choices=[0, 1, 2, 3],
        help="Verbosity level; 0=minimal output, 1=normal output, 2=all output",
    )
    parser.add_argument(
        "--noinput",
        action="store_false",
        dest="interactive",
        help="Tells Django to NOT prompt the user for input of any kind.",
    )
    parser.add_argument(
        "--failfast",
        action="store_true",
        help="Tells Django to stop running the test suite after first failed test.",
    )
    parser.add_argument(
        "--keepdb",
        action="store_true",
        help="Tells Django to preserve the test database between runs.",
    )
    parser.add_argument(
        "--settings",
        help='Python path to settings module, e.g. "myproject.settings". If '
        "this isn't provided, either the DJANGO_SETTINGS_MODULE "
        'environment variable or "test_sqlite" will be used.',
    )
    parser.add_argument(
        "--bisect",
        help="Bisect the test suite to discover a test that causes a test "
        "failure when combined with the named test.",
    )
    parser.add_argument(
        "--pair",
        help="Run the test suite in pairs with the named test to find problem pairs.",
    )
    parser.add_argument(
        "--shuffle",
        nargs="?",
        default=False,
        type=int,
        metavar="SEED",
        help=(
            "Shuffle the order of test cases to help check that tests are "
            "properly isolated."
        ),
    )
    parser.add_argument(
        "--reverse",
        action="store_true",
        help="Sort test suites and test cases in opposite order to debug "
        "test side effects not apparent with normal execution lineup.",
    )
    parser.add_argument(
        "--selenium",
        action=ActionSelenium,
        metavar="BROWSERS",
        help="A comma-separated list of browsers to run the Selenium tests against.",
    )
    parser.add_argument(
        "--headless",
        action="store_true",
        help="Run selenium tests in headless mode, if the browser supports the option.",
    )
    parser.add_argument(
        "--selenium-hub",
        help="A URL for a selenium hub instance to use in combination with --selenium.",
    )
    parser.add_argument(
        "--external-host",
        default=socket.gethostname(),
        help=(
            "The external host that can be reached by the selenium hub instance when "
            "running Selenium tests via Selenium Hub."
        ),
    )
    parser.add_argument(
        "--debug-sql",
        action="store_true",
        help="Turn on the SQL query logger within tests.",
    )
    # 0 is converted to "auto" or 1 later on, depending on a method used by
    # multiprocessing to start subprocesses and on the backend support for
    # cloning databases.
    parser.add_argument(
        "--parallel",
        nargs="?",
        const="auto",
        default=0,
        type=parallel_type,
        metavar="N",
        help=(
            'Run tests using up to N parallel processes. Use the value "auto" '
            "to run one test process for each processor core."
        ),
    )
    parser.add_argument(
        "--tag",
        dest="tags",
        action="append",
        help="Run only tests with the specified tags. Can be used multiple times.",
    )
    parser.add_argument(
        "--exclude-tag",
        dest="exclude_tags",
        action="append",
        help="Do not run tests with the specified tag. Can be used multiple times.",
    )
    parser.add_argument(
        "--start-after",
        dest="start_after",
        help="Run tests starting after the specified top-level module.",
    )
    parser.add_argument(
        "--start-at",
        dest="start_at",
        help="Run tests starting at the specified top-level module.",
    )
    parser.add_argument(
        "--pdb", action="store_true", help="Runs the PDB debugger on error or failure."
    )
    parser.add_argument(
        "-b",
        "--buffer",
        action="store_true",
        help="Discard output of passing tests.",
    )
    parser.add_argument(
        "--timing",
        action="store_true",
        help="Output timings, including database set up and total run time.",
    )
    parser.add_argument(
        "-k",
        dest="test_name_patterns",
        action="append",
        help=(
            "Only run test methods and classes matching test name pattern. "
            "Same as unittest -k option. Can be used multiple times."
        ),
    )

    options = parser.parse_args()

    using_selenium_hub = options.selenium and options.selenium_hub
    if options.selenium_hub and not options.selenium:
        parser.error(
            "--selenium-hub and --external-host require --selenium to be used."
        )
    if using_selenium_hub and not options.external_host:
        parser.error("--selenium-hub and --external-host must be used together.")

    # Allow including a trailing slash on app_labels for tab completion convenience
    options.modules = [os.path.normpath(labels) for labels in options.modules]

    mutually_exclusive_options = [
        options.start_at,
        options.start_after,
        options.modules,
    ]
    enabled_module_options = [
        bool(option) for option in mutually_exclusive_options
    ].count(True)
    if enabled_module_options > 1:
        print(
            "Aborting: --start-at, --start-after, and test labels are mutually "
            "exclusive."
        )
        sys.exit(1)
    for opt_name in ["start_at", "start_after"]:
        opt_val = getattr(options, opt_name)
        if opt_val:
            if "." in opt_val:
                print(
                    "Aborting: --%s must be a top-level module."
                    % opt_name.replace("_", "-")
                )
                sys.exit(1)
            setattr(options, opt_name, os.path.normpath(opt_val))
    if options.settings:
        os.environ["DJANGO_SETTINGS_MODULE"] = options.settings
    else:
        os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_sqlite")
        options.settings = os.environ["DJANGO_SETTINGS_MODULE"]

    if options.selenium:
        if not options.tags:
            options.tags = ["selenium"]
        elif "selenium" not in options.tags:
            options.tags.append("selenium")
        if options.selenium_hub:
            SeleniumTestCaseBase.selenium_hub = options.selenium_hub
            SeleniumTestCaseBase.external_host = options.external_host
        SeleniumTestCaseBase.headless = options.headless
        SeleniumTestCaseBase.browsers = options.selenium

    if options.bisect:
        bisect_tests(
            options.bisect,
            options,
            options.modules,
            options.start_at,
            options.start_after,
        )
    elif options.pair:
        paired_tests(
            options.pair,
            options,
            options.modules,
            options.start_at,
            options.start_after,
        )
    else:
        time_keeper = TimeKeeper() if options.timing else NullTimeKeeper()
        with time_keeper.timed("Total run"):
            failures = django_tests(
                options.verbosity,
                options.interactive,
                options.failfast,
                options.keepdb,
                options.reverse,
                options.modules,
                options.debug_sql,
                options.parallel,
                options.tags,
                options.exclude_tags,
                getattr(options, "test_name_patterns", None),
                options.start_at,
                options.start_after,
                options.pdb,
                options.buffer,
                options.timing,
                options.shuffle,
            )
        time_keeper.print_results()
        if failures:
            sys.exit(1)