diff --git a/django/__init__.py b/django/__init__.py index 2d5ebe0074..7c6a889907 100644 --- a/django/__init__.py +++ b/django/__init__.py @@ -1,6 +1,6 @@ from django.utils.version import get_version -VERSION = (4, 1, 0, 'alpha', 0) +VERSION = (4, 1, 0, "alpha", 0) __version__ = get_version(VERSION) @@ -19,6 +19,6 @@ def setup(set_prefix=True): configure_logging(settings.LOGGING_CONFIG, settings.LOGGING) if set_prefix: set_script_prefix( - '/' if settings.FORCE_SCRIPT_NAME is None else settings.FORCE_SCRIPT_NAME + "/" if settings.FORCE_SCRIPT_NAME is None else settings.FORCE_SCRIPT_NAME ) apps.populate(settings.INSTALLED_APPS) diff --git a/django/apps/__init__.py b/django/apps/__init__.py index 79091dc535..96674be73c 100644 --- a/django/apps/__init__.py +++ b/django/apps/__init__.py @@ -1,4 +1,4 @@ from .config import AppConfig from .registry import apps -__all__ = ['AppConfig', 'apps'] +__all__ = ["AppConfig", "apps"] diff --git a/django/apps/config.py b/django/apps/config.py index 671db98200..28e50e5225 100644 --- a/django/apps/config.py +++ b/django/apps/config.py @@ -6,8 +6,8 @@ from django.core.exceptions import ImproperlyConfigured from django.utils.functional import cached_property from django.utils.module_loading import import_string, module_has_submodule -APPS_MODULE_NAME = 'apps' -MODELS_MODULE_NAME = 'models' +APPS_MODULE_NAME = "apps" +MODELS_MODULE_NAME = "models" class AppConfig: @@ -30,7 +30,7 @@ class AppConfig: # Last component of the Python path to the application e.g. 'admin'. # This value must be unique across a Django project. - if not hasattr(self, 'label'): + if not hasattr(self, "label"): self.label = app_name.rpartition(".")[2] if not self.label.isidentifier(): raise ImproperlyConfigured( @@ -38,12 +38,12 @@ class AppConfig: ) # Human-readable name for the application e.g. "Admin". - if not hasattr(self, 'verbose_name'): + if not hasattr(self, "verbose_name"): self.verbose_name = self.label.title() # Filesystem path to the application directory e.g. # '/path/to/django/contrib/admin'. - if not hasattr(self, 'path'): + if not hasattr(self, "path"): self.path = self._path_from_module(app_module) # Module containing models e.g. ' % (self.__class__.__name__, self.label) + return "<%s: %s>" % (self.__class__.__name__, self.label) @cached_property def default_auto_field(self): from django.conf import settings + return settings.DEFAULT_AUTO_FIELD @property @@ -72,9 +73,9 @@ class AppConfig: # See #21874 for extended discussion of the behavior of this method in # various cases. # Convert to list because __path__ may not support indexing. - paths = list(getattr(module, '__path__', [])) + paths = list(getattr(module, "__path__", [])) if len(paths) != 1: - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename is not None: paths = [os.path.dirname(filename)] else: @@ -85,12 +86,14 @@ class AppConfig: raise ImproperlyConfigured( "The app module %r has multiple filesystem locations (%r); " "you must configure this app with an AppConfig subclass " - "with a 'path' class attribute." % (module, paths)) + "with a 'path' class attribute." % (module, paths) + ) elif not paths: raise ImproperlyConfigured( "The app module %r has no filesystem location, " "you must configure this app with an AppConfig subclass " - "with a 'path' class attribute." % module) + "with a 'path' class attribute." % module + ) return paths[0] @classmethod @@ -116,7 +119,7 @@ class AppConfig: # If the apps module defines more than one AppConfig subclass, # the default one can declare default = True. if module_has_submodule(app_module, APPS_MODULE_NAME): - mod_path = '%s.%s' % (entry, APPS_MODULE_NAME) + mod_path = "%s.%s" % (entry, APPS_MODULE_NAME) mod = import_module(mod_path) # Check if there's exactly one AppConfig candidate, # excluding those that explicitly define default = False. @@ -124,9 +127,9 @@ class AppConfig: (name, candidate) for name, candidate in inspect.getmembers(mod, inspect.isclass) if ( - issubclass(candidate, cls) and - candidate is not cls and - getattr(candidate, 'default', True) + issubclass(candidate, cls) + and candidate is not cls + and getattr(candidate, "default", True) ) ] if len(app_configs) == 1: @@ -137,13 +140,13 @@ class AppConfig: app_configs = [ (name, candidate) for name, candidate in app_configs - if getattr(candidate, 'default', False) + if getattr(candidate, "default", False) ] if len(app_configs) > 1: candidates = [repr(name) for name, _ in app_configs] raise RuntimeError( - '%r declares more than one default AppConfig: ' - '%s.' % (mod_path, ', '.join(candidates)) + "%r declares more than one default AppConfig: " + "%s." % (mod_path, ", ".join(candidates)) ) elif len(app_configs) == 1: app_config_class = app_configs[0][1] @@ -165,7 +168,7 @@ class AppConfig: # If the last component of entry starts with an uppercase letter, # then it was likely intended to be an app config class; if not, # an app module. Provide a nice error message in both cases. - mod_path, _, cls_name = entry.rpartition('.') + mod_path, _, cls_name = entry.rpartition(".") if mod_path and cls_name[0].isupper(): # We could simply re-trigger the string import exception, but # we're going the extra mile and providing a better error @@ -178,9 +181,12 @@ class AppConfig: for name, candidate in inspect.getmembers(mod, inspect.isclass) if issubclass(candidate, cls) and candidate is not cls ] - msg = "Module '%s' does not contain a '%s' class." % (mod_path, cls_name) + msg = "Module '%s' does not contain a '%s' class." % ( + mod_path, + cls_name, + ) if candidates: - msg += ' Choices are: %s.' % ', '.join(candidates) + msg += " Choices are: %s." % ", ".join(candidates) raise ImportError(msg) else: # Re-trigger the module import exception. @@ -189,8 +195,7 @@ class AppConfig: # Check for obvious errors. (This check prevents duck typing, but # it could be removed if it became a problem in practice.) if not issubclass(app_config_class, AppConfig): - raise ImproperlyConfigured( - "'%s' isn't a subclass of AppConfig." % entry) + raise ImproperlyConfigured("'%s' isn't a subclass of AppConfig." % entry) # Obtain app name here rather than in AppClass.__init__ to keep # all error checking for entries in INSTALLED_APPS in one place. @@ -198,16 +203,15 @@ class AppConfig: try: app_name = app_config_class.name except AttributeError: - raise ImproperlyConfigured( - "'%s' must supply a name attribute." % entry - ) + raise ImproperlyConfigured("'%s' must supply a name attribute." % entry) # Ensure app_name points to a valid module. try: app_module = import_module(app_name) except ImportError: raise ImproperlyConfigured( - "Cannot import '%s'. Check that '%s.%s.name' is correct." % ( + "Cannot import '%s'. Check that '%s.%s.name' is correct." + % ( app_name, app_config_class.__module__, app_config_class.__qualname__, @@ -231,7 +235,8 @@ class AppConfig: return self.models[model_name.lower()] except KeyError: raise LookupError( - "App '%s' doesn't have a '%s' model." % (self.label, model_name)) + "App '%s' doesn't have a '%s' model." % (self.label, model_name) + ) def get_models(self, include_auto_created=False, include_swapped=False): """ @@ -260,7 +265,7 @@ class AppConfig: self.models = self.apps.all_models[self.label] if module_has_submodule(self.module, MODELS_MODULE_NAME): - models_module_name = '%s.%s' % (self.name, MODELS_MODULE_NAME) + models_module_name = "%s.%s" % (self.name, MODELS_MODULE_NAME) self.models_module = import_module(models_module_name) def ready(self): diff --git a/django/apps/registry.py b/django/apps/registry.py index 268e1a6af7..06d5bca060 100644 --- a/django/apps/registry.py +++ b/django/apps/registry.py @@ -21,7 +21,7 @@ class Apps: # installed_apps is set to None when creating the master registry # because it cannot be populated at that point. Other registries must # provide a list of installed apps and are populated immediately. - if installed_apps is None and hasattr(sys.modules[__name__], 'apps'): + if installed_apps is None and hasattr(sys.modules[__name__], "apps"): raise RuntimeError("You must supply an installed_apps argument.") # Mapping of app labels => model names => model classes. Every time a @@ -92,20 +92,22 @@ class Apps: if app_config.label in self.app_configs: raise ImproperlyConfigured( "Application labels aren't unique, " - "duplicates: %s" % app_config.label) + "duplicates: %s" % app_config.label + ) self.app_configs[app_config.label] = app_config app_config.apps = self # Check for duplicate app names. counts = Counter( - app_config.name for app_config in self.app_configs.values()) - duplicates = [ - name for name, count in counts.most_common() if count > 1] + app_config.name for app_config in self.app_configs.values() + ) + duplicates = [name for name, count in counts.most_common() if count > 1] if duplicates: raise ImproperlyConfigured( "Application names aren't unique, " - "duplicates: %s" % ", ".join(duplicates)) + "duplicates: %s" % ", ".join(duplicates) + ) self.apps_ready = True @@ -201,7 +203,7 @@ class Apps: self.check_apps_ready() if model_name is None: - app_label, model_name = app_label.split('.') + app_label, model_name = app_label.split(".") app_config = self.get_app_config(app_label) @@ -217,17 +219,22 @@ class Apps: model_name = model._meta.model_name app_models = self.all_models[app_label] if model_name in app_models: - if (model.__name__ == app_models[model_name].__name__ and - model.__module__ == app_models[model_name].__module__): + if ( + model.__name__ == app_models[model_name].__name__ + and model.__module__ == app_models[model_name].__module__ + ): warnings.warn( "Model '%s.%s' was already registered. " "Reloading models is not advised as it can lead to inconsistencies, " "most notably with related models." % (app_label, model_name), - RuntimeWarning, stacklevel=2) + RuntimeWarning, + stacklevel=2, + ) else: raise RuntimeError( - "Conflicting '%s' models in application '%s': %s and %s." % - (model_name, app_label, app_models[model_name], model)) + "Conflicting '%s' models in application '%s': %s and %s." + % (model_name, app_label, app_models[model_name], model) + ) app_models[model_name] = model self.do_pending_operations(model) self.clear_cache() @@ -254,8 +261,8 @@ class Apps: candidates = [] for app_config in self.app_configs.values(): if object_name.startswith(app_config.name): - subpath = object_name[len(app_config.name):] - if subpath == '' or subpath[0] == '.': + subpath = object_name[len(app_config.name) :] + if subpath == "" or subpath[0] == ".": candidates.append(app_config) if candidates: return sorted(candidates, key=lambda ac: -len(ac.name))[0] @@ -270,8 +277,7 @@ class Apps: """ model = self.all_models[app_label].get(model_name.lower()) if model is None: - raise LookupError( - "Model '%s.%s' not registered." % (app_label, model_name)) + raise LookupError("Model '%s.%s' not registered." % (app_label, model_name)) return model @functools.lru_cache(maxsize=None) @@ -403,6 +409,7 @@ class Apps: def apply_next_model(model): next_function = partial(apply_next_model.func, model) self.lazy_model_operation(next_function, *more_models) + apply_next_model.func = function # If the model has already been imported and registered, partially diff --git a/django/conf/__init__.py b/django/conf/__init__.py index 05345cb6e9..e1fb7f973f 100644 --- a/django/conf/__init__.py +++ b/django/conf/__init__.py @@ -23,20 +23,20 @@ ENVIRONMENT_VARIABLE = "DJANGO_SETTINGS_MODULE" # RemovedInDjango50Warning USE_DEPRECATED_PYTZ_DEPRECATED_MSG = ( - 'The USE_DEPRECATED_PYTZ setting, and support for pytz timezones is ' - 'deprecated in favor of the stdlib zoneinfo module. Please update your ' - 'code to use zoneinfo and remove the USE_DEPRECATED_PYTZ setting.' + "The USE_DEPRECATED_PYTZ setting, and support for pytz timezones is " + "deprecated in favor of the stdlib zoneinfo module. Please update your " + "code to use zoneinfo and remove the USE_DEPRECATED_PYTZ setting." ) USE_L10N_DEPRECATED_MSG = ( - 'The USE_L10N setting is deprecated. Starting with Django 5.0, localized ' - 'formatting of data will always be enabled. For example Django will ' - 'display numbers and dates using the format of the current locale.' + "The USE_L10N setting is deprecated. Starting with Django 5.0, localized " + "formatting of data will always be enabled. For example Django will " + "display numbers and dates using the format of the current locale." ) CSRF_COOKIE_MASKED_DEPRECATED_MSG = ( - 'The CSRF_COOKIE_MASKED transitional setting is deprecated. Support for ' - 'it will be removed in Django 5.0.' + "The CSRF_COOKIE_MASKED transitional setting is deprecated. Support for " + "it will be removed in Django 5.0." ) @@ -45,6 +45,7 @@ class SettingsReference(str): String subclass which references a current settings value. It's treated as the value in memory but serializes to a settings.NAME attribute reference. """ + def __new__(self, value, setting_name): return str.__new__(self, value) @@ -58,6 +59,7 @@ class LazySettings(LazyObject): The user can manually configure settings prior to using them. Otherwise, Django uses the settings module pointed to by DJANGO_SETTINGS_MODULE. """ + def _setup(self, name=None): """ Load the settings module pointed to by the environment variable. This @@ -71,16 +73,17 @@ class LazySettings(LazyObject): "Requested %s, but settings are not configured. " "You must either define the environment variable %s " "or call settings.configure() before accessing settings." - % (desc, ENVIRONMENT_VARIABLE)) + % (desc, ENVIRONMENT_VARIABLE) + ) self._wrapped = Settings(settings_module) def __repr__(self): # Hardcode the class name as otherwise it yields 'Settings'. if self._wrapped is empty: - return '' + return "" return '' % { - 'settings_module': self._wrapped.SETTINGS_MODULE, + "settings_module": self._wrapped.SETTINGS_MODULE, } def __getattr__(self, name): @@ -91,9 +94,9 @@ class LazySettings(LazyObject): # Special case some settings which require further modification. # This is done here for performance reasons so the modified value is cached. - if name in {'MEDIA_URL', 'STATIC_URL'} and val is not None: + if name in {"MEDIA_URL", "STATIC_URL"} and val is not None: val = self._add_script_prefix(val) - elif name == 'SECRET_KEY' and not val: + elif name == "SECRET_KEY" and not val: raise ImproperlyConfigured("The SECRET_KEY setting must not be empty.") self.__dict__[name] = val @@ -104,7 +107,7 @@ class LazySettings(LazyObject): Set the value of setting. Clear all cached values if _wrapped changes (@override_settings does this) or clear single values when set. """ - if name == '_wrapped': + if name == "_wrapped": self.__dict__.clear() else: self.__dict__.pop(name, None) @@ -122,11 +125,11 @@ class LazySettings(LazyObject): argument must support attribute access (__getattr__)). """ if self._wrapped is not empty: - raise RuntimeError('Settings already configured.') + raise RuntimeError("Settings already configured.") holder = UserSettingsHolder(default_settings) for name, value in options.items(): if not name.isupper(): - raise TypeError('Setting %r must be uppercase.' % name) + raise TypeError("Setting %r must be uppercase." % name) setattr(holder, name, value) self._wrapped = holder @@ -139,10 +142,11 @@ class LazySettings(LazyObject): subpath to STATIC_URL and MEDIA_URL in settings is inconvenient. """ # Don't apply prefix to absolute paths and URLs. - if value.startswith(('http://', 'https://', '/')): + if value.startswith(("http://", "https://", "/")): return value from django.urls import get_script_prefix - return '%s%s' % (get_script_prefix(), value) + + return "%s%s" % (get_script_prefix(), value) @property def configured(self): @@ -161,14 +165,14 @@ class LazySettings(LazyObject): RemovedInDjango50Warning, stacklevel=2, ) - return self.__getattr__('USE_L10N') + return self.__getattr__("USE_L10N") # RemovedInDjango50Warning. @property def _USE_L10N_INTERNAL(self): # Special hook to avoid checking a traceback in internal use on hot # paths. - return self.__getattr__('USE_L10N') + return self.__getattr__("USE_L10N") class Settings: @@ -184,7 +188,7 @@ class Settings: mod = importlib.import_module(self.SETTINGS_MODULE) tuple_settings = ( - 'ALLOWED_HOSTS', + "ALLOWED_HOSTS", "INSTALLED_APPS", "TEMPLATE_DIRS", "LOCALE_PATHS", @@ -195,39 +199,42 @@ class Settings: if setting.isupper(): setting_value = getattr(mod, setting) - if (setting in tuple_settings and - not isinstance(setting_value, (list, tuple))): - raise ImproperlyConfigured("The %s setting must be a list or a tuple." % setting) + if setting in tuple_settings and not isinstance( + setting_value, (list, tuple) + ): + raise ImproperlyConfigured( + "The %s setting must be a list or a tuple." % setting + ) setattr(self, setting, setting_value) self._explicit_settings.add(setting) - if self.USE_TZ is False and not self.is_overridden('USE_TZ'): + if self.USE_TZ is False and not self.is_overridden("USE_TZ"): warnings.warn( - 'The default value of USE_TZ will change from False to True ' - 'in Django 5.0. Set USE_TZ to False in your project settings ' - 'if you want to keep the current default behavior.', + "The default value of USE_TZ will change from False to True " + "in Django 5.0. Set USE_TZ to False in your project settings " + "if you want to keep the current default behavior.", category=RemovedInDjango50Warning, ) - if self.is_overridden('USE_DEPRECATED_PYTZ'): + if self.is_overridden("USE_DEPRECATED_PYTZ"): warnings.warn(USE_DEPRECATED_PYTZ_DEPRECATED_MSG, RemovedInDjango50Warning) - if self.is_overridden('CSRF_COOKIE_MASKED'): + if self.is_overridden("CSRF_COOKIE_MASKED"): warnings.warn(CSRF_COOKIE_MASKED_DEPRECATED_MSG, RemovedInDjango50Warning) - if hasattr(time, 'tzset') and self.TIME_ZONE: + if hasattr(time, "tzset") and self.TIME_ZONE: # When we can, attempt to validate the timezone. If we can't find # this file, no check happens and it's harmless. - zoneinfo_root = Path('/usr/share/zoneinfo') - zone_info_file = zoneinfo_root.joinpath(*self.TIME_ZONE.split('/')) + zoneinfo_root = Path("/usr/share/zoneinfo") + zone_info_file = zoneinfo_root.joinpath(*self.TIME_ZONE.split("/")) if zoneinfo_root.exists() and not zone_info_file.exists(): raise ValueError("Incorrect timezone setting: %s" % self.TIME_ZONE) # Move the time zone info into os.environ. See ticket #2315 for why # we don't do this unconditionally (breaks Windows). - os.environ['TZ'] = self.TIME_ZONE + os.environ["TZ"] = self.TIME_ZONE time.tzset() - if self.is_overridden('USE_L10N'): + if self.is_overridden("USE_L10N"): warnings.warn(USE_L10N_DEPRECATED_MSG, RemovedInDjango50Warning) def is_overridden(self, setting): @@ -235,13 +242,14 @@ class Settings: def __repr__(self): return '<%(cls)s "%(settings_module)s">' % { - 'cls': self.__class__.__name__, - 'settings_module': self.SETTINGS_MODULE, + "cls": self.__class__.__name__, + "settings_module": self.SETTINGS_MODULE, } class UserSettingsHolder: """Holder for user configured settings.""" + # SETTINGS_MODULE doesn't make much sense in the manually configured # (standalone) case. SETTINGS_MODULE = None @@ -251,7 +259,7 @@ class UserSettingsHolder: Requests for configuration variables not in this class are satisfied from the module specified in default_settings (if possible). """ - self.__dict__['_deleted'] = set() + self.__dict__["_deleted"] = set() self.default_settings = default_settings def __getattr__(self, name): @@ -261,12 +269,12 @@ class UserSettingsHolder: def __setattr__(self, name, value): self._deleted.discard(name) - if name == 'USE_L10N': + if name == "USE_L10N": warnings.warn(USE_L10N_DEPRECATED_MSG, RemovedInDjango50Warning) - if name == 'CSRF_COOKIE_MASKED': + if name == "CSRF_COOKIE_MASKED": warnings.warn(CSRF_COOKIE_MASKED_DEPRECATED_MSG, RemovedInDjango50Warning) super().__setattr__(name, value) - if name == 'USE_DEPRECATED_PYTZ': + if name == "USE_DEPRECATED_PYTZ": warnings.warn(USE_DEPRECATED_PYTZ_DEPRECATED_MSG, RemovedInDjango50Warning) def __delattr__(self, name): @@ -276,19 +284,22 @@ class UserSettingsHolder: def __dir__(self): return sorted( - s for s in [*self.__dict__, *dir(self.default_settings)] + s + for s in [*self.__dict__, *dir(self.default_settings)] if s not in self._deleted ) def is_overridden(self, setting): - deleted = (setting in self._deleted) - set_locally = (setting in self.__dict__) - set_on_default = getattr(self.default_settings, 'is_overridden', lambda s: False)(setting) + deleted = setting in self._deleted + set_locally = setting in self.__dict__ + set_on_default = getattr( + self.default_settings, "is_overridden", lambda s: False + )(setting) return deleted or set_locally or set_on_default def __repr__(self): - return '<%(cls)s>' % { - 'cls': self.__class__.__name__, + return "<%(cls)s>" % { + "cls": self.__class__.__name__, } diff --git a/django/conf/global_settings.py b/django/conf/global_settings.py index 82acca010a..2a5f95d6fb 100644 --- a/django/conf/global_settings.py +++ b/django/conf/global_settings.py @@ -38,7 +38,7 @@ ALLOWED_HOSTS = [] # https://en.wikipedia.org/wiki/List_of_tz_zones_by_name (although not all # systems may support all possibilities). When USE_TZ is True, this is # interpreted as the default user time zone. -TIME_ZONE = 'America/Chicago' +TIME_ZONE = "America/Chicago" # If you set this to True, Django will use timezone-aware datetimes. USE_TZ = False @@ -50,107 +50,107 @@ USE_DEPRECATED_PYTZ = False # Language code for this installation. All choices can be found here: # http://www.i18nguy.com/unicode/language-identifiers.html -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" # Languages we provide translations for, out of the box. LANGUAGES = [ - ('af', gettext_noop('Afrikaans')), - ('ar', gettext_noop('Arabic')), - ('ar-dz', gettext_noop('Algerian Arabic')), - ('ast', gettext_noop('Asturian')), - ('az', gettext_noop('Azerbaijani')), - ('bg', gettext_noop('Bulgarian')), - ('be', gettext_noop('Belarusian')), - ('bn', gettext_noop('Bengali')), - ('br', gettext_noop('Breton')), - ('bs', gettext_noop('Bosnian')), - ('ca', gettext_noop('Catalan')), - ('cs', gettext_noop('Czech')), - ('cy', gettext_noop('Welsh')), - ('da', gettext_noop('Danish')), - ('de', gettext_noop('German')), - ('dsb', gettext_noop('Lower Sorbian')), - ('el', gettext_noop('Greek')), - ('en', gettext_noop('English')), - ('en-au', gettext_noop('Australian English')), - ('en-gb', gettext_noop('British English')), - ('eo', gettext_noop('Esperanto')), - ('es', gettext_noop('Spanish')), - ('es-ar', gettext_noop('Argentinian Spanish')), - ('es-co', gettext_noop('Colombian Spanish')), - ('es-mx', gettext_noop('Mexican Spanish')), - ('es-ni', gettext_noop('Nicaraguan Spanish')), - ('es-ve', gettext_noop('Venezuelan Spanish')), - ('et', gettext_noop('Estonian')), - ('eu', gettext_noop('Basque')), - ('fa', gettext_noop('Persian')), - ('fi', gettext_noop('Finnish')), - ('fr', gettext_noop('French')), - ('fy', gettext_noop('Frisian')), - ('ga', gettext_noop('Irish')), - ('gd', gettext_noop('Scottish Gaelic')), - ('gl', gettext_noop('Galician')), - ('he', gettext_noop('Hebrew')), - ('hi', gettext_noop('Hindi')), - ('hr', gettext_noop('Croatian')), - ('hsb', gettext_noop('Upper Sorbian')), - ('hu', gettext_noop('Hungarian')), - ('hy', gettext_noop('Armenian')), - ('ia', gettext_noop('Interlingua')), - ('id', gettext_noop('Indonesian')), - ('ig', gettext_noop('Igbo')), - ('io', gettext_noop('Ido')), - ('is', gettext_noop('Icelandic')), - ('it', gettext_noop('Italian')), - ('ja', gettext_noop('Japanese')), - ('ka', gettext_noop('Georgian')), - ('kab', gettext_noop('Kabyle')), - ('kk', gettext_noop('Kazakh')), - ('km', gettext_noop('Khmer')), - ('kn', gettext_noop('Kannada')), - ('ko', gettext_noop('Korean')), - ('ky', gettext_noop('Kyrgyz')), - ('lb', gettext_noop('Luxembourgish')), - ('lt', gettext_noop('Lithuanian')), - ('lv', gettext_noop('Latvian')), - ('mk', gettext_noop('Macedonian')), - ('ml', gettext_noop('Malayalam')), - ('mn', gettext_noop('Mongolian')), - ('mr', gettext_noop('Marathi')), - ('ms', gettext_noop('Malay')), - ('my', gettext_noop('Burmese')), - ('nb', gettext_noop('Norwegian Bokmål')), - ('ne', gettext_noop('Nepali')), - ('nl', gettext_noop('Dutch')), - ('nn', gettext_noop('Norwegian Nynorsk')), - ('os', gettext_noop('Ossetic')), - ('pa', gettext_noop('Punjabi')), - ('pl', gettext_noop('Polish')), - ('pt', gettext_noop('Portuguese')), - ('pt-br', gettext_noop('Brazilian Portuguese')), - ('ro', gettext_noop('Romanian')), - ('ru', gettext_noop('Russian')), - ('sk', gettext_noop('Slovak')), - ('sl', gettext_noop('Slovenian')), - ('sq', gettext_noop('Albanian')), - ('sr', gettext_noop('Serbian')), - ('sr-latn', gettext_noop('Serbian Latin')), - ('sv', gettext_noop('Swedish')), - ('sw', gettext_noop('Swahili')), - ('ta', gettext_noop('Tamil')), - ('te', gettext_noop('Telugu')), - ('tg', gettext_noop('Tajik')), - ('th', gettext_noop('Thai')), - ('tk', gettext_noop('Turkmen')), - ('tr', gettext_noop('Turkish')), - ('tt', gettext_noop('Tatar')), - ('udm', gettext_noop('Udmurt')), - ('uk', gettext_noop('Ukrainian')), - ('ur', gettext_noop('Urdu')), - ('uz', gettext_noop('Uzbek')), - ('vi', gettext_noop('Vietnamese')), - ('zh-hans', gettext_noop('Simplified Chinese')), - ('zh-hant', gettext_noop('Traditional Chinese')), + ("af", gettext_noop("Afrikaans")), + ("ar", gettext_noop("Arabic")), + ("ar-dz", gettext_noop("Algerian Arabic")), + ("ast", gettext_noop("Asturian")), + ("az", gettext_noop("Azerbaijani")), + ("bg", gettext_noop("Bulgarian")), + ("be", gettext_noop("Belarusian")), + ("bn", gettext_noop("Bengali")), + ("br", gettext_noop("Breton")), + ("bs", gettext_noop("Bosnian")), + ("ca", gettext_noop("Catalan")), + ("cs", gettext_noop("Czech")), + ("cy", gettext_noop("Welsh")), + ("da", gettext_noop("Danish")), + ("de", gettext_noop("German")), + ("dsb", gettext_noop("Lower Sorbian")), + ("el", gettext_noop("Greek")), + ("en", gettext_noop("English")), + ("en-au", gettext_noop("Australian English")), + ("en-gb", gettext_noop("British English")), + ("eo", gettext_noop("Esperanto")), + ("es", gettext_noop("Spanish")), + ("es-ar", gettext_noop("Argentinian Spanish")), + ("es-co", gettext_noop("Colombian Spanish")), + ("es-mx", gettext_noop("Mexican Spanish")), + ("es-ni", gettext_noop("Nicaraguan Spanish")), + ("es-ve", gettext_noop("Venezuelan Spanish")), + ("et", gettext_noop("Estonian")), + ("eu", gettext_noop("Basque")), + ("fa", gettext_noop("Persian")), + ("fi", gettext_noop("Finnish")), + ("fr", gettext_noop("French")), + ("fy", gettext_noop("Frisian")), + ("ga", gettext_noop("Irish")), + ("gd", gettext_noop("Scottish Gaelic")), + ("gl", gettext_noop("Galician")), + ("he", gettext_noop("Hebrew")), + ("hi", gettext_noop("Hindi")), + ("hr", gettext_noop("Croatian")), + ("hsb", gettext_noop("Upper Sorbian")), + ("hu", gettext_noop("Hungarian")), + ("hy", gettext_noop("Armenian")), + ("ia", gettext_noop("Interlingua")), + ("id", gettext_noop("Indonesian")), + ("ig", gettext_noop("Igbo")), + ("io", gettext_noop("Ido")), + ("is", gettext_noop("Icelandic")), + ("it", gettext_noop("Italian")), + ("ja", gettext_noop("Japanese")), + ("ka", gettext_noop("Georgian")), + ("kab", gettext_noop("Kabyle")), + ("kk", gettext_noop("Kazakh")), + ("km", gettext_noop("Khmer")), + ("kn", gettext_noop("Kannada")), + ("ko", gettext_noop("Korean")), + ("ky", gettext_noop("Kyrgyz")), + ("lb", gettext_noop("Luxembourgish")), + ("lt", gettext_noop("Lithuanian")), + ("lv", gettext_noop("Latvian")), + ("mk", gettext_noop("Macedonian")), + ("ml", gettext_noop("Malayalam")), + ("mn", gettext_noop("Mongolian")), + ("mr", gettext_noop("Marathi")), + ("ms", gettext_noop("Malay")), + ("my", gettext_noop("Burmese")), + ("nb", gettext_noop("Norwegian Bokmål")), + ("ne", gettext_noop("Nepali")), + ("nl", gettext_noop("Dutch")), + ("nn", gettext_noop("Norwegian Nynorsk")), + ("os", gettext_noop("Ossetic")), + ("pa", gettext_noop("Punjabi")), + ("pl", gettext_noop("Polish")), + ("pt", gettext_noop("Portuguese")), + ("pt-br", gettext_noop("Brazilian Portuguese")), + ("ro", gettext_noop("Romanian")), + ("ru", gettext_noop("Russian")), + ("sk", gettext_noop("Slovak")), + ("sl", gettext_noop("Slovenian")), + ("sq", gettext_noop("Albanian")), + ("sr", gettext_noop("Serbian")), + ("sr-latn", gettext_noop("Serbian Latin")), + ("sv", gettext_noop("Swedish")), + ("sw", gettext_noop("Swahili")), + ("ta", gettext_noop("Tamil")), + ("te", gettext_noop("Telugu")), + ("tg", gettext_noop("Tajik")), + ("th", gettext_noop("Thai")), + ("tk", gettext_noop("Turkmen")), + ("tr", gettext_noop("Turkish")), + ("tt", gettext_noop("Tatar")), + ("udm", gettext_noop("Udmurt")), + ("uk", gettext_noop("Ukrainian")), + ("ur", gettext_noop("Urdu")), + ("uz", gettext_noop("Uzbek")), + ("vi", gettext_noop("Vietnamese")), + ("zh-hans", gettext_noop("Simplified Chinese")), + ("zh-hant", gettext_noop("Traditional Chinese")), ] # Languages using BiDi (right-to-left) layout @@ -162,10 +162,10 @@ USE_I18N = True LOCALE_PATHS = [] # Settings for language cookie -LANGUAGE_COOKIE_NAME = 'django_language' +LANGUAGE_COOKIE_NAME = "django_language" LANGUAGE_COOKIE_AGE = None LANGUAGE_COOKIE_DOMAIN = None -LANGUAGE_COOKIE_PATH = '/' +LANGUAGE_COOKIE_PATH = "/" LANGUAGE_COOKIE_SECURE = False LANGUAGE_COOKIE_HTTPONLY = False LANGUAGE_COOKIE_SAMESITE = None @@ -181,10 +181,10 @@ MANAGERS = ADMINS # Default charset to use for all HttpResponse objects, if a MIME type isn't # manually specified. It's used to construct the Content-Type header. -DEFAULT_CHARSET = 'utf-8' +DEFAULT_CHARSET = "utf-8" # Email address that error messages come from. -SERVER_EMAIL = 'root@localhost' +SERVER_EMAIL = "root@localhost" # Database connection info. If left empty, will default to the dummy backend. DATABASES = {} @@ -196,10 +196,10 @@ DATABASE_ROUTERS = [] # The default is to use the SMTP backend. # Third-party backends can be specified by providing a Python path # to a module that defines an EmailBackend class. -EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' +EMAIL_BACKEND = "django.core.mail.backends.smtp.EmailBackend" # Host for sending email. -EMAIL_HOST = 'localhost' +EMAIL_HOST = "localhost" # Port for sending email. EMAIL_PORT = 25 @@ -208,8 +208,8 @@ EMAIL_PORT = 25 EMAIL_USE_LOCALTIME = False # Optional SMTP authentication information for EMAIL_HOST. -EMAIL_HOST_USER = '' -EMAIL_HOST_PASSWORD = '' +EMAIL_HOST_USER = "" +EMAIL_HOST_PASSWORD = "" EMAIL_USE_TLS = False EMAIL_USE_SSL = False EMAIL_SSL_CERTFILE = None @@ -222,15 +222,15 @@ INSTALLED_APPS = [] TEMPLATES = [] # Default form rendering class. -FORM_RENDERER = 'django.forms.renderers.DjangoTemplates' +FORM_RENDERER = "django.forms.renderers.DjangoTemplates" # Default email address to use for various automated correspondence from # the site managers. -DEFAULT_FROM_EMAIL = 'webmaster@localhost' +DEFAULT_FROM_EMAIL = "webmaster@localhost" # Subject-line prefix for email messages send with django.core.mail.mail_admins # or ...mail_managers. Make sure to include the trailing space. -EMAIL_SUBJECT_PREFIX = '[Django] ' +EMAIL_SUBJECT_PREFIX = "[Django] " # Whether to append trailing slashes to URLs. APPEND_SLASH = True @@ -270,22 +270,22 @@ IGNORABLE_404_URLS = [] # A secret key for this particular Django installation. Used in secret-key # hashing algorithms. Set this in your settings, or Django will complain # loudly. -SECRET_KEY = '' +SECRET_KEY = "" # List of secret keys used to verify the validity of signatures. This allows # secret key rotation. SECRET_KEY_FALLBACKS = [] # Default file storage mechanism that holds media. -DEFAULT_FILE_STORAGE = 'django.core.files.storage.FileSystemStorage' +DEFAULT_FILE_STORAGE = "django.core.files.storage.FileSystemStorage" # Absolute filesystem path to the directory that will hold user-uploaded files. # Example: "/var/www/example.com/media/" -MEDIA_ROOT = '' +MEDIA_ROOT = "" # URL that handles the media served from MEDIA_ROOT. # Examples: "http://example.com/media/", "http://media.example.com/" -MEDIA_URL = '' +MEDIA_URL = "" # Absolute path to the directory static files should be collected to. # Example: "/var/www/example.com/static/" @@ -297,8 +297,8 @@ STATIC_URL = None # List of upload handler classes to be applied in order. FILE_UPLOAD_HANDLERS = [ - 'django.core.files.uploadhandler.MemoryFileUploadHandler', - 'django.core.files.uploadhandler.TemporaryFileUploadHandler', + "django.core.files.uploadhandler.MemoryFileUploadHandler", + "django.core.files.uploadhandler.TemporaryFileUploadHandler", ] # Maximum size, in bytes, of a request before it will be streamed to the @@ -335,51 +335,51 @@ FORMAT_MODULE_PATH = None # Default formatting for date objects. See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'N j, Y' +DATE_FORMAT = "N j, Y" # Default formatting for datetime objects. See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATETIME_FORMAT = 'N j, Y, P' +DATETIME_FORMAT = "N j, Y, P" # Default formatting for time objects. See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -TIME_FORMAT = 'P' +TIME_FORMAT = "P" # Default formatting for date objects when only the year and month are relevant. # See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -YEAR_MONTH_FORMAT = 'F Y' +YEAR_MONTH_FORMAT = "F Y" # Default formatting for date objects when only the month and day are relevant. # See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -MONTH_DAY_FORMAT = 'F j' +MONTH_DAY_FORMAT = "F j" # Default short formatting for date objects. See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -SHORT_DATE_FORMAT = 'm/d/Y' +SHORT_DATE_FORMAT = "m/d/Y" # Default short formatting for datetime objects. # See all available format strings here: # https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -SHORT_DATETIME_FORMAT = 'm/d/Y P' +SHORT_DATETIME_FORMAT = "m/d/Y P" # Default formats to be used when parsing dates from input boxes, in order # See all available format string here: # https://docs.python.org/library/datetime.html#strftime-behavior # * Note that these format strings are different from the ones to display dates DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' - '%b %d %Y', # 'Oct 25 2006' - '%b %d, %Y', # 'Oct 25, 2006' - '%d %b %Y', # '25 Oct 2006' - '%d %b, %Y', # '25 Oct, 2006' - '%B %d %Y', # 'October 25 2006' - '%B %d, %Y', # 'October 25, 2006' - '%d %B %Y', # '25 October 2006' - '%d %B, %Y', # '25 October, 2006' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' + "%b %d %Y", # 'Oct 25 2006' + "%b %d, %Y", # 'Oct 25, 2006' + "%d %b %Y", # '25 Oct 2006' + "%d %b, %Y", # '25 Oct, 2006' + "%B %d %Y", # 'October 25 2006' + "%B %d, %Y", # 'October 25, 2006' + "%d %B %Y", # '25 October 2006' + "%d %B, %Y", # '25 October, 2006' ] # Default formats to be used when parsing times from input boxes, in order @@ -387,9 +387,9 @@ DATE_INPUT_FORMATS = [ # https://docs.python.org/library/datetime.html#strftime-behavior # * Note that these format strings are different from the ones to display dates TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' ] # Default formats to be used when parsing dates and times from input boxes, @@ -398,15 +398,15 @@ TIME_INPUT_FORMATS = [ # https://docs.python.org/library/datetime.html#strftime-behavior # * Note that these format strings are different from the ones to display dates DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' ] # First day of week, to be used on calendars @@ -414,7 +414,7 @@ DATETIME_INPUT_FORMATS = [ FIRST_DAY_OF_WEEK = 0 # Decimal separator symbol -DECIMAL_SEPARATOR = '.' +DECIMAL_SEPARATOR = "." # Boolean that sets whether to add thousand separator when formatting numbers USE_THOUSAND_SEPARATOR = False @@ -424,17 +424,17 @@ USE_THOUSAND_SEPARATOR = False NUMBER_GROUPING = 0 # Thousand separator symbol -THOUSAND_SEPARATOR = ',' +THOUSAND_SEPARATOR = "," # The tablespaces to use for each model when not specified otherwise. -DEFAULT_TABLESPACE = '' -DEFAULT_INDEX_TABLESPACE = '' +DEFAULT_TABLESPACE = "" +DEFAULT_INDEX_TABLESPACE = "" # Default primary key field type. -DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" # Default X-Frame-Options header value -X_FRAME_OPTIONS = 'DENY' +X_FRAME_OPTIONS = "DENY" USE_X_FORWARDED_HOST = False USE_X_FORWARDED_PORT = False @@ -469,9 +469,9 @@ MIDDLEWARE = [] ############ # Cache to store session data if using the cache session backend. -SESSION_CACHE_ALIAS = 'default' +SESSION_CACHE_ALIAS = "default" # Cookie name. This can be whatever you want. -SESSION_COOKIE_NAME = 'sessionid' +SESSION_COOKIE_NAME = "sessionid" # Age of cookie, in seconds (default: 2 weeks). SESSION_COOKIE_AGE = 60 * 60 * 24 * 7 * 2 # A string like "example.com", or None for standard domain cookie. @@ -479,23 +479,23 @@ SESSION_COOKIE_DOMAIN = None # Whether the session cookie should be secure (https:// only). SESSION_COOKIE_SECURE = False # The path of the session cookie. -SESSION_COOKIE_PATH = '/' +SESSION_COOKIE_PATH = "/" # Whether to use the HttpOnly flag. SESSION_COOKIE_HTTPONLY = True # Whether to set the flag restricting cookie leaks on cross-site requests. # This can be 'Lax', 'Strict', 'None', or False to disable the flag. -SESSION_COOKIE_SAMESITE = 'Lax' +SESSION_COOKIE_SAMESITE = "Lax" # Whether to save the session data on every request. SESSION_SAVE_EVERY_REQUEST = False # Whether a user's session cookie expires when the web browser is closed. SESSION_EXPIRE_AT_BROWSER_CLOSE = False # The module to store session data -SESSION_ENGINE = 'django.contrib.sessions.backends.db' +SESSION_ENGINE = "django.contrib.sessions.backends.db" # Directory to store session files if using the file session module. If None, # the backend will use a sensible default. SESSION_FILE_PATH = None # class to serialize session data -SESSION_SERIALIZER = 'django.contrib.sessions.serializers.JSONSerializer' +SESSION_SERIALIZER = "django.contrib.sessions.serializers.JSONSerializer" ######### # CACHE # @@ -503,25 +503,25 @@ SESSION_SERIALIZER = 'django.contrib.sessions.serializers.JSONSerializer' # The cache backends to use. CACHES = { - 'default': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", } } -CACHE_MIDDLEWARE_KEY_PREFIX = '' +CACHE_MIDDLEWARE_KEY_PREFIX = "" CACHE_MIDDLEWARE_SECONDS = 600 -CACHE_MIDDLEWARE_ALIAS = 'default' +CACHE_MIDDLEWARE_ALIAS = "default" ################## # AUTHENTICATION # ################## -AUTH_USER_MODEL = 'auth.User' +AUTH_USER_MODEL = "auth.User" -AUTHENTICATION_BACKENDS = ['django.contrib.auth.backends.ModelBackend'] +AUTHENTICATION_BACKENDS = ["django.contrib.auth.backends.ModelBackend"] -LOGIN_URL = '/accounts/login/' +LOGIN_URL = "/accounts/login/" -LOGIN_REDIRECT_URL = '/accounts/profile/' +LOGIN_REDIRECT_URL = "/accounts/profile/" LOGOUT_REDIRECT_URL = None @@ -532,11 +532,11 @@ PASSWORD_RESET_TIMEOUT = 60 * 60 * 24 * 3 # password using different algorithms will be converted automatically # upon login PASSWORD_HASHERS = [ - 'django.contrib.auth.hashers.PBKDF2PasswordHasher', - 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', - 'django.contrib.auth.hashers.Argon2PasswordHasher', - 'django.contrib.auth.hashers.BCryptSHA256PasswordHasher', - 'django.contrib.auth.hashers.ScryptPasswordHasher', + "django.contrib.auth.hashers.PBKDF2PasswordHasher", + "django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher", + "django.contrib.auth.hashers.Argon2PasswordHasher", + "django.contrib.auth.hashers.BCryptSHA256PasswordHasher", + "django.contrib.auth.hashers.ScryptPasswordHasher", ] AUTH_PASSWORD_VALIDATORS = [] @@ -545,7 +545,7 @@ AUTH_PASSWORD_VALIDATORS = [] # SIGNING # ########### -SIGNING_BACKEND = 'django.core.signing.TimestampSigner' +SIGNING_BACKEND = "django.core.signing.TimestampSigner" ######## # CSRF # @@ -553,17 +553,17 @@ SIGNING_BACKEND = 'django.core.signing.TimestampSigner' # Dotted path to callable to be used as view when a request is # rejected by the CSRF middleware. -CSRF_FAILURE_VIEW = 'django.views.csrf.csrf_failure' +CSRF_FAILURE_VIEW = "django.views.csrf.csrf_failure" # Settings for CSRF cookie. -CSRF_COOKIE_NAME = 'csrftoken' +CSRF_COOKIE_NAME = "csrftoken" CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52 CSRF_COOKIE_DOMAIN = None -CSRF_COOKIE_PATH = '/' +CSRF_COOKIE_PATH = "/" CSRF_COOKIE_SECURE = False CSRF_COOKIE_HTTPONLY = False -CSRF_COOKIE_SAMESITE = 'Lax' -CSRF_HEADER_NAME = 'HTTP_X_CSRFTOKEN' +CSRF_COOKIE_SAMESITE = "Lax" +CSRF_HEADER_NAME = "HTTP_X_CSRFTOKEN" CSRF_TRUSTED_ORIGINS = [] CSRF_USE_SESSIONS = False @@ -576,7 +576,7 @@ CSRF_COOKIE_MASKED = False ############ # Class to use as messages backend -MESSAGE_STORAGE = 'django.contrib.messages.storage.fallback.FallbackStorage' +MESSAGE_STORAGE = "django.contrib.messages.storage.fallback.FallbackStorage" # Default values of MESSAGE_LEVEL and MESSAGE_TAGS are defined within # django.contrib.messages to avoid imports in this settings file. @@ -586,25 +586,25 @@ MESSAGE_STORAGE = 'django.contrib.messages.storage.fallback.FallbackStorage' ########### # The callable to use to configure logging -LOGGING_CONFIG = 'logging.config.dictConfig' +LOGGING_CONFIG = "logging.config.dictConfig" # Custom logging configuration. LOGGING = {} # Default exception reporter class used in case none has been # specifically assigned to the HttpRequest instance. -DEFAULT_EXCEPTION_REPORTER = 'django.views.debug.ExceptionReporter' +DEFAULT_EXCEPTION_REPORTER = "django.views.debug.ExceptionReporter" # Default exception reporter filter class used in case none has been # specifically assigned to the HttpRequest instance. -DEFAULT_EXCEPTION_REPORTER_FILTER = 'django.views.debug.SafeExceptionReporterFilter' +DEFAULT_EXCEPTION_REPORTER_FILTER = "django.views.debug.SafeExceptionReporterFilter" ########### # TESTING # ########### # The name of the class to use to run the test suite -TEST_RUNNER = 'django.test.runner.DiscoverRunner' +TEST_RUNNER = "django.test.runner.DiscoverRunner" # Apps that don't need to be serialized at test database creation time # (only apps with migrations are to start with) @@ -625,13 +625,13 @@ FIXTURE_DIRS = [] STATICFILES_DIRS = [] # The default file storage backend used during the build process -STATICFILES_STORAGE = 'django.contrib.staticfiles.storage.StaticFilesStorage' +STATICFILES_STORAGE = "django.contrib.staticfiles.storage.StaticFilesStorage" # List of finder classes that know how to find static files in # various locations. STATICFILES_FINDERS = [ - 'django.contrib.staticfiles.finders.FileSystemFinder', - 'django.contrib.staticfiles.finders.AppDirectoriesFinder', + "django.contrib.staticfiles.finders.FileSystemFinder", + "django.contrib.staticfiles.finders.AppDirectoriesFinder", # 'django.contrib.staticfiles.finders.DefaultStorageFinder', ] @@ -656,11 +656,11 @@ SILENCED_SYSTEM_CHECKS = [] # SECURITY MIDDLEWARE # ####################### SECURE_CONTENT_TYPE_NOSNIFF = True -SECURE_CROSS_ORIGIN_OPENER_POLICY = 'same-origin' +SECURE_CROSS_ORIGIN_OPENER_POLICY = "same-origin" SECURE_HSTS_INCLUDE_SUBDOMAINS = False SECURE_HSTS_PRELOAD = False SECURE_HSTS_SECONDS = 0 SECURE_REDIRECT_EXEMPT = [] -SECURE_REFERRER_POLICY = 'same-origin' +SECURE_REFERRER_POLICY = "same-origin" SECURE_SSL_HOST = None SECURE_SSL_REDIRECT = False diff --git a/django/conf/locale/__init__.py b/django/conf/locale/__init__.py index c57ad00440..a18d5775d5 100644 --- a/django/conf/locale/__init__.py +++ b/django/conf/locale/__init__.py @@ -8,610 +8,610 @@ follow the traditional 'fr-ca' -> 'fr' fallback logic. """ LANG_INFO = { - 'af': { - 'bidi': False, - 'code': 'af', - 'name': 'Afrikaans', - 'name_local': 'Afrikaans', - }, - 'ar': { - 'bidi': True, - 'code': 'ar', - 'name': 'Arabic', - 'name_local': 'العربيّة', - }, - 'ar-dz': { - 'bidi': True, - 'code': 'ar-dz', - 'name': 'Algerian Arabic', - 'name_local': 'العربية الجزائرية', - }, - 'ast': { - 'bidi': False, - 'code': 'ast', - 'name': 'Asturian', - 'name_local': 'asturianu', - }, - 'az': { - 'bidi': True, - 'code': 'az', - 'name': 'Azerbaijani', - 'name_local': 'Azərbaycanca', - }, - 'be': { - 'bidi': False, - 'code': 'be', - 'name': 'Belarusian', - 'name_local': 'беларуская', - }, - 'bg': { - 'bidi': False, - 'code': 'bg', - 'name': 'Bulgarian', - 'name_local': 'български', - }, - 'bn': { - 'bidi': False, - 'code': 'bn', - 'name': 'Bengali', - 'name_local': 'বাংলা', - }, - 'br': { - 'bidi': False, - 'code': 'br', - 'name': 'Breton', - 'name_local': 'brezhoneg', - }, - 'bs': { - 'bidi': False, - 'code': 'bs', - 'name': 'Bosnian', - 'name_local': 'bosanski', - }, - 'ca': { - 'bidi': False, - 'code': 'ca', - 'name': 'Catalan', - 'name_local': 'català', - }, - 'cs': { - 'bidi': False, - 'code': 'cs', - 'name': 'Czech', - 'name_local': 'česky', - }, - 'cy': { - 'bidi': False, - 'code': 'cy', - 'name': 'Welsh', - 'name_local': 'Cymraeg', - }, - 'da': { - 'bidi': False, - 'code': 'da', - 'name': 'Danish', - 'name_local': 'dansk', - }, - 'de': { - 'bidi': False, - 'code': 'de', - 'name': 'German', - 'name_local': 'Deutsch', - }, - 'dsb': { - 'bidi': False, - 'code': 'dsb', - 'name': 'Lower Sorbian', - 'name_local': 'dolnoserbski', - }, - 'el': { - 'bidi': False, - 'code': 'el', - 'name': 'Greek', - 'name_local': 'Ελληνικά', - }, - 'en': { - 'bidi': False, - 'code': 'en', - 'name': 'English', - 'name_local': 'English', - }, - 'en-au': { - 'bidi': False, - 'code': 'en-au', - 'name': 'Australian English', - 'name_local': 'Australian English', - }, - 'en-gb': { - 'bidi': False, - 'code': 'en-gb', - 'name': 'British English', - 'name_local': 'British English', - }, - 'eo': { - 'bidi': False, - 'code': 'eo', - 'name': 'Esperanto', - 'name_local': 'Esperanto', - }, - 'es': { - 'bidi': False, - 'code': 'es', - 'name': 'Spanish', - 'name_local': 'español', - }, - 'es-ar': { - 'bidi': False, - 'code': 'es-ar', - 'name': 'Argentinian Spanish', - 'name_local': 'español de Argentina', - }, - 'es-co': { - 'bidi': False, - 'code': 'es-co', - 'name': 'Colombian Spanish', - 'name_local': 'español de Colombia', - }, - 'es-mx': { - 'bidi': False, - 'code': 'es-mx', - 'name': 'Mexican Spanish', - 'name_local': 'español de Mexico', - }, - 'es-ni': { - 'bidi': False, - 'code': 'es-ni', - 'name': 'Nicaraguan Spanish', - 'name_local': 'español de Nicaragua', - }, - 'es-ve': { - 'bidi': False, - 'code': 'es-ve', - 'name': 'Venezuelan Spanish', - 'name_local': 'español de Venezuela', - }, - 'et': { - 'bidi': False, - 'code': 'et', - 'name': 'Estonian', - 'name_local': 'eesti', - }, - 'eu': { - 'bidi': False, - 'code': 'eu', - 'name': 'Basque', - 'name_local': 'Basque', - }, - 'fa': { - 'bidi': True, - 'code': 'fa', - 'name': 'Persian', - 'name_local': 'فارسی', - }, - 'fi': { - 'bidi': False, - 'code': 'fi', - 'name': 'Finnish', - 'name_local': 'suomi', - }, - 'fr': { - 'bidi': False, - 'code': 'fr', - 'name': 'French', - 'name_local': 'français', - }, - 'fy': { - 'bidi': False, - 'code': 'fy', - 'name': 'Frisian', - 'name_local': 'frysk', - }, - 'ga': { - 'bidi': False, - 'code': 'ga', - 'name': 'Irish', - 'name_local': 'Gaeilge', - }, - 'gd': { - 'bidi': False, - 'code': 'gd', - 'name': 'Scottish Gaelic', - 'name_local': 'Gàidhlig', - }, - 'gl': { - 'bidi': False, - 'code': 'gl', - 'name': 'Galician', - 'name_local': 'galego', - }, - 'he': { - 'bidi': True, - 'code': 'he', - 'name': 'Hebrew', - 'name_local': 'עברית', - }, - 'hi': { - 'bidi': False, - 'code': 'hi', - 'name': 'Hindi', - 'name_local': 'हिंदी', - }, - 'hr': { - 'bidi': False, - 'code': 'hr', - 'name': 'Croatian', - 'name_local': 'Hrvatski', - }, - 'hsb': { - 'bidi': False, - 'code': 'hsb', - 'name': 'Upper Sorbian', - 'name_local': 'hornjoserbsce', - }, - 'hu': { - 'bidi': False, - 'code': 'hu', - 'name': 'Hungarian', - 'name_local': 'Magyar', - }, - 'hy': { - 'bidi': False, - 'code': 'hy', - 'name': 'Armenian', - 'name_local': 'հայերեն', - }, - 'ia': { - 'bidi': False, - 'code': 'ia', - 'name': 'Interlingua', - 'name_local': 'Interlingua', - }, - 'io': { - 'bidi': False, - 'code': 'io', - 'name': 'Ido', - 'name_local': 'ido', - }, - 'id': { - 'bidi': False, - 'code': 'id', - 'name': 'Indonesian', - 'name_local': 'Bahasa Indonesia', - }, - 'ig': { - 'bidi': False, - 'code': 'ig', - 'name': 'Igbo', - 'name_local': 'Asụsụ Ìgbò', - }, - 'is': { - 'bidi': False, - 'code': 'is', - 'name': 'Icelandic', - 'name_local': 'Íslenska', - }, - 'it': { - 'bidi': False, - 'code': 'it', - 'name': 'Italian', - 'name_local': 'italiano', - }, - 'ja': { - 'bidi': False, - 'code': 'ja', - 'name': 'Japanese', - 'name_local': '日本語', - }, - 'ka': { - 'bidi': False, - 'code': 'ka', - 'name': 'Georgian', - 'name_local': 'ქართული', - }, - 'kab': { - 'bidi': False, - 'code': 'kab', - 'name': 'Kabyle', - 'name_local': 'taqbaylit', - }, - 'kk': { - 'bidi': False, - 'code': 'kk', - 'name': 'Kazakh', - 'name_local': 'Қазақ', - }, - 'km': { - 'bidi': False, - 'code': 'km', - 'name': 'Khmer', - 'name_local': 'Khmer', - }, - 'kn': { - 'bidi': False, - 'code': 'kn', - 'name': 'Kannada', - 'name_local': 'Kannada', - }, - 'ko': { - 'bidi': False, - 'code': 'ko', - 'name': 'Korean', - 'name_local': '한국어', - }, - 'ky': { - 'bidi': False, - 'code': 'ky', - 'name': 'Kyrgyz', - 'name_local': 'Кыргызча', - }, - 'lb': { - 'bidi': False, - 'code': 'lb', - 'name': 'Luxembourgish', - 'name_local': 'Lëtzebuergesch', - }, - 'lt': { - 'bidi': False, - 'code': 'lt', - 'name': 'Lithuanian', - 'name_local': 'Lietuviškai', - }, - 'lv': { - 'bidi': False, - 'code': 'lv', - 'name': 'Latvian', - 'name_local': 'latviešu', - }, - 'mk': { - 'bidi': False, - 'code': 'mk', - 'name': 'Macedonian', - 'name_local': 'Македонски', - }, - 'ml': { - 'bidi': False, - 'code': 'ml', - 'name': 'Malayalam', - 'name_local': 'മലയാളം', - }, - 'mn': { - 'bidi': False, - 'code': 'mn', - 'name': 'Mongolian', - 'name_local': 'Mongolian', - }, - 'mr': { - 'bidi': False, - 'code': 'mr', - 'name': 'Marathi', - 'name_local': 'मराठी', - }, - 'ms': { - 'bidi': False, - 'code': 'ms', - 'name': 'Malay', - 'name_local': 'Bahasa Melayu', - }, - 'my': { - 'bidi': False, - 'code': 'my', - 'name': 'Burmese', - 'name_local': 'မြန်မာဘာသာ', - }, - 'nb': { - 'bidi': False, - 'code': 'nb', - 'name': 'Norwegian Bokmal', - 'name_local': 'norsk (bokmål)', - }, - 'ne': { - 'bidi': False, - 'code': 'ne', - 'name': 'Nepali', - 'name_local': 'नेपाली', - }, - 'nl': { - 'bidi': False, - 'code': 'nl', - 'name': 'Dutch', - 'name_local': 'Nederlands', - }, - 'nn': { - 'bidi': False, - 'code': 'nn', - 'name': 'Norwegian Nynorsk', - 'name_local': 'norsk (nynorsk)', - }, - 'no': { - 'bidi': False, - 'code': 'no', - 'name': 'Norwegian', - 'name_local': 'norsk', - }, - 'os': { - 'bidi': False, - 'code': 'os', - 'name': 'Ossetic', - 'name_local': 'Ирон', - }, - 'pa': { - 'bidi': False, - 'code': 'pa', - 'name': 'Punjabi', - 'name_local': 'Punjabi', - }, - 'pl': { - 'bidi': False, - 'code': 'pl', - 'name': 'Polish', - 'name_local': 'polski', - }, - 'pt': { - 'bidi': False, - 'code': 'pt', - 'name': 'Portuguese', - 'name_local': 'Português', - }, - 'pt-br': { - 'bidi': False, - 'code': 'pt-br', - 'name': 'Brazilian Portuguese', - 'name_local': 'Português Brasileiro', - }, - 'ro': { - 'bidi': False, - 'code': 'ro', - 'name': 'Romanian', - 'name_local': 'Română', - }, - 'ru': { - 'bidi': False, - 'code': 'ru', - 'name': 'Russian', - 'name_local': 'Русский', - }, - 'sk': { - 'bidi': False, - 'code': 'sk', - 'name': 'Slovak', - 'name_local': 'Slovensky', - }, - 'sl': { - 'bidi': False, - 'code': 'sl', - 'name': 'Slovenian', - 'name_local': 'Slovenščina', - }, - 'sq': { - 'bidi': False, - 'code': 'sq', - 'name': 'Albanian', - 'name_local': 'shqip', - }, - 'sr': { - 'bidi': False, - 'code': 'sr', - 'name': 'Serbian', - 'name_local': 'српски', - }, - 'sr-latn': { - 'bidi': False, - 'code': 'sr-latn', - 'name': 'Serbian Latin', - 'name_local': 'srpski (latinica)', - }, - 'sv': { - 'bidi': False, - 'code': 'sv', - 'name': 'Swedish', - 'name_local': 'svenska', - }, - 'sw': { - 'bidi': False, - 'code': 'sw', - 'name': 'Swahili', - 'name_local': 'Kiswahili', - }, - 'ta': { - 'bidi': False, - 'code': 'ta', - 'name': 'Tamil', - 'name_local': 'தமிழ்', - }, - 'te': { - 'bidi': False, - 'code': 'te', - 'name': 'Telugu', - 'name_local': 'తెలుగు', - }, - 'tg': { - 'bidi': False, - 'code': 'tg', - 'name': 'Tajik', - 'name_local': 'тоҷикӣ', - }, - 'th': { - 'bidi': False, - 'code': 'th', - 'name': 'Thai', - 'name_local': 'ภาษาไทย', - }, - 'tk': { - 'bidi': False, - 'code': 'tk', - 'name': 'Turkmen', - 'name_local': 'Türkmençe', - }, - 'tr': { - 'bidi': False, - 'code': 'tr', - 'name': 'Turkish', - 'name_local': 'Türkçe', - }, - 'tt': { - 'bidi': False, - 'code': 'tt', - 'name': 'Tatar', - 'name_local': 'Татарча', - }, - 'udm': { - 'bidi': False, - 'code': 'udm', - 'name': 'Udmurt', - 'name_local': 'Удмурт', - }, - 'uk': { - 'bidi': False, - 'code': 'uk', - 'name': 'Ukrainian', - 'name_local': 'Українська', - }, - 'ur': { - 'bidi': True, - 'code': 'ur', - 'name': 'Urdu', - 'name_local': 'اردو', - }, - 'uz': { - 'bidi': False, - 'code': 'uz', - 'name': 'Uzbek', - 'name_local': 'oʻzbek tili', - }, - 'vi': { - 'bidi': False, - 'code': 'vi', - 'name': 'Vietnamese', - 'name_local': 'Tiếng Việt', - }, - 'zh-cn': { - 'fallback': ['zh-hans'], - }, - 'zh-hans': { - 'bidi': False, - 'code': 'zh-hans', - 'name': 'Simplified Chinese', - 'name_local': '简体中文', - }, - 'zh-hant': { - 'bidi': False, - 'code': 'zh-hant', - 'name': 'Traditional Chinese', - 'name_local': '繁體中文', - }, - 'zh-hk': { - 'fallback': ['zh-hant'], - }, - 'zh-mo': { - 'fallback': ['zh-hant'], - }, - 'zh-my': { - 'fallback': ['zh-hans'], - }, - 'zh-sg': { - 'fallback': ['zh-hans'], - }, - 'zh-tw': { - 'fallback': ['zh-hant'], + "af": { + "bidi": False, + "code": "af", + "name": "Afrikaans", + "name_local": "Afrikaans", + }, + "ar": { + "bidi": True, + "code": "ar", + "name": "Arabic", + "name_local": "العربيّة", + }, + "ar-dz": { + "bidi": True, + "code": "ar-dz", + "name": "Algerian Arabic", + "name_local": "العربية الجزائرية", + }, + "ast": { + "bidi": False, + "code": "ast", + "name": "Asturian", + "name_local": "asturianu", + }, + "az": { + "bidi": True, + "code": "az", + "name": "Azerbaijani", + "name_local": "Azərbaycanca", + }, + "be": { + "bidi": False, + "code": "be", + "name": "Belarusian", + "name_local": "беларуская", + }, + "bg": { + "bidi": False, + "code": "bg", + "name": "Bulgarian", + "name_local": "български", + }, + "bn": { + "bidi": False, + "code": "bn", + "name": "Bengali", + "name_local": "বাংলা", + }, + "br": { + "bidi": False, + "code": "br", + "name": "Breton", + "name_local": "brezhoneg", + }, + "bs": { + "bidi": False, + "code": "bs", + "name": "Bosnian", + "name_local": "bosanski", + }, + "ca": { + "bidi": False, + "code": "ca", + "name": "Catalan", + "name_local": "català", + }, + "cs": { + "bidi": False, + "code": "cs", + "name": "Czech", + "name_local": "česky", + }, + "cy": { + "bidi": False, + "code": "cy", + "name": "Welsh", + "name_local": "Cymraeg", + }, + "da": { + "bidi": False, + "code": "da", + "name": "Danish", + "name_local": "dansk", + }, + "de": { + "bidi": False, + "code": "de", + "name": "German", + "name_local": "Deutsch", + }, + "dsb": { + "bidi": False, + "code": "dsb", + "name": "Lower Sorbian", + "name_local": "dolnoserbski", + }, + "el": { + "bidi": False, + "code": "el", + "name": "Greek", + "name_local": "Ελληνικά", + }, + "en": { + "bidi": False, + "code": "en", + "name": "English", + "name_local": "English", + }, + "en-au": { + "bidi": False, + "code": "en-au", + "name": "Australian English", + "name_local": "Australian English", + }, + "en-gb": { + "bidi": False, + "code": "en-gb", + "name": "British English", + "name_local": "British English", + }, + "eo": { + "bidi": False, + "code": "eo", + "name": "Esperanto", + "name_local": "Esperanto", + }, + "es": { + "bidi": False, + "code": "es", + "name": "Spanish", + "name_local": "español", + }, + "es-ar": { + "bidi": False, + "code": "es-ar", + "name": "Argentinian Spanish", + "name_local": "español de Argentina", + }, + "es-co": { + "bidi": False, + "code": "es-co", + "name": "Colombian Spanish", + "name_local": "español de Colombia", + }, + "es-mx": { + "bidi": False, + "code": "es-mx", + "name": "Mexican Spanish", + "name_local": "español de Mexico", + }, + "es-ni": { + "bidi": False, + "code": "es-ni", + "name": "Nicaraguan Spanish", + "name_local": "español de Nicaragua", + }, + "es-ve": { + "bidi": False, + "code": "es-ve", + "name": "Venezuelan Spanish", + "name_local": "español de Venezuela", + }, + "et": { + "bidi": False, + "code": "et", + "name": "Estonian", + "name_local": "eesti", + }, + "eu": { + "bidi": False, + "code": "eu", + "name": "Basque", + "name_local": "Basque", + }, + "fa": { + "bidi": True, + "code": "fa", + "name": "Persian", + "name_local": "فارسی", + }, + "fi": { + "bidi": False, + "code": "fi", + "name": "Finnish", + "name_local": "suomi", + }, + "fr": { + "bidi": False, + "code": "fr", + "name": "French", + "name_local": "français", + }, + "fy": { + "bidi": False, + "code": "fy", + "name": "Frisian", + "name_local": "frysk", + }, + "ga": { + "bidi": False, + "code": "ga", + "name": "Irish", + "name_local": "Gaeilge", + }, + "gd": { + "bidi": False, + "code": "gd", + "name": "Scottish Gaelic", + "name_local": "Gàidhlig", + }, + "gl": { + "bidi": False, + "code": "gl", + "name": "Galician", + "name_local": "galego", + }, + "he": { + "bidi": True, + "code": "he", + "name": "Hebrew", + "name_local": "עברית", + }, + "hi": { + "bidi": False, + "code": "hi", + "name": "Hindi", + "name_local": "हिंदी", + }, + "hr": { + "bidi": False, + "code": "hr", + "name": "Croatian", + "name_local": "Hrvatski", + }, + "hsb": { + "bidi": False, + "code": "hsb", + "name": "Upper Sorbian", + "name_local": "hornjoserbsce", + }, + "hu": { + "bidi": False, + "code": "hu", + "name": "Hungarian", + "name_local": "Magyar", + }, + "hy": { + "bidi": False, + "code": "hy", + "name": "Armenian", + "name_local": "հայերեն", + }, + "ia": { + "bidi": False, + "code": "ia", + "name": "Interlingua", + "name_local": "Interlingua", + }, + "io": { + "bidi": False, + "code": "io", + "name": "Ido", + "name_local": "ido", + }, + "id": { + "bidi": False, + "code": "id", + "name": "Indonesian", + "name_local": "Bahasa Indonesia", + }, + "ig": { + "bidi": False, + "code": "ig", + "name": "Igbo", + "name_local": "Asụsụ Ìgbò", + }, + "is": { + "bidi": False, + "code": "is", + "name": "Icelandic", + "name_local": "Íslenska", + }, + "it": { + "bidi": False, + "code": "it", + "name": "Italian", + "name_local": "italiano", + }, + "ja": { + "bidi": False, + "code": "ja", + "name": "Japanese", + "name_local": "日本語", + }, + "ka": { + "bidi": False, + "code": "ka", + "name": "Georgian", + "name_local": "ქართული", + }, + "kab": { + "bidi": False, + "code": "kab", + "name": "Kabyle", + "name_local": "taqbaylit", + }, + "kk": { + "bidi": False, + "code": "kk", + "name": "Kazakh", + "name_local": "Қазақ", + }, + "km": { + "bidi": False, + "code": "km", + "name": "Khmer", + "name_local": "Khmer", + }, + "kn": { + "bidi": False, + "code": "kn", + "name": "Kannada", + "name_local": "Kannada", + }, + "ko": { + "bidi": False, + "code": "ko", + "name": "Korean", + "name_local": "한국어", + }, + "ky": { + "bidi": False, + "code": "ky", + "name": "Kyrgyz", + "name_local": "Кыргызча", + }, + "lb": { + "bidi": False, + "code": "lb", + "name": "Luxembourgish", + "name_local": "Lëtzebuergesch", + }, + "lt": { + "bidi": False, + "code": "lt", + "name": "Lithuanian", + "name_local": "Lietuviškai", + }, + "lv": { + "bidi": False, + "code": "lv", + "name": "Latvian", + "name_local": "latviešu", + }, + "mk": { + "bidi": False, + "code": "mk", + "name": "Macedonian", + "name_local": "Македонски", + }, + "ml": { + "bidi": False, + "code": "ml", + "name": "Malayalam", + "name_local": "മലയാളം", + }, + "mn": { + "bidi": False, + "code": "mn", + "name": "Mongolian", + "name_local": "Mongolian", + }, + "mr": { + "bidi": False, + "code": "mr", + "name": "Marathi", + "name_local": "मराठी", + }, + "ms": { + "bidi": False, + "code": "ms", + "name": "Malay", + "name_local": "Bahasa Melayu", + }, + "my": { + "bidi": False, + "code": "my", + "name": "Burmese", + "name_local": "မြန်မာဘာသာ", + }, + "nb": { + "bidi": False, + "code": "nb", + "name": "Norwegian Bokmal", + "name_local": "norsk (bokmål)", + }, + "ne": { + "bidi": False, + "code": "ne", + "name": "Nepali", + "name_local": "नेपाली", + }, + "nl": { + "bidi": False, + "code": "nl", + "name": "Dutch", + "name_local": "Nederlands", + }, + "nn": { + "bidi": False, + "code": "nn", + "name": "Norwegian Nynorsk", + "name_local": "norsk (nynorsk)", + }, + "no": { + "bidi": False, + "code": "no", + "name": "Norwegian", + "name_local": "norsk", + }, + "os": { + "bidi": False, + "code": "os", + "name": "Ossetic", + "name_local": "Ирон", + }, + "pa": { + "bidi": False, + "code": "pa", + "name": "Punjabi", + "name_local": "Punjabi", + }, + "pl": { + "bidi": False, + "code": "pl", + "name": "Polish", + "name_local": "polski", + }, + "pt": { + "bidi": False, + "code": "pt", + "name": "Portuguese", + "name_local": "Português", + }, + "pt-br": { + "bidi": False, + "code": "pt-br", + "name": "Brazilian Portuguese", + "name_local": "Português Brasileiro", + }, + "ro": { + "bidi": False, + "code": "ro", + "name": "Romanian", + "name_local": "Română", + }, + "ru": { + "bidi": False, + "code": "ru", + "name": "Russian", + "name_local": "Русский", + }, + "sk": { + "bidi": False, + "code": "sk", + "name": "Slovak", + "name_local": "Slovensky", + }, + "sl": { + "bidi": False, + "code": "sl", + "name": "Slovenian", + "name_local": "Slovenščina", + }, + "sq": { + "bidi": False, + "code": "sq", + "name": "Albanian", + "name_local": "shqip", + }, + "sr": { + "bidi": False, + "code": "sr", + "name": "Serbian", + "name_local": "српски", + }, + "sr-latn": { + "bidi": False, + "code": "sr-latn", + "name": "Serbian Latin", + "name_local": "srpski (latinica)", + }, + "sv": { + "bidi": False, + "code": "sv", + "name": "Swedish", + "name_local": "svenska", + }, + "sw": { + "bidi": False, + "code": "sw", + "name": "Swahili", + "name_local": "Kiswahili", + }, + "ta": { + "bidi": False, + "code": "ta", + "name": "Tamil", + "name_local": "தமிழ்", + }, + "te": { + "bidi": False, + "code": "te", + "name": "Telugu", + "name_local": "తెలుగు", + }, + "tg": { + "bidi": False, + "code": "tg", + "name": "Tajik", + "name_local": "тоҷикӣ", + }, + "th": { + "bidi": False, + "code": "th", + "name": "Thai", + "name_local": "ภาษาไทย", + }, + "tk": { + "bidi": False, + "code": "tk", + "name": "Turkmen", + "name_local": "Türkmençe", + }, + "tr": { + "bidi": False, + "code": "tr", + "name": "Turkish", + "name_local": "Türkçe", + }, + "tt": { + "bidi": False, + "code": "tt", + "name": "Tatar", + "name_local": "Татарча", + }, + "udm": { + "bidi": False, + "code": "udm", + "name": "Udmurt", + "name_local": "Удмурт", + }, + "uk": { + "bidi": False, + "code": "uk", + "name": "Ukrainian", + "name_local": "Українська", + }, + "ur": { + "bidi": True, + "code": "ur", + "name": "Urdu", + "name_local": "اردو", + }, + "uz": { + "bidi": False, + "code": "uz", + "name": "Uzbek", + "name_local": "oʻzbek tili", + }, + "vi": { + "bidi": False, + "code": "vi", + "name": "Vietnamese", + "name_local": "Tiếng Việt", + }, + "zh-cn": { + "fallback": ["zh-hans"], + }, + "zh-hans": { + "bidi": False, + "code": "zh-hans", + "name": "Simplified Chinese", + "name_local": "简体中文", + }, + "zh-hant": { + "bidi": False, + "code": "zh-hant", + "name": "Traditional Chinese", + "name_local": "繁體中文", + }, + "zh-hk": { + "fallback": ["zh-hant"], + }, + "zh-mo": { + "fallback": ["zh-hant"], + }, + "zh-my": { + "fallback": ["zh-hans"], + }, + "zh-sg": { + "fallback": ["zh-hans"], + }, + "zh-tw": { + "fallback": ["zh-hant"], }, } diff --git a/django/conf/locale/ar/formats.py b/django/conf/locale/ar/formats.py index 19cc8601b7..8008ce6ec4 100644 --- a/django/conf/locale/ar/formats.py +++ b/django/conf/locale/ar/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F، Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "j F، Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd‏/m‏/Y' +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d‏/m‏/Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'd‏/m‏/Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/ar_DZ/formats.py b/django/conf/locale/ar_DZ/formats.py index e091e1788d..cbd361d62e 100644 --- a/django/conf/locale/ar_DZ/formats.py +++ b/django/conf/locale/ar_DZ/formats.py @@ -2,28 +2,28 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j F Y' -SHORT_DATETIME_FORMAT = 'j F Y H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j F Y" +SHORT_DATETIME_FORMAT = "j F Y H:i" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y/%m/%d', # '2006/10/25' + "%Y/%m/%d", # '2006/10/25' ] TIME_INPUT_FORMATS = [ - '%H:%M', # '14:30 - '%H:%M:%S', # '14:30:59' + "%H:%M", # '14:30 + "%H:%M:%S", # '14:30:59' ] DATETIME_INPUT_FORMATS = [ - '%Y/%m/%d %H:%M', # '2006/10/25 14:30' - '%Y/%m/%d %H:%M:%S', # '2006/10/25 14:30:59' + "%Y/%m/%d %H:%M", # '2006/10/25 14:30' + "%Y/%m/%d %H:%M:%S", # '2006/10/25 14:30:59' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/az/formats.py b/django/conf/locale/az/formats.py index 6f655d18ef..253b6dddf5 100644 --- a/django/conf/locale/az/formats.py +++ b/django/conf/locale/az/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j E Y, G:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j E Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j E Y, G:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/bg/formats.py b/django/conf/locale/bg/formats.py index b7d0c3b53d..ee90c5b08f 100644 --- a/django/conf/locale/bg/formats.py +++ b/django/conf/locale/bg/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' -TIME_FORMAT = 'H:i' +DATE_FORMAT = "d F Y" +TIME_FORMAT = "H:i" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'd.m.Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = ' ' # Non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = " " # Non-breaking space # NUMBER_GROUPING = diff --git a/django/conf/locale/bn/formats.py b/django/conf/locale/bn/formats.py index 6205fb95cb..9d1bb09d13 100644 --- a/django/conf/locale/bn/formats.py +++ b/django/conf/locale/bn/formats.py @@ -2,31 +2,31 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F, Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "j F, Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M, Y' +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M, Y" # SHORT_DATETIME_FORMAT = FIRST_DAY_OF_WEEK = 6 # Saturday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # 25/10/2016 - '%d/%m/%y', # 25/10/16 - '%d-%m-%Y', # 25-10-2016 - '%d-%m-%y', # 25-10-16 + "%d/%m/%Y", # 25/10/2016 + "%d/%m/%y", # 25/10/16 + "%d-%m-%Y", # 25-10-2016 + "%d-%m-%y", # 25-10-16 ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # 14:30:59 - '%H:%M', # 14:30 + "%H:%M:%S", # 14:30:59 + "%H:%M", # 14:30 ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # 25/10/2006 14:30:59 - '%d/%m/%Y %H:%M', # 25/10/2006 14:30 + "%d/%m/%Y %H:%M:%S", # 25/10/2006 14:30:59 + "%d/%m/%Y %H:%M", # 25/10/2006 14:30 ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/bs/formats.py b/django/conf/locale/bs/formats.py index 25d9b40e45..a15e7099e4 100644 --- a/django/conf/locale/bs/formats.py +++ b/django/conf/locale/bs/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. N Y.' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j. N. Y. G:i T' -YEAR_MONTH_FORMAT = 'F Y.' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'Y M j' +DATE_FORMAT = "j. N Y." +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j. N. Y. G:i T" +YEAR_MONTH_FORMAT = "F Y." +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "Y M j" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'Y M j' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/ca/formats.py b/django/conf/locale/ca/formats.py index f6b2957d11..2f91009119 100644 --- a/django/conf/locale/ca/formats.py +++ b/django/conf/locale/ca/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\e\s G:i' -YEAR_MONTH_FORMAT = r'F \d\e\l Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y G:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\e\s G:i" +YEAR_MONTH_FORMAT = r"F \d\e\l Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y G:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '31/12/2009' - '%d/%m/%y', # '31/12/09' + "%d/%m/%Y", # '31/12/2009' + "%d/%m/%y", # '31/12/09' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/cs/formats.py b/django/conf/locale/cs/formats.py index 5c6d8f9c0f..e4a7ab9946 100644 --- a/django/conf/locale/cs/formats.py +++ b/django/conf/locale/cs/formats.py @@ -2,42 +2,42 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. E Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j. E Y G:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y G:i' +DATE_FORMAT = "j. E Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j. E Y G:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y G:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '05.01.2006' - '%d.%m.%y', # '05.01.06' - '%d. %m. %Y', # '5. 1. 2006' - '%d. %m. %y', # '5. 1. 06' + "%d.%m.%Y", # '05.01.2006' + "%d.%m.%y", # '05.01.06' + "%d. %m. %Y", # '5. 1. 2006' + "%d. %m. %y", # '5. 1. 06' # "%d. %B %Y", # '25. October 2006' # "%d. %b. %Y", # '25. Oct. 2006' ] # Kept ISO formats as one is in first position TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '04:30:59' - '%H.%M', # '04.30' - '%H:%M', # '04:30' + "%H:%M:%S", # '04:30:59' + "%H.%M", # '04.30' + "%H:%M", # '04:30' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '05.01.2006 04:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '05.01.2006 04:30:59.000200' - '%d.%m.%Y %H.%M', # '05.01.2006 04.30' - '%d.%m.%Y %H:%M', # '05.01.2006 04:30' - '%d. %m. %Y %H:%M:%S', # '05. 01. 2006 04:30:59' - '%d. %m. %Y %H:%M:%S.%f', # '05. 01. 2006 04:30:59.000200' - '%d. %m. %Y %H.%M', # '05. 01. 2006 04.30' - '%d. %m. %Y %H:%M', # '05. 01. 2006 04:30' - '%Y-%m-%d %H.%M', # '2006-01-05 04.30' + "%d.%m.%Y %H:%M:%S", # '05.01.2006 04:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '05.01.2006 04:30:59.000200' + "%d.%m.%Y %H.%M", # '05.01.2006 04.30' + "%d.%m.%Y %H:%M", # '05.01.2006 04:30' + "%d. %m. %Y %H:%M:%S", # '05. 01. 2006 04:30:59' + "%d. %m. %Y %H:%M:%S.%f", # '05. 01. 2006 04:30:59.000200' + "%d. %m. %Y %H.%M", # '05. 01. 2006 04.30' + "%d. %m. %Y %H:%M", # '05. 01. 2006 04:30' + "%Y-%m-%d %H.%M", # '2006-01-05 04.30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/cy/formats.py b/django/conf/locale/cy/formats.py index bcdcb111df..eaef6a618f 100644 --- a/django/conf/locale/cy/formats.py +++ b/django/conf/locale/cy/formats.py @@ -2,32 +2,32 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' # '25 Hydref 2006' -TIME_FORMAT = 'P' # '2:30 y.b.' -DATETIME_FORMAT = 'j F Y, P' # '25 Hydref 2006, 2:30 y.b.' -YEAR_MONTH_FORMAT = 'F Y' # 'Hydref 2006' -MONTH_DAY_FORMAT = 'j F' # '25 Hydref' -SHORT_DATE_FORMAT = 'd/m/Y' # '25/10/2006' -SHORT_DATETIME_FORMAT = 'd/m/Y P' # '25/10/2006 2:30 y.b.' -FIRST_DAY_OF_WEEK = 1 # 'Dydd Llun' +DATE_FORMAT = "j F Y" # '25 Hydref 2006' +TIME_FORMAT = "P" # '2:30 y.b.' +DATETIME_FORMAT = "j F Y, P" # '25 Hydref 2006, 2:30 y.b.' +YEAR_MONTH_FORMAT = "F Y" # 'Hydref 2006' +MONTH_DAY_FORMAT = "j F" # '25 Hydref' +SHORT_DATE_FORMAT = "d/m/Y" # '25/10/2006' +SHORT_DATETIME_FORMAT = "d/m/Y P" # '25/10/2006 2:30 y.b.' +FIRST_DAY_OF_WEEK = 1 # 'Dydd Llun' # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/da/formats.py b/django/conf/locale/da/formats.py index 6237a7209d..58292084fb 100644 --- a/django/conf/locale/da/formats.py +++ b/django/conf/locale/da/formats.py @@ -2,25 +2,25 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' + "%d.%m.%Y", # '25.10.2006' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/de/formats.py b/django/conf/locale/de/formats.py index 65d58b5f2b..45953ce238 100644 --- a/django/conf/locale/de/formats.py +++ b/django/conf/locale/de/formats.py @@ -2,28 +2,28 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' # "%d. %B %Y", # '25. October 2006' # "%d. %b. %Y", # '25. Oct. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/de_CH/formats.py b/django/conf/locale/de_CH/formats.py index 9ee69120c8..f42dd48739 100644 --- a/django/conf/locale/de_CH/formats.py +++ b/django/conf/locale/de_CH/formats.py @@ -2,27 +2,27 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' # "%d. %B %Y", # '25. October 2006' # "%d. %b. %Y", # '25. Oct. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' ] # these are the separators for non-monetary numbers. For monetary numbers, @@ -30,6 +30,6 @@ DATETIME_INPUT_FORMATS = [ # ' (single quote). # For details, please refer to the documentation and the following link: # https://www.bk.admin.ch/bk/de/home/dokumentation/sprachen/hilfsmittel-textredaktion/schreibweisungen.html -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/el/formats.py b/django/conf/locale/el/formats.py index 7aa53a0f1a..25c8ef7d37 100644 --- a/django/conf/locale/el/formats.py +++ b/django/conf/locale/el/formats.py @@ -2,33 +2,33 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd/m/Y' -TIME_FORMAT = 'P' -DATETIME_FORMAT = 'd/m/Y P' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y P' +DATE_FORMAT = "d/m/Y" +TIME_FORMAT = "P" +DATETIME_FORMAT = "d/m/Y P" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y P" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%Y-%m-%d', # '2006-10-25' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%Y-%m-%d", # '2006-10-25' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/en/formats.py b/django/conf/locale/en/formats.py index 884deff17e..f9d143b717 100644 --- a/django/conf/locale/en/formats.py +++ b/django/conf/locale/en/formats.py @@ -4,19 +4,19 @@ # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date # Formatting for date objects. -DATE_FORMAT = 'N j, Y' +DATE_FORMAT = "N j, Y" # Formatting for time objects. -TIME_FORMAT = 'P' +TIME_FORMAT = "P" # Formatting for datetime objects. -DATETIME_FORMAT = 'N j, Y, P' +DATETIME_FORMAT = "N j, Y, P" # Formatting for date objects when only the year and month are relevant. -YEAR_MONTH_FORMAT = 'F Y' +YEAR_MONTH_FORMAT = "F Y" # Formatting for date objects when only the month and day are relevant. -MONTH_DAY_FORMAT = 'F j' +MONTH_DAY_FORMAT = "F j" # Short formatting for date objects. -SHORT_DATE_FORMAT = 'm/d/Y' +SHORT_DATE_FORMAT = "m/d/Y" # Short formatting for datetime objects. -SHORT_DATETIME_FORMAT = 'm/d/Y P' +SHORT_DATETIME_FORMAT = "m/d/Y P" # First day of week, to be used on calendars. # 0 means Sunday, 1 means Monday... FIRST_DAY_OF_WEEK = 0 @@ -27,39 +27,39 @@ FIRST_DAY_OF_WEEK = 0 # Note that these format strings are different from the ones to display dates. # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' - '%b %d %Y', # 'Oct 25 2006' - '%b %d, %Y', # 'Oct 25, 2006' - '%d %b %Y', # '25 Oct 2006' - '%d %b, %Y', # '25 Oct, 2006' - '%B %d %Y', # 'October 25 2006' - '%B %d, %Y', # 'October 25, 2006' - '%d %B %Y', # '25 October 2006' - '%d %B, %Y', # '25 October, 2006' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' + "%b %d %Y", # 'Oct 25 2006' + "%b %d, %Y", # 'Oct 25, 2006' + "%d %b %Y", # '25 Oct 2006' + "%d %b, %Y", # '25 Oct, 2006' + "%B %d %Y", # 'October 25 2006' + "%B %d, %Y", # 'October 25, 2006' + "%d %B %Y", # '25 October 2006' + "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' ] # Decimal separator symbol. -DECIMAL_SEPARATOR = '.' +DECIMAL_SEPARATOR = "." # Thousand separator symbol. -THOUSAND_SEPARATOR = ',' +THOUSAND_SEPARATOR = "," # Number of digits that will be together, when splitting them by # THOUSAND_SEPARATOR. 0 means no grouping, 3 means splitting by thousands. NUMBER_GROUPING = 3 diff --git a/django/conf/locale/en_AU/formats.py b/django/conf/locale/en_AU/formats.py index d6a3b7235d..caa6f7201c 100644 --- a/django/conf/locale/en_AU/formats.py +++ b/django/conf/locale/en_AU/formats.py @@ -2,20 +2,20 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j M Y' # '25 Oct 2006' -TIME_FORMAT = 'P' # '2:30 p.m.' -DATETIME_FORMAT = 'j M Y, P' # '25 Oct 2006, 2:30 p.m.' -YEAR_MONTH_FORMAT = 'F Y' # 'October 2006' -MONTH_DAY_FORMAT = 'j F' # '25 October' -SHORT_DATE_FORMAT = 'd/m/Y' # '25/10/2006' -SHORT_DATETIME_FORMAT = 'd/m/Y P' # '25/10/2006 2:30 p.m.' -FIRST_DAY_OF_WEEK = 0 # Sunday +DATE_FORMAT = "j M Y" # '25 Oct 2006' +TIME_FORMAT = "P" # '2:30 p.m.' +DATETIME_FORMAT = "j M Y, P" # '25 Oct 2006, 2:30 p.m.' +YEAR_MONTH_FORMAT = "F Y" # 'October 2006' +MONTH_DAY_FORMAT = "j F" # '25 October' +SHORT_DATE_FORMAT = "d/m/Y" # '25/10/2006' +SHORT_DATETIME_FORMAT = "d/m/Y P" # '25/10/2006 2:30 p.m.' +FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' # "%b %d %Y", # 'Oct 25 2006' # "%b %d, %Y", # 'Oct 25, 2006' # "%d %b %Y", # '25 Oct 2006' @@ -26,16 +26,16 @@ DATE_INPUT_FORMATS = [ # "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/en_GB/formats.py b/django/conf/locale/en_GB/formats.py index 84ab5a8a93..bc90da59bc 100644 --- a/django/conf/locale/en_GB/formats.py +++ b/django/conf/locale/en_GB/formats.py @@ -2,20 +2,20 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j M Y' # '25 Oct 2006' -TIME_FORMAT = 'P' # '2:30 p.m.' -DATETIME_FORMAT = 'j M Y, P' # '25 Oct 2006, 2:30 p.m.' -YEAR_MONTH_FORMAT = 'F Y' # 'October 2006' -MONTH_DAY_FORMAT = 'j F' # '25 October' -SHORT_DATE_FORMAT = 'd/m/Y' # '25/10/2006' -SHORT_DATETIME_FORMAT = 'd/m/Y P' # '25/10/2006 2:30 p.m.' -FIRST_DAY_OF_WEEK = 1 # Monday +DATE_FORMAT = "j M Y" # '25 Oct 2006' +TIME_FORMAT = "P" # '2:30 p.m.' +DATETIME_FORMAT = "j M Y, P" # '25 Oct 2006, 2:30 p.m.' +YEAR_MONTH_FORMAT = "F Y" # 'October 2006' +MONTH_DAY_FORMAT = "j F" # '25 October' +SHORT_DATE_FORMAT = "d/m/Y" # '25/10/2006' +SHORT_DATETIME_FORMAT = "d/m/Y P" # '25/10/2006 2:30 p.m.' +FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' # "%b %d %Y", # 'Oct 25 2006' # "%b %d, %Y", # 'Oct 25, 2006' # "%d %b %Y", # '25 Oct 2006' @@ -26,16 +26,16 @@ DATE_INPUT_FORMATS = [ # "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/eo/formats.py b/django/conf/locale/eo/formats.py index 604e5f5bfa..d1346d1c36 100644 --- a/django/conf/locale/eo/formats.py +++ b/django/conf/locale/eo/formats.py @@ -2,46 +2,43 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j\-\a \d\e F Y' # '26-a de julio 1887' -TIME_FORMAT = 'H:i' # '18:59' -DATETIME_FORMAT = r'j\-\a \d\e F Y\, \j\e H:i' # '26-a de julio 1887, je 18:59' -YEAR_MONTH_FORMAT = r'F \d\e Y' # 'julio de 1887' -MONTH_DAY_FORMAT = r'j\-\a \d\e F' # '26-a de julio' -SHORT_DATE_FORMAT = 'Y-m-d' # '1887-07-26' -SHORT_DATETIME_FORMAT = 'Y-m-d H:i' # '1887-07-26 18:59' +DATE_FORMAT = r"j\-\a \d\e F Y" # '26-a de julio 1887' +TIME_FORMAT = "H:i" # '18:59' +DATETIME_FORMAT = r"j\-\a \d\e F Y\, \j\e H:i" # '26-a de julio 1887, je 18:59' +YEAR_MONTH_FORMAT = r"F \d\e Y" # 'julio de 1887' +MONTH_DAY_FORMAT = r"j\-\a \d\e F" # '26-a de julio' +SHORT_DATE_FORMAT = "Y-m-d" # '1887-07-26' +SHORT_DATETIME_FORMAT = "Y-m-d H:i" # '1887-07-26 18:59' FIRST_DAY_OF_WEEK = 1 # Monday (lundo) # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '1887-07-26' - '%y-%m-%d', # '87-07-26' - '%Y %m %d', # '1887 07 26' - '%Y.%m.%d', # '1887.07.26' - '%d-a de %b %Y', # '26-a de jul 1887' - '%d %b %Y', # '26 jul 1887' - '%d-a de %B %Y', # '26-a de julio 1887' - '%d %B %Y', # '26 julio 1887' - '%d %m %Y', # '26 07 1887' - '%d/%m/%Y', # '26/07/1887' + "%Y-%m-%d", # '1887-07-26' + "%y-%m-%d", # '87-07-26' + "%Y %m %d", # '1887 07 26' + "%Y.%m.%d", # '1887.07.26' + "%d-a de %b %Y", # '26-a de jul 1887' + "%d %b %Y", # '26 jul 1887' + "%d-a de %B %Y", # '26-a de julio 1887' + "%d %B %Y", # '26 julio 1887' + "%d %m %Y", # '26 07 1887' + "%d/%m/%Y", # '26/07/1887' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '18:59:00' - '%H:%M', # '18:59' + "%H:%M:%S", # '18:59:00' + "%H:%M", # '18:59' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '1887-07-26 18:59:00' - '%Y-%m-%d %H:%M', # '1887-07-26 18:59' - - '%Y.%m.%d %H:%M:%S', # '1887.07.26 18:59:00' - '%Y.%m.%d %H:%M', # '1887.07.26 18:59' - - '%d/%m/%Y %H:%M:%S', # '26/07/1887 18:59:00' - '%d/%m/%Y %H:%M', # '26/07/1887 18:59' - - '%y-%m-%d %H:%M:%S', # '87-07-26 18:59:00' - '%y-%m-%d %H:%M', # '87-07-26 18:59' + "%Y-%m-%d %H:%M:%S", # '1887-07-26 18:59:00' + "%Y-%m-%d %H:%M", # '1887-07-26 18:59' + "%Y.%m.%d %H:%M:%S", # '1887.07.26 18:59:00' + "%Y.%m.%d %H:%M", # '1887.07.26 18:59' + "%d/%m/%Y %H:%M:%S", # '26/07/1887 18:59:00' + "%d/%m/%Y %H:%M", # '26/07/1887 18:59' + "%y-%m-%d %H:%M:%S", # '87-07-26 18:59:00' + "%y-%m-%d %H:%M", # '87-07-26 18:59' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es/formats.py b/django/conf/locale/es/formats.py index 6b017df49a..ff9690bb5e 100644 --- a/django/conf/locale/es/formats.py +++ b/django/conf/locale/es/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\a\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\a\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '31/12/2009' - '%d/%m/%y', # '31/12/09' + "%d/%m/%Y", # '31/12/2009' + "%d/%m/%y", # '31/12/09' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es_AR/formats.py b/django/conf/locale/es_AR/formats.py index e856c4a265..601b45843f 100644 --- a/django/conf/locale/es_AR/formats.py +++ b/django/conf/locale/es_AR/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j N Y' -TIME_FORMAT = r'H:i' -DATETIME_FORMAT = r'j N Y H:i' -YEAR_MONTH_FORMAT = r'F Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = r'd/m/Y' -SHORT_DATETIME_FORMAT = r'd/m/Y H:i' +DATE_FORMAT = r"j N Y" +TIME_FORMAT = r"H:i" +DATETIME_FORMAT = r"j N Y H:i" +YEAR_MONTH_FORMAT = r"F Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = r"d/m/Y" +SHORT_DATETIME_FORMAT = r"d/m/Y H:i" FIRST_DAY_OF_WEEK = 0 # 0: Sunday, 1: Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '31/12/2009' - '%d/%m/%y', # '31/12/09' + "%d/%m/%Y", # '31/12/2009' + "%d/%m/%y", # '31/12/09' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es_CO/formats.py b/django/conf/locale/es_CO/formats.py index d3a7a35d5c..056d0adaf7 100644 --- a/django/conf/locale/es_CO/formats.py +++ b/django/conf/locale/es_CO/formats.py @@ -1,27 +1,26 @@ # This file is distributed under the same license as the Django package. # -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\a\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\a\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 1 DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%Y%m%d', # '20061025' - + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%Y%m%d", # '20061025' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es_MX/formats.py b/django/conf/locale/es_MX/formats.py index d5b94e6e89..d675d79bdf 100644 --- a/django/conf/locale/es_MX/formats.py +++ b/django/conf/locale/es_MX/formats.py @@ -1,26 +1,26 @@ # This file is distributed under the same license as the Django package. # -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\a\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\a\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday: ISO 8601 DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%Y%m%d', # '20061025' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%Y%m%d", # '20061025' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = '.' # ',' is also official (less common): NOM-008-SCFI-2002 -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." # ',' is also official (less common): NOM-008-SCFI-2002 +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es_NI/formats.py b/django/conf/locale/es_NI/formats.py index ca17c9a998..0c8112a62c 100644 --- a/django/conf/locale/es_NI/formats.py +++ b/django/conf/locale/es_NI/formats.py @@ -1,27 +1,26 @@ # This file is distributed under the same license as the Django package. # -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\a\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\a\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday: ISO 8601 DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%Y%m%d', # '20061025' - + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%Y%m%d", # '20061025' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/es_PR/formats.py b/django/conf/locale/es_PR/formats.py index 8f95484f63..d50fe5d657 100644 --- a/django/conf/locale/es_PR/formats.py +++ b/django/conf/locale/es_PR/formats.py @@ -1,27 +1,27 @@ # This file is distributed under the same license as the Django package. # -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \a \l\a\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \a \l\a\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 0 # Sunday DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '31/12/2009' - '%d/%m/%y', # '31/12/09' + "%d/%m/%Y", # '31/12/2009' + "%d/%m/%y", # '31/12/09' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', - '%d/%m/%Y %H:%M:%S.%f', - '%d/%m/%Y %H:%M', - '%d/%m/%y %H:%M:%S', - '%d/%m/%y %H:%M:%S.%f', - '%d/%m/%y %H:%M', + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y %H:%M:%S.%f", + "%d/%m/%Y %H:%M", + "%d/%m/%y %H:%M:%S", + "%d/%m/%y %H:%M:%S.%f", + "%d/%m/%y %H:%M", ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/et/formats.py b/django/conf/locale/et/formats.py index 1e1e458e75..3b2d9ba443 100644 --- a/django/conf/locale/et/formats.py +++ b/django/conf/locale/et/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'G:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "G:i" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'd.m.Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = ' ' # Non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = " " # Non-breaking space # NUMBER_GROUPING = diff --git a/django/conf/locale/eu/formats.py b/django/conf/locale/eu/formats.py index 33e6305352..61b16fbc6f 100644 --- a/django/conf/locale/eu/formats.py +++ b/django/conf/locale/eu/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'Y\k\o N j\a' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'Y\k\o N j\a, H:i' -YEAR_MONTH_FORMAT = r'Y\k\o F' -MONTH_DAY_FORMAT = r'F\r\e\n j\a' -SHORT_DATE_FORMAT = 'Y-m-d' -SHORT_DATETIME_FORMAT = 'Y-m-d H:i' +DATE_FORMAT = r"Y\k\o N j\a" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"Y\k\o N j\a, H:i" +YEAR_MONTH_FORMAT = r"Y\k\o F" +MONTH_DAY_FORMAT = r"F\r\e\n j\a" +SHORT_DATE_FORMAT = "Y-m-d" +SHORT_DATETIME_FORMAT = "Y-m-d H:i" FIRST_DAY_OF_WEEK = 1 # Astelehena # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ FIRST_DAY_OF_WEEK = 1 # Astelehena # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/fa/formats.py b/django/conf/locale/fa/formats.py index c8666f7a03..e7019bc7a6 100644 --- a/django/conf/locale/fa/formats.py +++ b/django/conf/locale/fa/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j F Y، ساعت G:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'Y/n/j' -SHORT_DATETIME_FORMAT = 'Y/n/j،‏ G:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j F Y، ساعت G:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "Y/n/j" +SHORT_DATETIME_FORMAT = "Y/n/j،‏ G:i" FIRST_DAY_OF_WEEK = 6 # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ FIRST_DAY_OF_WEEK = 6 # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/fi/formats.py b/django/conf/locale/fi/formats.py index 0a56b37185..d9fb6d2f48 100644 --- a/django/conf/locale/fi/formats.py +++ b/django/conf/locale/fi/formats.py @@ -2,36 +2,35 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. E Y' -TIME_FORMAT = 'G.i' -DATETIME_FORMAT = r'j. E Y \k\e\l\l\o G.i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.n.Y' -SHORT_DATETIME_FORMAT = 'j.n.Y G.i' +DATE_FORMAT = "j. E Y" +TIME_FORMAT = "G.i" +DATETIME_FORMAT = r"j. E Y \k\e\l\l\o G.i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.n.Y" +SHORT_DATETIME_FORMAT = "j.n.Y G.i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '20.3.2014' - '%d.%m.%y', # '20.3.14' + "%d.%m.%Y", # '20.3.2014' + "%d.%m.%y", # '20.3.14' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H.%M.%S', # '20.3.2014 14.30.59' - '%d.%m.%Y %H.%M.%S.%f', # '20.3.2014 14.30.59.000200' - '%d.%m.%Y %H.%M', # '20.3.2014 14.30' - - '%d.%m.%y %H.%M.%S', # '20.3.14 14.30.59' - '%d.%m.%y %H.%M.%S.%f', # '20.3.14 14.30.59.000200' - '%d.%m.%y %H.%M', # '20.3.14 14.30' + "%d.%m.%Y %H.%M.%S", # '20.3.2014 14.30.59' + "%d.%m.%Y %H.%M.%S.%f", # '20.3.2014 14.30.59.000200' + "%d.%m.%Y %H.%M", # '20.3.2014 14.30' + "%d.%m.%y %H.%M.%S", # '20.3.14 14.30.59' + "%d.%m.%y %H.%M.%S.%f", # '20.3.14 14.30.59.000200' + "%d.%m.%y %H.%M", # '20.3.14 14.30' ] TIME_INPUT_FORMATS = [ - '%H.%M.%S', # '14.30.59' - '%H.%M.%S.%f', # '14.30.59.000200' - '%H.%M', # '14.30' + "%H.%M.%S", # '14.30.59' + "%H.%M.%S.%f", # '14.30.59.000200' + "%H.%M", # '14.30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # Non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # Non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/fr/formats.py b/django/conf/locale/fr/formats.py index a7c45598c8..5845e6aa73 100644 --- a/django/conf/locale/fr/formats.py +++ b/django/conf/locale/fr/formats.py @@ -2,32 +2,32 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j N Y' -SHORT_DATETIME_FORMAT = 'j N Y H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j N Y" +SHORT_DATETIME_FORMAT = "j N Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%d.%m.%Y', # Swiss [fr_CH] '25.10.2006' - '%d.%m.%y', # Swiss [fr_CH] '25.10.06' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%d.%m.%Y", # Swiss [fr_CH] '25.10.2006' + "%d.%m.%y", # Swiss [fr_CH] '25.10.06' # '%d %B %Y', '%d %b %Y', # '25 octobre 2006', '25 oct. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d.%m.%Y %H:%M:%S', # Swiss [fr_CH), '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # Swiss (fr_CH), '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # Swiss (fr_CH), '25.10.2006 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d.%m.%Y %H:%M:%S", # Swiss [fr_CH), '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # Swiss (fr_CH), '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # Swiss (fr_CH), '25.10.2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ga/formats.py b/django/conf/locale/ga/formats.py index eb3614abd9..7cde1a5689 100644 --- a/django/conf/locale/ga/formats.py +++ b/django/conf/locale/ga/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "H:i" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'j M Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/gd/formats.py b/django/conf/locale/gd/formats.py index 19b42ee015..5ef6774462 100644 --- a/django/conf/locale/gd/formats.py +++ b/django/conf/locale/gd/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'h:ia' -DATETIME_FORMAT = 'j F Y h:ia' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "h:ia" +DATETIME_FORMAT = "j F Y h:ia" # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' -SHORT_DATETIME_FORMAT = 'j M Y h:ia' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" +SHORT_DATETIME_FORMAT = "j M Y h:ia" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ FIRST_DAY_OF_WEEK = 1 # Monday # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/gl/formats.py b/django/conf/locale/gl/formats.py index 9f29c239df..73729355ff 100644 --- a/django/conf/locale/gl/formats.py +++ b/django/conf/locale/gl/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y \á\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd-m-Y' -SHORT_DATETIME_FORMAT = 'd-m-Y, H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y \á\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d-m-Y" +SHORT_DATETIME_FORMAT = "d-m-Y, H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ FIRST_DAY_OF_WEEK = 1 # Monday # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/he/formats.py b/django/conf/locale/he/formats.py index 2314565442..2cf9286555 100644 --- a/django/conf/locale/he/formats.py +++ b/django/conf/locale/he/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j בF Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j בF Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j בF' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = "j בF Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j בF Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j בF" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" # FIRST_DAY_OF_WEEK = # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ SHORT_DATETIME_FORMAT = 'd/m/Y H:i' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/hi/formats.py b/django/conf/locale/hi/formats.py index 923967ac51..ac078ec6c3 100644 --- a/django/conf/locale/hi/formats.py +++ b/django/conf/locale/hi/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd-m-Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d-m-Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'd-m-Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/hr/formats.py b/django/conf/locale/hr/formats.py index 627c39764f..a2dc45730d 100644 --- a/django/conf/locale/hr/formats.py +++ b/django/conf/locale/hr/formats.py @@ -2,43 +2,43 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. E Y.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. E Y. H:i' -YEAR_MONTH_FORMAT = 'F Y.' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.m.Y.' -SHORT_DATETIME_FORMAT = 'j.m.Y. H:i' +DATE_FORMAT = "j. E Y." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. E Y. H:i" +YEAR_MONTH_FORMAT = "F Y." +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.m.Y." +SHORT_DATETIME_FORMAT = "j.m.Y. H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d.%m.%Y.', # '25.10.2006.' - '%d.%m.%y.', # '25.10.06.' - '%d. %m. %Y.', # '25. 10. 2006.' - '%d. %m. %y.', # '25. 10. 06.' + "%Y-%m-%d", # '2006-10-25' + "%d.%m.%Y.", # '25.10.2006.' + "%d.%m.%y.", # '25.10.06.' + "%d. %m. %Y.", # '25. 10. 2006.' + "%d. %m. %y.", # '25. 10. 06.' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y. %H:%M:%S', # '25.10.2006. 14:30:59' - '%d.%m.%Y. %H:%M:%S.%f', # '25.10.2006. 14:30:59.000200' - '%d.%m.%Y. %H:%M', # '25.10.2006. 14:30' - '%d.%m.%y. %H:%M:%S', # '25.10.06. 14:30:59' - '%d.%m.%y. %H:%M:%S.%f', # '25.10.06. 14:30:59.000200' - '%d.%m.%y. %H:%M', # '25.10.06. 14:30' - '%d. %m. %Y. %H:%M:%S', # '25. 10. 2006. 14:30:59' - '%d. %m. %Y. %H:%M:%S.%f', # '25. 10. 2006. 14:30:59.000200' - '%d. %m. %Y. %H:%M', # '25. 10. 2006. 14:30' - '%d. %m. %y. %H:%M:%S', # '25. 10. 06. 14:30:59' - '%d. %m. %y. %H:%M:%S.%f', # '25. 10. 06. 14:30:59.000200' - '%d. %m. %y. %H:%M', # '25. 10. 06. 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y. %H:%M:%S", # '25.10.2006. 14:30:59' + "%d.%m.%Y. %H:%M:%S.%f", # '25.10.2006. 14:30:59.000200' + "%d.%m.%Y. %H:%M", # '25.10.2006. 14:30' + "%d.%m.%y. %H:%M:%S", # '25.10.06. 14:30:59' + "%d.%m.%y. %H:%M:%S.%f", # '25.10.06. 14:30:59.000200' + "%d.%m.%y. %H:%M", # '25.10.06. 14:30' + "%d. %m. %Y. %H:%M:%S", # '25. 10. 2006. 14:30:59' + "%d. %m. %Y. %H:%M:%S.%f", # '25. 10. 2006. 14:30:59.000200' + "%d. %m. %Y. %H:%M", # '25. 10. 2006. 14:30' + "%d. %m. %y. %H:%M:%S", # '25. 10. 06. 14:30:59' + "%d. %m. %y. %H:%M:%S.%f", # '25. 10. 06. 14:30:59.000200' + "%d. %m. %y. %H:%M", # '25. 10. 06. 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/hu/formats.py b/django/conf/locale/hu/formats.py index f0bfa21810..c17f2c75f8 100644 --- a/django/conf/locale/hu/formats.py +++ b/django/conf/locale/hu/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'Y. F j.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'Y. F j. H:i' -YEAR_MONTH_FORMAT = 'Y. F' -MONTH_DAY_FORMAT = 'F j.' -SHORT_DATE_FORMAT = 'Y.m.d.' -SHORT_DATETIME_FORMAT = 'Y.m.d. H:i' +DATE_FORMAT = "Y. F j." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "Y. F j. H:i" +YEAR_MONTH_FORMAT = "Y. F" +MONTH_DAY_FORMAT = "F j." +SHORT_DATE_FORMAT = "Y.m.d." +SHORT_DATETIME_FORMAT = "Y.m.d. H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y.%m.%d.', # '2006.10.25.' + "%Y.%m.%d.", # '2006.10.25.' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M', # '14:30' + "%H:%M:%S", # '14:30:59' + "%H:%M", # '14:30' ] DATETIME_INPUT_FORMATS = [ - '%Y.%m.%d. %H:%M:%S', # '2006.10.25. 14:30:59' - '%Y.%m.%d. %H:%M:%S.%f', # '2006.10.25. 14:30:59.000200' - '%Y.%m.%d. %H:%M', # '2006.10.25. 14:30' + "%Y.%m.%d. %H:%M:%S", # '2006.10.25. 14:30:59' + "%Y.%m.%d. %H:%M:%S.%f", # '2006.10.25. 14:30:59.000200' + "%Y.%m.%d. %H:%M", # '2006.10.25. 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = ' ' # Non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = " " # Non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/id/formats.py b/django/conf/locale/id/formats.py index a8cdc9589c..91a25590fa 100644 --- a/django/conf/locale/id/formats.py +++ b/django/conf/locale/id/formats.py @@ -2,48 +2,48 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j N Y' +DATE_FORMAT = "j N Y" DATETIME_FORMAT = "j N Y, G.i" -TIME_FORMAT = 'G.i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd-m-Y' -SHORT_DATETIME_FORMAT = 'd-m-Y G.i' +TIME_FORMAT = "G.i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d-m-Y" +SHORT_DATETIME_FORMAT = "d-m-Y G.i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d-%m-%Y', # '25-10-2009' - '%d/%m/%Y', # '25/10/2009' - '%d-%m-%y', # '25-10-09' - '%d/%m/%y', # '25/10/09' - '%d %b %Y', # '25 Oct 2006', - '%d %B %Y', # '25 October 2006' - '%m/%d/%y', # '10/25/06' - '%m/%d/%Y', # '10/25/2009' + "%d-%m-%Y", # '25-10-2009' + "%d/%m/%Y", # '25/10/2009' + "%d-%m-%y", # '25-10-09' + "%d/%m/%y", # '25/10/09' + "%d %b %Y", # '25 Oct 2006', + "%d %B %Y", # '25 October 2006' + "%m/%d/%y", # '10/25/06' + "%m/%d/%Y", # '10/25/2009' ] TIME_INPUT_FORMATS = [ - '%H.%M.%S', # '14.30.59' - '%H.%M', # '14.30' + "%H.%M.%S", # '14.30.59' + "%H.%M", # '14.30' ] DATETIME_INPUT_FORMATS = [ - '%d-%m-%Y %H.%M.%S', # '25-10-2009 14.30.59' - '%d-%m-%Y %H.%M.%S.%f', # '25-10-2009 14.30.59.000200' - '%d-%m-%Y %H.%M', # '25-10-2009 14.30' - '%d-%m-%y %H.%M.%S', # '25-10-09' 14.30.59' - '%d-%m-%y %H.%M.%S.%f', # '25-10-09' 14.30.59.000200' - '%d-%m-%y %H.%M', # '25-10-09' 14.30' - '%m/%d/%y %H.%M.%S', # '10/25/06 14.30.59' - '%m/%d/%y %H.%M.%S.%f', # '10/25/06 14.30.59.000200' - '%m/%d/%y %H.%M', # '10/25/06 14.30' - '%m/%d/%Y %H.%M.%S', # '25/10/2009 14.30.59' - '%m/%d/%Y %H.%M.%S.%f', # '25/10/2009 14.30.59.000200' - '%m/%d/%Y %H.%M', # '25/10/2009 14.30' + "%d-%m-%Y %H.%M.%S", # '25-10-2009 14.30.59' + "%d-%m-%Y %H.%M.%S.%f", # '25-10-2009 14.30.59.000200' + "%d-%m-%Y %H.%M", # '25-10-2009 14.30' + "%d-%m-%y %H.%M.%S", # '25-10-09' 14.30.59' + "%d-%m-%y %H.%M.%S.%f", # '25-10-09' 14.30.59.000200' + "%d-%m-%y %H.%M", # '25-10-09' 14.30' + "%m/%d/%y %H.%M.%S", # '10/25/06 14.30.59' + "%m/%d/%y %H.%M.%S.%f", # '10/25/06 14.30.59.000200' + "%m/%d/%y %H.%M", # '10/25/06 14.30' + "%m/%d/%Y %H.%M.%S", # '25/10/2009 14.30.59' + "%m/%d/%Y %H.%M.%S.%f", # '25/10/2009 14.30.59.000200' + "%m/%d/%Y %H.%M", # '25/10/2009 14.30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ig/formats.py b/django/conf/locale/ig/formats.py index 61fc2c0c77..cb0b4de5ee 100644 --- a/django/conf/locale/ig/formats.py +++ b/django/conf/locale/ig/formats.py @@ -2,31 +2,31 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'P' -DATETIME_FORMAT = 'j F Y P' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "P" +DATETIME_FORMAT = "j F Y P" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y", # '25.10.06' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/is/formats.py b/django/conf/locale/is/formats.py index e6cc7d51ed..d0f71cff70 100644 --- a/django/conf/locale/is/formats.py +++ b/django/conf/locale/is/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" # DATETIME_FORMAT = -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.n.Y' +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.n.Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'j.n.Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/it/formats.py b/django/conf/locale/it/formats.py index 703297af03..bb9e0270bc 100644 --- a/django/conf/locale/it/formats.py +++ b/django/conf/locale/it/formats.py @@ -2,42 +2,42 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' # 25 Ottobre 2006 -TIME_FORMAT = 'H:i' # 14:30 -DATETIME_FORMAT = 'l d F Y H:i' # Mercoledì 25 Ottobre 2006 14:30 -YEAR_MONTH_FORMAT = 'F Y' # Ottobre 2006 -MONTH_DAY_FORMAT = 'j F' # 25 Ottobre -SHORT_DATE_FORMAT = 'd/m/Y' # 25/12/2009 -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' # 25/10/2009 14:30 +DATE_FORMAT = "d F Y" # 25 Ottobre 2006 +TIME_FORMAT = "H:i" # 14:30 +DATETIME_FORMAT = "l d F Y H:i" # Mercoledì 25 Ottobre 2006 14:30 +YEAR_MONTH_FORMAT = "F Y" # Ottobre 2006 +MONTH_DAY_FORMAT = "j F" # 25 Ottobre +SHORT_DATE_FORMAT = "d/m/Y" # 25/12/2009 +SHORT_DATETIME_FORMAT = "d/m/Y H:i" # 25/10/2009 14:30 FIRST_DAY_OF_WEEK = 1 # Lunedì # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%Y/%m/%d', # '2006/10/25' - '%d-%m-%Y', # '25-10-2006' - '%Y-%m-%d', # '2006-10-25' - '%d-%m-%y', # '25-10-06' - '%d/%m/%y', # '25/10/06' + "%d/%m/%Y", # '25/10/2006' + "%Y/%m/%d", # '2006/10/25' + "%d-%m-%Y", # '25-10-2006' + "%Y-%m-%d", # '2006-10-25' + "%d-%m-%y", # '25-10-06' + "%d/%m/%y", # '25/10/06' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d-%m-%Y %H:%M:%S', # '25-10-2006 14:30:59' - '%d-%m-%Y %H:%M:%S.%f', # '25-10-2006 14:30:59.000200' - '%d-%m-%Y %H:%M', # '25-10-2006 14:30' - '%d-%m-%y %H:%M:%S', # '25-10-06 14:30:59' - '%d-%m-%y %H:%M:%S.%f', # '25-10-06 14:30:59.000200' - '%d-%m-%y %H:%M', # '25-10-06 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d-%m-%Y %H:%M:%S", # '25-10-2006 14:30:59' + "%d-%m-%Y %H:%M:%S.%f", # '25-10-2006 14:30:59.000200' + "%d-%m-%Y %H:%M", # '25-10-2006 14:30' + "%d-%m-%y %H:%M:%S", # '25-10-06 14:30:59' + "%d-%m-%y %H:%M:%S.%f", # '25-10-06 14:30:59.000200' + "%d-%m-%y %H:%M", # '25-10-06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ja/formats.py b/django/conf/locale/ja/formats.py index 2f1faa69ad..aaf5f9838f 100644 --- a/django/conf/locale/ja/formats.py +++ b/django/conf/locale/ja/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'Y年n月j日' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'Y年n月j日G:i' -YEAR_MONTH_FORMAT = 'Y年n月' -MONTH_DAY_FORMAT = 'n月j日' -SHORT_DATE_FORMAT = 'Y/m/d' -SHORT_DATETIME_FORMAT = 'Y/m/d G:i' +DATE_FORMAT = "Y年n月j日" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "Y年n月j日G:i" +YEAR_MONTH_FORMAT = "Y年n月" +MONTH_DAY_FORMAT = "n月j日" +SHORT_DATE_FORMAT = "Y/m/d" +SHORT_DATETIME_FORMAT = "Y/m/d G:i" # FIRST_DAY_OF_WEEK = # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ SHORT_DATETIME_FORMAT = 'Y/m/d G:i' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," # NUMBER_GROUPING = diff --git a/django/conf/locale/ka/formats.py b/django/conf/locale/ka/formats.py index bc75ac7407..661b71e2c5 100644 --- a/django/conf/locale/ka/formats.py +++ b/django/conf/locale/ka/formats.py @@ -2,24 +2,24 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'l, j F, Y' -TIME_FORMAT = 'h:i a' -DATETIME_FORMAT = 'j F, Y h:i a' -YEAR_MONTH_FORMAT = 'F, Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j.M.Y' -SHORT_DATETIME_FORMAT = 'j.M.Y H:i' +DATE_FORMAT = "l, j F, Y" +TIME_FORMAT = "h:i a" +DATETIME_FORMAT = "j F, Y h:i a" +YEAR_MONTH_FORMAT = "F, Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j.M.Y" +SHORT_DATETIME_FORMAT = "j.M.Y H:i" FIRST_DAY_OF_WEEK = 1 # (Monday) # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' # "%d %b %Y", # '25 Oct 2006' # "%d %b, %Y", # '25 Oct, 2006' # "%d %b. %Y", # '25 Oct. 2006' @@ -27,22 +27,22 @@ DATE_INPUT_FORMATS = [ # "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' ] -DECIMAL_SEPARATOR = '.' +DECIMAL_SEPARATOR = "." THOUSAND_SEPARATOR = " " NUMBER_GROUPING = 3 diff --git a/django/conf/locale/km/formats.py b/django/conf/locale/km/formats.py index b704e9c62d..5923437476 100644 --- a/django/conf/locale/km/formats.py +++ b/django/conf/locale/km/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j ខែ F ឆ្នាំ Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j ខែ F ឆ្នាំ Y, G:i' +DATE_FORMAT = "j ខែ F ឆ្នាំ Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j ខែ F ឆ្នាំ Y, G:i" # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' -SHORT_DATETIME_FORMAT = 'j M Y, G:i' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" +SHORT_DATETIME_FORMAT = "j M Y, G:i" # FIRST_DAY_OF_WEEK = # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ SHORT_DATETIME_FORMAT = 'j M Y, G:i' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/kn/formats.py b/django/conf/locale/kn/formats.py index 5003c6441b..d212fd52d1 100644 --- a/django/conf/locale/kn/formats.py +++ b/django/conf/locale/kn/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'h:i A' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "h:i A" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = diff --git a/django/conf/locale/ko/formats.py b/django/conf/locale/ko/formats.py index 78f9f10f17..1f3487c6f9 100644 --- a/django/conf/locale/ko/formats.py +++ b/django/conf/locale/ko/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'Y년 n월 j일' -TIME_FORMAT = 'A g:i' -DATETIME_FORMAT = 'Y년 n월 j일 g:i A' -YEAR_MONTH_FORMAT = 'Y년 n월' -MONTH_DAY_FORMAT = 'n월 j일' -SHORT_DATE_FORMAT = 'Y-n-j.' -SHORT_DATETIME_FORMAT = 'Y-n-j H:i' +DATE_FORMAT = "Y년 n월 j일" +TIME_FORMAT = "A g:i" +DATETIME_FORMAT = "Y년 n월 j일 g:i A" +YEAR_MONTH_FORMAT = "Y년 n월" +MONTH_DAY_FORMAT = "n월 j일" +SHORT_DATE_FORMAT = "Y-n-j." +SHORT_DATETIME_FORMAT = "Y-n-j H:i" # FIRST_DAY_OF_WEEK = # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' # "%b %d %Y", # 'Oct 25 2006' # "%b %d, %Y", # 'Oct 25, 2006' # "%d %b %Y", # '25 Oct 2006' @@ -26,30 +26,29 @@ DATE_INPUT_FORMATS = [ # "%B %d, %Y", #'October 25, 2006' # "%d %B %Y", # '25 October 2006' # "%d %B, %Y", # '25 October, 2006' - '%Y년 %m월 %d일', # '2006년 10월 25일', with localized suffix. + "%Y년 %m월 %d일", # '2006년 10월 25일', with localized suffix. ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' - '%H시 %M분 %S초', # '14시 30분 59초' - '%H시 %M분', # '14시 30분' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' + "%H시 %M분 %S초", # '14시 30분 59초' + "%H시 %M분", # '14시 30분' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' - - '%Y년 %m월 %d일 %H시 %M분 %S초', # '2006년 10월 25일 14시 30분 59초' - '%Y년 %m월 %d일 %H시 %M분', # '2006년 10월 25일 14시 30분' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' + "%Y년 %m월 %d일 %H시 %M분 %S초", # '2006년 10월 25일 14시 30분 59초' + "%Y년 %m월 %d일 %H시 %M분", # '2006년 10월 25일 14시 30분' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ky/formats.py b/django/conf/locale/ky/formats.py index 1dc42c41e4..25a092872a 100644 --- a/django/conf/locale/ky/formats.py +++ b/django/conf/locale/ky/formats.py @@ -2,31 +2,31 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y ж.' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j E Y ж. G:i' -YEAR_MONTH_FORMAT = 'F Y ж.' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j E Y ж." +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j E Y ж. G:i" +YEAR_MONTH_FORMAT = "F Y ж." +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Дүйшөмбү, Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y", # '25.10.06' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/lt/formats.py b/django/conf/locale/lt/formats.py index 94e4469057..a351b3c240 100644 --- a/django/conf/locale/lt/formats.py +++ b/django/conf/locale/lt/formats.py @@ -2,44 +2,44 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'Y \m. E j \d.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'Y \m. E j \d., H:i' -YEAR_MONTH_FORMAT = r'Y \m. F' -MONTH_DAY_FORMAT = r'E j \d.' -SHORT_DATE_FORMAT = 'Y-m-d' -SHORT_DATETIME_FORMAT = 'Y-m-d H:i' +DATE_FORMAT = r"Y \m. E j \d." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"Y \m. E j \d., H:i" +YEAR_MONTH_FORMAT = r"Y \m. F" +MONTH_DAY_FORMAT = r"E j \d." +SHORT_DATE_FORMAT = "Y-m-d" +SHORT_DATETIME_FORMAT = "Y-m-d H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%Y-%m-%d", # '2006-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' - '%H.%M.%S', # '14.30.59' - '%H.%M.%S.%f', # '14.30.59.000200' - '%H.%M', # '14.30' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' + "%H.%M.%S", # '14.30.59' + "%H.%M.%S.%f", # '14.30.59.000200' + "%H.%M", # '14.30' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y %H.%M.%S', # '25.10.06 14.30.59' - '%d.%m.%y %H.%M.%S.%f', # '25.10.06 14.30.59.000200' - '%d.%m.%y %H.%M', # '25.10.06 14.30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y %H.%M.%S", # '25.10.06 14.30.59' + "%d.%m.%y %H.%M.%S.%f", # '25.10.06 14.30.59.000200' + "%d.%m.%y %H.%M", # '25.10.06 14.30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/lv/formats.py b/django/conf/locale/lv/formats.py index 9fac516cc5..bb34444338 100644 --- a/django/conf/locale/lv/formats.py +++ b/django/conf/locale/lv/formats.py @@ -2,45 +2,45 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'Y. \g\a\d\a j. F' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'Y. \g\a\d\a j. F, H:i' -YEAR_MONTH_FORMAT = r'Y. \g. F' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = r'j.m.Y' -SHORT_DATETIME_FORMAT = 'j.m.Y H:i' +DATE_FORMAT = r"Y. \g\a\d\a j. F" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"Y. \g\a\d\a j. F, H:i" +YEAR_MONTH_FORMAT = r"Y. \g. F" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = r"j.m.Y" +SHORT_DATETIME_FORMAT = "j.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%Y-%m-%d", # '2006-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' - '%H.%M.%S', # '14.30.59' - '%H.%M.%S.%f', # '14.30.59.000200' - '%H.%M', # '14.30' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' + "%H.%M.%S", # '14.30.59' + "%H.%M.%S.%f", # '14.30.59.000200' + "%H.%M", # '14.30' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y %H.%M.%S', # '25.10.06 14.30.59' - '%d.%m.%y %H.%M.%S.%f', # '25.10.06 14.30.59.000200' - '%d.%m.%y %H.%M', # '25.10.06 14.30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y %H.%M.%S", # '25.10.06 14.30.59' + "%d.%m.%y %H.%M.%S.%f", # '25.10.06 14.30.59.000200' + "%d.%m.%y %H.%M", # '25.10.06 14.30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = ' ' # Non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = " " # Non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/mk/formats.py b/django/conf/locale/mk/formats.py index 12e506163d..fbb577f772 100644 --- a/django/conf/locale/mk/formats.py +++ b/django/conf/locale/mk/formats.py @@ -2,39 +2,39 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.m.Y' -SHORT_DATETIME_FORMAT = 'j.m.Y H:i' +DATE_FORMAT = "d F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.m.Y" +SHORT_DATETIME_FORMAT = "j.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' - '%d. %m. %Y', # '25. 10. 2006' - '%d. %m. %y', # '25. 10. 06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' + "%d. %m. %Y", # '25. 10. 2006' + "%d. %m. %y", # '25. 10. 06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d. %m. %Y %H:%M:%S', # '25. 10. 2006 14:30:59' - '%d. %m. %Y %H:%M:%S.%f', # '25. 10. 2006 14:30:59.000200' - '%d. %m. %Y %H:%M', # '25. 10. 2006 14:30' - '%d. %m. %y %H:%M:%S', # '25. 10. 06 14:30:59' - '%d. %m. %y %H:%M:%S.%f', # '25. 10. 06 14:30:59.000200' - '%d. %m. %y %H:%M', # '25. 10. 06 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d. %m. %Y %H:%M:%S", # '25. 10. 2006 14:30:59' + "%d. %m. %Y %H:%M:%S.%f", # '25. 10. 2006 14:30:59.000200' + "%d. %m. %Y %H:%M", # '25. 10. 2006 14:30' + "%d. %m. %y %H:%M:%S", # '25. 10. 06 14:30:59' + "%d. %m. %y %H:%M:%S.%f", # '25. 10. 06 14:30:59.000200' + "%d. %m. %y %H:%M", # '25. 10. 06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ml/formats.py b/django/conf/locale/ml/formats.py index 5f8b51d51d..b1ca2ee846 100644 --- a/django/conf/locale/ml/formats.py +++ b/django/conf/locale/ml/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'N j, Y' -TIME_FORMAT = 'P' -DATETIME_FORMAT = 'N j, Y, P' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'F j' -SHORT_DATE_FORMAT = 'm/d/Y' -SHORT_DATETIME_FORMAT = 'm/d/Y P' +DATE_FORMAT = "N j, Y" +TIME_FORMAT = "P" +DATETIME_FORMAT = "N j, Y, P" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "F j" +SHORT_DATE_FORMAT = "m/d/Y" +SHORT_DATETIME_FORMAT = "m/d/Y P" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' # "%b %d %Y", # 'Oct 25 2006' # "%b %d, %Y", # 'Oct 25, 2006' # "%d %b %Y", # '25 Oct 2006' @@ -28,16 +28,16 @@ DATE_INPUT_FORMATS = [ # "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/mn/formats.py b/django/conf/locale/mn/formats.py index 24c7dec8a7..589c24cf66 100644 --- a/django/conf/locale/mn/formats.py +++ b/django/conf/locale/mn/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "d F Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = # MONTH_DAY_FORMAT = -SHORT_DATE_FORMAT = 'j M Y' +SHORT_DATE_FORMAT = "j M Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = diff --git a/django/conf/locale/ms/formats.py b/django/conf/locale/ms/formats.py index 8f46fdc144..d06719fee3 100644 --- a/django/conf/locale/ms/formats.py +++ b/django/conf/locale/ms/formats.py @@ -2,37 +2,37 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j M Y' # '25 Oct 2006' -TIME_FORMAT = 'P' # '2:30 p.m.' -DATETIME_FORMAT = 'j M Y, P' # '25 Oct 2006, 2:30 p.m.' -YEAR_MONTH_FORMAT = 'F Y' # 'October 2006' -MONTH_DAY_FORMAT = 'j F' # '25 October' -SHORT_DATE_FORMAT = 'd/m/Y' # '25/10/2006' -SHORT_DATETIME_FORMAT = 'd/m/Y P' # '25/10/2006 2:30 p.m.' -FIRST_DAY_OF_WEEK = 0 # Sunday +DATE_FORMAT = "j M Y" # '25 Oct 2006' +TIME_FORMAT = "P" # '2:30 p.m.' +DATETIME_FORMAT = "j M Y, P" # '25 Oct 2006, 2:30 p.m.' +YEAR_MONTH_FORMAT = "F Y" # 'October 2006' +MONTH_DAY_FORMAT = "j F" # '25 October' +SHORT_DATE_FORMAT = "d/m/Y" # '25/10/2006' +SHORT_DATETIME_FORMAT = "d/m/Y P" # '25/10/2006 2:30 p.m.' +FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%d %b %Y', # '25 Oct 2006' - '%d %b, %Y', # '25 Oct, 2006' - '%d %B %Y', # '25 October 2006' - '%d %B, %Y', # '25 October, 2006' + "%Y-%m-%d", # '2006-10-25' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%d %b %Y", # '25 Oct 2006' + "%d %b, %Y", # '25 Oct, 2006' + "%d %B %Y", # '25 October 2006' + "%d %B, %Y", # '25 October, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/nb/formats.py b/django/conf/locale/nb/formats.py index c34cb778ea..0ddb8fef60 100644 --- a/django/conf/locale/nb/formats.py +++ b/django/conf/locale/nb/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%Y-%m-%d", # '2006-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' # "%d. %b %Y", # '25. okt 2006' # "%d %b %Y", # '25 okt 2006' # "%d. %b. %Y", # '25. okt. 2006' @@ -26,16 +26,16 @@ DATE_INPUT_FORMATS = [ # "%d %B %Y", # '25 oktober 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/nl/formats.py b/django/conf/locale/nl/formats.py index fbe937c3f4..e9f52b9bd3 100644 --- a/django/conf/locale/nl/formats.py +++ b/django/conf/locale/nl/formats.py @@ -2,23 +2,23 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' # '20 januari 2009' -TIME_FORMAT = 'H:i' # '15:23' -DATETIME_FORMAT = 'j F Y H:i' # '20 januari 2009 15:23' -YEAR_MONTH_FORMAT = 'F Y' # 'januari 2009' -MONTH_DAY_FORMAT = 'j F' # '20 januari' -SHORT_DATE_FORMAT = 'j-n-Y' # '20-1-2009' -SHORT_DATETIME_FORMAT = 'j-n-Y H:i' # '20-1-2009 15:23' -FIRST_DAY_OF_WEEK = 1 # Monday (in Dutch 'maandag') +DATE_FORMAT = "j F Y" # '20 januari 2009' +TIME_FORMAT = "H:i" # '15:23' +DATETIME_FORMAT = "j F Y H:i" # '20 januari 2009 15:23' +YEAR_MONTH_FORMAT = "F Y" # 'januari 2009' +MONTH_DAY_FORMAT = "j F" # '20 januari' +SHORT_DATE_FORMAT = "j-n-Y" # '20-1-2009' +SHORT_DATETIME_FORMAT = "j-n-Y H:i" # '20-1-2009 15:23' +FIRST_DAY_OF_WEEK = 1 # Monday (in Dutch 'maandag') # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d-%m-%Y', # '20-01-2009' - '%d-%m-%y', # '20-01-09' - '%d/%m/%Y', # '20/01/2009' - '%d/%m/%y', # '20/01/09' - '%Y/%m/%d', # '2009/01/20' + "%d-%m-%Y", # '20-01-2009' + "%d-%m-%y", # '20-01-09' + "%d/%m/%Y", # '20/01/2009' + "%d/%m/%y", # '20/01/09' + "%Y/%m/%d", # '2009/01/20' # "%d %b %Y", # '20 jan 2009' # "%d %b %y", # '20 jan 09' # "%d %B %Y", # '20 januari 2009' @@ -26,67 +26,67 @@ DATE_INPUT_FORMATS = [ ] # Kept ISO formats as one is in first position TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '15:23:35' - '%H:%M:%S.%f', # '15:23:35.000200' - '%H.%M:%S', # '15.23:35' - '%H.%M:%S.%f', # '15.23:35.000200' - '%H.%M', # '15.23' - '%H:%M', # '15:23' + "%H:%M:%S", # '15:23:35' + "%H:%M:%S.%f", # '15:23:35.000200' + "%H.%M:%S", # '15.23:35' + "%H.%M:%S.%f", # '15.23:35.000200' + "%H.%M", # '15.23' + "%H:%M", # '15:23' ] DATETIME_INPUT_FORMATS = [ # With time in %H:%M:%S : - '%d-%m-%Y %H:%M:%S', # '20-01-2009 15:23:35' - '%d-%m-%y %H:%M:%S', # '20-01-09 15:23:35' - '%Y-%m-%d %H:%M:%S', # '2009-01-20 15:23:35' - '%d/%m/%Y %H:%M:%S', # '20/01/2009 15:23:35' - '%d/%m/%y %H:%M:%S', # '20/01/09 15:23:35' - '%Y/%m/%d %H:%M:%S', # '2009/01/20 15:23:35' + "%d-%m-%Y %H:%M:%S", # '20-01-2009 15:23:35' + "%d-%m-%y %H:%M:%S", # '20-01-09 15:23:35' + "%Y-%m-%d %H:%M:%S", # '2009-01-20 15:23:35' + "%d/%m/%Y %H:%M:%S", # '20/01/2009 15:23:35' + "%d/%m/%y %H:%M:%S", # '20/01/09 15:23:35' + "%Y/%m/%d %H:%M:%S", # '2009/01/20 15:23:35' # "%d %b %Y %H:%M:%S", # '20 jan 2009 15:23:35' # "%d %b %y %H:%M:%S", # '20 jan 09 15:23:35' # "%d %B %Y %H:%M:%S", # '20 januari 2009 15:23:35' # "%d %B %y %H:%M:%S", # '20 januari 2009 15:23:35' # With time in %H:%M:%S.%f : - '%d-%m-%Y %H:%M:%S.%f', # '20-01-2009 15:23:35.000200' - '%d-%m-%y %H:%M:%S.%f', # '20-01-09 15:23:35.000200' - '%Y-%m-%d %H:%M:%S.%f', # '2009-01-20 15:23:35.000200' - '%d/%m/%Y %H:%M:%S.%f', # '20/01/2009 15:23:35.000200' - '%d/%m/%y %H:%M:%S.%f', # '20/01/09 15:23:35.000200' - '%Y/%m/%d %H:%M:%S.%f', # '2009/01/20 15:23:35.000200' + "%d-%m-%Y %H:%M:%S.%f", # '20-01-2009 15:23:35.000200' + "%d-%m-%y %H:%M:%S.%f", # '20-01-09 15:23:35.000200' + "%Y-%m-%d %H:%M:%S.%f", # '2009-01-20 15:23:35.000200' + "%d/%m/%Y %H:%M:%S.%f", # '20/01/2009 15:23:35.000200' + "%d/%m/%y %H:%M:%S.%f", # '20/01/09 15:23:35.000200' + "%Y/%m/%d %H:%M:%S.%f", # '2009/01/20 15:23:35.000200' # With time in %H.%M:%S : - '%d-%m-%Y %H.%M:%S', # '20-01-2009 15.23:35' - '%d-%m-%y %H.%M:%S', # '20-01-09 15.23:35' - '%d/%m/%Y %H.%M:%S', # '20/01/2009 15.23:35' - '%d/%m/%y %H.%M:%S', # '20/01/09 15.23:35' + "%d-%m-%Y %H.%M:%S", # '20-01-2009 15.23:35' + "%d-%m-%y %H.%M:%S", # '20-01-09 15.23:35' + "%d/%m/%Y %H.%M:%S", # '20/01/2009 15.23:35' + "%d/%m/%y %H.%M:%S", # '20/01/09 15.23:35' # "%d %b %Y %H.%M:%S", # '20 jan 2009 15.23:35' # "%d %b %y %H.%M:%S", # '20 jan 09 15.23:35' # "%d %B %Y %H.%M:%S", # '20 januari 2009 15.23:35' # "%d %B %y %H.%M:%S", # '20 januari 2009 15.23:35' # With time in %H.%M:%S.%f : - '%d-%m-%Y %H.%M:%S.%f', # '20-01-2009 15.23:35.000200' - '%d-%m-%y %H.%M:%S.%f', # '20-01-09 15.23:35.000200' - '%d/%m/%Y %H.%M:%S.%f', # '20/01/2009 15.23:35.000200' - '%d/%m/%y %H.%M:%S.%f', # '20/01/09 15.23:35.000200' + "%d-%m-%Y %H.%M:%S.%f", # '20-01-2009 15.23:35.000200' + "%d-%m-%y %H.%M:%S.%f", # '20-01-09 15.23:35.000200' + "%d/%m/%Y %H.%M:%S.%f", # '20/01/2009 15.23:35.000200' + "%d/%m/%y %H.%M:%S.%f", # '20/01/09 15.23:35.000200' # With time in %H:%M : - '%d-%m-%Y %H:%M', # '20-01-2009 15:23' - '%d-%m-%y %H:%M', # '20-01-09 15:23' - '%Y-%m-%d %H:%M', # '2009-01-20 15:23' - '%d/%m/%Y %H:%M', # '20/01/2009 15:23' - '%d/%m/%y %H:%M', # '20/01/09 15:23' - '%Y/%m/%d %H:%M', # '2009/01/20 15:23' + "%d-%m-%Y %H:%M", # '20-01-2009 15:23' + "%d-%m-%y %H:%M", # '20-01-09 15:23' + "%Y-%m-%d %H:%M", # '2009-01-20 15:23' + "%d/%m/%Y %H:%M", # '20/01/2009 15:23' + "%d/%m/%y %H:%M", # '20/01/09 15:23' + "%Y/%m/%d %H:%M", # '2009/01/20 15:23' # "%d %b %Y %H:%M", # '20 jan 2009 15:23' # "%d %b %y %H:%M", # '20 jan 09 15:23' # "%d %B %Y %H:%M", # '20 januari 2009 15:23' # "%d %B %y %H:%M", # '20 januari 2009 15:23' # With time in %H.%M : - '%d-%m-%Y %H.%M', # '20-01-2009 15.23' - '%d-%m-%y %H.%M', # '20-01-09 15.23' - '%d/%m/%Y %H.%M', # '20/01/2009 15.23' - '%d/%m/%y %H.%M', # '20/01/09 15.23' + "%d-%m-%Y %H.%M", # '20-01-2009 15.23' + "%d-%m-%y %H.%M", # '20-01-09 15.23' + "%d/%m/%Y %H.%M", # '20/01/2009 15.23' + "%d/%m/%y %H.%M", # '20/01/09 15.23' # "%d %b %Y %H.%M", # '20 jan 2009 15.23' # "%d %b %y %H.%M", # '20 jan 09 15.23' # "%d %B %Y %H.%M", # '20 januari 2009 15.23' # "%d %B %y %H.%M", # '20 januari 2009 15.23' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/nn/formats.py b/django/conf/locale/nn/formats.py index c34cb778ea..0ddb8fef60 100644 --- a/django/conf/locale/nn/formats.py +++ b/django/conf/locale/nn/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%Y-%m-%d", # '2006-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' # "%d. %b %Y", # '25. okt 2006' # "%d %b %Y", # '25 okt 2006' # "%d. %b. %Y", # '25. okt. 2006' @@ -26,16 +26,16 @@ DATE_INPUT_FORMATS = [ # "%d %B %Y", # '25 oktober 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/pl/formats.py b/django/conf/locale/pl/formats.py index 7c9e30ad1f..2ad1bfee42 100644 --- a/django/conf/locale/pl/formats.py +++ b/django/conf/locale/pl/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j E Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j E' -SHORT_DATE_FORMAT = 'd-m-Y' -SHORT_DATETIME_FORMAT = 'd-m-Y H:i' +DATE_FORMAT = "j E Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j E Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j E" +SHORT_DATE_FORMAT = "d-m-Y" +SHORT_DATETIME_FORMAT = "d-m-Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' - '%y-%m-%d', # '06-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' + "%y-%m-%d", # '06-10-25' # "%d. %B %Y", # '25. października 2006' # "%d. %b. %Y", # '25. paź. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = ' ' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = " " NUMBER_GROUPING = 3 diff --git a/django/conf/locale/pt/formats.py b/django/conf/locale/pt/formats.py index 39e06c7002..bb4b3f50fb 100644 --- a/django/conf/locale/pt/formats.py +++ b/django/conf/locale/pt/formats.py @@ -2,38 +2,38 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y à\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y à\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' + "%Y-%m-%d", # '2006-10-25' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' # "%d de %b de %Y", # '25 de Out de 2006' # "%d de %b, %Y", # '25 Out, 2006' # "%d de %B de %Y", # '25 de Outubro de 2006' # "%d de %B, %Y", # '25 de Outubro, 2006' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/pt_BR/formats.py b/django/conf/locale/pt_BR/formats.py index 9f3244955d..96a49b48c7 100644 --- a/django/conf/locale/pt_BR/formats.py +++ b/django/conf/locale/pt_BR/formats.py @@ -2,33 +2,33 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j \d\e F \d\e Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'j \d\e F \d\e Y à\s H:i' -YEAR_MONTH_FORMAT = r'F \d\e Y' -MONTH_DAY_FORMAT = r'j \d\e F' -SHORT_DATE_FORMAT = 'd/m/Y' -SHORT_DATETIME_FORMAT = 'd/m/Y H:i' +DATE_FORMAT = r"j \d\e F \d\e Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"j \d\e F \d\e Y à\s H:i" +YEAR_MONTH_FORMAT = r"F \d\e Y" +MONTH_DAY_FORMAT = r"j \d\e F" +SHORT_DATE_FORMAT = "d/m/Y" +SHORT_DATETIME_FORMAT = "d/m/Y H:i" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' # "%d de %b de %Y", # '24 de Out de 2006' # "%d de %b, %Y", # '25 Out, 2006' # "%d de %B de %Y", # '25 de Outubro de 2006' # "%d de %B, %Y", # '25 de Outubro, 2006' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' - '%d/%m/%y %H:%M:%S', # '25/10/06 14:30:59' - '%d/%m/%y %H:%M:%S.%f', # '25/10/06 14:30:59.000200' - '%d/%m/%y %H:%M', # '25/10/06 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' + "%d/%m/%y %H:%M:%S", # '25/10/06 14:30:59' + "%d/%m/%y %H:%M:%S.%f", # '25/10/06 14:30:59.000200' + "%d/%m/%y %H:%M", # '25/10/06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ro/formats.py b/django/conf/locale/ro/formats.py index 8cefeb8395..5a0c173f0b 100644 --- a/django/conf/locale/ro/formats.py +++ b/django/conf/locale/ro/formats.py @@ -2,34 +2,34 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j F Y, H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y, H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j F Y, H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y, H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', - '%d.%b.%Y', - '%d %B %Y', - '%A, %d %B %Y', + "%d.%m.%Y", + "%d.%b.%Y", + "%d %B %Y", + "%A, %d %B %Y", ] TIME_INPUT_FORMATS = [ - '%H:%M', - '%H:%M:%S', - '%H:%M:%S.%f', + "%H:%M", + "%H:%M:%S", + "%H:%M:%S.%f", ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y, %H:%M', - '%d.%m.%Y, %H:%M:%S', - '%d.%B.%Y, %H:%M', - '%d.%B.%Y, %H:%M:%S', + "%d.%m.%Y, %H:%M", + "%d.%m.%Y, %H:%M:%S", + "%d.%B.%Y, %H:%M", + "%d.%B.%Y, %H:%M:%S", ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ru/formats.py b/django/conf/locale/ru/formats.py index c601c3e51a..212e5267d0 100644 --- a/django/conf/locale/ru/formats.py +++ b/django/conf/locale/ru/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y г.' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j E Y г. G:i' -YEAR_MONTH_FORMAT = 'F Y г.' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j E Y г." +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j E Y г. G:i" +YEAR_MONTH_FORMAT = "F Y г." +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/sk/formats.py b/django/conf/locale/sk/formats.py index 56e710dd44..31d4912256 100644 --- a/django/conf/locale/sk/formats.py +++ b/django/conf/locale/sk/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j. F Y G:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y G:i' +DATE_FORMAT = "j. F Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j. F Y G:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y G:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' - '%y-%m-%d', # '06-10-25' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' + "%y-%m-%d", # '06-10-25' # "%d. %B %Y", # '25. October 2006' # "%d. %b. %Y", # '25. Oct. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/sl/formats.py b/django/conf/locale/sl/formats.py index ab6d46d9f1..c3e96bb2fb 100644 --- a/django/conf/locale/sl/formats.py +++ b/django/conf/locale/sl/formats.py @@ -2,43 +2,43 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd. F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y. H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j. M. Y' -SHORT_DATETIME_FORMAT = 'j.n.Y. H:i' +DATE_FORMAT = "d. F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y. H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j. M. Y" +SHORT_DATETIME_FORMAT = "j.n.Y. H:i" FIRST_DAY_OF_WEEK = 0 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' - '%d-%m-%Y', # '25-10-2006' - '%d. %m. %Y', # '25. 10. 2006' - '%d. %m. %y', # '25. 10. 06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' + "%d-%m-%Y", # '25-10-2006' + "%d. %m. %Y", # '25. 10. 2006' + "%d. %m. %y", # '25. 10. 06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d-%m-%Y %H:%M:%S', # '25-10-2006 14:30:59' - '%d-%m-%Y %H:%M:%S.%f', # '25-10-2006 14:30:59.000200' - '%d-%m-%Y %H:%M', # '25-10-2006 14:30' - '%d. %m. %Y %H:%M:%S', # '25. 10. 2006 14:30:59' - '%d. %m. %Y %H:%M:%S.%f', # '25. 10. 2006 14:30:59.000200' - '%d. %m. %Y %H:%M', # '25. 10. 2006 14:30' - '%d. %m. %y %H:%M:%S', # '25. 10. 06 14:30:59' - '%d. %m. %y %H:%M:%S.%f', # '25. 10. 06 14:30:59.000200' - '%d. %m. %y %H:%M', # '25. 10. 06 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d-%m-%Y %H:%M:%S", # '25-10-2006 14:30:59' + "%d-%m-%Y %H:%M:%S.%f", # '25-10-2006 14:30:59.000200' + "%d-%m-%Y %H:%M", # '25-10-2006 14:30' + "%d. %m. %Y %H:%M:%S", # '25. 10. 2006 14:30:59' + "%d. %m. %Y %H:%M:%S.%f", # '25. 10. 2006 14:30:59.000200' + "%d. %m. %Y %H:%M", # '25. 10. 2006 14:30' + "%d. %m. %y %H:%M:%S", # '25. 10. 06 14:30:59' + "%d. %m. %y %H:%M:%S.%f", # '25. 10. 06 14:30:59.000200' + "%d. %m. %y %H:%M", # '25. 10. 06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/sq/formats.py b/django/conf/locale/sq/formats.py index 2f0da0d400..c7ed92e12f 100644 --- a/django/conf/locale/sq/formats.py +++ b/django/conf/locale/sq/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' -TIME_FORMAT = 'g.i.A' +DATE_FORMAT = "d F Y" +TIME_FORMAT = "g.i.A" # DATETIME_FORMAT = -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'Y-m-d' +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "Y-m-d" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = @@ -16,6 +16,6 @@ SHORT_DATE_FORMAT = 'Y-m-d' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/sr/formats.py b/django/conf/locale/sr/formats.py index 624571cd95..423f86d75c 100644 --- a/django/conf/locale/sr/formats.py +++ b/django/conf/locale/sr/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y. H:i' -YEAR_MONTH_FORMAT = 'F Y.' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.m.Y.' -SHORT_DATETIME_FORMAT = 'j.m.Y. H:i' +DATE_FORMAT = "j. F Y." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y. H:i" +YEAR_MONTH_FORMAT = "F Y." +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.m.Y." +SHORT_DATETIME_FORMAT = "j.m.Y. H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y.', # '25.10.2006.' - '%d.%m.%y.', # '25.10.06.' - '%d. %m. %Y.', # '25. 10. 2006.' - '%d. %m. %y.', # '25. 10. 06.' + "%d.%m.%Y.", # '25.10.2006.' + "%d.%m.%y.", # '25.10.06.' + "%d. %m. %Y.", # '25. 10. 2006.' + "%d. %m. %y.", # '25. 10. 06.' # "%d. %b %y.", # '25. Oct 06.' # "%d. %B %y.", # '25. October 06.' # "%d. %b '%y.", # '25. Oct '06.' @@ -26,19 +26,19 @@ DATE_INPUT_FORMATS = [ # "%d. %B %Y.", # '25. October 2006.' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y. %H:%M:%S', # '25.10.2006. 14:30:59' - '%d.%m.%Y. %H:%M:%S.%f', # '25.10.2006. 14:30:59.000200' - '%d.%m.%Y. %H:%M', # '25.10.2006. 14:30' - '%d.%m.%y. %H:%M:%S', # '25.10.06. 14:30:59' - '%d.%m.%y. %H:%M:%S.%f', # '25.10.06. 14:30:59.000200' - '%d.%m.%y. %H:%M', # '25.10.06. 14:30' - '%d. %m. %Y. %H:%M:%S', # '25. 10. 2006. 14:30:59' - '%d. %m. %Y. %H:%M:%S.%f', # '25. 10. 2006. 14:30:59.000200' - '%d. %m. %Y. %H:%M', # '25. 10. 2006. 14:30' - '%d. %m. %y. %H:%M:%S', # '25. 10. 06. 14:30:59' - '%d. %m. %y. %H:%M:%S.%f', # '25. 10. 06. 14:30:59.000200' - '%d. %m. %y. %H:%M', # '25. 10. 06. 14:30' + "%d.%m.%Y. %H:%M:%S", # '25.10.2006. 14:30:59' + "%d.%m.%Y. %H:%M:%S.%f", # '25.10.2006. 14:30:59.000200' + "%d.%m.%Y. %H:%M", # '25.10.2006. 14:30' + "%d.%m.%y. %H:%M:%S", # '25.10.06. 14:30:59' + "%d.%m.%y. %H:%M:%S.%f", # '25.10.06. 14:30:59.000200' + "%d.%m.%y. %H:%M", # '25.10.06. 14:30' + "%d. %m. %Y. %H:%M:%S", # '25. 10. 2006. 14:30:59' + "%d. %m. %Y. %H:%M:%S.%f", # '25. 10. 2006. 14:30:59.000200' + "%d. %m. %Y. %H:%M", # '25. 10. 2006. 14:30' + "%d. %m. %y. %H:%M:%S", # '25. 10. 06. 14:30:59' + "%d. %m. %y. %H:%M:%S.%f", # '25. 10. 06. 14:30:59.000200' + "%d. %m. %y. %H:%M", # '25. 10. 06. 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/sr_Latn/formats.py b/django/conf/locale/sr_Latn/formats.py index 5c28783825..0078895923 100644 --- a/django/conf/locale/sr_Latn/formats.py +++ b/django/conf/locale/sr_Latn/formats.py @@ -2,22 +2,22 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j. F Y.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j. F Y. H:i' -YEAR_MONTH_FORMAT = 'F Y.' -MONTH_DAY_FORMAT = 'j. F' -SHORT_DATE_FORMAT = 'j.m.Y.' -SHORT_DATETIME_FORMAT = 'j.m.Y. H:i' +DATE_FORMAT = "j. F Y." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j. F Y. H:i" +YEAR_MONTH_FORMAT = "F Y." +MONTH_DAY_FORMAT = "j. F" +SHORT_DATE_FORMAT = "j.m.Y." +SHORT_DATETIME_FORMAT = "j.m.Y. H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y.', # '25.10.2006.' - '%d.%m.%y.', # '25.10.06.' - '%d. %m. %Y.', # '25. 10. 2006.' - '%d. %m. %y.', # '25. 10. 06.' + "%d.%m.%Y.", # '25.10.2006.' + "%d.%m.%y.", # '25.10.06.' + "%d. %m. %Y.", # '25. 10. 2006.' + "%d. %m. %y.", # '25. 10. 06.' # "%d. %b %y.", # '25. Oct 06.' # "%d. %B %y.", # '25. October 06.' # "%d. %b '%y.", # '25. Oct '06.' @@ -26,19 +26,19 @@ DATE_INPUT_FORMATS = [ # "%d. %B %Y.", # '25. October 2006.' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y. %H:%M:%S', # '25.10.2006. 14:30:59' - '%d.%m.%Y. %H:%M:%S.%f', # '25.10.2006. 14:30:59.000200' - '%d.%m.%Y. %H:%M', # '25.10.2006. 14:30' - '%d.%m.%y. %H:%M:%S', # '25.10.06. 14:30:59' - '%d.%m.%y. %H:%M:%S.%f', # '25.10.06. 14:30:59.000200' - '%d.%m.%y. %H:%M', # '25.10.06. 14:30' - '%d. %m. %Y. %H:%M:%S', # '25. 10. 2006. 14:30:59' - '%d. %m. %Y. %H:%M:%S.%f', # '25. 10. 2006. 14:30:59.000200' - '%d. %m. %Y. %H:%M', # '25. 10. 2006. 14:30' - '%d. %m. %y. %H:%M:%S', # '25. 10. 06. 14:30:59' - '%d. %m. %y. %H:%M:%S.%f', # '25. 10. 06. 14:30:59.000200' - '%d. %m. %y. %H:%M', # '25. 10. 06. 14:30' + "%d.%m.%Y. %H:%M:%S", # '25.10.2006. 14:30:59' + "%d.%m.%Y. %H:%M:%S.%f", # '25.10.2006. 14:30:59.000200' + "%d.%m.%Y. %H:%M", # '25.10.2006. 14:30' + "%d.%m.%y. %H:%M:%S", # '25.10.06. 14:30:59' + "%d.%m.%y. %H:%M:%S.%f", # '25.10.06. 14:30:59.000200' + "%d.%m.%y. %H:%M", # '25.10.06. 14:30' + "%d. %m. %Y. %H:%M:%S", # '25. 10. 2006. 14:30:59' + "%d. %m. %Y. %H:%M:%S.%f", # '25. 10. 2006. 14:30:59.000200' + "%d. %m. %Y. %H:%M", # '25. 10. 2006. 14:30' + "%d. %m. %y. %H:%M:%S", # '25. 10. 06. 14:30:59' + "%d. %m. %y. %H:%M:%S.%f", # '25. 10. 06. 14:30:59.000200' + "%d. %m. %y. %H:%M", # '25. 10. 06. 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/sv/formats.py b/django/conf/locale/sv/formats.py index 9467526893..29e6317392 100644 --- a/django/conf/locale/sv/formats.py +++ b/django/conf/locale/sv/formats.py @@ -2,34 +2,34 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'j F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'Y-m-d' -SHORT_DATETIME_FORMAT = 'Y-m-d H:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "j F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "Y-m-d" +SHORT_DATETIME_FORMAT = "Y-m-d H:i" FIRST_DAY_OF_WEEK = 1 # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior # Kept ISO formats as they are in first position DATE_INPUT_FORMATS = [ - '%Y-%m-%d', # '2006-10-25' - '%m/%d/%Y', # '10/25/2006' - '%m/%d/%y', # '10/25/06' + "%Y-%m-%d", # '2006-10-25' + "%m/%d/%Y", # '10/25/2006' + "%m/%d/%y", # '10/25/06' ] DATETIME_INPUT_FORMATS = [ - '%Y-%m-%d %H:%M:%S', # '2006-10-25 14:30:59' - '%Y-%m-%d %H:%M:%S.%f', # '2006-10-25 14:30:59.000200' - '%Y-%m-%d %H:%M', # '2006-10-25 14:30' - '%m/%d/%Y %H:%M:%S', # '10/25/2006 14:30:59' - '%m/%d/%Y %H:%M:%S.%f', # '10/25/2006 14:30:59.000200' - '%m/%d/%Y %H:%M', # '10/25/2006 14:30' - '%m/%d/%y %H:%M:%S', # '10/25/06 14:30:59' - '%m/%d/%y %H:%M:%S.%f', # '10/25/06 14:30:59.000200' - '%m/%d/%y %H:%M', # '10/25/06 14:30' + "%Y-%m-%d %H:%M:%S", # '2006-10-25 14:30:59' + "%Y-%m-%d %H:%M:%S.%f", # '2006-10-25 14:30:59.000200' + "%Y-%m-%d %H:%M", # '2006-10-25 14:30' + "%m/%d/%Y %H:%M:%S", # '10/25/2006 14:30:59' + "%m/%d/%Y %H:%M:%S.%f", # '10/25/2006 14:30:59.000200' + "%m/%d/%Y %H:%M", # '10/25/2006 14:30' + "%m/%d/%y %H:%M:%S", # '10/25/06 14:30:59' + "%m/%d/%y %H:%M:%S.%f", # '10/25/06 14:30:59.000200' + "%m/%d/%y %H:%M", # '10/25/06 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/ta/formats.py b/django/conf/locale/ta/formats.py index 61810e3fa7..d023608ca2 100644 --- a/django/conf/locale/ta/formats.py +++ b/django/conf/locale/ta/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F, Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "j F, Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M, Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M, Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = diff --git a/django/conf/locale/te/formats.py b/django/conf/locale/te/formats.py index 8fb98cf720..bb7f2d13d2 100644 --- a/django/conf/locale/te/formats.py +++ b/django/conf/locale/te/formats.py @@ -2,12 +2,12 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'g:i A' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "g:i A" # DATETIME_FORMAT = # YEAR_MONTH_FORMAT = -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" # SHORT_DATETIME_FORMAT = # FIRST_DAY_OF_WEEK = diff --git a/django/conf/locale/tg/formats.py b/django/conf/locale/tg/formats.py index 3e7651d755..0ab7d49ae5 100644 --- a/django/conf/locale/tg/formats.py +++ b/django/conf/locale/tg/formats.py @@ -2,31 +2,31 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y г.' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j E Y г. G:i' -YEAR_MONTH_FORMAT = 'F Y г.' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j E Y г." +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j E Y г. G:i" +YEAR_MONTH_FORMAT = "F Y г." +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y", # '25.10.06' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/th/formats.py b/django/conf/locale/th/formats.py index d7394eb69c..190e6d196c 100644 --- a/django/conf/locale/th/formats.py +++ b/django/conf/locale/th/formats.py @@ -2,32 +2,32 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j F Y' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j F Y, G:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'j M Y' -SHORT_DATETIME_FORMAT = 'j M Y, G:i' +DATE_FORMAT = "j F Y" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j F Y, G:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "j M Y" +SHORT_DATETIME_FORMAT = "j M Y, G:i" FIRST_DAY_OF_WEEK = 0 # Sunday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # 25/10/2006 - '%d %b %Y', # 25 ต.ค. 2006 - '%d %B %Y', # 25 ตุลาคม 2006 + "%d/%m/%Y", # 25/10/2006 + "%d %b %Y", # 25 ต.ค. 2006 + "%d %B %Y", # 25 ตุลาคม 2006 ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # 14:30:59 - '%H:%M:%S.%f', # 14:30:59.000200 - '%H:%M', # 14:30 + "%H:%M:%S", # 14:30:59 + "%H:%M:%S.%f", # 14:30:59.000200 + "%H:%M", # 14:30 ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # 25/10/2006 14:30:59 - '%d/%m/%Y %H:%M:%S.%f', # 25/10/2006 14:30:59.000200 - '%d/%m/%Y %H:%M', # 25/10/2006 14:30 + "%d/%m/%Y %H:%M:%S", # 25/10/2006 14:30:59 + "%d/%m/%Y %H:%M:%S.%f", # 25/10/2006 14:30:59.000200 + "%d/%m/%Y %H:%M", # 25/10/2006 14:30 ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = ',' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "," NUMBER_GROUPING = 3 diff --git a/django/conf/locale/tk/formats.py b/django/conf/locale/tk/formats.py index 3e7651d755..0ab7d49ae5 100644 --- a/django/conf/locale/tk/formats.py +++ b/django/conf/locale/tk/formats.py @@ -2,31 +2,31 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'j E Y г.' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = 'j E Y г. G:i' -YEAR_MONTH_FORMAT = 'F Y г.' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "j E Y г." +TIME_FORMAT = "G:i" +DATETIME_FORMAT = "j E Y г. G:i" +YEAR_MONTH_FORMAT = "F Y г." +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y", # '25.10.06' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d.%m.%Y', # '25.10.2006' - '%d.%m.%y %H:%M:%S', # '25.10.06 14:30:59' - '%d.%m.%y %H:%M:%S.%f', # '25.10.06 14:30:59.000200' - '%d.%m.%y %H:%M', # '25.10.06 14:30' - '%d.%m.%y', # '25.10.06' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d.%m.%Y", # '25.10.2006' + "%d.%m.%y %H:%M:%S", # '25.10.06 14:30:59' + "%d.%m.%y %H:%M:%S.%f", # '25.10.06 14:30:59.000200' + "%d.%m.%y %H:%M", # '25.10.06 14:30' + "%d.%m.%y", # '25.10.06' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/tr/formats.py b/django/conf/locale/tr/formats.py index 74bdab729b..806f4428d1 100644 --- a/django/conf/locale/tr/formats.py +++ b/django/conf/locale/tr/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd F Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'd F Y H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'd F' -SHORT_DATE_FORMAT = 'd M Y' -SHORT_DATETIME_FORMAT = 'd M Y H:i' +DATE_FORMAT = "d F Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "d F Y H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "d F" +SHORT_DATE_FORMAT = "d M Y" +SHORT_DATETIME_FORMAT = "d M Y H:i" FIRST_DAY_OF_WEEK = 1 # Pazartesi # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d/%m/%Y', # '25/10/2006' - '%d/%m/%y', # '25/10/06' - '%y-%m-%d', # '06-10-25' + "%d/%m/%Y", # '25/10/2006' + "%d/%m/%y", # '25/10/06' + "%y-%m-%d", # '06-10-25' # "%d %B %Y", # '25 Ekim 2006' # "%d %b. %Y", # '25 Eki. 2006' ] DATETIME_INPUT_FORMATS = [ - '%d/%m/%Y %H:%M:%S', # '25/10/2006 14:30:59' - '%d/%m/%Y %H:%M:%S.%f', # '25/10/2006 14:30:59.000200' - '%d/%m/%Y %H:%M', # '25/10/2006 14:30' + "%d/%m/%Y %H:%M:%S", # '25/10/2006 14:30:59' + "%d/%m/%Y %H:%M:%S.%f", # '25/10/2006 14:30:59.000200' + "%d/%m/%Y %H:%M", # '25/10/2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." NUMBER_GROUPING = 3 diff --git a/django/conf/locale/uk/formats.py b/django/conf/locale/uk/formats.py index ca2593beba..0f28831af5 100644 --- a/django/conf/locale/uk/formats.py +++ b/django/conf/locale/uk/formats.py @@ -2,34 +2,34 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'd E Y р.' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = 'd E Y р. H:i' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'd F' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = "d E Y р." +TIME_FORMAT = "H:i" +DATETIME_FORMAT = "d E Y р. H:i" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "d F" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d %B %Y', # '25 October 2006' + "%d.%m.%Y", # '25.10.2006' + "%d %B %Y", # '25 October 2006' ] TIME_INPUT_FORMATS = [ - '%H:%M:%S', # '14:30:59' - '%H:%M:%S.%f', # '14:30:59.000200' - '%H:%M', # '14:30' + "%H:%M:%S", # '14:30:59' + "%H:%M:%S.%f", # '14:30:59.000200' + "%H:%M", # '14:30' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d %B %Y %H:%M:%S', # '25 October 2006 14:30:59' - '%d %B %Y %H:%M:%S.%f', # '25 October 2006 14:30:59.000200' - '%d %B %Y %H:%M', # '25 October 2006 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d %B %Y %H:%M:%S", # '25 October 2006 14:30:59' + "%d %B %Y %H:%M:%S.%f", # '25 October 2006 14:30:59.000200' + "%d %B %Y %H:%M", # '25 October 2006 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/uz/formats.py b/django/conf/locale/uz/formats.py index 14af096f96..2c7ee73a17 100644 --- a/django/conf/locale/uz/formats.py +++ b/django/conf/locale/uz/formats.py @@ -2,29 +2,29 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'j-E, Y-\y\i\l' -TIME_FORMAT = 'G:i' -DATETIME_FORMAT = r'j-E, Y-\y\i\l G:i' -YEAR_MONTH_FORMAT = r'F Y-\y\i\l' -MONTH_DAY_FORMAT = 'j-E' -SHORT_DATE_FORMAT = 'd.m.Y' -SHORT_DATETIME_FORMAT = 'd.m.Y H:i' +DATE_FORMAT = r"j-E, Y-\y\i\l" +TIME_FORMAT = "G:i" +DATETIME_FORMAT = r"j-E, Y-\y\i\l G:i" +YEAR_MONTH_FORMAT = r"F Y-\y\i\l" +MONTH_DAY_FORMAT = "j-E" +SHORT_DATE_FORMAT = "d.m.Y" +SHORT_DATETIME_FORMAT = "d.m.Y H:i" FIRST_DAY_OF_WEEK = 1 # Monday # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%d.%m.%Y', # '25.10.2006' - '%d-%B, %Y-yil', # '25-Oktabr, 2006-yil' + "%d.%m.%Y", # '25.10.2006' + "%d-%B, %Y-yil", # '25-Oktabr, 2006-yil' ] DATETIME_INPUT_FORMATS = [ - '%d.%m.%Y %H:%M:%S', # '25.10.2006 14:30:59' - '%d.%m.%Y %H:%M:%S.%f', # '25.10.2006 14:30:59.000200' - '%d.%m.%Y %H:%M', # '25.10.2006 14:30' - '%d-%B, %Y-yil %H:%M:%S', # '25-Oktabr, 2006-yil 14:30:59' - '%d-%B, %Y-yil %H:%M:%S.%f', # '25-Oktabr, 2006-yil 14:30:59.000200' - '%d-%B, %Y-yil %H:%M', # '25-Oktabr, 2006-yil 14:30' + "%d.%m.%Y %H:%M:%S", # '25.10.2006 14:30:59' + "%d.%m.%Y %H:%M:%S.%f", # '25.10.2006 14:30:59.000200' + "%d.%m.%Y %H:%M", # '25.10.2006 14:30' + "%d-%B, %Y-yil %H:%M:%S", # '25-Oktabr, 2006-yil 14:30:59' + "%d-%B, %Y-yil %H:%M:%S.%f", # '25-Oktabr, 2006-yil 14:30:59.000200' + "%d-%B, %Y-yil %H:%M", # '25-Oktabr, 2006-yil 14:30' ] -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '\xa0' # non-breaking space +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "\xa0" # non-breaking space NUMBER_GROUPING = 3 diff --git a/django/conf/locale/vi/formats.py b/django/conf/locale/vi/formats.py index 495b6f7993..7b7602044a 100644 --- a/django/conf/locale/vi/formats.py +++ b/django/conf/locale/vi/formats.py @@ -2,13 +2,13 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = r'\N\gà\y d \t\há\n\g n \nă\m Y' -TIME_FORMAT = 'H:i' -DATETIME_FORMAT = r'H:i \N\gà\y d \t\há\n\g n \nă\m Y' -YEAR_MONTH_FORMAT = 'F Y' -MONTH_DAY_FORMAT = 'j F' -SHORT_DATE_FORMAT = 'd-m-Y' -SHORT_DATETIME_FORMAT = 'H:i d-m-Y' +DATE_FORMAT = r"\N\gà\y d \t\há\n\g n \nă\m Y" +TIME_FORMAT = "H:i" +DATETIME_FORMAT = r"H:i \N\gà\y d \t\há\n\g n \nă\m Y" +YEAR_MONTH_FORMAT = "F Y" +MONTH_DAY_FORMAT = "j F" +SHORT_DATE_FORMAT = "d-m-Y" +SHORT_DATETIME_FORMAT = "H:i d-m-Y" # FIRST_DAY_OF_WEEK = # The *_INPUT_FORMATS strings use the Python strftime format syntax, @@ -16,6 +16,6 @@ SHORT_DATETIME_FORMAT = 'H:i d-m-Y' # DATE_INPUT_FORMATS = # TIME_INPUT_FORMATS = # DATETIME_INPUT_FORMATS = -DECIMAL_SEPARATOR = ',' -THOUSAND_SEPARATOR = '.' +DECIMAL_SEPARATOR = "," +THOUSAND_SEPARATOR = "." # NUMBER_GROUPING = diff --git a/django/conf/locale/zh_Hans/formats.py b/django/conf/locale/zh_Hans/formats.py index 018b9b17f4..79936f8a34 100644 --- a/django/conf/locale/zh_Hans/formats.py +++ b/django/conf/locale/zh_Hans/formats.py @@ -2,41 +2,41 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'Y年n月j日' # 2016年9月5日 -TIME_FORMAT = 'H:i' # 20:45 -DATETIME_FORMAT = 'Y年n月j日 H:i' # 2016年9月5日 20:45 -YEAR_MONTH_FORMAT = 'Y年n月' # 2016年9月 -MONTH_DAY_FORMAT = 'm月j日' # 9月5日 -SHORT_DATE_FORMAT = 'Y年n月j日' # 2016年9月5日 -SHORT_DATETIME_FORMAT = 'Y年n月j日 H:i' # 2016年9月5日 20:45 -FIRST_DAY_OF_WEEK = 1 # 星期一 (Monday) +DATE_FORMAT = "Y年n月j日" # 2016年9月5日 +TIME_FORMAT = "H:i" # 20:45 +DATETIME_FORMAT = "Y年n月j日 H:i" # 2016年9月5日 20:45 +YEAR_MONTH_FORMAT = "Y年n月" # 2016年9月 +MONTH_DAY_FORMAT = "m月j日" # 9月5日 +SHORT_DATE_FORMAT = "Y年n月j日" # 2016年9月5日 +SHORT_DATETIME_FORMAT = "Y年n月j日 H:i" # 2016年9月5日 20:45 +FIRST_DAY_OF_WEEK = 1 # 星期一 (Monday) # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y/%m/%d', # '2016/09/05' - '%Y-%m-%d', # '2016-09-05' - '%Y年%n月%j日', # '2016年9月5日' + "%Y/%m/%d", # '2016/09/05' + "%Y-%m-%d", # '2016-09-05' + "%Y年%n月%j日", # '2016年9月5日' ] TIME_INPUT_FORMATS = [ - '%H:%M', # '20:45' - '%H:%M:%S', # '20:45:29' - '%H:%M:%S.%f', # '20:45:29.000200' + "%H:%M", # '20:45' + "%H:%M:%S", # '20:45:29' + "%H:%M:%S.%f", # '20:45:29.000200' ] DATETIME_INPUT_FORMATS = [ - '%Y/%m/%d %H:%M', # '2016/09/05 20:45' - '%Y-%m-%d %H:%M', # '2016-09-05 20:45' - '%Y年%n月%j日 %H:%M', # '2016年9月5日 14:45' - '%Y/%m/%d %H:%M:%S', # '2016/09/05 20:45:29' - '%Y-%m-%d %H:%M:%S', # '2016-09-05 20:45:29' - '%Y年%n月%j日 %H:%M:%S', # '2016年9月5日 20:45:29' - '%Y/%m/%d %H:%M:%S.%f', # '2016/09/05 20:45:29.000200' - '%Y-%m-%d %H:%M:%S.%f', # '2016-09-05 20:45:29.000200' - '%Y年%n月%j日 %H:%n:%S.%f', # '2016年9月5日 20:45:29.000200' + "%Y/%m/%d %H:%M", # '2016/09/05 20:45' + "%Y-%m-%d %H:%M", # '2016-09-05 20:45' + "%Y年%n月%j日 %H:%M", # '2016年9月5日 14:45' + "%Y/%m/%d %H:%M:%S", # '2016/09/05 20:45:29' + "%Y-%m-%d %H:%M:%S", # '2016-09-05 20:45:29' + "%Y年%n月%j日 %H:%M:%S", # '2016年9月5日 20:45:29' + "%Y/%m/%d %H:%M:%S.%f", # '2016/09/05 20:45:29.000200' + "%Y-%m-%d %H:%M:%S.%f", # '2016-09-05 20:45:29.000200' + "%Y年%n月%j日 %H:%n:%S.%f", # '2016年9月5日 20:45:29.000200' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = '' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "" NUMBER_GROUPING = 4 diff --git a/django/conf/locale/zh_Hant/formats.py b/django/conf/locale/zh_Hant/formats.py index 018b9b17f4..79936f8a34 100644 --- a/django/conf/locale/zh_Hant/formats.py +++ b/django/conf/locale/zh_Hant/formats.py @@ -2,41 +2,41 @@ # # The *_FORMAT strings use the Django date format syntax, # see https://docs.djangoproject.com/en/dev/ref/templates/builtins/#date -DATE_FORMAT = 'Y年n月j日' # 2016年9月5日 -TIME_FORMAT = 'H:i' # 20:45 -DATETIME_FORMAT = 'Y年n月j日 H:i' # 2016年9月5日 20:45 -YEAR_MONTH_FORMAT = 'Y年n月' # 2016年9月 -MONTH_DAY_FORMAT = 'm月j日' # 9月5日 -SHORT_DATE_FORMAT = 'Y年n月j日' # 2016年9月5日 -SHORT_DATETIME_FORMAT = 'Y年n月j日 H:i' # 2016年9月5日 20:45 -FIRST_DAY_OF_WEEK = 1 # 星期一 (Monday) +DATE_FORMAT = "Y年n月j日" # 2016年9月5日 +TIME_FORMAT = "H:i" # 20:45 +DATETIME_FORMAT = "Y年n月j日 H:i" # 2016年9月5日 20:45 +YEAR_MONTH_FORMAT = "Y年n月" # 2016年9月 +MONTH_DAY_FORMAT = "m月j日" # 9月5日 +SHORT_DATE_FORMAT = "Y年n月j日" # 2016年9月5日 +SHORT_DATETIME_FORMAT = "Y年n月j日 H:i" # 2016年9月5日 20:45 +FIRST_DAY_OF_WEEK = 1 # 星期一 (Monday) # The *_INPUT_FORMATS strings use the Python strftime format syntax, # see https://docs.python.org/library/datetime.html#strftime-strptime-behavior DATE_INPUT_FORMATS = [ - '%Y/%m/%d', # '2016/09/05' - '%Y-%m-%d', # '2016-09-05' - '%Y年%n月%j日', # '2016年9月5日' + "%Y/%m/%d", # '2016/09/05' + "%Y-%m-%d", # '2016-09-05' + "%Y年%n月%j日", # '2016年9月5日' ] TIME_INPUT_FORMATS = [ - '%H:%M', # '20:45' - '%H:%M:%S', # '20:45:29' - '%H:%M:%S.%f', # '20:45:29.000200' + "%H:%M", # '20:45' + "%H:%M:%S", # '20:45:29' + "%H:%M:%S.%f", # '20:45:29.000200' ] DATETIME_INPUT_FORMATS = [ - '%Y/%m/%d %H:%M', # '2016/09/05 20:45' - '%Y-%m-%d %H:%M', # '2016-09-05 20:45' - '%Y年%n月%j日 %H:%M', # '2016年9月5日 14:45' - '%Y/%m/%d %H:%M:%S', # '2016/09/05 20:45:29' - '%Y-%m-%d %H:%M:%S', # '2016-09-05 20:45:29' - '%Y年%n月%j日 %H:%M:%S', # '2016年9月5日 20:45:29' - '%Y/%m/%d %H:%M:%S.%f', # '2016/09/05 20:45:29.000200' - '%Y-%m-%d %H:%M:%S.%f', # '2016-09-05 20:45:29.000200' - '%Y年%n月%j日 %H:%n:%S.%f', # '2016年9月5日 20:45:29.000200' + "%Y/%m/%d %H:%M", # '2016/09/05 20:45' + "%Y-%m-%d %H:%M", # '2016-09-05 20:45' + "%Y年%n月%j日 %H:%M", # '2016年9月5日 14:45' + "%Y/%m/%d %H:%M:%S", # '2016/09/05 20:45:29' + "%Y-%m-%d %H:%M:%S", # '2016-09-05 20:45:29' + "%Y年%n月%j日 %H:%M:%S", # '2016年9月5日 20:45:29' + "%Y/%m/%d %H:%M:%S.%f", # '2016/09/05 20:45:29.000200' + "%Y-%m-%d %H:%M:%S.%f", # '2016-09-05 20:45:29.000200' + "%Y年%n月%j日 %H:%n:%S.%f", # '2016年9月5日 20:45:29.000200' ] -DECIMAL_SEPARATOR = '.' -THOUSAND_SEPARATOR = '' +DECIMAL_SEPARATOR = "." +THOUSAND_SEPARATOR = "" NUMBER_GROUPING = 4 diff --git a/django/conf/urls/__init__.py b/django/conf/urls/__init__.py index 1ec5da82ad..302f68dd07 100644 --- a/django/conf/urls/__init__.py +++ b/django/conf/urls/__init__.py @@ -1,7 +1,7 @@ from django.urls import include from django.views import defaults -__all__ = ['handler400', 'handler403', 'handler404', 'handler500', 'include'] +__all__ = ["handler400", "handler403", "handler404", "handler500", "include"] handler400 = defaults.bad_request handler403 = defaults.permission_denied diff --git a/django/conf/urls/i18n.py b/django/conf/urls/i18n.py index 256c247491..ebe5d51b14 100644 --- a/django/conf/urls/i18n.py +++ b/django/conf/urls/i18n.py @@ -35,5 +35,5 @@ def is_language_prefix_patterns_used(urlconf): urlpatterns = [ - path('setlang/', set_language, name='set_language'), + path("setlang/", set_language, name="set_language"), ] diff --git a/django/conf/urls/static.py b/django/conf/urls/static.py index fa83645b9d..8e7816e772 100644 --- a/django/conf/urls/static.py +++ b/django/conf/urls/static.py @@ -24,5 +24,7 @@ def static(prefix, view=serve, **kwargs): # No-op if not in debug mode or a non-local prefix. return [] return [ - re_path(r'^%s(?P.*)$' % re.escape(prefix.lstrip('/')), view, kwargs=kwargs), + re_path( + r"^%s(?P.*)$" % re.escape(prefix.lstrip("/")), view, kwargs=kwargs + ), ] diff --git a/django/contrib/admin/__init__.py b/django/contrib/admin/__init__.py index 975cf053aa..ef5c64ffef 100644 --- a/django/contrib/admin/__init__.py +++ b/django/contrib/admin/__init__.py @@ -1,24 +1,50 @@ from django.contrib.admin.decorators import action, display, register from django.contrib.admin.filters import ( - AllValuesFieldListFilter, BooleanFieldListFilter, ChoicesFieldListFilter, - DateFieldListFilter, EmptyFieldListFilter, FieldListFilter, ListFilter, - RelatedFieldListFilter, RelatedOnlyFieldListFilter, SimpleListFilter, + AllValuesFieldListFilter, + BooleanFieldListFilter, + ChoicesFieldListFilter, + DateFieldListFilter, + EmptyFieldListFilter, + FieldListFilter, + ListFilter, + RelatedFieldListFilter, + RelatedOnlyFieldListFilter, + SimpleListFilter, ) from django.contrib.admin.options import ( - HORIZONTAL, VERTICAL, ModelAdmin, StackedInline, TabularInline, + HORIZONTAL, + VERTICAL, + ModelAdmin, + StackedInline, + TabularInline, ) from django.contrib.admin.sites import AdminSite, site from django.utils.module_loading import autodiscover_modules __all__ = [ - "action", "display", "register", "ModelAdmin", "HORIZONTAL", "VERTICAL", - "StackedInline", "TabularInline", "AdminSite", "site", "ListFilter", - "SimpleListFilter", "FieldListFilter", "BooleanFieldListFilter", - "RelatedFieldListFilter", "ChoicesFieldListFilter", "DateFieldListFilter", - "AllValuesFieldListFilter", "EmptyFieldListFilter", - "RelatedOnlyFieldListFilter", "autodiscover", + "action", + "display", + "register", + "ModelAdmin", + "HORIZONTAL", + "VERTICAL", + "StackedInline", + "TabularInline", + "AdminSite", + "site", + "ListFilter", + "SimpleListFilter", + "FieldListFilter", + "BooleanFieldListFilter", + "RelatedFieldListFilter", + "ChoicesFieldListFilter", + "DateFieldListFilter", + "AllValuesFieldListFilter", + "EmptyFieldListFilter", + "RelatedOnlyFieldListFilter", + "autodiscover", ] def autodiscover(): - autodiscover_modules('admin', register_to=site) + autodiscover_modules("admin", register_to=site) diff --git a/django/contrib/admin/actions.py b/django/contrib/admin/actions.py index 8eab4ae5f7..e1e76cc31a 100644 --- a/django/contrib/admin/actions.py +++ b/django/contrib/admin/actions.py @@ -8,12 +8,13 @@ from django.contrib.admin.decorators import action from django.contrib.admin.utils import model_ngettext from django.core.exceptions import PermissionDenied from django.template.response import TemplateResponse -from django.utils.translation import gettext as _, gettext_lazy +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy @action( - permissions=['delete'], - description=gettext_lazy('Delete selected %(verbose_name_plural)s'), + permissions=["delete"], + description=gettext_lazy("Delete selected %(verbose_name_plural)s"), ) def delete_selected(modeladmin, request, queryset): """ @@ -30,11 +31,16 @@ def delete_selected(modeladmin, request, queryset): # Populate deletable_objects, a data structure of all related objects that # will also be deleted. - deletable_objects, model_count, perms_needed, protected = modeladmin.get_deleted_objects(queryset, request) + ( + deletable_objects, + model_count, + perms_needed, + protected, + ) = modeladmin.get_deleted_objects(queryset, request) # The user has already confirmed the deletion. # Do the deletion and return None to display the change list view again. - if request.POST.get('post') and not protected: + if request.POST.get("post") and not protected: if perms_needed: raise PermissionDenied n = queryset.count() @@ -43,9 +49,12 @@ def delete_selected(modeladmin, request, queryset): obj_display = str(obj) modeladmin.log_deletion(request, obj, obj_display) modeladmin.delete_queryset(request, queryset) - modeladmin.message_user(request, _("Successfully deleted %(count)d %(items)s.") % { - "count": n, "items": model_ngettext(modeladmin.opts, n) - }, messages.SUCCESS) + modeladmin.message_user( + request, + _("Successfully deleted %(count)d %(items)s.") + % {"count": n, "items": model_ngettext(modeladmin.opts, n)}, + messages.SUCCESS, + ) # Return None to display the change list page again. return None @@ -58,24 +67,30 @@ def delete_selected(modeladmin, request, queryset): context = { **modeladmin.admin_site.each_context(request), - 'title': title, - 'subtitle': None, - 'objects_name': str(objects_name), - 'deletable_objects': [deletable_objects], - 'model_count': dict(model_count).items(), - 'queryset': queryset, - 'perms_lacking': perms_needed, - 'protected': protected, - 'opts': opts, - 'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME, - 'media': modeladmin.media, + "title": title, + "subtitle": None, + "objects_name": str(objects_name), + "deletable_objects": [deletable_objects], + "model_count": dict(model_count).items(), + "queryset": queryset, + "perms_lacking": perms_needed, + "protected": protected, + "opts": opts, + "action_checkbox_name": helpers.ACTION_CHECKBOX_NAME, + "media": modeladmin.media, } request.current_app = modeladmin.admin_site.name # Display the confirmation page - return TemplateResponse(request, modeladmin.delete_selected_confirmation_template or [ - "admin/%s/%s/delete_selected_confirmation.html" % (app_label, opts.model_name), - "admin/%s/delete_selected_confirmation.html" % app_label, - "admin/delete_selected_confirmation.html" - ], context) + return TemplateResponse( + request, + modeladmin.delete_selected_confirmation_template + or [ + "admin/%s/%s/delete_selected_confirmation.html" + % (app_label, opts.model_name), + "admin/%s/delete_selected_confirmation.html" % app_label, + "admin/delete_selected_confirmation.html", + ], + context, + ) diff --git a/django/contrib/admin/apps.py b/django/contrib/admin/apps.py index c4fba8837c..08a9e0d832 100644 --- a/django/contrib/admin/apps.py +++ b/django/contrib/admin/apps.py @@ -7,9 +7,9 @@ from django.utils.translation import gettext_lazy as _ class SimpleAdminConfig(AppConfig): """Simple AppConfig which does not do automatic discovery.""" - default_auto_field = 'django.db.models.AutoField' - default_site = 'django.contrib.admin.sites.AdminSite' - name = 'django.contrib.admin' + default_auto_field = "django.db.models.AutoField" + default_site = "django.contrib.admin.sites.AdminSite" + name = "django.contrib.admin" verbose_name = _("Administration") def ready(self): diff --git a/django/contrib/admin/checks.py b/django/contrib/admin/checks.py index 045aaca346..7449503f5c 100644 --- a/django/contrib/admin/checks.py +++ b/django/contrib/admin/checks.py @@ -3,17 +3,13 @@ from itertools import chain from django.apps import apps from django.conf import settings -from django.contrib.admin.utils import ( - NotRelationField, flatten, get_fields_from_path, -) +from django.contrib.admin.utils import NotRelationField, flatten, get_fields_from_path from django.core import checks from django.core.exceptions import FieldDoesNotExist from django.db import models from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Combinable -from django.forms.models import ( - BaseModelForm, BaseModelFormSet, _get_foreign_key, -) +from django.forms.models import BaseModelForm, BaseModelFormSet, _get_foreign_key from django.template import engines from django.template.backends.django import DjangoTemplates from django.utils.module_loading import import_string @@ -49,6 +45,7 @@ def _contains_subclass(class_path, candidate_paths): def check_admin_app(app_configs, **kwargs): from django.contrib.admin.sites import all_sites + errors = [] for site in all_sites: errors.extend(site.check(app_configs)) @@ -60,21 +57,24 @@ def check_dependencies(**kwargs): Check that the admin's dependencies are correctly installed. """ from django.contrib.admin.sites import all_sites - if not apps.is_installed('django.contrib.admin'): + + if not apps.is_installed("django.contrib.admin"): return [] errors = [] app_dependencies = ( - ('django.contrib.contenttypes', 401), - ('django.contrib.auth', 405), - ('django.contrib.messages', 406), + ("django.contrib.contenttypes", 401), + ("django.contrib.auth", 405), + ("django.contrib.messages", 406), ) for app_name, error_code in app_dependencies: if not apps.is_installed(app_name): - errors.append(checks.Error( - "'%s' must be in INSTALLED_APPS in order to use the admin " - "application." % app_name, - id='admin.E%d' % error_code, - )) + errors.append( + checks.Error( + "'%s' must be in INSTALLED_APPS in order to use the admin " + "application." % app_name, + id="admin.E%d" % error_code, + ) + ) for engine in engines.all(): if isinstance(engine, DjangoTemplates): django_templates_instance = engine.engine @@ -82,69 +82,98 @@ def check_dependencies(**kwargs): else: django_templates_instance = None if not django_templates_instance: - errors.append(checks.Error( - "A 'django.template.backends.django.DjangoTemplates' instance " - "must be configured in TEMPLATES in order to use the admin " - "application.", - id='admin.E403', - )) + errors.append( + checks.Error( + "A 'django.template.backends.django.DjangoTemplates' instance " + "must be configured in TEMPLATES in order to use the admin " + "application.", + id="admin.E403", + ) + ) else: - if ('django.contrib.auth.context_processors.auth' - not in django_templates_instance.context_processors and - _contains_subclass('django.contrib.auth.backends.ModelBackend', settings.AUTHENTICATION_BACKENDS)): - errors.append(checks.Error( - "'django.contrib.auth.context_processors.auth' must be " - "enabled in DjangoTemplates (TEMPLATES) if using the default " - "auth backend in order to use the admin application.", - id='admin.E402', - )) - if ('django.contrib.messages.context_processors.messages' - not in django_templates_instance.context_processors): - errors.append(checks.Error( - "'django.contrib.messages.context_processors.messages' must " - "be enabled in DjangoTemplates (TEMPLATES) in order to use " - "the admin application.", - id='admin.E404', - )) + if ( + "django.contrib.auth.context_processors.auth" + not in django_templates_instance.context_processors + and _contains_subclass( + "django.contrib.auth.backends.ModelBackend", + settings.AUTHENTICATION_BACKENDS, + ) + ): + errors.append( + checks.Error( + "'django.contrib.auth.context_processors.auth' must be " + "enabled in DjangoTemplates (TEMPLATES) if using the default " + "auth backend in order to use the admin application.", + id="admin.E402", + ) + ) + if ( + "django.contrib.messages.context_processors.messages" + not in django_templates_instance.context_processors + ): + errors.append( + checks.Error( + "'django.contrib.messages.context_processors.messages' must " + "be enabled in DjangoTemplates (TEMPLATES) in order to use " + "the admin application.", + id="admin.E404", + ) + ) sidebar_enabled = any(site.enable_nav_sidebar for site in all_sites) - if (sidebar_enabled and 'django.template.context_processors.request' - not in django_templates_instance.context_processors): - errors.append(checks.Warning( - "'django.template.context_processors.request' must be enabled " - "in DjangoTemplates (TEMPLATES) in order to use the admin " - "navigation sidebar.", - id='admin.W411', - )) + if ( + sidebar_enabled + and "django.template.context_processors.request" + not in django_templates_instance.context_processors + ): + errors.append( + checks.Warning( + "'django.template.context_processors.request' must be enabled " + "in DjangoTemplates (TEMPLATES) in order to use the admin " + "navigation sidebar.", + id="admin.W411", + ) + ) - if not _contains_subclass('django.contrib.auth.middleware.AuthenticationMiddleware', settings.MIDDLEWARE): - errors.append(checks.Error( - "'django.contrib.auth.middleware.AuthenticationMiddleware' must " - "be in MIDDLEWARE in order to use the admin application.", - id='admin.E408', - )) - if not _contains_subclass('django.contrib.messages.middleware.MessageMiddleware', settings.MIDDLEWARE): - errors.append(checks.Error( - "'django.contrib.messages.middleware.MessageMiddleware' must " - "be in MIDDLEWARE in order to use the admin application.", - id='admin.E409', - )) - if not _contains_subclass('django.contrib.sessions.middleware.SessionMiddleware', settings.MIDDLEWARE): - errors.append(checks.Error( - "'django.contrib.sessions.middleware.SessionMiddleware' must " - "be in MIDDLEWARE in order to use the admin application.", - hint=( - "Insert " - "'django.contrib.sessions.middleware.SessionMiddleware' " - "before " - "'django.contrib.auth.middleware.AuthenticationMiddleware'." - ), - id='admin.E410', - )) + if not _contains_subclass( + "django.contrib.auth.middleware.AuthenticationMiddleware", settings.MIDDLEWARE + ): + errors.append( + checks.Error( + "'django.contrib.auth.middleware.AuthenticationMiddleware' must " + "be in MIDDLEWARE in order to use the admin application.", + id="admin.E408", + ) + ) + if not _contains_subclass( + "django.contrib.messages.middleware.MessageMiddleware", settings.MIDDLEWARE + ): + errors.append( + checks.Error( + "'django.contrib.messages.middleware.MessageMiddleware' must " + "be in MIDDLEWARE in order to use the admin application.", + id="admin.E409", + ) + ) + if not _contains_subclass( + "django.contrib.sessions.middleware.SessionMiddleware", settings.MIDDLEWARE + ): + errors.append( + checks.Error( + "'django.contrib.sessions.middleware.SessionMiddleware' must " + "be in MIDDLEWARE in order to use the admin application.", + hint=( + "Insert " + "'django.contrib.sessions.middleware.SessionMiddleware' " + "before " + "'django.contrib.auth.middleware.AuthenticationMiddleware'." + ), + id="admin.E410", + ) + ) return errors class BaseModelAdminChecks: - def check(self, admin_obj, **kwargs): return [ *self._check_autocomplete_fields(admin_obj), @@ -167,12 +196,23 @@ class BaseModelAdminChecks: Check that `autocomplete_fields` is a list or tuple of model fields. """ if not isinstance(obj.autocomplete_fields, (list, tuple)): - return must_be('a list or tuple', option='autocomplete_fields', obj=obj, id='admin.E036') + return must_be( + "a list or tuple", + option="autocomplete_fields", + obj=obj, + id="admin.E036", + ) else: - return list(chain.from_iterable([ - self._check_autocomplete_fields_item(obj, field_name, 'autocomplete_fields[%d]' % index) - for index, field_name in enumerate(obj.autocomplete_fields) - ])) + return list( + chain.from_iterable( + [ + self._check_autocomplete_fields_item( + obj, field_name, "autocomplete_fields[%d]" % index + ) + for index, field_name in enumerate(obj.autocomplete_fields) + ] + ) + ) def _check_autocomplete_fields_item(self, obj, field_name, label): """ @@ -183,61 +223,75 @@ class BaseModelAdminChecks: try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E037') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E037" + ) else: if not field.many_to_many and not isinstance(field, models.ForeignKey): return must_be( - 'a foreign key or a many-to-many field', - option=label, obj=obj, id='admin.E038' + "a foreign key or a many-to-many field", + option=label, + obj=obj, + id="admin.E038", ) related_admin = obj.admin_site._registry.get(field.remote_field.model) if related_admin is None: return [ checks.Error( 'An admin for model "%s" has to be registered ' - 'to be referenced by %s.autocomplete_fields.' % ( + "to be referenced by %s.autocomplete_fields." + % ( field.remote_field.model.__name__, type(obj).__name__, ), obj=obj.__class__, - id='admin.E039', + id="admin.E039", ) ] elif not related_admin.search_fields: return [ checks.Error( '%s must define "search_fields", because it\'s ' - 'referenced by %s.autocomplete_fields.' % ( + "referenced by %s.autocomplete_fields." + % ( related_admin.__class__.__name__, type(obj).__name__, ), obj=obj.__class__, - id='admin.E040', + id="admin.E040", ) ] return [] def _check_raw_id_fields(self, obj): - """ Check that `raw_id_fields` only contains field names that are listed - on the model. """ + """Check that `raw_id_fields` only contains field names that are listed + on the model.""" if not isinstance(obj.raw_id_fields, (list, tuple)): - return must_be('a list or tuple', option='raw_id_fields', obj=obj, id='admin.E001') + return must_be( + "a list or tuple", option="raw_id_fields", obj=obj, id="admin.E001" + ) else: - return list(chain.from_iterable( - self._check_raw_id_fields_item(obj, field_name, 'raw_id_fields[%d]' % index) - for index, field_name in enumerate(obj.raw_id_fields) - )) + return list( + chain.from_iterable( + self._check_raw_id_fields_item( + obj, field_name, "raw_id_fields[%d]" % index + ) + for index, field_name in enumerate(obj.raw_id_fields) + ) + ) def _check_raw_id_fields_item(self, obj, field_name, label): - """ Check an item of `raw_id_fields`, i.e. check that field named + """Check an item of `raw_id_fields`, i.e. check that field named `field_name` exists in model `model` and is a ForeignKey or a - ManyToManyField. """ + ManyToManyField.""" try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E002') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E002" + ) else: # Using attname is not supported. if field.name != field_name: @@ -245,28 +299,33 @@ class BaseModelAdminChecks: field=field_name, option=label, obj=obj, - id='admin.E002', + id="admin.E002", ) if not field.many_to_many and not isinstance(field, models.ForeignKey): - return must_be('a foreign key or a many-to-many field', option=label, obj=obj, id='admin.E003') + return must_be( + "a foreign key or a many-to-many field", + option=label, + obj=obj, + id="admin.E003", + ) else: return [] def _check_fields(self, obj): - """ Check that `fields` only refer to existing fields, doesn't contain + """Check that `fields` only refer to existing fields, doesn't contain duplicates. Check if at most one of `fields` and `fieldsets` is defined. """ if obj.fields is None: return [] elif not isinstance(obj.fields, (list, tuple)): - return must_be('a list or tuple', option='fields', obj=obj, id='admin.E004') + return must_be("a list or tuple", option="fields", obj=obj, id="admin.E004") elif obj.fieldsets: return [ checks.Error( "Both 'fieldsets' and 'fields' are specified.", obj=obj.__class__, - id='admin.E005', + id="admin.E005", ) ] fields = flatten(obj.fields) @@ -275,75 +334,96 @@ class BaseModelAdminChecks: checks.Error( "The value of 'fields' contains duplicate field(s).", obj=obj.__class__, - id='admin.E006', + id="admin.E006", ) ] - return list(chain.from_iterable( - self._check_field_spec(obj, field_name, 'fields') - for field_name in obj.fields - )) + return list( + chain.from_iterable( + self._check_field_spec(obj, field_name, "fields") + for field_name in obj.fields + ) + ) def _check_fieldsets(self, obj): - """ Check that fieldsets is properly formatted and doesn't contain - duplicates. """ + """Check that fieldsets is properly formatted and doesn't contain + duplicates.""" if obj.fieldsets is None: return [] elif not isinstance(obj.fieldsets, (list, tuple)): - return must_be('a list or tuple', option='fieldsets', obj=obj, id='admin.E007') + return must_be( + "a list or tuple", option="fieldsets", obj=obj, id="admin.E007" + ) else: seen_fields = [] - return list(chain.from_iterable( - self._check_fieldsets_item(obj, fieldset, 'fieldsets[%d]' % index, seen_fields) - for index, fieldset in enumerate(obj.fieldsets) - )) + return list( + chain.from_iterable( + self._check_fieldsets_item( + obj, fieldset, "fieldsets[%d]" % index, seen_fields + ) + for index, fieldset in enumerate(obj.fieldsets) + ) + ) def _check_fieldsets_item(self, obj, fieldset, label, seen_fields): - """ Check an item of `fieldsets`, i.e. check that this is a pair of a - set name and a dictionary containing "fields" key. """ + """Check an item of `fieldsets`, i.e. check that this is a pair of a + set name and a dictionary containing "fields" key.""" if not isinstance(fieldset, (list, tuple)): - return must_be('a list or tuple', option=label, obj=obj, id='admin.E008') + return must_be("a list or tuple", option=label, obj=obj, id="admin.E008") elif len(fieldset) != 2: - return must_be('of length 2', option=label, obj=obj, id='admin.E009') + return must_be("of length 2", option=label, obj=obj, id="admin.E009") elif not isinstance(fieldset[1], dict): - return must_be('a dictionary', option='%s[1]' % label, obj=obj, id='admin.E010') - elif 'fields' not in fieldset[1]: + return must_be( + "a dictionary", option="%s[1]" % label, obj=obj, id="admin.E010" + ) + elif "fields" not in fieldset[1]: return [ checks.Error( "The value of '%s[1]' must contain the key 'fields'." % label, obj=obj.__class__, - id='admin.E011', + id="admin.E011", ) ] - elif not isinstance(fieldset[1]['fields'], (list, tuple)): - return must_be('a list or tuple', option="%s[1]['fields']" % label, obj=obj, id='admin.E008') + elif not isinstance(fieldset[1]["fields"], (list, tuple)): + return must_be( + "a list or tuple", + option="%s[1]['fields']" % label, + obj=obj, + id="admin.E008", + ) - seen_fields.extend(flatten(fieldset[1]['fields'])) + seen_fields.extend(flatten(fieldset[1]["fields"])) if len(seen_fields) != len(set(seen_fields)): return [ checks.Error( "There are duplicate field(s) in '%s[1]'." % label, obj=obj.__class__, - id='admin.E012', + id="admin.E012", ) ] - return list(chain.from_iterable( - self._check_field_spec(obj, fieldset_fields, '%s[1]["fields"]' % label) - for fieldset_fields in fieldset[1]['fields'] - )) + return list( + chain.from_iterable( + self._check_field_spec(obj, fieldset_fields, '%s[1]["fields"]' % label) + for fieldset_fields in fieldset[1]["fields"] + ) + ) def _check_field_spec(self, obj, fields, label): - """ `fields` should be an item of `fields` or an item of + """`fields` should be an item of `fields` or an item of fieldset[1]['fields'] for any `fieldset` in `fieldsets`. It should be a - field name or a tuple of field names. """ + field name or a tuple of field names.""" if isinstance(fields, tuple): - return list(chain.from_iterable( - self._check_field_spec_item(obj, field_name, "%s[%d]" % (label, index)) - for index, field_name in enumerate(fields) - )) + return list( + chain.from_iterable( + self._check_field_spec_item( + obj, field_name, "%s[%d]" % (label, index) + ) + for index, field_name in enumerate(fields) + ) + ) else: return self._check_field_spec_item(obj, fields, label) @@ -361,125 +441,154 @@ class BaseModelAdminChecks: # be an extra field on the form. return [] else: - if (isinstance(field, models.ManyToManyField) and - not field.remote_field.through._meta.auto_created): + if ( + isinstance(field, models.ManyToManyField) + and not field.remote_field.through._meta.auto_created + ): return [ checks.Error( "The value of '%s' cannot include the ManyToManyField '%s', " "because that field manually specifies a relationship model." % (label, field_name), obj=obj.__class__, - id='admin.E013', + id="admin.E013", ) ] else: return [] def _check_exclude(self, obj): - """ Check that exclude is a sequence without duplicates. """ + """Check that exclude is a sequence without duplicates.""" if obj.exclude is None: # default value is None return [] elif not isinstance(obj.exclude, (list, tuple)): - return must_be('a list or tuple', option='exclude', obj=obj, id='admin.E014') + return must_be( + "a list or tuple", option="exclude", obj=obj, id="admin.E014" + ) elif len(obj.exclude) > len(set(obj.exclude)): return [ checks.Error( "The value of 'exclude' contains duplicate field(s).", obj=obj.__class__, - id='admin.E015', + id="admin.E015", ) ] else: return [] def _check_form(self, obj): - """ Check that form subclasses BaseModelForm. """ + """Check that form subclasses BaseModelForm.""" if not _issubclass(obj.form, BaseModelForm): - return must_inherit_from(parent='BaseModelForm', option='form', - obj=obj, id='admin.E016') + return must_inherit_from( + parent="BaseModelForm", option="form", obj=obj, id="admin.E016" + ) else: return [] def _check_filter_vertical(self, obj): - """ Check that filter_vertical is a sequence of field names. """ + """Check that filter_vertical is a sequence of field names.""" if not isinstance(obj.filter_vertical, (list, tuple)): - return must_be('a list or tuple', option='filter_vertical', obj=obj, id='admin.E017') + return must_be( + "a list or tuple", option="filter_vertical", obj=obj, id="admin.E017" + ) else: - return list(chain.from_iterable( - self._check_filter_item(obj, field_name, "filter_vertical[%d]" % index) - for index, field_name in enumerate(obj.filter_vertical) - )) + return list( + chain.from_iterable( + self._check_filter_item( + obj, field_name, "filter_vertical[%d]" % index + ) + for index, field_name in enumerate(obj.filter_vertical) + ) + ) def _check_filter_horizontal(self, obj): - """ Check that filter_horizontal is a sequence of field names. """ + """Check that filter_horizontal is a sequence of field names.""" if not isinstance(obj.filter_horizontal, (list, tuple)): - return must_be('a list or tuple', option='filter_horizontal', obj=obj, id='admin.E018') + return must_be( + "a list or tuple", option="filter_horizontal", obj=obj, id="admin.E018" + ) else: - return list(chain.from_iterable( - self._check_filter_item(obj, field_name, "filter_horizontal[%d]" % index) - for index, field_name in enumerate(obj.filter_horizontal) - )) + return list( + chain.from_iterable( + self._check_filter_item( + obj, field_name, "filter_horizontal[%d]" % index + ) + for index, field_name in enumerate(obj.filter_horizontal) + ) + ) def _check_filter_item(self, obj, field_name, label): - """ Check one item of `filter_vertical` or `filter_horizontal`, i.e. - check that given field exists and is a ManyToManyField. """ + """Check one item of `filter_vertical` or `filter_horizontal`, i.e. + check that given field exists and is a ManyToManyField.""" try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E019') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E019" + ) else: if not field.many_to_many: - return must_be('a many-to-many field', option=label, obj=obj, id='admin.E020') + return must_be( + "a many-to-many field", option=label, obj=obj, id="admin.E020" + ) else: return [] def _check_radio_fields(self, obj): - """ Check that `radio_fields` is a dictionary. """ + """Check that `radio_fields` is a dictionary.""" if not isinstance(obj.radio_fields, dict): - return must_be('a dictionary', option='radio_fields', obj=obj, id='admin.E021') + return must_be( + "a dictionary", option="radio_fields", obj=obj, id="admin.E021" + ) else: - return list(chain.from_iterable( - self._check_radio_fields_key(obj, field_name, 'radio_fields') + - self._check_radio_fields_value(obj, val, 'radio_fields["%s"]' % field_name) - for field_name, val in obj.radio_fields.items() - )) + return list( + chain.from_iterable( + self._check_radio_fields_key(obj, field_name, "radio_fields") + + self._check_radio_fields_value( + obj, val, 'radio_fields["%s"]' % field_name + ) + for field_name, val in obj.radio_fields.items() + ) + ) def _check_radio_fields_key(self, obj, field_name, label): - """ Check that a key of `radio_fields` dictionary is name of existing - field and that the field is a ForeignKey or has `choices` defined. """ + """Check that a key of `radio_fields` dictionary is name of existing + field and that the field is a ForeignKey or has `choices` defined.""" try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E022') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E022" + ) else: if not (isinstance(field, models.ForeignKey) or field.choices): return [ checks.Error( "The value of '%s' refers to '%s', which is not an " - "instance of ForeignKey, and does not have a 'choices' definition." % ( - label, field_name - ), + "instance of ForeignKey, and does not have a 'choices' definition." + % (label, field_name), obj=obj.__class__, - id='admin.E023', + id="admin.E023", ) ] else: return [] def _check_radio_fields_value(self, obj, val, label): - """ Check type of a value of `radio_fields` dictionary. """ + """Check type of a value of `radio_fields` dictionary.""" from django.contrib.admin.options import HORIZONTAL, VERTICAL if val not in (HORIZONTAL, VERTICAL): return [ checks.Error( - "The value of '%s' must be either admin.HORIZONTAL or admin.VERTICAL." % label, + "The value of '%s' must be either admin.HORIZONTAL or admin.VERTICAL." + % label, obj=obj.__class__, - id='admin.E024', + id="admin.E024", ) ] else: @@ -491,85 +600,108 @@ class BaseModelAdminChecks: checks.Error( "The value of 'view_on_site' must be a callable or a boolean value.", obj=obj.__class__, - id='admin.E025', + id="admin.E025", ) ] else: return [] def _check_prepopulated_fields(self, obj): - """ Check that `prepopulated_fields` is a dictionary containing allowed - field types. """ + """Check that `prepopulated_fields` is a dictionary containing allowed + field types.""" if not isinstance(obj.prepopulated_fields, dict): - return must_be('a dictionary', option='prepopulated_fields', obj=obj, id='admin.E026') + return must_be( + "a dictionary", option="prepopulated_fields", obj=obj, id="admin.E026" + ) else: - return list(chain.from_iterable( - self._check_prepopulated_fields_key(obj, field_name, 'prepopulated_fields') + - self._check_prepopulated_fields_value(obj, val, 'prepopulated_fields["%s"]' % field_name) - for field_name, val in obj.prepopulated_fields.items() - )) + return list( + chain.from_iterable( + self._check_prepopulated_fields_key( + obj, field_name, "prepopulated_fields" + ) + + self._check_prepopulated_fields_value( + obj, val, 'prepopulated_fields["%s"]' % field_name + ) + for field_name, val in obj.prepopulated_fields.items() + ) + ) def _check_prepopulated_fields_key(self, obj, field_name, label): - """ Check a key of `prepopulated_fields` dictionary, i.e. check that it + """Check a key of `prepopulated_fields` dictionary, i.e. check that it is a name of existing field and the field is one of the allowed types. """ try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E027') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E027" + ) else: - if isinstance(field, (models.DateTimeField, models.ForeignKey, models.ManyToManyField)): + if isinstance( + field, (models.DateTimeField, models.ForeignKey, models.ManyToManyField) + ): return [ checks.Error( "The value of '%s' refers to '%s', which must not be a DateTimeField, " - "a ForeignKey, a OneToOneField, or a ManyToManyField." % (label, field_name), + "a ForeignKey, a OneToOneField, or a ManyToManyField." + % (label, field_name), obj=obj.__class__, - id='admin.E028', + id="admin.E028", ) ] else: return [] def _check_prepopulated_fields_value(self, obj, val, label): - """ Check a value of `prepopulated_fields` dictionary, i.e. it's an - iterable of existing fields. """ + """Check a value of `prepopulated_fields` dictionary, i.e. it's an + iterable of existing fields.""" if not isinstance(val, (list, tuple)): - return must_be('a list or tuple', option=label, obj=obj, id='admin.E029') + return must_be("a list or tuple", option=label, obj=obj, id="admin.E029") else: - return list(chain.from_iterable( - self._check_prepopulated_fields_value_item(obj, subfield_name, "%s[%r]" % (label, index)) - for index, subfield_name in enumerate(val) - )) + return list( + chain.from_iterable( + self._check_prepopulated_fields_value_item( + obj, subfield_name, "%s[%r]" % (label, index) + ) + for index, subfield_name in enumerate(val) + ) + ) def _check_prepopulated_fields_value_item(self, obj, field_name, label): - """ For `prepopulated_fields` equal to {"slug": ("title",)}, - `field_name` is "title". """ + """For `prepopulated_fields` equal to {"slug": ("title",)}, + `field_name` is "title".""" try: obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E030') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E030" + ) else: return [] def _check_ordering(self, obj): - """ Check that ordering refers to existing fields or is random. """ + """Check that ordering refers to existing fields or is random.""" # ordering = None if obj.ordering is None: # The default value is None return [] elif not isinstance(obj.ordering, (list, tuple)): - return must_be('a list or tuple', option='ordering', obj=obj, id='admin.E031') + return must_be( + "a list or tuple", option="ordering", obj=obj, id="admin.E031" + ) else: - return list(chain.from_iterable( - self._check_ordering_item(obj, field_name, 'ordering[%d]' % index) - for index, field_name in enumerate(obj.ordering) - )) + return list( + chain.from_iterable( + self._check_ordering_item(obj, field_name, "ordering[%d]" % index) + for index, field_name in enumerate(obj.ordering) + ) + ) def _check_ordering_item(self, obj, field_name, label): - """ Check that `ordering` refers to existing fields. """ + """Check that `ordering` refers to existing fields.""" if isinstance(field_name, (Combinable, models.OrderBy)): if not isinstance(field_name, models.OrderBy): field_name = field_name.asc() @@ -577,46 +709,54 @@ class BaseModelAdminChecks: field_name = field_name.expression.name else: return [] - if field_name == '?' and len(obj.ordering) != 1: + if field_name == "?" and len(obj.ordering) != 1: return [ checks.Error( "The value of 'ordering' has the random ordering marker '?', " "but contains other fields as well.", hint='Either remove the "?", or remove the other fields.', obj=obj.__class__, - id='admin.E032', + id="admin.E032", ) ] - elif field_name == '?': + elif field_name == "?": return [] elif LOOKUP_SEP in field_name: # Skip ordering in the format field1__field2 (FIXME: checking # this format would be nice, but it's a little fiddly). return [] else: - if field_name.startswith('-'): + if field_name.startswith("-"): field_name = field_name[1:] - if field_name == 'pk': + if field_name == "pk": return [] try: obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E033') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E033" + ) else: return [] def _check_readonly_fields(self, obj): - """ Check that readonly_fields refers to proper attribute or field. """ + """Check that readonly_fields refers to proper attribute or field.""" if obj.readonly_fields == (): return [] elif not isinstance(obj.readonly_fields, (list, tuple)): - return must_be('a list or tuple', option='readonly_fields', obj=obj, id='admin.E034') + return must_be( + "a list or tuple", option="readonly_fields", obj=obj, id="admin.E034" + ) else: - return list(chain.from_iterable( - self._check_readonly_fields_item(obj, field_name, "readonly_fields[%d]" % index) - for index, field_name in enumerate(obj.readonly_fields) - )) + return list( + chain.from_iterable( + self._check_readonly_fields_item( + obj, field_name, "readonly_fields[%d]" % index + ) + for index, field_name in enumerate(obj.readonly_fields) + ) + ) def _check_readonly_fields_item(self, obj, field_name, label): if callable(field_name): @@ -632,11 +772,14 @@ class BaseModelAdminChecks: return [ checks.Error( "The value of '%s' is not a callable, an attribute of " - "'%s', or an attribute of '%s'." % ( - label, obj.__class__.__name__, obj.model._meta.label, + "'%s', or an attribute of '%s'." + % ( + label, + obj.__class__.__name__, + obj.model._meta.label, ), obj=obj.__class__, - id='admin.E035', + id="admin.E035", ) ] else: @@ -644,7 +787,6 @@ class BaseModelAdminChecks: class ModelAdminChecks(BaseModelAdminChecks): - def check(self, admin_obj, **kwargs): return [ *super().check(admin_obj), @@ -665,44 +807,46 @@ class ModelAdminChecks(BaseModelAdminChecks): ] def _check_save_as(self, obj): - """ Check save_as is a boolean. """ + """Check save_as is a boolean.""" if not isinstance(obj.save_as, bool): - return must_be('a boolean', option='save_as', - obj=obj, id='admin.E101') + return must_be("a boolean", option="save_as", obj=obj, id="admin.E101") else: return [] def _check_save_on_top(self, obj): - """ Check save_on_top is a boolean. """ + """Check save_on_top is a boolean.""" if not isinstance(obj.save_on_top, bool): - return must_be('a boolean', option='save_on_top', - obj=obj, id='admin.E102') + return must_be("a boolean", option="save_on_top", obj=obj, id="admin.E102") else: return [] def _check_inlines(self, obj): - """ Check all inline model admin classes. """ + """Check all inline model admin classes.""" if not isinstance(obj.inlines, (list, tuple)): - return must_be('a list or tuple', option='inlines', obj=obj, id='admin.E103') + return must_be( + "a list or tuple", option="inlines", obj=obj, id="admin.E103" + ) else: - return list(chain.from_iterable( - self._check_inlines_item(obj, item, "inlines[%d]" % index) - for index, item in enumerate(obj.inlines) - )) + return list( + chain.from_iterable( + self._check_inlines_item(obj, item, "inlines[%d]" % index) + for index, item in enumerate(obj.inlines) + ) + ) def _check_inlines_item(self, obj, inline, label): - """ Check one inline model admin. """ + """Check one inline model admin.""" try: - inline_label = inline.__module__ + '.' + inline.__name__ + inline_label = inline.__module__ + "." + inline.__name__ except AttributeError: return [ checks.Error( "'%s' must inherit from 'InlineModelAdmin'." % obj, obj=obj.__class__, - id='admin.E104', + id="admin.E104", ) ] @@ -713,7 +857,7 @@ class ModelAdminChecks(BaseModelAdminChecks): checks.Error( "'%s' must inherit from 'InlineModelAdmin'." % inline_label, obj=obj.__class__, - id='admin.E104', + id="admin.E104", ) ] elif not inline.model: @@ -721,25 +865,30 @@ class ModelAdminChecks(BaseModelAdminChecks): checks.Error( "'%s' must have a 'model' attribute." % inline_label, obj=obj.__class__, - id='admin.E105', + id="admin.E105", ) ] elif not _issubclass(inline.model, models.Model): - return must_be('a Model', option='%s.model' % inline_label, obj=obj, id='admin.E106') + return must_be( + "a Model", option="%s.model" % inline_label, obj=obj, id="admin.E106" + ) else: return inline(obj.model, obj.admin_site).check() def _check_list_display(self, obj): - """ Check that list_display only contains fields or usable attributes. - """ + """Check that list_display only contains fields or usable attributes.""" if not isinstance(obj.list_display, (list, tuple)): - return must_be('a list or tuple', option='list_display', obj=obj, id='admin.E107') + return must_be( + "a list or tuple", option="list_display", obj=obj, id="admin.E107" + ) else: - return list(chain.from_iterable( - self._check_list_display_item(obj, item, "list_display[%d]" % index) - for index, item in enumerate(obj.list_display) - )) + return list( + chain.from_iterable( + self._check_list_display_item(obj, item, "list_display[%d]" % index) + for index, item in enumerate(obj.list_display) + ) + ) def _check_list_display_item(self, obj, item, label): if callable(item): @@ -756,12 +905,15 @@ class ModelAdminChecks(BaseModelAdminChecks): checks.Error( "The value of '%s' refers to '%s', which is not a " "callable, an attribute of '%s', or an attribute or " - "method on '%s'." % ( - label, item, obj.__class__.__name__, + "method on '%s'." + % ( + label, + item, + obj.__class__.__name__, obj.model._meta.label, ), obj=obj.__class__, - id='admin.E108', + id="admin.E108", ) ] if isinstance(field, models.ManyToManyField): @@ -769,37 +921,44 @@ class ModelAdminChecks(BaseModelAdminChecks): checks.Error( "The value of '%s' must not be a ManyToManyField." % label, obj=obj.__class__, - id='admin.E109', + id="admin.E109", ) ] return [] def _check_list_display_links(self, obj): - """ Check that list_display_links is a unique subset of list_display. - """ + """Check that list_display_links is a unique subset of list_display.""" from django.contrib.admin.options import ModelAdmin if obj.list_display_links is None: return [] elif not isinstance(obj.list_display_links, (list, tuple)): - return must_be('a list, a tuple, or None', option='list_display_links', obj=obj, id='admin.E110') + return must_be( + "a list, a tuple, or None", + option="list_display_links", + obj=obj, + id="admin.E110", + ) # Check only if ModelAdmin.get_list_display() isn't overridden. elif obj.get_list_display.__func__ is ModelAdmin.get_list_display: - return list(chain.from_iterable( - self._check_list_display_links_item(obj, field_name, "list_display_links[%d]" % index) - for index, field_name in enumerate(obj.list_display_links) - )) + return list( + chain.from_iterable( + self._check_list_display_links_item( + obj, field_name, "list_display_links[%d]" % index + ) + for index, field_name in enumerate(obj.list_display_links) + ) + ) return [] def _check_list_display_links_item(self, obj, field_name, label): if field_name not in obj.list_display: return [ checks.Error( - "The value of '%s' refers to '%s', which is not defined in 'list_display'." % ( - label, field_name - ), + "The value of '%s' refers to '%s', which is not defined in 'list_display'." + % (label, field_name), obj=obj.__class__, - id='admin.E111', + id="admin.E111", ) ] else: @@ -807,12 +966,16 @@ class ModelAdminChecks(BaseModelAdminChecks): def _check_list_filter(self, obj): if not isinstance(obj.list_filter, (list, tuple)): - return must_be('a list or tuple', option='list_filter', obj=obj, id='admin.E112') + return must_be( + "a list or tuple", option="list_filter", obj=obj, id="admin.E112" + ) else: - return list(chain.from_iterable( - self._check_list_filter_item(obj, item, "list_filter[%d]" % index) - for index, item in enumerate(obj.list_filter) - )) + return list( + chain.from_iterable( + self._check_list_filter_item(obj, item, "list_filter[%d]" % index) + for index, item in enumerate(obj.list_filter) + ) + ) def _check_list_filter_item(self, obj, item, label): """ @@ -827,15 +990,17 @@ class ModelAdminChecks(BaseModelAdminChecks): if callable(item) and not isinstance(item, models.Field): # If item is option 3, it should be a ListFilter... if not _issubclass(item, ListFilter): - return must_inherit_from(parent='ListFilter', option=label, - obj=obj, id='admin.E113') + return must_inherit_from( + parent="ListFilter", option=label, obj=obj, id="admin.E113" + ) # ... but not a FieldListFilter. elif issubclass(item, FieldListFilter): return [ checks.Error( - "The value of '%s' must not inherit from 'FieldListFilter'." % label, + "The value of '%s' must not inherit from 'FieldListFilter'." + % label, obj=obj.__class__, - id='admin.E114', + id="admin.E114", ) ] else: @@ -844,7 +1009,12 @@ class ModelAdminChecks(BaseModelAdminChecks): # item is option #2 field, list_filter_class = item if not _issubclass(list_filter_class, FieldListFilter): - return must_inherit_from(parent='FieldListFilter', option='%s[1]' % label, obj=obj, id='admin.E115') + return must_inherit_from( + parent="FieldListFilter", + option="%s[1]" % label, + obj=obj, + id="admin.E115", + ) else: return [] else: @@ -857,55 +1027,73 @@ class ModelAdminChecks(BaseModelAdminChecks): except (NotRelationField, FieldDoesNotExist): return [ checks.Error( - "The value of '%s' refers to '%s', which does not refer to a Field." % (label, field), + "The value of '%s' refers to '%s', which does not refer to a Field." + % (label, field), obj=obj.__class__, - id='admin.E116', + id="admin.E116", ) ] else: return [] def _check_list_select_related(self, obj): - """ Check that list_select_related is a boolean, a list or a tuple. """ + """Check that list_select_related is a boolean, a list or a tuple.""" if not isinstance(obj.list_select_related, (bool, list, tuple)): - return must_be('a boolean, tuple or list', option='list_select_related', obj=obj, id='admin.E117') + return must_be( + "a boolean, tuple or list", + option="list_select_related", + obj=obj, + id="admin.E117", + ) else: return [] def _check_list_per_page(self, obj): - """ Check that list_per_page is an integer. """ + """Check that list_per_page is an integer.""" if not isinstance(obj.list_per_page, int): - return must_be('an integer', option='list_per_page', obj=obj, id='admin.E118') + return must_be( + "an integer", option="list_per_page", obj=obj, id="admin.E118" + ) else: return [] def _check_list_max_show_all(self, obj): - """ Check that list_max_show_all is an integer. """ + """Check that list_max_show_all is an integer.""" if not isinstance(obj.list_max_show_all, int): - return must_be('an integer', option='list_max_show_all', obj=obj, id='admin.E119') + return must_be( + "an integer", option="list_max_show_all", obj=obj, id="admin.E119" + ) else: return [] def _check_list_editable(self, obj): - """ Check that list_editable is a sequence of editable fields from - list_display without first element. """ + """Check that list_editable is a sequence of editable fields from + list_display without first element.""" if not isinstance(obj.list_editable, (list, tuple)): - return must_be('a list or tuple', option='list_editable', obj=obj, id='admin.E120') + return must_be( + "a list or tuple", option="list_editable", obj=obj, id="admin.E120" + ) else: - return list(chain.from_iterable( - self._check_list_editable_item(obj, item, "list_editable[%d]" % index) - for index, item in enumerate(obj.list_editable) - )) + return list( + chain.from_iterable( + self._check_list_editable_item( + obj, item, "list_editable[%d]" % index + ) + for index, item in enumerate(obj.list_editable) + ) + ) def _check_list_editable_item(self, obj, field_name, label): try: field = obj.model._meta.get_field(field_name) except FieldDoesNotExist: - return refer_to_missing_field(field=field_name, option=label, obj=obj, id='admin.E121') + return refer_to_missing_field( + field=field_name, option=label, obj=obj, id="admin.E121" + ) else: if field_name not in obj.list_display: return [ @@ -913,54 +1101,58 @@ class ModelAdminChecks(BaseModelAdminChecks): "The value of '%s' refers to '%s', which is not " "contained in 'list_display'." % (label, field_name), obj=obj.__class__, - id='admin.E122', + id="admin.E122", ) ] elif obj.list_display_links and field_name in obj.list_display_links: return [ checks.Error( - "The value of '%s' cannot be in both 'list_editable' and 'list_display_links'." % field_name, + "The value of '%s' cannot be in both 'list_editable' and 'list_display_links'." + % field_name, obj=obj.__class__, - id='admin.E123', + id="admin.E123", ) ] # If list_display[0] is in list_editable, check that # list_display_links is set. See #22792 and #26229 for use cases. - elif (obj.list_display[0] == field_name and not obj.list_display_links and - obj.list_display_links is not None): + elif ( + obj.list_display[0] == field_name + and not obj.list_display_links + and obj.list_display_links is not None + ): return [ checks.Error( "The value of '%s' refers to the first field in 'list_display' ('%s'), " - "which cannot be used unless 'list_display_links' is set." % ( - label, obj.list_display[0] - ), + "which cannot be used unless 'list_display_links' is set." + % (label, obj.list_display[0]), obj=obj.__class__, - id='admin.E124', + id="admin.E124", ) ] elif not field.editable: return [ checks.Error( - "The value of '%s' refers to '%s', which is not editable through the admin." % ( - label, field_name - ), + "The value of '%s' refers to '%s', which is not editable through the admin." + % (label, field_name), obj=obj.__class__, - id='admin.E125', + id="admin.E125", ) ] else: return [] def _check_search_fields(self, obj): - """ Check search_fields is a sequence. """ + """Check search_fields is a sequence.""" if not isinstance(obj.search_fields, (list, tuple)): - return must_be('a list or tuple', option='search_fields', obj=obj, id='admin.E126') + return must_be( + "a list or tuple", option="search_fields", obj=obj, id="admin.E126" + ) else: return [] def _check_date_hierarchy(self, obj): - """ Check that date_hierarchy refers to DateField or DateTimeField. """ + """Check that date_hierarchy refers to DateField or DateTimeField.""" if obj.date_hierarchy is None: return [] @@ -973,12 +1165,17 @@ class ModelAdminChecks(BaseModelAdminChecks): "The value of 'date_hierarchy' refers to '%s', which " "does not refer to a Field." % obj.date_hierarchy, obj=obj.__class__, - id='admin.E127', + id="admin.E127", ) ] else: if not isinstance(field, (models.DateField, models.DateTimeField)): - return must_be('a DateField or DateTimeField', option='date_hierarchy', obj=obj, id='admin.E128') + return must_be( + "a DateField or DateTimeField", + option="date_hierarchy", + obj=obj, + id="admin.E128", + ) else: return [] @@ -990,20 +1187,21 @@ class ModelAdminChecks(BaseModelAdminChecks): actions = obj._get_base_actions() errors = [] for func, name, _ in actions: - if not hasattr(func, 'allowed_permissions'): + if not hasattr(func, "allowed_permissions"): continue for permission in func.allowed_permissions: - method_name = 'has_%s_permission' % permission + method_name = "has_%s_permission" % permission if not hasattr(obj, method_name): errors.append( checks.Error( - '%s must define a %s() method for the %s action.' % ( + "%s must define a %s() method for the %s action." + % ( obj.__class__.__name__, method_name, func.__name__, ), obj=obj.__class__, - id='admin.E129', + id="admin.E129", ) ) return errors @@ -1014,20 +1212,22 @@ class ModelAdminChecks(BaseModelAdminChecks): names = collections.Counter(name for _, name, _ in obj._get_base_actions()) for name, count in names.items(): if count > 1: - errors.append(checks.Error( - '__name__ attributes of actions defined in %s must be ' - 'unique. Name %r is not unique.' % ( - obj.__class__.__name__, - name, - ), - obj=obj.__class__, - id='admin.E130', - )) + errors.append( + checks.Error( + "__name__ attributes of actions defined in %s must be " + "unique. Name %r is not unique." + % ( + obj.__class__.__name__, + name, + ), + obj=obj.__class__, + id="admin.E130", + ) + ) return errors class InlineModelAdminChecks(BaseModelAdminChecks): - def check(self, inline_obj, **kwargs): parent_model = inline_obj.parent_model return [ @@ -1059,11 +1259,13 @@ class InlineModelAdminChecks(BaseModelAdminChecks): return [ checks.Error( "Cannot exclude the field '%s', because it is the foreign key " - "to the parent model '%s'." % ( - fk.name, parent_model._meta.label, + "to the parent model '%s'." + % ( + fk.name, + parent_model._meta.label, ), obj=obj.__class__, - id='admin.E201', + id="admin.E201", ) ] else: @@ -1073,43 +1275,45 @@ class InlineModelAdminChecks(BaseModelAdminChecks): try: _get_foreign_key(parent_model, obj.model, fk_name=obj.fk_name) except ValueError as e: - return [checks.Error(e.args[0], obj=obj.__class__, id='admin.E202')] + return [checks.Error(e.args[0], obj=obj.__class__, id="admin.E202")] else: return [] def _check_extra(self, obj): - """ Check that extra is an integer. """ + """Check that extra is an integer.""" if not isinstance(obj.extra, int): - return must_be('an integer', option='extra', obj=obj, id='admin.E203') + return must_be("an integer", option="extra", obj=obj, id="admin.E203") else: return [] def _check_max_num(self, obj): - """ Check that max_num is an integer. """ + """Check that max_num is an integer.""" if obj.max_num is None: return [] elif not isinstance(obj.max_num, int): - return must_be('an integer', option='max_num', obj=obj, id='admin.E204') + return must_be("an integer", option="max_num", obj=obj, id="admin.E204") else: return [] def _check_min_num(self, obj): - """ Check that min_num is an integer. """ + """Check that min_num is an integer.""" if obj.min_num is None: return [] elif not isinstance(obj.min_num, int): - return must_be('an integer', option='min_num', obj=obj, id='admin.E205') + return must_be("an integer", option="min_num", obj=obj, id="admin.E205") else: return [] def _check_formset(self, obj): - """ Check formset is a subclass of BaseModelFormSet. """ + """Check formset is a subclass of BaseModelFormSet.""" if not _issubclass(obj.formset, BaseModelFormSet): - return must_inherit_from(parent='BaseModelFormSet', option='formset', obj=obj, id='admin.E206') + return must_inherit_from( + parent="BaseModelFormSet", option="formset", obj=obj, id="admin.E206" + ) else: return [] diff --git a/django/contrib/admin/decorators.py b/django/contrib/admin/decorators.py index 4de99580ad..d3ff56a59a 100644 --- a/django/contrib/admin/decorators.py +++ b/django/contrib/admin/decorators.py @@ -17,19 +17,23 @@ def action(function=None, *, permissions=None, description=None): make_published.allowed_permissions = ['publish'] make_published.short_description = 'Mark selected stories as published' """ + def decorator(func): if permissions is not None: func.allowed_permissions = permissions if description is not None: func.short_description = description return func + if function is None: return decorator else: return decorator(function) -def display(function=None, *, boolean=None, ordering=None, description=None, empty_value=None): +def display( + function=None, *, boolean=None, ordering=None, description=None, empty_value=None +): """ Conveniently add attributes to a display function:: @@ -50,11 +54,12 @@ def display(function=None, *, boolean=None, ordering=None, description=None, emp is_published.admin_order_field = '-publish_date' is_published.short_description = 'Is Published?' """ + def decorator(func): if boolean is not None and empty_value is not None: raise ValueError( - 'The boolean and empty_value arguments to the @display ' - 'decorator are mutually exclusive.' + "The boolean and empty_value arguments to the @display " + "decorator are mutually exclusive." ) if boolean is not None: func.boolean = boolean @@ -65,6 +70,7 @@ def display(function=None, *, boolean=None, ordering=None, description=None, emp if empty_value is not None: func.empty_value_display = empty_value return func + if function is None: return decorator else: @@ -83,21 +89,23 @@ def register(*models, site=None): The `site` kwarg is an admin site to use instead of the default admin site. """ from django.contrib.admin import ModelAdmin - from django.contrib.admin.sites import AdminSite, site as default_site + from django.contrib.admin.sites import AdminSite + from django.contrib.admin.sites import site as default_site def _model_admin_wrapper(admin_class): if not models: - raise ValueError('At least one model must be passed to register.') + raise ValueError("At least one model must be passed to register.") admin_site = site or default_site if not isinstance(admin_site, AdminSite): - raise ValueError('site must subclass AdminSite') + raise ValueError("site must subclass AdminSite") if not issubclass(admin_class, ModelAdmin): - raise ValueError('Wrapped class must subclass ModelAdmin.') + raise ValueError("Wrapped class must subclass ModelAdmin.") admin_site.register(models, admin_class=admin_class) return admin_class + return _model_admin_wrapper diff --git a/django/contrib/admin/exceptions.py b/django/contrib/admin/exceptions.py index f619bc2252..2ee8f625ca 100644 --- a/django/contrib/admin/exceptions.py +++ b/django/contrib/admin/exceptions.py @@ -3,9 +3,11 @@ from django.core.exceptions import SuspiciousOperation class DisallowedModelAdminLookup(SuspiciousOperation): """Invalid filter was passed to admin view via URL querystring""" + pass class DisallowedModelAdminToField(SuspiciousOperation): """Invalid to_field was passed to admin view via URL query string""" + pass diff --git a/django/contrib/admin/filters.py b/django/contrib/admin/filters.py index f6833e57cb..ec97605f14 100644 --- a/django/contrib/admin/filters.py +++ b/django/contrib/admin/filters.py @@ -9,7 +9,9 @@ import datetime from django.contrib.admin.options import IncorrectLookupParameters from django.contrib.admin.utils import ( - get_model_from_relation, prepare_lookup_value, reverse_field_path, + get_model_from_relation, + prepare_lookup_value, + reverse_field_path, ) from django.core.exceptions import ImproperlyConfigured, ValidationError from django.db import models @@ -19,7 +21,7 @@ from django.utils.translation import gettext_lazy as _ class ListFilter: title = None # Human-readable title to appear in the right sidebar. - template = 'admin/filter.html' + template = "admin/filter.html" def __init__(self, request, params, model, model_admin): # This dictionary will eventually contain the request's query string @@ -35,7 +37,9 @@ class ListFilter: """ Return True if some choices would be output for this filter. """ - raise NotImplementedError('subclasses of ListFilter must provide a has_output() method') + raise NotImplementedError( + "subclasses of ListFilter must provide a has_output() method" + ) def choices(self, changelist): """ @@ -43,20 +47,26 @@ class ListFilter: `changelist` is the ChangeList to be displayed. """ - raise NotImplementedError('subclasses of ListFilter must provide a choices() method') + raise NotImplementedError( + "subclasses of ListFilter must provide a choices() method" + ) def queryset(self, request, queryset): """ Return the filtered queryset. """ - raise NotImplementedError('subclasses of ListFilter must provide a queryset() method') + raise NotImplementedError( + "subclasses of ListFilter must provide a queryset() method" + ) def expected_parameters(self): """ Return the list of parameter names that are expected from the request's query string and that will be used by this filter. """ - raise NotImplementedError('subclasses of ListFilter must provide an expected_parameters() method') + raise NotImplementedError( + "subclasses of ListFilter must provide an expected_parameters() method" + ) class SimpleListFilter(ListFilter): @@ -94,8 +104,8 @@ class SimpleListFilter(ListFilter): Must be overridden to return a list of tuples (value, verbose value) """ raise NotImplementedError( - 'The SimpleListFilter.lookups() method must be overridden to ' - 'return a list of tuples (value, verbose value).' + "The SimpleListFilter.lookups() method must be overridden to " + "return a list of tuples (value, verbose value)." ) def expected_parameters(self): @@ -103,32 +113,36 @@ class SimpleListFilter(ListFilter): def choices(self, changelist): yield { - 'selected': self.value() is None, - 'query_string': changelist.get_query_string(remove=[self.parameter_name]), - 'display': _('All'), + "selected": self.value() is None, + "query_string": changelist.get_query_string(remove=[self.parameter_name]), + "display": _("All"), } for lookup, title in self.lookup_choices: yield { - 'selected': self.value() == str(lookup), - 'query_string': changelist.get_query_string({self.parameter_name: lookup}), - 'display': title, + "selected": self.value() == str(lookup), + "query_string": changelist.get_query_string( + {self.parameter_name: lookup} + ), + "display": title, } class FieldListFilter(ListFilter): _field_list_filters = [] _take_priority_index = 0 - list_separator = ',' + list_separator = "," def __init__(self, field, request, params, model, model_admin, field_path): self.field = field self.field_path = field_path - self.title = getattr(field, 'verbose_name', field_path) + self.title = getattr(field, "verbose_name", field_path) super().__init__(request, params, model, model_admin) for p in self.expected_parameters(): if p in params: value = params.pop(p) - self.used_parameters[p] = prepare_lookup_value(p, value, self.list_separator) + self.used_parameters[p] = prepare_lookup_value( + p, value, self.list_separator + ) def has_output(self): return True @@ -148,7 +162,8 @@ class FieldListFilter(ListFilter): # of fields with some custom filters. The first found in the list # is used in priority. cls._field_list_filters.insert( - cls._take_priority_index, (test, list_filter_class)) + cls._take_priority_index, (test, list_filter_class) + ) cls._take_priority_index += 1 else: cls._field_list_filters.append((test, list_filter_class)) @@ -157,19 +172,21 @@ class FieldListFilter(ListFilter): def create(cls, field, request, params, model, model_admin, field_path): for test, list_filter_class in cls._field_list_filters: if test(field): - return list_filter_class(field, request, params, model, model_admin, field_path=field_path) + return list_filter_class( + field, request, params, model, model_admin, field_path=field_path + ) class RelatedFieldListFilter(FieldListFilter): def __init__(self, field, request, params, model, model_admin, field_path): other_model = get_model_from_relation(field) - self.lookup_kwarg = '%s__%s__exact' % (field_path, field.target_field.name) - self.lookup_kwarg_isnull = '%s__isnull' % field_path + self.lookup_kwarg = "%s__%s__exact" % (field_path, field.target_field.name) + self.lookup_kwarg_isnull = "%s__isnull" % field_path self.lookup_val = params.get(self.lookup_kwarg) self.lookup_val_isnull = params.get(self.lookup_kwarg_isnull) super().__init__(field, request, params, model, model_admin, field_path) self.lookup_choices = self.field_choices(field, request, model_admin) - if hasattr(field, 'verbose_name'): + if hasattr(field, "verbose_name"): self.lookup_title = field.verbose_name else: self.lookup_title = other_model._meta.verbose_name @@ -209,21 +226,27 @@ class RelatedFieldListFilter(FieldListFilter): def choices(self, changelist): yield { - 'selected': self.lookup_val is None and not self.lookup_val_isnull, - 'query_string': changelist.get_query_string(remove=[self.lookup_kwarg, self.lookup_kwarg_isnull]), - 'display': _('All'), + "selected": self.lookup_val is None and not self.lookup_val_isnull, + "query_string": changelist.get_query_string( + remove=[self.lookup_kwarg, self.lookup_kwarg_isnull] + ), + "display": _("All"), } for pk_val, val in self.lookup_choices: yield { - 'selected': self.lookup_val == str(pk_val), - 'query_string': changelist.get_query_string({self.lookup_kwarg: pk_val}, [self.lookup_kwarg_isnull]), - 'display': val, + "selected": self.lookup_val == str(pk_val), + "query_string": changelist.get_query_string( + {self.lookup_kwarg: pk_val}, [self.lookup_kwarg_isnull] + ), + "display": val, } if self.include_empty_choice: yield { - 'selected': bool(self.lookup_val_isnull), - 'query_string': changelist.get_query_string({self.lookup_kwarg_isnull: 'True'}, [self.lookup_kwarg]), - 'display': self.empty_value_display, + "selected": bool(self.lookup_val_isnull), + "query_string": changelist.get_query_string( + {self.lookup_kwarg_isnull: "True"}, [self.lookup_kwarg] + ), + "display": self.empty_value_display, } @@ -232,14 +255,19 @@ FieldListFilter.register(lambda f: f.remote_field, RelatedFieldListFilter) class BooleanFieldListFilter(FieldListFilter): def __init__(self, field, request, params, model, model_admin, field_path): - self.lookup_kwarg = '%s__exact' % field_path - self.lookup_kwarg2 = '%s__isnull' % field_path + self.lookup_kwarg = "%s__exact" % field_path + self.lookup_kwarg2 = "%s__isnull" % field_path self.lookup_val = params.get(self.lookup_kwarg) self.lookup_val2 = params.get(self.lookup_kwarg2) super().__init__(field, request, params, model, model_admin, field_path) - if (self.used_parameters and self.lookup_kwarg in self.used_parameters and - self.used_parameters[self.lookup_kwarg] in ('1', '0')): - self.used_parameters[self.lookup_kwarg] = bool(int(self.used_parameters[self.lookup_kwarg])) + if ( + self.used_parameters + and self.lookup_kwarg in self.used_parameters + and self.used_parameters[self.lookup_kwarg] in ("1", "0") + ): + self.used_parameters[self.lookup_kwarg] = bool( + int(self.used_parameters[self.lookup_kwarg]) + ) def expected_parameters(self): return [self.lookup_kwarg, self.lookup_kwarg2] @@ -247,30 +275,36 @@ class BooleanFieldListFilter(FieldListFilter): def choices(self, changelist): field_choices = dict(self.field.flatchoices) for lookup, title in ( - (None, _('All')), - ('1', field_choices.get(True, _('Yes'))), - ('0', field_choices.get(False, _('No'))), + (None, _("All")), + ("1", field_choices.get(True, _("Yes"))), + ("0", field_choices.get(False, _("No"))), ): yield { - 'selected': self.lookup_val == lookup and not self.lookup_val2, - 'query_string': changelist.get_query_string({self.lookup_kwarg: lookup}, [self.lookup_kwarg2]), - 'display': title, + "selected": self.lookup_val == lookup and not self.lookup_val2, + "query_string": changelist.get_query_string( + {self.lookup_kwarg: lookup}, [self.lookup_kwarg2] + ), + "display": title, } if self.field.null: yield { - 'selected': self.lookup_val2 == 'True', - 'query_string': changelist.get_query_string({self.lookup_kwarg2: 'True'}, [self.lookup_kwarg]), - 'display': field_choices.get(None, _('Unknown')), + "selected": self.lookup_val2 == "True", + "query_string": changelist.get_query_string( + {self.lookup_kwarg2: "True"}, [self.lookup_kwarg] + ), + "display": field_choices.get(None, _("Unknown")), } -FieldListFilter.register(lambda f: isinstance(f, models.BooleanField), BooleanFieldListFilter) +FieldListFilter.register( + lambda f: isinstance(f, models.BooleanField), BooleanFieldListFilter +) class ChoicesFieldListFilter(FieldListFilter): def __init__(self, field, request, params, model, model_admin, field_path): - self.lookup_kwarg = '%s__exact' % field_path - self.lookup_kwarg_isnull = '%s__isnull' % field_path + self.lookup_kwarg = "%s__exact" % field_path + self.lookup_kwarg_isnull = "%s__isnull" % field_path self.lookup_val = params.get(self.lookup_kwarg) self.lookup_val_isnull = params.get(self.lookup_kwarg_isnull) super().__init__(field, request, params, model, model_admin, field_path) @@ -280,25 +314,31 @@ class ChoicesFieldListFilter(FieldListFilter): def choices(self, changelist): yield { - 'selected': self.lookup_val is None, - 'query_string': changelist.get_query_string(remove=[self.lookup_kwarg, self.lookup_kwarg_isnull]), - 'display': _('All') + "selected": self.lookup_val is None, + "query_string": changelist.get_query_string( + remove=[self.lookup_kwarg, self.lookup_kwarg_isnull] + ), + "display": _("All"), } - none_title = '' + none_title = "" for lookup, title in self.field.flatchoices: if lookup is None: none_title = title continue yield { - 'selected': str(lookup) == self.lookup_val, - 'query_string': changelist.get_query_string({self.lookup_kwarg: lookup}, [self.lookup_kwarg_isnull]), - 'display': title, + "selected": str(lookup) == self.lookup_val, + "query_string": changelist.get_query_string( + {self.lookup_kwarg: lookup}, [self.lookup_kwarg_isnull] + ), + "display": title, } if none_title: yield { - 'selected': bool(self.lookup_val_isnull), - 'query_string': changelist.get_query_string({self.lookup_kwarg_isnull: 'True'}, [self.lookup_kwarg]), - 'display': none_title, + "selected": bool(self.lookup_val_isnull), + "query_string": changelist.get_query_string( + {self.lookup_kwarg_isnull: "True"}, [self.lookup_kwarg] + ), + "display": none_title, } @@ -307,8 +347,10 @@ FieldListFilter.register(lambda f: bool(f.choices), ChoicesFieldListFilter) class DateFieldListFilter(FieldListFilter): def __init__(self, field, request, params, model, model_admin, field_path): - self.field_generic = '%s__' % field_path - self.date_params = {k: v for k, v in params.items() if k.startswith(self.field_generic)} + self.field_generic = "%s__" % field_path + self.date_params = { + k: v for k, v in params.items() if k.startswith(self.field_generic) + } now = timezone.now() # When time zone support is enabled, convert "now" to the user's time @@ -318,7 +360,7 @@ class DateFieldListFilter(FieldListFilter): if isinstance(field, models.DateTimeField): today = now.replace(hour=0, minute=0, second=0, microsecond=0) - else: # field is a models.DateField + else: # field is a models.DateField today = now.date() tomorrow = today + datetime.timedelta(days=1) if today.month == 12: @@ -327,32 +369,44 @@ class DateFieldListFilter(FieldListFilter): next_month = today.replace(month=today.month + 1, day=1) next_year = today.replace(year=today.year + 1, month=1, day=1) - self.lookup_kwarg_since = '%s__gte' % field_path - self.lookup_kwarg_until = '%s__lt' % field_path + self.lookup_kwarg_since = "%s__gte" % field_path + self.lookup_kwarg_until = "%s__lt" % field_path self.links = ( - (_('Any date'), {}), - (_('Today'), { - self.lookup_kwarg_since: str(today), - self.lookup_kwarg_until: str(tomorrow), - }), - (_('Past 7 days'), { - self.lookup_kwarg_since: str(today - datetime.timedelta(days=7)), - self.lookup_kwarg_until: str(tomorrow), - }), - (_('This month'), { - self.lookup_kwarg_since: str(today.replace(day=1)), - self.lookup_kwarg_until: str(next_month), - }), - (_('This year'), { - self.lookup_kwarg_since: str(today.replace(month=1, day=1)), - self.lookup_kwarg_until: str(next_year), - }), + (_("Any date"), {}), + ( + _("Today"), + { + self.lookup_kwarg_since: str(today), + self.lookup_kwarg_until: str(tomorrow), + }, + ), + ( + _("Past 7 days"), + { + self.lookup_kwarg_since: str(today - datetime.timedelta(days=7)), + self.lookup_kwarg_until: str(tomorrow), + }, + ), + ( + _("This month"), + { + self.lookup_kwarg_since: str(today.replace(day=1)), + self.lookup_kwarg_until: str(next_month), + }, + ), + ( + _("This year"), + { + self.lookup_kwarg_since: str(today.replace(month=1, day=1)), + self.lookup_kwarg_until: str(next_year), + }, + ), ) if field.null: - self.lookup_kwarg_isnull = '%s__isnull' % field_path + self.lookup_kwarg_isnull = "%s__isnull" % field_path self.links += ( - (_('No date'), {self.field_generic + 'isnull': 'True'}), - (_('Has date'), {self.field_generic + 'isnull': 'False'}), + (_("No date"), {self.field_generic + "isnull": "True"}), + (_("Has date"), {self.field_generic + "isnull": "False"}), ) super().__init__(field, request, params, model, model_admin, field_path) @@ -365,14 +419,15 @@ class DateFieldListFilter(FieldListFilter): def choices(self, changelist): for title, param_dict in self.links: yield { - 'selected': self.date_params == param_dict, - 'query_string': changelist.get_query_string(param_dict, [self.field_generic]), - 'display': title, + "selected": self.date_params == param_dict, + "query_string": changelist.get_query_string( + param_dict, [self.field_generic] + ), + "display": title, } -FieldListFilter.register( - lambda f: isinstance(f, models.DateField), DateFieldListFilter) +FieldListFilter.register(lambda f: isinstance(f, models.DateField), DateFieldListFilter) # This should be registered last, because it's a last resort. For example, @@ -381,7 +436,7 @@ FieldListFilter.register( class AllValuesFieldListFilter(FieldListFilter): def __init__(self, field, request, params, model, model_admin, field_path): self.lookup_kwarg = field_path - self.lookup_kwarg_isnull = '%s__isnull' % field_path + self.lookup_kwarg_isnull = "%s__isnull" % field_path self.lookup_val = params.get(self.lookup_kwarg) self.lookup_val_isnull = params.get(self.lookup_kwarg_isnull) self.empty_value_display = model_admin.get_empty_value_display() @@ -391,7 +446,9 @@ class AllValuesFieldListFilter(FieldListFilter): queryset = model_admin.get_queryset(request) else: queryset = parent_model._default_manager.all() - self.lookup_choices = queryset.distinct().order_by(field.name).values_list(field.name, flat=True) + self.lookup_choices = ( + queryset.distinct().order_by(field.name).values_list(field.name, flat=True) + ) super().__init__(field, request, params, model, model_admin, field_path) def expected_parameters(self): @@ -399,9 +456,11 @@ class AllValuesFieldListFilter(FieldListFilter): def choices(self, changelist): yield { - 'selected': self.lookup_val is None and self.lookup_val_isnull is None, - 'query_string': changelist.get_query_string(remove=[self.lookup_kwarg, self.lookup_kwarg_isnull]), - 'display': _('All'), + "selected": self.lookup_val is None and self.lookup_val_isnull is None, + "query_string": changelist.get_query_string( + remove=[self.lookup_kwarg, self.lookup_kwarg_isnull] + ), + "display": _("All"), } include_none = False for val in self.lookup_choices: @@ -410,15 +469,19 @@ class AllValuesFieldListFilter(FieldListFilter): continue val = str(val) yield { - 'selected': self.lookup_val == val, - 'query_string': changelist.get_query_string({self.lookup_kwarg: val}, [self.lookup_kwarg_isnull]), - 'display': val, + "selected": self.lookup_val == val, + "query_string": changelist.get_query_string( + {self.lookup_kwarg: val}, [self.lookup_kwarg_isnull] + ), + "display": val, } if include_none: yield { - 'selected': bool(self.lookup_val_isnull), - 'query_string': changelist.get_query_string({self.lookup_kwarg_isnull: 'True'}, [self.lookup_kwarg]), - 'display': self.empty_value_display, + "selected": bool(self.lookup_val_isnull), + "query_string": changelist.get_query_string( + {self.lookup_kwarg_isnull: "True"}, [self.lookup_kwarg] + ), + "display": self.empty_value_display, } @@ -427,9 +490,15 @@ FieldListFilter.register(lambda f: True, AllValuesFieldListFilter) class RelatedOnlyFieldListFilter(RelatedFieldListFilter): def field_choices(self, field, request, model_admin): - pk_qs = model_admin.get_queryset(request).distinct().values_list('%s__pk' % self.field_path, flat=True) + pk_qs = ( + model_admin.get_queryset(request) + .distinct() + .values_list("%s__pk" % self.field_path, flat=True) + ) ordering = self.field_admin_ordering(field, request, model_admin) - return field.get_choices(include_blank=False, limit_choices_to={'pk__in': pk_qs}, ordering=ordering) + return field.get_choices( + include_blank=False, limit_choices_to={"pk__in": pk_qs}, ordering=ordering + ) class EmptyFieldListFilter(FieldListFilter): @@ -437,28 +506,29 @@ class EmptyFieldListFilter(FieldListFilter): if not field.empty_strings_allowed and not field.null: raise ImproperlyConfigured( "The list filter '%s' cannot be used with field '%s' which " - "doesn't allow empty strings and nulls." % ( + "doesn't allow empty strings and nulls." + % ( self.__class__.__name__, field.name, ) ) - self.lookup_kwarg = '%s__isempty' % field_path + self.lookup_kwarg = "%s__isempty" % field_path self.lookup_val = params.get(self.lookup_kwarg) super().__init__(field, request, params, model, model_admin, field_path) def queryset(self, request, queryset): if self.lookup_kwarg not in self.used_parameters: return queryset - if self.lookup_val not in ('0', '1'): + if self.lookup_val not in ("0", "1"): raise IncorrectLookupParameters lookup_conditions = [] if self.field.empty_strings_allowed: - lookup_conditions.append((self.field_path, '')) + lookup_conditions.append((self.field_path, "")) if self.field.null: - lookup_conditions.append((f'{self.field_path}__isnull', True)) + lookup_conditions.append((f"{self.field_path}__isnull", True)) lookup_condition = models.Q(*lookup_conditions, _connector=models.Q.OR) - if self.lookup_val == '1': + if self.lookup_val == "1": return queryset.filter(lookup_condition) return queryset.exclude(lookup_condition) @@ -467,12 +537,14 @@ class EmptyFieldListFilter(FieldListFilter): def choices(self, changelist): for lookup, title in ( - (None, _('All')), - ('1', _('Empty')), - ('0', _('Not empty')), + (None, _("All")), + ("1", _("Empty")), + ("0", _("Not empty")), ): yield { - 'selected': self.lookup_val == lookup, - 'query_string': changelist.get_query_string({self.lookup_kwarg: lookup}), - 'display': title, + "selected": self.lookup_val == lookup, + "query_string": changelist.get_query_string( + {self.lookup_kwarg: lookup} + ), + "display": title, } diff --git a/django/contrib/admin/forms.py b/django/contrib/admin/forms.py index ee275095e3..bbb072bdb2 100644 --- a/django/contrib/admin/forms.py +++ b/django/contrib/admin/forms.py @@ -7,24 +7,25 @@ class AdminAuthenticationForm(AuthenticationForm): """ A custom authentication form used in the admin app. """ + error_messages = { **AuthenticationForm.error_messages, - 'invalid_login': _( + "invalid_login": _( "Please enter the correct %(username)s and password for a staff " "account. Note that both fields may be case-sensitive." ), } - required_css_class = 'required' + required_css_class = "required" def confirm_login_allowed(self, user): super().confirm_login_allowed(user) if not user.is_staff: raise ValidationError( - self.error_messages['invalid_login'], - code='invalid_login', - params={'username': self.username_field.verbose_name} + self.error_messages["invalid_login"], + code="invalid_login", + params={"username": self.username_field.verbose_name}, ) class AdminPasswordChangeForm(PasswordChangeForm): - required_css_class = 'required' + required_css_class = "required" diff --git a/django/contrib/admin/helpers.py b/django/contrib/admin/helpers.py index dae626b550..2e7a20a49b 100644 --- a/django/contrib/admin/helpers.py +++ b/django/contrib/admin/helpers.py @@ -2,43 +2,57 @@ import json from django import forms from django.contrib.admin.utils import ( - display_for_field, flatten_fieldsets, help_text_for_field, label_for_field, - lookup_field, quote, + display_for_field, + flatten_fieldsets, + help_text_for_field, + label_for_field, + lookup_field, + quote, ) from django.core.exceptions import ObjectDoesNotExist from django.db.models.fields.related import ( - ForeignObjectRel, ManyToManyRel, OneToOneField, + ForeignObjectRel, + ManyToManyRel, + OneToOneField, ) from django.forms.utils import flatatt from django.template.defaultfilters import capfirst, linebreaksbr from django.urls import NoReverseMatch, reverse from django.utils.html import conditional_escape, format_html from django.utils.safestring import mark_safe -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ -ACTION_CHECKBOX_NAME = '_selected_action' +ACTION_CHECKBOX_NAME = "_selected_action" class ActionForm(forms.Form): - action = forms.ChoiceField(label=_('Action:')) + action = forms.ChoiceField(label=_("Action:")) select_across = forms.BooleanField( - label='', + label="", required=False, initial=0, - widget=forms.HiddenInput({'class': 'select-across'}), + widget=forms.HiddenInput({"class": "select-across"}), ) -checkbox = forms.CheckboxInput({'class': 'action-select'}, lambda value: False) +checkbox = forms.CheckboxInput({"class": "action-select"}, lambda value: False) class AdminForm: - def __init__(self, form, fieldsets, prepopulated_fields, readonly_fields=None, model_admin=None): + def __init__( + self, + form, + fieldsets, + prepopulated_fields, + readonly_fields=None, + model_admin=None, + ): self.form, self.fieldsets = form, fieldsets - self.prepopulated_fields = [{ - 'field': form[field_name], - 'dependencies': [form[f] for f in dependencies] - } for field_name, dependencies in prepopulated_fields.items()] + self.prepopulated_fields = [ + {"field": form[field_name], "dependencies": [form[f] for f in dependencies]} + for field_name, dependencies in prepopulated_fields.items() + ] self.model_admin = model_admin if readonly_fields is None: readonly_fields = () @@ -46,18 +60,19 @@ class AdminForm: def __repr__(self): return ( - f'<{self.__class__.__qualname__}: ' - f'form={self.form.__class__.__qualname__} ' - f'fieldsets={self.fieldsets!r}>' + f"<{self.__class__.__qualname__}: " + f"form={self.form.__class__.__qualname__} " + f"fieldsets={self.fieldsets!r}>" ) def __iter__(self): for name, options in self.fieldsets: yield Fieldset( - self.form, name, + self.form, + name, readonly_fields=self.readonly_fields, model_admin=self.model_admin, - **options + **options, ) @property @@ -77,24 +92,34 @@ class AdminForm: class Fieldset: - def __init__(self, form, name=None, readonly_fields=(), fields=(), classes=(), - description=None, model_admin=None): + def __init__( + self, + form, + name=None, + readonly_fields=(), + fields=(), + classes=(), + description=None, + model_admin=None, + ): self.form = form self.name, self.fields = name, fields - self.classes = ' '.join(classes) + self.classes = " ".join(classes) self.description = description self.model_admin = model_admin self.readonly_fields = readonly_fields @property def media(self): - if 'collapse' in self.classes: - return forms.Media(js=['admin/js/collapse.js']) + if "collapse" in self.classes: + return forms.Media(js=["admin/js/collapse.js"]) return forms.Media() def __iter__(self): for field in self.fields: - yield Fieldline(self.form, field, self.readonly_fields, model_admin=self.model_admin) + yield Fieldline( + self.form, field, self.readonly_fields, model_admin=self.model_admin + ) class Fieldline: @@ -116,15 +141,19 @@ class Fieldline: def __iter__(self): for i, field in enumerate(self.fields): if field in self.readonly_fields: - yield AdminReadonlyField(self.form, field, is_first=(i == 0), model_admin=self.model_admin) + yield AdminReadonlyField( + self.form, field, is_first=(i == 0), model_admin=self.model_admin + ) else: yield AdminField(self.form, field, is_first=(i == 0)) def errors(self): return mark_safe( - '\n'.join( - self.form[f].errors.as_ul() for f in self.fields if f not in self.readonly_fields - ).strip('\n') + "\n".join( + self.form[f].errors.as_ul() + for f in self.fields + if f not in self.readonly_fields + ).strip("\n") ) @@ -139,18 +168,19 @@ class AdminField: classes = [] contents = conditional_escape(self.field.label) if self.is_checkbox: - classes.append('vCheckboxLabel') + classes.append("vCheckboxLabel") if self.field.field.required: - classes.append('required') + classes.append("required") if not self.is_first: - classes.append('inline') - attrs = {'class': ' '.join(classes)} if classes else {} + classes.append("inline") + attrs = {"class": " ".join(classes)} if classes else {} # checkboxes should not have a label suffix as the checkbox appears # to the left of the label. return self.field.label_tag( - contents=mark_safe(contents), attrs=attrs, - label_suffix='' if self.is_checkbox else None, + contents=mark_safe(contents), + attrs=attrs, + label_suffix="" if self.is_checkbox else None, ) def errors(self): @@ -163,7 +193,7 @@ class AdminReadonlyField: # {{ field.name }} must be a useful class name to identify the field. # For convenience, store other field-related data here too. if callable(field): - class_name = field.__name__ if field.__name__ != '' else '' + class_name = field.__name__ if field.__name__ != "" else "" else: class_name = field @@ -183,11 +213,11 @@ class AdminReadonlyField: is_hidden = False self.field = { - 'name': class_name, - 'label': label, - 'help_text': help_text, - 'field': field, - 'is_hidden': is_hidden, + "name": class_name, + "label": label, + "help_text": help_text, + "field": field, + "is_hidden": is_hidden, } self.form = form self.model_admin = model_admin @@ -200,11 +230,16 @@ class AdminReadonlyField: attrs = {} if not self.is_first: attrs["class"] = "inline" - label = self.field['label'] - return format_html('{}{}', flatatt(attrs), capfirst(label), self.form.label_suffix) + label = self.field["label"] + return format_html( + "{}{}", + flatatt(attrs), + capfirst(label), + self.form.label_suffix, + ) def get_admin_url(self, remote_field, remote_obj): - url_name = 'admin:%s_%s_change' % ( + url_name = "admin:%s_%s_change" % ( remote_field.model._meta.app_label, remote_field.model._meta.model_name, ) @@ -220,7 +255,12 @@ class AdminReadonlyField: def contents(self): from django.contrib.admin.templatetags.admin_list import _boolean_icon - field, obj, model_admin = self.field['field'], self.form.instance, self.model_admin + + field, obj, model_admin = ( + self.field["field"], + self.form.instance, + self.model_admin, + ) try: f, attr, value = lookup_field(field, obj, model_admin) except (AttributeError, ValueError, ObjectDoesNotExist): @@ -230,10 +270,10 @@ class AdminReadonlyField: widget = self.form[field].field.widget # This isn't elegant but suffices for contrib.auth's # ReadOnlyPasswordHashWidget. - if getattr(widget, 'read_only', False): + if getattr(widget, "read_only", False): return widget.render(field, value) if f is None: - if getattr(attr, 'boolean', False): + if getattr(attr, "boolean", False): result_repr = _boolean_icon(value) else: if hasattr(value, "__html__"): @@ -244,8 +284,8 @@ class AdminReadonlyField: if isinstance(f.remote_field, ManyToManyRel) and value is not None: result_repr = ", ".join(map(str, value.all())) elif ( - isinstance(f.remote_field, (ForeignObjectRel, OneToOneField)) and - value is not None + isinstance(f.remote_field, (ForeignObjectRel, OneToOneField)) + and value is not None ): result_repr = self.get_admin_url(f.remote_field, value) else: @@ -258,10 +298,20 @@ class InlineAdminFormSet: """ A wrapper around an inline formset for use in the admin system. """ - def __init__(self, inline, formset, fieldsets, prepopulated_fields=None, - readonly_fields=None, model_admin=None, has_add_permission=True, - has_change_permission=True, has_delete_permission=True, - has_view_permission=True): + + def __init__( + self, + inline, + formset, + fieldsets, + prepopulated_fields=None, + readonly_fields=None, + model_admin=None, + has_add_permission=True, + has_change_permission=True, + has_delete_permission=True, + has_view_permission=True, + ): self.opts = inline self.formset = formset self.fieldsets = fieldsets @@ -272,7 +322,7 @@ class InlineAdminFormSet: if prepopulated_fields is None: prepopulated_fields = {} self.prepopulated_fields = prepopulated_fields - self.classes = ' '.join(inline.classes) if inline.classes else '' + self.classes = " ".join(inline.classes) if inline.classes else "" self.has_add_permission = has_add_permission self.has_change_permission = has_change_permission self.has_delete_permission = has_delete_permission @@ -282,25 +332,43 @@ class InlineAdminFormSet: if self.has_change_permission: readonly_fields_for_editing = self.readonly_fields else: - readonly_fields_for_editing = self.readonly_fields + flatten_fieldsets(self.fieldsets) + readonly_fields_for_editing = self.readonly_fields + flatten_fieldsets( + self.fieldsets + ) - for form, original in zip(self.formset.initial_forms, self.formset.get_queryset()): + for form, original in zip( + self.formset.initial_forms, self.formset.get_queryset() + ): view_on_site_url = self.opts.get_view_on_site_url(original) yield InlineAdminForm( - self.formset, form, self.fieldsets, self.prepopulated_fields, - original, readonly_fields_for_editing, model_admin=self.opts, + self.formset, + form, + self.fieldsets, + self.prepopulated_fields, + original, + readonly_fields_for_editing, + model_admin=self.opts, view_on_site_url=view_on_site_url, ) for form in self.formset.extra_forms: yield InlineAdminForm( - self.formset, form, self.fieldsets, self.prepopulated_fields, - None, self.readonly_fields, model_admin=self.opts, + self.formset, + form, + self.fieldsets, + self.prepopulated_fields, + None, + self.readonly_fields, + model_admin=self.opts, ) if self.has_add_permission: yield InlineAdminForm( - self.formset, self.formset.empty_form, - self.fieldsets, self.prepopulated_fields, None, - self.readonly_fields, model_admin=self.opts, + self.formset, + self.formset.empty_form, + self.fieldsets, + self.prepopulated_fields, + None, + self.readonly_fields, + model_admin=self.opts, ) def fields(self): @@ -317,42 +385,49 @@ class InlineAdminFormSet: if form_field is not None: widget_is_hidden = form_field.widget.is_hidden yield { - 'name': field_name, - 'label': meta_labels.get(field_name) or label_for_field( + "name": field_name, + "label": meta_labels.get(field_name) + or label_for_field( field_name, self.opts.model, self.opts, form=empty_form, ), - 'widget': {'is_hidden': widget_is_hidden}, - 'required': False, - 'help_text': meta_help_texts.get(field_name) or help_text_for_field(field_name, self.opts.model), + "widget": {"is_hidden": widget_is_hidden}, + "required": False, + "help_text": meta_help_texts.get(field_name) + or help_text_for_field(field_name, self.opts.model), } else: form_field = empty_form.fields[field_name] label = form_field.label if label is None: - label = label_for_field(field_name, self.opts.model, self.opts, form=empty_form) + label = label_for_field( + field_name, self.opts.model, self.opts, form=empty_form + ) yield { - 'name': field_name, - 'label': label, - 'widget': form_field.widget, - 'required': form_field.required, - 'help_text': form_field.help_text, + "name": field_name, + "label": label, + "widget": form_field.widget, + "required": form_field.required, + "help_text": form_field.help_text, } def inline_formset_data(self): verbose_name = self.opts.verbose_name - return json.dumps({ - 'name': '#%s' % self.formset.prefix, - 'options': { - 'prefix': self.formset.prefix, - 'addText': gettext('Add another %(verbose_name)s') % { - 'verbose_name': capfirst(verbose_name), + return json.dumps( + { + "name": "#%s" % self.formset.prefix, + "options": { + "prefix": self.formset.prefix, + "addText": gettext("Add another %(verbose_name)s") + % { + "verbose_name": capfirst(verbose_name), + }, + "deleteText": gettext("Remove"), }, - 'deleteText': gettext('Remove'), } - }) + ) @property def forms(self): @@ -374,31 +449,51 @@ class InlineAdminForm(AdminForm): """ A wrapper around an inline form for use in the admin system. """ - def __init__(self, formset, form, fieldsets, prepopulated_fields, original, - readonly_fields=None, model_admin=None, view_on_site_url=None): + + def __init__( + self, + formset, + form, + fieldsets, + prepopulated_fields, + original, + readonly_fields=None, + model_admin=None, + view_on_site_url=None, + ): self.formset = formset self.model_admin = model_admin self.original = original self.show_url = original and view_on_site_url is not None self.absolute_url = view_on_site_url - super().__init__(form, fieldsets, prepopulated_fields, readonly_fields, model_admin) + super().__init__( + form, fieldsets, prepopulated_fields, readonly_fields, model_admin + ) def __iter__(self): for name, options in self.fieldsets: yield InlineFieldset( - self.formset, self.form, name, self.readonly_fields, - model_admin=self.model_admin, **options + self.formset, + self.form, + name, + self.readonly_fields, + model_admin=self.model_admin, + **options, ) def needs_explicit_pk_field(self): return ( # Auto fields are editable, so check for auto or non-editable pk. - self.form._meta.model._meta.auto_field or not self.form._meta.model._meta.pk.editable or + self.form._meta.model._meta.auto_field + or not self.form._meta.model._meta.pk.editable + or # Also search any parents for an auto field. (The pk info is # propagated to child models so that does not need to be checked # in parents.) - any(parent._meta.auto_field or not parent._meta.model._meta.pk.editable - for parent in self.form._meta.model._meta.get_parent_list()) + any( + parent._meta.auto_field or not parent._meta.model._meta.pk.editable + for parent in self.form._meta.model._meta.get_parent_list() + ) ) def pk_field(self): @@ -413,10 +508,12 @@ class InlineAdminForm(AdminForm): def deletion_field(self): from django.forms.formsets import DELETION_FIELD_NAME + return AdminField(self.form, DELETION_FIELD_NAME, False) def ordering_field(self): from django.forms.formsets import ORDERING_FIELD_NAME + return AdminField(self.form, ORDERING_FIELD_NAME, False) @@ -429,11 +526,14 @@ class InlineFieldset(Fieldset): fk = getattr(self.formset, "fk", None) for field in self.fields: if not fk or fk.name != field: - yield Fieldline(self.form, field, self.readonly_fields, model_admin=self.model_admin) + yield Fieldline( + self.form, field, self.readonly_fields, model_admin=self.model_admin + ) class AdminErrorList(forms.utils.ErrorList): """Store errors for the form/formsets in an add/change view.""" + def __init__(self, form, inline_formsets): super().__init__() diff --git a/django/contrib/admin/migrations/0001_initial.py b/django/contrib/admin/migrations/0001_initial.py index d6d35bdd72..d02e128497 100644 --- a/django/contrib/admin/migrations/0001_initial.py +++ b/django/contrib/admin/migrations/0001_initial.py @@ -7,40 +7,70 @@ class Migration(migrations.Migration): dependencies = [ migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('contenttypes', '__first__'), + ("contenttypes", "__first__"), ] operations = [ migrations.CreateModel( - name='LogEntry', + name="LogEntry", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('action_time', models.DateTimeField(auto_now=True, verbose_name='action time')), - ('object_id', models.TextField(null=True, verbose_name='object id', blank=True)), - ('object_repr', models.CharField(max_length=200, verbose_name='object repr')), - ('action_flag', models.PositiveSmallIntegerField(verbose_name='action flag')), - ('change_message', models.TextField(verbose_name='change message', blank=True)), - ('content_type', models.ForeignKey( - on_delete=models.SET_NULL, - blank=True, null=True, - to='contenttypes.ContentType', - verbose_name='content type', - )), - ('user', models.ForeignKey( - to=settings.AUTH_USER_MODEL, - on_delete=models.CASCADE, - verbose_name='user', - )), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "action_time", + models.DateTimeField(auto_now=True, verbose_name="action time"), + ), + ( + "object_id", + models.TextField(null=True, verbose_name="object id", blank=True), + ), + ( + "object_repr", + models.CharField(max_length=200, verbose_name="object repr"), + ), + ( + "action_flag", + models.PositiveSmallIntegerField(verbose_name="action flag"), + ), + ( + "change_message", + models.TextField(verbose_name="change message", blank=True), + ), + ( + "content_type", + models.ForeignKey( + on_delete=models.SET_NULL, + blank=True, + null=True, + to="contenttypes.ContentType", + verbose_name="content type", + ), + ), + ( + "user", + models.ForeignKey( + to=settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + verbose_name="user", + ), + ), ], options={ - 'ordering': ['-action_time'], - 'db_table': 'django_admin_log', - 'verbose_name': 'log entry', - 'verbose_name_plural': 'log entries', + "ordering": ["-action_time"], + "db_table": "django_admin_log", + "verbose_name": "log entry", + "verbose_name_plural": "log entries", }, bases=(models.Model,), managers=[ - ('objects', django.contrib.admin.models.LogEntryManager()), + ("objects", django.contrib.admin.models.LogEntryManager()), ], ), ] diff --git a/django/contrib/admin/migrations/0002_logentry_remove_auto_add.py b/django/contrib/admin/migrations/0002_logentry_remove_auto_add.py index a2b19162f2..4e83978e21 100644 --- a/django/contrib/admin/migrations/0002_logentry_remove_auto_add.py +++ b/django/contrib/admin/migrations/0002_logentry_remove_auto_add.py @@ -5,16 +5,16 @@ from django.utils import timezone class Migration(migrations.Migration): dependencies = [ - ('admin', '0001_initial'), + ("admin", "0001_initial"), ] # No database changes; removes auto_add and adds default/editable. operations = [ migrations.AlterField( - model_name='logentry', - name='action_time', + model_name="logentry", + name="action_time", field=models.DateTimeField( - verbose_name='action time', + verbose_name="action time", default=timezone.now, editable=False, ), diff --git a/django/contrib/admin/migrations/0003_logentry_add_action_flag_choices.py b/django/contrib/admin/migrations/0003_logentry_add_action_flag_choices.py index a041a9de0e..59b22314d4 100644 --- a/django/contrib/admin/migrations/0003_logentry_add_action_flag_choices.py +++ b/django/contrib/admin/migrations/0003_logentry_add_action_flag_choices.py @@ -4,17 +4,17 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('admin', '0002_logentry_remove_auto_add'), + ("admin", "0002_logentry_remove_auto_add"), ] # No database changes; adds choices to action_flag. operations = [ migrations.AlterField( - model_name='logentry', - name='action_flag', + model_name="logentry", + name="action_flag", field=models.PositiveSmallIntegerField( - choices=[(1, 'Addition'), (2, 'Change'), (3, 'Deletion')], - verbose_name='action flag', + choices=[(1, "Addition"), (2, "Change"), (3, "Deletion")], + verbose_name="action flag", ), ), ] diff --git a/django/contrib/admin/models.py b/django/contrib/admin/models.py index a0fbb02afd..917fc77dba 100644 --- a/django/contrib/admin/models.py +++ b/django/contrib/admin/models.py @@ -7,23 +7,32 @@ from django.db import models from django.urls import NoReverseMatch, reverse from django.utils import timezone from django.utils.text import get_text_list -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ ADDITION = 1 CHANGE = 2 DELETION = 3 ACTION_FLAG_CHOICES = ( - (ADDITION, _('Addition')), - (CHANGE, _('Change')), - (DELETION, _('Deletion')), + (ADDITION, _("Addition")), + (CHANGE, _("Change")), + (DELETION, _("Deletion")), ) class LogEntryManager(models.Manager): use_in_migrations = True - def log_action(self, user_id, content_type_id, object_id, object_repr, action_flag, change_message=''): + def log_action( + self, + user_id, + content_type_id, + object_id, + object_repr, + action_flag, + change_message="", + ): if isinstance(change_message, list): change_message = json.dumps(change_message) return self.model.objects.create( @@ -38,51 +47,54 @@ class LogEntryManager(models.Manager): class LogEntry(models.Model): action_time = models.DateTimeField( - _('action time'), + _("action time"), default=timezone.now, editable=False, ) user = models.ForeignKey( settings.AUTH_USER_MODEL, models.CASCADE, - verbose_name=_('user'), + verbose_name=_("user"), ) content_type = models.ForeignKey( ContentType, models.SET_NULL, - verbose_name=_('content type'), - blank=True, null=True, + verbose_name=_("content type"), + blank=True, + null=True, ) - object_id = models.TextField(_('object id'), blank=True, null=True) + object_id = models.TextField(_("object id"), blank=True, null=True) # Translators: 'repr' means representation (https://docs.python.org/library/functions.html#repr) - object_repr = models.CharField(_('object repr'), max_length=200) - action_flag = models.PositiveSmallIntegerField(_('action flag'), choices=ACTION_FLAG_CHOICES) + object_repr = models.CharField(_("object repr"), max_length=200) + action_flag = models.PositiveSmallIntegerField( + _("action flag"), choices=ACTION_FLAG_CHOICES + ) # change_message is either a string or a JSON structure - change_message = models.TextField(_('change message'), blank=True) + change_message = models.TextField(_("change message"), blank=True) objects = LogEntryManager() class Meta: - verbose_name = _('log entry') - verbose_name_plural = _('log entries') - db_table = 'django_admin_log' - ordering = ['-action_time'] + verbose_name = _("log entry") + verbose_name_plural = _("log entries") + db_table = "django_admin_log" + ordering = ["-action_time"] def __repr__(self): return str(self.action_time) def __str__(self): if self.is_addition(): - return gettext('Added “%(object)s”.') % {'object': self.object_repr} + return gettext("Added “%(object)s”.") % {"object": self.object_repr} elif self.is_change(): - return gettext('Changed “%(object)s” — %(changes)s') % { - 'object': self.object_repr, - 'changes': self.get_change_message(), + return gettext("Changed “%(object)s” — %(changes)s") % { + "object": self.object_repr, + "changes": self.get_change_message(), } elif self.is_deletion(): - return gettext('Deleted “%(object)s.”') % {'object': self.object_repr} + return gettext("Deleted “%(object)s.”") % {"object": self.object_repr} - return gettext('LogEntry Object') + return gettext("LogEntry Object") def is_addition(self): return self.action_flag == ADDITION @@ -98,38 +110,62 @@ class LogEntry(models.Model): If self.change_message is a JSON structure, interpret it as a change string, properly translated. """ - if self.change_message and self.change_message[0] == '[': + if self.change_message and self.change_message[0] == "[": try: change_message = json.loads(self.change_message) except json.JSONDecodeError: return self.change_message messages = [] for sub_message in change_message: - if 'added' in sub_message: - if sub_message['added']: - sub_message['added']['name'] = gettext(sub_message['added']['name']) - messages.append(gettext('Added {name} “{object}”.').format(**sub_message['added'])) + if "added" in sub_message: + if sub_message["added"]: + sub_message["added"]["name"] = gettext( + sub_message["added"]["name"] + ) + messages.append( + gettext("Added {name} “{object}”.").format( + **sub_message["added"] + ) + ) else: - messages.append(gettext('Added.')) + messages.append(gettext("Added.")) - elif 'changed' in sub_message: - sub_message['changed']['fields'] = get_text_list( - [gettext(field_name) for field_name in sub_message['changed']['fields']], gettext('and') + elif "changed" in sub_message: + sub_message["changed"]["fields"] = get_text_list( + [ + gettext(field_name) + for field_name in sub_message["changed"]["fields"] + ], + gettext("and"), ) - if 'name' in sub_message['changed']: - sub_message['changed']['name'] = gettext(sub_message['changed']['name']) - messages.append(gettext('Changed {fields} for {name} “{object}”.').format( - **sub_message['changed'] - )) + if "name" in sub_message["changed"]: + sub_message["changed"]["name"] = gettext( + sub_message["changed"]["name"] + ) + messages.append( + gettext("Changed {fields} for {name} “{object}”.").format( + **sub_message["changed"] + ) + ) else: - messages.append(gettext('Changed {fields}.').format(**sub_message['changed'])) + messages.append( + gettext("Changed {fields}.").format( + **sub_message["changed"] + ) + ) - elif 'deleted' in sub_message: - sub_message['deleted']['name'] = gettext(sub_message['deleted']['name']) - messages.append(gettext('Deleted {name} “{object}”.').format(**sub_message['deleted'])) + elif "deleted" in sub_message: + sub_message["deleted"]["name"] = gettext( + sub_message["deleted"]["name"] + ) + messages.append( + gettext("Deleted {name} “{object}”.").format( + **sub_message["deleted"] + ) + ) - change_message = ' '.join(msg[0].upper() + msg[1:] for msg in messages) - return change_message or gettext('No fields changed.') + change_message = " ".join(msg[0].upper() + msg[1:] for msg in messages) + return change_message or gettext("No fields changed.") else: return self.change_message @@ -142,7 +178,10 @@ class LogEntry(models.Model): Return the admin URL to edit the object represented by this log entry. """ if self.content_type and self.object_id: - url_name = 'admin:%s_%s_change' % (self.content_type.app_label, self.content_type.model) + url_name = "admin:%s_%s_change" % ( + self.content_type.app_label, + self.content_type.model, + ) try: return reverse(url_name, args=(quote(self.object_id),)) except NoReverseMatch: diff --git a/django/contrib/admin/options.py b/django/contrib/admin/options.py index 12cd0948ca..f051c79ac9 100644 --- a/django/contrib/admin/options.py +++ b/django/contrib/admin/options.py @@ -9,30 +9,42 @@ from django.conf import settings from django.contrib import messages from django.contrib.admin import helpers, widgets from django.contrib.admin.checks import ( - BaseModelAdminChecks, InlineModelAdminChecks, ModelAdminChecks, + BaseModelAdminChecks, + InlineModelAdminChecks, + ModelAdminChecks, ) from django.contrib.admin.decorators import display from django.contrib.admin.exceptions import DisallowedModelAdminToField from django.contrib.admin.templatetags.admin_urls import add_preserved_filters from django.contrib.admin.utils import ( - NestedObjects, construct_change_message, flatten_fieldsets, - get_deleted_objects, lookup_spawns_duplicates, model_format_dict, - model_ngettext, quote, unquote, -) -from django.contrib.admin.widgets import ( - AutocompleteSelect, AutocompleteSelectMultiple, + NestedObjects, + construct_change_message, + flatten_fieldsets, + get_deleted_objects, + lookup_spawns_duplicates, + model_format_dict, + model_ngettext, + quote, + unquote, ) +from django.contrib.admin.widgets import AutocompleteSelect, AutocompleteSelectMultiple from django.contrib.auth import get_permission_codename from django.core.exceptions import ( - FieldDoesNotExist, FieldError, PermissionDenied, ValidationError, + FieldDoesNotExist, + FieldError, + PermissionDenied, + ValidationError, ) from django.core.paginator import Paginator from django.db import models, router, transaction from django.db.models.constants import LOOKUP_SEP from django.forms.formsets import DELETION_FIELD_NAME, all_valid from django.forms.models import ( - BaseInlineFormSet, inlineformset_factory, modelform_defines_fields, - modelform_factory, modelformset_factory, + BaseInlineFormSet, + inlineformset_factory, + modelform_defines_fields, + modelform_factory, + modelformset_factory, ) from django.forms.widgets import CheckboxSelectMultiple, SelectMultiple from django.http import HttpResponseRedirect @@ -44,14 +56,19 @@ from django.utils.html import format_html from django.utils.http import urlencode from django.utils.safestring import mark_safe from django.utils.text import ( - capfirst, format_lazy, get_text_list, smart_split, unescape_string_literal, + capfirst, + format_lazy, + get_text_list, + smart_split, + unescape_string_literal, ) -from django.utils.translation import gettext as _, ngettext +from django.utils.translation import gettext as _ +from django.utils.translation import ngettext from django.views.decorators.csrf import csrf_protect from django.views.generic import RedirectView -IS_POPUP_VAR = '_popup' -TO_FIELD_VAR = '_to_field' +IS_POPUP_VAR = "_popup" +TO_FIELD_VAR = "_to_field" HORIZONTAL, VERTICAL = 1, 2 @@ -61,11 +78,12 @@ def get_content_type_for_model(obj): # Since this module gets imported in the application's root package, # it cannot import models from other applications at the module level. from django.contrib.contenttypes.models import ContentType + return ContentType.objects.get_for_model(obj, for_concrete_model=False) def get_ul_class(radio_style): - return 'radiolist' if radio_style == VERTICAL else 'radiolist inline' + return "radiolist" if radio_style == VERTICAL else "radiolist inline" class IncorrectLookupParameters(Exception): @@ -77,20 +95,20 @@ class IncorrectLookupParameters(Exception): FORMFIELD_FOR_DBFIELD_DEFAULTS = { models.DateTimeField: { - 'form_class': forms.SplitDateTimeField, - 'widget': widgets.AdminSplitDateTime + "form_class": forms.SplitDateTimeField, + "widget": widgets.AdminSplitDateTime, }, - models.DateField: {'widget': widgets.AdminDateWidget}, - models.TimeField: {'widget': widgets.AdminTimeWidget}, - models.TextField: {'widget': widgets.AdminTextareaWidget}, - models.URLField: {'widget': widgets.AdminURLFieldWidget}, - models.IntegerField: {'widget': widgets.AdminIntegerFieldWidget}, - models.BigIntegerField: {'widget': widgets.AdminBigIntegerFieldWidget}, - models.CharField: {'widget': widgets.AdminTextInputWidget}, - models.ImageField: {'widget': widgets.AdminFileWidget}, - models.FileField: {'widget': widgets.AdminFileWidget}, - models.EmailField: {'widget': widgets.AdminEmailInputWidget}, - models.UUIDField: {'widget': widgets.AdminUUIDInputWidget}, + models.DateField: {"widget": widgets.AdminDateWidget}, + models.TimeField: {"widget": widgets.AdminTimeWidget}, + models.TextField: {"widget": widgets.AdminTextareaWidget}, + models.URLField: {"widget": widgets.AdminURLFieldWidget}, + models.IntegerField: {"widget": widgets.AdminIntegerFieldWidget}, + models.BigIntegerField: {"widget": widgets.AdminBigIntegerFieldWidget}, + models.CharField: {"widget": widgets.AdminTextInputWidget}, + models.ImageField: {"widget": widgets.AdminFileWidget}, + models.FileField: {"widget": widgets.AdminFileWidget}, + models.EmailField: {"widget": widgets.AdminEmailInputWidget}, + models.UUIDField: {"widget": widgets.AdminUUIDInputWidget}, } csrf_protect_m = method_decorator(csrf_protect) @@ -160,17 +178,28 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # rendered output. formfield can be None if it came from a # OneToOneField with parent_link=True or a M2M intermediary. if formfield and db_field.name not in self.raw_id_fields: - related_modeladmin = self.admin_site._registry.get(db_field.remote_field.model) + related_modeladmin = self.admin_site._registry.get( + db_field.remote_field.model + ) wrapper_kwargs = {} if related_modeladmin: wrapper_kwargs.update( can_add_related=related_modeladmin.has_add_permission(request), - can_change_related=related_modeladmin.has_change_permission(request), - can_delete_related=related_modeladmin.has_delete_permission(request), - can_view_related=related_modeladmin.has_view_permission(request), + can_change_related=related_modeladmin.has_change_permission( + request + ), + can_delete_related=related_modeladmin.has_delete_permission( + request + ), + can_view_related=related_modeladmin.has_view_permission( + request + ), ) formfield.widget = widgets.RelatedFieldWidgetWrapper( - formfield.widget, db_field.remote_field, self.admin_site, **wrapper_kwargs + formfield.widget, + db_field.remote_field, + self.admin_site, + **wrapper_kwargs, ) return formfield @@ -192,14 +221,15 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # If the field is named as a radio_field, use a RadioSelect if db_field.name in self.radio_fields: # Avoid stomping on custom widget/choices arguments. - if 'widget' not in kwargs: - kwargs['widget'] = widgets.AdminRadioSelect(attrs={ - 'class': get_ul_class(self.radio_fields[db_field.name]), - }) - if 'choices' not in kwargs: - kwargs['choices'] = db_field.get_choices( - include_blank=db_field.blank, - blank_choice=[('', _('None'))] + if "widget" not in kwargs: + kwargs["widget"] = widgets.AdminRadioSelect( + attrs={ + "class": get_ul_class(self.radio_fields[db_field.name]), + } + ) + if "choices" not in kwargs: + kwargs["choices"] = db_field.get_choices( + include_blank=db_field.blank, blank_choice=[("", _("None"))] ) return db_field.formfield(**kwargs) @@ -213,30 +243,38 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): if related_admin is not None: ordering = related_admin.get_ordering(request) if ordering is not None and ordering != (): - return db_field.remote_field.model._default_manager.using(db).order_by(*ordering) + return db_field.remote_field.model._default_manager.using(db).order_by( + *ordering + ) return None def formfield_for_foreignkey(self, db_field, request, **kwargs): """ Get a form Field for a ForeignKey. """ - db = kwargs.get('using') + db = kwargs.get("using") - if 'widget' not in kwargs: + if "widget" not in kwargs: if db_field.name in self.get_autocomplete_fields(request): - kwargs['widget'] = AutocompleteSelect(db_field, self.admin_site, using=db) + kwargs["widget"] = AutocompleteSelect( + db_field, self.admin_site, using=db + ) elif db_field.name in self.raw_id_fields: - kwargs['widget'] = widgets.ForeignKeyRawIdWidget(db_field.remote_field, self.admin_site, using=db) + kwargs["widget"] = widgets.ForeignKeyRawIdWidget( + db_field.remote_field, self.admin_site, using=db + ) elif db_field.name in self.radio_fields: - kwargs['widget'] = widgets.AdminRadioSelect(attrs={ - 'class': get_ul_class(self.radio_fields[db_field.name]), - }) - kwargs['empty_label'] = _('None') if db_field.blank else None + kwargs["widget"] = widgets.AdminRadioSelect( + attrs={ + "class": get_ul_class(self.radio_fields[db_field.name]), + } + ) + kwargs["empty_label"] = _("None") if db_field.blank else None - if 'queryset' not in kwargs: + if "queryset" not in kwargs: queryset = self.get_field_queryset(db, db_field, request) if queryset is not None: - kwargs['queryset'] = queryset + kwargs["queryset"] = queryset return db_field.formfield(**kwargs) @@ -248,38 +286,42 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # a field in admin. if not db_field.remote_field.through._meta.auto_created: return None - db = kwargs.get('using') + db = kwargs.get("using") - if 'widget' not in kwargs: + if "widget" not in kwargs: autocomplete_fields = self.get_autocomplete_fields(request) if db_field.name in autocomplete_fields: - kwargs['widget'] = AutocompleteSelectMultiple( + kwargs["widget"] = AutocompleteSelectMultiple( db_field, self.admin_site, using=db, ) elif db_field.name in self.raw_id_fields: - kwargs['widget'] = widgets.ManyToManyRawIdWidget( + kwargs["widget"] = widgets.ManyToManyRawIdWidget( db_field.remote_field, self.admin_site, using=db, ) elif db_field.name in [*self.filter_vertical, *self.filter_horizontal]: - kwargs['widget'] = widgets.FilteredSelectMultiple( - db_field.verbose_name, - db_field.name in self.filter_vertical + kwargs["widget"] = widgets.FilteredSelectMultiple( + db_field.verbose_name, db_field.name in self.filter_vertical ) - if 'queryset' not in kwargs: + if "queryset" not in kwargs: queryset = self.get_field_queryset(db, db_field, request) if queryset is not None: - kwargs['queryset'] = queryset + kwargs["queryset"] = queryset form_field = db_field.formfield(**kwargs) - if (isinstance(form_field.widget, SelectMultiple) and - not isinstance(form_field.widget, (CheckboxSelectMultiple, AutocompleteSelectMultiple))): - msg = _('Hold down “Control”, or “Command” on a Mac, to select more than one.') + if isinstance(form_field.widget, SelectMultiple) and not isinstance( + form_field.widget, (CheckboxSelectMultiple, AutocompleteSelectMultiple) + ): + msg = _( + "Hold down “Control”, or “Command” on a Mac, to select more than one." + ) help_text = form_field.help_text - form_field.help_text = format_lazy('{} {}', help_text, msg) if help_text else msg + form_field.help_text = ( + format_lazy("{} {}", help_text, msg) if help_text else msg + ) return form_field def get_autocomplete_fields(self, request): @@ -295,12 +337,15 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): if callable(self.view_on_site): return self.view_on_site(obj) - elif hasattr(obj, 'get_absolute_url'): + elif hasattr(obj, "get_absolute_url"): # use the ContentType lookup if view_on_site is True - return reverse('admin:view_on_site', kwargs={ - 'content_type_id': get_content_type_for_model(obj).pk, - 'object_id': obj.pk - }) + return reverse( + "admin:view_on_site", + kwargs={ + "content_type_id": get_content_type_for_model(obj).pk, + "object_id": obj.pk, + }, + ) def get_empty_value_display(self): """ @@ -333,7 +378,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): """ if self.fieldsets: return self.fieldsets - return [(None, {'fields': self.get_fields(request, obj)})] + return [(None, {"fields": self.get_fields(request, obj)})] def get_inlines(self, request, obj): """Hook for specifying custom inlines.""" @@ -371,7 +416,11 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): def get_sortable_by(self, request): """Hook for specifying which fields can be sorted in the changelist.""" - return self.sortable_by if self.sortable_by is not None else self.get_list_display(request) + return ( + self.sortable_by + if self.sortable_by is not None + else self.get_list_display(request) + ) def lookup_allowed(self, lookup, value): from django.contrib.admin.filters import SimpleListFilter @@ -384,7 +433,9 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # As ``limit_choices_to`` can be a callable, invoke it here. if callable(fk_lookup): fk_lookup = fk_lookup() - if (lookup, value) in widgets.url_params_from_lookup_dict(fk_lookup).items(): + if (lookup, value) in widgets.url_params_from_lookup_dict( + fk_lookup + ).items(): return True relation_parts = [] @@ -399,10 +450,12 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # It is allowed to filter on values that would be found from local # model anyways. For example, if you filter on employee__department__id, # then the id value would be found already from employee__department_id. - if not prev_field or (prev_field.is_relation and - field not in prev_field.path_infos[-1].target_fields): + if not prev_field or ( + prev_field.is_relation + and field not in prev_field.path_infos[-1].target_fields + ): relation_parts.append(part) - if not getattr(field, 'path_infos', None): + if not getattr(field, "path_infos", None): # This is not a relational field, so further parts # must be transforms. break @@ -414,7 +467,9 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): return True valid_lookups = {self.date_hierarchy} for filter_item in self.list_filter: - if isinstance(filter_item, type) and issubclass(filter_item, SimpleListFilter): + if isinstance(filter_item, type) and issubclass( + filter_item, SimpleListFilter + ): valid_lookups.add(filter_item.parameter_name) elif isinstance(filter_item, (list, tuple)): valid_lookups.add(filter_item[0]) @@ -424,7 +479,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): # Is it a valid relational lookup? return not { LOOKUP_SEP.join(relation_parts), - LOOKUP_SEP.join(relation_parts + [part]) + LOOKUP_SEP.join(relation_parts + [part]), }.isdisjoint(valid_lookups) def to_field_allowed(self, request, to_field): @@ -459,15 +514,18 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): registered_models.add(inline.model) related_objects = ( - f for f in opts.get_fields(include_hidden=True) + f + for f in opts.get_fields(include_hidden=True) if (f.auto_created and not f.concrete) ) for related_object in related_objects: related_model = related_object.related_model remote_field = related_object.field.remote_field - if (any(issubclass(model, related_model) for model in registered_models) and - hasattr(remote_field, 'get_related_field') and - remote_field.get_related_field() == field): + if ( + any(issubclass(model, related_model) for model in registered_models) + and hasattr(remote_field, "get_related_field") + and remote_field.get_related_field() == field + ): return True return False @@ -478,7 +536,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): Can be overridden by the user in subclasses. """ opts = self.opts - codename = get_permission_codename('add', opts) + codename = get_permission_codename("add", opts) return request.user.has_perm("%s.%s" % (opts.app_label, codename)) def has_change_permission(self, request, obj=None): @@ -493,7 +551,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): request has permission to change *any* object of the given type. """ opts = self.opts - codename = get_permission_codename('change', opts) + codename = get_permission_codename("change", opts) return request.user.has_perm("%s.%s" % (opts.app_label, codename)) def has_delete_permission(self, request, obj=None): @@ -508,7 +566,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): request has permission to delete *any* object of the given type. """ opts = self.opts - codename = get_permission_codename('delete', opts) + codename = get_permission_codename("delete", opts) return request.user.has_perm("%s.%s" % (opts.app_label, codename)) def has_view_permission(self, request, obj=None): @@ -523,15 +581,16 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): any object of the given type. """ opts = self.opts - codename_view = get_permission_codename('view', opts) - codename_change = get_permission_codename('change', opts) - return ( - request.user.has_perm('%s.%s' % (opts.app_label, codename_view)) or - request.user.has_perm('%s.%s' % (opts.app_label, codename_change)) - ) + codename_view = get_permission_codename("view", opts) + codename_change = get_permission_codename("change", opts) + return request.user.has_perm( + "%s.%s" % (opts.app_label, codename_view) + ) or request.user.has_perm("%s.%s" % (opts.app_label, codename_change)) def has_view_or_change_permission(self, request, obj=None): - return self.has_view_permission(request, obj) or self.has_change_permission(request, obj) + return self.has_view_permission(request, obj) or self.has_change_permission( + request, obj + ) def has_module_permission(self, request): """ @@ -550,7 +609,7 @@ class BaseModelAdmin(metaclass=forms.MediaDefiningClass): class ModelAdmin(BaseModelAdmin): """Encapsulate all admin options and functionality for a given model.""" - list_display = ('__str__',) + list_display = ("__str__",) list_display_links = () list_filter = () list_select_related = False @@ -595,8 +654,8 @@ class ModelAdmin(BaseModelAdmin): def __repr__(self): return ( - f'<{self.__class__.__qualname__}: model={self.model.__qualname__} ' - f'site={self.admin_site!r}>' + f"<{self.__class__.__qualname__}: model={self.model.__qualname__} " + f"site={self.admin_site!r}>" ) def get_inline_instances(self, request, obj=None): @@ -604,9 +663,11 @@ class ModelAdmin(BaseModelAdmin): for inline_class in self.get_inlines(request, obj): inline = inline_class(self.model, self.admin_site) if request: - if not (inline.has_view_or_change_permission(request, obj) or - inline.has_add_permission(request, obj) or - inline.has_delete_permission(request, obj)): + if not ( + inline.has_view_or_change_permission(request, obj) + or inline.has_add_permission(request, obj) + or inline.has_delete_permission(request, obj) + ): continue if not inline.has_add_permission(request, obj): inline.max_num = 0 @@ -620,21 +681,40 @@ class ModelAdmin(BaseModelAdmin): def wrap(view): def wrapper(*args, **kwargs): return self.admin_site.admin_view(view)(*args, **kwargs) + wrapper.model_admin = self return update_wrapper(wrapper, view) info = self.model._meta.app_label, self.model._meta.model_name return [ - path('', wrap(self.changelist_view), name='%s_%s_changelist' % info), - path('add/', wrap(self.add_view), name='%s_%s_add' % info), - path('/history/', wrap(self.history_view), name='%s_%s_history' % info), - path('/delete/', wrap(self.delete_view), name='%s_%s_delete' % info), - path('/change/', wrap(self.change_view), name='%s_%s_change' % info), + path("", wrap(self.changelist_view), name="%s_%s_changelist" % info), + path("add/", wrap(self.add_view), name="%s_%s_add" % info), + path( + "/history/", + wrap(self.history_view), + name="%s_%s_history" % info, + ), + path( + "/delete/", + wrap(self.delete_view), + name="%s_%s_delete" % info, + ), + path( + "/change/", + wrap(self.change_view), + name="%s_%s_change" % info, + ), # For backwards compatibility (was the change url before 1.9) - path('/', wrap(RedirectView.as_view( - pattern_name='%s:%s_%s_change' % ((self.admin_site.name,) + info) - ))), + path( + "/", + wrap( + RedirectView.as_view( + pattern_name="%s:%s_%s_change" + % ((self.admin_site.name,) + info) + ) + ), + ), ] @property @@ -643,18 +723,18 @@ class ModelAdmin(BaseModelAdmin): @property def media(self): - extra = '' if settings.DEBUG else '.min' + extra = "" if settings.DEBUG else ".min" js = [ - 'vendor/jquery/jquery%s.js' % extra, - 'jquery.init.js', - 'core.js', - 'admin/RelatedObjectLookups.js', - 'actions.js', - 'urlify.js', - 'prepopulate.js', - 'vendor/xregexp/xregexp%s.js' % extra, + "vendor/jquery/jquery%s.js" % extra, + "jquery.init.js", + "core.js", + "admin/RelatedObjectLookups.js", + "actions.js", + "urlify.js", + "prepopulate.js", + "vendor/xregexp/xregexp%s.js" % extra, ] - return forms.Media(js=['admin/js/%s' % url for url in js]) + return forms.Media(js=["admin/js/%s" % url for url in js]) def get_model_perms(self, request): """ @@ -663,10 +743,10 @@ class ModelAdmin(BaseModelAdmin): for each of those actions. """ return { - 'add': self.has_add_permission(request), - 'change': self.has_change_permission(request), - 'delete': self.has_delete_permission(request), - 'view': self.has_view_permission(request), + "add": self.has_add_permission(request), + "change": self.has_change_permission(request), + "delete": self.has_delete_permission(request), + "view": self.has_view_permission(request), } def _get_form_for_get_fields(self, request, obj): @@ -677,8 +757,8 @@ class ModelAdmin(BaseModelAdmin): Return a Form class for use in the admin add view. This is used by add_view and change_view. """ - if 'fields' in kwargs: - fields = kwargs.pop('fields') + if "fields" in kwargs: + fields = kwargs.pop("fields") else: fields = flatten_fieldsets(self.get_fieldsets(request, obj)) excluded = self.get_exclude(request, obj) @@ -687,9 +767,13 @@ class ModelAdmin(BaseModelAdmin): exclude.extend(readonly_fields) # Exclude all fields if it's a change form and the user doesn't have # the change permission. - if change and hasattr(request, 'user') and not self.has_change_permission(request, obj): + if ( + change + and hasattr(request, "user") + and not self.has_change_permission(request, obj) + ): exclude.extend(fields) - if excluded is None and hasattr(self.form, '_meta') and self.form._meta.exclude: + if excluded is None and hasattr(self.form, "_meta") and self.form._meta.exclude: # Take the custom ModelForm's Meta.exclude into account only if the # ModelAdmin doesn't define its own. exclude.extend(self.form._meta.exclude) @@ -698,25 +782,29 @@ class ModelAdmin(BaseModelAdmin): exclude = exclude or None # Remove declared form fields which are in readonly_fields. - new_attrs = dict.fromkeys(f for f in readonly_fields if f in self.form.declared_fields) + new_attrs = dict.fromkeys( + f for f in readonly_fields if f in self.form.declared_fields + ) form = type(self.form.__name__, (self.form,), new_attrs) defaults = { - 'form': form, - 'fields': fields, - 'exclude': exclude, - 'formfield_callback': partial(self.formfield_for_dbfield, request=request), + "form": form, + "fields": fields, + "exclude": exclude, + "formfield_callback": partial(self.formfield_for_dbfield, request=request), **kwargs, } - if defaults['fields'] is None and not modelform_defines_fields(defaults['form']): - defaults['fields'] = forms.ALL_FIELDS + if defaults["fields"] is None and not modelform_defines_fields( + defaults["form"] + ): + defaults["fields"] = forms.ALL_FIELDS try: return modelform_factory(self.model, **defaults) except FieldError as e: raise FieldError( - '%s. Check fields/fieldsets/exclude attributes of class %s.' + "%s. Check fields/fieldsets/exclude attributes of class %s." % (e, self.__class__.__name__) ) @@ -725,6 +813,7 @@ class ModelAdmin(BaseModelAdmin): Return the ChangeList class for use on the changelist page. """ from django.contrib.admin.views.main import ChangeList + return ChangeList def get_changelist_instance(self, request): @@ -736,7 +825,7 @@ class ModelAdmin(BaseModelAdmin): list_display_links = self.get_list_display_links(request, list_display) # Add the action checkboxes if any actions are available. if self.get_actions(request): - list_display = ['action_checkbox', *list_display] + list_display = ["action_checkbox", *list_display] sortable_by = self.get_sortable_by(request) ChangeList = self.get_changelist(request) return ChangeList( @@ -764,7 +853,9 @@ class ModelAdmin(BaseModelAdmin): """ queryset = self.get_queryset(request) model = queryset.model - field = model._meta.pk if from_field is None else model._meta.get_field(from_field) + field = ( + model._meta.pk if from_field is None else model._meta.get_field(from_field) + ) try: object_id = field.to_python(object_id) return queryset.get(**{field.name: object_id}) @@ -776,11 +867,13 @@ class ModelAdmin(BaseModelAdmin): Return a Form class for use in the Formset on the changelist page. """ defaults = { - 'formfield_callback': partial(self.formfield_for_dbfield, request=request), + "formfield_callback": partial(self.formfield_for_dbfield, request=request), **kwargs, } - if defaults.get('fields') is None and not modelform_defines_fields(defaults.get('form')): - defaults['fields'] = forms.ALL_FIELDS + if defaults.get("fields") is None and not modelform_defines_fields( + defaults.get("form") + ): + defaults["fields"] = forms.ALL_FIELDS return modelform_factory(self.model, **defaults) @@ -790,12 +883,15 @@ class ModelAdmin(BaseModelAdmin): is used. """ defaults = { - 'formfield_callback': partial(self.formfield_for_dbfield, request=request), + "formfield_callback": partial(self.formfield_for_dbfield, request=request), **kwargs, } return modelformset_factory( - self.model, self.get_changelist_form(request), extra=0, - fields=self.list_editable, **defaults + self.model, + self.get_changelist_form(request), + extra=0, + fields=self.list_editable, + **defaults, ) def get_formsets_with_inlines(self, request, obj=None): @@ -805,7 +901,9 @@ class ModelAdmin(BaseModelAdmin): for inline in self.get_inline_instances(request, obj): yield inline.get_formset(request, obj), inline - def get_paginator(self, request, queryset, per_page, orphans=0, allow_empty_first_page=True): + def get_paginator( + self, request, queryset, per_page, orphans=0, allow_empty_first_page=True + ): return self.paginator(queryset, per_page, orphans, allow_empty_first_page) def log_addition(self, request, obj, message): @@ -815,6 +913,7 @@ class ModelAdmin(BaseModelAdmin): The default implementation creates an admin LogEntry object. """ from django.contrib.admin.models import ADDITION, LogEntry + return LogEntry.objects.log_action( user_id=request.user.pk, content_type_id=get_content_type_for_model(obj).pk, @@ -831,6 +930,7 @@ class ModelAdmin(BaseModelAdmin): The default implementation creates an admin LogEntry object. """ from django.contrib.admin.models import CHANGE, LogEntry + return LogEntry.objects.log_action( user_id=request.user.pk, content_type_id=get_content_type_for_model(obj).pk, @@ -848,6 +948,7 @@ class ModelAdmin(BaseModelAdmin): The default implementation creates an admin LogEntry object. """ from django.contrib.admin.models import DELETION, LogEntry + return LogEntry.objects.log_action( user_id=request.user.pk, content_type_id=get_content_type_for_model(obj).pk, @@ -865,7 +966,7 @@ class ModelAdmin(BaseModelAdmin): @staticmethod def _get_action_description(func, name): - return getattr(func, 'short_description', capfirst(name.replace('_', ' '))) + return getattr(func, "short_description", capfirst(name.replace("_", " "))) def _get_base_actions(self): """Return the list of actions, prior to any request-based filtering.""" @@ -890,11 +991,11 @@ class ModelAdmin(BaseModelAdmin): filtered_actions = [] for action in actions: callable = action[0] - if not hasattr(callable, 'allowed_permissions'): + if not hasattr(callable, "allowed_permissions"): filtered_actions.append(action) continue permission_checks = ( - getattr(self, 'has_%s_permission' % permission) + getattr(self, "has_%s_permission" % permission) for permission in callable.allowed_permissions ) if any(has_permission(request) for has_permission in permission_checks): @@ -964,7 +1065,11 @@ class ModelAdmin(BaseModelAdmin): on the changelist. The list_display parameter is the list of fields returned by get_list_display(). """ - if self.list_display_links or self.list_display_links is None or not list_display: + if ( + self.list_display_links + or self.list_display_links is None + or not list_display + ): return self.list_display_links else: # Use only the first item in list_display as link @@ -998,11 +1103,11 @@ class ModelAdmin(BaseModelAdmin): """ # Apply keyword searches. def construct_search(field_name): - if field_name.startswith('^'): + if field_name.startswith("^"): return "%s__istartswith" % field_name[1:] - elif field_name.startswith('='): + elif field_name.startswith("="): return "%s__iexact" % field_name[1:] - elif field_name.startswith('@'): + elif field_name.startswith("@"): return "%s__search" % field_name[1:] # Use field_name if it includes a lookup. opts = queryset.model._meta @@ -1010,7 +1115,7 @@ class ModelAdmin(BaseModelAdmin): # Go through the fields, following all relations. prev_field = None for path_part in lookup_fields: - if path_part == 'pk': + if path_part == "pk": path_part = opts.pk.name try: field = opts.get_field(path_part) @@ -1020,7 +1125,7 @@ class ModelAdmin(BaseModelAdmin): return field_name else: prev_field = field - if hasattr(field, 'path_infos'): + if hasattr(field, "path_infos"): # Update opts to follow the relation. opts = field.path_infos[-1].to_opts # Otherwise, use the field with icontains. @@ -1029,8 +1134,9 @@ class ModelAdmin(BaseModelAdmin): may_have_duplicates = False search_fields = self.get_search_fields(request) if search_fields and search_term: - orm_lookups = [construct_search(str(search_field)) - for search_field in search_fields] + orm_lookups = [ + construct_search(str(search_field)) for search_field in search_fields + ] term_queries = [] for bit in smart_split(search_term): if bit.startswith(('"', "'")) and bit[0] == bit[-1]: @@ -1054,16 +1160,19 @@ class ModelAdmin(BaseModelAdmin): match = request.resolver_match if self.preserve_filters and match: opts = self.model._meta - current_url = '%s:%s' % (match.app_name, match.url_name) - changelist_url = 'admin:%s_%s_changelist' % (opts.app_label, opts.model_name) + current_url = "%s:%s" % (match.app_name, match.url_name) + changelist_url = "admin:%s_%s_changelist" % ( + opts.app_label, + opts.model_name, + ) if current_url == changelist_url: preserved_filters = request.GET.urlencode() else: - preserved_filters = request.GET.get('_changelist_filters') + preserved_filters = request.GET.get("_changelist_filters") if preserved_filters: - return urlencode({'_changelist_filters': preserved_filters}) - return '' + return urlencode({"_changelist_filters": preserved_filters}) + return "" def construct_change_message(self, request, form, formsets, add=False): """ @@ -1071,8 +1180,9 @@ class ModelAdmin(BaseModelAdmin): """ return construct_change_message(form, formsets, add) - def message_user(self, request, message, level=messages.INFO, extra_tags='', - fail_silently=False): + def message_user( + self, request, message, level=messages.INFO, extra_tags="", fail_silently=False + ): """ Send a message to the user. The default implementation posts a message using the django.contrib.messages backend. @@ -1088,13 +1198,15 @@ class ModelAdmin(BaseModelAdmin): level = getattr(messages.constants, level.upper()) except AttributeError: levels = messages.constants.DEFAULT_TAGS.values() - levels_repr = ', '.join('`%s`' % level for level in levels) + levels_repr = ", ".join("`%s`" % level for level in levels) raise ValueError( - 'Bad message level string: `%s`. Possible values are: %s' + "Bad message level string: `%s`. Possible values are: %s" % (level, levels_repr) ) - messages.add_message(request, level, message, extra_tags=extra_tags, fail_silently=fail_silently) + messages.add_message( + request, level, message, extra_tags=extra_tags, fail_silently=fail_silently + ) def save_form(self, request, form, change): """ @@ -1137,40 +1249,51 @@ class ModelAdmin(BaseModelAdmin): for formset in formsets: self.save_formset(request, form, formset, change=change) - def render_change_form(self, request, context, add=False, change=False, form_url='', obj=None): + def render_change_form( + self, request, context, add=False, change=False, form_url="", obj=None + ): opts = self.model._meta app_label = opts.app_label preserved_filters = self.get_preserved_filters(request) - form_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, form_url) + form_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, form_url + ) view_on_site_url = self.get_view_on_site_url(obj) has_editable_inline_admin_formsets = False - for inline in context['inline_admin_formsets']: - if inline.has_add_permission or inline.has_change_permission or inline.has_delete_permission: + for inline in context["inline_admin_formsets"]: + if ( + inline.has_add_permission + or inline.has_change_permission + or inline.has_delete_permission + ): has_editable_inline_admin_formsets = True break - context.update({ - 'add': add, - 'change': change, - 'has_view_permission': self.has_view_permission(request, obj), - 'has_add_permission': self.has_add_permission(request), - 'has_change_permission': self.has_change_permission(request, obj), - 'has_delete_permission': self.has_delete_permission(request, obj), - 'has_editable_inline_admin_formsets': has_editable_inline_admin_formsets, - 'has_file_field': context['adminform'].form.is_multipart() or any( - admin_formset.formset.is_multipart() - for admin_formset in context['inline_admin_formsets'] - ), - 'has_absolute_url': view_on_site_url is not None, - 'absolute_url': view_on_site_url, - 'form_url': form_url, - 'opts': opts, - 'content_type_id': get_content_type_for_model(self.model).pk, - 'save_as': self.save_as, - 'save_on_top': self.save_on_top, - 'to_field_var': TO_FIELD_VAR, - 'is_popup_var': IS_POPUP_VAR, - 'app_label': app_label, - }) + context.update( + { + "add": add, + "change": change, + "has_view_permission": self.has_view_permission(request, obj), + "has_add_permission": self.has_add_permission(request), + "has_change_permission": self.has_change_permission(request, obj), + "has_delete_permission": self.has_delete_permission(request, obj), + "has_editable_inline_admin_formsets": has_editable_inline_admin_formsets, + "has_file_field": context["adminform"].form.is_multipart() + or any( + admin_formset.formset.is_multipart() + for admin_formset in context["inline_admin_formsets"] + ), + "has_absolute_url": view_on_site_url is not None, + "absolute_url": view_on_site_url, + "form_url": form_url, + "opts": opts, + "content_type_id": get_content_type_for_model(self.model).pk, + "save_as": self.save_as, + "save_on_top": self.save_on_top, + "to_field_var": TO_FIELD_VAR, + "is_popup_var": IS_POPUP_VAR, + "app_label": app_label, + } + ) if add and self.add_form_template is not None: form_template = self.add_form_template else: @@ -1178,11 +1301,16 @@ class ModelAdmin(BaseModelAdmin): request.current_app = self.admin_site.name - return TemplateResponse(request, form_template or [ - "admin/%s/%s/change_form.html" % (app_label, opts.model_name), - "admin/%s/change_form.html" % app_label, - "admin/change_form.html" - ], context) + return TemplateResponse( + request, + form_template + or [ + "admin/%s/%s/change_form.html" % (app_label, opts.model_name), + "admin/%s/change_form.html" % app_label, + "admin/change_form.html", + ], + context, + ) def response_add(self, request, obj, post_url_continue=None): """ @@ -1191,7 +1319,7 @@ class ModelAdmin(BaseModelAdmin): opts = obj._meta preserved_filters = self.get_preserved_filters(request) obj_url = reverse( - 'admin:%s_%s_change' % (opts.app_label, opts.model_name), + "admin:%s_%s_change" % (opts.app_label, opts.model_name), args=(quote(obj.pk),), current_app=self.admin_site.name, ) @@ -1201,8 +1329,8 @@ class ModelAdmin(BaseModelAdmin): else: obj_repr = str(obj) msg_dict = { - 'name': opts.verbose_name, - 'obj': obj_repr, + "name": opts.verbose_name, + "obj": obj_repr, } # Here, we distinguish between different save types by checking for # the presence of keys in request.POST. @@ -1214,49 +1342,61 @@ class ModelAdmin(BaseModelAdmin): else: attr = obj._meta.pk.attname value = obj.serializable_value(attr) - popup_response_data = json.dumps({ - 'value': str(value), - 'obj': str(obj), - }) - return TemplateResponse(request, self.popup_response_template or [ - 'admin/%s/%s/popup_response.html' % (opts.app_label, opts.model_name), - 'admin/%s/popup_response.html' % opts.app_label, - 'admin/popup_response.html', - ], { - 'popup_response_data': popup_response_data, - }) + popup_response_data = json.dumps( + { + "value": str(value), + "obj": str(obj), + } + ) + return TemplateResponse( + request, + self.popup_response_template + or [ + "admin/%s/%s/popup_response.html" + % (opts.app_label, opts.model_name), + "admin/%s/popup_response.html" % opts.app_label, + "admin/popup_response.html", + ], + { + "popup_response_data": popup_response_data, + }, + ) elif "_continue" in request.POST or ( - # Redirecting after "Save as new". - "_saveasnew" in request.POST and self.save_as_continue and - self.has_change_permission(request, obj) + # Redirecting after "Save as new". + "_saveasnew" in request.POST + and self.save_as_continue + and self.has_change_permission(request, obj) ): - msg = _('The {name} “{obj}” was added successfully.') + msg = _("The {name} “{obj}” was added successfully.") if self.has_change_permission(request, obj): - msg += ' ' + _('You may edit it again below.') + msg += " " + _("You may edit it again below.") self.message_user(request, format_html(msg, **msg_dict), messages.SUCCESS) if post_url_continue is None: post_url_continue = obj_url post_url_continue = add_preserved_filters( - {'preserved_filters': preserved_filters, 'opts': opts}, - post_url_continue + {"preserved_filters": preserved_filters, "opts": opts}, + post_url_continue, ) return HttpResponseRedirect(post_url_continue) elif "_addanother" in request.POST: msg = format_html( - _('The {name} “{obj}” was added successfully. You may add another {name} below.'), - **msg_dict + _( + "The {name} “{obj}” was added successfully. You may add another {name} below." + ), + **msg_dict, ) self.message_user(request, msg, messages.SUCCESS) redirect_url = request.path - redirect_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, redirect_url) + redirect_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, redirect_url + ) return HttpResponseRedirect(redirect_url) else: msg = format_html( - _('The {name} “{obj}” was added successfully.'), - **msg_dict + _("The {name} “{obj}” was added successfully."), **msg_dict ) self.message_user(request, msg, messages.SUCCESS) return self.response_post_save_add(request, obj) @@ -1270,68 +1410,89 @@ class ModelAdmin(BaseModelAdmin): opts = obj._meta to_field = request.POST.get(TO_FIELD_VAR) attr = str(to_field) if to_field else opts.pk.attname - value = request.resolver_match.kwargs['object_id'] + value = request.resolver_match.kwargs["object_id"] new_value = obj.serializable_value(attr) - popup_response_data = json.dumps({ - 'action': 'change', - 'value': str(value), - 'obj': str(obj), - 'new_value': str(new_value), - }) - return TemplateResponse(request, self.popup_response_template or [ - 'admin/%s/%s/popup_response.html' % (opts.app_label, opts.model_name), - 'admin/%s/popup_response.html' % opts.app_label, - 'admin/popup_response.html', - ], { - 'popup_response_data': popup_response_data, - }) + popup_response_data = json.dumps( + { + "action": "change", + "value": str(value), + "obj": str(obj), + "new_value": str(new_value), + } + ) + return TemplateResponse( + request, + self.popup_response_template + or [ + "admin/%s/%s/popup_response.html" + % (opts.app_label, opts.model_name), + "admin/%s/popup_response.html" % opts.app_label, + "admin/popup_response.html", + ], + { + "popup_response_data": popup_response_data, + }, + ) opts = self.model._meta preserved_filters = self.get_preserved_filters(request) msg_dict = { - 'name': opts.verbose_name, - 'obj': format_html('{}', urlquote(request.path), obj), + "name": opts.verbose_name, + "obj": format_html('{}', urlquote(request.path), obj), } if "_continue" in request.POST: msg = format_html( - _('The {name} “{obj}” was changed successfully. You may edit it again below.'), - **msg_dict + _( + "The {name} “{obj}” was changed successfully. You may edit it again below." + ), + **msg_dict, ) self.message_user(request, msg, messages.SUCCESS) redirect_url = request.path - redirect_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, redirect_url) + redirect_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, redirect_url + ) return HttpResponseRedirect(redirect_url) elif "_saveasnew" in request.POST: msg = format_html( - _('The {name} “{obj}” was added successfully. You may edit it again below.'), - **msg_dict + _( + "The {name} “{obj}” was added successfully. You may edit it again below." + ), + **msg_dict, ) self.message_user(request, msg, messages.SUCCESS) - redirect_url = reverse('admin:%s_%s_change' % - (opts.app_label, opts.model_name), - args=(obj.pk,), - current_app=self.admin_site.name) - redirect_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, redirect_url) + redirect_url = reverse( + "admin:%s_%s_change" % (opts.app_label, opts.model_name), + args=(obj.pk,), + current_app=self.admin_site.name, + ) + redirect_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, redirect_url + ) return HttpResponseRedirect(redirect_url) elif "_addanother" in request.POST: msg = format_html( - _('The {name} “{obj}” was changed successfully. You may add another {name} below.'), - **msg_dict + _( + "The {name} “{obj}” was changed successfully. You may add another {name} below." + ), + **msg_dict, ) self.message_user(request, msg, messages.SUCCESS) - redirect_url = reverse('admin:%s_%s_add' % - (opts.app_label, opts.model_name), - current_app=self.admin_site.name) - redirect_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, redirect_url) + redirect_url = reverse( + "admin:%s_%s_add" % (opts.app_label, opts.model_name), + current_app=self.admin_site.name, + ) + redirect_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, redirect_url + ) return HttpResponseRedirect(redirect_url) else: msg = format_html( - _('The {name} “{obj}” was changed successfully.'), - **msg_dict + _("The {name} “{obj}” was changed successfully."), **msg_dict ) self.message_user(request, msg, messages.SUCCESS) return self.response_post_save_change(request, obj) @@ -1339,14 +1500,16 @@ class ModelAdmin(BaseModelAdmin): def _response_post_save(self, request, obj): opts = self.model._meta if self.has_view_or_change_permission(request): - post_url = reverse('admin:%s_%s_changelist' % - (opts.app_label, opts.model_name), - current_app=self.admin_site.name) + post_url = reverse( + "admin:%s_%s_changelist" % (opts.app_label, opts.model_name), + current_app=self.admin_site.name, + ) preserved_filters = self.get_preserved_filters(request) - post_url = add_preserved_filters({'preserved_filters': preserved_filters, 'opts': opts}, post_url) + post_url = add_preserved_filters( + {"preserved_filters": preserved_filters, "opts": opts}, post_url + ) else: - post_url = reverse('admin:index', - current_app=self.admin_site.name) + post_url = reverse("admin:index", current_app=self.admin_site.name) return HttpResponseRedirect(post_url) def response_post_save_add(self, request, obj): @@ -1374,7 +1537,7 @@ class ModelAdmin(BaseModelAdmin): # and bottom of the change list, for example). Get the action # whose button was pushed. try: - action_index = int(request.POST.get('index', 0)) + action_index = int(request.POST.get("index", 0)) except ValueError: action_index = 0 @@ -1385,7 +1548,7 @@ class ModelAdmin(BaseModelAdmin): # Use the action whose button was pushed try: - data.update({'action': data.getlist('action')[action_index]}) + data.update({"action": data.getlist("action")[action_index]}) except IndexError: # If we didn't get an action from the chosen form that's invalid # POST data, so by deleting action it'll fail the validation check @@ -1393,12 +1556,12 @@ class ModelAdmin(BaseModelAdmin): pass action_form = self.action_form(data, auto_id=None) - action_form.fields['action'].choices = self.get_action_choices(request) + action_form.fields["action"].choices = self.get_action_choices(request) # If the form's valid we can handle the action. if action_form.is_valid(): - action = action_form.cleaned_data['action'] - select_across = action_form.cleaned_data['select_across'] + action = action_form.cleaned_data["action"] + select_across = action_form.cleaned_data["select_across"] func = self.get_actions(request)[action][0] # Get the list of selected PKs. If nothing's selected, we can't @@ -1407,8 +1570,10 @@ class ModelAdmin(BaseModelAdmin): selected = request.POST.getlist(helpers.ACTION_CHECKBOX_NAME) if not selected and not select_across: # Reminder that something needs to be selected or nothing will happen - msg = _("Items must be selected in order to perform " - "actions on them. No items have been changed.") + msg = _( + "Items must be selected in order to perform " + "actions on them. No items have been changed." + ) self.message_user(request, msg, messages.WARNING) return None @@ -1437,38 +1602,47 @@ class ModelAdmin(BaseModelAdmin): opts = self.model._meta if IS_POPUP_VAR in request.POST: - popup_response_data = json.dumps({ - 'action': 'delete', - 'value': str(obj_id), - }) - return TemplateResponse(request, self.popup_response_template or [ - 'admin/%s/%s/popup_response.html' % (opts.app_label, opts.model_name), - 'admin/%s/popup_response.html' % opts.app_label, - 'admin/popup_response.html', - ], { - 'popup_response_data': popup_response_data, - }) + popup_response_data = json.dumps( + { + "action": "delete", + "value": str(obj_id), + } + ) + return TemplateResponse( + request, + self.popup_response_template + or [ + "admin/%s/%s/popup_response.html" + % (opts.app_label, opts.model_name), + "admin/%s/popup_response.html" % opts.app_label, + "admin/popup_response.html", + ], + { + "popup_response_data": popup_response_data, + }, + ) self.message_user( request, - _('The %(name)s “%(obj)s” was deleted successfully.') % { - 'name': opts.verbose_name, - 'obj': obj_display, + _("The %(name)s “%(obj)s” was deleted successfully.") + % { + "name": opts.verbose_name, + "obj": obj_display, }, messages.SUCCESS, ) if self.has_change_permission(request, None): post_url = reverse( - 'admin:%s_%s_changelist' % (opts.app_label, opts.model_name), + "admin:%s_%s_changelist" % (opts.app_label, opts.model_name), current_app=self.admin_site.name, ) preserved_filters = self.get_preserved_filters(request) post_url = add_preserved_filters( - {'preserved_filters': preserved_filters, 'opts': opts}, post_url + {"preserved_filters": preserved_filters, "opts": opts}, post_url ) else: - post_url = reverse('admin:index', current_app=self.admin_site.name) + post_url = reverse("admin:index", current_app=self.admin_site.name) return HttpResponseRedirect(post_url) def render_delete_form(self, request, context): @@ -1484,8 +1658,11 @@ class ModelAdmin(BaseModelAdmin): return TemplateResponse( request, - self.delete_confirmation_template or [ - "admin/{}/{}/delete_confirmation.html".format(app_label, opts.model_name), + self.delete_confirmation_template + or [ + "admin/{}/{}/delete_confirmation.html".format( + app_label, opts.model_name + ), "admin/{}/delete_confirmation.html".format(app_label), "admin/delete_confirmation.html", ], @@ -1494,7 +1671,11 @@ class ModelAdmin(BaseModelAdmin): def get_inline_formsets(self, request, formsets, inline_instances, obj=None): # Edit permissions on parent model are required for editable inlines. - can_edit_parent = self.has_change_permission(request, obj) if obj else self.has_add_permission(request) + can_edit_parent = ( + self.has_change_permission(request, obj) + if obj + else self.has_add_permission(request) + ) inline_admin_formsets = [] for inline, formset in zip(inline_instances, formsets): fieldsets = list(inline.get_fieldsets(request, obj)) @@ -1505,14 +1686,23 @@ class ModelAdmin(BaseModelAdmin): has_delete_permission = inline.has_delete_permission(request, obj) else: # Disable all edit-permissions, and overide formset settings. - has_add_permission = has_change_permission = has_delete_permission = False + has_add_permission = ( + has_change_permission + ) = has_delete_permission = False formset.extra = formset.max_num = 0 has_view_permission = inline.has_view_permission(request, obj) prepopulated = dict(inline.get_prepopulated_fields(request, obj)) inline_admin_formset = helpers.InlineAdminFormSet( - inline, formset, fieldsets, prepopulated, readonly, model_admin=self, - has_add_permission=has_add_permission, has_change_permission=has_change_permission, - has_delete_permission=has_delete_permission, has_view_permission=has_view_permission, + inline, + formset, + fieldsets, + prepopulated, + readonly, + model_admin=self, + has_add_permission=has_add_permission, + has_change_permission=has_change_permission, + has_delete_permission=has_delete_permission, + has_view_permission=has_view_permission, ) inline_admin_formsets.append(inline_admin_formset) return inline_admin_formsets @@ -1537,28 +1727,30 @@ class ModelAdmin(BaseModelAdmin): Create a message informing the user that the object doesn't exist and return a redirect to the admin index page. """ - msg = _('%(name)s with ID “%(key)s” doesn’t exist. Perhaps it was deleted?') % { - 'name': opts.verbose_name, - 'key': unquote(object_id), + msg = _("%(name)s with ID “%(key)s” doesn’t exist. Perhaps it was deleted?") % { + "name": opts.verbose_name, + "key": unquote(object_id), } self.message_user(request, msg, messages.WARNING) - url = reverse('admin:index', current_app=self.admin_site.name) + url = reverse("admin:index", current_app=self.admin_site.name) return HttpResponseRedirect(url) @csrf_protect_m - def changeform_view(self, request, object_id=None, form_url='', extra_context=None): + def changeform_view(self, request, object_id=None, form_url="", extra_context=None): with transaction.atomic(using=router.db_for_write(self.model)): return self._changeform_view(request, object_id, form_url, extra_context) def _changeform_view(self, request, object_id, form_url, extra_context): to_field = request.POST.get(TO_FIELD_VAR, request.GET.get(TO_FIELD_VAR)) if to_field and not self.to_field_allowed(request, to_field): - raise DisallowedModelAdminToField("The field %s cannot be referenced." % to_field) + raise DisallowedModelAdminToField( + "The field %s cannot be referenced." % to_field + ) model = self.model opts = model._meta - if request.method == 'POST' and '_saveasnew' in request.POST: + if request.method == "POST" and "_saveasnew" in request.POST: object_id = None add = object_id is None @@ -1571,7 +1763,7 @@ class ModelAdmin(BaseModelAdmin): else: obj = self.get_object(request, unquote(object_id), to_field) - if request.method == 'POST': + if request.method == "POST": if not self.has_change_permission(request, obj): raise PermissionDenied else: @@ -1585,7 +1777,7 @@ class ModelAdmin(BaseModelAdmin): ModelForm = self.get_form( request, obj, change=not add, fields=flatten_fieldsets(fieldsets) ) - if request.method == 'POST': + if request.method == "POST": form = ModelForm(request.POST, request.FILES, instance=obj) formsets, inline_instances = self._create_formsets( request, @@ -1600,7 +1792,9 @@ class ModelAdmin(BaseModelAdmin): if all_valid(formsets) and form_validated: self.save_model(request, new_object, form, not add) self.save_related(request, form, formsets, not add) - change_message = self.construct_change_message(request, form, formsets, add) + change_message = self.construct_change_message( + request, form, formsets, add + ) if add: self.log_addition(request, new_object, change_message) return self.response_add(request, new_object) @@ -1613,10 +1807,14 @@ class ModelAdmin(BaseModelAdmin): if add: initial = self.get_changeform_initial_data(request) form = ModelForm(initial=initial) - formsets, inline_instances = self._create_formsets(request, form.instance, change=False) + formsets, inline_instances = self._create_formsets( + request, form.instance, change=False + ) else: form = ModelForm(instance=obj) - formsets, inline_instances = self._create_formsets(request, obj, change=True) + formsets, inline_instances = self._create_formsets( + request, obj, change=True + ) if not add and not self.has_change_permission(request, obj): readonly_fields = flatten_fieldsets(fieldsets) @@ -1626,58 +1824,69 @@ class ModelAdmin(BaseModelAdmin): form, list(fieldsets), # Clear prepopulated fields on a view-only form to avoid a crash. - self.get_prepopulated_fields(request, obj) if add or self.has_change_permission(request, obj) else {}, + self.get_prepopulated_fields(request, obj) + if add or self.has_change_permission(request, obj) + else {}, readonly_fields, - model_admin=self) + model_admin=self, + ) media = self.media + adminForm.media - inline_formsets = self.get_inline_formsets(request, formsets, inline_instances, obj) + inline_formsets = self.get_inline_formsets( + request, formsets, inline_instances, obj + ) for inline_formset in inline_formsets: media = media + inline_formset.media if add: - title = _('Add %s') + title = _("Add %s") elif self.has_change_permission(request, obj): - title = _('Change %s') + title = _("Change %s") else: - title = _('View %s') + title = _("View %s") context = { **self.admin_site.each_context(request), - 'title': title % opts.verbose_name, - 'subtitle': str(obj) if obj else None, - 'adminform': adminForm, - 'object_id': object_id, - 'original': obj, - 'is_popup': IS_POPUP_VAR in request.POST or IS_POPUP_VAR in request.GET, - 'to_field': to_field, - 'media': media, - 'inline_admin_formsets': inline_formsets, - 'errors': helpers.AdminErrorList(form, formsets), - 'preserved_filters': self.get_preserved_filters(request), + "title": title % opts.verbose_name, + "subtitle": str(obj) if obj else None, + "adminform": adminForm, + "object_id": object_id, + "original": obj, + "is_popup": IS_POPUP_VAR in request.POST or IS_POPUP_VAR in request.GET, + "to_field": to_field, + "media": media, + "inline_admin_formsets": inline_formsets, + "errors": helpers.AdminErrorList(form, formsets), + "preserved_filters": self.get_preserved_filters(request), } # Hide the "Save" and "Save and continue" buttons if "Save as New" was # previously chosen to prevent the interface from getting confusing. - if request.method == 'POST' and not form_validated and "_saveasnew" in request.POST: - context['show_save'] = False - context['show_save_and_continue'] = False + if ( + request.method == "POST" + and not form_validated + and "_saveasnew" in request.POST + ): + context["show_save"] = False + context["show_save_and_continue"] = False # Use the change template instead of the add template. add = False context.update(extra_context or {}) - return self.render_change_form(request, context, add=add, change=not add, obj=obj, form_url=form_url) + return self.render_change_form( + request, context, add=add, change=not add, obj=obj, form_url=form_url + ) - def add_view(self, request, form_url='', extra_context=None): + def add_view(self, request, form_url="", extra_context=None): return self.changeform_view(request, None, form_url, extra_context) - def change_view(self, request, object_id, form_url='', extra_context=None): + def change_view(self, request, object_id, form_url="", extra_context=None): return self.changeform_view(request, object_id, form_url, extra_context) def _get_edited_object_pks(self, request, prefix): """Return POST data values of list_editable primary keys.""" pk_pattern = re.compile( - r'{}-\d+-{}$'.format(re.escape(prefix), self.model._meta.pk.name) + r"{}-\d+-{}$".format(re.escape(prefix), self.model._meta.pk.name) ) return [value for key, value in request.POST.items() if pk_pattern.match(key)] @@ -1703,6 +1912,7 @@ class ModelAdmin(BaseModelAdmin): The 'change list' admin view for this model. """ from django.contrib.admin.views.main import ERROR_FLAG + opts = self.model._meta app_label = opts.app_label if not self.has_view_or_change_permission(request): @@ -1718,10 +1928,13 @@ class ModelAdmin(BaseModelAdmin): # something is screwed up with the database, so display an error # page. if ERROR_FLAG in request.GET: - return SimpleTemplateResponse('admin/invalid_setup.html', { - 'title': _('Database error'), - }) - return HttpResponseRedirect(request.path + '?' + ERROR_FLAG + '=1') + return SimpleTemplateResponse( + "admin/invalid_setup.html", + { + "title": _("Database error"), + }, + ) + return HttpResponseRedirect(request.path + "?" + ERROR_FLAG + "=1") # If the request was POSTed, this might be a bulk action or a bulk # edit. Try to look up an action or confirmation first, but if this @@ -1732,26 +1945,40 @@ class ModelAdmin(BaseModelAdmin): actions = self.get_actions(request) # Actions with no confirmation - if (actions and request.method == 'POST' and - 'index' in request.POST and '_save' not in request.POST): + if ( + actions + and request.method == "POST" + and "index" in request.POST + and "_save" not in request.POST + ): if selected: - response = self.response_action(request, queryset=cl.get_queryset(request)) + response = self.response_action( + request, queryset=cl.get_queryset(request) + ) if response: return response else: action_failed = True else: - msg = _("Items must be selected in order to perform " - "actions on them. No items have been changed.") + msg = _( + "Items must be selected in order to perform " + "actions on them. No items have been changed." + ) self.message_user(request, msg, messages.WARNING) action_failed = True # Actions with confirmation - if (actions and request.method == 'POST' and - helpers.ACTION_CHECKBOX_NAME in request.POST and - 'index' not in request.POST and '_save' not in request.POST): + if ( + actions + and request.method == "POST" + and helpers.ACTION_CHECKBOX_NAME in request.POST + and "index" not in request.POST + and "_save" not in request.POST + ): if selected: - response = self.response_action(request, queryset=cl.get_queryset(request)) + response = self.response_action( + request, queryset=cl.get_queryset(request) + ) if response: return response else: @@ -1769,12 +1996,16 @@ class ModelAdmin(BaseModelAdmin): formset = cl.formset = None # Handle POSTed bulk-edit data. - if request.method == 'POST' and cl.list_editable and '_save' in request.POST: + if request.method == "POST" and cl.list_editable and "_save" in request.POST: if not self.has_change_permission(request): raise PermissionDenied FormSet = self.get_changelist_formset(request) - modified_objects = self._get_list_editable_queryset(request, FormSet.get_default_prefix()) - formset = cl.formset = FormSet(request.POST, request.FILES, queryset=modified_objects) + modified_objects = self._get_list_editable_queryset( + request, FormSet.get_default_prefix() + ) + formset = cl.formset = FormSet( + request.POST, request.FILES, queryset=modified_objects + ) if formset.is_valid(): changecount = 0 for form in formset.forms: @@ -1790,10 +2021,10 @@ class ModelAdmin(BaseModelAdmin): msg = ngettext( "%(count)s %(name)s was changed successfully.", "%(count)s %(name)s were changed successfully.", - changecount + changecount, ) % { - 'count': changecount, - 'name': model_ngettext(opts, changecount), + "count": changecount, + "name": model_ngettext(opts, changecount), } self.message_user(request, msg, messages.SUCCESS) @@ -1813,45 +2044,48 @@ class ModelAdmin(BaseModelAdmin): # Build the action form and populate it with available actions. if actions: action_form = self.action_form(auto_id=None) - action_form.fields['action'].choices = self.get_action_choices(request) + action_form.fields["action"].choices = self.get_action_choices(request) media += action_form.media else: action_form = None selection_note_all = ngettext( - '%(total_count)s selected', - 'All %(total_count)s selected', - cl.result_count + "%(total_count)s selected", "All %(total_count)s selected", cl.result_count ) context = { **self.admin_site.each_context(request), - 'module_name': str(opts.verbose_name_plural), - 'selection_note': _('0 of %(cnt)s selected') % {'cnt': len(cl.result_list)}, - 'selection_note_all': selection_note_all % {'total_count': cl.result_count}, - 'title': cl.title, - 'subtitle': None, - 'is_popup': cl.is_popup, - 'to_field': cl.to_field, - 'cl': cl, - 'media': media, - 'has_add_permission': self.has_add_permission(request), - 'opts': cl.opts, - 'action_form': action_form, - 'actions_on_top': self.actions_on_top, - 'actions_on_bottom': self.actions_on_bottom, - 'actions_selection_counter': self.actions_selection_counter, - 'preserved_filters': self.get_preserved_filters(request), + "module_name": str(opts.verbose_name_plural), + "selection_note": _("0 of %(cnt)s selected") % {"cnt": len(cl.result_list)}, + "selection_note_all": selection_note_all % {"total_count": cl.result_count}, + "title": cl.title, + "subtitle": None, + "is_popup": cl.is_popup, + "to_field": cl.to_field, + "cl": cl, + "media": media, + "has_add_permission": self.has_add_permission(request), + "opts": cl.opts, + "action_form": action_form, + "actions_on_top": self.actions_on_top, + "actions_on_bottom": self.actions_on_bottom, + "actions_selection_counter": self.actions_selection_counter, + "preserved_filters": self.get_preserved_filters(request), **(extra_context or {}), } request.current_app = self.admin_site.name - return TemplateResponse(request, self.change_list_template or [ - 'admin/%s/%s/change_list.html' % (app_label, opts.model_name), - 'admin/%s/change_list.html' % app_label, - 'admin/change_list.html' - ], context) + return TemplateResponse( + request, + self.change_list_template + or [ + "admin/%s/%s/change_list.html" % (app_label, opts.model_name), + "admin/%s/change_list.html" % app_label, + "admin/change_list.html", + ], + context, + ) def get_deleted_objects(self, objs, request): """ @@ -1872,7 +2106,9 @@ class ModelAdmin(BaseModelAdmin): to_field = request.POST.get(TO_FIELD_VAR, request.GET.get(TO_FIELD_VAR)) if to_field and not self.to_field_allowed(request, to_field): - raise DisallowedModelAdminToField("The field %s cannot be referenced." % to_field) + raise DisallowedModelAdminToField( + "The field %s cannot be referenced." % to_field + ) obj = self.get_object(request, unquote(object_id), to_field) @@ -1884,7 +2120,12 @@ class ModelAdmin(BaseModelAdmin): # Populate deleted_objects, a data structure of all related objects that # will also be deleted. - deleted_objects, model_count, perms_needed, protected = self.get_deleted_objects([obj], request) + ( + deleted_objects, + model_count, + perms_needed, + protected, + ) = self.get_deleted_objects([obj], request) if request.POST and not protected: # The user has confirmed the deletion. if perms_needed: @@ -1906,19 +2147,19 @@ class ModelAdmin(BaseModelAdmin): context = { **self.admin_site.each_context(request), - 'title': title, - 'subtitle': None, - 'object_name': object_name, - 'object': obj, - 'deleted_objects': deleted_objects, - 'model_count': dict(model_count).items(), - 'perms_lacking': perms_needed, - 'protected': protected, - 'opts': opts, - 'app_label': app_label, - 'preserved_filters': self.get_preserved_filters(request), - 'is_popup': IS_POPUP_VAR in request.POST or IS_POPUP_VAR in request.GET, - 'to_field': to_field, + "title": title, + "subtitle": None, + "object_name": object_name, + "object": obj, + "deleted_objects": deleted_objects, + "model_count": dict(model_count).items(), + "perms_lacking": perms_needed, + "protected": protected, + "opts": opts, + "app_label": app_label, + "preserved_filters": self.get_preserved_filters(request), + "is_popup": IS_POPUP_VAR in request.POST or IS_POPUP_VAR in request.GET, + "to_field": to_field, **(extra_context or {}), } @@ -1933,7 +2174,9 @@ class ModelAdmin(BaseModelAdmin): model = self.model obj = self.get_object(request, unquote(object_id)) if obj is None: - return self._get_obj_does_not_exist_redirect(request, model._meta, object_id) + return self._get_obj_does_not_exist_redirect( + request, model._meta, object_id + ) if not self.has_view_or_change_permission(request, obj): raise PermissionDenied @@ -1941,10 +2184,14 @@ class ModelAdmin(BaseModelAdmin): # Then get the history for this object. opts = model._meta app_label = opts.app_label - action_list = LogEntry.objects.filter( - object_id=unquote(object_id), - content_type=get_content_type_for_model(model) - ).select_related().order_by('action_time') + action_list = ( + LogEntry.objects.filter( + object_id=unquote(object_id), + content_type=get_content_type_for_model(model), + ) + .select_related() + .order_by("action_time") + ) paginator = self.get_paginator(request, action_list, 100) page_number = request.GET.get(PAGE_VAR, 1) @@ -1953,39 +2200,46 @@ class ModelAdmin(BaseModelAdmin): context = { **self.admin_site.each_context(request), - 'title': _('Change history: %s') % obj, - 'subtitle': None, - 'action_list': page_obj, - 'page_range': page_range, - 'page_var': PAGE_VAR, - 'pagination_required': paginator.count > 100, - 'module_name': str(capfirst(opts.verbose_name_plural)), - 'object': obj, - 'opts': opts, - 'preserved_filters': self.get_preserved_filters(request), + "title": _("Change history: %s") % obj, + "subtitle": None, + "action_list": page_obj, + "page_range": page_range, + "page_var": PAGE_VAR, + "pagination_required": paginator.count > 100, + "module_name": str(capfirst(opts.verbose_name_plural)), + "object": obj, + "opts": opts, + "preserved_filters": self.get_preserved_filters(request), **(extra_context or {}), } request.current_app = self.admin_site.name - return TemplateResponse(request, self.object_history_template or [ - "admin/%s/%s/object_history.html" % (app_label, opts.model_name), - "admin/%s/object_history.html" % app_label, - "admin/object_history.html" - ], context) + return TemplateResponse( + request, + self.object_history_template + or [ + "admin/%s/%s/object_history.html" % (app_label, opts.model_name), + "admin/%s/object_history.html" % app_label, + "admin/object_history.html", + ], + context, + ) def get_formset_kwargs(self, request, obj, inline, prefix): formset_params = { - 'instance': obj, - 'prefix': prefix, - 'queryset': inline.get_queryset(request), + "instance": obj, + "prefix": prefix, + "queryset": inline.get_queryset(request), } - if request.method == 'POST': - formset_params.update({ - 'data': request.POST.copy(), - 'files': request.FILES, - 'save_as_new': '_saveasnew' in request.POST - }) + if request.method == "POST": + formset_params.update( + { + "data": request.POST.copy(), + "files": request.FILES, + "save_as_new": "_saveasnew" in request.POST, + } + ) return formset_params def _create_formsets(self, request, obj, change): @@ -2007,8 +2261,8 @@ class ModelAdmin(BaseModelAdmin): def user_deleted_form(request, obj, formset, index): """Return whether or not the user deleted the form.""" return ( - inline.has_delete_permission(request, obj) and - '{}-{}-DELETE'.format(formset.prefix, index) in request.POST + inline.has_delete_permission(request, obj) + and "{}-{}-DELETE".format(formset.prefix, index) in request.POST ) # Bypass validation of each view-only inline form (since the form's @@ -2032,6 +2286,7 @@ class InlineModelAdmin(BaseModelAdmin): from ``model`` to its parent. This is required if ``model`` has more than one ``ForeignKey`` to its parent. """ + model = None fk_name = None formset = BaseInlineFormSet @@ -2056,19 +2311,19 @@ class InlineModelAdmin(BaseModelAdmin): if self.verbose_name is None: self.verbose_name_plural = self.model._meta.verbose_name_plural else: - self.verbose_name_plural = format_lazy('{}s', self.verbose_name) + self.verbose_name_plural = format_lazy("{}s", self.verbose_name) if self.verbose_name is None: self.verbose_name = self.model._meta.verbose_name @property def media(self): - extra = '' if settings.DEBUG else '.min' - js = ['vendor/jquery/jquery%s.js' % extra, 'jquery.init.js', 'inlines.js'] + extra = "" if settings.DEBUG else ".min" + js = ["vendor/jquery/jquery%s.js" % extra, "jquery.init.js", "inlines.js"] if self.filter_vertical or self.filter_horizontal: - js.extend(['SelectBox.js', 'SelectFilter2.js']) - if self.classes and 'collapse' in self.classes: - js.append('collapse.js') - return forms.Media(js=['admin/js/%s' % url for url in js]) + js.extend(["SelectBox.js", "SelectFilter2.js"]) + if self.classes and "collapse" in self.classes: + js.append("collapse.js") + return forms.Media(js=["admin/js/%s" % url for url in js]) def get_extra(self, request, obj=None, **kwargs): """Hook for customizing the number of extra inline forms.""" @@ -2084,14 +2339,14 @@ class InlineModelAdmin(BaseModelAdmin): def get_formset(self, request, obj=None, **kwargs): """Return a BaseInlineFormSet class for use in admin add/change views.""" - if 'fields' in kwargs: - fields = kwargs.pop('fields') + if "fields" in kwargs: + fields = kwargs.pop("fields") else: fields = flatten_fieldsets(self.get_fieldsets(request, obj)) excluded = self.get_exclude(request, obj) exclude = [] if excluded is None else list(excluded) exclude.extend(self.get_readonly_fields(request, obj)) - if excluded is None and hasattr(self.form, '_meta') and self.form._meta.exclude: + if excluded is None and hasattr(self.form, "_meta") and self.form._meta.exclude: # Take the custom ModelForm's Meta.exclude into account only if the # InlineModelAdmin doesn't define its own. exclude.extend(self.form._meta.exclude) @@ -2100,25 +2355,24 @@ class InlineModelAdmin(BaseModelAdmin): exclude = exclude or None can_delete = self.can_delete and self.has_delete_permission(request, obj) defaults = { - 'form': self.form, - 'formset': self.formset, - 'fk_name': self.fk_name, - 'fields': fields, - 'exclude': exclude, - 'formfield_callback': partial(self.formfield_for_dbfield, request=request), - 'extra': self.get_extra(request, obj, **kwargs), - 'min_num': self.get_min_num(request, obj, **kwargs), - 'max_num': self.get_max_num(request, obj, **kwargs), - 'can_delete': can_delete, + "form": self.form, + "formset": self.formset, + "fk_name": self.fk_name, + "fields": fields, + "exclude": exclude, + "formfield_callback": partial(self.formfield_for_dbfield, request=request), + "extra": self.get_extra(request, obj, **kwargs), + "min_num": self.get_min_num(request, obj, **kwargs), + "max_num": self.get_max_num(request, obj, **kwargs), + "can_delete": can_delete, **kwargs, } - base_model_form = defaults['form'] + base_model_form = defaults["form"] can_change = self.has_change_permission(request, obj) if request else True can_add = self.has_add_permission(request, obj) if request else True class DeleteProtectedModelForm(base_model_form): - def hand_clean_DELETE(self): """ We don't validate the 'DELETE' field itself because on @@ -2137,19 +2391,22 @@ class InlineModelAdmin(BaseModelAdmin): objs.append( # Translators: Model verbose name and instance representation, # suitable to be an item in a list. - _('%(class_name)s %(instance)s') % { - 'class_name': p._meta.verbose_name, - 'instance': p} + _("%(class_name)s %(instance)s") + % {"class_name": p._meta.verbose_name, "instance": p} ) params = { - 'class_name': self._meta.model._meta.verbose_name, - 'instance': self.instance, - 'related_objects': get_text_list(objs, _('and')), + "class_name": self._meta.model._meta.verbose_name, + "instance": self.instance, + "related_objects": get_text_list(objs, _("and")), } - msg = _("Deleting %(class_name)s %(instance)s would require " - "deleting the following protected related objects: " - "%(related_objects)s") - raise ValidationError(msg, code='deleting_protected', params=params) + msg = _( + "Deleting %(class_name)s %(instance)s would require " + "deleting the following protected related objects: " + "%(related_objects)s" + ) + raise ValidationError( + msg, code="deleting_protected", params=params + ) def is_valid(self): result = super().is_valid() @@ -2164,10 +2421,12 @@ class InlineModelAdmin(BaseModelAdmin): return False return super().has_changed() - defaults['form'] = DeleteProtectedModelForm + defaults["form"] = DeleteProtectedModelForm - if defaults['fields'] is None and not modelform_defines_fields(defaults['form']): - defaults['fields'] = forms.ALL_FIELDS + if defaults["fields"] is None and not modelform_defines_fields( + defaults["form"] + ): + defaults["fields"] = forms.ALL_FIELDS return inlineformset_factory(self.parent_model, self.model, **defaults) @@ -2194,7 +2453,9 @@ class InlineModelAdmin(BaseModelAdmin): opts = field.remote_field.model._meta break return any( - request.user.has_perm('%s.%s' % (opts.app_label, get_permission_codename(perm, opts))) + request.user.has_perm( + "%s.%s" % (opts.app_label, get_permission_codename(perm, opts)) + ) for perm in perms ) @@ -2204,32 +2465,32 @@ class InlineModelAdmin(BaseModelAdmin): # permissions. The user needs to have the change permission for the # related model in order to be able to do anything with the # intermediate model. - return self._has_any_perms_for_target_model(request, ['change']) + return self._has_any_perms_for_target_model(request, ["change"]) return super().has_add_permission(request) def has_change_permission(self, request, obj=None): if self.opts.auto_created: # Same comment as has_add_permission(). - return self._has_any_perms_for_target_model(request, ['change']) + return self._has_any_perms_for_target_model(request, ["change"]) return super().has_change_permission(request) def has_delete_permission(self, request, obj=None): if self.opts.auto_created: # Same comment as has_add_permission(). - return self._has_any_perms_for_target_model(request, ['change']) + return self._has_any_perms_for_target_model(request, ["change"]) return super().has_delete_permission(request, obj) def has_view_permission(self, request, obj=None): if self.opts.auto_created: # Same comment as has_add_permission(). The 'change' permission # also implies the 'view' permission. - return self._has_any_perms_for_target_model(request, ['view', 'change']) + return self._has_any_perms_for_target_model(request, ["view", "change"]) return super().has_view_permission(request) class StackedInline(InlineModelAdmin): - template = 'admin/edit_inline/stacked.html' + template = "admin/edit_inline/stacked.html" class TabularInline(InlineModelAdmin): - template = 'admin/edit_inline/tabular.html' + template = "admin/edit_inline/tabular.html" diff --git a/django/contrib/admin/sites.py b/django/contrib/admin/sites.py index cff1dab829..c8870924e7 100644 --- a/django/contrib/admin/sites.py +++ b/django/contrib/admin/sites.py @@ -9,16 +9,15 @@ from django.contrib.admin.views.autocomplete import AutocompleteJsonView from django.contrib.auth import REDIRECT_FIELD_NAME from django.core.exceptions import ImproperlyConfigured from django.db.models.base import ModelBase -from django.http import ( - Http404, HttpResponsePermanentRedirect, HttpResponseRedirect, -) +from django.http import Http404, HttpResponsePermanentRedirect, HttpResponseRedirect from django.template.response import TemplateResponse from django.urls import NoReverseMatch, Resolver404, resolve, reverse from django.utils.decorators import method_decorator from django.utils.functional import LazyObject from django.utils.module_loading import import_string from django.utils.text import capfirst -from django.utils.translation import gettext as _, gettext_lazy +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy from django.views.decorators.cache import never_cache from django.views.decorators.common import no_append_slash from django.views.decorators.csrf import csrf_protect @@ -45,20 +44,20 @@ class AdminSite: """ # Text to put at the end of each page's . - site_title = gettext_lazy('Django site admin') + site_title = gettext_lazy("Django site admin") # Text to put in each page's <h1>. - site_header = gettext_lazy('Django administration') + site_header = gettext_lazy("Django administration") # Text to put at the top of the admin index page. - index_title = gettext_lazy('Site administration') + index_title = gettext_lazy("Site administration") # URL for the "View site" link at the top of each admin page. - site_url = '/' + site_url = "/" enable_nav_sidebar = True - empty_value_display = '-' + empty_value_display = "-" login_form = None index_template = None @@ -70,15 +69,15 @@ class AdminSite: final_catch_all_view = True - def __init__(self, name='admin'): + def __init__(self, name="admin"): self._registry = {} # model_class class -> admin_class instance self.name = name - self._actions = {'delete_selected': actions.delete_selected} + self._actions = {"delete_selected": actions.delete_selected} self._global_actions = self._actions.copy() all_sites.add(self) def __repr__(self): - return f'{self.__class__.__name__}(name={self.name!r})' + return f"{self.__class__.__name__}(name={self.name!r})" def check(self, app_configs): """ @@ -90,7 +89,9 @@ class AdminSite: app_configs = set(app_configs) # Speed up lookups below errors = [] - modeladmins = (o for o in self._registry.values() if o.__class__ is not ModelAdmin) + modeladmins = ( + o for o in self._registry.values() if o.__class__ is not ModelAdmin + ) for modeladmin in modeladmins: if modeladmin.model._meta.app_config in app_configs: errors.extend(modeladmin.check()) @@ -116,17 +117,18 @@ class AdminSite: for model in model_or_iterable: if model._meta.abstract: raise ImproperlyConfigured( - 'The model %s is abstract, so it cannot be registered with admin.' % model.__name__ + "The model %s is abstract, so it cannot be registered with admin." + % model.__name__ ) if model in self._registry: registered_admin = str(self._registry[model]) - msg = 'The model %s is already registered ' % model.__name__ - if registered_admin.endswith('.ModelAdmin'): + msg = "The model %s is already registered " % model.__name__ + if registered_admin.endswith(".ModelAdmin"): # Most likely registered without a ModelAdmin subclass. - msg += 'in app %r.' % re.sub(r'\.ModelAdmin$', '', registered_admin) + msg += "in app %r." % re.sub(r"\.ModelAdmin$", "", registered_admin) else: - msg += 'with %r.' % registered_admin + msg += "with %r." % registered_admin raise AlreadyRegistered(msg) # Ignore the registration if the model has been @@ -138,8 +140,10 @@ class AdminSite: # For reasons I don't quite understand, without a __module__ # the created class appears to "live" in the wrong place, # which causes issues later on. - options['__module__'] = __name__ - admin_class = type("%sAdmin" % model.__name__, (admin_class,), options) + options["__module__"] = __name__ + admin_class = type( + "%sAdmin" % model.__name__, (admin_class,), options + ) # Instantiate the admin class to save in the registry self._registry[model] = admin_class(model, self) @@ -154,7 +158,7 @@ class AdminSite: model_or_iterable = [model_or_iterable] for model in model_or_iterable: if model not in self._registry: - raise NotRegistered('The model %s is not registered' % model.__name__) + raise NotRegistered("The model %s is not registered" % model.__name__) del self._registry[model] def is_registered(self, model): @@ -221,24 +225,27 @@ class AdminSite: ``never_cache`` decorator. If the view can be safely cached, set cacheable=True. """ + def inner(request, *args, **kwargs): if not self.has_permission(request): - if request.path == reverse('admin:logout', current_app=self.name): - index_path = reverse('admin:index', current_app=self.name) + if request.path == reverse("admin:logout", current_app=self.name): + index_path = reverse("admin:index", current_app=self.name) return HttpResponseRedirect(index_path) # Inner import to prevent django.contrib.admin (app) from # importing django.contrib.auth.models.User (unrelated model). from django.contrib.auth.views import redirect_to_login + return redirect_to_login( request.get_full_path(), - reverse('admin:login', current_app=self.name) + reverse("admin:login", current_app=self.name), ) return view(request, *args, **kwargs) + if not cacheable: inner = never_cache(inner) # We add csrf_protect here so this function can be used as a utility # function for any view, without having to repeat 'csrf_protect'. - if not getattr(view, 'csrf_exempt', False): + if not getattr(view, "csrf_exempt", False): inner = csrf_protect(inner) return update_wrapper(inner, view) @@ -252,26 +259,31 @@ class AdminSite: def wrap(view, cacheable=False): def wrapper(*args, **kwargs): return self.admin_view(view, cacheable)(*args, **kwargs) + wrapper.admin_site = self return update_wrapper(wrapper, view) # Admin-site-wide views. urlpatterns = [ - path('', wrap(self.index), name='index'), - path('login/', self.login, name='login'), - path('logout/', wrap(self.logout), name='logout'), - path('password_change/', wrap(self.password_change, cacheable=True), name='password_change'), + path("", wrap(self.index), name="index"), + path("login/", self.login, name="login"), + path("logout/", wrap(self.logout), name="logout"), path( - 'password_change/done/', - wrap(self.password_change_done, cacheable=True), - name='password_change_done', + "password_change/", + wrap(self.password_change, cacheable=True), + name="password_change", ), - path('autocomplete/', wrap(self.autocomplete_view), name='autocomplete'), - path('jsi18n/', wrap(self.i18n_javascript, cacheable=True), name='jsi18n'), path( - 'r/<int:content_type_id>/<path:object_id>/', + "password_change/done/", + wrap(self.password_change_done, cacheable=True), + name="password_change_done", + ), + path("autocomplete/", wrap(self.autocomplete_view), name="autocomplete"), + path("jsi18n/", wrap(self.i18n_javascript, cacheable=True), name="jsi18n"), + path( + "r/<int:content_type_id>/<path:object_id>/", wrap(contenttype_views.shortcut), - name='view_on_site', + name="view_on_site", ), ] @@ -280,7 +292,10 @@ class AdminSite: valid_app_labels = [] for model, model_admin in self._registry.items(): urlpatterns += [ - path('%s/%s/' % (model._meta.app_label, model._meta.model_name), include(model_admin.urls)), + path( + "%s/%s/" % (model._meta.app_label, model._meta.model_name), + include(model_admin.urls), + ), ] if model._meta.app_label not in valid_app_labels: valid_app_labels.append(model._meta.app_label) @@ -288,19 +303,19 @@ class AdminSite: # If there were ModelAdmins registered, we should have a list of app # labels for which we need to allow access to the app_index view, if valid_app_labels: - regex = r'^(?P<app_label>' + '|'.join(valid_app_labels) + ')/$' + regex = r"^(?P<app_label>" + "|".join(valid_app_labels) + ")/$" urlpatterns += [ - re_path(regex, wrap(self.app_index), name='app_list'), + re_path(regex, wrap(self.app_index), name="app_list"), ] if self.final_catch_all_view: - urlpatterns.append(re_path(r'(?P<url>.*)$', wrap(self.catch_all_view))) + urlpatterns.append(re_path(r"(?P<url>.*)$", wrap(self.catch_all_view))) return urlpatterns @property def urls(self): - return self.get_urls(), 'admin', self.name + return self.get_urls(), "admin", self.name def each_context(self, request): """ @@ -310,16 +325,18 @@ class AdminSite: For sites running on a subpath, use the SCRIPT_NAME value if site_url hasn't been customized. """ - script_name = request.META['SCRIPT_NAME'] - site_url = script_name if self.site_url == '/' and script_name else self.site_url + script_name = request.META["SCRIPT_NAME"] + site_url = ( + script_name if self.site_url == "/" and script_name else self.site_url + ) return { - 'site_title': self.site_title, - 'site_header': self.site_header, - 'site_url': site_url, - 'has_permission': self.has_permission(request), - 'available_apps': self.get_app_list(request), - 'is_popup': False, - 'is_nav_sidebar_enabled': self.enable_nav_sidebar, + "site_title": self.site_title, + "site_header": self.site_header, + "site_url": site_url, + "has_permission": self.has_permission(request), + "available_apps": self.get_app_list(request), + "is_popup": False, + "is_nav_sidebar_enabled": self.enable_nav_sidebar, } def password_change(self, request, extra_context=None): @@ -328,14 +345,15 @@ class AdminSite: """ from django.contrib.admin.forms import AdminPasswordChangeForm from django.contrib.auth.views import PasswordChangeView - url = reverse('admin:password_change_done', current_app=self.name) + + url = reverse("admin:password_change_done", current_app=self.name) defaults = { - 'form_class': AdminPasswordChangeForm, - 'success_url': url, - 'extra_context': {**self.each_context(request), **(extra_context or {})}, + "form_class": AdminPasswordChangeForm, + "success_url": url, + "extra_context": {**self.each_context(request), **(extra_context or {})}, } if self.password_change_template is not None: - defaults['template_name'] = self.password_change_template + defaults["template_name"] = self.password_change_template request.current_app = self.name return PasswordChangeView.as_view(**defaults)(request) @@ -344,11 +362,12 @@ class AdminSite: Display the "success" page after a password change. """ from django.contrib.auth.views import PasswordChangeDoneView + defaults = { - 'extra_context': {**self.each_context(request), **(extra_context or {})}, + "extra_context": {**self.each_context(request), **(extra_context or {})}, } if self.password_change_done_template is not None: - defaults['template_name'] = self.password_change_done_template + defaults["template_name"] = self.password_change_done_template request.current_app = self.name return PasswordChangeDoneView.as_view(**defaults)(request) @@ -359,7 +378,7 @@ class AdminSite: `extra_context` is unused but present for consistency with the other admin views. """ - return JavaScriptCatalog.as_view(packages=['django.contrib.admin'])(request) + return JavaScriptCatalog.as_view(packages=["django.contrib.admin"])(request) def logout(self, request, extra_context=None): """ @@ -368,17 +387,18 @@ class AdminSite: This should *not* assume the user is already logged in. """ from django.contrib.auth.views import LogoutView + defaults = { - 'extra_context': { + "extra_context": { **self.each_context(request), # Since the user isn't logged out at this point, the value of # has_permission must be overridden. - 'has_permission': False, - **(extra_context or {}) + "has_permission": False, + **(extra_context or {}), }, } if self.logout_template is not None: - defaults['template_name'] = self.logout_template + defaults["template_name"] = self.logout_template request.current_app = self.name return LogoutView.as_view(**defaults)(request) @@ -387,9 +407,9 @@ class AdminSite: """ Display the login form for the given HttpRequest. """ - if request.method == 'GET' and self.has_permission(request): + if request.method == "GET" and self.has_permission(request): # Already logged-in, redirect to admin index - index_path = reverse('admin:index', current_app=self.name) + index_path = reverse("admin:index", current_app=self.name) return HttpResponseRedirect(index_path) # Since this module gets imported in the application's root package, @@ -397,22 +417,25 @@ class AdminSite: # and django.contrib.admin.forms eventually imports User. from django.contrib.admin.forms import AdminAuthenticationForm from django.contrib.auth.views import LoginView + context = { **self.each_context(request), - 'title': _('Log in'), - 'subtitle': None, - 'app_path': request.get_full_path(), - 'username': request.user.get_username(), + "title": _("Log in"), + "subtitle": None, + "app_path": request.get_full_path(), + "username": request.user.get_username(), } - if (REDIRECT_FIELD_NAME not in request.GET and - REDIRECT_FIELD_NAME not in request.POST): - context[REDIRECT_FIELD_NAME] = reverse('admin:index', current_app=self.name) + if ( + REDIRECT_FIELD_NAME not in request.GET + and REDIRECT_FIELD_NAME not in request.POST + ): + context[REDIRECT_FIELD_NAME] = reverse("admin:index", current_app=self.name) context.update(extra_context or {}) defaults = { - 'extra_context': context, - 'authentication_form': self.login_form or AdminAuthenticationForm, - 'template_name': self.login_template or 'admin/login.html', + "extra_context": context, + "authentication_form": self.login_form or AdminAuthenticationForm, + "template_name": self.login_template or "admin/login.html", } request.current_app = self.name return LoginView.as_view(**defaults)(request) @@ -422,15 +445,15 @@ class AdminSite: @no_append_slash def catch_all_view(self, request, url): - if settings.APPEND_SLASH and not url.endswith('/'): - urlconf = getattr(request, 'urlconf', None) + if settings.APPEND_SLASH and not url.endswith("/"): + urlconf = getattr(request, "urlconf", None) try: - match = resolve('%s/' % request.path_info, urlconf) + match = resolve("%s/" % request.path_info, urlconf) except Resolver404: pass else: - if getattr(match.func, 'should_append_slash', True): - return HttpResponsePermanentRedirect('%s/' % request.path) + if getattr(match.func, "should_append_slash", True): + return HttpResponsePermanentRedirect("%s/" % request.path) raise Http404 def _build_app_dict(self, request, label=None): @@ -442,7 +465,8 @@ class AdminSite: if label: models = { - m: m_a for m, m_a in self._registry.items() + m: m_a + for m, m_a in self._registry.items() if m._meta.app_label == label } else: @@ -464,38 +488,42 @@ class AdminSite: info = (app_label, model._meta.model_name) model_dict = { - 'model': model, - 'name': capfirst(model._meta.verbose_name_plural), - 'object_name': model._meta.object_name, - 'perms': perms, - 'admin_url': None, - 'add_url': None, + "model": model, + "name": capfirst(model._meta.verbose_name_plural), + "object_name": model._meta.object_name, + "perms": perms, + "admin_url": None, + "add_url": None, } - if perms.get('change') or perms.get('view'): - model_dict['view_only'] = not perms.get('change') + if perms.get("change") or perms.get("view"): + model_dict["view_only"] = not perms.get("change") try: - model_dict['admin_url'] = reverse('admin:%s_%s_changelist' % info, current_app=self.name) + model_dict["admin_url"] = reverse( + "admin:%s_%s_changelist" % info, current_app=self.name + ) except NoReverseMatch: pass - if perms.get('add'): + if perms.get("add"): try: - model_dict['add_url'] = reverse('admin:%s_%s_add' % info, current_app=self.name) + model_dict["add_url"] = reverse( + "admin:%s_%s_add" % info, current_app=self.name + ) except NoReverseMatch: pass if app_label in app_dict: - app_dict[app_label]['models'].append(model_dict) + app_dict[app_label]["models"].append(model_dict) else: app_dict[app_label] = { - 'name': apps.get_app_config(app_label).verbose_name, - 'app_label': app_label, - 'app_url': reverse( - 'admin:app_list', - kwargs={'app_label': app_label}, + "name": apps.get_app_config(app_label).verbose_name, + "app_label": app_label, + "app_url": reverse( + "admin:app_list", + kwargs={"app_label": app_label}, current_app=self.name, ), - 'has_module_perms': has_module_perms, - 'models': [model_dict], + "has_module_perms": has_module_perms, + "models": [model_dict], } if label: @@ -510,11 +538,11 @@ class AdminSite: app_dict = self._build_app_dict(request) # Sort the apps alphabetically. - app_list = sorted(app_dict.values(), key=lambda x: x['name'].lower()) + app_list = sorted(app_dict.values(), key=lambda x: x["name"].lower()) # Sort the models alphabetically within each app. for app in app_list: - app['models'].sort(key=lambda x: x['name']) + app["models"].sort(key=lambda x: x["name"]) return app_list @@ -527,42 +555,46 @@ class AdminSite: context = { **self.each_context(request), - 'title': self.index_title, - 'subtitle': None, - 'app_list': app_list, + "title": self.index_title, + "subtitle": None, + "app_list": app_list, **(extra_context or {}), } request.current_app = self.name - return TemplateResponse(request, self.index_template or 'admin/index.html', context) + return TemplateResponse( + request, self.index_template or "admin/index.html", context + ) def app_index(self, request, app_label, extra_context=None): app_dict = self._build_app_dict(request, app_label) if not app_dict: - raise Http404('The requested admin page does not exist.') + raise Http404("The requested admin page does not exist.") # Sort the models alphabetically within each app. - app_dict['models'].sort(key=lambda x: x['name']) + app_dict["models"].sort(key=lambda x: x["name"]) context = { **self.each_context(request), - 'title': _('%(app)s administration') % {'app': app_dict['name']}, - 'subtitle': None, - 'app_list': [app_dict], - 'app_label': app_label, + "title": _("%(app)s administration") % {"app": app_dict["name"]}, + "subtitle": None, + "app_list": [app_dict], + "app_label": app_label, **(extra_context or {}), } request.current_app = self.name - return TemplateResponse(request, self.app_index_template or [ - 'admin/%s/app_index.html' % app_label, - 'admin/app_index.html' - ], context) + return TemplateResponse( + request, + self.app_index_template + or ["admin/%s/app_index.html" % app_label, "admin/app_index.html"], + context, + ) class DefaultAdminSite(LazyObject): def _setup(self): - AdminSiteClass = import_string(apps.get_app_config('admin').default_site) + AdminSiteClass = import_string(apps.get_app_config("admin").default_site) self._wrapped = AdminSiteClass() def __repr__(self): diff --git a/django/contrib/admin/templatetags/admin_list.py b/django/contrib/admin/templatetags/admin_list.py index bc74a2a3d6..5865843dce 100644 --- a/django/contrib/admin/templatetags/admin_list.py +++ b/django/contrib/admin/templatetags/admin_list.py @@ -3,11 +3,18 @@ import datetime from django.conf import settings from django.contrib.admin.templatetags.admin_urls import add_preserved_filters from django.contrib.admin.utils import ( - display_for_field, display_for_value, get_fields_from_path, - label_for_field, lookup_field, + display_for_field, + display_for_value, + get_fields_from_path, + label_for_field, + lookup_field, ) from django.contrib.admin.views.main import ( - ALL_VAR, IS_POPUP_VAR, ORDER_VAR, PAGE_VAR, SEARCH_VAR, + ALL_VAR, + IS_POPUP_VAR, + ORDER_VAR, + PAGE_VAR, + SEARCH_VAR, ) from django.core.exceptions import ObjectDoesNotExist from django.db import models @@ -32,14 +39,14 @@ def paginator_number(cl, i): Generate an individual page index link in a paginated list. """ if i == cl.paginator.ELLIPSIS: - return format_html('{} ', cl.paginator.ELLIPSIS) + return format_html("{} ", cl.paginator.ELLIPSIS) elif i == cl.page_num: return format_html('<span class="this-page">{}</span> ', i) else: return format_html( '<a href="{}"{}>{}</a> ', cl.get_query_string({PAGE_VAR: i}), - mark_safe(' class="end"' if i == cl.paginator.num_pages else ''), + mark_safe(' class="end"' if i == cl.paginator.num_pages else ""), i, ) @@ -49,24 +56,27 @@ def pagination(cl): Generate the series of links to the pages in a paginated list. """ pagination_required = (not cl.show_all or not cl.can_show_all) and cl.multi_page - page_range = cl.paginator.get_elided_page_range(cl.page_num) if pagination_required else [] + page_range = ( + cl.paginator.get_elided_page_range(cl.page_num) if pagination_required else [] + ) need_show_all_link = cl.can_show_all and not cl.show_all and cl.multi_page return { - 'cl': cl, - 'pagination_required': pagination_required, - 'show_all_url': need_show_all_link and cl.get_query_string({ALL_VAR: ''}), - 'page_range': page_range, - 'ALL_VAR': ALL_VAR, - '1': 1, + "cl": cl, + "pagination_required": pagination_required, + "show_all_url": need_show_all_link and cl.get_query_string({ALL_VAR: ""}), + "page_range": page_range, + "ALL_VAR": ALL_VAR, + "1": 1, } -@register.tag(name='pagination') +@register.tag(name="pagination") def pagination_tag(parser, token): return InclusionAdminNode( - parser, token, + parser, + token, func=pagination, - template_name='pagination.html', + template_name="pagination.html", takes_context=False, ) @@ -78,9 +88,7 @@ def result_headers(cl): ordering_field_columns = cl.get_ordering_field_columns() for i, field_name in enumerate(cl.list_display): text, attr = label_for_field( - field_name, cl.model, - model_admin=cl.model_admin, - return_attr=True + field_name, cl.model, model_admin=cl.model_admin, return_attr=True ) is_field_sortable = cl.sortable_by is None or field_name in cl.sortable_by if attr: @@ -88,7 +96,7 @@ def result_headers(cl): # Potentially not sortable # if the field is the action checkbox: no sorting and special class - if field_name == 'action_checkbox': + if field_name == "action_checkbox": yield { "text": text, "class_attrib": mark_safe(' class="action-checkbox-column"'), @@ -98,32 +106,32 @@ def result_headers(cl): admin_order_field = getattr(attr, "admin_order_field", None) # Set ordering for attr that is a property, if defined. - if isinstance(attr, property) and hasattr(attr, 'fget'): - admin_order_field = getattr(attr.fget, 'admin_order_field', None) + if isinstance(attr, property) and hasattr(attr, "fget"): + admin_order_field = getattr(attr.fget, "admin_order_field", None) if not admin_order_field: is_field_sortable = False if not is_field_sortable: # Not sortable yield { - 'text': text, - 'class_attrib': format_html(' class="column-{}"', field_name), - 'sortable': False, + "text": text, + "class_attrib": format_html(' class="column-{}"', field_name), + "sortable": False, } continue # OK, it is sortable if we got this far - th_classes = ['sortable', 'column-{}'.format(field_name)] - order_type = '' - new_order_type = 'asc' + th_classes = ["sortable", "column-{}".format(field_name)] + order_type = "" + new_order_type = "asc" sort_priority = 0 # Is it currently being sorted on? is_sorted = i in ordering_field_columns if is_sorted: order_type = ordering_field_columns.get(i).lower() sort_priority = list(ordering_field_columns).index(i) + 1 - th_classes.append('sorted %sending' % order_type) - new_order_type = {'asc': 'desc', 'desc': 'asc'}[order_type] + th_classes.append("sorted %sending" % order_type) + new_order_type = {"asc": "desc", "desc": "asc"}[order_type] # build new ordering param o_list_primary = [] # URL for making this field the primary sort @@ -131,7 +139,7 @@ def result_headers(cl): o_list_toggle = [] # URL for toggling order type for this field def make_qs_param(t, n): - return ('-' if t == 'desc' else '') + str(n) + return ("-" if t == "desc" else "") + str(n) for j, ot in ordering_field_columns.items(): if j == i: # Same column @@ -156,15 +164,19 @@ def result_headers(cl): "sorted": is_sorted, "ascending": order_type == "asc", "sort_priority": sort_priority, - "url_primary": cl.get_query_string({ORDER_VAR: '.'.join(o_list_primary)}), - "url_remove": cl.get_query_string({ORDER_VAR: '.'.join(o_list_remove)}), - "url_toggle": cl.get_query_string({ORDER_VAR: '.'.join(o_list_toggle)}), - "class_attrib": format_html(' class="{}"', ' '.join(th_classes)) if th_classes else '', + "url_primary": cl.get_query_string({ORDER_VAR: ".".join(o_list_primary)}), + "url_remove": cl.get_query_string({ORDER_VAR: ".".join(o_list_remove)}), + "url_toggle": cl.get_query_string({ORDER_VAR: ".".join(o_list_toggle)}), + "class_attrib": format_html(' class="{}"', " ".join(th_classes)) + if th_classes + else "", } def _boolean_icon(field_val): - icon_url = static('admin/img/icon-%s.svg' % {True: 'yes', False: 'no', None: 'unknown'}[field_val]) + icon_url = static( + "admin/img/icon-%s.svg" % {True: "yes", False: "no", None: "unknown"}[field_val] + ) return format_html('<img src="{}" alt="{}">', icon_url, field_val) @@ -173,8 +185,8 @@ def _coerce_field_name(field_name, field_index): Coerce a field_name (which may be a callable) to a string. """ if callable(field_name): - if field_name.__name__ == '<lambda>': - return 'lambda' + str(field_index) + if field_name.__name__ == "<lambda>": + return "lambda" + str(field_index) else: return field_name.__name__ return field_name @@ -196,20 +208,22 @@ def items_for_result(cl, result, form): pk = cl.lookup_opts.pk.attname for field_index, field_name in enumerate(cl.list_display): empty_value_display = cl.model_admin.get_empty_value_display() - row_classes = ['field-%s' % _coerce_field_name(field_name, field_index)] + row_classes = ["field-%s" % _coerce_field_name(field_name, field_index)] try: f, attr, value = lookup_field(field_name, result, cl.model_admin) except ObjectDoesNotExist: result_repr = empty_value_display else: - empty_value_display = getattr(attr, 'empty_value_display', empty_value_display) + empty_value_display = getattr( + attr, "empty_value_display", empty_value_display + ) if f is None or f.auto_created: - if field_name == 'action_checkbox': - row_classes = ['action-checkbox'] - boolean = getattr(attr, 'boolean', False) + if field_name == "action_checkbox": + row_classes = ["action-checkbox"] + boolean = getattr(attr, "boolean", False) result_repr = display_for_value(value, empty_value_display, boolean) if isinstance(value, (datetime.date, datetime.time)): - row_classes.append('nowrap') + row_classes.append("nowrap") else: if isinstance(f.remote_field, models.ManyToOneRel): field_val = getattr(result, f.name) @@ -219,12 +233,14 @@ def items_for_result(cl, result, form): result_repr = field_val else: result_repr = display_for_field(value, f, empty_value_display) - if isinstance(f, (models.DateField, models.TimeField, models.ForeignKey)): - row_classes.append('nowrap') - row_class = mark_safe(' class="%s"' % ' '.join(row_classes)) + if isinstance( + f, (models.DateField, models.TimeField, models.ForeignKey) + ): + row_classes.append("nowrap") + row_class = mark_safe(' class="%s"' % " ".join(row_classes)) # If list_display_links not defined, add the link tag to the first field if link_in_col(first, field_name, cl): - table_tag = 'th' if first else 'td' + table_tag = "th" if first else "td" first = False # Display link to the result's change_view if the url exists, else @@ -234,7 +250,9 @@ def items_for_result(cl, result, form): except NoReverseMatch: link_or_text = result_repr else: - url = add_preserved_filters({'preserved_filters': cl.preserved_filters, 'opts': cl.opts}, url) + url = add_preserved_filters( + {"preserved_filters": cl.preserved_filters, "opts": cl.opts}, url + ) # Convert the pk to something that can be used in JavaScript. # Problem cases are non-ASCII strings. if cl.to_field: @@ -245,24 +263,32 @@ def items_for_result(cl, result, form): link_or_text = format_html( '<a href="{}"{}>{}</a>', url, - format_html( - ' data-popup-opener="{}"', value - ) if cl.is_popup else '', - result_repr) + format_html(' data-popup-opener="{}"', value) + if cl.is_popup + else "", + result_repr, + ) - yield format_html('<{}{}>{}</{}>', table_tag, row_class, link_or_text, table_tag) + yield format_html( + "<{}{}>{}</{}>", table_tag, row_class, link_or_text, table_tag + ) else: # By default the fields come from ModelAdmin.list_editable, but if we pull # the fields out of the form instead of list_editable custom admins # can provide fields on a per request basis - if (form and field_name in form.fields and not ( - field_name == cl.model._meta.pk.name and - form[cl.model._meta.pk.name].is_hidden)): + if ( + form + and field_name in form.fields + and not ( + field_name == cl.model._meta.pk.name + and form[cl.model._meta.pk.name].is_hidden + ) + ): bf = form[field_name] result_repr = mark_safe(str(bf.errors) + str(bf)) - yield format_html('<td{}>{}</td>', row_class, result_repr) + yield format_html("<td{}>{}</td>", row_class, result_repr) if form and not form[cl.model._meta.pk.name].is_hidden: - yield format_html('<td>{}</td>', form[cl.model._meta.pk.name]) + yield format_html("<td>{}</td>", form[cl.model._meta.pk.name]) class ResultList(list): @@ -271,6 +297,7 @@ class ResultList(list): with the form object for error reporting purposes. Needed to maintain backwards compatibility with existing admin templates. """ + def __init__(self, form, *items): self.form = form super().__init__(*items) @@ -299,23 +326,24 @@ def result_list(cl): headers = list(result_headers(cl)) num_sorted_fields = 0 for h in headers: - if h['sortable'] and h['sorted']: + if h["sortable"] and h["sorted"]: num_sorted_fields += 1 return { - 'cl': cl, - 'result_hidden_fields': list(result_hidden_fields(cl)), - 'result_headers': headers, - 'num_sorted_fields': num_sorted_fields, - 'results': list(results(cl)), + "cl": cl, + "result_hidden_fields": list(result_hidden_fields(cl)), + "result_headers": headers, + "num_sorted_fields": num_sorted_fields, + "results": list(results(cl)), } -@register.tag(name='result_list') +@register.tag(name="result_list") def result_list_tag(parser, token): return InclusionAdminNode( - parser, token, + parser, + token, func=result_list, - template_name='change_list_results.html', + template_name="change_list_results.html", takes_context=False, ) @@ -328,15 +356,15 @@ def date_hierarchy(cl): field_name = cl.date_hierarchy field = get_fields_from_path(cl.model, field_name)[-1] if isinstance(field, models.DateTimeField): - dates_or_datetimes = 'datetimes' - qs_kwargs = {'is_dst': True} if settings.USE_DEPRECATED_PYTZ else {} + dates_or_datetimes = "datetimes" + qs_kwargs = {"is_dst": True} if settings.USE_DEPRECATED_PYTZ else {} else: - dates_or_datetimes = 'dates' + dates_or_datetimes = "dates" qs_kwargs = {} - year_field = '%s__year' % field_name - month_field = '%s__month' % field_name - day_field = '%s__day' % field_name - field_generic = '%s__' % field_name + year_field = "%s__year" % field_name + month_field = "%s__month" % field_name + day_field = "%s__day" % field_name + field_generic = "%s__" % field_name year_lookup = cl.params.get(year_field) month_lookup = cl.params.get(month_field) day_lookup = cl.params.get(day_field) @@ -346,73 +374,99 @@ def date_hierarchy(cl): if not (year_lookup or month_lookup or day_lookup): # select appropriate start level - date_range = cl.queryset.aggregate(first=models.Min(field_name), - last=models.Max(field_name)) - if date_range['first'] and date_range['last']: - if dates_or_datetimes == 'datetimes': + date_range = cl.queryset.aggregate( + first=models.Min(field_name), last=models.Max(field_name) + ) + if date_range["first"] and date_range["last"]: + if dates_or_datetimes == "datetimes": date_range = { k: timezone.localtime(v) if timezone.is_aware(v) else v for k, v in date_range.items() } - if date_range['first'].year == date_range['last'].year: - year_lookup = date_range['first'].year - if date_range['first'].month == date_range['last'].month: - month_lookup = date_range['first'].month + if date_range["first"].year == date_range["last"].year: + year_lookup = date_range["first"].year + if date_range["first"].month == date_range["last"].month: + month_lookup = date_range["first"].month if year_lookup and month_lookup and day_lookup: day = datetime.date(int(year_lookup), int(month_lookup), int(day_lookup)) return { - 'show': True, - 'back': { - 'link': link({year_field: year_lookup, month_field: month_lookup}), - 'title': capfirst(formats.date_format(day, 'YEAR_MONTH_FORMAT')) + "show": True, + "back": { + "link": link({year_field: year_lookup, month_field: month_lookup}), + "title": capfirst(formats.date_format(day, "YEAR_MONTH_FORMAT")), }, - 'choices': [{'title': capfirst(formats.date_format(day, 'MONTH_DAY_FORMAT'))}] + "choices": [ + {"title": capfirst(formats.date_format(day, "MONTH_DAY_FORMAT"))} + ], } elif year_lookup and month_lookup: - days = getattr(cl.queryset, dates_or_datetimes)(field_name, 'day', **qs_kwargs) + days = getattr(cl.queryset, dates_or_datetimes)( + field_name, "day", **qs_kwargs + ) return { - 'show': True, - 'back': { - 'link': link({year_field: year_lookup}), - 'title': str(year_lookup) + "show": True, + "back": { + "link": link({year_field: year_lookup}), + "title": str(year_lookup), }, - 'choices': [{ - 'link': link({year_field: year_lookup, month_field: month_lookup, day_field: day.day}), - 'title': capfirst(formats.date_format(day, 'MONTH_DAY_FORMAT')) - } for day in days] + "choices": [ + { + "link": link( + { + year_field: year_lookup, + month_field: month_lookup, + day_field: day.day, + } + ), + "title": capfirst(formats.date_format(day, "MONTH_DAY_FORMAT")), + } + for day in days + ], } elif year_lookup: - months = getattr(cl.queryset, dates_or_datetimes)(field_name, 'month', **qs_kwargs) + months = getattr(cl.queryset, dates_or_datetimes)( + field_name, "month", **qs_kwargs + ) return { - 'show': True, - 'back': { - 'link': link({}), - 'title': _('All dates') - }, - 'choices': [{ - 'link': link({year_field: year_lookup, month_field: month.month}), - 'title': capfirst(formats.date_format(month, 'YEAR_MONTH_FORMAT')) - } for month in months] + "show": True, + "back": {"link": link({}), "title": _("All dates")}, + "choices": [ + { + "link": link( + {year_field: year_lookup, month_field: month.month} + ), + "title": capfirst( + formats.date_format(month, "YEAR_MONTH_FORMAT") + ), + } + for month in months + ], } else: - years = getattr(cl.queryset, dates_or_datetimes)(field_name, 'year', **qs_kwargs) + years = getattr(cl.queryset, dates_or_datetimes)( + field_name, "year", **qs_kwargs + ) return { - 'show': True, - 'back': None, - 'choices': [{ - 'link': link({year_field: str(year.year)}), - 'title': str(year.year), - } for year in years] + "show": True, + "back": None, + "choices": [ + { + "link": link({year_field: str(year.year)}), + "title": str(year.year), + } + for year in years + ], } -@register.tag(name='date_hierarchy') +@register.tag(name="date_hierarchy") def date_hierarchy_tag(parser, token): return InclusionAdminNode( - parser, token, + parser, + token, func=date_hierarchy, - template_name='date_hierarchy.html', + template_name="date_hierarchy.html", takes_context=False, ) @@ -422,26 +476,34 @@ def search_form(cl): Display a search form for searching the list. """ return { - 'cl': cl, - 'show_result_count': cl.result_count != cl.full_result_count, - 'search_var': SEARCH_VAR, - 'is_popup_var': IS_POPUP_VAR, + "cl": cl, + "show_result_count": cl.result_count != cl.full_result_count, + "search_var": SEARCH_VAR, + "is_popup_var": IS_POPUP_VAR, } -@register.tag(name='search_form') +@register.tag(name="search_form") def search_form_tag(parser, token): - return InclusionAdminNode(parser, token, func=search_form, template_name='search_form.html', takes_context=False) + return InclusionAdminNode( + parser, + token, + func=search_form, + template_name="search_form.html", + takes_context=False, + ) @register.simple_tag def admin_list_filter(cl, spec): tpl = get_template(spec.template) - return tpl.render({ - 'title': spec.title, - 'choices': list(spec.choices(cl)), - 'spec': spec, - }) + return tpl.render( + { + "title": spec.title, + "choices": list(spec.choices(cl)), + "spec": spec, + } + ) def admin_actions(context): @@ -449,20 +511,23 @@ def admin_actions(context): Track the number of times the action field has been rendered on the page, so we know which value to use. """ - context['action_index'] = context.get('action_index', -1) + 1 + context["action_index"] = context.get("action_index", -1) + 1 return context -@register.tag(name='admin_actions') +@register.tag(name="admin_actions") def admin_actions_tag(parser, token): - return InclusionAdminNode(parser, token, func=admin_actions, template_name='actions.html') + return InclusionAdminNode( + parser, token, func=admin_actions, template_name="actions.html" + ) -@register.tag(name='change_list_object_tools') +@register.tag(name="change_list_object_tools") def change_list_object_tools_tag(parser, token): """Display the row of change list object tools.""" return InclusionAdminNode( - parser, token, + parser, + token, func=lambda context: context, - template_name='change_list_object_tools.html', + template_name="change_list_object_tools.html", ) diff --git a/django/contrib/admin/templatetags/admin_modify.py b/django/contrib/admin/templatetags/admin_modify.py index 583b7639e8..910e6b68b9 100644 --- a/django/contrib/admin/templatetags/admin_modify.py +++ b/django/contrib/admin/templatetags/admin_modify.py @@ -14,90 +14,119 @@ def prepopulated_fields_js(context): the prepopulated fields for both the admin form and inlines. """ prepopulated_fields = [] - if 'adminform' in context: - prepopulated_fields.extend(context['adminform'].prepopulated_fields) - if 'inline_admin_formsets' in context: - for inline_admin_formset in context['inline_admin_formsets']: + if "adminform" in context: + prepopulated_fields.extend(context["adminform"].prepopulated_fields) + if "inline_admin_formsets" in context: + for inline_admin_formset in context["inline_admin_formsets"]: for inline_admin_form in inline_admin_formset: if inline_admin_form.original is None: prepopulated_fields.extend(inline_admin_form.prepopulated_fields) prepopulated_fields_json = [] for field in prepopulated_fields: - prepopulated_fields_json.append({ - "id": "#%s" % field["field"].auto_id, - "name": field["field"].name, - "dependency_ids": ["#%s" % dependency.auto_id for dependency in field["dependencies"]], - "dependency_list": [dependency.name for dependency in field["dependencies"]], - "maxLength": field["field"].field.max_length or 50, - "allowUnicode": getattr(field["field"].field, "allow_unicode", False) - }) + prepopulated_fields_json.append( + { + "id": "#%s" % field["field"].auto_id, + "name": field["field"].name, + "dependency_ids": [ + "#%s" % dependency.auto_id for dependency in field["dependencies"] + ], + "dependency_list": [ + dependency.name for dependency in field["dependencies"] + ], + "maxLength": field["field"].field.max_length or 50, + "allowUnicode": getattr(field["field"].field, "allow_unicode", False), + } + ) - context.update({ - 'prepopulated_fields': prepopulated_fields, - 'prepopulated_fields_json': json.dumps(prepopulated_fields_json), - }) + context.update( + { + "prepopulated_fields": prepopulated_fields, + "prepopulated_fields_json": json.dumps(prepopulated_fields_json), + } + ) return context -@register.tag(name='prepopulated_fields_js') +@register.tag(name="prepopulated_fields_js") def prepopulated_fields_js_tag(parser, token): - return InclusionAdminNode(parser, token, func=prepopulated_fields_js, template_name="prepopulated_fields_js.html") + return InclusionAdminNode( + parser, + token, + func=prepopulated_fields_js, + template_name="prepopulated_fields_js.html", + ) def submit_row(context): """ Display the row of buttons for delete and save. """ - add = context['add'] - change = context['change'] - is_popup = context['is_popup'] - save_as = context['save_as'] - show_save = context.get('show_save', True) - show_save_and_add_another = context.get('show_save_and_add_another', True) - show_save_and_continue = context.get('show_save_and_continue', True) - has_add_permission = context['has_add_permission'] - has_change_permission = context['has_change_permission'] - has_view_permission = context['has_view_permission'] - has_editable_inline_admin_formsets = context['has_editable_inline_admin_formsets'] - can_save = (has_change_permission and change) or (has_add_permission and add) or has_editable_inline_admin_formsets - can_save_and_add_another = ( - has_add_permission and - not is_popup and - (not save_as or add) and - can_save and - show_save_and_add_another + add = context["add"] + change = context["change"] + is_popup = context["is_popup"] + save_as = context["save_as"] + show_save = context.get("show_save", True) + show_save_and_add_another = context.get("show_save_and_add_another", True) + show_save_and_continue = context.get("show_save_and_continue", True) + has_add_permission = context["has_add_permission"] + has_change_permission = context["has_change_permission"] + has_view_permission = context["has_view_permission"] + has_editable_inline_admin_formsets = context["has_editable_inline_admin_formsets"] + can_save = ( + (has_change_permission and change) + or (has_add_permission and add) + or has_editable_inline_admin_formsets + ) + can_save_and_add_another = ( + has_add_permission + and not is_popup + and (not save_as or add) + and can_save + and show_save_and_add_another + ) + can_save_and_continue = ( + not is_popup and can_save and has_view_permission and show_save_and_continue ) - can_save_and_continue = not is_popup and can_save and has_view_permission and show_save_and_continue can_change = has_change_permission or has_editable_inline_admin_formsets ctx = Context(context) - ctx.update({ - 'can_change': can_change, - 'show_delete_link': ( - not is_popup and context['has_delete_permission'] and - change and context.get('show_delete', True) - ), - 'show_save_as_new': not is_popup and has_change_permission and change and save_as, - 'show_save_and_add_another': can_save_and_add_another, - 'show_save_and_continue': can_save_and_continue, - 'show_save': show_save and can_save, - 'show_close': not(show_save and can_save) - }) + ctx.update( + { + "can_change": can_change, + "show_delete_link": ( + not is_popup + and context["has_delete_permission"] + and change + and context.get("show_delete", True) + ), + "show_save_as_new": not is_popup + and has_change_permission + and change + and save_as, + "show_save_and_add_another": can_save_and_add_another, + "show_save_and_continue": can_save_and_continue, + "show_save": show_save and can_save, + "show_close": not (show_save and can_save), + } + ) return ctx -@register.tag(name='submit_row') +@register.tag(name="submit_row") def submit_row_tag(parser, token): - return InclusionAdminNode(parser, token, func=submit_row, template_name='submit_line.html') + return InclusionAdminNode( + parser, token, func=submit_row, template_name="submit_line.html" + ) -@register.tag(name='change_form_object_tools') +@register.tag(name="change_form_object_tools") def change_form_object_tools_tag(parser, token): """Display the row of change form object tools.""" return InclusionAdminNode( - parser, token, + parser, + token, func=lambda context: context, - template_name='change_form_object_tools.html', + template_name="change_form_object_tools.html", ) diff --git a/django/contrib/admin/templatetags/admin_urls.py b/django/contrib/admin/templatetags/admin_urls.py index f817c254eb..13ded03361 100644 --- a/django/contrib/admin/templatetags/admin_urls.py +++ b/django/contrib/admin/templatetags/admin_urls.py @@ -10,7 +10,7 @@ register = template.Library() @register.filter def admin_urlname(value, arg): - return 'admin:%s_%s_%s' % (value.app_label, value.model_name, arg) + return "admin:%s_%s_%s" % (value.app_label, value.model_name, arg) @register.filter @@ -20,8 +20,8 @@ def admin_urlquote(value): @register.simple_tag(takes_context=True) def add_preserved_filters(context, url, popup=False, to_field=None): - opts = context.get('opts') - preserved_filters = context.get('preserved_filters') + opts = context.get("opts") + preserved_filters = context.get("preserved_filters") parsed_url = list(urlparse(url)) parsed_qs = dict(parse_qsl(parsed_url[4])) @@ -30,24 +30,34 @@ def add_preserved_filters(context, url, popup=False, to_field=None): if opts and preserved_filters: preserved_filters = dict(parse_qsl(preserved_filters)) - match_url = '/%s' % unquote(url).partition(get_script_prefix())[2] + match_url = "/%s" % unquote(url).partition(get_script_prefix())[2] try: match = resolve(match_url) except Resolver404: pass else: - current_url = '%s:%s' % (match.app_name, match.url_name) - changelist_url = 'admin:%s_%s_changelist' % (opts.app_label, opts.model_name) - if changelist_url == current_url and '_changelist_filters' in preserved_filters: - preserved_filters = dict(parse_qsl(preserved_filters['_changelist_filters'])) + current_url = "%s:%s" % (match.app_name, match.url_name) + changelist_url = "admin:%s_%s_changelist" % ( + opts.app_label, + opts.model_name, + ) + if ( + changelist_url == current_url + and "_changelist_filters" in preserved_filters + ): + preserved_filters = dict( + parse_qsl(preserved_filters["_changelist_filters"]) + ) merged_qs.update(preserved_filters) if popup: from django.contrib.admin.options import IS_POPUP_VAR + merged_qs[IS_POPUP_VAR] = 1 if to_field: from django.contrib.admin.options import TO_FIELD_VAR + merged_qs[TO_FIELD_VAR] = to_field merged_qs.update(parsed_qs) diff --git a/django/contrib/admin/templatetags/base.py b/django/contrib/admin/templatetags/base.py index e98604ac5a..23e4cfbe86 100644 --- a/django/contrib/admin/templatetags/base.py +++ b/django/contrib/admin/templatetags/base.py @@ -11,23 +11,35 @@ class InclusionAdminNode(InclusionNode): def __init__(self, parser, token, func, template_name, takes_context=True): self.template_name = template_name - params, varargs, varkw, defaults, kwonly, kwonly_defaults, _ = getfullargspec(func) + params, varargs, varkw, defaults, kwonly, kwonly_defaults, _ = getfullargspec( + func + ) bits = token.split_contents() args, kwargs = parse_bits( - parser, bits[1:], params, varargs, varkw, defaults, kwonly, - kwonly_defaults, takes_context, bits[0], + parser, + bits[1:], + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + takes_context, + bits[0], ) super().__init__(func, takes_context, args, kwargs, filename=None) def render(self, context): - opts = context['opts'] + opts = context["opts"] app_label = opts.app_label.lower() object_name = opts.object_name.lower() # Load template for this render call. (Setting self.filename isn't # thread-safe.) - context.render_context[self] = context.template.engine.select_template([ - 'admin/%s/%s/%s' % (app_label, object_name, self.template_name), - 'admin/%s/%s' % (app_label, self.template_name), - 'admin/%s' % self.template_name, - ]) + context.render_context[self] = context.template.engine.select_template( + [ + "admin/%s/%s/%s" % (app_label, object_name, self.template_name), + "admin/%s/%s" % (app_label, self.template_name), + "admin/%s" % self.template_name, + ] + ) return super().render(context) diff --git a/django/contrib/admin/templatetags/log.py b/django/contrib/admin/templatetags/log.py index 08c2345e7c..04a375013d 100644 --- a/django/contrib/admin/templatetags/log.py +++ b/django/contrib/admin/templatetags/log.py @@ -19,8 +19,10 @@ class AdminLogNode(template.Node): if not user_id.isdigit(): user_id = context[self.user].pk entries = LogEntry.objects.filter(user__pk=user_id) - context[self.varname] = entries.select_related('content_type', 'user')[:int(self.limit)] - return '' + context[self.varname] = entries.select_related("content_type", "user")[ + : int(self.limit) + ] + return "" @register.tag @@ -45,15 +47,23 @@ def get_admin_log(parser, token): tokens = token.contents.split() if len(tokens) < 4: raise template.TemplateSyntaxError( - "'get_admin_log' statements require two arguments") + "'get_admin_log' statements require two arguments" + ) if not tokens[1].isdigit(): raise template.TemplateSyntaxError( - "First argument to 'get_admin_log' must be an integer") - if tokens[2] != 'as': + "First argument to 'get_admin_log' must be an integer" + ) + if tokens[2] != "as": raise template.TemplateSyntaxError( - "Second argument to 'get_admin_log' must be 'as'") + "Second argument to 'get_admin_log' must be 'as'" + ) if len(tokens) > 4: - if tokens[4] != 'for_user': + if tokens[4] != "for_user": raise template.TemplateSyntaxError( - "Fourth argument to 'get_admin_log' must be 'for_user'") - return AdminLogNode(limit=tokens[1], varname=tokens[3], user=(tokens[5] if len(tokens) > 5 else None)) + "Fourth argument to 'get_admin_log' must be 'for_user'" + ) + return AdminLogNode( + limit=tokens[1], + varname=tokens[3], + user=(tokens[5] if len(tokens) > 5 else None), + ) diff --git a/django/contrib/admin/tests.py b/django/contrib/admin/tests.py index 36e2fe5c90..70c3949fef 100644 --- a/django/contrib/admin/tests.py +++ b/django/contrib/admin/tests.py @@ -9,20 +9,21 @@ from django.utils.translation import gettext as _ class CSPMiddleware(MiddlewareMixin): """The admin's JavaScript should be compatible with CSP.""" + def process_response(self, request, response): - response.headers['Content-Security-Policy'] = "default-src 'self'" + response.headers["Content-Security-Policy"] = "default-src 'self'" return response -@modify_settings(MIDDLEWARE={'append': 'django.contrib.admin.tests.CSPMiddleware'}) +@modify_settings(MIDDLEWARE={"append": "django.contrib.admin.tests.CSPMiddleware"}) class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): available_apps = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.sites', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.sites", ] def wait_until(self, callback, timeout=10): @@ -33,6 +34,7 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): call this function for more details. """ from selenium.webdriver.support.wait import WebDriverWait + WebDriverWait(self.selenium, timeout).until(callback) def wait_for_and_switch_to_popup(self, num_windows=2, timeout=10): @@ -51,9 +53,9 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec + self.wait_until( - ec.presence_of_element_located((By.CSS_SELECTOR, css_selector)), - timeout + ec.presence_of_element_located((By.CSS_SELECTOR, css_selector)), timeout ) def wait_for_text(self, css_selector, text, timeout=10): @@ -62,10 +64,10 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec + self.wait_until( - ec.text_to_be_present_in_element( - (By.CSS_SELECTOR, css_selector), text), - timeout + ec.text_to_be_present_in_element((By.CSS_SELECTOR, css_selector), text), + timeout, ) def wait_for_value(self, css_selector, text, timeout=10): @@ -74,10 +76,12 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec + self.wait_until( ec.text_to_be_present_in_element_value( - (By.CSS_SELECTOR, css_selector), text), - timeout + (By.CSS_SELECTOR, css_selector), text + ), + timeout, ) def wait_until_visible(self, css_selector, timeout=10): @@ -86,9 +90,9 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec + self.wait_until( - ec.visibility_of_element_located((By.CSS_SELECTOR, css_selector)), - timeout + ec.visibility_of_element_located((By.CSS_SELECTOR, css_selector)), timeout ) def wait_until_invisible(self, css_selector, timeout=10): @@ -97,9 +101,9 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec + self.wait_until( - ec.invisibility_of_element_located((By.CSS_SELECTOR, css_selector)), - timeout + ec.invisibility_of_element_located((By.CSS_SELECTOR, css_selector)), timeout ) def wait_page_ready(self, timeout=10): @@ -107,7 +111,8 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): Block until the page is ready. """ self.wait_until( - lambda driver: driver.execute_script('return document.readyState;') == 'complete', + lambda driver: driver.execute_script("return document.readyState;") + == "complete", timeout, ) @@ -118,25 +123,29 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec - old_page = self.selenium.find_element(By.TAG_NAME, 'html') + + old_page = self.selenium.find_element(By.TAG_NAME, "html") yield # Wait for the next page to be loaded self.wait_until(ec.staleness_of(old_page), timeout=timeout) self.wait_page_ready(timeout=timeout) - def admin_login(self, username, password, login_url='/admin/'): + def admin_login(self, username, password, login_url="/admin/"): """ Log in to the admin. """ from selenium.webdriver.common.by import By - self.selenium.get('%s%s' % (self.live_server_url, login_url)) - username_input = self.selenium.find_element(By.NAME, 'username') + + self.selenium.get("%s%s" % (self.live_server_url, login_url)) + username_input = self.selenium.find_element(By.NAME, "username") username_input.send_keys(username) - password_input = self.selenium.find_element(By.NAME, 'password') + password_input = self.selenium.find_element(By.NAME, "password") password_input.send_keys(password) - login_text = _('Log in') + login_text = _("Log in") with self.wait_page_loaded(): - self.selenium.find_element(By.XPATH, '//input[@value="%s"]' % login_text).click() + self.selenium.find_element( + By.XPATH, '//input[@value="%s"]' % login_text + ).click() def select_option(self, selector, value): """ @@ -145,6 +154,7 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import Select + select = Select(self.selenium.find_element(By.CSS_SELECTOR, selector)) select.select_by_value(value) @@ -155,6 +165,7 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): """ from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import Select + select = Select(self.selenium.find_element(By.CSS_SELECTOR, selector)) select.deselect_by_value(value) @@ -165,16 +176,20 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): `root_element` allow restriction to a pre-selected node. """ from selenium.webdriver.common.by import By + root_element = root_element or self.selenium - self.assertEqual(len(root_element.find_elements(By.CSS_SELECTOR, selector)), count) + self.assertEqual( + len(root_element.find_elements(By.CSS_SELECTOR, selector)), count + ) def _assertOptionsValues(self, options_selector, values): from selenium.webdriver.common.by import By + if values: options = self.selenium.find_elements(By.CSS_SELECTOR, options_selector) actual_values = [] for option in options: - actual_values.append(option.get_attribute('value')) + actual_values.append(option.get_attribute("value")) self.assertEqual(values, actual_values) else: # Prevent the `find_elements(By.CSS_SELECTOR, …)` call from blocking @@ -182,7 +197,9 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): # to be the case. with self.disable_implicit_wait(): self.wait_until( - lambda driver: not driver.find_elements(By.CSS_SELECTOR, options_selector) + lambda driver: not driver.find_elements( + By.CSS_SELECTOR, options_selector + ) ) def assertSelectOptions(self, selector, values): @@ -205,7 +222,13 @@ class AdminSeleniumTestCase(SeleniumTestCase, StaticLiveServerTestCase): `klass`. """ from selenium.webdriver.common.by import By - return self.selenium.find_element( - By.CSS_SELECTOR, - selector, - ).get_attribute('class').find(klass) != -1 + + return ( + self.selenium.find_element( + By.CSS_SELECTOR, + selector, + ) + .get_attribute("class") + .find(klass) + != -1 + ) diff --git a/django/contrib/admin/utils.py b/django/contrib/admin/utils.py index 4e3d025a56..6cb549fb10 100644 --- a/django/contrib/admin/utils.py +++ b/django/contrib/admin/utils.py @@ -13,15 +13,17 @@ from django.utils import formats, timezone from django.utils.html import format_html from django.utils.regex_helper import _lazy_re_compile from django.utils.text import capfirst -from django.utils.translation import ngettext, override as translation_override +from django.utils.translation import ngettext +from django.utils.translation import override as translation_override -QUOTE_MAP = {i: '_%02X' % i for i in b'":/_#?;@&=+$,"[]<>%\n\\'} +QUOTE_MAP = {i: "_%02X" % i for i in b'":/_#?;@&=+$,"[]<>%\n\\'} UNQUOTE_MAP = {v: chr(k) for k, v in QUOTE_MAP.items()} -UNQUOTE_RE = _lazy_re_compile('_(?:%s)' % '|'.join([x[1:] for x in UNQUOTE_MAP])) +UNQUOTE_RE = _lazy_re_compile("_(?:%s)" % "|".join([x[1:] for x in UNQUOTE_MAP])) class FieldIsAForeignKeyColumnName(Exception): """A field is a foreign key attname, i.e. <FK>_id.""" + pass @@ -32,7 +34,7 @@ def lookup_spawns_duplicates(opts, lookup_path): lookup_fields = lookup_path.split(LOOKUP_SEP) # Go through the fields (following all relations) and look for an m2m. for field_name in lookup_fields: - if field_name == 'pk': + if field_name == "pk": field_name = opts.pk.name try: field = opts.get_field(field_name) @@ -40,7 +42,7 @@ def lookup_spawns_duplicates(opts, lookup_path): # Ignore query lookups. continue else: - if hasattr(field, 'path_infos'): + if hasattr(field, "path_infos"): # This field is a relation; update opts to follow the relation. path_info = field.path_infos opts = path_info[-1].to_opts @@ -51,16 +53,16 @@ def lookup_spawns_duplicates(opts, lookup_path): return False -def prepare_lookup_value(key, value, separator=','): +def prepare_lookup_value(key, value, separator=","): """ Return a lookup value prepared to be used in queryset filtering. """ # if key ends with __in, split parameter into separate values - if key.endswith('__in'): + if key.endswith("__in"): value = value.split(separator) # if key ends with __isnull, special case '' and the string literals 'false' and '0' - elif key.endswith('__isnull'): - value = value.lower() not in ('', 'false', '0') + elif key.endswith("__isnull"): + value = value.lower() not in ("", "false", "0") return value @@ -96,9 +98,7 @@ def flatten_fieldsets(fieldsets): """Return a list of field names from an admin fieldsets structure.""" field_names = [] for name, opts in fieldsets: - field_names.extend( - flatten(opts['fields']) - ) + field_names.extend(flatten(opts["fields"])) return field_names @@ -125,26 +125,26 @@ def get_deleted_objects(objs, request, admin_site): has_admin = model in admin_site._registry opts = obj._meta - no_edit_link = '%s: %s' % (capfirst(opts.verbose_name), obj) + no_edit_link = "%s: %s" % (capfirst(opts.verbose_name), obj) if has_admin: if not admin_site._registry[model].has_delete_permission(request, obj): perms_needed.add(opts.verbose_name) try: - admin_url = reverse('%s:%s_%s_change' - % (admin_site.name, - opts.app_label, - opts.model_name), - None, (quote(obj.pk),)) + admin_url = reverse( + "%s:%s_%s_change" + % (admin_site.name, opts.app_label, opts.model_name), + None, + (quote(obj.pk),), + ) except NoReverseMatch: # Change url doesn't exist -- don't display link to edit return no_edit_link # Display a link to the admin page. - return format_html('{}: <a href="{}">{}</a>', - capfirst(opts.verbose_name), - admin_url, - obj) + return format_html( + '{}: <a href="{}">{}</a>', capfirst(opts.verbose_name), admin_url, obj + ) else: # Don't display link to edit, because it either has no # admin or is edited inline. @@ -153,7 +153,10 @@ def get_deleted_objects(objs, request, admin_site): to_delete = collector.nested(format_callback) protected = [format_callback(obj) for obj in collector.protected] - model_count = {model._meta.verbose_name_plural: len(objs) for model, objs in collector.model_objs.items()} + model_count = { + model._meta.verbose_name_plural: len(objs) + for model, objs in collector.model_objs.items() + } return to_delete, model_count, perms_needed, protected @@ -170,10 +173,10 @@ class NestedObjects(Collector): def collect(self, objs, source=None, source_attr=None, **kwargs): for obj in objs: - if source_attr and not source_attr.endswith('+'): + if source_attr and not source_attr.endswith("+"): related_name = source_attr % { - 'class': source._meta.model_name, - 'app_label': source._meta.app_label, + "class": source._meta.model_name, + "app_label": source._meta.app_label, } self.add_edge(getattr(obj, related_name), obj) else: @@ -188,7 +191,9 @@ class NestedObjects(Collector): def related_objects(self, related_model, related_fields, objs): qs = super().related_objects(related_model, related_fields, objs) - return qs.select_related(*[related_field.name for related_field in related_fields]) + return qs.select_related( + *[related_field.name for related_field in related_fields] + ) def _nested(self, obj, seen, format_callback): if obj in seen: @@ -237,8 +242,8 @@ def model_format_dict(obj): else: opts = obj return { - 'verbose_name': opts.verbose_name, - 'verbose_name_plural': opts.verbose_name_plural, + "verbose_name": opts.verbose_name, + "verbose_name_plural": opts.verbose_name_plural, } @@ -270,7 +275,7 @@ def lookup_field(name, obj, model_admin=None): if callable(name): attr = name value = attr(obj) - elif hasattr(model_admin, name) and name != '__str__': + elif hasattr(model_admin, name) and name != "__str__": attr = getattr(model_admin, name) value = attr(obj) else: @@ -295,13 +300,21 @@ def _get_non_gfk_field(opts, name): model (rather something like `foo_set`). """ field = opts.get_field(name) - if (field.is_relation and - # Generic foreign keys OR reverse relations - ((field.many_to_one and not field.related_model) or field.one_to_many)): + if ( + field.is_relation + and + # Generic foreign keys OR reverse relations + ((field.many_to_one and not field.related_model) or field.one_to_many) + ): raise FieldDoesNotExist() # Avoid coercing <FK>_id fields to FK - if field.is_relation and not field.many_to_many and hasattr(field, 'attname') and field.attname == name: + if ( + field.is_relation + and not field.many_to_many + and hasattr(field, "attname") + and field.attname == name + ): raise FieldIsAForeignKeyColumnName() return field @@ -337,7 +350,10 @@ def label_for_field(name, model, model_admin=None, return_attr=False, form=None) elif form and name in form.fields: attr = form.fields[name] else: - message = "Unable to lookup '%s' on %s" % (name, model._meta.object_name) + message = "Unable to lookup '%s' on %s" % ( + name, + model._meta.object_name, + ) if model_admin: message += " or %s" % model_admin.__class__.__name__ if form: @@ -346,9 +362,11 @@ def label_for_field(name, model, model_admin=None, return_attr=False, form=None) if hasattr(attr, "short_description"): label = attr.short_description - elif (isinstance(attr, property) and - hasattr(attr, "fget") and - hasattr(attr.fget, "short_description")): + elif ( + isinstance(attr, property) + and hasattr(attr, "fget") + and hasattr(attr.fget, "short_description") + ): label = attr.fget.short_description elif callable(attr): if attr.__name__ == "<lambda>": @@ -374,7 +392,7 @@ def help_text_for_field(name, model): except (FieldDoesNotExist, FieldIsAForeignKeyColumnName): pass else: - if hasattr(field, 'help_text'): + if hasattr(field, "help_text"): help_text = field.help_text return help_text @@ -382,7 +400,7 @@ def help_text_for_field(name, model): def display_for_field(value, field, empty_value_display): from django.contrib.admin.templatetags.admin_list import _boolean_icon - if getattr(field, 'flatchoices', None): + if getattr(field, "flatchoices", None): return dict(field.flatchoices).get(value, empty_value_display) # BooleanField needs special-case null-handling, so it comes before the # general null test. @@ -425,7 +443,7 @@ def display_for_value(value, empty_value_display, boolean=False): elif isinstance(value, (int, decimal.Decimal, float)): return formats.number_format(value) elif isinstance(value, (list, tuple)): - return ', '.join(str(v) for v in value) + return ", ".join(str(v) for v in value) else: return str(value) @@ -435,14 +453,14 @@ class NotRelationField(Exception): def get_model_from_relation(field): - if hasattr(field, 'path_infos'): + if hasattr(field, "path_infos"): return field.path_infos[-1].to_opts.model else: raise NotRelationField def reverse_field_path(model, path): - """ Create a reversed field path. + """Create a reversed field path. E.g. Given (Order, "user__groups"), return (Group, "user__order"). @@ -473,7 +491,7 @@ def reverse_field_path(model, path): def get_fields_from_path(model, path): - """ Return list of Fields given path relative to model. + """Return list of Fields given path relative to model. e.g. (ModelX, "user__groups__name") -> [ <django.db.models.fields.related.ForeignKey object at 0x...>, @@ -510,34 +528,42 @@ def construct_change_message(form, formsets, add): change_message = [] if add: - change_message.append({'added': {}}) + change_message.append({"added": {}}) elif form.changed_data: - change_message.append({'changed': {'fields': changed_field_labels}}) + change_message.append({"changed": {"fields": changed_field_labels}}) if formsets: with translation_override(None): for formset in formsets: for added_object in formset.new_objects: - change_message.append({ - 'added': { - 'name': str(added_object._meta.verbose_name), - 'object': str(added_object), + change_message.append( + { + "added": { + "name": str(added_object._meta.verbose_name), + "object": str(added_object), + } } - }) + ) for changed_object, changed_fields in formset.changed_objects: - change_message.append({ - 'changed': { - 'name': str(changed_object._meta.verbose_name), - 'object': str(changed_object), - 'fields': _get_changed_field_labels_from_form(formset.forms[0], changed_fields), + change_message.append( + { + "changed": { + "name": str(changed_object._meta.verbose_name), + "object": str(changed_object), + "fields": _get_changed_field_labels_from_form( + formset.forms[0], changed_fields + ), + } } - }) + ) for deleted_object in formset.deleted_objects: - change_message.append({ - 'deleted': { - 'name': str(deleted_object._meta.verbose_name), - 'object': str(deleted_object), + change_message.append( + { + "deleted": { + "name": str(deleted_object._meta.verbose_name), + "object": str(deleted_object), + } } - }) + ) return change_message diff --git a/django/contrib/admin/views/autocomplete.py b/django/contrib/admin/views/autocomplete.py index 26aff083b6..130848b551 100644 --- a/django/contrib/admin/views/autocomplete.py +++ b/django/contrib/admin/views/autocomplete.py @@ -6,6 +6,7 @@ from django.views.generic.list import BaseListView class AutocompleteJsonView(BaseListView): """Handle AutocompleteWidget's AJAX requests for data.""" + paginate_by = 20 admin_site = None @@ -18,27 +19,34 @@ class AutocompleteJsonView(BaseListView): pagination: {more: true} } """ - self.term, self.model_admin, self.source_field, to_field_name = self.process_request(request) + ( + self.term, + self.model_admin, + self.source_field, + to_field_name, + ) = self.process_request(request) if not self.has_perm(request): raise PermissionDenied self.object_list = self.get_queryset() context = self.get_context_data() - return JsonResponse({ - 'results': [ - self.serialize_result(obj, to_field_name) - for obj in context['object_list'] - ], - 'pagination': {'more': context['page_obj'].has_next()}, - }) + return JsonResponse( + { + "results": [ + self.serialize_result(obj, to_field_name) + for obj in context["object_list"] + ], + "pagination": {"more": context["page_obj"].has_next()}, + } + ) def serialize_result(self, obj, to_field_name): """ Convert the provided model object to a dictionary that is added to the results list. """ - return {'id': str(getattr(obj, to_field_name)), 'text': str(obj)} + return {"id": str(getattr(obj, to_field_name)), "text": str(obj)} def get_paginator(self, *args, **kwargs): """Use the ModelAdmin's paginator.""" @@ -48,7 +56,9 @@ class AutocompleteJsonView(BaseListView): """Return queryset based on ModelAdmin.get_search_results().""" qs = self.model_admin.get_queryset(self.request) qs = qs.complex_filter(self.source_field.get_limit_choices_to()) - qs, search_use_distinct = self.model_admin.get_search_results(self.request, qs, self.term) + qs, search_use_distinct = self.model_admin.get_search_results( + self.request, qs, self.term + ) if search_use_distinct: qs = qs.distinct() return qs @@ -64,11 +74,11 @@ class AutocompleteJsonView(BaseListView): Raise Http404 if the target model admin is not configured properly with search_fields. """ - term = request.GET.get('term', '') + term = request.GET.get("term", "") try: - app_label = request.GET['app_label'] - model_name = request.GET['model_name'] - field_name = request.GET['field_name'] + app_label = request.GET["app_label"] + model_name = request.GET["model_name"] + field_name = request.GET["field_name"] except KeyError as e: raise PermissionDenied from e @@ -94,11 +104,13 @@ class AutocompleteJsonView(BaseListView): # Validate suitability of objects. if not model_admin.get_search_fields(request): raise Http404( - '%s must have search_fields for the autocomplete_view.' % - type(model_admin).__qualname__ + "%s must have search_fields for the autocomplete_view." + % type(model_admin).__qualname__ ) - to_field_name = getattr(source_field.remote_field, 'field_name', remote_model._meta.pk.attname) + to_field_name = getattr( + source_field.remote_field, "field_name", remote_model._meta.pk.attname + ) to_field_name = remote_model._meta.get_field(to_field_name).attname if not model_admin.to_field_allowed(request, to_field_name): raise PermissionDenied diff --git a/django/contrib/admin/views/decorators.py b/django/contrib/admin/views/decorators.py index f14570c24d..c1b63ba64d 100644 --- a/django/contrib/admin/views/decorators.py +++ b/django/contrib/admin/views/decorators.py @@ -2,8 +2,9 @@ from django.contrib.auth import REDIRECT_FIELD_NAME from django.contrib.auth.decorators import user_passes_test -def staff_member_required(view_func=None, redirect_field_name=REDIRECT_FIELD_NAME, - login_url='admin:login'): +def staff_member_required( + view_func=None, redirect_field_name=REDIRECT_FIELD_NAME, login_url="admin:login" +): """ Decorator for views that checks that the user is logged in and is a staff member, redirecting to the login page if necessary. @@ -11,7 +12,7 @@ def staff_member_required(view_func=None, redirect_field_name=REDIRECT_FIELD_NAM actual_decorator = user_passes_test( lambda u: u.is_active and u.is_staff, login_url=login_url, - redirect_field_name=redirect_field_name + redirect_field_name=redirect_field_name, ) if view_func: return actual_decorator(view_func) diff --git a/django/contrib/admin/views/main.py b/django/contrib/admin/views/main.py index 43583c81e9..ace4b34ce5 100644 --- a/django/contrib/admin/views/main.py +++ b/django/contrib/admin/views/main.py @@ -5,17 +5,24 @@ from django.conf import settings from django.contrib import messages from django.contrib.admin import FieldListFilter from django.contrib.admin.exceptions import ( - DisallowedModelAdminLookup, DisallowedModelAdminToField, + DisallowedModelAdminLookup, + DisallowedModelAdminToField, ) from django.contrib.admin.options import ( - IS_POPUP_VAR, TO_FIELD_VAR, IncorrectLookupParameters, + IS_POPUP_VAR, + TO_FIELD_VAR, + IncorrectLookupParameters, ) from django.contrib.admin.utils import ( - get_fields_from_path, lookup_spawns_duplicates, prepare_lookup_value, + get_fields_from_path, + lookup_spawns_duplicates, + prepare_lookup_value, quote, ) from django.core.exceptions import ( - FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation, + FieldDoesNotExist, + ImproperlyConfigured, + SuspiciousOperation, ) from django.core.paginator import InvalidPage from django.db.models import Exists, F, Field, ManyToOneRel, OrderBy, OuterRef @@ -26,11 +33,11 @@ from django.utils.timezone import make_aware from django.utils.translation import gettext # Changelist settings -ALL_VAR = 'all' -ORDER_VAR = 'o' -PAGE_VAR = 'p' -SEARCH_VAR = 'q' -ERROR_FLAG = 'e' +ALL_VAR = "all" +ORDER_VAR = "o" +PAGE_VAR = "p" +SEARCH_VAR = "q" +ERROR_FLAG = "e" IGNORED_PARAMS = (ALL_VAR, ORDER_VAR, SEARCH_VAR, IS_POPUP_VAR, TO_FIELD_VAR) @@ -47,10 +54,23 @@ class ChangeListSearchForm(forms.Form): class ChangeList: search_form_class = ChangeListSearchForm - def __init__(self, request, model, list_display, list_display_links, - list_filter, date_hierarchy, search_fields, list_select_related, - list_per_page, list_max_show_all, list_editable, model_admin, sortable_by, - search_help_text): + def __init__( + self, + request, + model, + list_display, + list_display_links, + list_filter, + date_hierarchy, + search_fields, + list_select_related, + list_per_page, + list_max_show_all, + list_editable, + model_admin, + sortable_by, + search_help_text, + ): self.model = model self.opts = model._meta self.lookup_opts = self.opts @@ -75,8 +95,8 @@ class ChangeList: _search_form = self.search_form_class(request.GET) if not _search_form.is_valid(): for error in _search_form.errors.values(): - messages.error(request, ', '.join(error)) - self.query = _search_form.cleaned_data.get(SEARCH_VAR) or '' + messages.error(request, ", ".join(error)) + self.query = _search_form.cleaned_data.get(SEARCH_VAR) or "" try: self.page_num = int(request.GET.get(PAGE_VAR, 1)) except ValueError: @@ -85,7 +105,9 @@ class ChangeList: self.is_popup = IS_POPUP_VAR in request.GET to_field = request.GET.get(TO_FIELD_VAR) if to_field and not model_admin.to_field_allowed(request, to_field): - raise DisallowedModelAdminToField("The field %s cannot be referenced." % to_field) + raise DisallowedModelAdminToField( + "The field %s cannot be referenced." % to_field + ) self.to_field = to_field self.params = dict(request.GET.items()) if PAGE_VAR in self.params: @@ -100,16 +122,16 @@ class ChangeList: self.queryset = self.get_queryset(request) self.get_results(request) if self.is_popup: - title = gettext('Select %s') + title = gettext("Select %s") elif self.model_admin.has_change_permission(request): - title = gettext('Select %s to change') + title = gettext("Select %s to change") else: - title = gettext('Select %s to view') + title = gettext("Select %s to view") self.title = title % self.opts.verbose_name self.pk_attname = self.lookup_opts.pk.attname def __repr__(self): - return '<%s: model=%s model_admin=%s>' % ( + return "<%s: model=%s model_admin=%s>" % ( self.__class__.__qualname__, self.model.__qualname__, self.model_admin.__class__.__qualname__, @@ -158,15 +180,20 @@ class ChangeList: field = get_fields_from_path(self.model, field_path)[-1] spec = field_list_filter_class( - field, request, lookup_params, - self.model, self.model_admin, field_path=field_path, + field, + request, + lookup_params, + self.model, + self.model_admin, + field_path=field_path, ) # field_list_filter_class removes any lookup_params it # processes. If that happened, check if duplicates should be # removed. if lookup_params_count > len(lookup_params): may_have_duplicates |= lookup_spawns_duplicates( - self.lookup_opts, field_path, + self.lookup_opts, + field_path, ) if spec and spec.has_output(): filter_specs.append(spec) @@ -176,10 +203,10 @@ class ChangeList: if self.date_hierarchy: # Create bounded lookup parameters so that the query is more # efficient. - year = lookup_params.pop('%s__year' % self.date_hierarchy, None) + year = lookup_params.pop("%s__year" % self.date_hierarchy, None) if year is not None: - month = lookup_params.pop('%s__month' % self.date_hierarchy, None) - day = lookup_params.pop('%s__day' % self.date_hierarchy, None) + month = lookup_params.pop("%s__month" % self.date_hierarchy, None) + day = lookup_params.pop("%s__day" % self.date_hierarchy, None) try: from_date = datetime( int(year), @@ -199,10 +226,12 @@ class ChangeList: if settings.USE_TZ: from_date = make_aware(from_date) to_date = make_aware(to_date) - lookup_params.update({ - '%s__gte' % self.date_hierarchy: from_date, - '%s__lt' % self.date_hierarchy: to_date, - }) + lookup_params.update( + { + "%s__gte" % self.date_hierarchy: from_date, + "%s__lt" % self.date_hierarchy: to_date, + } + ) # At this point, all the parameters used by the various ListFilters # have been removed from lookup_params, which now only contains other @@ -215,7 +244,10 @@ class ChangeList: lookup_params[key] = prepare_lookup_value(key, value) may_have_duplicates |= lookup_spawns_duplicates(self.lookup_opts, key) return ( - filter_specs, bool(filter_specs), lookup_params, may_have_duplicates, + filter_specs, + bool(filter_specs), + lookup_params, + may_have_duplicates, has_active_filters, ) except FieldDoesNotExist as e: @@ -237,10 +269,12 @@ class ChangeList: del p[k] else: p[k] = v - return '?%s' % urlencode(sorted(p.items())) + return "?%s" % urlencode(sorted(p.items())) def get_results(self, request): - paginator = self.model_admin.get_paginator(request, self.queryset, self.list_per_page) + paginator = self.model_admin.get_paginator( + request, self.queryset, self.list_per_page + ) # Get the number of objects, with admin filters applied. result_count = paginator.count @@ -265,7 +299,9 @@ class ChangeList: self.show_full_result_count = self.model_admin.show_full_result_count # Admin actions are shown if there is at least one entry # or if entries are not counted because show_full_result_count is disabled - self.show_admin_actions = not self.show_full_result_count or bool(full_result_count) + self.show_admin_actions = not self.show_full_result_count or bool( + full_result_count + ) self.full_result_count = full_result_count self.result_list = result_list self.can_show_all = can_show_all @@ -300,9 +336,9 @@ class ChangeList: attr = getattr(self.model_admin, field_name) else: attr = getattr(self.model, field_name) - if isinstance(attr, property) and hasattr(attr, 'fget'): + if isinstance(attr, property) and hasattr(attr, "fget"): attr = attr.fget - return getattr(attr, 'admin_order_field', None) + return getattr(attr, "admin_order_field", None) def get_ordering(self, request, queryset): """ @@ -314,28 +350,32 @@ class ChangeList: constructed ordering. """ params = self.params - ordering = list(self.model_admin.get_ordering(request) or self._get_default_ordering()) + ordering = list( + self.model_admin.get_ordering(request) or self._get_default_ordering() + ) if ORDER_VAR in params: # Clear ordering and used params ordering = [] - order_params = params[ORDER_VAR].split('.') + order_params = params[ORDER_VAR].split(".") for p in order_params: try: - none, pfx, idx = p.rpartition('-') + none, pfx, idx = p.rpartition("-") field_name = self.list_display[int(idx)] order_field = self.get_ordering_field(field_name) if not order_field: continue # No 'admin_order_field', skip it if isinstance(order_field, OrderBy): - if pfx == '-': + if pfx == "-": order_field = order_field.copy() order_field.reverse_ordering() ordering.append(order_field) - elif hasattr(order_field, 'resolve_expression'): + elif hasattr(order_field, "resolve_expression"): # order_field is an expression. - ordering.append(order_field.desc() if pfx == '-' else order_field.asc()) + ordering.append( + order_field.desc() if pfx == "-" else order_field.asc() + ) # reverse order if order_field has already "-" as prefix - elif order_field.startswith('-') and pfx == '-': + elif order_field.startswith("-") and pfx == "-": ordering.append(order_field[1:]) else: ordering.append(pfx + order_field) @@ -356,15 +396,16 @@ class ChangeList: """ ordering = list(ordering) ordering_fields = set() - total_ordering_fields = {'pk'} | { - field.attname for field in self.lookup_opts.fields + total_ordering_fields = {"pk"} | { + field.attname + for field in self.lookup_opts.fields if field.unique and not field.null } for part in ordering: # Search for single field providing a total ordering. field_name = None if isinstance(part, str): - field_name = part.lstrip('-') + field_name = part.lstrip("-") elif isinstance(part, F): field_name = part.name elif isinstance(part, OrderBy) and isinstance(part.expression, F): @@ -396,7 +437,9 @@ class ChangeList: ) for field_names in constraint_field_names: # Normalize attname references by using get_field(). - fields = [self.lookup_opts.get_field(field_name) for field_name in field_names] + fields = [ + self.lookup_opts.get_field(field_name) for field_name in field_names + ] # Composite unique constraints containing a nullable column # cannot ensure total ordering. if any(field.null for field in fields): @@ -406,7 +449,7 @@ class ChangeList: else: # If no set of unique fields is present in the ordering, rely # on the primary key to provide total ordering. - ordering.append('-pk') + ordering.append("-pk") return ordering def get_ordering_field_columns(self): @@ -426,27 +469,27 @@ class ChangeList: if not isinstance(field, OrderBy): field = field.asc() if isinstance(field.expression, F): - order_type = 'desc' if field.descending else 'asc' + order_type = "desc" if field.descending else "asc" field = field.expression.name else: continue - elif field.startswith('-'): + elif field.startswith("-"): field = field[1:] - order_type = 'desc' + order_type = "desc" else: - order_type = 'asc' + order_type = "asc" for index, attr in enumerate(self.list_display): if self.get_ordering_field(attr) == field: ordering_fields[index] = order_type break else: - for p in self.params[ORDER_VAR].split('.'): - none, pfx, idx = p.rpartition('-') + for p in self.params[ORDER_VAR].split("."): + none, pfx, idx = p.rpartition("-") try: idx = int(idx) except ValueError: continue # skip it - ordering_fields[idx] = 'desc' if pfx == '-' else 'asc' + ordering_fields[idx] = "desc" if pfx == "-" else "asc" return ordering_fields def get_queryset(self, request): @@ -484,7 +527,9 @@ class ChangeList: # Apply search results qs, search_may_have_duplicates = self.model_admin.get_search_results( - request, qs, self.query, + request, + qs, + self.query, ) # Set query string for clearing all filters. @@ -494,7 +539,7 @@ class ChangeList: ) # Remove duplicates from results, if necessary if filters_may_have_duplicates | search_may_have_duplicates: - qs = qs.filter(pk=OuterRef('pk')) + qs = qs.filter(pk=OuterRef("pk")) qs = self.root_queryset.filter(Exists(qs)) # Set ordering. @@ -533,7 +578,8 @@ class ChangeList: def url_for_result(self, result): pk = getattr(result, self.pk_attname) - return reverse('admin:%s_%s_change' % (self.opts.app_label, - self.opts.model_name), - args=(quote(pk),), - current_app=self.model_admin.admin_site.name) + return reverse( + "admin:%s_%s_change" % (self.opts.app_label, self.opts.model_name), + args=(quote(pk),), + current_app=self.model_admin.admin_site.name, + ) diff --git a/django/contrib/admin/widgets.py b/django/contrib/admin/widgets.py index ec0423b284..b81ea7de3e 100644 --- a/django/contrib/admin/widgets.py +++ b/django/contrib/admin/widgets.py @@ -14,7 +14,8 @@ from django.urls.exceptions import NoReverseMatch from django.utils.html import smart_urlquote from django.utils.http import urlencode from django.utils.text import Truncator -from django.utils.translation import get_language, gettext as _ +from django.utils.translation import get_language +from django.utils.translation import gettext as _ class FilteredSelectMultiple(forms.SelectMultiple): @@ -24,11 +25,12 @@ class FilteredSelectMultiple(forms.SelectMultiple): Note that the resulting JavaScript assumes that the jsi18n catalog has been loaded in the page """ + class Media: js = [ - 'admin/js/core.js', - 'admin/js/SelectBox.js', - 'admin/js/SelectFilter2.js', + "admin/js/core.js", + "admin/js/SelectBox.js", + "admin/js/SelectFilter2.js", ] def __init__(self, verbose_name, is_stacked, attrs=None, choices=()): @@ -38,35 +40,35 @@ class FilteredSelectMultiple(forms.SelectMultiple): def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) - context['widget']['attrs']['class'] = 'selectfilter' + context["widget"]["attrs"]["class"] = "selectfilter" if self.is_stacked: - context['widget']['attrs']['class'] += 'stacked' - context['widget']['attrs']['data-field-name'] = self.verbose_name - context['widget']['attrs']['data-is-stacked'] = int(self.is_stacked) + context["widget"]["attrs"]["class"] += "stacked" + context["widget"]["attrs"]["data-field-name"] = self.verbose_name + context["widget"]["attrs"]["data-is-stacked"] = int(self.is_stacked) return context class AdminDateWidget(forms.DateInput): class Media: js = [ - 'admin/js/calendar.js', - 'admin/js/admin/DateTimeShortcuts.js', + "admin/js/calendar.js", + "admin/js/admin/DateTimeShortcuts.js", ] def __init__(self, attrs=None, format=None): - attrs = {'class': 'vDateField', 'size': '10', **(attrs or {})} + attrs = {"class": "vDateField", "size": "10", **(attrs or {})} super().__init__(attrs=attrs, format=format) class AdminTimeWidget(forms.TimeInput): class Media: js = [ - 'admin/js/calendar.js', - 'admin/js/admin/DateTimeShortcuts.js', + "admin/js/calendar.js", + "admin/js/admin/DateTimeShortcuts.js", ] def __init__(self, attrs=None, format=None): - attrs = {'class': 'vTimeField', 'size': '8', **(attrs or {})} + attrs = {"class": "vTimeField", "size": "8", **(attrs or {})} super().__init__(attrs=attrs, format=format) @@ -74,7 +76,8 @@ class AdminSplitDateTime(forms.SplitDateTimeWidget): """ A SplitDateTime Widget that has some admin-specific styling. """ - template_name = 'admin/widgets/split_datetime.html' + + template_name = "admin/widgets/split_datetime.html" def __init__(self, attrs=None): widgets = [AdminDateWidget, AdminTimeWidget] @@ -84,17 +87,17 @@ class AdminSplitDateTime(forms.SplitDateTimeWidget): def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) - context['date_label'] = _('Date:') - context['time_label'] = _('Time:') + context["date_label"] = _("Date:") + context["time_label"] = _("Time:") return context class AdminRadioSelect(forms.RadioSelect): - template_name = 'admin/widgets/radio.html' + template_name = "admin/widgets/radio.html" class AdminFileWidget(forms.ClearableFileInput): - template_name = 'admin/widgets/clearable_file_input.html' + template_name = "admin/widgets/clearable_file_input.html" def url_params_from_lookup_dict(lookups): @@ -103,14 +106,14 @@ def url_params_from_lookup_dict(lookups): attribute to a dictionary of query parameters """ params = {} - if lookups and hasattr(lookups, 'items'): + if lookups and hasattr(lookups, "items"): for k, v in lookups.items(): if callable(v): v = v() if isinstance(v, (tuple, list)): - v = ','.join(str(x) for x in v) + v = ",".join(str(x) for x in v) elif isinstance(v, bool): - v = ('0', '1')[v] + v = ("0", "1")[v] else: v = str(v) params[k] = v @@ -122,7 +125,8 @@ class ForeignKeyRawIdWidget(forms.TextInput): A Widget for displaying ForeignKeys in the "raw_id" interface rather than in a <select> box. """ - template_name = 'admin/widgets/foreign_key_raw_id.html' + + template_name = "admin/widgets/foreign_key_raw_id.html" def __init__(self, rel, admin_site, attrs=None, using=None): self.rel = rel @@ -136,7 +140,8 @@ class ForeignKeyRawIdWidget(forms.TextInput): if rel_to in self.admin_site._registry: # The related object is registered with the same AdminSite related_url = reverse( - 'admin:%s_%s_changelist' % ( + "admin:%s_%s_changelist" + % ( rel_to._meta.app_label, rel_to._meta.model_name, ), @@ -145,20 +150,22 @@ class ForeignKeyRawIdWidget(forms.TextInput): params = self.url_parameters() if params: - related_url += '?' + urlencode(params) - context['related_url'] = related_url - context['link_title'] = _('Lookup') + related_url += "?" + urlencode(params) + context["related_url"] = related_url + context["link_title"] = _("Lookup") # The JavaScript code looks for this class. - css_class = 'vForeignKeyRawIdAdminField' + css_class = "vForeignKeyRawIdAdminField" if isinstance(self.rel.get_related_field(), UUIDField): - css_class += ' vUUIDField' - context['widget']['attrs'].setdefault('class', css_class) + css_class += " vUUIDField" + context["widget"]["attrs"].setdefault("class", css_class) else: - context['related_url'] = None - if context['widget']['value']: - context['link_label'], context['link_url'] = self.label_and_url_for_value(value) + context["related_url"] = None + if context["widget"]["value"]: + context["link_label"], context["link_url"] = self.label_and_url_for_value( + value + ) else: - context['link_label'] = None + context["link_label"] = None return context def base_url_parameters(self): @@ -169,6 +176,7 @@ class ForeignKeyRawIdWidget(forms.TextInput): def url_parameters(self): from django.contrib.admin.views.main import TO_FIELD_VAR + params = self.base_url_parameters() params.update({TO_FIELD_VAR: self.rel.get_related_field().name}) return params @@ -178,19 +186,20 @@ class ForeignKeyRawIdWidget(forms.TextInput): try: obj = self.rel.model._default_manager.using(self.db).get(**{key: value}) except (ValueError, self.rel.model.DoesNotExist, ValidationError): - return '', '' + return "", "" try: url = reverse( - '%s:%s_%s_change' % ( + "%s:%s_%s_change" + % ( self.admin_site.name, obj._meta.app_label, obj._meta.object_name.lower(), ), - args=(obj.pk,) + args=(obj.pk,), ) except NoReverseMatch: - url = '' # Admin not registered for target model. + url = "" # Admin not registered for target model. return Truncator(obj).words(14), url @@ -200,28 +209,29 @@ class ManyToManyRawIdWidget(ForeignKeyRawIdWidget): A Widget for displaying ManyToMany ids in the "raw_id" interface rather than in a <select multiple> box. """ - template_name = 'admin/widgets/many_to_many_raw_id.html' + + template_name = "admin/widgets/many_to_many_raw_id.html" def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) if self.rel.model in self.admin_site._registry: # The related object is registered with the same AdminSite - context['widget']['attrs']['class'] = 'vManyToManyRawIdAdminField' + context["widget"]["attrs"]["class"] = "vManyToManyRawIdAdminField" return context def url_parameters(self): return self.base_url_parameters() def label_and_url_for_value(self, value): - return '', '' + return "", "" def value_from_datadict(self, data, files, name): value = data.get(name) if value: - return value.split(',') + return value.split(",") def format_value(self, value): - return ','.join(str(v) for v in value) if value else '' + return ",".join(str(v) for v in value) if value else "" class RelatedFieldWidgetWrapper(forms.Widget): @@ -229,11 +239,19 @@ class RelatedFieldWidgetWrapper(forms.Widget): This class is a wrapper to a given widget to add the add icon for the admin interface. """ - template_name = 'admin/widgets/related_widget_wrapper.html' - def __init__(self, widget, rel, admin_site, can_add_related=None, - can_change_related=False, can_delete_related=False, - can_view_related=False): + template_name = "admin/widgets/related_widget_wrapper.html" + + def __init__( + self, + widget, + rel, + admin_site, + can_add_related=None, + can_change_related=False, + can_delete_related=False, + can_view_related=False, + ): self.needs_multipart_form = widget.needs_multipart_form self.attrs = widget.attrs self.choices = widget.choices @@ -245,10 +263,10 @@ class RelatedFieldWidgetWrapper(forms.Widget): can_add_related = rel.model in admin_site._registry self.can_add_related = can_add_related # XXX: The UX does not support multiple selected values. - multiple = getattr(widget, 'allow_multiple_selected', False) + multiple = getattr(widget, "allow_multiple_selected", False) self.can_change_related = not multiple and can_change_related # XXX: The deletion UX can be confusing when dealing with cascading deletion. - cascade = getattr(rel, 'on_delete', None) is CASCADE + cascade = getattr(rel, "on_delete", None) is CASCADE self.can_delete_related = not multiple and not cascade and can_delete_related self.can_view_related = not multiple and can_view_related # so we can check if the related object is registered with this AdminSite @@ -270,35 +288,46 @@ class RelatedFieldWidgetWrapper(forms.Widget): return self.widget.media def get_related_url(self, info, action, *args): - return reverse("admin:%s_%s_%s" % (info + (action,)), - current_app=self.admin_site.name, args=args) + return reverse( + "admin:%s_%s_%s" % (info + (action,)), + current_app=self.admin_site.name, + args=args, + ) def get_context(self, name, value, attrs): from django.contrib.admin.views.main import IS_POPUP_VAR, TO_FIELD_VAR + rel_opts = self.rel.model._meta info = (rel_opts.app_label, rel_opts.model_name) self.widget.choices = self.choices - url_params = '&'.join("%s=%s" % param for param in [ - (TO_FIELD_VAR, self.rel.get_related_field().name), - (IS_POPUP_VAR, 1), - ]) + url_params = "&".join( + "%s=%s" % param + for param in [ + (TO_FIELD_VAR, self.rel.get_related_field().name), + (IS_POPUP_VAR, 1), + ] + ) context = { - 'rendered_widget': self.widget.render(name, value, attrs), - 'is_hidden': self.is_hidden, - 'name': name, - 'url_params': url_params, - 'model': rel_opts.verbose_name, - 'can_add_related': self.can_add_related, - 'can_change_related': self.can_change_related, - 'can_delete_related': self.can_delete_related, - 'can_view_related': self.can_view_related, + "rendered_widget": self.widget.render(name, value, attrs), + "is_hidden": self.is_hidden, + "name": name, + "url_params": url_params, + "model": rel_opts.verbose_name, + "can_add_related": self.can_add_related, + "can_change_related": self.can_change_related, + "can_delete_related": self.can_delete_related, + "can_view_related": self.can_view_related, } if self.can_add_related: - context['add_related_url'] = self.get_related_url(info, 'add') + context["add_related_url"] = self.get_related_url(info, "add") if self.can_delete_related: - context['delete_related_template_url'] = self.get_related_url(info, 'delete', '__fk__') + context["delete_related_template_url"] = self.get_related_url( + info, "delete", "__fk__" + ) if self.can_view_related or self.can_change_related: - context['change_related_template_url'] = self.get_related_url(info, 'change', '__fk__') + context["change_related_template_url"] = self.get_related_url( + info, "change", "__fk__" + ) return context def value_from_datadict(self, data, files, name): @@ -313,67 +342,112 @@ class RelatedFieldWidgetWrapper(forms.Widget): class AdminTextareaWidget(forms.Textarea): def __init__(self, attrs=None): - super().__init__(attrs={'class': 'vLargeTextField', **(attrs or {})}) + super().__init__(attrs={"class": "vLargeTextField", **(attrs or {})}) class AdminTextInputWidget(forms.TextInput): def __init__(self, attrs=None): - super().__init__(attrs={'class': 'vTextField', **(attrs or {})}) + super().__init__(attrs={"class": "vTextField", **(attrs or {})}) class AdminEmailInputWidget(forms.EmailInput): def __init__(self, attrs=None): - super().__init__(attrs={'class': 'vTextField', **(attrs or {})}) + super().__init__(attrs={"class": "vTextField", **(attrs or {})}) class AdminURLFieldWidget(forms.URLInput): - template_name = 'admin/widgets/url.html' + template_name = "admin/widgets/url.html" def __init__(self, attrs=None, validator_class=URLValidator): - super().__init__(attrs={'class': 'vURLField', **(attrs or {})}) + super().__init__(attrs={"class": "vURLField", **(attrs or {})}) self.validator = validator_class() def get_context(self, name, value, attrs): try: - self.validator(value if value else '') + self.validator(value if value else "") url_valid = True except ValidationError: url_valid = False context = super().get_context(name, value, attrs) - context['current_label'] = _('Currently:') - context['change_label'] = _('Change:') - context['widget']['href'] = smart_urlquote(context['widget']['value']) if value else '' - context['url_valid'] = url_valid + context["current_label"] = _("Currently:") + context["change_label"] = _("Change:") + context["widget"]["href"] = ( + smart_urlquote(context["widget"]["value"]) if value else "" + ) + context["url_valid"] = url_valid return context class AdminIntegerFieldWidget(forms.NumberInput): - class_name = 'vIntegerField' + class_name = "vIntegerField" def __init__(self, attrs=None): - super().__init__(attrs={'class': self.class_name, **(attrs or {})}) + super().__init__(attrs={"class": self.class_name, **(attrs or {})}) class AdminBigIntegerFieldWidget(AdminIntegerFieldWidget): - class_name = 'vBigIntegerField' + class_name = "vBigIntegerField" class AdminUUIDInputWidget(forms.TextInput): def __init__(self, attrs=None): - super().__init__(attrs={'class': 'vUUIDField', **(attrs or {})}) + super().__init__(attrs={"class": "vUUIDField", **(attrs or {})}) # Mapping of lowercase language codes [returned by Django's get_language()] to # language codes supported by select2. # See django/contrib/admin/static/admin/js/vendor/select2/i18n/* -SELECT2_TRANSLATIONS = {x.lower(): x for x in [ - 'ar', 'az', 'bg', 'ca', 'cs', 'da', 'de', 'el', 'en', 'es', 'et', - 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hi', 'hr', 'hu', 'id', 'is', - 'it', 'ja', 'km', 'ko', 'lt', 'lv', 'mk', 'ms', 'nb', 'nl', 'pl', - 'pt-BR', 'pt', 'ro', 'ru', 'sk', 'sr-Cyrl', 'sr', 'sv', 'th', - 'tr', 'uk', 'vi', -]} -SELECT2_TRANSLATIONS.update({'zh-hans': 'zh-CN', 'zh-hant': 'zh-TW'}) +SELECT2_TRANSLATIONS = { + x.lower(): x + for x in [ + "ar", + "az", + "bg", + "ca", + "cs", + "da", + "de", + "el", + "en", + "es", + "et", + "eu", + "fa", + "fi", + "fr", + "gl", + "he", + "hi", + "hr", + "hu", + "id", + "is", + "it", + "ja", + "km", + "ko", + "lt", + "lv", + "mk", + "ms", + "nb", + "nl", + "pl", + "pt-BR", + "pt", + "ro", + "ru", + "sk", + "sr-Cyrl", + "sr", + "sv", + "th", + "tr", + "uk", + "vi", + ] +} +SELECT2_TRANSLATIONS.update({"zh-hans": "zh-CN", "zh-hant": "zh-TW"}) class AutocompleteMixin: @@ -383,7 +457,8 @@ class AutocompleteMixin: Renders the necessary data attributes for select2 and adds the static form media. """ - url_name = '%s:autocomplete' + + url_name = "%s:autocomplete" def __init__(self, field, admin_site, attrs=None, choices=(), using=None): self.field = field @@ -405,21 +480,25 @@ class AutocompleteMixin: https://select2.org/configuration/data-attributes#nested-subkey-options """ attrs = super().build_attrs(base_attrs, extra_attrs=extra_attrs) - attrs.setdefault('class', '') - attrs.update({ - 'data-ajax--cache': 'true', - 'data-ajax--delay': 250, - 'data-ajax--type': 'GET', - 'data-ajax--url': self.get_url(), - 'data-app-label': self.field.model._meta.app_label, - 'data-model-name': self.field.model._meta.model_name, - 'data-field-name': self.field.name, - 'data-theme': 'admin-autocomplete', - 'data-allow-clear': json.dumps(not self.is_required), - 'data-placeholder': '', # Allows clearing of the input. - 'lang': self.i18n_name, - 'class': attrs['class'] + (' ' if attrs['class'] else '') + 'admin-autocomplete', - }) + attrs.setdefault("class", "") + attrs.update( + { + "data-ajax--cache": "true", + "data-ajax--delay": 250, + "data-ajax--type": "GET", + "data-ajax--url": self.get_url(), + "data-app-label": self.field.model._meta.app_label, + "data-model-name": self.field.model._meta.model_name, + "data-field-name": self.field.name, + "data-theme": "admin-autocomplete", + "data-allow-clear": json.dumps(not self.is_required), + "data-placeholder": "", # Allows clearing of the input. + "lang": self.i18n_name, + "class": attrs["class"] + + (" " if attrs["class"] else "") + + "admin-autocomplete", + } + ) return attrs def optgroups(self, name, value, attr=None): @@ -428,45 +507,57 @@ class AutocompleteMixin: groups = [default] has_selected = False selected_choices = { - str(v) for v in value - if str(v) not in self.choices.field.empty_values + str(v) for v in value if str(v) not in self.choices.field.empty_values } if not self.is_required and not self.allow_multiple_selected: - default[1].append(self.create_option(name, '', '', False, 0)) + default[1].append(self.create_option(name, "", "", False, 0)) remote_model_opts = self.field.remote_field.model._meta - to_field_name = getattr(self.field.remote_field, 'field_name', remote_model_opts.pk.attname) + to_field_name = getattr( + self.field.remote_field, "field_name", remote_model_opts.pk.attname + ) to_field_name = remote_model_opts.get_field(to_field_name).attname choices = ( (getattr(obj, to_field_name), self.choices.field.label_from_instance(obj)) - for obj in self.choices.queryset.using(self.db).filter(**{'%s__in' % to_field_name: selected_choices}) + for obj in self.choices.queryset.using(self.db).filter( + **{"%s__in" % to_field_name: selected_choices} + ) ) for option_value, option_label in choices: - selected = ( - str(option_value) in value and - (has_selected is False or self.allow_multiple_selected) + selected = str(option_value) in value and ( + has_selected is False or self.allow_multiple_selected ) has_selected |= selected index = len(default[1]) subgroup = default[1] - subgroup.append(self.create_option(name, option_value, option_label, selected_choices, index)) + subgroup.append( + self.create_option( + name, option_value, option_label, selected_choices, index + ) + ) return groups @property def media(self): - extra = '' if settings.DEBUG else '.min' - i18n_file = ('admin/js/vendor/select2/i18n/%s.js' % self.i18n_name,) if self.i18n_name else () + extra = "" if settings.DEBUG else ".min" + i18n_file = ( + ("admin/js/vendor/select2/i18n/%s.js" % self.i18n_name,) + if self.i18n_name + else () + ) return forms.Media( js=( - 'admin/js/vendor/jquery/jquery%s.js' % extra, - 'admin/js/vendor/select2/select2.full%s.js' % extra, - ) + i18n_file + ( - 'admin/js/jquery.init.js', - 'admin/js/autocomplete.js', + "admin/js/vendor/jquery/jquery%s.js" % extra, + "admin/js/vendor/select2/select2.full%s.js" % extra, + ) + + i18n_file + + ( + "admin/js/jquery.init.js", + "admin/js/autocomplete.js", ), css={ - 'screen': ( - 'admin/css/vendor/select2/select2%s.css' % extra, - 'admin/css/autocomplete.css', + "screen": ( + "admin/css/vendor/select2/select2%s.css" % extra, + "admin/css/autocomplete.css", ), }, ) diff --git a/django/contrib/admindocs/apps.py b/django/contrib/admindocs/apps.py index 1a502688f7..e79dc892cb 100644 --- a/django/contrib/admindocs/apps.py +++ b/django/contrib/admindocs/apps.py @@ -3,5 +3,5 @@ from django.utils.translation import gettext_lazy as _ class AdminDocsConfig(AppConfig): - name = 'django.contrib.admindocs' + name = "django.contrib.admindocs" verbose_name = _("Administrative Documentation") diff --git a/django/contrib/admindocs/middleware.py b/django/contrib/admindocs/middleware.py index cf8879d75a..5c9f08d0cc 100644 --- a/django/contrib/admindocs/middleware.py +++ b/django/contrib/admindocs/middleware.py @@ -10,6 +10,7 @@ class XViewMiddleware(MiddlewareMixin): """ Add an X-View header to internal HEAD requests. """ + def process_view(self, request, view_func, view_args, view_kwargs): """ If the request method is HEAD and either the IP is internal or the @@ -17,14 +18,16 @@ class XViewMiddleware(MiddlewareMixin): header indicating the view function. This is used to lookup the view function for an arbitrary page. """ - if not hasattr(request, 'user'): + if not hasattr(request, "user"): raise ImproperlyConfigured( "The XView middleware requires authentication middleware to " "be installed. Edit your MIDDLEWARE setting to insert " "'django.contrib.auth.middleware.AuthenticationMiddleware'." ) - if request.method == 'HEAD' and (request.META.get('REMOTE_ADDR') in settings.INTERNAL_IPS or - (request.user.is_active and request.user.is_staff)): + if request.method == "HEAD" and ( + request.META.get("REMOTE_ADDR") in settings.INTERNAL_IPS + or (request.user.is_active and request.user.is_staff) + ): response = HttpResponse() - response.headers['X-View'] = get_view_name(view_func) + response.headers["X-View"] = get_view_name(view_func) return response diff --git a/django/contrib/admindocs/urls.py b/django/contrib/admindocs/urls.py index bc9c3df7cf..6ee14ac67b 100644 --- a/django/contrib/admindocs/urls.py +++ b/django/contrib/admindocs/urls.py @@ -3,48 +3,48 @@ from django.urls import path, re_path urlpatterns = [ path( - '', - views.BaseAdminDocsView.as_view(template_name='admin_doc/index.html'), - name='django-admindocs-docroot', + "", + views.BaseAdminDocsView.as_view(template_name="admin_doc/index.html"), + name="django-admindocs-docroot", ), path( - 'bookmarklets/', + "bookmarklets/", views.BookmarkletsView.as_view(), - name='django-admindocs-bookmarklets', + name="django-admindocs-bookmarklets", ), path( - 'tags/', + "tags/", views.TemplateTagIndexView.as_view(), - name='django-admindocs-tags', + name="django-admindocs-tags", ), path( - 'filters/', + "filters/", views.TemplateFilterIndexView.as_view(), - name='django-admindocs-filters', + name="django-admindocs-filters", ), path( - 'views/', + "views/", views.ViewIndexView.as_view(), - name='django-admindocs-views-index', + name="django-admindocs-views-index", ), path( - 'views/<view>/', + "views/<view>/", views.ViewDetailView.as_view(), - name='django-admindocs-views-detail', + name="django-admindocs-views-detail", ), path( - 'models/', + "models/", views.ModelIndexView.as_view(), - name='django-admindocs-models-index', + name="django-admindocs-models-index", ), re_path( - r'^models/(?P<app_label>[^\.]+)\.(?P<model_name>[^/]+)/$', + r"^models/(?P<app_label>[^\.]+)\.(?P<model_name>[^/]+)/$", views.ModelDetailView.as_view(), - name='django-admindocs-models-detail', + name="django-admindocs-models-detail", ), path( - 'templates/<path:template>/', + "templates/<path:template>/", views.TemplateDetailView.as_view(), - name='django-admindocs-templates', + name="django-admindocs-templates", ), ] diff --git a/django/contrib/admindocs/utils.py b/django/contrib/admindocs/utils.py index e956913c78..6edff502ec 100644 --- a/django/contrib/admindocs/utils.py +++ b/django/contrib/admindocs/utils.py @@ -20,12 +20,12 @@ else: def get_view_name(view_func): - if hasattr(view_func, 'view_class'): + if hasattr(view_func, "view_class"): klass = view_func.view_class - return f'{klass.__module__}.{klass.__qualname__}' + return f"{klass.__module__}.{klass.__qualname__}" mod_name = view_func.__module__ - view_name = getattr(view_func, '__qualname__', view_func.__class__.__name__) - return mod_name + '.' + view_name + view_name = getattr(view_func, "__qualname__", view_func.__class__.__name__) + return mod_name + "." + view_name def parse_docstring(docstring): @@ -33,12 +33,12 @@ def parse_docstring(docstring): Parse out the parts of a docstring. Return (title, body, metadata). """ if not docstring: - return '', '', {} + return "", "", {} docstring = cleandoc(docstring) - parts = re.split(r'\n{2,}', docstring) + parts = re.split(r"\n{2,}", docstring) title = parts[0] if len(parts) == 1: - body = '' + body = "" metadata = {} else: parser = HeaderParser() @@ -61,14 +61,14 @@ def parse_rst(text, default_reference_context, thing_being_parsed=None): Convert the string from reST to an XHTML fragment. """ overrides = { - 'doctitle_xform': True, - 'initial_header_level': 3, + "doctitle_xform": True, + "initial_header_level": 3, "default_reference_context": default_reference_context, - "link_base": reverse('django-admindocs-docroot').rstrip('/'), - 'raw_enabled': False, - 'file_insertion_enabled': False, + "link_base": reverse("django-admindocs-docroot").rstrip("/"), + "raw_enabled": False, + "file_insertion_enabled": False, } - thing_being_parsed = thing_being_parsed and '<%s>' % thing_being_parsed + thing_being_parsed = thing_being_parsed and "<%s>" % thing_being_parsed # Wrap ``text`` in some reST that sets the default role to ``cmsreference``, # then restores it. source = """ @@ -80,21 +80,23 @@ def parse_rst(text, default_reference_context, thing_being_parsed=None): """ parts = docutils.core.publish_parts( source % text, - source_path=thing_being_parsed, destination_path=None, - writer_name='html', settings_overrides=overrides, + source_path=thing_being_parsed, + destination_path=None, + writer_name="html", + settings_overrides=overrides, ) - return mark_safe(parts['fragment']) + return mark_safe(parts["fragment"]) # # reST roles # ROLES = { - 'model': '%s/models/%s/', - 'view': '%s/views/%s/', - 'template': '%s/templates/%s/', - 'filter': '%s/filters/#%s', - 'tag': '%s/tags/#%s', + "model": "%s/models/%s/", + "view": "%s/views/%s/", + "template": "%s/templates/%s/", + "filter": "%s/filters/#%s", + "tag": "%s/tags/#%s", } @@ -105,48 +107,59 @@ def create_reference_role(rolename, urlbase): node = docutils.nodes.reference( rawtext, text, - refuri=(urlbase % ( - inliner.document.settings.link_base, - text.lower(), - )), - **options + refuri=( + urlbase + % ( + inliner.document.settings.link_base, + text.lower(), + ) + ), + **options, ) return [node], [] + docutils.parsers.rst.roles.register_canonical_role(rolename, _role) -def default_reference_role(name, rawtext, text, lineno, inliner, options=None, content=None): +def default_reference_role( + name, rawtext, text, lineno, inliner, options=None, content=None +): if options is None: options = {} context = inliner.document.settings.default_reference_context node = docutils.nodes.reference( rawtext, text, - refuri=(ROLES[context] % ( - inliner.document.settings.link_base, - text.lower(), - )), - **options + refuri=( + ROLES[context] + % ( + inliner.document.settings.link_base, + text.lower(), + ) + ), + **options, ) return [node], [] if docutils_is_available: - docutils.parsers.rst.roles.register_canonical_role('cmsreference', default_reference_role) + docutils.parsers.rst.roles.register_canonical_role( + "cmsreference", default_reference_role + ) for name, urlbase in ROLES.items(): create_reference_role(name, urlbase) # Match the beginning of a named, unnamed, or non-capturing groups. -named_group_matcher = _lazy_re_compile(r'\(\?P(<\w+>)') -unnamed_group_matcher = _lazy_re_compile(r'\(') -non_capturing_group_matcher = _lazy_re_compile(r'\(\?\:') +named_group_matcher = _lazy_re_compile(r"\(\?P(<\w+>)") +unnamed_group_matcher = _lazy_re_compile(r"\(") +non_capturing_group_matcher = _lazy_re_compile(r"\(\?\:") def replace_metacharacters(pattern): """Remove unescaped metacharacters from the pattern.""" return re.sub( - r'((?:^|(?<!\\))(?:\\\\)*)(\\?)([?*+^$]|\\[bBAZ])', + r"((?:^|(?<!\\))(?:\\\\)*)(\\?)([?*+^$]|\\[bBAZ])", lambda m: m[1] + m[3] if m[2] else m[1], pattern, ) @@ -158,9 +171,9 @@ def _get_group_start_end(start, end, pattern): for idx, val in enumerate(pattern[end:]): # Check for unescaped `(` and `)`. They mark the start and end of a # nested group. - if val == '(' and prev_char != '\\': + if val == "(" and prev_char != "\\": unmatched_open_brackets += 1 - elif val == ')' and prev_char != '\\': + elif val == ")" and prev_char != "\\": unmatched_open_brackets -= 1 prev_char = val # If brackets are balanced, the end of the string for the current named @@ -204,11 +217,11 @@ def replace_unnamed_groups(pattern): 3. ^(?P<a>\w+)/b/(\w+) ==> ^(?P<a>\w+)/b/<var> 4. ^(?P<a>\w+)/b/((x|y)\w+) ==> ^(?P<a>\w+)/b/<var> """ - final_pattern, prev_end = '', None + final_pattern, prev_end = "", None for start, end, _ in _find_groups(pattern, unnamed_group_matcher): if prev_end: final_pattern += pattern[prev_end:start] - final_pattern += pattern[:start] + '<var>' + final_pattern += pattern[:start] + "<var>" prev_end = end return final_pattern + pattern[prev_end:] @@ -221,7 +234,7 @@ def remove_non_capturing_groups(pattern): 3. ^a(?:\w+)/b(?:\w+) => ^a/b """ group_start_end_indices = _find_groups(pattern, non_capturing_group_matcher) - final_pattern, prev_end = '', None + final_pattern, prev_end = "", None for start, end, _ in group_start_end_indices: final_pattern += pattern[prev_end:start] prev_end = end diff --git a/django/contrib/admindocs/views.py b/django/contrib/admindocs/views.py index a1ad47626b..468c69d43f 100644 --- a/django/contrib/admindocs/views.py +++ b/django/contrib/admindocs/views.py @@ -8,7 +8,9 @@ from django.contrib import admin from django.contrib.admin.views.decorators import staff_member_required from django.contrib.admindocs import utils from django.contrib.admindocs.utils import ( - remove_non_capturing_groups, replace_metacharacters, replace_named_groups, + remove_non_capturing_groups, + replace_metacharacters, + replace_named_groups, replace_unnamed_groups, ) from django.core.exceptions import ImproperlyConfigured, ViewDoesNotExist @@ -20,7 +22,9 @@ from django.utils._os import safe_join from django.utils.decorators import method_decorator from django.utils.functional import cached_property from django.utils.inspect import ( - func_accepts_kwargs, func_accepts_var_args, get_func_full_args, + func_accepts_kwargs, + func_accepts_var_args, + get_func_full_args, method_has_no_args, ) from django.utils.translation import gettext as _ @@ -29,34 +33,37 @@ from django.views.generic import TemplateView from .utils import get_view_name # Exclude methods starting with these strings from documentation -MODEL_METHODS_EXCLUDE = ('_', 'add_', 'delete', 'save', 'set_') +MODEL_METHODS_EXCLUDE = ("_", "add_", "delete", "save", "set_") class BaseAdminDocsView(TemplateView): """ Base view for admindocs views. """ + @method_decorator(staff_member_required) def dispatch(self, request, *args, **kwargs): if not utils.docutils_is_available: # Display an error message for people without docutils - self.template_name = 'admin_doc/missing_docutils.html' + self.template_name = "admin_doc/missing_docutils.html" return self.render_to_response(admin.site.each_context(request)) return super().dispatch(request, *args, **kwargs) def get_context_data(self, **kwargs): - return super().get_context_data(**{ - **kwargs, - **admin.site.each_context(self.request), - }) + return super().get_context_data( + **{ + **kwargs, + **admin.site.each_context(self.request), + } + ) class BookmarkletsView(BaseAdminDocsView): - template_name = 'admin_doc/bookmarklets.html' + template_name = "admin_doc/bookmarklets.html" class TemplateTagIndexView(BaseAdminDocsView): - template_name = 'admin_doc/template_tag_index.html' + template_name = "admin_doc/template_tag_index.html" def get_context_data(self, **kwargs): tags = [] @@ -67,27 +74,33 @@ class TemplateTagIndexView(BaseAdminDocsView): pass else: app_libs = sorted(engine.template_libraries.items()) - builtin_libs = [('', lib) for lib in engine.template_builtins] + builtin_libs = [("", lib) for lib in engine.template_builtins] for module_name, library in builtin_libs + app_libs: for tag_name, tag_func in library.tags.items(): title, body, metadata = utils.parse_docstring(tag_func.__doc__) - title = title and utils.parse_rst(title, 'tag', _('tag:') + tag_name) - body = body and utils.parse_rst(body, 'tag', _('tag:') + tag_name) + title = title and utils.parse_rst( + title, "tag", _("tag:") + tag_name + ) + body = body and utils.parse_rst(body, "tag", _("tag:") + tag_name) for key in metadata: - metadata[key] = utils.parse_rst(metadata[key], 'tag', _('tag:') + tag_name) - tag_library = module_name.split('.')[-1] - tags.append({ - 'name': tag_name, - 'title': title, - 'body': body, - 'meta': metadata, - 'library': tag_library, - }) - return super().get_context_data(**{**kwargs, 'tags': tags}) + metadata[key] = utils.parse_rst( + metadata[key], "tag", _("tag:") + tag_name + ) + tag_library = module_name.split(".")[-1] + tags.append( + { + "name": tag_name, + "title": title, + "body": body, + "meta": metadata, + "library": tag_library, + } + ) + return super().get_context_data(**{**kwargs, "tags": tags}) class TemplateFilterIndexView(BaseAdminDocsView): - template_name = 'admin_doc/template_filter_index.html' + template_name = "admin_doc/template_filter_index.html" def get_context_data(self, **kwargs): filters = [] @@ -98,27 +111,35 @@ class TemplateFilterIndexView(BaseAdminDocsView): pass else: app_libs = sorted(engine.template_libraries.items()) - builtin_libs = [('', lib) for lib in engine.template_builtins] + builtin_libs = [("", lib) for lib in engine.template_builtins] for module_name, library in builtin_libs + app_libs: for filter_name, filter_func in library.filters.items(): title, body, metadata = utils.parse_docstring(filter_func.__doc__) - title = title and utils.parse_rst(title, 'filter', _('filter:') + filter_name) - body = body and utils.parse_rst(body, 'filter', _('filter:') + filter_name) + title = title and utils.parse_rst( + title, "filter", _("filter:") + filter_name + ) + body = body and utils.parse_rst( + body, "filter", _("filter:") + filter_name + ) for key in metadata: - metadata[key] = utils.parse_rst(metadata[key], 'filter', _('filter:') + filter_name) - tag_library = module_name.split('.')[-1] - filters.append({ - 'name': filter_name, - 'title': title, - 'body': body, - 'meta': metadata, - 'library': tag_library, - }) - return super().get_context_data(**{**kwargs, 'filters': filters}) + metadata[key] = utils.parse_rst( + metadata[key], "filter", _("filter:") + filter_name + ) + tag_library = module_name.split(".")[-1] + filters.append( + { + "name": filter_name, + "title": title, + "body": body, + "meta": metadata, + "library": tag_library, + } + ) + return super().get_context_data(**{**kwargs, "filters": filters}) class ViewIndexView(BaseAdminDocsView): - template_name = 'admin_doc/view_index.html' + template_name = "admin_doc/view_index.html" def get_context_data(self, **kwargs): views = [] @@ -128,18 +149,20 @@ class ViewIndexView(BaseAdminDocsView): except ImproperlyConfigured: view_functions = [] for (func, regex, namespace, name) in view_functions: - views.append({ - 'full_name': get_view_name(func), - 'url': simplify_regex(regex), - 'url_name': ':'.join((namespace or []) + (name and [name] or [])), - 'namespace': ':'.join(namespace or []), - 'name': name, - }) - return super().get_context_data(**{**kwargs, 'views': views}) + views.append( + { + "full_name": get_view_name(func), + "url": simplify_regex(regex), + "url_name": ":".join((namespace or []) + (name and [name] or [])), + "namespace": ":".join(namespace or []), + "name": name, + } + ) + return super().get_context_data(**{**kwargs, "views": views}) class ViewDetailView(BaseAdminDocsView): - template_name = 'admin_doc/view_detail.html' + template_name = "admin_doc/view_detail.html" @staticmethod def _get_view_func(view): @@ -159,52 +182,56 @@ class ViewDetailView(BaseAdminDocsView): return getattr(getattr(import_module(mod), klass), func) def get_context_data(self, **kwargs): - view = self.kwargs['view'] + view = self.kwargs["view"] view_func = self._get_view_func(view) if view_func is None: raise Http404 title, body, metadata = utils.parse_docstring(view_func.__doc__) - title = title and utils.parse_rst(title, 'view', _('view:') + view) - body = body and utils.parse_rst(body, 'view', _('view:') + view) + title = title and utils.parse_rst(title, "view", _("view:") + view) + body = body and utils.parse_rst(body, "view", _("view:") + view) for key in metadata: - metadata[key] = utils.parse_rst(metadata[key], 'model', _('view:') + view) - return super().get_context_data(**{ - **kwargs, - 'name': view, - 'summary': title, - 'body': body, - 'meta': metadata, - }) + metadata[key] = utils.parse_rst(metadata[key], "model", _("view:") + view) + return super().get_context_data( + **{ + **kwargs, + "name": view, + "summary": title, + "body": body, + "meta": metadata, + } + ) class ModelIndexView(BaseAdminDocsView): - template_name = 'admin_doc/model_index.html' + template_name = "admin_doc/model_index.html" def get_context_data(self, **kwargs): m_list = [m._meta for m in apps.get_models()] - return super().get_context_data(**{**kwargs, 'models': m_list}) + return super().get_context_data(**{**kwargs, "models": m_list}) class ModelDetailView(BaseAdminDocsView): - template_name = 'admin_doc/model_detail.html' + template_name = "admin_doc/model_detail.html" def get_context_data(self, **kwargs): - model_name = self.kwargs['model_name'] + model_name = self.kwargs["model_name"] # Get the model class. try: - app_config = apps.get_app_config(self.kwargs['app_label']) + app_config = apps.get_app_config(self.kwargs["app_label"]) except LookupError: raise Http404(_("App %(app_label)r not found") % self.kwargs) try: model = app_config.get_model(model_name) except LookupError: - raise Http404(_("Model %(model_name)r not found in app %(app_label)r") % self.kwargs) + raise Http404( + _("Model %(model_name)r not found in app %(app_label)r") % self.kwargs + ) opts = model._meta title, body, metadata = utils.parse_docstring(model.__doc__) - title = title and utils.parse_rst(title, 'model', _('model:') + model_name) - body = body and utils.parse_rst(body, 'model', _('model:') + model_name) + title = title and utils.parse_rst(title, "model", _("model:") + model_name) + body = body and utils.parse_rst(body, "model", _("model:") + model_name) # Gather fields/field descriptions. fields = [] @@ -215,45 +242,63 @@ class ModelDetailView(BaseAdminDocsView): data_type = field.remote_field.model.__name__ app_label = field.remote_field.model._meta.app_label verbose = utils.parse_rst( - (_("the related `%(app_label)s.%(data_type)s` object") % { - 'app_label': app_label, 'data_type': data_type, - }), - 'model', - _('model:') + data_type, + ( + _("the related `%(app_label)s.%(data_type)s` object") + % { + "app_label": app_label, + "data_type": data_type, + } + ), + "model", + _("model:") + data_type, ) else: data_type = get_readable_field_data_type(field) verbose = field.verbose_name - fields.append({ - 'name': field.name, - 'data_type': data_type, - 'verbose': verbose or '', - 'help_text': field.help_text, - }) + fields.append( + { + "name": field.name, + "data_type": data_type, + "verbose": verbose or "", + "help_text": field.help_text, + } + ) # Gather many-to-many fields. for field in opts.many_to_many: data_type = field.remote_field.model.__name__ app_label = field.remote_field.model._meta.app_label verbose = _("related `%(app_label)s.%(object_name)s` objects") % { - 'app_label': app_label, - 'object_name': data_type, + "app_label": app_label, + "object_name": data_type, } - fields.append({ - 'name': "%s.all" % field.name, - "data_type": 'List', - 'verbose': utils.parse_rst(_("all %s") % verbose, 'model', _('model:') + opts.model_name), - }) - fields.append({ - 'name': "%s.count" % field.name, - 'data_type': 'Integer', - 'verbose': utils.parse_rst(_("number of %s") % verbose, 'model', _('model:') + opts.model_name), - }) + fields.append( + { + "name": "%s.all" % field.name, + "data_type": "List", + "verbose": utils.parse_rst( + _("all %s") % verbose, "model", _("model:") + opts.model_name + ), + } + ) + fields.append( + { + "name": "%s.count" % field.name, + "data_type": "Integer", + "verbose": utils.parse_rst( + _("number of %s") % verbose, + "model", + _("model:") + opts.model_name, + ), + } + ) methods = [] # Gather model methods. for func_name, func in model.__dict__.items(): - if inspect.isfunction(func) or isinstance(func, (cached_property, property)): + if inspect.isfunction(func) or isinstance( + func, (cached_property, property) + ): try: for exclude in MODEL_METHODS_EXCLUDE: if func_name.startswith(exclude): @@ -262,70 +307,96 @@ class ModelDetailView(BaseAdminDocsView): continue verbose = func.__doc__ verbose = verbose and ( - utils.parse_rst(cleandoc(verbose), 'model', _('model:') + opts.model_name) + utils.parse_rst( + cleandoc(verbose), "model", _("model:") + opts.model_name + ) ) # Show properties, cached_properties, and methods without # arguments as fields. Otherwise, show as a 'method with # arguments'. if isinstance(func, (cached_property, property)): - fields.append({ - 'name': func_name, - 'data_type': get_return_data_type(func_name), - 'verbose': verbose or '' - }) - elif method_has_no_args(func) and not func_accepts_kwargs(func) and not func_accepts_var_args(func): - fields.append({ - 'name': func_name, - 'data_type': get_return_data_type(func_name), - 'verbose': verbose or '', - }) + fields.append( + { + "name": func_name, + "data_type": get_return_data_type(func_name), + "verbose": verbose or "", + } + ) + elif ( + method_has_no_args(func) + and not func_accepts_kwargs(func) + and not func_accepts_var_args(func) + ): + fields.append( + { + "name": func_name, + "data_type": get_return_data_type(func_name), + "verbose": verbose or "", + } + ) else: arguments = get_func_full_args(func) # Join arguments with ', ' and in case of default value, # join it with '='. Use repr() so that strings will be # correctly displayed. - print_arguments = ', '.join([ - '='.join([arg_el[0], *map(repr, arg_el[1:])]) - for arg_el in arguments - ]) - methods.append({ - 'name': func_name, - 'arguments': print_arguments, - 'verbose': verbose or '', - }) + print_arguments = ", ".join( + [ + "=".join([arg_el[0], *map(repr, arg_el[1:])]) + for arg_el in arguments + ] + ) + methods.append( + { + "name": func_name, + "arguments": print_arguments, + "verbose": verbose or "", + } + ) # Gather related objects for rel in opts.related_objects: verbose = _("related `%(app_label)s.%(object_name)s` objects") % { - 'app_label': rel.related_model._meta.app_label, - 'object_name': rel.related_model._meta.object_name, + "app_label": rel.related_model._meta.app_label, + "object_name": rel.related_model._meta.object_name, } accessor = rel.get_accessor_name() - fields.append({ - 'name': "%s.all" % accessor, - 'data_type': 'List', - 'verbose': utils.parse_rst(_("all %s") % verbose, 'model', _('model:') + opts.model_name), - }) - fields.append({ - 'name': "%s.count" % accessor, - 'data_type': 'Integer', - 'verbose': utils.parse_rst(_("number of %s") % verbose, 'model', _('model:') + opts.model_name), - }) - return super().get_context_data(**{ - **kwargs, - 'name': opts.label, - 'summary': title, - 'description': body, - 'fields': fields, - 'methods': methods, - }) + fields.append( + { + "name": "%s.all" % accessor, + "data_type": "List", + "verbose": utils.parse_rst( + _("all %s") % verbose, "model", _("model:") + opts.model_name + ), + } + ) + fields.append( + { + "name": "%s.count" % accessor, + "data_type": "Integer", + "verbose": utils.parse_rst( + _("number of %s") % verbose, + "model", + _("model:") + opts.model_name, + ), + } + ) + return super().get_context_data( + **{ + **kwargs, + "name": opts.label, + "summary": title, + "description": body, + "fields": fields, + "methods": methods, + } + ) class TemplateDetailView(BaseAdminDocsView): - template_name = 'admin_doc/template_detail.html' + template_name = "admin_doc/template_detail.html" def get_context_data(self, **kwargs): - template = self.kwargs['template'] + template = self.kwargs["template"] templates = [] try: default_engine = Engine.get_default() @@ -339,18 +410,22 @@ class TemplateDetailView(BaseAdminDocsView): if template_file.exists(): template_contents = template_file.read_text() else: - template_contents = '' - templates.append({ - 'file': template_file, - 'exists': template_file.exists(), - 'contents': template_contents, - 'order': index, - }) - return super().get_context_data(**{ - **kwargs, - 'name': template, - 'templates': templates, - }) + template_contents = "" + templates.append( + { + "file": template_file, + "exists": template_file.exists(), + "contents": template_contents, + "order": index, + } + ) + return super().get_context_data( + **{ + **kwargs, + "name": template, + "templates": templates, + } + ) #################### @@ -360,12 +435,12 @@ class TemplateDetailView(BaseAdminDocsView): def get_return_data_type(func_name): """Return a somewhat-helpful data type given a function name""" - if func_name.startswith('get_'): - if func_name.endswith('_list'): - return 'List' - elif func_name.endswith('_count'): - return 'Integer' - return '' + if func_name.startswith("get_"): + if func_name.endswith("_list"): + return "List" + elif func_name.endswith("_count"): + return "Integer" + return "" def get_readable_field_data_type(field): @@ -377,7 +452,7 @@ def get_readable_field_data_type(field): return field.description % field.__dict__ -def extract_views_from_urlpatterns(urlpatterns, base='', namespace=None): +def extract_views_from_urlpatterns(urlpatterns, base="", namespace=None): """ Return a list of views from a list of urlpatterns. @@ -385,17 +460,19 @@ def extract_views_from_urlpatterns(urlpatterns, base='', namespace=None): """ views = [] for p in urlpatterns: - if hasattr(p, 'url_patterns'): + if hasattr(p, "url_patterns"): try: patterns = p.url_patterns except ImportError: continue - views.extend(extract_views_from_urlpatterns( - patterns, - base + str(p.pattern), - (namespace or []) + (p.namespace and [p.namespace] or []) - )) - elif hasattr(p, 'callback'): + views.extend( + extract_views_from_urlpatterns( + patterns, + base + str(p.pattern), + (namespace or []) + (p.namespace and [p.namespace] or []), + ) + ) + elif hasattr(p, "callback"): try: views.append((p.callback, base + str(p.pattern), namespace, p.name)) except ViewDoesNotExist: @@ -415,6 +492,6 @@ def simplify_regex(pattern): pattern = replace_named_groups(pattern) pattern = replace_unnamed_groups(pattern) pattern = replace_metacharacters(pattern) - if not pattern.startswith('/'): - pattern = '/' + pattern + if not pattern.startswith("/"): + pattern = "/" + pattern return pattern diff --git a/django/contrib/auth/__init__.py b/django/contrib/auth/__init__.py index 1e15665ced..00a2940344 100644 --- a/django/contrib/auth/__init__.py +++ b/django/contrib/auth/__init__.py @@ -11,10 +11,10 @@ from django.views.decorators.debug import sensitive_variables from .signals import user_logged_in, user_logged_out, user_login_failed -SESSION_KEY = '_auth_user_id' -BACKEND_SESSION_KEY = '_auth_user_backend' -HASH_SESSION_KEY = '_auth_user_hash' -REDIRECT_FIELD_NAME = 'next' +SESSION_KEY = "_auth_user_id" +BACKEND_SESSION_KEY = "_auth_user_backend" +HASH_SESSION_KEY = "_auth_user_hash" +REDIRECT_FIELD_NAME = "next" def load_backend(path): @@ -28,8 +28,8 @@ def _get_backends(return_tuples=False): backends.append((backend, backend_path) if return_tuples else backend) if not backends: raise ImproperlyConfigured( - 'No authentication backends have been defined. Does ' - 'AUTHENTICATION_BACKENDS contain anything?' + "No authentication backends have been defined. Does " + "AUTHENTICATION_BACKENDS contain anything?" ) return backends @@ -38,7 +38,7 @@ def get_backends(): return _get_backends(return_tuples=False) -@sensitive_variables('credentials') +@sensitive_variables("credentials") def _clean_credentials(credentials): """ Clean a dictionary of credentials of potentially sensitive info before @@ -46,8 +46,8 @@ def _clean_credentials(credentials): Not comprehensive - intended for user_login_failed signal """ - SENSITIVE_CREDENTIALS = re.compile('api|token|key|secret|password|signature', re.I) - CLEANSED_SUBSTITUTE = '********************' + SENSITIVE_CREDENTIALS = re.compile("api|token|key|secret|password|signature", re.I) + CLEANSED_SUBSTITUTE = "********************" for key in credentials: if SENSITIVE_CREDENTIALS.search(key): credentials[key] = CLEANSED_SUBSTITUTE @@ -60,7 +60,7 @@ def _get_user_session_key(request): return get_user_model()._meta.pk.to_python(request.session[SESSION_KEY]) -@sensitive_variables('credentials') +@sensitive_variables("credentials") def authenticate(request=None, **credentials): """ If the given credentials are valid, return a User object. @@ -84,7 +84,9 @@ def authenticate(request=None, **credentials): return user # The credentials supplied are invalid to all backends, fire signal - user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request) + user_login_failed.send( + sender=__name__, credentials=_clean_credentials(credentials), request=request + ) def login(request, user, backend=None): @@ -93,16 +95,19 @@ def login(request, user, backend=None): have to reauthenticate on every request. Note that data set during the anonymous session is retained when the user logs in. """ - session_auth_hash = '' + session_auth_hash = "" if user is None: user = request.user - if hasattr(user, 'get_session_auth_hash'): + if hasattr(user, "get_session_auth_hash"): session_auth_hash = user.get_session_auth_hash() if SESSION_KEY in request.session: if _get_user_session_key(request) != user.pk or ( - session_auth_hash and - not constant_time_compare(request.session.get(HASH_SESSION_KEY, ''), session_auth_hash)): + session_auth_hash + and not constant_time_compare( + request.session.get(HASH_SESSION_KEY, ""), session_auth_hash + ) + ): # To avoid reusing another user's session, create a new, empty # session if the existing session corresponds to a different # authenticated user. @@ -118,18 +123,20 @@ def login(request, user, backend=None): _, backend = backends[0] else: raise ValueError( - 'You have multiple authentication backends configured and ' - 'therefore must provide the `backend` argument or set the ' - '`backend` attribute on the user.' + "You have multiple authentication backends configured and " + "therefore must provide the `backend` argument or set the " + "`backend` attribute on the user." ) else: if not isinstance(backend, str): - raise TypeError('backend must be a dotted import path string (got %r).' % backend) + raise TypeError( + "backend must be a dotted import path string (got %r)." % backend + ) request.session[SESSION_KEY] = user._meta.pk.value_to_string(user) request.session[BACKEND_SESSION_KEY] = backend request.session[HASH_SESSION_KEY] = session_auth_hash - if hasattr(request, 'user'): + if hasattr(request, "user"): request.user = user rotate_token(request) user_logged_in.send(sender=user.__class__, request=request, user=user) @@ -142,13 +149,14 @@ def logout(request): """ # Dispatch the signal before the user is logged out so the receivers have a # chance to find out *who* logged out. - user = getattr(request, 'user', None) - if not getattr(user, 'is_authenticated', True): + user = getattr(request, "user", None) + if not getattr(user, "is_authenticated", True): user = None user_logged_out.send(sender=user.__class__, request=request, user=user) request.session.flush() - if hasattr(request, 'user'): + if hasattr(request, "user"): from django.contrib.auth.models import AnonymousUser + request.user = AnonymousUser() @@ -159,10 +167,13 @@ def get_user_model(): try: return django_apps.get_model(settings.AUTH_USER_MODEL, require_ready=False) except ValueError: - raise ImproperlyConfigured("AUTH_USER_MODEL must be of the form 'app_label.model_name'") + raise ImproperlyConfigured( + "AUTH_USER_MODEL must be of the form 'app_label.model_name'" + ) except LookupError: raise ImproperlyConfigured( - "AUTH_USER_MODEL refers to model '%s' that has not been installed" % settings.AUTH_USER_MODEL + "AUTH_USER_MODEL refers to model '%s' that has not been installed" + % settings.AUTH_USER_MODEL ) @@ -172,6 +183,7 @@ def get_user(request): If no user is retrieved, return an instance of `AnonymousUser`. """ from .models import AnonymousUser + user = None try: user_id = _get_user_session_key(request) @@ -183,11 +195,10 @@ def get_user(request): backend = load_backend(backend_path) user = backend.get_user(user_id) # Verify the session - if hasattr(user, 'get_session_auth_hash'): + if hasattr(user, "get_session_auth_hash"): session_hash = request.session.get(HASH_SESSION_KEY) session_hash_verified = session_hash and constant_time_compare( - session_hash, - user.get_session_auth_hash() + session_hash, user.get_session_auth_hash() ) if not session_hash_verified: request.session.flush() @@ -200,7 +211,7 @@ def get_permission_codename(action, opts): """ Return the codename of the permission for the specified action. """ - return '%s_%s' % (action, opts.model_name) + return "%s_%s" % (action, opts.model_name) def update_session_auth_hash(request, user): @@ -213,5 +224,5 @@ def update_session_auth_hash(request, user): password was changed. """ request.session.cycle_key() - if hasattr(user, 'get_session_auth_hash') and request.user == user: + if hasattr(user, "get_session_auth_hash") and request.user == user: request.session[HASH_SESSION_KEY] = user.get_session_auth_hash() diff --git a/django/contrib/auth/admin.py b/django/contrib/auth/admin.py index a63a4e28a7..636e620eb1 100644 --- a/django/contrib/auth/admin.py +++ b/django/contrib/auth/admin.py @@ -4,7 +4,9 @@ from django.contrib.admin.options import IS_POPUP_VAR from django.contrib.admin.utils import unquote from django.contrib.auth import update_session_auth_hash from django.contrib.auth.forms import ( - AdminPasswordChangeForm, UserChangeForm, UserCreationForm, + AdminPasswordChangeForm, + UserChangeForm, + UserCreationForm, ) from django.contrib.auth.models import Group, User from django.core.exceptions import PermissionDenied @@ -14,7 +16,8 @@ from django.template.response import TemplateResponse from django.urls import path, reverse from django.utils.decorators import method_decorator from django.utils.html import escape -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ from django.views.decorators.csrf import csrf_protect from django.views.decorators.debug import sensitive_post_parameters @@ -24,45 +27,60 @@ sensitive_post_parameters_m = method_decorator(sensitive_post_parameters()) @admin.register(Group) class GroupAdmin(admin.ModelAdmin): - search_fields = ('name',) - ordering = ('name',) - filter_horizontal = ('permissions',) + search_fields = ("name",) + ordering = ("name",) + filter_horizontal = ("permissions",) def formfield_for_manytomany(self, db_field, request=None, **kwargs): - if db_field.name == 'permissions': - qs = kwargs.get('queryset', db_field.remote_field.model.objects) + if db_field.name == "permissions": + qs = kwargs.get("queryset", db_field.remote_field.model.objects) # Avoid a major performance hit resolving permission names which # triggers a content_type load: - kwargs['queryset'] = qs.select_related('content_type') + kwargs["queryset"] = qs.select_related("content_type") return super().formfield_for_manytomany(db_field, request=request, **kwargs) @admin.register(User) class UserAdmin(admin.ModelAdmin): - add_form_template = 'admin/auth/user/add_form.html' + add_form_template = "admin/auth/user/add_form.html" change_user_password_template = None fieldsets = ( - (None, {'fields': ('username', 'password')}), - (_('Personal info'), {'fields': ('first_name', 'last_name', 'email')}), - (_('Permissions'), { - 'fields': ('is_active', 'is_staff', 'is_superuser', 'groups', 'user_permissions'), - }), - (_('Important dates'), {'fields': ('last_login', 'date_joined')}), + (None, {"fields": ("username", "password")}), + (_("Personal info"), {"fields": ("first_name", "last_name", "email")}), + ( + _("Permissions"), + { + "fields": ( + "is_active", + "is_staff", + "is_superuser", + "groups", + "user_permissions", + ), + }, + ), + (_("Important dates"), {"fields": ("last_login", "date_joined")}), ) add_fieldsets = ( - (None, { - 'classes': ('wide',), - 'fields': ('username', 'password1', 'password2'), - }), + ( + None, + { + "classes": ("wide",), + "fields": ("username", "password1", "password2"), + }, + ), ) form = UserChangeForm add_form = UserCreationForm change_password_form = AdminPasswordChangeForm - list_display = ('username', 'email', 'first_name', 'last_name', 'is_staff') - list_filter = ('is_staff', 'is_superuser', 'is_active', 'groups') - search_fields = ('username', 'first_name', 'last_name', 'email') - ordering = ('username',) - filter_horizontal = ('groups', 'user_permissions',) + list_display = ("username", "email", "first_name", "last_name", "is_staff") + list_filter = ("is_staff", "is_superuser", "is_active", "groups") + search_fields = ("username", "first_name", "last_name", "email") + ordering = ("username",) + filter_horizontal = ( + "groups", + "user_permissions", + ) def get_fieldsets(self, request, obj=None): if not obj: @@ -75,30 +93,32 @@ class UserAdmin(admin.ModelAdmin): """ defaults = {} if obj is None: - defaults['form'] = self.add_form + defaults["form"] = self.add_form defaults.update(kwargs) return super().get_form(request, obj, **defaults) def get_urls(self): return [ path( - '<id>/password/', + "<id>/password/", self.admin_site.admin_view(self.user_change_password), - name='auth_user_password_change', + name="auth_user_password_change", ), ] + super().get_urls() def lookup_allowed(self, lookup, value): # Don't allow lookups involving passwords. - return not lookup.startswith('password') and super().lookup_allowed(lookup, value) + return not lookup.startswith("password") and super().lookup_allowed( + lookup, value + ) @sensitive_post_parameters_m @csrf_protect_m - def add_view(self, request, form_url='', extra_context=None): + def add_view(self, request, form_url="", extra_context=None): with transaction.atomic(using=router.db_for_write(self.model)): return self._add_view(request, form_url, extra_context) - def _add_view(self, request, form_url='', extra_context=None): + def _add_view(self, request, form_url="", extra_context=None): # It's an error for a user to have add permission but NOT change # permission for users. If we allowed such users to add users, they # could create superusers, which would mean they would essentially have @@ -111,42 +131,47 @@ class UserAdmin(admin.ModelAdmin): # error message. raise Http404( 'Your user does not have the "Change user" permission. In ' - 'order to add users, Django requires that your user ' + "order to add users, Django requires that your user " 'account have both the "Add user" and "Change user" ' - 'permissions set.') + "permissions set." + ) raise PermissionDenied if extra_context is None: extra_context = {} username_field = self.model._meta.get_field(self.model.USERNAME_FIELD) defaults = { - 'auto_populated_fields': (), - 'username_help_text': username_field.help_text, + "auto_populated_fields": (), + "username_help_text": username_field.help_text, } extra_context.update(defaults) return super().add_view(request, form_url, extra_context) @sensitive_post_parameters_m - def user_change_password(self, request, id, form_url=''): + def user_change_password(self, request, id, form_url=""): user = self.get_object(request, unquote(id)) if not self.has_change_permission(request, user): raise PermissionDenied if user is None: - raise Http404(_('%(name)s object with primary key %(key)r does not exist.') % { - 'name': self.model._meta.verbose_name, - 'key': escape(id), - }) - if request.method == 'POST': + raise Http404( + _("%(name)s object with primary key %(key)r does not exist.") + % { + "name": self.model._meta.verbose_name, + "key": escape(id), + } + ) + if request.method == "POST": form = self.change_password_form(user, request.POST) if form.is_valid(): form.save() change_message = self.construct_change_message(request, form, None) self.log_change(request, user, change_message) - msg = gettext('Password changed successfully.') + msg = gettext("Password changed successfully.") messages.success(request, msg) update_session_auth_hash(request, form.user) return HttpResponseRedirect( reverse( - '%s:%s_%s_change' % ( + "%s:%s_%s_change" + % ( self.admin_site.name, user._meta.app_label, user._meta.model_name, @@ -157,26 +182,25 @@ class UserAdmin(admin.ModelAdmin): else: form = self.change_password_form(user) - fieldsets = [(None, {'fields': list(form.base_fields)})] + fieldsets = [(None, {"fields": list(form.base_fields)})] adminForm = admin.helpers.AdminForm(form, fieldsets, {}) context = { - 'title': _('Change password: %s') % escape(user.get_username()), - 'adminForm': adminForm, - 'form_url': form_url, - 'form': form, - 'is_popup': (IS_POPUP_VAR in request.POST or - IS_POPUP_VAR in request.GET), - 'is_popup_var': IS_POPUP_VAR, - 'add': True, - 'change': False, - 'has_delete_permission': False, - 'has_change_permission': True, - 'has_absolute_url': False, - 'opts': self.model._meta, - 'original': user, - 'save_as': False, - 'show_save': True, + "title": _("Change password: %s") % escape(user.get_username()), + "adminForm": adminForm, + "form_url": form_url, + "form": form, + "is_popup": (IS_POPUP_VAR in request.POST or IS_POPUP_VAR in request.GET), + "is_popup_var": IS_POPUP_VAR, + "add": True, + "change": False, + "has_delete_permission": False, + "has_change_permission": True, + "has_absolute_url": False, + "opts": self.model._meta, + "original": user, + "save_as": False, + "show_save": True, **self.admin_site.each_context(request), } @@ -184,8 +208,8 @@ class UserAdmin(admin.ModelAdmin): return TemplateResponse( request, - self.change_user_password_template or - 'admin/auth/user/change_password.html', + self.change_user_password_template + or "admin/auth/user/change_password.html", context, ) @@ -200,7 +224,7 @@ class UserAdmin(admin.ModelAdmin): # button except in two scenarios: # * The user has pressed the 'Save and add another' button # * We are adding a user in a popup - if '_addanother' not in request.POST and IS_POPUP_VAR not in request.POST: + if "_addanother" not in request.POST and IS_POPUP_VAR not in request.POST: request.POST = request.POST.copy() - request.POST['_continue'] = 1 + request.POST["_continue"] = 1 return super().response_add(request, obj, post_url_continue) diff --git a/django/contrib/auth/apps.py b/django/contrib/auth/apps.py index 4e4ef06d27..4882a27c42 100644 --- a/django/contrib/auth/apps.py +++ b/django/contrib/auth/apps.py @@ -11,19 +11,20 @@ from .signals import user_logged_in class AuthConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.auth' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.auth" verbose_name = _("Authentication and Authorization") def ready(self): post_migrate.connect( create_permissions, - dispatch_uid="django.contrib.auth.management.create_permissions" + dispatch_uid="django.contrib.auth.management.create_permissions", ) - last_login_field = getattr(get_user_model(), 'last_login', None) + last_login_field = getattr(get_user_model(), "last_login", None) # Register the handler only if UserModel.last_login is a field. if isinstance(last_login_field, DeferredAttribute): from .models import update_last_login - user_logged_in.connect(update_last_login, dispatch_uid='update_last_login') + + user_logged_in.connect(update_last_login, dispatch_uid="update_last_login") checks.register(check_user_model, checks.Tags.models) checks.register(check_models_permissions, checks.Tags.models) diff --git a/django/contrib/auth/backends.py b/django/contrib/auth/backends.py index 0ddbd9ca77..7cf405713d 100644 --- a/django/contrib/auth/backends.py +++ b/django/contrib/auth/backends.py @@ -53,15 +53,15 @@ class ModelBackend(BaseBackend): Reject users with is_active=False. Custom user models that don't have that attribute are allowed. """ - is_active = getattr(user, 'is_active', None) + is_active = getattr(user, "is_active", None) return is_active or is_active is None def _get_user_permissions(self, user_obj): return user_obj.user_permissions.all() def _get_group_permissions(self, user_obj): - user_groups_field = get_user_model()._meta.get_field('groups') - user_groups_query = 'group__%s' % user_groups_field.related_query_name() + user_groups_field = get_user_model()._meta.get_field("groups") + user_groups_query = "group__%s" % user_groups_field.related_query_name() return Permission.objects.filter(**{user_groups_query: user_obj}) def _get_permissions(self, user_obj, obj, from_name): @@ -73,14 +73,16 @@ class ModelBackend(BaseBackend): if not user_obj.is_active or user_obj.is_anonymous or obj is not None: return set() - perm_cache_name = '_%s_perm_cache' % from_name + perm_cache_name = "_%s_perm_cache" % from_name if not hasattr(user_obj, perm_cache_name): if user_obj.is_superuser: perms = Permission.objects.all() else: - perms = getattr(self, '_get_%s_permissions' % from_name)(user_obj) - perms = perms.values_list('content_type__app_label', 'codename').order_by() - setattr(user_obj, perm_cache_name, {"%s.%s" % (ct, name) for ct, name in perms}) + perms = getattr(self, "_get_%s_permissions" % from_name)(user_obj) + perms = perms.values_list("content_type__app_label", "codename").order_by() + setattr( + user_obj, perm_cache_name, {"%s.%s" % (ct, name) for ct, name in perms} + ) return getattr(user_obj, perm_cache_name) def get_user_permissions(self, user_obj, obj=None): @@ -88,19 +90,19 @@ class ModelBackend(BaseBackend): Return a set of permission strings the user `user_obj` has from their `user_permissions`. """ - return self._get_permissions(user_obj, obj, 'user') + return self._get_permissions(user_obj, obj, "user") def get_group_permissions(self, user_obj, obj=None): """ Return a set of permission strings the user `user_obj` has from the groups they belong. """ - return self._get_permissions(user_obj, obj, 'group') + return self._get_permissions(user_obj, obj, "group") def get_all_permissions(self, user_obj, obj=None): if not user_obj.is_active or user_obj.is_anonymous or obj is not None: return set() - if not hasattr(user_obj, '_perm_cache'): + if not hasattr(user_obj, "_perm_cache"): user_obj._perm_cache = super().get_all_permissions(user_obj) return user_obj._perm_cache @@ -112,7 +114,7 @@ class ModelBackend(BaseBackend): Return True if user_obj has any permissions in the given app_label. """ return user_obj.is_active and any( - perm[:perm.index('.')] == app_label + perm[: perm.index(".")] == app_label for perm in self.get_all_permissions(user_obj) ) @@ -123,21 +125,21 @@ class ModelBackend(BaseBackend): """ if isinstance(perm, str): try: - app_label, codename = perm.split('.') + app_label, codename = perm.split(".") except ValueError: raise ValueError( - 'Permission name should be in the form ' - 'app_label.permission_codename.' + "Permission name should be in the form " + "app_label.permission_codename." ) elif not isinstance(perm, Permission): raise TypeError( - 'The `perm` argument must be a string or a permission instance.' + "The `perm` argument must be a string or a permission instance." ) if obj is not None: return UserModel._default_manager.none() - permission_q = Q(group__user=OuterRef('pk')) | Q(user=OuterRef('pk')) + permission_q = Q(group__user=OuterRef("pk")) | Q(user=OuterRef("pk")) if isinstance(perm, Permission): permission_q &= Q(pk=perm.pk) else: @@ -197,9 +199,9 @@ class RemoteUserBackend(ModelBackend): # instead we use get_or_create when creating unknown users since it has # built-in safeguards for multiple threads. if self.create_unknown_user: - user, created = UserModel._default_manager.get_or_create(**{ - UserModel.USERNAME_FIELD: username - }) + user, created = UserModel._default_manager.get_or_create( + **{UserModel.USERNAME_FIELD: username} + ) if created: user = self.configure_user(request, user) else: diff --git a/django/contrib/auth/base_user.py b/django/contrib/auth/base_user.py index cbfe5d686a..f6de3b9317 100644 --- a/django/contrib/auth/base_user.py +++ b/django/contrib/auth/base_user.py @@ -6,7 +6,9 @@ import unicodedata from django.contrib.auth import password_validation from django.contrib.auth.hashers import ( - check_password, is_password_usable, make_password, + check_password, + is_password_usable, + make_password, ) from django.db import models from django.utils.crypto import get_random_string, salted_hmac @@ -14,25 +16,24 @@ from django.utils.translation import gettext_lazy as _ class BaseUserManager(models.Manager): - @classmethod def normalize_email(cls, email): """ Normalize the email address by lowercasing the domain part of it. """ - email = email or '' + email = email or "" try: - email_name, domain_part = email.strip().rsplit('@', 1) + email_name, domain_part = email.strip().rsplit("@", 1) except ValueError: pass else: - email = email_name + '@' + domain_part.lower() + email = email_name + "@" + domain_part.lower() return email def make_random_password( self, length=10, - allowed_chars='abcdefghjkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789', + allowed_chars="abcdefghjkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789", ): """ Generate a random password with the given length and given @@ -46,8 +47,8 @@ class BaseUserManager(models.Manager): class AbstractBaseUser(models.Model): - password = models.CharField(_('password'), max_length=128) - last_login = models.DateTimeField(_('last login'), blank=True, null=True) + password = models.CharField(_("password"), max_length=128) + last_login = models.DateTimeField(_("last login"), blank=True, null=True) is_active = True @@ -104,11 +105,13 @@ class AbstractBaseUser(models.Model): Return a boolean of whether the raw_password was correct. Handles hashing formats behind the scenes. """ + def setter(raw_password): self.set_password(raw_password) # Password hash upgrades shouldn't be considered password changes. self._password = None self.save(update_fields=["password"]) + return check_password(raw_password, self.password, setter) def set_unusable_password(self): @@ -129,7 +132,7 @@ class AbstractBaseUser(models.Model): return salted_hmac( key_salt, self.password, - algorithm='sha256', + algorithm="sha256", ).hexdigest() @classmethod @@ -137,8 +140,12 @@ class AbstractBaseUser(models.Model): try: return cls.EMAIL_FIELD except AttributeError: - return 'email' + return "email" @classmethod def normalize_username(cls, username): - return unicodedata.normalize('NFKC', username) if isinstance(username, str) else username + return ( + unicodedata.normalize("NFKC", username) + if isinstance(username, str) + else username + ) diff --git a/django/contrib/auth/checks.py b/django/contrib/auth/checks.py index c08ed8a49a..e1505d63fd 100644 --- a/django/contrib/auth/checks.py +++ b/django/contrib/auth/checks.py @@ -12,7 +12,7 @@ def check_user_model(app_configs=None, **kwargs): if app_configs is None: cls = apps.get_model(settings.AUTH_USER_MODEL) else: - app_label, model_name = settings.AUTH_USER_MODEL.split('.') + app_label, model_name = settings.AUTH_USER_MODEL.split(".") for app_config in app_configs: if app_config.label == app_label: cls = app_config.get_model(model_name) @@ -31,7 +31,7 @@ def check_user_model(app_configs=None, **kwargs): checks.Error( "'REQUIRED_FIELDS' must be a list or tuple.", obj=cls, - id='auth.E001', + id="auth.E001", ) ) @@ -47,7 +47,7 @@ def check_user_model(app_configs=None, **kwargs): % (cls.USERNAME_FIELD, cls.USERNAME_FIELD) ), obj=cls, - id='auth.E002', + id="auth.E002", ) ) @@ -56,47 +56,46 @@ def check_user_model(app_configs=None, **kwargs): constraint.fields == (cls.USERNAME_FIELD,) for constraint in cls._meta.total_unique_constraints ): - if (settings.AUTHENTICATION_BACKENDS == - ['django.contrib.auth.backends.ModelBackend']): + if settings.AUTHENTICATION_BACKENDS == [ + "django.contrib.auth.backends.ModelBackend" + ]: errors.append( checks.Error( - "'%s.%s' must be unique because it is named as the 'USERNAME_FIELD'." % ( - cls._meta.object_name, cls.USERNAME_FIELD - ), + "'%s.%s' must be unique because it is named as the 'USERNAME_FIELD'." + % (cls._meta.object_name, cls.USERNAME_FIELD), obj=cls, - id='auth.E003', + id="auth.E003", ) ) else: errors.append( checks.Warning( - "'%s.%s' is named as the 'USERNAME_FIELD', but it is not unique." % ( - cls._meta.object_name, cls.USERNAME_FIELD - ), - hint='Ensure that your authentication backend(s) can handle non-unique usernames.', + "'%s.%s' is named as the 'USERNAME_FIELD', but it is not unique." + % (cls._meta.object_name, cls.USERNAME_FIELD), + hint="Ensure that your authentication backend(s) can handle non-unique usernames.", obj=cls, - id='auth.W004', + id="auth.W004", ) ) if isinstance(cls().is_anonymous, MethodType): errors.append( checks.Critical( - '%s.is_anonymous must be an attribute or property rather than ' - 'a method. Ignoring this is a security issue as anonymous ' - 'users will be treated as authenticated!' % cls, + "%s.is_anonymous must be an attribute or property rather than " + "a method. Ignoring this is a security issue as anonymous " + "users will be treated as authenticated!" % cls, obj=cls, - id='auth.C009', + id="auth.C009", ) ) if isinstance(cls().is_authenticated, MethodType): errors.append( checks.Critical( - '%s.is_authenticated must be an attribute or property rather ' - 'than a method. Ignoring this is a security issue as anonymous ' - 'users will be treated as authenticated!' % cls, + "%s.is_authenticated must be an attribute or property rather " + "than a method. Ignoring this is a security issue as anonymous " + "users will be treated as authenticated!" % cls, obj=cls, - id='auth.C010', + id="auth.C010", ) ) return errors @@ -106,11 +105,13 @@ def check_models_permissions(app_configs=None, **kwargs): if app_configs is None: models = apps.get_models() else: - models = chain.from_iterable(app_config.get_models() for app_config in app_configs) + models = chain.from_iterable( + app_config.get_models() for app_config in app_configs + ) - Permission = apps.get_model('auth', 'Permission') - permission_name_max_length = Permission._meta.get_field('name').max_length - permission_codename_max_length = Permission._meta.get_field('codename').max_length + Permission = apps.get_model("auth", "Permission") + permission_name_max_length = Permission._meta.get_field("name").max_length + permission_codename_max_length = Permission._meta.get_field("codename").max_length errors = [] for model in models: @@ -119,27 +120,28 @@ def check_models_permissions(app_configs=None, **kwargs): # Check builtin permission name length. max_builtin_permission_name_length = ( max(len(name) for name in builtin_permissions.values()) - if builtin_permissions else 0 + if builtin_permissions + else 0 ) if max_builtin_permission_name_length > permission_name_max_length: - verbose_name_max_length = ( - permission_name_max_length - (max_builtin_permission_name_length - len(opts.verbose_name_raw)) + verbose_name_max_length = permission_name_max_length - ( + max_builtin_permission_name_length - len(opts.verbose_name_raw) ) errors.append( checks.Error( "The verbose_name of model '%s' must be at most %d " "characters for its builtin permission names to be at " - "most %d characters." % ( - opts.label, verbose_name_max_length, permission_name_max_length - ), + "most %d characters." + % (opts.label, verbose_name_max_length, permission_name_max_length), obj=model, - id='auth.E007', + id="auth.E007", ) ) # Check builtin permission codename length. max_builtin_permission_codename_length = ( max(len(codename) for codename in builtin_permissions.keys()) - if builtin_permissions else 0 + if builtin_permissions + else 0 ) if max_builtin_permission_codename_length > permission_codename_max_length: model_name_max_length = permission_codename_max_length - ( @@ -149,13 +151,14 @@ def check_models_permissions(app_configs=None, **kwargs): checks.Error( "The name of model '%s' must be at most %d characters " "for its builtin permission codenames to be at most %d " - "characters." % ( + "characters." + % ( opts.label, model_name_max_length, permission_codename_max_length, ), obj=model, - id='auth.E011', + id="auth.E011", ) ) codenames = set() @@ -165,11 +168,14 @@ def check_models_permissions(app_configs=None, **kwargs): errors.append( checks.Error( "The permission named '%s' of model '%s' is longer " - "than %d characters." % ( - name, opts.label, permission_name_max_length, + "than %d characters." + % ( + name, + opts.label, + permission_name_max_length, ), obj=model, - id='auth.E008', + id="auth.E008", ) ) # Check custom permission codename length. @@ -177,13 +183,14 @@ def check_models_permissions(app_configs=None, **kwargs): errors.append( checks.Error( "The permission codenamed '%s' of model '%s' is " - "longer than %d characters." % ( + "longer than %d characters." + % ( codename, opts.label, permission_codename_max_length, ), obj=model, - id='auth.E012', + id="auth.E012", ) ) # Check custom permissions codename clashing. @@ -193,7 +200,7 @@ def check_models_permissions(app_configs=None, **kwargs): "The permission codenamed '%s' clashes with a builtin permission " "for model '%s'." % (codename, opts.label), obj=model, - id='auth.E005', + id="auth.E005", ) ) elif codename in codenames: @@ -202,7 +209,7 @@ def check_models_permissions(app_configs=None, **kwargs): "The permission codenamed '%s' is duplicated for " "model '%s'." % (codename, opts.label), obj=model, - id='auth.E006', + id="auth.E006", ) ) codenames.add(codename) diff --git a/django/contrib/auth/context_processors.py b/django/contrib/auth/context_processors.py index 3dfd9177a9..0a88199227 100644 --- a/django/contrib/auth/context_processors.py +++ b/django/contrib/auth/context_processors.py @@ -26,7 +26,7 @@ class PermWrapper: self.user = user def __repr__(self): - return f'{self.__class__.__qualname__}({self.user!r})' + return f"{self.__class__.__qualname__}({self.user!r})" def __getitem__(self, app_label): return PermLookupDict(self.user, app_label) @@ -39,10 +39,10 @@ class PermWrapper: """ Lookup by "someapp" or "someapp.someperm" in perms. """ - if '.' not in perm_name: + if "." not in perm_name: # The name refers to module. return bool(self[perm_name]) - app_label, perm_name = perm_name.split('.', 1) + app_label, perm_name = perm_name.split(".", 1) return self[app_label][perm_name] @@ -54,13 +54,14 @@ def auth(request): If there is no 'user' attribute in the request, use AnonymousUser (from django.contrib.auth). """ - if hasattr(request, 'user'): + if hasattr(request, "user"): user = request.user else: from django.contrib.auth.models import AnonymousUser + user = AnonymousUser() return { - 'user': user, - 'perms': PermWrapper(user), + "user": user, + "perms": PermWrapper(user), } diff --git a/django/contrib/auth/decorators.py b/django/contrib/auth/decorators.py index 53f62e8371..a419068d49 100644 --- a/django/contrib/auth/decorators.py +++ b/django/contrib/auth/decorators.py @@ -7,7 +7,9 @@ from django.core.exceptions import PermissionDenied from django.shortcuts import resolve_url -def user_passes_test(test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME): +def user_passes_test( + test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME +): """ Decorator for views that checks that the user passes the given test, redirecting to the log-in page if necessary. The test should be a callable @@ -25,17 +27,22 @@ def user_passes_test(test_func, login_url=None, redirect_field_name=REDIRECT_FIE # use the path as the "next" url. login_scheme, login_netloc = urlparse(resolved_login_url)[:2] current_scheme, current_netloc = urlparse(path)[:2] - if ((not login_scheme or login_scheme == current_scheme) and - (not login_netloc or login_netloc == current_netloc)): + if (not login_scheme or login_scheme == current_scheme) and ( + not login_netloc or login_netloc == current_netloc + ): path = request.get_full_path() from django.contrib.auth.views import redirect_to_login - return redirect_to_login( - path, resolved_login_url, redirect_field_name) + + return redirect_to_login(path, resolved_login_url, redirect_field_name) + return _wrapped_view + return decorator -def login_required(function=None, redirect_field_name=REDIRECT_FIELD_NAME, login_url=None): +def login_required( + function=None, redirect_field_name=REDIRECT_FIELD_NAME, login_url=None +): """ Decorator for views that checks that the user is logged in, redirecting to the log-in page if necessary. @@ -43,7 +50,7 @@ def login_required(function=None, redirect_field_name=REDIRECT_FIELD_NAME, login actual_decorator = user_passes_test( lambda u: u.is_authenticated, login_url=login_url, - redirect_field_name=redirect_field_name + redirect_field_name=redirect_field_name, ) if function: return actual_decorator(function) @@ -57,6 +64,7 @@ def permission_required(perm, login_url=None, raise_exception=False): If the raise_exception parameter is given the PermissionDenied exception is raised. """ + def check_perms(user): if isinstance(perm, str): perms = (perm,) @@ -70,4 +78,5 @@ def permission_required(perm, login_url=None, raise_exception=False): raise PermissionDenied # As the last resort, show the login form return False + return user_passes_test(check_perms, login_url=login_url) diff --git a/django/contrib/auth/forms.py b/django/contrib/auth/forms.py index 25d500962b..9d6a9918bb 100644 --- a/django/contrib/auth/forms.py +++ b/django/contrib/auth/forms.py @@ -1,12 +1,8 @@ import unicodedata from django import forms -from django.contrib.auth import ( - authenticate, get_user_model, password_validation, -) -from django.contrib.auth.hashers import ( - UNUSABLE_PASSWORD_PREFIX, identify_hasher, -) +from django.contrib.auth import authenticate, get_user_model, password_validation +from django.contrib.auth.hashers import UNUSABLE_PASSWORD_PREFIX, identify_hasher from django.contrib.auth.models import User from django.contrib.auth.tokens import default_token_generator from django.contrib.sites.shortcuts import get_current_site @@ -16,7 +12,8 @@ from django.template import loader from django.utils.encoding import force_bytes from django.utils.http import urlsafe_base64_encode from django.utils.text import capfirst -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ UserModel = get_user_model() @@ -27,27 +24,36 @@ def _unicode_ci_compare(s1, s2): recommended algorithm from Unicode Technical Report 36, section 2.11.2(B)(2). """ - return unicodedata.normalize('NFKC', s1).casefold() == unicodedata.normalize('NFKC', s2).casefold() + return ( + unicodedata.normalize("NFKC", s1).casefold() + == unicodedata.normalize("NFKC", s2).casefold() + ) class ReadOnlyPasswordHashWidget(forms.Widget): - template_name = 'auth/widgets/read_only_password_hash.html' + template_name = "auth/widgets/read_only_password_hash.html" read_only = True def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) summary = [] if not value or value.startswith(UNUSABLE_PASSWORD_PREFIX): - summary.append({'label': gettext("No password set.")}) + summary.append({"label": gettext("No password set.")}) else: try: hasher = identify_hasher(value) except ValueError: - summary.append({'label': gettext("Invalid password format or unknown hashing algorithm.")}) + summary.append( + { + "label": gettext( + "Invalid password format or unknown hashing algorithm." + ) + } + ) else: for key, value_ in hasher.safe_summary(value).items(): - summary.append({'label': gettext(key), 'value': value_}) - context['summary'] = summary + summary.append({"label": gettext(key), "value": value_}) + context["summary"] = summary return context def id_for_label(self, id_): @@ -59,19 +65,19 @@ class ReadOnlyPasswordHashField(forms.Field): def __init__(self, *args, **kwargs): kwargs.setdefault("required", False) - kwargs.setdefault('disabled', True) + kwargs.setdefault("disabled", True) super().__init__(*args, **kwargs) class UsernameField(forms.CharField): def to_python(self, value): - return unicodedata.normalize('NFKC', super().to_python(value)) + return unicodedata.normalize("NFKC", super().to_python(value)) def widget_attrs(self, widget): return { **super().widget_attrs(widget), - 'autocapitalize': 'none', - 'autocomplete': 'username', + "autocapitalize": "none", + "autocomplete": "username", } @@ -80,18 +86,19 @@ class UserCreationForm(forms.ModelForm): A form that creates a user, with no privileges, from the given username and password. """ + error_messages = { - 'password_mismatch': _('The two password fields didn’t match.'), + "password_mismatch": _("The two password fields didn’t match."), } password1 = forms.CharField( label=_("Password"), strip=False, - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "new-password"}), help_text=password_validation.password_validators_help_text_html(), ) password2 = forms.CharField( label=_("Password confirmation"), - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "new-password"}), strip=False, help_text=_("Enter the same password as before, for verification."), ) @@ -99,20 +106,22 @@ class UserCreationForm(forms.ModelForm): class Meta: model = User fields = ("username",) - field_classes = {'username': UsernameField} + field_classes = {"username": UsernameField} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self._meta.model.USERNAME_FIELD in self.fields: - self.fields[self._meta.model.USERNAME_FIELD].widget.attrs['autofocus'] = True + self.fields[self._meta.model.USERNAME_FIELD].widget.attrs[ + "autofocus" + ] = True def clean_password2(self): password1 = self.cleaned_data.get("password1") password2 = self.cleaned_data.get("password2") if password1 and password2 and password1 != password2: raise ValidationError( - self.error_messages['password_mismatch'], - code='password_mismatch', + self.error_messages["password_mismatch"], + code="password_mismatch", ) return password2 @@ -120,12 +129,12 @@ class UserCreationForm(forms.ModelForm): super()._post_clean() # Validate the password after self.instance is updated with form data # by super(). - password = self.cleaned_data.get('password2') + password = self.cleaned_data.get("password2") if password: try: password_validation.validate_password(password, self.instance) except ValidationError as error: - self.add_error('password2', error) + self.add_error("password2", error) def save(self, commit=True): user = super().save(commit=False) @@ -139,25 +148,27 @@ class UserChangeForm(forms.ModelForm): password = ReadOnlyPasswordHashField( label=_("Password"), help_text=_( - 'Raw passwords are not stored, so there is no way to see this ' - 'user’s password, but you can change the password using ' + "Raw passwords are not stored, so there is no way to see this " + "user’s password, but you can change the password using " '<a href="{}">this form</a>.' ), ) class Meta: model = User - fields = '__all__' - field_classes = {'username': UsernameField} + fields = "__all__" + field_classes = {"username": UsernameField} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - password = self.fields.get('password') + password = self.fields.get("password") if password: - password.help_text = password.help_text.format('../password/') - user_permissions = self.fields.get('user_permissions') + password.help_text = password.help_text.format("../password/") + user_permissions = self.fields.get("user_permissions") if user_permissions: - user_permissions.queryset = user_permissions.queryset.select_related('content_type') + user_permissions.queryset = user_permissions.queryset.select_related( + "content_type" + ) class AuthenticationForm(forms.Form): @@ -165,19 +176,20 @@ class AuthenticationForm(forms.Form): Base class for authenticating users. Extend this to get a form that accepts username/password logins. """ - username = UsernameField(widget=forms.TextInput(attrs={'autofocus': True})) + + username = UsernameField(widget=forms.TextInput(attrs={"autofocus": True})) password = forms.CharField( label=_("Password"), strip=False, - widget=forms.PasswordInput(attrs={'autocomplete': 'current-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "current-password"}), ) error_messages = { - 'invalid_login': _( + "invalid_login": _( "Please enter a correct %(username)s and password. Note that both " "fields may be case-sensitive." ), - 'inactive': _("This account is inactive."), + "inactive": _("This account is inactive."), } def __init__(self, request=None, *args, **kwargs): @@ -192,17 +204,19 @@ class AuthenticationForm(forms.Form): # Set the max length and label for the "username" field. self.username_field = UserModel._meta.get_field(UserModel.USERNAME_FIELD) username_max_length = self.username_field.max_length or 254 - self.fields['username'].max_length = username_max_length - self.fields['username'].widget.attrs['maxlength'] = username_max_length - if self.fields['username'].label is None: - self.fields['username'].label = capfirst(self.username_field.verbose_name) + self.fields["username"].max_length = username_max_length + self.fields["username"].widget.attrs["maxlength"] = username_max_length + if self.fields["username"].label is None: + self.fields["username"].label = capfirst(self.username_field.verbose_name) def clean(self): - username = self.cleaned_data.get('username') - password = self.cleaned_data.get('password') + username = self.cleaned_data.get("username") + password = self.cleaned_data.get("password") if username is not None and password: - self.user_cache = authenticate(self.request, username=username, password=password) + self.user_cache = authenticate( + self.request, username=username, password=password + ) if self.user_cache is None: raise self.get_invalid_login_error() else: @@ -223,8 +237,8 @@ class AuthenticationForm(forms.Form): """ if not user.is_active: raise ValidationError( - self.error_messages['inactive'], - code='inactive', + self.error_messages["inactive"], + code="inactive", ) def get_user(self): @@ -232,9 +246,9 @@ class AuthenticationForm(forms.Form): def get_invalid_login_error(self): return ValidationError( - self.error_messages['invalid_login'], - code='invalid_login', - params={'username': self.username_field.verbose_name}, + self.error_messages["invalid_login"], + code="invalid_login", + params={"username": self.username_field.verbose_name}, ) @@ -242,23 +256,30 @@ class PasswordResetForm(forms.Form): email = forms.EmailField( label=_("Email"), max_length=254, - widget=forms.EmailInput(attrs={'autocomplete': 'email'}) + widget=forms.EmailInput(attrs={"autocomplete": "email"}), ) - def send_mail(self, subject_template_name, email_template_name, - context, from_email, to_email, html_email_template_name=None): + def send_mail( + self, + subject_template_name, + email_template_name, + context, + from_email, + to_email, + html_email_template_name=None, + ): """ Send a django.core.mail.EmailMultiAlternatives to `to_email`. """ subject = loader.render_to_string(subject_template_name, context) # Email subject *must not* contain newlines - subject = ''.join(subject.splitlines()) + subject = "".join(subject.splitlines()) body = loader.render_to_string(email_template_name, context) email_message = EmailMultiAlternatives(subject, body, from_email, [to_email]) if html_email_template_name is not None: html_email = loader.render_to_string(html_email_template_name, context) - email_message.attach_alternative(html_email, 'text/html') + email_message.attach_alternative(html_email, "text/html") email_message.send() @@ -270,22 +291,31 @@ class PasswordResetForm(forms.Form): resetting their password. """ email_field_name = UserModel.get_email_field_name() - active_users = UserModel._default_manager.filter(**{ - '%s__iexact' % email_field_name: email, - 'is_active': True, - }) + active_users = UserModel._default_manager.filter( + **{ + "%s__iexact" % email_field_name: email, + "is_active": True, + } + ) return ( - u for u in active_users - if u.has_usable_password() and - _unicode_ci_compare(email, getattr(u, email_field_name)) + u + for u in active_users + if u.has_usable_password() + and _unicode_ci_compare(email, getattr(u, email_field_name)) ) - def save(self, domain_override=None, - subject_template_name='registration/password_reset_subject.txt', - email_template_name='registration/password_reset_email.html', - use_https=False, token_generator=default_token_generator, - from_email=None, request=None, html_email_template_name=None, - extra_email_context=None): + def save( + self, + domain_override=None, + subject_template_name="registration/password_reset_subject.txt", + email_template_name="registration/password_reset_email.html", + use_https=False, + token_generator=default_token_generator, + from_email=None, + request=None, + html_email_template_name=None, + extra_email_context=None, + ): """ Generate a one-use only link for resetting password and send it to the user. @@ -301,18 +331,22 @@ class PasswordResetForm(forms.Form): for user in self.get_users(email): user_email = getattr(user, email_field_name) context = { - 'email': user_email, - 'domain': domain, - 'site_name': site_name, - 'uid': urlsafe_base64_encode(force_bytes(user.pk)), - 'user': user, - 'token': token_generator.make_token(user), - 'protocol': 'https' if use_https else 'http', + "email": user_email, + "domain": domain, + "site_name": site_name, + "uid": urlsafe_base64_encode(force_bytes(user.pk)), + "user": user, + "token": token_generator.make_token(user), + "protocol": "https" if use_https else "http", **(extra_email_context or {}), } self.send_mail( - subject_template_name, email_template_name, context, from_email, - user_email, html_email_template_name=html_email_template_name, + subject_template_name, + email_template_name, + context, + from_email, + user_email, + html_email_template_name=html_email_template_name, ) @@ -321,19 +355,20 @@ class SetPasswordForm(forms.Form): A form that lets a user change set their password without entering the old password """ + error_messages = { - 'password_mismatch': _('The two password fields didn’t match.'), + "password_mismatch": _("The two password fields didn’t match."), } new_password1 = forms.CharField( label=_("New password"), - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "new-password"}), strip=False, help_text=password_validation.password_validators_help_text_html(), ) new_password2 = forms.CharField( label=_("New password confirmation"), strip=False, - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "new-password"}), ) def __init__(self, user, *args, **kwargs): @@ -341,13 +376,13 @@ class SetPasswordForm(forms.Form): super().__init__(*args, **kwargs) def clean_new_password2(self): - password1 = self.cleaned_data.get('new_password1') - password2 = self.cleaned_data.get('new_password2') + password1 = self.cleaned_data.get("new_password1") + password2 = self.cleaned_data.get("new_password2") if password1 and password2: if password1 != password2: raise ValidationError( - self.error_messages['password_mismatch'], - code='password_mismatch', + self.error_messages["password_mismatch"], + code="password_mismatch", ) password_validation.validate_password(password2, self.user) return password2 @@ -365,17 +400,22 @@ class PasswordChangeForm(SetPasswordForm): A form that lets a user change their password by entering their old password. """ + error_messages = { **SetPasswordForm.error_messages, - 'password_incorrect': _("Your old password was entered incorrectly. Please enter it again."), + "password_incorrect": _( + "Your old password was entered incorrectly. Please enter it again." + ), } old_password = forms.CharField( label=_("Old password"), strip=False, - widget=forms.PasswordInput(attrs={'autocomplete': 'current-password', 'autofocus': True}), + widget=forms.PasswordInput( + attrs={"autocomplete": "current-password", "autofocus": True} + ), ) - field_order = ['old_password', 'new_password1', 'new_password2'] + field_order = ["old_password", "new_password1", "new_password2"] def clean_old_password(self): """ @@ -384,8 +424,8 @@ class PasswordChangeForm(SetPasswordForm): old_password = self.cleaned_data["old_password"] if not self.user.check_password(old_password): raise ValidationError( - self.error_messages['password_incorrect'], - code='password_incorrect', + self.error_messages["password_incorrect"], + code="password_incorrect", ) return old_password @@ -394,19 +434,22 @@ class AdminPasswordChangeForm(forms.Form): """ A form used to change the password of a user in the admin interface. """ + error_messages = { - 'password_mismatch': _('The two password fields didn’t match.'), + "password_mismatch": _("The two password fields didn’t match."), } - required_css_class = 'required' + required_css_class = "required" password1 = forms.CharField( label=_("Password"), - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password', 'autofocus': True}), + widget=forms.PasswordInput( + attrs={"autocomplete": "new-password", "autofocus": True} + ), strip=False, help_text=password_validation.password_validators_help_text_html(), ) password2 = forms.CharField( label=_("Password (again)"), - widget=forms.PasswordInput(attrs={'autocomplete': 'new-password'}), + widget=forms.PasswordInput(attrs={"autocomplete": "new-password"}), strip=False, help_text=_("Enter the same password as before, for verification."), ) @@ -416,12 +459,12 @@ class AdminPasswordChangeForm(forms.Form): super().__init__(*args, **kwargs) def clean_password2(self): - password1 = self.cleaned_data.get('password1') - password2 = self.cleaned_data.get('password2') + password1 = self.cleaned_data.get("password1") + password2 = self.cleaned_data.get("password2") if password1 and password2 and password1 != password2: raise ValidationError( - self.error_messages['password_mismatch'], - code='password_mismatch', + self.error_messages["password_mismatch"], + code="password_mismatch", ) password_validation.validate_password(password2, self.user) return password2 @@ -440,4 +483,4 @@ class AdminPasswordChangeForm(forms.Form): for name in self.fields: if name not in data: return [] - return ['password'] + return ["password"] diff --git a/django/contrib/auth/hashers.py b/django/contrib/auth/hashers.py index b0b3e730b8..3cdbaa75b0 100644 --- a/django/contrib/auth/hashers.py +++ b/django/contrib/auth/hashers.py @@ -11,13 +11,19 @@ from django.core.exceptions import ImproperlyConfigured from django.core.signals import setting_changed from django.dispatch import receiver from django.utils.crypto import ( - RANDOM_STRING_CHARS, constant_time_compare, get_random_string, md5, pbkdf2, + RANDOM_STRING_CHARS, + constant_time_compare, + get_random_string, + md5, + pbkdf2, ) from django.utils.module_loading import import_string from django.utils.translation import gettext_noop as _ -UNUSABLE_PASSWORD_PREFIX = '!' # This will never be a valid encoded hash -UNUSABLE_PASSWORD_SUFFIX_LENGTH = 40 # number of random chars to add after UNUSABLE_PASSWORD_PREFIX +UNUSABLE_PASSWORD_PREFIX = "!" # This will never be a valid encoded hash +UNUSABLE_PASSWORD_SUFFIX_LENGTH = ( + 40 # number of random chars to add after UNUSABLE_PASSWORD_PREFIX +) def is_password_usable(encoded): @@ -28,7 +34,7 @@ def is_password_usable(encoded): return encoded is None or not encoded.startswith(UNUSABLE_PASSWORD_PREFIX) -def check_password(password, encoded, setter=None, preferred='default'): +def check_password(password, encoded, setter=None, preferred="default"): """ Return a boolean of whether the raw password matches the three part encoded digest. @@ -62,7 +68,7 @@ def check_password(password, encoded, setter=None, preferred='default'): return is_correct -def make_password(password, salt=None, hasher='default'): +def make_password(password, salt=None, hasher="default"): """ Turn a plain-text password into a hash for database storage @@ -72,11 +78,12 @@ def make_password(password, salt=None, hasher='default'): access to staff or superuser accounts. See ticket #20079 for more info. """ if password is None: - return UNUSABLE_PASSWORD_PREFIX + get_random_string(UNUSABLE_PASSWORD_SUFFIX_LENGTH) + return UNUSABLE_PASSWORD_PREFIX + get_random_string( + UNUSABLE_PASSWORD_SUFFIX_LENGTH + ) if not isinstance(password, (bytes, str)): raise TypeError( - 'Password must be a string or bytes, got %s.' - % type(password).__qualname__ + "Password must be a string or bytes, got %s." % type(password).__qualname__ ) hasher = get_hasher(hasher) salt = salt or hasher.salt() @@ -89,7 +96,7 @@ def get_hashers(): for hasher_path in settings.PASSWORD_HASHERS: hasher_cls = import_string(hasher_path) hasher = hasher_cls() - if not getattr(hasher, 'algorithm'): + if not getattr(hasher, "algorithm"): raise ImproperlyConfigured( "hasher doesn't specify an algorithm name: %s" % hasher_path ) @@ -104,22 +111,22 @@ def get_hashers_by_algorithm(): @receiver(setting_changed) def reset_hashers(*, setting, **kwargs): - if setting == 'PASSWORD_HASHERS': + if setting == "PASSWORD_HASHERS": get_hashers.cache_clear() get_hashers_by_algorithm.cache_clear() -def get_hasher(algorithm='default'): +def get_hasher(algorithm="default"): """ Return an instance of a loaded password hasher. If algorithm is 'default', return the default hasher. Lazily import hashers specified in the project's settings file if needed. """ - if hasattr(algorithm, 'algorithm'): + if hasattr(algorithm, "algorithm"): return algorithm - elif algorithm == 'default': + elif algorithm == "default": return get_hashers()[0] else: @@ -127,9 +134,11 @@ def get_hasher(algorithm='default'): try: return hashers[algorithm] except KeyError: - raise ValueError("Unknown password hashing algorithm '%s'. " - "Did you specify it in the PASSWORD_HASHERS " - "setting?" % algorithm) + raise ValueError( + "Unknown password hashing algorithm '%s'. " + "Did you specify it in the PASSWORD_HASHERS " + "setting?" % algorithm + ) def identify_hasher(encoded): @@ -142,14 +151,15 @@ def identify_hasher(encoded): """ # Ancient versions of Django created plain MD5 passwords and accepted # MD5 passwords with an empty salt. - if ((len(encoded) == 32 and '$' not in encoded) or - (len(encoded) == 37 and encoded.startswith('md5$$'))): - algorithm = 'unsalted_md5' + if (len(encoded) == 32 and "$" not in encoded) or ( + len(encoded) == 37 and encoded.startswith("md5$$") + ): + algorithm = "unsalted_md5" # Ancient versions of Django accepted SHA1 passwords with an empty salt. - elif len(encoded) == 46 and encoded.startswith('sha1$$'): - algorithm = 'unsalted_sha1' + elif len(encoded) == 46 and encoded.startswith("sha1$$"): + algorithm = "unsalted_sha1" else: - algorithm = encoded.split('$', 1)[0] + algorithm = encoded.split("$", 1)[0] return get_hasher(algorithm) @@ -177,6 +187,7 @@ class BasePasswordHasher: PasswordHasher objects are immutable. """ + algorithm = None library = None salt_entropy = 128 @@ -190,11 +201,14 @@ class BasePasswordHasher: try: module = importlib.import_module(mod_path) except ImportError as e: - raise ValueError("Couldn't load %r algorithm library: %s" % - (self.__class__.__name__, e)) + raise ValueError( + "Couldn't load %r algorithm library: %s" + % (self.__class__.__name__, e) + ) return module - raise ValueError("Hasher %r doesn't specify a library attribute" % - self.__class__.__name__) + raise ValueError( + "Hasher %r doesn't specify a library attribute" % self.__class__.__name__ + ) def salt(self): """ @@ -208,13 +222,15 @@ class BasePasswordHasher: def verify(self, password, encoded): """Check if the given password is correct.""" - raise NotImplementedError('subclasses of BasePasswordHasher must provide a verify() method') + raise NotImplementedError( + "subclasses of BasePasswordHasher must provide a verify() method" + ) def _check_encode_args(self, password, salt): if password is None: - raise TypeError('password must be provided.') - if not salt or '$' in salt: - raise ValueError('salt must be provided and cannot contain $.') + raise TypeError("password must be provided.") + if not salt or "$" in salt: + raise ValueError("salt must be provided and cannot contain $.") def encode(self, password, salt): """ @@ -223,7 +239,9 @@ class BasePasswordHasher: The result is normally formatted as "algorithm$salt$hash" and must be fewer than 128 characters. """ - raise NotImplementedError('subclasses of BasePasswordHasher must provide an encode() method') + raise NotImplementedError( + "subclasses of BasePasswordHasher must provide an encode() method" + ) def decode(self, encoded): """ @@ -234,7 +252,7 @@ class BasePasswordHasher: `work_factor`. """ raise NotImplementedError( - 'subclasses of BasePasswordHasher must provide a decode() method.' + "subclasses of BasePasswordHasher must provide a decode() method." ) def safe_summary(self, encoded): @@ -244,7 +262,9 @@ class BasePasswordHasher: The result is a dictionary and will be used where the password field must be displayed to construct a safe representation of the password. """ - raise NotImplementedError('subclasses of BasePasswordHasher must provide a safe_summary() method') + raise NotImplementedError( + "subclasses of BasePasswordHasher must provide a safe_summary() method" + ) def must_update(self, encoded): return False @@ -260,7 +280,9 @@ class BasePasswordHasher: for any hasher that has a work factor. If not, this method should be defined as a no-op to silence the warning. """ - warnings.warn('subclasses of BasePasswordHasher should provide a harden_runtime() method') + warnings.warn( + "subclasses of BasePasswordHasher should provide a harden_runtime() method" + ) class PBKDF2PasswordHasher(BasePasswordHasher): @@ -271,6 +293,7 @@ class PBKDF2PasswordHasher(BasePasswordHasher): The result is a 64 byte binary string. Iterations may be changed safely but you must rename the algorithm if you change SHA256. """ + algorithm = "pbkdf2_sha256" iterations = 390000 digest = hashlib.sha256 @@ -279,43 +302,43 @@ class PBKDF2PasswordHasher(BasePasswordHasher): self._check_encode_args(password, salt) iterations = iterations or self.iterations hash = pbkdf2(password, salt, iterations, digest=self.digest) - hash = base64.b64encode(hash).decode('ascii').strip() + hash = base64.b64encode(hash).decode("ascii").strip() return "%s$%d$%s$%s" % (self.algorithm, iterations, salt, hash) def decode(self, encoded): - algorithm, iterations, salt, hash = encoded.split('$', 3) + algorithm, iterations, salt, hash = encoded.split("$", 3) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'hash': hash, - 'iterations': int(iterations), - 'salt': salt, + "algorithm": algorithm, + "hash": hash, + "iterations": int(iterations), + "salt": salt, } def verify(self, password, encoded): decoded = self.decode(encoded) - encoded_2 = self.encode(password, decoded['salt'], decoded['iterations']) + encoded_2 = self.encode(password, decoded["salt"], decoded["iterations"]) return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('iterations'): decoded['iterations'], - _('salt'): mask_hash(decoded['salt']), - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("iterations"): decoded["iterations"], + _("salt"): mask_hash(decoded["salt"]), + _("hash"): mask_hash(decoded["hash"]), } def must_update(self, encoded): decoded = self.decode(encoded) - update_salt = must_update_salt(decoded['salt'], self.salt_entropy) - return (decoded['iterations'] != self.iterations) or update_salt + update_salt = must_update_salt(decoded["salt"], self.salt_entropy) + return (decoded["iterations"] != self.iterations) or update_salt def harden_runtime(self, password, encoded): decoded = self.decode(encoded) - extra_iterations = self.iterations - decoded['iterations'] + extra_iterations = self.iterations - decoded["iterations"] if extra_iterations > 0: - self.encode(password, decoded['salt'], extra_iterations) + self.encode(password, decoded["salt"], extra_iterations) class PBKDF2SHA1PasswordHasher(PBKDF2PasswordHasher): @@ -325,6 +348,7 @@ class PBKDF2SHA1PasswordHasher(PBKDF2PasswordHasher): implementations of PBKDF2, such as openssl's PKCS5_PBKDF2_HMAC_SHA1(). """ + algorithm = "pbkdf2_sha1" digest = hashlib.sha1 @@ -337,8 +361,9 @@ class Argon2PasswordHasher(BasePasswordHasher): (https://password-hashing.net). It requires the argon2-cffi library which depends on native C code and might cause portability issues. """ - algorithm = 'argon2' - library = 'argon2' + + algorithm = "argon2" + library = "argon2" time_cost = 2 memory_cost = 102400 @@ -356,59 +381,59 @@ class Argon2PasswordHasher(BasePasswordHasher): hash_len=params.hash_len, type=params.type, ) - return self.algorithm + data.decode('ascii') + return self.algorithm + data.decode("ascii") def decode(self, encoded): argon2 = self._load_library() - algorithm, rest = encoded.split('$', 1) + algorithm, rest = encoded.split("$", 1) assert algorithm == self.algorithm - params = argon2.extract_parameters('$' + rest) - variety, *_, b64salt, hash = rest.split('$') + params = argon2.extract_parameters("$" + rest) + variety, *_, b64salt, hash = rest.split("$") # Add padding. - b64salt += '=' * (-len(b64salt) % 4) - salt = base64.b64decode(b64salt).decode('latin1') + b64salt += "=" * (-len(b64salt) % 4) + salt = base64.b64decode(b64salt).decode("latin1") return { - 'algorithm': algorithm, - 'hash': hash, - 'memory_cost': params.memory_cost, - 'parallelism': params.parallelism, - 'salt': salt, - 'time_cost': params.time_cost, - 'variety': variety, - 'version': params.version, - 'params': params, + "algorithm": algorithm, + "hash": hash, + "memory_cost": params.memory_cost, + "parallelism": params.parallelism, + "salt": salt, + "time_cost": params.time_cost, + "variety": variety, + "version": params.version, + "params": params, } def verify(self, password, encoded): argon2 = self._load_library() - algorithm, rest = encoded.split('$', 1) + algorithm, rest = encoded.split("$", 1) assert algorithm == self.algorithm try: - return argon2.PasswordHasher().verify('$' + rest, password) + return argon2.PasswordHasher().verify("$" + rest, password) except argon2.exceptions.VerificationError: return False def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('variety'): decoded['variety'], - _('version'): decoded['version'], - _('memory cost'): decoded['memory_cost'], - _('time cost'): decoded['time_cost'], - _('parallelism'): decoded['parallelism'], - _('salt'): mask_hash(decoded['salt']), - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("variety"): decoded["variety"], + _("version"): decoded["version"], + _("memory cost"): decoded["memory_cost"], + _("time cost"): decoded["time_cost"], + _("parallelism"): decoded["parallelism"], + _("salt"): mask_hash(decoded["salt"]), + _("hash"): mask_hash(decoded["hash"]), } def must_update(self, encoded): decoded = self.decode(encoded) - current_params = decoded['params'] + current_params = decoded["params"] new_params = self.params() # Set salt_len to the salt_len of the current parameters because salt # is explicitly passed to argon2. new_params.salt_len = current_params.salt_len - update_salt = must_update_salt(decoded['salt'], self.salt_entropy) + update_salt = must_update_salt(decoded["salt"], self.salt_entropy) return (current_params != new_params) or update_salt def harden_runtime(self, password, encoded): @@ -439,6 +464,7 @@ class BCryptSHA256PasswordHasher(BasePasswordHasher): this library depends on native C code and might cause portability issues. """ + algorithm = "bcrypt_sha256" digest = hashlib.sha256 library = ("bcrypt", "bcrypt") @@ -458,46 +484,46 @@ class BCryptSHA256PasswordHasher(BasePasswordHasher): password = binascii.hexlify(self.digest(password).digest()) data = bcrypt.hashpw(password, salt) - return "%s$%s" % (self.algorithm, data.decode('ascii')) + return "%s$%s" % (self.algorithm, data.decode("ascii")) def decode(self, encoded): - algorithm, empty, algostr, work_factor, data = encoded.split('$', 4) + algorithm, empty, algostr, work_factor, data = encoded.split("$", 4) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'algostr': algostr, - 'checksum': data[22:], - 'salt': data[:22], - 'work_factor': int(work_factor), + "algorithm": algorithm, + "algostr": algostr, + "checksum": data[22:], + "salt": data[:22], + "work_factor": int(work_factor), } def verify(self, password, encoded): - algorithm, data = encoded.split('$', 1) + algorithm, data = encoded.split("$", 1) assert algorithm == self.algorithm - encoded_2 = self.encode(password, data.encode('ascii')) + encoded_2 = self.encode(password, data.encode("ascii")) return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('work factor'): decoded['work_factor'], - _('salt'): mask_hash(decoded['salt']), - _('checksum'): mask_hash(decoded['checksum']), + _("algorithm"): decoded["algorithm"], + _("work factor"): decoded["work_factor"], + _("salt"): mask_hash(decoded["salt"]), + _("checksum"): mask_hash(decoded["checksum"]), } def must_update(self, encoded): decoded = self.decode(encoded) - return decoded['work_factor'] != self.rounds + return decoded["work_factor"] != self.rounds def harden_runtime(self, password, encoded): - _, data = encoded.split('$', 1) + _, data = encoded.split("$", 1) salt = data[:29] # Length of the salt in bcrypt. - rounds = data.split('$')[2] + rounds = data.split("$")[2] # work factor is logarithmic, adding one doubles the load. - diff = 2**(self.rounds - int(rounds)) - 1 + diff = 2 ** (self.rounds - int(rounds)) - 1 while diff > 0: - self.encode(password, salt.encode('ascii')) + self.encode(password, salt.encode("ascii")) diff -= 1 @@ -514,6 +540,7 @@ class BCryptPasswordHasher(BCryptSHA256PasswordHasher): bcrypt's 72 bytes password truncation. Most use cases should prefer the BCryptSHA256PasswordHasher. """ + algorithm = "bcrypt" digest = None @@ -522,11 +549,12 @@ class ScryptPasswordHasher(BasePasswordHasher): """ Secure password hashing using the Scrypt algorithm. """ - algorithm = 'scrypt' + + algorithm = "scrypt" block_size = 8 maxmem = 0 parallelism = 1 - work_factor = 2 ** 14 + work_factor = 2**14 def encode(self, password, salt, n=None, r=None, p=None): self._check_encode_args(password, salt) @@ -542,49 +570,51 @@ class ScryptPasswordHasher(BasePasswordHasher): maxmem=self.maxmem, dklen=64, ) - hash_ = base64.b64encode(hash_).decode('ascii').strip() - return '%s$%d$%s$%d$%d$%s' % (self.algorithm, n, salt, r, p, hash_) + hash_ = base64.b64encode(hash_).decode("ascii").strip() + return "%s$%d$%s$%d$%d$%s" % (self.algorithm, n, salt, r, p, hash_) def decode(self, encoded): - algorithm, work_factor, salt, block_size, parallelism, hash_ = encoded.split('$', 6) + algorithm, work_factor, salt, block_size, parallelism, hash_ = encoded.split( + "$", 6 + ) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'work_factor': int(work_factor), - 'salt': salt, - 'block_size': int(block_size), - 'parallelism': int(parallelism), - 'hash': hash_, + "algorithm": algorithm, + "work_factor": int(work_factor), + "salt": salt, + "block_size": int(block_size), + "parallelism": int(parallelism), + "hash": hash_, } def verify(self, password, encoded): decoded = self.decode(encoded) encoded_2 = self.encode( password, - decoded['salt'], - decoded['work_factor'], - decoded['block_size'], - decoded['parallelism'], + decoded["salt"], + decoded["work_factor"], + decoded["block_size"], + decoded["parallelism"], ) return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('work factor'): decoded['work_factor'], - _('block size'): decoded['block_size'], - _('parallelism'): decoded['parallelism'], - _('salt'): mask_hash(decoded['salt']), - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("work factor"): decoded["work_factor"], + _("block size"): decoded["block_size"], + _("parallelism"): decoded["parallelism"], + _("salt"): mask_hash(decoded["salt"]), + _("hash"): mask_hash(decoded["hash"]), } def must_update(self, encoded): decoded = self.decode(encoded) return ( - decoded['work_factor'] != self.work_factor or - decoded['block_size'] != self.block_size or - decoded['parallelism'] != self.parallelism + decoded["work_factor"] != self.work_factor + or decoded["block_size"] != self.block_size + or decoded["parallelism"] != self.parallelism ) def harden_runtime(self, password, encoded): @@ -597,6 +627,7 @@ class SHA1PasswordHasher(BasePasswordHasher): """ The SHA1 password hashing algorithm (not recommended) """ + algorithm = "sha1" def encode(self, password, salt): @@ -605,30 +636,30 @@ class SHA1PasswordHasher(BasePasswordHasher): return "%s$%s$%s" % (self.algorithm, salt, hash) def decode(self, encoded): - algorithm, salt, hash = encoded.split('$', 2) + algorithm, salt, hash = encoded.split("$", 2) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'hash': hash, - 'salt': salt, + "algorithm": algorithm, + "hash": hash, + "salt": salt, } def verify(self, password, encoded): decoded = self.decode(encoded) - encoded_2 = self.encode(password, decoded['salt']) + encoded_2 = self.encode(password, decoded["salt"]) return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('salt'): mask_hash(decoded['salt'], show=2), - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("salt"): mask_hash(decoded["salt"], show=2), + _("hash"): mask_hash(decoded["hash"]), } def must_update(self, encoded): decoded = self.decode(encoded) - return must_update_salt(decoded['salt'], self.salt_entropy) + return must_update_salt(decoded["salt"], self.salt_entropy) def harden_runtime(self, password, encoded): pass @@ -638,6 +669,7 @@ class MD5PasswordHasher(BasePasswordHasher): """ The Salted MD5 password hashing algorithm (not recommended) """ + algorithm = "md5" def encode(self, password, salt): @@ -646,30 +678,30 @@ class MD5PasswordHasher(BasePasswordHasher): return "%s$%s$%s" % (self.algorithm, salt, hash) def decode(self, encoded): - algorithm, salt, hash = encoded.split('$', 2) + algorithm, salt, hash = encoded.split("$", 2) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'hash': hash, - 'salt': salt, + "algorithm": algorithm, + "hash": hash, + "salt": salt, } def verify(self, password, encoded): decoded = self.decode(encoded) - encoded_2 = self.encode(password, decoded['salt']) + encoded_2 = self.encode(password, decoded["salt"]) return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('salt'): mask_hash(decoded['salt'], show=2), - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("salt"): mask_hash(decoded["salt"], show=2), + _("hash"): mask_hash(decoded["hash"]), } def must_update(self, encoded): decoded = self.decode(encoded) - return must_update_salt(decoded['salt'], self.salt_entropy) + return must_update_salt(decoded["salt"], self.salt_entropy) def harden_runtime(self, password, encoded): pass @@ -684,34 +716,35 @@ class UnsaltedSHA1PasswordHasher(BasePasswordHasher): hashes. Some older Django installs still have these values lingering around so we need to handle and upgrade them properly. """ + algorithm = "unsalted_sha1" def salt(self): - return '' + return "" def encode(self, password, salt): - if salt != '': - raise ValueError('salt must be empty.') + if salt != "": + raise ValueError("salt must be empty.") hash = hashlib.sha1(password.encode()).hexdigest() - return 'sha1$$%s' % hash + return "sha1$$%s" % hash def decode(self, encoded): - assert encoded.startswith('sha1$$') + assert encoded.startswith("sha1$$") return { - 'algorithm': self.algorithm, - 'hash': encoded[6:], - 'salt': None, + "algorithm": self.algorithm, + "hash": encoded[6:], + "salt": None, } def verify(self, password, encoded): - encoded_2 = self.encode(password, '') + encoded_2 = self.encode(password, "") return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('hash'): mask_hash(decoded['hash']), + _("algorithm"): decoded["algorithm"], + _("hash"): mask_hash(decoded["hash"]), } def harden_runtime(self, password, encoded): @@ -729,34 +762,35 @@ class UnsaltedMD5PasswordHasher(BasePasswordHasher): these values lingering around so we need to handle and upgrade them properly. """ + algorithm = "unsalted_md5" def salt(self): - return '' + return "" def encode(self, password, salt): - if salt != '': - raise ValueError('salt must be empty.') + if salt != "": + raise ValueError("salt must be empty.") return md5(password.encode()).hexdigest() def decode(self, encoded): return { - 'algorithm': self.algorithm, - 'hash': encoded, - 'salt': None, + "algorithm": self.algorithm, + "hash": encoded, + "salt": None, } def verify(self, password, encoded): - if len(encoded) == 37 and encoded.startswith('md5$$'): + if len(encoded) == 37 and encoded.startswith("md5$$"): encoded = encoded[5:] - encoded_2 = self.encode(password, '') + encoded_2 = self.encode(password, "") return constant_time_compare(encoded, encoded_2) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('hash'): mask_hash(decoded['hash'], show=3), + _("algorithm"): decoded["algorithm"], + _("hash"): mask_hash(decoded["hash"], show=3), } def harden_runtime(self, password, encoded): @@ -769,6 +803,7 @@ class CryptPasswordHasher(BasePasswordHasher): The crypt module is not supported on all platforms. """ + algorithm = "crypt" library = "crypt" @@ -778,34 +813,34 @@ class CryptPasswordHasher(BasePasswordHasher): def encode(self, password, salt): crypt = self._load_library() if len(salt) != 2: - raise ValueError('salt must be of length 2.') + raise ValueError("salt must be of length 2.") hash = crypt.crypt(password, salt) if hash is None: # A platform like OpenBSD with a dummy crypt module. - raise TypeError('hash must be provided.') + raise TypeError("hash must be provided.") # we don't need to store the salt, but Django used to do this - return '%s$%s$%s' % (self.algorithm, '', hash) + return "%s$%s$%s" % (self.algorithm, "", hash) def decode(self, encoded): - algorithm, salt, hash = encoded.split('$', 2) + algorithm, salt, hash = encoded.split("$", 2) assert algorithm == self.algorithm return { - 'algorithm': algorithm, - 'hash': hash, - 'salt': salt, + "algorithm": algorithm, + "hash": hash, + "salt": salt, } def verify(self, password, encoded): crypt = self._load_library() decoded = self.decode(encoded) - data = crypt.crypt(password, decoded['hash']) - return constant_time_compare(decoded['hash'], data) + data = crypt.crypt(password, decoded["hash"]) + return constant_time_compare(decoded["hash"], data) def safe_summary(self, encoded): decoded = self.decode(encoded) return { - _('algorithm'): decoded['algorithm'], - _('salt'): decoded['salt'], - _('hash'): mask_hash(decoded['hash'], show=3), + _("algorithm"): decoded["algorithm"], + _("salt"): decoded["salt"], + _("hash"): mask_hash(decoded["hash"], show=3), } def harden_runtime(self, password, encoded): diff --git a/django/contrib/auth/management/__init__.py b/django/contrib/auth/management/__init__.py index 1170cb0b20..0b5a982617 100644 --- a/django/contrib/auth/management/__init__.py +++ b/django/contrib/auth/management/__init__.py @@ -25,27 +25,43 @@ def _get_builtin_permissions(opts): """ perms = [] for action in opts.default_permissions: - perms.append(( - get_permission_codename(action, opts), - 'Can %s %s' % (action, opts.verbose_name_raw) - )) + perms.append( + ( + get_permission_codename(action, opts), + "Can %s %s" % (action, opts.verbose_name_raw), + ) + ) return perms -def create_permissions(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, apps=global_apps, **kwargs): +def create_permissions( + app_config, + verbosity=2, + interactive=True, + using=DEFAULT_DB_ALIAS, + apps=global_apps, + **kwargs, +): if not app_config.models_module: return # Ensure that contenttypes are created for this app. Needed if # 'django.contrib.auth' is in INSTALLED_APPS before # 'django.contrib.contenttypes'. - create_contenttypes(app_config, verbosity=verbosity, interactive=interactive, using=using, apps=apps, **kwargs) + create_contenttypes( + app_config, + verbosity=verbosity, + interactive=interactive, + using=using, + apps=apps, + **kwargs, + ) app_label = app_config.label try: app_config = apps.get_app_config(app_label) - ContentType = apps.get_model('contenttypes', 'ContentType') - Permission = apps.get_model('auth', 'Permission') + ContentType = apps.get_model("contenttypes", "ContentType") + Permission = apps.get_model("auth", "Permission") except LookupError: return @@ -60,7 +76,9 @@ def create_permissions(app_config, verbosity=2, interactive=True, using=DEFAULT_ for klass in app_config.get_models(): # Force looking up the content types in the current database # before creating foreign keys to them. - ctype = ContentType.objects.db_manager(using).get_for_model(klass, for_concrete_model=False) + ctype = ContentType.objects.db_manager(using).get_for_model( + klass, for_concrete_model=False + ) ctypes.add(ctype) for perm in _get_all_permissions(klass._meta): @@ -69,11 +87,13 @@ def create_permissions(app_config, verbosity=2, interactive=True, using=DEFAULT_ # Find all the Permissions that have a content_type for a model we're # looking for. We don't need to check for codenames since we already have # a list of the ones we're going to create. - all_perms = set(Permission.objects.using(using).filter( - content_type__in=ctypes, - ).values_list( - "content_type", "codename" - )) + all_perms = set( + Permission.objects.using(using) + .filter( + content_type__in=ctypes, + ) + .values_list("content_type", "codename") + ) perms = [ Permission(codename=codename, name=name, content_type=ct) @@ -97,7 +117,7 @@ def get_system_username(): # KeyError will be raised by os.getpwuid() (called by getuser()) # if there is no corresponding entry in the /etc/passwd file # (a very restricted chroot environment, for example). - return '' + return "" return result @@ -117,23 +137,25 @@ def get_default_username(check_db=True, database=DEFAULT_DB_ALIAS): # If the User model has been swapped out, we can't make any assumptions # about the default user name. if auth_app.User._meta.swapped: - return '' + return "" default_username = get_system_username() try: default_username = ( - unicodedata.normalize('NFKD', default_username) - .encode('ascii', 'ignore').decode('ascii') - .replace(' ', '').lower() + unicodedata.normalize("NFKD", default_username) + .encode("ascii", "ignore") + .decode("ascii") + .replace(" ", "") + .lower() ) except UnicodeDecodeError: - return '' + return "" # Run the username validator try: - auth_app.User._meta.get_field('username').run_validators(default_username) + auth_app.User._meta.get_field("username").run_validators(default_username) except exceptions.ValidationError: - return '' + return "" # Don't return the default username if it is already taken. if check_db and default_username: @@ -144,5 +166,5 @@ def get_default_username(check_db=True, database=DEFAULT_DB_ALIAS): except auth_app.User.DoesNotExist: pass else: - return '' + return "" return default_username diff --git a/django/contrib/auth/management/commands/changepassword.py b/django/contrib/auth/management/commands/changepassword.py index b0c0a7f59a..dfe6cbe5f5 100644 --- a/django/contrib/auth/management/commands/changepassword.py +++ b/django/contrib/auth/management/commands/changepassword.py @@ -22,25 +22,26 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'username', nargs='?', - help='Username to change password for; by default, it\'s the current username.', + "username", + nargs="?", + help="Username to change password for; by default, it's the current username.", ) parser.add_argument( - '--database', + "--database", default=DEFAULT_DB_ALIAS, help='Specifies the database to use. Default is "default".', ) def handle(self, *args, **options): - if options['username']: - username = options['username'] + if options["username"]: + username = options["username"] else: username = getpass.getuser() try: - u = UserModel._default_manager.using(options['database']).get(**{ - UserModel.USERNAME_FIELD: username - }) + u = UserModel._default_manager.using(options["database"]).get( + **{UserModel.USERNAME_FIELD: username} + ) except UserModel.DoesNotExist: raise CommandError("user '%s' does not exist" % username) @@ -54,20 +55,22 @@ class Command(BaseCommand): p1 = self._get_pass() p2 = self._get_pass("Password (again): ") if p1 != p2: - self.stdout.write('Passwords do not match. Please try again.') + self.stdout.write("Passwords do not match. Please try again.") count += 1 # Don't validate passwords that don't match. continue try: validate_password(p2, u) except ValidationError as err: - self.stderr.write('\n'.join(err.messages)) + self.stderr.write("\n".join(err.messages)) count += 1 else: password_validated = True if count == MAX_TRIES: - raise CommandError("Aborting password change for user '%s' after %s attempts" % (u, count)) + raise CommandError( + "Aborting password change for user '%s' after %s attempts" % (u, count) + ) u.set_password(p1) u.save() diff --git a/django/contrib/auth/management/commands/createsuperuser.py b/django/contrib/auth/management/commands/createsuperuser.py index 42898e6744..5fffa55a22 100644 --- a/django/contrib/auth/management/commands/createsuperuser.py +++ b/django/contrib/auth/management/commands/createsuperuser.py @@ -18,69 +18,77 @@ class NotRunningInTTYException(Exception): pass -PASSWORD_FIELD = 'password' +PASSWORD_FIELD = "password" class Command(BaseCommand): - help = 'Used to create a superuser.' + help = "Used to create a superuser." requires_migrations_checks = True - stealth_options = ('stdin',) + stealth_options = ("stdin",) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.UserModel = get_user_model() - self.username_field = self.UserModel._meta.get_field(self.UserModel.USERNAME_FIELD) + self.username_field = self.UserModel._meta.get_field( + self.UserModel.USERNAME_FIELD + ) def add_arguments(self, parser): parser.add_argument( - '--%s' % self.UserModel.USERNAME_FIELD, - help='Specifies the login for the superuser.', + "--%s" % self.UserModel.USERNAME_FIELD, + help="Specifies the login for the superuser.", ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", help=( - 'Tells Django to NOT prompt the user for input of any kind. ' - 'You must use --%s with --noinput, along with an option for ' - 'any other required field. Superusers created with --noinput will ' - 'not be able to log in until they\'re given a valid password.' % - self.UserModel.USERNAME_FIELD + "Tells Django to NOT prompt the user for input of any kind. " + "You must use --%s with --noinput, along with an option for " + "any other required field. Superusers created with --noinput will " + "not be able to log in until they're given a valid password." + % self.UserModel.USERNAME_FIELD ), ) parser.add_argument( - '--database', + "--database", default=DEFAULT_DB_ALIAS, help='Specifies the database to use. Default is "default".', ) for field_name in self.UserModel.REQUIRED_FIELDS: field = self.UserModel._meta.get_field(field_name) if field.many_to_many: - if field.remote_field.through and not field.remote_field.through._meta.auto_created: + if ( + field.remote_field.through + and not field.remote_field.through._meta.auto_created + ): raise CommandError( "Required field '%s' specifies a many-to-many " - "relation through model, which is not supported." - % field_name + "relation through model, which is not supported." % field_name ) else: parser.add_argument( - '--%s' % field_name, action='append', + "--%s" % field_name, + action="append", help=( - 'Specifies the %s for the superuser. Can be used ' - 'multiple times.' % field_name, + "Specifies the %s for the superuser. Can be used " + "multiple times." % field_name, ), ) else: parser.add_argument( - '--%s' % field_name, - help='Specifies the %s for the superuser.' % field_name, + "--%s" % field_name, + help="Specifies the %s for the superuser." % field_name, ) def execute(self, *args, **options): - self.stdin = options.get('stdin', sys.stdin) # Used for testing + self.stdin = options.get("stdin", sys.stdin) # Used for testing return super().execute(*args, **options) def handle(self, *args, **options): username = options[self.UserModel.USERNAME_FIELD] - database = options['database'] + database = options["database"] user_data = {} verbose_field_name = self.username_field.verbose_name try: @@ -91,26 +99,36 @@ class Command(BaseCommand): # If not provided, create the user with an unusable password. user_data[PASSWORD_FIELD] = None try: - if options['interactive']: + if options["interactive"]: # Same as user_data but without many to many fields and with # foreign keys as fake model instances instead of raw IDs. fake_user_data = {} - if hasattr(self.stdin, 'isatty') and not self.stdin.isatty(): + if hasattr(self.stdin, "isatty") and not self.stdin.isatty(): raise NotRunningInTTYException default_username = get_default_username(database=database) if username: - error_msg = self._validate_username(username, verbose_field_name, database) + error_msg = self._validate_username( + username, verbose_field_name, database + ) if error_msg: self.stderr.write(error_msg) username = None - elif username == '': - raise CommandError('%s cannot be blank.' % capfirst(verbose_field_name)) + elif username == "": + raise CommandError( + "%s cannot be blank." % capfirst(verbose_field_name) + ) # Prompt for username. while username is None: - message = self._get_input_message(self.username_field, default_username) - username = self.get_input_data(self.username_field, message, default_username) + message = self._get_input_message( + self.username_field, default_username + ) + username = self.get_input_data( + self.username_field, message, default_username + ) if username: - error_msg = self._validate_username(username, verbose_field_name, database) + error_msg = self._validate_username( + username, verbose_field_name, database + ) if error_msg: self.stderr.write(error_msg) username = None @@ -118,7 +136,8 @@ class Command(BaseCommand): user_data[self.UserModel.USERNAME_FIELD] = username fake_user_data[self.UserModel.USERNAME_FIELD] = ( self.username_field.remote_field.model(username) - if self.username_field.remote_field else username + if self.username_field.remote_field + else username ) # Prompt for required fields. for field_name in self.UserModel.REQUIRED_FIELDS: @@ -133,78 +152,98 @@ class Command(BaseCommand): if field.many_to_many and input_value: if not input_value.strip(): user_data[field_name] = None - self.stderr.write('Error: This field cannot be blank.') + self.stderr.write("Error: This field cannot be blank.") continue - user_data[field_name] = [pk.strip() for pk in input_value.split(',')] + user_data[field_name] = [ + pk.strip() for pk in input_value.split(",") + ] if not field.many_to_many: fake_user_data[field_name] = user_data[field_name] # Wrap any foreign keys in fake model instances. if field.many_to_one: - fake_user_data[field_name] = field.remote_field.model(user_data[field_name]) + fake_user_data[field_name] = field.remote_field.model( + user_data[field_name] + ) # Prompt for a password if the model has one. while PASSWORD_FIELD in user_data and user_data[PASSWORD_FIELD] is None: password = getpass.getpass() - password2 = getpass.getpass('Password (again): ') + password2 = getpass.getpass("Password (again): ") if password != password2: self.stderr.write("Error: Your passwords didn't match.") # Don't validate passwords that don't match. continue - if password.strip() == '': + if password.strip() == "": self.stderr.write("Error: Blank passwords aren't allowed.") # Don't validate blank passwords. continue try: validate_password(password2, self.UserModel(**fake_user_data)) except exceptions.ValidationError as err: - self.stderr.write('\n'.join(err.messages)) - response = input('Bypass password validation and create user anyway? [y/N]: ') - if response.lower() != 'y': + self.stderr.write("\n".join(err.messages)) + response = input( + "Bypass password validation and create user anyway? [y/N]: " + ) + if response.lower() != "y": continue user_data[PASSWORD_FIELD] = password else: # Non-interactive mode. # Use password from environment variable, if provided. - if PASSWORD_FIELD in user_data and 'DJANGO_SUPERUSER_PASSWORD' in os.environ: - user_data[PASSWORD_FIELD] = os.environ['DJANGO_SUPERUSER_PASSWORD'] + if ( + PASSWORD_FIELD in user_data + and "DJANGO_SUPERUSER_PASSWORD" in os.environ + ): + user_data[PASSWORD_FIELD] = os.environ["DJANGO_SUPERUSER_PASSWORD"] # Use username from environment variable, if not provided in # options. if username is None: - username = os.environ.get('DJANGO_SUPERUSER_' + self.UserModel.USERNAME_FIELD.upper()) + username = os.environ.get( + "DJANGO_SUPERUSER_" + self.UserModel.USERNAME_FIELD.upper() + ) if username is None: - raise CommandError('You must use --%s with --noinput.' % self.UserModel.USERNAME_FIELD) + raise CommandError( + "You must use --%s with --noinput." + % self.UserModel.USERNAME_FIELD + ) else: - error_msg = self._validate_username(username, verbose_field_name, database) + error_msg = self._validate_username( + username, verbose_field_name, database + ) if error_msg: raise CommandError(error_msg) user_data[self.UserModel.USERNAME_FIELD] = username for field_name in self.UserModel.REQUIRED_FIELDS: - env_var = 'DJANGO_SUPERUSER_' + field_name.upper() + env_var = "DJANGO_SUPERUSER_" + field_name.upper() value = options[field_name] or os.environ.get(env_var) if not value: - raise CommandError('You must use --%s with --noinput.' % field_name) + raise CommandError( + "You must use --%s with --noinput." % field_name + ) field = self.UserModel._meta.get_field(field_name) user_data[field_name] = field.clean(value, None) if field.many_to_many and isinstance(user_data[field_name], str): user_data[field_name] = [ - pk.strip() for pk in user_data[field_name].split(',') + pk.strip() for pk in user_data[field_name].split(",") ] - self.UserModel._default_manager.db_manager(database).create_superuser(**user_data) - if options['verbosity'] >= 1: + self.UserModel._default_manager.db_manager(database).create_superuser( + **user_data + ) + if options["verbosity"] >= 1: self.stdout.write("Superuser created successfully.") except KeyboardInterrupt: - self.stderr.write('\nOperation cancelled.') + self.stderr.write("\nOperation cancelled.") sys.exit(1) except exceptions.ValidationError as e: - raise CommandError('; '.join(e.messages)) + raise CommandError("; ".join(e.messages)) except NotRunningInTTYException: self.stdout.write( - 'Superuser creation skipped due to not running in a TTY. ' - 'You can run `manage.py createsuperuser` in your project ' - 'to create one manually.' + "Superuser creation skipped due to not running in a TTY. " + "You can run `manage.py createsuperuser` in your project " + "to create one manually." ) def get_input_data(self, field, message, default=None): @@ -213,38 +252,45 @@ class Command(BaseCommand): validation exceptions. """ raw_value = input(message) - if default and raw_value == '': + if default and raw_value == "": raw_value = default try: val = field.clean(raw_value, None) except exceptions.ValidationError as e: - self.stderr.write("Error: %s" % '; '.join(e.messages)) + self.stderr.write("Error: %s" % "; ".join(e.messages)) val = None return val def _get_input_message(self, field, default=None): - return '%s%s%s: ' % ( + return "%s%s%s: " % ( capfirst(field.verbose_name), - " (leave blank to use '%s')" % default if default else '', - ' (%s.%s)' % ( + " (leave blank to use '%s')" % default if default else "", + " (%s.%s)" + % ( field.remote_field.model._meta.object_name, - field.m2m_target_field_name() if field.many_to_many else field.remote_field.field_name, - ) if field.remote_field else '', + field.m2m_target_field_name() + if field.many_to_many + else field.remote_field.field_name, + ) + if field.remote_field + else "", ) def _validate_username(self, username, verbose_field_name, database): """Validate username. If invalid, return a string error message.""" if self.username_field.unique: try: - self.UserModel._default_manager.db_manager(database).get_by_natural_key(username) + self.UserModel._default_manager.db_manager(database).get_by_natural_key( + username + ) except self.UserModel.DoesNotExist: pass else: - return 'Error: That %s is already taken.' % verbose_field_name + return "Error: That %s is already taken." % verbose_field_name if not username: - return '%s cannot be blank.' % capfirst(verbose_field_name) + return "%s cannot be blank." % capfirst(verbose_field_name) try: self.username_field.clean(username, None) except exceptions.ValidationError as e: - return '; '.join(e.messages) + return "; ".join(e.messages) diff --git a/django/contrib/auth/middleware.py b/django/contrib/auth/middleware.py index 89699a7f82..dcc482154c 100644 --- a/django/contrib/auth/middleware.py +++ b/django/contrib/auth/middleware.py @@ -7,14 +7,14 @@ from django.utils.functional import SimpleLazyObject def get_user(request): - if not hasattr(request, '_cached_user'): + if not hasattr(request, "_cached_user"): request._cached_user = auth.get_user(request) return request._cached_user class AuthenticationMiddleware(MiddlewareMixin): def process_request(self, request): - if not hasattr(request, 'session'): + if not hasattr(request, "session"): raise ImproperlyConfigured( "The Django authentication middleware requires session " "middleware to be installed. Edit your MIDDLEWARE setting to " @@ -47,13 +47,14 @@ class RemoteUserMiddleware(MiddlewareMixin): def process_request(self, request): # AuthenticationMiddleware is required so that request.user exists. - if not hasattr(request, 'user'): + if not hasattr(request, "user"): raise ImproperlyConfigured( "The Django remote user auth middleware requires the" " authentication middleware to be installed. Edit your" " MIDDLEWARE setting to insert" " 'django.contrib.auth.middleware.AuthenticationMiddleware'" - " before the RemoteUserMiddleware class.") + " before the RemoteUserMiddleware class." + ) try: username = request.META[self.header] except KeyError: @@ -102,7 +103,9 @@ class RemoteUserMiddleware(MiddlewareMixin): but only if the user is authenticated via the RemoteUserBackend. """ try: - stored_backend = load_backend(request.session.get(auth.BACKEND_SESSION_KEY, '')) + stored_backend = load_backend( + request.session.get(auth.BACKEND_SESSION_KEY, "") + ) except ImportError: # backend failed to load auth.logout(request) @@ -121,4 +124,5 @@ class PersistentRemoteUserMiddleware(RemoteUserMiddleware): is only expected to happen on some "logon" URL and the rest of the application wants to use Django's authentication mechanism. """ + force_logout_if_no_header = False diff --git a/django/contrib/auth/migrations/0001_initial.py b/django/contrib/auth/migrations/0001_initial.py index 166bc66dfb..87652308c1 100644 --- a/django/contrib/auth/migrations/0001_initial.py +++ b/django/contrib/auth/migrations/0001_initial.py @@ -7,97 +7,191 @@ from django.utils import timezone class Migration(migrations.Migration): dependencies = [ - ('contenttypes', '__first__'), + ("contenttypes", "__first__"), ] operations = [ migrations.CreateModel( - name='Permission', + name="Permission", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(max_length=50, verbose_name='name')), - ('content_type', models.ForeignKey( - to='contenttypes.ContentType', - on_delete=models.CASCADE, - verbose_name='content type', - )), - ('codename', models.CharField(max_length=100, verbose_name='codename')), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("name", models.CharField(max_length=50, verbose_name="name")), + ( + "content_type", + models.ForeignKey( + to="contenttypes.ContentType", + on_delete=models.CASCADE, + verbose_name="content type", + ), + ), + ("codename", models.CharField(max_length=100, verbose_name="codename")), ], options={ - 'ordering': ['content_type__app_label', 'content_type__model', 'codename'], - 'unique_together': {('content_type', 'codename')}, - 'verbose_name': 'permission', - 'verbose_name_plural': 'permissions', + "ordering": [ + "content_type__app_label", + "content_type__model", + "codename", + ], + "unique_together": {("content_type", "codename")}, + "verbose_name": "permission", + "verbose_name_plural": "permissions", }, managers=[ - ('objects', django.contrib.auth.models.PermissionManager()), + ("objects", django.contrib.auth.models.PermissionManager()), ], ), migrations.CreateModel( - name='Group', + name="Group", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(unique=True, max_length=80, verbose_name='name')), - ('permissions', models.ManyToManyField(to='auth.Permission', verbose_name='permissions', blank=True)), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "name", + models.CharField(unique=True, max_length=80, verbose_name="name"), + ), + ( + "permissions", + models.ManyToManyField( + to="auth.Permission", verbose_name="permissions", blank=True + ), + ), ], options={ - 'verbose_name': 'group', - 'verbose_name_plural': 'groups', + "verbose_name": "group", + "verbose_name_plural": "groups", }, managers=[ - ('objects', django.contrib.auth.models.GroupManager()), + ("objects", django.contrib.auth.models.GroupManager()), ], ), migrations.CreateModel( - name='User', + name="User", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('password', models.CharField(max_length=128, verbose_name='password')), - ('last_login', models.DateTimeField(default=timezone.now, verbose_name='last login')), - ('is_superuser', models.BooleanField( - default=False, - help_text='Designates that this user has all permissions without explicitly assigning them.', - verbose_name='superuser status' - )), - ('username', models.CharField( - help_text='Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.', unique=True, - max_length=30, verbose_name='username', - validators=[validators.UnicodeUsernameValidator()], - )), - ('first_name', models.CharField(max_length=30, verbose_name='first name', blank=True)), - ('last_name', models.CharField(max_length=30, verbose_name='last name', blank=True)), - ('email', models.EmailField(max_length=75, verbose_name='email address', blank=True)), - ('is_staff', models.BooleanField( - default=False, help_text='Designates whether the user can log into this admin site.', - verbose_name='staff status' - )), - ('is_active', models.BooleanField( - default=True, verbose_name='active', help_text=( - 'Designates whether this user should be treated as active. Unselect this instead of deleting ' - 'accounts.' - ) - )), - ('date_joined', models.DateTimeField(default=timezone.now, verbose_name='date joined')), - ('groups', models.ManyToManyField( - to='auth.Group', verbose_name='groups', blank=True, related_name='user_set', - related_query_name='user', help_text=( - 'The groups this user belongs to. A user will get all permissions granted to each of their ' - 'groups.' - ) - )), - ('user_permissions', models.ManyToManyField( - to='auth.Permission', verbose_name='user permissions', blank=True, - help_text='Specific permissions for this user.', related_name='user_set', - related_query_name='user') - ), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("password", models.CharField(max_length=128, verbose_name="password")), + ( + "last_login", + models.DateTimeField( + default=timezone.now, verbose_name="last login" + ), + ), + ( + "is_superuser", + models.BooleanField( + default=False, + help_text="Designates that this user has all permissions without explicitly assigning them.", + verbose_name="superuser status", + ), + ), + ( + "username", + models.CharField( + help_text="Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.", + unique=True, + max_length=30, + verbose_name="username", + validators=[validators.UnicodeUsernameValidator()], + ), + ), + ( + "first_name", + models.CharField( + max_length=30, verbose_name="first name", blank=True + ), + ), + ( + "last_name", + models.CharField( + max_length=30, verbose_name="last name", blank=True + ), + ), + ( + "email", + models.EmailField( + max_length=75, verbose_name="email address", blank=True + ), + ), + ( + "is_staff", + models.BooleanField( + default=False, + help_text="Designates whether the user can log into this admin site.", + verbose_name="staff status", + ), + ), + ( + "is_active", + models.BooleanField( + default=True, + verbose_name="active", + help_text=( + "Designates whether this user should be treated as active. Unselect this instead of deleting " + "accounts." + ), + ), + ), + ( + "date_joined", + models.DateTimeField( + default=timezone.now, verbose_name="date joined" + ), + ), + ( + "groups", + models.ManyToManyField( + to="auth.Group", + verbose_name="groups", + blank=True, + related_name="user_set", + related_query_name="user", + help_text=( + "The groups this user belongs to. A user will get all permissions granted to each of their " + "groups." + ), + ), + ), + ( + "user_permissions", + models.ManyToManyField( + to="auth.Permission", + verbose_name="user permissions", + blank=True, + help_text="Specific permissions for this user.", + related_name="user_set", + related_query_name="user", + ), + ), ], options={ - 'swappable': 'AUTH_USER_MODEL', - 'verbose_name': 'user', - 'verbose_name_plural': 'users', + "swappable": "AUTH_USER_MODEL", + "verbose_name": "user", + "verbose_name_plural": "users", }, managers=[ - ('objects', django.contrib.auth.models.UserManager()), + ("objects", django.contrib.auth.models.UserManager()), ], ), ] diff --git a/django/contrib/auth/migrations/0002_alter_permission_name_max_length.py b/django/contrib/auth/migrations/0002_alter_permission_name_max_length.py index 556c320409..a9ca6f51a9 100644 --- a/django/contrib/auth/migrations/0002_alter_permission_name_max_length.py +++ b/django/contrib/auth/migrations/0002_alter_permission_name_max_length.py @@ -4,13 +4,13 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0001_initial'), + ("auth", "0001_initial"), ] operations = [ migrations.AlterField( - model_name='permission', - name='name', - field=models.CharField(max_length=255, verbose_name='name'), + model_name="permission", + name="name", + field=models.CharField(max_length=255, verbose_name="name"), ), ] diff --git a/django/contrib/auth/migrations/0003_alter_user_email_max_length.py b/django/contrib/auth/migrations/0003_alter_user_email_max_length.py index ee8a9bd607..8a57548460 100644 --- a/django/contrib/auth/migrations/0003_alter_user_email_max_length.py +++ b/django/contrib/auth/migrations/0003_alter_user_email_max_length.py @@ -4,13 +4,15 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0002_alter_permission_name_max_length'), + ("auth", "0002_alter_permission_name_max_length"), ] operations = [ migrations.AlterField( - model_name='user', - name='email', - field=models.EmailField(max_length=254, verbose_name='email address', blank=True), + model_name="user", + name="email", + field=models.EmailField( + max_length=254, verbose_name="email address", blank=True + ), ), ] diff --git a/django/contrib/auth/migrations/0004_alter_user_username_opts.py b/django/contrib/auth/migrations/0004_alter_user_username_opts.py index a16083ee37..380737478d 100644 --- a/django/contrib/auth/migrations/0004_alter_user_username_opts.py +++ b/django/contrib/auth/migrations/0004_alter_user_username_opts.py @@ -5,19 +5,21 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0003_alter_user_email_max_length'), + ("auth", "0003_alter_user_email_max_length"), ] # No database changes; modifies validators and error_messages (#13147). operations = [ migrations.AlterField( - model_name='user', - name='username', + model_name="user", + name="username", field=models.CharField( - error_messages={'unique': 'A user with that username already exists.'}, max_length=30, + error_messages={"unique": "A user with that username already exists."}, + max_length=30, validators=[validators.UnicodeUsernameValidator()], - help_text='Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.', - unique=True, verbose_name='username' + help_text="Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.", + unique=True, + verbose_name="username", ), ), ] diff --git a/django/contrib/auth/migrations/0005_alter_user_last_login_null.py b/django/contrib/auth/migrations/0005_alter_user_last_login_null.py index 97cd105a0f..8407e2d822 100644 --- a/django/contrib/auth/migrations/0005_alter_user_last_login_null.py +++ b/django/contrib/auth/migrations/0005_alter_user_last_login_null.py @@ -4,13 +4,15 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0004_alter_user_username_opts'), + ("auth", "0004_alter_user_username_opts"), ] operations = [ migrations.AlterField( - model_name='user', - name='last_login', - field=models.DateTimeField(null=True, verbose_name='last login', blank=True), + model_name="user", + name="last_login", + field=models.DateTimeField( + null=True, verbose_name="last login", blank=True + ), ), ] diff --git a/django/contrib/auth/migrations/0006_require_contenttypes_0002.py b/django/contrib/auth/migrations/0006_require_contenttypes_0002.py index 48c26be011..b4e816a56f 100644 --- a/django/contrib/auth/migrations/0006_require_contenttypes_0002.py +++ b/django/contrib/auth/migrations/0006_require_contenttypes_0002.py @@ -4,8 +4,8 @@ from django.db import migrations class Migration(migrations.Migration): dependencies = [ - ('auth', '0005_alter_user_last_login_null'), - ('contenttypes', '0002_remove_content_type_name'), + ("auth", "0005_alter_user_last_login_null"), + ("contenttypes", "0002_remove_content_type_name"), ] operations = [ diff --git a/django/contrib/auth/migrations/0007_alter_validators_add_error_messages.py b/django/contrib/auth/migrations/0007_alter_validators_add_error_messages.py index 42f5087730..e0c835115e 100644 --- a/django/contrib/auth/migrations/0007_alter_validators_add_error_messages.py +++ b/django/contrib/auth/migrations/0007_alter_validators_add_error_messages.py @@ -5,20 +5,20 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0006_require_contenttypes_0002'), + ("auth", "0006_require_contenttypes_0002"), ] operations = [ migrations.AlterField( - model_name='user', - name='username', + model_name="user", + name="username", field=models.CharField( - error_messages={'unique': 'A user with that username already exists.'}, - help_text='Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.', + error_messages={"unique": "A user with that username already exists."}, + help_text="Required. 30 characters or fewer. Letters, digits and @/./+/-/_ only.", max_length=30, unique=True, validators=[validators.UnicodeUsernameValidator()], - verbose_name='username', + verbose_name="username", ), ), ] diff --git a/django/contrib/auth/migrations/0008_alter_user_username_max_length.py b/django/contrib/auth/migrations/0008_alter_user_username_max_length.py index 7c9dae0950..4208dbcc7a 100644 --- a/django/contrib/auth/migrations/0008_alter_user_username_max_length.py +++ b/django/contrib/auth/migrations/0008_alter_user_username_max_length.py @@ -5,20 +5,20 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0007_alter_validators_add_error_messages'), + ("auth", "0007_alter_validators_add_error_messages"), ] operations = [ migrations.AlterField( - model_name='user', - name='username', + model_name="user", + name="username", field=models.CharField( - error_messages={'unique': 'A user with that username already exists.'}, - help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', + error_messages={"unique": "A user with that username already exists."}, + help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", max_length=150, unique=True, validators=[validators.UnicodeUsernameValidator()], - verbose_name='username', + verbose_name="username", ), ), ] diff --git a/django/contrib/auth/migrations/0009_alter_user_last_name_max_length.py b/django/contrib/auth/migrations/0009_alter_user_last_name_max_length.py index b217359e48..e0665366d0 100644 --- a/django/contrib/auth/migrations/0009_alter_user_last_name_max_length.py +++ b/django/contrib/auth/migrations/0009_alter_user_last_name_max_length.py @@ -4,13 +4,15 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0008_alter_user_username_max_length'), + ("auth", "0008_alter_user_username_max_length"), ] operations = [ migrations.AlterField( - model_name='user', - name='last_name', - field=models.CharField(blank=True, max_length=150, verbose_name='last name'), + model_name="user", + name="last_name", + field=models.CharField( + blank=True, max_length=150, verbose_name="last name" + ), ), ] diff --git a/django/contrib/auth/migrations/0010_alter_group_name_max_length.py b/django/contrib/auth/migrations/0010_alter_group_name_max_length.py index 67ea0610ca..a58e11480f 100644 --- a/django/contrib/auth/migrations/0010_alter_group_name_max_length.py +++ b/django/contrib/auth/migrations/0010_alter_group_name_max_length.py @@ -4,13 +4,13 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0009_alter_user_last_name_max_length'), + ("auth", "0009_alter_user_last_name_max_length"), ] operations = [ migrations.AlterField( - model_name='group', - name='name', - field=models.CharField(max_length=150, unique=True, verbose_name='name'), + model_name="group", + name="name", + field=models.CharField(max_length=150, unique=True, verbose_name="name"), ), ] diff --git a/django/contrib/auth/migrations/0011_update_proxy_permissions.py b/django/contrib/auth/migrations/0011_update_proxy_permissions.py index a939244561..b792141d59 100644 --- a/django/contrib/auth/migrations/0011_update_proxy_permissions.py +++ b/django/contrib/auth/migrations/0011_update_proxy_permissions.py @@ -20,23 +20,26 @@ def update_proxy_model_permissions(apps, schema_editor, reverse=False): of the proxy model. """ style = color_style() - Permission = apps.get_model('auth', 'Permission') - ContentType = apps.get_model('contenttypes', 'ContentType') + Permission = apps.get_model("auth", "Permission") + ContentType = apps.get_model("contenttypes", "ContentType") alias = schema_editor.connection.alias for Model in apps.get_models(): opts = Model._meta if not opts.proxy: continue proxy_default_permissions_codenames = [ - '%s_%s' % (action, opts.model_name) - for action in opts.default_permissions + "%s_%s" % (action, opts.model_name) for action in opts.default_permissions ] permissions_query = Q(codename__in=proxy_default_permissions_codenames) for codename, name in opts.permissions: permissions_query = permissions_query | Q(codename=codename, name=name) content_type_manager = ContentType.objects.db_manager(alias) - concrete_content_type = content_type_manager.get_for_model(Model, for_concrete_model=True) - proxy_content_type = content_type_manager.get_for_model(Model, for_concrete_model=False) + concrete_content_type = content_type_manager.get_for_model( + Model, for_concrete_model=True + ) + proxy_content_type = content_type_manager.get_for_model( + Model, for_concrete_model=False + ) old_content_type = proxy_content_type if reverse else concrete_content_type new_content_type = concrete_content_type if reverse else proxy_content_type try: @@ -46,9 +49,11 @@ def update_proxy_model_permissions(apps, schema_editor, reverse=False): content_type=old_content_type, ).update(content_type=new_content_type) except IntegrityError: - old = '{}_{}'.format(old_content_type.app_label, old_content_type.model) - new = '{}_{}'.format(new_content_type.app_label, new_content_type.model) - sys.stdout.write(style.WARNING(WARNING.format(old=old, new=new, query=permissions_query))) + old = "{}_{}".format(old_content_type.app_label, old_content_type.model) + new = "{}_{}".format(new_content_type.app_label, new_content_type.model) + sys.stdout.write( + style.WARNING(WARNING.format(old=old, new=new, query=permissions_query)) + ) def revert_proxy_model_permissions(apps, schema_editor): @@ -61,9 +66,11 @@ def revert_proxy_model_permissions(apps, schema_editor): class Migration(migrations.Migration): dependencies = [ - ('auth', '0010_alter_group_name_max_length'), - ('contenttypes', '0002_remove_content_type_name'), + ("auth", "0010_alter_group_name_max_length"), + ("contenttypes", "0002_remove_content_type_name"), ] operations = [ - migrations.RunPython(update_proxy_model_permissions, revert_proxy_model_permissions), + migrations.RunPython( + update_proxy_model_permissions, revert_proxy_model_permissions + ), ] diff --git a/django/contrib/auth/migrations/0012_alter_user_first_name_max_length.py b/django/contrib/auth/migrations/0012_alter_user_first_name_max_length.py index 69f123ff32..839c950417 100644 --- a/django/contrib/auth/migrations/0012_alter_user_first_name_max_length.py +++ b/django/contrib/auth/migrations/0012_alter_user_first_name_max_length.py @@ -4,13 +4,15 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('auth', '0011_update_proxy_permissions'), + ("auth", "0011_update_proxy_permissions"), ] operations = [ migrations.AlterField( - model_name='user', - name='first_name', - field=models.CharField(blank=True, max_length=150, verbose_name='first name'), + model_name="user", + name="first_name", + field=models.CharField( + blank=True, max_length=150, verbose_name="first name" + ), ), ] diff --git a/django/contrib/auth/mixins.py b/django/contrib/auth/mixins.py index 02f9d23f07..09fd087f56 100644 --- a/django/contrib/auth/mixins.py +++ b/django/contrib/auth/mixins.py @@ -12,8 +12,9 @@ class AccessMixin: Abstract CBV mixin that gives access mixins the same customizable functionality. """ + login_url = None - permission_denied_message = '' + permission_denied_message = "" raise_exception = False redirect_field_name = REDIRECT_FIELD_NAME @@ -24,8 +25,8 @@ class AccessMixin: login_url = self.login_url or settings.LOGIN_URL if not login_url: raise ImproperlyConfigured( - '{0} is missing the login_url attribute. Define {0}.login_url, settings.LOGIN_URL, or override ' - '{0}.get_login_url().'.format(self.__class__.__name__) + "{0} is missing the login_url attribute. Define {0}.login_url, settings.LOGIN_URL, or override " + "{0}.get_login_url().".format(self.__class__.__name__) ) return str(login_url) @@ -51,9 +52,8 @@ class AccessMixin: # path as the "next" url. login_scheme, login_netloc = urlparse(resolved_login_url)[:2] current_scheme, current_netloc = urlparse(path)[:2] - if ( - (not login_scheme or login_scheme == current_scheme) and - (not login_netloc or login_netloc == current_netloc) + if (not login_scheme or login_scheme == current_scheme) and ( + not login_netloc or login_netloc == current_netloc ): path = self.request.get_full_path() return redirect_to_login( @@ -65,6 +65,7 @@ class AccessMixin: class LoginRequiredMixin(AccessMixin): """Verify that the current user is authenticated.""" + def dispatch(self, request, *args, **kwargs): if not request.user.is_authenticated: return self.handle_no_permission() @@ -73,6 +74,7 @@ class LoginRequiredMixin(AccessMixin): class PermissionRequiredMixin(AccessMixin): """Verify that the current user has all specified permissions.""" + permission_required = None def get_permission_required(self): @@ -82,8 +84,8 @@ class PermissionRequiredMixin(AccessMixin): """ if self.permission_required is None: raise ImproperlyConfigured( - '{0} is missing the permission_required attribute. Define {0}.permission_required, or override ' - '{0}.get_permission_required().'.format(self.__class__.__name__) + "{0} is missing the permission_required attribute. Define {0}.permission_required, or override " + "{0}.get_permission_required().".format(self.__class__.__name__) ) if isinstance(self.permission_required, str): perms = (self.permission_required,) @@ -112,7 +114,9 @@ class UserPassesTestMixin(AccessMixin): def test_func(self): raise NotImplementedError( - '{} is missing the implementation of the test_func() method.'.format(self.__class__.__name__) + "{} is missing the implementation of the test_func() method.".format( + self.__class__.__name__ + ) ) def get_test_func(self): diff --git a/django/contrib/auth/models.py b/django/contrib/auth/models.py index a9faef3517..5dd9c03a76 100644 --- a/django/contrib/auth/models.py +++ b/django/contrib/auth/models.py @@ -20,7 +20,7 @@ def update_last_login(sender, user, **kwargs): the user logging in. """ user.last_login = timezone.now() - user.save(update_fields=['last_login']) + user.save(update_fields=["last_login"]) class PermissionManager(models.Manager): @@ -29,7 +29,9 @@ class PermissionManager(models.Manager): def get_by_natural_key(self, codename, app_label, model): return self.get( codename=codename, - content_type=ContentType.objects.db_manager(self.db).get_by_natural_key(app_label, model), + content_type=ContentType.objects.db_manager(self.db).get_by_natural_key( + app_label, model + ), ) @@ -56,34 +58,37 @@ class Permission(models.Model): The permissions listed above are automatically created for each model. """ - name = models.CharField(_('name'), max_length=255) + + name = models.CharField(_("name"), max_length=255) content_type = models.ForeignKey( ContentType, models.CASCADE, - verbose_name=_('content type'), + verbose_name=_("content type"), ) - codename = models.CharField(_('codename'), max_length=100) + codename = models.CharField(_("codename"), max_length=100) objects = PermissionManager() class Meta: - verbose_name = _('permission') - verbose_name_plural = _('permissions') - unique_together = [['content_type', 'codename']] - ordering = ['content_type__app_label', 'content_type__model', 'codename'] + verbose_name = _("permission") + verbose_name_plural = _("permissions") + unique_together = [["content_type", "codename"]] + ordering = ["content_type__app_label", "content_type__model", "codename"] def __str__(self): - return '%s | %s' % (self.content_type, self.name) + return "%s | %s" % (self.content_type, self.name) def natural_key(self): return (self.codename,) + self.content_type.natural_key() - natural_key.dependencies = ['contenttypes.contenttype'] + + natural_key.dependencies = ["contenttypes.contenttype"] class GroupManager(models.Manager): """ The manager for the auth's Group model. """ + use_in_migrations = True def get_by_natural_key(self, name): @@ -107,18 +112,19 @@ class Group(models.Model): members-only portion of your site, or sending them members-only email messages. """ - name = models.CharField(_('name'), max_length=150, unique=True) + + name = models.CharField(_("name"), max_length=150, unique=True) permissions = models.ManyToManyField( Permission, - verbose_name=_('permissions'), + verbose_name=_("permissions"), blank=True, ) objects = GroupManager() class Meta: - verbose_name = _('group') - verbose_name_plural = _('groups') + verbose_name = _("group") + verbose_name_plural = _("groups") def __str__(self): return self.name @@ -135,12 +141,14 @@ class UserManager(BaseUserManager): Create and save a user with the given username, email, and password. """ if not username: - raise ValueError('The given username must be set') + raise ValueError("The given username must be set") email = self.normalize_email(email) # Lookup the real model class from the global app registry so this # manager method can be used in migrations. This is fine because # managers are by definition working on the real model. - GlobalUserModel = apps.get_model(self.model._meta.app_label, self.model._meta.object_name) + GlobalUserModel = apps.get_model( + self.model._meta.app_label, self.model._meta.object_name + ) username = GlobalUserModel.normalize_username(username) user = self.model(username=username, email=email, **extra_fields) user.password = make_password(password) @@ -148,39 +156,40 @@ class UserManager(BaseUserManager): return user def create_user(self, username, email=None, password=None, **extra_fields): - extra_fields.setdefault('is_staff', False) - extra_fields.setdefault('is_superuser', False) + extra_fields.setdefault("is_staff", False) + extra_fields.setdefault("is_superuser", False) return self._create_user(username, email, password, **extra_fields) def create_superuser(self, username, email=None, password=None, **extra_fields): - extra_fields.setdefault('is_staff', True) - extra_fields.setdefault('is_superuser', True) + extra_fields.setdefault("is_staff", True) + extra_fields.setdefault("is_superuser", True) - if extra_fields.get('is_staff') is not True: - raise ValueError('Superuser must have is_staff=True.') - if extra_fields.get('is_superuser') is not True: - raise ValueError('Superuser must have is_superuser=True.') + if extra_fields.get("is_staff") is not True: + raise ValueError("Superuser must have is_staff=True.") + if extra_fields.get("is_superuser") is not True: + raise ValueError("Superuser must have is_superuser=True.") return self._create_user(username, email, password, **extra_fields) - def with_perm(self, perm, is_active=True, include_superusers=True, backend=None, obj=None): + def with_perm( + self, perm, is_active=True, include_superusers=True, backend=None, obj=None + ): if backend is None: backends = auth._get_backends(return_tuples=True) if len(backends) == 1: backend, _ = backends[0] else: raise ValueError( - 'You have multiple authentication backends configured and ' - 'therefore must provide the `backend` argument.' + "You have multiple authentication backends configured and " + "therefore must provide the `backend` argument." ) elif not isinstance(backend, str): raise TypeError( - 'backend must be a dotted import path string (got %r).' - % backend + "backend must be a dotted import path string (got %r)." % backend ) else: backend = auth.load_backend(backend) - if hasattr(backend, 'with_perm'): + if hasattr(backend, "with_perm"): return backend.with_perm( perm, is_active=is_active, @@ -193,7 +202,7 @@ class UserManager(BaseUserManager): # A few helper functions for common logic between User and AnonymousUser. def _user_get_permissions(user, obj, from_name): permissions = set() - name = 'get_%s_permissions' % from_name + name = "get_%s_permissions" % from_name for backend in auth.get_backends(): if hasattr(backend, name): permissions.update(getattr(backend, name)(user, obj)) @@ -205,7 +214,7 @@ def _user_has_perm(user, perm, obj): A backend can raise `PermissionDenied` to short-circuit permission checking. """ for backend in auth.get_backends(): - if not hasattr(backend, 'has_perm'): + if not hasattr(backend, "has_perm"): continue try: if backend.has_perm(user, perm, obj): @@ -220,7 +229,7 @@ def _user_has_module_perms(user, app_label): A backend can raise `PermissionDenied` to short-circuit permission checking. """ for backend in auth.get_backends(): - if not hasattr(backend, 'has_module_perms'): + if not hasattr(backend, "has_module_perms"): continue try: if backend.has_module_perms(user, app_label): @@ -235,30 +244,31 @@ class PermissionsMixin(models.Model): Add the fields and methods necessary to support the Group and Permission models using the ModelBackend. """ + is_superuser = models.BooleanField( - _('superuser status'), + _("superuser status"), default=False, help_text=_( - 'Designates that this user has all permissions without ' - 'explicitly assigning them.' + "Designates that this user has all permissions without " + "explicitly assigning them." ), ) groups = models.ManyToManyField( Group, - verbose_name=_('groups'), + verbose_name=_("groups"), blank=True, help_text=_( - 'The groups this user belongs to. A user will get all permissions ' - 'granted to each of their groups.' + "The groups this user belongs to. A user will get all permissions " + "granted to each of their groups." ), related_name="user_set", related_query_name="user", ) user_permissions = models.ManyToManyField( Permission, - verbose_name=_('user permissions'), + verbose_name=_("user permissions"), blank=True, - help_text=_('Specific permissions for this user.'), + help_text=_("Specific permissions for this user."), related_name="user_set", related_query_name="user", ) @@ -272,7 +282,7 @@ class PermissionsMixin(models.Model): Query all available auth backends. If an object is passed in, return only permissions matching this object. """ - return _user_get_permissions(self, obj, 'user') + return _user_get_permissions(self, obj, "user") def get_group_permissions(self, obj=None): """ @@ -280,10 +290,10 @@ class PermissionsMixin(models.Model): groups. Query all available auth backends. If an object is passed in, return only permissions matching this object. """ - return _user_get_permissions(self, obj, 'group') + return _user_get_permissions(self, obj, "group") def get_all_permissions(self, obj=None): - return _user_get_permissions(self, obj, 'all') + return _user_get_permissions(self, obj, "all") def has_perm(self, perm, obj=None): """ @@ -306,7 +316,7 @@ class PermissionsMixin(models.Model): object is passed, check if the user has all required perms for it. """ if not is_iterable(perm_list) or isinstance(perm_list, str): - raise ValueError('perm_list must be an iterable of permissions.') + raise ValueError("perm_list must be an iterable of permissions.") return all(self.has_perm(perm, obj) for perm in perm_list) def has_module_perms(self, app_label): @@ -328,45 +338,48 @@ class AbstractUser(AbstractBaseUser, PermissionsMixin): Username and password are required. Other fields are optional. """ + username_validator = UnicodeUsernameValidator() username = models.CharField( - _('username'), + _("username"), max_length=150, unique=True, - help_text=_('Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.'), + help_text=_( + "Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only." + ), validators=[username_validator], error_messages={ - 'unique': _("A user with that username already exists."), + "unique": _("A user with that username already exists."), }, ) - first_name = models.CharField(_('first name'), max_length=150, blank=True) - last_name = models.CharField(_('last name'), max_length=150, blank=True) - email = models.EmailField(_('email address'), blank=True) + first_name = models.CharField(_("first name"), max_length=150, blank=True) + last_name = models.CharField(_("last name"), max_length=150, blank=True) + email = models.EmailField(_("email address"), blank=True) is_staff = models.BooleanField( - _('staff status'), + _("staff status"), default=False, - help_text=_('Designates whether the user can log into this admin site.'), + help_text=_("Designates whether the user can log into this admin site."), ) is_active = models.BooleanField( - _('active'), + _("active"), default=True, help_text=_( - 'Designates whether this user should be treated as active. ' - 'Unselect this instead of deleting accounts.' + "Designates whether this user should be treated as active. " + "Unselect this instead of deleting accounts." ), ) - date_joined = models.DateTimeField(_('date joined'), default=timezone.now) + date_joined = models.DateTimeField(_("date joined"), default=timezone.now) objects = UserManager() - EMAIL_FIELD = 'email' - USERNAME_FIELD = 'username' - REQUIRED_FIELDS = ['email'] + EMAIL_FIELD = "email" + USERNAME_FIELD = "username" + REQUIRED_FIELDS = ["email"] class Meta: - verbose_name = _('user') - verbose_name_plural = _('users') + verbose_name = _("user") + verbose_name_plural = _("users") abstract = True def clean(self): @@ -377,7 +390,7 @@ class AbstractUser(AbstractBaseUser, PermissionsMixin): """ Return the first_name plus the last_name, with a space in between. """ - full_name = '%s %s' % (self.first_name, self.last_name) + full_name = "%s %s" % (self.first_name, self.last_name) return full_name.strip() def get_short_name(self): @@ -396,14 +409,15 @@ class User(AbstractUser): Username and password are required. Other fields are optional. """ + class Meta(AbstractUser.Meta): - swappable = 'AUTH_USER_MODEL' + swappable = "AUTH_USER_MODEL" class AnonymousUser: id = None pk = None - username = '' + username = "" is_staff = False is_active = False is_superuser = False @@ -411,7 +425,7 @@ class AnonymousUser: _user_permissions = EmptyManager(Permission) def __str__(self): - return 'AnonymousUser' + return "AnonymousUser" def __eq__(self, other): return isinstance(other, self.__class__) @@ -420,19 +434,29 @@ class AnonymousUser: return 1 # instances always return the same hash value def __int__(self): - raise TypeError('Cannot cast AnonymousUser to int. Are you trying to use it in place of User?') + raise TypeError( + "Cannot cast AnonymousUser to int. Are you trying to use it in place of User?" + ) def save(self): - raise NotImplementedError("Django doesn't provide a DB representation for AnonymousUser.") + raise NotImplementedError( + "Django doesn't provide a DB representation for AnonymousUser." + ) def delete(self): - raise NotImplementedError("Django doesn't provide a DB representation for AnonymousUser.") + raise NotImplementedError( + "Django doesn't provide a DB representation for AnonymousUser." + ) def set_password(self, raw_password): - raise NotImplementedError("Django doesn't provide a DB representation for AnonymousUser.") + raise NotImplementedError( + "Django doesn't provide a DB representation for AnonymousUser." + ) def check_password(self, raw_password): - raise NotImplementedError("Django doesn't provide a DB representation for AnonymousUser.") + raise NotImplementedError( + "Django doesn't provide a DB representation for AnonymousUser." + ) @property def groups(self): @@ -443,20 +467,20 @@ class AnonymousUser: return self._user_permissions def get_user_permissions(self, obj=None): - return _user_get_permissions(self, obj, 'user') + return _user_get_permissions(self, obj, "user") def get_group_permissions(self, obj=None): return set() def get_all_permissions(self, obj=None): - return _user_get_permissions(self, obj, 'all') + return _user_get_permissions(self, obj, "all") def has_perm(self, perm, obj=None): return _user_has_perm(self, perm, obj=obj) def has_perms(self, perm_list, obj=None): if not is_iterable(perm_list) or isinstance(perm_list, str): - raise ValueError('perm_list must be an iterable of permissions.') + raise ValueError("perm_list must be an iterable of permissions.") return all(self.has_perm(perm, obj) for perm in perm_list) def has_module_perms(self, module): diff --git a/django/contrib/auth/password_validation.py b/django/contrib/auth/password_validation.py index 2e64bf1586..6af742eb6a 100644 --- a/django/contrib/auth/password_validation.py +++ b/django/contrib/auth/password_validation.py @@ -6,12 +6,15 @@ from pathlib import Path from django.conf import settings from django.core.exceptions import ( - FieldDoesNotExist, ImproperlyConfigured, ValidationError, + FieldDoesNotExist, + ImproperlyConfigured, + ValidationError, ) from django.utils.functional import cached_property, lazy from django.utils.html import format_html, format_html_join from django.utils.module_loading import import_string -from django.utils.translation import gettext as _, ngettext +from django.utils.translation import gettext as _ +from django.utils.translation import ngettext @functools.lru_cache(maxsize=None) @@ -23,11 +26,11 @@ def get_password_validators(validator_config): validators = [] for validator in validator_config: try: - klass = import_string(validator['NAME']) + klass = import_string(validator["NAME"]) except ImportError: msg = "The module in NAME could not be imported: %s. Check your AUTH_PASSWORD_VALIDATORS setting." - raise ImproperlyConfigured(msg % validator['NAME']) - validators.append(klass(**validator.get('OPTIONS', {}))) + raise ImproperlyConfigured(msg % validator["NAME"]) + validators.append(klass(**validator.get("OPTIONS", {}))) return validators @@ -59,7 +62,7 @@ def password_changed(password, user=None, password_validators=None): if password_validators is None: password_validators = get_default_password_validators() for validator in password_validators: - password_changed = getattr(validator, 'password_changed', lambda *a: None) + password_changed = getattr(validator, "password_changed", lambda *a: None) password_changed(password, user) @@ -81,8 +84,10 @@ def _password_validators_help_text_html(password_validators=None): in an <ul>. """ help_texts = password_validators_help_texts(password_validators) - help_items = format_html_join('', '<li>{}</li>', ((help_text,) for help_text in help_texts)) - return format_html('<ul>{}</ul>', help_items) if help_items else '' + help_items = format_html_join( + "", "<li>{}</li>", ((help_text,) for help_text in help_texts) + ) + return format_html("<ul>{}</ul>", help_items) if help_items else "" password_validators_help_text_html = lazy(_password_validators_help_text_html, str) @@ -92,6 +97,7 @@ class MinimumLengthValidator: """ Validate that the password is of a minimum length. """ + def __init__(self, min_length=8): self.min_length = min_length @@ -101,18 +107,18 @@ class MinimumLengthValidator: ngettext( "This password is too short. It must contain at least %(min_length)d character.", "This password is too short. It must contain at least %(min_length)d characters.", - self.min_length + self.min_length, ), - code='password_too_short', - params={'min_length': self.min_length}, + code="password_too_short", + params={"min_length": self.min_length}, ) def get_help_text(self): return ngettext( "Your password must contain at least %(min_length)d character.", "Your password must contain at least %(min_length)d characters.", - self.min_length - ) % {'min_length': self.min_length} + self.min_length, + ) % {"min_length": self.min_length} def exceeds_maximum_length_ratio(password, max_similarity, value): @@ -156,12 +162,13 @@ class UserAttributeSimilarityValidator: example, a password is validated against either part of an email address, as well as the full address. """ - DEFAULT_USER_ATTRIBUTES = ('username', 'first_name', 'last_name', 'email') + + DEFAULT_USER_ATTRIBUTES = ("username", "first_name", "last_name", "email") def __init__(self, user_attributes=DEFAULT_USER_ATTRIBUTES, max_similarity=0.7): self.user_attributes = user_attributes if max_similarity < 0.1: - raise ValueError('max_similarity must be at least 0.1') + raise ValueError("max_similarity must be at least 0.1") self.max_similarity = max_similarity def validate(self, password, user=None): @@ -174,23 +181,32 @@ class UserAttributeSimilarityValidator: if not value or not isinstance(value, str): continue value_lower = value.lower() - value_parts = re.split(r'\W+', value_lower) + [value_lower] + value_parts = re.split(r"\W+", value_lower) + [value_lower] for value_part in value_parts: - if exceeds_maximum_length_ratio(password, self.max_similarity, value_part): + if exceeds_maximum_length_ratio( + password, self.max_similarity, value_part + ): continue - if SequenceMatcher(a=password, b=value_part).quick_ratio() >= self.max_similarity: + if ( + SequenceMatcher(a=password, b=value_part).quick_ratio() + >= self.max_similarity + ): try: - verbose_name = str(user._meta.get_field(attribute_name).verbose_name) + verbose_name = str( + user._meta.get_field(attribute_name).verbose_name + ) except FieldDoesNotExist: verbose_name = attribute_name raise ValidationError( _("The password is too similar to the %(verbose_name)s."), - code='password_too_similar', - params={'verbose_name': verbose_name}, + code="password_too_similar", + params={"verbose_name": verbose_name}, ) def get_help_text(self): - return _('Your password can’t be too similar to your other personal information.') + return _( + "Your password can’t be too similar to your other personal information." + ) class CommonPasswordValidator: @@ -206,13 +222,13 @@ class CommonPasswordValidator: @cached_property def DEFAULT_PASSWORD_LIST_PATH(self): - return Path(__file__).resolve().parent / 'common-passwords.txt.gz' + return Path(__file__).resolve().parent / "common-passwords.txt.gz" def __init__(self, password_list_path=DEFAULT_PASSWORD_LIST_PATH): if password_list_path is CommonPasswordValidator.DEFAULT_PASSWORD_LIST_PATH: password_list_path = self.DEFAULT_PASSWORD_LIST_PATH try: - with gzip.open(password_list_path, 'rt', encoding='utf-8') as f: + with gzip.open(password_list_path, "rt", encoding="utf-8") as f: self.passwords = {x.strip() for x in f} except OSError: with open(password_list_path) as f: @@ -222,23 +238,24 @@ class CommonPasswordValidator: if password.lower().strip() in self.passwords: raise ValidationError( _("This password is too common."), - code='password_too_common', + code="password_too_common", ) def get_help_text(self): - return _('Your password can’t be a commonly used password.') + return _("Your password can’t be a commonly used password.") class NumericPasswordValidator: """ Validate that the password is not entirely numeric. """ + def validate(self, password, user=None): if password.isdigit(): raise ValidationError( _("This password is entirely numeric."), - code='password_entirely_numeric', + code="password_entirely_numeric", ) def get_help_text(self): - return _('Your password can’t be entirely numeric.') + return _("Your password can’t be entirely numeric.") diff --git a/django/contrib/auth/tokens.py b/django/contrib/auth/tokens.py index 97ecd06bc8..09cc2b5195 100644 --- a/django/contrib/auth/tokens.py +++ b/django/contrib/auth/tokens.py @@ -10,13 +10,14 @@ class PasswordResetTokenGenerator: Strategy object used to generate and check tokens for the password reset mechanism. """ + key_salt = "django.contrib.auth.tokens.PasswordResetTokenGenerator" algorithm = None _secret = None _secret_fallbacks = None def __init__(self): - self.algorithm = self.algorithm or 'sha256' + self.algorithm = self.algorithm or "sha256" def _get_secret(self): return self._secret or settings.SECRET_KEY @@ -89,7 +90,9 @@ class PasswordResetTokenGenerator: self._make_hash_value(user, timestamp), secret=secret, algorithm=self.algorithm, - ).hexdigest()[::2] # Limit to shorten the URL. + ).hexdigest()[ + ::2 + ] # Limit to shorten the URL. return "%s-%s" % (ts_b36, hash_string) def _make_hash_value(self, user, timestamp): @@ -109,10 +112,14 @@ class PasswordResetTokenGenerator: """ # Truncate microseconds so that tokens are consistent even if the # database doesn't support microseconds. - login_timestamp = '' if user.last_login is None else user.last_login.replace(microsecond=0, tzinfo=None) + login_timestamp = ( + "" + if user.last_login is None + else user.last_login.replace(microsecond=0, tzinfo=None) + ) email_field = user.get_email_field_name() - email = getattr(user, email_field, '') or '' - return f'{user.pk}{user.password}{login_timestamp}{timestamp}{email}' + email = getattr(user, email_field, "") or "" + return f"{user.pk}{user.password}{login_timestamp}{timestamp}{email}" def _num_seconds(self, dt): return int((dt - datetime(2001, 1, 1)).total_seconds()) diff --git a/django/contrib/auth/urls.py b/django/contrib/auth/urls.py index cd4af69a55..699ba6179a 100644 --- a/django/contrib/auth/urls.py +++ b/django/contrib/auth/urls.py @@ -7,14 +7,30 @@ from django.contrib.auth import views from django.urls import path urlpatterns = [ - path('login/', views.LoginView.as_view(), name='login'), - path('logout/', views.LogoutView.as_view(), name='logout'), - - path('password_change/', views.PasswordChangeView.as_view(), name='password_change'), - path('password_change/done/', views.PasswordChangeDoneView.as_view(), name='password_change_done'), - - path('password_reset/', views.PasswordResetView.as_view(), name='password_reset'), - path('password_reset/done/', views.PasswordResetDoneView.as_view(), name='password_reset_done'), - path('reset/<uidb64>/<token>/', views.PasswordResetConfirmView.as_view(), name='password_reset_confirm'), - path('reset/done/', views.PasswordResetCompleteView.as_view(), name='password_reset_complete'), + path("login/", views.LoginView.as_view(), name="login"), + path("logout/", views.LogoutView.as_view(), name="logout"), + path( + "password_change/", views.PasswordChangeView.as_view(), name="password_change" + ), + path( + "password_change/done/", + views.PasswordChangeDoneView.as_view(), + name="password_change_done", + ), + path("password_reset/", views.PasswordResetView.as_view(), name="password_reset"), + path( + "password_reset/done/", + views.PasswordResetDoneView.as_view(), + name="password_reset_done", + ), + path( + "reset/<uidb64>/<token>/", + views.PasswordResetConfirmView.as_view(), + name="password_reset_confirm", + ), + path( + "reset/done/", + views.PasswordResetCompleteView.as_view(), + name="password_reset_complete", + ), ] diff --git a/django/contrib/auth/validators.py b/django/contrib/auth/validators.py index 9345c5cb0b..55f70283cc 100644 --- a/django/contrib/auth/validators.py +++ b/django/contrib/auth/validators.py @@ -7,19 +7,19 @@ from django.utils.translation import gettext_lazy as _ @deconstructible class ASCIIUsernameValidator(validators.RegexValidator): - regex = r'^[\w.@+-]+\Z' + regex = r"^[\w.@+-]+\Z" message = _( - 'Enter a valid username. This value may contain only English letters, ' - 'numbers, and @/./+/-/_ characters.' + "Enter a valid username. This value may contain only English letters, " + "numbers, and @/./+/-/_ characters." ) flags = re.ASCII @deconstructible class UnicodeUsernameValidator(validators.RegexValidator): - regex = r'^[\w.@+-]+\Z' + regex = r"^[\w.@+-]+\Z" message = _( - 'Enter a valid username. This value may contain only letters, ' - 'numbers, and @/./+/-/_ characters.' + "Enter a valid username. This value may contain only letters, " + "numbers, and @/./+/-/_ characters." ) flags = 0 diff --git a/django/contrib/auth/views.py b/django/contrib/auth/views.py index 8790076dc4..6de053f492 100644 --- a/django/contrib/auth/views.py +++ b/django/contrib/auth/views.py @@ -1,14 +1,18 @@ from urllib.parse import urlparse, urlunparse from django.conf import settings + # Avoid shadowing the login() and logout() views below. -from django.contrib.auth import ( - REDIRECT_FIELD_NAME, get_user_model, login as auth_login, - logout as auth_logout, update_session_auth_hash, -) +from django.contrib.auth import REDIRECT_FIELD_NAME, get_user_model +from django.contrib.auth import login as auth_login +from django.contrib.auth import logout as auth_logout +from django.contrib.auth import update_session_auth_hash from django.contrib.auth.decorators import login_required from django.contrib.auth.forms import ( - AuthenticationForm, PasswordChangeForm, PasswordResetForm, SetPasswordForm, + AuthenticationForm, + PasswordChangeForm, + PasswordResetForm, + SetPasswordForm, ) from django.contrib.auth.tokens import default_token_generator from django.contrib.sites.shortcuts import get_current_site @@ -17,9 +21,7 @@ from django.http import HttpResponseRedirect, QueryDict from django.shortcuts import resolve_url from django.urls import reverse_lazy from django.utils.decorators import method_decorator -from django.utils.http import ( - url_has_allowed_host_and_scheme, urlsafe_base64_decode, -) +from django.utils.http import url_has_allowed_host_and_scheme, urlsafe_base64_decode from django.utils.translation import gettext_lazy as _ from django.views.decorators.cache import never_cache from django.views.decorators.csrf import csrf_protect @@ -41,11 +43,12 @@ class LoginView(SuccessURLAllowedHostsMixin, FormView): """ Display the login form and handle the login action. """ + form_class = AuthenticationForm authentication_form = None next_page = None redirect_field_name = REDIRECT_FIELD_NAME - template_name = 'registration/login.html' + template_name = "registration/login.html" redirect_authenticated_user = False extra_context = None @@ -69,15 +72,14 @@ class LoginView(SuccessURLAllowedHostsMixin, FormView): def get_redirect_url(self): """Return the user-originating redirect URL if it's safe.""" redirect_to = self.request.POST.get( - self.redirect_field_name, - self.request.GET.get(self.redirect_field_name, '') + self.redirect_field_name, self.request.GET.get(self.redirect_field_name, "") ) url_is_safe = url_has_allowed_host_and_scheme( url=redirect_to, allowed_hosts=self.get_success_url_allowed_hosts(), require_https=self.request.is_secure(), ) - return redirect_to if url_is_safe else '' + return redirect_to if url_is_safe else "" def get_default_redirect_url(self): """Return the default redirect URL.""" @@ -88,7 +90,7 @@ class LoginView(SuccessURLAllowedHostsMixin, FormView): def get_form_kwargs(self): kwargs = super().get_form_kwargs() - kwargs['request'] = self.request + kwargs["request"] = self.request return kwargs def form_valid(self, form): @@ -99,12 +101,14 @@ class LoginView(SuccessURLAllowedHostsMixin, FormView): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) current_site = get_current_site(self.request) - context.update({ - self.redirect_field_name: self.get_redirect_url(), - 'site': current_site, - 'site_name': current_site.name, - **(self.extra_context or {}) - }) + context.update( + { + self.redirect_field_name: self.get_redirect_url(), + "site": current_site, + "site_name": current_site.name, + **(self.extra_context or {}), + } + ) return context @@ -112,9 +116,10 @@ class LogoutView(SuccessURLAllowedHostsMixin, TemplateView): """ Log out the user and display the 'You are logged out' message. """ + next_page = None redirect_field_name = REDIRECT_FIELD_NAME - template_name = 'registration/logged_out.html' + template_name = "registration/logged_out.html" extra_context = None @method_decorator(never_cache) @@ -138,11 +143,12 @@ class LogoutView(SuccessURLAllowedHostsMixin, TemplateView): else: next_page = self.next_page - if (self.redirect_field_name in self.request.POST or - self.redirect_field_name in self.request.GET): + if ( + self.redirect_field_name in self.request.POST + or self.redirect_field_name in self.request.GET + ): next_page = self.request.POST.get( - self.redirect_field_name, - self.request.GET.get(self.redirect_field_name) + self.redirect_field_name, self.request.GET.get(self.redirect_field_name) ) url_is_safe = url_has_allowed_host_and_scheme( url=next_page, @@ -158,13 +164,15 @@ class LogoutView(SuccessURLAllowedHostsMixin, TemplateView): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) current_site = get_current_site(self.request) - context.update({ - 'site': current_site, - 'site_name': current_site.name, - 'title': _('Logged out'), - 'subtitle': None, - **(self.extra_context or {}) - }) + context.update( + { + "site": current_site, + "site_name": current_site.name, + "title": _("Logged out"), + "subtitle": None, + **(self.extra_context or {}), + } + ) return context @@ -186,7 +194,7 @@ def redirect_to_login(next, login_url=None, redirect_field_name=REDIRECT_FIELD_N if redirect_field_name: querystring = QueryDict(login_url_parts[4], mutable=True) querystring[redirect_field_name] = next - login_url_parts[4] = querystring.urlencode(safe='/') + login_url_parts[4] = querystring.urlencode(safe="/") return HttpResponseRedirect(urlunparse(login_url_parts)) @@ -198,29 +206,28 @@ def redirect_to_login(next, login_url=None, redirect_field_name=REDIRECT_FIELD_N # prompts for a new password # - PasswordResetCompleteView shows a success message for the above + class PasswordContextMixin: extra_context = None def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - context.update({ - 'title': self.title, - 'subtitle': None, - **(self.extra_context or {}) - }) + context.update( + {"title": self.title, "subtitle": None, **(self.extra_context or {})} + ) return context class PasswordResetView(PasswordContextMixin, FormView): - email_template_name = 'registration/password_reset_email.html' + email_template_name = "registration/password_reset_email.html" extra_email_context = None form_class = PasswordResetForm from_email = None html_email_template_name = None - subject_template_name = 'registration/password_reset_subject.txt' - success_url = reverse_lazy('password_reset_done') - template_name = 'registration/password_reset_form.html' - title = _('Password reset') + subject_template_name = "registration/password_reset_subject.txt" + success_url = reverse_lazy("password_reset_done") + template_name = "registration/password_reset_form.html" + title = _("Password reset") token_generator = default_token_generator @method_decorator(csrf_protect) @@ -229,50 +236,50 @@ class PasswordResetView(PasswordContextMixin, FormView): def form_valid(self, form): opts = { - 'use_https': self.request.is_secure(), - 'token_generator': self.token_generator, - 'from_email': self.from_email, - 'email_template_name': self.email_template_name, - 'subject_template_name': self.subject_template_name, - 'request': self.request, - 'html_email_template_name': self.html_email_template_name, - 'extra_email_context': self.extra_email_context, + "use_https": self.request.is_secure(), + "token_generator": self.token_generator, + "from_email": self.from_email, + "email_template_name": self.email_template_name, + "subject_template_name": self.subject_template_name, + "request": self.request, + "html_email_template_name": self.html_email_template_name, + "extra_email_context": self.extra_email_context, } form.save(**opts) return super().form_valid(form) -INTERNAL_RESET_SESSION_TOKEN = '_password_reset_token' +INTERNAL_RESET_SESSION_TOKEN = "_password_reset_token" class PasswordResetDoneView(PasswordContextMixin, TemplateView): - template_name = 'registration/password_reset_done.html' - title = _('Password reset sent') + template_name = "registration/password_reset_done.html" + title = _("Password reset sent") class PasswordResetConfirmView(PasswordContextMixin, FormView): form_class = SetPasswordForm post_reset_login = False post_reset_login_backend = None - reset_url_token = 'set-password' - success_url = reverse_lazy('password_reset_complete') - template_name = 'registration/password_reset_confirm.html' - title = _('Enter new password') + reset_url_token = "set-password" + success_url = reverse_lazy("password_reset_complete") + template_name = "registration/password_reset_confirm.html" + title = _("Enter new password") token_generator = default_token_generator @method_decorator(sensitive_post_parameters()) @method_decorator(never_cache) def dispatch(self, *args, **kwargs): - if 'uidb64' not in kwargs or 'token' not in kwargs: + if "uidb64" not in kwargs or "token" not in kwargs: raise ImproperlyConfigured( "The URL path must contain 'uidb64' and 'token' parameters." ) self.validlink = False - self.user = self.get_user(kwargs['uidb64']) + self.user = self.get_user(kwargs["uidb64"]) if self.user is not None: - token = kwargs['token'] + token = kwargs["token"] if token == self.reset_url_token: session_token = self.request.session.get(INTERNAL_RESET_SESSION_TOKEN) if self.token_generator.check_token(self.user, session_token): @@ -286,7 +293,9 @@ class PasswordResetConfirmView(PasswordContextMixin, FormView): # avoids the possibility of leaking the token in the # HTTP Referer header. self.request.session[INTERNAL_RESET_SESSION_TOKEN] = token - redirect_url = self.request.path.replace(token, self.reset_url_token) + redirect_url = self.request.path.replace( + token, self.reset_url_token + ) return HttpResponseRedirect(redirect_url) # Display the "Password reset unsuccessful" page. @@ -297,13 +306,19 @@ class PasswordResetConfirmView(PasswordContextMixin, FormView): # urlsafe_base64_decode() decodes to bytestring uid = urlsafe_base64_decode(uidb64).decode() user = UserModel._default_manager.get(pk=uid) - except (TypeError, ValueError, OverflowError, UserModel.DoesNotExist, ValidationError): + except ( + TypeError, + ValueError, + OverflowError, + UserModel.DoesNotExist, + ValidationError, + ): user = None return user def get_form_kwargs(self): kwargs = super().get_form_kwargs() - kwargs['user'] = self.user + kwargs["user"] = self.user return kwargs def form_valid(self, form): @@ -316,31 +331,33 @@ class PasswordResetConfirmView(PasswordContextMixin, FormView): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) if self.validlink: - context['validlink'] = True + context["validlink"] = True else: - context.update({ - 'form': None, - 'title': _('Password reset unsuccessful'), - 'validlink': False, - }) + context.update( + { + "form": None, + "title": _("Password reset unsuccessful"), + "validlink": False, + } + ) return context class PasswordResetCompleteView(PasswordContextMixin, TemplateView): - template_name = 'registration/password_reset_complete.html' - title = _('Password reset complete') + template_name = "registration/password_reset_complete.html" + title = _("Password reset complete") def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - context['login_url'] = resolve_url(settings.LOGIN_URL) + context["login_url"] = resolve_url(settings.LOGIN_URL) return context class PasswordChangeView(PasswordContextMixin, FormView): form_class = PasswordChangeForm - success_url = reverse_lazy('password_change_done') - template_name = 'registration/password_change_form.html' - title = _('Password change') + success_url = reverse_lazy("password_change_done") + template_name = "registration/password_change_form.html" + title = _("Password change") @method_decorator(sensitive_post_parameters()) @method_decorator(csrf_protect) @@ -350,7 +367,7 @@ class PasswordChangeView(PasswordContextMixin, FormView): def get_form_kwargs(self): kwargs = super().get_form_kwargs() - kwargs['user'] = self.request.user + kwargs["user"] = self.request.user return kwargs def form_valid(self, form): @@ -362,8 +379,8 @@ class PasswordChangeView(PasswordContextMixin, FormView): class PasswordChangeDoneView(PasswordContextMixin, TemplateView): - template_name = 'registration/password_change_done.html' - title = _('Password change successful') + template_name = "registration/password_change_done.html" + title = _("Password change successful") @method_decorator(login_required) def dispatch(self, *args, **kwargs): diff --git a/django/contrib/contenttypes/admin.py b/django/contrib/contenttypes/admin.py index dd3e77c4c2..6f442549f1 100644 --- a/django/contrib/contenttypes/admin.py +++ b/django/contrib/contenttypes/admin.py @@ -4,7 +4,8 @@ from django.contrib.admin.checks import InlineModelAdminChecks from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.forms import ( - BaseGenericInlineFormSet, generic_inlineformset_factory, + BaseGenericInlineFormSet, + generic_inlineformset_factory, ) from django.core import checks from django.core.exceptions import FieldDoesNotExist @@ -22,7 +23,8 @@ class GenericInlineModelAdminChecks(InlineModelAdminChecks): # and that they are part of a GenericForeignKey. gfks = [ - f for f in obj.model._meta.private_fields + f + for f in obj.model._meta.private_fields if isinstance(f, GenericForeignKey) ] if not gfks: @@ -30,7 +32,7 @@ class GenericInlineModelAdminChecks(InlineModelAdminChecks): checks.Error( "'%s' has no GenericForeignKey." % obj.model._meta.label, obj=obj.__class__, - id='admin.E301' + id="admin.E301", ) ] else: @@ -40,11 +42,13 @@ class GenericInlineModelAdminChecks(InlineModelAdminChecks): except FieldDoesNotExist: return [ checks.Error( - "'ct_field' references '%s', which is not a field on '%s'." % ( - obj.ct_field, obj.model._meta.label, + "'ct_field' references '%s', which is not a field on '%s'." + % ( + obj.ct_field, + obj.model._meta.label, ), obj=obj.__class__, - id='admin.E302' + id="admin.E302", ) ] @@ -53,11 +57,13 @@ class GenericInlineModelAdminChecks(InlineModelAdminChecks): except FieldDoesNotExist: return [ checks.Error( - "'ct_fk_field' references '%s', which is not a field on '%s'." % ( - obj.ct_fk_field, obj.model._meta.label, + "'ct_fk_field' references '%s', which is not a field on '%s'." + % ( + obj.ct_fk_field, + obj.model._meta.label, ), obj=obj.__class__, - id='admin.E303' + id="admin.E303", ) ] @@ -69,11 +75,14 @@ class GenericInlineModelAdminChecks(InlineModelAdminChecks): return [ checks.Error( - "'%s' has no GenericForeignKey using content type field '%s' and object ID field '%s'." % ( - obj.model._meta.label, obj.ct_field, obj.ct_fk_field, + "'%s' has no GenericForeignKey using content type field '%s' and object ID field '%s'." + % ( + obj.model._meta.label, + obj.ct_field, + obj.ct_fk_field, ), obj=obj.__class__, - id='admin.E304' + id="admin.E304", ) ] @@ -86,42 +95,48 @@ class GenericInlineModelAdmin(InlineModelAdmin): checks_class = GenericInlineModelAdminChecks def get_formset(self, request, obj=None, **kwargs): - if 'fields' in kwargs: - fields = kwargs.pop('fields') + if "fields" in kwargs: + fields = kwargs.pop("fields") else: fields = flatten_fieldsets(self.get_fieldsets(request, obj)) exclude = [*(self.exclude or []), *self.get_readonly_fields(request, obj)] - if self.exclude is None and hasattr(self.form, '_meta') and self.form._meta.exclude: + if ( + self.exclude is None + and hasattr(self.form, "_meta") + and self.form._meta.exclude + ): # Take the custom ModelForm's Meta.exclude into account only if the # GenericInlineModelAdmin doesn't define its own. exclude.extend(self.form._meta.exclude) exclude = exclude or None can_delete = self.can_delete and self.has_delete_permission(request, obj) defaults = { - 'ct_field': self.ct_field, - 'fk_field': self.ct_fk_field, - 'form': self.form, - 'formfield_callback': partial(self.formfield_for_dbfield, request=request), - 'formset': self.formset, - 'extra': self.get_extra(request, obj), - 'can_delete': can_delete, - 'can_order': False, - 'fields': fields, - 'min_num': self.get_min_num(request, obj), - 'max_num': self.get_max_num(request, obj), - 'exclude': exclude, + "ct_field": self.ct_field, + "fk_field": self.ct_fk_field, + "form": self.form, + "formfield_callback": partial(self.formfield_for_dbfield, request=request), + "formset": self.formset, + "extra": self.get_extra(request, obj), + "can_delete": can_delete, + "can_order": False, + "fields": fields, + "min_num": self.get_min_num(request, obj), + "max_num": self.get_max_num(request, obj), + "exclude": exclude, **kwargs, } - if defaults['fields'] is None and not modelform_defines_fields(defaults['form']): - defaults['fields'] = ALL_FIELDS + if defaults["fields"] is None and not modelform_defines_fields( + defaults["form"] + ): + defaults["fields"] = ALL_FIELDS return generic_inlineformset_factory(self.model, **defaults) class GenericStackedInline(GenericInlineModelAdmin): - template = 'admin/edit_inline/stacked.html' + template = "admin/edit_inline/stacked.html" class GenericTabularInline(GenericInlineModelAdmin): - template = 'admin/edit_inline/tabular.html' + template = "admin/edit_inline/tabular.html" diff --git a/django/contrib/contenttypes/apps.py b/django/contrib/contenttypes/apps.py index 390afb3fcf..11dfb91010 100644 --- a/django/contrib/contenttypes/apps.py +++ b/django/contrib/contenttypes/apps.py @@ -1,19 +1,18 @@ from django.apps import AppConfig from django.contrib.contenttypes.checks import ( - check_generic_foreign_keys, check_model_name_lengths, + check_generic_foreign_keys, + check_model_name_lengths, ) from django.core import checks from django.db.models.signals import post_migrate, pre_migrate from django.utils.translation import gettext_lazy as _ -from .management import ( - create_contenttypes, inject_rename_contenttypes_operations, -) +from .management import create_contenttypes, inject_rename_contenttypes_operations class ContentTypesConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.contenttypes' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.contenttypes" verbose_name = _("Content Types") def ready(self): diff --git a/django/contrib/contenttypes/checks.py b/django/contrib/contenttypes/checks.py index 3e802ea26b..753c5d22f8 100644 --- a/django/contrib/contenttypes/checks.py +++ b/django/contrib/contenttypes/checks.py @@ -10,10 +10,14 @@ def check_generic_foreign_keys(app_configs=None, **kwargs): if app_configs is None: models = apps.get_models() else: - models = chain.from_iterable(app_config.get_models() for app_config in app_configs) + models = chain.from_iterable( + app_config.get_models() for app_config in app_configs + ) errors = [] fields = ( - obj for model in models for obj in vars(model).values() + obj + for model in models + for obj in vars(model).values() if isinstance(obj, GenericForeignKey) ) for field in fields: @@ -25,17 +29,18 @@ def check_model_name_lengths(app_configs=None, **kwargs): if app_configs is None: models = apps.get_models() else: - models = chain.from_iterable(app_config.get_models() for app_config in app_configs) + models = chain.from_iterable( + app_config.get_models() for app_config in app_configs + ) errors = [] for model in models: if len(model._meta.model_name) > 100: errors.append( Error( - 'Model names must be at most 100 characters (got %d).' % ( - len(model._meta.model_name), - ), + "Model names must be at most 100 characters (got %d)." + % (len(model._meta.model_name),), obj=model, - id='contenttypes.E005', + id="contenttypes.E005", ) ) return errors diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 278b5bb585..193c2f2687 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -10,7 +10,8 @@ from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel from django.db.models.base import ModelBase, make_foreign_order_accessors from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.related import ( - ReverseManyToOneDescriptor, lazy_related_operation, + ReverseManyToOneDescriptor, + lazy_related_operation, ) from django.db.models.query_utils import PathInfo from django.db.models.sql import AND @@ -41,7 +42,9 @@ class GenericForeignKey(FieldCacheMixin): related_model = None remote_field = None - def __init__(self, ct_field='content_type', fk_field='object_id', for_concrete_model=True): + def __init__( + self, ct_field="content_type", fk_field="object_id", for_concrete_model=True + ): self.ct_field = ct_field self.fk_field = fk_field self.for_concrete_model = for_concrete_model @@ -71,7 +74,7 @@ class GenericForeignKey(FieldCacheMixin): def __str__(self): model = self.model - return '%s.%s' % (model._meta.label, self.name) + return "%s.%s" % (model._meta.label, self.name) def check(self, **kwargs): return [ @@ -84,9 +87,9 @@ class GenericForeignKey(FieldCacheMixin): if self.name.endswith("_"): return [ checks.Error( - 'Field names must not end with an underscore.', + "Field names must not end with an underscore.", obj=self, - id='fields.E001', + id="fields.E001", ) ] else: @@ -101,7 +104,7 @@ class GenericForeignKey(FieldCacheMixin): "The GenericForeignKey object ID references the " "nonexistent field '%s'." % self.fk_field, obj=self, - id='contenttypes.E001', + id="contenttypes.E001", ) ] else: @@ -118,40 +121,37 @@ class GenericForeignKey(FieldCacheMixin): return [ checks.Error( "The GenericForeignKey content type references the " - "nonexistent field '%s.%s'." % ( - self.model._meta.object_name, self.ct_field - ), + "nonexistent field '%s.%s'." + % (self.model._meta.object_name, self.ct_field), obj=self, - id='contenttypes.E002', + id="contenttypes.E002", ) ] else: if not isinstance(field, models.ForeignKey): return [ checks.Error( - "'%s.%s' is not a ForeignKey." % ( - self.model._meta.object_name, self.ct_field - ), + "'%s.%s' is not a ForeignKey." + % (self.model._meta.object_name, self.ct_field), hint=( "GenericForeignKeys must use a ForeignKey to " "'contenttypes.ContentType' as the 'content_type' field." ), obj=self, - id='contenttypes.E003', + id="contenttypes.E003", ) ] elif field.remote_field.model != ContentType: return [ checks.Error( - "'%s.%s' is not a ForeignKey to 'contenttypes.ContentType'." % ( - self.model._meta.object_name, self.ct_field - ), + "'%s.%s' is not a ForeignKey to 'contenttypes.ContentType'." + % (self.model._meta.object_name, self.ct_field), hint=( "GenericForeignKeys must use a ForeignKey to " "'contenttypes.ContentType' as the 'content_type' field." ), obj=self, - id='contenttypes.E004', + id="contenttypes.E004", ) ] else: @@ -163,7 +163,8 @@ class GenericForeignKey(FieldCacheMixin): def get_content_type(self, obj=None, id=None, using=None): if obj is not None: return ContentType.objects.db_manager(obj._state.db).get_for_model( - obj, for_concrete_model=self.for_concrete_model) + obj, for_concrete_model=self.for_concrete_model + ) elif id is not None: return ContentType.objects.db_manager(using).get_for_id(id) else: @@ -202,10 +203,13 @@ class GenericForeignKey(FieldCacheMixin): if ct_id is None: return None else: - model = self.get_content_type(id=ct_id, - using=obj._state.db).model_class() - return (model._meta.pk.get_prep_value(getattr(obj, self.fk_field)), - model) + model = self.get_content_type( + id=ct_id, using=obj._state.db + ).model_class() + return ( + model._meta.pk.get_prep_value(getattr(obj, self.fk_field)), + model, + ) return ( ret_val, @@ -232,7 +236,9 @@ class GenericForeignKey(FieldCacheMixin): if rel_obj is None and self.is_cached(instance): return rel_obj if rel_obj is not None: - ct_match = ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id + ct_match = ( + ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id + ) pk_match = rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk if ct_match and pk_match: return rel_obj @@ -264,11 +270,21 @@ class GenericRel(ForeignObjectRel): Used by GenericRelation to store information about the relation. """ - def __init__(self, field, to, related_name=None, related_query_name=None, limit_choices_to=None): + def __init__( + self, + field, + to, + related_name=None, + related_query_name=None, + limit_choices_to=None, + ): super().__init__( - field, to, related_name=related_query_name or '+', + field, + to, + related_name=related_query_name or "+", related_query_name=related_query_name, - limit_choices_to=limit_choices_to, on_delete=DO_NOTHING, + limit_choices_to=limit_choices_to, + on_delete=DO_NOTHING, ) @@ -290,21 +306,30 @@ class GenericRelation(ForeignObject): mti_inherited = False - def __init__(self, to, object_id_field='object_id', content_type_field='content_type', - for_concrete_model=True, related_query_name=None, limit_choices_to=None, **kwargs): - kwargs['rel'] = self.rel_class( - self, to, + def __init__( + self, + to, + object_id_field="object_id", + content_type_field="content_type", + for_concrete_model=True, + related_query_name=None, + limit_choices_to=None, + **kwargs, + ): + kwargs["rel"] = self.rel_class( + self, + to, related_query_name=related_query_name, limit_choices_to=limit_choices_to, ) # Reverse relations are always nullable (Django can't enforce that a # foreign key on the related model points to this model). - kwargs['null'] = True - kwargs['blank'] = True - kwargs['on_delete'] = models.CASCADE - kwargs['editable'] = False - kwargs['serialize'] = False + kwargs["null"] = True + kwargs["blank"] = True + kwargs["on_delete"] = models.CASCADE + kwargs["editable"] = False + kwargs["serialize"] = False # This construct is somewhat of an abuse of ForeignObject. This field # represents a relation from pk to object_id field. But, this relation @@ -330,9 +355,9 @@ class GenericRelation(ForeignObject): GenericRelation. """ return ( - isinstance(field, GenericForeignKey) and - field.ct_field == self.content_type_field_name and - field.fk_field == self.object_id_field_name + isinstance(field, GenericForeignKey) + and field.ct_field == self.content_type_field_name + and field.fk_field == self.object_id_field_name ) def _check_generic_foreign_key_existence(self): @@ -348,7 +373,7 @@ class GenericRelation(ForeignObject): "'%s', but that model does not have a GenericForeignKey." % target._meta.label, obj=self, - id='contenttypes.E004', + id="contenttypes.E004", ) ] else: @@ -356,7 +381,12 @@ class GenericRelation(ForeignObject): def resolve_related_fields(self): self.to_fields = [self.model._meta.pk.name] - return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] + return [ + ( + self.remote_field.model._meta.get_field(self.object_id_field_name), + self.model._meta.pk, + ) + ] def _get_path_info_with_parent(self, filtered_relation): """ @@ -375,15 +405,17 @@ class GenericRelation(ForeignObject): opts = self.remote_field.model._meta.concrete_model._meta parent_opts = opts.get_field(self.object_id_field_name).model._meta target = parent_opts.pk - path.append(PathInfo( - from_opts=self.model._meta, - to_opts=parent_opts, - target_fields=(target,), - join_field=self.remote_field, - m2m=True, - direct=False, - filtered_relation=filtered_relation, - )) + path.append( + PathInfo( + from_opts=self.model._meta, + to_opts=parent_opts, + target_fields=(target,), + join_field=self.remote_field, + m2m=True, + direct=False, + filtered_relation=filtered_relation, + ) + ) # Collect joins needed for the parent -> child chain. This is easiest # to do if we collect joins for the child -> parent chain and then # reverse the direction (call to reverse() and use of @@ -405,42 +437,46 @@ class GenericRelation(ForeignObject): return self._get_path_info_with_parent(filtered_relation) else: target = opts.pk - return [PathInfo( - from_opts=self.model._meta, - to_opts=opts, - target_fields=(target,), - join_field=self.remote_field, - m2m=True, - direct=False, - filtered_relation=filtered_relation, - )] + return [ + PathInfo( + from_opts=self.model._meta, + to_opts=opts, + target_fields=(target,), + join_field=self.remote_field, + m2m=True, + direct=False, + filtered_relation=filtered_relation, + ) + ] def get_reverse_path_info(self, filtered_relation=None): opts = self.model._meta from_opts = self.remote_field.model._meta - return [PathInfo( - from_opts=from_opts, - to_opts=opts, - target_fields=(opts.pk,), - join_field=self, - m2m=not self.unique, - direct=False, - filtered_relation=filtered_relation, - )] + return [ + PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + ) + ] def value_to_string(self, obj): qs = getattr(obj, self.name).all() return str([instance.pk for instance in qs]) def contribute_to_class(self, cls, name, **kwargs): - kwargs['private_only'] = True + kwargs["private_only"] = True super().contribute_to_class(cls, name, **kwargs) self.model = cls # Disable the reverse relation for fields inherited by subclasses of a # model in multi-table inheritance. The reverse relation points to the # field of the base model. if self.mti_inherited: - self.remote_field.related_name = '+' + self.remote_field.related_name = "+" self.remote_field.related_query_name = None setattr(cls, self.name, ReverseGenericManyToOneDescriptor(self.remote_field)) @@ -450,10 +486,16 @@ class GenericRelation(ForeignObject): if not cls._meta.abstract: def make_generic_foreign_order_accessors(related_model, model): - if self._is_matching_generic_foreign_key(model._meta.order_with_respect_to): + if self._is_matching_generic_foreign_key( + model._meta.order_with_respect_to + ): make_foreign_order_accessors(model, related_model) - lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model) + lazy_related_operation( + make_generic_foreign_order_accessors, + self.model, + self.remote_field.model, + ) def set_attributes_from_rel(self): pass @@ -465,24 +507,29 @@ class GenericRelation(ForeignObject): """ Return the content type associated with this field's model. """ - return ContentType.objects.get_for_model(self.model, - for_concrete_model=self.for_concrete_model) + return ContentType.objects.get_for_model( + self.model, for_concrete_model=self.for_concrete_model + ) def get_extra_restriction(self, alias, remote_alias): field = self.remote_field.model._meta.get_field(self.content_type_field_name) contenttype_pk = self.get_content_type().pk - lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk) + lookup = field.get_lookup("exact")(field.get_col(remote_alias), contenttype_pk) return WhereNode([lookup], connector=AND) def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): """ Return all objects related to ``objs`` via this ``GenericRelation``. """ - return self.remote_field.model._base_manager.db_manager(using).filter(**{ - "%s__pk" % self.content_type_field_name: ContentType.objects.db_manager(using).get_for_model( - self.model, for_concrete_model=self.for_concrete_model).pk, - "%s__in" % self.object_id_field_name: [obj.pk for obj in objs] - }) + return self.remote_field.model._base_manager.db_manager(using).filter( + **{ + "%s__pk" + % self.content_type_field_name: ContentType.objects.db_manager(using) + .get_for_model(self.model, for_concrete_model=self.for_concrete_model) + .pk, + "%s__in" % self.object_id_field_name: [obj.pk for obj in objs], + } + ) class ReverseGenericManyToOneDescriptor(ReverseManyToOneDescriptor): @@ -539,7 +586,7 @@ def create_generic_related_manager(superclass, rel): self.pk_val = instance.pk self.core_filters = { - '%s__pk' % self.content_type_field_name: self.content_type.id, + "%s__pk" % self.content_type_field_name: self.content_type.id, self.object_id_field_name: self.pk_val, } @@ -547,6 +594,7 @@ def create_generic_related_manager(superclass, rel): manager = getattr(self.model, manager) manager_class = create_generic_related_manager(manager.__class__, rel) return manager_class(instance=self.instance) + do_not_call_in_templates = True def __str__(self): @@ -581,8 +629,8 @@ def create_generic_related_manager(superclass, rel): # Group instances by content types. content_type_queries = ( models.Q( - (f'{self.content_type_field_name}__pk', content_type_id), - (f'{self.object_id_field_name}__in', {obj.pk for obj in objs}), + (f"{self.content_type_field_name}__pk", content_type_id), + (f"{self.object_id_field_name}__in", {obj.pk for obj in objs}), ) for content_type_id, objs in itertools.groupby( sorted(instances, key=lambda obj: self.get_content_type(obj).pk), @@ -593,7 +641,7 @@ def create_generic_related_manager(superclass, rel): # We (possibly) need to convert object IDs to the type of the # instances' PK in order to match up instances: object_id_converter = instances[0]._meta.pk.to_python - content_type_id_field_name = '%s_id' % self.content_type_field_name + content_type_id_field_name = "%s_id" % self.content_type_field_name return ( queryset.filter(query), lambda relobj: ( @@ -612,9 +660,10 @@ def create_generic_related_manager(superclass, rel): def check_and_update_obj(obj): if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected, got %r" % ( - self.model._meta.object_name, obj - )) + raise TypeError( + "'%s' instance expected, got %r" + % (self.model._meta.object_name, obj) + ) setattr(obj, self.content_type_field_name, self.content_type) setattr(obj, self.object_id_field_name, self.pk_val) @@ -629,25 +678,30 @@ def create_generic_related_manager(superclass, rel): check_and_update_obj(obj) pks.append(obj.pk) - self.model._base_manager.using(db).filter(pk__in=pks).update(**{ - self.content_type_field_name: self.content_type, - self.object_id_field_name: self.pk_val, - }) + self.model._base_manager.using(db).filter(pk__in=pks).update( + **{ + self.content_type_field_name: self.content_type, + self.object_id_field_name: self.pk_val, + } + ) else: with transaction.atomic(using=db, savepoint=False): for obj in objs: check_and_update_obj(obj) obj.save() + add.alters_data = True def remove(self, *objs, bulk=True): if not objs: return self._clear(self.filter(pk__in=[o.pk for o in objs]), bulk) + remove.alters_data = True def clear(self, *, bulk=True): self._clear(self, bulk) + clear.alters_data = True def _clear(self, queryset, bulk): @@ -662,6 +716,7 @@ def create_generic_related_manager(superclass, rel): with transaction.atomic(using=db, savepoint=False): for obj in queryset: obj.delete() + _clear.alters_data = True def set(self, objs, *, bulk=True, clear=False): @@ -685,6 +740,7 @@ def create_generic_related_manager(superclass, rel): self.remove(*old_objs) self.add(*new_objs, bulk=bulk) + set.alters_data = True def create(self, **kwargs): @@ -693,6 +749,7 @@ def create_generic_related_manager(superclass, rel): kwargs[self.object_id_field_name] = self.pk_val db = router.db_for_write(self.model, instance=self.instance) return super().using(db).create(**kwargs) + create.alters_data = True def get_or_create(self, **kwargs): @@ -700,6 +757,7 @@ def create_generic_related_manager(superclass, rel): kwargs[self.object_id_field_name] = self.pk_val db = router.db_for_write(self.model, instance=self.instance) return super().using(db).get_or_create(**kwargs) + get_or_create.alters_data = True def update_or_create(self, **kwargs): @@ -707,6 +765,7 @@ def create_generic_related_manager(superclass, rel): kwargs[self.object_id_field_name] = self.pk_val db = router.db_for_write(self.model, instance=self.instance) return super().using(db).update_or_create(**kwargs) + update_or_create.alters_data = True return GenericRelatedObjectManager diff --git a/django/contrib/contenttypes/forms.py b/django/contrib/contenttypes/forms.py index 92a58d49f8..c0ff4f7257 100644 --- a/django/contrib/contenttypes/forms.py +++ b/django/contrib/contenttypes/forms.py @@ -9,13 +9,26 @@ class BaseGenericInlineFormSet(BaseModelFormSet): A formset for generic inline objects to a parent. """ - def __init__(self, data=None, files=None, instance=None, save_as_new=False, - prefix=None, queryset=None, **kwargs): + def __init__( + self, + data=None, + files=None, + instance=None, + save_as_new=False, + prefix=None, + queryset=None, + **kwargs, + ): opts = self.model._meta self.instance = instance self.rel_name = ( - opts.app_label + '-' + opts.model_name + '-' + - self.ct_field.name + '-' + self.ct_fk_field.name + opts.app_label + + "-" + + opts.model_name + + "-" + + self.ct_field.name + + "-" + + self.ct_fk_field.name ) self.save_as_new = save_as_new if self.instance is None or self.instance.pk is None: @@ -23,11 +36,14 @@ class BaseGenericInlineFormSet(BaseModelFormSet): else: if queryset is None: queryset = self.model._default_manager - qs = queryset.filter(**{ - self.ct_field.name: ContentType.objects.get_for_model( - self.instance, for_concrete_model=self.for_concrete_model), - self.ct_fk_field.name: self.instance.pk, - }) + qs = queryset.filter( + **{ + self.ct_field.name: ContentType.objects.get_for_model( + self.instance, for_concrete_model=self.for_concrete_model + ), + self.ct_fk_field.name: self.instance.pk, + } + ) super().__init__(queryset=qs, data=data, files=files, prefix=prefix, **kwargs) def initial_form_count(self): @@ -39,25 +55,45 @@ class BaseGenericInlineFormSet(BaseModelFormSet): def get_default_prefix(cls): opts = cls.model._meta return ( - opts.app_label + '-' + opts.model_name + '-' + - cls.ct_field.name + '-' + cls.ct_fk_field.name + opts.app_label + + "-" + + opts.model_name + + "-" + + cls.ct_field.name + + "-" + + cls.ct_fk_field.name ) def save_new(self, form, commit=True): - setattr(form.instance, self.ct_field.get_attname(), ContentType.objects.get_for_model(self.instance).pk) + setattr( + form.instance, + self.ct_field.get_attname(), + ContentType.objects.get_for_model(self.instance).pk, + ) setattr(form.instance, self.ct_fk_field.get_attname(), self.instance.pk) return form.save(commit=commit) -def generic_inlineformset_factory(model, form=ModelForm, - formset=BaseGenericInlineFormSet, - ct_field="content_type", fk_field="object_id", - fields=None, exclude=None, - extra=3, can_order=False, can_delete=True, - max_num=None, formfield_callback=None, - validate_max=False, for_concrete_model=True, - min_num=None, validate_min=False, - absolute_max=None, can_delete_extra=True): +def generic_inlineformset_factory( + model, + form=ModelForm, + formset=BaseGenericInlineFormSet, + ct_field="content_type", + fk_field="object_id", + fields=None, + exclude=None, + extra=3, + can_order=False, + can_delete=True, + max_num=None, + formfield_callback=None, + validate_max=False, + for_concrete_model=True, + min_num=None, + validate_min=False, + absolute_max=None, + can_delete_extra=True, +): """ Return a ``GenericInlineFormSet`` for the given kwargs. @@ -67,16 +103,29 @@ def generic_inlineformset_factory(model, form=ModelForm, opts = model._meta # if there is no field called `ct_field` let the exception propagate ct_field = opts.get_field(ct_field) - if not isinstance(ct_field, models.ForeignKey) or ct_field.remote_field.model != ContentType: + if ( + not isinstance(ct_field, models.ForeignKey) + or ct_field.remote_field.model != ContentType + ): raise Exception("fk_name '%s' is not a ForeignKey to ContentType" % ct_field) fk_field = opts.get_field(fk_field) # let the exception propagate exclude = [*(exclude or []), ct_field.name, fk_field.name] FormSet = modelformset_factory( - model, form=form, formfield_callback=formfield_callback, - formset=formset, extra=extra, can_delete=can_delete, - can_order=can_order, fields=fields, exclude=exclude, max_num=max_num, - validate_max=validate_max, min_num=min_num, validate_min=validate_min, - absolute_max=absolute_max, can_delete_extra=can_delete_extra, + model, + form=form, + formfield_callback=formfield_callback, + formset=formset, + extra=extra, + can_delete=can_delete, + can_order=can_order, + fields=fields, + exclude=exclude, + max_num=max_num, + validate_max=validate_max, + min_num=min_num, + validate_min=validate_min, + absolute_max=absolute_max, + can_delete_extra=can_delete_extra, ) FormSet.ct_field = ct_field FormSet.ct_fk_field = fk_field diff --git a/django/contrib/contenttypes/management/__init__.py b/django/contrib/contenttypes/management/__init__.py index 4971fa0c4b..903b9ab1a0 100644 --- a/django/contrib/contenttypes/management/__init__.py +++ b/django/contrib/contenttypes/management/__init__.py @@ -1,7 +1,5 @@ from django.apps import apps as global_apps -from django.db import ( - DEFAULT_DB_ALIAS, IntegrityError, migrations, router, transaction, -) +from django.db import DEFAULT_DB_ALIAS, IntegrityError, migrations, router, transaction class RenameContentType(migrations.RunPython): @@ -12,20 +10,22 @@ class RenameContentType(migrations.RunPython): super().__init__(self.rename_forward, self.rename_backward) def _rename(self, apps, schema_editor, old_model, new_model): - ContentType = apps.get_model('contenttypes', 'ContentType') + ContentType = apps.get_model("contenttypes", "ContentType") db = schema_editor.connection.alias if not router.allow_migrate_model(db, ContentType): return try: - content_type = ContentType.objects.db_manager(db).get_by_natural_key(self.app_label, old_model) + content_type = ContentType.objects.db_manager(db).get_by_natural_key( + self.app_label, old_model + ) except ContentType.DoesNotExist: pass else: content_type.model = new_model try: with transaction.atomic(using=db): - content_type.save(using=db, update_fields={'model'}) + content_type.save(using=db, update_fields={"model"}) except IntegrityError: # Gracefully fallback if a stale content type causes a # conflict as remove_stale_contenttypes will take care of @@ -43,7 +43,9 @@ class RenameContentType(migrations.RunPython): self._rename(apps, schema_editor, self.new_model, self.old_model) -def inject_rename_contenttypes_operations(plan=None, apps=global_apps, using=DEFAULT_DB_ALIAS, **kwargs): +def inject_rename_contenttypes_operations( + plan=None, apps=global_apps, using=DEFAULT_DB_ALIAS, **kwargs +): """ Insert a `RenameContentType` operation after every planned `RenameModel` operation. @@ -53,7 +55,7 @@ def inject_rename_contenttypes_operations(plan=None, apps=global_apps, using=DEF # Determine whether or not the ContentType model is available. try: - ContentType = apps.get_model('contenttypes', 'ContentType') + ContentType = apps.get_model("contenttypes", "ContentType") except LookupError: available = False else: @@ -62,7 +64,7 @@ def inject_rename_contenttypes_operations(plan=None, apps=global_apps, using=DEF available = True for migration, backward in plan: - if (migration.app_label, migration.name) == ('contenttypes', '0001_initial'): + if (migration.app_label, migration.name) == ("contenttypes", "0001_initial"): # There's no point in going forward if the initial contenttypes # migration is unapplied as the ContentType model will be # unavailable from this point. @@ -78,7 +80,9 @@ def inject_rename_contenttypes_operations(plan=None, apps=global_apps, using=DEF for index, operation in enumerate(migration.operations): if isinstance(operation, migrations.RenameModel): operation = RenameContentType( - migration.app_label, operation.old_name_lower, operation.new_name_lower + migration.app_label, + operation.old_name_lower, + operation.new_name_lower, ) inserts.append((index + 1, operation)) for inserted, (index, operation) in enumerate(inserts): @@ -95,14 +99,18 @@ def get_contenttypes_and_models(app_config, using, ContentType): ct.model: ct for ct in ContentType.objects.using(using).filter(app_label=app_config.label) } - app_models = { - model._meta.model_name: model - for model in app_config.get_models() - } + app_models = {model._meta.model_name: model for model in app_config.get_models()} return content_types, app_models -def create_contenttypes(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, apps=global_apps, **kwargs): +def create_contenttypes( + app_config, + verbosity=2, + interactive=True, + using=DEFAULT_DB_ALIAS, + apps=global_apps, + **kwargs, +): """ Create content types for models in the given app. """ @@ -112,11 +120,13 @@ def create_contenttypes(app_config, verbosity=2, interactive=True, using=DEFAULT app_label = app_config.label try: app_config = apps.get_app_config(app_label) - ContentType = apps.get_model('contenttypes', 'ContentType') + ContentType = apps.get_model("contenttypes", "ContentType") except LookupError: return - content_types, app_models = get_contenttypes_and_models(app_config, using, ContentType) + content_types, app_models = get_contenttypes_and_models( + app_config, using, ContentType + ) if not app_models: return diff --git a/django/contrib/contenttypes/management/commands/remove_stale_contenttypes.py b/django/contrib/contenttypes/management/commands/remove_stale_contenttypes.py index 5593ecb469..3002ceaf53 100644 --- a/django/contrib/contenttypes/management/commands/remove_stale_contenttypes.py +++ b/django/contrib/contenttypes/management/commands/remove_stale_contenttypes.py @@ -8,18 +8,23 @@ from django.db.models.deletion import Collector class Command(BaseCommand): - def add_arguments(self, parser): parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates the database to use. Defaults to the "default" database.', ) parser.add_argument( - '--include-stale-apps', action='store_true', default=False, + "--include-stale-apps", + action="store_true", + default=False, help=( "Deletes stale content types including ones from previously " "installed apps that have been removed from INSTALLED_APPS." @@ -27,17 +32,17 @@ class Command(BaseCommand): ) def handle(self, **options): - db = options['database'] - include_stale_apps = options['include_stale_apps'] - interactive = options['interactive'] - verbosity = options['verbosity'] + db = options["database"] + include_stale_apps = options["include_stale_apps"] + interactive = options["interactive"] + verbosity = options["verbosity"] if not router.allow_migrate_model(db, ContentType): return ContentType.objects.clear_cache() apps_content_types = itertools.groupby( - ContentType.objects.using(db).order_by('app_label', 'model'), + ContentType.objects.using(db).order_by("app_label", "model"), lambda obj: obj.app_label, ) for app_label, content_types in apps_content_types: @@ -50,18 +55,24 @@ class Command(BaseCommand): if interactive: ct_info = [] for ct in to_remove: - ct_info.append(' - Content type for %s.%s' % (ct.app_label, ct.model)) + ct_info.append( + " - Content type for %s.%s" % (ct.app_label, ct.model) + ) collector = NoFastDeleteCollector(using=using, origin=ct) collector.collect([ct]) for obj_type, objs in collector.data.items(): if objs != {ct}: - ct_info.append(' - %s %s object(s)' % ( - len(objs), - obj_type._meta.label, - )) - content_type_display = '\n'.join(ct_info) - self.stdout.write("""Some content types in your database are stale and can be deleted. + ct_info.append( + " - %s %s object(s)" + % ( + len(objs), + obj_type._meta.label, + ) + ) + content_type_display = "\n".join(ct_info) + self.stdout.write( + """Some content types in your database are stale and can be deleted. Any objects that depend on these content types will also be deleted. The content types and dependent objects that would be deleted are: @@ -71,15 +82,20 @@ This list doesn't include any cascade deletions to data outside of Django's models (uncommon). Are you sure you want to delete these content types? -If you're unsure, answer 'no'.""" % content_type_display) +If you're unsure, answer 'no'.""" + % content_type_display + ) ok_to_delete = input("Type 'yes' to continue, or 'no' to cancel: ") else: - ok_to_delete = 'yes' + ok_to_delete = "yes" - if ok_to_delete == 'yes': + if ok_to_delete == "yes": for ct in to_remove: if verbosity >= 2: - self.stdout.write("Deleting stale content type '%s | %s'" % (ct.app_label, ct.model)) + self.stdout.write( + "Deleting stale content type '%s | %s'" + % (ct.app_label, ct.model) + ) ct.delete() else: if verbosity >= 2: diff --git a/django/contrib/contenttypes/migrations/0001_initial.py b/django/contrib/contenttypes/migrations/0001_initial.py index e55c320d80..5468fb6a0d 100644 --- a/django/contrib/contenttypes/migrations/0001_initial.py +++ b/django/contrib/contenttypes/migrations/0001_initial.py @@ -4,31 +4,43 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( - name='ContentType', + name="ContentType", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(max_length=100)), - ('app_label', models.CharField(max_length=100)), - ('model', models.CharField(max_length=100, verbose_name='python model class name')), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ("name", models.CharField(max_length=100)), + ("app_label", models.CharField(max_length=100)), + ( + "model", + models.CharField( + max_length=100, verbose_name="python model class name" + ), + ), ], options={ - 'ordering': ('name',), - 'db_table': 'django_content_type', - 'verbose_name': 'content type', - 'verbose_name_plural': 'content types', + "ordering": ("name",), + "db_table": "django_content_type", + "verbose_name": "content type", + "verbose_name_plural": "content types", }, bases=(models.Model,), managers=[ - ('objects', django.contrib.contenttypes.models.ContentTypeManager()), + ("objects", django.contrib.contenttypes.models.ContentTypeManager()), ], ), migrations.AlterUniqueTogether( - name='contenttype', - unique_together={('app_label', 'model')}, + name="contenttype", + unique_together={("app_label", "model")}, ), ] diff --git a/django/contrib/contenttypes/migrations/0002_remove_content_type_name.py b/django/contrib/contenttypes/migrations/0002_remove_content_type_name.py index c88e603147..3bee3a864f 100644 --- a/django/contrib/contenttypes/migrations/0002_remove_content_type_name.py +++ b/django/contrib/contenttypes/migrations/0002_remove_content_type_name.py @@ -2,7 +2,7 @@ from django.db import migrations, models def add_legacy_name(apps, schema_editor): - ContentType = apps.get_model('contenttypes', 'ContentType') + ContentType = apps.get_model("contenttypes", "ContentType") for ct in ContentType.objects.all(): try: ct.name = apps.get_model(ct.app_label, ct.model)._meta.object_name @@ -14,26 +14,29 @@ def add_legacy_name(apps, schema_editor): class Migration(migrations.Migration): dependencies = [ - ('contenttypes', '0001_initial'), + ("contenttypes", "0001_initial"), ] operations = [ migrations.AlterModelOptions( - name='contenttype', - options={'verbose_name': 'content type', 'verbose_name_plural': 'content types'}, + name="contenttype", + options={ + "verbose_name": "content type", + "verbose_name_plural": "content types", + }, ), migrations.AlterField( - model_name='contenttype', - name='name', + model_name="contenttype", + name="name", field=models.CharField(max_length=100, null=True), ), migrations.RunPython( migrations.RunPython.noop, add_legacy_name, - hints={'model_name': 'contenttype'}, + hints={"model_name": "contenttype"}, ), migrations.RemoveField( - model_name='contenttype', - name='name', + model_name="contenttype", + name="name", ), ] diff --git a/django/contrib/contenttypes/models.py b/django/contrib/contenttypes/models.py index edc488728a..e4734cc893 100644 --- a/django/contrib/contenttypes/models.py +++ b/django/contrib/contenttypes/models.py @@ -81,10 +81,7 @@ class ContentTypeManager(models.Manager): results[model] = ct if needed_opts: # Lookup required content types from the DB. - cts = self.filter( - app_label__in=needed_app_labels, - model__in=needed_models - ) + cts = self.filter(app_label__in=needed_app_labels, model__in=needed_models) for ct in cts: opts_models = needed_opts.pop(ct.model_class()._meta, []) for model in opts_models: @@ -132,14 +129,14 @@ class ContentTypeManager(models.Manager): class ContentType(models.Model): app_label = models.CharField(max_length=100) - model = models.CharField(_('python model class name'), max_length=100) + model = models.CharField(_("python model class name"), max_length=100) objects = ContentTypeManager() class Meta: - verbose_name = _('content type') - verbose_name_plural = _('content types') - db_table = 'django_content_type' - unique_together = [['app_label', 'model']] + verbose_name = _("content type") + verbose_name_plural = _("content types") + db_table = "django_content_type" + unique_together = [["app_label", "model"]] def __str__(self): return self.app_labeled_name @@ -156,7 +153,7 @@ class ContentType(models.Model): model = self.model_class() if not model: return self.model - return '%s | %s' % (model._meta.app_label, model._meta.verbose_name) + return "%s | %s" % (model._meta.app_label, model._meta.verbose_name) def model_class(self): """Return the model class for this type of content.""" diff --git a/django/contrib/contenttypes/views.py b/django/contrib/contenttypes/views.py index 3231e6bfe1..bfde73c567 100644 --- a/django/contrib/contenttypes/views.py +++ b/django/contrib/contenttypes/views.py @@ -15,22 +15,22 @@ def shortcut(request, content_type_id, object_id): content_type = ContentType.objects.get(pk=content_type_id) if not content_type.model_class(): raise Http404( - _("Content type %(ct_id)s object has no associated model") % - {'ct_id': content_type_id} + _("Content type %(ct_id)s object has no associated model") + % {"ct_id": content_type_id} ) obj = content_type.get_object_for_this_type(pk=object_id) except (ObjectDoesNotExist, ValueError): raise Http404( - _('Content type %(ct_id)s object %(obj_id)s doesn’t exist') % - {'ct_id': content_type_id, 'obj_id': object_id} + _("Content type %(ct_id)s object %(obj_id)s doesn’t exist") + % {"ct_id": content_type_id, "obj_id": object_id} ) try: get_absolute_url = obj.get_absolute_url except AttributeError: raise Http404( - _('%(ct_name)s objects don’t have a get_absolute_url() method') % - {'ct_name': content_type.name} + _("%(ct_name)s objects don’t have a get_absolute_url() method") + % {"ct_name": content_type.name} ) absurl = get_absolute_url() @@ -38,7 +38,7 @@ def shortcut(request, content_type_id, object_id): # if necessary. # If the object actually defines a domain, we're done. - if absurl.startswith(('http://', 'https://', '//')): + if absurl.startswith(("http://", "https://", "//")): return HttpResponseRedirect(absurl) # Otherwise, we need to introspect the object's relationships for a @@ -48,8 +48,8 @@ def shortcut(request, content_type_id, object_id): except ObjectDoesNotExist: object_domain = None - if apps.is_installed('django.contrib.sites'): - Site = apps.get_model('sites.Site') + if apps.is_installed("django.contrib.sites"): + Site = apps.get_model("sites.Site") opts = obj._meta for field in opts.many_to_many: @@ -83,6 +83,6 @@ def shortcut(request, content_type_id, object_id): # to whatever get_absolute_url() returned. if object_domain is not None: protocol = request.scheme - return HttpResponseRedirect('%s://%s%s' % (protocol, object_domain, absurl)) + return HttpResponseRedirect("%s://%s%s" % (protocol, object_domain, absurl)) else: return HttpResponseRedirect(absurl) diff --git a/django/contrib/flatpages/admin.py b/django/contrib/flatpages/admin.py index ead6b52b50..5bbc2ad726 100644 --- a/django/contrib/flatpages/admin.py +++ b/django/contrib/flatpages/admin.py @@ -8,12 +8,15 @@ from django.utils.translation import gettext_lazy as _ class FlatPageAdmin(admin.ModelAdmin): form = FlatpageForm fieldsets = ( - (None, {'fields': ('url', 'title', 'content', 'sites')}), - (_('Advanced options'), { - 'classes': ('collapse',), - 'fields': ('registration_required', 'template_name'), - }), + (None, {"fields": ("url", "title", "content", "sites")}), + ( + _("Advanced options"), + { + "classes": ("collapse",), + "fields": ("registration_required", "template_name"), + }, + ), ) - list_display = ('url', 'title') - list_filter = ('sites', 'registration_required') - search_fields = ('url', 'title') + list_display = ("url", "title") + list_filter = ("sites", "registration_required") + search_fields = ("url", "title") diff --git a/django/contrib/flatpages/apps.py b/django/contrib/flatpages/apps.py index 4f5ef17004..eb9f470b59 100644 --- a/django/contrib/flatpages/apps.py +++ b/django/contrib/flatpages/apps.py @@ -3,6 +3,6 @@ from django.utils.translation import gettext_lazy as _ class FlatPagesConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.flatpages' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.flatpages" verbose_name = _("Flat Pages") diff --git a/django/contrib/flatpages/forms.py b/django/contrib/flatpages/forms.py index f5ee76443a..18fdfa658b 100644 --- a/django/contrib/flatpages/forms.py +++ b/django/contrib/flatpages/forms.py @@ -2,15 +2,18 @@ from django import forms from django.conf import settings from django.contrib.flatpages.models import FlatPage from django.core.exceptions import ValidationError -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ class FlatpageForm(forms.ModelForm): url = forms.RegexField( label=_("URL"), max_length=100, - regex=r'^[-\w/\.~]+$', - help_text=_('Example: “/about/contact/”. Make sure to have leading and trailing slashes.'), + regex=r"^[-\w/\.~]+$", + help_text=_( + "Example: “/about/contact/”. Make sure to have leading and trailing slashes." + ), error_messages={ "invalid": _( "This value must contain only letters, numbers, dots, " @@ -21,38 +24,38 @@ class FlatpageForm(forms.ModelForm): class Meta: model = FlatPage - fields = '__all__' + fields = "__all__" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self._trailing_slash_required(): - self.fields['url'].help_text = _( - 'Example: “/about/contact”. Make sure to have a leading slash.' + self.fields["url"].help_text = _( + "Example: “/about/contact”. Make sure to have a leading slash." ) def _trailing_slash_required(self): return ( - settings.APPEND_SLASH and - 'django.middleware.common.CommonMiddleware' in settings.MIDDLEWARE + settings.APPEND_SLASH + and "django.middleware.common.CommonMiddleware" in settings.MIDDLEWARE ) def clean_url(self): - url = self.cleaned_data['url'] - if not url.startswith('/'): + url = self.cleaned_data["url"] + if not url.startswith("/"): raise ValidationError( gettext("URL is missing a leading slash."), - code='missing_leading_slash', + code="missing_leading_slash", ) - if self._trailing_slash_required() and not url.endswith('/'): + if self._trailing_slash_required() and not url.endswith("/"): raise ValidationError( gettext("URL is missing a trailing slash."), - code='missing_trailing_slash', + code="missing_trailing_slash", ) return url def clean(self): - url = self.cleaned_data.get('url') - sites = self.cleaned_data.get('sites') + url = self.cleaned_data.get("url") + sites = self.cleaned_data.get("sites") same_url = FlatPage.objects.filter(url=url) if self.instance.pk: @@ -62,9 +65,9 @@ class FlatpageForm(forms.ModelForm): for site in sites: if same_url.filter(sites=site).exists(): raise ValidationError( - _('Flatpage with url %(url)s already exists for site %(site)s'), - code='duplicate_url', - params={'url': url, 'site': site}, + _("Flatpage with url %(url)s already exists for site %(site)s"), + code="duplicate_url", + params={"url": url, "site": site}, ) return super().clean() diff --git a/django/contrib/flatpages/migrations/0001_initial.py b/django/contrib/flatpages/migrations/0001_initial.py index 867cd6d4ea..6faa610181 100644 --- a/django/contrib/flatpages/migrations/0001_initial.py +++ b/django/contrib/flatpages/migrations/0001_initial.py @@ -4,35 +4,62 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('sites', '0001_initial'), + ("sites", "0001_initial"), ] operations = [ migrations.CreateModel( - name='FlatPage', + name="FlatPage", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('url', models.CharField(max_length=100, verbose_name='URL', db_index=True)), - ('title', models.CharField(max_length=200, verbose_name='title')), - ('content', models.TextField(verbose_name='content', blank=True)), - ('enable_comments', models.BooleanField(default=False, verbose_name='enable comments')), - ('template_name', models.CharField( - help_text=( - 'Example: “flatpages/contact_page.html”. If this isn’t provided, the system will use ' - '“flatpages/default.html”.' - ), max_length=70, verbose_name='template name', blank=True - )), - ('registration_required', models.BooleanField( - default=False, help_text='If this is checked, only logged-in users will be able to view the page.', - verbose_name='registration required' - )), - ('sites', models.ManyToManyField(to='sites.Site', verbose_name='sites')), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "url", + models.CharField(max_length=100, verbose_name="URL", db_index=True), + ), + ("title", models.CharField(max_length=200, verbose_name="title")), + ("content", models.TextField(verbose_name="content", blank=True)), + ( + "enable_comments", + models.BooleanField(default=False, verbose_name="enable comments"), + ), + ( + "template_name", + models.CharField( + help_text=( + "Example: “flatpages/contact_page.html”. If this isn’t provided, the system will use " + "“flatpages/default.html”." + ), + max_length=70, + verbose_name="template name", + blank=True, + ), + ), + ( + "registration_required", + models.BooleanField( + default=False, + help_text="If this is checked, only logged-in users will be able to view the page.", + verbose_name="registration required", + ), + ), + ( + "sites", + models.ManyToManyField(to="sites.Site", verbose_name="sites"), + ), ], options={ - 'ordering': ['url'], - 'db_table': 'django_flatpage', - 'verbose_name': 'flat page', - 'verbose_name_plural': 'flat pages', + "ordering": ["url"], + "db_table": "django_flatpage", + "verbose_name": "flat page", + "verbose_name_plural": "flat pages", }, bases=(models.Model,), ), diff --git a/django/contrib/flatpages/models.py b/django/contrib/flatpages/models.py index 2f2473b842..d001cd182c 100644 --- a/django/contrib/flatpages/models.py +++ b/django/contrib/flatpages/models.py @@ -6,31 +6,33 @@ from django.utils.translation import gettext_lazy as _ class FlatPage(models.Model): - url = models.CharField(_('URL'), max_length=100, db_index=True) - title = models.CharField(_('title'), max_length=200) - content = models.TextField(_('content'), blank=True) - enable_comments = models.BooleanField(_('enable comments'), default=False) + url = models.CharField(_("URL"), max_length=100, db_index=True) + title = models.CharField(_("title"), max_length=200) + content = models.TextField(_("content"), blank=True) + enable_comments = models.BooleanField(_("enable comments"), default=False) template_name = models.CharField( - _('template name'), + _("template name"), max_length=70, blank=True, help_text=_( - 'Example: “flatpages/contact_page.html”. If this isn’t provided, ' - 'the system will use “flatpages/default.html”.' + "Example: “flatpages/contact_page.html”. If this isn’t provided, " + "the system will use “flatpages/default.html”." ), ) registration_required = models.BooleanField( - _('registration required'), - help_text=_("If this is checked, only logged-in users will be able to view the page."), + _("registration required"), + help_text=_( + "If this is checked, only logged-in users will be able to view the page." + ), default=False, ) - sites = models.ManyToManyField(Site, verbose_name=_('sites')) + sites = models.ManyToManyField(Site, verbose_name=_("sites")) class Meta: - db_table = 'django_flatpage' - verbose_name = _('flat page') - verbose_name_plural = _('flat pages') - ordering = ['url'] + db_table = "django_flatpage" + verbose_name = _("flat page") + verbose_name_plural = _("flat pages") + ordering = ["url"] def __str__(self): return "%s -- %s" % (self.url, self.title) @@ -38,10 +40,10 @@ class FlatPage(models.Model): def get_absolute_url(self): from .views import flatpage - for url in (self.url.lstrip('/'), self.url): + for url in (self.url.lstrip("/"), self.url): try: - return reverse(flatpage, kwargs={'url': url}) + return reverse(flatpage, kwargs={"url": url}) except NoReverseMatch: pass # Handle script prefix manually because we bypass reverse() - return iri_to_uri(get_script_prefix().rstrip('/') + self.url) + return iri_to_uri(get_script_prefix().rstrip("/") + self.url) diff --git a/django/contrib/flatpages/sitemaps.py b/django/contrib/flatpages/sitemaps.py index a144023178..405507c3a3 100644 --- a/django/contrib/flatpages/sitemaps.py +++ b/django/contrib/flatpages/sitemaps.py @@ -5,8 +5,10 @@ from django.core.exceptions import ImproperlyConfigured class FlatPageSitemap(Sitemap): def items(self): - if not django_apps.is_installed('django.contrib.sites'): - raise ImproperlyConfigured("FlatPageSitemap requires django.contrib.sites, which isn't installed.") - Site = django_apps.get_model('sites.Site') + if not django_apps.is_installed("django.contrib.sites"): + raise ImproperlyConfigured( + "FlatPageSitemap requires django.contrib.sites, which isn't installed." + ) + Site = django_apps.get_model("sites.Site") current_site = Site.objects.get_current() return current_site.flatpage_set.filter(registration_required=False) diff --git a/django/contrib/flatpages/templatetags/flatpages.py b/django/contrib/flatpages/templatetags/flatpages.py index 1d99104ce3..9ec6b8808a 100644 --- a/django/contrib/flatpages/templatetags/flatpages.py +++ b/django/contrib/flatpages/templatetags/flatpages.py @@ -19,15 +19,16 @@ class FlatpageNode(template.Node): self.user = None def render(self, context): - if 'request' in context: - site_pk = get_current_site(context['request']).pk + if "request" in context: + site_pk = get_current_site(context["request"]).pk else: site_pk = settings.SITE_ID flatpages = FlatPage.objects.filter(sites__id=site_pk) # If a prefix was specified, add a filter if self.starts_with: flatpages = flatpages.filter( - url__startswith=self.starts_with.resolve(context)) + url__startswith=self.starts_with.resolve(context) + ) # If the provided user is not authenticated, or no user # was provided, filter the list to only public flatpages. @@ -39,7 +40,7 @@ class FlatpageNode(template.Node): flatpages = flatpages.filter(registration_required=False) context[self.context_name] = flatpages - return '' + return "" @register.tag @@ -70,9 +71,10 @@ def get_flatpages(parser, token): {% get_flatpages '/about/' for someuser as about_pages %} """ bits = token.split_contents() - syntax_message = ("%(tag_name)s expects a syntax of %(tag_name)s " - "['url_starts_with'] [for user] as context_name" % - {'tag_name': bits[0]}) + syntax_message = ( + "%(tag_name)s expects a syntax of %(tag_name)s " + "['url_starts_with'] [for user] as context_name" % {"tag_name": bits[0]} + ) # Must have at 3-6 bits in the tag if 3 <= len(bits) <= 6: # If there's an even number of bits, there's no prefix @@ -82,13 +84,13 @@ def get_flatpages(parser, token): prefix = None # The very last bit must be the context name - if bits[-2] != 'as': + if bits[-2] != "as": raise template.TemplateSyntaxError(syntax_message) context_name = bits[-1] # If there are 5 or 6 bits, there is a user defined if len(bits) >= 5: - if bits[-4] != 'for': + if bits[-4] != "for": raise template.TemplateSyntaxError(syntax_message) user = bits[-3] else: diff --git a/django/contrib/flatpages/urls.py b/django/contrib/flatpages/urls.py index a087fe8c1d..d4480b3fba 100644 --- a/django/contrib/flatpages/urls.py +++ b/django/contrib/flatpages/urls.py @@ -2,5 +2,5 @@ from django.contrib.flatpages import views from django.urls import path urlpatterns = [ - path('<path:url>', views.flatpage, name='django.contrib.flatpages.views.flatpage'), + path("<path:url>", views.flatpage, name="django.contrib.flatpages.views.flatpage"), ] diff --git a/django/contrib/flatpages/views.py b/django/contrib/flatpages/views.py index 2722ec5f94..776f1796a6 100644 --- a/django/contrib/flatpages/views.py +++ b/django/contrib/flatpages/views.py @@ -7,7 +7,7 @@ from django.template import loader from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_protect -DEFAULT_TEMPLATE = 'flatpages/default.html' +DEFAULT_TEMPLATE = "flatpages/default.html" # This view is called from FlatpageFallbackMiddleware.process_response # when a 404 is raised, which often means CsrfViewMiddleware.process_view @@ -30,16 +30,16 @@ def flatpage(request, url): flatpage `flatpages.flatpages` object """ - if not url.startswith('/'): - url = '/' + url + if not url.startswith("/"): + url = "/" + url site_id = get_current_site(request).id try: f = get_object_or_404(FlatPage, url=url, sites=site_id) except Http404: - if not url.endswith('/') and settings.APPEND_SLASH: - url += '/' + if not url.endswith("/") and settings.APPEND_SLASH: + url += "/" f = get_object_or_404(FlatPage, url=url, sites=site_id) - return HttpResponsePermanentRedirect('%s/' % request.path) + return HttpResponsePermanentRedirect("%s/" % request.path) else: raise return render_flatpage(request, f) @@ -54,6 +54,7 @@ def render_flatpage(request, f): # logged in, redirect to the login page. if f.registration_required and not request.user.is_authenticated: from django.contrib.auth.views import redirect_to_login + return redirect_to_login(request.path) if f.template_name: template = loader.select_template((f.template_name, DEFAULT_TEMPLATE)) @@ -66,4 +67,4 @@ def render_flatpage(request, f): f.title = mark_safe(f.title) f.content = mark_safe(f.content) - return HttpResponse(template.render({'flatpage': f}, request)) + return HttpResponse(template.render({"flatpage": f}, request)) diff --git a/django/contrib/gis/admin/__init__.py b/django/contrib/gis/admin/__init__.py index 4abc4898ae..d5bd54c479 100644 --- a/django/contrib/gis/admin/__init__.py +++ b/django/contrib/gis/admin/__init__.py @@ -1,16 +1,34 @@ from django.contrib.admin import ( - HORIZONTAL, VERTICAL, AdminSite, ModelAdmin, StackedInline, TabularInline, - action, autodiscover, display, register, site, -) -from django.contrib.gis.admin.options import ( - GeoModelAdmin, GISModelAdmin, OSMGeoAdmin, + HORIZONTAL, + VERTICAL, + AdminSite, + ModelAdmin, + StackedInline, + TabularInline, + action, + autodiscover, + display, + register, + site, ) +from django.contrib.gis.admin.options import GeoModelAdmin, GISModelAdmin, OSMGeoAdmin from django.contrib.gis.admin.widgets import OpenLayersWidget __all__ = [ - 'HORIZONTAL', 'VERTICAL', 'AdminSite', 'ModelAdmin', 'StackedInline', - 'TabularInline', 'action', 'autodiscover', 'display', 'register', 'site', - 'GISModelAdmin', 'OpenLayersWidget', + "HORIZONTAL", + "VERTICAL", + "AdminSite", + "ModelAdmin", + "StackedInline", + "TabularInline", + "action", + "autodiscover", + "display", + "register", + "site", + "GISModelAdmin", + "OpenLayersWidget", # RemovedInDjango50Warning. - 'GeoModelAdmin', 'OSMGeoAdmin', + "GeoModelAdmin", + "OSMGeoAdmin", ] diff --git a/django/contrib/gis/admin/options.py b/django/contrib/gis/admin/options.py index 524ba4bdc5..9eb6af6335 100644 --- a/django/contrib/gis/admin/options.py +++ b/django/contrib/gis/admin/options.py @@ -14,11 +14,10 @@ class GeoModelAdminMixin: gis_widget_kwargs = {} def formfield_for_dbfield(self, db_field, request, **kwargs): - if ( - isinstance(db_field, models.GeometryField) and - (db_field.dim < 3 or self.gis_widget.supports_3d) + if isinstance(db_field, models.GeometryField) and ( + db_field.dim < 3 or self.gis_widget.supports_3d ): - kwargs['widget'] = self.gis_widget(**self.gis_widget_kwargs) + kwargs["widget"] = self.gis_widget(**self.gis_widget_kwargs) return db_field.formfield(**kwargs) else: return super().formfield_for_dbfield(db_field, request, **kwargs) @@ -38,6 +37,7 @@ class GeoModelAdmin(ModelAdmin): The administration options class for Geographic models. Map settings may be overloaded from their defaults to create custom maps. """ + # The default map settings that may be overloaded -- still subject # to API changes. default_lon = 0 @@ -60,22 +60,25 @@ class GeoModelAdmin(ModelAdmin): map_width = 600 map_height = 400 map_srid = 4326 - map_template = 'gis/admin/openlayers.html' - openlayers_url = 'https://cdnjs.cloudflare.com/ajax/libs/openlayers/2.13.1/OpenLayers.js' + map_template = "gis/admin/openlayers.html" + openlayers_url = ( + "https://cdnjs.cloudflare.com/ajax/libs/openlayers/2.13.1/OpenLayers.js" + ) point_zoom = num_zoom - 6 - wms_url = 'http://vmap0.tiles.osgeo.org/wms/vmap0' - wms_layer = 'basic' - wms_name = 'OpenLayers WMS' - wms_options = {'format': 'image/jpeg'} + wms_url = "http://vmap0.tiles.osgeo.org/wms/vmap0" + wms_layer = "basic" + wms_name = "OpenLayers WMS" + wms_options = {"format": "image/jpeg"} debug = False widget = OpenLayersWidget def __init__(self, *args, **kwargs): warnings.warn( - 'django.contrib.gis.admin.GeoModelAdmin and OSMGeoAdmin are ' - 'deprecated in favor of django.contrib.admin.ModelAdmin and ' - 'django.contrib.gis.admin.GISModelAdmin.', - RemovedInDjango50Warning, stacklevel=2, + "django.contrib.gis.admin.GeoModelAdmin and OSMGeoAdmin are " + "deprecated in favor of django.contrib.admin.ModelAdmin and " + "django.contrib.gis.admin.GISModelAdmin.", + RemovedInDjango50Warning, + stacklevel=2, ) super().__init__(*args, **kwargs) @@ -92,7 +95,7 @@ class GeoModelAdmin(ModelAdmin): """ if isinstance(db_field, models.GeometryField) and db_field.dim < 3: # Setting the widget with the newly defined widget. - kwargs['widget'] = self.get_map_widget(db_field) + kwargs["widget"] = self.get_map_widget(db_field) return db_field.formfield(**kwargs) else: return super().formfield_for_dbfield(db_field, request, **kwargs) @@ -103,68 +106,75 @@ class GeoModelAdmin(ModelAdmin): in the `widget` attribute) using the settings from the attributes set in this class. """ - is_collection = db_field.geom_type in ('MULTIPOINT', 'MULTILINESTRING', 'MULTIPOLYGON', 'GEOMETRYCOLLECTION') + is_collection = db_field.geom_type in ( + "MULTIPOINT", + "MULTILINESTRING", + "MULTIPOLYGON", + "GEOMETRYCOLLECTION", + ) if is_collection: - if db_field.geom_type == 'GEOMETRYCOLLECTION': - collection_type = 'Any' + if db_field.geom_type == "GEOMETRYCOLLECTION": + collection_type = "Any" else: - collection_type = OGRGeomType(db_field.geom_type.replace('MULTI', '')) + collection_type = OGRGeomType(db_field.geom_type.replace("MULTI", "")) else: - collection_type = 'None' + collection_type = "None" class OLMap(self.widget): template_name = self.map_template geom_type = db_field.geom_type - wms_options = '' + wms_options = "" if self.wms_options: wms_options = ["%s: '%s'" % pair for pair in self.wms_options.items()] - wms_options = ', %s' % ', '.join(wms_options) + wms_options = ", %s" % ", ".join(wms_options) params = { - 'default_lon': self.default_lon, - 'default_lat': self.default_lat, - 'default_zoom': self.default_zoom, - 'display_wkt': self.debug or self.display_wkt, - 'geom_type': OGRGeomType(db_field.geom_type), - 'field_name': db_field.name, - 'is_collection': is_collection, - 'scrollable': self.scrollable, - 'layerswitcher': self.layerswitcher, - 'collection_type': collection_type, - 'is_generic': db_field.geom_type == 'GEOMETRY', - 'is_linestring': db_field.geom_type in ('LINESTRING', 'MULTILINESTRING'), - 'is_polygon': db_field.geom_type in ('POLYGON', 'MULTIPOLYGON'), - 'is_point': db_field.geom_type in ('POINT', 'MULTIPOINT'), - 'num_zoom': self.num_zoom, - 'max_zoom': self.max_zoom, - 'min_zoom': self.min_zoom, - 'units': self.units, # likely should get from object - 'max_resolution': self.max_resolution, - 'max_extent': self.max_extent, - 'modifiable': self.modifiable, - 'mouse_position': self.mouse_position, - 'scale_text': self.scale_text, - 'map_width': self.map_width, - 'map_height': self.map_height, - 'point_zoom': self.point_zoom, - 'srid': self.map_srid, - 'display_srid': self.display_srid, - 'wms_url': self.wms_url, - 'wms_layer': self.wms_layer, - 'wms_name': self.wms_name, - 'wms_options': wms_options, - 'debug': self.debug, + "default_lon": self.default_lon, + "default_lat": self.default_lat, + "default_zoom": self.default_zoom, + "display_wkt": self.debug or self.display_wkt, + "geom_type": OGRGeomType(db_field.geom_type), + "field_name": db_field.name, + "is_collection": is_collection, + "scrollable": self.scrollable, + "layerswitcher": self.layerswitcher, + "collection_type": collection_type, + "is_generic": db_field.geom_type == "GEOMETRY", + "is_linestring": db_field.geom_type + in ("LINESTRING", "MULTILINESTRING"), + "is_polygon": db_field.geom_type in ("POLYGON", "MULTIPOLYGON"), + "is_point": db_field.geom_type in ("POINT", "MULTIPOINT"), + "num_zoom": self.num_zoom, + "max_zoom": self.max_zoom, + "min_zoom": self.min_zoom, + "units": self.units, # likely should get from object + "max_resolution": self.max_resolution, + "max_extent": self.max_extent, + "modifiable": self.modifiable, + "mouse_position": self.mouse_position, + "scale_text": self.scale_text, + "map_width": self.map_width, + "map_height": self.map_height, + "point_zoom": self.point_zoom, + "srid": self.map_srid, + "display_srid": self.display_srid, + "wms_url": self.wms_url, + "wms_layer": self.wms_layer, + "wms_name": self.wms_name, + "wms_options": wms_options, + "debug": self.debug, } + return OLMap # RemovedInDjango50Warning. class OSMGeoAdmin(GeoModelAdmin): - map_template = 'gis/admin/osm.html' + map_template = "gis/admin/osm.html" num_zoom = 20 map_srid = spherical_mercator_srid - max_extent = '-20037508,-20037508,20037508,20037508' - max_resolution = '156543.0339' + max_extent = "-20037508,-20037508,20037508,20037508" + max_resolution = "156543.0339" point_zoom = num_zoom - 6 - units = 'm' + units = "m" diff --git a/django/contrib/gis/admin/widgets.py b/django/contrib/gis/admin/widgets.py index 1ce94ec7c2..420c170608 100644 --- a/django/contrib/gis/admin/widgets.py +++ b/django/contrib/gis/admin/widgets.py @@ -7,26 +7,27 @@ from django.utils import translation # Creating a template context that contains Django settings # values needed by admin map templates. -geo_context = {'LANGUAGE_BIDI': translation.get_language_bidi()} -logger = logging.getLogger('django.contrib.gis') +geo_context = {"LANGUAGE_BIDI": translation.get_language_bidi()} +logger = logging.getLogger("django.contrib.gis") class OpenLayersWidget(Textarea): """ Render an OpenLayers map using the WKT of the geometry. """ + def get_context(self, name, value, attrs): # Update the template parameters with any attributes passed in. if attrs: self.params.update(attrs) - self.params['editable'] = self.params['modifiable'] + self.params["editable"] = self.params["modifiable"] else: - self.params['editable'] = True + self.params["editable"] = True # Defaulting the WKT value to a blank string -- this # will be tested in the JavaScript and the appropriate # interface will be constructed. - self.params['wkt'] = '' + self.params["wkt"] = "" # If a string reaches here (via a validation error on another # field) then just reconstruct the Geometry. @@ -37,26 +38,29 @@ class OpenLayersWidget(Textarea): logger.error("Error creating geometry from value '%s' (%s)", value, err) value = None - if (value and value.geom_type.upper() != self.geom_type and - self.geom_type != 'GEOMETRY'): + if ( + value + and value.geom_type.upper() != self.geom_type + and self.geom_type != "GEOMETRY" + ): value = None # Constructing the dictionary of the map options. - self.params['map_options'] = self.map_options() + self.params["map_options"] = self.map_options() # Constructing the JavaScript module name using the name of # the GeometryField (passed in via the `attrs` keyword). # Use the 'name' attr for the field name (rather than 'field') - self.params['name'] = name + self.params["name"] = name # note: we must switch out dashes for underscores since js # functions are created using the module variable - js_safe_name = self.params['name'].replace('-', '_') - self.params['module'] = 'geodjango_%s' % js_safe_name + js_safe_name = self.params["name"].replace("-", "_") + self.params["module"] = "geodjango_%s" % js_safe_name if value: # Transforming the geometry to the projection used on the # OpenLayers map. - srid = self.params['srid'] + srid = self.params["srid"] if value.srid != srid: try: ogr = value.ogr @@ -65,15 +69,17 @@ class OpenLayersWidget(Textarea): except GDALException as err: logger.error( "Error transforming geometry from srid '%s' to srid '%s' (%s)", - value.srid, srid, err + value.srid, + srid, + err, ) - wkt = '' + wkt = "" else: wkt = value.wkt # Setting the parameter WKT with that of the transformed # geometry. - self.params['wkt'] = wkt + self.params["wkt"] = wkt self.params.update(geo_context) return self.params @@ -82,30 +88,31 @@ class OpenLayersWidget(Textarea): """Build the map options hash for the OpenLayers template.""" # JavaScript construction utilities for the Bounds and Projection. def ol_bounds(extent): - return 'new OpenLayers.Bounds(%s)' % extent + return "new OpenLayers.Bounds(%s)" % extent def ol_projection(srid): return 'new OpenLayers.Projection("EPSG:%s")' % srid # An array of the parameter name, the name of their OpenLayers # counterpart, and the type of variable they are. - map_types = [('srid', 'projection', 'srid'), - ('display_srid', 'displayProjection', 'srid'), - ('units', 'units', str), - ('max_resolution', 'maxResolution', float), - ('max_extent', 'maxExtent', 'bounds'), - ('num_zoom', 'numZoomLevels', int), - ('max_zoom', 'maxZoomLevels', int), - ('min_zoom', 'minZoomLevel', int), - ] + map_types = [ + ("srid", "projection", "srid"), + ("display_srid", "displayProjection", "srid"), + ("units", "units", str), + ("max_resolution", "maxResolution", float), + ("max_extent", "maxExtent", "bounds"), + ("num_zoom", "numZoomLevels", int), + ("max_zoom", "maxZoomLevels", int), + ("min_zoom", "minZoomLevel", int), + ] # Building the map options hash. map_options = {} for param_name, js_name, option_type in map_types: if self.params.get(param_name, False): - if option_type == 'srid': + if option_type == "srid": value = ol_projection(self.params[param_name]) - elif option_type == 'bounds': + elif option_type == "bounds": value = ol_bounds(self.params[param_name]) elif option_type in (float, int): value = self.params[param_name] diff --git a/django/contrib/gis/apps.py b/django/contrib/gis/apps.py index e582e76760..6282501056 100644 --- a/django/contrib/gis/apps.py +++ b/django/contrib/gis/apps.py @@ -4,9 +4,11 @@ from django.utils.translation import gettext_lazy as _ class GISConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.gis' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.gis" verbose_name = _("GIS") def ready(self): - serializers.BUILTIN_SERIALIZERS.setdefault('geojson', 'django.contrib.gis.serializers.geojson') + serializers.BUILTIN_SERIALIZERS.setdefault( + "geojson", "django.contrib.gis.serializers.geojson" + ) diff --git a/django/contrib/gis/db/backends/base/adapter.py b/django/contrib/gis/db/backends/base/adapter.py index f6f271915d..b472e3aacb 100644 --- a/django/contrib/gis/db/backends/base/adapter.py +++ b/django/contrib/gis/db/backends/base/adapter.py @@ -2,14 +2,16 @@ class WKTAdapter: """ An adaptor for Geometries sent to the MySQL and Oracle database backends. """ + def __init__(self, geom): self.wkt = geom.wkt self.srid = geom.srid def __eq__(self, other): return ( - isinstance(other, WKTAdapter) and - self.wkt == other.wkt and self.srid == other.srid + isinstance(other, WKTAdapter) + and self.wkt == other.wkt + and self.srid == other.srid ) def __hash__(self): diff --git a/django/contrib/gis/db/backends/base/features.py b/django/contrib/gis/db/backends/base/features.py index cf52c4e122..cc4ce1046b 100644 --- a/django/contrib/gis/db/backends/base/features.py +++ b/django/contrib/gis/db/backends/base/features.py @@ -60,15 +60,15 @@ class BaseSpatialFeatures: @property def supports_bbcontains_lookup(self): - return 'bbcontains' in self.connection.ops.gis_operators + return "bbcontains" in self.connection.ops.gis_operators @property def supports_contained_lookup(self): - return 'contained' in self.connection.ops.gis_operators + return "contained" in self.connection.ops.gis_operators @property def supports_crosses_lookup(self): - return 'crosses' in self.connection.ops.gis_operators + return "crosses" in self.connection.ops.gis_operators @property def supports_distances_lookups(self): @@ -76,11 +76,11 @@ class BaseSpatialFeatures: @property def supports_dwithin_lookup(self): - return 'dwithin' in self.connection.ops.gis_operators + return "dwithin" in self.connection.ops.gis_operators @property def supports_relate_lookup(self): - return 'relate' in self.connection.ops.gis_operators + return "relate" in self.connection.ops.gis_operators @property def supports_isvalid_lookup(self): @@ -104,7 +104,7 @@ class BaseSpatialFeatures: return models.Union not in self.connection.ops.disallowed_aggregates def __getattr__(self, name): - m = re.match(r'has_(\w*)_function$', name) + m = re.match(r"has_(\w*)_function$", name) if m: func_name = m[1] return func_name not in self.connection.ops.unsupported_functions diff --git a/django/contrib/gis/db/backends/base/models.py b/django/contrib/gis/db/backends/base/models.py index c25c1ccdee..589c872da6 100644 --- a/django/contrib/gis/db/backends/base/models.py +++ b/django/contrib/gis/db/backends/base/models.py @@ -6,13 +6,14 @@ class SpatialRefSysMixin: The SpatialRefSysMixin is a class used by the database-dependent SpatialRefSys objects to reduce redundant code. """ + @property def srs(self): """ Return a GDAL SpatialReference object. """ # TODO: Is caching really necessary here? Is complexity worth it? - if hasattr(self, '_srs'): + if hasattr(self, "_srs"): # Returning a clone of the cached SpatialReference object. return self._srs.clone() else: @@ -31,7 +32,10 @@ class SpatialRefSysMixin: except Exception as e: msg = e - raise Exception('Could not get OSR SpatialReference from WKT: %s\nError:\n%s' % (self.wkt, msg)) + raise Exception( + "Could not get OSR SpatialReference from WKT: %s\nError:\n%s" + % (self.wkt, msg) + ) @property def ellipsoid(self): @@ -49,12 +53,12 @@ class SpatialRefSysMixin: @property def spheroid(self): "Return the spheroid name for this spatial reference." - return self.srs['spheroid'] + return self.srs["spheroid"] @property def datum(self): "Return the datum for this spatial reference." - return self.srs['datum'] + return self.srs["datum"] @property def projected(self): @@ -117,7 +121,7 @@ class SpatialRefSysMixin: """ srs = gdal.SpatialReference(wkt) sphere_params = srs.ellipsoid - sphere_name = srs['spheroid'] + sphere_name = srs["spheroid"] if not string: return sphere_name, sphere_params diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index 84b1785a1c..5e56b82a78 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -1,8 +1,7 @@ from django.contrib.gis.db.models import GeometryField from django.contrib.gis.db.models.functions import Distance -from django.contrib.gis.measure import ( - Area as AreaMeasure, Distance as DistanceMeasure, -) +from django.contrib.gis.measure import Area as AreaMeasure +from django.contrib.gis.measure import Distance as DistanceMeasure from django.db import NotSupportedError from django.utils.functional import cached_property @@ -18,7 +17,7 @@ class BaseSpatialOperations: spatial_version = None # How the geometry column should be selected. - select = '%s' + select = "%s" @cached_property def select_extent(self): @@ -27,7 +26,7 @@ class BaseSpatialOperations: # Aggregates disallowed_aggregates = () - geom_func_prefix = '' + geom_func_prefix = "" # Mapping between Django function names and backend names, when names do not # match; used in spatial_function_name(). @@ -35,12 +34,36 @@ class BaseSpatialOperations: # Set of known unsupported functions of the backend unsupported_functions = { - 'Area', 'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'Azimuth', - 'BoundingCircle', 'Centroid', 'Difference', 'Distance', 'Envelope', - 'GeoHash', 'GeometryDistance', 'Intersection', 'IsValid', 'Length', - 'LineLocatePoint', 'MakeValid', 'MemSize', 'NumGeometries', - 'NumPoints', 'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', - 'SnapToGrid', 'SymDifference', 'Transform', 'Translate', 'Union', + "Area", + "AsGeoJSON", + "AsGML", + "AsKML", + "AsSVG", + "Azimuth", + "BoundingCircle", + "Centroid", + "Difference", + "Distance", + "Envelope", + "GeoHash", + "GeometryDistance", + "Intersection", + "IsValid", + "Length", + "LineLocatePoint", + "MakeValid", + "MemSize", + "NumGeometries", + "NumPoints", + "Perimeter", + "PointOnSurface", + "Reverse", + "Scale", + "SnapToGrid", + "SymDifference", + "Transform", + "Translate", + "Union", } # Constructors @@ -49,10 +72,14 @@ class BaseSpatialOperations: # Default conversion functions for aggregates; will be overridden if implemented # for the spatial backend. def convert_extent(self, box, srid): - raise NotImplementedError('Aggregate extent not implemented for this spatial backend.') + raise NotImplementedError( + "Aggregate extent not implemented for this spatial backend." + ) def convert_extent3d(self, box, srid): - raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.') + raise NotImplementedError( + "Aggregate 3D extent not implemented for this spatial backend." + ) # For quoting column values, rather than columns. def geo_quote_name(self, name): @@ -64,14 +91,18 @@ class BaseSpatialOperations: Return the database column type for the geometry field on the spatial backend. """ - raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_type() method') + raise NotImplementedError( + "subclasses of BaseSpatialOperations must provide a geo_db_type() method" + ) def get_distance(self, f, value, lookup_type): """ Return the distance parameters for the given geometry field, lookup value, and lookup type. """ - raise NotImplementedError('Distance operations not available on this spatial backend.') + raise NotImplementedError( + "Distance operations not available on this spatial backend." + ) def get_geom_placeholder(self, f, value, compiler): """ @@ -80,48 +111,60 @@ class BaseSpatialOperations: stored procedure call to the transformation function of the spatial backend. """ + def transform_value(value, field): return value is not None and value.srid != field.srid - if hasattr(value, 'as_sql'): + if hasattr(value, "as_sql"): return ( - '%s(%%s, %s)' % (self.spatial_function_name('Transform'), f.srid) + "%s(%%s, %s)" % (self.spatial_function_name("Transform"), f.srid) if transform_value(value.output_field, f) - else '%s' + else "%s" ) if transform_value(value, f): # Add Transform() to the SQL placeholder. - return '%s(%s(%%s,%s), %s)' % ( - self.spatial_function_name('Transform'), - self.from_text, value.srid, f.srid, + return "%s(%s(%%s,%s), %s)" % ( + self.spatial_function_name("Transform"), + self.from_text, + value.srid, + f.srid, ) elif self.connection.features.has_spatialrefsys_table: - return '%s(%%s,%s)' % (self.from_text, f.srid) + return "%s(%%s,%s)" % (self.from_text, f.srid) else: # For backwards compatibility on MySQL (#27464). - return '%s(%%s)' % self.from_text + return "%s(%%s)" % self.from_text def check_expression_support(self, expression): if isinstance(expression, self.disallowed_aggregates): raise NotSupportedError( - "%s spatial aggregation is not supported by this database backend." % expression.name + "%s spatial aggregation is not supported by this database backend." + % expression.name ) super().check_expression_support(expression) def spatial_aggregate_name(self, agg_name): - raise NotImplementedError('Aggregate support not implemented for this spatial backend.') + raise NotImplementedError( + "Aggregate support not implemented for this spatial backend." + ) def spatial_function_name(self, func_name): if func_name in self.unsupported_functions: - raise NotSupportedError("This backend doesn't support the %s function." % func_name) + raise NotSupportedError( + "This backend doesn't support the %s function." % func_name + ) return self.function_names.get(func_name, self.geom_func_prefix + func_name) # Routines for getting the OGC-compliant models. def geometry_columns(self): - raise NotImplementedError('Subclasses of BaseSpatialOperations must provide a geometry_columns() method.') + raise NotImplementedError( + "Subclasses of BaseSpatialOperations must provide a geometry_columns() method." + ) def spatial_ref_sys(self): - raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method') + raise NotImplementedError( + "subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method" + ) distance_expr_for_lookup = staticmethod(Distance) @@ -133,15 +176,17 @@ class BaseSpatialOperations: def get_geometry_converter(self, expression): raise NotImplementedError( - 'Subclasses of BaseSpatialOperations must provide a ' - 'get_geometry_converter() method.' + "Subclasses of BaseSpatialOperations must provide a " + "get_geometry_converter() method." ) def get_area_att_for_field(self, field): if field.geodetic(self.connection): if self.connection.features.supports_area_geodetic: - return 'sq_m' - raise NotImplementedError('Area on geodetic coordinate systems not supported.') + return "sq_m" + raise NotImplementedError( + "Area on geodetic coordinate systems not supported." + ) else: units_name = field.units_name(self.connection) if units_name: @@ -151,7 +196,7 @@ class BaseSpatialOperations: dist_att = None if field.geodetic(self.connection): if self.connection.features.supports_distance_geodetic: - dist_att = 'm' + dist_att = "m" else: units = field.units_name(self.connection) if units: diff --git a/django/contrib/gis/db/backends/mysql/base.py b/django/contrib/gis/db/backends/mysql/base.py index fccea5919d..4abc052e6e 100644 --- a/django/contrib/gis/db/backends/mysql/base.py +++ b/django/contrib/gis/db/backends/mysql/base.py @@ -1,6 +1,4 @@ -from django.db.backends.mysql.base import ( - DatabaseWrapper as MySQLDatabaseWrapper, -) +from django.db.backends.mysql.base import DatabaseWrapper as MySQLDatabaseWrapper from .features import DatabaseFeatures from .introspection import MySQLIntrospection diff --git a/django/contrib/gis/db/backends/mysql/features.py b/django/contrib/gis/db/backends/mysql/features.py index 613ccce9c5..8999a38bf9 100644 --- a/django/contrib/gis/db/backends/mysql/features.py +++ b/django/contrib/gis/db/backends/mysql/features.py @@ -1,7 +1,5 @@ from django.contrib.gis.db.backends.base.features import BaseSpatialFeatures -from django.db.backends.mysql.features import ( - DatabaseFeatures as MySQLDatabaseFeatures, -) +from django.db.backends.mysql.features import DatabaseFeatures as MySQLDatabaseFeatures from django.utils.functional import cached_property @@ -14,13 +12,13 @@ class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures): supports_transform = False supports_null_geometries = False supports_num_points_poly = False - unsupported_geojson_options = {'crs'} + unsupported_geojson_options = {"crs"} @cached_property def empty_intersection_returns_none(self): return ( - not self.connection.mysql_is_mariadb and - self.connection.mysql_version < (5, 7, 5) + not self.connection.mysql_is_mariadb + and self.connection.mysql_version < (5, 7, 5) ) @cached_property @@ -31,13 +29,16 @@ class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures): @cached_property def django_test_skips(self): skips = super().django_test_skips - if ( - not self.connection.mysql_is_mariadb and - self.connection.mysql_version < (8, 0, 0) + if not self.connection.mysql_is_mariadb and self.connection.mysql_version < ( + 8, + 0, + 0, ): - skips.update({ - 'MySQL < 8 gives different results.': { - 'gis_tests.geoapp.tests.GeoLookupTest.test_disjoint_lookup', - }, - }) + skips.update( + { + "MySQL < 8 gives different results.": { + "gis_tests.geoapp.tests.GeoLookupTest.test_disjoint_lookup", + }, + } + ) return skips diff --git a/django/contrib/gis/db/backends/mysql/introspection.py b/django/contrib/gis/db/backends/mysql/introspection.py index ad78595e24..7497ad520c 100644 --- a/django/contrib/gis/db/backends/mysql/introspection.py +++ b/django/contrib/gis/db/backends/mysql/introspection.py @@ -8,14 +8,13 @@ class MySQLIntrospection(DatabaseIntrospection): # Updating the data_types_reverse dictionary with the appropriate # type for Geometry fields. data_types_reverse = DatabaseIntrospection.data_types_reverse.copy() - data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField' + data_types_reverse[FIELD_TYPE.GEOMETRY] = "GeometryField" def get_geometry_type(self, table_name, description): with self.connection.cursor() as cursor: # In order to get the specific geometry type of the field, # we introspect on the table definition using `DESCRIBE`. - cursor.execute('DESCRIBE %s' % - self.connection.ops.quote_name(table_name)) + cursor.execute("DESCRIBE %s" % self.connection.ops.quote_name(table_name)) # Increment over description info until we get to the geometry # column. for column, typ, null, key, default, extra in cursor.fetchall(): @@ -31,8 +30,8 @@ class MySQLIntrospection(DatabaseIntrospection): def supports_spatial_index(self, cursor, table_name): # Supported with MyISAM/Aria, or InnoDB on MySQL 5.7.5+/MariaDB. storage_engine = self.get_storage_engine(cursor, table_name) - if storage_engine == 'InnoDB': + if storage_engine == "InnoDB": if self.connection.mysql_is_mariadb: return True return self.connection.mysql_version >= (5, 7, 5) - return storage_engine in ('MyISAM', 'Aria') + return storage_engine in ("MyISAM", "Aria") diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 87ab371aa3..4ef014b36b 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -1,8 +1,6 @@ from django.contrib.gis.db import models from django.contrib.gis.db.backends.base.adapter import WKTAdapter -from django.contrib.gis.db.backends.base.operations import ( - BaseSpatialOperations, -) +from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.prototypes.io import wkb_r @@ -12,8 +10,8 @@ from django.utils.functional import cached_property class MySQLOperations(BaseSpatialOperations, DatabaseOperations): - name = 'mysql' - geom_func_prefix = 'ST_' + name = "mysql" + geom_func_prefix = "ST_" Adapter = WKTAdapter @@ -27,51 +25,69 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): @cached_property def select(self): - return self.geom_func_prefix + 'AsBinary(%s)' + return self.geom_func_prefix + "AsBinary(%s)" @cached_property def from_text(self): - return self.geom_func_prefix + 'GeomFromText' + return self.geom_func_prefix + "GeomFromText" @cached_property def gis_operators(self): operators = { - 'bbcontains': SpatialOperator(func='MBRContains'), # For consistency w/PostGIS API - 'bboverlaps': SpatialOperator(func='MBROverlaps'), # ... - 'contained': SpatialOperator(func='MBRWithin'), # ... - 'contains': SpatialOperator(func='ST_Contains'), - 'crosses': SpatialOperator(func='ST_Crosses'), - 'disjoint': SpatialOperator(func='ST_Disjoint'), - 'equals': SpatialOperator(func='ST_Equals'), - 'exact': SpatialOperator(func='ST_Equals'), - 'intersects': SpatialOperator(func='ST_Intersects'), - 'overlaps': SpatialOperator(func='ST_Overlaps'), - 'same_as': SpatialOperator(func='ST_Equals'), - 'touches': SpatialOperator(func='ST_Touches'), - 'within': SpatialOperator(func='ST_Within'), + "bbcontains": SpatialOperator( + func="MBRContains" + ), # For consistency w/PostGIS API + "bboverlaps": SpatialOperator(func="MBROverlaps"), # ... + "contained": SpatialOperator(func="MBRWithin"), # ... + "contains": SpatialOperator(func="ST_Contains"), + "crosses": SpatialOperator(func="ST_Crosses"), + "disjoint": SpatialOperator(func="ST_Disjoint"), + "equals": SpatialOperator(func="ST_Equals"), + "exact": SpatialOperator(func="ST_Equals"), + "intersects": SpatialOperator(func="ST_Intersects"), + "overlaps": SpatialOperator(func="ST_Overlaps"), + "same_as": SpatialOperator(func="ST_Equals"), + "touches": SpatialOperator(func="ST_Touches"), + "within": SpatialOperator(func="ST_Within"), } if self.connection.mysql_is_mariadb: - operators['relate'] = SpatialOperator(func='ST_Relate') + operators["relate"] = SpatialOperator(func="ST_Relate") return operators disallowed_aggregates = ( - models.Collect, models.Extent, models.Extent3D, models.MakeLine, + models.Collect, + models.Extent, + models.Extent3D, + models.MakeLine, models.Union, ) @cached_property def unsupported_functions(self): unsupported = { - 'AsGML', 'AsKML', 'AsSVG', 'Azimuth', 'BoundingCircle', - 'ForcePolygonCW', 'GeometryDistance', 'LineLocatePoint', - 'MakeValid', 'MemSize', 'Perimeter', 'PointOnSurface', 'Reverse', - 'Scale', 'SnapToGrid', 'Transform', 'Translate', + "AsGML", + "AsKML", + "AsSVG", + "Azimuth", + "BoundingCircle", + "ForcePolygonCW", + "GeometryDistance", + "LineLocatePoint", + "MakeValid", + "MemSize", + "Perimeter", + "PointOnSurface", + "Reverse", + "Scale", + "SnapToGrid", + "Transform", + "Translate", } if self.connection.mysql_is_mariadb: - unsupported.remove('PointOnSurface') - unsupported.update({'GeoHash', 'IsValid'}) + unsupported.remove("PointOnSurface") + unsupported.update({"GeoHash", "IsValid"}) elif self.connection.mysql_version < (5, 7, 5): - unsupported.update({'AsGeoJSON', 'GeoHash', 'IsValid'}) + unsupported.update({"AsGeoJSON", "GeoHash", "IsValid"}) return unsupported def geo_db_type(self, f): @@ -82,10 +98,12 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): if isinstance(value, Distance): if f.geodetic(self.connection): raise ValueError( - 'Only numeric values of degree units are allowed on ' - 'geodetic distance queries.' + "Only numeric values of degree units are allowed on " + "geodetic distance queries." ) - dist_param = getattr(value, Distance.unit_attname(f.units_name(self.connection))) + dist_param = getattr( + value, Distance.unit_attname(f.units_name(self.connection)) + ) else: dist_param = value return [dist_param] @@ -103,4 +121,5 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): if srid: geom.srid = srid return geom + return converter diff --git a/django/contrib/gis/db/backends/mysql/schema.py b/django/contrib/gis/db/backends/mysql/schema.py index da033ce5d6..b173df0198 100644 --- a/django/contrib/gis/db/backends/mysql/schema.py +++ b/django/contrib/gis/db/backends/mysql/schema.py @@ -4,12 +4,12 @@ from django.contrib.gis.db.models import GeometryField from django.db import OperationalError from django.db.backends.mysql.schema import DatabaseSchemaEditor -logger = logging.getLogger('django.contrib.gis') +logger = logging.getLogger("django.contrib.gis") class MySQLGISSchemaEditor(DatabaseSchemaEditor): - sql_add_spatial_index = 'CREATE SPATIAL INDEX %(index)s ON %(table)s(%(column)s)' - sql_drop_spatial_index = 'DROP INDEX %(index)s ON %(table)s' + sql_add_spatial_index = "CREATE SPATIAL INDEX %(index)s ON %(table)s(%(column)s)" + sql_drop_spatial_index = "DROP INDEX %(index)s ON %(table)s" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -18,7 +18,10 @@ class MySQLGISSchemaEditor(DatabaseSchemaEditor): def skip_default(self, field): # Geometry fields are stored as BLOB/TEXT, for which MySQL < 8.0.13 # doesn't support defaults. - if isinstance(field, GeometryField) and not self._supports_limited_data_type_defaults: + if ( + isinstance(field, GeometryField) + and not self._supports_limited_data_type_defaults + ): return True return super().skip_default(field) @@ -34,10 +37,11 @@ class MySQLGISSchemaEditor(DatabaseSchemaEditor): qn = self.connection.ops.quote_name db_table = model._meta.db_table self.geometry_sql.append( - self.sql_add_spatial_index % { - 'index': qn(self._create_spatial_index_name(model, field)), - 'table': qn(db_table), - 'column': qn(field.column), + self.sql_add_spatial_index + % { + "index": qn(self._create_spatial_index_name(model, field)), + "table": qn(db_table), + "column": qn(field.column), } ) return column_sql @@ -54,21 +58,22 @@ class MySQLGISSchemaEditor(DatabaseSchemaEditor): if isinstance(field, GeometryField) and field.spatial_index: qn = self.connection.ops.quote_name sql = self.sql_drop_spatial_index % { - 'index': qn(self._create_spatial_index_name(model, field)), - 'table': qn(model._meta.db_table), + "index": qn(self._create_spatial_index_name(model, field)), + "table": qn(model._meta.db_table), } try: self.execute(sql) except OperationalError: logger.error( "Couldn't remove spatial index: %s (may be expected " - "if your storage engine doesn't support them).", sql + "if your storage engine doesn't support them).", + sql, ) super().remove_field(model, field) def _create_spatial_index_name(self, model, field): - return '%s_%s_id' % (model._meta.db_table, field.column) + return "%s_%s_id" % (model._meta.db_table, field.column) def create_spatial_indexes(self): for sql in self.geometry_sql: @@ -77,6 +82,7 @@ class MySQLGISSchemaEditor(DatabaseSchemaEditor): except OperationalError: logger.error( "Cannot create SPATIAL INDEX %s. Only MyISAM and (as of " - "MySQL 5.7.5) InnoDB support them.", sql + "MySQL 5.7.5) InnoDB support them.", + sql, ) self.geometry_sql = [] diff --git a/django/contrib/gis/db/backends/oracle/adapter.py b/django/contrib/gis/db/backends/oracle/adapter.py index 3f96178091..7556b121b8 100644 --- a/django/contrib/gis/db/backends/oracle/adapter.py +++ b/django/contrib/gis/db/backends/oracle/adapter.py @@ -19,7 +19,9 @@ class OracleSpatialAdapter(WKTAdapter): if self._polygon_must_be_fixed(geom): geom = self._fix_polygon(geom) elif isinstance(geom, GeometryCollection): - if any(isinstance(g, Polygon) and self._polygon_must_be_fixed(g) for g in geom): + if any( + isinstance(g, Polygon) and self._polygon_must_be_fixed(g) for g in geom + ): geom = self._fix_geometry_collection(geom) self.wkt = geom.wkt @@ -27,12 +29,9 @@ class OracleSpatialAdapter(WKTAdapter): @staticmethod def _polygon_must_be_fixed(poly): - return ( - not poly.empty and - ( - not poly.exterior_ring.is_counterclockwise or - any(x.is_counterclockwise for x in poly) - ) + return not poly.empty and ( + not poly.exterior_ring.is_counterclockwise + or any(x.is_counterclockwise for x in poly) ) @classmethod diff --git a/django/contrib/gis/db/backends/oracle/base.py b/django/contrib/gis/db/backends/oracle/base.py index 0093ef83bb..d43516a72f 100644 --- a/django/contrib/gis/db/backends/oracle/base.py +++ b/django/contrib/gis/db/backends/oracle/base.py @@ -1,6 +1,4 @@ -from django.db.backends.oracle.base import ( - DatabaseWrapper as OracleDatabaseWrapper, -) +from django.db.backends.oracle.base import DatabaseWrapper as OracleDatabaseWrapper from .features import DatabaseFeatures from .introspection import OracleIntrospection diff --git a/django/contrib/gis/db/backends/oracle/features.py b/django/contrib/gis/db/backends/oracle/features.py index 2ba951de42..ee4b296c07 100644 --- a/django/contrib/gis/db/backends/oracle/features.py +++ b/django/contrib/gis/db/backends/oracle/features.py @@ -12,14 +12,16 @@ class DatabaseFeatures(BaseSpatialFeatures, OracleDatabaseFeatures): supports_perimeter_geodetic = True supports_dwithin_distance_expr = False supports_tolerance_parameter = True - unsupported_geojson_options = {'bbox', 'crs', 'precision'} + unsupported_geojson_options = {"bbox", "crs", "precision"} @cached_property def django_test_skips(self): skips = super().django_test_skips - skips.update({ - "Oracle doesn't support spatial operators in constraints.": { - 'gis_tests.gis_migrations.test_operations.OperationTests.test_add_check_constraint', - }, - }) + skips.update( + { + "Oracle doesn't support spatial operators in constraints.": { + "gis_tests.gis_migrations.test_operations.OperationTests.test_add_check_constraint", + }, + } + ) return skips diff --git a/django/contrib/gis/db/backends/oracle/introspection.py b/django/contrib/gis/db/backends/oracle/introspection.py index 493fe74b46..096fee54e1 100644 --- a/django/contrib/gis/db/backends/oracle/introspection.py +++ b/django/contrib/gis/db/backends/oracle/introspection.py @@ -12,7 +12,7 @@ class OracleIntrospection(DatabaseIntrospection): def data_types_reverse(self): return { **super().data_types_reverse, - cx_Oracle.OBJECT: 'GeometryField', + cx_Oracle.OBJECT: "GeometryField", } def get_geometry_type(self, table_name, description): @@ -22,26 +22,26 @@ class OracleIntrospection(DatabaseIntrospection): cursor.execute( 'SELECT "DIMINFO", "SRID" FROM "USER_SDO_GEOM_METADATA" ' 'WHERE "TABLE_NAME"=%s AND "COLUMN_NAME"=%s', - (table_name.upper(), description.name.upper()) + (table_name.upper(), description.name.upper()), ) row = cursor.fetchone() except Exception as exc: raise Exception( - 'Could not find entry in USER_SDO_GEOM_METADATA ' + "Could not find entry in USER_SDO_GEOM_METADATA " 'corresponding to "%s"."%s"' % (table_name, description.name) ) from exc # TODO: Research way to find a more specific geometry field type for # the column's contents. - field_type = 'GeometryField' + field_type = "GeometryField" # Getting the field parameters. field_params = {} dim, srid = row if srid != 4326: - field_params['srid'] = srid + field_params["srid"] = srid # Size of object array (SDO_DIM_ARRAY) is number of dimensions. dim = dim.size() if dim != 2: - field_params['dim'] = dim + field_params["dim"] = dim return field_type, field_params diff --git a/django/contrib/gis/db/backends/oracle/models.py b/django/contrib/gis/db/backends/oracle/models.py index 6876eecefd..f06f73148e 100644 --- a/django/contrib/gis/db/backends/oracle/models.py +++ b/django/contrib/gis/db/backends/oracle/models.py @@ -19,12 +19,12 @@ class OracleGeometryColumns(models.Model): # TODO: Add support for `diminfo` column (type MDSYS.SDO_DIM_ARRAY). class Meta: - app_label = 'gis' - db_table = 'USER_SDO_GEOM_METADATA' + app_label = "gis" + db_table = "USER_SDO_GEOM_METADATA" managed = False def __str__(self): - return '%s - %s (SRID: %s)' % (self.table_name, self.column_name, self.srid) + return "%s - %s (SRID: %s)" % (self.table_name, self.column_name, self.srid) @classmethod def table_name_col(cls): @@ -32,7 +32,7 @@ class OracleGeometryColumns(models.Model): Return the name of the metadata column used to store the feature table name. """ - return 'table_name' + return "table_name" @classmethod def geom_col_name(cls): @@ -40,7 +40,7 @@ class OracleGeometryColumns(models.Model): Return the name of the metadata column used to store the feature geometry column. """ - return 'column_name' + return "column_name" class OracleSpatialRefSys(models.Model, SpatialRefSysMixin): @@ -55,8 +55,8 @@ class OracleSpatialRefSys(models.Model, SpatialRefSysMixin): cs_bounds = models.PolygonField(null=True) class Meta: - app_label = 'gis' - db_table = 'CS_SRS' + app_label = "gis" + db_table = "CS_SRS" managed = False @property diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 9b363aaf4e..45010ef517 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -10,9 +10,7 @@ import re from django.contrib.gis.db import models -from django.contrib.gis.db.backends.base.operations import ( - BaseSpatialOperations, -) +from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase @@ -20,7 +18,7 @@ from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.measure import Distance from django.db.backends.oracle.operations import DatabaseOperations -DEFAULT_TOLERANCE = '0.05' +DEFAULT_TOLERANCE = "0.05" class SDOOperator(SpatialOperator): @@ -32,57 +30,60 @@ class SDODWithin(SpatialOperator): class SDODisjoint(SpatialOperator): - sql_template = "SDO_GEOM.RELATE(%%(lhs)s, 'DISJOINT', %%(rhs)s, %s) = 'DISJOINT'" % DEFAULT_TOLERANCE + sql_template = ( + "SDO_GEOM.RELATE(%%(lhs)s, 'DISJOINT', %%(rhs)s, %s) = 'DISJOINT'" + % DEFAULT_TOLERANCE + ) class SDORelate(SpatialOperator): sql_template = "SDO_RELATE(%(lhs)s, %(rhs)s, 'mask=%(mask)s') = 'TRUE'" def check_relate_argument(self, arg): - masks = 'TOUCH|OVERLAPBDYDISJOINT|OVERLAPBDYINTERSECT|EQUAL|INSIDE|COVEREDBY|CONTAINS|COVERS|ANYINTERACT|ON' - mask_regex = re.compile(r'^(%s)(\+(%s))*$' % (masks, masks), re.I) + masks = "TOUCH|OVERLAPBDYDISJOINT|OVERLAPBDYINTERSECT|EQUAL|INSIDE|COVEREDBY|CONTAINS|COVERS|ANYINTERACT|ON" + mask_regex = re.compile(r"^(%s)(\+(%s))*$" % (masks, masks), re.I) if not isinstance(arg, str) or not mask_regex.match(arg): raise ValueError('Invalid SDO_RELATE mask: "%s"' % arg) def as_sql(self, connection, lookup, template_params, sql_params): - template_params['mask'] = sql_params[-1] + template_params["mask"] = sql_params[-1] return super().as_sql(connection, lookup, template_params, sql_params[:-1]) class OracleOperations(BaseSpatialOperations, DatabaseOperations): - name = 'oracle' + name = "oracle" oracle = True disallowed_aggregates = (models.Collect, models.Extent3D, models.MakeLine) Adapter = OracleSpatialAdapter - extent = 'SDO_AGGR_MBR' - unionagg = 'SDO_AGGR_UNION' + extent = "SDO_AGGR_MBR" + unionagg = "SDO_AGGR_UNION" - from_text = 'SDO_GEOMETRY' + from_text = "SDO_GEOMETRY" function_names = { - 'Area': 'SDO_GEOM.SDO_AREA', - 'AsGeoJSON': 'SDO_UTIL.TO_GEOJSON', - 'AsWKB': 'SDO_UTIL.TO_WKBGEOMETRY', - 'AsWKT': 'SDO_UTIL.TO_WKTGEOMETRY', - 'BoundingCircle': 'SDO_GEOM.SDO_MBC', - 'Centroid': 'SDO_GEOM.SDO_CENTROID', - 'Difference': 'SDO_GEOM.SDO_DIFFERENCE', - 'Distance': 'SDO_GEOM.SDO_DISTANCE', - 'Envelope': 'SDO_GEOM_MBR', - 'Intersection': 'SDO_GEOM.SDO_INTERSECTION', - 'IsValid': 'SDO_GEOM.VALIDATE_GEOMETRY_WITH_CONTEXT', - 'Length': 'SDO_GEOM.SDO_LENGTH', - 'NumGeometries': 'SDO_UTIL.GETNUMELEM', - 'NumPoints': 'SDO_UTIL.GETNUMVERTICES', - 'Perimeter': 'SDO_GEOM.SDO_LENGTH', - 'PointOnSurface': 'SDO_GEOM.SDO_POINTONSURFACE', - 'Reverse': 'SDO_UTIL.REVERSE_LINESTRING', - 'SymDifference': 'SDO_GEOM.SDO_XOR', - 'Transform': 'SDO_CS.TRANSFORM', - 'Union': 'SDO_GEOM.SDO_UNION', + "Area": "SDO_GEOM.SDO_AREA", + "AsGeoJSON": "SDO_UTIL.TO_GEOJSON", + "AsWKB": "SDO_UTIL.TO_WKBGEOMETRY", + "AsWKT": "SDO_UTIL.TO_WKTGEOMETRY", + "BoundingCircle": "SDO_GEOM.SDO_MBC", + "Centroid": "SDO_GEOM.SDO_CENTROID", + "Difference": "SDO_GEOM.SDO_DIFFERENCE", + "Distance": "SDO_GEOM.SDO_DISTANCE", + "Envelope": "SDO_GEOM_MBR", + "Intersection": "SDO_GEOM.SDO_INTERSECTION", + "IsValid": "SDO_GEOM.VALIDATE_GEOMETRY_WITH_CONTEXT", + "Length": "SDO_GEOM.SDO_LENGTH", + "NumGeometries": "SDO_UTIL.GETNUMELEM", + "NumPoints": "SDO_UTIL.GETNUMVERTICES", + "Perimeter": "SDO_GEOM.SDO_LENGTH", + "PointOnSurface": "SDO_GEOM.SDO_POINTONSURFACE", + "Reverse": "SDO_UTIL.REVERSE_LINESTRING", + "SymDifference": "SDO_GEOM.SDO_XOR", + "Transform": "SDO_CS.TRANSFORM", + "Union": "SDO_GEOM.SDO_UNION", } # We want to get SDO Geometries as WKT because it is much easier to @@ -90,28 +91,39 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): # However, this adversely affects performance (i.e., Java is called # to convert to WKT on every query). If someone wishes to write a # SDO_GEOMETRY(...) parser in Python, let me know =) - select = 'SDO_UTIL.TO_WKBGEOMETRY(%s)' + select = "SDO_UTIL.TO_WKBGEOMETRY(%s)" gis_operators = { - 'contains': SDOOperator(func='SDO_CONTAINS'), - 'coveredby': SDOOperator(func='SDO_COVEREDBY'), - 'covers': SDOOperator(func='SDO_COVERS'), - 'disjoint': SDODisjoint(), - 'intersects': SDOOperator(func='SDO_OVERLAPBDYINTERSECT'), # TODO: Is this really the same as ST_Intersects()? - 'equals': SDOOperator(func='SDO_EQUAL'), - 'exact': SDOOperator(func='SDO_EQUAL'), - 'overlaps': SDOOperator(func='SDO_OVERLAPS'), - 'same_as': SDOOperator(func='SDO_EQUAL'), - 'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch' - 'touches': SDOOperator(func='SDO_TOUCH'), - 'within': SDOOperator(func='SDO_INSIDE'), - 'dwithin': SDODWithin(), + "contains": SDOOperator(func="SDO_CONTAINS"), + "coveredby": SDOOperator(func="SDO_COVEREDBY"), + "covers": SDOOperator(func="SDO_COVERS"), + "disjoint": SDODisjoint(), + "intersects": SDOOperator( + func="SDO_OVERLAPBDYINTERSECT" + ), # TODO: Is this really the same as ST_Intersects()? + "equals": SDOOperator(func="SDO_EQUAL"), + "exact": SDOOperator(func="SDO_EQUAL"), + "overlaps": SDOOperator(func="SDO_OVERLAPS"), + "same_as": SDOOperator(func="SDO_EQUAL"), + "relate": SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch' + "touches": SDOOperator(func="SDO_TOUCH"), + "within": SDOOperator(func="SDO_INSIDE"), + "dwithin": SDODWithin(), } unsupported_functions = { - 'AsKML', 'AsSVG', 'Azimuth', 'ForcePolygonCW', 'GeoHash', - 'GeometryDistance', 'LineLocatePoint', 'MakeValid', 'MemSize', - 'Scale', 'SnapToGrid', 'Translate', + "AsKML", + "AsSVG", + "Azimuth", + "ForcePolygonCW", + "GeoHash", + "GeometryDistance", + "LineLocatePoint", + "MakeValid", + "MemSize", + "Scale", + "SnapToGrid", + "Translate", } def geo_quote_name(self, name): @@ -124,15 +136,17 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): # table. ext_geom = GEOSGeometry(memoryview(clob.read())) gtype = str(ext_geom.geom_type) - if gtype == 'Polygon': + if gtype == "Polygon": # Construct the 4-tuple from the coordinates in the polygon. shell = ext_geom.shell ll, ur = shell[0][:2], shell[2][:2] - elif gtype == 'Point': + elif gtype == "Point": ll = ext_geom.coords[:2] ur = ll else: - raise Exception('Unexpected geometry type returned for extent: %s' % gtype) + raise Exception( + "Unexpected geometry type returned for extent: %s" % gtype + ) xmin, ymin = ll xmax, ymax = ur return (xmin, ymin, xmax, ymax) @@ -145,7 +159,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): backends, no stored procedure is necessary and it's the same for all geometry types. """ - return 'MDSYS.SDO_GEOMETRY' + return "MDSYS.SDO_GEOMETRY" def get_distance(self, f, value, lookup_type): """ @@ -161,47 +175,47 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): if f.geodetic(self.connection): dist_param = value.m else: - dist_param = getattr(value, Distance.unit_attname(f.units_name(self.connection))) + dist_param = getattr( + value, Distance.unit_attname(f.units_name(self.connection)) + ) else: dist_param = value # dwithin lookups on Oracle require a special string parameter # that starts with "distance=". - if lookup_type == 'dwithin': - dist_param = 'distance=%s' % dist_param + if lookup_type == "dwithin": + dist_param = "distance=%s" % dist_param return [dist_param] def get_geom_placeholder(self, f, value, compiler): if value is None: - return 'NULL' + return "NULL" return super().get_geom_placeholder(f, value, compiler) def spatial_aggregate_name(self, agg_name): """ Return the spatial aggregate SQL name. """ - agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower() + agg_name = "unionagg" if agg_name.lower() == "union" else agg_name.lower() return getattr(self, agg_name) # Routines for getting the OGC-compliant models. def geometry_columns(self): - from django.contrib.gis.db.backends.oracle.models import ( - OracleGeometryColumns, - ) + from django.contrib.gis.db.backends.oracle.models import OracleGeometryColumns + return OracleGeometryColumns def spatial_ref_sys(self): - from django.contrib.gis.db.backends.oracle.models import ( - OracleSpatialRefSys, - ) + from django.contrib.gis.db.backends.oracle.models import OracleSpatialRefSys + return OracleSpatialRefSys def modify_insert_params(self, placeholder, params): """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial backend due to #10888. """ - if placeholder == 'NULL': + if placeholder == "NULL": return [] return super().modify_insert_params(placeholder, params) @@ -218,7 +232,8 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): if srid: geom.srid = srid return geom + return converter def get_area_att_for_field(self, field): - return 'sq_m' + return "sq_m" diff --git a/django/contrib/gis/db/backends/oracle/schema.py b/django/contrib/gis/db/backends/oracle/schema.py index b281abc81b..edc692ffc7 100644 --- a/django/contrib/gis/db/backends/oracle/schema.py +++ b/django/contrib/gis/db/backends/oracle/schema.py @@ -4,7 +4,7 @@ from django.db.backends.utils import strip_quotes, truncate_name class OracleGISSchemaEditor(DatabaseSchemaEditor): - sql_add_geometry_metadata = (""" + sql_add_geometry_metadata = """ INSERT INTO USER_SDO_GEOM_METADATA ("TABLE_NAME", "COLUMN_NAME", "DIMINFO", "SRID") VALUES ( @@ -15,13 +15,15 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): MDSYS.SDO_DIM_ELEMENT('LAT', %(dim1)s, %(dim3)s, %(tolerance)s) ), %(srid)s - )""") - sql_add_spatial_index = 'CREATE INDEX %(index)s ON %(table)s(%(column)s) INDEXTYPE IS MDSYS.SPATIAL_INDEX' - sql_drop_spatial_index = 'DROP INDEX %(index)s' - sql_clear_geometry_table_metadata = 'DELETE FROM USER_SDO_GEOM_METADATA WHERE TABLE_NAME = %(table)s' + )""" + sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s(%(column)s) INDEXTYPE IS MDSYS.SPATIAL_INDEX" + sql_drop_spatial_index = "DROP INDEX %(index)s" + sql_clear_geometry_table_metadata = ( + "DELETE FROM USER_SDO_GEOM_METADATA WHERE TABLE_NAME = %(table)s" + ) sql_clear_geometry_field_metadata = ( - 'DELETE FROM USER_SDO_GEOM_METADATA WHERE TABLE_NAME = %(table)s ' - 'AND COLUMN_NAME = %(column)s' + "DELETE FROM USER_SDO_GEOM_METADATA WHERE TABLE_NAME = %(table)s " + "AND COLUMN_NAME = %(column)s" ) def __init__(self, *args, **kwargs): @@ -41,23 +43,27 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): if isinstance(field, GeometryField): db_table = model._meta.db_table self.geometry_sql.append( - self.sql_add_geometry_metadata % { - 'table': self.geo_quote_name(db_table), - 'column': self.geo_quote_name(field.column), - 'dim0': field._extent[0], - 'dim1': field._extent[1], - 'dim2': field._extent[2], - 'dim3': field._extent[3], - 'tolerance': field._tolerance, - 'srid': field.srid, + self.sql_add_geometry_metadata + % { + "table": self.geo_quote_name(db_table), + "column": self.geo_quote_name(field.column), + "dim0": field._extent[0], + "dim1": field._extent[1], + "dim2": field._extent[2], + "dim3": field._extent[3], + "tolerance": field._tolerance, + "srid": field.srid, } ) if field.spatial_index: self.geometry_sql.append( - self.sql_add_spatial_index % { - 'index': self.quote_name(self._create_spatial_index_name(model, field)), - 'table': self.quote_name(db_table), - 'column': self.quote_name(field.column), + self.sql_add_spatial_index + % { + "index": self.quote_name( + self._create_spatial_index_name(model, field) + ), + "table": self.quote_name(db_table), + "column": self.quote_name(field.column), } ) return column_sql @@ -68,9 +74,12 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): def delete_model(self, model): super().delete_model(model) - self.execute(self.sql_clear_geometry_table_metadata % { - 'table': self.geo_quote_name(model._meta.db_table), - }) + self.execute( + self.sql_clear_geometry_table_metadata + % { + "table": self.geo_quote_name(model._meta.db_table), + } + ) def add_field(self, model, field): super().add_field(model, field) @@ -78,14 +87,22 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): def remove_field(self, model, field): if isinstance(field, GeometryField): - self.execute(self.sql_clear_geometry_field_metadata % { - 'table': self.geo_quote_name(model._meta.db_table), - 'column': self.geo_quote_name(field.column), - }) + self.execute( + self.sql_clear_geometry_field_metadata + % { + "table": self.geo_quote_name(model._meta.db_table), + "column": self.geo_quote_name(field.column), + } + ) if field.spatial_index: - self.execute(self.sql_drop_spatial_index % { - 'index': self.quote_name(self._create_spatial_index_name(model, field)), - }) + self.execute( + self.sql_drop_spatial_index + % { + "index": self.quote_name( + self._create_spatial_index_name(model, field) + ), + } + ) super().remove_field(model, field) def run_geometry_sql(self): @@ -96,4 +113,6 @@ class OracleGISSchemaEditor(DatabaseSchemaEditor): def _create_spatial_index_name(self, model, field): # Oracle doesn't allow object names > 30 characters. Use this scheme # instead of self._create_index_name() for backwards compatibility. - return truncate_name('%s_%s_id' % (strip_quotes(model._meta.db_table), field.column), 30) + return truncate_name( + "%s_%s_id" % (strip_quotes(model._meta.db_table), field.column), 30 + ) diff --git a/django/contrib/gis/db/backends/postgis/adapter.py b/django/contrib/gis/db/backends/postgis/adapter.py index d6cb6ca47d..0e02427811 100644 --- a/django/contrib/gis/db/backends/postgis/adapter.py +++ b/django/contrib/gis/db/backends/postgis/adapter.py @@ -31,7 +31,9 @@ class PostGISAdapter: if proto == ISQLQuote: return self else: - raise Exception('Error implementing psycopg2 protocol. Is psycopg2 installed?') + raise Exception( + "Error implementing psycopg2 protocol. Is psycopg2 installed?" + ) def __eq__(self, other): return isinstance(other, PostGISAdapter) and self.ewkb == other.ewkb @@ -60,9 +62,9 @@ class PostGISAdapter: """ if self.is_geometry: # Psycopg will figure out whether to use E'\\000' or '\000'. - return b'%s(%s)' % ( - b'ST_GeogFromWKB' if self.geography else b'ST_GeomFromEWKB', - self._adapter.getquoted() + return b"%s(%s)" % ( + b"ST_GeogFromWKB" if self.geography else b"ST_GeomFromEWKB", + self._adapter.getquoted(), ) else: # For rasters, add explicit type cast to WKB string. diff --git a/django/contrib/gis/db/backends/postgis/base.py b/django/contrib/gis/db/backends/postgis/base.py index 5b93d7cf1e..87a30004a1 100644 --- a/django/contrib/gis/db/backends/postgis/base.py +++ b/django/contrib/gis/db/backends/postgis/base.py @@ -14,7 +14,7 @@ class DatabaseWrapper(Psycopg2DatabaseWrapper): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if kwargs.get('alias', '') != NO_DB_ALIAS: + if kwargs.get("alias", "") != NO_DB_ALIAS: self.features = DatabaseFeatures(self) self.ops = PostGISOperations(self) self.introspection = PostGISIntrospection(self) diff --git a/django/contrib/gis/db/backends/postgis/const.py b/django/contrib/gis/db/backends/postgis/const.py index 193aa3fc36..d2ad51799f 100644 --- a/django/contrib/gis/db/backends/postgis/const.py +++ b/django/contrib/gis/db/backends/postgis/const.py @@ -15,13 +15,23 @@ POSTGIS_TO_GDAL = [1, 1, 1, 3, 1, 3, 2, 5, 4, None, 6, 7, None, None] # # Scale, origin, and skew have x and y values. PostGIS currently uses # a fixed endianness (1) and there is only one version (0). -POSTGIS_HEADER_STRUCTURE = 'B H H d d d d d d i H H' +POSTGIS_HEADER_STRUCTURE = "B H H d d d d d d i H H" # Lookup values to convert GDAL pixel types to struct characters. This is # used to pack and unpack the pixel values of PostGIS raster bands. GDAL_TO_STRUCT = [ - None, 'B', 'H', 'h', 'L', 'l', 'f', 'd', - None, None, None, None, + None, + "B", + "H", + "h", + "L", + "l", + "f", + "d", + None, + None, + None, + None, ] # Size of the packed value in bytes for different numerical types. @@ -29,17 +39,17 @@ GDAL_TO_STRUCT = [ # when decomposing them into GDALRasters. # See https://docs.python.org/library/struct.html#format-characters STRUCT_SIZE = { - 'b': 1, # Signed char - 'B': 1, # Unsigned char - '?': 1, # _Bool - 'h': 2, # Short - 'H': 2, # Unsigned short - 'i': 4, # Integer - 'I': 4, # Unsigned Integer - 'l': 4, # Long - 'L': 4, # Unsigned Long - 'f': 4, # Float - 'd': 8, # Double + "b": 1, # Signed char + "B": 1, # Unsigned char + "?": 1, # _Bool + "h": 2, # Short + "H": 2, # Unsigned short + "i": 4, # Integer + "I": 4, # Unsigned Integer + "l": 4, # Long + "L": 4, # Unsigned Long + "f": 4, # Float + "d": 8, # Double } # Pixel type specifies type of pixel values in a band. Storage flag specifies diff --git a/django/contrib/gis/db/backends/postgis/introspection.py b/django/contrib/gis/db/backends/postgis/introspection.py index 0647012b04..b12b7b912f 100644 --- a/django/contrib/gis/db/backends/postgis/introspection.py +++ b/django/contrib/gis/db/backends/postgis/introspection.py @@ -6,11 +6,11 @@ class PostGISIntrospection(DatabaseIntrospection): postgis_oid_lookup = {} # Populated when introspection is performed. ignored_tables = DatabaseIntrospection.ignored_tables + [ - 'geography_columns', - 'geometry_columns', - 'raster_columns', - 'spatial_ref_sys', - 'raster_overviews', + "geography_columns", + "geometry_columns", + "raster_columns", + "spatial_ref_sys", + "raster_overviews", ] def get_field_type(self, data_type, description): @@ -21,9 +21,13 @@ class PostGISIntrospection(DatabaseIntrospection): # requests upon connection initialization, the `data_types_reverse` # dictionary isn't updated until introspection is performed here. with self.connection.cursor() as cursor: - cursor.execute("SELECT oid, typname FROM pg_type WHERE typname IN ('geometry', 'geography')") + cursor.execute( + "SELECT oid, typname FROM pg_type WHERE typname IN ('geometry', 'geography')" + ) self.postgis_oid_lookup = dict(cursor.fetchall()) - self.data_types_reverse.update((oid, 'GeometryField') for oid in self.postgis_oid_lookup) + self.data_types_reverse.update( + (oid, "GeometryField") for oid in self.postgis_oid_lookup + ) return super().get_field_type(data_type, description) def get_geometry_type(self, table_name, description): @@ -34,27 +38,32 @@ class PostGISIntrospection(DatabaseIntrospection): metadata tables to determine the geometry type. """ with self.connection.cursor() as cursor: - cursor.execute(""" + cursor.execute( + """ SELECT t.coord_dimension, t.srid, t.type FROM ( SELECT * FROM geometry_columns UNION ALL SELECT * FROM geography_columns ) AS t WHERE t.f_table_name = %s AND t.f_geometry_column = %s - """, (table_name, description.name)) + """, + (table_name, description.name), + ) row = cursor.fetchone() if not row: - raise Exception('Could not find a geometry or geography column for "%s"."%s"' % - (table_name, description.name)) + raise Exception( + 'Could not find a geometry or geography column for "%s"."%s"' + % (table_name, description.name) + ) dim, srid, field_type = row # OGRGeomType does not require GDAL and makes it easy to convert # from OGC geom type name to Django field. field_type = OGRGeomType(field_type).django # Getting any GeometryField keyword arguments that are not the default. field_params = {} - if self.postgis_oid_lookup.get(description.type_code) == 'geography': - field_params['geography'] = True + if self.postgis_oid_lookup.get(description.type_code) == "geography": + field_params["geography"] = True if srid != 4326: - field_params['srid'] = srid + field_params["srid"] = srid if dim != 2: - field_params['dim'] = dim + field_params["dim"] = dim return field_type, field_params diff --git a/django/contrib/gis/db/backends/postgis/models.py b/django/contrib/gis/db/backends/postgis/models.py index 0cd4d5505c..b7b568274a 100644 --- a/django/contrib/gis/db/backends/postgis/models.py +++ b/django/contrib/gis/db/backends/postgis/models.py @@ -10,6 +10,7 @@ class PostGISGeometryColumns(models.Model): The 'geometry_columns' view from PostGIS. See the PostGIS documentation at Ch. 4.3.2. """ + f_table_catalog = models.CharField(max_length=256) f_table_schema = models.CharField(max_length=256) f_table_name = models.CharField(max_length=256) @@ -19,12 +20,12 @@ class PostGISGeometryColumns(models.Model): type = models.CharField(max_length=30) class Meta: - app_label = 'gis' - db_table = 'geometry_columns' + app_label = "gis" + db_table = "geometry_columns" managed = False def __str__(self): - return '%s.%s - %dD %s field (SRID: %d)' % ( + return "%s.%s - %dD %s field (SRID: %d)" % ( self.f_table_name, self.f_geometry_column, self.coord_dimension, @@ -38,7 +39,7 @@ class PostGISGeometryColumns(models.Model): Return the name of the metadata column used to store the feature table name. """ - return 'f_table_name' + return "f_table_name" @classmethod def geom_col_name(cls): @@ -46,7 +47,7 @@ class PostGISGeometryColumns(models.Model): Return the name of the metadata column used to store the feature geometry column. """ - return 'f_geometry_column' + return "f_geometry_column" class PostGISSpatialRefSys(models.Model, SpatialRefSysMixin): @@ -54,6 +55,7 @@ class PostGISSpatialRefSys(models.Model, SpatialRefSysMixin): The 'spatial_ref_sys' table from PostGIS. See the PostGIS documentation at Ch. 4.2.1. """ + srid = models.IntegerField(primary_key=True) auth_name = models.CharField(max_length=256) auth_srid = models.IntegerField() @@ -61,8 +63,8 @@ class PostGISSpatialRefSys(models.Model, SpatialRefSysMixin): proj4text = models.CharField(max_length=2048) class Meta: - app_label = 'gis' - db_table = 'spatial_ref_sys' + app_label = "gis" + db_table = "spatial_ref_sys" managed = False @property diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 84110fce0f..b11b2efcf3 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -1,9 +1,7 @@ import re from django.conf import settings -from django.contrib.gis.db.backends.base.operations import ( - BaseSpatialOperations, -) +from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import GeometryField, RasterField from django.contrib.gis.gdal import GDALRaster @@ -22,7 +20,7 @@ from .models import PostGISGeometryColumns, PostGISSpatialRefSys from .pgraster import from_pgraster # Identifier to mark raster lookups as bilateral. -BILATERAL = 'bilateral' +BILATERAL = "bilateral" class PostGISOperator(SpatialOperator): @@ -39,56 +37,70 @@ class PostGISOperator(SpatialOperator): def as_sql(self, connection, lookup, template_params, *args): if lookup.lhs.output_field.geography and not self.geography: - raise ValueError('PostGIS geography does not support the "%s" ' - 'function/operator.' % (self.func or self.op,)) + raise ValueError( + 'PostGIS geography does not support the "%s" ' + "function/operator." % (self.func or self.op,) + ) template_params = self.check_raster(lookup, template_params) return super().as_sql(connection, lookup, template_params, *args) def check_raster(self, lookup, template_params): - spheroid = lookup.rhs_params and lookup.rhs_params[-1] == 'spheroid' + spheroid = lookup.rhs_params and lookup.rhs_params[-1] == "spheroid" # Check which input is a raster. - lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER' + lhs_is_raster = lookup.lhs.field.geom_type == "RASTER" rhs_is_raster = isinstance(lookup.rhs, GDALRaster) # Look for band indices and inject them if provided. if lookup.band_lhs is not None and lhs_is_raster: if not self.func: - raise ValueError('Band indices are not allowed for this operator, it works on bbox only.') - template_params['lhs'] = '%s, %s' % (template_params['lhs'], lookup.band_lhs) + raise ValueError( + "Band indices are not allowed for this operator, it works on bbox only." + ) + template_params["lhs"] = "%s, %s" % ( + template_params["lhs"], + lookup.band_lhs, + ) if lookup.band_rhs is not None and rhs_is_raster: if not self.func: - raise ValueError('Band indices are not allowed for this operator, it works on bbox only.') - template_params['rhs'] = '%s, %s' % (template_params['rhs'], lookup.band_rhs) + raise ValueError( + "Band indices are not allowed for this operator, it works on bbox only." + ) + template_params["rhs"] = "%s, %s" % ( + template_params["rhs"], + lookup.band_rhs, + ) # Convert rasters to polygons if necessary. if not self.raster or spheroid: # Operators without raster support. if lhs_is_raster: - template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs'] + template_params["lhs"] = "ST_Polygon(%s)" % template_params["lhs"] if rhs_is_raster: - template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs'] + template_params["rhs"] = "ST_Polygon(%s)" % template_params["rhs"] elif self.raster == BILATERAL: # Operators with raster support but don't support mixed (rast-geom) # lookups. if lhs_is_raster and not rhs_is_raster: - template_params['lhs'] = 'ST_Polygon(%s)' % template_params['lhs'] + template_params["lhs"] = "ST_Polygon(%s)" % template_params["lhs"] elif rhs_is_raster and not lhs_is_raster: - template_params['rhs'] = 'ST_Polygon(%s)' % template_params['rhs'] + template_params["rhs"] = "ST_Polygon(%s)" % template_params["rhs"] return template_params class ST_Polygon(Func): - function = 'ST_Polygon' + function = "ST_Polygon" def __init__(self, expr): super().__init__(expr) expr = self.source_expressions[0] if isinstance(expr, Value) and not expr._output_field_or_none: - self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid)) + self.source_expressions[0] = Value( + expr.value, output_field=RasterField(srid=expr.value.srid) + ) @cached_property def output_field(self): @@ -96,64 +108,70 @@ class ST_Polygon(Func): class PostGISOperations(BaseSpatialOperations, DatabaseOperations): - name = 'postgis' + name = "postgis" postgis = True - geom_func_prefix = 'ST_' + geom_func_prefix = "ST_" Adapter = PostGISAdapter - collect = geom_func_prefix + 'Collect' - extent = geom_func_prefix + 'Extent' - extent3d = geom_func_prefix + '3DExtent' - length3d = geom_func_prefix + '3DLength' - makeline = geom_func_prefix + 'MakeLine' - perimeter3d = geom_func_prefix + '3DPerimeter' - unionagg = geom_func_prefix + 'Union' + collect = geom_func_prefix + "Collect" + extent = geom_func_prefix + "Extent" + extent3d = geom_func_prefix + "3DExtent" + length3d = geom_func_prefix + "3DLength" + makeline = geom_func_prefix + "MakeLine" + perimeter3d = geom_func_prefix + "3DPerimeter" + unionagg = geom_func_prefix + "Union" gis_operators = { - 'bbcontains': PostGISOperator(op='~', raster=True), - 'bboverlaps': PostGISOperator(op='&&', geography=True, raster=True), - 'contained': PostGISOperator(op='@', raster=True), - 'overlaps_left': PostGISOperator(op='&<', raster=BILATERAL), - 'overlaps_right': PostGISOperator(op='&>', raster=BILATERAL), - 'overlaps_below': PostGISOperator(op='&<|'), - 'overlaps_above': PostGISOperator(op='|&>'), - 'left': PostGISOperator(op='<<'), - 'right': PostGISOperator(op='>>'), - 'strictly_below': PostGISOperator(op='<<|'), - 'strictly_above': PostGISOperator(op='|>>'), - 'same_as': PostGISOperator(op='~=', raster=BILATERAL), - 'exact': PostGISOperator(op='~=', raster=BILATERAL), # alias of same_as - 'contains': PostGISOperator(func='ST_Contains', raster=BILATERAL), - 'contains_properly': PostGISOperator(func='ST_ContainsProperly', raster=BILATERAL), - 'coveredby': PostGISOperator(func='ST_CoveredBy', geography=True, raster=BILATERAL), - 'covers': PostGISOperator(func='ST_Covers', geography=True, raster=BILATERAL), - 'crosses': PostGISOperator(func='ST_Crosses'), - 'disjoint': PostGISOperator(func='ST_Disjoint', raster=BILATERAL), - 'equals': PostGISOperator(func='ST_Equals'), - 'intersects': PostGISOperator(func='ST_Intersects', geography=True, raster=BILATERAL), - 'overlaps': PostGISOperator(func='ST_Overlaps', raster=BILATERAL), - 'relate': PostGISOperator(func='ST_Relate'), - 'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL), - 'within': PostGISOperator(func='ST_Within', raster=BILATERAL), - 'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL), + "bbcontains": PostGISOperator(op="~", raster=True), + "bboverlaps": PostGISOperator(op="&&", geography=True, raster=True), + "contained": PostGISOperator(op="@", raster=True), + "overlaps_left": PostGISOperator(op="&<", raster=BILATERAL), + "overlaps_right": PostGISOperator(op="&>", raster=BILATERAL), + "overlaps_below": PostGISOperator(op="&<|"), + "overlaps_above": PostGISOperator(op="|&>"), + "left": PostGISOperator(op="<<"), + "right": PostGISOperator(op=">>"), + "strictly_below": PostGISOperator(op="<<|"), + "strictly_above": PostGISOperator(op="|>>"), + "same_as": PostGISOperator(op="~=", raster=BILATERAL), + "exact": PostGISOperator(op="~=", raster=BILATERAL), # alias of same_as + "contains": PostGISOperator(func="ST_Contains", raster=BILATERAL), + "contains_properly": PostGISOperator( + func="ST_ContainsProperly", raster=BILATERAL + ), + "coveredby": PostGISOperator( + func="ST_CoveredBy", geography=True, raster=BILATERAL + ), + "covers": PostGISOperator(func="ST_Covers", geography=True, raster=BILATERAL), + "crosses": PostGISOperator(func="ST_Crosses"), + "disjoint": PostGISOperator(func="ST_Disjoint", raster=BILATERAL), + "equals": PostGISOperator(func="ST_Equals"), + "intersects": PostGISOperator( + func="ST_Intersects", geography=True, raster=BILATERAL + ), + "overlaps": PostGISOperator(func="ST_Overlaps", raster=BILATERAL), + "relate": PostGISOperator(func="ST_Relate"), + "touches": PostGISOperator(func="ST_Touches", raster=BILATERAL), + "within": PostGISOperator(func="ST_Within", raster=BILATERAL), + "dwithin": PostGISOperator(func="ST_DWithin", geography=True, raster=BILATERAL), } unsupported_functions = set() - select = '%s::bytea' + select = "%s::bytea" select_extent = None @cached_property def function_names(self): function_names = { - 'AsWKB': 'ST_AsBinary', - 'AsWKT': 'ST_AsText', - 'BoundingCircle': 'ST_MinimumBoundingCircle', - 'NumPoints': 'ST_NPoints', + "AsWKB": "ST_AsBinary", + "AsWKT": "ST_AsText", + "BoundingCircle": "ST_MinimumBoundingCircle", + "NumPoints": "ST_NPoints", } if self.spatial_version < (2, 4, 0): - function_names['ForcePolygonCW'] = 'ST_ForceRHR' + function_names["ForcePolygonCW"] = "ST_ForceRHR" return function_names @cached_property @@ -165,13 +183,13 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): # can be mitigated by setting `POSTGIS_VERSION` with a 3-tuple # comprising user-supplied values for the major, minor, and # subminor revision of PostGIS. - if hasattr(settings, 'POSTGIS_VERSION'): + if hasattr(settings, "POSTGIS_VERSION"): version = settings.POSTGIS_VERSION else: # Run a basic query to check the status of the connection so we're # sure we only raise the error below if the problem comes from # PostGIS and not from PostgreSQL itself (see #24862). - self._get_postgis_func('version') + self._get_postgis_func("version") try: vtup = self.postgis_version_tuple() @@ -179,9 +197,9 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): raise ImproperlyConfigured( 'Cannot determine PostGIS version for database "%s" ' 'using command "SELECT postgis_lib_version()". ' - 'GeoDjango requires at least PostGIS version 2.4. ' - 'Was the database created from a spatial database ' - 'template?' % self.connection.settings_dict['NAME'] + "GeoDjango requires at least PostGIS version 2.4. " + "Was the database created from a spatial database " + "template?" % self.connection.settings_dict["NAME"] ) version = vtup[1:] return version @@ -194,7 +212,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): """ if box is None: return None - ll, ur = box[4:-1].split(',') + ll, ur = box[4:-1].split(",") xmin, ymin = map(float, ll.split()) xmax, ymax = map(float, ur.split()) return (xmin, ymin, xmax, ymax) @@ -207,7 +225,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): """ if box3d is None: return None - ll, ur = box3d[6:-1].split(',') + ll, ur = box3d[6:-1].split(",") xmin, ymin, zmin = map(float, ll.split()) xmax, ymax, zmax = map(float, ur.split()) return (xmin, ymin, zmin, xmax, ymax, zmax) @@ -216,22 +234,24 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): """ Return the database field type for the given spatial field. """ - if f.geom_type == 'RASTER': - return 'raster' + if f.geom_type == "RASTER": + return "raster" # Type-based geometries. # TODO: Support 'M' extension. if f.dim == 3: - geom_type = f.geom_type + 'Z' + geom_type = f.geom_type + "Z" else: geom_type = f.geom_type if f.geography: if f.srid != 4326: - raise NotSupportedError('PostGIS only supports geography columns with an SRID of 4326.') + raise NotSupportedError( + "PostGIS only supports geography columns with an SRID of 4326." + ) - return 'geography(%s,%d)' % (geom_type, f.srid) + return "geography(%s,%d)" % (geom_type, f.srid) else: - return 'geometry(%s,%d)' % (geom_type, f.srid) + return "geometry(%s,%d)" % (geom_type, f.srid) def get_distance(self, f, dist_val, lookup_type): """ @@ -254,12 +274,16 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): if geography: dist_param = value.m elif geodetic: - if lookup_type == 'dwithin': - raise ValueError('Only numeric values of degree units are ' - 'allowed on geographic DWithin queries.') + if lookup_type == "dwithin": + raise ValueError( + "Only numeric values of degree units are " + "allowed on geographic DWithin queries." + ) dist_param = value.m else: - dist_param = getattr(value, Distance.unit_attname(f.units_name(self.connection))) + dist_param = getattr( + value, Distance.unit_attname(f.units_name(self.connection)) + ) else: # Assuming the distance is in the units of the field. dist_param = value @@ -272,12 +296,12 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): not in the SRID of the field. Specifically, this routine will substitute in the ST_Transform() function call. """ - transform_func = self.spatial_function_name('Transform') - if hasattr(value, 'as_sql'): + transform_func = self.spatial_function_name("Transform") + if hasattr(value, "as_sql"): if value.field.srid == f.srid: - placeholder = '%s' + placeholder = "%s" else: - placeholder = '%s(%%s, %s)' % (transform_func, f.srid) + placeholder = "%s(%%s, %s)" % (transform_func, f.srid) return placeholder # Get the srid for this object @@ -289,9 +313,9 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): # Adding Transform() to the SQL placeholder if the value srid # is not equal to the field srid. if value_srid is None or value_srid == f.srid: - placeholder = '%s' + placeholder = "%s" else: - placeholder = '%s(%%s, %s)' % (transform_func, f.srid) + placeholder = "%s(%%s, %s)" % (transform_func, f.srid) return placeholder @@ -301,28 +325,28 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): """ # Close out the connection. See #9437. with self.connection.temporary_connection() as cursor: - cursor.execute('SELECT %s()' % func) + cursor.execute("SELECT %s()" % func) return cursor.fetchone()[0] def postgis_geos_version(self): "Return the version of the GEOS library used with PostGIS." - return self._get_postgis_func('postgis_geos_version') + return self._get_postgis_func("postgis_geos_version") def postgis_lib_version(self): "Return the version number of the PostGIS library used with PostgreSQL." - return self._get_postgis_func('postgis_lib_version') + return self._get_postgis_func("postgis_lib_version") def postgis_proj_version(self): """Return the version of the PROJ library used with PostGIS.""" - return self._get_postgis_func('postgis_proj_version') + return self._get_postgis_func("postgis_proj_version") def postgis_version(self): "Return PostGIS version number and compile-time options." - return self._get_postgis_func('postgis_version') + return self._get_postgis_func("postgis_version") def postgis_full_version(self): "Return PostGIS version number and compile-time options." - return self._get_postgis_func('postgis_full_version') + return self._get_postgis_func("postgis_full_version") def postgis_version_tuple(self): """ @@ -337,16 +361,16 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): Return the version of PROJ used by PostGIS as a tuple of the major, minor, and subminor release numbers. """ - proj_regex = re.compile(r'(\d+)\.(\d+)\.(\d+)') + proj_regex = re.compile(r"(\d+)\.(\d+)\.(\d+)") proj_ver_str = self.postgis_proj_version() m = proj_regex.search(proj_ver_str) if m: return tuple(map(int, m.groups())) else: - raise Exception('Could not determine PROJ version from PostGIS.') + raise Exception("Could not determine PROJ version from PostGIS.") def spatial_aggregate_name(self, agg_name): - if agg_name == 'Extent3D': + if agg_name == "Extent3D": return self.extent3d else: return self.geom_func_prefix + agg_name @@ -366,15 +390,15 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): return super().distance_expr_for_lookup( self._normalize_distance_lookup_arg(lhs), self._normalize_distance_lookup_arg(rhs), - **kwargs + **kwargs, ) @staticmethod def _normalize_distance_lookup_arg(arg): is_raster = ( - arg.field.geom_type == 'RASTER' - if hasattr(arg, 'field') else - isinstance(arg, GDALRaster) + arg.field.geom_type == "RASTER" + if hasattr(arg, "field") + else isinstance(arg, GDALRaster) ) return ST_Polygon(arg) if is_raster else arg @@ -384,7 +408,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): def converter(value, expression, connection): return None if value is None else GEOSGeometryBase(read(value), geom_class) + return converter def get_area_att_for_field(self, field): - return 'sq_m' + return "sq_m" diff --git a/django/contrib/gis/db/backends/postgis/pgraster.py b/django/contrib/gis/db/backends/postgis/pgraster.py index bfdeb6faf3..22794342ca 100644 --- a/django/contrib/gis/db/backends/postgis/pgraster.py +++ b/django/contrib/gis/db/backends/postgis/pgraster.py @@ -3,8 +3,13 @@ import struct from django.core.exceptions import ValidationError from .const import ( - BANDTYPE_FLAG_HASNODATA, BANDTYPE_PIXTYPE_MASK, GDAL_TO_POSTGIS, - GDAL_TO_STRUCT, POSTGIS_HEADER_STRUCTURE, POSTGIS_TO_GDAL, STRUCT_SIZE, + BANDTYPE_FLAG_HASNODATA, + BANDTYPE_PIXTYPE_MASK, + GDAL_TO_POSTGIS, + GDAL_TO_STRUCT, + POSTGIS_HEADER_STRUCTURE, + POSTGIS_TO_GDAL, + STRUCT_SIZE, ) @@ -12,14 +17,14 @@ def pack(structure, data): """ Pack data into hex string with little endian format. """ - return struct.pack('<' + structure, *data) + return struct.pack("<" + structure, *data) def unpack(structure, data): """ Unpack little endian hexlified binary string into a list. """ - return struct.unpack('<' + structure, bytes.fromhex(data)) + return struct.unpack("<" + structure, bytes.fromhex(data)) def chunk(data, index): @@ -46,7 +51,7 @@ def from_pgraster(data): while data: # Get pixel type for this band pixeltype_with_flags, data = chunk(data, 2) - pixeltype_with_flags = unpack('B', pixeltype_with_flags)[0] + pixeltype_with_flags = unpack("B", pixeltype_with_flags)[0] pixeltype = pixeltype_with_flags & BANDTYPE_PIXTYPE_MASK # Convert datatype from PostGIS to GDAL & get pack type and size @@ -62,11 +67,11 @@ def from_pgraster(data): # Chunk and unpack band data (pack size times nr of pixels) band, data = chunk(data, pack_size * header[10] * header[11]) - band_result = {'data': bytes.fromhex(band)} + band_result = {"data": bytes.fromhex(band)} # Set the nodata value if the nodata flag is set. if pixeltype_with_flags & BANDTYPE_FLAG_HASNODATA: - band_result['nodata_value'] = nodata + band_result["nodata_value"] = nodata # Append band data to band list bands.append(band_result) @@ -81,13 +86,14 @@ def from_pgraster(data): raise ValidationError("Band pixeltypes are not all equal.") return { - 'srid': int(header[9]), - 'width': header[10], 'height': header[11], - 'datatype': pixeltypes[0], - 'origin': (header[5], header[6]), - 'scale': (header[3], header[4]), - 'skew': (header[7], header[8]), - 'bands': bands, + "srid": int(header[9]), + "width": header[10], + "height": header[11], + "datatype": pixeltypes[0], + "origin": (header[5], header[6]), + "scale": (header[3], header[4]), + "skew": (header[7], header[8]), + "bands": bands, } @@ -99,9 +105,18 @@ def to_pgraster(rast): # the endianness and the PostGIS Raster Version, both are fixed by # PostGIS at the moment. rasterheader = ( - 1, 0, len(rast.bands), rast.scale.x, rast.scale.y, - rast.origin.x, rast.origin.y, rast.skew.x, rast.skew.y, - rast.srs.srid, rast.width, rast.height, + 1, + 0, + len(rast.bands), + rast.scale.x, + rast.scale.y, + rast.origin.x, + rast.origin.y, + rast.skew.x, + rast.skew.y, + rast.srs.srid, + rast.width, + rast.height, ) # Pack raster header. @@ -119,7 +134,7 @@ def to_pgraster(rast): # For example, if the byte value is 71, then the datatype is # 71 & ~BANDTYPE_FLAG_HASNODATA = 7 (32BSI) # and the nodata value is True. - structure = 'B' + GDAL_TO_STRUCT[band.datatype()] + structure = "B" + GDAL_TO_STRUCT[band.datatype()] # Get band pixel type in PostGIS notation pixeltype = GDAL_TO_POSTGIS[band.datatype()] diff --git a/django/contrib/gis/db/backends/postgis/schema.py b/django/contrib/gis/db/backends/postgis/schema.py index c574bed84f..77a9096ef4 100644 --- a/django/contrib/gis/db/backends/postgis/schema.py +++ b/django/contrib/gis/db/backends/postgis/schema.py @@ -3,29 +3,33 @@ from django.db.models.expressions import Col, Func class PostGISSchemaEditor(DatabaseSchemaEditor): - geom_index_type = 'GIST' - geom_index_ops_nd = 'GIST_GEOMETRY_OPS_ND' - rast_index_template = 'ST_ConvexHull(%(expressions)s)' + geom_index_type = "GIST" + geom_index_ops_nd = "GIST_GEOMETRY_OPS_ND" + rast_index_template = "ST_ConvexHull(%(expressions)s)" - sql_alter_column_to_3d = "ALTER COLUMN %(column)s TYPE %(type)s USING ST_Force3D(%(column)s)::%(type)s" - sql_alter_column_to_2d = "ALTER COLUMN %(column)s TYPE %(type)s USING ST_Force2D(%(column)s)::%(type)s" + sql_alter_column_to_3d = ( + "ALTER COLUMN %(column)s TYPE %(type)s USING ST_Force3D(%(column)s)::%(type)s" + ) + sql_alter_column_to_2d = ( + "ALTER COLUMN %(column)s TYPE %(type)s USING ST_Force2D(%(column)s)::%(type)s" + ) def geo_quote_name(self, name): return self.connection.ops.geo_quote_name(name) def _field_should_be_indexed(self, model, field): - if getattr(field, 'spatial_index', False): + if getattr(field, "spatial_index", False): return True return super()._field_should_be_indexed(model, field) def _create_index_sql(self, model, *, fields=None, **kwargs): - if fields is None or len(fields) != 1 or not hasattr(fields[0], 'geodetic'): + if fields is None or len(fields) != 1 or not hasattr(fields[0], "geodetic"): return super()._create_index_sql(model, fields=fields, **kwargs) field = fields[0] expressions = None opclasses = None - if field.geom_type == 'RASTER': + if field.geom_type == "RASTER": # For raster fields, wrap index creation SQL statement with ST_ConvexHull. # Indexes on raster columns are based on the convex hull of the raster. expressions = Func(Col(None, field), template=self.rast_index_template) @@ -33,15 +37,15 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): elif field.dim > 2 and not field.geography: # Use "nd" ops which are fast on multidimensional cases opclasses = [self.geom_index_ops_nd] - name = kwargs.get('name') + name = kwargs.get("name") if not name: - name = self._create_index_name(model._meta.db_table, [field.column], '_id') + name = self._create_index_name(model._meta.db_table, [field.column], "_id") return super()._create_index_sql( model, fields=fields, name=name, - using=' USING %s' % self.geom_index_type, + using=" USING %s" % self.geom_index_type, opclasses=opclasses, expressions=expressions, ) @@ -50,7 +54,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): """ Special case when dimension changed. """ - if not hasattr(old_field, 'dim') or not hasattr(new_field, 'dim'): + if not hasattr(old_field, "dim") or not hasattr(new_field, "dim"): return super()._alter_column_type_sql(table, old_field, new_field, new_type) if old_field.dim == 2 and new_field.dim == 3: @@ -61,7 +65,8 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): sql_alter = self.sql_alter_column_type return ( ( - sql_alter % { + sql_alter + % { "column": self.quote_name(new_field.column), "type": new_type, }, diff --git a/django/contrib/gis/db/backends/spatialite/adapter.py b/django/contrib/gis/db/backends/spatialite/adapter.py index 2a36159b06..91d14dc3c7 100644 --- a/django/contrib/gis/db/backends/spatialite/adapter.py +++ b/django/contrib/gis/db/backends/spatialite/adapter.py @@ -4,6 +4,7 @@ from django.db.backends.sqlite3.base import Database class SpatiaLiteAdapter(WKTAdapter): "SQLite adapter for geometry objects." + def __conform__(self, protocol): if protocol is Database.PrepareProtocol: return str(self) diff --git a/django/contrib/gis/db/backends/spatialite/base.py b/django/contrib/gis/db/backends/spatialite/base.py index fef7de62af..3359a7a971 100644 --- a/django/contrib/gis/db/backends/spatialite/base.py +++ b/django/contrib/gis/db/backends/spatialite/base.py @@ -2,9 +2,7 @@ from ctypes.util import find_library from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.db.backends.sqlite3.base import ( - DatabaseWrapper as SQLiteDatabaseWrapper, -) +from django.db.backends.sqlite3.base import DatabaseWrapper as SQLiteDatabaseWrapper from .client import SpatiaLiteClient from .features import DatabaseFeatures @@ -27,12 +25,16 @@ class DatabaseWrapper(SQLiteDatabaseWrapper): # (`libspatialite`). If it's not in the system library path (e.g., it # cannot be found by `ctypes.util.find_library`), then it may be set # manually in the settings via the `SPATIALITE_LIBRARY_PATH` setting. - self.lib_spatialite_paths = [name for name in [ - getattr(settings, 'SPATIALITE_LIBRARY_PATH', None), - 'mod_spatialite.so', - 'mod_spatialite', - find_library('spatialite'), - ] if name is not None] + self.lib_spatialite_paths = [ + name + for name in [ + getattr(settings, "SPATIALITE_LIBRARY_PATH", None), + "mod_spatialite.so", + "mod_spatialite", + find_library("spatialite"), + ] + if name is not None + ] super().__init__(*args, **kwargs) def get_new_connection(self, conn_params): @@ -42,26 +44,26 @@ class DatabaseWrapper(SQLiteDatabaseWrapper): conn.enable_load_extension(True) except AttributeError: raise ImproperlyConfigured( - 'SpatiaLite requires SQLite to be configured to allow ' - 'extension loading.' + "SpatiaLite requires SQLite to be configured to allow " + "extension loading." ) # Load the SpatiaLite library extension on the connection. for path in self.lib_spatialite_paths: try: conn.load_extension(path) except Exception: - if getattr(settings, 'SPATIALITE_LIBRARY_PATH', None): + if getattr(settings, "SPATIALITE_LIBRARY_PATH", None): raise ImproperlyConfigured( - 'Unable to load the SpatiaLite library extension ' - 'as specified in your SPATIALITE_LIBRARY_PATH setting.' + "Unable to load the SpatiaLite library extension " + "as specified in your SPATIALITE_LIBRARY_PATH setting." ) continue else: break else: raise ImproperlyConfigured( - 'Unable to load the SpatiaLite library extension. ' - 'Library names tried: %s' % ', '.join(self.lib_spatialite_paths) + "Unable to load the SpatiaLite library extension. " + "Library names tried: %s" % ", ".join(self.lib_spatialite_paths) ) return conn @@ -72,6 +74,6 @@ class DatabaseWrapper(SQLiteDatabaseWrapper): cursor.execute("PRAGMA table_info(geometry_columns);") if cursor.fetchall() == []: if self.ops.spatial_version < (5,): - cursor.execute('SELECT InitSpatialMetaData(1)') + cursor.execute("SELECT InitSpatialMetaData(1)") else: - cursor.execute('SELECT InitSpatialMetaDataFull(1)') + cursor.execute("SELECT InitSpatialMetaDataFull(1)") diff --git a/django/contrib/gis/db/backends/spatialite/client.py b/django/contrib/gis/db/backends/spatialite/client.py index c9dfd1a527..527fe153bb 100644 --- a/django/contrib/gis/db/backends/spatialite/client.py +++ b/django/contrib/gis/db/backends/spatialite/client.py @@ -2,4 +2,4 @@ from django.db.backends.sqlite3.client import DatabaseClient class SpatiaLiteClient(DatabaseClient): - executable_name = 'spatialite' + executable_name = "spatialite" diff --git a/django/contrib/gis/db/backends/spatialite/features.py b/django/contrib/gis/db/backends/spatialite/features.py index b6c1746c07..9504bb0949 100644 --- a/django/contrib/gis/db/backends/spatialite/features.py +++ b/django/contrib/gis/db/backends/spatialite/features.py @@ -16,9 +16,11 @@ class DatabaseFeatures(BaseSpatialFeatures, SQLiteDatabaseFeatures): @cached_property def django_test_skips(self): skips = super().django_test_skips - skips.update({ - "SpatiaLite doesn't support distance lookups with Distance objects.": { - 'gis_tests.geogapp.tests.GeographyTest.test02_distance_lookup', - }, - }) + skips.update( + { + "SpatiaLite doesn't support distance lookups with Distance objects.": { + "gis_tests.geogapp.tests.GeographyTest.test02_distance_lookup", + }, + } + ) return skips diff --git a/django/contrib/gis/db/backends/spatialite/introspection.py b/django/contrib/gis/db/backends/spatialite/introspection.py index 0feeb24397..8d0003fd53 100644 --- a/django/contrib/gis/db/backends/spatialite/introspection.py +++ b/django/contrib/gis/db/backends/spatialite/introspection.py @@ -1,6 +1,7 @@ from django.contrib.gis.gdal import OGRGeomType from django.db.backends.sqlite3.introspection import ( - DatabaseIntrospection, FlexibleFieldLookupDict, + DatabaseIntrospection, + FlexibleFieldLookupDict, ) @@ -9,15 +10,16 @@ class GeoFlexibleFieldLookupDict(FlexibleFieldLookupDict): Subclass that includes updates the `base_data_types_reverse` dict for geometry field types. """ + base_data_types_reverse = { **FlexibleFieldLookupDict.base_data_types_reverse, - 'point': 'GeometryField', - 'linestring': 'GeometryField', - 'polygon': 'GeometryField', - 'multipoint': 'GeometryField', - 'multilinestring': 'GeometryField', - 'multipolygon': 'GeometryField', - 'geometrycollection': 'GeometryField', + "point": "GeometryField", + "linestring": "GeometryField", + "polygon": "GeometryField", + "multipoint": "GeometryField", + "multilinestring": "GeometryField", + "multipolygon": "GeometryField", + "geometrycollection": "GeometryField", } @@ -27,14 +29,18 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): def get_geometry_type(self, table_name, description): with self.connection.cursor() as cursor: # Querying the `geometry_columns` table to get additional metadata. - cursor.execute('SELECT coord_dimension, srid, geometry_type ' - 'FROM geometry_columns ' - 'WHERE f_table_name=%s AND f_geometry_column=%s', - (table_name, description.name)) + cursor.execute( + "SELECT coord_dimension, srid, geometry_type " + "FROM geometry_columns " + "WHERE f_table_name=%s AND f_geometry_column=%s", + (table_name, description.name), + ) row = cursor.fetchone() if not row: - raise Exception('Could not find a geometry column for "%s"."%s"' % - (table_name, description.name)) + raise Exception( + 'Could not find a geometry column for "%s"."%s"' + % (table_name, description.name) + ) # OGRGeomType does not require GDAL and makes it easy to convert # from OGC geom type name to Django field. @@ -51,18 +57,21 @@ class SpatiaLiteIntrospection(DatabaseIntrospection): srid = row[1] field_params = {} if srid != 4326: - field_params['srid'] = srid - if (isinstance(dim, str) and 'Z' in dim) or dim == 3: - field_params['dim'] = 3 + field_params["srid"] = srid + if (isinstance(dim, str) and "Z" in dim) or dim == 3: + field_params["dim"] = 3 return field_type, field_params def get_constraints(self, cursor, table_name): constraints = super().get_constraints(cursor, table_name) - cursor.execute('SELECT f_geometry_column ' - 'FROM geometry_columns ' - 'WHERE f_table_name=%s AND spatial_index_enabled=1', (table_name,)) + cursor.execute( + "SELECT f_geometry_column " + "FROM geometry_columns " + "WHERE f_table_name=%s AND spatial_index_enabled=1", + (table_name,), + ) for row in cursor.fetchall(): - constraints['%s__spatial__index' % row[0]] = { + constraints["%s__spatial__index" % row[0]] = { "columns": [row[0]], "primary_key": False, "unique": False, diff --git a/django/contrib/gis/db/backends/spatialite/models.py b/django/contrib/gis/db/backends/spatialite/models.py index 577c7236e3..7cc98ae126 100644 --- a/django/contrib/gis/db/backends/spatialite/models.py +++ b/django/contrib/gis/db/backends/spatialite/models.py @@ -9,20 +9,21 @@ class SpatialiteGeometryColumns(models.Model): """ The 'geometry_columns' table from SpatiaLite. """ + f_table_name = models.CharField(max_length=256) f_geometry_column = models.CharField(max_length=256) coord_dimension = models.IntegerField() srid = models.IntegerField(primary_key=True) spatial_index_enabled = models.IntegerField() - type = models.IntegerField(db_column='geometry_type') + type = models.IntegerField(db_column="geometry_type") class Meta: - app_label = 'gis' - db_table = 'geometry_columns' + app_label = "gis" + db_table = "geometry_columns" managed = False def __str__(self): - return '%s.%s - %dD %s field (SRID: %d)' % ( + return "%s.%s - %dD %s field (SRID: %d)" % ( self.f_table_name, self.f_geometry_column, self.coord_dimension, @@ -36,7 +37,7 @@ class SpatialiteGeometryColumns(models.Model): Return the name of the metadata column used to store the feature table name. """ - return 'f_table_name' + return "f_table_name" @classmethod def geom_col_name(cls): @@ -44,13 +45,14 @@ class SpatialiteGeometryColumns(models.Model): Return the name of the metadata column used to store the feature geometry column. """ - return 'f_geometry_column' + return "f_geometry_column" class SpatialiteSpatialRefSys(models.Model, SpatialRefSysMixin): """ The 'spatial_ref_sys' table from SpatiaLite. """ + srid = models.IntegerField(primary_key=True) auth_name = models.CharField(max_length=256) auth_srid = models.IntegerField() @@ -59,8 +61,8 @@ class SpatialiteSpatialRefSys(models.Model, SpatialRefSysMixin): srtext = models.CharField(max_length=2048) class Meta: - app_label = 'gis' - db_table = 'spatial_ref_sys' + app_label = "gis" + db_table = "spatial_ref_sys" managed = False @property diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 8dcd62de9b..8003fcb6c6 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -3,9 +3,7 @@ SQL functions reference lists: https://www.gaia-gis.it/gaia-sins/spatialite-sql-4.3.0.html """ from django.contrib.gis.db import models -from django.contrib.gis.db.backends.base.operations import ( - BaseSpatialOperations, -) +from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase @@ -20,69 +18,69 @@ from django.utils.version import get_version_tuple class SpatialiteNullCheckOperator(SpatialOperator): def as_sql(self, connection, lookup, template_params, sql_params): sql, params = super().as_sql(connection, lookup, template_params, sql_params) - return '%s > 0' % sql, params + return "%s > 0" % sql, params class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): - name = 'spatialite' + name = "spatialite" spatialite = True Adapter = SpatiaLiteAdapter - collect = 'Collect' - extent = 'Extent' - makeline = 'MakeLine' - unionagg = 'GUnion' + collect = "Collect" + extent = "Extent" + makeline = "MakeLine" + unionagg = "GUnion" - from_text = 'GeomFromText' + from_text = "GeomFromText" gis_operators = { # Binary predicates - 'equals': SpatialiteNullCheckOperator(func='Equals'), - 'disjoint': SpatialiteNullCheckOperator(func='Disjoint'), - 'touches': SpatialiteNullCheckOperator(func='Touches'), - 'crosses': SpatialiteNullCheckOperator(func='Crosses'), - 'within': SpatialiteNullCheckOperator(func='Within'), - 'overlaps': SpatialiteNullCheckOperator(func='Overlaps'), - 'contains': SpatialiteNullCheckOperator(func='Contains'), - 'intersects': SpatialiteNullCheckOperator(func='Intersects'), - 'relate': SpatialiteNullCheckOperator(func='Relate'), - 'coveredby': SpatialiteNullCheckOperator(func='CoveredBy'), - 'covers': SpatialiteNullCheckOperator(func='Covers'), + "equals": SpatialiteNullCheckOperator(func="Equals"), + "disjoint": SpatialiteNullCheckOperator(func="Disjoint"), + "touches": SpatialiteNullCheckOperator(func="Touches"), + "crosses": SpatialiteNullCheckOperator(func="Crosses"), + "within": SpatialiteNullCheckOperator(func="Within"), + "overlaps": SpatialiteNullCheckOperator(func="Overlaps"), + "contains": SpatialiteNullCheckOperator(func="Contains"), + "intersects": SpatialiteNullCheckOperator(func="Intersects"), + "relate": SpatialiteNullCheckOperator(func="Relate"), + "coveredby": SpatialiteNullCheckOperator(func="CoveredBy"), + "covers": SpatialiteNullCheckOperator(func="Covers"), # Returns true if B's bounding box completely contains A's bounding box. - 'contained': SpatialOperator(func='MbrWithin'), + "contained": SpatialOperator(func="MbrWithin"), # Returns true if A's bounding box completely contains B's bounding box. - 'bbcontains': SpatialOperator(func='MbrContains'), + "bbcontains": SpatialOperator(func="MbrContains"), # Returns true if A's bounding box overlaps B's bounding box. - 'bboverlaps': SpatialOperator(func='MbrOverlaps'), + "bboverlaps": SpatialOperator(func="MbrOverlaps"), # These are implemented here as synonyms for Equals - 'same_as': SpatialiteNullCheckOperator(func='Equals'), - 'exact': SpatialiteNullCheckOperator(func='Equals'), + "same_as": SpatialiteNullCheckOperator(func="Equals"), + "exact": SpatialiteNullCheckOperator(func="Equals"), # Distance predicates - 'dwithin': SpatialOperator(func='PtDistWithin'), + "dwithin": SpatialOperator(func="PtDistWithin"), } disallowed_aggregates = (models.Extent3D,) - select = 'CAST (AsEWKB(%s) AS BLOB)' + select = "CAST (AsEWKB(%s) AS BLOB)" function_names = { - 'AsWKB': 'St_AsBinary', - 'ForcePolygonCW': 'ST_ForceLHR', - 'Length': 'ST_Length', - 'LineLocatePoint': 'ST_Line_Locate_Point', - 'NumPoints': 'ST_NPoints', - 'Reverse': 'ST_Reverse', - 'Scale': 'ScaleCoords', - 'Translate': 'ST_Translate', - 'Union': 'ST_Union', + "AsWKB": "St_AsBinary", + "ForcePolygonCW": "ST_ForceLHR", + "Length": "ST_Length", + "LineLocatePoint": "ST_Line_Locate_Point", + "NumPoints": "ST_NPoints", + "Reverse": "ST_Reverse", + "Scale": "ScaleCoords", + "Translate": "ST_Translate", + "Union": "ST_Union", } @cached_property def unsupported_functions(self): - unsupported = {'BoundingCircle', 'GeometryDistance', 'MemSize'} + unsupported = {"BoundingCircle", "GeometryDistance", "MemSize"} if not self.geom_lib_version(): - unsupported |= {'Azimuth', 'GeoHash', 'MakeValid'} + unsupported |= {"Azimuth", "GeoHash", "MakeValid"} return unsupported @cached_property @@ -93,12 +91,11 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): except Exception as exc: raise ImproperlyConfigured( 'Cannot determine the SpatiaLite version for the "%s" database. ' - 'Was the SpatiaLite initialization SQL loaded on this database?' % ( - self.connection.settings_dict['NAME'], - ) + "Was the SpatiaLite initialization SQL loaded on this database?" + % (self.connection.settings_dict["NAME"],) ) from exc if version < (4, 3, 0): - raise ImproperlyConfigured('GeoDjango supports SpatiaLite 4.3.0 and above.') + raise ImproperlyConfigured("GeoDjango supports SpatiaLite 4.3.0 and above.") return version def convert_extent(self, box): @@ -129,14 +126,16 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): value = value[0] if isinstance(value, Distance): if f.geodetic(self.connection): - if lookup_type == 'dwithin': + if lookup_type == "dwithin": raise ValueError( - 'Only numeric values of degree units are allowed on ' - 'geographic DWithin queries.' + "Only numeric values of degree units are allowed on " + "geographic DWithin queries." ) dist_param = value.m else: - dist_param = getattr(value, Distance.unit_attname(f.units_name(self.connection))) + dist_param = getattr( + value, Distance.unit_attname(f.units_name(self.connection)) + ) else: dist_param = value return [dist_param] @@ -149,7 +148,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): """ cursor = self.connection._cursor() try: - cursor.execute('SELECT %s' % func) + cursor.execute("SELECT %s" % func) row = cursor.fetchone() finally: cursor.close() @@ -157,19 +156,19 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): def geos_version(self): "Return the version of GEOS used by SpatiaLite as a string." - return self._get_spatialite_func('geos_version()') + return self._get_spatialite_func("geos_version()") def proj_version(self): """Return the version of the PROJ library used by SpatiaLite.""" - return self._get_spatialite_func('proj4_version()') + return self._get_spatialite_func("proj4_version()") def lwgeom_version(self): """Return the version of LWGEOM library used by SpatiaLite.""" - return self._get_spatialite_func('lwgeom_version()') + return self._get_spatialite_func("lwgeom_version()") def rttopo_version(self): """Return the version of RTTOPO library used by SpatiaLite.""" - return self._get_spatialite_func('rttopo_version()') + return self._get_spatialite_func("rttopo_version()") def geom_lib_version(self): """ @@ -183,7 +182,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): def spatialite_version(self): "Return the SpatiaLite library version as a string." - return self._get_spatialite_func('spatialite_version()') + return self._get_spatialite_func("spatialite_version()") def spatialite_version_tuple(self): """ @@ -198,7 +197,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): Return the spatial aggregate SQL template and function for the given Aggregate instance. """ - agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower() + agg_name = "unionagg" if agg_name.lower() == "union" else agg_name.lower() return getattr(self, agg_name) # Routines for getting the OGC-compliant models. @@ -206,12 +205,14 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): from django.contrib.gis.db.backends.spatialite.models import ( SpatialiteGeometryColumns, ) + return SpatialiteGeometryColumns def spatial_ref_sys(self): from django.contrib.gis.db.backends.spatialite.models import ( SpatialiteSpatialRefSys, ) + return SpatialiteSpatialRefSys def get_geometry_converter(self, expression): @@ -220,4 +221,5 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): def converter(value, expression, connection): return None if value is None else GEOSGeometryBase(read(value), geom_class) + return converter diff --git a/django/contrib/gis/db/backends/spatialite/schema.py b/django/contrib/gis/db/backends/spatialite/schema.py index 066ce6d732..d37632edf8 100644 --- a/django/contrib/gis/db/backends/spatialite/schema.py +++ b/django/contrib/gis/db/backends/spatialite/schema.py @@ -14,7 +14,9 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): "%(geom_type)s, %(dim)s)" ) sql_remove_geometry_metadata = "SELECT DiscardGeometryColumn(%(table)s, %(column)s)" - sql_discard_geometry_columns = "DELETE FROM %(geom_table)s WHERE f_table_name = %(table)s" + sql_discard_geometry_columns = ( + "DELETE FROM %(geom_table)s WHERE f_table_name = %(table)s" + ) sql_update_geometry_columns = ( "UPDATE %(geom_table)s SET f_table_name = %(new_table)s " "WHERE f_table_name = %(old_table)s" @@ -36,12 +38,14 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): def column_sql(self, model, field, include_default=False): from django.contrib.gis.db.models import GeometryField + if not isinstance(field, GeometryField): return super().column_sql(model, field, include_default) # Geometry columns are created by the `AddGeometryColumn` function self.geometry_sql.append( - self.sql_add_geometry_column % { + self.sql_add_geometry_column + % { "table": self.geo_quote_name(model._meta.db_table), "column": self.geo_quote_name(field.column), "srid": field.srid, @@ -53,7 +57,8 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): if field.spatial_index: self.geometry_sql.append( - self.sql_add_spatial_index % { + self.sql_add_spatial_index + % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(field.column), } @@ -62,13 +67,15 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): def remove_geometry_metadata(self, model, field): self.execute( - self.sql_remove_geometry_metadata % { + self.sql_remove_geometry_metadata + % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(field.column), } ) self.execute( - self.sql_drop_spatial_index % { + self.sql_drop_spatial_index + % { "table": model._meta.db_table, "column": field.column, } @@ -92,7 +99,8 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): for geom_table in self.geometry_tables: try: self.execute( - self.sql_discard_geometry_columns % { + self.sql_discard_geometry_columns + % { "geom_table": geom_table, "table": self.quote_name(model._meta.db_table), } @@ -103,6 +111,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): def add_field(self, model, field): from django.contrib.gis.db.models import GeometryField + if isinstance(field, GeometryField): # Populate self.geometry_sql self.column_sql(model, field) @@ -125,14 +134,17 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): else: super().remove_field(model, field) - def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True): + def alter_db_table( + self, model, old_db_table, new_db_table, disable_constraints=True + ): from django.contrib.gis.db.models import GeometryField # Remove geometry-ness from temp table for field in model._meta.local_fields: if isinstance(field, GeometryField): self.execute( - self.sql_remove_geometry_metadata % { + self.sql_remove_geometry_metadata + % { "table": self.quote_name(old_db_table), "column": self.quote_name(field.column), } @@ -143,7 +155,8 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): for geom_table in self.geometry_tables: try: self.execute( - self.sql_update_geometry_columns % { + self.sql_update_geometry_columns + % { "geom_table": geom_table, "old_table": self.quote_name(old_db_table), "new_table": self.quote_name(new_db_table), @@ -154,15 +167,25 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor): # Re-add geometry-ness and rename spatial index tables for field in model._meta.local_fields: if isinstance(field, GeometryField): - self.execute(self.sql_recover_geometry_metadata % { - "table": self.geo_quote_name(new_db_table), - "column": self.geo_quote_name(field.column), - "srid": field.srid, - "geom_type": self.geo_quote_name(field.geom_type), - "dim": field.dim, - }) - if getattr(field, 'spatial_index', False): - self.execute(self.sql_rename_table % { - "old_table": self.quote_name("idx_%s_%s" % (old_db_table, field.column)), - "new_table": self.quote_name("idx_%s_%s" % (new_db_table, field.column)), - }) + self.execute( + self.sql_recover_geometry_metadata + % { + "table": self.geo_quote_name(new_db_table), + "column": self.geo_quote_name(field.column), + "srid": field.srid, + "geom_type": self.geo_quote_name(field.geom_type), + "dim": field.dim, + } + ) + if getattr(field, "spatial_index", False): + self.execute( + self.sql_rename_table + % { + "old_table": self.quote_name( + "idx_%s_%s" % (old_db_table, field.column) + ), + "new_table": self.quote_name( + "idx_%s_%s" % (new_db_table, field.column) + ), + } + ) diff --git a/django/contrib/gis/db/backends/utils.py b/django/contrib/gis/db/backends/utils.py index d479009d4e..ffb7420019 100644 --- a/django/contrib/gis/db/backends/utils.py +++ b/django/contrib/gis/db/backends/utils.py @@ -8,6 +8,7 @@ class SpatialOperator: """ Class encapsulating the behavior specific to a GIS operation (used by lookups). """ + sql_template = None def __init__(self, op=None, func=None): @@ -17,11 +18,11 @@ class SpatialOperator: @property def default_template(self): if self.func: - return '%(func)s(%(lhs)s, %(rhs)s)' + return "%(func)s(%(lhs)s, %(rhs)s)" else: - return '%(lhs)s %(op)s %(rhs)s' + return "%(lhs)s %(op)s %(rhs)s" def as_sql(self, connection, lookup, template_params, sql_params): sql_template = self.sql_template or lookup.sql_template or self.default_template - template_params.update({'op': self.op, 'func': self.func}) + template_params.update({"op": self.op, "func": self.func}) return sql_template % template_params, sql_params diff --git a/django/contrib/gis/db/models/__init__.py b/django/contrib/gis/db/models/__init__.py index 2e472557d8..0d5d45d272 100644 --- a/django/contrib/gis/db/models/__init__.py +++ b/django/contrib/gis/db/models/__init__.py @@ -5,14 +5,26 @@ import django.contrib.gis.db.models.lookups # NOQA from django.contrib.gis.db.models.aggregates import * # NOQA from django.contrib.gis.db.models.aggregates import __all__ as aggregates_all from django.contrib.gis.db.models.fields import ( - GeometryCollectionField, GeometryField, LineStringField, - MultiLineStringField, MultiPointField, MultiPolygonField, PointField, - PolygonField, RasterField, + GeometryCollectionField, + GeometryField, + LineStringField, + MultiLineStringField, + MultiPointField, + MultiPolygonField, + PointField, + PolygonField, + RasterField, ) __all__ = models_all + aggregates_all __all__ += [ - 'GeometryCollectionField', 'GeometryField', 'LineStringField', - 'MultiLineStringField', 'MultiPointField', 'MultiPolygonField', 'PointField', - 'PolygonField', 'RasterField', + "GeometryCollectionField", + "GeometryField", + "LineStringField", + "MultiLineStringField", + "MultiPointField", + "MultiPolygonField", + "PointField", + "PolygonField", + "RasterField", ] diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 3aa46a52c0..c19cbd06c3 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -1,10 +1,13 @@ from django.contrib.gis.db.models.fields import ( - ExtentField, GeometryCollectionField, GeometryField, LineStringField, + ExtentField, + GeometryCollectionField, + GeometryField, + LineStringField, ) from django.db.models import Aggregate, Value from django.utils.functional import cached_property -__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] +__all__ = ["Collect", "Extent", "Extent3D", "MakeLine", "Union"] class GeoAggregate(Aggregate): @@ -23,37 +26,45 @@ class GeoAggregate(Aggregate): compiler, connection, function=function or connection.ops.spatial_aggregate_name(self.name), - **extra_context + **extra_context, ) def as_oracle(self, compiler, connection, **extra_context): if not self.is_extent: - tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) + tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05) clone = self.copy() - clone.set_source_expressions([ - *self.get_source_expressions(), - Value(tolerance), - ]) - template = '%(function)s(SDOAGGRTYPE(%(expressions)s))' - return clone.as_sql(compiler, connection, template=template, **extra_context) + clone.set_source_expressions( + [ + *self.get_source_expressions(), + Value(tolerance), + ] + ) + template = "%(function)s(SDOAGGRTYPE(%(expressions)s))" + return clone.as_sql( + compiler, connection, template=template, **extra_context + ) return self.as_sql(compiler, connection, **extra_context) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) for expr in c.get_source_expressions(): - if not hasattr(expr.field, 'geom_type'): - raise ValueError('Geospatial aggregates only allowed on geometry fields.') + if not hasattr(expr.field, "geom_type"): + raise ValueError( + "Geospatial aggregates only allowed on geometry fields." + ) return c class Collect(GeoAggregate): - name = 'Collect' + name = "Collect" output_field_class = GeometryCollectionField class Extent(GeoAggregate): - name = 'Extent' - is_extent = '2D' + name = "Extent" + is_extent = "2D" def __init__(self, expression, **extra): super().__init__(expression, output_field=ExtentField(), **extra) @@ -63,8 +74,8 @@ class Extent(GeoAggregate): class Extent3D(GeoAggregate): - name = 'Extent3D' - is_extent = '3D' + name = "Extent3D" + is_extent = "3D" def __init__(self, expression, **extra): super().__init__(expression, output_field=ExtentField(), **extra) @@ -74,10 +85,10 @@ class Extent3D(GeoAggregate): class MakeLine(GeoAggregate): - name = 'MakeLine' + name = "MakeLine" output_field_class = LineStringField class Union(GeoAggregate): - name = 'Union' + name = "Union" output_field_class = GeometryField diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 16b3bb5d08..889c1cfe84 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -4,8 +4,15 @@ from django.contrib.gis import forms, gdal from django.contrib.gis.db.models.proxy import SpatialProxy from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.geos import ( - GeometryCollection, GEOSException, GEOSGeometry, LineString, - MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, + GeometryCollection, + GEOSException, + GEOSGeometry, + LineString, + MultiLineString, + MultiPoint, + MultiPolygon, + Point, + Polygon, ) from django.core.exceptions import ImproperlyConfigured from django.db.models import Field @@ -17,7 +24,9 @@ from django.utils.translation import gettext_lazy as _ _srid_cache = defaultdict(dict) -SRIDCacheEntry = namedtuple('SRIDCacheEntry', ['units', 'units_name', 'spheroid', 'geodetic']) +SRIDCacheEntry = namedtuple( + "SRIDCacheEntry", ["units", "units_name", "spheroid", "geodetic"] +) def get_srid_info(srid, connection): @@ -27,6 +36,7 @@ def get_srid_info(srid, connection): table for the given database connection. These results are cached. """ from django.contrib.gis.gdal import SpatialReference + global _srid_cache try: @@ -36,9 +46,14 @@ def get_srid_info(srid, connection): SpatialRefSys = None alias, get_srs = ( - (connection.alias, lambda srid: SpatialRefSys.objects.using(connection.alias).get(srid=srid).srs) - if SpatialRefSys else - (None, SpatialReference) + ( + connection.alias, + lambda srid: SpatialRefSys.objects.using(connection.alias) + .get(srid=srid) + .srs, + ) + if SpatialRefSys + else (None, SpatialReference) ) if srid not in _srid_cache[alias]: srs = get_srs(srid) @@ -46,7 +61,8 @@ def get_srid_info(srid, connection): _srid_cache[alias][srid] = SRIDCacheEntry( units=units, units_name=units_name, - spheroid='SPHEROID["%s",%s,%s]' % (srs['spheroid'], srs.semi_major, srs.inverse_flattening), + spheroid='SPHEROID["%s",%s,%s]' + % (srs["spheroid"], srs.semi_major, srs.inverse_flattening), geodetic=srs.geographic, ) @@ -61,6 +77,7 @@ class BaseSpatialField(Field): properties that are common to all GIS fields such as the characteristics of the spatial reference system of the field. """ + description = _("The base GIS field.") empty_strings_allowed = False @@ -88,7 +105,7 @@ class BaseSpatialField(Field): # Setting the verbose_name keyword argument with the positional # first parameter, so this works like normal fields. - kwargs['verbose_name'] = verbose_name + kwargs["verbose_name"] = verbose_name super().__init__(**kwargs) @@ -96,9 +113,9 @@ class BaseSpatialField(Field): name, path, args, kwargs = super().deconstruct() # Always include SRID for less fragility; include spatial index if it's # not the default value. - kwargs['srid'] = self.srid + kwargs["srid"] = self.srid if self.spatial_index is not True: - kwargs['spatial_index'] = self.spatial_index + kwargs["spatial_index"] = self.spatial_index return name, path, args, kwargs def db_type(self, connection): @@ -146,10 +163,10 @@ class BaseSpatialField(Field): return connection.ops.Adapter( super().get_db_prep_value(value, connection, *args, **kwargs), **( - {'geography': True} + {"geography": True} if self.geography and connection.features.supports_geography else {} - ) + ), ) def get_raster_prep_value(self, value, is_candidate): @@ -167,7 +184,9 @@ class BaseSpatialField(Field): try: return gdal.GDALRaster(value) except GDALException: - raise ValueError("Couldn't create spatial object from lookup value '%s'." % value) + raise ValueError( + "Couldn't create spatial object from lookup value '%s'." % value + ) def get_prep_value(self, value): obj = super().get_prep_value(value) @@ -179,7 +198,9 @@ class BaseSpatialField(Field): pass else: # Check if input is a candidate for conversion to raster or geometry. - is_candidate = isinstance(obj, (bytes, str)) or hasattr(obj, '__geo_interface__') + is_candidate = isinstance(obj, (bytes, str)) or hasattr( + obj, "__geo_interface__" + ) # Try to convert the input to raster. raster = self.get_raster_prep_value(obj, is_candidate) @@ -189,9 +210,14 @@ class BaseSpatialField(Field): try: obj = GEOSGeometry(obj) except (GEOSException, GDALException): - raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj) + raise ValueError( + "Couldn't create spatial object from lookup value '%s'." % obj + ) else: - raise ValueError('Cannot use object with type %s for a spatial lookup parameter.' % type(obj).__name__) + raise ValueError( + "Cannot use object with type %s for a spatial lookup parameter." + % type(obj).__name__ + ) # Assigning the SRID value. obj.srid = self.get_srid(obj) @@ -202,14 +228,25 @@ class GeometryField(BaseSpatialField): """ The base Geometry field -- maps to the OpenGIS Specification Geometry type. """ - description = _('The base Geometry field — maps to the OpenGIS Specification Geometry type.') + + description = _( + "The base Geometry field — maps to the OpenGIS Specification Geometry type." + ) form_class = forms.GeometryField # The OpenGIS Geometry name. - geom_type = 'GEOMETRY' + geom_type = "GEOMETRY" geom_class = None - def __init__(self, verbose_name=None, dim=2, geography=False, *, extent=(-180.0, -90.0, 180.0, 90.0), - tolerance=0.05, **kwargs): + def __init__( + self, + verbose_name=None, + dim=2, + geography=False, + *, + extent=(-180.0, -90.0, 180.0, 90.0), + tolerance=0.05, + **kwargs, + ): """ The initialization function for geometry fields. In addition to the parameters from BaseSpatialField, it takes the following as keyword @@ -244,30 +281,36 @@ class GeometryField(BaseSpatialField): name, path, args, kwargs = super().deconstruct() # Include kwargs if they're not the default values. if self.dim != 2: - kwargs['dim'] = self.dim + kwargs["dim"] = self.dim if self.geography is not False: - kwargs['geography'] = self.geography + kwargs["geography"] = self.geography if self._extent != (-180.0, -90.0, 180.0, 90.0): - kwargs['extent'] = self._extent + kwargs["extent"] = self._extent if self._tolerance != 0.05: - kwargs['tolerance'] = self._tolerance + kwargs["tolerance"] = self._tolerance return name, path, args, kwargs def contribute_to_class(self, cls, name, **kwargs): super().contribute_to_class(cls, name, **kwargs) # Setup for lazy-instantiated Geometry object. - setattr(cls, self.attname, SpatialProxy(self.geom_class or GEOSGeometry, self, load_func=GEOSGeometry)) + setattr( + cls, + self.attname, + SpatialProxy(self.geom_class or GEOSGeometry, self, load_func=GEOSGeometry), + ) def formfield(self, **kwargs): defaults = { - 'form_class': self.form_class, - 'geom_type': self.geom_type, - 'srid': self.srid, + "form_class": self.form_class, + "geom_type": self.geom_type, + "srid": self.srid, **kwargs, } - if self.dim > 2 and not getattr(defaults['form_class'].widget, 'supports_3d', False): - defaults.setdefault('widget', forms.Textarea) + if self.dim > 2 and not getattr( + defaults["form_class"].widget, "supports_3d", False + ): + defaults.setdefault("widget", forms.Textarea) return super().formfield(**defaults) def select_format(self, compiler, sql, params): @@ -283,49 +326,49 @@ class GeometryField(BaseSpatialField): # The OpenGIS Geometry Type Fields class PointField(GeometryField): - geom_type = 'POINT' + geom_type = "POINT" geom_class = Point form_class = forms.PointField description = _("Point") class LineStringField(GeometryField): - geom_type = 'LINESTRING' + geom_type = "LINESTRING" geom_class = LineString form_class = forms.LineStringField description = _("Line string") class PolygonField(GeometryField): - geom_type = 'POLYGON' + geom_type = "POLYGON" geom_class = Polygon form_class = forms.PolygonField description = _("Polygon") class MultiPointField(GeometryField): - geom_type = 'MULTIPOINT' + geom_type = "MULTIPOINT" geom_class = MultiPoint form_class = forms.MultiPointField description = _("Multi-point") class MultiLineStringField(GeometryField): - geom_type = 'MULTILINESTRING' + geom_type = "MULTILINESTRING" geom_class = MultiLineString form_class = forms.MultiLineStringField description = _("Multi-line string") class MultiPolygonField(GeometryField): - geom_type = 'MULTIPOLYGON' + geom_type = "MULTIPOLYGON" geom_class = MultiPolygon form_class = forms.MultiPolygonField description = _("Multi polygon") class GeometryCollectionField(GeometryField): - geom_type = 'GEOMETRYCOLLECTION' + geom_type = "GEOMETRYCOLLECTION" geom_class = GeometryCollection form_class = forms.GeometryCollectionField description = _("Geometry collection") @@ -350,13 +393,18 @@ class RasterField(BaseSpatialField): """ description = _("Raster Field") - geom_type = 'RASTER' + geom_type = "RASTER" geography = False def _check_connection(self, connection): # Make sure raster fields are used only on backends with raster support. - if not connection.features.gis_enabled or not connection.features.supports_raster: - raise ImproperlyConfigured('Raster fields require backends with raster support.') + if ( + not connection.features.gis_enabled + or not connection.features.supports_raster + ): + raise ImproperlyConfigured( + "Raster fields require backends with raster support." + ) def db_type(self, connection): self._check_connection(connection) @@ -375,12 +423,13 @@ class RasterField(BaseSpatialField): def get_transform(self, name): from django.contrib.gis.db.models.lookups import RasterBandTransform + try: band_index = int(name) return type( - 'SpecificRasterBandTransform', + "SpecificRasterBandTransform", (RasterBandTransform,), - {'band_index': band_index} + {"band_index": band_index}, ) except ValueError: pass diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 1f2d372ebb..c7b071c7c9 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -6,8 +6,14 @@ from django.contrib.gis.geos import GEOSGeometry from django.core.exceptions import FieldError from django.db import NotSupportedError from django.db.models import ( - BinaryField, BooleanField, FloatField, Func, IntegerField, TextField, - Transform, Value, + BinaryField, + BooleanField, + FloatField, + Func, + IntegerField, + TextField, + Transform, + Value, ) from django.db.models.functions import Cast from django.utils.functional import cached_property @@ -32,12 +38,21 @@ class GeoFuncMixin: except FieldError: output_field = None geom = expr.value - if not isinstance(geom, GEOSGeometry) or output_field and not isinstance(output_field, GeometryField): - raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1)) + if ( + not isinstance(geom, GEOSGeometry) + or output_field + and not isinstance(output_field, GeometryField) + ): + raise TypeError( + "%s function requires a geometric argument in position %d." + % (self.name, pos + 1) + ) if not geom.srid and not output_field: raise ValueError("SRID is required for all geometries.") if not output_field: - self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid)) + self.source_expressions[pos] = Value( + geom, output_field=GeometryField(srid=geom.srid) + ) @property def name(self): @@ -61,8 +76,11 @@ class GeoFuncMixin: field = source_fields[pos] if not isinstance(field, GeometryField): raise TypeError( - "%s function requires a GeometryField in position %s, got %s." % ( - self.name, pos + 1, type(field).__name__, + "%s function requires a GeometryField in position %s, got %s." + % ( + self.name, + pos + 1, + type(field).__name__, ) ) @@ -72,15 +90,17 @@ class GeoFuncMixin: expr_srid = expr.output_field.srid if expr_srid != base_srid: # Automatic SRID conversion so objects are comparable. - res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs) + res.source_expressions[pos] = Transform( + expr, base_srid + ).resolve_expression(*args, **kwargs) return res - def _handle_param(self, value, param_name='', check_types=None): - if not hasattr(value, 'resolve_expression'): + def _handle_param(self, value, param_name="", check_types=None): + if not hasattr(value, "resolve_expression"): if check_types and not isinstance(value, check_types): raise TypeError( - "The %s parameter has the wrong type: should be %s." % ( - param_name, check_types) + "The %s parameter has the wrong type: should be %s." + % (param_name, check_types) ) return value @@ -100,13 +120,17 @@ class SQLiteDecimalToFloatMixin: By default, Decimal values are converted to str by the SQLite backend, which is not acceptable by the GIS functions expecting numeric values. """ + def as_sqlite(self, compiler, connection, **extra_context): copy = self.copy() - copy.set_source_expressions([ - Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal) - else expr - for expr in copy.get_source_expressions() - ]) + copy.set_source_expressions( + [ + Value(float(expr.value)) + if hasattr(expr, "value") and isinstance(expr.value, Decimal) + else expr + for expr in copy.get_source_expressions() + ] + ) return copy.as_sql(compiler, connection, **extra_context) @@ -114,11 +138,13 @@ class OracleToleranceMixin: tolerance = 0.05 def as_oracle(self, compiler, connection, **extra_context): - tolerance = Value(self._handle_param( - self.extra.get('tolerance', self.tolerance), - 'tolerance', - NUMERIC_TYPES, - )) + tolerance = Value( + self._handle_param( + self.extra.get("tolerance", self.tolerance), + "tolerance", + NUMERIC_TYPES, + ) + ) clone = self.copy() clone.set_source_expressions([*self.get_source_expressions(), tolerance]) return clone.as_sql(compiler, connection, **extra_context) @@ -132,14 +158,18 @@ class Area(OracleToleranceMixin, GeoFunc): return AreaField(self.geo_field) def as_sql(self, compiler, connection, **extra_context): - if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection): - raise NotSupportedError('Area on geodetic coordinate systems not supported.') + if not connection.features.supports_area_geodetic and self.geo_field.geodetic( + connection + ): + raise NotSupportedError( + "Area on geodetic coordinate systems not supported." + ) return super().as_sql(compiler, connection, **extra_context) def as_sqlite(self, compiler, connection, **extra_context): if self.geo_field.geodetic(connection): - extra_context['template'] = '%(function)s(%(expressions)s, %(spheroid)d)' - extra_context['spheroid'] = True + extra_context["template"] = "%(function)s(%(expressions)s, %(spheroid)d)" + extra_context["spheroid"] = True return self.as_sql(compiler, connection, **extra_context) @@ -155,7 +185,7 @@ class AsGeoJSON(GeoFunc): def __init__(self, expression, bbox=False, crs=False, precision=8, **extra): expressions = [expression] if precision is not None: - expressions.append(self._handle_param(precision, 'precision', int)) + expressions.append(self._handle_param(precision, "precision", int)) options = 0 if crs and bbox: options = 3 @@ -181,7 +211,7 @@ class AsGML(GeoFunc): def __init__(self, expression, version=2, precision=8, **extra): expressions = [version, expression] if precision is not None: - expressions.append(self._handle_param(precision, 'precision', int)) + expressions.append(self._handle_param(precision, "precision", int)) super().__init__(*expressions, **extra) def as_oracle(self, compiler, connection, **extra_context): @@ -189,7 +219,11 @@ class AsGML(GeoFunc): version = source_expressions[0] clone = self.copy() clone.set_source_expressions([source_expressions[1]]) - extra_context['function'] = 'SDO_UTIL.TO_GML311GEOMETRY' if version.value == 3 else 'SDO_UTIL.TO_GMLGEOMETRY' + extra_context["function"] = ( + "SDO_UTIL.TO_GML311GEOMETRY" + if version.value == 3 + else "SDO_UTIL.TO_GMLGEOMETRY" + ) return super(AsGML, clone).as_sql(compiler, connection, **extra_context) @@ -199,7 +233,7 @@ class AsKML(GeoFunc): def __init__(self, expression, precision=8, **extra): expressions = [expression] if precision is not None: - expressions.append(self._handle_param(precision, 'precision', int)) + expressions.append(self._handle_param(precision, "precision", int)) super().__init__(*expressions, **extra) @@ -207,11 +241,13 @@ class AsSVG(GeoFunc): output_field = TextField() def __init__(self, expression, relative=False, precision=8, **extra): - relative = relative if hasattr(relative, 'resolve_expression') else int(relative) + relative = ( + relative if hasattr(relative, "resolve_expression") else int(relative) + ) expressions = [ expression, relative, - self._handle_param(precision, 'precision', int), + self._handle_param(precision, "precision", int), ] super().__init__(*expressions, **extra) @@ -233,7 +269,9 @@ class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc): def as_oracle(self, compiler, connection, **extra_context): clone = self.copy() clone.set_source_expressions([self.get_source_expressions()[0]]) - return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context) + return super(BoundingCircle, clone).as_oracle( + compiler, connection, **extra_context + ) class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): @@ -261,7 +299,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def __init__(self, expr1, expr2, spheroid=None, **extra): expressions = [expr1, expr2] if spheroid is not None: - self.spheroid = self._handle_param(spheroid, 'spheroid', bool) + self.spheroid = self._handle_param(spheroid, "spheroid", bool) super().__init__(*expressions, **extra) def as_postgresql(self, compiler, connection, **extra_context): @@ -282,18 +320,24 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): # Geometry fields with geodetic (lon/lat) coordinates need special distance functions if self.spheroid: # DistanceSpheroid is more accurate and resource intensive than DistanceSphere - function = connection.ops.spatial_function_name('DistanceSpheroid') + function = connection.ops.spatial_function_name("DistanceSpheroid") # Replace boolean param by the real spheroid of the base field - clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) + clone.source_expressions.append( + Value(self.geo_field.spheroid(connection)) + ) else: - function = connection.ops.spatial_function_name('DistanceSphere') - return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context) + function = connection.ops.spatial_function_name("DistanceSphere") + return super(Distance, clone).as_sql( + compiler, connection, function=function, **extra_context + ) def as_sqlite(self, compiler, connection, **extra_context): if self.geo_field.geodetic(connection): # SpatiaLite returns NULL instead of zero on geodetic coordinates - extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)' - extra_context['spheroid'] = int(bool(self.spheroid)) + extra_context[ + "template" + ] = "COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)" + extra_context["spheroid"] = int(bool(self.spheroid)) return super().as_sql(compiler, connection, **extra_context) @@ -311,7 +355,7 @@ class GeoHash(GeoFunc): def __init__(self, expression, precision=None, **extra): expressions = [expression] if precision is not None: - expressions.append(self._handle_param(precision, 'precision', int)) + expressions.append(self._handle_param(precision, "precision", int)) super().__init__(*expressions, **extra) def as_mysql(self, compiler, connection, **extra_context): @@ -325,8 +369,8 @@ class GeoHash(GeoFunc): class GeometryDistance(GeoFunc): output_field = FloatField() arity = 2 - function = '' - arg_joiner = ' <-> ' + function = "" + arg_joiner = " <-> " geom_param_pos = (0, 1) @@ -337,7 +381,7 @@ class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): @BaseSpatialField.register_lookup class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform): - lookup_name = 'isvalid' + lookup_name = "isvalid" output_field = BooleanField() def as_oracle(self, compiler, connection, **extra_context): @@ -351,8 +395,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): super().__init__(expr1, **extra) def as_sql(self, compiler, connection, **extra_context): - if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: - raise NotSupportedError("This backend doesn't support Length on geodetic fields") + if ( + self.geo_field.geodetic(connection) + and not connection.features.supports_length_geodetic + ): + raise NotSupportedError( + "This backend doesn't support Length on geodetic fields" + ) return super().as_sql(compiler, connection, **extra_context) def as_postgresql(self, compiler, connection, **extra_context): @@ -362,18 +411,20 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): clone.source_expressions.append(Value(self.spheroid)) elif self.geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid - function = connection.ops.spatial_function_name('LengthSpheroid') + function = connection.ops.spatial_function_name("LengthSpheroid") clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) else: dim = min(f.dim for f in self.get_source_fields() if f) if dim > 2: function = connection.ops.length3d - return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context) + return super(Length, clone).as_sql( + compiler, connection, function=function, **extra_context + ) def as_sqlite(self, compiler, connection, **extra_context): function = None if self.geo_field.geodetic(connection): - function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' + function = "GeodesicLength" if self.spheroid else "GreatCircleLength" return super().as_sql(compiler, connection, function=function, **extra_context) @@ -408,7 +459,9 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def as_postgresql(self, compiler, connection, **extra_context): function = None if self.geo_field.geodetic(connection) and not self.source_is_geography(): - raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") + raise NotSupportedError( + "ST_Perimeter cannot use a non-projected non-geography field." + ) dim = min(f.dim for f in self.get_source_fields()) if dim > 2: function = connection.ops.perimeter3d @@ -432,11 +485,11 @@ class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): def __init__(self, expression, x, y, z=0.0, **extra): expressions = [ expression, - self._handle_param(x, 'x', NUMERIC_TYPES), - self._handle_param(y, 'y', NUMERIC_TYPES), + self._handle_param(x, "x", NUMERIC_TYPES), + self._handle_param(y, "y", NUMERIC_TYPES), ] if z != 0.0: - expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES)) + expressions.append(self._handle_param(z, "z", NUMERIC_TYPES)) super().__init__(*expressions, **extra) @@ -446,16 +499,16 @@ class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc): expressions = [expression] if nargs in (1, 2): expressions.extend( - [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args] + [self._handle_param(arg, "", NUMERIC_TYPES) for arg in args] ) elif nargs == 4: # Reverse origin and size param ordering expressions += [ - *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]), - *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]), + *(self._handle_param(arg, "", NUMERIC_TYPES) for arg in args[2:]), + *(self._handle_param(arg, "", NUMERIC_TYPES) for arg in args[0:2]), ] else: - raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.') + raise ValueError("Must provide 1, 2, or 4 arguments to `SnapToGrid`.") super().__init__(*expressions, **extra) @@ -468,10 +521,10 @@ class Transform(GeomOutputGeoFunc): def __init__(self, expression, srid, **extra): expressions = [ expression, - self._handle_param(srid, 'srid', int), + self._handle_param(srid, "srid", int), ] - if 'output_field' not in extra: - extra['output_field'] = GeometryField(srid=srid) + if "output_field" not in extra: + extra["output_field"] = GeometryField(srid=srid) super().__init__(*expressions, **extra) diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index c1d021b28e..9c06922508 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -27,10 +27,10 @@ class GISLookup(Lookup): def process_rhs_params(self): if self.rhs_params: # Check if a band index was passed in the query argument. - if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1): + if len(self.rhs_params) == (2 if self.lookup_name == "relate" else 1): self.process_band_indices() elif len(self.rhs_params) > 1: - raise ValueError('Tuple too long for lookup %s.' % self.lookup_name) + raise ValueError("Tuple too long for lookup %s." % self.lookup_name) elif isinstance(self.lhs, RasterBandTransform): self.process_band_indices(only_lhs=True) @@ -55,7 +55,7 @@ class GISLookup(Lookup): def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class - return ('%s', [connection.ops.Adapter(value)]) + return ("%s", [connection.ops.Adapter(value)]) def process_rhs(self, compiler, connection): if isinstance(self.rhs, Query): @@ -64,7 +64,9 @@ class GISLookup(Lookup): if isinstance(self.rhs, Expression): self.rhs = self.rhs.resolve_expression(compiler.query) rhs, rhs_params = super().process_rhs(compiler, connection) - placeholder = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler) + placeholder = connection.ops.get_geom_placeholder( + self.lhs.output_field, self.rhs, compiler + ) return placeholder % rhs, rhs_params def get_rhs_op(self, connection, rhs): @@ -78,7 +80,12 @@ class GISLookup(Lookup): rhs_sql, rhs_params = self.process_rhs(compiler, connection) sql_params = (*lhs_params, *rhs_params) - template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s', **self.template_params} + template_params = { + "lhs": lhs_sql, + "rhs": rhs_sql, + "value": "%s", + **self.template_params, + } rhs_op = self.get_rhs_op(connection, rhs_sql) return rhs_op.as_sql(connection, self, template_params, sql_params) @@ -87,13 +94,15 @@ class GISLookup(Lookup): # Geometry operators # ------------------ + @BaseSpatialField.register_lookup class OverlapsLeftLookup(GISLookup): """ The overlaps_left operator returns true if A's bounding box overlaps or is to the left of B's bounding box. """ - lookup_name = 'overlaps_left' + + lookup_name = "overlaps_left" @BaseSpatialField.register_lookup @@ -102,7 +111,8 @@ class OverlapsRightLookup(GISLookup): The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the right of B's bounding box. """ - lookup_name = 'overlaps_right' + + lookup_name = "overlaps_right" @BaseSpatialField.register_lookup @@ -111,7 +121,8 @@ class OverlapsBelowLookup(GISLookup): The 'overlaps_below' operator returns true if A's bounding box overlaps or is below B's bounding box. """ - lookup_name = 'overlaps_below' + + lookup_name = "overlaps_below" @BaseSpatialField.register_lookup @@ -120,7 +131,8 @@ class OverlapsAboveLookup(GISLookup): The 'overlaps_above' operator returns true if A's bounding box overlaps or is above B's bounding box. """ - lookup_name = 'overlaps_above' + + lookup_name = "overlaps_above" @BaseSpatialField.register_lookup @@ -129,7 +141,8 @@ class LeftLookup(GISLookup): The 'left' operator returns true if A's bounding box is strictly to the left of B's bounding box. """ - lookup_name = 'left' + + lookup_name = "left" @BaseSpatialField.register_lookup @@ -138,7 +151,8 @@ class RightLookup(GISLookup): The 'right' operator returns true if A's bounding box is strictly to the right of B's bounding box. """ - lookup_name = 'right' + + lookup_name = "right" @BaseSpatialField.register_lookup @@ -147,7 +161,8 @@ class StrictlyBelowLookup(GISLookup): The 'strictly_below' operator returns true if A's bounding box is strictly below B's bounding box. """ - lookup_name = 'strictly_below' + + lookup_name = "strictly_below" @BaseSpatialField.register_lookup @@ -156,7 +171,8 @@ class StrictlyAboveLookup(GISLookup): The 'strictly_above' operator returns true if A's bounding box is strictly above B's bounding box. """ - lookup_name = 'strictly_above' + + lookup_name = "strictly_above" @BaseSpatialField.register_lookup @@ -166,10 +182,11 @@ class SameAsLookup(GISLookup): equality of two features. So if A and B are the same feature, vertex-by-vertex, the operator returns true. """ - lookup_name = 'same_as' + + lookup_name = "same_as" -BaseSpatialField.register_lookup(SameAsLookup, 'exact') +BaseSpatialField.register_lookup(SameAsLookup, "exact") @BaseSpatialField.register_lookup @@ -178,7 +195,8 @@ class BBContainsLookup(GISLookup): The 'bbcontains' operator returns true if A's bounding box completely contains by B's bounding box. """ - lookup_name = 'bbcontains' + + lookup_name = "bbcontains" @BaseSpatialField.register_lookup @@ -186,7 +204,8 @@ class BBOverlapsLookup(GISLookup): """ The 'bboverlaps' operator returns true if A's bounding box overlaps B's bounding box. """ - lookup_name = 'bboverlaps' + + lookup_name = "bboverlaps" @BaseSpatialField.register_lookup @@ -195,69 +214,71 @@ class ContainedLookup(GISLookup): The 'contained' operator returns true if A's bounding box is completely contained by B's bounding box. """ - lookup_name = 'contained' + + lookup_name = "contained" # ------------------ # Geometry functions # ------------------ + @BaseSpatialField.register_lookup class ContainsLookup(GISLookup): - lookup_name = 'contains' + lookup_name = "contains" @BaseSpatialField.register_lookup class ContainsProperlyLookup(GISLookup): - lookup_name = 'contains_properly' + lookup_name = "contains_properly" @BaseSpatialField.register_lookup class CoveredByLookup(GISLookup): - lookup_name = 'coveredby' + lookup_name = "coveredby" @BaseSpatialField.register_lookup class CoversLookup(GISLookup): - lookup_name = 'covers' + lookup_name = "covers" @BaseSpatialField.register_lookup class CrossesLookup(GISLookup): - lookup_name = 'crosses' + lookup_name = "crosses" @BaseSpatialField.register_lookup class DisjointLookup(GISLookup): - lookup_name = 'disjoint' + lookup_name = "disjoint" @BaseSpatialField.register_lookup class EqualsLookup(GISLookup): - lookup_name = 'equals' + lookup_name = "equals" @BaseSpatialField.register_lookup class IntersectsLookup(GISLookup): - lookup_name = 'intersects' + lookup_name = "intersects" @BaseSpatialField.register_lookup class OverlapsLookup(GISLookup): - lookup_name = 'overlaps' + lookup_name = "overlaps" @BaseSpatialField.register_lookup class RelateLookup(GISLookup): - lookup_name = 'relate' - sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)' - pattern_regex = _lazy_re_compile(r'^[012TF\*]{9}$') + lookup_name = "relate" + sql_template = "%(func)s(%(lhs)s, %(rhs)s, %%s)" + pattern_regex = _lazy_re_compile(r"^[012TF\*]{9}$") def process_rhs(self, compiler, connection): # Check the pattern argument pattern = self.rhs_params[0] backend_op = connection.ops.gis_operators[self.lookup_name] - if hasattr(backend_op, 'check_relate_argument'): + if hasattr(backend_op, "check_relate_argument"): backend_op.check_relate_argument(pattern) elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern): raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) @@ -267,93 +288,106 @@ class RelateLookup(GISLookup): @BaseSpatialField.register_lookup class TouchesLookup(GISLookup): - lookup_name = 'touches' + lookup_name = "touches" @BaseSpatialField.register_lookup class WithinLookup(GISLookup): - lookup_name = 'within' + lookup_name = "within" class DistanceLookupBase(GISLookup): distance = True - sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' + sql_template = "%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s" def process_rhs_params(self): if not 1 <= len(self.rhs_params) <= 3: - raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name) - elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid': - raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.") + raise ValueError( + "2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name + ) + elif len(self.rhs_params) == 3 and self.rhs_params[2] != "spheroid": + raise ValueError( + "For 4-element tuples the last argument must be the 'spheroid' directive." + ) # Check if the second parameter is a band index. - if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid': + if len(self.rhs_params) > 1 and self.rhs_params[1] != "spheroid": self.process_band_indices() def process_distance(self, compiler, connection): dist_param = self.rhs_params[0] return ( compiler.compile(dist_param.resolve_expression(compiler.query)) - if hasattr(dist_param, 'resolve_expression') else - ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name)) + if hasattr(dist_param, "resolve_expression") + else ( + "%s", + connection.ops.get_distance( + self.lhs.output_field, self.rhs_params, self.lookup_name + ), + ) ) @BaseSpatialField.register_lookup class DWithinLookup(DistanceLookupBase): - lookup_name = 'dwithin' - sql_template = '%(func)s(%(lhs)s, %(rhs)s, %(value)s)' + lookup_name = "dwithin" + sql_template = "%(func)s(%(lhs)s, %(rhs)s, %(value)s)" def process_distance(self, compiler, connection): dist_param = self.rhs_params[0] if ( - not connection.features.supports_dwithin_distance_expr and - hasattr(dist_param, 'resolve_expression') and - not isinstance(dist_param, Distance) + not connection.features.supports_dwithin_distance_expr + and hasattr(dist_param, "resolve_expression") + and not isinstance(dist_param, Distance) ): raise NotSupportedError( - 'This backend does not support expressions for specifying ' - 'distance in the dwithin lookup.' + "This backend does not support expressions for specifying " + "distance in the dwithin lookup." ) return super().process_distance(compiler, connection) def process_rhs(self, compiler, connection): dist_sql, dist_params = self.process_distance(compiler, connection) - self.template_params['value'] = dist_sql + self.template_params["value"] = dist_sql rhs_sql, params = super().process_rhs(compiler, connection) return rhs_sql, params + dist_params class DistanceLookupFromFunction(DistanceLookupBase): def as_sql(self, compiler, connection): - spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None - distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid) + spheroid = ( + len(self.rhs_params) == 2 and self.rhs_params[-1] == "spheroid" + ) or None + distance_expr = connection.ops.distance_expr_for_lookup( + self.lhs, self.rhs, spheroid=spheroid + ) sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query)) dist_sql, dist_params = self.process_distance(compiler, connection) return ( - '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql}, + "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql}, params + dist_params, ) @BaseSpatialField.register_lookup class DistanceGTLookup(DistanceLookupFromFunction): - lookup_name = 'distance_gt' - op = '>' + lookup_name = "distance_gt" + op = ">" @BaseSpatialField.register_lookup class DistanceGTELookup(DistanceLookupFromFunction): - lookup_name = 'distance_gte' - op = '>=' + lookup_name = "distance_gte" + op = ">=" @BaseSpatialField.register_lookup class DistanceLTLookup(DistanceLookupFromFunction): - lookup_name = 'distance_lt' - op = '<' + lookup_name = "distance_lt" + op = "<" @BaseSpatialField.register_lookup class DistanceLTELookup(DistanceLookupFromFunction): - lookup_name = 'distance_lte' - op = '<=' + lookup_name = "distance_lte" + op = "<=" diff --git a/django/contrib/gis/db/models/proxy.py b/django/contrib/gis/db/models/proxy.py index 07b26e9910..4db365dc16 100644 --- a/django/contrib/gis/db/models/proxy.py +++ b/django/contrib/gis/db/models/proxy.py @@ -37,7 +37,7 @@ class SpatialProxy(DeferredAttribute): if isinstance(geo_value, self._klass): geo_obj = geo_value - elif (geo_value is None) or (geo_value == ''): + elif (geo_value is None) or (geo_value == ""): geo_obj = None else: # Otherwise, a geometry or raster object is built using the field's @@ -57,7 +57,9 @@ class SpatialProxy(DeferredAttribute): # The geographic type of the field. gtype = self.field.geom_type - if gtype == 'RASTER' and (value is None or isinstance(value, (str, dict, self._klass))): + if gtype == "RASTER" and ( + value is None or isinstance(value, (str, dict, self._klass)) + ): # For raster fields, ensure input is None or a string, dict, or # raster instance. pass @@ -71,8 +73,10 @@ class SpatialProxy(DeferredAttribute): # Set geometries with None, WKT, HEX, or WKB pass else: - raise TypeError('Cannot set %s SpatialProxy (%s) with value of type: %s' % ( - instance.__class__.__name__, gtype, type(value))) + raise TypeError( + "Cannot set %s SpatialProxy (%s) with value of type: %s" + % (instance.__class__.__name__, gtype, type(value)) + ) # Setting the objects dictionary with the value, and returning. instance.__dict__[self.field.attname] = value diff --git a/django/contrib/gis/db/models/sql/__init__.py b/django/contrib/gis/db/models/sql/__init__.py index 850c644d5d..1376e8d7d1 100644 --- a/django/contrib/gis/db/models/sql/__init__.py +++ b/django/contrib/gis/db/models/sql/__init__.py @@ -1,7 +1,6 @@ -from django.contrib.gis.db.models.sql.conversion import ( - AreaField, DistanceField, -) +from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField __all__ = [ - 'AreaField', 'DistanceField', + "AreaField", + "DistanceField", ] diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index 99ab51e239..be712319fb 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -10,13 +10,14 @@ from django.db import models class AreaField(models.FloatField): "Wrapper for Area values." + def __init__(self, geo_field): super().__init__() self.geo_field = geo_field def get_prep_value(self, value): if not isinstance(value, Area): - raise ValueError('AreaField only accepts Area measurement objects.') + raise ValueError("AreaField only accepts Area measurement objects.") return value def get_db_prep_value(self, value, connection, prepared=False): @@ -37,11 +38,12 @@ class AreaField(models.FloatField): return Area(**{area_att: value}) if area_att else value def get_internal_type(self): - return 'AreaField' + return "AreaField" class DistanceField(models.FloatField): "Wrapper for Distance values." + def __init__(self, geo_field): super().__init__() self.geo_field = geo_field @@ -56,7 +58,9 @@ class DistanceField(models.FloatField): return value distance_att = connection.ops.get_distance_att_for_field(self.geo_field) if not distance_att: - raise ValueError('Distance measure is supplied, but units are unknown for result.') + raise ValueError( + "Distance measure is supplied, but units are unknown for result." + ) return getattr(value, distance_att) def from_db_value(self, value, expression, connection): @@ -66,4 +70,4 @@ class DistanceField(models.FloatField): return Distance(**{distance_att: value}) if distance_att else value def get_internal_type(self): - return 'DistanceField' + return "DistanceField" diff --git a/django/contrib/gis/feeds.py b/django/contrib/gis/feeds.py index cfc078b781..ebd4511889 100644 --- a/django/contrib/gis/feeds.py +++ b/django/contrib/gis/feeds.py @@ -14,7 +14,7 @@ class GeoFeedMixin: a single white space. Given a tuple of coordinates, return a string GeoRSS representation. """ - return ' '.join('%f %f' % (coord[1], coord[0]) for coord in coords) + return " ".join("%f %f" % (coord[1], coord[0]) for coord in coords) def add_georss_point(self, handler, coords, w3c_geo=False): """ @@ -24,15 +24,15 @@ class GeoFeedMixin: """ if w3c_geo: lon, lat = coords[:2] - handler.addQuickElement('geo:lat', '%f' % lat) - handler.addQuickElement('geo:lon', '%f' % lon) + handler.addQuickElement("geo:lat", "%f" % lat) + handler.addQuickElement("geo:lon", "%f" % lon) else: - handler.addQuickElement('georss:point', self.georss_coords((coords,))) + handler.addQuickElement("georss:point", self.georss_coords((coords,))) def add_georss_element(self, handler, item, w3c_geo=False): """Add a GeoRSS XML element using the given item and handler.""" # Getting the Geometry object. - geom = item.get('geometry') + geom = item.get("geometry") if geom is not None: if isinstance(geom, (list, tuple)): # Special case if a tuple/list was passed in. The tuple may be @@ -43,7 +43,7 @@ class GeoFeedMixin: if len(geom) == 2: box_coords = geom else: - raise ValueError('Only should be two sets of coordinates.') + raise ValueError("Only should be two sets of coordinates.") else: if len(geom) == 2: # Point: (X, Y) @@ -52,36 +52,46 @@ class GeoFeedMixin: # Box: (X0, Y0, X1, Y1) box_coords = (geom[:2], geom[2:]) else: - raise ValueError('Only should be 2 or 4 numeric elements.') + raise ValueError("Only should be 2 or 4 numeric elements.") # If a GeoRSS box was given via tuple. if box_coords is not None: if w3c_geo: - raise ValueError('Cannot use simple GeoRSS box in W3C Geo feeds.') - handler.addQuickElement('georss:box', self.georss_coords(box_coords)) + raise ValueError( + "Cannot use simple GeoRSS box in W3C Geo feeds." + ) + handler.addQuickElement( + "georss:box", self.georss_coords(box_coords) + ) else: # Getting the lowercase geometry type. gtype = str(geom.geom_type).lower() - if gtype == 'point': + if gtype == "point": self.add_georss_point(handler, geom.coords, w3c_geo=w3c_geo) else: if w3c_geo: - raise ValueError('W3C Geo only supports Point geometries.') + raise ValueError("W3C Geo only supports Point geometries.") # For formatting consistent w/the GeoRSS simple standard: # http://georss.org/1.0#simple - if gtype in ('linestring', 'linearring'): - handler.addQuickElement('georss:line', self.georss_coords(geom.coords)) - elif gtype in ('polygon',): + if gtype in ("linestring", "linearring"): + handler.addQuickElement( + "georss:line", self.georss_coords(geom.coords) + ) + elif gtype in ("polygon",): # Only support the exterior ring. - handler.addQuickElement('georss:polygon', self.georss_coords(geom[0].coords)) + handler.addQuickElement( + "georss:polygon", self.georss_coords(geom[0].coords) + ) else: - raise ValueError('Geometry type "%s" not supported.' % geom.geom_type) + raise ValueError( + 'Geometry type "%s" not supported.' % geom.geom_type + ) # ### SyndicationFeed subclasses ### class GeoRSSFeed(Rss201rev2Feed, GeoFeedMixin): def rss_attributes(self): attrs = super().rss_attributes() - attrs['xmlns:georss'] = 'http://www.georss.org/georss' + attrs["xmlns:georss"] = "http://www.georss.org/georss" return attrs def add_item_elements(self, handler, item): @@ -96,7 +106,7 @@ class GeoRSSFeed(Rss201rev2Feed, GeoFeedMixin): class GeoAtom1Feed(Atom1Feed, GeoFeedMixin): def root_attributes(self): attrs = super().root_attributes() - attrs['xmlns:georss'] = 'http://www.georss.org/georss' + attrs["xmlns:georss"] = "http://www.georss.org/georss" return attrs def add_item_elements(self, handler, item): @@ -111,7 +121,7 @@ class GeoAtom1Feed(Atom1Feed, GeoFeedMixin): class W3CGeoFeed(Rss201rev2Feed, GeoFeedMixin): def rss_attributes(self): attrs = super().rss_attributes() - attrs['xmlns:geo'] = 'http://www.w3.org/2003/01/geo/wgs84_pos#' + attrs["xmlns:geo"] = "http://www.w3.org/2003/01/geo/wgs84_pos#" return attrs def add_item_elements(self, handler, item): @@ -131,10 +141,11 @@ class Feed(BaseFeed): methods on their own subclasses so that geo-referenced information may placed in the feed. """ + feed_type = GeoRSSFeed def feed_extra_kwargs(self, obj): - return {'geometry': self._get_dynamic_attr('geometry', obj)} + return {"geometry": self._get_dynamic_attr("geometry", obj)} def item_extra_kwargs(self, item): - return {'geometry': self._get_dynamic_attr('item_geometry', item)} + return {"geometry": self._get_dynamic_attr("item_geometry", item)} diff --git a/django/contrib/gis/forms/__init__.py b/django/contrib/gis/forms/__init__.py index 237cbace8d..c07720b2d0 100644 --- a/django/contrib/gis/forms/__init__.py +++ b/django/contrib/gis/forms/__init__.py @@ -1,8 +1,13 @@ from django.forms import * # NOQA from .fields import ( # NOQA - GeometryCollectionField, GeometryField, LineStringField, - MultiLineStringField, MultiPointField, MultiPolygonField, PointField, + GeometryCollectionField, + GeometryField, + LineStringField, + MultiLineStringField, + MultiPointField, + MultiPolygonField, + PointField, PolygonField, ) from .widgets import BaseGeometryWidget, OpenLayersWidget, OSMWidget # NOQA diff --git a/django/contrib/gis/forms/fields.py b/django/contrib/gis/forms/fields.py index 301d770836..1fd31530c1 100644 --- a/django/contrib/gis/forms/fields.py +++ b/django/contrib/gis/forms/fields.py @@ -13,15 +13,18 @@ class GeometryField(forms.Field): accepted by GEOSGeometry is accepted by this form. By default, this includes WKT, HEXEWKB, WKB (in a buffer), and GeoJSON. """ + widget = OpenLayersWidget - geom_type = 'GEOMETRY' + geom_type = "GEOMETRY" default_error_messages = { - 'required': _('No geometry value provided.'), - 'invalid_geom': _('Invalid geometry value.'), - 'invalid_geom_type': _('Invalid geometry type.'), - 'transform_error': _('An error occurred when transforming the geometry ' - 'to the SRID of the geometry form field.'), + "required": _("No geometry value provided."), + "invalid_geom": _("Invalid geometry value."), + "invalid_geom_type": _("Invalid geometry type."), + "transform_error": _( + "An error occurred when transforming the geometry " + "to the SRID of the geometry form field." + ), } def __init__(self, *, srid=None, geom_type=None, **kwargs): @@ -29,7 +32,7 @@ class GeometryField(forms.Field): if geom_type is not None: self.geom_type = geom_type super().__init__(**kwargs) - self.widget.attrs['geom_type'] = self.geom_type + self.widget.attrs["geom_type"] = self.geom_type def to_python(self, value): """Transform the value to a Geometry object.""" @@ -37,7 +40,7 @@ class GeometryField(forms.Field): return None if not isinstance(value, GEOSGeometry): - if hasattr(self.widget, 'deserialize'): + if hasattr(self.widget, "deserialize"): try: value = self.widget.deserialize(value) except GDALException: @@ -48,7 +51,9 @@ class GeometryField(forms.Field): except (GEOSException, ValueError, TypeError): value = None if value is None: - raise ValidationError(self.error_messages['invalid_geom'], code='invalid_geom') + raise ValidationError( + self.error_messages["invalid_geom"], code="invalid_geom" + ) # Try to set the srid if not value.srid: @@ -71,8 +76,13 @@ class GeometryField(forms.Field): # Ensuring that the geometry is of the correct type (indicated # using the OGC string label). - if str(geom.geom_type).upper() != self.geom_type and self.geom_type != 'GEOMETRY': - raise ValidationError(self.error_messages['invalid_geom_type'], code='invalid_geom_type') + if ( + str(geom.geom_type).upper() != self.geom_type + and self.geom_type != "GEOMETRY" + ): + raise ValidationError( + self.error_messages["invalid_geom_type"], code="invalid_geom_type" + ) # Transforming the geometry if the SRID was set. if self.srid and self.srid != -1 and self.srid != geom.srid: @@ -80,12 +90,13 @@ class GeometryField(forms.Field): geom.transform(self.srid) except GEOSException: raise ValidationError( - self.error_messages['transform_error'], code='transform_error') + self.error_messages["transform_error"], code="transform_error" + ) return geom def has_changed(self, initial, data): - """ Compare geographic value of data with its initial value. """ + """Compare geographic value of data with its initial value.""" try: data = self.to_python(data) @@ -106,28 +117,28 @@ class GeometryField(forms.Field): class GeometryCollectionField(GeometryField): - geom_type = 'GEOMETRYCOLLECTION' + geom_type = "GEOMETRYCOLLECTION" class PointField(GeometryField): - geom_type = 'POINT' + geom_type = "POINT" class MultiPointField(GeometryField): - geom_type = 'MULTIPOINT' + geom_type = "MULTIPOINT" class LineStringField(GeometryField): - geom_type = 'LINESTRING' + geom_type = "LINESTRING" class MultiLineStringField(GeometryField): - geom_type = 'MULTILINESTRING' + geom_type = "MULTILINESTRING" class PolygonField(GeometryField): - geom_type = 'POLYGON' + geom_type = "POLYGON" class MultiPolygonField(GeometryField): - geom_type = 'MULTIPOLYGON' + geom_type = "MULTIPOLYGON" diff --git a/django/contrib/gis/forms/widgets.py b/django/contrib/gis/forms/widgets.py index e38f173a98..0f53ee2e96 100644 --- a/django/contrib/gis/forms/widgets.py +++ b/django/contrib/gis/forms/widgets.py @@ -7,7 +7,7 @@ from django.contrib.gis.geos import GEOSException, GEOSGeometry from django.forms.widgets import Widget from django.utils import translation -logger = logging.getLogger('django.contrib.gis') +logger = logging.getLogger("django.contrib.gis") class BaseGeometryWidget(Widget): @@ -15,24 +15,25 @@ class BaseGeometryWidget(Widget): The base class for rich geometry widgets. Render a map using the WKT of the geometry. """ - geom_type = 'GEOMETRY' + + geom_type = "GEOMETRY" map_srid = 4326 map_width = 600 map_height = 400 display_raw = False supports_3d = False - template_name = '' # set on subclasses + template_name = "" # set on subclasses def __init__(self, attrs=None): self.attrs = {} - for key in ('geom_type', 'map_srid', 'map_width', 'map_height', 'display_raw'): + for key in ("geom_type", "map_srid", "map_width", "map_height", "display_raw"): self.attrs[key] = getattr(self, key) if attrs: self.attrs.update(attrs) def serialize(self, value): - return value.wkt if value else '' + return value.wkt if value else "" def deserialize(self, value): try: @@ -58,40 +59,47 @@ class BaseGeometryWidget(Widget): except gdal.GDALException as err: logger.error( "Error transforming geometry from srid '%s' to srid '%s' (%s)", - value.srid, self.map_srid, err + value.srid, + self.map_srid, + err, ) - geom_type = gdal.OGRGeomType(self.attrs['geom_type']).name - context.update(self.build_attrs(self.attrs, { - 'name': name, - 'module': 'geodjango_%s' % name.replace('-', '_'), # JS-safe - 'serialized': self.serialize(value), - 'geom_type': 'Geometry' if geom_type == 'Unknown' else geom_type, - 'STATIC_URL': settings.STATIC_URL, - 'LANGUAGE_BIDI': translation.get_language_bidi(), - **(attrs or {}), - })) + geom_type = gdal.OGRGeomType(self.attrs["geom_type"]).name + context.update( + self.build_attrs( + self.attrs, + { + "name": name, + "module": "geodjango_%s" % name.replace("-", "_"), # JS-safe + "serialized": self.serialize(value), + "geom_type": "Geometry" if geom_type == "Unknown" else geom_type, + "STATIC_URL": settings.STATIC_URL, + "LANGUAGE_BIDI": translation.get_language_bidi(), + **(attrs or {}), + }, + ) + ) return context class OpenLayersWidget(BaseGeometryWidget): - template_name = 'gis/openlayers.html' + template_name = "gis/openlayers.html" map_srid = 3857 class Media: css = { - 'all': ( - 'https://cdnjs.cloudflare.com/ajax/libs/ol3/4.6.5/ol.css', - 'gis/css/ol3.css', + "all": ( + "https://cdnjs.cloudflare.com/ajax/libs/ol3/4.6.5/ol.css", + "gis/css/ol3.css", ) } js = ( - 'https://cdnjs.cloudflare.com/ajax/libs/ol3/4.6.5/ol.js', - 'gis/js/OLMapWidget.js', + "https://cdnjs.cloudflare.com/ajax/libs/ol3/4.6.5/ol.js", + "gis/js/OLMapWidget.js", ) def serialize(self, value): - return value.json if value else '' + return value.json if value else "" def deserialize(self, value): geom = super().deserialize(value) @@ -105,14 +113,15 @@ class OSMWidget(OpenLayersWidget): """ An OpenLayers/OpenStreetMap-based widget. """ - template_name = 'gis/openlayers-osm.html' + + template_name = "gis/openlayers-osm.html" default_lon = 5 default_lat = 47 default_zoom = 12 def __init__(self, attrs=None): super().__init__() - for key in ('default_lon', 'default_lat', 'default_zoom'): + for key in ("default_lon", "default_lat", "default_zoom"): self.attrs[key] = getattr(self, key) if attrs: self.attrs.update(attrs) diff --git a/django/contrib/gis/gdal/__init__.py b/django/contrib/gis/gdal/__init__.py index 1d8c6884bd..9ed6e31156 100644 --- a/django/contrib/gis/gdal/__init__.py +++ b/django/contrib/gis/gdal/__init__.py @@ -28,22 +28,31 @@ from django.contrib.gis.gdal.datasource import DataSource from django.contrib.gis.gdal.driver import Driver from django.contrib.gis.gdal.envelope import Envelope -from django.contrib.gis.gdal.error import ( - GDALException, SRSException, check_err, -) +from django.contrib.gis.gdal.error import GDALException, SRSException, check_err from django.contrib.gis.gdal.geometries import OGRGeometry from django.contrib.gis.gdal.geomtype import OGRGeomType from django.contrib.gis.gdal.libgdal import ( - GDAL_VERSION, gdal_full_version, gdal_version, + GDAL_VERSION, + gdal_full_version, + gdal_version, ) from django.contrib.gis.gdal.raster.source import GDALRaster -from django.contrib.gis.gdal.srs import ( - AxisOrder, CoordTransform, SpatialReference, -) +from django.contrib.gis.gdal.srs import AxisOrder, CoordTransform, SpatialReference __all__ = ( - 'AxisOrder', 'Driver', 'DataSource', 'CoordTransform', 'Envelope', - 'GDALException', 'GDALRaster', 'GDAL_VERSION', 'OGRGeometry', - 'OGRGeomType', 'SpatialReference', 'SRSException', 'check_err', - 'gdal_version', 'gdal_full_version', + "AxisOrder", + "Driver", + "DataSource", + "CoordTransform", + "Envelope", + "GDALException", + "GDALRaster", + "GDAL_VERSION", + "OGRGeometry", + "OGRGeomType", + "SpatialReference", + "SRSException", + "check_err", + "gdal_version", + "gdal_full_version", ) diff --git a/django/contrib/gis/gdal/datasource.py b/django/contrib/gis/gdal/datasource.py index f2091c864a..dfd043ab0c 100644 --- a/django/contrib/gis/gdal/datasource.py +++ b/django/contrib/gis/gdal/datasource.py @@ -52,7 +52,7 @@ class DataSource(GDALBase): "Wraps an OGR Data Source object." destructor = capi.destroy_ds - def __init__(self, ds_input, ds_driver=False, write=False, encoding='utf-8'): + def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"): # The write flag. if write: self._write = 1 @@ -73,10 +73,12 @@ class DataSource(GDALBase): # Making the error message more clear rather than something # like "Invalid pointer returned from OGROpen". raise GDALException('Could not open the datasource at "%s"' % ds_input) - elif isinstance(ds_input, self.ptr_type) and isinstance(ds_driver, Driver.ptr_type): + elif isinstance(ds_input, self.ptr_type) and isinstance( + ds_driver, Driver.ptr_type + ): ds = ds_input else: - raise GDALException('Invalid data source input type: %s' % type(ds_input)) + raise GDALException("Invalid data source input type: %s" % type(ds_input)) if ds: self.ptr = ds @@ -91,14 +93,17 @@ class DataSource(GDALBase): try: layer = capi.get_layer_by_name(self.ptr, force_bytes(index)) except GDALException: - raise IndexError('Invalid OGR layer name given: %s.' % index) + raise IndexError("Invalid OGR layer name given: %s." % index) elif isinstance(index, int): if 0 <= index < self.layer_count: layer = capi.get_layer(self._ptr, index) else: - raise IndexError('Index out of range when accessing layers in a datasource: %s.' % index) + raise IndexError( + "Index out of range when accessing layers in a datasource: %s." + % index + ) else: - raise TypeError('Invalid index type: %s' % type(index)) + raise TypeError("Invalid index type: %s" % type(index)) return Layer(layer, self) def __len__(self): @@ -107,7 +112,7 @@ class DataSource(GDALBase): def __str__(self): "Return OGR GetName and Driver for the Data Source." - return '%s (%s)' % (self.name, self.driver) + return "%s (%s)" % (self.name, self.driver) @property def layer_count(self): diff --git a/django/contrib/gis/gdal/driver.py b/django/contrib/gis/gdal/driver.py index f90b886251..0ce7a2cdc8 100644 --- a/django/contrib/gis/gdal/driver.py +++ b/django/contrib/gis/gdal/driver.py @@ -2,7 +2,8 @@ from ctypes import c_void_p from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.error import GDALException -from django.contrib.gis.gdal.prototypes import ds as vcapi, raster as rcapi +from django.contrib.gis.gdal.prototypes import ds as vcapi +from django.contrib.gis.gdal.prototypes import raster as rcapi from django.utils.encoding import force_bytes, force_str @@ -20,16 +21,16 @@ class Driver(GDALBase): # https://gdal.org/drivers/raster/ _alias = { # vector - 'esri': 'ESRI Shapefile', - 'shp': 'ESRI Shapefile', - 'shape': 'ESRI Shapefile', - 'tiger': 'TIGER', - 'tiger/line': 'TIGER', + "esri": "ESRI Shapefile", + "shp": "ESRI Shapefile", + "shape": "ESRI Shapefile", + "tiger": "TIGER", + "tiger/line": "TIGER", # raster - 'tiff': 'GTiff', - 'tif': 'GTiff', - 'jpeg': 'JPEG', - 'jpg': 'JPEG', + "tiff": "GTiff", + "tif": "GTiff", + "jpeg": "JPEG", + "jpg": "JPEG", } def __init__(self, dr_input): @@ -61,11 +62,15 @@ class Driver(GDALBase): elif isinstance(dr_input, c_void_p): driver = dr_input else: - raise GDALException('Unrecognized input type for GDAL/OGR Driver: %s' % type(dr_input)) + raise GDALException( + "Unrecognized input type for GDAL/OGR Driver: %s" % type(dr_input) + ) # Making sure we get a valid pointer to the OGR Driver if not driver: - raise GDALException('Could not initialize GDAL/OGR Driver on input: %s' % dr_input) + raise GDALException( + "Could not initialize GDAL/OGR Driver on input: %s" % dr_input + ) self.ptr = driver def __str__(self): diff --git a/django/contrib/gis/gdal/envelope.py b/django/contrib/gis/gdal/envelope.py index 98c333f483..4c2c1e4a1a 100644 --- a/django/contrib/gis/gdal/envelope.py +++ b/django/contrib/gis/gdal/envelope.py @@ -20,11 +20,12 @@ from django.contrib.gis.gdal.error import GDALException # https://gdal.org/doxygen/ogr__core_8h_source.html class OGREnvelope(Structure): "Represent the OGREnvelope C Structure." - _fields_ = [("MinX", c_double), - ("MaxX", c_double), - ("MinY", c_double), - ("MaxY", c_double), - ] + _fields_ = [ + ("MinX", c_double), + ("MaxX", c_double), + ("MinY", c_double), + ("MaxY", c_double), + ] class Envelope: @@ -47,23 +48,25 @@ class Envelope: elif isinstance(args[0], (tuple, list)): # A tuple was passed in. if len(args[0]) != 4: - raise GDALException('Incorrect number of tuple elements (%d).' % len(args[0])) + raise GDALException( + "Incorrect number of tuple elements (%d)." % len(args[0]) + ) else: self._from_sequence(args[0]) else: - raise TypeError('Incorrect type of argument: %s' % type(args[0])) + raise TypeError("Incorrect type of argument: %s" % type(args[0])) elif len(args) == 4: # Individual parameters passed in. # Thanks to ww for the help self._from_sequence([float(a) for a in args]) else: - raise GDALException('Incorrect number (%d) of arguments.' % len(args)) + raise GDALException("Incorrect number (%d) of arguments." % len(args)) # Checking the x,y coordinates if self.min_x > self.max_x: - raise GDALException('Envelope minimum X > maximum X.') + raise GDALException("Envelope minimum X > maximum X.") if self.min_y > self.max_y: - raise GDALException('Envelope minimum Y > maximum Y.') + raise GDALException("Envelope minimum Y > maximum Y.") def __eq__(self, other): """ @@ -71,13 +74,21 @@ class Envelope: other Envelopes and 4-tuples. """ if isinstance(other, Envelope): - return (self.min_x == other.min_x) and (self.min_y == other.min_y) and \ - (self.max_x == other.max_x) and (self.max_y == other.max_y) + return ( + (self.min_x == other.min_x) + and (self.min_y == other.min_y) + and (self.max_x == other.max_x) + and (self.max_y == other.max_y) + ) elif isinstance(other, tuple) and len(other) == 4: - return (self.min_x == other[0]) and (self.min_y == other[1]) and \ - (self.max_x == other[2]) and (self.max_y == other[3]) + return ( + (self.min_x == other[0]) + and (self.min_y == other[1]) + and (self.max_x == other[2]) + and (self.max_y == other[3]) + ) else: - raise GDALException('Equivalence testing only works with other Envelopes.') + raise GDALException("Equivalence testing only works with other Envelopes.") def __str__(self): "Return a string representation of the tuple." @@ -104,12 +115,16 @@ class Envelope: if len(args) == 1: if isinstance(args[0], Envelope): return self.expand_to_include(args[0].tuple) - elif hasattr(args[0], 'x') and hasattr(args[0], 'y'): - return self.expand_to_include(args[0].x, args[0].y, args[0].x, args[0].y) + elif hasattr(args[0], "x") and hasattr(args[0], "y"): + return self.expand_to_include( + args[0].x, args[0].y, args[0].x, args[0].y + ) elif isinstance(args[0], (tuple, list)): # A tuple was passed in. if len(args[0]) == 2: - return self.expand_to_include((args[0][0], args[0][1], args[0][0], args[0][1])) + return self.expand_to_include( + (args[0][0], args[0][1], args[0][0], args[0][1]) + ) elif len(args[0]) == 4: (minx, miny, maxx, maxy) = args[0] if minx < self._envelope.MinX: @@ -121,9 +136,11 @@ class Envelope: if maxy > self._envelope.MaxY: self._envelope.MaxY = maxy else: - raise GDALException('Incorrect number of tuple elements (%d).' % len(args[0])) + raise GDALException( + "Incorrect number of tuple elements (%d)." % len(args[0]) + ) else: - raise TypeError('Incorrect type of argument: %s' % type(args[0])) + raise TypeError("Incorrect type of argument: %s" % type(args[0])) elif len(args) == 2: # An x and an y parameter were passed in return self.expand_to_include((args[0], args[1], args[0], args[1])) @@ -131,7 +148,7 @@ class Envelope: # Individual parameters passed in. return self.expand_to_include(args) else: - raise GDALException('Incorrect number (%d) of arguments.' % len(args[0])) + raise GDALException("Incorrect number (%d) of arguments." % len(args[0])) @property def min_x(self): @@ -172,7 +189,15 @@ class Envelope: def wkt(self): "Return WKT representing a Polygon for this envelope." # TODO: Fix significant figures. - return 'POLYGON((%s %s,%s %s,%s %s,%s %s,%s %s))' % \ - (self.min_x, self.min_y, self.min_x, self.max_y, - self.max_x, self.max_y, self.max_x, self.min_y, - self.min_x, self.min_y) + return "POLYGON((%s %s,%s %s,%s %s,%s %s,%s %s))" % ( + self.min_x, + self.min_y, + self.min_x, + self.max_y, + self.max_x, + self.max_y, + self.max_x, + self.min_y, + self.min_x, + self.min_y, + ) diff --git a/django/contrib/gis/gdal/error.py b/django/contrib/gis/gdal/error.py index b5646dd751..df19c2e4a7 100644 --- a/django/contrib/gis/gdal/error.py +++ b/django/contrib/gis/gdal/error.py @@ -18,29 +18,29 @@ class SRSException(Exception): # OGR Error Codes OGRERR_DICT = { - 1: (GDALException, 'Not enough data.'), - 2: (GDALException, 'Not enough memory.'), - 3: (GDALException, 'Unsupported geometry type.'), - 4: (GDALException, 'Unsupported operation.'), - 5: (GDALException, 'Corrupt data.'), - 6: (GDALException, 'OGR failure.'), - 7: (SRSException, 'Unsupported SRS.'), - 8: (GDALException, 'Invalid handle.'), + 1: (GDALException, "Not enough data."), + 2: (GDALException, "Not enough memory."), + 3: (GDALException, "Unsupported geometry type."), + 4: (GDALException, "Unsupported operation."), + 5: (GDALException, "Corrupt data."), + 6: (GDALException, "OGR failure."), + 7: (SRSException, "Unsupported SRS."), + 8: (GDALException, "Invalid handle."), } # CPL Error Codes # https://gdal.org/api/cpl.html#cpl-error-h CPLERR_DICT = { - 1: (GDALException, 'AppDefined'), - 2: (GDALException, 'OutOfMemory'), - 3: (GDALException, 'FileIO'), - 4: (GDALException, 'OpenFailed'), - 5: (GDALException, 'IllegalArg'), - 6: (GDALException, 'NotSupported'), - 7: (GDALException, 'AssertionFailed'), - 8: (GDALException, 'NoWriteAccess'), - 9: (GDALException, 'UserInterrupt'), - 10: (GDALException, 'ObjectNull'), + 1: (GDALException, "AppDefined"), + 2: (GDALException, "OutOfMemory"), + 3: (GDALException, "FileIO"), + 4: (GDALException, "OpenFailed"), + 5: (GDALException, "IllegalArg"), + 6: (GDALException, "NotSupported"), + 7: (GDALException, "AssertionFailed"), + 8: (GDALException, "NoWriteAccess"), + 9: (GDALException, "UserInterrupt"), + 10: (GDALException, "ObjectNull"), } ERR_NONE = 0 diff --git a/django/contrib/gis/gdal/feature.py b/django/contrib/gis/gdal/feature.py index 82a9089c16..6f08969984 100644 --- a/django/contrib/gis/gdal/feature.py +++ b/django/contrib/gis/gdal/feature.py @@ -2,7 +2,8 @@ from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.field import Field from django.contrib.gis.gdal.geometries import OGRGeometry, OGRGeomType -from django.contrib.gis.gdal.prototypes import ds as capi, geom as geom_api +from django.contrib.gis.gdal.prototypes import ds as capi +from django.contrib.gis.gdal.prototypes import geom as geom_api from django.utils.encoding import force_bytes, force_str @@ -15,6 +16,7 @@ class Feature(GDALBase): This class that wraps an OGR Feature, needs to be instantiated from a Layer object. """ + destructor = capi.destroy_feature def __init__(self, feat, layer): @@ -22,7 +24,7 @@ class Feature(GDALBase): Initialize Feature from a pointer and its Layer object. """ if not feat: - raise GDALException('Cannot create OGR Feature, invalid pointer given.') + raise GDALException("Cannot create OGR Feature, invalid pointer given.") self.ptr = feat self._layer = layer @@ -38,7 +40,9 @@ class Feature(GDALBase): elif 0 <= index < self.num_fields: i = index else: - raise IndexError('Index out of range when accessing field in a feature: %s.' % index) + raise IndexError( + "Index out of range when accessing field in a feature: %s." % index + ) return Field(self, i) def __len__(self): @@ -47,7 +51,7 @@ class Feature(GDALBase): def __str__(self): "The string name of the feature." - return 'Feature FID %d in Layer<%s>' % (self.fid, self.layer_name) + return "Feature FID %d in Layer<%s>" % (self.fid, self.layer_name) def __eq__(self, other): "Do equivalence testing on the features." @@ -81,8 +85,9 @@ class Feature(GDALBase): force_str( capi.get_field_name(capi.get_field_defn(self._layer._ldefn, i)), self.encoding, - strings_only=True - ) for i in range(self.num_fields) + strings_only=True, + ) + for i in range(self.num_fields) ] @property @@ -104,12 +109,12 @@ class Feature(GDALBase): object. May take a string of the field name or a Field object as parameters. """ - field_name = getattr(field, 'name', field) + field_name = getattr(field, "name", field) return self[field_name].value def index(self, field_name): "Return the index of the given field name." i = capi.get_field_index(self.ptr, force_bytes(field_name)) if i < 0: - raise IndexError('Invalid OFT field name given: %s.' % field_name) + raise IndexError("Invalid OFT field name given: %s." % field_name) return i diff --git a/django/contrib/gis/gdal/field.py b/django/contrib/gis/gdal/field.py index e803886182..5374f6c3cb 100644 --- a/django/contrib/gis/gdal/field.py +++ b/django/contrib/gis/gdal/field.py @@ -28,7 +28,7 @@ class Field(GDALBase): # Getting the pointer for this field. fld_ptr = capi.get_feat_field_defn(feat.ptr, index) if not fld_ptr: - raise GDALException('Cannot create OGR Field, invalid pointer given.') + raise GDALException("Cannot create OGR Field, invalid pointer given.") self.ptr = fld_ptr # Setting the class depending upon the OGR Field Type (OFT) @@ -41,14 +41,26 @@ class Field(GDALBase): # #### Field Methods #### def as_double(self): "Retrieve the Field's value as a double (float)." - return capi.get_field_as_double(self._feat.ptr, self._index) if self.is_set else None + return ( + capi.get_field_as_double(self._feat.ptr, self._index) + if self.is_set + else None + ) def as_int(self, is_64=False): "Retrieve the Field's value as an integer." if is_64: - return capi.get_field_as_integer64(self._feat.ptr, self._index) if self.is_set else None + return ( + capi.get_field_as_integer64(self._feat.ptr, self._index) + if self.is_set + else None + ) else: - return capi.get_field_as_integer(self._feat.ptr, self._index) if self.is_set else None + return ( + capi.get_field_as_integer(self._feat.ptr, self._index) + if self.is_set + else None + ) def as_string(self): "Retrieve the Field's value as a string." @@ -63,12 +75,22 @@ class Field(GDALBase): return None yy, mm, dd, hh, mn, ss, tz = [c_int() for i in range(7)] status = capi.get_field_as_datetime( - self._feat.ptr, self._index, byref(yy), byref(mm), byref(dd), - byref(hh), byref(mn), byref(ss), byref(tz)) + self._feat.ptr, + self._index, + byref(yy), + byref(mm), + byref(dd), + byref(hh), + byref(mn), + byref(ss), + byref(tz), + ) if status: return (yy, mm, dd, hh, mn, ss, tz) else: - raise GDALException('Unable to retrieve date & time information from the field.') + raise GDALException( + "Unable to retrieve date & time information from the field." + ) # #### Field Properties #### @property diff --git a/django/contrib/gis/gdal/geometries.py b/django/contrib/gis/gdal/geometries.py index d807b3e8b9..dbb391f75f 100644 --- a/django/contrib/gis/gdal/geometries.py +++ b/django/contrib/gis/gdal/geometries.py @@ -46,7 +46,8 @@ from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.envelope import Envelope, OGREnvelope from django.contrib.gis.gdal.error import GDALException, SRSException from django.contrib.gis.gdal.geomtype import OGRGeomType -from django.contrib.gis.gdal.prototypes import geom as capi, srs as srs_api +from django.contrib.gis.gdal.prototypes import geom as capi +from django.contrib.gis.gdal.prototypes import srs as srs_api from django.contrib.gis.gdal.srs import CoordTransform, SpatialReference from django.contrib.gis.geometry import hex_regex, json_regex, wkt_regex from django.utils.encoding import force_bytes @@ -58,6 +59,7 @@ from django.utils.encoding import force_bytes # The OGR_G_* routines are relevant here. class OGRGeometry(GDALBase): """Encapsulate an OGR geometry.""" + destructor = capi.destroy_geom def __init__(self, geom_input, srs=None): @@ -74,16 +76,18 @@ class OGRGeometry(GDALBase): wkt_m = wkt_regex.match(geom_input) json_m = json_regex.match(geom_input) if wkt_m: - if wkt_m['srid']: + if wkt_m["srid"]: # If there's EWKT, set the SRS w/value of the SRID. - srs = int(wkt_m['srid']) - if wkt_m['type'].upper() == 'LINEARRING': + srs = int(wkt_m["srid"]) + if wkt_m["type"].upper() == "LINEARRING": # OGR_G_CreateFromWkt doesn't work with LINEARRING WKT. # See https://trac.osgeo.org/gdal/ticket/1992. - g = capi.create_geom(OGRGeomType(wkt_m['type']).num) - capi.import_wkt(g, byref(c_char_p(wkt_m['wkt'].encode()))) + g = capi.create_geom(OGRGeomType(wkt_m["type"]).num) + capi.import_wkt(g, byref(c_char_p(wkt_m["wkt"].encode()))) else: - g = capi.from_wkt(byref(c_char_p(wkt_m['wkt'].encode())), None, byref(c_void_p())) + g = capi.from_wkt( + byref(c_char_p(wkt_m["wkt"].encode())), None, byref(c_void_p()) + ) elif json_m: g = self._from_json(geom_input.encode()) else: @@ -101,12 +105,17 @@ class OGRGeometry(GDALBase): # OGR pointer (c_void_p) was the input. g = geom_input else: - raise GDALException('Invalid input type for OGR Geometry construction: %s' % type(geom_input)) + raise GDALException( + "Invalid input type for OGR Geometry construction: %s" + % type(geom_input) + ) # Now checking the Geometry pointer before finishing initialization # by setting the pointer for the object. if not g: - raise GDALException('Cannot create OGR Geometry from input: %s' % geom_input) + raise GDALException( + "Cannot create OGR Geometry from input: %s" % geom_input + ) self.ptr = g # Assigning the SpatialReference object to the geometry, if valid. @@ -129,13 +138,15 @@ class OGRGeometry(GDALBase): wkb, srs = state ptr = capi.from_wkb(wkb, None, byref(c_void_p()), len(wkb)) if not ptr: - raise GDALException('Invalid OGRGeometry loaded from pickled state.') + raise GDALException("Invalid OGRGeometry loaded from pickled state.") self.ptr = ptr self.srs = srs @classmethod def _from_wkb(cls, geom_input): - return capi.from_wkb(bytes(geom_input), None, byref(c_void_p()), len(geom_input)) + return capi.from_wkb( + bytes(geom_input), None, byref(c_void_p()), len(geom_input) + ) @staticmethod def _from_json(geom_input): @@ -145,8 +156,10 @@ class OGRGeometry(GDALBase): def from_bbox(cls, bbox): "Construct a Polygon from a bounding box (4-tuple)." x0, y0, x1, y1 = bbox - return OGRGeometry('POLYGON((%s %s, %s %s, %s %s, %s %s, %s %s))' % ( - x0, y0, x0, y1, x1, y1, x1, y0, x0, y0)) + return OGRGeometry( + "POLYGON((%s %s, %s %s, %s %s, %s %s, %s %s))" + % (x0, y0, x0, y1, x1, y1, x1, y0, x0, y0) + ) @staticmethod def from_json(geom_input): @@ -198,7 +211,7 @@ class OGRGeometry(GDALBase): def _set_coord_dim(self, dim): "Set the coordinate dimension of this Geometry." if dim not in (2, 3): - raise ValueError('Geometry dimension must be either 2 or 3') + raise ValueError("Geometry dimension must be either 2 or 3") capi.set_coord_dim(self.ptr, dim) coord_dim = property(_get_coord_dim, _set_coord_dim) @@ -278,7 +291,9 @@ class OGRGeometry(GDALBase): elif srs is None: srs_ptr = None else: - raise TypeError('Cannot assign spatial reference with object of type: %s' % type(srs)) + raise TypeError( + "Cannot assign spatial reference with object of type: %s" % type(srs) + ) capi.assign_srs(self.ptr, srs_ptr) srs = property(_get_srs, _set_srs) @@ -294,19 +309,21 @@ class OGRGeometry(GDALBase): if isinstance(srid, int) or srid is None: self.srs = srid else: - raise TypeError('SRID must be set with an integer.') + raise TypeError("SRID must be set with an integer.") srid = property(_get_srid, _set_srid) # #### Output Methods #### def _geos_ptr(self): from django.contrib.gis.geos import GEOSGeometry + return GEOSGeometry._from_wkb(self.wkb) @property def geos(self): "Return a GEOSGeometry object from this OGRGeometry." from django.contrib.gis.geos import GEOSGeometry + return GEOSGeometry(self._geos_ptr(), self.srid) @property @@ -325,6 +342,7 @@ class OGRGeometry(GDALBase): Return the GeoJSON representation of this Geometry. """ return capi.to_json(self.ptr) + geojson = json @property @@ -340,7 +358,7 @@ class OGRGeometry(GDALBase): @property def wkb(self): "Return the WKB representation of the Geometry." - if sys.byteorder == 'little': + if sys.byteorder == "little": byteorder = 1 # wkbNDR (from ogr_core.h) else: byteorder = 0 # wkbXDR @@ -361,7 +379,7 @@ class OGRGeometry(GDALBase): "Return the EWKT representation of the Geometry." srs = self.srs if srs and srs.srid: - return 'SRID=%s;%s' % (srs.srid, self.wkt) + return "SRID=%s;%s" % (srs.srid, self.wkt) else: return self.wkt @@ -402,15 +420,19 @@ class OGRGeometry(GDALBase): sr = SpatialReference(coord_trans) capi.geom_transform_to(self.ptr, sr.ptr) else: - raise TypeError('Transform only accepts CoordTransform, ' - 'SpatialReference, string, and integer objects.') + raise TypeError( + "Transform only accepts CoordTransform, " + "SpatialReference, string, and integer objects." + ) # #### Topology Methods #### def _topology(self, func, other): """A generalized function for topology operations, takes a GDAL function and the other geometry to perform the operation on.""" if not isinstance(other, OGRGeometry): - raise TypeError('Must use another OGRGeometry object for topology operations!') + raise TypeError( + "Must use another OGRGeometry object for topology operations!" + ) # Returning the output of the given function with the other geometry's # pointer. @@ -500,14 +522,14 @@ class OGRGeometry(GDALBase): # The subclasses for OGR Geometry. class Point(OGRGeometry): - def _geos_ptr(self): from django.contrib.gis import geos + return geos.Point._create_empty() if self.empty else super()._geos_ptr() @classmethod def _create_empty(cls): - return capi.create_geom(OGRGeomType('point').num) + return capi.create_geom(OGRGeomType("point").num) @property def x(self): @@ -532,11 +554,11 @@ class Point(OGRGeometry): return (self.x, self.y) elif self.coord_dim == 3: return (self.x, self.y, self.z) + coords = tuple class LineString(OGRGeometry): - def __getitem__(self, index): "Return the Point at the given index." if 0 <= index < self.point_count: @@ -550,7 +572,9 @@ class LineString(OGRGeometry): elif dim == 3: return (x.value, y.value, z.value) else: - raise IndexError('Index out of range when accessing points of a line string: %s.' % index) + raise IndexError( + "Index out of range when accessing points of a line string: %s." % index + ) def __len__(self): "Return the number of points in the LineString." @@ -560,6 +584,7 @@ class LineString(OGRGeometry): def tuple(self): "Return the tuple representation of this LineString." return tuple(self[i] for i in range(len(self))) + coords = tuple def _listarr(self, func): @@ -592,7 +617,6 @@ class LinearRing(LineString): class Polygon(OGRGeometry): - def __len__(self): "Return the number of interior rings in this Polygon." return self.geom_count @@ -600,21 +624,27 @@ class Polygon(OGRGeometry): def __getitem__(self, index): "Get the ring at the specified index." if 0 <= index < self.geom_count: - return OGRGeometry(capi.clone_geom(capi.get_geom_ref(self.ptr, index)), self.srs) + return OGRGeometry( + capi.clone_geom(capi.get_geom_ref(self.ptr, index)), self.srs + ) else: - raise IndexError('Index out of range when accessing rings of a polygon: %s.' % index) + raise IndexError( + "Index out of range when accessing rings of a polygon: %s." % index + ) # Polygon Properties @property def shell(self): "Return the shell of this Polygon." return self[0] # First ring is the shell + exterior_ring = shell @property def tuple(self): "Return a tuple of LinearRing coordinate tuples." return tuple(self[i].tuple for i in range(self.geom_count)) + coords = tuple @property @@ -627,7 +657,7 @@ class Polygon(OGRGeometry): def centroid(self): "Return the centroid (a Point) of this Polygon." # The centroid is a Point, create a geometry for this. - p = OGRGeometry(OGRGeomType('Point')) + p = OGRGeometry(OGRGeomType("Point")) capi.get_centroid(self.ptr, p.ptr) return p @@ -639,9 +669,14 @@ class GeometryCollection(OGRGeometry): def __getitem__(self, index): "Get the Geometry at the specified index." if 0 <= index < self.geom_count: - return OGRGeometry(capi.clone_geom(capi.get_geom_ref(self.ptr, index)), self.srs) + return OGRGeometry( + capi.clone_geom(capi.get_geom_ref(self.ptr, index)), self.srs + ) else: - raise IndexError('Index out of range when accessing geometry in a collection: %s.' % index) + raise IndexError( + "Index out of range when accessing geometry in a collection: %s." + % index + ) def __len__(self): "Return the number of geometries in this Geometry Collection." @@ -659,7 +694,7 @@ class GeometryCollection(OGRGeometry): tmp = OGRGeometry(geom) capi.add_geom(self.ptr, tmp.ptr) else: - raise GDALException('Must add an OGRGeometry.') + raise GDALException("Must add an OGRGeometry.") @property def point_count(self): @@ -671,6 +706,7 @@ class GeometryCollection(OGRGeometry): def tuple(self): "Return a tuple representation of this Geometry Collection." return tuple(self[i].tuple for i in range(self.geom_count)) + coords = tuple diff --git a/django/contrib/gis/gdal/geomtype.py b/django/contrib/gis/gdal/geomtype.py index 591c680c59..8b77cafa34 100644 --- a/django/contrib/gis/gdal/geomtype.py +++ b/django/contrib/gis/gdal/geomtype.py @@ -8,24 +8,24 @@ class OGRGeomType: # Dictionary of acceptable OGRwkbGeometryType s and their string names. _types = { - 0: 'Unknown', - 1: 'Point', - 2: 'LineString', - 3: 'Polygon', - 4: 'MultiPoint', - 5: 'MultiLineString', - 6: 'MultiPolygon', - 7: 'GeometryCollection', - 100: 'None', - 101: 'LinearRing', - 102: 'PointZ', - 1 + wkb25bit: 'Point25D', - 2 + wkb25bit: 'LineString25D', - 3 + wkb25bit: 'Polygon25D', - 4 + wkb25bit: 'MultiPoint25D', - 5 + wkb25bit: 'MultiLineString25D', - 6 + wkb25bit: 'MultiPolygon25D', - 7 + wkb25bit: 'GeometryCollection25D', + 0: "Unknown", + 1: "Point", + 2: "LineString", + 3: "Polygon", + 4: "MultiPoint", + 5: "MultiLineString", + 6: "MultiPolygon", + 7: "GeometryCollection", + 100: "None", + 101: "LinearRing", + 102: "PointZ", + 1 + wkb25bit: "Point25D", + 2 + wkb25bit: "LineString25D", + 3 + wkb25bit: "Polygon25D", + 4 + wkb25bit: "MultiPoint25D", + 5 + wkb25bit: "MultiLineString25D", + 6 + wkb25bit: "MultiPolygon25D", + 7 + wkb25bit: "GeometryCollection25D", } # Reverse type dictionary, keyed by lowercase of the name. _str_types = {v.lower(): k for k, v in _types.items()} @@ -36,17 +36,17 @@ class OGRGeomType: num = type_input.num elif isinstance(type_input, str): type_input = type_input.lower() - if type_input == 'geometry': - type_input = 'unknown' + if type_input == "geometry": + type_input = "unknown" num = self._str_types.get(type_input) if num is None: raise GDALException('Invalid OGR String Type "%s"' % type_input) elif isinstance(type_input, int): if type_input not in self._types: - raise GDALException('Invalid OGR Integer Type: %d' % type_input) + raise GDALException("Invalid OGR Integer Type: %d" % type_input) num = type_input else: - raise TypeError('Invalid OGR input type given.') + raise TypeError("Invalid OGR input type given.") # Setting the OGR geometry type number. self.num = num @@ -77,19 +77,19 @@ class OGRGeomType: @property def django(self): "Return the Django GeometryField for this OGR Type." - s = self.name.replace('25D', '') - if s in ('LinearRing', 'None'): + s = self.name.replace("25D", "") + if s in ("LinearRing", "None"): return None - elif s == 'Unknown': - s = 'Geometry' - elif s == 'PointZ': - s = 'Point' - return s + 'Field' + elif s == "Unknown": + s = "Geometry" + elif s == "PointZ": + s = "Point" + return s + "Field" def to_multi(self): """ Transform Point, LineString, Polygon, and their 25D equivalents to their Multi... counterpart. """ - if self.name.startswith(('Point', 'LineString', 'Polygon')): + if self.name.startswith(("Point", "LineString", "Polygon")): self.num += 3 diff --git a/django/contrib/gis/gdal/layer.py b/django/contrib/gis/gdal/layer.py index 012f70bdb0..e8f97b7552 100644 --- a/django/contrib/gis/gdal/layer.py +++ b/django/contrib/gis/gdal/layer.py @@ -7,9 +7,9 @@ from django.contrib.gis.gdal.feature import Feature from django.contrib.gis.gdal.field import OGRFieldTypes from django.contrib.gis.gdal.geometries import OGRGeometry from django.contrib.gis.gdal.geomtype import OGRGeomType -from django.contrib.gis.gdal.prototypes import ( - ds as capi, geom as geom_api, srs as srs_api, -) +from django.contrib.gis.gdal.prototypes import ds as capi +from django.contrib.gis.gdal.prototypes import geom as geom_api +from django.contrib.gis.gdal.prototypes import srs as srs_api from django.contrib.gis.gdal.srs import SpatialReference from django.utils.encoding import force_bytes, force_str @@ -29,12 +29,12 @@ class Layer(GDALBase): collection of the `DataSource` while this Layer is still active. """ if not layer_ptr: - raise GDALException('Cannot create Layer, invalid pointer given') + raise GDALException("Cannot create Layer, invalid pointer given") self.ptr = layer_ptr self._ds = ds self._ldefn = capi.get_layer_defn(self._ptr) # Does the Layer support random reading? - self._random_read = self.test_capability(b'RandomRead') + self._random_read = self.test_capability(b"RandomRead") def __getitem__(self, index): "Get the Feature at the specified index." @@ -43,14 +43,16 @@ class Layer(GDALBase): # number of features because the beginning and ending feature IDs # are not guaranteed to be 0 and len(layer)-1, respectively. if index < 0: - raise IndexError('Negative indices are not allowed on OGR Layers.') + raise IndexError("Negative indices are not allowed on OGR Layers.") return self._make_feature(index) elif isinstance(index, slice): # A slice was given start, stop, stride = index.indices(self.num_feat) return [self._make_feature(fid) for fid in range(start, stop, stride)] else: - raise TypeError('Integers and slices may only be used when indexing OGR Layers.') + raise TypeError( + "Integers and slices may only be used when indexing OGR Layers." + ) def __iter__(self): "Iterate over each Feature in the Layer." @@ -87,7 +89,7 @@ class Layer(GDALBase): if feat.fid == feat_id: return feat # Should have returned a Feature, raise an IndexError. - raise IndexError('Invalid feature id: %s.' % feat_id) + raise IndexError("Invalid feature id: %s." % feat_id) # #### Layer properties #### @property @@ -133,10 +135,14 @@ class Layer(GDALBase): Return a list of string names corresponding to each of the Fields available in this Layer. """ - return [force_str( - capi.get_field_name(capi.get_field_defn(self._ldefn, i)), - self._ds.encoding, strings_only=True, - ) for i in range(self.num_fields)] + return [ + force_str( + capi.get_field_name(capi.get_field_defn(self._ldefn, i)), + self._ds.encoding, + strings_only=True, + ) + for i in range(self.num_fields) + ] @property def field_types(self): @@ -145,20 +151,26 @@ class Layer(GDALBase): return the list [OFTInteger, OFTReal, OFTString] for an OGR layer that has an integer, a floating-point, and string fields. """ - return [OGRFieldTypes[capi.get_field_type(capi.get_field_defn(self._ldefn, i))] - for i in range(self.num_fields)] + return [ + OGRFieldTypes[capi.get_field_type(capi.get_field_defn(self._ldefn, i))] + for i in range(self.num_fields) + ] @property def field_widths(self): "Return a list of the maximum field widths for the features." - return [capi.get_field_width(capi.get_field_defn(self._ldefn, i)) - for i in range(self.num_fields)] + return [ + capi.get_field_width(capi.get_field_defn(self._ldefn, i)) + for i in range(self.num_fields) + ] @property def field_precisions(self): "Return the field precisions for the features." - return [capi.get_field_precision(capi.get_field_defn(self._ldefn, i)) - for i in range(self.num_fields)] + return [ + capi.get_field_precision(capi.get_field_defn(self._ldefn, i)) + for i in range(self.num_fields) + ] def _get_spatial_filter(self): try: @@ -171,7 +183,7 @@ class Layer(GDALBase): capi.set_spatial_filter(self.ptr, filter.ptr) elif isinstance(filter, (tuple, list)): if not len(filter) == 4: - raise ValueError('Spatial filter list/tuple must have 4 elements.') + raise ValueError("Spatial filter list/tuple must have 4 elements.") # Map c_double onto params -- if a bad type is passed in it # will be caught here. xmin, ymin, xmax, ymax = map(c_double, filter) @@ -179,7 +191,9 @@ class Layer(GDALBase): elif filter is None: capi.set_spatial_filter(self.ptr, None) else: - raise TypeError('Spatial filter must be either an OGRGeometry instance, a 4-tuple, or None.') + raise TypeError( + "Spatial filter must be either an OGRGeometry instance, a 4-tuple, or None." + ) spatial_filter = property(_get_spatial_filter, _set_spatial_filter) @@ -190,7 +204,7 @@ class Layer(GDALBase): in the Layer. """ if field_name not in self.fields: - raise GDALException('invalid field name: %s' % field_name) + raise GDALException("invalid field name: %s" % field_name) return [feat.get(field_name) for feat in self] def get_geoms(self, geos=False): @@ -200,6 +214,7 @@ class Layer(GDALBase): """ if geos: from django.contrib.gis.geos import GEOSGeometry + return [GEOSGeometry(feat.geom.wkb) for feat in self] else: return [feat.geom for feat in self] diff --git a/django/contrib/gis/gdal/libgdal.py b/django/contrib/gis/gdal/libgdal.py index 3ff82ef086..2cb8329451 100644 --- a/django/contrib/gis/gdal/libgdal.py +++ b/django/contrib/gis/gdal/libgdal.py @@ -7,29 +7,41 @@ from ctypes.util import find_library from django.contrib.gis.gdal.error import GDALException from django.core.exceptions import ImproperlyConfigured -logger = logging.getLogger('django.contrib.gis') +logger = logging.getLogger("django.contrib.gis") # Custom library path set? try: from django.conf import settings + lib_path = settings.GDAL_LIBRARY_PATH except (AttributeError, ImportError, ImproperlyConfigured, OSError): lib_path = None if lib_path: lib_names = None -elif os.name == 'nt': +elif os.name == "nt": # Windows NT shared libraries lib_names = [ - 'gdal303', 'gdal302', 'gdal301', 'gdal300', - 'gdal204', 'gdal203', 'gdal202', + "gdal303", + "gdal302", + "gdal301", + "gdal300", + "gdal204", + "gdal203", + "gdal202", ] -elif os.name == 'posix': +elif os.name == "posix": # *NIX library names. lib_names = [ - 'gdal', 'GDAL', - 'gdal3.3.0', 'gdal3.2.0', 'gdal3.1.0', 'gdal3.0.0', - 'gdal2.4.0', 'gdal2.3.0', 'gdal2.2.0', + "gdal", + "GDAL", + "gdal3.3.0", + "gdal3.2.0", + "gdal3.1.0", + "gdal3.0.0", + "gdal2.4.0", + "gdal2.3.0", + "gdal2.2.0", ] else: raise ImproperlyConfigured('GDAL is unsupported on OS "%s".' % os.name) @@ -45,7 +57,7 @@ if lib_names: if lib_path is None: raise ImproperlyConfigured( 'Could not find the GDAL library (tried "%s"). Is GDAL installed? ' - 'If it is, try setting GDAL_LIBRARY_PATH in your settings.' + "If it is, try setting GDAL_LIBRARY_PATH in your settings." % '", "'.join(lib_names) ) @@ -56,8 +68,9 @@ lgdal = CDLL(lib_path) # STDCALL, while others are not. Thus, the library will also need to # be loaded up as WinDLL for said OSR functions that require the # different calling convention. -if os.name == 'nt': +if os.name == "nt": from ctypes import WinDLL + lwingdal = WinDLL(lib_path) @@ -66,7 +79,7 @@ def std_call(func): Return the correct STDCALL function for certain OSR routines on Win32 platforms. """ - if os.name == 'nt': + if os.name == "nt": return lwingdal[func] else: return lgdal[func] @@ -75,24 +88,24 @@ def std_call(func): # #### Version-information functions. #### # Return GDAL library version information with the given key. -_version_info = std_call('GDALVersionInfo') +_version_info = std_call("GDALVersionInfo") _version_info.argtypes = [c_char_p] _version_info.restype = c_char_p def gdal_version(): "Return only the GDAL version number information." - return _version_info(b'RELEASE_NAME') + return _version_info(b"RELEASE_NAME") def gdal_full_version(): "Return the full GDAL version information." - return _version_info(b'') + return _version_info(b"") def gdal_version_info(): ver = gdal_version() - m = re.match(br'^(?P<major>\d+)\.(?P<minor>\d+)(?:\.(?P<subminor>\d+))?', ver) + m = re.match(rb"^(?P<major>\d+)\.(?P<minor>\d+)(?:\.(?P<subminor>\d+))?", ver) if not m: raise GDALException('Could not parse GDAL version string "%s"' % ver) major, minor, subminor = m.groups() @@ -106,7 +119,7 @@ CPLErrorHandler = CFUNCTYPE(None, c_int, c_int, c_char_p) def err_handler(error_class, error_number, message): - logger.error('GDAL_ERROR %d: %s', error_number, message) + logger.error("GDAL_ERROR %d: %s", error_number, message) err_handler = CPLErrorHandler(err_handler) @@ -119,5 +132,5 @@ def function(name, args, restype): return func -set_error_handler = function('CPLSetErrorHandler', [CPLErrorHandler], CPLErrorHandler) +set_error_handler = function("CPLSetErrorHandler", [CPLErrorHandler], CPLErrorHandler) set_error_handler(err_handler) diff --git a/django/contrib/gis/gdal/prototypes/ds.py b/django/contrib/gis/gdal/prototypes/ds.py index 0729ab3386..bc5250e2db 100644 --- a/django/contrib/gis/gdal/prototypes/ds.py +++ b/django/contrib/gis/gdal/prototypes/ds.py @@ -8,8 +8,15 @@ from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_void_p from django.contrib.gis.gdal.envelope import OGREnvelope from django.contrib.gis.gdal.libgdal import lgdal from django.contrib.gis.gdal.prototypes.generation import ( - bool_output, const_string_output, double_output, geom_output, int64_output, - int_output, srs_output, void_output, voidptr_output, + bool_output, + const_string_output, + double_output, + geom_output, + int64_output, + int_output, + srs_output, + void_output, + voidptr_output, ) c_int_p = POINTER(c_int) # shortcut type @@ -18,9 +25,13 @@ c_int_p = POINTER(c_int) # shortcut type register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False) cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False) get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int]) -get_driver_by_name = voidptr_output(lgdal.OGRGetDriverByName, [c_char_p], errcheck=False) +get_driver_by_name = voidptr_output( + lgdal.OGRGetDriverByName, [c_char_p], errcheck=False +) get_driver_count = int_output(lgdal.OGRGetDriverCount, []) -get_driver_name = const_string_output(lgdal.OGR_Dr_GetName, [c_void_p], decoding='ascii') +get_driver_name = const_string_output( + lgdal.OGR_Dr_GetName, [c_void_p], decoding="ascii" +) # DataSource open_ds = voidptr_output(lgdal.OGROpen, [c_char_p, c_int, POINTER(c_void_p)]) @@ -41,10 +52,13 @@ get_next_feature = voidptr_output(lgdal.OGR_L_GetNextFeature, [c_void_p]) reset_reading = void_output(lgdal.OGR_L_ResetReading, [c_void_p], errcheck=False) test_capability = int_output(lgdal.OGR_L_TestCapability, [c_void_p, c_char_p]) get_spatial_filter = geom_output(lgdal.OGR_L_GetSpatialFilter, [c_void_p]) -set_spatial_filter = void_output(lgdal.OGR_L_SetSpatialFilter, [c_void_p, c_void_p], errcheck=False) +set_spatial_filter = void_output( + lgdal.OGR_L_SetSpatialFilter, [c_void_p, c_void_p], errcheck=False +) set_spatial_filter_rect = void_output( lgdal.OGR_L_SetSpatialFilterRect, - [c_void_p, c_double, c_double, c_double, c_double], errcheck=False + [c_void_p, c_double, c_double, c_double, c_double], + errcheck=False, ) # Feature Definition Routines @@ -64,13 +78,17 @@ get_feat_field_defn = voidptr_output(lgdal.OGR_F_GetFieldDefnRef, [c_void_p, c_i get_fid = int_output(lgdal.OGR_F_GetFID, [c_void_p]) get_field_as_datetime = int_output( lgdal.OGR_F_GetFieldAsDateTime, - [c_void_p, c_int, c_int_p, c_int_p, c_int_p, c_int_p, c_int_p, c_int_p] + [c_void_p, c_int, c_int_p, c_int_p, c_int_p, c_int_p, c_int_p, c_int_p], ) get_field_as_double = double_output(lgdal.OGR_F_GetFieldAsDouble, [c_void_p, c_int]) get_field_as_integer = int_output(lgdal.OGR_F_GetFieldAsInteger, [c_void_p, c_int]) -get_field_as_integer64 = int64_output(lgdal.OGR_F_GetFieldAsInteger64, [c_void_p, c_int]) +get_field_as_integer64 = int64_output( + lgdal.OGR_F_GetFieldAsInteger64, [c_void_p, c_int] +) is_field_set = bool_output(lgdal.OGR_F_IsFieldSetAndNotNull, [c_void_p, c_int]) -get_field_as_string = const_string_output(lgdal.OGR_F_GetFieldAsString, [c_void_p, c_int]) +get_field_as_string = const_string_output( + lgdal.OGR_F_GetFieldAsString, [c_void_p, c_int] +) get_field_index = int_output(lgdal.OGR_F_GetFieldIndex, [c_void_p, c_char_p]) # Field Routines diff --git a/django/contrib/gis/gdal/prototypes/errcheck.py b/django/contrib/gis/gdal/prototypes/errcheck.py index 405f2e7070..52bb7cb083 100644 --- a/django/contrib/gis/gdal/prototypes/errcheck.py +++ b/django/contrib/gis/gdal/prototypes/errcheck.py @@ -4,9 +4,7 @@ """ from ctypes import c_void_p, string_at -from django.contrib.gis.gdal.error import ( - GDALException, SRSException, check_err, -) +from django.contrib.gis.gdal.error import GDALException, SRSException, check_err from django.contrib.gis.gdal.libgdal import lgdal @@ -63,6 +61,7 @@ def check_string(result, func, cargs, offset=-1, str_result=False): lgdal.VSIFree(ptr) return s + # ### DataSource, Layer error-checking ### @@ -80,7 +79,9 @@ def check_geom(result, func, cargs): if isinstance(result, int): result = c_void_p(result) if not result: - raise GDALException('Invalid geometry pointer returned from "%s".' % func.__name__) + raise GDALException( + 'Invalid geometry pointer returned from "%s".' % func.__name__ + ) return result @@ -96,7 +97,9 @@ def check_srs(result, func, cargs): if isinstance(result, int): result = c_void_p(result) if not result: - raise SRSException('Invalid spatial reference pointer returned from "%s".' % func.__name__) + raise SRSException( + 'Invalid spatial reference pointer returned from "%s".' % func.__name__ + ) return result diff --git a/django/contrib/gis/gdal/prototypes/generation.py b/django/contrib/gis/gdal/prototypes/generation.py index 0b26585161..230e56f665 100644 --- a/django/contrib/gis/gdal/prototypes/generation.py +++ b/django/contrib/gis/gdal/prototypes/generation.py @@ -2,14 +2,19 @@ This module contains functions that generate ctypes prototypes for the GDAL routines. """ -from ctypes import ( - POINTER, c_bool, c_char_p, c_double, c_int, c_int64, c_void_p, -) +from ctypes import POINTER, c_bool, c_char_p, c_double, c_int, c_int64, c_void_p from functools import partial from django.contrib.gis.gdal.prototypes.errcheck import ( - check_arg_errcode, check_const_string, check_errcode, check_geom, - check_geom_offset, check_pointer, check_srs, check_str_arg, check_string, + check_arg_errcode, + check_const_string, + check_errcode, + check_geom, + check_geom_offset, + check_pointer, + check_srs, + check_str_arg, + check_string, ) @@ -55,6 +60,7 @@ def geom_output(func, argtypes, offset=None): def geomerrcheck(result, func, cargs): return check_geom_offset(result, func, cargs, offset) + func.errcheck = geomerrcheck return func @@ -100,6 +106,7 @@ def const_string_output(func, argtypes, offset=None, decoding=None, cpl=False): if res and decoding: res = res.decode(decoding) return res + func.errcheck = _check_const return func @@ -129,6 +136,7 @@ def string_output(func, argtypes, offset=-1, str_result=False, decoding=None): if res and decoding: res = res.decode(decoding) return res + func.errcheck = _check_str return func diff --git a/django/contrib/gis/gdal/prototypes/geom.py b/django/contrib/gis/gdal/prototypes/geom.py index f7745f552f..06c2a75f3a 100644 --- a/django/contrib/gis/gdal/prototypes/geom.py +++ b/django/contrib/gis/gdal/prototypes/geom.py @@ -4,8 +4,13 @@ from django.contrib.gis.gdal.envelope import OGREnvelope from django.contrib.gis.gdal.libgdal import lgdal from django.contrib.gis.gdal.prototypes.errcheck import check_envelope from django.contrib.gis.gdal.prototypes.generation import ( - const_string_output, double_output, geom_output, int_output, srs_output, - string_output, void_output, + const_string_output, + double_output, + geom_output, + int_output, + srs_output, + string_output, + void_output, ) @@ -34,8 +39,12 @@ def topology_func(f): # GeoJSON routines. from_json = geom_output(lgdal.OGR_G_CreateGeometryFromJson, [c_char_p]) -to_json = string_output(lgdal.OGR_G_ExportToJson, [c_void_p], str_result=True, decoding='ascii') -to_kml = string_output(lgdal.OGR_G_ExportToKML, [c_void_p, c_char_p], str_result=True, decoding='ascii') +to_json = string_output( + lgdal.OGR_G_ExportToJson, [c_void_p], str_result=True, decoding="ascii" +) +to_kml = string_output( + lgdal.OGR_G_ExportToKML, [c_void_p, c_char_p], str_result=True, decoding="ascii" +) # GetX, GetY, GetZ all return doubles. getx = pnt_func(lgdal.OGR_G_GetX) @@ -43,8 +52,14 @@ gety = pnt_func(lgdal.OGR_G_GetY) getz = pnt_func(lgdal.OGR_G_GetZ) # Geometry creation routines. -from_wkb = geom_output(lgdal.OGR_G_CreateFromWkb, [c_char_p, c_void_p, POINTER(c_void_p), c_int], offset=-2) -from_wkt = geom_output(lgdal.OGR_G_CreateFromWkt, [POINTER(c_char_p), c_void_p, POINTER(c_void_p)], offset=-1) +from_wkb = geom_output( + lgdal.OGR_G_CreateFromWkb, [c_char_p, c_void_p, POINTER(c_void_p), c_int], offset=-2 +) +from_wkt = geom_output( + lgdal.OGR_G_CreateFromWkt, + [POINTER(c_char_p), c_void_p, POINTER(c_void_p)], + offset=-1, +) from_gml = geom_output(lgdal.OGR_G_CreateFromGML, [c_char_p]) create_geom = geom_output(lgdal.OGR_G_CreateGeometry, [c_int]) clone_geom = geom_output(lgdal.OGR_G_Clone, [c_void_p]) @@ -64,13 +79,21 @@ import_wkt = void_output(lgdal.OGR_G_ImportFromWkt, [c_void_p, POINTER(c_char_p) destroy_geom = void_output(lgdal.OGR_G_DestroyGeometry, [c_void_p], errcheck=False) # Geometry export routines. -to_wkb = void_output(lgdal.OGR_G_ExportToWkb, None, errcheck=True) # special handling for WKB. -to_wkt = string_output(lgdal.OGR_G_ExportToWkt, [c_void_p, POINTER(c_char_p)], decoding='ascii') -to_gml = string_output(lgdal.OGR_G_ExportToGML, [c_void_p], str_result=True, decoding='ascii') +to_wkb = void_output( + lgdal.OGR_G_ExportToWkb, None, errcheck=True +) # special handling for WKB. +to_wkt = string_output( + lgdal.OGR_G_ExportToWkt, [c_void_p, POINTER(c_char_p)], decoding="ascii" +) +to_gml = string_output( + lgdal.OGR_G_ExportToGML, [c_void_p], str_result=True, decoding="ascii" +) get_wkbsize = int_output(lgdal.OGR_G_WkbSize, [c_void_p]) # Geometry spatial-reference related routines. -assign_srs = void_output(lgdal.OGR_G_AssignSpatialReference, [c_void_p, c_void_p], errcheck=False) +assign_srs = void_output( + lgdal.OGR_G_AssignSpatialReference, [c_void_p, c_void_p], errcheck=False +) get_geom_srs = srs_output(lgdal.OGR_G_GetSpatialReference, [c_void_p]) # Geometry properties @@ -78,16 +101,23 @@ get_area = double_output(lgdal.OGR_G_GetArea, [c_void_p]) get_centroid = void_output(lgdal.OGR_G_Centroid, [c_void_p, c_void_p]) get_dims = int_output(lgdal.OGR_G_GetDimension, [c_void_p]) get_coord_dim = int_output(lgdal.OGR_G_GetCoordinateDimension, [c_void_p]) -set_coord_dim = void_output(lgdal.OGR_G_SetCoordinateDimension, [c_void_p, c_int], errcheck=False) -is_empty = int_output(lgdal.OGR_G_IsEmpty, [c_void_p], errcheck=lambda result, func, cargs: bool(result)) +set_coord_dim = void_output( + lgdal.OGR_G_SetCoordinateDimension, [c_void_p, c_int], errcheck=False +) +is_empty = int_output( + lgdal.OGR_G_IsEmpty, [c_void_p], errcheck=lambda result, func, cargs: bool(result) +) get_geom_count = int_output(lgdal.OGR_G_GetGeometryCount, [c_void_p]) -get_geom_name = const_string_output(lgdal.OGR_G_GetGeometryName, [c_void_p], decoding='ascii') +get_geom_name = const_string_output( + lgdal.OGR_G_GetGeometryName, [c_void_p], decoding="ascii" +) get_geom_type = int_output(lgdal.OGR_G_GetGeometryType, [c_void_p]) get_point_count = int_output(lgdal.OGR_G_GetPointCount, [c_void_p]) get_point = void_output( lgdal.OGR_G_GetPoint, - [c_void_p, c_int, POINTER(c_double), POINTER(c_double), POINTER(c_double)], errcheck=False + [c_void_p, c_int, POINTER(c_double), POINTER(c_double), POINTER(c_double)], + errcheck=False, ) geom_close_rings = void_output(lgdal.OGR_G_CloseRings, [c_void_p], errcheck=False) diff --git a/django/contrib/gis/gdal/prototypes/raster.py b/django/contrib/gis/gdal/prototypes/raster.py index 3fc9446ce8..59b930cb02 100644 --- a/django/contrib/gis/gdal/prototypes/raster.py +++ b/django/contrib/gis/gdal/prototypes/raster.py @@ -7,8 +7,12 @@ from functools import partial from django.contrib.gis.gdal.libgdal import std_call from django.contrib.gis.gdal.prototypes.generation import ( - chararray_output, const_string_output, double_output, int_output, - void_output, voidptr_output, + chararray_output, + const_string_output, + double_output, + int_output, + void_output, + voidptr_output, ) # For more detail about c function names and definitions see @@ -22,80 +26,152 @@ const_string_output = partial(const_string_output, cpl=True) double_output = partial(double_output, cpl=True) # Raster Driver Routines -register_all = void_output(std_call('GDALAllRegister'), [], errcheck=False) -get_driver = voidptr_output(std_call('GDALGetDriver'), [c_int]) -get_driver_by_name = voidptr_output(std_call('GDALGetDriverByName'), [c_char_p], errcheck=False) -get_driver_count = int_output(std_call('GDALGetDriverCount'), []) -get_driver_description = const_string_output(std_call('GDALGetDescription'), [c_void_p]) +register_all = void_output(std_call("GDALAllRegister"), [], errcheck=False) +get_driver = voidptr_output(std_call("GDALGetDriver"), [c_int]) +get_driver_by_name = voidptr_output( + std_call("GDALGetDriverByName"), [c_char_p], errcheck=False +) +get_driver_count = int_output(std_call("GDALGetDriverCount"), []) +get_driver_description = const_string_output(std_call("GDALGetDescription"), [c_void_p]) # Raster Data Source Routines -create_ds = voidptr_output(std_call('GDALCreate'), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p]) -open_ds = voidptr_output(std_call('GDALOpen'), [c_char_p, c_int]) -close_ds = void_output(std_call('GDALClose'), [c_void_p], errcheck=False) -flush_ds = int_output(std_call('GDALFlushCache'), [c_void_p]) -copy_ds = voidptr_output( - std_call('GDALCreateCopy'), - [c_void_p, c_char_p, c_void_p, c_int, POINTER(c_char_p), c_void_p, c_void_p] +create_ds = voidptr_output( + std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p] +) +open_ds = voidptr_output(std_call("GDALOpen"), [c_char_p, c_int]) +close_ds = void_output(std_call("GDALClose"), [c_void_p], errcheck=False) +flush_ds = int_output(std_call("GDALFlushCache"), [c_void_p]) +copy_ds = voidptr_output( + std_call("GDALCreateCopy"), + [c_void_p, c_char_p, c_void_p, c_int, POINTER(c_char_p), c_void_p, c_void_p], +) +add_band_ds = void_output(std_call("GDALAddBand"), [c_void_p, c_int]) +get_ds_description = const_string_output(std_call("GDALGetDescription"), [c_void_p]) +get_ds_driver = voidptr_output(std_call("GDALGetDatasetDriver"), [c_void_p]) +get_ds_info = const_string_output(std_call("GDALInfo"), [c_void_p, c_void_p]) +get_ds_xsize = int_output(std_call("GDALGetRasterXSize"), [c_void_p]) +get_ds_ysize = int_output(std_call("GDALGetRasterYSize"), [c_void_p]) +get_ds_raster_count = int_output(std_call("GDALGetRasterCount"), [c_void_p]) +get_ds_raster_band = voidptr_output(std_call("GDALGetRasterBand"), [c_void_p, c_int]) +get_ds_projection_ref = const_string_output( + std_call("GDALGetProjectionRef"), [c_void_p] +) +set_ds_projection_ref = void_output(std_call("GDALSetProjection"), [c_void_p, c_char_p]) +get_ds_geotransform = void_output( + std_call("GDALGetGeoTransform"), [c_void_p, POINTER(c_double * 6)], errcheck=False +) +set_ds_geotransform = void_output( + std_call("GDALSetGeoTransform"), [c_void_p, POINTER(c_double * 6)] ) -add_band_ds = void_output(std_call('GDALAddBand'), [c_void_p, c_int]) -get_ds_description = const_string_output(std_call('GDALGetDescription'), [c_void_p]) -get_ds_driver = voidptr_output(std_call('GDALGetDatasetDriver'), [c_void_p]) -get_ds_info = const_string_output(std_call('GDALInfo'), [c_void_p, c_void_p]) -get_ds_xsize = int_output(std_call('GDALGetRasterXSize'), [c_void_p]) -get_ds_ysize = int_output(std_call('GDALGetRasterYSize'), [c_void_p]) -get_ds_raster_count = int_output(std_call('GDALGetRasterCount'), [c_void_p]) -get_ds_raster_band = voidptr_output(std_call('GDALGetRasterBand'), [c_void_p, c_int]) -get_ds_projection_ref = const_string_output(std_call('GDALGetProjectionRef'), [c_void_p]) -set_ds_projection_ref = void_output(std_call('GDALSetProjection'), [c_void_p, c_char_p]) -get_ds_geotransform = void_output(std_call('GDALGetGeoTransform'), [c_void_p, POINTER(c_double * 6)], errcheck=False) -set_ds_geotransform = void_output(std_call('GDALSetGeoTransform'), [c_void_p, POINTER(c_double * 6)]) -get_ds_metadata = chararray_output(std_call('GDALGetMetadata'), [c_void_p, c_char_p], errcheck=False) -set_ds_metadata = void_output(std_call('GDALSetMetadata'), [c_void_p, POINTER(c_char_p), c_char_p]) -get_ds_metadata_domain_list = chararray_output(std_call('GDALGetMetadataDomainList'), [c_void_p], errcheck=False) -get_ds_metadata_item = const_string_output(std_call('GDALGetMetadataItem'), [c_void_p, c_char_p, c_char_p]) -set_ds_metadata_item = const_string_output(std_call('GDALSetMetadataItem'), [c_void_p, c_char_p, c_char_p, c_char_p]) -free_dsl = void_output(std_call('CSLDestroy'), [POINTER(c_char_p)], errcheck=False) +get_ds_metadata = chararray_output( + std_call("GDALGetMetadata"), [c_void_p, c_char_p], errcheck=False +) +set_ds_metadata = void_output( + std_call("GDALSetMetadata"), [c_void_p, POINTER(c_char_p), c_char_p] +) +get_ds_metadata_domain_list = chararray_output( + std_call("GDALGetMetadataDomainList"), [c_void_p], errcheck=False +) +get_ds_metadata_item = const_string_output( + std_call("GDALGetMetadataItem"), [c_void_p, c_char_p, c_char_p] +) +set_ds_metadata_item = const_string_output( + std_call("GDALSetMetadataItem"), [c_void_p, c_char_p, c_char_p, c_char_p] +) +free_dsl = void_output(std_call("CSLDestroy"), [POINTER(c_char_p)], errcheck=False) # Raster Band Routines band_io = void_output( - std_call('GDALRasterIO'), - [c_void_p, c_int, c_int, c_int, c_int, c_int, c_void_p, c_int, c_int, c_int, c_int, c_int] -) -get_band_xsize = int_output(std_call('GDALGetRasterBandXSize'), [c_void_p]) -get_band_ysize = int_output(std_call('GDALGetRasterBandYSize'), [c_void_p]) -get_band_index = int_output(std_call('GDALGetBandNumber'), [c_void_p]) -get_band_description = const_string_output(std_call('GDALGetDescription'), [c_void_p]) -get_band_ds = voidptr_output(std_call('GDALGetBandDataset'), [c_void_p]) -get_band_datatype = int_output(std_call('GDALGetRasterDataType'), [c_void_p]) -get_band_color_interp = int_output(std_call('GDALGetRasterColorInterpretation'), [c_void_p]) -get_band_nodata_value = double_output(std_call('GDALGetRasterNoDataValue'), [c_void_p, POINTER(c_int)]) -set_band_nodata_value = void_output(std_call('GDALSetRasterNoDataValue'), [c_void_p, c_double]) -delete_band_nodata_value = void_output(std_call('GDALDeleteRasterNoDataValue'), [c_void_p]) -get_band_statistics = void_output( - std_call('GDALGetRasterStatistics'), + std_call("GDALRasterIO"), [ - c_void_p, c_int, c_int, POINTER(c_double), POINTER(c_double), - POINTER(c_double), POINTER(c_double), c_void_p, c_void_p, + c_void_p, + c_int, + c_int, + c_int, + c_int, + c_int, + c_void_p, + c_int, + c_int, + c_int, + c_int, + c_int, + ], +) +get_band_xsize = int_output(std_call("GDALGetRasterBandXSize"), [c_void_p]) +get_band_ysize = int_output(std_call("GDALGetRasterBandYSize"), [c_void_p]) +get_band_index = int_output(std_call("GDALGetBandNumber"), [c_void_p]) +get_band_description = const_string_output(std_call("GDALGetDescription"), [c_void_p]) +get_band_ds = voidptr_output(std_call("GDALGetBandDataset"), [c_void_p]) +get_band_datatype = int_output(std_call("GDALGetRasterDataType"), [c_void_p]) +get_band_color_interp = int_output( + std_call("GDALGetRasterColorInterpretation"), [c_void_p] +) +get_band_nodata_value = double_output( + std_call("GDALGetRasterNoDataValue"), [c_void_p, POINTER(c_int)] +) +set_band_nodata_value = void_output( + std_call("GDALSetRasterNoDataValue"), [c_void_p, c_double] +) +delete_band_nodata_value = void_output( + std_call("GDALDeleteRasterNoDataValue"), [c_void_p] +) +get_band_statistics = void_output( + std_call("GDALGetRasterStatistics"), + [ + c_void_p, + c_int, + c_int, + POINTER(c_double), + POINTER(c_double), + POINTER(c_double), + POINTER(c_double), + c_void_p, + c_void_p, ], ) compute_band_statistics = void_output( - std_call('GDALComputeRasterStatistics'), - [c_void_p, c_int, POINTER(c_double), POINTER(c_double), POINTER(c_double), POINTER(c_double), c_void_p, c_void_p], + std_call("GDALComputeRasterStatistics"), + [ + c_void_p, + c_int, + POINTER(c_double), + POINTER(c_double), + POINTER(c_double), + POINTER(c_double), + c_void_p, + c_void_p, + ], ) # Reprojection routine reproject_image = void_output( - std_call('GDALReprojectImage'), - [c_void_p, c_char_p, c_void_p, c_char_p, c_int, c_double, c_double, c_void_p, c_void_p, c_void_p] + std_call("GDALReprojectImage"), + [ + c_void_p, + c_char_p, + c_void_p, + c_char_p, + c_int, + c_double, + c_double, + c_void_p, + c_void_p, + c_void_p, + ], ) auto_create_warped_vrt = voidptr_output( - std_call('GDALAutoCreateWarpedVRT'), - [c_void_p, c_char_p, c_char_p, c_int, c_double, c_void_p] + std_call("GDALAutoCreateWarpedVRT"), + [c_void_p, c_char_p, c_char_p, c_int, c_double, c_void_p], ) # Create VSI gdal raster files from in-memory buffers. # https://gdal.org/api/cpl.html#cpl-vsi-h -create_vsi_file_from_mem_buffer = voidptr_output(std_call('VSIFileFromMemBuffer'), [c_char_p, c_void_p, c_int, c_int]) -get_mem_buffer_from_vsi_file = voidptr_output(std_call('VSIGetMemFileBuffer'), [c_char_p, POINTER(c_int), c_bool]) -unlink_vsi_file = int_output(std_call('VSIUnlink'), [c_char_p]) +create_vsi_file_from_mem_buffer = voidptr_output( + std_call("VSIFileFromMemBuffer"), [c_char_p, c_void_p, c_int, c_int] +) +get_mem_buffer_from_vsi_file = voidptr_output( + std_call("VSIGetMemFileBuffer"), [c_char_p, POINTER(c_int), c_bool] +) +unlink_vsi_file = int_output(std_call("VSIUnlink"), [c_char_p]) diff --git a/django/contrib/gis/gdal/prototypes/srs.py b/django/contrib/gis/gdal/prototypes/srs.py index c392b7bf2a..721eba8622 100644 --- a/django/contrib/gis/gdal/prototypes/srs.py +++ b/django/contrib/gis/gdal/prototypes/srs.py @@ -2,7 +2,11 @@ from ctypes import POINTER, c_char_p, c_int, c_void_p from django.contrib.gis.gdal.libgdal import GDAL_VERSION, lgdal, std_call from django.contrib.gis.gdal.prototypes.generation import ( - const_string_output, double_output, int_output, srs_output, string_output, + const_string_output, + double_output, + int_output, + srs_output, + string_output, void_output, ) @@ -25,14 +29,18 @@ def units_func(f): # Creation & destruction. -clone_srs = srs_output(std_call('OSRClone'), [c_void_p]) -new_srs = srs_output(std_call('OSRNewSpatialReference'), [c_char_p]) +clone_srs = srs_output(std_call("OSRClone"), [c_void_p]) +new_srs = srs_output(std_call("OSRNewSpatialReference"), [c_char_p]) release_srs = void_output(lgdal.OSRRelease, [c_void_p], errcheck=False) -destroy_srs = void_output(std_call('OSRDestroySpatialReference'), [c_void_p], errcheck=False) +destroy_srs = void_output( + std_call("OSRDestroySpatialReference"), [c_void_p], errcheck=False +) srs_validate = void_output(lgdal.OSRValidate, [c_void_p]) if GDAL_VERSION >= (3, 0): - set_axis_strategy = void_output(lgdal.OSRSetAxisMappingStrategy, [c_void_p, c_int], errcheck=False) + set_axis_strategy = void_output( + lgdal.OSRSetAxisMappingStrategy, [c_void_p, c_int], errcheck=False + ) # Getting the semi_major, semi_minor, and flattening functions. semi_major = srs_double(lgdal.OSRGetSemiMajor) @@ -42,9 +50,9 @@ invflattening = srs_double(lgdal.OSRGetInvFlattening) # WKT, PROJ, EPSG, XML importation routines. from_wkt = void_output(lgdal.OSRImportFromWkt, [c_void_p, POINTER(c_char_p)]) from_proj = void_output(lgdal.OSRImportFromProj4, [c_void_p, c_char_p]) -from_epsg = void_output(std_call('OSRImportFromEPSG'), [c_void_p, c_int]) +from_epsg = void_output(std_call("OSRImportFromEPSG"), [c_void_p, c_int]) from_xml = void_output(lgdal.OSRImportFromXML, [c_void_p, c_char_p]) -from_user_input = void_output(std_call('OSRSetFromUserInput'), [c_void_p, c_char_p]) +from_user_input = void_output(std_call("OSRSetFromUserInput"), [c_void_p, c_char_p]) # Morphing to/from ESRI WKT. morph_to_esri = void_output(lgdal.OSRMorphToESRI, [c_void_p]) @@ -58,19 +66,36 @@ linear_units = units_func(lgdal.OSRGetLinearUnits) angular_units = units_func(lgdal.OSRGetAngularUnits) # For exporting to WKT, PROJ, "Pretty" WKT, and XML. -to_wkt = string_output(std_call('OSRExportToWkt'), [c_void_p, POINTER(c_char_p)], decoding='utf-8') -to_proj = string_output(std_call('OSRExportToProj4'), [c_void_p, POINTER(c_char_p)], decoding='ascii') +to_wkt = string_output( + std_call("OSRExportToWkt"), [c_void_p, POINTER(c_char_p)], decoding="utf-8" +) +to_proj = string_output( + std_call("OSRExportToProj4"), [c_void_p, POINTER(c_char_p)], decoding="ascii" +) to_pretty_wkt = string_output( - std_call('OSRExportToPrettyWkt'), - [c_void_p, POINTER(c_char_p), c_int], offset=-2, decoding='utf-8' + std_call("OSRExportToPrettyWkt"), + [c_void_p, POINTER(c_char_p), c_int], + offset=-2, + decoding="utf-8", ) -to_xml = string_output(lgdal.OSRExportToXML, [c_void_p, POINTER(c_char_p), c_char_p], offset=-2, decoding='utf-8') +to_xml = string_output( + lgdal.OSRExportToXML, + [c_void_p, POINTER(c_char_p), c_char_p], + offset=-2, + decoding="utf-8", +) # String attribute retrieval routines. -get_attr_value = const_string_output(std_call('OSRGetAttrValue'), [c_void_p, c_char_p, c_int], decoding='utf-8') -get_auth_name = const_string_output(lgdal.OSRGetAuthorityName, [c_void_p, c_char_p], decoding='ascii') -get_auth_code = const_string_output(lgdal.OSRGetAuthorityCode, [c_void_p, c_char_p], decoding='ascii') +get_attr_value = const_string_output( + std_call("OSRGetAttrValue"), [c_void_p, c_char_p, c_int], decoding="utf-8" +) +get_auth_name = const_string_output( + lgdal.OSRGetAuthorityName, [c_void_p, c_char_p], decoding="ascii" +) +get_auth_code = const_string_output( + lgdal.OSRGetAuthorityCode, [c_void_p, c_char_p], decoding="ascii" +) # SRS Properties isgeographic = int_output(lgdal.OSRIsGeographic, [c_void_p]) @@ -78,5 +103,7 @@ islocal = int_output(lgdal.OSRIsLocal, [c_void_p]) isprojected = int_output(lgdal.OSRIsProjected, [c_void_p]) # Coordinate transformation -new_ct = srs_output(std_call('OCTNewCoordinateTransformation'), [c_void_p, c_void_p]) -destroy_ct = void_output(std_call('OCTDestroyCoordinateTransformation'), [c_void_p], errcheck=False) +new_ct = srs_output(std_call("OCTNewCoordinateTransformation"), [c_void_p, c_void_p]) +destroy_ct = void_output( + std_call("OCTDestroyCoordinateTransformation"), [c_void_p], errcheck=False +) diff --git a/django/contrib/gis/gdal/raster/band.py b/django/contrib/gis/gdal/raster/band.py index 6739a52628..c3ec960643 100644 --- a/django/contrib/gis/gdal/raster/band.py +++ b/django/contrib/gis/gdal/raster/band.py @@ -7,7 +7,10 @@ from django.contrib.gis.shortcuts import numpy from django.utils.encoding import force_str from .const import ( - GDAL_COLOR_TYPES, GDAL_INTEGER_TYPES, GDAL_PIXEL_TYPES, GDAL_TO_CTYPES, + GDAL_COLOR_TYPES, + GDAL_INTEGER_TYPES, + GDAL_PIXEL_TYPES, + GDAL_TO_CTYPES, ) @@ -15,6 +18,7 @@ class GDALBand(GDALRasterBase): """ Wrap a GDAL raster band, needs to be obtained from a GDALRaster object. """ + def __init__(self, source, index): self.source = source self._ptr = capi.get_ds_raster_band(source._ptr, index) @@ -79,8 +83,14 @@ class GDALBand(GDALRasterBase): # Prepare array with arguments for capi function smin, smax, smean, sstd = c_double(), c_double(), c_double(), c_double() stats_args = [ - self._ptr, c_int(approximate), byref(smin), byref(smax), - byref(smean), byref(sstd), c_void_p(), c_void_p(), + self._ptr, + c_int(approximate), + byref(smin), + byref(smax), + byref(smean), + byref(sstd), + c_void_p(), + c_void_p(), ] if refresh or self._stats_refresh: @@ -154,7 +164,7 @@ class GDALBand(GDALRasterBase): if value is None: capi.delete_band_nodata_value(self._ptr) elif not isinstance(value, (int, float)): - raise ValueError('Nodata value must be numeric or None.') + raise ValueError("Nodata value must be numeric or None.") else: capi.set_band_nodata_value(self._ptr, value) self._flush() @@ -188,10 +198,10 @@ class GDALBand(GDALRasterBase): size = size or (self.width - offset[0], self.height - offset[1]) shape = shape or size if any(x <= 0 for x in size): - raise ValueError('Offset too big for this raster.') + raise ValueError("Offset too big for this raster.") if size[0] > self.width or size[1] > self.height: - raise ValueError('Size is larger than raster.') + raise ValueError("Size is larger than raster.") # Create ctypes type array generator ctypes_array = GDAL_TO_CTYPES[self.datatype()] * (shape[0] * shape[1]) @@ -206,15 +216,28 @@ class GDALBand(GDALRasterBase): access_flag = 1 # Instantiate ctypes array holding the input data - if isinstance(data, (bytes, memoryview)) or (numpy and isinstance(data, numpy.ndarray)): + if isinstance(data, (bytes, memoryview)) or ( + numpy and isinstance(data, numpy.ndarray) + ): data_array = ctypes_array.from_buffer_copy(data) else: data_array = ctypes_array(*data) # Access band - capi.band_io(self._ptr, access_flag, offset[0], offset[1], - size[0], size[1], byref(data_array), shape[0], - shape[1], self.datatype(), 0, 0) + capi.band_io( + self._ptr, + access_flag, + offset[0], + offset[1], + size[0], + size[1], + byref(data_array), + shape[0], + shape[1], + self.datatype(), + 0, + 0, + ) # Return data as numpy array if possible, otherwise as list if data is None: @@ -247,4 +270,4 @@ class BandList(list): try: return GDALBand(self.source, index + 1) except GDALException: - raise GDALException('Unable to get band index %d' % index) + raise GDALException("Unable to get band index %d" % index) diff --git a/django/contrib/gis/gdal/raster/base.py b/django/contrib/gis/gdal/raster/base.py index c5b438ee64..3d95c90dc7 100644 --- a/django/contrib/gis/gdal/raster/base.py +++ b/django/contrib/gis/gdal/raster/base.py @@ -6,6 +6,7 @@ class GDALRasterBase(GDALBase): """ Attributes that exist on both GDALRaster and GDALBand. """ + @property def metadata(self): """ @@ -15,7 +16,7 @@ class GDALRasterBase(GDALBase): """ # The initial metadata domain list contains the default domain. # The default is returned if domain name is None. - domain_list = ['DEFAULT'] + domain_list = ["DEFAULT"] # Get additional metadata domains from the raster. meta_list = capi.get_ds_metadata_domain_list(self._ptr) @@ -38,7 +39,7 @@ class GDALRasterBase(GDALBase): # Get metadata for this domain. data = capi.get_ds_metadata( self._ptr, - (None if domain == 'DEFAULT' else domain.encode()), + (None if domain == "DEFAULT" else domain.encode()), ) if not data: continue @@ -48,12 +49,12 @@ class GDALRasterBase(GDALBase): counter = 0 item = data[counter] while item: - key, val = item.decode().split('=') + key, val = item.decode().split("=") domain_meta[key] = val counter += 1 item = data[counter] # The default domain values are returned if domain is None. - result[domain or 'DEFAULT'] = domain_meta + result[domain or "DEFAULT"] = domain_meta return result @metadata.setter @@ -65,11 +66,12 @@ class GDALRasterBase(GDALBase): # Loop through domains. for domain, metadata in value.items(): # Set the domain to None for the default, otherwise encode. - domain = None if domain == 'DEFAULT' else domain.encode() + domain = None if domain == "DEFAULT" else domain.encode() # Set each metadata entry separately. for meta_name, meta_value in metadata.items(): capi.set_ds_metadata_item( - self._ptr, meta_name.encode(), + self._ptr, + meta_name.encode(), meta_value.encode() if meta_value else None, domain, ) diff --git a/django/contrib/gis/gdal/raster/const.py b/django/contrib/gis/gdal/raster/const.py index c246a6a190..6d3761d9fb 100644 --- a/django/contrib/gis/gdal/raster/const.py +++ b/django/contrib/gis/gdal/raster/const.py @@ -1,24 +1,22 @@ """ GDAL - Constant definitions """ -from ctypes import ( - c_double, c_float, c_int16, c_int32, c_ubyte, c_uint16, c_uint32, -) +from ctypes import c_double, c_float, c_int16, c_int32, c_ubyte, c_uint16, c_uint32 # See https://gdal.org/api/raster_c_api.html#_CPPv412GDALDataType GDAL_PIXEL_TYPES = { - 0: 'GDT_Unknown', # Unknown or unspecified type - 1: 'GDT_Byte', # Eight bit unsigned integer - 2: 'GDT_UInt16', # Sixteen bit unsigned integer - 3: 'GDT_Int16', # Sixteen bit signed integer - 4: 'GDT_UInt32', # Thirty-two bit unsigned integer - 5: 'GDT_Int32', # Thirty-two bit signed integer - 6: 'GDT_Float32', # Thirty-two bit floating point - 7: 'GDT_Float64', # Sixty-four bit floating point - 8: 'GDT_CInt16', # Complex Int16 - 9: 'GDT_CInt32', # Complex Int32 - 10: 'GDT_CFloat32', # Complex Float32 - 11: 'GDT_CFloat64', # Complex Float64 + 0: "GDT_Unknown", # Unknown or unspecified type + 1: "GDT_Byte", # Eight bit unsigned integer + 2: "GDT_UInt16", # Sixteen bit unsigned integer + 3: "GDT_Int16", # Sixteen bit signed integer + 4: "GDT_UInt32", # Thirty-two bit unsigned integer + 5: "GDT_Int32", # Thirty-two bit signed integer + 6: "GDT_Float32", # Thirty-two bit floating point + 7: "GDT_Float64", # Sixty-four bit floating point + 8: "GDT_CInt16", # Complex Int16 + 9: "GDT_CInt32", # Complex Int32 + 10: "GDT_CFloat32", # Complex Float32 + 11: "GDT_CFloat64", # Complex Float64 } # A list of gdal datatypes that are integers. @@ -29,47 +27,57 @@ GDAL_INTEGER_TYPES = [1, 2, 3, 4, 5] # or to hold the space for data to be read into. The lookup below helps # selecting the right ctypes object for a given gdal pixel type. GDAL_TO_CTYPES = [ - None, c_ubyte, c_uint16, c_int16, c_uint32, c_int32, - c_float, c_double, None, None, None, None + None, + c_ubyte, + c_uint16, + c_int16, + c_uint32, + c_int32, + c_float, + c_double, + None, + None, + None, + None, ] # List of resampling algorithms that can be used to warp a GDALRaster. GDAL_RESAMPLE_ALGORITHMS = { - 'NearestNeighbour': 0, - 'Bilinear': 1, - 'Cubic': 2, - 'CubicSpline': 3, - 'Lanczos': 4, - 'Average': 5, - 'Mode': 6, + "NearestNeighbour": 0, + "Bilinear": 1, + "Cubic": 2, + "CubicSpline": 3, + "Lanczos": 4, + "Average": 5, + "Mode": 6, } # See https://gdal.org/api/raster_c_api.html#_CPPv415GDALColorInterp GDAL_COLOR_TYPES = { - 0: 'GCI_Undefined', # Undefined, default value, i.e. not known - 1: 'GCI_GrayIndex', # Grayscale - 2: 'GCI_PaletteIndex', # Paletted - 3: 'GCI_RedBand', # Red band of RGBA image - 4: 'GCI_GreenBand', # Green band of RGBA image - 5: 'GCI_BlueBand', # Blue band of RGBA image - 6: 'GCI_AlphaBand', # Alpha (0=transparent, 255=opaque) - 7: 'GCI_HueBand', # Hue band of HLS image - 8: 'GCI_SaturationBand', # Saturation band of HLS image - 9: 'GCI_LightnessBand', # Lightness band of HLS image - 10: 'GCI_CyanBand', # Cyan band of CMYK image - 11: 'GCI_MagentaBand', # Magenta band of CMYK image - 12: 'GCI_YellowBand', # Yellow band of CMYK image - 13: 'GCI_BlackBand', # Black band of CMLY image - 14: 'GCI_YCbCr_YBand', # Y Luminance - 15: 'GCI_YCbCr_CbBand', # Cb Chroma - 16: 'GCI_YCbCr_CrBand', # Cr Chroma, also GCI_Max + 0: "GCI_Undefined", # Undefined, default value, i.e. not known + 1: "GCI_GrayIndex", # Grayscale + 2: "GCI_PaletteIndex", # Paletted + 3: "GCI_RedBand", # Red band of RGBA image + 4: "GCI_GreenBand", # Green band of RGBA image + 5: "GCI_BlueBand", # Blue band of RGBA image + 6: "GCI_AlphaBand", # Alpha (0=transparent, 255=opaque) + 7: "GCI_HueBand", # Hue band of HLS image + 8: "GCI_SaturationBand", # Saturation band of HLS image + 9: "GCI_LightnessBand", # Lightness band of HLS image + 10: "GCI_CyanBand", # Cyan band of CMYK image + 11: "GCI_MagentaBand", # Magenta band of CMYK image + 12: "GCI_YellowBand", # Yellow band of CMYK image + 13: "GCI_BlackBand", # Black band of CMLY image + 14: "GCI_YCbCr_YBand", # Y Luminance + 15: "GCI_YCbCr_CbBand", # Cb Chroma + 16: "GCI_YCbCr_CrBand", # Cr Chroma, also GCI_Max } # GDAL virtual filesystems prefix. -VSI_FILESYSTEM_PREFIX = '/vsi' +VSI_FILESYSTEM_PREFIX = "/vsi" # Fixed base path for buffer-based GDAL in-memory files. -VSI_MEM_FILESYSTEM_BASE_PATH = '/vsimem/' +VSI_MEM_FILESYSTEM_BASE_PATH = "/vsimem/" # Should the memory file system take ownership of the buffer, freeing it when # the file is deleted? (No, GDALRaster.__del__() will delete the buffer.) diff --git a/django/contrib/gis/gdal/raster/source.py b/django/contrib/gis/gdal/raster/source.py index 7c17b85925..ca7875752b 100644 --- a/django/contrib/gis/gdal/raster/source.py +++ b/django/contrib/gis/gdal/raster/source.py @@ -3,7 +3,14 @@ import os import sys import uuid from ctypes import ( - addressof, byref, c_buffer, c_char_p, c_double, c_int, c_void_p, string_at, + addressof, + byref, + c_buffer, + c_char_p, + c_double, + c_int, + c_void_p, + string_at, ) from django.contrib.gis.gdal.driver import Driver @@ -12,8 +19,11 @@ from django.contrib.gis.gdal.prototypes import raster as capi from django.contrib.gis.gdal.raster.band import BandList from django.contrib.gis.gdal.raster.base import GDALRasterBase from django.contrib.gis.gdal.raster.const import ( - GDAL_RESAMPLE_ALGORITHMS, VSI_DELETE_BUFFER_ON_READ, VSI_FILESYSTEM_PREFIX, - VSI_MEM_FILESYSTEM_BASE_PATH, VSI_TAKE_BUFFER_OWNERSHIP, + GDAL_RESAMPLE_ALGORITHMS, + VSI_DELETE_BUFFER_ON_READ, + VSI_FILESYSTEM_PREFIX, + VSI_MEM_FILESYSTEM_BASE_PATH, + VSI_TAKE_BUFFER_OWNERSHIP, ) from django.contrib.gis.gdal.srs import SpatialReference, SRSException from django.contrib.gis.geometry import json_regex @@ -23,9 +33,9 @@ from django.utils.functional import cached_property class TransformPoint(list): indices = { - 'origin': (0, 3), - 'scale': (1, 5), - 'skew': (2, 4), + "origin": (0, 3), + "scale": (1, 5), + "skew": (2, 4), } def __init__(self, raster, prop): @@ -60,6 +70,7 @@ class GDALRaster(GDALRasterBase): """ Wrap a raster GDAL Data Source object. """ + destructor = capi.close_ds def __init__(self, ds_input, write=False): @@ -73,9 +84,8 @@ class GDALRaster(GDALRasterBase): # If input is a valid file path, try setting file as source. if isinstance(ds_input, str): - if ( - not ds_input.startswith(VSI_FILESYSTEM_PREFIX) and - not os.path.exists(ds_input) + if not ds_input.startswith(VSI_FILESYSTEM_PREFIX) and not os.path.exists( + ds_input ): raise GDALException( 'Unable to read raster source input "%s".' % ds_input @@ -84,7 +94,9 @@ class GDALRaster(GDALRasterBase): # GDALOpen will auto-detect the data source type. self._ptr = capi.open_ds(force_bytes(ds_input), self._write) except GDALException as err: - raise GDALException('Could not open the datasource at "{}" ({}).'.format(ds_input, err)) + raise GDALException( + 'Could not open the datasource at "{}" ({}).'.format(ds_input, err) + ) elif isinstance(ds_input, bytes): # Create a new raster in write mode. self._write = 1 @@ -109,30 +121,36 @@ class GDALRaster(GDALRasterBase): except GDALException: # Remove the broken file from the VSI filesystem. capi.unlink_vsi_file(force_bytes(vsi_path)) - raise GDALException('Failed creating VSI raster from the input buffer.') + raise GDALException("Failed creating VSI raster from the input buffer.") elif isinstance(ds_input, dict): # A new raster needs to be created in write mode self._write = 1 # Create driver (in memory by default) - driver = Driver(ds_input.get('driver', 'MEM')) + driver = Driver(ds_input.get("driver", "MEM")) # For out of memory drivers, check filename argument - if driver.name != 'MEM' and 'name' not in ds_input: - raise GDALException('Specify name for creation of raster with driver "{}".'.format(driver.name)) + if driver.name != "MEM" and "name" not in ds_input: + raise GDALException( + 'Specify name for creation of raster with driver "{}".'.format( + driver.name + ) + ) # Check if width and height where specified - if 'width' not in ds_input or 'height' not in ds_input: - raise GDALException('Specify width and height attributes for JSON or dict input.') + if "width" not in ds_input or "height" not in ds_input: + raise GDALException( + "Specify width and height attributes for JSON or dict input." + ) # Check if srid was specified - if 'srid' not in ds_input: - raise GDALException('Specify srid for JSON or dict input.') + if "srid" not in ds_input: + raise GDALException("Specify srid for JSON or dict input.") # Create null terminated gdal options array. papsz_options = [] - for key, val in ds_input.get('papsz_options', {}).items(): - option = '{}={}'.format(key, val) + for key, val in ds_input.get("papsz_options", {}).items(): + option = "{}={}".format(key, val) papsz_options.append(option.upper().encode()) papsz_options.append(None) @@ -142,51 +160,54 @@ class GDALRaster(GDALRasterBase): # Create GDAL Raster self._ptr = capi.create_ds( driver._ptr, - force_bytes(ds_input.get('name', '')), - ds_input['width'], - ds_input['height'], - ds_input.get('nr_of_bands', len(ds_input.get('bands', []))), - ds_input.get('datatype', 6), + force_bytes(ds_input.get("name", "")), + ds_input["width"], + ds_input["height"], + ds_input.get("nr_of_bands", len(ds_input.get("bands", []))), + ds_input.get("datatype", 6), byref(papsz_options), ) # Set band data if provided - for i, band_input in enumerate(ds_input.get('bands', [])): + for i, band_input in enumerate(ds_input.get("bands", [])): band = self.bands[i] - if 'nodata_value' in band_input: - band.nodata_value = band_input['nodata_value'] + if "nodata_value" in band_input: + band.nodata_value = band_input["nodata_value"] # Instantiate band filled with nodata values if only # partial input data has been provided. if band.nodata_value is not None and ( - 'data' not in band_input or - 'size' in band_input or - 'shape' in band_input): + "data" not in band_input + or "size" in band_input + or "shape" in band_input + ): band.data(data=(band.nodata_value,), shape=(1, 1)) # Set band data values from input. band.data( - data=band_input.get('data'), - size=band_input.get('size'), - shape=band_input.get('shape'), - offset=band_input.get('offset'), + data=band_input.get("data"), + size=band_input.get("size"), + shape=band_input.get("shape"), + offset=band_input.get("offset"), ) # Set SRID - self.srs = ds_input.get('srid') + self.srs = ds_input.get("srid") # Set additional properties if provided - if 'origin' in ds_input: - self.origin.x, self.origin.y = ds_input['origin'] + if "origin" in ds_input: + self.origin.x, self.origin.y = ds_input["origin"] - if 'scale' in ds_input: - self.scale.x, self.scale.y = ds_input['scale'] + if "scale" in ds_input: + self.scale.x, self.scale.y = ds_input["scale"] - if 'skew' in ds_input: - self.skew.x, self.skew.y = ds_input['skew'] + if "skew" in ds_input: + self.skew.x, self.skew.y = ds_input["skew"] elif isinstance(ds_input, c_void_p): # Instantiate the object using an existing pointer to a gdal raster. self._ptr = ds_input else: - raise GDALException('Invalid data source input type: "{}".'.format(type(ds_input))) + raise GDALException( + 'Invalid data source input type: "{}".'.format(type(ds_input)) + ) def __del__(self): if self.is_vsi_based: @@ -201,7 +222,7 @@ class GDALRaster(GDALRasterBase): """ Short-hand representation because WKB may be very large. """ - return '<Raster object at %s>' % hex(addressof(self._ptr)) + return "<Raster object at %s>" % hex(addressof(self._ptr)) def _flush(self): """ @@ -212,14 +233,15 @@ class GDALRaster(GDALRasterBase): """ # Raise an Exception if the value is being changed in read mode. if not self._write: - raise GDALException('Raster needs to be opened in write mode to change values.') + raise GDALException( + "Raster needs to be opened in write mode to change values." + ) capi.flush_ds(self._ptr) @property def vsi_buffer(self): if not ( - self.is_vsi_based and - self.name.startswith(VSI_MEM_FILESYSTEM_BASE_PATH) + self.is_vsi_based and self.name.startswith(VSI_MEM_FILESYSTEM_BASE_PATH) ): return None # Prepare an integer that will contain the buffer length. @@ -276,7 +298,7 @@ class GDALRaster(GDALRasterBase): wkt = capi.get_ds_projection_ref(self._ptr) if not wkt: return None - return SpatialReference(wkt, srs_type='wkt') + return SpatialReference(wkt, srs_type="wkt") except SRSException: return None @@ -292,7 +314,7 @@ class GDALRaster(GDALRasterBase): elif isinstance(value, (int, str)): srs = SpatialReference(value) else: - raise ValueError('Could not create a SpatialReference from input.') + raise ValueError("Could not create a SpatialReference from input.") capi.set_ds_projection_ref(self._ptr, srs.wkt.encode()) self._flush() @@ -326,7 +348,7 @@ class GDALRaster(GDALRasterBase): def geotransform(self, values): "Set the geotransform for the data source." if len(values) != 6 or not all(isinstance(x, (int, float)) for x in values): - raise ValueError('Geotransform must consist of 6 numeric values.') + raise ValueError("Geotransform must consist of 6 numeric values.") # Create ctypes double array with input and write data values = (c_double * 6)(*values) capi.set_ds_geotransform(self._ptr, byref(values)) @@ -337,21 +359,21 @@ class GDALRaster(GDALRasterBase): """ Coordinates of the raster origin. """ - return TransformPoint(self, 'origin') + return TransformPoint(self, "origin") @property def scale(self): """ Pixel scale in units of the raster projection. """ - return TransformPoint(self, 'scale') + return TransformPoint(self, "scale") @property def skew(self): """ Skew of pixels (rotation parameters). """ - return TransformPoint(self, 'skew') + return TransformPoint(self, "skew") @property def extent(self): @@ -373,7 +395,7 @@ class GDALRaster(GDALRasterBase): def bands(self): return BandList(self) - def warp(self, ds_input, resampling='NearestNeighbour', max_error=0.0): + def warp(self, ds_input, resampling="NearestNeighbour", max_error=0.0): """ Return a warped GDALRaster with the given input characteristics. @@ -391,23 +413,23 @@ class GDALRaster(GDALRasterBase): consult the GDAL_RESAMPLE_ALGORITHMS constant. """ # Get the parameters defining the geotransform, srid, and size of the raster - ds_input.setdefault('width', self.width) - ds_input.setdefault('height', self.height) - ds_input.setdefault('srid', self.srs.srid) - ds_input.setdefault('origin', self.origin) - ds_input.setdefault('scale', self.scale) - ds_input.setdefault('skew', self.skew) + ds_input.setdefault("width", self.width) + ds_input.setdefault("height", self.height) + ds_input.setdefault("srid", self.srs.srid) + ds_input.setdefault("origin", self.origin) + ds_input.setdefault("scale", self.scale) + ds_input.setdefault("skew", self.skew) # Get the driver, name, and datatype of the target raster - ds_input.setdefault('driver', self.driver.name) + ds_input.setdefault("driver", self.driver.name) - if 'name' not in ds_input: - ds_input['name'] = self.name + '_copy.' + self.driver.name + if "name" not in ds_input: + ds_input["name"] = self.name + "_copy." + self.driver.name - if 'datatype' not in ds_input: - ds_input['datatype'] = self.bands[0].datatype() + if "datatype" not in ds_input: + ds_input["datatype"] = self.bands[0].datatype() # Instantiate raster bands filled with nodata values. - ds_input['bands'] = [{'nodata_value': bnd.nodata_value} for bnd in self.bands] + ds_input["bands"] = [{"nodata_value": bnd.nodata_value} for bnd in self.bands] # Create target raster target = GDALRaster(ds_input, write=True) @@ -417,10 +439,16 @@ class GDALRaster(GDALRasterBase): # Reproject image capi.reproject_image( - self._ptr, self.srs.wkt.encode(), - target._ptr, target.srs.wkt.encode(), - algorithm, 0.0, max_error, - c_void_p(), c_void_p(), c_void_p() + self._ptr, + self.srs.wkt.encode(), + target._ptr, + target.srs.wkt.encode(), + algorithm, + 0.0, + max_error, + c_void_p(), + c_void_p(), + c_void_p(), ) # Make sure all data is written to file @@ -432,8 +460,8 @@ class GDALRaster(GDALRasterBase): """Return a clone of this GDALRaster.""" if name: clone_name = name - elif self.driver.name != 'MEM': - clone_name = self.name + '_copy.' + self.driver.name + elif self.driver.name != "MEM": + clone_name = self.name + "_copy." + self.driver.name else: clone_name = os.path.join(VSI_MEM_FILESYSTEM_BASE_PATH, str(uuid.uuid4())) return GDALRaster( @@ -449,8 +477,9 @@ class GDALRaster(GDALRasterBase): write=self._write, ) - def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour', - max_error=0.0): + def transform( + self, srs, driver=None, name=None, resampling="NearestNeighbour", max_error=0.0 + ): """ Return a copy of this raster reprojected into the given spatial reference system. @@ -464,35 +493,39 @@ class GDALRaster(GDALRasterBase): target_srs = SpatialReference(srs) else: raise TypeError( - 'Transform only accepts SpatialReference, string, and integer ' - 'objects.' + "Transform only accepts SpatialReference, string, and integer " + "objects." ) if target_srs.srid == self.srid and (not driver or driver == self.driver.name): return self.clone(name) # Create warped virtual dataset in the target reference system target = capi.auto_create_warped_vrt( - self._ptr, self.srs.wkt.encode(), target_srs.wkt.encode(), - algorithm, max_error, c_void_p() + self._ptr, + self.srs.wkt.encode(), + target_srs.wkt.encode(), + algorithm, + max_error, + c_void_p(), ) target = GDALRaster(target) # Construct the target warp dictionary from the virtual raster data = { - 'srid': target_srs.srid, - 'width': target.width, - 'height': target.height, - 'origin': [target.origin.x, target.origin.y], - 'scale': [target.scale.x, target.scale.y], - 'skew': [target.skew.x, target.skew.y], + "srid": target_srs.srid, + "width": target.width, + "height": target.height, + "origin": [target.origin.x, target.origin.y], + "scale": [target.scale.x, target.scale.y], + "skew": [target.skew.x, target.skew.y], } # Set the driver and filepath if provided if driver: - data['driver'] = driver + data["driver"] = driver if name: - data['name'] = name + data["name"] = name # Warp the raster into new srid return self.warp(data, resampling=resampling, max_error=max_error) diff --git a/django/contrib/gis/gdal/srs.py b/django/contrib/gis/gdal/srs.py index 5df8b54ab2..41477d55db 100644 --- a/django/contrib/gis/gdal/srs.py +++ b/django/contrib/gis/gdal/srs.py @@ -47,9 +47,10 @@ class SpatialReference(GDALBase): the SpatialReference object "provide[s] services to represent coordinate systems (projections and datums) and to transform between them." """ + destructor = capi.release_srs - def __init__(self, srs_input='', srs_type='user', axis_order=None): + def __init__(self, srs_input="", srs_type="user", axis_order=None): """ Create a GDAL OSR Spatial Reference object from the given input. The input may be string of OGC Well Known Text (WKT), an integer @@ -58,56 +59,58 @@ class SpatialReference(GDALBase): """ if not isinstance(axis_order, (type(None), AxisOrder)): raise ValueError( - 'SpatialReference.axis_order must be an AxisOrder instance.' + "SpatialReference.axis_order must be an AxisOrder instance." ) self.axis_order = axis_order or AxisOrder.TRADITIONAL - if srs_type == 'wkt': - self.ptr = capi.new_srs(c_char_p(b'')) + if srs_type == "wkt": + self.ptr = capi.new_srs(c_char_p(b"")) self.import_wkt(srs_input) if self.axis_order == AxisOrder.TRADITIONAL and GDAL_VERSION >= (3, 0): capi.set_axis_strategy(self.ptr, self.axis_order) elif self.axis_order != AxisOrder.TRADITIONAL and GDAL_VERSION < (3, 0): - raise ValueError('%s is not supported in GDAL < 3.0.' % self.axis_order) + raise ValueError("%s is not supported in GDAL < 3.0." % self.axis_order) return elif isinstance(srs_input, str): try: # If SRID is a string, e.g., '4326', then make acceptable # as user input. srid = int(srs_input) - srs_input = 'EPSG:%d' % srid + srs_input = "EPSG:%d" % srid except ValueError: pass elif isinstance(srs_input, int): # EPSG integer code was input. - srs_type = 'epsg' + srs_type = "epsg" elif isinstance(srs_input, self.ptr_type): srs = srs_input - srs_type = 'ogr' + srs_type = "ogr" else: raise TypeError('Invalid SRS type "%s"' % srs_type) - if srs_type == 'ogr': + if srs_type == "ogr": # Input is already an SRS pointer. srs = srs_input else: # Creating a new SRS pointer, using the string buffer. - buf = c_char_p(b'') + buf = c_char_p(b"") srs = capi.new_srs(buf) # If the pointer is NULL, throw an exception. if not srs: - raise SRSException('Could not create spatial reference from: %s' % srs_input) + raise SRSException( + "Could not create spatial reference from: %s" % srs_input + ) else: self.ptr = srs if self.axis_order == AxisOrder.TRADITIONAL and GDAL_VERSION >= (3, 0): capi.set_axis_strategy(self.ptr, self.axis_order) elif self.axis_order != AxisOrder.TRADITIONAL and GDAL_VERSION < (3, 0): - raise ValueError('%s is not supported in GDAL < 3.0.' % self.axis_order) + raise ValueError("%s is not supported in GDAL < 3.0." % self.axis_order) # Importing from either the user input string or an integer SRID. - if srs_type == 'user': + if srs_type == "user": self.import_user_input(srs_input) - elif srs_type == 'epsg': + elif srs_type == "epsg": self.import_epsg(srs_input) def __getitem__(self, target): @@ -188,11 +191,11 @@ class SpatialReference(GDALBase): def name(self): "Return the name of this Spatial Reference." if self.projected: - return self.attr_value('PROJCS') + return self.attr_value("PROJCS") elif self.geographic: - return self.attr_value('GEOGCS') + return self.attr_value("GEOGCS") elif self.local: - return self.attr_value('LOCAL_CS') + return self.attr_value("LOCAL_CS") else: return None @@ -200,7 +203,7 @@ class SpatialReference(GDALBase): def srid(self): "Return the SRID of top-level authority, or None if undefined." try: - return int(self.attr_value('AUTHORITY', 1)) + return int(self.attr_value("AUTHORITY", 1)) except (TypeError, ValueError): return None @@ -333,7 +336,7 @@ class SpatialReference(GDALBase): return self.proj @property - def xml(self, dialect=''): + def xml(self, dialect=""): "Return the XML representation of this Spatial Reference." return capi.to_xml(self.ptr, byref(c_char_p()), force_bytes(dialect)) @@ -344,8 +347,10 @@ class CoordTransform(GDALBase): def __init__(self, source, target): "Initialize on a source and target SpatialReference objects." - if not isinstance(source, SpatialReference) or not isinstance(target, SpatialReference): - raise TypeError('source and target must be of type SpatialReference') + if not isinstance(source, SpatialReference) or not isinstance( + target, SpatialReference + ): + raise TypeError("source and target must be of type SpatialReference") self.ptr = capi.new_ct(source._ptr, target._ptr) self._srs1_name = source.name self._srs2_name = target.name diff --git a/django/contrib/gis/geoip2/__init__.py b/django/contrib/gis/geoip2/__init__.py index 191706b01e..71b71f68db 100644 --- a/django/contrib/gis/geoip2/__init__.py +++ b/django/contrib/gis/geoip2/__init__.py @@ -11,7 +11,7 @@ downloaded from MaxMind at https://dev.maxmind.com/geoip/geoip2/geolite2/. Grab GeoLite2-Country.mmdb.gz and GeoLite2-City.mmdb.gz, and unzip them in the directory corresponding to settings.GEOIP_PATH. """ -__all__ = ['HAS_GEOIP2'] +__all__ = ["HAS_GEOIP2"] try: import geoip2 # NOQA @@ -19,5 +19,6 @@ except ImportError: HAS_GEOIP2 = False else: from .base import GeoIP2, GeoIP2Exception + HAS_GEOIP2 = True - __all__ += ['GeoIP2', 'GeoIP2Exception'] + __all__ += ["GeoIP2", "GeoIP2Exception"] diff --git a/django/contrib/gis/geoip2/base.py b/django/contrib/gis/geoip2/base.py index bcd8f1e2d8..e74984fe45 100644 --- a/django/contrib/gis/geoip2/base.py +++ b/django/contrib/gis/geoip2/base.py @@ -12,9 +12,9 @@ from .resources import City, Country # Creating the settings dictionary with any settings, if needed. GEOIP_SETTINGS = { - 'GEOIP_PATH': getattr(settings, 'GEOIP_PATH', None), - 'GEOIP_CITY': getattr(settings, 'GEOIP_CITY', 'GeoLite2-City.mmdb'), - 'GEOIP_COUNTRY': getattr(settings, 'GEOIP_COUNTRY', 'GeoLite2-Country.mmdb'), + "GEOIP_PATH": getattr(settings, "GEOIP_PATH", None), + "GEOIP_CITY": getattr(settings, "GEOIP_CITY", "GeoLite2-City.mmdb"), + "GEOIP_COUNTRY": getattr(settings, "GEOIP_COUNTRY", "GeoLite2-Country.mmdb"), } @@ -34,11 +34,13 @@ class GeoIP2: MODE_FILE = 4 # Load database into memory. Pure Python. MODE_MEMORY = 8 - cache_options = frozenset((MODE_AUTO, MODE_MMAP_EXT, MODE_MMAP, MODE_FILE, MODE_MEMORY)) + cache_options = frozenset( + (MODE_AUTO, MODE_MMAP_EXT, MODE_MMAP, MODE_FILE, MODE_MEMORY) + ) # Paths to the city & country binary databases. - _city_file = '' - _country_file = '' + _city_file = "" + _country_file = "" # Initially, pointers to GeoIP file references are NULL. _city = None @@ -69,29 +71,31 @@ class GeoIP2: """ # Checking the given cache option. if cache not in self.cache_options: - raise GeoIP2Exception('Invalid GeoIP caching option: %s' % cache) + raise GeoIP2Exception("Invalid GeoIP caching option: %s" % cache) # Getting the GeoIP data path. - path = path or GEOIP_SETTINGS['GEOIP_PATH'] + path = path or GEOIP_SETTINGS["GEOIP_PATH"] if not path: - raise GeoIP2Exception('GeoIP path must be provided via parameter or the GEOIP_PATH setting.') + raise GeoIP2Exception( + "GeoIP path must be provided via parameter or the GEOIP_PATH setting." + ) path = to_path(path) if path.is_dir(): # Constructing the GeoIP database filenames using the settings # dictionary. If the database files for the GeoLite country # and/or city datasets exist, then try to open them. - country_db = path / (country or GEOIP_SETTINGS['GEOIP_COUNTRY']) + country_db = path / (country or GEOIP_SETTINGS["GEOIP_COUNTRY"]) if country_db.is_file(): self._country = geoip2.database.Reader(str(country_db), mode=cache) self._country_file = country_db - city_db = path / (city or GEOIP_SETTINGS['GEOIP_CITY']) + city_db = path / (city or GEOIP_SETTINGS["GEOIP_CITY"]) if city_db.is_file(): self._city = geoip2.database.Reader(str(city_db), mode=cache) self._city_file = city_db if not self._reader: - raise GeoIP2Exception('Could not load a database from %s.' % path) + raise GeoIP2Exception("Could not load a database from %s." % path) elif path.is_file(): # Otherwise, some detective work will be needed to figure out # whether the given database path is for the GeoIP country or city @@ -99,18 +103,20 @@ class GeoIP2: reader = geoip2.database.Reader(str(path), mode=cache) db_type = reader.metadata().database_type - if db_type.endswith('City'): + if db_type.endswith("City"): # GeoLite City database detected. self._city = reader self._city_file = path - elif db_type.endswith('Country'): + elif db_type.endswith("Country"): # GeoIP Country database detected. self._country = reader self._country_file = path else: - raise GeoIP2Exception('Unable to recognize database edition: %s' % db_type) + raise GeoIP2Exception( + "Unable to recognize database edition: %s" % db_type + ) else: - raise GeoIP2Exception('GeoIP path must be a valid file or directory.') + raise GeoIP2Exception("GeoIP path must be a valid file or directory.") @property def _reader(self): @@ -130,25 +136,33 @@ class GeoIP2: def __repr__(self): meta = self._reader.metadata() - version = '[v%s.%s]' % (meta.binary_format_major_version, meta.binary_format_minor_version) - return '<%(cls)s %(version)s _country_file="%(country)s", _city_file="%(city)s">' % { - 'cls': self.__class__.__name__, - 'version': version, - 'country': self._country_file, - 'city': self._city_file, - } + version = "[v%s.%s]" % ( + meta.binary_format_major_version, + meta.binary_format_minor_version, + ) + return ( + '<%(cls)s %(version)s _country_file="%(country)s", _city_file="%(city)s">' + % { + "cls": self.__class__.__name__, + "version": version, + "country": self._country_file, + "city": self._city_file, + } + ) def _check_query(self, query, city=False, city_or_country=False): "Check the query and database availability." # Making sure a string was passed in for the query. if not isinstance(query, str): - raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__) + raise TypeError( + "GeoIP query must be a string, not type %s" % type(query).__name__ + ) # Extra checks for the existence of country and city databases. if city_or_country and not (self._country or self._city): - raise GeoIP2Exception('Invalid GeoIP country and city data files.') + raise GeoIP2Exception("Invalid GeoIP country and city data files.") elif city and not self._city: - raise GeoIP2Exception('Invalid GeoIP city data file: %s' % self._city_file) + raise GeoIP2Exception("Invalid GeoIP city data file: %s" % self._city_file) # Return the query string back to the caller. GeoIP2 only takes IP addresses. try: @@ -169,11 +183,11 @@ class GeoIP2: def country_code(self, query): "Return the country code for the given IP Address or FQDN." - return self.country(query)['country_code'] + return self.country(query)["country_code"] def country_name(self, query): "Return the country name for the given IP Address or FQDN." - return self.country(query)['country_name'] + return self.country(query)["country_name"] def country(self, query): """ @@ -186,7 +200,7 @@ class GeoIP2: return Country(self._country_or_city(enc_query)) # #### Coordinate retrieval routines #### - def coords(self, query, ordering=('longitude', 'latitude')): + def coords(self, query, ordering=("longitude", "latitude")): cdict = self.city(query) if cdict is None: return None @@ -199,7 +213,7 @@ class GeoIP2: def lat_lon(self, query): "Return a tuple of the (latitude, longitude) for the given query." - return self.coords(query, ('latitude', 'longitude')) + return self.coords(query, ("latitude", "longitude")) def geos(self, query): "Return a GEOS Point object for the given query." @@ -214,7 +228,10 @@ class GeoIP2: def info(self): "Return information about the GeoIP library and databases in use." meta = self._reader.metadata() - return 'GeoIP Library:\n\t%s.%s\n' % (meta.binary_format_major_version, meta.binary_format_minor_version) + return "GeoIP Library:\n\t%s.%s\n" % ( + meta.binary_format_major_version, + meta.binary_format_minor_version, + ) @classmethod def open(cls, full_path, cache): diff --git a/django/contrib/gis/geoip2/resources.py b/django/contrib/gis/geoip2/resources.py index 08923b0998..74f4228697 100644 --- a/django/contrib/gis/geoip2/resources.py +++ b/django/contrib/gis/geoip2/resources.py @@ -1,22 +1,22 @@ def City(response): return { - 'city': response.city.name, - 'continent_code': response.continent.code, - 'continent_name': response.continent.name, - 'country_code': response.country.iso_code, - 'country_name': response.country.name, - 'dma_code': response.location.metro_code, - 'is_in_european_union': response.country.is_in_european_union, - 'latitude': response.location.latitude, - 'longitude': response.location.longitude, - 'postal_code': response.postal.code, - 'region': response.subdivisions[0].iso_code if response.subdivisions else None, - 'time_zone': response.location.time_zone, + "city": response.city.name, + "continent_code": response.continent.code, + "continent_name": response.continent.name, + "country_code": response.country.iso_code, + "country_name": response.country.name, + "dma_code": response.location.metro_code, + "is_in_european_union": response.country.is_in_european_union, + "latitude": response.location.latitude, + "longitude": response.location.longitude, + "postal_code": response.postal.code, + "region": response.subdivisions[0].iso_code if response.subdivisions else None, + "time_zone": response.location.time_zone, } def Country(response): return { - 'country_code': response.country.iso_code, - 'country_name': response.country.name, + "country_code": response.country.iso_code, + "country_name": response.country.name, } diff --git a/django/contrib/gis/geometry.py b/django/contrib/gis/geometry.py index f7a70618fa..0b289cf355 100644 --- a/django/contrib/gis/geometry.py +++ b/django/contrib/gis/geometry.py @@ -5,13 +5,13 @@ from django.utils.regex_helper import _lazy_re_compile # Regular expression for recognizing HEXEWKB and WKT. A prophylactic measure # to prevent potentially malicious input from reaching the underlying C # library. Not a substitute for good web security programming practices. -hex_regex = _lazy_re_compile(r'^[0-9A-F]+$', re.I) +hex_regex = _lazy_re_compile(r"^[0-9A-F]+$", re.I) wkt_regex = _lazy_re_compile( - r'^(SRID=(?P<srid>\-?[0-9]+);)?' - r'(?P<wkt>' - r'(?P<type>POINT|LINESTRING|LINEARRING|POLYGON|MULTIPOINT|' - r'MULTILINESTRING|MULTIPOLYGON|GEOMETRYCOLLECTION)' - r'[ACEGIMLONPSRUTYZ0-9,\.\-\+\(\) ]+)$', - re.I + r"^(SRID=(?P<srid>\-?[0-9]+);)?" + r"(?P<wkt>" + r"(?P<type>POINT|LINESTRING|LINEARRING|POLYGON|MULTIPOINT|" + r"MULTILINESTRING|MULTIPOLYGON|GEOMETRYCOLLECTION)" + r"[ACEGIMLONPSRUTYZ0-9,\.\-\+\(\) ]+)$", + re.I, ) -json_regex = _lazy_re_compile(r'^(\s+)?\{.*}(\s+)?$', re.DOTALL) +json_regex = _lazy_re_compile(r"^(\s+)?\{.*}(\s+)?$", re.DOTALL) diff --git a/django/contrib/gis/geos/__init__.py b/django/contrib/gis/geos/__init__.py index 65e7e54817..27de1ca8e6 100644 --- a/django/contrib/gis/geos/__init__.py +++ b/django/contrib/gis/geos/__init__.py @@ -3,7 +3,10 @@ The GeoDjango GEOS module. Please consult the GeoDjango documentation for more details: https://docs.djangoproject.com/en/dev/ref/contrib/gis/geos/ """ from .collections import ( # NOQA - GeometryCollection, MultiLineString, MultiPoint, MultiPolygon, + GeometryCollection, + MultiLineString, + MultiPoint, + MultiPolygon, ) from .error import GEOSException # NOQA from .factory import fromfile, fromstr # NOQA diff --git a/django/contrib/gis/geos/collections.py b/django/contrib/gis/geos/collections.py index 3c1f0c858e..abfec8af3d 100644 --- a/django/contrib/gis/geos/collections.py +++ b/django/contrib/gis/geos/collections.py @@ -46,11 +46,14 @@ class GeometryCollection(GEOSGeometry): # ### Methods for compatibility with ListMixin ### def _create_collection(self, length, items): # Creating the geometry pointer array. - geoms = (GEOM_PTR * length)(*[ - # this is a little sloppy, but makes life easier - # allow GEOSGeometry types (python wrappers) or pointer types - capi.geom_clone(getattr(g, 'ptr', g)) for g in items - ]) + geoms = (GEOM_PTR * length)( + *[ + # this is a little sloppy, but makes life easier + # allow GEOSGeometry types (python wrappers) or pointer types + capi.geom_clone(getattr(g, "ptr", g)) + for g in items + ] + ) return capi.create_collection(self._typeid, geoms, length) def _get_single_internal(self, index): @@ -59,7 +62,9 @@ class GeometryCollection(GEOSGeometry): def _get_single_external(self, index): "Return the Geometry from this Collection at the given index (0-based)." # Checking the index and returning the corresponding GEOS geometry. - return GEOSGeometry(capi.geom_clone(self._get_single_internal(index)), srid=self.srid) + return GEOSGeometry( + capi.geom_clone(self._get_single_internal(index)), srid=self.srid + ) def _set_list(self, length, items): "Create a new collection, and destroy the contents of the previous pointer." @@ -76,12 +81,13 @@ class GeometryCollection(GEOSGeometry): @property def kml(self): "Return the KML for this Geometry Collection." - return '<MultiGeometry>%s</MultiGeometry>' % ''.join(g.kml for g in self) + return "<MultiGeometry>%s</MultiGeometry>" % "".join(g.kml for g in self) @property def tuple(self): "Return a tuple of all the coordinates in this Geometry Collection" return tuple(g.tuple for g in self) + coords = tuple @@ -103,4 +109,12 @@ class MultiPolygon(GeometryCollection): # Setting the allowed types here since GeometryCollection is defined before # its subclasses. -GeometryCollection._allowed = (Point, LineString, LinearRing, Polygon, MultiPoint, MultiLineString, MultiPolygon) +GeometryCollection._allowed = ( + Point, + LineString, + LinearRing, + Polygon, + MultiPoint, + MultiLineString, + MultiPolygon, +) diff --git a/django/contrib/gis/geos/coordseq.py b/django/contrib/gis/geos/coordseq.py index d6a7c6cdc7..07a3b7d213 100644 --- a/django/contrib/gis/geos/coordseq.py +++ b/django/contrib/gis/geos/coordseq.py @@ -20,7 +20,7 @@ class GEOSCoordSeq(GEOSBase): def __init__(self, ptr, z=False): "Initialize from a GEOS pointer." if not isinstance(ptr, CS_PTR): - raise TypeError('Coordinate sequence should initialize with a CS_PTR.') + raise TypeError("Coordinate sequence should initialize with a CS_PTR.") self._ptr = ptr self._z = z @@ -50,7 +50,9 @@ class GEOSCoordSeq(GEOSBase): elif numpy and isinstance(value, numpy.ndarray): pass else: - raise TypeError('Must set coordinate with a sequence (list, tuple, or numpy array).') + raise TypeError( + "Must set coordinate with a sequence (list, tuple, or numpy array)." + ) # Checking the dims of the input if self.dims == 3 and self._z: n_args = 3 @@ -59,7 +61,7 @@ class GEOSCoordSeq(GEOSBase): n_args = 2 point_setter = self._set_point_2d if len(value) != n_args: - raise TypeError('Dimension of value does not match.') + raise TypeError("Dimension of value does not match.") self._checkindex(index) point_setter(index, value) @@ -67,7 +69,7 @@ class GEOSCoordSeq(GEOSBase): def _checkindex(self, index): "Check the given index." if not (0 <= index < self.size): - raise IndexError('invalid GEOS Geometry index: %s' % index) + raise IndexError("invalid GEOS Geometry index: %s" % index) def _checkdim(self, dim): "Check the given dimension." @@ -180,11 +182,13 @@ class GEOSCoordSeq(GEOSBase): # Getting the substitution string depending on whether the coordinates have # a Z dimension. if self.hasz: - substr = '%s,%s,%s ' + substr = "%s,%s,%s " else: - substr = '%s,%s,0 ' - return '<coordinates>%s</coordinates>' % \ - ''.join(substr % self[i] for i in range(len(self))).strip() + substr = "%s,%s,0 " + return ( + "<coordinates>%s</coordinates>" + % "".join(substr % self[i] for i in range(len(self))).strip() + ) @property def tuple(self): diff --git a/django/contrib/gis/geos/factory.py b/django/contrib/gis/geos/factory.py index 46ed1f0c51..2c1c24741b 100644 --- a/django/contrib/gis/geos/factory.py +++ b/django/contrib/gis/geos/factory.py @@ -8,7 +8,7 @@ def fromfile(file_h): """ # If given a file name, get a real handle. if isinstance(file_h, str): - with open(file_h, 'rb') as file_h: + with open(file_h, "rb") as file_h: buf = file_h.read() else: buf = file_h.read() diff --git a/django/contrib/gis/geos/geometry.py b/django/contrib/gis/geos/geometry.py index 36f6895224..9d22425fab 100644 --- a/django/contrib/gis/geos/geometry.py +++ b/django/contrib/gis/geos/geometry.py @@ -14,9 +14,7 @@ from django.contrib.gis.geos.error import GEOSException from django.contrib.gis.geos.libgeos import GEOM_PTR, geos_version_tuple from django.contrib.gis.geos.mutable_list import ListMixin from django.contrib.gis.geos.prepared import PreparedGeometry -from django.contrib.gis.geos.prototypes.io import ( - ewkb_w, wkb_r, wkb_w, wkt_r, wkt_w, -) +from django.contrib.gis.geos.prototypes.io import ewkb_w, wkb_r, wkb_w, wkt_r, wkt_w from django.utils.deconstruct import deconstructible from django.utils.encoding import force_bytes, force_str @@ -38,12 +36,15 @@ class GEOSGeometryBase(GEOSBase): if GEOSGeometryBase._GEOS_CLASSES is None: # Inner imports avoid import conflicts with GEOSGeometry. from .collections import ( - GeometryCollection, MultiLineString, MultiPoint, + GeometryCollection, + MultiLineString, + MultiPoint, MultiPolygon, ) from .linestring import LinearRing, LineString from .point import Point from .polygon import Polygon + GEOSGeometryBase._GEOS_CLASSES = { 0: Point, 1: LineString, @@ -62,7 +63,9 @@ class GEOSGeometryBase(GEOSBase): "Perform post-initialization setup." # Setting the coordinate sequence for the geometry (will be None on # geometries that do not have coordinate sequences) - self._cs = GEOSCoordSeq(capi.get_cs(self.ptr), self.hasz) if self.has_cs else None + self._cs = ( + GEOSCoordSeq(capi.get_cs(self.ptr), self.hasz) if self.has_cs else None + ) def __copy__(self): """ @@ -85,7 +88,7 @@ class GEOSGeometryBase(GEOSBase): def __repr__(self): "Short-hand representation because WKT may be very large." - return '<%s object at %s>' % (self.geom_type, hex(addressof(self.ptr))) + return "<%s object at %s>" % (self.geom_type, hex(addressof(self.ptr))) # Pickling support def _to_pickle_wkb(self): @@ -104,7 +107,7 @@ class GEOSGeometryBase(GEOSBase): wkb, srid = state ptr = self._from_pickle_wkb(wkb) if not ptr: - raise GEOSException('Invalid Geometry loaded from pickled state.') + raise GEOSException("Invalid Geometry loaded from pickled state.") self.ptr = ptr self._post_init() self.srid = srid @@ -117,17 +120,17 @@ class GEOSGeometryBase(GEOSBase): def from_ewkt(ewkt): ewkt = force_bytes(ewkt) srid = None - parts = ewkt.split(b';', 1) + parts = ewkt.split(b";", 1) if len(parts) == 2: srid_part, wkt = parts - match = re.match(br'SRID=(?P<srid>\-?\d+)', srid_part) + match = re.match(rb"SRID=(?P<srid>\-?\d+)", srid_part) if not match: - raise ValueError('EWKT has invalid SRID part.') - srid = int(match['srid']) + raise ValueError("EWKT has invalid SRID part.") + srid = int(match["srid"]) else: wkt = ewkt if not wkt: - raise ValueError('Expected WKT but got an empty string.') + raise ValueError("Expected WKT but got an empty string.") return GEOSGeometry(GEOSGeometry._from_wkt(wkt), srid=srid) @staticmethod @@ -149,7 +152,11 @@ class GEOSGeometryBase(GEOSBase): other = GEOSGeometry.from_ewkt(other) except (ValueError, GEOSException): return False - return isinstance(other, GEOSGeometry) and self.srid == other.srid and self.equals_exact(other) + return ( + isinstance(other, GEOSGeometry) + and self.srid == other.srid + and self.equals_exact(other) + ) def __hash__(self): return hash((self.srid, self.wkt)) @@ -225,7 +232,7 @@ class GEOSGeometryBase(GEOSBase): without losing any of the input vertices. """ if geos_version_tuple() < (3, 8): - raise GEOSException('GEOSGeometry.make_valid() requires GEOS >= 3.8.0.') + raise GEOSException("GEOSGeometry.make_valid() requires GEOS >= 3.8.0.") return GEOSGeometry(capi.geos_makevalid(self.ptr), srid=self.srid) # #### Unary predicates #### @@ -323,7 +330,7 @@ class GEOSGeometryBase(GEOSBase): two Geometries match the elements in pattern. """ if not isinstance(pattern, str) or len(pattern) > 9: - raise GEOSException('invalid intersection matrix pattern') + raise GEOSException("invalid intersection matrix pattern") return capi.geos_relatepattern(self.ptr, other.ptr, force_bytes(pattern)) def touches(self, other): @@ -362,7 +369,7 @@ class GEOSGeometryBase(GEOSBase): Return the EWKT (SRID + WKT) of the Geometry. """ srid = self.srid - return 'SRID=%s;%s' % (srid, self.wkt) if srid else self.wkt + return "SRID=%s;%s" % (srid, self.wkt) if srid else self.wkt @property def wkt(self): @@ -395,6 +402,7 @@ class GEOSGeometryBase(GEOSBase): Return GeoJSON representation of this Geometry. """ return self.ogr.json + geojson = json @property @@ -419,7 +427,7 @@ class GEOSGeometryBase(GEOSBase): def kml(self): "Return the KML representation of this Geometry." gtype = self.geom_type - return '<%s>%s</%s>' % (gtype, self.coord_seq.kml, gtype) + return "<%s>%s</%s>" % (gtype, self.coord_seq.kml, gtype) @property def prepared(self): @@ -493,7 +501,7 @@ class GEOSGeometryBase(GEOSBase): self._post_init() self.srid = g.srid else: - raise GEOSException('Transformed WKB was invalid.') + raise GEOSException("Transformed WKB was invalid.") # #### Topology Routines #### def _topology(self, gptr): @@ -515,7 +523,9 @@ class GEOSGeometryBase(GEOSBase): """ return self._topology(capi.geos_buffer(self.ptr, width, quadsegs)) - def buffer_with_style(self, width, quadsegs=8, end_cap_style=1, join_style=1, mitre_limit=5.0): + def buffer_with_style( + self, width, quadsegs=8, end_cap_style=1, join_style=1, mitre_limit=5.0 + ): """ Same as buffer() but allows customizing the style of the buffer. @@ -524,7 +534,9 @@ class GEOSGeometryBase(GEOSBase): Mitre ratio limit only affects mitered join style. """ return self._topology( - capi.geos_bufferwithstyle(self.ptr, width, quadsegs, end_cap_style, join_style, mitre_limit), + capi.geos_bufferwithstyle( + self.ptr, width, quadsegs, end_cap_style, join_style, mitre_limit + ), ) @property @@ -615,7 +627,7 @@ class GEOSGeometryBase(GEOSBase): the Geometry. """ if not isinstance(other, GEOSGeometry): - raise TypeError('distance() works only on other GEOS Geometries.') + raise TypeError("distance() works only on other GEOS Geometries.") return capi.geos_distance(self.ptr, other.ptr, byref(c_double())) @property @@ -625,6 +637,7 @@ class GEOSGeometryBase(GEOSBase): (xmin, ymin, xmax, ymax). """ from .point import Point + env = self.envelope if isinstance(env, Point): xmin, ymin = env.tuple @@ -651,6 +664,7 @@ class LinearGeometryMixin: """ Used for LineString and MultiLineString. """ + def interpolate(self, distance): return self._topology(capi.geos_interpolate(self.ptr, distance)) @@ -659,14 +673,16 @@ class LinearGeometryMixin: def project(self, point): from .point import Point + if not isinstance(point, Point): - raise TypeError('locate_point argument must be a Point') + raise TypeError("locate_point argument must be a Point") return capi.geos_project(self.ptr, point.ptr) def project_normalized(self, point): from .point import Point + if not isinstance(point, Point): - raise TypeError('locate_point argument must be a Point') + raise TypeError("locate_point argument must be a Point") return capi.geos_project_normalized(self.ptr, point.ptr) @property @@ -710,9 +726,9 @@ class GEOSGeometry(GEOSGeometryBase, ListMixin): wkt_m = wkt_regex.match(geo_input) if wkt_m: # Handle WKT input. - if wkt_m['srid']: - input_srid = int(wkt_m['srid']) - g = self._from_wkt(force_bytes(wkt_m['wkt'])) + if wkt_m["srid"]: + input_srid = int(wkt_m["srid"]) + g = self._from_wkt(force_bytes(wkt_m["wkt"])) elif hex_regex.match(geo_input): # Handle HEXEWKB input. g = wkb_r().read(force_bytes(geo_input)) @@ -722,7 +738,7 @@ class GEOSGeometry(GEOSGeometryBase, ListMixin): g = ogr._geos_ptr() input_srid = ogr.srid else: - raise ValueError('String input unrecognized as WKT EWKT, and HEXEWKB.') + raise ValueError("String input unrecognized as WKT EWKT, and HEXEWKB.") elif isinstance(geo_input, GEOM_PTR): # When the input is a pointer to a geometry (GEOM_PTR). g = geo_input @@ -732,14 +748,14 @@ class GEOSGeometry(GEOSGeometryBase, ListMixin): elif isinstance(geo_input, GEOSGeometry): g = capi.geom_clone(geo_input.ptr) else: - raise TypeError('Improper geometry input type: %s' % type(geo_input)) + raise TypeError("Improper geometry input type: %s" % type(geo_input)) if not g: - raise GEOSException('Could not initialize GEOS Geometry with given input.') + raise GEOSException("Could not initialize GEOS Geometry with given input.") input_srid = input_srid or capi.geos_get_srid(g) or None if input_srid and srid and input_srid != srid: - raise ValueError('Input geometry already has SRID: %d.' % input_srid) + raise ValueError("Input geometry already has SRID: %d." % input_srid) super().__init__(g, None) # Set the SRID, if given. diff --git a/django/contrib/gis/geos/io.py b/django/contrib/gis/geos/io.py index 15027ff68f..d7898065f0 100644 --- a/django/contrib/gis/geos/io.py +++ b/django/contrib/gis/geos/io.py @@ -5,10 +5,13 @@ reader and writer classes. """ from django.contrib.gis.geos.geometry import GEOSGeometry from django.contrib.gis.geos.prototypes.io import ( - WKBWriter, WKTWriter, _WKBReader, _WKTReader, + WKBWriter, + WKTWriter, + _WKBReader, + _WKTReader, ) -__all__ = ['WKBWriter', 'WKTWriter', 'WKBReader', 'WKTReader'] +__all__ = ["WKBWriter", "WKTWriter", "WKBReader", "WKTReader"] # Public classes for (WKB|WKT)Reader, which return GEOSGeometry diff --git a/django/contrib/gis/geos/libgeos.py b/django/contrib/gis/geos/libgeos.py index aa19b3a706..1121b4f715 100644 --- a/django/contrib/gis/geos/libgeos.py +++ b/django/contrib/gis/geos/libgeos.py @@ -15,13 +15,14 @@ from django.core.exceptions import ImproperlyConfigured from django.utils.functional import SimpleLazyObject, cached_property from django.utils.version import get_version_tuple -logger = logging.getLogger('django.contrib.gis') +logger = logging.getLogger("django.contrib.gis") def load_geos(): # Custom library path set? try: from django.conf import settings + lib_path = settings.GEOS_LIBRARY_PATH except (AttributeError, ImportError, ImproperlyConfigured, OSError): lib_path = None @@ -29,12 +30,12 @@ def load_geos(): # Setting the appropriate names for the GEOS-C library. if lib_path: lib_names = None - elif os.name == 'nt': + elif os.name == "nt": # Windows NT libraries - lib_names = ['geos_c', 'libgeos_c-1'] - elif os.name == 'posix': + lib_names = ["geos_c", "libgeos_c-1"] + elif os.name == "posix": # *NIX libraries - lib_names = ['geos_c', 'GEOS'] + lib_names = ["geos_c", "GEOS"] else: raise ImportError('Unsupported OS "%s"' % os.name) @@ -51,8 +52,7 @@ def load_geos(): if lib_path is None: raise ImportError( 'Could not find the GEOS library (tried "%s"). ' - 'Try setting GEOS_LIBRARY_PATH in your settings.' % - '", "'.join(lib_names) + "Try setting GEOS_LIBRARY_PATH in your settings." % '", "'.join(lib_names) ) # Getting the GEOS C library. The C interface (CDLL) is used for # both *NIX and Windows. @@ -82,7 +82,7 @@ def notice_h(fmt, lst): warn_msg = fmt % lst except TypeError: warn_msg = fmt - logger.warning('GEOS_NOTICE: %s\n', warn_msg) + logger.warning("GEOS_NOTICE: %s\n", warn_msg) notice_h = NOTICEFUNC(notice_h) @@ -96,7 +96,7 @@ def error_h(fmt, lst): err_msg = fmt % lst except TypeError: err_msg = fmt - logger.error('GEOS_ERROR: %s\n', err_msg) + logger.error("GEOS_ERROR: %s\n", err_msg) error_h = ERRORFUNC(error_h) @@ -135,6 +135,7 @@ class GEOSFuncFactory: """ Lazy loading of GEOS functions. """ + argtypes = None restype = None errcheck = None @@ -154,6 +155,7 @@ class GEOSFuncFactory: @cached_property def func(self): from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc + func = GEOSFunc(self.func_name) func.argtypes = self.argtypes or [] func.restype = self.restype diff --git a/django/contrib/gis/geos/linestring.py b/django/contrib/gis/geos/linestring.py index a85ecfff8c..78a265cebf 100644 --- a/django/contrib/gis/geos/linestring.py +++ b/django/contrib/gis/geos/linestring.py @@ -29,11 +29,15 @@ class LineString(LinearGeometryMixin, GEOSGeometry): else: coords = args - if not (isinstance(coords, (tuple, list)) or numpy and isinstance(coords, numpy.ndarray)): - raise TypeError('Invalid initialization input for LineStrings.') + if not ( + isinstance(coords, (tuple, list)) + or numpy + and isinstance(coords, numpy.ndarray) + ): + raise TypeError("Invalid initialization input for LineStrings.") # If SRID was passed in with the keyword arguments - srid = kwargs.get('srid') + srid = kwargs.get("srid") ncoords = len(coords) if not ncoords: @@ -42,7 +46,8 @@ class LineString(LinearGeometryMixin, GEOSGeometry): if ncoords < self._minlength: raise ValueError( - '%s requires at least %d points, got %s.' % ( + "%s requires at least %d points, got %s." + % ( self.__class__.__name__, self._minlength, ncoords, @@ -53,7 +58,7 @@ class LineString(LinearGeometryMixin, GEOSGeometry): if numpy_coords: shape = coords.shape # Using numpy's shape. if len(shape) != 2: - raise TypeError('Too many dimensions.') + raise TypeError("Too many dimensions.") self._checkdim(shape[1]) ndim = shape[1] else: @@ -63,13 +68,15 @@ class LineString(LinearGeometryMixin, GEOSGeometry): # Incrementing through each of the coordinates and verifying for coord in coords: if not isinstance(coord, (tuple, list, Point)): - raise TypeError('Each coordinate should be a sequence (list or tuple)') + raise TypeError( + "Each coordinate should be a sequence (list or tuple)" + ) if ndim is None: ndim = len(coord) self._checkdim(ndim) elif len(coord) != ndim: - raise TypeError('Dimension mismatch.') + raise TypeError("Dimension mismatch.") # Creating a coordinate sequence object because it is easier to # set the points using its methods. @@ -122,20 +129,21 @@ class LineString(LinearGeometryMixin, GEOSGeometry): self._post_init() else: # can this happen? - raise GEOSException('Geometry resulting from slice deletion was invalid.') + raise GEOSException("Geometry resulting from slice deletion was invalid.") def _set_single(self, index, value): self._cs[index] = value def _checkdim(self, dim): if dim not in (2, 3): - raise TypeError('Dimension mismatch.') + raise TypeError("Dimension mismatch.") # #### Sequence Properties #### @property def tuple(self): "Return a tuple version of the geometry from the coordinate sequence." return self._cs.tuple + coords = tuple def _listarr(self, func): @@ -181,7 +189,5 @@ class LinearRing(LineString): @property def is_counterclockwise(self): if self.empty: - raise ValueError( - 'Orientation of an empty LinearRing cannot be determined.' - ) + raise ValueError("Orientation of an empty LinearRing cannot be determined.") return self._cs.is_counterclockwise diff --git a/django/contrib/gis/geos/mutable_list.py b/django/contrib/gis/geos/mutable_list.py index b04a2128af..36131fe9ce 100644 --- a/django/contrib/gis/geos/mutable_list.py +++ b/django/contrib/gis/geos/mutable_list.py @@ -60,10 +60,10 @@ class ListMixin: # ### Python initialization and special list interface methods ### def __init__(self, *args, **kwargs): - if not hasattr(self, '_get_single_internal'): + if not hasattr(self, "_get_single_internal"): self._get_single_internal = self._get_single_external - if not hasattr(self, '_set_single'): + if not hasattr(self, "_set_single"): self._set_single = self._set_single_rebuild self._assign_extended_slice = self._assign_extended_slice_rebuild @@ -72,7 +72,9 @@ class ListMixin: def __getitem__(self, index): "Get the item(s) at the specified index/slice." if isinstance(index, slice): - return [self._get_single_external(i) for i in range(*index.indices(len(self)))] + return [ + self._get_single_external(i) for i in range(*index.indices(len(self))) + ] else: index = self._checkindex(index) return self._get_single_external(index) @@ -91,9 +93,9 @@ class ListMixin: indexRange = range(*index.indices(origLen)) newLen = origLen - len(indexRange) - newItems = (self._get_single_internal(i) - for i in range(origLen) - if i not in indexRange) + newItems = ( + self._get_single_internal(i) for i in range(origLen) if i not in indexRange + ) self._rebuild(newLen, newItems) @@ -108,28 +110,28 @@ class ListMixin: # ### Special methods for arithmetic operations ### def __add__(self, other): - 'add another list-like object' + "add another list-like object" return self.__class__([*self, *other]) def __radd__(self, other): - 'add to another list-like object' + "add to another list-like object" return other.__class__([*other, *self]) def __iadd__(self, other): - 'add another list-like object to self' + "add another list-like object to self" self.extend(other) return self def __mul__(self, n): - 'multiply' + "multiply" return self.__class__(list(self) * n) def __rmul__(self, n): - 'multiply' + "multiply" return self.__class__(list(self) * n) def __imul__(self, n): - 'multiply' + "multiply" if n <= 0: del self[:] else: @@ -179,16 +181,16 @@ class ListMixin: for i in range(0, len(self)): if self[i] == val: return i - raise ValueError('%s not found in object' % val) + raise ValueError("%s not found in object" % val) # ## Mutating ## def append(self, val): "Standard list append method" - self[len(self):] = [val] + self[len(self) :] = [val] def extend(self, vals): "Standard list extend method" - self[len(self):] = vals + self[len(self) :] = vals def insert(self, index, val): "Standard list insert method" @@ -217,9 +219,9 @@ class ListMixin: # ### Private routines ### def _rebuild(self, newLen, newItems): if newLen and newLen < self._minlength: - raise ValueError('Must have at least %d items' % self._minlength) + raise ValueError("Must have at least %d items" % self._minlength) if self._maxlength is not None and newLen > self._maxlength: - raise ValueError('Cannot have more than %d items' % self._maxlength) + raise ValueError("Cannot have more than %d items" % self._maxlength) self._set_list(newLen, newItems) @@ -232,19 +234,19 @@ class ListMixin: return index if -length <= index < 0: return index + length - raise IndexError('invalid index: %s' % index) + raise IndexError("invalid index: %s" % index) def _check_allowed(self, items): - if hasattr(self, '_allowed'): + if hasattr(self, "_allowed"): if False in [isinstance(val, self._allowed) for val in items]: - raise TypeError('Invalid type encountered in the arguments.') + raise TypeError("Invalid type encountered in the arguments.") def _set_slice(self, index, values): "Assign values to a slice of the object" try: valueList = list(values) except TypeError: - raise TypeError('can only assign an iterable to a slice') + raise TypeError("can only assign an iterable to a slice") self._check_allowed(valueList) @@ -259,13 +261,14 @@ class ListMixin: self._assign_extended_slice(start, stop, step, valueList) def _assign_extended_slice_rebuild(self, start, stop, step, valueList): - 'Assign an extended slice by rebuilding entire list' + "Assign an extended slice by rebuilding entire list" indexList = range(start, stop, step) # extended slice, only allow assigning slice of same size if len(valueList) != len(indexList): - raise ValueError('attempt to assign sequence of size %d ' - 'to extended slice of size %d' - % (len(valueList), len(indexList))) + raise ValueError( + "attempt to assign sequence of size %d " + "to extended slice of size %d" % (len(valueList), len(indexList)) + ) # we're not changing the length of the sequence newLen = len(self) @@ -281,19 +284,20 @@ class ListMixin: self._rebuild(newLen, newItems()) def _assign_extended_slice(self, start, stop, step, valueList): - 'Assign an extended slice by re-assigning individual items' + "Assign an extended slice by re-assigning individual items" indexList = range(start, stop, step) # extended slice, only allow assigning slice of same size if len(valueList) != len(indexList): - raise ValueError('attempt to assign sequence of size %d ' - 'to extended slice of size %d' - % (len(valueList), len(indexList))) + raise ValueError( + "attempt to assign sequence of size %d " + "to extended slice of size %d" % (len(valueList), len(indexList)) + ) for i, val in zip(indexList, valueList): self._set_single(i, val) def _assign_simple_slice(self, start, stop, valueList): - 'Assign a simple slice; Can assign slice of any length' + "Assign a simple slice; Can assign slice of any length" origLen = len(self) stop = max(start, stop) newLen = origLen - stop + start + len(valueList) diff --git a/django/contrib/gis/geos/point.py b/django/contrib/gis/geos/point.py index 00fdd96a27..06b2668621 100644 --- a/django/contrib/gis/geos/point.py +++ b/django/contrib/gis/geos/point.py @@ -32,7 +32,7 @@ class Point(GEOSGeometry): else: coords = [x, y] else: - raise TypeError('Invalid parameters given for Point initialization.') + raise TypeError("Invalid parameters given for Point initialization.") point = self._create_point(len(coords), coords) @@ -47,7 +47,9 @@ class Point(GEOSGeometry): return self._create_empty() if wkb is None else super()._from_pickle_wkb(wkb) def _ogr_ptr(self): - return gdal.geometries.Point._create_empty() if self.empty else super()._ogr_ptr() + return ( + gdal.geometries.Point._create_empty() if self.empty else super()._ogr_ptr() + ) @classmethod def _create_empty(cls): @@ -62,7 +64,7 @@ class Point(GEOSGeometry): return capi.create_point(None) if ndim < 2 or ndim > 3: - raise TypeError('Invalid point dimension: %s' % ndim) + raise TypeError("Invalid point dimension: %s" % ndim) cs = capi.create_cs(c_uint(1), c_uint(ndim)) i = iter(coords) @@ -84,7 +86,7 @@ class Point(GEOSGeometry): self._post_init() else: # can this happen? - raise GEOSException('Geometry resulting from slice deletion was invalid.') + raise GEOSException("Geometry resulting from slice deletion was invalid.") def _set_single(self, index, value): self._cs.setOrdinate(index, 0, value) @@ -142,7 +144,7 @@ class Point(GEOSGeometry): def z(self, value): "Set the Z component of the Point." if not self.hasz: - raise GEOSException('Cannot set Z on 2D Point.') + raise GEOSException("Cannot set Z on 2D Point.") self._cs.setOrdinate(2, 0, value) # ### Tuple setting and retrieval routines. ### diff --git a/django/contrib/gis/geos/polygon.py b/django/contrib/gis/geos/polygon.py index d75106a9a8..452e72fcb6 100644 --- a/django/contrib/gis/geos/polygon.py +++ b/django/contrib/gis/geos/polygon.py @@ -59,8 +59,10 @@ class Polygon(GEOSGeometry): x0, y0, x1, y1 = bbox for z in bbox: if not isinstance(z, (float, int)): - return GEOSGeometry('POLYGON((%s %s, %s %s, %s %s, %s %s, %s %s))' % - (x0, y0, x0, y1, x1, y1, x1, y0, x0, y0)) + return GEOSGeometry( + "POLYGON((%s %s, %s %s, %s %s, %s %s, %s %s))" + % (x0, y0, x0, y1, x1, y1, x1, y0, x0, y0) + ) return Polygon(((x0, y0), (x0, y1), (x1, y1), (x1, y0), (x0, y0))) # ### These routines are needed for list-like operation w/ListMixin ### @@ -95,8 +97,13 @@ class Polygon(GEOSGeometry): else: return capi.geom_clone(g.ptr) - def _construct_ring(self, param, msg=( - 'Parameter must be a sequence of LinearRings or objects that can initialize to LinearRings')): + def _construct_ring( + self, + param, + msg=( + "Parameter must be a sequence of LinearRings or objects that can initialize to LinearRings" + ), + ): "Try to construct a ring from the given parameter." if isinstance(param, LinearRing): return param @@ -135,7 +142,9 @@ class Polygon(GEOSGeometry): return capi.get_intring(self.ptr, index - 1) def _get_single_external(self, index): - return GEOSGeometry(capi.geom_clone(self._get_single_internal(index)), srid=self.srid) + return GEOSGeometry( + capi.geom_clone(self._get_single_internal(index)), srid=self.srid + ) _set_single = GEOSGeometry._set_single_rebuild _assign_extended_slice = GEOSGeometry._assign_extended_slice_rebuild @@ -163,13 +172,17 @@ class Polygon(GEOSGeometry): def tuple(self): "Get the tuple for each ring in this Polygon." return tuple(self[i].tuple for i in range(len(self))) + coords = tuple @property def kml(self): "Return the KML representation of this Polygon." - inner_kml = ''.join( + inner_kml = "".join( "<innerBoundaryIs>%s</innerBoundaryIs>" % self[i + 1].kml for i in range(self.num_interior_rings) ) - return "<Polygon><outerBoundaryIs>%s</outerBoundaryIs>%s</Polygon>" % (self[0].kml, inner_kml) + return "<Polygon><outerBoundaryIs>%s</outerBoundaryIs>%s</Polygon>" % ( + self[0].kml, + inner_kml, + ) diff --git a/django/contrib/gis/geos/prepared.py b/django/contrib/gis/geos/prepared.py index 789432f81a..9c77d8a533 100644 --- a/django/contrib/gis/geos/prepared.py +++ b/django/contrib/gis/geos/prepared.py @@ -8,6 +8,7 @@ class PreparedGeometry(GEOSBase): At the moment this includes the contains covers, and intersects operations. """ + ptr_type = capi.PREPGEOM_PTR destructor = capi.prepared_destroy @@ -17,6 +18,7 @@ class PreparedGeometry(GEOSBase): # See #21662 self._base_geom = geom from .geometry import GEOSGeometry + if not isinstance(geom, GEOSGeometry): raise TypeError self.ptr = capi.geos_prepare(geom.ptr) diff --git a/django/contrib/gis/geos/prototypes/__init__.py b/django/contrib/gis/geos/prototypes/__init__.py index a79533fa14..8fa98f98e7 100644 --- a/django/contrib/gis/geos/prototypes/__init__.py +++ b/django/contrib/gis/geos/prototypes/__init__.py @@ -5,22 +5,62 @@ """ from django.contrib.gis.geos.prototypes.coordseq import ( # NOQA - create_cs, cs_clone, cs_getdims, cs_getordinate, cs_getsize, cs_getx, - cs_gety, cs_getz, cs_is_ccw, cs_setordinate, cs_setx, cs_sety, cs_setz, + create_cs, + cs_clone, + cs_getdims, + cs_getordinate, + cs_getsize, + cs_getx, + cs_gety, + cs_getz, + cs_is_ccw, + cs_setordinate, + cs_setx, + cs_sety, + cs_setz, get_cs, ) from django.contrib.gis.geos.prototypes.geom import ( # NOQA - create_collection, create_empty_polygon, create_linearring, - create_linestring, create_point, create_polygon, destroy_geom, geom_clone, - geos_get_srid, geos_makevalid, geos_normalize, geos_set_srid, geos_type, - geos_typeid, get_dims, get_extring, get_geomn, get_intring, get_nrings, - get_num_coords, get_num_geoms, + create_collection, + create_empty_polygon, + create_linearring, + create_linestring, + create_point, + create_polygon, + destroy_geom, + geom_clone, + geos_get_srid, + geos_makevalid, + geos_normalize, + geos_set_srid, + geos_type, + geos_typeid, + get_dims, + get_extring, + get_geomn, + get_intring, + get_nrings, + get_num_coords, + get_num_geoms, ) from django.contrib.gis.geos.prototypes.misc import * # NOQA from django.contrib.gis.geos.prototypes.predicates import ( # NOQA - geos_contains, geos_covers, geos_crosses, geos_disjoint, geos_equals, - geos_equalsexact, geos_hasz, geos_intersects, geos_isclosed, geos_isempty, - geos_isring, geos_issimple, geos_isvalid, geos_overlaps, - geos_relatepattern, geos_touches, geos_within, + geos_contains, + geos_covers, + geos_crosses, + geos_disjoint, + geos_equals, + geos_equalsexact, + geos_hasz, + geos_intersects, + geos_isclosed, + geos_isempty, + geos_isring, + geos_issimple, + geos_isvalid, + geos_overlaps, + geos_relatepattern, + geos_touches, + geos_within, ) from django.contrib.gis.geos.prototypes.topology import * # NOQA diff --git a/django/contrib/gis/geos/prototypes/coordseq.py b/django/contrib/gis/geos/prototypes/coordseq.py index aab5d3e75e..ed05de99f4 100644 --- a/django/contrib/gis/geos/prototypes/coordseq.py +++ b/django/contrib/gis/geos/prototypes/coordseq.py @@ -1,16 +1,14 @@ from ctypes import POINTER, c_byte, c_double, c_int, c_uint from django.contrib.gis.geos.libgeos import CS_PTR, GEOM_PTR, GEOSFuncFactory -from django.contrib.gis.geos.prototypes.errcheck import ( - GEOSException, last_arg_byref, -) +from django.contrib.gis.geos.prototypes.errcheck import GEOSException, last_arg_byref # ## Error-checking routines specific to coordinate sequences. ## def check_cs_op(result, func, cargs): "Check the status code of a coordinate sequence operation." if result == 0: - raise GEOSException('Could not set value on coordinate sequence') + raise GEOSException("Could not set value on coordinate sequence") else: return result @@ -49,7 +47,9 @@ class CsOperation(GEOSFuncFactory): else: argtypes = [CS_PTR, c_uint, dbl_param] - super().__init__(*args, **{**kwargs, 'errcheck': errcheck, 'argtypes': argtypes}) + super().__init__( + *args, **{**kwargs, "errcheck": errcheck, "argtypes": argtypes} + ) class CsOutput(GEOSFuncFactory): @@ -59,7 +59,7 @@ class CsOutput(GEOSFuncFactory): def errcheck(result, func, cargs): if not result: raise GEOSException( - 'Error encountered checking Coordinate Sequence returned from GEOS ' + "Error encountered checking Coordinate Sequence returned from GEOS " 'C function "%s".' % func.__name__ ) return result @@ -68,26 +68,28 @@ class CsOutput(GEOSFuncFactory): # ## Coordinate Sequence ctypes prototypes ## # Coordinate Sequence constructors & cloning. -cs_clone = CsOutput('GEOSCoordSeq_clone', argtypes=[CS_PTR]) -create_cs = CsOutput('GEOSCoordSeq_create', argtypes=[c_uint, c_uint]) -get_cs = CsOutput('GEOSGeom_getCoordSeq', argtypes=[GEOM_PTR]) +cs_clone = CsOutput("GEOSCoordSeq_clone", argtypes=[CS_PTR]) +create_cs = CsOutput("GEOSCoordSeq_create", argtypes=[c_uint, c_uint]) +get_cs = CsOutput("GEOSGeom_getCoordSeq", argtypes=[GEOM_PTR]) # Getting, setting ordinate -cs_getordinate = CsOperation('GEOSCoordSeq_getOrdinate', ordinate=True, get=True) -cs_setordinate = CsOperation('GEOSCoordSeq_setOrdinate', ordinate=True) +cs_getordinate = CsOperation("GEOSCoordSeq_getOrdinate", ordinate=True, get=True) +cs_setordinate = CsOperation("GEOSCoordSeq_setOrdinate", ordinate=True) # For getting, x, y, z -cs_getx = CsOperation('GEOSCoordSeq_getX', get=True) -cs_gety = CsOperation('GEOSCoordSeq_getY', get=True) -cs_getz = CsOperation('GEOSCoordSeq_getZ', get=True) +cs_getx = CsOperation("GEOSCoordSeq_getX", get=True) +cs_gety = CsOperation("GEOSCoordSeq_getY", get=True) +cs_getz = CsOperation("GEOSCoordSeq_getZ", get=True) # For setting, x, y, z -cs_setx = CsOperation('GEOSCoordSeq_setX') -cs_sety = CsOperation('GEOSCoordSeq_setY') -cs_setz = CsOperation('GEOSCoordSeq_setZ') +cs_setx = CsOperation("GEOSCoordSeq_setX") +cs_sety = CsOperation("GEOSCoordSeq_setY") +cs_setz = CsOperation("GEOSCoordSeq_setZ") # These routines return size & dimensions. -cs_getsize = CsInt('GEOSCoordSeq_getSize') -cs_getdims = CsInt('GEOSCoordSeq_getDimensions') +cs_getsize = CsInt("GEOSCoordSeq_getSize") +cs_getdims = CsInt("GEOSCoordSeq_getDimensions") -cs_is_ccw = GEOSFuncFactory('GEOSCoordSeq_isCCW', restype=c_int, argtypes=[CS_PTR, POINTER(c_byte)]) +cs_is_ccw = GEOSFuncFactory( + "GEOSCoordSeq_isCCW", restype=c_int, argtypes=[CS_PTR, POINTER(c_byte)] +) diff --git a/django/contrib/gis/geos/prototypes/errcheck.py b/django/contrib/gis/geos/prototypes/errcheck.py index 7d5f8422e0..a527f513a7 100644 --- a/django/contrib/gis/geos/prototypes/errcheck.py +++ b/django/contrib/gis/geos/prototypes/errcheck.py @@ -8,7 +8,7 @@ from django.contrib.gis.geos.libgeos import GEOSFuncFactory # Getting the `free` routine used to free the memory allocated for # string pointers returned by GEOS. -free = GEOSFuncFactory('GEOSFree') +free = GEOSFuncFactory("GEOSFree") free.argtypes = [c_void_p] @@ -29,14 +29,19 @@ def check_dbl(result, func, cargs): def check_geom(result, func, cargs): "Error checking on routines that return Geometries." if not result: - raise GEOSException('Error encountered checking Geometry returned from GEOS C function "%s".' % func.__name__) + raise GEOSException( + 'Error encountered checking Geometry returned from GEOS C function "%s".' + % func.__name__ + ) return result def check_minus_one(result, func, cargs): "Error checking on routines that should not return -1." if result == -1: - raise GEOSException('Error encountered in GEOS C function "%s".' % func.__name__) + raise GEOSException( + 'Error encountered in GEOS C function "%s".' % func.__name__ + ) else: return result @@ -48,7 +53,9 @@ def check_predicate(result, func, cargs): elif result == 0: return False else: - raise GEOSException('Error encountered on GEOS C predicate function "%s".' % func.__name__) + raise GEOSException( + 'Error encountered on GEOS C predicate function "%s".' % func.__name__ + ) def check_sized_string(result, func, cargs): @@ -58,7 +65,9 @@ def check_sized_string(result, func, cargs): This frees the memory allocated by GEOS at the result pointer. """ if not result: - raise GEOSException('Invalid string pointer returned by GEOS C function "%s"' % func.__name__) + raise GEOSException( + 'Invalid string pointer returned by GEOS C function "%s"' % func.__name__ + ) # A c_size_t object is passed in by reference for the second # argument on these routines, and its needed to determine the # correct size. @@ -75,7 +84,10 @@ def check_string(result, func, cargs): This frees the memory allocated by GEOS at the result pointer. """ if not result: - raise GEOSException('Error encountered checking string return value in GEOS C function "%s".' % func.__name__) + raise GEOSException( + 'Error encountered checking string return value in GEOS C function "%s".' + % func.__name__ + ) # Getting the string value at the pointer address. s = string_at(result) # Freeing the memory allocated within GEOS diff --git a/django/contrib/gis/geos/prototypes/geom.py b/django/contrib/gis/geos/prototypes/geom.py index bddf7b55b2..a456acd0c1 100644 --- a/django/contrib/gis/geos/prototypes/geom.py +++ b/django/contrib/gis/geos/prototypes/geom.py @@ -2,7 +2,9 @@ from ctypes import POINTER, c_char_p, c_int, c_ubyte, c_uint from django.contrib.gis.geos.libgeos import CS_PTR, GEOM_PTR, GEOSFuncFactory from django.contrib.gis.geos.prototypes.errcheck import ( - check_geom, check_minus_one, check_string, + check_geom, + check_minus_one, + check_string, ) # This is the return type used by binary output (WKB, HEX) routines. @@ -44,44 +46,46 @@ class StringFromGeom(GEOSFuncFactory): # ### ctypes prototypes ### # The GEOS geometry type, typeid, num_coordinates and number of geometries -geos_makevalid = GeomOutput('GEOSMakeValid', argtypes=[GEOM_PTR]) -geos_normalize = IntFromGeom('GEOSNormalize') -geos_type = StringFromGeom('GEOSGeomType') -geos_typeid = IntFromGeom('GEOSGeomTypeId') -get_dims = GEOSFuncFactory('GEOSGeom_getDimensions', argtypes=[GEOM_PTR], restype=c_int) -get_num_coords = IntFromGeom('GEOSGetNumCoordinates') -get_num_geoms = IntFromGeom('GEOSGetNumGeometries') +geos_makevalid = GeomOutput("GEOSMakeValid", argtypes=[GEOM_PTR]) +geos_normalize = IntFromGeom("GEOSNormalize") +geos_type = StringFromGeom("GEOSGeomType") +geos_typeid = IntFromGeom("GEOSGeomTypeId") +get_dims = GEOSFuncFactory("GEOSGeom_getDimensions", argtypes=[GEOM_PTR], restype=c_int) +get_num_coords = IntFromGeom("GEOSGetNumCoordinates") +get_num_geoms = IntFromGeom("GEOSGetNumGeometries") # Geometry creation factories -create_point = GeomOutput('GEOSGeom_createPoint', argtypes=[CS_PTR]) -create_linestring = GeomOutput('GEOSGeom_createLineString', argtypes=[CS_PTR]) -create_linearring = GeomOutput('GEOSGeom_createLinearRing', argtypes=[CS_PTR]) +create_point = GeomOutput("GEOSGeom_createPoint", argtypes=[CS_PTR]) +create_linestring = GeomOutput("GEOSGeom_createLineString", argtypes=[CS_PTR]) +create_linearring = GeomOutput("GEOSGeom_createLinearRing", argtypes=[CS_PTR]) # Polygon and collection creation routines need argument types defined # for compatibility with some platforms, e.g. macOS ARM64. With argtypes # defined, arrays are automatically cast and byref() calls are not needed. create_polygon = GeomOutput( - 'GEOSGeom_createPolygon', argtypes=[GEOM_PTR, POINTER(GEOM_PTR), c_uint], + "GEOSGeom_createPolygon", + argtypes=[GEOM_PTR, POINTER(GEOM_PTR), c_uint], ) -create_empty_polygon = GeomOutput('GEOSGeom_createEmptyPolygon', argtypes=[]) +create_empty_polygon = GeomOutput("GEOSGeom_createEmptyPolygon", argtypes=[]) create_collection = GeomOutput( - 'GEOSGeom_createCollection', argtypes=[c_int, POINTER(GEOM_PTR), c_uint], + "GEOSGeom_createCollection", + argtypes=[c_int, POINTER(GEOM_PTR), c_uint], ) # Ring routines -get_extring = GeomOutput('GEOSGetExteriorRing', argtypes=[GEOM_PTR]) -get_intring = GeomOutput('GEOSGetInteriorRingN', argtypes=[GEOM_PTR, c_int]) -get_nrings = IntFromGeom('GEOSGetNumInteriorRings') +get_extring = GeomOutput("GEOSGetExteriorRing", argtypes=[GEOM_PTR]) +get_intring = GeomOutput("GEOSGetInteriorRingN", argtypes=[GEOM_PTR, c_int]) +get_nrings = IntFromGeom("GEOSGetNumInteriorRings") # Collection Routines -get_geomn = GeomOutput('GEOSGetGeometryN', argtypes=[GEOM_PTR, c_int]) +get_geomn = GeomOutput("GEOSGetGeometryN", argtypes=[GEOM_PTR, c_int]) # Cloning -geom_clone = GEOSFuncFactory('GEOSGeom_clone', argtypes=[GEOM_PTR], restype=GEOM_PTR) +geom_clone = GEOSFuncFactory("GEOSGeom_clone", argtypes=[GEOM_PTR], restype=GEOM_PTR) # Destruction routine. -destroy_geom = GEOSFuncFactory('GEOSGeom_destroy', argtypes=[GEOM_PTR]) +destroy_geom = GEOSFuncFactory("GEOSGeom_destroy", argtypes=[GEOM_PTR]) # SRID routines -geos_get_srid = GEOSFuncFactory('GEOSGetSRID', argtypes=[GEOM_PTR], restype=c_int) -geos_set_srid = GEOSFuncFactory('GEOSSetSRID', argtypes=[GEOM_PTR, c_int]) +geos_get_srid = GEOSFuncFactory("GEOSGetSRID", argtypes=[GEOM_PTR], restype=c_int) +geos_set_srid = GEOSFuncFactory("GEOSSetSRID", argtypes=[GEOM_PTR, c_int]) diff --git a/django/contrib/gis/geos/prototypes/io.py b/django/contrib/gis/geos/prototypes/io.py index 4a1180ad7a..555d0e3a25 100644 --- a/django/contrib/gis/geos/prototypes/io.py +++ b/django/contrib/gis/geos/prototypes/io.py @@ -3,10 +3,14 @@ from ctypes import POINTER, Structure, byref, c_byte, c_char_p, c_int, c_size_t from django.contrib.gis.geos.base import GEOSBase from django.contrib.gis.geos.libgeos import ( - GEOM_PTR, GEOSFuncFactory, geos_version_tuple, + GEOM_PTR, + GEOSFuncFactory, + geos_version_tuple, ) from django.contrib.gis.geos.prototypes.errcheck import ( - check_geom, check_sized_string, check_string, + check_geom, + check_sized_string, + check_string, ) from django.contrib.gis.geos.prototypes.geom import c_uchar_p, geos_char_p from django.utils.encoding import force_bytes @@ -35,33 +39,43 @@ WKB_READ_PTR = POINTER(WKBReader_st) WKB_WRITE_PTR = POINTER(WKBReader_st) # WKTReader routines -wkt_reader_create = GEOSFuncFactory('GEOSWKTReader_create', restype=WKT_READ_PTR) -wkt_reader_destroy = GEOSFuncFactory('GEOSWKTReader_destroy', argtypes=[WKT_READ_PTR]) +wkt_reader_create = GEOSFuncFactory("GEOSWKTReader_create", restype=WKT_READ_PTR) +wkt_reader_destroy = GEOSFuncFactory("GEOSWKTReader_destroy", argtypes=[WKT_READ_PTR]) wkt_reader_read = GEOSFuncFactory( - 'GEOSWKTReader_read', argtypes=[WKT_READ_PTR, c_char_p], restype=GEOM_PTR, errcheck=check_geom + "GEOSWKTReader_read", + argtypes=[WKT_READ_PTR, c_char_p], + restype=GEOM_PTR, + errcheck=check_geom, ) # WKTWriter routines -wkt_writer_create = GEOSFuncFactory('GEOSWKTWriter_create', restype=WKT_WRITE_PTR) -wkt_writer_destroy = GEOSFuncFactory('GEOSWKTWriter_destroy', argtypes=[WKT_WRITE_PTR]) +wkt_writer_create = GEOSFuncFactory("GEOSWKTWriter_create", restype=WKT_WRITE_PTR) +wkt_writer_destroy = GEOSFuncFactory("GEOSWKTWriter_destroy", argtypes=[WKT_WRITE_PTR]) wkt_writer_write = GEOSFuncFactory( - 'GEOSWKTWriter_write', argtypes=[WKT_WRITE_PTR, GEOM_PTR], restype=geos_char_p, errcheck=check_string + "GEOSWKTWriter_write", + argtypes=[WKT_WRITE_PTR, GEOM_PTR], + restype=geos_char_p, + errcheck=check_string, ) wkt_writer_get_outdim = GEOSFuncFactory( - 'GEOSWKTWriter_getOutputDimension', argtypes=[WKT_WRITE_PTR], restype=c_int + "GEOSWKTWriter_getOutputDimension", argtypes=[WKT_WRITE_PTR], restype=c_int ) wkt_writer_set_outdim = GEOSFuncFactory( - 'GEOSWKTWriter_setOutputDimension', argtypes=[WKT_WRITE_PTR, c_int] + "GEOSWKTWriter_setOutputDimension", argtypes=[WKT_WRITE_PTR, c_int] ) -wkt_writer_set_trim = GEOSFuncFactory('GEOSWKTWriter_setTrim', argtypes=[WKT_WRITE_PTR, c_byte]) -wkt_writer_set_precision = GEOSFuncFactory('GEOSWKTWriter_setRoundingPrecision', argtypes=[WKT_WRITE_PTR, c_int]) +wkt_writer_set_trim = GEOSFuncFactory( + "GEOSWKTWriter_setTrim", argtypes=[WKT_WRITE_PTR, c_byte] +) +wkt_writer_set_precision = GEOSFuncFactory( + "GEOSWKTWriter_setRoundingPrecision", argtypes=[WKT_WRITE_PTR, c_int] +) # WKBReader routines -wkb_reader_create = GEOSFuncFactory('GEOSWKBReader_create', restype=WKB_READ_PTR) -wkb_reader_destroy = GEOSFuncFactory('GEOSWKBReader_destroy', argtypes=[WKB_READ_PTR]) +wkb_reader_create = GEOSFuncFactory("GEOSWKBReader_create", restype=WKB_READ_PTR) +wkb_reader_destroy = GEOSFuncFactory("GEOSWKBReader_destroy", argtypes=[WKB_READ_PTR]) class WKBReadFunc(GEOSFuncFactory): @@ -75,12 +89,12 @@ class WKBReadFunc(GEOSFuncFactory): errcheck = staticmethod(check_geom) -wkb_reader_read = WKBReadFunc('GEOSWKBReader_read') -wkb_reader_read_hex = WKBReadFunc('GEOSWKBReader_readHEX') +wkb_reader_read = WKBReadFunc("GEOSWKBReader_read") +wkb_reader_read_hex = WKBReadFunc("GEOSWKBReader_readHEX") # WKBWriter routines -wkb_writer_create = GEOSFuncFactory('GEOSWKBWriter_create', restype=WKB_WRITE_PTR) -wkb_writer_destroy = GEOSFuncFactory('GEOSWKBWriter_destroy', argtypes=[WKB_WRITE_PTR]) +wkb_writer_create = GEOSFuncFactory("GEOSWKBWriter_create", restype=WKB_WRITE_PTR) +wkb_writer_destroy = GEOSFuncFactory("GEOSWKBWriter_destroy", argtypes=[WKB_WRITE_PTR]) # WKB Writing prototypes. @@ -90,8 +104,8 @@ class WKBWriteFunc(GEOSFuncFactory): errcheck = staticmethod(check_sized_string) -wkb_writer_write = WKBWriteFunc('GEOSWKBWriter_write') -wkb_writer_write_hex = WKBWriteFunc('GEOSWKBWriter_writeHEX') +wkb_writer_write = WKBWriteFunc("GEOSWKBWriter_write") +wkb_writer_write_hex = WKBWriteFunc("GEOSWKBWriter_writeHEX") # WKBWriter property getter/setter prototypes. @@ -104,17 +118,22 @@ class WKBWriterSet(GEOSFuncFactory): argtypes = [WKB_WRITE_PTR, c_int] -wkb_writer_get_byteorder = WKBWriterGet('GEOSWKBWriter_getByteOrder') -wkb_writer_set_byteorder = WKBWriterSet('GEOSWKBWriter_setByteOrder') -wkb_writer_get_outdim = WKBWriterGet('GEOSWKBWriter_getOutputDimension') -wkb_writer_set_outdim = WKBWriterSet('GEOSWKBWriter_setOutputDimension') -wkb_writer_get_include_srid = WKBWriterGet('GEOSWKBWriter_getIncludeSRID', restype=c_byte) -wkb_writer_set_include_srid = WKBWriterSet('GEOSWKBWriter_setIncludeSRID', argtypes=[WKB_WRITE_PTR, c_byte]) +wkb_writer_get_byteorder = WKBWriterGet("GEOSWKBWriter_getByteOrder") +wkb_writer_set_byteorder = WKBWriterSet("GEOSWKBWriter_setByteOrder") +wkb_writer_get_outdim = WKBWriterGet("GEOSWKBWriter_getOutputDimension") +wkb_writer_set_outdim = WKBWriterSet("GEOSWKBWriter_setOutputDimension") +wkb_writer_get_include_srid = WKBWriterGet( + "GEOSWKBWriter_getIncludeSRID", restype=c_byte +) +wkb_writer_set_include_srid = WKBWriterSet( + "GEOSWKBWriter_setIncludeSRID", argtypes=[WKB_WRITE_PTR, c_byte] +) # ### Base I/O Class ### class IOBase(GEOSBase): "Base class for GEOS I/O objects." + def __init__(self): # Getting the pointer with the constructor. self.ptr = self._constructor() @@ -122,6 +141,7 @@ class IOBase(GEOSBase): # __del__ is too late (import error). self.destructor.func + # ### Base WKB/WKT Reading and Writing objects ### @@ -183,7 +203,7 @@ class WKTWriter(IOBase): @outdim.setter def outdim(self, new_dim): if new_dim not in (2, 3): - raise ValueError('WKT output dimension must be 2 or 3') + raise ValueError("WKT output dimension must be 2 or 3") wkt_writer_set_outdim(self.ptr, new_dim) @property @@ -203,7 +223,9 @@ class WKTWriter(IOBase): @precision.setter def precision(self, precision): if (not isinstance(precision, int) or precision < 0) and precision is not None: - raise AttributeError('WKT output rounding precision must be non-negative integer or None.') + raise AttributeError( + "WKT output rounding precision must be non-negative integer or None." + ) if precision != self._precision: self._precision = precision wkt_writer_set_precision(self.ptr, -1 if precision is None else precision) @@ -221,34 +243,37 @@ class WKBWriter(IOBase): def _handle_empty_point(self, geom): from django.contrib.gis.geos import Point + if isinstance(geom, Point) and geom.empty: if self.srid: # PostGIS uses POINT(NaN NaN) for WKB representation of empty # points. Use it for EWKB as it's a PostGIS specific format. # https://trac.osgeo.org/postgis/ticket/3181 - geom = Point(float('NaN'), float('NaN'), srid=geom.srid) + geom = Point(float("NaN"), float("NaN"), srid=geom.srid) else: - raise ValueError('Empty point is not representable in WKB.') + raise ValueError("Empty point is not representable in WKB.") return geom def write(self, geom): "Return the WKB representation of the given geometry." from django.contrib.gis.geos import Polygon + geom = self._handle_empty_point(geom) wkb = wkb_writer_write(self.ptr, geom.ptr, byref(c_size_t())) if self.geos_version < (3, 6, 1) and isinstance(geom, Polygon) and geom.empty: # Fix GEOS output for empty polygon. # See https://trac.osgeo.org/geos/ticket/680. - wkb = wkb[:-8] + b'\0' * 4 + wkb = wkb[:-8] + b"\0" * 4 return memoryview(wkb) def write_hex(self, geom): "Return the HEXEWKB representation of the given geometry." from django.contrib.gis.geos.polygon import Polygon + geom = self._handle_empty_point(geom) wkb = wkb_writer_write_hex(self.ptr, geom.ptr, byref(c_size_t())) if self.geos_version < (3, 6, 1) and isinstance(geom, Polygon) and geom.empty: - wkb = wkb[:-16] + b'0' * 8 + wkb = wkb[:-16] + b"0" * 8 return wkb # ### WKBWriter Properties ### @@ -259,7 +284,9 @@ class WKBWriter(IOBase): def _set_byteorder(self, order): if order not in (0, 1): - raise ValueError('Byte order parameter must be 0 (Big Endian) or 1 (Little Endian).') + raise ValueError( + "Byte order parameter must be 0 (Big Endian) or 1 (Little Endian)." + ) wkb_writer_set_byteorder(self.ptr, order) byteorder = property(_get_byteorder, _set_byteorder) @@ -272,7 +299,7 @@ class WKBWriter(IOBase): @outdim.setter def outdim(self, new_dim): if new_dim not in (2, 3): - raise ValueError('WKB output dimension must be 2 or 3') + raise ValueError("WKB output dimension must be 2 or 3") wkb_writer_set_outdim(self.ptr, new_dim) # Property for getting/setting the include srid flag. diff --git a/django/contrib/gis/geos/prototypes/misc.py b/django/contrib/gis/geos/prototypes/misc.py index 016c8328cb..fccd0ecc9e 100644 --- a/django/contrib/gis/geos/prototypes/misc.py +++ b/django/contrib/gis/geos/prototypes/misc.py @@ -8,7 +8,7 @@ from django.contrib.gis.geos.libgeos import GEOM_PTR, GEOSFuncFactory from django.contrib.gis.geos.prototypes.errcheck import check_dbl, check_string from django.contrib.gis.geos.prototypes.geom import geos_char_p -__all__ = ['geos_area', 'geos_distance', 'geos_length', 'geos_isvalidreason'] +__all__ = ["geos_area", "geos_distance", "geos_length", "geos_isvalidreason"] class DblFromGeom(GEOSFuncFactory): @@ -16,6 +16,7 @@ class DblFromGeom(GEOSFuncFactory): Argument is a Geometry, return type is double that is passed in by reference as the last argument. """ + restype = c_int # Status code returned errcheck = staticmethod(check_dbl) @@ -23,9 +24,11 @@ class DblFromGeom(GEOSFuncFactory): # ### ctypes prototypes ### # Area, distance, and length prototypes. -geos_area = DblFromGeom('GEOSArea', argtypes=[GEOM_PTR, POINTER(c_double)]) -geos_distance = DblFromGeom('GEOSDistance', argtypes=[GEOM_PTR, GEOM_PTR, POINTER(c_double)]) -geos_length = DblFromGeom('GEOSLength', argtypes=[GEOM_PTR, POINTER(c_double)]) -geos_isvalidreason = GEOSFuncFactory( - 'GEOSisValidReason', restype=geos_char_p, errcheck=check_string, argtypes=[GEOM_PTR] +geos_area = DblFromGeom("GEOSArea", argtypes=[GEOM_PTR, POINTER(c_double)]) +geos_distance = DblFromGeom( + "GEOSDistance", argtypes=[GEOM_PTR, GEOM_PTR, POINTER(c_double)] +) +geos_length = DblFromGeom("GEOSLength", argtypes=[GEOM_PTR, POINTER(c_double)]) +geos_isvalidreason = GEOSFuncFactory( + "GEOSisValidReason", restype=geos_char_p, errcheck=check_string, argtypes=[GEOM_PTR] ) diff --git a/django/contrib/gis/geos/prototypes/predicates.py b/django/contrib/gis/geos/prototypes/predicates.py index d4681c1689..d2e113a734 100644 --- a/django/contrib/gis/geos/prototypes/predicates.py +++ b/django/contrib/gis/geos/prototypes/predicates.py @@ -22,22 +22,26 @@ class BinaryPredicate(UnaryPredicate): # ## Unary Predicates ## -geos_hasz = UnaryPredicate('GEOSHasZ') -geos_isclosed = UnaryPredicate('GEOSisClosed') -geos_isempty = UnaryPredicate('GEOSisEmpty') -geos_isring = UnaryPredicate('GEOSisRing') -geos_issimple = UnaryPredicate('GEOSisSimple') -geos_isvalid = UnaryPredicate('GEOSisValid') +geos_hasz = UnaryPredicate("GEOSHasZ") +geos_isclosed = UnaryPredicate("GEOSisClosed") +geos_isempty = UnaryPredicate("GEOSisEmpty") +geos_isring = UnaryPredicate("GEOSisRing") +geos_issimple = UnaryPredicate("GEOSisSimple") +geos_isvalid = UnaryPredicate("GEOSisValid") # ## Binary Predicates ## -geos_contains = BinaryPredicate('GEOSContains') -geos_covers = BinaryPredicate('GEOSCovers') -geos_crosses = BinaryPredicate('GEOSCrosses') -geos_disjoint = BinaryPredicate('GEOSDisjoint') -geos_equals = BinaryPredicate('GEOSEquals') -geos_equalsexact = BinaryPredicate('GEOSEqualsExact', argtypes=[GEOM_PTR, GEOM_PTR, c_double]) -geos_intersects = BinaryPredicate('GEOSIntersects') -geos_overlaps = BinaryPredicate('GEOSOverlaps') -geos_relatepattern = BinaryPredicate('GEOSRelatePattern', argtypes=[GEOM_PTR, GEOM_PTR, c_char_p]) -geos_touches = BinaryPredicate('GEOSTouches') -geos_within = BinaryPredicate('GEOSWithin') +geos_contains = BinaryPredicate("GEOSContains") +geos_covers = BinaryPredicate("GEOSCovers") +geos_crosses = BinaryPredicate("GEOSCrosses") +geos_disjoint = BinaryPredicate("GEOSDisjoint") +geos_equals = BinaryPredicate("GEOSEquals") +geos_equalsexact = BinaryPredicate( + "GEOSEqualsExact", argtypes=[GEOM_PTR, GEOM_PTR, c_double] +) +geos_intersects = BinaryPredicate("GEOSIntersects") +geos_overlaps = BinaryPredicate("GEOSOverlaps") +geos_relatepattern = BinaryPredicate( + "GEOSRelatePattern", argtypes=[GEOM_PTR, GEOM_PTR, c_char_p] +) +geos_touches = BinaryPredicate("GEOSTouches") +geos_within = BinaryPredicate("GEOSWithin") diff --git a/django/contrib/gis/geos/prototypes/prepared.py b/django/contrib/gis/geos/prototypes/prepared.py index 52c31563d8..4fdfab30ed 100644 --- a/django/contrib/gis/geos/prototypes/prepared.py +++ b/django/contrib/gis/geos/prototypes/prepared.py @@ -1,13 +1,11 @@ from ctypes import c_byte -from django.contrib.gis.geos.libgeos import ( - GEOM_PTR, PREPGEOM_PTR, GEOSFuncFactory, -) +from django.contrib.gis.geos.libgeos import GEOM_PTR, PREPGEOM_PTR, GEOSFuncFactory from django.contrib.gis.geos.prototypes.errcheck import check_predicate # Prepared geometry constructor and destructors. -geos_prepare = GEOSFuncFactory('GEOSPrepare', argtypes=[GEOM_PTR], restype=PREPGEOM_PTR) -prepared_destroy = GEOSFuncFactory('GEOSPreparedGeom_destroy', argtypes=[PREPGEOM_PTR]) +geos_prepare = GEOSFuncFactory("GEOSPrepare", argtypes=[GEOM_PTR], restype=PREPGEOM_PTR) +prepared_destroy = GEOSFuncFactory("GEOSPreparedGeom_destroy", argtypes=[PREPGEOM_PTR]) # Prepared geometry binary predicate support. @@ -17,12 +15,12 @@ class PreparedPredicate(GEOSFuncFactory): errcheck = staticmethod(check_predicate) -prepared_contains = PreparedPredicate('GEOSPreparedContains') -prepared_contains_properly = PreparedPredicate('GEOSPreparedContainsProperly') -prepared_covers = PreparedPredicate('GEOSPreparedCovers') -prepared_crosses = PreparedPredicate('GEOSPreparedCrosses') -prepared_disjoint = PreparedPredicate('GEOSPreparedDisjoint') -prepared_intersects = PreparedPredicate('GEOSPreparedIntersects') -prepared_overlaps = PreparedPredicate('GEOSPreparedOverlaps') -prepared_touches = PreparedPredicate('GEOSPreparedTouches') -prepared_within = PreparedPredicate('GEOSPreparedWithin') +prepared_contains = PreparedPredicate("GEOSPreparedContains") +prepared_contains_properly = PreparedPredicate("GEOSPreparedContainsProperly") +prepared_covers = PreparedPredicate("GEOSPreparedCovers") +prepared_crosses = PreparedPredicate("GEOSPreparedCrosses") +prepared_disjoint = PreparedPredicate("GEOSPreparedDisjoint") +prepared_intersects = PreparedPredicate("GEOSPreparedIntersects") +prepared_overlaps = PreparedPredicate("GEOSPreparedOverlaps") +prepared_touches = PreparedPredicate("GEOSPreparedTouches") +prepared_within = PreparedPredicate("GEOSPreparedWithin") diff --git a/django/contrib/gis/geos/prototypes/threadsafe.py b/django/contrib/gis/geos/prototypes/threadsafe.py index 500f6a6019..d4f7ffb8ac 100644 --- a/django/contrib/gis/geos/prototypes/threadsafe.py +++ b/django/contrib/gis/geos/prototypes/threadsafe.py @@ -1,13 +1,12 @@ import threading from django.contrib.gis.geos.base import GEOSBase -from django.contrib.gis.geos.libgeos import ( - CONTEXT_PTR, error_h, lgeos, notice_h, -) +from django.contrib.gis.geos.libgeos import CONTEXT_PTR, error_h, lgeos, notice_h class GEOSContextHandle(GEOSBase): """Represent a GEOS context handle.""" + ptr_type = CONTEXT_PTR destructor = lgeos.finishGEOS_r @@ -31,10 +30,11 @@ class GEOSFunc: Serve as a wrapper for GEOS C Functions. Use thread-safe function variants when available. """ + def __init__(self, func_name): # GEOS thread-safe function signatures end with '_r' and take an # additional context handle parameter. - self.cfunc = getattr(lgeos, func_name + '_r') + self.cfunc = getattr(lgeos, func_name + "_r") # Create a reference to thread_context so it's not garbage-collected # before an attempt to call this object. self.thread_context = thread_context diff --git a/django/contrib/gis/geos/prototypes/topology.py b/django/contrib/gis/geos/prototypes/topology.py index 9ca0dce695..e61eae964a 100644 --- a/django/contrib/gis/geos/prototypes/topology.py +++ b/django/contrib/gis/geos/prototypes/topology.py @@ -6,7 +6,9 @@ from ctypes import c_double, c_int from django.contrib.gis.geos.libgeos import GEOM_PTR, GEOSFuncFactory from django.contrib.gis.geos.prototypes.errcheck import ( - check_geom, check_minus_one, check_string, + check_geom, + check_minus_one, + check_string, ) from django.contrib.gis.geos.prototypes.geom import geos_char_p @@ -19,35 +21,52 @@ class Topology(GEOSFuncFactory): # Topology Routines -geos_boundary = Topology('GEOSBoundary') -geos_buffer = Topology('GEOSBuffer', argtypes=[GEOM_PTR, c_double, c_int]) -geos_bufferwithstyle = Topology('GEOSBufferWithStyle', argtypes=[GEOM_PTR, c_double, c_int, c_int, c_int, c_double]) -geos_centroid = Topology('GEOSGetCentroid') -geos_convexhull = Topology('GEOSConvexHull') -geos_difference = Topology('GEOSDifference', argtypes=[GEOM_PTR, GEOM_PTR]) -geos_envelope = Topology('GEOSEnvelope') -geos_intersection = Topology('GEOSIntersection', argtypes=[GEOM_PTR, GEOM_PTR]) -geos_linemerge = Topology('GEOSLineMerge') -geos_pointonsurface = Topology('GEOSPointOnSurface') -geos_preservesimplify = Topology('GEOSTopologyPreserveSimplify', argtypes=[GEOM_PTR, c_double]) -geos_simplify = Topology('GEOSSimplify', argtypes=[GEOM_PTR, c_double]) -geos_symdifference = Topology('GEOSSymDifference', argtypes=[GEOM_PTR, GEOM_PTR]) -geos_union = Topology('GEOSUnion', argtypes=[GEOM_PTR, GEOM_PTR]) +geos_boundary = Topology("GEOSBoundary") +geos_buffer = Topology("GEOSBuffer", argtypes=[GEOM_PTR, c_double, c_int]) +geos_bufferwithstyle = Topology( + "GEOSBufferWithStyle", argtypes=[GEOM_PTR, c_double, c_int, c_int, c_int, c_double] +) +geos_centroid = Topology("GEOSGetCentroid") +geos_convexhull = Topology("GEOSConvexHull") +geos_difference = Topology("GEOSDifference", argtypes=[GEOM_PTR, GEOM_PTR]) +geos_envelope = Topology("GEOSEnvelope") +geos_intersection = Topology("GEOSIntersection", argtypes=[GEOM_PTR, GEOM_PTR]) +geos_linemerge = Topology("GEOSLineMerge") +geos_pointonsurface = Topology("GEOSPointOnSurface") +geos_preservesimplify = Topology( + "GEOSTopologyPreserveSimplify", argtypes=[GEOM_PTR, c_double] +) +geos_simplify = Topology("GEOSSimplify", argtypes=[GEOM_PTR, c_double]) +geos_symdifference = Topology("GEOSSymDifference", argtypes=[GEOM_PTR, GEOM_PTR]) +geos_union = Topology("GEOSUnion", argtypes=[GEOM_PTR, GEOM_PTR]) -geos_unary_union = GEOSFuncFactory('GEOSUnaryUnion', argtypes=[GEOM_PTR], restype=GEOM_PTR) +geos_unary_union = GEOSFuncFactory( + "GEOSUnaryUnion", argtypes=[GEOM_PTR], restype=GEOM_PTR +) # GEOSRelate returns a string, not a geometry. geos_relate = GEOSFuncFactory( - 'GEOSRelate', argtypes=[GEOM_PTR, GEOM_PTR], restype=geos_char_p, errcheck=check_string + "GEOSRelate", + argtypes=[GEOM_PTR, GEOM_PTR], + restype=geos_char_p, + errcheck=check_string, ) # Linear referencing routines geos_project = GEOSFuncFactory( - 'GEOSProject', argtypes=[GEOM_PTR, GEOM_PTR], restype=c_double, errcheck=check_minus_one + "GEOSProject", + argtypes=[GEOM_PTR, GEOM_PTR], + restype=c_double, + errcheck=check_minus_one, ) -geos_interpolate = Topology('GEOSInterpolate', argtypes=[GEOM_PTR, c_double]) +geos_interpolate = Topology("GEOSInterpolate", argtypes=[GEOM_PTR, c_double]) geos_project_normalized = GEOSFuncFactory( - 'GEOSProjectNormalized', argtypes=[GEOM_PTR, GEOM_PTR], restype=c_double, errcheck=check_minus_one + "GEOSProjectNormalized", + argtypes=[GEOM_PTR, GEOM_PTR], + restype=c_double, + errcheck=check_minus_one, +) +geos_interpolate_normalized = Topology( + "GEOSInterpolateNormalized", argtypes=[GEOM_PTR, c_double] ) -geos_interpolate_normalized = Topology('GEOSInterpolateNormalized', argtypes=[GEOM_PTR, c_double]) diff --git a/django/contrib/gis/management/commands/inspectdb.py b/django/contrib/gis/management/commands/inspectdb.py index 8c6f62932a..1bcc1a0793 100644 --- a/django/contrib/gis/management/commands/inspectdb.py +++ b/django/contrib/gis/management/commands/inspectdb.py @@ -1,16 +1,18 @@ -from django.core.management.commands.inspectdb import ( - Command as InspectDBCommand, -) +from django.core.management.commands.inspectdb import Command as InspectDBCommand class Command(InspectDBCommand): - db_module = 'django.contrib.gis.db' + db_module = "django.contrib.gis.db" def get_field_type(self, connection, table_name, row): - field_type, field_params, field_notes = super().get_field_type(connection, table_name, row) - if field_type == 'GeometryField': + field_type, field_params, field_notes = super().get_field_type( + connection, table_name, row + ) + if field_type == "GeometryField": # Getting a more specific field type and any additional parameters # from the `get_geometry_type` routine for the spatial backend. - field_type, geo_params = connection.introspection.get_geometry_type(table_name, row) + field_type, geo_params = connection.introspection.get_geometry_type( + table_name, row + ) field_params.update(geo_params) return field_type, field_params, field_notes diff --git a/django/contrib/gis/management/commands/ogrinspect.py b/django/contrib/gis/management/commands/ogrinspect.py index 12c1db392f..133cec3e60 100644 --- a/django/contrib/gis/management/commands/ogrinspect.py +++ b/django/contrib/gis/management/commands/ogrinspect.py @@ -10,6 +10,7 @@ class LayerOptionAction(argparse.Action): Custom argparse action for the `ogrinspect` `layer_key` keyword option which may be an integer or a string. """ + def __call__(self, parser, namespace, value, option_string=None): try: setattr(namespace, self.dest, int(value)) @@ -23,80 +24,92 @@ class ListOptionAction(argparse.Action): a string list. If the string is 'True'/'true' then the option value will be a boolean instead. """ + def __call__(self, parser, namespace, value, option_string=None): - if value.lower() == 'true': + if value.lower() == "true": setattr(namespace, self.dest, True) else: - setattr(namespace, self.dest, value.split(',')) + setattr(namespace, self.dest, value.split(",")) class Command(BaseCommand): help = ( - 'Inspects the given OGR-compatible data source (e.g., a shapefile) and outputs\n' - 'a GeoDjango model with the given model name. For example:\n' - ' ./manage.py ogrinspect zipcode.shp Zipcode' + "Inspects the given OGR-compatible data source (e.g., a shapefile) and outputs\n" + "a GeoDjango model with the given model name. For example:\n" + " ./manage.py ogrinspect zipcode.shp Zipcode" ) requires_system_checks = [] def add_arguments(self, parser): - parser.add_argument('data_source', help='Path to the data source.') - parser.add_argument('model_name', help='Name of the model to create.') + parser.add_argument("data_source", help="Path to the data source.") + parser.add_argument("model_name", help="Name of the model to create.") parser.add_argument( - '--blank', - action=ListOptionAction, default=False, - help='Use a comma separated list of OGR field names to add ' - 'the `blank=True` option to the field definition. Set to `true` ' - 'to apply to all applicable fields.', + "--blank", + action=ListOptionAction, + default=False, + help="Use a comma separated list of OGR field names to add " + "the `blank=True` option to the field definition. Set to `true` " + "to apply to all applicable fields.", ) parser.add_argument( - '--decimal', - action=ListOptionAction, default=False, - help='Use a comma separated list of OGR float fields to ' - 'generate `DecimalField` instead of the default ' - '`FloatField`. Set to `true` to apply to all OGR float fields.', + "--decimal", + action=ListOptionAction, + default=False, + help="Use a comma separated list of OGR float fields to " + "generate `DecimalField` instead of the default " + "`FloatField`. Set to `true` to apply to all OGR float fields.", ) parser.add_argument( - '--geom-name', default='geom', - help='Specifies the model name for the Geometry Field (defaults to `geom`)' + "--geom-name", + default="geom", + help="Specifies the model name for the Geometry Field (defaults to `geom`)", ) parser.add_argument( - '--layer', dest='layer_key', - action=LayerOptionAction, default=0, - help='The key for specifying which layer in the OGR data ' - 'source to use. Defaults to 0 (the first layer). May be ' - 'an integer or a string identifier for the layer.', + "--layer", + dest="layer_key", + action=LayerOptionAction, + default=0, + help="The key for specifying which layer in the OGR data " + "source to use. Defaults to 0 (the first layer). May be " + "an integer or a string identifier for the layer.", ) parser.add_argument( - '--multi-geom', action='store_true', - help='Treat the geometry in the data source as a geometry collection.', + "--multi-geom", + action="store_true", + help="Treat the geometry in the data source as a geometry collection.", ) parser.add_argument( - '--name-field', - help='Specifies a field name to return for the __str__() method.', + "--name-field", + help="Specifies a field name to return for the __str__() method.", ) parser.add_argument( - '--no-imports', action='store_false', dest='imports', - help='Do not include `from django.contrib.gis.db import models` statement.', + "--no-imports", + action="store_false", + dest="imports", + help="Do not include `from django.contrib.gis.db import models` statement.", ) parser.add_argument( - '--null', action=ListOptionAction, default=False, - help='Use a comma separated list of OGR field names to add ' - 'the `null=True` option to the field definition. Set to `true` ' - 'to apply to all applicable fields.', + "--null", + action=ListOptionAction, + default=False, + help="Use a comma separated list of OGR field names to add " + "the `null=True` option to the field definition. Set to `true` " + "to apply to all applicable fields.", ) parser.add_argument( - '--srid', - help='The SRID to use for the Geometry Field. If it can be ' - 'determined, the SRID of the data source is used.', + "--srid", + help="The SRID to use for the Geometry Field. If it can be " + "determined, the SRID of the data source is used.", ) parser.add_argument( - '--mapping', action='store_true', - help='Generate mapping dictionary for use with `LayerMapping`.', + "--mapping", + action="store_true", + help="Generate mapping dictionary for use with `LayerMapping`.", ) def handle(self, *args, **options): - data_source, model_name = options.pop('data_source'), options.pop('model_name') + data_source, model_name = options.pop("data_source"), options.pop("model_name") # Getting the OGR DataSource from the string parameter. try: @@ -109,26 +122,43 @@ class Command(BaseCommand): from django.contrib.gis.utils.ogrinspect import _ogrinspect, mapping # Filter options to params accepted by `_ogrinspect` - ogr_options = {k: v for k, v in options.items() - if k in get_func_args(_ogrinspect) and v is not None} + ogr_options = { + k: v + for k, v in options.items() + if k in get_func_args(_ogrinspect) and v is not None + } output = [s for s in _ogrinspect(ds, model_name, **ogr_options)] - if options['mapping']: + if options["mapping"]: # Constructing the keyword arguments for `mapping`, and # calling it on the data source. kwargs = { - 'geom_name': options['geom_name'], - 'layer_key': options['layer_key'], - 'multi_geom': options['multi_geom'], + "geom_name": options["geom_name"], + "layer_key": options["layer_key"], + "multi_geom": options["multi_geom"], } mapping_dict = mapping(ds, **kwargs) # This extra legwork is so that the dictionary definition comes # out in the same order as the fields in the model definition. rev_mapping = {v: k for k, v in mapping_dict.items()} - output.extend(['', '', '# Auto-generated `LayerMapping` dictionary for %s model' % model_name, - '%s_mapping = {' % model_name.lower()]) - output.extend(" '%s': '%s'," % ( - rev_mapping[ogr_fld], ogr_fld) for ogr_fld in ds[options['layer_key']].fields + output.extend( + [ + "", + "", + "# Auto-generated `LayerMapping` dictionary for %s model" + % model_name, + "%s_mapping = {" % model_name.lower(), + ] ) - output.extend([" '%s': '%s'," % (options['geom_name'], mapping_dict[options['geom_name']]), '}']) - return '\n'.join(output) + output.extend( + " '%s': '%s'," % (rev_mapping[ogr_fld], ogr_fld) + for ogr_fld in ds[options["layer_key"]].fields + ) + output.extend( + [ + " '%s': '%s'," + % (options["geom_name"], mapping_dict[options["geom_name"]]), + "}", + ] + ) + return "\n".join(output) diff --git a/django/contrib/gis/measure.py b/django/contrib/gis/measure.py index 7e571f14dd..640f0e0068 100644 --- a/django/contrib/gis/measure.py +++ b/django/contrib/gis/measure.py @@ -38,7 +38,7 @@ and Geoff Biggs' PhD work on dimensioned units for robotics. from decimal import Decimal from functools import total_ordering -__all__ = ['A', 'Area', 'D', 'Distance'] +__all__ = ["A", "Area", "D", "Distance"] NUMERIC_TYPES = (int, float, Decimal) AREA_PREFIX = "sq_" @@ -73,13 +73,17 @@ class MeasureBase: if name in self.UNITS: return self.standard / self.UNITS[name] else: - raise AttributeError('Unknown unit type: %s' % name) + raise AttributeError("Unknown unit type: %s" % name) def __repr__(self): - return '%s(%s=%s)' % (pretty_name(self), self._default_unit, getattr(self, self._default_unit)) + return "%s(%s=%s)" % ( + pretty_name(self), + self._default_unit, + getattr(self, self._default_unit), + ) def __str__(self): - return '%s %s' % (getattr(self, self._default_unit), self._default_unit) + return "%s %s" % (getattr(self, self._default_unit), self._default_unit) # **** Comparison methods **** @@ -104,49 +108,65 @@ class MeasureBase: if isinstance(other, self.__class__): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard + other.standard)} + **{self.STANDARD_UNIT: (self.standard + other.standard)}, ) else: - raise TypeError('%(class)s must be added with %(class)s' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be added with %(class)s" % {"class": pretty_name(self)} + ) def __iadd__(self, other): if isinstance(other, self.__class__): self.standard += other.standard return self else: - raise TypeError('%(class)s must be added with %(class)s' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be added with %(class)s" % {"class": pretty_name(self)} + ) def __sub__(self, other): if isinstance(other, self.__class__): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard - other.standard)} + **{self.STANDARD_UNIT: (self.standard - other.standard)}, ) else: - raise TypeError('%(class)s must be subtracted from %(class)s' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be subtracted from %(class)s" + % {"class": pretty_name(self)} + ) def __isub__(self, other): if isinstance(other, self.__class__): self.standard -= other.standard return self else: - raise TypeError('%(class)s must be subtracted from %(class)s' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be subtracted from %(class)s" + % {"class": pretty_name(self)} + ) def __mul__(self, other): if isinstance(other, NUMERIC_TYPES): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard * other)} + **{self.STANDARD_UNIT: (self.standard * other)}, ) else: - raise TypeError('%(class)s must be multiplied with number' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be multiplied with number" + % {"class": pretty_name(self)} + ) def __imul__(self, other): if isinstance(other, NUMERIC_TYPES): self.standard *= float(other) return self else: - raise TypeError('%(class)s must be multiplied with number' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be multiplied with number" + % {"class": pretty_name(self)} + ) def __rmul__(self, other): return self * other @@ -157,17 +177,22 @@ class MeasureBase: if isinstance(other, NUMERIC_TYPES): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard / other)} + **{self.STANDARD_UNIT: (self.standard / other)}, ) else: - raise TypeError('%(class)s must be divided with number or %(class)s' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be divided with number or %(class)s" + % {"class": pretty_name(self)} + ) def __itruediv__(self, other): if isinstance(other, NUMERIC_TYPES): self.standard /= float(other) return self else: - raise TypeError('%(class)s must be divided with number' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be divided with number" % {"class": pretty_name(self)} + ) def __bool__(self): return bool(self.standard) @@ -199,7 +224,7 @@ class MeasureBase: val += self.UNITS[u] * value default_unit = u else: - raise AttributeError('Unknown unit type: %s' % unit) + raise AttributeError("Unknown unit type: %s" % unit) return val, default_unit @classmethod @@ -217,85 +242,87 @@ class MeasureBase: elif lower in cls.LALIAS: return cls.LALIAS[lower] else: - raise Exception('Could not find a unit keyword associated with "%s"' % unit_str) + raise Exception( + 'Could not find a unit keyword associated with "%s"' % unit_str + ) class Distance(MeasureBase): STANDARD_UNIT = "m" UNITS = { - 'chain': 20.1168, - 'chain_benoit': 20.116782, - 'chain_sears': 20.1167645, - 'british_chain_benoit': 20.1167824944, - 'british_chain_sears': 20.1167651216, - 'british_chain_sears_truncated': 20.116756, - 'cm': 0.01, - 'british_ft': 0.304799471539, - 'british_yd': 0.914398414616, - 'clarke_ft': 0.3047972654, - 'clarke_link': 0.201166195164, - 'fathom': 1.8288, - 'ft': 0.3048, - 'furlong': 201.168, - 'german_m': 1.0000135965, - 'gold_coast_ft': 0.304799710181508, - 'indian_yd': 0.914398530744, - 'inch': 0.0254, - 'km': 1000.0, - 'link': 0.201168, - 'link_benoit': 0.20116782, - 'link_sears': 0.20116765, - 'm': 1.0, - 'mi': 1609.344, - 'mm': 0.001, - 'nm': 1852.0, - 'nm_uk': 1853.184, - 'rod': 5.0292, - 'sears_yd': 0.91439841, - 'survey_ft': 0.304800609601, - 'um': 0.000001, - 'yd': 0.9144, + "chain": 20.1168, + "chain_benoit": 20.116782, + "chain_sears": 20.1167645, + "british_chain_benoit": 20.1167824944, + "british_chain_sears": 20.1167651216, + "british_chain_sears_truncated": 20.116756, + "cm": 0.01, + "british_ft": 0.304799471539, + "british_yd": 0.914398414616, + "clarke_ft": 0.3047972654, + "clarke_link": 0.201166195164, + "fathom": 1.8288, + "ft": 0.3048, + "furlong": 201.168, + "german_m": 1.0000135965, + "gold_coast_ft": 0.304799710181508, + "indian_yd": 0.914398530744, + "inch": 0.0254, + "km": 1000.0, + "link": 0.201168, + "link_benoit": 0.20116782, + "link_sears": 0.20116765, + "m": 1.0, + "mi": 1609.344, + "mm": 0.001, + "nm": 1852.0, + "nm_uk": 1853.184, + "rod": 5.0292, + "sears_yd": 0.91439841, + "survey_ft": 0.304800609601, + "um": 0.000001, + "yd": 0.9144, } # Unit aliases for `UNIT` terms encountered in Spatial Reference WKT. ALIAS = { - 'centimeter': 'cm', - 'foot': 'ft', - 'inches': 'inch', - 'kilometer': 'km', - 'kilometre': 'km', - 'meter': 'm', - 'metre': 'm', - 'micrometer': 'um', - 'micrometre': 'um', - 'millimeter': 'mm', - 'millimetre': 'mm', - 'mile': 'mi', - 'yard': 'yd', - 'British chain (Benoit 1895 B)': 'british_chain_benoit', - 'British chain (Sears 1922)': 'british_chain_sears', - 'British chain (Sears 1922 truncated)': 'british_chain_sears_truncated', - 'British foot (Sears 1922)': 'british_ft', - 'British foot': 'british_ft', - 'British yard (Sears 1922)': 'british_yd', - 'British yard': 'british_yd', - "Clarke's Foot": 'clarke_ft', - "Clarke's link": 'clarke_link', - 'Chain (Benoit)': 'chain_benoit', - 'Chain (Sears)': 'chain_sears', - 'Foot (International)': 'ft', - 'Furrow Long': 'furlong', - 'German legal metre': 'german_m', - 'Gold Coast foot': 'gold_coast_ft', - 'Indian yard': 'indian_yd', - 'Link (Benoit)': 'link_benoit', - 'Link (Sears)': 'link_sears', - 'Nautical Mile': 'nm', - 'Nautical Mile (UK)': 'nm_uk', - 'US survey foot': 'survey_ft', - 'U.S. Foot': 'survey_ft', - 'Yard (Indian)': 'indian_yd', - 'Yard (Sears)': 'sears_yd' + "centimeter": "cm", + "foot": "ft", + "inches": "inch", + "kilometer": "km", + "kilometre": "km", + "meter": "m", + "metre": "m", + "micrometer": "um", + "micrometre": "um", + "millimeter": "mm", + "millimetre": "mm", + "mile": "mi", + "yard": "yd", + "British chain (Benoit 1895 B)": "british_chain_benoit", + "British chain (Sears 1922)": "british_chain_sears", + "British chain (Sears 1922 truncated)": "british_chain_sears_truncated", + "British foot (Sears 1922)": "british_ft", + "British foot": "british_ft", + "British yard (Sears 1922)": "british_yd", + "British yard": "british_yd", + "Clarke's Foot": "clarke_ft", + "Clarke's link": "clarke_link", + "Chain (Benoit)": "chain_benoit", + "Chain (Sears)": "chain_sears", + "Foot (International)": "ft", + "Furrow Long": "furlong", + "German legal metre": "german_m", + "Gold Coast foot": "gold_coast_ft", + "Indian yard": "indian_yd", + "Link (Benoit)": "link_benoit", + "Link (Sears)": "link_sears", + "Nautical Mile": "nm", + "Nautical Mile (UK)": "nm_uk", + "US survey foot": "survey_ft", + "U.S. Foot": "survey_ft", + "Yard (Indian)": "indian_yd", + "Yard (Sears)": "sears_yd", } LALIAS = {k.lower(): v for k, v in ALIAS.items()} @@ -303,34 +330,39 @@ class Distance(MeasureBase): if isinstance(other, self.__class__): return Area( default_unit=AREA_PREFIX + self._default_unit, - **{AREA_PREFIX + self.STANDARD_UNIT: (self.standard * other.standard)} + **{AREA_PREFIX + self.STANDARD_UNIT: (self.standard * other.standard)}, ) elif isinstance(other, NUMERIC_TYPES): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard * other)} + **{self.STANDARD_UNIT: (self.standard * other)}, ) else: - raise TypeError('%(distance)s must be multiplied with number or %(distance)s' % { - "distance": pretty_name(self.__class__), - }) + raise TypeError( + "%(distance)s must be multiplied with number or %(distance)s" + % { + "distance": pretty_name(self.__class__), + } + ) class Area(MeasureBase): STANDARD_UNIT = AREA_PREFIX + Distance.STANDARD_UNIT # Getting the square units values and the alias dictionary. - UNITS = {'%s%s' % (AREA_PREFIX, k): v ** 2 for k, v in Distance.UNITS.items()} - ALIAS = {k: '%s%s' % (AREA_PREFIX, v) for k, v in Distance.ALIAS.items()} + UNITS = {"%s%s" % (AREA_PREFIX, k): v**2 for k, v in Distance.UNITS.items()} + ALIAS = {k: "%s%s" % (AREA_PREFIX, v) for k, v in Distance.ALIAS.items()} LALIAS = {k.lower(): v for k, v in ALIAS.items()} def __truediv__(self, other): if isinstance(other, NUMERIC_TYPES): return self.__class__( default_unit=self._default_unit, - **{self.STANDARD_UNIT: (self.standard / other)} + **{self.STANDARD_UNIT: (self.standard / other)}, ) else: - raise TypeError('%(class)s must be divided by a number' % {"class": pretty_name(self)}) + raise TypeError( + "%(class)s must be divided by a number" % {"class": pretty_name(self)} + ) # Shortcuts diff --git a/django/contrib/gis/ptr.py b/django/contrib/gis/ptr.py index a5a117a19a..6754701613 100644 --- a/django/contrib/gis/ptr.py +++ b/django/contrib/gis/ptr.py @@ -6,6 +6,7 @@ class CPointerBase: Base class for objects that have a pointer access property that controls access to the underlying C pointer. """ + _ptr = None # Initially the pointer is NULL. ptr_type = c_void_p destructor = None @@ -17,14 +18,16 @@ class CPointerBase: # aren't passed to routines -- that's very bad. if self._ptr: return self._ptr - raise self.null_ptr_exception_class('NULL %s pointer encountered.' % self.__class__.__name__) + raise self.null_ptr_exception_class( + "NULL %s pointer encountered." % self.__class__.__name__ + ) @ptr.setter def ptr(self, ptr): # Only allow the pointer to be set with pointers of the compatible # type or None (NULL). if not (ptr is None or isinstance(ptr, self.ptr_type)): - raise TypeError('Incompatible pointer type: %s.' % type(ptr)) + raise TypeError("Incompatible pointer type: %s." % type(ptr)) self._ptr = ptr def __del__(self): diff --git a/django/contrib/gis/serializers/geojson.py b/django/contrib/gis/serializers/geojson.py index 0e4f744774..f90d3544e7 100644 --- a/django/contrib/gis/serializers/geojson.py +++ b/django/contrib/gis/serializers/geojson.py @@ -9,12 +9,16 @@ class Serializer(JSONSerializer): """ Convert a queryset to GeoJSON, http://geojson.org/ """ + def _init_options(self): super()._init_options() - self.geometry_field = self.json_kwargs.pop('geometry_field', None) - self.srid = self.json_kwargs.pop('srid', 4326) - if (self.selected_fields is not None and self.geometry_field is not None and - self.geometry_field not in self.selected_fields): + self.geometry_field = self.json_kwargs.pop("geometry_field", None) + self.srid = self.json_kwargs.pop("srid", 4326) + if ( + self.selected_fields is not None + and self.geometry_field is not None + and self.geometry_field not in self.selected_fields + ): self.selected_fields = [*self.selected_fields, self.geometry_field] def start_serialization(self): @@ -22,10 +26,11 @@ class Serializer(JSONSerializer): self._cts = {} # cache of CoordTransform's self.stream.write( '{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "EPSG:%d"}},' - ' "features": [' % self.srid) + ' "features": [' % self.srid + ) def end_serialization(self): - self.stream.write(']}') + self.stream.write("]}") def start_object(self, obj): super().start_object(obj) @@ -33,7 +38,7 @@ class Serializer(JSONSerializer): if self.geometry_field is None: # Find the first declared geometry field for field in obj._meta.fields: - if hasattr(field, 'geom_type'): + if hasattr(field, "geom_type"): self.geometry_field = field.name break @@ -42,15 +47,18 @@ class Serializer(JSONSerializer): "type": "Feature", "properties": self._current, } - if ((self.selected_fields is None or 'pk' in self.selected_fields) and - 'pk' not in data["properties"]): + if ( + self.selected_fields is None or "pk" in self.selected_fields + ) and "pk" not in data["properties"]: data["properties"]["pk"] = obj._meta.pk.value_to_string(obj) if self._geometry: if self._geometry.srid != self.srid: # If needed, transform the geometry in the srid of the global geojson srid if self._geometry.srid not in self._cts: srs = SpatialReference(self.srid) - self._cts[self._geometry.srid] = CoordTransform(self._geometry.srs, srs) + self._cts[self._geometry.srid] = CoordTransform( + self._geometry.srs, srs + ) self._geometry.transform(self._cts[self._geometry.srid]) data["geometry"] = json.loads(self._geometry.geojson) else: diff --git a/django/contrib/gis/shortcuts.py b/django/contrib/gis/shortcuts.py index aa68140b62..33ff2acb06 100644 --- a/django/contrib/gis/shortcuts.py +++ b/django/contrib/gis/shortcuts.py @@ -15,8 +15,8 @@ except ImportError: def compress_kml(kml): "Return compressed KMZ from the given KML string." kmz = BytesIO() - with zipfile.ZipFile(kmz, 'a', zipfile.ZIP_DEFLATED) as zf: - zf.writestr('doc.kml', kml.encode(settings.DEFAULT_CHARSET)) + with zipfile.ZipFile(kmz, "a", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("doc.kml", kml.encode(settings.DEFAULT_CHARSET)) kmz.seek(0) return kmz.read() @@ -25,7 +25,7 @@ def render_to_kml(*args, **kwargs): "Render the response as KML (using the correct MIME type)." return HttpResponse( loader.render_to_string(*args, **kwargs), - content_type='application/vnd.google-earth.kml+xml', + content_type="application/vnd.google-earth.kml+xml", ) @@ -36,5 +36,5 @@ def render_to_kmz(*args, **kwargs): """ return HttpResponse( compress_kml(loader.render_to_string(*args, **kwargs)), - content_type='application/vnd.google-earth.kmz', + content_type="application/vnd.google-earth.kmz", ) diff --git a/django/contrib/gis/sitemaps/__init__.py b/django/contrib/gis/sitemaps/__init__.py index 7337e65a9c..3654bfc80b 100644 --- a/django/contrib/gis/sitemaps/__init__.py +++ b/django/contrib/gis/sitemaps/__init__.py @@ -1,4 +1,4 @@ # Geo-enabled Sitemap classes. from django.contrib.gis.sitemaps.kml import KMLSitemap, KMZSitemap -__all__ = ['KMLSitemap', 'KMZSitemap'] +__all__ = ["KMLSitemap", "KMZSitemap"] diff --git a/django/contrib/gis/sitemaps/kml.py b/django/contrib/gis/sitemaps/kml.py index 103dc20408..5ec089d6f5 100644 --- a/django/contrib/gis/sitemaps/kml.py +++ b/django/contrib/gis/sitemaps/kml.py @@ -9,7 +9,8 @@ class KMLSitemap(Sitemap): """ A minimal hook to produce KML sitemaps. """ - geo_format = 'kml' + + geo_format = "kml" def __init__(self, locations=None): # If no locations specified, then we try to build for @@ -31,15 +32,21 @@ class KMLSitemap(Sitemap): if isinstance(source, models.base.ModelBase): for field in source._meta.fields: if isinstance(field, GeometryField): - kml_sources.append((source._meta.app_label, - source._meta.model_name, - field.name)) + kml_sources.append( + ( + source._meta.app_label, + source._meta.model_name, + field.name, + ) + ) elif isinstance(source, (list, tuple)): if len(source) != 3: - raise ValueError('Must specify a 3-tuple of (app_label, module_name, field_name).') + raise ValueError( + "Must specify a 3-tuple of (app_label, module_name, field_name)." + ) kml_sources.append(source) else: - raise TypeError('KML Sources must be a model or a 3-tuple.') + raise TypeError("KML Sources must be a model or a 3-tuple.") return kml_sources def get_urls(self, page=1, site=None, protocol=None): @@ -49,7 +56,7 @@ class KMLSitemap(Sitemap): """ urls = Sitemap.get_urls(self, page=page, site=site, protocol=protocol) for url in urls: - url['geo_format'] = self.geo_format + url["geo_format"] = self.geo_format return urls def items(self): @@ -57,14 +64,14 @@ class KMLSitemap(Sitemap): def location(self, obj): return reverse( - 'django.contrib.gis.sitemaps.views.%s' % self.geo_format, + "django.contrib.gis.sitemaps.views.%s" % self.geo_format, kwargs={ - 'label': obj[0], - 'model': obj[1], - 'field_name': obj[2], + "label": obj[0], + "model": obj[1], + "field_name": obj[2], }, ) class KMZSitemap(KMLSitemap): - geo_format = 'kmz' + geo_format = "kmz" diff --git a/django/contrib/gis/sitemaps/views.py b/django/contrib/gis/sitemaps/views.py index be91d929f6..17eb54e60c 100644 --- a/django/contrib/gis/sitemaps/views.py +++ b/django/contrib/gis/sitemaps/views.py @@ -17,7 +17,10 @@ def kml(request, label, model, field_name=None, compress=False, using=DEFAULT_DB try: klass = apps.get_model(label, model) except LookupError: - raise Http404('You must supply a valid app label and module name. Got "%s.%s"' % (label, model)) + raise Http404( + 'You must supply a valid app label and module name. Got "%s.%s"' + % (label, model) + ) if field_name: try: @@ -25,7 +28,7 @@ def kml(request, label, model, field_name=None, compress=False, using=DEFAULT_DB if not isinstance(field, GeometryField): raise FieldDoesNotExist except FieldDoesNotExist: - raise Http404('Invalid geometry field.') + raise Http404("Invalid geometry field.") connection = connections[using] @@ -38,8 +41,9 @@ def kml(request, label, model, field_name=None, compress=False, using=DEFAULT_DB placemarks = [] if connection.features.has_Transform_function: qs = klass._default_manager.using(using).annotate( - **{'%s_4326' % field_name: Transform(field_name, 4326)}) - field_name += '_4326' + **{"%s_4326" % field_name: Transform(field_name, 4326)} + ) + field_name += "_4326" else: qs = klass._default_manager.using(using).all() for mod in qs: @@ -51,7 +55,7 @@ def kml(request, label, model, field_name=None, compress=False, using=DEFAULT_DB render = render_to_kmz else: render = render_to_kml - return render('gis/kml/placemarks.kml', {'places': placemarks}) + return render("gis/kml/placemarks.kml", {"places": placemarks}) def kmz(request, label, model, field_name=None, using=DEFAULT_DB_ALIAS): diff --git a/django/contrib/gis/utils/__init__.py b/django/contrib/gis/utils/__init__.py index c195ded932..26334fb6a4 100644 --- a/django/contrib/gis/utils/__init__.py +++ b/django/contrib/gis/utils/__init__.py @@ -10,7 +10,8 @@ try: # LayerMapping requires DJANGO_SETTINGS_MODULE to be set, # and ImproperlyConfigured is raised if that's not the case. from django.contrib.gis.utils.layermapping import ( # NOQA - LayerMapError, LayerMapping, + LayerMapError, + LayerMapping, ) except ImproperlyConfigured: pass diff --git a/django/contrib/gis/utils/layermapping.py b/django/contrib/gis/utils/layermapping.py index 3c78085489..bad4c05fb8 100644 --- a/django/contrib/gis/utils/layermapping.py +++ b/django/contrib/gis/utils/layermapping.py @@ -7,16 +7,26 @@ https://docs.djangoproject.com/en/dev/ref/contrib/gis/layermapping/ """ import sys -from decimal import Decimal, InvalidOperation as DecimalInvalidOperation +from decimal import Decimal +from decimal import InvalidOperation as DecimalInvalidOperation from pathlib import Path from django.contrib.gis.db.models import GeometryField from django.contrib.gis.gdal import ( - CoordTransform, DataSource, GDALException, OGRGeometry, OGRGeomType, + CoordTransform, + DataSource, + GDALException, + OGRGeometry, + OGRGeomType, SpatialReference, ) from django.contrib.gis.gdal.field import ( - OFTDate, OFTDateTime, OFTInteger, OFTInteger64, OFTReal, OFTString, + OFTDate, + OFTDateTime, + OFTInteger, + OFTInteger64, + OFTReal, + OFTString, OFTTime, ) from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist @@ -50,12 +60,12 @@ class LayerMapping: # Acceptable 'base' types for a multi-geometry type. MULTI_TYPES = { - 1: OGRGeomType('MultiPoint'), - 2: OGRGeomType('MultiLineString'), - 3: OGRGeomType('MultiPolygon'), - OGRGeomType('Point25D').num: OGRGeomType('MultiPoint25D'), - OGRGeomType('LineString25D').num: OGRGeomType('MultiLineString25D'), - OGRGeomType('Polygon25D').num: OGRGeomType('MultiPolygon25D'), + 1: OGRGeomType("MultiPoint"), + 2: OGRGeomType("MultiLineString"), + 3: OGRGeomType("MultiPolygon"), + OGRGeomType("Point25D").num: OGRGeomType("MultiPoint25D"), + OGRGeomType("LineString25D").num: OGRGeomType("MultiLineString25D"), + OGRGeomType("Polygon25D").num: OGRGeomType("MultiPolygon25D"), } # Acceptable Django field types and corresponding acceptable OGR # counterparts. @@ -83,10 +93,19 @@ class LayerMapping: models.PositiveSmallIntegerField: (OFTInteger, OFTReal, OFTString), } - def __init__(self, model, data, mapping, layer=0, - source_srs=None, encoding='utf-8', - transaction_mode='commit_on_success', - transform=True, unique=None, using=None): + def __init__( + self, + model, + data, + mapping, + layer=0, + source_srs=None, + encoding="utf-8", + transaction_mode="commit_on_success", + transform=True, + unique=None, + using=None, + ): """ A LayerMapping object is initialized using the given Model (not an instance), a DataSource (or string path to an OGR-supported data file), and a mapping @@ -133,6 +152,7 @@ class LayerMapping: # Making sure the encoding exists, if not a LookupError # exception will be thrown. from codecs import lookup + lookup(encoding) self.encoding = encoding else: @@ -140,7 +160,7 @@ class LayerMapping: if unique: self.check_unique(unique) - transaction_mode = 'autocommit' # Has to be set to autocommit. + transaction_mode = "autocommit" # Has to be set to autocommit. self.unique = unique else: self.unique = None @@ -148,12 +168,12 @@ class LayerMapping: # Setting the transaction decorator with the function in the # transaction modes dictionary. self.transaction_mode = transaction_mode - if transaction_mode == 'autocommit': + if transaction_mode == "autocommit": self.transaction_decorator = None - elif transaction_mode == 'commit_on_success': + elif transaction_mode == "commit_on_success": self.transaction_decorator = transaction.atomic else: - raise LayerMapError('Unrecognized transaction mode: %s' % transaction_mode) + raise LayerMapError("Unrecognized transaction mode: %s" % transaction_mode) # #### Checking routines used during initialization #### def check_fid_range(self, fid_range): @@ -190,7 +210,9 @@ class LayerMapping: try: idx = ogr_fields.index(ogr_map_fld) except ValueError: - raise LayerMapError('Given mapping OGR field "%s" not found in OGR Layer.' % ogr_map_fld) + raise LayerMapError( + 'Given mapping OGR field "%s" not found in OGR Layer.' % ogr_map_fld + ) return idx # No need to increment through each feature in the model, simply check @@ -201,32 +223,43 @@ class LayerMapping: try: model_field = self.model._meta.get_field(field_name) except FieldDoesNotExist: - raise LayerMapError('Given mapping field "%s" not in given Model fields.' % field_name) + raise LayerMapError( + 'Given mapping field "%s" not in given Model fields.' % field_name + ) # Getting the string name for the Django field class (e.g., 'PointField'). fld_name = model_field.__class__.__name__ if isinstance(model_field, GeometryField): if self.geom_field: - raise LayerMapError('LayerMapping does not support more than one GeometryField per model.') + raise LayerMapError( + "LayerMapping does not support more than one GeometryField per model." + ) # Getting the coordinate dimension of the geometry field. coord_dim = model_field.dim try: if coord_dim == 3: - gtype = OGRGeomType(ogr_name + '25D') + gtype = OGRGeomType(ogr_name + "25D") else: gtype = OGRGeomType(ogr_name) except GDALException: - raise LayerMapError('Invalid mapping for GeometryField "%s".' % field_name) + raise LayerMapError( + 'Invalid mapping for GeometryField "%s".' % field_name + ) # Making sure that the OGR Layer's Geometry is compatible. ltype = self.layer.geom_type - if not (ltype.name.startswith(gtype.name) or self.make_multi(ltype, model_field)): - raise LayerMapError('Invalid mapping geometry; model has %s%s, ' - 'layer geometry type is %s.' % - (fld_name, '(dim=3)' if coord_dim == 3 else '', ltype)) + if not ( + ltype.name.startswith(gtype.name) + or self.make_multi(ltype, model_field) + ): + raise LayerMapError( + "Invalid mapping geometry; model has %s%s, " + "layer geometry type is %s." + % (fld_name, "(dim=3)" if coord_dim == 3 else "", ltype) + ) # Setting the `geom_field` attribute w/the name of the model field # that is a Geometry. Also setting the coordinate dimension @@ -243,15 +276,19 @@ class LayerMapping: try: rel_model._meta.get_field(rel_name) except FieldDoesNotExist: - raise LayerMapError('ForeignKey mapping field "%s" not in %s fields.' % - (rel_name, rel_model.__class__.__name__)) + raise LayerMapError( + 'ForeignKey mapping field "%s" not in %s fields.' + % (rel_name, rel_model.__class__.__name__) + ) fields_val = rel_model else: - raise TypeError('ForeignKey mapping must be of dictionary type.') + raise TypeError("ForeignKey mapping must be of dictionary type.") else: # Is the model field type supported by LayerMapping? if model_field.__class__ not in self.FIELD_TYPES: - raise LayerMapError('Django field type "%s" has no OGR mapping (yet).' % fld_name) + raise LayerMapError( + 'Django field type "%s" has no OGR mapping (yet).' % fld_name + ) # Is the OGR field in the Layer? idx = check_ogr_fld(ogr_name) @@ -259,8 +296,10 @@ class LayerMapping: # Can the OGR field type be mapped to the Django field type? if not issubclass(ogr_field, self.FIELD_TYPES[model_field.__class__]): - raise LayerMapError('OGR field "%s" (of type %s) cannot be mapped to Django %s.' % - (ogr_field, ogr_field.__name__, fld_name)) + raise LayerMapError( + 'OGR field "%s" (of type %s) cannot be mapped to Django %s.' + % (ogr_field, ogr_field.__name__, fld_name) + ) fields_val = model_field self.fields[field_name] = fields_val @@ -279,7 +318,7 @@ class LayerMapping: sr = self.layer.srs if not sr: - raise LayerMapError('No source reference system defined.') + raise LayerMapError("No source reference system defined.") else: return sr @@ -295,7 +334,9 @@ class LayerMapping: if unique not in self.mapping: raise ValueError else: - raise TypeError('Unique keyword argument must be set with a tuple, list, or string.') + raise TypeError( + "Unique keyword argument must be set with a tuple, list, or string." + ) # Keyword argument retrieval routines #### def feature_kwargs(self, feat): @@ -316,7 +357,7 @@ class LayerMapping: try: val = self.verify_geom(feat.geom, model_field) except GDALException: - raise LayerMapError('Could not retrieve geometry from feature.') + raise LayerMapError("Could not retrieve geometry from feature.") elif isinstance(model_field, models.base.ModelBase): # The related _model_, not a field was passed in -- indicating # another mapping for the related Model. @@ -348,23 +389,34 @@ class LayerMapping: Verify if the OGR Field contents are acceptable to the model field. If they are, return the verified value, otherwise raise an exception. """ - if (isinstance(ogr_field, OFTString) and - isinstance(model_field, (models.CharField, models.TextField))): + if isinstance(ogr_field, OFTString) and isinstance( + model_field, (models.CharField, models.TextField) + ): if self.encoding and ogr_field.value is not None: # The encoding for OGR data sources may be specified here # (e.g., 'cp437' for Census Bureau boundary files). val = force_str(ogr_field.value, self.encoding) else: val = ogr_field.value - if model_field.max_length and val is not None and len(val) > model_field.max_length: - raise InvalidString('%s model field maximum string length is %s, given %s characters.' % - (model_field.name, model_field.max_length, len(val))) - elif isinstance(ogr_field, OFTReal) and isinstance(model_field, models.DecimalField): + if ( + model_field.max_length + and val is not None + and len(val) > model_field.max_length + ): + raise InvalidString( + "%s model field maximum string length is %s, given %s characters." + % (model_field.name, model_field.max_length, len(val)) + ) + elif isinstance(ogr_field, OFTReal) and isinstance( + model_field, models.DecimalField + ): try: # Creating an instance of the Decimal value to use. d = Decimal(str(ogr_field.value)) except DecimalInvalidOperation: - raise InvalidDecimal('Could not construct decimal from: %s' % ogr_field.value) + raise InvalidDecimal( + "Could not construct decimal from: %s" % ogr_field.value + ) # Getting the decimal value as a tuple. dtup = d.as_tuple() @@ -385,17 +437,21 @@ class LayerMapping: # InvalidDecimal exception. if n_prec > max_prec: raise InvalidDecimal( - 'A DecimalField with max_digits %d, decimal_places %d must ' - 'round to an absolute value less than 10^%d.' % - (model_field.max_digits, model_field.decimal_places, max_prec) + "A DecimalField with max_digits %d, decimal_places %d must " + "round to an absolute value less than 10^%d." + % (model_field.max_digits, model_field.decimal_places, max_prec) ) val = d - elif isinstance(ogr_field, (OFTReal, OFTString)) and isinstance(model_field, models.IntegerField): + elif isinstance(ogr_field, (OFTReal, OFTString)) and isinstance( + model_field, models.IntegerField + ): # Attempt to convert any OFTReal and OFTString value to an OFTInteger. try: val = int(ogr_field.value) except ValueError: - raise InvalidInteger('Could not construct integer from: %s' % ogr_field.value) + raise InvalidInteger( + "Could not construct integer from: %s" % ogr_field.value + ) else: val = ogr_field.value return val @@ -412,15 +468,17 @@ class LayerMapping: # Constructing and verifying the related model keyword arguments. fk_kwargs = {} for field_name, ogr_name in rel_mapping.items(): - fk_kwargs[field_name] = self.verify_ogr_field(feat[ogr_name], rel_model._meta.get_field(field_name)) + fk_kwargs[field_name] = self.verify_ogr_field( + feat[ogr_name], rel_model._meta.get_field(field_name) + ) # Attempting to retrieve and return the related model. try: return rel_model.objects.using(self.using).get(**fk_kwargs) except ObjectDoesNotExist: raise MissingForeignKey( - 'No ForeignKey %s model found with keyword arguments: %s' % - (rel_model.__name__, fk_kwargs) + "No ForeignKey %s model found with keyword arguments: %s" + % (rel_model.__name__, fk_kwargs) ) def verify_geom(self, geom, model_field): @@ -456,13 +514,17 @@ class LayerMapping: SpatialRefSys = self.spatial_backend.spatial_ref_sys() try: # Getting the target spatial reference system - target_srs = SpatialRefSys.objects.using(self.using).get(srid=self.geo_field.srid).srs + target_srs = ( + SpatialRefSys.objects.using(self.using) + .get(srid=self.geo_field.srid) + .srs + ) # Creating the CoordTransform object return CoordTransform(self.source_srs, target_srs) except Exception as exc: raise LayerMapError( - 'Could not translate between the data source and model geometry.' + "Could not translate between the data source and model geometry." ) from exc def geometry_field(self): @@ -477,11 +539,21 @@ class LayerMapping: Given the OGRGeomType for a geometry and its associated GeometryField, determine whether the geometry should be turned into a GeometryCollection. """ - return (geom_type.num in self.MULTI_TYPES and - model_field.__class__.__name__ == 'Multi%s' % geom_type.django) + return ( + geom_type.num in self.MULTI_TYPES + and model_field.__class__.__name__ == "Multi%s" % geom_type.django + ) - def save(self, verbose=False, fid_range=False, step=False, - progress=False, silent=False, stream=sys.stdout, strict=False): + def save( + self, + verbose=False, + fid_range=False, + step=False, + progress=False, + silent=False, + stream=sys.stdout, + strict=False, + ): """ Save the contents from the OGR DataSource Layer into the database according to the mapping dictionary given at initialization. @@ -547,7 +619,9 @@ class LayerMapping: if strict: raise elif not silent: - stream.write('Ignoring Feature ID %s because: %s\n' % (feat.fid, msg)) + stream.write( + "Ignoring Feature ID %s because: %s\n" % (feat.fid, msg) + ) else: # Constructing the model using the keyword args is_update = False @@ -585,23 +659,29 @@ class LayerMapping: m.save(using=self.using) num_saved += 1 if verbose: - stream.write('%s: %s\n' % ('Updated' if is_update else 'Saved', m)) + stream.write( + "%s: %s\n" % ("Updated" if is_update else "Saved", m) + ) except Exception as msg: if strict: # Bailing out if the `strict` keyword is set. if not silent: stream.write( - 'Failed to save the feature (id: %s) into the ' - 'model with the keyword arguments:\n' % feat.fid + "Failed to save the feature (id: %s) into the " + "model with the keyword arguments:\n" % feat.fid ) - stream.write('%s\n' % kwargs) + stream.write("%s\n" % kwargs) raise elif not silent: - stream.write('Failed to save %s:\n %s\nContinuing\n' % (kwargs, msg)) + stream.write( + "Failed to save %s:\n %s\nContinuing\n" % (kwargs, msg) + ) # Printing progress information, if requested. if progress and num_feat % progress_interval == 0: - stream.write('Processed %d features, saved %d ...\n' % (num_feat, num_saved)) + stream.write( + "Processed %d features, saved %d ...\n" % (num_feat, num_saved) + ) # Only used for status output purposes -- incremental saving uses the # values returned here. @@ -614,7 +694,9 @@ class LayerMapping: if step and isinstance(step, int) and step < nfeat: # Incremental saving is requested at the given interval (step) if default_range: - raise LayerMapError('The `step` keyword may not be used in conjunction with the `fid_range` keyword.') + raise LayerMapError( + "The `step` keyword may not be used in conjunction with the `fid_range` keyword." + ) beg, num_feat, num_saved = (0, 0, 0) indices = range(step, nfeat, step) n_i = len(indices) @@ -631,7 +713,9 @@ class LayerMapping: num_feat, num_saved = _save(step_slice, num_feat, num_saved) beg = end except Exception: # Deliberately catch everything - stream.write('%s\nFailed to save slice: %s\n' % ('=-' * 20, step_slice)) + stream.write( + "%s\nFailed to save slice: %s\n" % ("=-" * 20, step_slice) + ) raise else: # Otherwise, just calling the previously defined _save() function. diff --git a/django/contrib/gis/utils/ogrinfo.py b/django/contrib/gis/utils/ogrinfo.py index 324e41a0ad..eafa23ccae 100644 --- a/django/contrib/gis/utils/ogrinfo.py +++ b/django/contrib/gis/utils/ogrinfo.py @@ -20,7 +20,9 @@ def ogrinfo(data_source, num_features=10): elif isinstance(data_source, DataSource): pass else: - raise Exception('Data source parameter must be a string or a DataSource object.') + raise Exception( + "Data source parameter must be a string or a DataSource object." + ) for i, layer in enumerate(data_source): print("data source : %s" % data_source.name) @@ -44,8 +46,8 @@ def ogrinfo(data_source, num_features=10): if isinstance(val, str): val_fmt = ' ("%s")' else: - val_fmt = ' (%s)' + val_fmt = " (%s)" output += val_fmt % val else: - output += ' (None)' + output += " (None)" print(output) diff --git a/django/contrib/gis/utils/ogrinspect.py b/django/contrib/gis/utils/ogrinspect.py index 8c83fd5a43..c50b39ad14 100644 --- a/django/contrib/gis/utils/ogrinspect.py +++ b/django/contrib/gis/utils/ogrinspect.py @@ -5,12 +5,17 @@ models for GeoDjango and/or mapping dictionaries for use with the """ from django.contrib.gis.gdal import DataSource from django.contrib.gis.gdal.field import ( - OFTDate, OFTDateTime, OFTInteger, OFTInteger64, OFTReal, OFTString, + OFTDate, + OFTDateTime, + OFTInteger, + OFTInteger64, + OFTReal, + OFTString, OFTTime, ) -def mapping(data_source, geom_name='geom', layer_key=0, multi_geom=False): +def mapping(data_source, geom_name="geom", layer_key=0, multi_geom=False): """ Given a DataSource, generate a dictionary that may be used for invoking the LayerMapping utility. @@ -30,7 +35,9 @@ def mapping(data_source, geom_name='geom', layer_key=0, multi_geom=False): elif isinstance(data_source, DataSource): pass else: - raise TypeError('Data source parameter must be a string or a DataSource object.') + raise TypeError( + "Data source parameter must be a string or a DataSource object." + ) # Creating the dictionary. _mapping = {} @@ -38,8 +45,8 @@ def mapping(data_source, geom_name='geom', layer_key=0, multi_geom=False): # Generating the field name for each field in the layer. for field in data_source[layer_key].fields: mfield = field.lower() - if mfield[-1:] == '_': - mfield += 'field' + if mfield[-1:] == "_": + mfield += "field" _mapping[mfield] = field gtype = data_source[layer_key].geom_type if multi_geom: @@ -116,12 +123,22 @@ def ogrinspect(*args, **kwargs): Note: Call the _ogrinspect() helper to do the heavy lifting. """ - return '\n'.join(_ogrinspect(*args, **kwargs)) + return "\n".join(_ogrinspect(*args, **kwargs)) -def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=None, - multi_geom=False, name_field=None, imports=True, - decimal=False, blank=False, null=False): +def _ogrinspect( + data_source, + model_name, + geom_name="geom", + layer_key=0, + srid=None, + multi_geom=False, + name_field=None, + imports=True, + decimal=False, + blank=False, + null=False, +): """ Helper routine for `ogrinspect` that generates GeoDjango models corresponding to the given data source. See the `ogrinspect` docstring for more details. @@ -132,7 +149,9 @@ def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=Non elif isinstance(data_source, DataSource): pass else: - raise TypeError('Data source parameter must be a string or a DataSource object.') + raise TypeError( + "Data source parameter must be a string or a DataSource object." + ) # Getting the layer corresponding to the layer key and getting # a string listing of all OGR fields in the Layer. @@ -148,6 +167,7 @@ def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=Non return [s.lower() for s in ogr_fields] else: return [] + null_fields = process_kwarg(null) blank_fields = process_kwarg(blank) decimal_fields = process_kwarg(decimal) @@ -156,29 +176,30 @@ def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=Non def get_kwargs_str(field_name): kwlist = [] if field_name.lower() in null_fields: - kwlist.append('null=True') + kwlist.append("null=True") if field_name.lower() in blank_fields: - kwlist.append('blank=True') + kwlist.append("blank=True") if kwlist: - return ', ' + ', '.join(kwlist) + return ", " + ", ".join(kwlist) else: - return '' + return "" # For those wishing to disable the imports. if imports: - yield '# This is an auto-generated Django model module created by ogrinspect.' - yield 'from django.contrib.gis.db import models' - yield '' - yield '' + yield "# This is an auto-generated Django model module created by ogrinspect." + yield "from django.contrib.gis.db import models" + yield "" + yield "" - yield 'class %s(models.Model):' % model_name + yield "class %s(models.Model):" % model_name for field_name, width, precision, field_type in zip( - ogr_fields, layer.field_widths, layer.field_precisions, layer.field_types): + ogr_fields, layer.field_widths, layer.field_precisions, layer.field_types + ): # The model field name. mfield = field_name.lower() - if mfield[-1:] == '_': - mfield += 'field' + if mfield[-1:] == "_": + mfield += "field" # Getting the keyword args string. kwargs_str = get_kwargs_str(field_name) @@ -188,25 +209,32 @@ def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=Non # may also be mapped to `DecimalField` if specified in the # `decimal` keyword. if field_name.lower() in decimal_fields: - yield ' %s = models.DecimalField(max_digits=%d, decimal_places=%d%s)' % ( - mfield, width, precision, kwargs_str + yield " %s = models.DecimalField(max_digits=%d, decimal_places=%d%s)" % ( + mfield, + width, + precision, + kwargs_str, ) else: - yield ' %s = models.FloatField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.FloatField(%s)" % (mfield, kwargs_str[2:]) elif field_type is OFTInteger: - yield ' %s = models.IntegerField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.IntegerField(%s)" % (mfield, kwargs_str[2:]) elif field_type is OFTInteger64: - yield ' %s = models.BigIntegerField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.BigIntegerField(%s)" % (mfield, kwargs_str[2:]) elif field_type is OFTString: - yield ' %s = models.CharField(max_length=%s%s)' % (mfield, width, kwargs_str) + yield " %s = models.CharField(max_length=%s%s)" % ( + mfield, + width, + kwargs_str, + ) elif field_type is OFTDate: - yield ' %s = models.DateField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.DateField(%s)" % (mfield, kwargs_str[2:]) elif field_type is OFTDateTime: - yield ' %s = models.DateTimeField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.DateTimeField(%s)" % (mfield, kwargs_str[2:]) elif field_type is OFTTime: - yield ' %s = models.TimeField(%s)' % (mfield, kwargs_str[2:]) + yield " %s = models.TimeField(%s)" % (mfield, kwargs_str[2:]) else: - raise TypeError('Unknown field type %s in %s' % (field_type, mfield)) + raise TypeError("Unknown field type %s in %s" % (field_type, mfield)) # TODO: Autodetection of multigeometry types (see #7218). gtype = layer.geom_type @@ -217,21 +245,21 @@ def _ogrinspect(data_source, model_name, geom_name='geom', layer_key=0, srid=Non # Setting up the SRID keyword string. if srid is None: if layer.srs is None: - srid_str = 'srid=-1' + srid_str = "srid=-1" else: srid = layer.srs.srid if srid is None: - srid_str = 'srid=-1' + srid_str = "srid=-1" elif srid == 4326: # WGS84 is already the default. - srid_str = '' + srid_str = "" else: - srid_str = 'srid=%s' % srid + srid_str = "srid=%s" % srid else: - srid_str = 'srid=%s' % srid + srid_str = "srid=%s" % srid - yield ' %s = models.%s(%s)' % (geom_name, geom_field, srid_str) + yield " %s = models.%s(%s)" % (geom_name, geom_field, srid_str) if name_field: - yield '' - yield ' def __str__(self): return self.%s' % name_field + yield "" + yield " def __str__(self): return self.%s" % name_field diff --git a/django/contrib/gis/utils/srs.py b/django/contrib/gis/utils/srs.py index d44d340383..2204767923 100644 --- a/django/contrib/gis/utils/srs.py +++ b/django/contrib/gis/utils/srs.py @@ -2,8 +2,9 @@ from django.contrib.gis.gdal import SpatialReference from django.db import DEFAULT_DB_ALIAS, connections -def add_srs_entry(srs, auth_name='EPSG', auth_srid=None, ref_sys_name=None, - database=None): +def add_srs_entry( + srs, auth_name="EPSG", auth_srid=None, ref_sys_name=None, database=None +): """ Take a GDAL SpatialReference system and add its information to the `spatial_ref_sys` table of the spatial backend. Doing this enables @@ -35,12 +36,10 @@ def add_srs_entry(srs, auth_name='EPSG', auth_srid=None, ref_sys_name=None, database = database or DEFAULT_DB_ALIAS connection = connections[database] - if not hasattr(connection.ops, 'spatial_version'): - raise Exception( - 'The `add_srs_entry` utility only works with spatial backends.' - ) + if not hasattr(connection.ops, "spatial_version"): + raise Exception("The `add_srs_entry` utility only works with spatial backends.") if not connection.features.supports_add_srs_entry: - raise Exception('This utility does not support your database backend.') + raise Exception("This utility does not support your database backend.") SpatialRefSys = connection.ops.spatial_ref_sys() # If argument is not a `SpatialReference` instance, use it as parameter @@ -49,24 +48,26 @@ def add_srs_entry(srs, auth_name='EPSG', auth_srid=None, ref_sys_name=None, srs = SpatialReference(srs) if srs.srid is None: - raise Exception('Spatial reference requires an SRID to be ' - 'compatible with the spatial backend.') + raise Exception( + "Spatial reference requires an SRID to be " + "compatible with the spatial backend." + ) # Initializing the keyword arguments dictionary for both PostGIS # and SpatiaLite. kwargs = { - 'srid': srs.srid, - 'auth_name': auth_name, - 'auth_srid': auth_srid or srs.srid, - 'proj4text': srs.proj4, + "srid": srs.srid, + "auth_name": auth_name, + "auth_srid": auth_srid or srs.srid, + "proj4text": srs.proj4, } # Backend-specific fields for the SpatialRefSys model. srs_field_names = {f.name for f in SpatialRefSys._meta.get_fields()} - if 'srtext' in srs_field_names: - kwargs['srtext'] = srs.wkt - if 'ref_sys_name' in srs_field_names: + if "srtext" in srs_field_names: + kwargs["srtext"] = srs.wkt + if "ref_sys_name" in srs_field_names: # SpatiaLite specific - kwargs['ref_sys_name'] = ref_sys_name or srs.name + kwargs["ref_sys_name"] = ref_sys_name or srs.name # Creating the spatial_ref_sys model. try: diff --git a/django/contrib/gis/views.py b/django/contrib/gis/views.py index 5b29db9fb2..346fb5b0c0 100644 --- a/django/contrib/gis/views.py +++ b/django/contrib/gis/views.py @@ -7,14 +7,16 @@ def feed(request, url, feed_dict=None): if not feed_dict: raise Http404(_("No feeds are registered.")) - slug = url.partition('/')[0] + slug = url.partition("/")[0] try: f = feed_dict[slug] except KeyError: - raise Http404(_('Slug %r isn’t registered.') % slug) + raise Http404(_("Slug %r isn’t registered.") % slug) instance = f() - instance.feed_url = getattr(f, 'feed_url', None) or request.path - instance.title_template = f.title_template or ('feeds/%s_title.html' % slug) - instance.description_template = f.description_template or ('feeds/%s_description.html' % slug) + instance.feed_url = getattr(f, "feed_url", None) or request.path + instance.title_template = f.title_template or ("feeds/%s_title.html" % slug) + instance.description_template = f.description_template or ( + "feeds/%s_description.html" % slug + ) return instance(request) diff --git a/django/contrib/humanize/apps.py b/django/contrib/humanize/apps.py index c5fcbca794..817ae96f68 100644 --- a/django/contrib/humanize/apps.py +++ b/django/contrib/humanize/apps.py @@ -3,5 +3,5 @@ from django.utils.translation import gettext_lazy as _ class HumanizeConfig(AppConfig): - name = 'django.contrib.humanize' + name = "django.contrib.humanize" verbose_name = _("Humanize") diff --git a/django/contrib/humanize/templatetags/humanize.py b/django/contrib/humanize/templatetags/humanize.py index 095a053fa1..bdceec344d 100644 --- a/django/contrib/humanize/templatetags/humanize.py +++ b/django/contrib/humanize/templatetags/humanize.py @@ -7,9 +7,14 @@ from django.template import defaultfilters from django.utils.formats import number_format from django.utils.safestring import mark_safe from django.utils.timezone import is_aware, utc +from django.utils.translation import gettext as _ from django.utils.translation import ( - gettext as _, gettext_lazy, ngettext, ngettext_lazy, npgettext_lazy, - pgettext, round_away_from_one, + gettext_lazy, + ngettext, + ngettext_lazy, + npgettext_lazy, + pgettext, + round_away_from_one, ) register = template.Library() @@ -27,29 +32,29 @@ def ordinal(value): return value if value % 100 in (11, 12, 13): # Translators: Ordinal format for 11 (11th), 12 (12th), and 13 (13th). - value = pgettext('ordinal 11, 12, 13', '{}th').format(value) + value = pgettext("ordinal 11, 12, 13", "{}th").format(value) else: templates = ( # Translators: Ordinal format when value ends with 0, e.g. 80th. - pgettext('ordinal 0', '{}th'), + pgettext("ordinal 0", "{}th"), # Translators: Ordinal format when value ends with 1, e.g. 81st, except 11. - pgettext('ordinal 1', '{}st'), + pgettext("ordinal 1", "{}st"), # Translators: Ordinal format when value ends with 2, e.g. 82nd, except 12. - pgettext('ordinal 2', '{}nd'), + pgettext("ordinal 2", "{}nd"), # Translators: Ordinal format when value ends with 3, e.g. 83th, except 13. - pgettext('ordinal 3', '{}rd'), + pgettext("ordinal 3", "{}rd"), # Translators: Ordinal format when value ends with 4, e.g. 84th. - pgettext('ordinal 4', '{}th'), + pgettext("ordinal 4", "{}th"), # Translators: Ordinal format when value ends with 5, e.g. 85th. - pgettext('ordinal 5', '{}th'), + pgettext("ordinal 5", "{}th"), # Translators: Ordinal format when value ends with 6, e.g. 86th. - pgettext('ordinal 6', '{}th'), + pgettext("ordinal 6", "{}th"), # Translators: Ordinal format when value ends with 7, e.g. 87th. - pgettext('ordinal 7', '{}th'), + pgettext("ordinal 7", "{}th"), # Translators: Ordinal format when value ends with 8, e.g. 88th. - pgettext('ordinal 8', '{}th'), + pgettext("ordinal 8", "{}th"), # Translators: Ordinal format when value ends with 9, e.g. 89th. - pgettext('ordinal 9', '{}th'), + pgettext("ordinal 9", "{}th"), ) value = templates[value % 10].format(value) # Mark value safe so i18n does not break with <sup> or <sub> see #19988 @@ -71,7 +76,7 @@ def intcomma(value, use_l10n=True): else: return number_format(value, use_l10n=True, force_grouping=True) orig = str(value) - new = re.sub(r"^(-?\d+)(\d{3})", r'\g<1>,\g<2>', orig) + new = re.sub(r"^(-?\d+)(\d{3})", r"\g<1>,\g<2>", orig) if orig == new: return new else: @@ -80,17 +85,33 @@ def intcomma(value, use_l10n=True): # A tuple of standard large number to their converters intword_converters = ( - (6, lambda number: ngettext('%(value)s million', '%(value)s million', number)), - (9, lambda number: ngettext('%(value)s billion', '%(value)s billion', number)), - (12, lambda number: ngettext('%(value)s trillion', '%(value)s trillion', number)), - (15, lambda number: ngettext('%(value)s quadrillion', '%(value)s quadrillion', number)), - (18, lambda number: ngettext('%(value)s quintillion', '%(value)s quintillion', number)), - (21, lambda number: ngettext('%(value)s sextillion', '%(value)s sextillion', number)), - (24, lambda number: ngettext('%(value)s septillion', '%(value)s septillion', number)), - (27, lambda number: ngettext('%(value)s octillion', '%(value)s octillion', number)), - (30, lambda number: ngettext('%(value)s nonillion', '%(value)s nonillion', number)), - (33, lambda number: ngettext('%(value)s decillion', '%(value)s decillion', number)), - (100, lambda number: ngettext('%(value)s googol', '%(value)s googol', number)), + (6, lambda number: ngettext("%(value)s million", "%(value)s million", number)), + (9, lambda number: ngettext("%(value)s billion", "%(value)s billion", number)), + (12, lambda number: ngettext("%(value)s trillion", "%(value)s trillion", number)), + ( + 15, + lambda number: ngettext( + "%(value)s quadrillion", "%(value)s quadrillion", number + ), + ), + ( + 18, + lambda number: ngettext( + "%(value)s quintillion", "%(value)s quintillion", number + ), + ), + ( + 21, + lambda number: ngettext("%(value)s sextillion", "%(value)s sextillion", number), + ), + ( + 24, + lambda number: ngettext("%(value)s septillion", "%(value)s septillion", number), + ), + (27, lambda number: ngettext("%(value)s octillion", "%(value)s octillion", number)), + (30, lambda number: ngettext("%(value)s nonillion", "%(value)s nonillion", number)), + (33, lambda number: ngettext("%(value)s decillion", "%(value)s decillion", number)), + (100, lambda number: ngettext("%(value)s googol", "%(value)s googol", number)), ) @@ -111,12 +132,12 @@ def intword(value): return value for exponent, converter in intword_converters: - large_number = 10 ** exponent + large_number = 10**exponent if abs_value < large_number * 1000: new_value = value / large_number rounded_value = round_away_from_one(new_value) return converter(abs(rounded_value)) % { - 'value': defaultfilters.floatformat(new_value, 1), + "value": defaultfilters.floatformat(new_value, 1), } return value @@ -133,8 +154,17 @@ def apnumber(value): return value if not 0 < value < 10: return value - return (_('one'), _('two'), _('three'), _('four'), _('five'), - _('six'), _('seven'), _('eight'), _('nine'))[value - 1] + return ( + _("one"), + _("two"), + _("three"), + _("four"), + _("five"), + _("six"), + _("seven"), + _("eight"), + _("nine"), + )[value - 1] # Perform the comparison in the default time zone when USE_TZ = True @@ -146,7 +176,7 @@ def naturalday(value, arg=None): present day return representing string. Otherwise, return a string formatted according to settings.DATE_FORMAT. """ - tzinfo = getattr(value, 'tzinfo', None) + tzinfo = getattr(value, "tzinfo", None) try: value = date(value.year, value.month, value.day) except AttributeError: @@ -155,11 +185,11 @@ def naturalday(value, arg=None): today = datetime.now(tzinfo).date() delta = value - today if delta.days == 0: - return _('today') + return _("today") elif delta.days == 1: - return _('tomorrow') + return _("tomorrow") elif delta.days == -1: - return _('yesterday') + return _("yesterday") return defaultfilters.date(value, arg) @@ -177,46 +207,74 @@ def naturaltime(value): class NaturalTimeFormatter: time_strings = { # Translators: delta will contain a string like '2 months' or '1 month, 2 weeks' - 'past-day': gettext_lazy('%(delta)s ago'), + "past-day": gettext_lazy("%(delta)s ago"), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'past-hour': ngettext_lazy('an hour ago', '%(count)s hours ago', 'count'), + "past-hour": ngettext_lazy("an hour ago", "%(count)s hours ago", "count"), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'past-minute': ngettext_lazy('a minute ago', '%(count)s minutes ago', 'count'), + "past-minute": ngettext_lazy("a minute ago", "%(count)s minutes ago", "count"), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'past-second': ngettext_lazy('a second ago', '%(count)s seconds ago', 'count'), - 'now': gettext_lazy('now'), + "past-second": ngettext_lazy("a second ago", "%(count)s seconds ago", "count"), + "now": gettext_lazy("now"), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'future-second': ngettext_lazy('a second from now', '%(count)s seconds from now', 'count'), + "future-second": ngettext_lazy( + "a second from now", "%(count)s seconds from now", "count" + ), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'future-minute': ngettext_lazy('a minute from now', '%(count)s minutes from now', 'count'), + "future-minute": ngettext_lazy( + "a minute from now", "%(count)s minutes from now", "count" + ), # Translators: please keep a non-breaking space (U+00A0) between count # and time unit. - 'future-hour': ngettext_lazy('an hour from now', '%(count)s hours from now', 'count'), + "future-hour": ngettext_lazy( + "an hour from now", "%(count)s hours from now", "count" + ), # Translators: delta will contain a string like '2 months' or '1 month, 2 weeks' - 'future-day': gettext_lazy('%(delta)s from now'), + "future-day": gettext_lazy("%(delta)s from now"), } past_substrings = { # Translators: 'naturaltime-past' strings will be included in '%(delta)s ago' - 'year': npgettext_lazy('naturaltime-past', '%(num)d year', '%(num)d years', 'num'), - 'month': npgettext_lazy('naturaltime-past', '%(num)d month', '%(num)d months', 'num'), - 'week': npgettext_lazy('naturaltime-past', '%(num)d week', '%(num)d weeks', 'num'), - 'day': npgettext_lazy('naturaltime-past', '%(num)d day', '%(num)d days', 'num'), - 'hour': npgettext_lazy('naturaltime-past', '%(num)d hour', '%(num)d hours', 'num'), - 'minute': npgettext_lazy('naturaltime-past', '%(num)d minute', '%(num)d minutes', 'num'), + "year": npgettext_lazy( + "naturaltime-past", "%(num)d year", "%(num)d years", "num" + ), + "month": npgettext_lazy( + "naturaltime-past", "%(num)d month", "%(num)d months", "num" + ), + "week": npgettext_lazy( + "naturaltime-past", "%(num)d week", "%(num)d weeks", "num" + ), + "day": npgettext_lazy("naturaltime-past", "%(num)d day", "%(num)d days", "num"), + "hour": npgettext_lazy( + "naturaltime-past", "%(num)d hour", "%(num)d hours", "num" + ), + "minute": npgettext_lazy( + "naturaltime-past", "%(num)d minute", "%(num)d minutes", "num" + ), } future_substrings = { # Translators: 'naturaltime-future' strings will be included in '%(delta)s from now' - 'year': npgettext_lazy('naturaltime-future', '%(num)d year', '%(num)d years', 'num'), - 'month': npgettext_lazy('naturaltime-future', '%(num)d month', '%(num)d months', 'num'), - 'week': npgettext_lazy('naturaltime-future', '%(num)d week', '%(num)d weeks', 'num'), - 'day': npgettext_lazy('naturaltime-future', '%(num)d day', '%(num)d days', 'num'), - 'hour': npgettext_lazy('naturaltime-future', '%(num)d hour', '%(num)d hours', 'num'), - 'minute': npgettext_lazy('naturaltime-future', '%(num)d minute', '%(num)d minutes', 'num'), + "year": npgettext_lazy( + "naturaltime-future", "%(num)d year", "%(num)d years", "num" + ), + "month": npgettext_lazy( + "naturaltime-future", "%(num)d month", "%(num)d months", "num" + ), + "week": npgettext_lazy( + "naturaltime-future", "%(num)d week", "%(num)d weeks", "num" + ), + "day": npgettext_lazy( + "naturaltime-future", "%(num)d day", "%(num)d days", "num" + ), + "hour": npgettext_lazy( + "naturaltime-future", "%(num)d hour", "%(num)d hours", "num" + ), + "minute": npgettext_lazy( + "naturaltime-future", "%(num)d minute", "%(num)d minutes", "num" + ), } @classmethod @@ -228,32 +286,36 @@ class NaturalTimeFormatter: if value < now: delta = now - value if delta.days != 0: - return cls.time_strings['past-day'] % { - 'delta': defaultfilters.timesince(value, now, time_strings=cls.past_substrings), + return cls.time_strings["past-day"] % { + "delta": defaultfilters.timesince( + value, now, time_strings=cls.past_substrings + ), } elif delta.seconds == 0: - return cls.time_strings['now'] + return cls.time_strings["now"] elif delta.seconds < 60: - return cls.time_strings['past-second'] % {'count': delta.seconds} + return cls.time_strings["past-second"] % {"count": delta.seconds} elif delta.seconds // 60 < 60: count = delta.seconds // 60 - return cls.time_strings['past-minute'] % {'count': count} + return cls.time_strings["past-minute"] % {"count": count} else: count = delta.seconds // 60 // 60 - return cls.time_strings['past-hour'] % {'count': count} + return cls.time_strings["past-hour"] % {"count": count} else: delta = value - now if delta.days != 0: - return cls.time_strings['future-day'] % { - 'delta': defaultfilters.timeuntil(value, now, time_strings=cls.future_substrings), + return cls.time_strings["future-day"] % { + "delta": defaultfilters.timeuntil( + value, now, time_strings=cls.future_substrings + ), } elif delta.seconds == 0: - return cls.time_strings['now'] + return cls.time_strings["now"] elif delta.seconds < 60: - return cls.time_strings['future-second'] % {'count': delta.seconds} + return cls.time_strings["future-second"] % {"count": delta.seconds} elif delta.seconds // 60 < 60: count = delta.seconds // 60 - return cls.time_strings['future-minute'] % {'count': count} + return cls.time_strings["future-minute"] % {"count": count} else: count = delta.seconds // 60 // 60 - return cls.time_strings['future-hour'] % {'count': count} + return cls.time_strings["future-hour"] % {"count": count} diff --git a/django/contrib/messages/api.py b/django/contrib/messages/api.py index f0da16818a..7a67e8b4b0 100644 --- a/django/contrib/messages/api.py +++ b/django/contrib/messages/api.py @@ -2,10 +2,16 @@ from django.contrib.messages import constants from django.contrib.messages.storage import default_storage __all__ = ( - 'add_message', 'get_messages', - 'get_level', 'set_level', - 'debug', 'info', 'success', 'warning', 'error', - 'MessageFailure', + "add_message", + "get_messages", + "get_level", + "set_level", + "debug", + "info", + "success", + "warning", + "error", + "MessageFailure", ) @@ -13,22 +19,22 @@ class MessageFailure(Exception): pass -def add_message(request, level, message, extra_tags='', fail_silently=False): +def add_message(request, level, message, extra_tags="", fail_silently=False): """ Attempt to add a message to the request using the 'messages' app. """ try: messages = request._messages except AttributeError: - if not hasattr(request, 'META'): + if not hasattr(request, "META"): raise TypeError( "add_message() argument must be an HttpRequest object, not " "'%s'." % request.__class__.__name__ ) if not fail_silently: raise MessageFailure( - 'You cannot add messages without installing ' - 'django.contrib.messages.middleware.MessageMiddleware' + "You cannot add messages without installing " + "django.contrib.messages.middleware.MessageMiddleware" ) else: return messages.add(level, message, extra_tags) @@ -39,7 +45,7 @@ def get_messages(request): Return the message storage on the request if it exists, otherwise return an empty list. """ - return getattr(request, '_messages', []) + return getattr(request, "_messages", []) def get_level(request): @@ -49,7 +55,7 @@ def get_level(request): The default level is the ``MESSAGE_LEVEL`` setting. If this is not found, use the ``INFO`` level. """ - storage = getattr(request, '_messages', default_storage(request)) + storage = getattr(request, "_messages", default_storage(request)) return storage.level @@ -60,37 +66,62 @@ def set_level(request, level): If set to ``None``, use the default level (see the get_level() function). """ - if not hasattr(request, '_messages'): + if not hasattr(request, "_messages"): return False request._messages.level = level return True -def debug(request, message, extra_tags='', fail_silently=False): +def debug(request, message, extra_tags="", fail_silently=False): """Add a message with the ``DEBUG`` level.""" - add_message(request, constants.DEBUG, message, extra_tags=extra_tags, - fail_silently=fail_silently) + add_message( + request, + constants.DEBUG, + message, + extra_tags=extra_tags, + fail_silently=fail_silently, + ) -def info(request, message, extra_tags='', fail_silently=False): +def info(request, message, extra_tags="", fail_silently=False): """Add a message with the ``INFO`` level.""" - add_message(request, constants.INFO, message, extra_tags=extra_tags, - fail_silently=fail_silently) + add_message( + request, + constants.INFO, + message, + extra_tags=extra_tags, + fail_silently=fail_silently, + ) -def success(request, message, extra_tags='', fail_silently=False): +def success(request, message, extra_tags="", fail_silently=False): """Add a message with the ``SUCCESS`` level.""" - add_message(request, constants.SUCCESS, message, extra_tags=extra_tags, - fail_silently=fail_silently) + add_message( + request, + constants.SUCCESS, + message, + extra_tags=extra_tags, + fail_silently=fail_silently, + ) -def warning(request, message, extra_tags='', fail_silently=False): +def warning(request, message, extra_tags="", fail_silently=False): """Add a message with the ``WARNING`` level.""" - add_message(request, constants.WARNING, message, extra_tags=extra_tags, - fail_silently=fail_silently) + add_message( + request, + constants.WARNING, + message, + extra_tags=extra_tags, + fail_silently=fail_silently, + ) -def error(request, message, extra_tags='', fail_silently=False): +def error(request, message, extra_tags="", fail_silently=False): """Add a message with the ``ERROR`` level.""" - add_message(request, constants.ERROR, message, extra_tags=extra_tags, - fail_silently=fail_silently) + add_message( + request, + constants.ERROR, + message, + extra_tags=extra_tags, + fail_silently=fail_silently, + ) diff --git a/django/contrib/messages/apps.py b/django/contrib/messages/apps.py index 0b02f7d851..09a9554b70 100644 --- a/django/contrib/messages/apps.py +++ b/django/contrib/messages/apps.py @@ -6,12 +6,12 @@ from django.utils.translation import gettext_lazy as _ def update_level_tags(setting, **kwargs): - if setting == 'MESSAGE_TAGS': + if setting == "MESSAGE_TAGS": base.LEVEL_TAGS = get_level_tags() class MessagesConfig(AppConfig): - name = 'django.contrib.messages' + name = "django.contrib.messages" verbose_name = _("Messages") def ready(self): diff --git a/django/contrib/messages/constants.py b/django/contrib/messages/constants.py index b43144872c..92c92422be 100644 --- a/django/contrib/messages/constants.py +++ b/django/contrib/messages/constants.py @@ -5,17 +5,17 @@ WARNING = 30 ERROR = 40 DEFAULT_TAGS = { - DEBUG: 'debug', - INFO: 'info', - SUCCESS: 'success', - WARNING: 'warning', - ERROR: 'error', + DEBUG: "debug", + INFO: "info", + SUCCESS: "success", + WARNING: "warning", + ERROR: "error", } DEFAULT_LEVELS = { - 'DEBUG': DEBUG, - 'INFO': INFO, - 'SUCCESS': SUCCESS, - 'WARNING': WARNING, - 'ERROR': ERROR, + "DEBUG": DEBUG, + "INFO": INFO, + "SUCCESS": SUCCESS, + "WARNING": WARNING, + "ERROR": ERROR, } diff --git a/django/contrib/messages/context_processors.py b/django/contrib/messages/context_processors.py index b4b20956a6..e01cc31664 100644 --- a/django/contrib/messages/context_processors.py +++ b/django/contrib/messages/context_processors.py @@ -8,6 +8,6 @@ def messages(request): 'DEFAULT_MESSAGE_LEVELS'. """ return { - 'messages': get_messages(request), - 'DEFAULT_MESSAGE_LEVELS': DEFAULT_LEVELS, + "messages": get_messages(request), + "DEFAULT_MESSAGE_LEVELS": DEFAULT_LEVELS, } diff --git a/django/contrib/messages/middleware.py b/django/contrib/messages/middleware.py index d5b787cee7..00870318d1 100644 --- a/django/contrib/messages/middleware.py +++ b/django/contrib/messages/middleware.py @@ -19,8 +19,8 @@ class MessageMiddleware(MiddlewareMixin): """ # A higher middleware layer may return a request which does not contain # messages storage, so make no assumption that it will be there. - if hasattr(request, '_messages'): + if hasattr(request, "_messages"): unstored_messages = request._messages.update(response) if unstored_messages and settings.DEBUG: - raise ValueError('Not all temporary messages could be stored.') + raise ValueError("Not all temporary messages could be stored.") return response diff --git a/django/contrib/messages/storage/base.py b/django/contrib/messages/storage/base.py index 01422066a6..61c5758aab 100644 --- a/django/contrib/messages/storage/base.py +++ b/django/contrib/messages/storage/base.py @@ -34,11 +34,11 @@ class Message: @property def tags(self): - return ' '.join(tag for tag in [self.extra_tags, self.level_tag] if tag) + return " ".join(tag for tag in [self.extra_tags, self.level_tag] if tag) @property def level_tag(self): - return LEVEL_TAGS.get(self.level, '') + return LEVEL_TAGS.get(self.level, "") class BaseStorage: @@ -70,7 +70,7 @@ class BaseStorage: return item in self._loaded_messages or item in self._queued_messages def __repr__(self): - return f'<{self.__class__.__qualname__}: request={self.request!r}>' + return f"<{self.__class__.__qualname__}: request={self.request!r}>" @property def _loaded_messages(self): @@ -78,7 +78,7 @@ class BaseStorage: Return a list of loaded messages, retrieving them first if they have not been loaded yet. """ - if not hasattr(self, '_loaded_data'): + if not hasattr(self, "_loaded_data"): messages, all_retrieved = self._get() self._loaded_data = messages or [] return self._loaded_data @@ -96,7 +96,9 @@ class BaseStorage: just containing no messages) then ``None`` should be returned in place of ``messages``. """ - raise NotImplementedError('subclasses of BaseStorage must provide a _get() method') + raise NotImplementedError( + "subclasses of BaseStorage must provide a _get() method" + ) def _store(self, messages, response, *args, **kwargs): """ @@ -107,7 +109,9 @@ class BaseStorage: **This method must be implemented by a subclass.** """ - raise NotImplementedError('subclasses of BaseStorage must provide a _store() method') + raise NotImplementedError( + "subclasses of BaseStorage must provide a _store() method" + ) def _prepare_messages(self, messages): """ @@ -130,7 +134,7 @@ class BaseStorage: messages = self._loaded_messages + self._queued_messages return self._store(messages, response) - def add(self, level, message, extra_tags=''): + def add(self, level, message, extra_tags=""): """ Queue a message to be stored. @@ -155,8 +159,8 @@ class BaseStorage: The default level is the ``MESSAGE_LEVEL`` setting. If this is not found, the ``INFO`` level is used. """ - if not hasattr(self, '_level'): - self._level = getattr(settings, 'MESSAGE_LEVEL', constants.INFO) + if not hasattr(self, "_level"): + self._level = getattr(settings, "MESSAGE_LEVEL", constants.INFO) return self._level def _set_level(self, value=None): @@ -166,7 +170,7 @@ class BaseStorage: If set to ``None``, the default level will be used (see the ``_get_level`` method). """ - if value is None and hasattr(self, '_level'): + if value is None and hasattr(self, "_level"): del self._level else: self._level = int(value) diff --git a/django/contrib/messages/storage/cookie.py b/django/contrib/messages/storage/cookie.py index e48d4edb08..0fd7ab60b6 100644 --- a/django/contrib/messages/storage/cookie.py +++ b/django/contrib/messages/storage/cookie.py @@ -12,7 +12,8 @@ class MessageEncoder(json.JSONEncoder): """ Compactly serialize instances of the ``Message`` class as JSON. """ - message_key = '__json_message' + + message_key = "__json_message" def default(self, obj): if isinstance(obj, Message): @@ -38,8 +39,7 @@ class MessageDecoder(json.JSONDecoder): return Message(*obj[2:]) return [self.process_messages(item) for item in obj] if isinstance(obj, dict): - return {key: self.process_messages(value) - for key, value in obj.items()} + return {key: self.process_messages(value) for key, value in obj.items()} return obj def decode(self, s, **kwargs): @@ -51,25 +51,26 @@ class MessageSerializer: def dumps(self, obj): return json.dumps( obj, - separators=(',', ':'), + separators=(",", ":"), cls=MessageEncoder, - ).encode('latin-1') + ).encode("latin-1") def loads(self, data): - return json.loads(data.decode('latin-1'), cls=MessageDecoder) + return json.loads(data.decode("latin-1"), cls=MessageDecoder) class CookieStorage(BaseStorage): """ Store messages in a cookie. """ - cookie_name = 'messages' + + cookie_name = "messages" # uwsgi's default configuration enforces a maximum size of 4kb for all the # HTTP headers. In order to leave some room for other cookies and headers, # restrict the session cookie to 1/2 of 4kb. See #18781. max_cookie_size = 2048 - not_finished = '__messagesnotfinished__' - key_salt = 'django.contrib.messages' + not_finished = "__messagesnotfinished__" + key_salt = "django.contrib.messages" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -97,7 +98,8 @@ class CookieStorage(BaseStorage): """ if encoded_data: response.set_cookie( - self.cookie_name, encoded_data, + self.cookie_name, + encoded_data, domain=settings.SESSION_COOKIE_DOMAIN, secure=settings.SESSION_COOKIE_SECURE or None, httponly=settings.SESSION_COOKIE_HTTPONLY or None, @@ -134,8 +136,9 @@ class CookieStorage(BaseStorage): unstored_messages.append(messages.pop(0)) else: unstored_messages.insert(0, messages.pop()) - encoded_data = self._encode(messages + [self.not_finished], - encode_empty=unstored_messages) + encoded_data = self._encode( + messages + [self.not_finished], encode_empty=unstored_messages + ) self._update_cookie(encoded_data, response) return unstored_messages @@ -148,7 +151,9 @@ class CookieStorage(BaseStorage): also contains a hash to ensure that the data was not tampered with. """ if messages or encode_empty: - return self.signer.sign_object(messages, serializer=MessageSerializer, compress=True) + return self.signer.sign_object( + messages, serializer=MessageSerializer, compress=True + ) def _decode(self, data): """ diff --git a/django/contrib/messages/storage/fallback.py b/django/contrib/messages/storage/fallback.py index 39df6f3c9d..44e26fab64 100644 --- a/django/contrib/messages/storage/fallback.py +++ b/django/contrib/messages/storage/fallback.py @@ -8,12 +8,14 @@ class FallbackStorage(BaseStorage): Try to store all messages in the first backend. Store any unstored messages in each subsequent backend. """ + storage_classes = (CookieStorage, SessionStorage) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.storages = [storage_class(*args, **kwargs) - for storage_class in self.storage_classes] + self.storages = [ + storage_class(*args, **kwargs) for storage_class in self.storage_classes + ] self._used_storages = set() def _get(self, *args, **kwargs): diff --git a/django/contrib/messages/storage/session.py b/django/contrib/messages/storage/session.py index 6ec0a21cc2..3271c9258f 100644 --- a/django/contrib/messages/storage/session.py +++ b/django/contrib/messages/storage/session.py @@ -1,9 +1,7 @@ import json from django.contrib.messages.storage.base import BaseStorage -from django.contrib.messages.storage.cookie import ( - MessageDecoder, MessageEncoder, -) +from django.contrib.messages.storage.cookie import MessageDecoder, MessageEncoder from django.core.exceptions import ImproperlyConfigured @@ -11,14 +9,15 @@ class SessionStorage(BaseStorage): """ Store messages in the session (that is, django.contrib.sessions). """ - session_key = '_messages' + + session_key = "_messages" def __init__(self, request, *args, **kwargs): - if not hasattr(request, 'session'): + if not hasattr(request, "session"): raise ImproperlyConfigured( - 'The session-based temporary message storage requires session ' - 'middleware to be installed, and come before the message ' - 'middleware in the MIDDLEWARE list.' + "The session-based temporary message storage requires session " + "middleware to be installed, and come before the message " + "middleware in the MIDDLEWARE list." ) super().__init__(request, *args, **kwargs) @@ -28,7 +27,10 @@ class SessionStorage(BaseStorage): always stores everything it is given, so return True for the all_retrieved flag. """ - return self.deserialize_messages(self.request.session.get(self.session_key)), True + return ( + self.deserialize_messages(self.request.session.get(self.session_key)), + True, + ) def _store(self, messages, response, *args, **kwargs): """ diff --git a/django/contrib/messages/utils.py b/django/contrib/messages/utils.py index 9013044969..0dd8873576 100644 --- a/django/contrib/messages/utils.py +++ b/django/contrib/messages/utils.py @@ -8,5 +8,5 @@ def get_level_tags(): """ return { **constants.DEFAULT_TAGS, - **getattr(settings, 'MESSAGE_TAGS', {}), + **getattr(settings, "MESSAGE_TAGS", {}), } diff --git a/django/contrib/messages/views.py b/django/contrib/messages/views.py index eaa1bee9d5..38dee9bece 100644 --- a/django/contrib/messages/views.py +++ b/django/contrib/messages/views.py @@ -5,7 +5,8 @@ class SuccessMessageMixin: """ Add a success message on successful form submission. """ - success_message = '' + + success_message = "" def form_valid(self, form): response = super().form_valid(form) diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index d90ca50e2b..f8b40fb709 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,16 +1,20 @@ import warnings from django.contrib.postgres.fields import ArrayField -from django.db.models import ( - Aggregate, BooleanField, JSONField, TextField, Value, -) +from django.db.models import Aggregate, BooleanField, JSONField, TextField, Value from django.utils.deprecation import RemovedInDjango50Warning from .mixins import OrderableAggMixin __all__ = [ - 'ArrayAgg', 'BitAnd', 'BitOr', 'BitXor', 'BoolAnd', 'BoolOr', 'JSONBAgg', - 'StringAgg', + "ArrayAgg", + "BitAnd", + "BitOr", + "BitXor", + "BoolAnd", + "BoolOr", + "JSONBAgg", + "StringAgg", ] @@ -35,17 +39,17 @@ class DeprecatedConvertValueMixin: class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'ARRAY_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "ARRAY_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True # RemovedInDjango50Warning deprecation_value = property(lambda self: []) deprecation_msg = ( - 'In Django 5.0, ArrayAgg() will return None instead of an empty list ' - 'if there are no rows. Pass default=None to opt into the new behavior ' - 'and silence this warning or default=Value([]) to keep the previous ' - 'behavior.' + "In Django 5.0, ArrayAgg() will return None instead of an empty list " + "if there are no rows. Pass default=None to opt into the new behavior " + "and silence this warning or default=Value([]) to keep the previous " + "behavior." ) @property @@ -54,35 +58,35 @@ class ArrayAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): class BitAnd(Aggregate): - function = 'BIT_AND' + function = "BIT_AND" class BitOr(Aggregate): - function = 'BIT_OR' + function = "BIT_OR" class BitXor(Aggregate): - function = 'BIT_XOR' + function = "BIT_XOR" class BoolAnd(Aggregate): - function = 'BOOL_AND' + function = "BOOL_AND" output_field = BooleanField() class BoolOr(Aggregate): - function = 'BOOL_OR' + function = "BOOL_OR" output_field = BooleanField() class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'JSONB_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "JSONB_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True output_field = JSONField() # RemovedInDjango50Warning - deprecation_value = '[]' + deprecation_value = "[]" deprecation_msg = ( "In Django 5.0, JSONBAgg() will return None instead of an empty list " "if there are no rows. Pass default=None to opt into the new behavior " @@ -92,13 +96,13 @@ class JSONBAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): class StringAgg(DeprecatedConvertValueMixin, OrderableAggMixin, Aggregate): - function = 'STRING_AGG' - template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + function = "STRING_AGG" + template = "%(function)s(%(distinct)s%(expressions)s %(ordering)s)" allow_distinct = True output_field = TextField() # RemovedInDjango50Warning - deprecation_value = '' + deprecation_value = "" deprecation_msg = ( "In Django 5.0, StringAgg() will return None instead of an empty " "string if there are no rows. Pass default=None to opt into the new " diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py index 4fedb9bd98..b2f4097b8f 100644 --- a/django/contrib/postgres/aggregates/mixins.py +++ b/django/contrib/postgres/aggregates/mixins.py @@ -2,7 +2,6 @@ from django.db.models.expressions import OrderByList class OrderableAggMixin: - def __init__(self, *expressions, ordering=(), **extra): if isinstance(ordering, (list, tuple)): self.order_by = OrderByList(*ordering) diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py index 2c83b78c0e..3dc442b290 100644 --- a/django/contrib/postgres/aggregates/statistics.py +++ b/django/contrib/postgres/aggregates/statistics.py @@ -1,8 +1,18 @@ from django.db.models import Aggregate, FloatField, IntegerField __all__ = [ - 'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept', - 'RegrR2', 'RegrSlope', 'RegrSXX', 'RegrSXY', 'RegrSYY', 'StatAggregate', + "CovarPop", + "Corr", + "RegrAvgX", + "RegrAvgY", + "RegrCount", + "RegrIntercept", + "RegrR2", + "RegrSlope", + "RegrSXX", + "RegrSXY", + "RegrSYY", + "StatAggregate", ] @@ -11,53 +21,55 @@ class StatAggregate(Aggregate): def __init__(self, y, x, output_field=None, filter=None, default=None): if not x or not y: - raise ValueError('Both y and x must be provided.') - super().__init__(y, x, output_field=output_field, filter=filter, default=default) + raise ValueError("Both y and x must be provided.") + super().__init__( + y, x, output_field=output_field, filter=filter, default=default + ) class Corr(StatAggregate): - function = 'CORR' + function = "CORR" class CovarPop(StatAggregate): def __init__(self, y, x, sample=False, filter=None, default=None): - self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' + self.function = "COVAR_SAMP" if sample else "COVAR_POP" super().__init__(y, x, filter=filter, default=default) class RegrAvgX(StatAggregate): - function = 'REGR_AVGX' + function = "REGR_AVGX" class RegrAvgY(StatAggregate): - function = 'REGR_AVGY' + function = "REGR_AVGY" class RegrCount(StatAggregate): - function = 'REGR_COUNT' + function = "REGR_COUNT" output_field = IntegerField() empty_result_set_value = 0 class RegrIntercept(StatAggregate): - function = 'REGR_INTERCEPT' + function = "REGR_INTERCEPT" class RegrR2(StatAggregate): - function = 'REGR_R2' + function = "REGR_R2" class RegrSlope(StatAggregate): - function = 'REGR_SLOPE' + function = "REGR_SLOPE" class RegrSXX(StatAggregate): - function = 'REGR_SXX' + function = "REGR_SXX" class RegrSXY(StatAggregate): - function = 'REGR_SXY' + function = "REGR_SXY" class RegrSYY(StatAggregate): - function = 'REGR_SYY' + function = "REGR_SYY" diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py index b8ec85b7a4..d917201f05 100644 --- a/django/contrib/postgres/apps.py +++ b/django/contrib/postgres/apps.py @@ -1,6 +1,4 @@ -from psycopg2.extras import ( - DateRange, DateTimeRange, DateTimeTZRange, NumericRange, -) +from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange from django.apps import AppConfig from django.core.signals import setting_changed @@ -25,7 +23,11 @@ def uninstall_if_needed(setting, value, enter, **kwargs): Undo the effects of PostgresConfig.ready() when django.contrib.postgres is "uninstalled" by override_settings(). """ - if not enter and setting == 'INSTALLED_APPS' and 'django.contrib.postgres' not in set(value): + if ( + not enter + and setting == "INSTALLED_APPS" + and "django.contrib.postgres" not in set(value) + ): connection_created.disconnect(register_type_handlers) CharField._unregister_lookup(Unaccent) TextField._unregister_lookup(Unaccent) @@ -43,21 +45,23 @@ def uninstall_if_needed(setting, value, enter, **kwargs): class PostgresConfig(AppConfig): - name = 'django.contrib.postgres' - verbose_name = _('PostgreSQL extensions') + name = "django.contrib.postgres" + verbose_name = _("PostgreSQL extensions") def ready(self): setting_changed.connect(uninstall_if_needed) # Connections may already exist before we are called. for conn in connections.all(): - if conn.vendor == 'postgresql': - conn.introspection.data_types_reverse.update({ - 3904: 'django.contrib.postgres.fields.IntegerRangeField', - 3906: 'django.contrib.postgres.fields.DecimalRangeField', - 3910: 'django.contrib.postgres.fields.DateTimeRangeField', - 3912: 'django.contrib.postgres.fields.DateRangeField', - 3926: 'django.contrib.postgres.fields.BigIntegerRangeField', - }) + if conn.vendor == "postgresql": + conn.introspection.data_types_reverse.update( + { + 3904: "django.contrib.postgres.fields.IntegerRangeField", + 3906: "django.contrib.postgres.fields.DecimalRangeField", + 3910: "django.contrib.postgres.fields.DateTimeRangeField", + 3912: "django.contrib.postgres.fields.DateRangeField", + 3926: "django.contrib.postgres.fields.BigIntegerRangeField", + } + ) if conn.connection is not None: register_type_handlers(conn) connection_created.connect(register_type_handlers) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 06ccd3616e..c0bc0c444f 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -10,71 +10,69 @@ from django.db.models.indexes import IndexExpression from django.db.models.sql import Query from django.utils.deprecation import RemovedInDjango50Warning -__all__ = ['ExclusionConstraint'] +__all__ = ["ExclusionConstraint"] class ExclusionConstraintExpression(IndexExpression): - template = '%(expressions)s WITH %(operator)s' + template = "%(expressions)s WITH %(operator)s" class ExclusionConstraint(BaseConstraint): - template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s' + template = "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s" def __init__( - self, *, name, expressions, index_type=None, condition=None, - deferrable=None, include=None, opclasses=(), + self, + *, + name, + expressions, + index_type=None, + condition=None, + deferrable=None, + include=None, + opclasses=(), ): - if index_type and index_type.lower() not in {'gist', 'spgist'}: + if index_type and index_type.lower() not in {"gist", "spgist"}: raise ValueError( - 'Exclusion constraints only support GiST or SP-GiST indexes.' + "Exclusion constraints only support GiST or SP-GiST indexes." ) if not expressions: raise ValueError( - 'At least one expression is required to define an exclusion ' - 'constraint.' + "At least one expression is required to define an exclusion " + "constraint." ) if not all( - isinstance(expr, (list, tuple)) and len(expr) == 2 - for expr in expressions + isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions ): - raise ValueError('The expressions must be a list of 2-tuples.') + raise ValueError("The expressions must be a list of 2-tuples.") if not isinstance(condition, (type(None), Q)): - raise ValueError( - 'ExclusionConstraint.condition must be a Q instance.' - ) + raise ValueError("ExclusionConstraint.condition must be a Q instance.") if condition and deferrable: - raise ValueError( - 'ExclusionConstraint with conditions cannot be deferred.' - ) + raise ValueError("ExclusionConstraint with conditions cannot be deferred.") if not isinstance(deferrable, (type(None), Deferrable)): raise ValueError( - 'ExclusionConstraint.deferrable must be a Deferrable instance.' + "ExclusionConstraint.deferrable must be a Deferrable instance." ) if not isinstance(include, (type(None), list, tuple)): - raise ValueError( - 'ExclusionConstraint.include must be a list or tuple.' - ) + raise ValueError("ExclusionConstraint.include must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): - raise ValueError( - 'ExclusionConstraint.opclasses must be a list or tuple.' - ) + raise ValueError("ExclusionConstraint.opclasses must be a list or tuple.") if opclasses and len(expressions) != len(opclasses): raise ValueError( - 'ExclusionConstraint.expressions and ' - 'ExclusionConstraint.opclasses must have the same number of ' - 'elements.' + "ExclusionConstraint.expressions and " + "ExclusionConstraint.opclasses must have the same number of " + "elements." ) self.expressions = expressions - self.index_type = index_type or 'GIST' + self.index_type = index_type or "GIST" self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () self.opclasses = opclasses if self.opclasses: warnings.warn( - 'The opclasses argument is deprecated in favor of using ' - 'django.contrib.postgres.indexes.OpClass in ' - 'ExclusionConstraint.expressions.', + "The opclasses argument is deprecated in favor of using " + "django.contrib.postgres.indexes.OpClass in " + "ExclusionConstraint.expressions.", category=RemovedInDjango50Warning, stacklevel=2, ) @@ -107,14 +105,18 @@ class ExclusionConstraint(BaseConstraint): expressions = self._get_expressions(schema_editor, query) table = model._meta.db_table condition = self._get_condition_sql(compiler, schema_editor, query) - include = [model._meta.get_field(field_name).column for field_name in self.include] + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] return Statement( self.template, table=Table(table, schema_editor.quote_name), name=schema_editor.quote_name(self.name), index_type=self.index_type, - expressions=Expressions(table, expressions, compiler, schema_editor.quote_value), - where=' WHERE (%s)' % condition if condition else '', + expressions=Expressions( + table, expressions, compiler, schema_editor.quote_value + ), + where=" WHERE (%s)" % condition if condition else "", include=schema_editor._index_include_sql(model, include), deferrable=schema_editor._deferrable_constraint_sql(self.deferrable), ) @@ -122,7 +124,7 @@ class ExclusionConstraint(BaseConstraint): def create_sql(self, model, schema_editor): self.check_supported(schema_editor) return Statement( - 'ALTER TABLE %(table)s ADD %(constraint)s', + "ALTER TABLE %(table)s ADD %(constraint)s", table=Table(model._meta.db_table, schema_editor.quote_name), constraint=self.constraint_sql(model, schema_editor), ) @@ -136,60 +138,60 @@ class ExclusionConstraint(BaseConstraint): def check_supported(self, schema_editor): if ( - self.include and - self.index_type.lower() == 'gist' and - not schema_editor.connection.features.supports_covering_gist_indexes + self.include + and self.index_type.lower() == "gist" + and not schema_editor.connection.features.supports_covering_gist_indexes ): raise NotSupportedError( - 'Covering exclusion constraints using a GiST index require ' - 'PostgreSQL 12+.' + "Covering exclusion constraints using a GiST index require " + "PostgreSQL 12+." ) if ( - self.include and - self.index_type.lower() == 'spgist' and - not schema_editor.connection.features.supports_covering_spgist_indexes + self.include + and self.index_type.lower() == "spgist" + and not schema_editor.connection.features.supports_covering_spgist_indexes ): raise NotSupportedError( - 'Covering exclusion constraints using an SP-GiST index ' - 'require PostgreSQL 14+.' + "Covering exclusion constraints using an SP-GiST index " + "require PostgreSQL 14+." ) def deconstruct(self): path, args, kwargs = super().deconstruct() - kwargs['expressions'] = self.expressions + kwargs["expressions"] = self.expressions if self.condition is not None: - kwargs['condition'] = self.condition - if self.index_type.lower() != 'gist': - kwargs['index_type'] = self.index_type + kwargs["condition"] = self.condition + if self.index_type.lower() != "gist": + kwargs["index_type"] = self.index_type if self.deferrable: - kwargs['deferrable'] = self.deferrable + kwargs["deferrable"] = self.deferrable if self.include: - kwargs['include'] = self.include + kwargs["include"] = self.include if self.opclasses: - kwargs['opclasses'] = self.opclasses + kwargs["opclasses"] = self.opclasses return path, args, kwargs def __eq__(self, other): if isinstance(other, self.__class__): return ( - self.name == other.name and - self.index_type == other.index_type and - self.expressions == other.expressions and - self.condition == other.condition and - self.deferrable == other.deferrable and - self.include == other.include and - self.opclasses == other.opclasses + self.name == other.name + and self.index_type == other.index_type + and self.expressions == other.expressions + and self.condition == other.condition + and self.deferrable == other.deferrable + and self.include == other.include + and self.opclasses == other.opclasses ) return super().__eq__(other) def __repr__(self): - return '<%s: index_type=%s expressions=%s name=%s%s%s%s%s>' % ( + return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % ( self.__class__.__qualname__, repr(self.index_type), repr(self.expressions), repr(self.name), - '' if self.condition is None else ' condition=%s' % self.condition, - '' if self.deferrable is None else ' deferrable=%r' % self.deferrable, - '' if not self.include else ' include=%s' % repr(self.include), - '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses), + "" if self.condition is None else " condition=%s" % self.condition, + "" if self.deferrable is None else " deferrable=%r" % self.deferrable, + "" if not self.include else " include=%s" % repr(self.include), + "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) diff --git a/django/contrib/postgres/expressions.py b/django/contrib/postgres/expressions.py index ea7cbe038d..469f4e9fb6 100644 --- a/django/contrib/postgres/expressions.py +++ b/django/contrib/postgres/expressions.py @@ -4,7 +4,7 @@ from django.utils.functional import cached_property class ArraySubquery(Subquery): - template = 'ARRAY(%(subquery)s)' + template = "ARRAY(%(subquery)s)" def __init__(self, queryset, **kwargs): super().__init__(queryset, **kwargs) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 9c1bb96b61..7269198674 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -12,38 +12,43 @@ from django.utils.translation import gettext_lazy as _ from ..utils import prefix_validation_error from .utils import AttributeSetter -__all__ = ['ArrayField'] +__all__ = ["ArrayField"] class ArrayField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), - 'nested_array_mismatch': _('Nested arrays must have the same length.'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), + "nested_array_mismatch": _("Nested arrays must have the same length."), } - _default_hint = ('list', '[]') + _default_hint = ("list", "[]") def __init__(self, base_field, size=None, **kwargs): self.base_field = base_field self.size = size if self.size: - self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)] + self.default_validators = [ + *self.default_validators, + ArrayMaxLengthValidator(self.size), + ] # For performance, only add a from_db_value() method if the base field # implements it. - if hasattr(self.base_field, 'from_db_value'): + if hasattr(self.base_field, "from_db_value"): self.from_db_value = self._from_db_value super().__init__(**kwargs) @property def model(self): try: - return self.__dict__['model'] + return self.__dict__["model"] except KeyError: - raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) @model.setter def model(self, model): - self.__dict__['model'] = model + self.__dict__["model"] = model self.base_field.model = model @classmethod @@ -55,21 +60,23 @@ class ArrayField(CheckFieldDefaultMixin, Field): if self.base_field.remote_field: errors.append( checks.Error( - 'Base field for array cannot be a related field.', + "Base field for array cannot be a related field.", obj=self, - id='postgres.E002' + id="postgres.E002", ) ) else: # Remove the field name checks as they are not needed here. base_errors = self.base_field.check() if base_errors: - messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors) + messages = "\n ".join( + "%s (%s)" % (error.msg, error.id) for error in base_errors + ) errors.append( checks.Error( - 'Base field for array has errors:\n %s' % messages, + "Base field for array has errors:\n %s" % messages, obj=self, - id='postgres.E001' + id="postgres.E001", ) ) return errors @@ -80,32 +87,37 @@ class ArrayField(CheckFieldDefaultMixin, Field): @property def description(self): - return 'Array of %s' % self.base_field.description + return "Array of %s" % self.base_field.description def db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.db_type(connection), size) def cast_db_type(self, connection): - size = self.size or '' - return '%s[%s]' % (self.base_field.cast_db_type(connection), size) + size = self.size or "" + return "%s[%s]" % (self.base_field.cast_db_type(connection), size) def get_placeholder(self, value, compiler, connection): - return '%s::{}'.format(self.db_type(connection)) + return "%s::{}".format(self.db_type(connection)) def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, (list, tuple)): - return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] + return [ + self.base_field.get_db_prep_value(i, connection, prepared=False) + for i in value + ] return value def deconstruct(self): name, path, args, kwargs = super().deconstruct() - if path == 'django.contrib.postgres.fields.array.ArrayField': - path = 'django.contrib.postgres.fields.ArrayField' - kwargs.update({ - 'base_field': self.base_field.clone(), - 'size': self.size, - }) + if path == "django.contrib.postgres.fields.array.ArrayField": + path = "django.contrib.postgres.fields.ArrayField" + kwargs.update( + { + "base_field": self.base_field.clone(), + "size": self.size, + } + ) return name, path, args, kwargs def to_python(self, value): @@ -140,7 +152,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): transform = super().get_transform(name) if transform: return transform - if '_' not in name: + if "_" not in name: try: index = int(name) except ValueError: @@ -149,7 +161,7 @@ class ArrayField(CheckFieldDefaultMixin, Field): index += 1 # postgres uses 1-indexing return IndexTransformFactory(index, self.base_field) try: - start, end = name.split('_') + start, end = name.split("_") start = int(start) + 1 end = int(end) # don't add one here because postgres slices are weird except ValueError: @@ -165,15 +177,15 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) if isinstance(self.base_field, ArrayField): if len({len(i) for i in value}) > 1: raise exceptions.ValidationError( - self.error_messages['nested_array_mismatch'], - code='nested_array_mismatch', + self.error_messages["nested_array_mismatch"], + code="nested_array_mismatch", ) def run_validators(self, value): @@ -184,18 +196,20 @@ class ArrayField(CheckFieldDefaultMixin, Field): except exceptions.ValidationError as error: raise prefix_validation_error( error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, ) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': SimpleArrayField, - 'base_field': self.base_field.formfield(), - 'max_length': self.size, - **kwargs, - }) + return super().formfield( + **{ + "form_class": SimpleArrayField, + "base_field": self.base_field.formfield(), + "max_length": self.size, + **kwargs, + } + ) class ArrayRHSMixin: @@ -203,21 +217,21 @@ class ArrayRHSMixin: if isinstance(rhs, (tuple, list)): expressions = [] for value in rhs: - if not hasattr(value, 'resolve_expression'): + if not hasattr(value, "resolve_expression"): field = lhs.output_field value = Value(field.base_field.get_prep_value(value)) expressions.append(value) rhs = Func( *expressions, - function='ARRAY', - template='%(function)s[%(expressions)s]', + function="ARRAY", + template="%(function)s[%(expressions)s]", ) super().__init__(lhs, rhs) def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) cast_type = self.lhs.output_field.cast_db_type(connection) - return '%s::%s' % (rhs, cast_type), rhs_params + return "%s::%s" % (rhs, cast_type), rhs_params @ArrayField.register_lookup @@ -242,29 +256,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): @ArrayField.register_lookup class ArrayLenTransform(Transform): - lookup_name = 'len' + lookup_name = "len" output_field = IntegerField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) # Distinguish NULL and empty arrays return ( - 'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE ' - 'coalesce(array_length(%(lhs)s, 1), 0) END' - ) % {'lhs': lhs}, params + "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " + "coalesce(array_length(%(lhs)s, 1), 0) END" + ) % {"lhs": lhs}, params @ArrayField.register_lookup class ArrayInLookup(In): def get_prep_lookup(self): values = super().get_prep_lookup() - if hasattr(values, 'resolve_expression'): + if hasattr(values, "resolve_expression"): return values # In.process_rhs() expects values to be hashable, so convert lists # to tuples. prepared_values = [] for value in values: - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): prepared_values.append(value) else: prepared_values.append(tuple(value)) @@ -272,7 +286,6 @@ class ArrayInLookup(In): class IndexTransform(Transform): - def __init__(self, index, base_field, *args, **kwargs): super().__init__(*args, **kwargs) self.index = index @@ -280,7 +293,7 @@ class IndexTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s]' % lhs, params + [self.index] + return "%s[%%s]" % lhs, params + [self.index] @property def output_field(self): @@ -288,7 +301,6 @@ class IndexTransform(Transform): class IndexTransformFactory: - def __init__(self, index, base_field): self.index = index self.base_field = base_field @@ -298,7 +310,6 @@ class IndexTransformFactory: class SliceTransform(Transform): - def __init__(self, start, end, *args, **kwargs): super().__init__(*args, **kwargs) self.start = start @@ -306,11 +317,10 @@ class SliceTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '%s[%%s:%%s]' % lhs, params + [self.start, self.end] + return "%s[%%s:%%s]" % lhs, params + [self.start, self.end] class SliceTransformFactory: - def __init__(self, start, end): self.start = start self.end = end diff --git a/django/contrib/postgres/fields/citext.py b/django/contrib/postgres/fields/citext.py index 46f6d3d1c2..2b943614d2 100644 --- a/django/contrib/postgres/fields/citext.py +++ b/django/contrib/postgres/fields/citext.py @@ -1,15 +1,14 @@ from django.db.models import CharField, EmailField, TextField -__all__ = ['CICharField', 'CIEmailField', 'CIText', 'CITextField'] +__all__ = ["CICharField", "CIEmailField", "CIText", "CITextField"] class CIText: - def get_internal_type(self): - return 'CI' + super().get_internal_type() + return "CI" + super().get_internal_type() def db_type(self, connection): - return 'citext' + return "citext" class CICharField(CIText, CharField): diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py index 2ec5766041..cfc156ab59 100644 --- a/django/contrib/postgres/fields/hstore.py +++ b/django/contrib/postgres/fields/hstore.py @@ -7,19 +7,19 @@ from django.db.models import Field, TextField, Transform from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.utils.translation import gettext_lazy as _ -__all__ = ['HStoreField'] +__all__ = ["HStoreField"] class HStoreField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False - description = _('Map of strings to strings/nulls') + description = _("Map of strings to strings/nulls") default_error_messages = { - 'not_a_string': _('The value of “%(key)s” is not a string or null.'), + "not_a_string": _("The value of “%(key)s” is not a string or null."), } - _default_hint = ('dict', '{}') + _default_hint = ("dict", "{}") def db_type(self, connection): - return 'hstore' + return "hstore" def get_transform(self, name): transform = super().get_transform(name) @@ -32,9 +32,9 @@ class HStoreField(CheckFieldDefaultMixin, Field): for key, val in value.items(): if not isinstance(val, str) and val is not None: raise exceptions.ValidationError( - self.error_messages['not_a_string'], - code='not_a_string', - params={'key': key}, + self.error_messages["not_a_string"], + code="not_a_string", + params={"key": key}, ) def to_python(self, value): @@ -46,10 +46,12 @@ class HStoreField(CheckFieldDefaultMixin, Field): return json.dumps(self.value_from_object(obj)) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.HStoreField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.HStoreField, + **kwargs, + } + ) def get_prep_value(self, value): value = super().get_prep_value(value) @@ -85,11 +87,10 @@ class KeyTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return '(%s -> %%s)' % lhs, tuple(params) + (self.key_name,) + return "(%s -> %%s)" % lhs, tuple(params) + (self.key_name,) class KeyTransformFactory: - def __init__(self, key_name): self.key_name = key_name @@ -99,13 +100,13 @@ class KeyTransformFactory: @HStoreField.register_lookup class KeysTransform(Transform): - lookup_name = 'keys' - function = 'akeys' + lookup_name = "keys" + function = "akeys" output_field = ArrayField(TextField()) @HStoreField.register_lookup class ValuesTransform(Transform): - lookup_name = 'values' - function = 'avals' + lookup_name = "values" + function = "avals" output_field = ArrayField(TextField()) diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index 29e8480665..760b5d8398 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -1,14 +1,14 @@ from django.db.models import JSONField as BuiltinJSONField -__all__ = ['JSONField'] +__all__ = ["JSONField"] class JSONField(BuiltinJSONField): system_check_removed_details = { - 'msg': ( - 'django.contrib.postgres.fields.JSONField is removed except for ' - 'support in historical migrations.' + "msg": ( + "django.contrib.postgres.fields.JSONField is removed except for " + "support in historical migrations." ), - 'hint': 'Use django.db.models.JSONField instead.', - 'id': 'fields.E904', + "hint": "Use django.db.models.JSONField instead.", + "id": "fields.E904", } diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index b395e213a1..58ffacbbe5 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -10,17 +10,23 @@ from django.db.models.lookups import PostgresOperatorLookup from .utils import AttributeSetter __all__ = [ - 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField', - 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField', - 'RangeBoundary', 'RangeOperators', + "RangeField", + "IntegerRangeField", + "BigIntegerRangeField", + "DecimalRangeField", + "DateTimeRangeField", + "DateRangeField", + "RangeBoundary", + "RangeOperators", ] class RangeBoundary(models.Expression): """A class that represents range boundaries.""" + def __init__(self, inclusive_lower=True, inclusive_upper=False): - self.lower = '[' if inclusive_lower else '(' - self.upper = ']' if inclusive_upper else ')' + self.lower = "[" if inclusive_lower else "(" + self.upper = "]" if inclusive_upper else ")" def as_sql(self, compiler, connection): return "'%s%s'" % (self.lower, self.upper), [] @@ -28,41 +34,43 @@ class RangeBoundary(models.Expression): class RangeOperators: # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE - EQUAL = '=' - NOT_EQUAL = '<>' - CONTAINS = '@>' - CONTAINED_BY = '<@' - OVERLAPS = '&&' - FULLY_LT = '<<' - FULLY_GT = '>>' - NOT_LT = '&>' - NOT_GT = '&<' - ADJACENT_TO = '-|-' + EQUAL = "=" + NOT_EQUAL = "<>" + CONTAINS = "@>" + CONTAINED_BY = "<@" + OVERLAPS = "&&" + FULLY_LT = "<<" + FULLY_GT = ">>" + NOT_LT = "&>" + NOT_GT = "&<" + ADJACENT_TO = "-|-" class RangeField(models.Field): empty_strings_allowed = False def __init__(self, *args, **kwargs): - if 'default_bounds' in kwargs: + if "default_bounds" in kwargs: raise TypeError( f"Cannot use 'default_bounds' with {self.__class__.__name__}." ) # Initializing base_field here ensures that its model matches the model for self. - if hasattr(self, 'base_field'): + if hasattr(self, "base_field"): self.base_field = self.base_field() super().__init__(*args, **kwargs) @property def model(self): try: - return self.__dict__['model'] + return self.__dict__["model"] except KeyError: - raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'model'" % self.__class__.__name__ + ) @model.setter def model(self, model): - self.__dict__['model'] = model + self.__dict__["model"] = model self.base_field.model = model @classmethod @@ -82,7 +90,7 @@ class RangeField(models.Field): if isinstance(value, str): # Assume we're deserializing vals = json.loads(value) - for end in ('lower', 'upper'): + for end in ("lower", "upper"): if end in vals: vals[end] = self.base_field.to_python(vals[end]) value = self.range_type(**vals) @@ -102,7 +110,7 @@ class RangeField(models.Field): return json.dumps({"empty": True}) base_field = self.base_field result = {"bounds": value._bounds} - for end in ('lower', 'upper'): + for end in ("lower", "upper"): val = getattr(value, end) if val is None: result[end] = None @@ -112,11 +120,11 @@ class RangeField(models.Field): return json.dumps(result) def formfield(self, **kwargs): - kwargs.setdefault('form_class', self.form_field) + kwargs.setdefault("form_class", self.form_field) return super().formfield(**kwargs) -CANONICAL_RANGE_BOUNDS = '[)' +CANONICAL_RANGE_BOUNDS = "[)" class ContinuousRangeField(RangeField): @@ -126,7 +134,7 @@ class ContinuousRangeField(RangeField): """ def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs): - if default_bounds not in ('[)', '(]', '()', '[]'): + if default_bounds not in ("[)", "(]", "()", "[]"): raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.") self.default_bounds = default_bounds super().__init__(*args, **kwargs) @@ -137,13 +145,13 @@ class ContinuousRangeField(RangeField): return super().get_prep_value(value) def formfield(self, **kwargs): - kwargs.setdefault('default_bounds', self.default_bounds) + kwargs.setdefault("default_bounds", self.default_bounds) return super().formfield(**kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS: - kwargs['default_bounds'] = self.default_bounds + kwargs["default_bounds"] = self.default_bounds return name, path, args, kwargs @@ -153,7 +161,7 @@ class IntegerRangeField(RangeField): form_field = forms.IntegerRangeField def db_type(self, connection): - return 'int4range' + return "int4range" class BigIntegerRangeField(RangeField): @@ -162,7 +170,7 @@ class BigIntegerRangeField(RangeField): form_field = forms.IntegerRangeField def db_type(self, connection): - return 'int8range' + return "int8range" class DecimalRangeField(ContinuousRangeField): @@ -171,7 +179,7 @@ class DecimalRangeField(ContinuousRangeField): form_field = forms.DecimalRangeField def db_type(self, connection): - return 'numrange' + return "numrange" class DateTimeRangeField(ContinuousRangeField): @@ -180,7 +188,7 @@ class DateTimeRangeField(ContinuousRangeField): form_field = forms.DateTimeRangeField def db_type(self, connection): - return 'tstzrange' + return "tstzrange" class DateRangeField(RangeField): @@ -189,7 +197,7 @@ class DateRangeField(RangeField): form_field = forms.DateRangeField def db_type(self, connection): - return 'daterange' + return "daterange" RangeField.register_lookup(lookups.DataContains) @@ -202,7 +210,8 @@ class DateTimeRangeContains(PostgresOperatorLookup): Lookup for Date/DateTimeRange containment to cast the rhs to the correct type. """ - lookup_name = 'contains' + + lookup_name = "contains" postgres_operator = RangeOperators.CONTAINS def process_rhs(self, compiler, connection): @@ -215,16 +224,19 @@ class DateTimeRangeContains(PostgresOperatorLookup): def as_postgresql(self, compiler, connection): sql, params = super().as_postgresql(compiler, connection) # Cast the rhs if needed. - cast_sql = '' + cast_sql = "" if ( - isinstance(self.rhs, models.Expression) and - self.rhs._output_field_or_none and + isinstance(self.rhs, models.Expression) + and self.rhs._output_field_or_none + and # Skip cast if rhs has a matching range type. - not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__) + not isinstance( + self.rhs._output_field_or_none, self.lhs.output_field.__class__ + ) ): cast_internal_type = self.lhs.output_field.base_field.get_internal_type() - cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type)) - return '%s%s' % (sql, cast_sql), params + cast_sql = "::{}".format(connection.data_types.get(cast_internal_type)) + return "%s%s" % (sql, cast_sql), params DateRangeField.register_lookup(DateTimeRangeContains) @@ -232,31 +244,31 @@ DateTimeRangeField.register_lookup(DateTimeRangeContains) class RangeContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' + lookup_name = "contained_by" type_mapping = { - 'smallint': 'int4range', - 'integer': 'int4range', - 'bigint': 'int8range', - 'double precision': 'numrange', - 'numeric': 'numrange', - 'date': 'daterange', - 'timestamp with time zone': 'tstzrange', + "smallint": "int4range", + "integer": "int4range", + "bigint": "int8range", + "double precision": "numrange", + "numeric": "numrange", + "date": "daterange", + "timestamp with time zone": "tstzrange", } postgres_operator = RangeOperators.CONTAINED_BY def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Ignore precision for DecimalFields. - db_type = self.lhs.output_field.cast_db_type(connection).split('(')[0] + db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0] cast_type = self.type_mapping[db_type] - return '%s::%s' % (rhs, cast_type), rhs_params + return "%s::%s" % (rhs, cast_type), rhs_params def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) if isinstance(self.lhs.output_field, models.FloatField): - lhs = '%s::numeric' % lhs + lhs = "%s::numeric" % lhs elif isinstance(self.lhs.output_field, models.SmallIntegerField): - lhs = '%s::integer' % lhs + lhs = "%s::integer" % lhs return lhs, lhs_params def get_prep_lookup(self): @@ -272,38 +284,38 @@ models.DecimalField.register_lookup(RangeContainedBy) @RangeField.register_lookup class FullyLessThan(PostgresOperatorLookup): - lookup_name = 'fully_lt' + lookup_name = "fully_lt" postgres_operator = RangeOperators.FULLY_LT @RangeField.register_lookup class FullGreaterThan(PostgresOperatorLookup): - lookup_name = 'fully_gt' + lookup_name = "fully_gt" postgres_operator = RangeOperators.FULLY_GT @RangeField.register_lookup class NotLessThan(PostgresOperatorLookup): - lookup_name = 'not_lt' + lookup_name = "not_lt" postgres_operator = RangeOperators.NOT_LT @RangeField.register_lookup class NotGreaterThan(PostgresOperatorLookup): - lookup_name = 'not_gt' + lookup_name = "not_gt" postgres_operator = RangeOperators.NOT_GT @RangeField.register_lookup class AdjacentToLookup(PostgresOperatorLookup): - lookup_name = 'adjacent_to' + lookup_name = "adjacent_to" postgres_operator = RangeOperators.ADJACENT_TO @RangeField.register_lookup class RangeStartsWith(models.Transform): - lookup_name = 'startswith' - function = 'lower' + lookup_name = "startswith" + function = "lower" @property def output_field(self): @@ -312,8 +324,8 @@ class RangeStartsWith(models.Transform): @RangeField.register_lookup class RangeEndsWith(models.Transform): - lookup_name = 'endswith' - function = 'upper' + lookup_name = "endswith" + function = "upper" @property def output_field(self): @@ -322,34 +334,34 @@ class RangeEndsWith(models.Transform): @RangeField.register_lookup class IsEmpty(models.Transform): - lookup_name = 'isempty' - function = 'isempty' + lookup_name = "isempty" + function = "isempty" output_field = models.BooleanField() @RangeField.register_lookup class LowerInclusive(models.Transform): - lookup_name = 'lower_inc' - function = 'LOWER_INC' + lookup_name = "lower_inc" + function = "LOWER_INC" output_field = models.BooleanField() @RangeField.register_lookup class LowerInfinite(models.Transform): - lookup_name = 'lower_inf' - function = 'LOWER_INF' + lookup_name = "lower_inf" + function = "LOWER_INF" output_field = models.BooleanField() @RangeField.register_lookup class UpperInclusive(models.Transform): - lookup_name = 'upper_inc' - function = 'UPPER_INC' + lookup_name = "upper_inc" + function = "UPPER_INC" output_field = models.BooleanField() @RangeField.register_lookup class UpperInfinite(models.Transform): - lookup_name = 'upper_inf' - function = 'UPPER_INF' + lookup_name = "upper_inf" + function = "UPPER_INF" output_field = models.BooleanField() diff --git a/django/contrib/postgres/forms/array.py b/django/contrib/postgres/forms/array.py index 2e19cd574a..ddb022afc3 100644 --- a/django/contrib/postgres/forms/array.py +++ b/django/contrib/postgres/forms/array.py @@ -3,7 +3,8 @@ from itertools import chain from django import forms from django.contrib.postgres.validators import ( - ArrayMaxLengthValidator, ArrayMinLengthValidator, + ArrayMaxLengthValidator, + ArrayMinLengthValidator, ) from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ @@ -13,10 +14,12 @@ from ..utils import prefix_validation_error class SimpleArrayField(forms.CharField): default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), } - def __init__(self, base_field, *, delimiter=',', max_length=None, min_length=None, **kwargs): + def __init__( + self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs + ): self.base_field = base_field self.delimiter = delimiter super().__init__(**kwargs) @@ -33,7 +36,9 @@ class SimpleArrayField(forms.CharField): def prepare_value(self, value): if isinstance(value, list): - return self.delimiter.join(str(self.base_field.prepare_value(v)) for v in value) + return self.delimiter.join( + str(self.base_field.prepare_value(v)) for v in value + ) return value def to_python(self, value): @@ -49,12 +54,14 @@ class SimpleArrayField(forms.CharField): try: values.append(self.base_field.to_python(item)) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) return values @@ -66,12 +73,14 @@ class SimpleArrayField(forms.CharField): try: self.base_field.validate(item) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) @@ -82,12 +91,14 @@ class SimpleArrayField(forms.CharField): try: self.base_field.run_validators(item) except ValidationError as error: - errors.append(prefix_validation_error( - error, - prefix=self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) if errors: raise ValidationError(errors) @@ -103,7 +114,7 @@ class SimpleArrayField(forms.CharField): class SplitArrayWidget(forms.Widget): - template_name = 'postgres/widgets/split_array.html' + template_name = "postgres/widgets/split_array.html" def __init__(self, widget, size, **kwargs): self.widget = widget() if isinstance(widget, type) else widget @@ -115,19 +126,21 @@ class SplitArrayWidget(forms.Widget): return self.widget.is_hidden def value_from_datadict(self, data, files, name): - return [self.widget.value_from_datadict(data, files, '%s_%s' % (name, index)) - for index in range(self.size)] + return [ + self.widget.value_from_datadict(data, files, "%s_%s" % (name, index)) + for index in range(self.size) + ] def value_omitted_from_data(self, data, files, name): return all( - self.widget.value_omitted_from_data(data, files, '%s_%s' % (name, index)) + self.widget.value_omitted_from_data(data, files, "%s_%s" % (name, index)) for index in range(self.size) ) def id_for_label(self, id_): # See the comment for RadioSelect.id_for_label() if id_: - id_ += '_0' + id_ += "_0" return id_ def get_context(self, name, value, attrs=None): @@ -136,18 +149,20 @@ class SplitArrayWidget(forms.Widget): if self.is_localized: self.widget.is_localized = self.is_localized value = value or [] - context['widget']['subwidgets'] = [] + context["widget"]["subwidgets"] = [] final_attrs = self.build_attrs(attrs) - id_ = final_attrs.get('id') + id_ = final_attrs.get("id") for i in range(max(len(value), self.size)): try: widget_value = value[i] except IndexError: widget_value = None if id_: - final_attrs = {**final_attrs, 'id': '%s_%s' % (id_, i)} - context['widget']['subwidgets'].append( - self.widget.get_context(name + '_%s' % i, widget_value, final_attrs)['widget'] + final_attrs = {**final_attrs, "id": "%s_%s" % (id_, i)} + context["widget"]["subwidgets"].append( + self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)[ + "widget" + ] ) return context @@ -167,7 +182,7 @@ class SplitArrayWidget(forms.Widget): class SplitArrayField(forms.Field): default_error_messages = { - 'item_invalid': _('Item %(nth)s in the array did not validate:'), + "item_invalid": _("Item %(nth)s in the array did not validate:"), } def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs): @@ -175,7 +190,7 @@ class SplitArrayField(forms.Field): self.size = size self.remove_trailing_nulls = remove_trailing_nulls widget = SplitArrayWidget(widget=base_field.widget, size=size) - kwargs.setdefault('widget', widget) + kwargs.setdefault("widget", widget) super().__init__(**kwargs) def _remove_trailing_nulls(self, values): @@ -198,19 +213,21 @@ class SplitArrayField(forms.Field): cleaned_data = [] errors = [] if not any(value) and self.required: - raise ValidationError(self.error_messages['required']) + raise ValidationError(self.error_messages["required"]) max_size = max(self.size, len(value)) for index in range(max_size): item = value[index] try: cleaned_data.append(self.base_field.clean(item)) except ValidationError as error: - errors.append(prefix_validation_error( - error, - self.error_messages['item_invalid'], - code='item_invalid', - params={'nth': index + 1}, - )) + errors.append( + prefix_validation_error( + error, + self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) cleaned_data.append(None) else: errors.append(None) diff --git a/django/contrib/postgres/forms/hstore.py b/django/contrib/postgres/forms/hstore.py index f5af8f10e3..6a20f7b729 100644 --- a/django/contrib/postgres/forms/hstore.py +++ b/django/contrib/postgres/forms/hstore.py @@ -4,17 +4,18 @@ from django import forms from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ -__all__ = ['HStoreField'] +__all__ = ["HStoreField"] class HStoreField(forms.CharField): """ A field for HStore data which accepts dictionary JSON input. """ + widget = forms.Textarea default_error_messages = { - 'invalid_json': _('Could not load JSON data.'), - 'invalid_format': _('Input must be a JSON dictionary.'), + "invalid_json": _("Could not load JSON data."), + "invalid_format": _("Input must be a JSON dictionary."), } def prepare_value(self, value): @@ -30,14 +31,14 @@ class HStoreField(forms.CharField): value = json.loads(value) except json.JSONDecodeError: raise ValidationError( - self.error_messages['invalid_json'], - code='invalid_json', + self.error_messages["invalid_json"], + code="invalid_json", ) if not isinstance(value, dict): raise ValidationError( - self.error_messages['invalid_format'], - code='invalid_format', + self.error_messages["invalid_format"], + code="invalid_format", ) # Cast everything to strings for ease. diff --git a/django/contrib/postgres/forms/ranges.py b/django/contrib/postgres/forms/ranges.py index 9c673ab40c..444991970d 100644 --- a/django/contrib/postgres/forms/ranges.py +++ b/django/contrib/postgres/forms/ranges.py @@ -6,8 +6,13 @@ from django.forms.widgets import HiddenInput, MultiWidget from django.utils.translation import gettext_lazy as _ __all__ = [ - 'BaseRangeField', 'IntegerRangeField', 'DecimalRangeField', - 'DateTimeRangeField', 'DateRangeField', 'HiddenRangeWidget', 'RangeWidget', + "BaseRangeField", + "IntegerRangeField", + "DecimalRangeField", + "DateTimeRangeField", + "DateRangeField", + "HiddenRangeWidget", + "RangeWidget", ] @@ -24,27 +29,33 @@ class RangeWidget(MultiWidget): class HiddenRangeWidget(RangeWidget): """A widget that splits input into two <input type="hidden"> inputs.""" + def __init__(self, attrs=None): super().__init__(HiddenInput, attrs) class BaseRangeField(forms.MultiValueField): default_error_messages = { - 'invalid': _('Enter two valid values.'), - 'bound_ordering': _('The start of the range must not exceed the end of the range.'), + "invalid": _("Enter two valid values."), + "bound_ordering": _( + "The start of the range must not exceed the end of the range." + ), } hidden_widget = HiddenRangeWidget def __init__(self, **kwargs): - if 'widget' not in kwargs: - kwargs['widget'] = RangeWidget(self.base_field.widget) - if 'fields' not in kwargs: - kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)] - kwargs.setdefault('required', False) - kwargs.setdefault('require_all_fields', False) + if "widget" not in kwargs: + kwargs["widget"] = RangeWidget(self.base_field.widget) + if "fields" not in kwargs: + kwargs["fields"] = [ + self.base_field(required=False), + self.base_field(required=False), + ] + kwargs.setdefault("required", False) + kwargs.setdefault("require_all_fields", False) self.range_kwargs = {} - if default_bounds := kwargs.pop('default_bounds', None): - self.range_kwargs = {'bounds': default_bounds} + if default_bounds := kwargs.pop("default_bounds", None): + self.range_kwargs = {"bounds": default_bounds} super().__init__(**kwargs) def prepare_value(self, value): @@ -67,39 +78,39 @@ class BaseRangeField(forms.MultiValueField): lower, upper = values if lower is not None and upper is not None and lower > upper: raise exceptions.ValidationError( - self.error_messages['bound_ordering'], - code='bound_ordering', + self.error_messages["bound_ordering"], + code="bound_ordering", ) try: range_value = self.range_type(lower, upper, **self.range_kwargs) except TypeError: raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', + self.error_messages["invalid"], + code="invalid", ) else: return range_value class IntegerRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two whole numbers.')} + default_error_messages = {"invalid": _("Enter two whole numbers.")} base_field = forms.IntegerField range_type = NumericRange class DecimalRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two numbers.')} + default_error_messages = {"invalid": _("Enter two numbers.")} base_field = forms.DecimalField range_type = NumericRange class DateTimeRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two valid date/times.')} + default_error_messages = {"invalid": _("Enter two valid date/times.")} base_field = forms.DateTimeField range_type = DateTimeTZRange class DateRangeField(BaseRangeField): - default_error_messages = {'invalid': _('Enter two valid dates.')} + default_error_messages = {"invalid": _("Enter two valid dates.")} base_field = forms.DateField range_type = DateRange diff --git a/django/contrib/postgres/functions.py b/django/contrib/postgres/functions.py index 819ce058e5..f001a04fdc 100644 --- a/django/contrib/postgres/functions.py +++ b/django/contrib/postgres/functions.py @@ -2,10 +2,10 @@ from django.db.models import DateTimeField, Func, UUIDField class RandomUUID(Func): - template = 'GEN_RANDOM_UUID()' + template = "GEN_RANDOM_UUID()" output_field = UUIDField() class TransactionNow(Func): - template = 'CURRENT_TIMESTAMP' + template = "CURRENT_TIMESTAMP" output_field = DateTimeField() diff --git a/django/contrib/postgres/indexes.py b/django/contrib/postgres/indexes.py index 2e3de1d275..409f514147 100644 --- a/django/contrib/postgres/indexes.py +++ b/django/contrib/postgres/indexes.py @@ -3,13 +3,17 @@ from django.db.models import Func, Index from django.utils.functional import cached_property __all__ = [ - 'BloomIndex', 'BrinIndex', 'BTreeIndex', 'GinIndex', 'GistIndex', - 'HashIndex', 'SpGistIndex', + "BloomIndex", + "BrinIndex", + "BTreeIndex", + "GinIndex", + "GistIndex", + "HashIndex", + "SpGistIndex", ] class PostgresIndex(Index): - @cached_property def max_name_length(self): # Allow an index name longer than 30 characters when the suffix is @@ -18,14 +22,16 @@ class PostgresIndex(Index): # indexes. return Index.max_name_length - len(Index.suffix) + len(self.suffix) - def create_sql(self, model, schema_editor, using='', **kwargs): + def create_sql(self, model, schema_editor, using="", **kwargs): self.check_supported(schema_editor) - statement = super().create_sql(model, schema_editor, using=' USING %s' % self.suffix, **kwargs) + statement = super().create_sql( + model, schema_editor, using=" USING %s" % self.suffix, **kwargs + ) with_params = self.get_with_params() if with_params: - statement.parts['extra'] = 'WITH (%s) %s' % ( - ', '.join(with_params), - statement.parts['extra'], + statement.parts["extra"] = "WITH (%s) %s" % ( + ", ".join(with_params), + statement.parts["extra"], ) return statement @@ -37,25 +43,23 @@ class PostgresIndex(Index): class BloomIndex(PostgresIndex): - suffix = 'bloom' + suffix = "bloom" def __init__(self, *expressions, length=None, columns=(), **kwargs): super().__init__(*expressions, **kwargs) if len(self.fields) > 32: - raise ValueError('Bloom indexes support a maximum of 32 fields.') + raise ValueError("Bloom indexes support a maximum of 32 fields.") if not isinstance(columns, (list, tuple)): - raise ValueError('BloomIndex.columns must be a list or tuple.') + raise ValueError("BloomIndex.columns must be a list or tuple.") if len(columns) > len(self.fields): - raise ValueError( - 'BloomIndex.columns cannot have more values than fields.' - ) + raise ValueError("BloomIndex.columns cannot have more values than fields.") if not all(0 < col <= 4095 for col in columns): raise ValueError( - 'BloomIndex.columns must contain integers from 1 to 4095.', + "BloomIndex.columns must contain integers from 1 to 4095.", ) if length is not None and not 0 < length <= 4096: raise ValueError( - 'BloomIndex.length must be None or an integer from 1 to 4096.', + "BloomIndex.length must be None or an integer from 1 to 4096.", ) self.length = length self.columns = columns @@ -63,29 +67,30 @@ class BloomIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.length is not None: - kwargs['length'] = self.length + kwargs["length"] = self.length if self.columns: - kwargs['columns'] = self.columns + kwargs["columns"] = self.columns return path, args, kwargs def get_with_params(self): with_params = [] if self.length is not None: - with_params.append('length = %d' % self.length) + with_params.append("length = %d" % self.length) if self.columns: with_params.extend( - 'col%d = %d' % (i, v) - for i, v in enumerate(self.columns, start=1) + "col%d = %d" % (i, v) for i, v in enumerate(self.columns, start=1) ) return with_params class BrinIndex(PostgresIndex): - suffix = 'brin' + suffix = "brin" - def __init__(self, *expressions, autosummarize=None, pages_per_range=None, **kwargs): + def __init__( + self, *expressions, autosummarize=None, pages_per_range=None, **kwargs + ): if pages_per_range is not None and pages_per_range <= 0: - raise ValueError('pages_per_range must be None or a positive integer') + raise ValueError("pages_per_range must be None or a positive integer") self.autosummarize = autosummarize self.pages_per_range = pages_per_range super().__init__(*expressions, **kwargs) @@ -93,22 +98,24 @@ class BrinIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.autosummarize is not None: - kwargs['autosummarize'] = self.autosummarize + kwargs["autosummarize"] = self.autosummarize if self.pages_per_range is not None: - kwargs['pages_per_range'] = self.pages_per_range + kwargs["pages_per_range"] = self.pages_per_range return path, args, kwargs def get_with_params(self): with_params = [] if self.autosummarize is not None: - with_params.append('autosummarize = %s' % ('on' if self.autosummarize else 'off')) + with_params.append( + "autosummarize = %s" % ("on" if self.autosummarize else "off") + ) if self.pages_per_range is not None: - with_params.append('pages_per_range = %d' % self.pages_per_range) + with_params.append("pages_per_range = %d" % self.pages_per_range) return with_params class BTreeIndex(PostgresIndex): - suffix = 'btree' + suffix = "btree" def __init__(self, *expressions, fillfactor=None, **kwargs): self.fillfactor = fillfactor @@ -117,20 +124,22 @@ class BTreeIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params class GinIndex(PostgresIndex): - suffix = 'gin' + suffix = "gin" - def __init__(self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs): + def __init__( + self, *expressions, fastupdate=None, gin_pending_list_limit=None, **kwargs + ): self.fastupdate = fastupdate self.gin_pending_list_limit = gin_pending_list_limit super().__init__(*expressions, **kwargs) @@ -138,22 +147,24 @@ class GinIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fastupdate is not None: - kwargs['fastupdate'] = self.fastupdate + kwargs["fastupdate"] = self.fastupdate if self.gin_pending_list_limit is not None: - kwargs['gin_pending_list_limit'] = self.gin_pending_list_limit + kwargs["gin_pending_list_limit"] = self.gin_pending_list_limit return path, args, kwargs def get_with_params(self): with_params = [] if self.gin_pending_list_limit is not None: - with_params.append('gin_pending_list_limit = %d' % self.gin_pending_list_limit) + with_params.append( + "gin_pending_list_limit = %d" % self.gin_pending_list_limit + ) if self.fastupdate is not None: - with_params.append('fastupdate = %s' % ('on' if self.fastupdate else 'off')) + with_params.append("fastupdate = %s" % ("on" if self.fastupdate else "off")) return with_params class GistIndex(PostgresIndex): - suffix = 'gist' + suffix = "gist" def __init__(self, *expressions, buffering=None, fillfactor=None, **kwargs): self.buffering = buffering @@ -163,73 +174,76 @@ class GistIndex(PostgresIndex): def deconstruct(self): path, args, kwargs = super().deconstruct() if self.buffering is not None: - kwargs['buffering'] = self.buffering + kwargs["buffering"] = self.buffering if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor + kwargs["fillfactor"] = self.fillfactor return path, args, kwargs def get_with_params(self): with_params = [] if self.buffering is not None: - with_params.append('buffering = %s' % ('on' if self.buffering else 'off')) + with_params.append("buffering = %s" % ("on" if self.buffering else "off")) if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) - return with_params - - def check_supported(self, schema_editor): - if self.include and not schema_editor.connection.features.supports_covering_gist_indexes: - raise NotSupportedError('Covering GiST indexes require PostgreSQL 12+.') - - -class HashIndex(PostgresIndex): - suffix = 'hash' - - def __init__(self, *expressions, fillfactor=None, **kwargs): - self.fillfactor = fillfactor - super().__init__(*expressions, **kwargs) - - def deconstruct(self): - path, args, kwargs = super().deconstruct() - if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor - return path, args, kwargs - - def get_with_params(self): - with_params = [] - if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) - return with_params - - -class SpGistIndex(PostgresIndex): - suffix = 'spgist' - - def __init__(self, *expressions, fillfactor=None, **kwargs): - self.fillfactor = fillfactor - super().__init__(*expressions, **kwargs) - - def deconstruct(self): - path, args, kwargs = super().deconstruct() - if self.fillfactor is not None: - kwargs['fillfactor'] = self.fillfactor - return path, args, kwargs - - def get_with_params(self): - with_params = [] - if self.fillfactor is not None: - with_params.append('fillfactor = %d' % self.fillfactor) + with_params.append("fillfactor = %d" % self.fillfactor) return with_params def check_supported(self, schema_editor): if ( - self.include and - not schema_editor.connection.features.supports_covering_spgist_indexes + self.include + and not schema_editor.connection.features.supports_covering_gist_indexes ): - raise NotSupportedError('Covering SP-GiST indexes require PostgreSQL 14+.') + raise NotSupportedError("Covering GiST indexes require PostgreSQL 12+.") + + +class HashIndex(PostgresIndex): + suffix = "hash" + + def __init__(self, *expressions, fillfactor=None, **kwargs): + self.fillfactor = fillfactor + super().__init__(*expressions, **kwargs) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + if self.fillfactor is not None: + kwargs["fillfactor"] = self.fillfactor + return path, args, kwargs + + def get_with_params(self): + with_params = [] + if self.fillfactor is not None: + with_params.append("fillfactor = %d" % self.fillfactor) + return with_params + + +class SpGistIndex(PostgresIndex): + suffix = "spgist" + + def __init__(self, *expressions, fillfactor=None, **kwargs): + self.fillfactor = fillfactor + super().__init__(*expressions, **kwargs) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + if self.fillfactor is not None: + kwargs["fillfactor"] = self.fillfactor + return path, args, kwargs + + def get_with_params(self): + with_params = [] + if self.fillfactor is not None: + with_params.append("fillfactor = %d" % self.fillfactor) + return with_params + + def check_supported(self, schema_editor): + if ( + self.include + and not schema_editor.connection.features.supports_covering_spgist_indexes + ): + raise NotSupportedError("Covering SP-GiST indexes require PostgreSQL 14+.") class OpClass(Func): - template = '%(expressions)s %(name)s' + template = "%(expressions)s %(name)s" def __init__(self, expression, name): super().__init__(expression, name=name) diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index f7c6fc4b0c..9fed0eea30 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -5,61 +5,61 @@ from .search import SearchVector, SearchVectorExact, SearchVectorField class DataContains(PostgresOperatorLookup): - lookup_name = 'contains' - postgres_operator = '@>' + lookup_name = "contains" + postgres_operator = "@>" class ContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' - postgres_operator = '<@' + lookup_name = "contained_by" + postgres_operator = "<@" class Overlap(PostgresOperatorLookup): - lookup_name = 'overlap' - postgres_operator = '&&' + lookup_name = "overlap" + postgres_operator = "&&" class HasKey(PostgresOperatorLookup): - lookup_name = 'has_key' - postgres_operator = '?' + lookup_name = "has_key" + postgres_operator = "?" prepare_rhs = False class HasKeys(PostgresOperatorLookup): - lookup_name = 'has_keys' - postgres_operator = '?&' + lookup_name = "has_keys" + postgres_operator = "?&" def get_prep_lookup(self): return [str(item) for item in self.rhs] class HasAnyKeys(HasKeys): - lookup_name = 'has_any_keys' - postgres_operator = '?|' + lookup_name = "has_any_keys" + postgres_operator = "?|" class Unaccent(Transform): bilateral = True - lookup_name = 'unaccent' - function = 'UNACCENT' + lookup_name = "unaccent" + function = "UNACCENT" class SearchLookup(SearchVectorExact): - lookup_name = 'search' + lookup_name = "search" def process_lhs(self, qn, connection): if not isinstance(self.lhs.output_field, SearchVectorField): - config = getattr(self.rhs, 'config', None) + config = getattr(self.rhs, "config", None) self.lhs = SearchVector(self.lhs, config=config) lhs, lhs_params = super().process_lhs(qn, connection) return lhs, lhs_params class TrigramSimilar(PostgresOperatorLookup): - lookup_name = 'trigram_similar' - postgres_operator = '%%' + lookup_name = "trigram_similar" + postgres_operator = "%%" class TrigramWordSimilar(PostgresOperatorLookup): - lookup_name = 'trigram_word_similar' - postgres_operator = '%%>' + lookup_name = "trigram_word_similar" + postgres_operator = "%%>" diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py index 037bb4ec22..374f5ee1ec 100644 --- a/django/contrib/postgres/operations.py +++ b/django/contrib/postgres/operations.py @@ -1,5 +1,7 @@ from django.contrib.postgres.signals import ( - get_citext_oids, get_hstore_oids, register_type_handlers, + get_citext_oids, + get_hstore_oids, + register_type_handlers, ) from django.db import NotSupportedError, router from django.db.migrations import AddConstraint, AddIndex, RemoveIndex @@ -17,14 +19,14 @@ class CreateExtension(Operation): pass def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return if not self.extension_exists(schema_editor, self.name): schema_editor.execute( - 'CREATE EXTENSION IF NOT EXISTS %s' % schema_editor.quote_name(self.name) + "CREATE EXTENSION IF NOT EXISTS %s" + % schema_editor.quote_name(self.name) ) # Clear cached, stale oids. get_hstore_oids.cache_clear() @@ -39,7 +41,7 @@ class CreateExtension(Operation): return if self.extension_exists(schema_editor, self.name): schema_editor.execute( - 'DROP EXTENSION IF EXISTS %s' % schema_editor.quote_name(self.name) + "DROP EXTENSION IF EXISTS %s" % schema_editor.quote_name(self.name) ) # Clear cached, stale oids. get_hstore_oids.cache_clear() @@ -48,7 +50,7 @@ class CreateExtension(Operation): def extension_exists(self, schema_editor, extension): with schema_editor.connection.cursor() as cursor: cursor.execute( - 'SELECT 1 FROM pg_extension WHERE extname = %s', + "SELECT 1 FROM pg_extension WHERE extname = %s", [extension], ) return bool(cursor.fetchone()) @@ -58,75 +60,67 @@ class CreateExtension(Operation): @property def migration_name_fragment(self): - return 'create_extension_%s' % self.name + return "create_extension_%s" % self.name class BloomExtension(CreateExtension): - def __init__(self): - self.name = 'bloom' + self.name = "bloom" class BtreeGinExtension(CreateExtension): - def __init__(self): - self.name = 'btree_gin' + self.name = "btree_gin" class BtreeGistExtension(CreateExtension): - def __init__(self): - self.name = 'btree_gist' + self.name = "btree_gist" class CITextExtension(CreateExtension): - def __init__(self): - self.name = 'citext' + self.name = "citext" class CryptoExtension(CreateExtension): - def __init__(self): - self.name = 'pgcrypto' + self.name = "pgcrypto" class HStoreExtension(CreateExtension): - def __init__(self): - self.name = 'hstore' + self.name = "hstore" class TrigramExtension(CreateExtension): - def __init__(self): - self.name = 'pg_trgm' + self.name = "pg_trgm" class UnaccentExtension(CreateExtension): - def __init__(self): - self.name = 'unaccent' + self.name = "unaccent" class NotInTransactionMixin: def _ensure_not_in_transaction(self, schema_editor): if schema_editor.connection.in_atomic_block: raise NotSupportedError( - 'The %s operation cannot be executed inside a transaction ' - '(set atomic = False on the migration).' - % self.__class__.__name__ + "The %s operation cannot be executed inside a transaction " + "(set atomic = False on the migration)." % self.__class__.__name__ ) class AddIndexConcurrently(NotInTransactionMixin, AddIndex): """Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax.""" + atomic = False def describe(self): - return 'Concurrently create index %s on field(s) %s of model %s' % ( + return "Concurrently create index %s on field(s) %s of model %s" % ( self.index.name, - ', '.join(self.index.fields), + ", ".join(self.index.fields), self.model_name, ) @@ -145,10 +139,11 @@ class AddIndexConcurrently(NotInTransactionMixin, AddIndex): class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex): """Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax.""" + atomic = False def describe(self): - return 'Concurrently remove index %s from %s' % (self.name, self.model_name) + return "Concurrently remove index %s from %s" % (self.name, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): self._ensure_not_in_transaction(schema_editor) @@ -168,7 +163,7 @@ class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex): class CollationOperation(Operation): - def __init__(self, name, locale, *, provider='libc', deterministic=True): + def __init__(self, name, locale, *, provider="libc", deterministic=True): self.name = name self.locale = locale self.provider = provider @@ -178,11 +173,11 @@ class CollationOperation(Operation): pass def deconstruct(self): - kwargs = {'name': self.name, 'locale': self.locale} - if self.provider and self.provider != 'libc': - kwargs['provider'] = self.provider + kwargs = {"name": self.name, "locale": self.locale} + if self.provider and self.provider != "libc": + kwargs["provider"] = self.provider if self.deterministic is False: - kwargs['deterministic'] = self.deterministic + kwargs["deterministic"] = self.deterministic return ( self.__class__.__qualname__, [], @@ -191,34 +186,39 @@ class CollationOperation(Operation): def create_collation(self, schema_editor): if ( - self.deterministic is False and - not schema_editor.connection.features.supports_non_deterministic_collations + self.deterministic is False + and not schema_editor.connection.features.supports_non_deterministic_collations ): raise NotSupportedError( - 'Non-deterministic collations require PostgreSQL 12+.' + "Non-deterministic collations require PostgreSQL 12+." ) - args = {'locale': schema_editor.quote_name(self.locale)} - if self.provider != 'libc': - args['provider'] = schema_editor.quote_name(self.provider) + args = {"locale": schema_editor.quote_name(self.locale)} + if self.provider != "libc": + args["provider"] = schema_editor.quote_name(self.provider) if self.deterministic is False: - args['deterministic'] = 'false' - schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % { - 'name': schema_editor.quote_name(self.name), - 'args': ', '.join(f'{option}={value}' for option, value in args.items()), - }) + args["deterministic"] = "false" + schema_editor.execute( + "CREATE COLLATION %(name)s (%(args)s)" + % { + "name": schema_editor.quote_name(self.name), + "args": ", ".join( + f"{option}={value}" for option, value in args.items() + ), + } + ) def remove_collation(self, schema_editor): schema_editor.execute( - 'DROP COLLATION %s' % schema_editor.quote_name(self.name), + "DROP COLLATION %s" % schema_editor.quote_name(self.name), ) class CreateCollation(CollationOperation): """Create a collation.""" + def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return self.create_collation(schema_editor) @@ -229,19 +229,19 @@ class CreateCollation(CollationOperation): self.remove_collation(schema_editor) def describe(self): - return f'Create collation {self.name}' + return f"Create collation {self.name}" @property def migration_name_fragment(self): - return 'create_collation_%s' % self.name.lower() + return "create_collation_%s" % self.name.lower() class RemoveCollation(CollationOperation): """Remove a collation.""" + def database_forwards(self, app_label, schema_editor, from_state, to_state): - if ( - schema_editor.connection.vendor != 'postgresql' or - not router.allow_migrate(schema_editor.connection.alias, app_label) + if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate( + schema_editor.connection.alias, app_label ): return self.remove_collation(schema_editor) @@ -252,11 +252,11 @@ class RemoveCollation(CollationOperation): self.create_collation(schema_editor) def describe(self): - return f'Remove collation {self.name}' + return f"Remove collation {self.name}" @property def migration_name_fragment(self): - return 'remove_collation_%s' % self.name.lower() + return "remove_collation_%s" % self.name.lower() class AddConstraintNotValid(AddConstraint): @@ -268,12 +268,12 @@ class AddConstraintNotValid(AddConstraint): def __init__(self, model_name, constraint): if not isinstance(constraint, CheckConstraint): raise TypeError( - 'AddConstraintNotValid.constraint must be a check constraint.' + "AddConstraintNotValid.constraint must be a check constraint." ) super().__init__(model_name, constraint) def describe(self): - return 'Create not valid constraint %s on model %s' % ( + return "Create not valid constraint %s on model %s" % ( self.constraint.name, self.model_name, ) @@ -286,11 +286,11 @@ class AddConstraintNotValid(AddConstraint): # Constraint.create_sql returns interpolated SQL which makes # params=None a necessity to avoid escaping attempts on # execution. - schema_editor.execute(str(constraint_sql) + ' NOT VALID', params=None) + schema_editor.execute(str(constraint_sql) + " NOT VALID", params=None) @property def migration_name_fragment(self): - return super().migration_name_fragment + '_not_valid' + return super().migration_name_fragment + "_not_valid" class ValidateConstraint(Operation): @@ -301,15 +301,18 @@ class ValidateConstraint(Operation): self.name = name def describe(self): - return 'Validate constraint %s on model %s' % (self.name, self.model_name) + return "Validate constraint %s on model %s" % (self.name, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, model): - schema_editor.execute('ALTER TABLE %s VALIDATE CONSTRAINT %s' % ( - schema_editor.quote_name(model._meta.db_table), - schema_editor.quote_name(self.name), - )) + schema_editor.execute( + "ALTER TABLE %s VALIDATE CONSTRAINT %s" + % ( + schema_editor.quote_name(model._meta.db_table), + schema_editor.quote_name(self.name), + ) + ) def database_backwards(self, app_label, schema_editor, from_state, to_state): # PostgreSQL does not provide a way to make a constraint invalid. @@ -320,10 +323,14 @@ class ValidateConstraint(Operation): @property def migration_name_fragment(self): - return '%s_validate_%s' % (self.model_name.lower(), self.name.lower()) + return "%s_validate_%s" % (self.model_name.lower(), self.name.lower()) def deconstruct(self): - return self.__class__.__name__, [], { - 'model_name': self.model_name, - 'name': self.name, - } + return ( + self.__class__.__name__, + [], + { + "model_name": self.model_name, + "name": self.name, + }, + ) diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 164d359b91..f652c1d346 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -1,18 +1,25 @@ import psycopg2 from django.db.models import ( - CharField, Expression, Field, FloatField, Func, Lookup, TextField, Value, + CharField, + Expression, + Field, + FloatField, + Func, + Lookup, + TextField, + Value, ) from django.db.models.expressions import CombinedExpression from django.db.models.functions import Cast, Coalesce class SearchVectorExact(Lookup): - lookup_name = 'exact' + lookup_name = "exact" def process_rhs(self, qn, connection): if not isinstance(self.rhs, (SearchQuery, CombinedSearchQuery)): - config = getattr(self.lhs, 'config', None) + config = getattr(self.lhs, "config", None) self.rhs = SearchQuery(self.rhs, config=config) rhs, rhs_params = super().process_rhs(qn, connection) return rhs, rhs_params @@ -21,25 +28,23 @@ class SearchVectorExact(Lookup): lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) params = lhs_params + rhs_params - return '%s @@ %s' % (lhs, rhs), params + return "%s @@ %s" % (lhs, rhs), params class SearchVectorField(Field): - def db_type(self, connection): - return 'tsvector' + return "tsvector" class SearchQueryField(Field): - def db_type(self, connection): - return 'tsquery' + return "tsquery" class SearchConfig(Expression): def __init__(self, config): super().__init__() - if not hasattr(config, 'resolve_expression'): + if not hasattr(config, "resolve_expression"): config = Value(config) self.config = config @@ -53,21 +58,21 @@ class SearchConfig(Expression): return [self.config] def set_source_expressions(self, exprs): - self.config, = exprs + (self.config,) = exprs def as_sql(self, compiler, connection): sql, params = compiler.compile(self.config) - return '%s::regconfig' % sql, params + return "%s::regconfig" % sql, params class SearchVectorCombinable: - ADD = '||' + ADD = "||" def _combine(self, other, connector, reversed): if not isinstance(other, SearchVectorCombinable): raise TypeError( - 'SearchVector can only be combined with other SearchVector ' - 'instances, got %s.' % type(other).__name__ + "SearchVector can only be combined with other SearchVector " + "instances, got %s." % type(other).__name__ ) if reversed: return CombinedSearchVector(other, connector, self, self.config) @@ -75,49 +80,61 @@ class SearchVectorCombinable: class SearchVector(SearchVectorCombinable, Func): - function = 'to_tsvector' + function = "to_tsvector" arg_joiner = " || ' ' || " output_field = SearchVectorField() def __init__(self, *expressions, config=None, weight=None): super().__init__(*expressions) self.config = SearchConfig.from_parameter(config) - if weight is not None and not hasattr(weight, 'resolve_expression'): + if weight is not None and not hasattr(weight, "resolve_expression"): weight = Value(weight) self.weight = weight - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + resolved = super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) if self.config: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return resolved def as_sql(self, compiler, connection, function=None, template=None): clone = self.copy() - clone.set_source_expressions([ - Coalesce( - expression - if isinstance(expression.output_field, (CharField, TextField)) - else Cast(expression, TextField()), - Value('') - ) for expression in clone.get_source_expressions() - ]) + clone.set_source_expressions( + [ + Coalesce( + expression + if isinstance(expression.output_field, (CharField, TextField)) + else Cast(expression, TextField()), + Value(""), + ) + for expression in clone.get_source_expressions() + ] + ) config_sql = None config_params = [] if template is None: if clone.config: config_sql, config_params = compiler.compile(clone.config) - template = '%(function)s(%(config)s, %(expressions)s)' + template = "%(function)s(%(config)s, %(expressions)s)" else: template = clone.template sql, params = super(SearchVector, clone).as_sql( - compiler, connection, function=function, template=template, + compiler, + connection, + function=function, + template=template, config=config_sql, ) extra_params = [] if clone.weight: weight_sql, extra_params = compiler.compile(clone.weight) - sql = 'setweight({}, {})'.format(sql, weight_sql) + sql = "setweight({}, {})".format(sql, weight_sql) return sql, config_params + params + extra_params @@ -128,14 +145,14 @@ class CombinedSearchVector(SearchVectorCombinable, CombinedExpression): class SearchQueryCombinable: - BITAND = '&&' - BITOR = '||' + BITAND = "&&" + BITOR = "||" def _combine(self, other, connector, reversed): if not isinstance(other, SearchQueryCombinable): raise TypeError( - 'SearchQuery can only be combined with other SearchQuery ' - 'instances, got %s.' % type(other).__name__ + "SearchQuery can only be combined with other SearchQuery " + "instances, got %s." % type(other).__name__ ) if reversed: return CombinedSearchQuery(other, connector, self, self.config) @@ -160,17 +177,25 @@ class SearchQueryCombinable: class SearchQuery(SearchQueryCombinable, Func): output_field = SearchQueryField() SEARCH_TYPES = { - 'plain': 'plainto_tsquery', - 'phrase': 'phraseto_tsquery', - 'raw': 'to_tsquery', - 'websearch': 'websearch_to_tsquery', + "plain": "plainto_tsquery", + "phrase": "phraseto_tsquery", + "raw": "to_tsquery", + "websearch": "websearch_to_tsquery", } - def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): + def __init__( + self, + value, + output_field=None, + *, + config=None, + invert=False, + search_type="plain", + ): self.function = self.SEARCH_TYPES.get(search_type) if self.function is None: raise ValueError("Unknown search_type argument '%s'." % search_type) - if not hasattr(value, 'resolve_expression'): + if not hasattr(value, "resolve_expression"): value = Value(value) expressions = (value,) self.config = SearchConfig.from_parameter(config) @@ -182,7 +207,7 @@ class SearchQuery(SearchQueryCombinable, Func): def as_sql(self, compiler, connection, function=None, template=None): sql, params = super().as_sql(compiler, connection, function, template) if self.invert: - sql = '!!(%s)' % sql + sql = "!!(%s)" % sql return sql, params def __invert__(self): @@ -192,7 +217,7 @@ class SearchQuery(SearchQueryCombinable, Func): def __str__(self): result = super().__str__() - return ('~%s' % result) if self.invert else result + return ("~%s" % result) if self.invert else result class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): @@ -201,60 +226,73 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression): super().__init__(lhs, connector, rhs, output_field) def __str__(self): - return '(%s)' % super().__str__() + return "(%s)" % super().__str__() class SearchRank(Func): - function = 'ts_rank' + function = "ts_rank" output_field = FloatField() def __init__( - self, vector, query, weights=None, normalization=None, + self, + vector, + query, + weights=None, + normalization=None, cover_density=False, ): - if not hasattr(vector, 'resolve_expression'): + if not hasattr(vector, "resolve_expression"): vector = SearchVector(vector) - if not hasattr(query, 'resolve_expression'): + if not hasattr(query, "resolve_expression"): query = SearchQuery(query) expressions = (vector, query) if weights is not None: - if not hasattr(weights, 'resolve_expression'): + if not hasattr(weights, "resolve_expression"): weights = Value(weights) expressions = (weights,) + expressions if normalization is not None: - if not hasattr(normalization, 'resolve_expression'): + if not hasattr(normalization, "resolve_expression"): normalization = Value(normalization) expressions += (normalization,) if cover_density: - self.function = 'ts_rank_cd' + self.function = "ts_rank_cd" super().__init__(*expressions) class SearchHeadline(Func): - function = 'ts_headline' - template = '%(function)s(%(expressions)s%(options)s)' + function = "ts_headline" + template = "%(function)s(%(expressions)s%(options)s)" output_field = TextField() def __init__( - self, expression, query, *, config=None, start_sel=None, stop_sel=None, - max_words=None, min_words=None, short_word=None, highlight_all=None, - max_fragments=None, fragment_delimiter=None, + self, + expression, + query, + *, + config=None, + start_sel=None, + stop_sel=None, + max_words=None, + min_words=None, + short_word=None, + highlight_all=None, + max_fragments=None, + fragment_delimiter=None, ): - if not hasattr(query, 'resolve_expression'): + if not hasattr(query, "resolve_expression"): query = SearchQuery(query) options = { - 'StartSel': start_sel, - 'StopSel': stop_sel, - 'MaxWords': max_words, - 'MinWords': min_words, - 'ShortWord': short_word, - 'HighlightAll': highlight_all, - 'MaxFragments': max_fragments, - 'FragmentDelimiter': fragment_delimiter, + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxWords": max_words, + "MinWords": min_words, + "ShortWord": short_word, + "HighlightAll": highlight_all, + "MaxFragments": max_fragments, + "FragmentDelimiter": fragment_delimiter, } self.options = { - option: value - for option, value in options.items() if value is not None + option: value for option, value in options.items() if value is not None } expressions = (expression, query) if config is not None: @@ -263,19 +301,26 @@ class SearchHeadline(Func): super().__init__(*expressions) def as_sql(self, compiler, connection, function=None, template=None): - options_sql = '' + options_sql = "" options_params = [] if self.options: # getquoted() returns a quoted bytestring of the adapted value. - options_params.append(', '.join( - '%s=%s' % ( - option, - psycopg2.extensions.adapt(value).getquoted().decode(), - ) for option, value in self.options.items() - )) - options_sql = ', %s' + options_params.append( + ", ".join( + "%s=%s" + % ( + option, + psycopg2.extensions.adapt(value).getquoted().decode(), + ) + for option, value in self.options.items() + ) + ) + options_sql = ", %s" sql, params = super().as_sql( - compiler, connection, function=function, template=template, + compiler, + connection, + function=function, + template=template, options=options_sql, ) return sql, params + options_params @@ -288,7 +333,7 @@ class TrigramBase(Func): output_field = FloatField() def __init__(self, expression, string, **extra): - if not hasattr(string, 'resolve_expression'): + if not hasattr(string, "resolve_expression"): string = Value(string) super().__init__(expression, string, **extra) @@ -297,24 +342,24 @@ class TrigramWordBase(Func): output_field = FloatField() def __init__(self, string, expression, **extra): - if not hasattr(string, 'resolve_expression'): + if not hasattr(string, "resolve_expression"): string = Value(string) super().__init__(string, expression, **extra) class TrigramSimilarity(TrigramBase): - function = 'SIMILARITY' + function = "SIMILARITY" class TrigramDistance(TrigramBase): - function = '' - arg_joiner = ' <-> ' + function = "" + arg_joiner = " <-> " class TrigramWordDistance(TrigramWordBase): - function = '' - arg_joiner = ' <<-> ' + function = "" + arg_joiner = " <<-> " class TrigramWordSimilarity(TrigramWordBase): - function = 'WORD_SIMILARITY' + function = "WORD_SIMILARITY" diff --git a/django/contrib/postgres/serializers.py b/django/contrib/postgres/serializers.py index 1b1c2f1112..d04bfdbc69 100644 --- a/django/contrib/postgres/serializers.py +++ b/django/contrib/postgres/serializers.py @@ -6,5 +6,5 @@ class RangeSerializer(BaseSerializer): module = self.value.__class__.__module__ # Ranges are implemented in psycopg2._range but the public import # location is psycopg2.extras. - module = 'psycopg2.extras' if module == 'psycopg2._range' else module - return '%s.%r' % (module, self.value), {'import %s' % module} + module = "psycopg2.extras" if module == "psycopg2._range" else module + return "%s.%r" % (module, self.value), {"import %s" % module} diff --git a/django/contrib/postgres/signals.py b/django/contrib/postgres/signals.py index 420b5f6033..b61673fe1f 100644 --- a/django/contrib/postgres/signals.py +++ b/django/contrib/postgres/signals.py @@ -35,12 +35,14 @@ def get_citext_oids(connection_alias): def register_type_handlers(connection, **kwargs): - if connection.vendor != 'postgresql' or connection.alias == NO_DB_ALIAS: + if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS: return try: oids, array_oids = get_hstore_oids(connection.alias) - register_hstore(connection.connection, globally=True, oid=oids, array_oid=array_oids) + register_hstore( + connection.connection, globally=True, oid=oids, array_oid=array_oids + ) except ProgrammingError: # Hstore is not available on the database. # @@ -54,7 +56,9 @@ def register_type_handlers(connection, **kwargs): try: citext_oids = get_citext_oids(connection.alias) - array_type = psycopg2.extensions.new_array_type(citext_oids, 'citext[]', psycopg2.STRING) + array_type = psycopg2.extensions.new_array_type( + citext_oids, "citext[]", psycopg2.STRING + ) psycopg2.extensions.register_type(array_type, None) except ProgrammingError: # citext is not available on the database. diff --git a/django/contrib/postgres/utils.py b/django/contrib/postgres/utils.py index f3c022f474..e4f4d81514 100644 --- a/django/contrib/postgres/utils.py +++ b/django/contrib/postgres/utils.py @@ -17,13 +17,13 @@ def prefix_validation_error(error, prefix, code, params): # ngettext calls require a count parameter and are converted # to an empty string if they are missing it. message=format_lazy( - '{} {}', + "{} {}", SimpleLazyObject(lambda: prefix % params), SimpleLazyObject(lambda: error.message % error_params), ), code=code, params={**error_params, **params}, ) - return ValidationError([ - prefix_validation_error(e, prefix, code, params) for e in error.error_list - ]) + return ValidationError( + [prefix_validation_error(e, prefix, code, params) for e in error.error_list] + ) diff --git a/django/contrib/postgres/validators.py b/django/contrib/postgres/validators.py index db6205f356..df2bd88eb9 100644 --- a/django/contrib/postgres/validators.py +++ b/django/contrib/postgres/validators.py @@ -1,24 +1,29 @@ from django.core.exceptions import ValidationError from django.core.validators import ( - MaxLengthValidator, MaxValueValidator, MinLengthValidator, + MaxLengthValidator, + MaxValueValidator, + MinLengthValidator, MinValueValidator, ) from django.utils.deconstruct import deconstructible -from django.utils.translation import gettext_lazy as _, ngettext_lazy +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import ngettext_lazy class ArrayMaxLengthValidator(MaxLengthValidator): message = ngettext_lazy( - 'List contains %(show_value)d item, it should contain no more than %(limit_value)d.', - 'List contains %(show_value)d items, it should contain no more than %(limit_value)d.', - 'limit_value') + "List contains %(show_value)d item, it should contain no more than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no more than %(limit_value)d.", + "limit_value", + ) class ArrayMinLengthValidator(MinLengthValidator): message = ngettext_lazy( - 'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.', - 'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.', - 'limit_value') + "List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.", + "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.", + "limit_value", + ) @deconstructible @@ -26,8 +31,8 @@ class KeysValidator: """A validator designed for HStore to require/restrict keys.""" messages = { - 'missing_keys': _('Some keys were missing: %(keys)s'), - 'extra_keys': _('Some unknown keys were provided: %(keys)s'), + "missing_keys": _("Some keys were missing: %(keys)s"), + "extra_keys": _("Some unknown keys were provided: %(keys)s"), } strict = False @@ -42,35 +47,41 @@ class KeysValidator: missing_keys = self.keys - keys if missing_keys: raise ValidationError( - self.messages['missing_keys'], - code='missing_keys', - params={'keys': ', '.join(missing_keys)}, + self.messages["missing_keys"], + code="missing_keys", + params={"keys": ", ".join(missing_keys)}, ) if self.strict: extra_keys = keys - self.keys if extra_keys: raise ValidationError( - self.messages['extra_keys'], - code='extra_keys', - params={'keys': ', '.join(extra_keys)}, + self.messages["extra_keys"], + code="extra_keys", + params={"keys": ", ".join(extra_keys)}, ) def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self.keys == other.keys and - self.messages == other.messages and - self.strict == other.strict + isinstance(other, self.__class__) + and self.keys == other.keys + and self.messages == other.messages + and self.strict == other.strict ) class RangeMaxValueValidator(MaxValueValidator): def compare(self, a, b): return a.upper is None or a.upper > b - message = _('Ensure that this range is completely less than or equal to %(limit_value)s.') + + message = _( + "Ensure that this range is completely less than or equal to %(limit_value)s." + ) class RangeMinValueValidator(MinValueValidator): def compare(self, a, b): return a.lower is None or a.lower < b - message = _('Ensure that this range is completely greater than or equal to %(limit_value)s.') + + message = _( + "Ensure that this range is completely greater than or equal to %(limit_value)s." + ) diff --git a/django/contrib/redirects/admin.py b/django/contrib/redirects/admin.py index f828747d76..39400ad265 100644 --- a/django/contrib/redirects/admin.py +++ b/django/contrib/redirects/admin.py @@ -4,7 +4,7 @@ from django.contrib.redirects.models import Redirect @admin.register(Redirect) class RedirectAdmin(admin.ModelAdmin): - list_display = ('old_path', 'new_path') - list_filter = ('site',) - search_fields = ('old_path', 'new_path') - radio_fields = {'site': admin.VERTICAL} + list_display = ("old_path", "new_path") + list_filter = ("site",) + search_fields = ("old_path", "new_path") + radio_fields = {"site": admin.VERTICAL} diff --git a/django/contrib/redirects/apps.py b/django/contrib/redirects/apps.py index c1d80ee3c1..d7706711b7 100644 --- a/django/contrib/redirects/apps.py +++ b/django/contrib/redirects/apps.py @@ -3,6 +3,6 @@ from django.utils.translation import gettext_lazy as _ class RedirectsConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.redirects' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.redirects" verbose_name = _("Redirects") diff --git a/django/contrib/redirects/middleware.py b/django/contrib/redirects/middleware.py index e148c17693..bc4fe9cbcd 100644 --- a/django/contrib/redirects/middleware.py +++ b/django/contrib/redirects/middleware.py @@ -13,7 +13,7 @@ class RedirectFallbackMiddleware(MiddlewareMixin): response_redirect_class = HttpResponsePermanentRedirect def __init__(self, get_response): - if not apps.is_installed('django.contrib.sites'): + if not apps.is_installed("django.contrib.sites"): raise ImproperlyConfigured( "You cannot use RedirectFallbackMiddleware when " "django.contrib.sites is not installed." @@ -33,7 +33,7 @@ class RedirectFallbackMiddleware(MiddlewareMixin): r = Redirect.objects.get(site=current_site, old_path=full_path) except Redirect.DoesNotExist: pass - if r is None and settings.APPEND_SLASH and not request.path.endswith('/'): + if r is None and settings.APPEND_SLASH and not request.path.endswith("/"): try: r = Redirect.objects.get( site=current_site, @@ -42,7 +42,7 @@ class RedirectFallbackMiddleware(MiddlewareMixin): except Redirect.DoesNotExist: pass if r is not None: - if r.new_path == '': + if r.new_path == "": return self.response_gone_class() return self.response_redirect_class(r.new_path) diff --git a/django/contrib/redirects/migrations/0001_initial.py b/django/contrib/redirects/migrations/0001_initial.py index 58165a13bb..79dd21f15d 100644 --- a/django/contrib/redirects/migrations/0001_initial.py +++ b/django/contrib/redirects/migrations/0001_initial.py @@ -4,35 +4,57 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('sites', '0001_initial'), + ("sites", "0001_initial"), ] operations = [ migrations.CreateModel( - name='Redirect', + name="Redirect", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('site', models.ForeignKey( - to='sites.Site', - on_delete=models.CASCADE, - verbose_name='site', - )), - ('old_path', models.CharField( - help_text=( - 'This should be an absolute path, excluding the domain name. Example: “/events/search/”.' - ), max_length=200, verbose_name='redirect from', db_index=True - )), - ('new_path', models.CharField( - help_text='This can be either an absolute path (as above) or a full URL starting with “http://”.', - max_length=200, verbose_name='redirect to', blank=True - )), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "site", + models.ForeignKey( + to="sites.Site", + on_delete=models.CASCADE, + verbose_name="site", + ), + ), + ( + "old_path", + models.CharField( + help_text=( + "This should be an absolute path, excluding the domain name. Example: “/events/search/”." + ), + max_length=200, + verbose_name="redirect from", + db_index=True, + ), + ), + ( + "new_path", + models.CharField( + help_text="This can be either an absolute path (as above) or a full URL starting with “http://”.", + max_length=200, + verbose_name="redirect to", + blank=True, + ), + ), ], options={ - 'ordering': ['old_path'], - 'unique_together': {('site', 'old_path')}, - 'db_table': 'django_redirect', - 'verbose_name': 'redirect', - 'verbose_name_plural': 'redirects', + "ordering": ["old_path"], + "unique_together": {("site", "old_path")}, + "db_table": "django_redirect", + "verbose_name": "redirect", + "verbose_name_plural": "redirects", }, bases=(models.Model,), ), diff --git a/django/contrib/redirects/migrations/0002_alter_redirect_new_path_help_text.py b/django/contrib/redirects/migrations/0002_alter_redirect_new_path_help_text.py index afbd19e415..84f67e10d9 100644 --- a/django/contrib/redirects/migrations/0002_alter_redirect_new_path_help_text.py +++ b/django/contrib/redirects/migrations/0002_alter_redirect_new_path_help_text.py @@ -4,21 +4,21 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('redirects', '0001_initial'), + ("redirects", "0001_initial"), ] operations = [ migrations.AlterField( - model_name='redirect', - name='new_path', + model_name="redirect", + name="new_path", field=models.CharField( blank=True, help_text=( - 'This can be either an absolute path (as above) or a full ' - 'URL starting with a scheme such as “https://”.' + "This can be either an absolute path (as above) or a full " + "URL starting with a scheme such as “https://”." ), max_length=200, - verbose_name='redirect to', + verbose_name="redirect to", ), ), ] diff --git a/django/contrib/redirects/models.py b/django/contrib/redirects/models.py index a200b88f94..0375e65c67 100644 --- a/django/contrib/redirects/models.py +++ b/django/contrib/redirects/models.py @@ -4,29 +4,31 @@ from django.utils.translation import gettext_lazy as _ class Redirect(models.Model): - site = models.ForeignKey(Site, models.CASCADE, verbose_name=_('site')) + site = models.ForeignKey(Site, models.CASCADE, verbose_name=_("site")) old_path = models.CharField( - _('redirect from'), + _("redirect from"), max_length=200, db_index=True, - help_text=_('This should be an absolute path, excluding the domain name. Example: “/events/search/”.'), + help_text=_( + "This should be an absolute path, excluding the domain name. Example: “/events/search/”." + ), ) new_path = models.CharField( - _('redirect to'), + _("redirect to"), max_length=200, blank=True, help_text=_( - 'This can be either an absolute path (as above) or a full URL ' - 'starting with a scheme such as “https://”.' + "This can be either an absolute path (as above) or a full URL " + "starting with a scheme such as “https://”." ), ) class Meta: - verbose_name = _('redirect') - verbose_name_plural = _('redirects') - db_table = 'django_redirect' - unique_together = [['site', 'old_path']] - ordering = ['old_path'] + verbose_name = _("redirect") + verbose_name_plural = _("redirects") + db_table = "django_redirect" + unique_together = [["site", "old_path"]] + ordering = ["old_path"] def __str__(self): return "%s ---> %s" % (self.old_path, self.new_path) diff --git a/django/contrib/sessions/apps.py b/django/contrib/sessions/apps.py index 8b778d1109..83f0cefa2d 100644 --- a/django/contrib/sessions/apps.py +++ b/django/contrib/sessions/apps.py @@ -3,5 +3,5 @@ from django.utils.translation import gettext_lazy as _ class SessionsConfig(AppConfig): - name = 'django.contrib.sessions' + name = "django.contrib.sessions" verbose_name = _("Sessions") diff --git a/django/contrib/sessions/backends/base.py b/django/contrib/sessions/backends/base.py index f86f329550..050e7387be 100644 --- a/django/contrib/sessions/backends/base.py +++ b/django/contrib/sessions/backends/base.py @@ -18,6 +18,7 @@ class CreateError(Exception): Used internally as a consistent exception type to catch from save (see the docstring for SessionBase.save() for details). """ + pass @@ -25,6 +26,7 @@ class UpdateError(Exception): """ Occurs if Django tries to update a session that was deleted. """ + pass @@ -32,8 +34,9 @@ class SessionBase: """ Base class for all Session classes. """ - TEST_COOKIE_NAME = 'testcookie' - TEST_COOKIE_VALUE = 'worked' + + TEST_COOKIE_NAME = "testcookie" + TEST_COOKIE_VALUE = "worked" __not_given = object() @@ -59,7 +62,7 @@ class SessionBase: @property def key_salt(self): - return 'django.contrib.sessions.' + self.__class__.__qualname__ + return "django.contrib.sessions." + self.__class__.__qualname__ def get(self, key, default=None): return self._session.get(key, default) @@ -89,16 +92,20 @@ class SessionBase: def encode(self, session_dict): "Return the given session dictionary serialized and encoded as a string." return signing.dumps( - session_dict, salt=self.key_salt, serializer=self.serializer, + session_dict, + salt=self.key_salt, + serializer=self.serializer, compress=True, ) def decode(self, session_data): try: - return signing.loads(session_data, salt=self.key_salt, serializer=self.serializer) + return signing.loads( + session_data, salt=self.key_salt, serializer=self.serializer + ) except signing.BadSignature: - logger = logging.getLogger('django.security.SuspiciousSession') - logger.warning('Session data corrupted') + logger = logging.getLogger("django.security.SuspiciousSession") + logger.warning("Session data corrupted") except Exception: # ValueError, unpickling exceptions. If any of these happen, just # return an empty dictionary (an empty session). @@ -197,18 +204,18 @@ class SessionBase: arguments specifying the modification and expiry of the session. """ try: - modification = kwargs['modification'] + modification = kwargs["modification"] except KeyError: modification = timezone.now() # Make the difference between "expiry=None passed in kwargs" and # "expiry not passed in kwargs", in order to guarantee not to trigger # self.load() when expiry is provided. try: - expiry = kwargs['expiry'] + expiry = kwargs["expiry"] except KeyError: - expiry = self.get('_session_expiry') + expiry = self.get("_session_expiry") - if not expiry: # Checks both None and 0 cases + if not expiry: # Checks both None and 0 cases return self.get_session_cookie_age() if not isinstance(expiry, (datetime, str)): return expiry @@ -224,14 +231,14 @@ class SessionBase: arguments specifying the modification and expiry of the session. """ try: - modification = kwargs['modification'] + modification = kwargs["modification"] except KeyError: modification = timezone.now() # Same comment as in get_expiry_age try: - expiry = kwargs['expiry'] + expiry = kwargs["expiry"] except KeyError: - expiry = self.get('_session_expiry') + expiry = self.get("_session_expiry") if isinstance(expiry, datetime): return expiry @@ -258,7 +265,7 @@ class SessionBase: if value is None: # Remove any custom expiration for this session. try: - del self['_session_expiry'] + del self["_session_expiry"] except KeyError: pass return @@ -266,7 +273,7 @@ class SessionBase: value = timezone.now() + value if isinstance(value, datetime): value = value.isoformat() - self['_session_expiry'] = value + self["_session_expiry"] = value def get_expire_at_browser_close(self): """ @@ -275,7 +282,7 @@ class SessionBase: ``get_expiry_date()`` or ``get_expiry_age()`` to find the actual expiry date/age, if there is one. """ - if (expiry := self.get('_session_expiry')) is None: + if (expiry := self.get("_session_expiry")) is None: return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE return expiry == 0 @@ -305,7 +312,9 @@ class SessionBase: """ Return True if the given session_key already exists. """ - raise NotImplementedError('subclasses of SessionBase must provide an exists() method') + raise NotImplementedError( + "subclasses of SessionBase must provide an exists() method" + ) def create(self): """ @@ -313,7 +322,9 @@ class SessionBase: a unique key and will have saved the result once (with empty data) before the method returns. """ - raise NotImplementedError('subclasses of SessionBase must provide a create() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a create() method" + ) def save(self, must_create=False): """ @@ -321,20 +332,26 @@ class SessionBase: object (or raise CreateError). Otherwise, only update an existing object and don't create one (raise UpdateError if needed). """ - raise NotImplementedError('subclasses of SessionBase must provide a save() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a save() method" + ) def delete(self, session_key=None): """ Delete the session data under this key. If the key is None, use the current session key value. """ - raise NotImplementedError('subclasses of SessionBase must provide a delete() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a delete() method" + ) def load(self): """ Load the session data and return a dictionary. """ - raise NotImplementedError('subclasses of SessionBase must provide a load() method') + raise NotImplementedError( + "subclasses of SessionBase must provide a load() method" + ) @classmethod def clear_expired(cls): @@ -345,4 +362,4 @@ class SessionBase: NotImplementedError. If it isn't necessary, because the backend has a built-in expiration mechanism, it should be a no-op. """ - raise NotImplementedError('This backend does not support clear_expired().') + raise NotImplementedError("This backend does not support clear_expired().") diff --git a/django/contrib/sessions/backends/cache.py b/django/contrib/sessions/backends/cache.py index 860d3a46c5..0c9d244f56 100644 --- a/django/contrib/sessions/backends/cache.py +++ b/django/contrib/sessions/backends/cache.py @@ -1,7 +1,5 @@ from django.conf import settings -from django.contrib.sessions.backends.base import ( - CreateError, SessionBase, UpdateError, -) +from django.contrib.sessions.backends.base import CreateError, SessionBase, UpdateError from django.core.cache import caches KEY_PREFIX = "django.contrib.sessions.cache" @@ -11,6 +9,7 @@ class SessionStore(SessionBase): """ A cache-based session store. """ + cache_key_prefix = KEY_PREFIX def __init__(self, session_key=None): @@ -49,7 +48,8 @@ class SessionStore(SessionBase): return raise RuntimeError( "Unable to create a new session key. " - "It is likely that the cache is unavailable.") + "It is likely that the cache is unavailable." + ) def save(self, must_create=False): if self.session_key is None: @@ -60,14 +60,18 @@ class SessionStore(SessionBase): func = self._cache.set else: raise UpdateError - result = func(self.cache_key, - self._get_session(no_load=must_create), - self.get_expiry_age()) + result = func( + self.cache_key, + self._get_session(no_load=must_create), + self.get_expiry_age(), + ) if must_create and not result: raise CreateError def exists(self, session_key): - return bool(session_key) and (self.cache_key_prefix + session_key) in self._cache + return ( + bool(session_key) and (self.cache_key_prefix + session_key) in self._cache + ) def delete(self, session_key=None): if session_key is None: diff --git a/django/contrib/sessions/backends/cached_db.py b/django/contrib/sessions/backends/cached_db.py index 453d390694..3125a71cd0 100644 --- a/django/contrib/sessions/backends/cached_db.py +++ b/django/contrib/sessions/backends/cached_db.py @@ -13,6 +13,7 @@ class SessionStore(DBStore): """ Implement cached, database backed sessions. """ + cache_key_prefix = KEY_PREFIX def __init__(self, session_key=None): @@ -35,13 +36,19 @@ class SessionStore(DBStore): s = self._get_session_from_db() if s: data = self.decode(s.session_data) - self._cache.set(self.cache_key, data, self.get_expiry_age(expiry=s.expire_date)) + self._cache.set( + self.cache_key, data, self.get_expiry_age(expiry=s.expire_date) + ) else: data = {} return data def exists(self, session_key): - return session_key and (self.cache_key_prefix + session_key) in self._cache or super().exists(session_key) + return ( + session_key + and (self.cache_key_prefix + session_key) in self._cache + or super().exists(session_key) + ) def save(self, must_create=False): super().save(must_create) diff --git a/django/contrib/sessions/backends/db.py b/django/contrib/sessions/backends/db.py index 7c905a2c99..e1f6b69c55 100644 --- a/django/contrib/sessions/backends/db.py +++ b/django/contrib/sessions/backends/db.py @@ -1,8 +1,6 @@ import logging -from django.contrib.sessions.backends.base import ( - CreateError, SessionBase, UpdateError, -) +from django.contrib.sessions.backends.base import CreateError, SessionBase, UpdateError from django.core.exceptions import SuspiciousOperation from django.db import DatabaseError, IntegrityError, router, transaction from django.utils import timezone @@ -13,6 +11,7 @@ class SessionStore(SessionBase): """ Implement database session store. """ + def __init__(self, session_key=None): super().__init__(session_key) @@ -21,6 +20,7 @@ class SessionStore(SessionBase): # Avoids a circular import and allows importing SessionStore when # django.contrib.sessions is not in INSTALLED_APPS. from django.contrib.sessions.models import Session + return Session @cached_property @@ -30,12 +30,11 @@ class SessionStore(SessionBase): def _get_session_from_db(self): try: return self.model.objects.get( - session_key=self.session_key, - expire_date__gt=timezone.now() + session_key=self.session_key, expire_date__gt=timezone.now() ) except (self.model.DoesNotExist, SuspiciousOperation) as e: if isinstance(e, SuspiciousOperation): - logger = logging.getLogger('django.security.%s' % e.__class__.__name__) + logger = logging.getLogger("django.security.%s" % e.__class__.__name__) logger.warning(str(e)) self._session_key = None @@ -84,7 +83,9 @@ class SessionStore(SessionBase): using = router.db_for_write(self.model, instance=obj) try: with transaction.atomic(using=using): - obj.save(force_insert=must_create, force_update=not must_create, using=using) + obj.save( + force_insert=must_create, force_update=not must_create, using=using + ) except IntegrityError: if must_create: raise CreateError diff --git a/django/contrib/sessions/backends/file.py b/django/contrib/sessions/backends/file.py index 7032b4b1f8..947388333f 100644 --- a/django/contrib/sessions/backends/file.py +++ b/django/contrib/sessions/backends/file.py @@ -6,7 +6,10 @@ import tempfile from django.conf import settings from django.contrib.sessions.backends.base import ( - VALID_KEY_CHARS, CreateError, SessionBase, UpdateError, + VALID_KEY_CHARS, + CreateError, + SessionBase, + UpdateError, ) from django.contrib.sessions.exceptions import InvalidSessionKey from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation @@ -17,6 +20,7 @@ class SessionStore(SessionBase): """ Implement a file based session store. """ + def __init__(self, session_key=None): self.storage_path = self._get_storage_path() self.file_prefix = settings.SESSION_COOKIE_NAME @@ -27,13 +31,16 @@ class SessionStore(SessionBase): try: return cls._storage_path except AttributeError: - storage_path = getattr(settings, 'SESSION_FILE_PATH', None) or tempfile.gettempdir() + storage_path = ( + getattr(settings, "SESSION_FILE_PATH", None) or tempfile.gettempdir() + ) # Make sure the storage path is valid. if not os.path.isdir(storage_path): raise ImproperlyConfigured( "The session storage path %r doesn't exist. Please set your" " SESSION_FILE_PATH setting to an existing directory in which" - " Django can store session data." % storage_path) + " Django can store session data." % storage_path + ) cls._storage_path = storage_path return storage_path @@ -49,8 +56,7 @@ class SessionStore(SessionBase): # should always be md5s, so they should never contain directory # components. if not set(session_key).issubset(VALID_KEY_CHARS): - raise InvalidSessionKey( - "Invalid characters in session key") + raise InvalidSessionKey("Invalid characters in session key") return os.path.join(self.storage_path, self.file_prefix + session_key) @@ -66,14 +72,15 @@ class SessionStore(SessionBase): """ Return the expiry time of the file storing the session's content. """ - return session_data.get('_session_expiry') or ( - self._last_modification() + datetime.timedelta(seconds=self.get_session_cookie_age()) + return session_data.get("_session_expiry") or ( + self._last_modification() + + datetime.timedelta(seconds=self.get_session_cookie_age()) ) def load(self): session_data = {} try: - with open(self._key_to_file(), encoding='ascii') as session_file: + with open(self._key_to_file(), encoding="ascii") as session_file: file_data = session_file.read() # Don't fail if there is no data in the session file. # We may have opened the empty placeholder file. @@ -82,7 +89,9 @@ class SessionStore(SessionBase): session_data = self.decode(file_data) except (EOFError, SuspiciousOperation) as e: if isinstance(e, SuspiciousOperation): - logger = logging.getLogger('django.security.%s' % e.__class__.__name__) + logger = logging.getLogger( + "django.security.%s" % e.__class__.__name__ + ) logger.warning(str(e)) self.create() @@ -118,7 +127,7 @@ class SessionStore(SessionBase): try: # Make sure the file exists. If it does not already exist, an # empty placeholder file is created. - flags = os.O_WRONLY | getattr(os, 'O_BINARY', 0) + flags = os.O_WRONLY | getattr(os, "O_BINARY", 0) if must_create: flags |= os.O_EXCL | os.O_CREAT fd = os.open(session_file_name, flags) @@ -148,7 +157,9 @@ class SessionStore(SessionBase): dir, prefix = os.path.split(session_file_name) try: - output_file_fd, output_file_name = tempfile.mkstemp(dir=dir, prefix=prefix + '_out_') + output_file_fd, output_file_name = tempfile.mkstemp( + dir=dir, prefix=prefix + "_out_" + ) renamed = False try: try: @@ -191,7 +202,7 @@ class SessionStore(SessionBase): for session_file in os.listdir(storage_path): if not session_file.startswith(file_prefix): continue - session_key = session_file[len(file_prefix):] + session_key = session_file[len(file_prefix) :] session = cls(session_key) # When an expired session is loaded, its file is removed, and a # new file is immediately created. Prevent this by disabling diff --git a/django/contrib/sessions/backends/signed_cookies.py b/django/contrib/sessions/backends/signed_cookies.py index 8942df1ea4..dc41c6f12b 100644 --- a/django/contrib/sessions/backends/signed_cookies.py +++ b/django/contrib/sessions/backends/signed_cookies.py @@ -3,7 +3,6 @@ from django.core import signing class SessionStore(SessionBase): - def load(self): """ Load the data from the key itself instead of fetching from some @@ -16,7 +15,7 @@ class SessionStore(SessionBase): serializer=self.serializer, # This doesn't handle non-default expiry dates, see #19201 max_age=self.get_session_cookie_age(), - salt='django.contrib.sessions.backends.signed_cookies', + salt="django.contrib.sessions.backends.signed_cookies", ) except Exception: # BadSignature, ValueError, or unpickling exceptions. If any of @@ -54,7 +53,7 @@ class SessionStore(SessionBase): and set the modified flag so that the cookie is set on the client for the current request. """ - self._session_key = '' + self._session_key = "" self._session_cache = {} self.modified = True @@ -71,8 +70,9 @@ class SessionStore(SessionBase): base64-encoded string of data as our session key. """ return signing.dumps( - self._session, compress=True, - salt='django.contrib.sessions.backends.signed_cookies', + self._session, + compress=True, + salt="django.contrib.sessions.backends.signed_cookies", serializer=self.serializer, ) diff --git a/django/contrib/sessions/base_session.py b/django/contrib/sessions/base_session.py index 1d653b5adf..603d2fe12c 100644 --- a/django/contrib/sessions/base_session.py +++ b/django/contrib/sessions/base_session.py @@ -24,16 +24,16 @@ class BaseSessionManager(models.Manager): class AbstractBaseSession(models.Model): - session_key = models.CharField(_('session key'), max_length=40, primary_key=True) - session_data = models.TextField(_('session data')) - expire_date = models.DateTimeField(_('expire date'), db_index=True) + session_key = models.CharField(_("session key"), max_length=40, primary_key=True) + session_data = models.TextField(_("session data")) + expire_date = models.DateTimeField(_("expire date"), db_index=True) objects = BaseSessionManager() class Meta: abstract = True - verbose_name = _('session') - verbose_name_plural = _('sessions') + verbose_name = _("session") + verbose_name_plural = _("sessions") def __str__(self): return self.session_key diff --git a/django/contrib/sessions/exceptions.py b/django/contrib/sessions/exceptions.py index 8dbca3d04f..8a4853c502 100644 --- a/django/contrib/sessions/exceptions.py +++ b/django/contrib/sessions/exceptions.py @@ -3,14 +3,17 @@ from django.core.exceptions import BadRequest, SuspiciousOperation class InvalidSessionKey(SuspiciousOperation): """Invalid characters in session key""" + pass class SuspiciousSession(SuspiciousOperation): """The session may be tampered with""" + pass class SessionInterrupted(BadRequest): """The session was interrupted.""" + pass diff --git a/django/contrib/sessions/middleware.py b/django/contrib/sessions/middleware.py index 4de80a5858..2fcd7d508a 100644 --- a/django/contrib/sessions/middleware.py +++ b/django/contrib/sessions/middleware.py @@ -40,10 +40,10 @@ class SessionMiddleware(MiddlewareMixin): domain=settings.SESSION_COOKIE_DOMAIN, samesite=settings.SESSION_COOKIE_SAMESITE, ) - patch_vary_headers(response, ('Cookie',)) + patch_vary_headers(response, ("Cookie",)) else: if accessed: - patch_vary_headers(response, ('Cookie',)) + patch_vary_headers(response, ("Cookie",)) if (modified or settings.SESSION_SAVE_EVERY_REQUEST) and not empty: if request.session.get_expire_at_browser_close(): max_age = None @@ -65,8 +65,10 @@ class SessionMiddleware(MiddlewareMixin): ) response.set_cookie( settings.SESSION_COOKIE_NAME, - request.session.session_key, max_age=max_age, - expires=expires, domain=settings.SESSION_COOKIE_DOMAIN, + request.session.session_key, + max_age=max_age, + expires=expires, + domain=settings.SESSION_COOKIE_DOMAIN, path=settings.SESSION_COOKIE_PATH, secure=settings.SESSION_COOKIE_SECURE or None, httponly=settings.SESSION_COOKIE_HTTPONLY or None, diff --git a/django/contrib/sessions/migrations/0001_initial.py b/django/contrib/sessions/migrations/0001_initial.py index 39eaa6db41..83b0bbc2ae 100644 --- a/django/contrib/sessions/migrations/0001_initial.py +++ b/django/contrib/sessions/migrations/0001_initial.py @@ -4,27 +4,35 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( - name='Session', + name="Session", fields=[ - ('session_key', models.CharField( - max_length=40, serialize=False, verbose_name='session key', primary_key=True - )), - ('session_data', models.TextField(verbose_name='session data')), - ('expire_date', models.DateTimeField(verbose_name='expire date', db_index=True)), + ( + "session_key", + models.CharField( + max_length=40, + serialize=False, + verbose_name="session key", + primary_key=True, + ), + ), + ("session_data", models.TextField(verbose_name="session data")), + ( + "expire_date", + models.DateTimeField(verbose_name="expire date", db_index=True), + ), ], options={ - 'abstract': False, - 'db_table': 'django_session', - 'verbose_name': 'session', - 'verbose_name_plural': 'sessions', + "abstract": False, + "db_table": "django_session", + "verbose_name": "session", + "verbose_name_plural": "sessions", }, managers=[ - ('objects', django.contrib.sessions.models.SessionManager()), + ("objects", django.contrib.sessions.models.SessionManager()), ], ), ] diff --git a/django/contrib/sessions/models.py b/django/contrib/sessions/models.py index cb7a3f7c2f..e786ab4eac 100644 --- a/django/contrib/sessions/models.py +++ b/django/contrib/sessions/models.py @@ -1,6 +1,4 @@ -from django.contrib.sessions.base_session import ( - AbstractBaseSession, BaseSessionManager, -) +from django.contrib.sessions.base_session import AbstractBaseSession, BaseSessionManager class SessionManager(BaseSessionManager): @@ -24,12 +22,14 @@ class Session(AbstractBaseSession): the sessions documentation that is shipped with Django (also available on the Django web site). """ + objects = SessionManager() @classmethod def get_session_store_class(cls): from django.contrib.sessions.backends.db import SessionStore + return SessionStore class Meta(AbstractBaseSession.Meta): - db_table = 'django_session' + db_table = "django_session" diff --git a/django/contrib/sessions/serializers.py b/django/contrib/sessions/serializers.py index 6a9452bba0..f1adf91c22 100644 --- a/django/contrib/sessions/serializers.py +++ b/django/contrib/sessions/serializers.py @@ -1,7 +1,5 @@ # RemovedInDjango50Warning. -from django.core.serializers.base import ( - PickleSerializer as BasePickleSerializer, -) +from django.core.serializers.base import PickleSerializer as BasePickleSerializer from django.core.signing import JSONSerializer as BaseJSONSerializer JSONSerializer = BaseJSONSerializer diff --git a/django/contrib/sitemaps/__init__.py b/django/contrib/sitemaps/__init__.py index 27a2f4b434..05249b218b 100644 --- a/django/contrib/sitemaps/__init__.py +++ b/django/contrib/sitemaps/__init__.py @@ -25,32 +25,36 @@ def ping_google(sitemap_url=None, ping_url=PING_URL, sitemap_uses_https=True): function will attempt to deduce it by using urls.reverse(). """ sitemap_full_url = _get_sitemap_full_url(sitemap_url, sitemap_uses_https) - params = urlencode({'sitemap': sitemap_full_url}) - urlopen('%s?%s' % (ping_url, params)) + params = urlencode({"sitemap": sitemap_full_url}) + urlopen("%s?%s" % (ping_url, params)) def _get_sitemap_full_url(sitemap_url, sitemap_uses_https=True): - if not django_apps.is_installed('django.contrib.sites'): - raise ImproperlyConfigured("ping_google requires django.contrib.sites, which isn't installed.") + if not django_apps.is_installed("django.contrib.sites"): + raise ImproperlyConfigured( + "ping_google requires django.contrib.sites, which isn't installed." + ) if sitemap_url is None: try: # First, try to get the "index" sitemap URL. - sitemap_url = reverse('django.contrib.sitemaps.views.index') + sitemap_url = reverse("django.contrib.sitemaps.views.index") except NoReverseMatch: try: # Next, try for the "global" sitemap URL. - sitemap_url = reverse('django.contrib.sitemaps.views.sitemap') + sitemap_url = reverse("django.contrib.sitemaps.views.sitemap") except NoReverseMatch: pass if sitemap_url is None: - raise SitemapNotFound("You didn't provide a sitemap_url, and the sitemap URL couldn't be auto-detected.") + raise SitemapNotFound( + "You didn't provide a sitemap_url, and the sitemap URL couldn't be auto-detected." + ) - Site = django_apps.get_model('sites.Site') + Site = django_apps.get_model("sites.Site") current_site = Site.objects.get_current() - scheme = 'https' if sitemap_uses_https else 'http' - return '%s://%s%s' % (scheme, current_site.domain, sitemap_url) + scheme = "https" if sitemap_uses_https else "http" + return "%s://%s%s" % (scheme, current_site.domain, sitemap_url) class Sitemap: @@ -109,8 +113,8 @@ class Sitemap: obj, lang_code = item # Activate language from item-tuple or forced one before calling location. with translation.override(force_lang_code or lang_code): - return self._get('location', item) - return self._get('location', item) + return self._get("location", item) + return self._get("location", item) @property def paginator(self): @@ -134,13 +138,13 @@ class Sitemap: ) # RemovedInDjango50Warning: when the deprecation ends, replace 'http' # with 'https'. - return self.protocol or protocol or 'http' + return self.protocol or protocol or "http" def get_domain(self, site=None): # Determine domain if site is None: - if django_apps.is_installed('django.contrib.sites'): - Site = django_apps.get_model('sites.Site') + if django_apps.is_installed("django.contrib.sites"): + Site = django_apps.get_model("sites.Site") try: site = Site.objects.get_current() except Site.DoesNotExist: @@ -158,7 +162,7 @@ class Sitemap: return self._urls(page, protocol, domain) def get_latest_lastmod(self): - if not hasattr(self, 'lastmod'): + if not hasattr(self, "lastmod"): return None if callable(self.lastmod): try: @@ -175,40 +179,45 @@ class Sitemap: paginator_page = self.paginator.page(page) for item in paginator_page.object_list: - loc = f'{protocol}://{domain}{self._location(item)}' - priority = self._get('priority', item) - lastmod = self._get('lastmod', item) + loc = f"{protocol}://{domain}{self._location(item)}" + priority = self._get("priority", item) + lastmod = self._get("lastmod", item) if all_items_lastmod: all_items_lastmod = lastmod is not None - if (all_items_lastmod and - (latest_lastmod is None or lastmod > latest_lastmod)): + if all_items_lastmod and ( + latest_lastmod is None or lastmod > latest_lastmod + ): latest_lastmod = lastmod url_info = { - 'item': item, - 'location': loc, - 'lastmod': lastmod, - 'changefreq': self._get('changefreq', item), - 'priority': str(priority if priority is not None else ''), - 'alternates': [], + "item": item, + "location": loc, + "lastmod": lastmod, + "changefreq": self._get("changefreq", item), + "priority": str(priority if priority is not None else ""), + "alternates": [], } if self.i18n and self.alternates: for lang_code in self._languages(): - loc = f'{protocol}://{domain}{self._location(item, lang_code)}' - url_info['alternates'].append({ - 'location': loc, - 'lang_code': lang_code, - }) + loc = f"{protocol}://{domain}{self._location(item, lang_code)}" + url_info["alternates"].append( + { + "location": loc, + "lang_code": lang_code, + } + ) if self.x_default: lang_code = settings.LANGUAGE_CODE - loc = f'{protocol}://{domain}{self._location(item, lang_code)}' - loc = loc.replace(f'/{lang_code}/', '/', 1) - url_info['alternates'].append({ - 'location': loc, - 'lang_code': 'x-default', - }) + loc = f"{protocol}://{domain}{self._location(item, lang_code)}" + loc = loc.replace(f"/{lang_code}/", "/", 1) + url_info["alternates"].append( + { + "location": loc, + "lang_code": "x-default", + } + ) urls.append(url_info) @@ -223,8 +232,8 @@ class GenericSitemap(Sitemap): changefreq = None def __init__(self, info_dict, priority=None, changefreq=None, protocol=None): - self.queryset = info_dict['queryset'] - self.date_field = info_dict.get('date_field') + self.queryset = info_dict["queryset"] + self.date_field = info_dict.get("date_field") self.priority = self.priority or priority self.changefreq = self.changefreq or changefreq self.protocol = self.protocol or protocol @@ -240,5 +249,9 @@ class GenericSitemap(Sitemap): def get_latest_lastmod(self): if self.date_field is not None: - return self.queryset.order_by('-' + self.date_field).values_list(self.date_field, flat=True).first() + return ( + self.queryset.order_by("-" + self.date_field) + .values_list(self.date_field, flat=True) + .first() + ) return None diff --git a/django/contrib/sitemaps/apps.py b/django/contrib/sitemaps/apps.py index ec795eab87..70c200c63c 100644 --- a/django/contrib/sitemaps/apps.py +++ b/django/contrib/sitemaps/apps.py @@ -3,6 +3,6 @@ from django.utils.translation import gettext_lazy as _ class SiteMapsConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.sitemaps' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.sitemaps" verbose_name = _("Site Maps") diff --git a/django/contrib/sitemaps/management/commands/ping_google.py b/django/contrib/sitemaps/management/commands/ping_google.py index b2d8f84366..3db071ec59 100644 --- a/django/contrib/sitemaps/management/commands/ping_google.py +++ b/django/contrib/sitemaps/management/commands/ping_google.py @@ -6,11 +6,11 @@ class Command(BaseCommand): help = "Ping Google with an updated sitemap, pass optional url of sitemap" def add_arguments(self, parser): - parser.add_argument('sitemap_url', nargs='?') - parser.add_argument('--sitemap-uses-http', action='store_true') + parser.add_argument("sitemap_url", nargs="?") + parser.add_argument("--sitemap-uses-http", action="store_true") def handle(self, *args, **options): ping_google( - sitemap_url=options['sitemap_url'], - sitemap_uses_https=not options['sitemap_uses_http'], + sitemap_url=options["sitemap_url"], + sitemap_uses_https=not options["sitemap_uses_http"], ) diff --git a/django/contrib/sitemaps/views.py b/django/contrib/sitemaps/views.py index f742bd7f7a..f5c0935a96 100644 --- a/django/contrib/sitemaps/views.py +++ b/django/contrib/sitemaps/views.py @@ -20,7 +20,7 @@ class SitemapIndexItem: # RemovedInDjango50Warning def __str__(self): - msg = 'Calling `__str__` on SitemapIndexItem is deprecated, use the `location` attribute instead.' + msg = "Calling `__str__` on SitemapIndexItem is deprecated, use the `location` attribute instead." warnings.warn(msg, RemovedInDjango50Warning, stacklevel=2) return self.location @@ -29,8 +29,9 @@ def x_robots_tag(func): @wraps(func) def inner(request, *args, **kwargs): response = func(request, *args, **kwargs) - response.headers['X-Robots-Tag'] = 'noindex, noodp, noarchive' + response.headers["X-Robots-Tag"] = "noindex, noodp, noarchive" return response + return inner @@ -47,9 +48,13 @@ def _get_latest_lastmod(current_lastmod, new_lastmod): @x_robots_tag -def index(request, sitemaps, - template_name='sitemap_index.xml', content_type='application/xml', - sitemap_url_name='django.contrib.sitemaps.views.sitemap'): +def index( + request, + sitemaps, + template_name="sitemap_index.xml", + content_type="application/xml", + sitemap_url_name="django.contrib.sitemaps.views.sitemap", +): req_protocol = request.scheme req_site = get_current_site(request) @@ -63,8 +68,8 @@ def index(request, sitemaps, if callable(site): site = site() protocol = req_protocol if site.protocol is None else site.protocol - sitemap_url = reverse(sitemap_url_name, kwargs={'section': section}) - absolute_url = '%s://%s%s' % (protocol, req_site.domain, sitemap_url) + sitemap_url = reverse(sitemap_url_name, kwargs={"section": section}) + absolute_url = "%s://%s%s" % (protocol, req_site.domain, sitemap_url) site_lastmod = site.get_latest_lastmod() if all_indexes_lastmod: if site_lastmod is not None: @@ -74,25 +79,32 @@ def index(request, sitemaps, sites.append(SitemapIndexItem(absolute_url, site_lastmod)) # Add links to all pages of the sitemap. for page in range(2, site.paginator.num_pages + 1): - sites.append(SitemapIndexItem('%s?p=%s' % (absolute_url, page), site_lastmod)) + sites.append( + SitemapIndexItem("%s?p=%s" % (absolute_url, page), site_lastmod) + ) # If lastmod is defined for all sites, set header so as # ConditionalGetMiddleware is able to send 304 NOT MODIFIED if all_indexes_lastmod and latest_lastmod: - headers = {'Last-Modified': http_date(latest_lastmod.timestamp())} + headers = {"Last-Modified": http_date(latest_lastmod.timestamp())} else: headers = None return TemplateResponse( request, template_name, - {'sitemaps': sites}, + {"sitemaps": sites}, content_type=content_type, headers=headers, ) @x_robots_tag -def sitemap(request, sitemaps, section=None, - template_name='sitemap.xml', content_type='application/xml'): +def sitemap( + request, + sitemaps, + section=None, + template_name="sitemap.xml", + content_type="application/xml", +): req_protocol = request.scheme req_site = get_current_site(request) @@ -112,10 +124,9 @@ def sitemap(request, sitemaps, section=None, try: if callable(site): site = site() - urls.extend(site.get_urls(page=page, site=req_site, - protocol=req_protocol)) + urls.extend(site.get_urls(page=page, site=req_site, protocol=req_protocol)) if all_sites_lastmod: - site_lastmod = getattr(site, 'latest_lastmod', None) + site_lastmod = getattr(site, "latest_lastmod", None) if site_lastmod is not None: lastmod = _get_latest_lastmod(lastmod, site_lastmod) else: @@ -127,13 +138,13 @@ def sitemap(request, sitemaps, section=None, # If lastmod is defined for all sites, set header so as # ConditionalGetMiddleware is able to send 304 NOT MODIFIED if all_sites_lastmod: - headers = {'Last-Modified': http_date(lastmod.timestamp())} if lastmod else None + headers = {"Last-Modified": http_date(lastmod.timestamp())} if lastmod else None else: headers = None return TemplateResponse( request, template_name, - {'urlset': urls}, + {"urlset": urls}, content_type=content_type, headers=headers, ) diff --git a/django/contrib/sites/admin.py b/django/contrib/sites/admin.py index 2b167fe38c..53ad53d622 100644 --- a/django/contrib/sites/admin.py +++ b/django/contrib/sites/admin.py @@ -4,5 +4,5 @@ from django.contrib.sites.models import Site @admin.register(Site) class SiteAdmin(admin.ModelAdmin): - list_display = ('domain', 'name') - search_fields = ('domain', 'name') + list_display = ("domain", "name") + search_fields = ("domain", "name") diff --git a/django/contrib/sites/apps.py b/django/contrib/sites/apps.py index 7f820dcc79..ac51a84e18 100644 --- a/django/contrib/sites/apps.py +++ b/django/contrib/sites/apps.py @@ -8,8 +8,8 @@ from .management import create_default_site class SitesConfig(AppConfig): - default_auto_field = 'django.db.models.AutoField' - name = 'django.contrib.sites' + default_auto_field = "django.db.models.AutoField" + name = "django.contrib.sites" verbose_name = _("Sites") def ready(self): diff --git a/django/contrib/sites/checks.py b/django/contrib/sites/checks.py index c7dfe9ea52..6db039fb56 100644 --- a/django/contrib/sites/checks.py +++ b/django/contrib/sites/checks.py @@ -3,11 +3,10 @@ from django.core.checks import Error def check_site_id(app_configs, **kwargs): - if ( - hasattr(settings, 'SITE_ID') and - not isinstance(settings.SITE_ID, (type(None), int)) + if hasattr(settings, "SITE_ID") and not isinstance( + settings.SITE_ID, (type(None), int) ): return [ - Error('The SITE_ID setting must be an integer', id='sites.E101'), + Error("The SITE_ID setting must be an integer", id="sites.E101"), ] return [] diff --git a/django/contrib/sites/management.py b/django/contrib/sites/management.py index 34262336a2..dd75bc1ba9 100644 --- a/django/contrib/sites/management.py +++ b/django/contrib/sites/management.py @@ -8,9 +8,16 @@ from django.core.management.color import no_style from django.db import DEFAULT_DB_ALIAS, connections, router -def create_default_site(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, apps=global_apps, **kwargs): +def create_default_site( + app_config, + verbosity=2, + interactive=True, + using=DEFAULT_DB_ALIAS, + apps=global_apps, + **kwargs, +): try: - Site = apps.get_model('sites', 'Site') + Site = apps.get_model("sites", "Site") except LookupError: return @@ -25,7 +32,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, using=DEFAULT # can also crop up outside of tests - see #15346. if verbosity >= 2: print("Creating example.com Site object") - Site(pk=getattr(settings, 'SITE_ID', 1), domain="example.com", name="example.com").save(using=using) + Site( + pk=getattr(settings, "SITE_ID", 1), domain="example.com", name="example.com" + ).save(using=using) # We set an explicit pk instead of relying on auto-incrementation, # so we need to reset the database sequence. See #17415. diff --git a/django/contrib/sites/managers.py b/django/contrib/sites/managers.py index 91c034e967..15682d7e38 100644 --- a/django/contrib/sites/managers.py +++ b/django/contrib/sites/managers.py @@ -25,36 +25,40 @@ class CurrentSiteManager(models.Manager): except FieldDoesNotExist: return [ checks.Error( - "CurrentSiteManager could not find a field named '%s'." % field_name, + "CurrentSiteManager could not find a field named '%s'." + % field_name, obj=self, - id='sites.E001', + id="sites.E001", ) ] if not field.many_to_many and not isinstance(field, (models.ForeignKey)): return [ checks.Error( - "CurrentSiteManager cannot use '%s.%s' as it is not a foreign key or a many-to-many field." % ( - self.model._meta.object_name, field_name - ), + "CurrentSiteManager cannot use '%s.%s' as it is not a foreign key or a many-to-many field." + % (self.model._meta.object_name, field_name), obj=self, - id='sites.E002', + id="sites.E002", ) ] return [] def _get_field_name(self): - """ Return self.__field_name or 'site' or 'sites'. """ + """Return self.__field_name or 'site' or 'sites'.""" if not self.__field_name: try: - self.model._meta.get_field('site') + self.model._meta.get_field("site") except FieldDoesNotExist: - self.__field_name = 'sites' + self.__field_name = "sites" else: - self.__field_name = 'site' + self.__field_name = "site" return self.__field_name def get_queryset(self): - return super().get_queryset().filter(**{self._get_field_name() + '__id': settings.SITE_ID}) + return ( + super() + .get_queryset() + .filter(**{self._get_field_name() + "__id": settings.SITE_ID}) + ) diff --git a/django/contrib/sites/migrations/0001_initial.py b/django/contrib/sites/migrations/0001_initial.py index 9b261900fa..181cf47ad7 100644 --- a/django/contrib/sites/migrations/0001_initial.py +++ b/django/contrib/sites/migrations/0001_initial.py @@ -9,23 +9,36 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='Site', + name="Site", fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('domain', models.CharField( - max_length=100, verbose_name='domain name', validators=[_simple_domain_name_validator] - )), - ('name', models.CharField(max_length=50, verbose_name='display name')), + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "domain", + models.CharField( + max_length=100, + verbose_name="domain name", + validators=[_simple_domain_name_validator], + ), + ), + ("name", models.CharField(max_length=50, verbose_name="display name")), ], options={ - 'ordering': ['domain'], - 'db_table': 'django_site', - 'verbose_name': 'site', - 'verbose_name_plural': 'sites', + "ordering": ["domain"], + "db_table": "django_site", + "verbose_name": "site", + "verbose_name_plural": "sites", }, bases=(models.Model,), managers=[ - ('objects', django.contrib.sites.models.SiteManager()), + ("objects", django.contrib.sites.models.SiteManager()), ], ), ] diff --git a/django/contrib/sites/migrations/0002_alter_domain_unique.py b/django/contrib/sites/migrations/0002_alter_domain_unique.py index 6a26ebcde6..ccc7bfc384 100644 --- a/django/contrib/sites/migrations/0002_alter_domain_unique.py +++ b/django/contrib/sites/migrations/0002_alter_domain_unique.py @@ -5,16 +5,18 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('sites', '0001_initial'), + ("sites", "0001_initial"), ] operations = [ migrations.AlterField( - model_name='site', - name='domain', + model_name="site", + name="domain", field=models.CharField( - max_length=100, unique=True, validators=[django.contrib.sites.models._simple_domain_name_validator], - verbose_name='domain name' + max_length=100, + unique=True, + validators=[django.contrib.sites.models._simple_domain_name_validator], + verbose_name="domain name", ), ), ] diff --git a/django/contrib/sites/models.py b/django/contrib/sites/models.py index 3dc4c254d7..e1544f1fa2 100644 --- a/django/contrib/sites/models.py +++ b/django/contrib/sites/models.py @@ -18,7 +18,7 @@ def _simple_domain_name_validator(value): if any(checks): raise ValidationError( _("The domain name cannot contain any spaces or tabs."), - code='invalid', + code="invalid", ) @@ -53,14 +53,15 @@ class SiteManager(models.Manager): retrieved from the database. """ from django.conf import settings - if getattr(settings, 'SITE_ID', ''): + + if getattr(settings, "SITE_ID", ""): site_id = settings.SITE_ID return self._get_site_by_id(site_id) elif request: return self._get_site_by_request(request) raise ImproperlyConfigured( - "You're using the Django \"sites framework\" without having " + 'You\'re using the Django "sites framework" without having ' "set the SITE_ID setting. Create a site in your database and " "set the SITE_ID setting or pass a request to " "Site.objects.get_current() to fix this error." @@ -78,20 +79,20 @@ class SiteManager(models.Manager): class Site(models.Model): domain = models.CharField( - _('domain name'), + _("domain name"), max_length=100, validators=[_simple_domain_name_validator], unique=True, ) - name = models.CharField(_('display name'), max_length=50) + name = models.CharField(_("display name"), max_length=50) objects = SiteManager() class Meta: - db_table = 'django_site' - verbose_name = _('site') - verbose_name_plural = _('sites') - ordering = ['domain'] + db_table = "django_site" + verbose_name = _("site") + verbose_name_plural = _("sites") + ordering = ["domain"] def __str__(self): return self.domain @@ -104,8 +105,8 @@ def clear_site_cache(sender, **kwargs): """ Clear the cache (if primed) each time a site is saved or deleted. """ - instance = kwargs['instance'] - using = kwargs['using'] + instance = kwargs["instance"] + using = kwargs["using"] try: del SITE_CACHE[instance.pk] except KeyError: diff --git a/django/contrib/sites/requests.py b/django/contrib/sites/requests.py index f6f0270c09..a0c9c18aa3 100644 --- a/django/contrib/sites/requests.py +++ b/django/contrib/sites/requests.py @@ -6,6 +6,7 @@ class RequestSite: The save() and delete() methods raise NotImplementedError. """ + def __init__(self, request): self.domain = self.name = request.get_host() @@ -13,7 +14,7 @@ class RequestSite: return self.domain def save(self, force_insert=False, force_update=False): - raise NotImplementedError('RequestSite cannot be saved.') + raise NotImplementedError("RequestSite cannot be saved.") def delete(self): - raise NotImplementedError('RequestSite cannot be deleted.') + raise NotImplementedError("RequestSite cannot be deleted.") diff --git a/django/contrib/sites/shortcuts.py b/django/contrib/sites/shortcuts.py index 1131dba1ea..d4b65101af 100644 --- a/django/contrib/sites/shortcuts.py +++ b/django/contrib/sites/shortcuts.py @@ -10,8 +10,9 @@ def get_current_site(request): """ # Import is inside the function because its point is to avoid importing the # Site models when django.contrib.sites isn't installed. - if apps.is_installed('django.contrib.sites'): + if apps.is_installed("django.contrib.sites"): from .models import Site + return Site.objects.get_current(request) else: return RequestSite(request) diff --git a/django/contrib/staticfiles/apps.py b/django/contrib/staticfiles/apps.py index 2fbc829bf8..67acf042aa 100644 --- a/django/contrib/staticfiles/apps.py +++ b/django/contrib/staticfiles/apps.py @@ -5,9 +5,9 @@ from django.utils.translation import gettext_lazy as _ class StaticFilesConfig(AppConfig): - name = 'django.contrib.staticfiles' + name = "django.contrib.staticfiles" verbose_name = _("Static Files") - ignore_patterns = ['CVS', '.*', '*~'] + ignore_patterns = ["CVS", ".*", "*~"] def ready(self): checks.register(check_finders, checks.Tags.staticfiles) diff --git a/django/contrib/staticfiles/finders.py b/django/contrib/staticfiles/finders.py index 35e62a3ae0..184d297568 100644 --- a/django/contrib/staticfiles/finders.py +++ b/django/contrib/staticfiles/finders.py @@ -6,9 +6,7 @@ from django.conf import settings from django.contrib.staticfiles import utils from django.core.checks import Error, Warning from django.core.exceptions import ImproperlyConfigured -from django.core.files.storage import ( - FileSystemStorage, Storage, default_storage, -) +from django.core.files.storage import FileSystemStorage, Storage, default_storage from django.utils._os import safe_join from django.utils.functional import LazyObject, empty from django.utils.module_loading import import_string @@ -21,10 +19,11 @@ class BaseFinder: """ A base file finder to be used for custom staticfiles finder classes. """ + def check(self, **kwargs): raise NotImplementedError( - 'subclasses may provide a check() method to verify the finder is ' - 'configured correctly.' + "subclasses may provide a check() method to verify the finder is " + "configured correctly." ) def find(self, path, all=False): @@ -34,14 +33,18 @@ class BaseFinder: If the ``all`` parameter is False (default) return only the first found file path; if True, return a list of all found files paths. """ - raise NotImplementedError('subclasses of BaseFinder must provide a find() method') + raise NotImplementedError( + "subclasses of BaseFinder must provide a find() method" + ) def list(self, ignore_patterns): """ Given an optional list of paths to ignore, return a two item iterable consisting of the relative path and storage instance. """ - raise NotImplementedError('subclasses of BaseFinder must provide a list() method') + raise NotImplementedError( + "subclasses of BaseFinder must provide a list() method" + ) class FileSystemFinder(BaseFinder): @@ -49,6 +52,7 @@ class FileSystemFinder(BaseFinder): A static files finder that uses the ``STATICFILES_DIRS`` setting to locate files. """ + def __init__(self, app_names=None, *args, **kwargs): # List of locations with static files self.locations = [] @@ -58,7 +62,7 @@ class FileSystemFinder(BaseFinder): if isinstance(root, (list, tuple)): prefix, root = root else: - prefix = '' + prefix = "" if (prefix, root) not in self.locations: self.locations.append((prefix, root)) for prefix, root in self.locations: @@ -70,33 +74,43 @@ class FileSystemFinder(BaseFinder): def check(self, **kwargs): errors = [] if not isinstance(settings.STATICFILES_DIRS, (list, tuple)): - errors.append(Error( - 'The STATICFILES_DIRS setting is not a tuple or list.', - hint='Perhaps you forgot a trailing comma?', - id='staticfiles.E001', - )) + errors.append( + Error( + "The STATICFILES_DIRS setting is not a tuple or list.", + hint="Perhaps you forgot a trailing comma?", + id="staticfiles.E001", + ) + ) return errors for root in settings.STATICFILES_DIRS: if isinstance(root, (list, tuple)): prefix, root = root - if prefix.endswith('/'): - errors.append(Error( - 'The prefix %r in the STATICFILES_DIRS setting must ' - 'not end with a slash.' % prefix, - id='staticfiles.E003', - )) - if settings.STATIC_ROOT and os.path.abspath(settings.STATIC_ROOT) == os.path.abspath(root): - errors.append(Error( - 'The STATICFILES_DIRS setting should not contain the ' - 'STATIC_ROOT setting.', - id='staticfiles.E002', - )) + if prefix.endswith("/"): + errors.append( + Error( + "The prefix %r in the STATICFILES_DIRS setting must " + "not end with a slash." % prefix, + id="staticfiles.E003", + ) + ) + if settings.STATIC_ROOT and os.path.abspath( + settings.STATIC_ROOT + ) == os.path.abspath(root): + errors.append( + Error( + "The STATICFILES_DIRS setting should not contain the " + "STATIC_ROOT setting.", + id="staticfiles.E002", + ) + ) if not os.path.isdir(root): - errors.append(Warning( - f"The directory '{root}' in the STATICFILES_DIRS setting " - f"does not exist.", - id='staticfiles.W004', - )) + errors.append( + Warning( + f"The directory '{root}' in the STATICFILES_DIRS setting " + f"does not exist.", + id="staticfiles.W004", + ) + ) return errors def find(self, path, all=False): @@ -120,10 +134,10 @@ class FileSystemFinder(BaseFinder): absolute path (or ``None`` if no match). """ if prefix: - prefix = '%s%s' % (prefix, os.sep) + prefix = "%s%s" % (prefix, os.sep) if not path.startswith(prefix): return None - path = path[len(prefix):] + path = path[len(prefix) :] path = safe_join(root, path) if os.path.exists(path): return path @@ -145,8 +159,9 @@ class AppDirectoriesFinder(BaseFinder): A static files finder that looks in the directory of each app as specified in the source_dir attribute. """ + storage_class = FileSystemStorage - source_dir = 'static' + source_dir = "static" def __init__(self, app_names=None, *args, **kwargs): # The list of apps that are handled @@ -159,7 +174,8 @@ class AppDirectoriesFinder(BaseFinder): app_configs = [ac for ac in app_configs if ac.name in app_names] for app_config in app_configs: app_storage = self.storage_class( - os.path.join(app_config.path, self.source_dir)) + os.path.join(app_config.path, self.source_dir) + ) if os.path.isdir(app_storage.location): self.storages[app_config.name] = app_storage if app_config.name not in self.apps: @@ -171,7 +187,7 @@ class AppDirectoriesFinder(BaseFinder): List all files in all app storages. """ for storage in self.storages.values(): - if storage.exists(''): # check if storage location exists + if storage.exists(""): # check if storage location exists for path in utils.get_files(storage, ignore_patterns): yield path, storage @@ -208,15 +224,18 @@ class BaseStorageFinder(BaseFinder): A base static files finder to be used to extended with an own storage class. """ + storage = None def __init__(self, storage=None, *args, **kwargs): if storage is not None: self.storage = storage if self.storage is None: - raise ImproperlyConfigured("The staticfiles storage finder %r " - "doesn't have a storage class " - "assigned." % self.__class__) + raise ImproperlyConfigured( + "The staticfiles storage finder %r " + "doesn't have a storage class " + "assigned." % self.__class__ + ) # Make sure we have a storage instance here. if not isinstance(self.storage, (Storage, LazyObject)): self.storage = self.storage() @@ -227,7 +246,7 @@ class BaseStorageFinder(BaseFinder): Look for files in the default file storage, if it's local. """ try: - self.storage.path('') + self.storage.path("") except NotImplementedError: pass else: @@ -252,15 +271,18 @@ class DefaultStorageFinder(BaseStorageFinder): """ A static files finder that uses the default storage backend. """ + storage = default_storage def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - base_location = getattr(self.storage, 'base_location', empty) + base_location = getattr(self.storage, "base_location", empty) if not base_location: - raise ImproperlyConfigured("The storage backend of the " - "staticfiles finder %r doesn't have " - "a valid location." % self.__class__) + raise ImproperlyConfigured( + "The storage backend of the " + "staticfiles finder %r doesn't have " + "a valid location." % self.__class__ + ) def find(path, all=False): @@ -298,6 +320,7 @@ def get_finder(import_path): """ Finder = import_string(import_path) if not issubclass(Finder, BaseFinder): - raise ImproperlyConfigured('Finder "%s" is not a subclass of "%s"' % - (Finder, BaseFinder)) + raise ImproperlyConfigured( + 'Finder "%s" is not a subclass of "%s"' % (Finder, BaseFinder) + ) return Finder() diff --git a/django/contrib/staticfiles/handlers.py b/django/contrib/staticfiles/handlers.py index 44fef0722d..da53149e78 100644 --- a/django/contrib/staticfiles/handlers.py +++ b/django/contrib/staticfiles/handlers.py @@ -16,6 +16,7 @@ class StaticFilesHandlerMixin: """ Common methods used by WSGI and ASGI handlers. """ + # May be used to differentiate between handler types (e.g. in a # request_finished signal) handles_files = True @@ -41,7 +42,7 @@ class StaticFilesHandlerMixin: """ Return the relative path to the media file on disk for the given URL. """ - relative_url = url[len(self.base_url[2]):] + relative_url = url[len(self.base_url[2]) :] return url2pathname(relative_url) def serve(self, request): @@ -58,7 +59,9 @@ class StaticFilesHandlerMixin: try: return await sync_to_async(self.serve, thread_sensitive=False)(request) except Http404 as e: - return await sync_to_async(response_for_exception, thread_sensitive=False)(request, e) + return await sync_to_async(response_for_exception, thread_sensitive=False)( + request, e + ) class StaticFilesHandler(StaticFilesHandlerMixin, WSGIHandler): @@ -66,6 +69,7 @@ class StaticFilesHandler(StaticFilesHandlerMixin, WSGIHandler): WSGI middleware that intercepts calls to the static files directory, as defined by the STATIC_URL setting, and serves those files. """ + def __init__(self, application): self.application = application self.base_url = urlparse(self.get_base_url()) @@ -82,13 +86,14 @@ class ASGIStaticFilesHandler(StaticFilesHandlerMixin, ASGIHandler): ASGI application which wraps another and intercepts requests for static files, passing them off to Django's static file serving. """ + def __init__(self, application): self.application = application self.base_url = urlparse(self.get_base_url()) async def __call__(self, scope, receive, send): # Only even look at HTTP requests - if scope['type'] == 'http' and self._should_handle(scope['path']): + if scope["type"] == "http" and self._should_handle(scope["path"]): # Serve static content # (the one thing super() doesn't do is __call__, apparently) return await super().__call__(scope, receive, send) diff --git a/django/contrib/staticfiles/management/commands/collectstatic.py b/django/contrib/staticfiles/management/commands/collectstatic.py index 9d4c2d5006..c6d561bc81 100644 --- a/django/contrib/staticfiles/management/commands/collectstatic.py +++ b/django/contrib/staticfiles/management/commands/collectstatic.py @@ -15,6 +15,7 @@ class Command(BaseCommand): Copies or symlinks static files from different locations to the settings.STATIC_ROOT. """ + help = "Collect static files in a single location." requires_system_checks = [Tags.staticfiles] @@ -30,41 +31,58 @@ class Command(BaseCommand): @cached_property def local(self): try: - self.storage.path('') + self.storage.path("") except NotImplementedError: return False return True def add_arguments(self, parser): parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", help="Do NOT prompt the user for input of any kind.", ) parser.add_argument( - '--no-post-process', action='store_false', dest='post_process', + "--no-post-process", + action="store_false", + dest="post_process", help="Do NOT post process collected files.", ) parser.add_argument( - '-i', '--ignore', action='append', default=[], - dest='ignore_patterns', metavar='PATTERN', + "-i", + "--ignore", + action="append", + default=[], + dest="ignore_patterns", + metavar="PATTERN", help="Ignore files or directories matching this glob-style " - "pattern. Use multiple times to ignore more.", + "pattern. Use multiple times to ignore more.", ) parser.add_argument( - '-n', '--dry-run', action='store_true', + "-n", + "--dry-run", + action="store_true", help="Do everything except modify the filesystem.", ) parser.add_argument( - '-c', '--clear', action='store_true', + "-c", + "--clear", + action="store_true", help="Clear the existing files using the storage " - "before trying to copy or link the original file.", + "before trying to copy or link the original file.", ) parser.add_argument( - '-l', '--link', action='store_true', + "-l", + "--link", + action="store_true", help="Create a symbolic link to each file instead of copying.", ) parser.add_argument( - '--no-default-ignore', action='store_false', dest='use_default_ignore_patterns', + "--no-default-ignore", + action="store_false", + dest="use_default_ignore_patterns", help="Don't ignore the common private glob-style patterns (defaults to 'CVS', '.*' and '*~').", ) @@ -72,16 +90,16 @@ class Command(BaseCommand): """ Set instance variables based on an options dict """ - self.interactive = options['interactive'] - self.verbosity = options['verbosity'] - self.symlink = options['link'] - self.clear = options['clear'] - self.dry_run = options['dry_run'] - ignore_patterns = options['ignore_patterns'] - if options['use_default_ignore_patterns']: - ignore_patterns += apps.get_app_config('staticfiles').ignore_patterns + self.interactive = options["interactive"] + self.verbosity = options["verbosity"] + self.symlink = options["link"] + self.clear = options["clear"] + self.dry_run = options["dry_run"] + ignore_patterns = options["ignore_patterns"] + if options["use_default_ignore_patterns"]: + ignore_patterns += apps.get_app_config("staticfiles").ignore_patterns self.ignore_patterns = list({os.path.normpath(p) for p in ignore_patterns}) - self.post_process = options['post_process'] + self.post_process = options["post_process"] def collect(self): """ @@ -93,7 +111,7 @@ class Command(BaseCommand): raise CommandError("Can't symlink to a remote destination.") if self.clear: - self.clear_dir('') + self.clear_dir("") if self.symlink: handler = self.link_file @@ -104,7 +122,7 @@ class Command(BaseCommand): for finder in get_finders(): for path, storage in finder.list(self.ignore_patterns): # Prefix the relative path if the source storage contains it - if getattr(storage, 'prefix', None): + if getattr(storage, "prefix", None): prefixed_path = os.path.join(storage.prefix, path) else: prefixed_path = path @@ -122,9 +140,8 @@ class Command(BaseCommand): ) # Storage backends may define a post_process() method. - if self.post_process and hasattr(self.storage, 'post_process'): - processor = self.storage.post_process(found_files, - dry_run=self.dry_run) + if self.post_process and hasattr(self.storage, "post_process"): + processor = self.storage.post_process(found_files, dry_run=self.dry_run) for original_path, processed_path, processed in processor: if isinstance(processed, Exception): self.stderr.write("Post-processing '%s' failed!" % original_path) @@ -133,75 +150,84 @@ class Command(BaseCommand): self.stderr.write() raise processed if processed: - self.log("Post-processed '%s' as '%s'" % - (original_path, processed_path), level=2) + self.log( + "Post-processed '%s' as '%s'" % (original_path, processed_path), + level=2, + ) self.post_processed_files.append(original_path) else: self.log("Skipped post-processing '%s'" % original_path) return { - 'modified': self.copied_files + self.symlinked_files, - 'unmodified': self.unmodified_files, - 'post_processed': self.post_processed_files, + "modified": self.copied_files + self.symlinked_files, + "unmodified": self.unmodified_files, + "post_processed": self.post_processed_files, } def handle(self, **options): self.set_options(**options) - message = ['\n'] + message = ["\n"] if self.dry_run: message.append( - 'You have activated the --dry-run option so no files will be modified.\n\n' + "You have activated the --dry-run option so no files will be modified.\n\n" ) message.append( - 'You have requested to collect static files at the destination\n' - 'location as specified in your settings' + "You have requested to collect static files at the destination\n" + "location as specified in your settings" ) if self.is_local_storage() and self.storage.location: destination_path = self.storage.location - message.append(':\n\n %s\n\n' % destination_path) - should_warn_user = ( - self.storage.exists(destination_path) and - any(self.storage.listdir(destination_path)) + message.append(":\n\n %s\n\n" % destination_path) + should_warn_user = self.storage.exists(destination_path) and any( + self.storage.listdir(destination_path) ) else: destination_path = None - message.append('.\n\n') + message.append(".\n\n") # Destination files existence not checked; play it safe and warn. should_warn_user = True if self.interactive and should_warn_user: if self.clear: - message.append('This will DELETE ALL FILES in this location!\n') + message.append("This will DELETE ALL FILES in this location!\n") else: - message.append('This will overwrite existing files!\n') + message.append("This will overwrite existing files!\n") message.append( - 'Are you sure you want to do this?\n\n' + "Are you sure you want to do this?\n\n" "Type 'yes' to continue, or 'no' to cancel: " ) - if input(''.join(message)) != 'yes': + if input("".join(message)) != "yes": raise CommandError("Collecting static files cancelled.") collected = self.collect() if self.verbosity >= 1: - modified_count = len(collected['modified']) - unmodified_count = len(collected['unmodified']) - post_processed_count = len(collected['post_processed']) + modified_count = len(collected["modified"]) + unmodified_count = len(collected["unmodified"]) + post_processed_count = len(collected["post_processed"]) return ( "\n%(modified_count)s %(identifier)s %(action)s" "%(destination)s%(unmodified)s%(post_processed)s." ) % { - 'modified_count': modified_count, - 'identifier': 'static file' + ('' if modified_count == 1 else 's'), - 'action': 'symlinked' if self.symlink else 'copied', - 'destination': (" to '%s'" % destination_path if destination_path else ''), - 'unmodified': (', %s unmodified' % unmodified_count if collected['unmodified'] else ''), - 'post_processed': (collected['post_processed'] and - ', %s post-processed' - % post_processed_count or ''), + "modified_count": modified_count, + "identifier": "static file" + ("" if modified_count == 1 else "s"), + "action": "symlinked" if self.symlink else "copied", + "destination": ( + " to '%s'" % destination_path if destination_path else "" + ), + "unmodified": ( + ", %s unmodified" % unmodified_count + if collected["unmodified"] + else "" + ), + "post_processed": ( + collected["post_processed"] + and ", %s post-processed" % post_processed_count + or "" + ), } def log(self, msg, level=2): @@ -268,16 +294,17 @@ class Command(BaseCommand): # previous collectstatic was with --link), the old # links/files must be deleted so it's not safe to skip # unmodified files. - can_skip_unmodified_files = not (self.symlink ^ os.path.islink(full_path)) + can_skip_unmodified_files = not ( + self.symlink ^ os.path.islink(full_path) + ) else: # In remote storages, skipping is only based on the # modified times since symlinks aren't relevant. can_skip_unmodified_files = True # Avoid sub-second precision (see #14665, #19540) - file_is_unmodified = ( - target_last_modified.replace(microsecond=0) >= - source_last_modified.replace(microsecond=0) - ) + file_is_unmodified = target_last_modified.replace( + microsecond=0 + ) >= source_last_modified.replace(microsecond=0) if file_is_unmodified and can_skip_unmodified_files: if prefixed_path not in self.unmodified_files: self.unmodified_files.append(prefixed_path) @@ -316,8 +343,11 @@ class Command(BaseCommand): os.symlink(source_path, full_path) except NotImplementedError: import platform - raise CommandError("Symlinking is not supported in this " - "platform (%s)." % platform.platform()) + + raise CommandError( + "Symlinking is not supported in this " + "platform (%s)." % platform.platform() + ) except OSError as e: raise CommandError(e) if prefixed_path not in self.symlinked_files: diff --git a/django/contrib/staticfiles/management/commands/findstatic.py b/django/contrib/staticfiles/management/commands/findstatic.py index fe3b53cbd8..97413a64af 100644 --- a/django/contrib/staticfiles/management/commands/findstatic.py +++ b/django/contrib/staticfiles/management/commands/findstatic.py @@ -6,38 +6,43 @@ from django.core.management.base import LabelCommand class Command(LabelCommand): help = "Finds the absolute paths for the given static file(s)." - label = 'staticfile' + label = "staticfile" def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument( - '--first', action='store_false', dest='all', + "--first", + action="store_false", + dest="all", help="Only return the first match for each static file.", ) def handle_label(self, path, **options): - verbosity = options['verbosity'] - result = finders.find(path, all=options['all']) + verbosity = options["verbosity"] + result = finders.find(path, all=options["all"]) if verbosity >= 2: searched_locations = ( - "\nLooking in the following locations:\n %s" % - "\n ".join([str(loc) for loc in finders.searched_locations]) + "\nLooking in the following locations:\n %s" + % "\n ".join([str(loc) for loc in finders.searched_locations]) ) else: - searched_locations = '' + searched_locations = "" if result: if not isinstance(result, (list, tuple)): result = [result] result = (os.path.realpath(path) for path in result) if verbosity >= 1: - file_list = '\n '.join(result) - return ("Found '%s' here:\n %s%s" % - (path, file_list, searched_locations)) + file_list = "\n ".join(result) + return "Found '%s' here:\n %s%s" % ( + path, + file_list, + searched_locations, + ) else: - return '\n'.join(result) + return "\n".join(result) else: message = ["No matching file found for '%s'." % path] if verbosity >= 2: message.append(searched_locations) if verbosity >= 1: - self.stderr.write('\n'.join(message)) + self.stderr.write("\n".join(message)) diff --git a/django/contrib/staticfiles/management/commands/runserver.py b/django/contrib/staticfiles/management/commands/runserver.py index cf9605ee97..fd9ddb16a4 100644 --- a/django/contrib/staticfiles/management/commands/runserver.py +++ b/django/contrib/staticfiles/management/commands/runserver.py @@ -1,22 +1,26 @@ from django.conf import settings from django.contrib.staticfiles.handlers import StaticFilesHandler -from django.core.management.commands.runserver import ( - Command as RunserverCommand, -) +from django.core.management.commands.runserver import Command as RunserverCommand class Command(RunserverCommand): - help = "Starts a lightweight web server for development and also serves static files." + help = ( + "Starts a lightweight web server for development and also serves static files." + ) def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument( - '--nostatic', action="store_false", dest='use_static_handler', - help='Tells Django to NOT automatically serve static files at STATIC_URL.', + "--nostatic", + action="store_false", + dest="use_static_handler", + help="Tells Django to NOT automatically serve static files at STATIC_URL.", ) parser.add_argument( - '--insecure', action="store_true", dest='insecure_serving', - help='Allows serving static files even if DEBUG is False.', + "--insecure", + action="store_true", + dest="insecure_serving", + help="Allows serving static files even if DEBUG is False.", ) def get_handler(self, *args, **options): @@ -25,8 +29,8 @@ class Command(RunserverCommand): if static files should be served. Otherwise return the default handler. """ handler = super().get_handler(*args, **options) - use_static_handler = options['use_static_handler'] - insecure_serving = options['insecure_serving'] + use_static_handler = options["use_static_handler"] + insecure_serving = options["insecure_serving"] if use_static_handler and (settings.DEBUG or insecure_serving): return StaticFilesHandler(handler) return handler diff --git a/django/contrib/staticfiles/storage.py b/django/contrib/staticfiles/storage.py index 5b492abd75..754f173263 100644 --- a/django/contrib/staticfiles/storage.py +++ b/django/contrib/staticfiles/storage.py @@ -20,6 +20,7 @@ class StaticFilesStorage(FileSystemStorage): The defaults for ``location`` and ``base_url`` are ``STATIC_ROOT`` and ``STATIC_URL``. """ + def __init__(self, location=None, base_url=None, *args, **kwargs): if location is None: location = settings.STATIC_ROOT @@ -35,9 +36,11 @@ class StaticFilesStorage(FileSystemStorage): def path(self, name): if not self.location: - raise ImproperlyConfigured("You're using the staticfiles app " - "without having set the STATIC_ROOT " - "setting to a filesystem path.") + raise ImproperlyConfigured( + "You're using the staticfiles app " + "without having set the STATIC_ROOT " + "setting to a filesystem path." + ) return super().path(name) @@ -45,23 +48,29 @@ class HashedFilesMixin: default_template = """url("%(url)s")""" max_post_process_passes = 5 patterns = ( - ("*.css", ( - r"""(?P<matched>url\(['"]{0,1}\s*(?P<url>.*?)["']{0,1}\))""", + ( + "*.css", ( - r"""(?P<matched>@import\s*["']\s*(?P<url>.*?)["'])""", - """@import url("%(url)s")""", + r"""(?P<matched>url\(['"]{0,1}\s*(?P<url>.*?)["']{0,1}\))""", + ( + r"""(?P<matched>@import\s*["']\s*(?P<url>.*?)["'])""", + """@import url("%(url)s")""", + ), + ( + r"(?m)(?P<matched>)^(/\*# (?-i:sourceMappingURL)=(?P<url>.*) \*/)$", + "/*# sourceMappingURL=%(url)s */", + ), ), + ), + ( + "*.js", ( - r'(?m)(?P<matched>)^(/\*# (?-i:sourceMappingURL)=(?P<url>.*) \*/)$', - '/*# sourceMappingURL=%(url)s */', + ( + r"(?m)(?P<matched>)^(//# (?-i:sourceMappingURL)=(?P<url>.*))$", + "//# sourceMappingURL=%(url)s", + ), ), - )), - ('*.js', ( - ( - r'(?m)(?P<matched>)^(//# (?-i:sourceMappingURL)=(?P<url>.*))$', - '//# sourceMappingURL=%(url)s', - ), - )), + ), ) keep_intermediate_files = True @@ -98,7 +107,9 @@ class HashedFilesMixin: opened = content is None if opened: if not self.exists(filename): - raise ValueError("The file '%s' could not be found with %r." % (filename, self)) + raise ValueError( + "The file '%s' could not be found with %r." % (filename, self) + ) try: content = self.open(filename) except OSError: @@ -111,15 +122,14 @@ class HashedFilesMixin: content.close() path, filename = os.path.split(clean_name) root, ext = os.path.splitext(filename) - file_hash = ('.%s' % file_hash) if file_hash else '' - hashed_name = os.path.join(path, "%s%s%s" % - (root, file_hash, ext)) + file_hash = (".%s" % file_hash) if file_hash else "" + hashed_name = os.path.join(path, "%s%s%s" % (root, file_hash, ext)) unparsed_name = list(parsed_name) unparsed_name[2] = hashed_name # Special casing for a @font-face hack, like url(myfont.eot?#iefix") # http://www.fontspring.com/blog/the-new-bulletproof-font-face-syntax - if '?#' in name and not unparsed_name[3]: - unparsed_name[2] += '?' + if "?#" in name and not unparsed_name[3]: + unparsed_name[2] += "?" return urlunsplit(unparsed_name) def _url(self, hashed_name_func, name, force=False, hashed_files=None): @@ -127,10 +137,10 @@ class HashedFilesMixin: Return the non-hashed URL in DEBUG mode. """ if settings.DEBUG and not force: - hashed_name, fragment = name, '' + hashed_name, fragment = name, "" else: clean_name, fragment = urldefrag(name) - if urlsplit(clean_name).path.endswith('/'): # don't hash paths + if urlsplit(clean_name).path.endswith("/"): # don't hash paths hashed_name = name else: args = (clean_name,) @@ -142,13 +152,13 @@ class HashedFilesMixin: # Special casing for a @font-face hack, like url(myfont.eot?#iefix") # http://www.fontspring.com/blog/the-new-bulletproof-font-face-syntax - query_fragment = '?#' in name # [sic!] + query_fragment = "?#" in name # [sic!] if fragment or query_fragment: urlparts = list(urlsplit(final_url)) if fragment and not urlparts[4]: urlparts[4] = fragment if query_fragment and not urlparts[3]: - urlparts[2] += '?' + urlparts[2] += "?" final_url = urlunsplit(urlparts) return unquote(final_url) @@ -174,44 +184,48 @@ class HashedFilesMixin: to and calling the url() method of the storage. """ matches = matchobj.groupdict() - matched = matches['matched'] - url = matches['url'] + matched = matches["matched"] + url = matches["url"] # Ignore absolute/protocol-relative and data-uri URLs. - if re.match(r'^[a-z]+:', url): + if re.match(r"^[a-z]+:", url): return matched # Ignore absolute URLs that don't point to a static file (dynamic # CSS / JS?). Note that STATIC_URL cannot be empty. - if url.startswith('/') and not url.startswith(settings.STATIC_URL): + if url.startswith("/") and not url.startswith(settings.STATIC_URL): return matched # Strip off the fragment so a path-like fragment won't interfere. url_path, fragment = urldefrag(url) - if url_path.startswith('/'): + if url_path.startswith("/"): # Otherwise the condition above would have returned prematurely. assert url_path.startswith(settings.STATIC_URL) - target_name = url_path[len(settings.STATIC_URL):] + target_name = url_path[len(settings.STATIC_URL) :] else: # We're using the posixpath module to mix paths and URLs conveniently. - source_name = name if os.sep == '/' else name.replace(os.sep, '/') + source_name = name if os.sep == "/" else name.replace(os.sep, "/") target_name = posixpath.join(posixpath.dirname(source_name), url_path) # Determine the hashed name of the target file with the storage backend. hashed_url = self._url( - self._stored_name, unquote(target_name), - force=True, hashed_files=hashed_files, + self._stored_name, + unquote(target_name), + force=True, + hashed_files=hashed_files, ) - transformed_url = '/'.join(url_path.split('/')[:-1] + hashed_url.split('/')[-1:]) + transformed_url = "/".join( + url_path.split("/")[:-1] + hashed_url.split("/")[-1:] + ) # Restore the fragment that was stripped off earlier. if fragment: - transformed_url += ('?#' if '?#' in url else '#') + fragment + transformed_url += ("?#" if "?#" in url else "#") + fragment # Return the hashed version to the file - matches['url'] = unquote(transformed_url) + matches["url"] = unquote(transformed_url) return template % matches return converter @@ -239,8 +253,7 @@ class HashedFilesMixin: # build a list of adjustable files adjustable_paths = [ - path for path in paths - if matches_patterns(path, self._patterns) + path for path in paths if matches_patterns(path, self._patterns) ] # Adjustable files to yield at end, keyed by the original path. @@ -248,7 +261,9 @@ class HashedFilesMixin: # Do a single pass first. Post-process all files once, yielding not # adjustable files and exceptions, and collecting adjustable files. - for name, hashed_name, processed, _ in self._post_process(paths, adjustable_paths, hashed_files): + for name, hashed_name, processed, _ in self._post_process( + paths, adjustable_paths, hashed_files + ): if name not in adjustable_paths or isinstance(processed, Exception): yield name, hashed_name, processed else: @@ -259,7 +274,9 @@ class HashedFilesMixin: for i in range(self.max_post_process_passes): substitutions = False - for name, hashed_name, processed, subst in self._post_process(paths, adjustable_paths, hashed_files): + for name, hashed_name, processed, subst in self._post_process( + paths, adjustable_paths, hashed_files + ): # Overwrite since hashed_name may be newer. processed_adjustable_paths[name] = (name, hashed_name, processed) substitutions = substitutions or subst @@ -268,7 +285,7 @@ class HashedFilesMixin: break if substitutions: - yield 'All', None, RuntimeError('Max post-process passes exceeded.') + yield "All", None, RuntimeError("Max post-process passes exceeded.") # Store the processed paths self.hashed_files.update(hashed_files) @@ -298,7 +315,7 @@ class HashedFilesMixin: hashed_name = hashed_files[hash_key] # then get the original's file content.. - if hasattr(original_file, 'seek'): + if hasattr(original_file, "seek"): original_file.seek(0) hashed_file_exists = self.exists(hashed_name) @@ -307,11 +324,13 @@ class HashedFilesMixin: # ..to apply each replacement pattern to the content if name in adjustable_paths: old_hashed_name = hashed_name - content = original_file.read().decode('utf-8') + content = original_file.read().decode("utf-8") for extension, patterns in self._patterns.items(): if matches_patterns(path, (extension,)): for pattern, template in patterns: - converter = self.url_converter(name, hashed_files, template) + converter = self.url_converter( + name, hashed_files, template + ) try: content = pattern.sub(converter, content) except ValueError as exc: @@ -349,7 +368,7 @@ class HashedFilesMixin: yield name, hashed_name, processed, substitutions def clean_name(self, name): - return name.replace('\\', '/') + return name.replace("\\", "/") def hash_key(self, name): return name @@ -391,8 +410,8 @@ class HashedFilesMixin: class ManifestFilesMixin(HashedFilesMixin): - manifest_version = '1.0' # the manifest format standard - manifest_name = 'staticfiles.json' + manifest_version = "1.0" # the manifest format standard + manifest_name = "staticfiles.json" manifest_strict = True keep_intermediate_files = False @@ -419,20 +438,22 @@ class ManifestFilesMixin(HashedFilesMixin): except json.JSONDecodeError: pass else: - version = stored.get('version') - if version == '1.0': - return stored.get('paths', {}) - raise ValueError("Couldn't load manifest '%s' (version %s)" % - (self.manifest_name, self.manifest_version)) + version = stored.get("version") + if version == "1.0": + return stored.get("paths", {}) + raise ValueError( + "Couldn't load manifest '%s' (version %s)" + % (self.manifest_name, self.manifest_version) + ) def post_process(self, *args, **kwargs): self.hashed_files = {} yield from super().post_process(*args, **kwargs) - if not kwargs.get('dry_run'): + if not kwargs.get("dry_run"): self.save_manifest() def save_manifest(self): - payload = {'paths': self.hashed_files, 'version': self.manifest_version} + payload = {"paths": self.hashed_files, "version": self.manifest_version} if self.manifest_storage.exists(self.manifest_name): self.manifest_storage.delete(self.manifest_name) contents = json.dumps(payload).encode() @@ -445,14 +466,16 @@ class ManifestFilesMixin(HashedFilesMixin): cache_name = self.hashed_files.get(hash_key) if cache_name is None: if self.manifest_strict: - raise ValueError("Missing staticfiles manifest entry for '%s'" % clean_name) + raise ValueError( + "Missing staticfiles manifest entry for '%s'" % clean_name + ) cache_name = self.clean_name(self.hashed_name(name)) unparsed_name = list(parsed_name) unparsed_name[2] = cache_name # Special casing for a @font-face hack, like url(myfont.eot?#iefix") # http://www.fontspring.com/blog/the-new-bulletproof-font-face-syntax - if '?#' in name and not unparsed_name[3]: - unparsed_name[2] += '?' + if "?#" in name and not unparsed_name[3]: + unparsed_name[2] += "?" return urlunsplit(unparsed_name) @@ -461,6 +484,7 @@ class ManifestStaticFilesStorage(ManifestFilesMixin, StaticFilesStorage): A static file system storage backend which also saves hashed copies of the files it saves. """ + pass diff --git a/django/contrib/staticfiles/utils.py b/django/contrib/staticfiles/utils.py index e4297aff2b..efd67ac8e8 100644 --- a/django/contrib/staticfiles/utils.py +++ b/django/contrib/staticfiles/utils.py @@ -13,7 +13,7 @@ def matches_patterns(path, patterns): return any(fnmatch.fnmatchcase(path, pattern) for pattern in patterns) -def get_files(storage, ignore_patterns=None, location=''): +def get_files(storage, ignore_patterns=None, location=""): """ Recursively walk the storage directories yielding the paths of all files that should be copied. @@ -48,18 +48,24 @@ def check_settings(base_url=None): if not base_url: raise ImproperlyConfigured( "You're using the staticfiles app " - "without having set the required STATIC_URL setting.") + "without having set the required STATIC_URL setting." + ) if settings.MEDIA_URL == base_url: raise ImproperlyConfigured( "The MEDIA_URL and STATIC_URL settings must have different values" ) - if (settings.DEBUG and settings.MEDIA_URL and settings.STATIC_URL and - settings.MEDIA_URL.startswith(settings.STATIC_URL)): + if ( + settings.DEBUG + and settings.MEDIA_URL + and settings.STATIC_URL + and settings.MEDIA_URL.startswith(settings.STATIC_URL) + ): raise ImproperlyConfigured( "runserver can't serve media if MEDIA_URL is within STATIC_URL." ) - if ((settings.MEDIA_ROOT and settings.STATIC_ROOT) and - (settings.MEDIA_ROOT == settings.STATIC_ROOT)): + if (settings.MEDIA_ROOT and settings.STATIC_ROOT) and ( + settings.MEDIA_ROOT == settings.STATIC_ROOT + ): raise ImproperlyConfigured( "The MEDIA_ROOT and STATIC_ROOT settings must have different values" ) diff --git a/django/contrib/staticfiles/views.py b/django/contrib/staticfiles/views.py index 6c36e7a91b..83d04d4cec 100644 --- a/django/contrib/staticfiles/views.py +++ b/django/contrib/staticfiles/views.py @@ -29,10 +29,10 @@ def serve(request, path, insecure=False, **kwargs): """ if not settings.DEBUG and not insecure: raise Http404 - normalized_path = posixpath.normpath(path).lstrip('/') + normalized_path = posixpath.normpath(path).lstrip("/") absolute_path = finders.find(normalized_path) if not absolute_path: - if path.endswith('/') or path == '': + if path.endswith("/") or path == "": raise Http404("Directory indexes are not allowed here.") raise Http404("'%s' could not be found" % path) document_root, path = os.path.split(absolute_path) diff --git a/django/contrib/syndication/apps.py b/django/contrib/syndication/apps.py index b3f7c6cd61..bb0f86aa21 100644 --- a/django/contrib/syndication/apps.py +++ b/django/contrib/syndication/apps.py @@ -3,5 +3,5 @@ from django.utils.translation import gettext_lazy as _ class SyndicationConfig(AppConfig): - name = 'django.contrib.syndication' + name = "django.contrib.syndication" verbose_name = _("Syndication") diff --git a/django/contrib/syndication/views.py b/django/contrib/syndication/views.py index 7200907d7d..a9d1bff5cf 100644 --- a/django/contrib/syndication/views.py +++ b/django/contrib/syndication/views.py @@ -11,12 +11,12 @@ from django.utils.translation import get_language def add_domain(domain, url, secure=False): - protocol = 'https' if secure else 'http' - if url.startswith('//'): + protocol = "https" if secure else "http" + if url.startswith("//"): # Support network-path reference (see #16753) - RSS requires a protocol - url = '%s:%s' % (protocol, url) - elif not url.startswith(('http://', 'https://', 'mailto:')): - url = iri_to_uri('%s://%s%s' % (protocol, domain, url)) + url = "%s:%s" % (protocol, url) + elif not url.startswith(("http://", "https://", "mailto:")): + url = iri_to_uri("%s://%s%s" % (protocol, domain, url)) return url @@ -34,14 +34,16 @@ class Feed: try: obj = self.get_object(request, *args, **kwargs) except ObjectDoesNotExist: - raise Http404('Feed object does not exist.') + raise Http404("Feed object does not exist.") feedgen = self.get_feed(obj, request) response = HttpResponse(content_type=feedgen.content_type) - if hasattr(self, 'item_pubdate') or hasattr(self, 'item_updateddate'): + if hasattr(self, "item_pubdate") or hasattr(self, "item_updateddate"): # if item_pubdate or item_updateddate is defined for the feed, set # header so as ConditionalGetMiddleware is able to send 304 NOT MODIFIED - response.headers['Last-Modified'] = http_date(feedgen.latest_post_date().timestamp()) - feedgen.write(response, 'utf-8') + response.headers["Last-Modified"] = http_date( + feedgen.latest_post_date().timestamp() + ) + feedgen.write(response, "utf-8") return response def item_title(self, item): @@ -56,17 +58,17 @@ class Feed: return item.get_absolute_url() except AttributeError: raise ImproperlyConfigured( - 'Give your %s class a get_absolute_url() method, or define an ' - 'item_link() method in your Feed class.' % item.__class__.__name__ + "Give your %s class a get_absolute_url() method, or define an " + "item_link() method in your Feed class." % item.__class__.__name__ ) def item_enclosures(self, item): - enc_url = self._get_dynamic_attr('item_enclosure_url', item) + enc_url = self._get_dynamic_attr("item_enclosure_url", item) if enc_url: enc = feedgenerator.Enclosure( url=str(enc_url), - length=str(self._get_dynamic_attr('item_enclosure_length', item)), - mime_type=str(self._get_dynamic_attr('item_enclosure_mime_type', item)), + length=str(self._get_dynamic_attr("item_enclosure_length", item)), + mime_type=str(self._get_dynamic_attr("item_enclosure_mime_type", item)), ) return [enc] return [] @@ -84,7 +86,7 @@ class Feed: code = attr.__code__ except AttributeError: code = attr.__call__.__code__ - if code.co_argcount == 2: # one argument is 'self' + if code.co_argcount == 2: # one argument is 'self' return attr(obj) else: return attr() @@ -115,7 +117,7 @@ class Feed: Default implementation preserves the old behavior of using {'obj': item, 'site': current_site} as the context. """ - return {'obj': kwargs.get('item'), 'site': kwargs.get('site')} + return {"obj": kwargs.get("item"), "site": kwargs.get("site")} def get_feed(self, obj, request): """ @@ -124,28 +126,28 @@ class Feed: """ current_site = get_current_site(request) - link = self._get_dynamic_attr('link', obj) + link = self._get_dynamic_attr("link", obj) link = add_domain(current_site.domain, link, request.is_secure()) feed = self.feed_type( - title=self._get_dynamic_attr('title', obj), - subtitle=self._get_dynamic_attr('subtitle', obj), + title=self._get_dynamic_attr("title", obj), + subtitle=self._get_dynamic_attr("subtitle", obj), link=link, - description=self._get_dynamic_attr('description', obj), + description=self._get_dynamic_attr("description", obj), language=self.language or get_language(), feed_url=add_domain( current_site.domain, - self._get_dynamic_attr('feed_url', obj) or request.path, + self._get_dynamic_attr("feed_url", obj) or request.path, request.is_secure(), ), - author_name=self._get_dynamic_attr('author_name', obj), - author_link=self._get_dynamic_attr('author_link', obj), - author_email=self._get_dynamic_attr('author_email', obj), - categories=self._get_dynamic_attr('categories', obj), - feed_copyright=self._get_dynamic_attr('feed_copyright', obj), - feed_guid=self._get_dynamic_attr('feed_guid', obj), - ttl=self._get_dynamic_attr('ttl', obj), - **self.feed_extra_kwargs(obj) + author_name=self._get_dynamic_attr("author_name", obj), + author_link=self._get_dynamic_attr("author_link", obj), + author_email=self._get_dynamic_attr("author_email", obj), + categories=self._get_dynamic_attr("categories", obj), + feed_copyright=self._get_dynamic_attr("feed_copyright", obj), + feed_guid=self._get_dynamic_attr("feed_guid", obj), + ttl=self._get_dynamic_attr("ttl", obj), + **self.feed_extra_kwargs(obj), ) title_tmp = None @@ -162,37 +164,38 @@ class Feed: except TemplateDoesNotExist: pass - for item in self._get_dynamic_attr('items', obj): - context = self.get_context_data(item=item, site=current_site, - obj=obj, request=request) + for item in self._get_dynamic_attr("items", obj): + context = self.get_context_data( + item=item, site=current_site, obj=obj, request=request + ) if title_tmp is not None: title = title_tmp.render(context, request) else: - title = self._get_dynamic_attr('item_title', item) + title = self._get_dynamic_attr("item_title", item) if description_tmp is not None: description = description_tmp.render(context, request) else: - description = self._get_dynamic_attr('item_description', item) + description = self._get_dynamic_attr("item_description", item) link = add_domain( current_site.domain, - self._get_dynamic_attr('item_link', item), + self._get_dynamic_attr("item_link", item), request.is_secure(), ) - enclosures = self._get_dynamic_attr('item_enclosures', item) - author_name = self._get_dynamic_attr('item_author_name', item) + enclosures = self._get_dynamic_attr("item_enclosures", item) + author_name = self._get_dynamic_attr("item_author_name", item) if author_name is not None: - author_email = self._get_dynamic_attr('item_author_email', item) - author_link = self._get_dynamic_attr('item_author_link', item) + author_email = self._get_dynamic_attr("item_author_email", item) + author_link = self._get_dynamic_attr("item_author_link", item) else: author_email = author_link = None tz = get_default_timezone() - pubdate = self._get_dynamic_attr('item_pubdate', item) + pubdate = self._get_dynamic_attr("item_pubdate", item) if pubdate and is_naive(pubdate): pubdate = make_aware(pubdate, tz) - updateddate = self._get_dynamic_attr('item_updateddate', item) + updateddate = self._get_dynamic_attr("item_updateddate", item) if updateddate and is_naive(updateddate): updateddate = make_aware(updateddate, tz) @@ -200,18 +203,19 @@ class Feed: title=title, link=link, description=description, - unique_id=self._get_dynamic_attr('item_guid', item, link), + unique_id=self._get_dynamic_attr("item_guid", item, link), unique_id_is_permalink=self._get_dynamic_attr( - 'item_guid_is_permalink', item), + "item_guid_is_permalink", item + ), enclosures=enclosures, pubdate=pubdate, updateddate=updateddate, author_name=author_name, author_email=author_email, author_link=author_link, - comments=self._get_dynamic_attr('item_comments', item), - categories=self._get_dynamic_attr('item_categories', item), - item_copyright=self._get_dynamic_attr('item_copyright', item), - **self.item_extra_kwargs(item) + comments=self._get_dynamic_attr("item_comments", item), + categories=self._get_dynamic_attr("item_categories", item), + item_copyright=self._get_dynamic_attr("item_copyright", item), + **self.item_extra_kwargs(item), ) return feed diff --git a/django/core/cache/__init__.py b/django/core/cache/__init__.py index a311b50af6..f09c9ecc4b 100644 --- a/django/core/cache/__init__.py +++ b/django/core/cache/__init__.py @@ -14,27 +14,35 @@ See docs/topics/cache.txt for information on the public API. """ from django.core import signals from django.core.cache.backends.base import ( - BaseCache, CacheKeyWarning, InvalidCacheBackendError, InvalidCacheKey, + BaseCache, + CacheKeyWarning, + InvalidCacheBackendError, + InvalidCacheKey, ) from django.utils.connection import BaseConnectionHandler, ConnectionProxy from django.utils.module_loading import import_string __all__ = [ - 'cache', 'caches', 'DEFAULT_CACHE_ALIAS', 'InvalidCacheBackendError', - 'CacheKeyWarning', 'BaseCache', 'InvalidCacheKey', + "cache", + "caches", + "DEFAULT_CACHE_ALIAS", + "InvalidCacheBackendError", + "CacheKeyWarning", + "BaseCache", + "InvalidCacheKey", ] -DEFAULT_CACHE_ALIAS = 'default' +DEFAULT_CACHE_ALIAS = "default" class CacheHandler(BaseConnectionHandler): - settings_name = 'CACHES' + settings_name = "CACHES" exception_class = InvalidCacheBackendError def create_connection(self, alias): params = self.settings[alias].copy() - backend = params.pop('BACKEND') - location = params.pop('LOCATION', '') + backend = params.pop("BACKEND") + location = params.pop("LOCATION", "") try: backend_cls = import_string(backend) except ImportError as e: @@ -45,7 +53,8 @@ class CacheHandler(BaseConnectionHandler): def all(self, initialized_only=False): return [ - self[alias] for alias in self + self[alias] + for alias in self # If initialized_only is True, return only initialized caches. if not initialized_only or hasattr(self._connections, alias) ] diff --git a/django/core/cache/backends/base.py b/django/core/cache/backends/base.py index f632d851ea..eb4b3eac6d 100644 --- a/django/core/cache/backends/base.py +++ b/django/core/cache/backends/base.py @@ -36,7 +36,7 @@ def default_key_func(key, key_prefix, version): the `key_prefix`. KEY_FUNCTION can be used to specify an alternate function with custom key making behavior. """ - return '%s:%s:%s' % (key_prefix, version, key) + return "%s:%s:%s" % (key_prefix, version, key) def get_key_func(key_func): @@ -57,7 +57,7 @@ class BaseCache: _missing_key = object() def __init__(self, params): - timeout = params.get('timeout', params.get('TIMEOUT', 300)) + timeout = params.get("timeout", params.get("TIMEOUT", 300)) if timeout is not None: try: timeout = int(timeout) @@ -65,22 +65,22 @@ class BaseCache: timeout = 300 self.default_timeout = timeout - options = params.get('OPTIONS', {}) - max_entries = params.get('max_entries', options.get('MAX_ENTRIES', 300)) + options = params.get("OPTIONS", {}) + max_entries = params.get("max_entries", options.get("MAX_ENTRIES", 300)) try: self._max_entries = int(max_entries) except (ValueError, TypeError): self._max_entries = 300 - cull_frequency = params.get('cull_frequency', options.get('CULL_FREQUENCY', 3)) + cull_frequency = params.get("cull_frequency", options.get("CULL_FREQUENCY", 3)) try: self._cull_frequency = int(cull_frequency) except (ValueError, TypeError): self._cull_frequency = 3 - self.key_prefix = params.get('KEY_PREFIX', '') - self.version = params.get('VERSION', 1) - self.key_func = get_key_func(params.get('KEY_FUNCTION')) + self.key_prefix = params.get("KEY_PREFIX", "") + self.version = params.get("VERSION", 1) + self.key_func = get_key_func(params.get("KEY_FUNCTION")) def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): """ @@ -130,47 +130,61 @@ class BaseCache: Return True if the value was stored, False otherwise. """ - raise NotImplementedError('subclasses of BaseCache must provide an add() method') + raise NotImplementedError( + "subclasses of BaseCache must provide an add() method" + ) async def aadd(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): - return await sync_to_async(self.add, thread_sensitive=True)(key, value, timeout, version) + return await sync_to_async(self.add, thread_sensitive=True)( + key, value, timeout, version + ) def get(self, key, default=None, version=None): """ Fetch a given key from the cache. If the key does not exist, return default, which itself defaults to None. """ - raise NotImplementedError('subclasses of BaseCache must provide a get() method') + raise NotImplementedError("subclasses of BaseCache must provide a get() method") async def aget(self, key, default=None, version=None): - return await sync_to_async(self.get, thread_sensitive=True)(key, default, version) + return await sync_to_async(self.get, thread_sensitive=True)( + key, default, version + ) def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): """ Set a value in the cache. If timeout is given, use that timeout for the key; otherwise use the default cache timeout. """ - raise NotImplementedError('subclasses of BaseCache must provide a set() method') + raise NotImplementedError("subclasses of BaseCache must provide a set() method") async def aset(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): - return await sync_to_async(self.set, thread_sensitive=True)(key, value, timeout, version) + return await sync_to_async(self.set, thread_sensitive=True)( + key, value, timeout, version + ) def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): """ Update the key's expiry time using timeout. Return True if successful or False if the key does not exist. """ - raise NotImplementedError('subclasses of BaseCache must provide a touch() method') + raise NotImplementedError( + "subclasses of BaseCache must provide a touch() method" + ) async def atouch(self, key, timeout=DEFAULT_TIMEOUT, version=None): - return await sync_to_async(self.touch, thread_sensitive=True)(key, timeout, version) + return await sync_to_async(self.touch, thread_sensitive=True)( + key, timeout, version + ) def delete(self, key, version=None): """ Delete a key from the cache and return whether it succeeded, failing silently. """ - raise NotImplementedError('subclasses of BaseCache must provide a delete() method') + raise NotImplementedError( + "subclasses of BaseCache must provide a delete() method" + ) async def adelete(self, key, version=None): return await sync_to_async(self.delete, thread_sensitive=True)(key, version) @@ -234,7 +248,9 @@ class BaseCache: """ Return True if the key is in the cache and has not expired. """ - return self.get(key, self._missing_key, version=version) is not self._missing_key + return ( + self.get(key, self._missing_key, version=version) is not self._missing_key + ) async def ahas_key(self, key, version=None): return ( @@ -318,7 +334,9 @@ class BaseCache: def clear(self): """Remove *all* values from the cache at once.""" - raise NotImplementedError('subclasses of BaseCache must provide a clear() method') + raise NotImplementedError( + "subclasses of BaseCache must provide a clear() method" + ) async def aclear(self): return await sync_to_async(self.clear, thread_sensitive=True)() @@ -373,13 +391,13 @@ class BaseCache: def memcache_key_warnings(key): if len(key) > MEMCACHE_MAX_KEY_LENGTH: yield ( - 'Cache key will cause errors if used with memcached: %r ' - '(longer than %s)' % (key, MEMCACHE_MAX_KEY_LENGTH) + "Cache key will cause errors if used with memcached: %r " + "(longer than %s)" % (key, MEMCACHE_MAX_KEY_LENGTH) ) for char in key: if ord(char) < 33 or ord(char) == 127: yield ( - 'Cache key contains characters that will cause errors if ' - 'used with memcached: %r' % key + "Cache key contains characters that will cause errors if " + "used with memcached: %r" % key ) break diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py index 5bb1c5aec5..e3d055084c 100644 --- a/django/core/cache/backends/db.py +++ b/django/core/cache/backends/db.py @@ -14,13 +14,14 @@ class Options: This allows cache operations to be controlled by the router """ + def __init__(self, table): self.db_table = table - self.app_label = 'django_cache' - self.model_name = 'cacheentry' - self.verbose_name = 'cache entry' - self.verbose_name_plural = 'cache entries' - self.object_name = 'CacheEntry' + self.app_label = "django_cache" + self.model_name = "cacheentry" + self.verbose_name = "cache entry" + self.verbose_name_plural = "cache entries" + self.object_name = "CacheEntry" self.abstract = False self.managed = True self.proxy = False @@ -34,6 +35,7 @@ class BaseDatabaseCache(BaseCache): class CacheEntry: _meta = Options(table) + self.cache_model_class = CacheEntry @@ -54,7 +56,9 @@ class DatabaseCache(BaseDatabaseCache): if not keys: return {} - key_map = {self.make_and_validate_key(key, version=version): key for key in keys} + key_map = { + self.make_and_validate_key(key, version=version): key for key in keys + } db = router.db_for_read(self.cache_model_class) connection = connections[db] @@ -63,13 +67,14 @@ class DatabaseCache(BaseDatabaseCache): with connection.cursor() as cursor: cursor.execute( - 'SELECT %s, %s, %s FROM %s WHERE %s IN (%s)' % ( - quote_name('cache_key'), - quote_name('value'), - quote_name('expires'), + "SELECT %s, %s, %s FROM %s WHERE %s IN (%s)" + % ( + quote_name("cache_key"), + quote_name("value"), + quote_name("expires"), table, - quote_name('cache_key'), - ', '.join(['%s'] * len(key_map)), + quote_name("cache_key"), + ", ".join(["%s"] * len(key_map)), ), list(key_map), ) @@ -78,7 +83,9 @@ class DatabaseCache(BaseDatabaseCache): result = {} expired_keys = [] expression = models.Expression(output_field=models.DateTimeField()) - converters = (connection.ops.get_db_converters(expression) + expression.get_db_converters(connection)) + converters = connection.ops.get_db_converters( + expression + ) + expression.get_db_converters(connection) for key, value, expires in rows: for converter in converters: expires = converter(expires, expression, connection) @@ -93,15 +100,15 @@ class DatabaseCache(BaseDatabaseCache): def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_and_validate_key(key, version=version) - self._base_set('set', key, value, timeout) + self._base_set("set", key, value, timeout) def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_and_validate_key(key, version=version) - return self._base_set('add', key, value, timeout) + return self._base_set("add", key, value, timeout) def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_and_validate_key(key, version=version) - return self._base_set('touch', key, None, timeout) + return self._base_set("touch", key, None, timeout) def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT): timeout = self.get_backend_timeout(timeout) @@ -126,7 +133,7 @@ class DatabaseCache(BaseDatabaseCache): pickled = pickle.dumps(value, self.pickle_protocol) # The DB column is expecting a string, so make sure the value is a # string, not bytes. Refs #19274. - b64encoded = base64.b64encode(pickled).decode('latin1') + b64encoded = base64.b64encode(pickled).decode("latin1") try: # Note: typecasting for datetimes is needed by some 3rd party # database backends. All core backends work without typecasting, @@ -134,52 +141,59 @@ class DatabaseCache(BaseDatabaseCache): # regressions. with transaction.atomic(using=db): cursor.execute( - 'SELECT %s, %s FROM %s WHERE %s = %%s' % ( - quote_name('cache_key'), - quote_name('expires'), + "SELECT %s, %s FROM %s WHERE %s = %%s" + % ( + quote_name("cache_key"), + quote_name("expires"), table, - quote_name('cache_key'), + quote_name("cache_key"), ), - [key] + [key], ) result = cursor.fetchone() if result: current_expires = result[1] - expression = models.Expression(output_field=models.DateTimeField()) - for converter in (connection.ops.get_db_converters(expression) + - expression.get_db_converters(connection)): - current_expires = converter(current_expires, expression, connection) + expression = models.Expression( + output_field=models.DateTimeField() + ) + for converter in connection.ops.get_db_converters( + expression + ) + expression.get_db_converters(connection): + current_expires = converter( + current_expires, expression, connection + ) exp = connection.ops.adapt_datetimefield_value(exp) - if result and mode == 'touch': + if result and mode == "touch": cursor.execute( - 'UPDATE %s SET %s = %%s WHERE %s = %%s' % ( - table, - quote_name('expires'), - quote_name('cache_key') - ), - [exp, key] + "UPDATE %s SET %s = %%s WHERE %s = %%s" + % (table, quote_name("expires"), quote_name("cache_key")), + [exp, key], ) - elif result and (mode == 'set' or (mode == 'add' and current_expires < now)): + elif result and ( + mode == "set" or (mode == "add" and current_expires < now) + ): cursor.execute( - 'UPDATE %s SET %s = %%s, %s = %%s WHERE %s = %%s' % ( + "UPDATE %s SET %s = %%s, %s = %%s WHERE %s = %%s" + % ( table, - quote_name('value'), - quote_name('expires'), - quote_name('cache_key'), + quote_name("value"), + quote_name("expires"), + quote_name("cache_key"), ), - [b64encoded, exp, key] + [b64encoded, exp, key], ) - elif mode != 'touch': + elif mode != "touch": cursor.execute( - 'INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)' % ( + "INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)" + % ( table, - quote_name('cache_key'), - quote_name('value'), - quote_name('expires'), + quote_name("cache_key"), + quote_name("value"), + quote_name("expires"), ), - [key, b64encoded, exp] + [key, b64encoded, exp], ) else: return False # touch failed. @@ -208,10 +222,11 @@ class DatabaseCache(BaseDatabaseCache): with connection.cursor() as cursor: cursor.execute( - 'DELETE FROM %s WHERE %s IN (%s)' % ( + "DELETE FROM %s WHERE %s IN (%s)" + % ( table, - quote_name('cache_key'), - ', '.join(['%s'] * len(keys)), + quote_name("cache_key"), + ", ".join(["%s"] * len(keys)), ), keys, ) @@ -228,13 +243,14 @@ class DatabaseCache(BaseDatabaseCache): with connection.cursor() as cursor: cursor.execute( - 'SELECT %s FROM %s WHERE %s = %%s and %s > %%s' % ( - quote_name('cache_key'), + "SELECT %s FROM %s WHERE %s = %%s and %s > %%s" + % ( + quote_name("cache_key"), quote_name(self._table), - quote_name('cache_key'), - quote_name('expires'), + quote_name("cache_key"), + quote_name("expires"), ), - [key, connection.ops.adapt_datetimefield_value(now)] + [key, connection.ops.adapt_datetimefield_value(now)], ) return cursor.fetchone() is not None @@ -244,23 +260,28 @@ class DatabaseCache(BaseDatabaseCache): else: connection = connections[db] table = connection.ops.quote_name(self._table) - cursor.execute('DELETE FROM %s WHERE %s < %%s' % ( - table, - connection.ops.quote_name('expires'), - ), [connection.ops.adapt_datetimefield_value(now)]) + cursor.execute( + "DELETE FROM %s WHERE %s < %%s" + % ( + table, + connection.ops.quote_name("expires"), + ), + [connection.ops.adapt_datetimefield_value(now)], + ) deleted_count = cursor.rowcount remaining_num = num - deleted_count if remaining_num > self._max_entries: cull_num = remaining_num // self._cull_frequency cursor.execute( - connection.ops.cache_key_culling_sql() % table, - [cull_num]) + connection.ops.cache_key_culling_sql() % table, [cull_num] + ) last_cache_key = cursor.fetchone() if last_cache_key: cursor.execute( - 'DELETE FROM %s WHERE %s < %%s' % ( + "DELETE FROM %s WHERE %s < %%s" + % ( table, - connection.ops.quote_name('cache_key'), + connection.ops.quote_name("cache_key"), ), [last_cache_key[0]], ) @@ -270,4 +291,4 @@ class DatabaseCache(BaseDatabaseCache): connection = connections[db] table = connection.ops.quote_name(self._table) with connection.cursor() as cursor: - cursor.execute('DELETE FROM %s' % table) + cursor.execute("DELETE FROM %s" % table) diff --git a/django/core/cache/backends/filebased.py b/django/core/cache/backends/filebased.py index fc99d11687..631da49444 100644 --- a/django/core/cache/backends/filebased.py +++ b/django/core/cache/backends/filebased.py @@ -14,7 +14,7 @@ from django.utils.crypto import md5 class FileBasedCache(BaseCache): - cache_suffix = '.djcache' + cache_suffix = ".djcache" pickle_protocol = pickle.HIGHEST_PROTOCOL def __init__(self, dir, params): @@ -31,7 +31,7 @@ class FileBasedCache(BaseCache): def get(self, key, default=None, version=None): fname = self._key_to_file(key, version) try: - with open(fname, 'rb') as f: + with open(fname, "rb") as f: if not self._is_expired(f): return pickle.loads(zlib.decompress(f.read())) except FileNotFoundError: @@ -50,7 +50,7 @@ class FileBasedCache(BaseCache): fd, tmp_path = tempfile.mkstemp(dir=self._dir) renamed = False try: - with open(fd, 'wb') as f: + with open(fd, "wb") as f: self._write_content(f, timeout, value) file_move_safe(tmp_path, fname, allow_overwrite=True) renamed = True @@ -60,7 +60,7 @@ class FileBasedCache(BaseCache): def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): try: - with open(self._key_to_file(key, version), 'r+b') as f: + with open(self._key_to_file(key, version), "r+b") as f: try: locks.lock(f, locks.LOCK_EX) if self._is_expired(f): @@ -91,7 +91,7 @@ class FileBasedCache(BaseCache): def has_key(self, key, version=None): fname = self._key_to_file(key, version) if os.path.exists(fname): - with open(fname, 'rb') as f: + with open(fname, "rb") as f: return not self._is_expired(f) return False @@ -108,8 +108,7 @@ class FileBasedCache(BaseCache): if self._cull_frequency == 0: return self.clear() # Clear the cache when CULL_FREQUENCY = 0 # Delete a random selection of entries - filelist = random.sample(filelist, - int(num_entries / self._cull_frequency)) + filelist = random.sample(filelist, int(num_entries / self._cull_frequency)) for fname in filelist: self._delete(fname) @@ -128,10 +127,15 @@ class FileBasedCache(BaseCache): root cache path joined with the md5sum of the key and a suffix. """ key = self.make_and_validate_key(key, version=version) - return os.path.join(self._dir, ''.join([ - md5(key.encode(), usedforsecurity=False).hexdigest(), - self.cache_suffix, - ])) + return os.path.join( + self._dir, + "".join( + [ + md5(key.encode(), usedforsecurity=False).hexdigest(), + self.cache_suffix, + ] + ), + ) def clear(self): """ @@ -161,5 +165,5 @@ class FileBasedCache(BaseCache): """ return [ os.path.join(self._dir, fname) - for fname in glob.glob1(self._dir, '*%s' % self.cache_suffix) + for fname in glob.glob1(self._dir, "*%s" % self.cache_suffix) ] diff --git a/django/core/cache/backends/memcached.py b/django/core/cache/backends/memcached.py index 472a28179c..2416168634 100644 --- a/django/core/cache/backends/memcached.py +++ b/django/core/cache/backends/memcached.py @@ -4,7 +4,10 @@ import re import time from django.core.cache.backends.base import ( - DEFAULT_TIMEOUT, BaseCache, InvalidCacheKey, memcache_key_warnings, + DEFAULT_TIMEOUT, + BaseCache, + InvalidCacheKey, + memcache_key_warnings, ) from django.utils.functional import cached_property @@ -13,7 +16,7 @@ class BaseMemcachedCache(BaseCache): def __init__(self, server, params, library, value_not_found_exception): super().__init__(params) if isinstance(server, str): - self._servers = re.split('[;,]', server) + self._servers = re.split("[;,]", server) else: self._servers = server @@ -23,7 +26,7 @@ class BaseMemcachedCache(BaseCache): self._lib = library self._class = library.Client - self._options = params.get('OPTIONS') or {} + self._options = params.get("OPTIONS") or {} @property def client_servers(self): @@ -86,7 +89,9 @@ class BaseMemcachedCache(BaseCache): return bool(self._cache.delete(key)) def get_many(self, keys, version=None): - key_map = {self.make_and_validate_key(key, version=version): key for key in keys} + key_map = { + self.make_and_validate_key(key, version=version): key for key in keys + } ret = self._cache.get_multi(key_map.keys()) return {key_map[k]: v for k, v in ret.items()} @@ -118,7 +123,9 @@ class BaseMemcachedCache(BaseCache): safe_key = self.make_and_validate_key(key, version=version) safe_data[safe_key] = value original_keys[safe_key] = key - failed_keys = self._cache.set_multi(safe_data, self.get_backend_timeout(timeout)) + failed_keys = self._cache.set_multi( + safe_data, self.get_backend_timeout(timeout) + ) return [original_keys[k] for k in failed_keys] def delete_many(self, keys, version=None): @@ -135,15 +142,19 @@ class BaseMemcachedCache(BaseCache): class PyLibMCCache(BaseMemcachedCache): "An implementation of a cache binding using pylibmc" + def __init__(self, server, params): import pylibmc - super().__init__(server, params, library=pylibmc, value_not_found_exception=pylibmc.NotFound) + + super().__init__( + server, params, library=pylibmc, value_not_found_exception=pylibmc.NotFound + ) @property def client_servers(self): output = [] for server in self._servers: - output.append(server[5:] if server.startswith('unix:') else server) + output.append(server[5:] if server.startswith("unix:") else server) return output def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): @@ -160,13 +171,17 @@ class PyLibMCCache(BaseMemcachedCache): class PyMemcacheCache(BaseMemcachedCache): """An implementation of a cache binding using pymemcache.""" + def __init__(self, server, params): import pymemcache.serde - super().__init__(server, params, library=pymemcache, value_not_found_exception=KeyError) + + super().__init__( + server, params, library=pymemcache, value_not_found_exception=KeyError + ) self._class = self._lib.HashClient self._options = { - 'allow_unicode_keys': True, - 'default_noreply': False, - 'serde': pymemcache.serde.pickle_serde, + "allow_unicode_keys": True, + "default_noreply": False, + "serde": pymemcache.serde.pickle_serde, **self._options, } diff --git a/django/core/cache/backends/redis.py b/django/core/cache/backends/redis.py index f168e93737..e0d30784ff 100644 --- a/django/core/cache/backends/redis.py +++ b/django/core/cache/backends/redis.py @@ -58,7 +58,7 @@ class RedisCacheClient: parser_class = import_string(parser_class) parser_class = parser_class or self._lib.connection.DefaultParser - self._pool_options = {'parser_class': parser_class, 'db': db} + self._pool_options = {"parser_class": parser_class, "db": db} def _get_connection_pool_index(self, write): # Write to the first server. Read from other servers if there are more, @@ -71,7 +71,8 @@ class RedisCacheClient: index = self._get_connection_pool_index(write) if index not in self._pools: self._pools[index] = self._pool_class.from_url( - self._servers[index], **self._pool_options, + self._servers[index], + **self._pool_options, ) return self._pools[index] @@ -159,12 +160,12 @@ class RedisCache(BaseCache): def __init__(self, server, params): super().__init__(params) if isinstance(server, str): - self._servers = re.split('[;,]', server) + self._servers = re.split("[;,]", server) else: self._servers = server self._class = RedisCacheClient - self._options = params.get('OPTIONS', {}) + self._options = params.get("OPTIONS", {}) @cached_property def _cache(self): @@ -198,7 +199,9 @@ class RedisCache(BaseCache): return self._cache.delete(key) def get_many(self, keys, version=None): - key_map = {self.make_and_validate_key(key, version=version): key for key in keys} + key_map = { + self.make_and_validate_key(key, version=version): key for key in keys + } ret = self._cache.get_many(key_map.keys()) return {key_map[k]: v for k, v in ret.items()} diff --git a/django/core/cache/utils.py b/django/core/cache/utils.py index d41960f6e4..ff2a23aa6f 100644 --- a/django/core/cache/utils.py +++ b/django/core/cache/utils.py @@ -1,6 +1,6 @@ from django.utils.crypto import md5 -TEMPLATE_FRAGMENT_KEY_TEMPLATE = 'template.cache.%s.%s' +TEMPLATE_FRAGMENT_KEY_TEMPLATE = "template.cache.%s.%s" def make_template_fragment_key(fragment_name, vary_on=None): @@ -8,5 +8,5 @@ def make_template_fragment_key(fragment_name, vary_on=None): if vary_on is not None: for arg in vary_on: hasher.update(str(arg).encode()) - hasher.update(b':') + hasher.update(b":") return TEMPLATE_FRAGMENT_KEY_TEMPLATE % (fragment_name, hasher.hexdigest()) diff --git a/django/core/checks/__init__.py b/django/core/checks/__init__.py index 296e991ddc..998ab9dee2 100644 --- a/django/core/checks/__init__.py +++ b/django/core/checks/__init__.py @@ -1,6 +1,15 @@ from .messages import ( - CRITICAL, DEBUG, ERROR, INFO, WARNING, CheckMessage, Critical, Debug, - Error, Info, Warning, + CRITICAL, + DEBUG, + ERROR, + INFO, + WARNING, + CheckMessage, + Critical, + Debug, + Error, + Info, + Warning, ) from .registry import Tags, register, run_checks, tag_exists @@ -20,8 +29,19 @@ import django.core.checks.urls # NOQA isort:skip __all__ = [ - 'CheckMessage', - 'Debug', 'Info', 'Warning', 'Error', 'Critical', - 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', - 'register', 'run_checks', 'tag_exists', 'Tags', + "CheckMessage", + "Debug", + "Info", + "Warning", + "Error", + "Critical", + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL", + "register", + "run_checks", + "tag_exists", + "Tags", ] diff --git a/django/core/checks/async_checks.py b/django/core/checks/async_checks.py index fbb5267358..a0e01867d9 100644 --- a/django/core/checks/async_checks.py +++ b/django/core/checks/async_checks.py @@ -3,14 +3,14 @@ import os from . import Error, Tags, register E001 = Error( - 'You should not set the DJANGO_ALLOW_ASYNC_UNSAFE environment variable in ' - 'deployment. This disables async safety protection.', - id='async.E001', + "You should not set the DJANGO_ALLOW_ASYNC_UNSAFE environment variable in " + "deployment. This disables async safety protection.", + id="async.E001", ) @register(Tags.async_support, deploy=True) def check_async_unsafe(app_configs, **kwargs): - if os.environ.get('DJANGO_ALLOW_ASYNC_UNSAFE'): + if os.environ.get("DJANGO_ALLOW_ASYNC_UNSAFE"): return [E001] return [] diff --git a/django/core/checks/caches.py b/django/core/checks/caches.py index b755e0035a..c288a6ff4a 100644 --- a/django/core/checks/caches.py +++ b/django/core/checks/caches.py @@ -8,7 +8,7 @@ from . import Error, Tags, Warning, register E001 = Error( "You must define a '%s' cache in your CACHES setting." % DEFAULT_CACHE_ALIAS, - id='caches.E001', + id="caches.E001", ) @@ -22,11 +22,11 @@ def check_default_cache_is_configured(app_configs, **kwargs): @register(Tags.caches, deploy=True) def check_cache_location_not_exposed(app_configs, **kwargs): errors = [] - for name in ('MEDIA_ROOT', 'STATIC_ROOT', 'STATICFILES_DIRS'): + for name in ("MEDIA_ROOT", "STATIC_ROOT", "STATICFILES_DIRS"): setting = getattr(settings, name, None) if not setting: continue - if name == 'STATICFILES_DIRS': + if name == "STATICFILES_DIRS": paths = set() for staticfiles_dir in setting: if isinstance(staticfiles_dir, (list, tuple)): @@ -40,19 +40,21 @@ def check_cache_location_not_exposed(app_configs, **kwargs): continue cache_path = pathlib.Path(cache._dir).resolve() if any(path == cache_path for path in paths): - relation = 'matches' + relation = "matches" elif any(path in cache_path.parents for path in paths): - relation = 'is inside' + relation = "is inside" elif any(cache_path in path.parents for path in paths): - relation = 'contains' + relation = "contains" else: continue - errors.append(Warning( - f"Your '{alias}' cache configuration might expose your cache " - f"or lead to corruption of your data because its LOCATION " - f"{relation} {name}.", - id='caches.W002', - )) + errors.append( + Warning( + f"Your '{alias}' cache configuration might expose your cache " + f"or lead to corruption of your data because its LOCATION " + f"{relation} {name}.", + id="caches.W002", + ) + ) return errors @@ -63,10 +65,12 @@ def check_file_based_cache_is_absolute(app_configs, **kwargs): cache = caches[alias] if not isinstance(cache, FileBasedCache): continue - if not pathlib.Path(config['LOCATION']).is_absolute(): - errors.append(Warning( - f"Your '{alias}' cache LOCATION path is relative. Use an " - f"absolute path instead.", - id='caches.W003', - )) + if not pathlib.Path(config["LOCATION"]).is_absolute(): + errors.append( + Warning( + f"Your '{alias}' cache LOCATION path is relative. Use an " + f"absolute path instead.", + id="caches.W003", + ) + ) return errors diff --git a/django/core/checks/compatibility/django_4_0.py b/django/core/checks/compatibility/django_4_0.py index 7788629735..79ee5fa6b3 100644 --- a/django/core/checks/compatibility/django_4_0.py +++ b/django/core/checks/compatibility/django_4_0.py @@ -7,12 +7,14 @@ from .. import Error, Tags, register def check_csrf_trusted_origins(app_configs, **kwargs): errors = [] for origin in settings.CSRF_TRUSTED_ORIGINS: - if '://' not in origin: - errors.append(Error( - 'As of Django 4.0, the values in the CSRF_TRUSTED_ORIGINS ' - 'setting must start with a scheme (usually http:// or ' - 'https://) but found %s. See the release notes for details.' - % origin, - id='4_0.E001', - )) + if "://" not in origin: + errors.append( + Error( + "As of Django 4.0, the values in the CSRF_TRUSTED_ORIGINS " + "setting must start with a scheme (usually http:// or " + "https://) but found %s. See the release notes for details." + % origin, + id="4_0.E001", + ) + ) return errors diff --git a/django/core/checks/files.py b/django/core/checks/files.py index 5f76ae5a17..40dc745840 100644 --- a/django/core/checks/files.py +++ b/django/core/checks/files.py @@ -7,7 +7,7 @@ from . import Error, Tags, register @register(Tags.files) def check_setting_file_upload_temp_dir(app_configs, **kwargs): - setting = getattr(settings, 'FILE_UPLOAD_TEMP_DIR', None) + setting = getattr(settings, "FILE_UPLOAD_TEMP_DIR", None) if setting and not Path(setting).is_dir(): return [ Error( diff --git a/django/core/checks/messages.py b/django/core/checks/messages.py index 0987c2d118..db7aa55119 100644 --- a/django/core/checks/messages.py +++ b/django/core/checks/messages.py @@ -7,10 +7,9 @@ CRITICAL = 50 class CheckMessage: - def __init__(self, level, msg, hint=None, obj=None, id=None): if not isinstance(level, int): - raise TypeError('The first argument should be level.') + raise TypeError("The first argument should be level.") self.level = level self.msg = msg self.hint = hint @@ -18,10 +17,9 @@ class CheckMessage: self.id = id def __eq__(self, other): - return ( - isinstance(other, self.__class__) and - all(getattr(self, attr) == getattr(other, attr) - for attr in ['level', 'msg', 'hint', 'obj', 'id']) + return isinstance(other, self.__class__) and all( + getattr(self, attr) == getattr(other, attr) + for attr in ["level", "msg", "hint", "obj", "id"] ) def __str__(self): @@ -36,18 +34,25 @@ class CheckMessage: else: obj = str(self.obj) id = "(%s) " % self.id if self.id else "" - hint = "\n\tHINT: %s" % self.hint if self.hint else '' + hint = "\n\tHINT: %s" % self.hint if self.hint else "" return "%s: %s%s%s" % (obj, id, self.msg, hint) def __repr__(self): - return "<%s: level=%r, msg=%r, hint=%r, obj=%r, id=%r>" % \ - (self.__class__.__name__, self.level, self.msg, self.hint, self.obj, self.id) + return "<%s: level=%r, msg=%r, hint=%r, obj=%r, id=%r>" % ( + self.__class__.__name__, + self.level, + self.msg, + self.hint, + self.obj, + self.id, + ) def is_serious(self, level=ERROR): return self.level >= level def is_silenced(self): from django.conf import settings + return self.id in settings.SILENCED_SYSTEM_CHECKS diff --git a/django/core/checks/model_checks.py b/django/core/checks/model_checks.py index 15d9b7fd86..7a5bef9b26 100644 --- a/django/core/checks/model_checks.py +++ b/django/core/checks/model_checks.py @@ -17,7 +17,9 @@ def check_all_models(app_configs=None, **kwargs): if app_configs is None: models = apps.get_models() else: - models = chain.from_iterable(app_config.get_models() for app_config in app_configs) + models = chain.from_iterable( + app_config.get_models() for app_config in app_configs + ) for model in models: if model._meta.managed and not model._meta.proxy: db_table_models[model._meta.db_table].append(model._meta.label) @@ -27,7 +29,7 @@ def check_all_models(app_configs=None, **kwargs): "The '%s.check()' class method is currently overridden by %r." % (model.__name__, model.check), obj=model, - id='models.E020' + id="models.E020", ) ) else: @@ -37,17 +39,17 @@ def check_all_models(app_configs=None, **kwargs): for model_constraint in model._meta.constraints: constraints[model_constraint.name].append(model._meta.label) if settings.DATABASE_ROUTERS: - error_class, error_id = Warning, 'models.W035' + error_class, error_id = Warning, "models.W035" error_hint = ( - 'You have configured settings.DATABASE_ROUTERS. Verify that %s ' - 'are correctly routed to separate databases.' + "You have configured settings.DATABASE_ROUTERS. Verify that %s " + "are correctly routed to separate databases." ) else: - error_class, error_id = Error, 'models.E028' + error_class, error_id = Error, "models.E028" error_hint = None for db_table, model_labels in db_table_models.items(): if len(model_labels) != 1: - model_labels_str = ', '.join(model_labels) + model_labels_str = ", ".join(model_labels) errors.append( error_class( "db_table '%s' is used by multiple models: %s." @@ -62,12 +64,13 @@ def check_all_models(app_configs=None, **kwargs): model_labels = set(model_labels) errors.append( Error( - "index name '%s' is not unique %s %s." % ( + "index name '%s' is not unique %s %s." + % ( index_name, - 'for model' if len(model_labels) == 1 else 'among models:', - ', '.join(sorted(model_labels)), + "for model" if len(model_labels) == 1 else "among models:", + ", ".join(sorted(model_labels)), ), - id='models.E029' if len(model_labels) == 1 else 'models.E030', + id="models.E029" if len(model_labels) == 1 else "models.E030", ), ) for constraint_name, model_labels in constraints.items(): @@ -75,12 +78,13 @@ def check_all_models(app_configs=None, **kwargs): model_labels = set(model_labels) errors.append( Error( - "constraint name '%s' is not unique %s %s." % ( + "constraint name '%s' is not unique %s %s." + % ( constraint_name, - 'for model' if len(model_labels) == 1 else 'among models:', - ', '.join(sorted(model_labels)), + "for model" if len(model_labels) == 1 else "among models:", + ", ".join(sorted(model_labels)), ), - id='models.E031' if len(model_labels) == 1 else 'models.E032', + id="models.E031" if len(model_labels) == 1 else "models.E032", ), ) return errors @@ -104,8 +108,10 @@ def _check_lazy_references(apps, ignore=None): return [] from django.db.models import signals + model_signals = { - signal: name for name, signal in vars(signals).items() + signal: name + for name, signal in vars(signals).items() if isinstance(signal, signals.ModelSignal) } @@ -120,9 +126,9 @@ def _check_lazy_references(apps, ignore=None): annotated there with a `func` attribute so as to imitate a partial. """ operation, args, keywords = obj, [], {} - while hasattr(operation, 'func'): - args.extend(getattr(operation, 'args', [])) - keywords.update(getattr(operation, 'keywords', {})) + while hasattr(operation, "func"): + args.extend(getattr(operation, "args", [])) + keywords.update(getattr(operation, "keywords", {})) operation = operation.func return operation, args, keywords @@ -146,11 +152,11 @@ def _check_lazy_references(apps, ignore=None): "to '%(model)s', but %(model_error)s." ) params = { - 'model': '.'.join(model_key), - 'field': keywords['field'], - 'model_error': app_model_error(model_key), + "model": ".".join(model_key), + "field": keywords["field"], + "model_error": app_model_error(model_key), } - return Error(error_msg % params, obj=keywords['field'], id='fields.E307') + return Error(error_msg % params, obj=keywords["field"], id="fields.E307") def signal_connect_error(model_key, func, args, keywords): error_msg = ( @@ -163,34 +169,39 @@ def _check_lazy_references(apps, ignore=None): if isinstance(receiver, types.FunctionType): description = "The function '%s'" % receiver.__name__ elif isinstance(receiver, types.MethodType): - description = "Bound method '%s.%s'" % (receiver.__self__.__class__.__name__, receiver.__name__) + description = "Bound method '%s.%s'" % ( + receiver.__self__.__class__.__name__, + receiver.__name__, + ) else: description = "An instance of class '%s'" % receiver.__class__.__name__ - signal_name = model_signals.get(func.__self__, 'unknown') + signal_name = model_signals.get(func.__self__, "unknown") params = { - 'model': '.'.join(model_key), - 'receiver': description, - 'signal': signal_name, - 'model_error': app_model_error(model_key), + "model": ".".join(model_key), + "receiver": description, + "signal": signal_name, + "model_error": app_model_error(model_key), } - return Error(error_msg % params, obj=receiver.__module__, id='signals.E001') + return Error(error_msg % params, obj=receiver.__module__, id="signals.E001") def default_error(model_key, func, args, keywords): - error_msg = "%(op)s contains a lazy reference to %(model)s, but %(model_error)s." + error_msg = ( + "%(op)s contains a lazy reference to %(model)s, but %(model_error)s." + ) params = { - 'op': func, - 'model': '.'.join(model_key), - 'model_error': app_model_error(model_key), + "op": func, + "model": ".".join(model_key), + "model_error": app_model_error(model_key), } - return Error(error_msg % params, obj=func, id='models.E022') + return Error(error_msg % params, obj=func, id="models.E022") # Maps common uses of lazy operations to corresponding error functions # defined above. If a key maps to None, no error will be produced. # default_error() will be used for usages that don't appear in this dict. known_lazy = { - ('django.db.models.fields.related', 'resolve_related_class'): field_error, - ('django.db.models.fields.related', 'set_managed'): None, - ('django.dispatch.dispatcher', 'connect'): signal_connect_error, + ("django.db.models.fields.related", "resolve_related_class"): field_error, + ("django.db.models.fields.related", "set_managed"): None, + ("django.dispatch.dispatcher", "connect"): signal_connect_error, } def build_error(model_key, func, args, keywords): @@ -198,11 +209,17 @@ def _check_lazy_references(apps, ignore=None): error_fn = known_lazy.get(key, default_error) return error_fn(model_key, func, args, keywords) if error_fn else None - return sorted(filter(None, ( - build_error(model_key, *extract_operation(func)) - for model_key in pending_models - for func in apps._pending_operations[model_key] - )), key=lambda error: error.msg) + return sorted( + filter( + None, + ( + build_error(model_key, *extract_operation(func)) + for model_key in pending_models + for func in apps._pending_operations[model_key] + ), + ), + key=lambda error: error.msg, + ) @register(Tags.models) diff --git a/django/core/checks/registry.py b/django/core/checks/registry.py index d7bfa49548..f4bdea8691 100644 --- a/django/core/checks/registry.py +++ b/django/core/checks/registry.py @@ -8,24 +8,24 @@ class Tags: """ Built-in tags for internal checks. """ - admin = 'admin' - async_support = 'async_support' - caches = 'caches' - compatibility = 'compatibility' - database = 'database' - files = 'files' - models = 'models' - security = 'security' - signals = 'signals' - sites = 'sites' - staticfiles = 'staticfiles' - templates = 'templates' - translation = 'translation' - urls = 'urls' + + admin = "admin" + async_support = "async_support" + caches = "caches" + compatibility = "compatibility" + database = "database" + files = "files" + models = "models" + security = "security" + signals = "signals" + sites = "sites" + staticfiles = "staticfiles" + templates = "templates" + translation = "translation" + urls = "urls" class CheckRegistry: - def __init__(self): self.registered_checks = set() self.deployment_checks = set() @@ -46,13 +46,18 @@ class CheckRegistry: # or registry.register(my_check, 'mytag', 'anothertag') """ + def inner(check): if not func_accepts_kwargs(check): raise TypeError( - 'Check functions must accept keyword arguments (**kwargs).' + "Check functions must accept keyword arguments (**kwargs)." ) check.tags = tags - checks = self.deployment_checks if kwargs.get('deploy') else self.registered_checks + checks = ( + self.deployment_checks + if kwargs.get("deploy") + else self.registered_checks + ) checks.add(check) return check @@ -63,7 +68,13 @@ class CheckRegistry: tags += (check,) return inner - def run_checks(self, app_configs=None, tags=None, include_deployment_checks=False, databases=None): + def run_checks( + self, + app_configs=None, + tags=None, + include_deployment_checks=False, + databases=None, + ): """ Run all registered checks and return list of Errors and Warnings. """ @@ -77,9 +88,8 @@ class CheckRegistry: new_errors = check(app_configs=app_configs, databases=databases) if not is_iterable(new_errors): raise TypeError( - 'The function %r did not return a list. All functions ' - 'registered with the checks registry must return a list.' - % check, + "The function %r did not return a list. All functions " + "registered with the checks registry must return a list." % check, ) errors.extend(new_errors) return errors @@ -88,9 +98,11 @@ class CheckRegistry: return tag in self.tags_available(include_deployment_checks) def tags_available(self, deployment_checks=False): - return set(chain.from_iterable( - check.tags for check in self.get_checks(deployment_checks) - )) + return set( + chain.from_iterable( + check.tags for check in self.get_checks(deployment_checks) + ) + ) def get_checks(self, include_deployment_checks=False): checks = list(self.registered_checks) diff --git a/django/core/checks/security/base.py b/django/core/checks/security/base.py index d37b968a7d..f85adabd1a 100644 --- a/django/core/checks/security/base.py +++ b/django/core/checks/security/base.py @@ -4,15 +4,22 @@ from django.core.exceptions import ImproperlyConfigured from .. import Error, Tags, Warning, register CROSS_ORIGIN_OPENER_POLICY_VALUES = { - 'same-origin', 'same-origin-allow-popups', 'unsafe-none', + "same-origin", + "same-origin-allow-popups", + "unsafe-none", } REFERRER_POLICY_VALUES = { - 'no-referrer', 'no-referrer-when-downgrade', 'origin', - 'origin-when-cross-origin', 'same-origin', 'strict-origin', - 'strict-origin-when-cross-origin', 'unsafe-url', + "no-referrer", + "no-referrer-when-downgrade", + "origin", + "origin-when-cross-origin", + "same-origin", + "strict-origin", + "strict-origin-when-cross-origin", + "unsafe-url", } -SECRET_KEY_INSECURE_PREFIX = 'django-insecure-' +SECRET_KEY_INSECURE_PREFIX = "django-insecure-" SECRET_KEY_MIN_LENGTH = 50 SECRET_KEY_MIN_UNIQUE_CHARACTERS = 5 @@ -31,7 +38,7 @@ W001 = Warning( "SECURE_CONTENT_TYPE_NOSNIFF, SECURE_REFERRER_POLICY, " "SECURE_CROSS_ORIGIN_OPENER_POLICY, and SECURE_SSL_REDIRECT settings will " "have no effect.", - id='security.W001', + id="security.W001", ) W002 = Warning( @@ -41,7 +48,7 @@ W002 = Warning( "'x-frame-options' header. Unless there is a good reason for your " "site to be served in a frame, you should consider enabling this " "header to help prevent clickjacking attacks.", - id='security.W002', + id="security.W002", ) W004 = Warning( @@ -50,7 +57,7 @@ W004 = Warning( "setting a value and enabling HTTP Strict Transport Security. " "Be sure to read the documentation first; enabling HSTS carelessly " "can cause serious, irreversible problems.", - id='security.W004', + id="security.W004", ) W005 = Warning( @@ -59,7 +66,7 @@ W005 = Warning( "via an insecure connection to a subdomain. Only set this to True if " "you are certain that all subdomains of your domain should be served " "exclusively via SSL.", - id='security.W005', + id="security.W005", ) W006 = Warning( @@ -68,7 +75,7 @@ W006 = Warning( "'X-Content-Type-Options: nosniff' header. " "You should consider enabling this header to prevent the " "browser from identifying content types incorrectly.", - id='security.W006', + id="security.W006", ) W008 = Warning( @@ -77,17 +84,17 @@ W008 = Warning( "connections, you may want to either set this setting True " "or configure a load balancer or reverse-proxy server " "to redirect all connections to HTTPS.", - id='security.W008', + id="security.W008", ) W009 = Warning( - SECRET_KEY_WARNING_MSG % 'SECRET_KEY', - id='security.W009', + SECRET_KEY_WARNING_MSG % "SECRET_KEY", + id="security.W009", ) W018 = Warning( "You should not have DEBUG set to True in deployment.", - id='security.W018', + id="security.W018", ) W019 = Warning( @@ -96,51 +103,53 @@ W019 = Warning( "MIDDLEWARE, but X_FRAME_OPTIONS is not set to 'DENY'. " "Unless there is a good reason for your site to serve other parts of " "itself in a frame, you should change it to 'DENY'.", - id='security.W019', + id="security.W019", ) W020 = Warning( "ALLOWED_HOSTS must not be empty in deployment.", - id='security.W020', + id="security.W020", ) W021 = Warning( "You have not set the SECURE_HSTS_PRELOAD setting to True. Without this, " "your site cannot be submitted to the browser preload list.", - id='security.W021', + id="security.W021", ) W022 = Warning( - 'You have not set the SECURE_REFERRER_POLICY setting. Without this, your ' - 'site will not send a Referrer-Policy header. You should consider ' - 'enabling this header to protect user privacy.', - id='security.W022', + "You have not set the SECURE_REFERRER_POLICY setting. Without this, your " + "site will not send a Referrer-Policy header. You should consider " + "enabling this header to protect user privacy.", + id="security.W022", ) E023 = Error( - 'You have set the SECURE_REFERRER_POLICY setting to an invalid value.', - hint='Valid values are: {}.'.format(', '.join(sorted(REFERRER_POLICY_VALUES))), - id='security.E023', + "You have set the SECURE_REFERRER_POLICY setting to an invalid value.", + hint="Valid values are: {}.".format(", ".join(sorted(REFERRER_POLICY_VALUES))), + id="security.E023", ) E024 = Error( - 'You have set the SECURE_CROSS_ORIGIN_OPENER_POLICY setting to an invalid ' - 'value.', - hint='Valid values are: {}.'.format( - ', '.join(sorted(CROSS_ORIGIN_OPENER_POLICY_VALUES)), + "You have set the SECURE_CROSS_ORIGIN_OPENER_POLICY setting to an invalid " + "value.", + hint="Valid values are: {}.".format( + ", ".join(sorted(CROSS_ORIGIN_OPENER_POLICY_VALUES)), ), - id='security.E024', + id="security.E024", ) -W025 = Warning(SECRET_KEY_WARNING_MSG, id='security.W025') +W025 = Warning(SECRET_KEY_WARNING_MSG, id="security.W025") def _security_middleware(): - return 'django.middleware.security.SecurityMiddleware' in settings.MIDDLEWARE + return "django.middleware.security.SecurityMiddleware" in settings.MIDDLEWARE def _xframe_middleware(): - return 'django.middleware.clickjacking.XFrameOptionsMiddleware' in settings.MIDDLEWARE + return ( + "django.middleware.clickjacking.XFrameOptionsMiddleware" in settings.MIDDLEWARE + ) @register(Tags.security, deploy=True) @@ -164,9 +173,9 @@ def check_sts(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_sts_include_subdomains(app_configs, **kwargs): passed_check = ( - not _security_middleware() or - not settings.SECURE_HSTS_SECONDS or - settings.SECURE_HSTS_INCLUDE_SUBDOMAINS is True + not _security_middleware() + or not settings.SECURE_HSTS_SECONDS + or settings.SECURE_HSTS_INCLUDE_SUBDOMAINS is True ) return [] if passed_check else [W005] @@ -174,9 +183,9 @@ def check_sts_include_subdomains(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_sts_preload(app_configs, **kwargs): passed_check = ( - not _security_middleware() or - not settings.SECURE_HSTS_SECONDS or - settings.SECURE_HSTS_PRELOAD is True + not _security_middleware() + or not settings.SECURE_HSTS_SECONDS + or settings.SECURE_HSTS_PRELOAD is True ) return [] if passed_check else [W021] @@ -184,26 +193,22 @@ def check_sts_preload(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_content_type_nosniff(app_configs, **kwargs): passed_check = ( - not _security_middleware() or - settings.SECURE_CONTENT_TYPE_NOSNIFF is True + not _security_middleware() or settings.SECURE_CONTENT_TYPE_NOSNIFF is True ) return [] if passed_check else [W006] @register(Tags.security, deploy=True) def check_ssl_redirect(app_configs, **kwargs): - passed_check = ( - not _security_middleware() or - settings.SECURE_SSL_REDIRECT is True - ) + passed_check = not _security_middleware() or settings.SECURE_SSL_REDIRECT is True return [] if passed_check else [W008] def _check_secret_key(secret_key): return ( - len(set(secret_key)) >= SECRET_KEY_MIN_UNIQUE_CHARACTERS and - len(secret_key) >= SECRET_KEY_MIN_LENGTH and - not secret_key.startswith(SECRET_KEY_INSECURE_PREFIX) + len(set(secret_key)) >= SECRET_KEY_MIN_UNIQUE_CHARACTERS + and len(secret_key) >= SECRET_KEY_MIN_LENGTH + and not secret_key.startswith(SECRET_KEY_INSECURE_PREFIX) ) @@ -224,14 +229,12 @@ def check_secret_key_fallbacks(app_configs, **kwargs): try: fallbacks = settings.SECRET_KEY_FALLBACKS except (ImproperlyConfigured, AttributeError): - warnings.append( - Warning(W025.msg % 'SECRET_KEY_FALLBACKS', id=W025.id) - ) + warnings.append(Warning(W025.msg % "SECRET_KEY_FALLBACKS", id=W025.id)) else: for index, key in enumerate(fallbacks): if not _check_secret_key(key): warnings.append( - Warning(W025.msg % f'SECRET_KEY_FALLBACKS[{index}]', id=W025.id) + Warning(W025.msg % f"SECRET_KEY_FALLBACKS[{index}]", id=W025.id) ) return warnings @@ -244,10 +247,7 @@ def check_debug(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_xframe_deny(app_configs, **kwargs): - passed_check = ( - not _xframe_middleware() or - settings.X_FRAME_OPTIONS == 'DENY' - ) + passed_check = not _xframe_middleware() or settings.X_FRAME_OPTIONS == "DENY" return [] if passed_check else [W019] @@ -263,7 +263,7 @@ def check_referrer_policy(app_configs, **kwargs): return [W022] # Support a comma-separated string or iterable of values to allow fallback. if isinstance(settings.SECURE_REFERRER_POLICY, str): - values = {v.strip() for v in settings.SECURE_REFERRER_POLICY.split(',')} + values = {v.strip() for v in settings.SECURE_REFERRER_POLICY.split(",")} else: values = set(settings.SECURE_REFERRER_POLICY) if not values <= REFERRER_POLICY_VALUES: @@ -274,9 +274,10 @@ def check_referrer_policy(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_cross_origin_opener_policy(app_configs, **kwargs): if ( - _security_middleware() and - settings.SECURE_CROSS_ORIGIN_OPENER_POLICY is not None and - settings.SECURE_CROSS_ORIGIN_OPENER_POLICY not in CROSS_ORIGIN_OPENER_POLICY_VALUES + _security_middleware() + and settings.SECURE_CROSS_ORIGIN_OPENER_POLICY is not None + and settings.SECURE_CROSS_ORIGIN_OPENER_POLICY + not in CROSS_ORIGIN_OPENER_POLICY_VALUES ): return [E024] return [] diff --git a/django/core/checks/security/csrf.py b/django/core/checks/security/csrf.py index 2b70d363e2..ea65e48c94 100644 --- a/django/core/checks/security/csrf.py +++ b/django/core/checks/security/csrf.py @@ -10,7 +10,7 @@ W003 = Warning( "('django.middleware.csrf.CsrfViewMiddleware' is not in your " "MIDDLEWARE). Enabling the middleware is the safest approach " "to ensure you don't leave any holes.", - id='security.W003', + id="security.W003", ) W016 = Warning( @@ -18,12 +18,12 @@ W016 = Warning( "MIDDLEWARE, but you have not set CSRF_COOKIE_SECURE to True. " "Using a secure-only CSRF cookie makes it more difficult for network " "traffic sniffers to steal the CSRF token.", - id='security.W016', + id="security.W016", ) def _csrf_middleware(): - return 'django.middleware.csrf.CsrfViewMiddleware' in settings.MIDDLEWARE + return "django.middleware.csrf.CsrfViewMiddleware" in settings.MIDDLEWARE @register(Tags.security, deploy=True) @@ -35,9 +35,9 @@ def check_csrf_middleware(app_configs, **kwargs): @register(Tags.security, deploy=True) def check_csrf_cookie_secure(app_configs, **kwargs): passed_check = ( - settings.CSRF_USE_SESSIONS or - not _csrf_middleware() or - settings.CSRF_COOKIE_SECURE + settings.CSRF_USE_SESSIONS + or not _csrf_middleware() + or settings.CSRF_COOKIE_SECURE ) return [] if passed_check else [W016] @@ -51,17 +51,17 @@ def check_csrf_failure_view(app_configs, **kwargs): view = _get_failure_view() except ImportError: msg = ( - "The CSRF failure view '%s' could not be imported." % - settings.CSRF_FAILURE_VIEW + "The CSRF failure view '%s' could not be imported." + % settings.CSRF_FAILURE_VIEW ) - errors.append(Error(msg, id='security.E102')) + errors.append(Error(msg, id="security.E102")) else: try: inspect.signature(view).bind(None, reason=None) except TypeError: msg = ( - "The CSRF failure view '%s' does not take the correct number of arguments." % - settings.CSRF_FAILURE_VIEW + "The CSRF failure view '%s' does not take the correct number of arguments." + % settings.CSRF_FAILURE_VIEW ) - errors.append(Error(msg, id='security.E101')) + errors.append(Error(msg, id="security.E101")) return errors diff --git a/django/core/checks/security/sessions.py b/django/core/checks/security/sessions.py index 1f31a167fa..7c251c0601 100644 --- a/django/core/checks/security/sessions.py +++ b/django/core/checks/security/sessions.py @@ -15,7 +15,7 @@ W010 = Warning( "You have 'django.contrib.sessions' in your INSTALLED_APPS, " "but you have not set SESSION_COOKIE_SECURE to True." ), - id='security.W010', + id="security.W010", ) W011 = Warning( @@ -24,12 +24,12 @@ W011 = Warning( "in your MIDDLEWARE, but you have not set " "SESSION_COOKIE_SECURE to True." ), - id='security.W011', + id="security.W011", ) W012 = Warning( add_session_cookie_message("SESSION_COOKIE_SECURE is not set to True."), - id='security.W012', + id="security.W012", ) @@ -45,7 +45,7 @@ W013 = Warning( "You have 'django.contrib.sessions' in your INSTALLED_APPS, " "but you have not set SESSION_COOKIE_HTTPONLY to True.", ), - id='security.W013', + id="security.W013", ) W014 = Warning( @@ -54,12 +54,12 @@ W014 = Warning( "in your MIDDLEWARE, but you have not set " "SESSION_COOKIE_HTTPONLY to True." ), - id='security.W014', + id="security.W014", ) W015 = Warning( add_httponly_message("SESSION_COOKIE_HTTPONLY is not set to True."), - id='security.W015', + id="security.W015", ) @@ -90,7 +90,7 @@ def check_session_cookie_httponly(app_configs, **kwargs): def _session_middleware(): - return 'django.contrib.sessions.middleware.SessionMiddleware' in settings.MIDDLEWARE + return "django.contrib.sessions.middleware.SessionMiddleware" in settings.MIDDLEWARE def _session_app(): diff --git a/django/core/checks/templates.py b/django/core/checks/templates.py index 14325bd3e0..5214276987 100644 --- a/django/core/checks/templates.py +++ b/django/core/checks/templates.py @@ -9,34 +9,40 @@ from . import Error, Tags, register E001 = Error( "You have 'APP_DIRS': True in your TEMPLATES but also specify 'loaders' " "in OPTIONS. Either remove APP_DIRS or remove the 'loaders' option.", - id='templates.E001', + id="templates.E001", ) E002 = Error( "'string_if_invalid' in TEMPLATES OPTIONS must be a string but got: {} ({}).", id="templates.E002", ) E003 = Error( - '{} is used for multiple template tag modules: {}', - id='templates.E003', + "{} is used for multiple template tag modules: {}", + id="templates.E003", ) @register(Tags.templates) def check_setting_app_dirs_loaders(app_configs, **kwargs): - return [E001] if any( - conf.get('APP_DIRS') and 'loaders' in conf.get('OPTIONS', {}) - for conf in settings.TEMPLATES - ) else [] + return ( + [E001] + if any( + conf.get("APP_DIRS") and "loaders" in conf.get("OPTIONS", {}) + for conf in settings.TEMPLATES + ) + else [] + ) @register(Tags.templates) def check_string_if_invalid_is_string(app_configs, **kwargs): errors = [] for conf in settings.TEMPLATES: - string_if_invalid = conf.get('OPTIONS', {}).get('string_if_invalid', '') + string_if_invalid = conf.get("OPTIONS", {}).get("string_if_invalid", "") if not isinstance(string_if_invalid, str): error = copy.copy(E002) - error.msg = error.msg.format(string_if_invalid, type(string_if_invalid).__name__) + error.msg = error.msg.format( + string_if_invalid, type(string_if_invalid).__name__ + ) errors.append(error) return errors @@ -47,7 +53,7 @@ def check_for_template_tags_with_the_same_name(app_configs, **kwargs): libraries = defaultdict(list) for conf in settings.TEMPLATES: - custom_libraries = conf.get('OPTIONS', {}).get('libraries', {}) + custom_libraries = conf.get("OPTIONS", {}).get("libraries", {}) for module_name, module_path in custom_libraries.items(): libraries[module_name].append(module_path) @@ -56,12 +62,14 @@ def check_for_template_tags_with_the_same_name(app_configs, **kwargs): for library_name, items in libraries.items(): if len(items) > 1: - errors.append(Error( - E003.msg.format( - repr(library_name), - ', '.join(repr(item) for item in items), - ), - id=E003.id, - )) + errors.append( + Error( + E003.msg.format( + repr(library_name), + ", ".join(repr(item) for item in items), + ), + id=E003.id, + ) + ) return errors diff --git a/django/core/checks/translation.py b/django/core/checks/translation.py index 8457a6b89d..214e970373 100644 --- a/django/core/checks/translation.py +++ b/django/core/checks/translation.py @@ -5,24 +5,24 @@ from django.utils.translation.trans_real import language_code_re from . import Error, Tags, register E001 = Error( - 'You have provided an invalid value for the LANGUAGE_CODE setting: {!r}.', - id='translation.E001', + "You have provided an invalid value for the LANGUAGE_CODE setting: {!r}.", + id="translation.E001", ) E002 = Error( - 'You have provided an invalid language code in the LANGUAGES setting: {!r}.', - id='translation.E002', + "You have provided an invalid language code in the LANGUAGES setting: {!r}.", + id="translation.E002", ) E003 = Error( - 'You have provided an invalid language code in the LANGUAGES_BIDI setting: {!r}.', - id='translation.E003', + "You have provided an invalid language code in the LANGUAGES_BIDI setting: {!r}.", + id="translation.E003", ) E004 = Error( - 'You have provided a value for the LANGUAGE_CODE setting that is not in ' - 'the LANGUAGES setting.', - id='translation.E004', + "You have provided a value for the LANGUAGE_CODE setting that is not in " + "the LANGUAGES setting.", + id="translation.E004", ) @@ -40,7 +40,8 @@ def check_setting_languages(app_configs, **kwargs): """Error if LANGUAGES setting is invalid.""" return [ Error(E002.msg.format(tag), id=E002.id) - for tag, _ in settings.LANGUAGES if not isinstance(tag, str) or not language_code_re.match(tag) + for tag, _ in settings.LANGUAGES + if not isinstance(tag, str) or not language_code_re.match(tag) ] @@ -49,7 +50,8 @@ def check_setting_languages_bidi(app_configs, **kwargs): """Error if LANGUAGES_BIDI setting is invalid.""" return [ Error(E003.msg.format(tag), id=E003.id) - for tag in settings.LANGUAGES_BIDI if not isinstance(tag, str) or not language_code_re.match(tag) + for tag in settings.LANGUAGES_BIDI + if not isinstance(tag, str) or not language_code_re.match(tag) ] diff --git a/django/core/checks/urls.py b/django/core/checks/urls.py index e51ca3fc1f..34eff9671d 100644 --- a/django/core/checks/urls.py +++ b/django/core/checks/urls.py @@ -7,8 +7,9 @@ from . import Error, Tags, Warning, register @register(Tags.urls) def check_url_config(app_configs, **kwargs): - if getattr(settings, 'ROOT_URLCONF', None): + if getattr(settings, "ROOT_URLCONF", None): from django.urls import get_resolver + resolver = get_resolver() return check_resolver(resolver) return [] @@ -18,10 +19,10 @@ def check_resolver(resolver): """ Recursively check the resolver. """ - check_method = getattr(resolver, 'check', None) + check_method = getattr(resolver, "check", None) if check_method is not None: return check_method() - elif not hasattr(resolver, 'resolve'): + elif not hasattr(resolver, "resolve"): return get_warning_for_invalid_pattern(resolver) else: return [] @@ -32,21 +33,24 @@ def check_url_namespaces_unique(app_configs, **kwargs): """ Warn if URL namespaces used in applications aren't unique. """ - if not getattr(settings, 'ROOT_URLCONF', None): + if not getattr(settings, "ROOT_URLCONF", None): return [] from django.urls import get_resolver + resolver = get_resolver() all_namespaces = _load_all_namespaces(resolver) counter = Counter(all_namespaces) non_unique_namespaces = [n for n, count in counter.items() if count > 1] errors = [] for namespace in non_unique_namespaces: - errors.append(Warning( - "URL namespace '{}' isn't unique. You may not be able to reverse " - "all URLs in this namespace".format(namespace), - id="urls.W005", - )) + errors.append( + Warning( + "URL namespace '{}' isn't unique. You may not be able to reverse " + "all URLs in this namespace".format(namespace), + id="urls.W005", + ) + ) return errors @@ -54,13 +58,14 @@ def _load_all_namespaces(resolver, parents=()): """ Recursively load all namespaces from URL patterns. """ - url_patterns = getattr(resolver, 'url_patterns', []) + url_patterns = getattr(resolver, "url_patterns", []) namespaces = [ - ':'.join(parents + (url.namespace,)) for url in url_patterns - if getattr(url, 'namespace', None) is not None + ":".join(parents + (url.namespace,)) + for url in url_patterns + if getattr(url, "namespace", None) is not None ] for pattern in url_patterns: - namespace = getattr(pattern, 'namespace', None) + namespace = getattr(pattern, "namespace", None) current = parents if namespace is not None: current += (namespace,) @@ -85,26 +90,28 @@ def get_warning_for_invalid_pattern(pattern): else: hint = None - return [Error( - "Your URL pattern {!r} is invalid. Ensure that urlpatterns is a list " - "of path() and/or re_path() instances.".format(pattern), - hint=hint, - id="urls.E004", - )] + return [ + Error( + "Your URL pattern {!r} is invalid. Ensure that urlpatterns is a list " + "of path() and/or re_path() instances.".format(pattern), + hint=hint, + id="urls.E004", + ) + ] @register(Tags.urls) def check_url_settings(app_configs, **kwargs): errors = [] - for name in ('STATIC_URL', 'MEDIA_URL'): + for name in ("STATIC_URL", "MEDIA_URL"): value = getattr(settings, name) - if value and not value.endswith('/'): + if value and not value.endswith("/"): errors.append(E006(name)) return errors def E006(name): return Error( - 'The {} setting must end with a slash.'.format(name), - id='urls.E006', + "The {} setting must end with a slash.".format(name), + id="urls.E006", ) diff --git a/django/core/exceptions.py b/django/core/exceptions.py index 673d004d57..7be4e16bc5 100644 --- a/django/core/exceptions.py +++ b/django/core/exceptions.py @@ -8,21 +8,25 @@ from django.utils.hashable import make_hashable class FieldDoesNotExist(Exception): """The requested model field does not exist""" + pass class AppRegistryNotReady(Exception): """The django.apps registry is not populated yet""" + pass class ObjectDoesNotExist(Exception): """The requested object does not exist""" + silent_variable_failure = True class MultipleObjectsReturned(Exception): """The query returned multiple objects when only one was expected.""" + pass @@ -32,21 +36,25 @@ class SuspiciousOperation(Exception): class SuspiciousMultipartForm(SuspiciousOperation): """Suspect MIME request in multipart form data""" + pass class SuspiciousFileOperation(SuspiciousOperation): """A Suspicious filesystem operation was attempted""" + pass class DisallowedHost(SuspiciousOperation): """HTTP_HOST header contains invalid value""" + pass class DisallowedRedirect(SuspiciousOperation): """Redirect to scheme not in allowed list""" + pass @@ -55,6 +63,7 @@ class TooManyFieldsSent(SuspiciousOperation): The number of fields in a GET or POST request exceeded settings.DATA_UPLOAD_MAX_NUMBER_FIELDS. """ + pass @@ -63,49 +72,58 @@ class RequestDataTooBig(SuspiciousOperation): The size of the request (excluding any file uploads) exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE. """ + pass class RequestAborted(Exception): """The request was closed before it was completed, or timed out.""" + pass class BadRequest(Exception): """The request is malformed and cannot be processed.""" + pass class PermissionDenied(Exception): """The user did not have permission to do that""" + pass class ViewDoesNotExist(Exception): """The requested view does not exist""" + pass class MiddlewareNotUsed(Exception): """This middleware is not used in this server configuration""" + pass class ImproperlyConfigured(Exception): """Django is somehow improperly configured""" + pass class FieldError(Exception): """Some kind of problem with a model field.""" + pass -NON_FIELD_ERRORS = '__all__' +NON_FIELD_ERRORS = "__all__" class ValidationError(Exception): """An error while validating data.""" + def __init__(self, message, code=None, params=None): """ The `message` argument can be a single error, a list of errors, or a @@ -118,9 +136,9 @@ class ValidationError(Exception): super().__init__(message, code, params) if isinstance(message, ValidationError): - if hasattr(message, 'error_dict'): + if hasattr(message, "error_dict"): message = message.error_dict - elif not hasattr(message, 'message'): + elif not hasattr(message, "message"): message = message.error_list else: message, code, params = message.message, message.code, message.params @@ -138,7 +156,7 @@ class ValidationError(Exception): # Normalize plain strings to instances of ValidationError. if not isinstance(message, ValidationError): message = ValidationError(message) - if hasattr(message, 'error_dict'): + if hasattr(message, "error_dict"): self.error_list.extend(sum(message.error_dict.values(), [])) else: self.error_list.extend(message.error_list) @@ -153,18 +171,18 @@ class ValidationError(Exception): def message_dict(self): # Trigger an AttributeError if this ValidationError # doesn't have an error_dict. - getattr(self, 'error_dict') + getattr(self, "error_dict") return dict(self) @property def messages(self): - if hasattr(self, 'error_dict'): + if hasattr(self, "error_dict"): return sum(dict(self).values(), []) return list(self) def update_error_dict(self, error_dict): - if hasattr(self, 'error_dict'): + if hasattr(self, "error_dict"): for field, error_list in self.error_dict.items(): error_dict.setdefault(field, []).extend(error_list) else: @@ -172,7 +190,7 @@ class ValidationError(Exception): return error_dict def __iter__(self): - if hasattr(self, 'error_dict'): + if hasattr(self, "error_dict"): for field, errors in self.error_dict.items(): yield field, list(ValidationError(errors)) else: @@ -183,12 +201,12 @@ class ValidationError(Exception): yield str(message) def __str__(self): - if hasattr(self, 'error_dict'): + if hasattr(self, "error_dict"): return repr(dict(self)) return repr(list(self)) def __repr__(self): - return 'ValidationError(%s)' % self + return "ValidationError(%s)" % self def __eq__(self, other): if not isinstance(other, ValidationError): @@ -196,22 +214,26 @@ class ValidationError(Exception): return hash(self) == hash(other) def __hash__(self): - if hasattr(self, 'message'): - return hash(( - self.message, - self.code, - make_hashable(self.params), - )) - if hasattr(self, 'error_dict'): + if hasattr(self, "message"): + return hash( + ( + self.message, + self.code, + make_hashable(self.params), + ) + ) + if hasattr(self, "error_dict"): return hash(make_hashable(self.error_dict)) - return hash(tuple(sorted(self.error_list, key=operator.attrgetter('message')))) + return hash(tuple(sorted(self.error_list, key=operator.attrgetter("message")))) class EmptyResultSet(Exception): """A database query predicate is impossible.""" + pass class SynchronousOnlyOperation(Exception): """The user tried to call a sync-only function from an async context.""" + pass diff --git a/django/core/files/__init__.py b/django/core/files/__init__.py index 58a6fd8f85..d046aca084 100644 --- a/django/core/files/__init__.py +++ b/django/core/files/__init__.py @@ -1,3 +1,3 @@ from django.core.files.base import File -__all__ = ['File'] +__all__ = ["File"] diff --git a/django/core/files/base.py b/django/core/files/base.py index 2ac662ed7c..3ca43ec254 100644 --- a/django/core/files/base.py +++ b/django/core/files/base.py @@ -6,18 +6,18 @@ from django.utils.functional import cached_property class File(FileProxyMixin): - DEFAULT_CHUNK_SIZE = 64 * 2 ** 10 + DEFAULT_CHUNK_SIZE = 64 * 2**10 def __init__(self, file, name=None): self.file = file if name is None: - name = getattr(file, 'name', None) + name = getattr(file, "name", None) self.name = name - if hasattr(file, 'mode'): + if hasattr(file, "mode"): self.mode = file.mode def __str__(self): - return self.name or '' + return self.name or "" def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self or "None") @@ -30,14 +30,14 @@ class File(FileProxyMixin): @cached_property def size(self): - if hasattr(self.file, 'size'): + if hasattr(self.file, "size"): return self.file.size - if hasattr(self.file, 'name'): + if hasattr(self.file, "name"): try: return os.path.getsize(self.file.name) except (OSError, TypeError): pass - if hasattr(self.file, 'tell') and hasattr(self.file, 'seek'): + if hasattr(self.file, "tell") and hasattr(self.file, "seek"): pos = self.file.tell() self.file.seek(0, os.SEEK_END) size = self.file.tell() @@ -122,13 +122,14 @@ class ContentFile(File): """ A File-like object that takes just raw content, rather than an actual file. """ + def __init__(self, content, name=None): stream_class = StringIO if isinstance(content, str) else BytesIO super().__init__(stream_class(content), name=name) self.size = len(content) def __str__(self): - return 'Raw content' + return "Raw content" def __bool__(self): return True @@ -141,20 +142,20 @@ class ContentFile(File): pass def write(self, data): - self.__dict__.pop('size', None) # Clear the computed size. + self.__dict__.pop("size", None) # Clear the computed size. return self.file.write(data) def endswith_cr(line): """Return True if line (a text or bytestring) ends with '\r'.""" - return line.endswith('\r' if isinstance(line, str) else b'\r') + return line.endswith("\r" if isinstance(line, str) else b"\r") def endswith_lf(line): """Return True if line (a text or bytestring) ends with '\n'.""" - return line.endswith('\n' if isinstance(line, str) else b'\n') + return line.endswith("\n" if isinstance(line, str) else b"\n") def equals_lf(line): """Return True if line (a text or bytestring) equals '\n'.""" - return line == ('\n' if isinstance(line, str) else b'\n') + return line == ("\n" if isinstance(line, str) else b"\n") diff --git a/django/core/files/images.py b/django/core/files/images.py index a1252f2c8b..6a603f24fc 100644 --- a/django/core/files/images.py +++ b/django/core/files/images.py @@ -14,6 +14,7 @@ class ImageFile(File): A mixin for use alongside django.core.files.base.File, which provides additional features for dealing with images. """ + @property def width(self): return self._get_image_dimensions()[0] @@ -23,7 +24,7 @@ class ImageFile(File): return self._get_image_dimensions()[1] def _get_image_dimensions(self): - if not hasattr(self, '_dimensions_cache'): + if not hasattr(self, "_dimensions_cache"): close = self.closed self.open() self._dimensions_cache = get_image_dimensions(self, close=close) @@ -39,13 +40,13 @@ def get_image_dimensions(file_or_path, close=False): from PIL import ImageFile as PillowImageFile p = PillowImageFile.Parser() - if hasattr(file_or_path, 'read'): + if hasattr(file_or_path, "read"): file = file_or_path file_pos = file.tell() file.seek(0) else: try: - file = open(file_or_path, 'rb') + file = open(file_or_path, "rb") except OSError: return (None, None) close = True diff --git a/django/core/files/locks.py b/django/core/files/locks.py index fdb9665332..da1a25fcf0 100644 --- a/django/core/files/locks.py +++ b/django/core/files/locks.py @@ -18,18 +18,25 @@ Example Usage:: """ import os -__all__ = ('LOCK_EX', 'LOCK_SH', 'LOCK_NB', 'lock', 'unlock') +__all__ = ("LOCK_EX", "LOCK_SH", "LOCK_NB", "lock", "unlock") def _fd(f): """Get a filedescriptor from something which could be a file or an fd.""" - return f.fileno() if hasattr(f, 'fileno') else f + return f.fileno() if hasattr(f, "fileno") else f -if os.name == 'nt': +if os.name == "nt": import msvcrt from ctypes import ( - POINTER, Structure, Union, byref, c_int64, c_ulong, c_void_p, sizeof, + POINTER, + Structure, + Union, + byref, + c_int64, + c_ulong, + c_void_p, + sizeof, windll, ) from ctypes.wintypes import BOOL, DWORD, HANDLE @@ -48,23 +55,20 @@ if os.name == 'nt': # --- Union inside Structure by stackoverflow:3480240 --- class _OFFSET(Structure): - _fields_ = [ - ('Offset', DWORD), - ('OffsetHigh', DWORD)] + _fields_ = [("Offset", DWORD), ("OffsetHigh", DWORD)] class _OFFSET_UNION(Union): - _anonymous_ = ['_offset'] - _fields_ = [ - ('_offset', _OFFSET), - ('Pointer', PVOID)] + _anonymous_ = ["_offset"] + _fields_ = [("_offset", _OFFSET), ("Pointer", PVOID)] class OVERLAPPED(Structure): - _anonymous_ = ['_offset_union'] + _anonymous_ = ["_offset_union"] _fields_ = [ - ('Internal', ULONG_PTR), - ('InternalHigh', ULONG_PTR), - ('_offset_union', _OFFSET_UNION), - ('hEvent', HANDLE)] + ("Internal", ULONG_PTR), + ("InternalHigh", ULONG_PTR), + ("_offset_union", _OFFSET_UNION), + ("hEvent", HANDLE), + ] LPOVERLAPPED = POINTER(OVERLAPPED) @@ -87,9 +91,11 @@ if os.name == 'nt': overlapped = OVERLAPPED() ret = UnlockFileEx(hfile, 0, 0, 0xFFFF0000, byref(overlapped)) return bool(ret) + else: try: import fcntl + LOCK_SH = fcntl.LOCK_SH # shared lock LOCK_NB = fcntl.LOCK_NB # non-blocking LOCK_EX = fcntl.LOCK_EX @@ -105,7 +111,9 @@ else: def unlock(f): # File is unlocked return True + else: + def lock(f, flags): try: fcntl.flock(_fd(f), flags) diff --git a/django/core/files/move.py b/django/core/files/move.py index 2cce7848ca..2d71e11885 100644 --- a/django/core/files/move.py +++ b/django/core/files/move.py @@ -11,23 +11,26 @@ from shutil import copystat from django.core.files import locks -__all__ = ['file_move_safe'] +__all__ = ["file_move_safe"] def _samefile(src, dst): # Macintosh, Unix. - if hasattr(os.path, 'samefile'): + if hasattr(os.path, "samefile"): try: return os.path.samefile(src, dst) except OSError: return False # All other platforms: check for same pathname. - return (os.path.normcase(os.path.abspath(src)) == - os.path.normcase(os.path.abspath(dst))) + return os.path.normcase(os.path.abspath(src)) == os.path.normcase( + os.path.abspath(dst) + ) -def file_move_safe(old_file_name, new_file_name, chunk_size=1024 * 64, allow_overwrite=False): +def file_move_safe( + old_file_name, new_file_name, chunk_size=1024 * 64, allow_overwrite=False +): """ Move a file from one location to another in the safest way possible. @@ -43,7 +46,10 @@ def file_move_safe(old_file_name, new_file_name, chunk_size=1024 * 64, allow_ove try: if not allow_overwrite and os.access(new_file_name, os.F_OK): - raise FileExistsError('Destination file %s exists and allow_overwrite is False.' % new_file_name) + raise FileExistsError( + "Destination file %s exists and allow_overwrite is False." + % new_file_name + ) os.rename(old_file_name, new_file_name) return @@ -53,14 +59,21 @@ def file_move_safe(old_file_name, new_file_name, chunk_size=1024 * 64, allow_ove pass # first open the old file, so that it won't go away - with open(old_file_name, 'rb') as old_file: + with open(old_file_name, "rb") as old_file: # now open the new file, not forgetting allow_overwrite - fd = os.open(new_file_name, (os.O_WRONLY | os.O_CREAT | getattr(os, 'O_BINARY', 0) | - (os.O_EXCL if not allow_overwrite else 0))) + fd = os.open( + new_file_name, + ( + os.O_WRONLY + | os.O_CREAT + | getattr(os, "O_BINARY", 0) + | (os.O_EXCL if not allow_overwrite else 0) + ), + ) try: locks.lock(fd, locks.LOCK_EX) current_chunk = None - while current_chunk != b'': + while current_chunk != b"": current_chunk = old_file.read(chunk_size) os.write(fd, current_chunk) finally: @@ -83,5 +96,5 @@ def file_move_safe(old_file_name, new_file_name, chunk_size=1024 * 64, allow_ove # fail when deleting opened files, ignore it. (For the # systems where this happens, temporary files will be auto-deleted # on close anyway.) - if getattr(e, 'winerror', 0) != 32: + if getattr(e, "winerror", 0) != 32: raise diff --git a/django/core/files/storage.py b/django/core/files/storage.py index bbb3e8b21d..690456aa71 100644 --- a/django/core/files/storage.py +++ b/django/core/files/storage.py @@ -19,8 +19,11 @@ from django.utils.module_loading import import_string from django.utils.text import get_valid_filename __all__ = ( - 'Storage', 'FileSystemStorage', 'DefaultStorage', 'default_storage', - 'get_storage_class', + "Storage", + "FileSystemStorage", + "DefaultStorage", + "default_storage", + "get_storage_class", ) @@ -33,7 +36,7 @@ class Storage: # The following methods represent a public interface to private methods. # These shouldn't be overridden by subclasses unless absolutely necessary. - def open(self, name, mode='rb'): + def open(self, name, mode="rb"): """Retrieve the specified file from storage.""" return self._open(name, mode) @@ -47,7 +50,7 @@ class Storage: if name is None: name = content.name - if not hasattr(content, 'chunks'): + if not hasattr(content, "chunks"): content = File(content, name) name = self.get_available_name(name, max_length=max_length) @@ -71,17 +74,19 @@ class Storage: character alphanumeric string (before the file extension, if one exists) to the filename. """ - return '%s_%s%s' % (file_root, get_random_string(7), file_ext) + return "%s_%s%s" % (file_root, get_random_string(7), file_ext) def get_available_name(self, name, max_length=None): """ Return a filename that's free on the target storage system and available for new content to be written to. """ - name = str(name).replace('\\', '/') + name = str(name).replace("\\", "/") dir_name, file_name = os.path.split(name) - if '..' in pathlib.PurePath(dir_name).parts: - raise SuspiciousFileOperation("Detected path traversal attempt in '%s'" % dir_name) + if ".." in pathlib.PurePath(dir_name).parts: + raise SuspiciousFileOperation( + "Detected path traversal attempt in '%s'" % dir_name + ) validate_file_name(file_name) file_root, file_ext = os.path.splitext(file_name) # If the filename already exists, generate an alternative filename @@ -90,7 +95,9 @@ class Storage: # exceed the max_length. while self.exists(name) or (max_length and len(name) > max_length): # file_ext includes the dot. - name = os.path.join(dir_name, self.get_alternative_name(file_root, file_ext)) + name = os.path.join( + dir_name, self.get_alternative_name(file_root, file_ext) + ) if max_length is None: continue # Truncate file_root if max_length exceeded. @@ -101,10 +108,12 @@ class Storage: if not file_root: raise SuspiciousFileOperation( 'Storage can not find an available filename for "%s". ' - 'Please make sure that the corresponding file field ' + "Please make sure that the corresponding file field " 'allows sufficient "max_length".' % name ) - name = os.path.join(dir_name, self.get_alternative_name(file_root, file_ext)) + name = os.path.join( + dir_name, self.get_alternative_name(file_root, file_ext) + ) return name def generate_filename(self, filename): @@ -112,11 +121,13 @@ class Storage: Validate the filename by calling get_valid_name() and return a filename to be passed to the save() method. """ - filename = str(filename).replace('\\', '/') + filename = str(filename).replace("\\", "/") # `filename` may include a path as returned by FileField.upload_to. dirname, filename = os.path.split(filename) - if '..' in pathlib.PurePath(dirname).parts: - raise SuspiciousFileOperation("Detected path traversal attempt in '%s'" % dirname) + if ".." in pathlib.PurePath(dirname).parts: + raise SuspiciousFileOperation( + "Detected path traversal attempt in '%s'" % dirname + ) return os.path.normpath(os.path.join(dirname, self.get_valid_name(filename))) def path(self, name): @@ -134,55 +145,67 @@ class Storage: """ Delete the specified file from the storage system. """ - raise NotImplementedError('subclasses of Storage must provide a delete() method') + raise NotImplementedError( + "subclasses of Storage must provide a delete() method" + ) def exists(self, name): """ Return True if a file referenced by the given name already exists in the storage system, or False if the name is available for a new file. """ - raise NotImplementedError('subclasses of Storage must provide an exists() method') + raise NotImplementedError( + "subclasses of Storage must provide an exists() method" + ) def listdir(self, path): """ List the contents of the specified path. Return a 2-tuple of lists: the first item being directories, the second item being files. """ - raise NotImplementedError('subclasses of Storage must provide a listdir() method') + raise NotImplementedError( + "subclasses of Storage must provide a listdir() method" + ) def size(self, name): """ Return the total size, in bytes, of the file specified by name. """ - raise NotImplementedError('subclasses of Storage must provide a size() method') + raise NotImplementedError("subclasses of Storage must provide a size() method") def url(self, name): """ Return an absolute URL where the file's contents can be accessed directly by a web browser. """ - raise NotImplementedError('subclasses of Storage must provide a url() method') + raise NotImplementedError("subclasses of Storage must provide a url() method") def get_accessed_time(self, name): """ Return the last accessed time (as a datetime) of the file specified by name. The datetime will be timezone-aware if USE_TZ=True. """ - raise NotImplementedError('subclasses of Storage must provide a get_accessed_time() method') + raise NotImplementedError( + "subclasses of Storage must provide a get_accessed_time() method" + ) def get_created_time(self, name): """ Return the creation time (as a datetime) of the file specified by name. The datetime will be timezone-aware if USE_TZ=True. """ - raise NotImplementedError('subclasses of Storage must provide a get_created_time() method') + raise NotImplementedError( + "subclasses of Storage must provide a get_created_time() method" + ) def get_modified_time(self, name): """ Return the last modified time (as a datetime) of the file specified by name. The datetime will be timezone-aware if USE_TZ=True. """ - raise NotImplementedError('subclasses of Storage must provide a get_modified_time() method') + raise NotImplementedError( + "subclasses of Storage must provide a get_modified_time() method" + ) @deconstructible @@ -190,12 +213,18 @@ class FileSystemStorage(Storage): """ Standard filesystem storage """ + # The combination of O_CREAT and O_EXCL makes os.open() raise OSError if # the file already exists before it's opened. - OS_OPEN_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, 'O_BINARY', 0) + OS_OPEN_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0) - def __init__(self, location=None, base_url=None, file_permissions_mode=None, - directory_permissions_mode=None): + def __init__( + self, + location=None, + base_url=None, + file_permissions_mode=None, + directory_permissions_mode=None, + ): self._location = location self._base_url = base_url self._file_permissions_mode = file_permissions_mode @@ -204,15 +233,15 @@ class FileSystemStorage(Storage): def _clear_cached_properties(self, setting, **kwargs): """Reset setting based property values.""" - if setting == 'MEDIA_ROOT': - self.__dict__.pop('base_location', None) - self.__dict__.pop('location', None) - elif setting == 'MEDIA_URL': - self.__dict__.pop('base_url', None) - elif setting == 'FILE_UPLOAD_PERMISSIONS': - self.__dict__.pop('file_permissions_mode', None) - elif setting == 'FILE_UPLOAD_DIRECTORY_PERMISSIONS': - self.__dict__.pop('directory_permissions_mode', None) + if setting == "MEDIA_ROOT": + self.__dict__.pop("base_location", None) + self.__dict__.pop("location", None) + elif setting == "MEDIA_URL": + self.__dict__.pop("base_url", None) + elif setting == "FILE_UPLOAD_PERMISSIONS": + self.__dict__.pop("file_permissions_mode", None) + elif setting == "FILE_UPLOAD_DIRECTORY_PERMISSIONS": + self.__dict__.pop("directory_permissions_mode", None) def _value_or_setting(self, value, setting): return setting if value is None else value @@ -227,19 +256,23 @@ class FileSystemStorage(Storage): @cached_property def base_url(self): - if self._base_url is not None and not self._base_url.endswith('/'): - self._base_url += '/' + if self._base_url is not None and not self._base_url.endswith("/"): + self._base_url += "/" return self._value_or_setting(self._base_url, settings.MEDIA_URL) @cached_property def file_permissions_mode(self): - return self._value_or_setting(self._file_permissions_mode, settings.FILE_UPLOAD_PERMISSIONS) + return self._value_or_setting( + self._file_permissions_mode, settings.FILE_UPLOAD_PERMISSIONS + ) @cached_property def directory_permissions_mode(self): - return self._value_or_setting(self._directory_permissions_mode, settings.FILE_UPLOAD_DIRECTORY_PERMISSIONS) + return self._value_or_setting( + self._directory_permissions_mode, settings.FILE_UPLOAD_DIRECTORY_PERMISSIONS + ) - def _open(self, name, mode='rb'): + def _open(self, name, mode="rb"): return File(open(self.path(name), mode)) def _save(self, name, content): @@ -253,13 +286,15 @@ class FileSystemStorage(Storage): # argument to intermediate-level directories. old_umask = os.umask(0o777 & ~self.directory_permissions_mode) try: - os.makedirs(directory, self.directory_permissions_mode, exist_ok=True) + os.makedirs( + directory, self.directory_permissions_mode, exist_ok=True + ) finally: os.umask(old_umask) else: os.makedirs(directory, exist_ok=True) except FileExistsError: - raise FileExistsError('%s exists and is not a directory.' % directory) + raise FileExistsError("%s exists and is not a directory." % directory) # There's a potential race condition between get_available_name and # saving the file; it's possible that two threads might return the @@ -270,7 +305,7 @@ class FileSystemStorage(Storage): while True: try: # This file has a file path that we can move. - if hasattr(content, 'temporary_file_path'): + if hasattr(content, "temporary_file_path"): file_move_safe(content.temporary_file_path(), full_path) # This is a normal uploadedfile that we can stream. @@ -282,7 +317,7 @@ class FileSystemStorage(Storage): locks.lock(fd, locks.LOCK_EX) for chunk in content.chunks(): if _file is None: - mode = 'wb' if isinstance(chunk, bytes) else 'wt' + mode = "wb" if isinstance(chunk, bytes) else "wt" _file = os.fdopen(fd, mode) _file.write(chunk) finally: @@ -305,11 +340,11 @@ class FileSystemStorage(Storage): # Ensure the saved path is always relative to the storage root. name = os.path.relpath(full_path, self.location) # Store filenames with forward slashes, even on Windows. - return str(name).replace('\\', '/') + return str(name).replace("\\", "/") def delete(self, name): if not name: - raise ValueError('The name must be given to delete().') + raise ValueError("The name must be given to delete().") name = self.path(name) # If the file or directory exists, delete it from the filesystem. try: @@ -347,7 +382,7 @@ class FileSystemStorage(Storage): raise ValueError("This file is not accessible via a URL.") url = filepath_to_uri(name) if url is not None: - url = url.lstrip('/') + url = url.lstrip("/") return urljoin(self.base_url, url) def _datetime_from_timestamp(self, ts): diff --git a/django/core/files/temp.py b/django/core/files/temp.py index 57a8107b37..5bd31dd5f2 100644 --- a/django/core/files/temp.py +++ b/django/core/files/temp.py @@ -21,10 +21,14 @@ import tempfile from django.core.files.utils import FileProxyMixin -__all__ = ('NamedTemporaryFile', 'gettempdir',) +__all__ = ( + "NamedTemporaryFile", + "gettempdir", +) -if os.name == 'nt': +if os.name == "nt": + class TemporaryFile(FileProxyMixin): """ Temporary file object constructor that supports reopening of the @@ -34,7 +38,8 @@ if os.name == 'nt': __init__() doesn't support the 'delete', 'buffering', 'encoding', or 'newline' keyword arguments. """ - def __init__(self, mode='w+b', bufsize=-1, suffix='', prefix='', dir=None): + + def __init__(self, mode="w+b", bufsize=-1, suffix="", prefix="", dir=None): fd, name = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=dir) self.name = name self.file = os.fdopen(fd, mode, bufsize) diff --git a/django/core/files/uploadedfile.py b/django/core/files/uploadedfile.py index f452bcd9a4..efbfcac4c8 100644 --- a/django/core/files/uploadedfile.py +++ b/django/core/files/uploadedfile.py @@ -10,8 +10,12 @@ from django.core.files import temp as tempfile from django.core.files.base import File from django.core.files.utils import validate_file_name -__all__ = ('UploadedFile', 'TemporaryUploadedFile', 'InMemoryUploadedFile', - 'SimpleUploadedFile') +__all__ = ( + "UploadedFile", + "TemporaryUploadedFile", + "InMemoryUploadedFile", + "SimpleUploadedFile", +) class UploadedFile(File): @@ -23,7 +27,15 @@ class UploadedFile(File): represents some file data that the user submitted with a form. """ - def __init__(self, file=None, name=None, content_type=None, size=None, charset=None, content_type_extra=None): + def __init__( + self, + file=None, + name=None, + content_type=None, + size=None, + charset=None, + content_type_extra=None, + ): super().__init__(file, name) self.size = size self.content_type = content_type @@ -46,7 +58,7 @@ class UploadedFile(File): if len(name) > 255: name, ext = os.path.splitext(name) ext = ext[:255] - name = name[:255 - len(ext)] + ext + name = name[: 255 - len(ext)] + ext name = validate_file_name(name) @@ -59,9 +71,12 @@ class TemporaryUploadedFile(UploadedFile): """ A file uploaded to a temporary location (i.e. stream-to-disk). """ + def __init__(self, name, content_type, size, charset, content_type_extra=None): _, ext = os.path.splitext(name) - file = tempfile.NamedTemporaryFile(suffix='.upload' + ext, dir=settings.FILE_UPLOAD_TEMP_DIR) + file = tempfile.NamedTemporaryFile( + suffix=".upload" + ext, dir=settings.FILE_UPLOAD_TEMP_DIR + ) super().__init__(file, name, content_type, size, charset, content_type_extra) def temporary_file_path(self): @@ -82,7 +97,17 @@ class InMemoryUploadedFile(UploadedFile): """ A file uploaded into memory (i.e. stream-to-memory). """ - def __init__(self, file, field_name, name, content_type, size, charset, content_type_extra=None): + + def __init__( + self, + file, + field_name, + name, + content_type, + size, + charset, + content_type_extra=None, + ): super().__init__(file, name, content_type, size, charset, content_type_extra) self.field_name = field_name @@ -103,9 +128,12 @@ class SimpleUploadedFile(InMemoryUploadedFile): """ A simple representation of a file, which just has content, size, and a name. """ - def __init__(self, name, content, content_type='text/plain'): - content = content or b'' - super().__init__(BytesIO(content), None, name, content_type, len(content), None, None) + + def __init__(self, name, content, content_type="text/plain"): + content = content or b"" + super().__init__( + BytesIO(content), None, name, content_type, len(content), None, None + ) @classmethod def from_dict(cls, file_dict): @@ -115,6 +143,8 @@ class SimpleUploadedFile(InMemoryUploadedFile): - content-type - content """ - return cls(file_dict['filename'], - file_dict['content'], - file_dict.get('content-type', 'text/plain')) + return cls( + file_dict["filename"], + file_dict["content"], + file_dict.get("content-type", "text/plain"), + ) diff --git a/django/core/files/uploadhandler.py b/django/core/files/uploadhandler.py index ee6bb31fce..64781f811b 100644 --- a/django/core/files/uploadhandler.py +++ b/django/core/files/uploadhandler.py @@ -5,15 +5,18 @@ import os from io import BytesIO from django.conf import settings -from django.core.files.uploadedfile import ( - InMemoryUploadedFile, TemporaryUploadedFile, -) +from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUploadedFile from django.utils.module_loading import import_string __all__ = [ - 'UploadFileException', 'StopUpload', 'SkipFile', 'FileUploadHandler', - 'TemporaryFileUploadHandler', 'MemoryFileUploadHandler', 'load_handler', - 'StopFutureHandlers' + "UploadFileException", + "StopUpload", + "SkipFile", + "FileUploadHandler", + "TemporaryFileUploadHandler", + "MemoryFileUploadHandler", + "load_handler", + "StopFutureHandlers", ] @@ -21,6 +24,7 @@ class UploadFileException(Exception): """ Any error having to do with uploading files. """ + pass @@ -28,6 +32,7 @@ class StopUpload(UploadFileException): """ This exception is raised when an upload must abort. """ + def __init__(self, connection_reset=False): """ If ``connection_reset`` is ``True``, Django knows will halt the upload @@ -38,15 +43,16 @@ class StopUpload(UploadFileException): def __str__(self): if self.connection_reset: - return 'StopUpload: Halt current upload.' + return "StopUpload: Halt current upload." else: - return 'StopUpload: Consume request data, then halt.' + return "StopUpload: Consume request data, then halt." class SkipFile(UploadFileException): """ This exception is raised by an upload handler that wants to skip a given file. """ + pass @@ -55,6 +61,7 @@ class StopFutureHandlers(UploadFileException): Upload handlers that have handled a file and do not want future handlers to run should raise this exception instead of returning None. """ + pass @@ -62,7 +69,8 @@ class FileUploadHandler: """ Base class for streaming upload handlers. """ - chunk_size = 64 * 2 ** 10 # : The default chunk size is 64 KB. + + chunk_size = 64 * 2**10 # : The default chunk size is 64 KB. def __init__(self, request=None): self.file_name = None @@ -72,7 +80,9 @@ class FileUploadHandler: self.content_type_extra = None self.request = request - def handle_raw_input(self, input_data, META, content_length, boundary, encoding=None): + def handle_raw_input( + self, input_data, META, content_length, boundary, encoding=None + ): """ Handle the raw input from the client. @@ -90,7 +100,15 @@ class FileUploadHandler: """ pass - def new_file(self, field_name, file_name, content_type, content_length, charset=None, content_type_extra=None): + def new_file( + self, + field_name, + file_name, + content_type, + content_length, + charset=None, + content_type_extra=None, + ): """ Signal that a new file has been started. @@ -109,7 +127,9 @@ class FileUploadHandler: Receive data from the streamed upload parser. ``start`` is the position in the file of the chunk. """ - raise NotImplementedError('subclasses of FileUploadHandler must provide a receive_data_chunk() method') + raise NotImplementedError( + "subclasses of FileUploadHandler must provide a receive_data_chunk() method" + ) def file_complete(self, file_size): """ @@ -118,7 +138,9 @@ class FileUploadHandler: Subclasses should return a valid ``UploadedFile`` object. """ - raise NotImplementedError('subclasses of FileUploadHandler must provide a file_complete() method') + raise NotImplementedError( + "subclasses of FileUploadHandler must provide a file_complete() method" + ) def upload_complete(self): """ @@ -139,12 +161,15 @@ class TemporaryFileUploadHandler(FileUploadHandler): """ Upload handler that streams data into a temporary file. """ + def new_file(self, *args, **kwargs): """ Create the file object to append to as data is coming in. """ super().new_file(*args, **kwargs) - self.file = TemporaryUploadedFile(self.file_name, self.content_type, 0, self.charset, self.content_type_extra) + self.file = TemporaryUploadedFile( + self.file_name, self.content_type, 0, self.charset, self.content_type_extra + ) def receive_data_chunk(self, raw_data, start): self.file.write(raw_data) @@ -155,7 +180,7 @@ class TemporaryFileUploadHandler(FileUploadHandler): return self.file def upload_interrupted(self): - if hasattr(self, 'file'): + if hasattr(self, "file"): temp_location = self.file.temporary_file_path() try: self.file.close() @@ -169,7 +194,9 @@ class MemoryFileUploadHandler(FileUploadHandler): File upload handler to stream uploads into memory (used for small files). """ - def handle_raw_input(self, input_data, META, content_length, boundary, encoding=None): + def handle_raw_input( + self, input_data, META, content_length, boundary, encoding=None + ): """ Use the content_length to signal whether or not this handler should be used. @@ -204,7 +231,7 @@ class MemoryFileUploadHandler(FileUploadHandler): content_type=self.content_type, size=file_size, charset=self.charset, - content_type_extra=self.content_type_extra + content_type_extra=self.content_type_extra, ) diff --git a/django/core/files/utils.py b/django/core/files/utils.py index f28cea1077..85342b2f3f 100644 --- a/django/core/files/utils.py +++ b/django/core/files/utils.py @@ -6,7 +6,7 @@ from django.core.exceptions import SuspiciousFileOperation def validate_file_name(name, allow_relative_path=False): # Remove potentially dangerous names - if os.path.basename(name) in {'', '.', '..'}: + if os.path.basename(name) in {"", ".", ".."}: raise SuspiciousFileOperation("Could not derive file name from '%s'" % name) if allow_relative_path: @@ -14,7 +14,7 @@ def validate_file_name(name, allow_relative_path=False): # FileField.generate_filename() where all file paths are expected to be # Unix style (with forward slashes). path = pathlib.PurePosixPath(name) - if path.is_absolute() or '..' in path.parts: + if path.is_absolute() or ".." in path.parts: raise SuspiciousFileOperation( "Detected path traversal attempt in '%s'" % name ) @@ -56,21 +56,21 @@ class FileProxyMixin: def readable(self): if self.closed: return False - if hasattr(self.file, 'readable'): + if hasattr(self.file, "readable"): return self.file.readable() return True def writable(self): if self.closed: return False - if hasattr(self.file, 'writable'): + if hasattr(self.file, "writable"): return self.file.writable() - return 'w' in getattr(self.file, 'mode', '') + return "w" in getattr(self.file, "mode", "") def seekable(self): if self.closed: return False - if hasattr(self.file, 'seekable'): + if hasattr(self.file, "seekable"): return self.file.seekable() return True diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 9d84efc964..7b17c58153 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -10,13 +10,18 @@ from django.core import signals from django.core.exceptions import RequestAborted, RequestDataTooBig from django.core.handlers import base from django.http import ( - FileResponse, HttpRequest, HttpResponse, HttpResponseBadRequest, - HttpResponseServerError, QueryDict, parse_cookie, + FileResponse, + HttpRequest, + HttpResponse, + HttpResponseBadRequest, + HttpResponseServerError, + QueryDict, + parse_cookie, ) from django.urls import set_script_prefix from django.utils.functional import cached_property -logger = logging.getLogger('django.request') +logger = logging.getLogger("django.request") class ASGIRequest(HttpRequest): @@ -24,6 +29,7 @@ class ASGIRequest(HttpRequest): Custom request subclass that decodes from an ASGI-standard request dict and wraps request body handling. """ + # Number of seconds until a Request gives up on trying to read a request # body and aborts. body_receive_timeout = 60 @@ -33,60 +39,60 @@ class ASGIRequest(HttpRequest): self._post_parse_error = False self._read_started = False self.resolver_match = None - self.script_name = self.scope.get('root_path', '') - if self.script_name and scope['path'].startswith(self.script_name): + self.script_name = self.scope.get("root_path", "") + if self.script_name and scope["path"].startswith(self.script_name): # TODO: Better is-prefix checking, slash handling? - self.path_info = scope['path'][len(self.script_name):] + self.path_info = scope["path"][len(self.script_name) :] else: - self.path_info = scope['path'] + self.path_info = scope["path"] # The Django path is different from ASGI scope path args, it should # combine with script name. if self.script_name: - self.path = '%s/%s' % ( - self.script_name.rstrip('/'), - self.path_info.replace('/', '', 1), + self.path = "%s/%s" % ( + self.script_name.rstrip("/"), + self.path_info.replace("/", "", 1), ) else: - self.path = scope['path'] + self.path = scope["path"] # HTTP basics. - self.method = self.scope['method'].upper() + self.method = self.scope["method"].upper() # Ensure query string is encoded correctly. - query_string = self.scope.get('query_string', '') + query_string = self.scope.get("query_string", "") if isinstance(query_string, bytes): query_string = query_string.decode() self.META = { - 'REQUEST_METHOD': self.method, - 'QUERY_STRING': query_string, - 'SCRIPT_NAME': self.script_name, - 'PATH_INFO': self.path_info, + "REQUEST_METHOD": self.method, + "QUERY_STRING": query_string, + "SCRIPT_NAME": self.script_name, + "PATH_INFO": self.path_info, # WSGI-expecting code will need these for a while - 'wsgi.multithread': True, - 'wsgi.multiprocess': True, + "wsgi.multithread": True, + "wsgi.multiprocess": True, } - if self.scope.get('client'): - self.META['REMOTE_ADDR'] = self.scope['client'][0] - self.META['REMOTE_HOST'] = self.META['REMOTE_ADDR'] - self.META['REMOTE_PORT'] = self.scope['client'][1] - if self.scope.get('server'): - self.META['SERVER_NAME'] = self.scope['server'][0] - self.META['SERVER_PORT'] = str(self.scope['server'][1]) + if self.scope.get("client"): + self.META["REMOTE_ADDR"] = self.scope["client"][0] + self.META["REMOTE_HOST"] = self.META["REMOTE_ADDR"] + self.META["REMOTE_PORT"] = self.scope["client"][1] + if self.scope.get("server"): + self.META["SERVER_NAME"] = self.scope["server"][0] + self.META["SERVER_PORT"] = str(self.scope["server"][1]) else: - self.META['SERVER_NAME'] = 'unknown' - self.META['SERVER_PORT'] = '0' + self.META["SERVER_NAME"] = "unknown" + self.META["SERVER_PORT"] = "0" # Headers go into META. - for name, value in self.scope.get('headers', []): - name = name.decode('latin1') - if name == 'content-length': - corrected_name = 'CONTENT_LENGTH' - elif name == 'content-type': - corrected_name = 'CONTENT_TYPE' + for name, value in self.scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" else: - corrected_name = 'HTTP_%s' % name.upper().replace('-', '_') + corrected_name = "HTTP_%s" % name.upper().replace("-", "_") # HTTP/2 say only ASCII chars are allowed in headers, but decode # latin1 just in case. - value = value.decode('latin1') + value = value.decode("latin1") if corrected_name in self.META: - value = self.META[corrected_name] + ',' + value + value = self.META[corrected_name] + "," + value self.META[corrected_name] = value # Pull out request encoding, if provided. self._set_content_type_params(self.META) @@ -97,13 +103,13 @@ class ASGIRequest(HttpRequest): @cached_property def GET(self): - return QueryDict(self.META['QUERY_STRING']) + return QueryDict(self.META["QUERY_STRING"]) def _get_scheme(self): - return self.scope.get('scheme') or super()._get_scheme() + return self.scope.get("scheme") or super()._get_scheme() def _get_post(self): - if not hasattr(self, '_post'): + if not hasattr(self, "_post"): self._load_post_and_files() return self._post @@ -111,7 +117,7 @@ class ASGIRequest(HttpRequest): self._post = post def _get_files(self): - if not hasattr(self, '_files'): + if not hasattr(self, "_files"): self._load_post_and_files() return self._files @@ -120,14 +126,15 @@ class ASGIRequest(HttpRequest): @cached_property def COOKIES(self): - return parse_cookie(self.META.get('HTTP_COOKIE', '')) + return parse_cookie(self.META.get("HTTP_COOKIE", "")) class ASGIHandler(base.BaseHandler): """Handler for ASGI requests.""" + request_class = ASGIRequest # Size to chunk response bodies into for multiple response messages. - chunk_size = 2 ** 16 + chunk_size = 2**16 def __init__(self): super().__init__() @@ -139,10 +146,9 @@ class ASGIHandler(base.BaseHandler): """ # Serve only HTTP connections. # FIXME: Allow to override this. - if scope['type'] != 'http': + if scope["type"] != "http": raise ValueError( - 'Django can only handle ASGI/HTTP connections, not %s.' - % scope['type'] + "Django can only handle ASGI/HTTP connections, not %s." % scope["type"] ) async with ThreadSensitiveContext(): @@ -159,7 +165,9 @@ class ASGIHandler(base.BaseHandler): return # Request is complete and can be served. set_script_prefix(self.get_script_prefix(scope)) - await sync_to_async(signals.request_started.send, thread_sensitive=True)(sender=self.__class__, scope=scope) + await sync_to_async(signals.request_started.send, thread_sensitive=True)( + sender=self.__class__, scope=scope + ) # Get the request and check for basic issues. request, error_response = self.create_request(scope, body_file) if request is None: @@ -178,17 +186,19 @@ class ASGIHandler(base.BaseHandler): async def read_body(self, receive): """Reads an HTTP body from an ASGI connection.""" # Use the tempfile that auto rolls-over to a disk file as it fills up. - body_file = tempfile.SpooledTemporaryFile(max_size=settings.FILE_UPLOAD_MAX_MEMORY_SIZE, mode='w+b') + body_file = tempfile.SpooledTemporaryFile( + max_size=settings.FILE_UPLOAD_MAX_MEMORY_SIZE, mode="w+b" + ) while True: message = await receive() - if message['type'] == 'http.disconnect': + if message["type"] == "http.disconnect": # Early client disconnect. raise RequestAborted() # Add a body chunk from the message, if provided. - if 'body' in message: - body_file.write(message['body']) + if "body" in message: + body_file.write(message["body"]) # Quit out if that's the end. - if not message.get('more_body', False): + if not message.get("more_body", False): break body_file.seek(0) return body_file @@ -202,13 +212,13 @@ class ASGIHandler(base.BaseHandler): return self.request_class(scope, body_file), None except UnicodeDecodeError: logger.warning( - 'Bad Request (UnicodeDecodeError)', + "Bad Request (UnicodeDecodeError)", exc_info=sys.exc_info(), - extra={'status_code': 400}, + extra={"status_code": 400}, ) return None, HttpResponseBadRequest() except RequestDataTooBig: - return None, HttpResponse('413 Payload too large', status=413) + return None, HttpResponse("413 Payload too large", status=413) def handle_uncaught_exception(self, request, resolver, exc_info): """Last-chance handler for exceptions.""" @@ -218,8 +228,8 @@ class ASGIHandler(base.BaseHandler): return super().handle_uncaught_exception(request, resolver, exc_info) except Exception: return HttpResponseServerError( - traceback.format_exc() if settings.DEBUG else 'Internal Server Error', - content_type='text/plain', + traceback.format_exc() if settings.DEBUG else "Internal Server Error", + content_type="text/plain", ) async def send_response(self, response, send): @@ -229,44 +239,50 @@ class ASGIHandler(base.BaseHandler): response_headers = [] for header, value in response.items(): if isinstance(header, str): - header = header.encode('ascii') + header = header.encode("ascii") if isinstance(value, str): - value = value.encode('latin1') + value = value.encode("latin1") response_headers.append((bytes(header), bytes(value))) for c in response.cookies.values(): response_headers.append( - (b'Set-Cookie', c.output(header='').encode('ascii').strip()) + (b"Set-Cookie", c.output(header="").encode("ascii").strip()) ) # Initial response message. - await send({ - 'type': 'http.response.start', - 'status': response.status_code, - 'headers': response_headers, - }) + await send( + { + "type": "http.response.start", + "status": response.status_code, + "headers": response_headers, + } + ) # Streaming responses need to be pinned to their iterator. if response.streaming: # Access `__iter__` and not `streaming_content` directly in case # it has been overridden in a subclass. for part in response: for chunk, _ in self.chunk_bytes(part): - await send({ - 'type': 'http.response.body', - 'body': chunk, - # Ignore "more" as there may be more parts; instead, - # use an empty final closing message with False. - 'more_body': True, - }) + await send( + { + "type": "http.response.body", + "body": chunk, + # Ignore "more" as there may be more parts; instead, + # use an empty final closing message with False. + "more_body": True, + } + ) # Final closing message. - await send({'type': 'http.response.body'}) + await send({"type": "http.response.body"}) # Other responses just need chunking. else: # Yield chunks of response. for chunk, last in self.chunk_bytes(response.content): - await send({ - 'type': 'http.response.body', - 'body': chunk, - 'more_body': not last, - }) + await send( + { + "type": "http.response.body", + "body": chunk, + "more_body": not last, + } + ) await sync_to_async(response.close, thread_sensitive=True)() @classmethod @@ -281,7 +297,7 @@ class ASGIHandler(base.BaseHandler): return while position < len(data): yield ( - data[position:position + cls.chunk_size], + data[position : position + cls.chunk_size], (position + cls.chunk_size) >= len(data), ) position += cls.chunk_size @@ -292,4 +308,4 @@ class ASGIHandler(base.BaseHandler): """ if settings.FORCE_SCRIPT_NAME: return settings.FORCE_SCRIPT_NAME - return scope.get('root_path', '') or '' + return scope.get("root_path", "") or "" diff --git a/django/core/handlers/base.py b/django/core/handlers/base.py index 728e449703..7c863bb5c1 100644 --- a/django/core/handlers/base.py +++ b/django/core/handlers/base.py @@ -14,7 +14,7 @@ from django.utils.module_loading import import_string from .exception import convert_exception_to_response -logger = logging.getLogger('django.request') +logger = logging.getLogger("django.request") class BaseHandler: @@ -38,12 +38,12 @@ class BaseHandler: handler_is_async = is_async for middleware_path in reversed(settings.MIDDLEWARE): middleware = import_string(middleware_path) - middleware_can_sync = getattr(middleware, 'sync_capable', True) - middleware_can_async = getattr(middleware, 'async_capable', False) + middleware_can_sync = getattr(middleware, "sync_capable", True) + middleware_can_async = getattr(middleware, "async_capable", False) if not middleware_can_sync and not middleware_can_async: raise RuntimeError( - 'Middleware %s must have at least one of ' - 'sync_capable/async_capable set to True.' % middleware_path + "Middleware %s must have at least one of " + "sync_capable/async_capable set to True." % middleware_path ) elif not handler_is_async and middleware_can_sync: middleware_is_async = False @@ -52,35 +52,40 @@ class BaseHandler: try: # Adapt handler, if needed. adapted_handler = self.adapt_method_mode( - middleware_is_async, handler, handler_is_async, - debug=settings.DEBUG, name='middleware %s' % middleware_path, + middleware_is_async, + handler, + handler_is_async, + debug=settings.DEBUG, + name="middleware %s" % middleware_path, ) mw_instance = middleware(adapted_handler) except MiddlewareNotUsed as exc: if settings.DEBUG: if str(exc): - logger.debug('MiddlewareNotUsed(%r): %s', middleware_path, exc) + logger.debug("MiddlewareNotUsed(%r): %s", middleware_path, exc) else: - logger.debug('MiddlewareNotUsed: %r', middleware_path) + logger.debug("MiddlewareNotUsed: %r", middleware_path) continue else: handler = adapted_handler if mw_instance is None: raise ImproperlyConfigured( - 'Middleware factory %s returned None.' % middleware_path + "Middleware factory %s returned None." % middleware_path ) - if hasattr(mw_instance, 'process_view'): + if hasattr(mw_instance, "process_view"): self._view_middleware.insert( 0, self.adapt_method_mode(is_async, mw_instance.process_view), ) - if hasattr(mw_instance, 'process_template_response'): + if hasattr(mw_instance, "process_template_response"): self._template_response_middleware.append( - self.adapt_method_mode(is_async, mw_instance.process_template_response), + self.adapt_method_mode( + is_async, mw_instance.process_template_response + ), ) - if hasattr(mw_instance, 'process_exception'): + if hasattr(mw_instance, "process_exception"): # The exception-handling stack is still always synchronous for # now, so adapt that way. self._exception_middleware.append( @@ -97,7 +102,12 @@ class BaseHandler: self._middleware_chain = handler def adapt_method_mode( - self, is_async, method, method_is_async=None, debug=False, name=None, + self, + is_async, + method, + method_is_async=None, + debug=False, + name=None, ): """ Adapt a method to be in the correct "mode": @@ -111,15 +121,15 @@ class BaseHandler: if method_is_async is None: method_is_async = asyncio.iscoroutinefunction(method) if debug and not name: - name = name or 'method %s()' % method.__qualname__ + name = name or "method %s()" % method.__qualname__ if is_async: if not method_is_async: if debug: - logger.debug('Synchronous %s adapted.', name) + logger.debug("Synchronous %s adapted.", name) return sync_to_async(method, thread_sensitive=True) elif method_is_async: if debug: - logger.debug('Asynchronous %s adapted.', name) + logger.debug("Asynchronous %s adapted.", name) return async_to_sync(method) return method @@ -131,7 +141,9 @@ class BaseHandler: response._resource_closers.append(request.close) if response.status_code >= 400: log_response( - '%s: %s', response.reason_phrase, request.path, + "%s: %s", + response.reason_phrase, + request.path, response=response, request=request, ) @@ -151,7 +163,9 @@ class BaseHandler: response._resource_closers.append(request.close) if response.status_code >= 400: await sync_to_async(log_response, thread_sensitive=False)( - '%s: %s', response.reason_phrase, request.path, + "%s: %s", + response.reason_phrase, + request.path, response=response, request=request, ) @@ -168,7 +182,9 @@ class BaseHandler: # Apply view middleware for middleware_method in self._view_middleware: - response = middleware_method(request, callback, callback_args, callback_kwargs) + response = middleware_method( + request, callback, callback_args, callback_kwargs + ) if response: break @@ -189,16 +205,15 @@ class BaseHandler: # If the response supports deferred rendering, apply template # response middleware and then render the response - if hasattr(response, 'render') and callable(response.render): + if hasattr(response, "render") and callable(response.render): for middleware_method in self._template_response_middleware: response = middleware_method(request, response) # Complain if the template response middleware returned None (a common error). self.check_response( response, middleware_method, - name='%s.process_template_response' % ( - middleware_method.__self__.__class__.__name__, - ) + name="%s.process_template_response" + % (middleware_method.__self__.__class__.__name__,), ) try: response = response.render() @@ -220,7 +235,9 @@ class BaseHandler: # Apply view middleware. for middleware_method in self._view_middleware: - response = await middleware_method(request, callback, callback_args, callback_kwargs) + response = await middleware_method( + request, callback, callback_args, callback_kwargs + ) if response: break @@ -228,9 +245,13 @@ class BaseHandler: wrapped_callback = self.make_view_atomic(callback) # If it is a synchronous view, run it in a subthread if not asyncio.iscoroutinefunction(wrapped_callback): - wrapped_callback = sync_to_async(wrapped_callback, thread_sensitive=True) + wrapped_callback = sync_to_async( + wrapped_callback, thread_sensitive=True + ) try: - response = await wrapped_callback(request, *callback_args, **callback_kwargs) + response = await wrapped_callback( + request, *callback_args, **callback_kwargs + ) except Exception as e: response = await sync_to_async( self.process_exception_by_middleware, @@ -244,7 +265,7 @@ class BaseHandler: # If the response supports deferred rendering, apply template # response middleware and then render the response - if hasattr(response, 'render') and callable(response.render): + if hasattr(response, "render") and callable(response.render): for middleware_method in self._template_response_middleware: response = await middleware_method(request, response) # Complain if the template response middleware returned None or @@ -252,15 +273,16 @@ class BaseHandler: self.check_response( response, middleware_method, - name='%s.process_template_response' % ( - middleware_method.__self__.__class__.__name__, - ) + name="%s.process_template_response" + % (middleware_method.__self__.__class__.__name__,), ) try: if asyncio.iscoroutinefunction(response.render): response = await response.render() else: - response = await sync_to_async(response.render, thread_sensitive=True)() + response = await sync_to_async( + response.render, thread_sensitive=True + )() except Exception as e: response = await sync_to_async( self.process_exception_by_middleware, @@ -271,7 +293,7 @@ class BaseHandler: # Make sure the response is not a coroutine if asyncio.iscoroutine(response): - raise RuntimeError('Response is still a coroutine.') + raise RuntimeError("Response is still a coroutine.") return response def resolve_request(self, request): @@ -280,7 +302,7 @@ class BaseHandler: with its args and kwargs. """ # Work out the resolver. - if hasattr(request, 'urlconf'): + if hasattr(request, "urlconf"): urlconf = request.urlconf set_urlconf(urlconf) resolver = get_resolver(urlconf) @@ -295,13 +317,13 @@ class BaseHandler: """ Raise an error if the view returned None or an uncalled coroutine. """ - if not(response is None or asyncio.iscoroutine(response)): + if not (response is None or asyncio.iscoroutine(response)): return if not name: if isinstance(callback, types.FunctionType): # FBV - name = 'The view %s.%s' % (callback.__module__, callback.__name__) + name = "The view %s.%s" % (callback.__module__, callback.__name__) else: # CBV - name = 'The view %s.%s.__call__' % ( + name = "The view %s.%s.__call__" % ( callback.__module__, callback.__class__.__name__, ) @@ -320,12 +342,15 @@ class BaseHandler: # Other utility methods. def make_view_atomic(self, view): - non_atomic_requests = getattr(view, '_non_atomic_requests', set()) + non_atomic_requests = getattr(view, "_non_atomic_requests", set()) for db in connections.all(): - if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: + if ( + db.settings_dict["ATOMIC_REQUESTS"] + and db.alias not in non_atomic_requests + ): if asyncio.iscoroutinefunction(view): raise RuntimeError( - 'You cannot use ATOMIC_REQUESTS with async views.' + "You cannot use ATOMIC_REQUESTS with async views." ) view = transaction.atomic(using=db.alias)(view) return view diff --git a/django/core/handlers/exception.py b/django/core/handlers/exception.py index 5470b3dd53..79577c2d0a 100644 --- a/django/core/handlers/exception.py +++ b/django/core/handlers/exception.py @@ -8,7 +8,10 @@ from asgiref.sync import sync_to_async from django.conf import settings from django.core import signals from django.core.exceptions import ( - BadRequest, PermissionDenied, RequestDataTooBig, SuspiciousOperation, + BadRequest, + PermissionDenied, + RequestDataTooBig, + SuspiciousOperation, TooManyFieldsSent, ) from django.http import Http404 @@ -32,15 +35,20 @@ def convert_exception_to_response(get_response): can rely on getting a response instead of an exception. """ if asyncio.iscoroutinefunction(get_response): + @wraps(get_response) async def inner(request): try: response = await get_response(request) except Exception as exc: - response = await sync_to_async(response_for_exception, thread_sensitive=False)(request, exc) + response = await sync_to_async( + response_for_exception, thread_sensitive=False + )(request, exc) return response + return inner else: + @wraps(get_response) def inner(request): try: @@ -48,6 +56,7 @@ def convert_exception_to_response(get_response): except Exception as exc: response = response_for_exception(request, exc) return response + return inner @@ -56,21 +65,29 @@ def response_for_exception(request, exc): if settings.DEBUG: response = debug.technical_404_response(request, exc) else: - response = get_exception_response(request, get_resolver(get_urlconf()), 404, exc) + response = get_exception_response( + request, get_resolver(get_urlconf()), 404, exc + ) elif isinstance(exc, PermissionDenied): - response = get_exception_response(request, get_resolver(get_urlconf()), 403, exc) + response = get_exception_response( + request, get_resolver(get_urlconf()), 403, exc + ) log_response( - 'Forbidden (Permission denied): %s', request.path, + "Forbidden (Permission denied): %s", + request.path, response=response, request=request, exception=exc, ) elif isinstance(exc, MultiPartParserError): - response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc) + response = get_exception_response( + request, get_resolver(get_urlconf()), 400, exc + ) log_response( - 'Bad request (Unable to parse request body): %s', request.path, + "Bad request (Unable to parse request body): %s", + request.path, response=response, request=request, exception=exc, @@ -78,11 +95,17 @@ def response_for_exception(request, exc): elif isinstance(exc, BadRequest): if settings.DEBUG: - response = debug.technical_500_response(request, *sys.exc_info(), status_code=400) + response = debug.technical_500_response( + request, *sys.exc_info(), status_code=400 + ) else: - response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc) + response = get_exception_response( + request, get_resolver(get_urlconf()), 400, exc + ) log_response( - '%s: %s', str(exc), request.path, + "%s: %s", + str(exc), + request.path, response=response, request=request, exception=exc, @@ -95,29 +118,41 @@ def response_for_exception(request, exc): # The request logger receives events for any problematic request # The security logger receives events for all SuspiciousOperations - security_logger = logging.getLogger('django.security.%s' % exc.__class__.__name__) + security_logger = logging.getLogger( + "django.security.%s" % exc.__class__.__name__ + ) security_logger.error( str(exc), exc_info=exc, - extra={'status_code': 400, 'request': request}, + extra={"status_code": 400, "request": request}, ) if settings.DEBUG: - response = debug.technical_500_response(request, *sys.exc_info(), status_code=400) + response = debug.technical_500_response( + request, *sys.exc_info(), status_code=400 + ) else: - response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc) + response = get_exception_response( + request, get_resolver(get_urlconf()), 400, exc + ) else: signals.got_request_exception.send(sender=None, request=request) - response = handle_uncaught_exception(request, get_resolver(get_urlconf()), sys.exc_info()) + response = handle_uncaught_exception( + request, get_resolver(get_urlconf()), sys.exc_info() + ) log_response( - '%s: %s', response.reason_phrase, request.path, + "%s: %s", + response.reason_phrase, + request.path, response=response, request=request, exception=exc, ) # Force a TemplateResponse to be rendered. - if not getattr(response, 'is_rendered', True) and callable(getattr(response, 'render', None)): + if not getattr(response, "is_rendered", True) and callable( + getattr(response, "render", None) + ): response = response.render() return response diff --git a/django/core/handlers/wsgi.py b/django/core/handlers/wsgi.py index 30920da6d7..126e795fab 100644 --- a/django/core/handlers/wsgi.py +++ b/django/core/handlers/wsgi.py @@ -9,21 +9,22 @@ from django.utils.encoding import repercent_broken_unicode from django.utils.functional import cached_property from django.utils.regex_helper import _lazy_re_compile -_slashes_re = _lazy_re_compile(br'/+') +_slashes_re = _lazy_re_compile(rb"/+") class LimitedStream: """Wrap another stream to disallow reading it past a number of bytes.""" + def __init__(self, stream, limit): self.stream = stream self.remaining = limit - self.buffer = b'' + self.buffer = b"" def _read_limited(self, size=None): if size is None or size > self.remaining: size = self.remaining if size == 0: - return b'' + return b"" result = self.stream.read(size) self.remaining -= len(result) return result @@ -31,18 +32,17 @@ class LimitedStream: def read(self, size=None): if size is None: result = self.buffer + self._read_limited() - self.buffer = b'' + self.buffer = b"" elif size < len(self.buffer): result = self.buffer[:size] self.buffer = self.buffer[size:] else: # size >= len(self.buffer) result = self.buffer + self._read_limited(size - len(self.buffer)) - self.buffer = b'' + self.buffer = b"" return result def readline(self, size=None): - while b'\n' not in self.buffer and \ - (size is None or len(self.buffer) < size): + while b"\n" not in self.buffer and (size is None or len(self.buffer) < size): if size: # since size is not None here, len(self.buffer) < size chunk = self._read_limited(size - len(self.buffer)) @@ -65,39 +65,38 @@ class WSGIRequest(HttpRequest): script_name = get_script_name(environ) # If PATH_INFO is empty (e.g. accessing the SCRIPT_NAME URL without a # trailing slash), operate as if '/' was requested. - path_info = get_path_info(environ) or '/' + path_info = get_path_info(environ) or "/" self.environ = environ self.path_info = path_info # be careful to only replace the first slash in the path because of # http://test/something and http://test//something being different as # stated in https://www.ietf.org/rfc/rfc2396.txt - self.path = '%s/%s' % (script_name.rstrip('/'), - path_info.replace('/', '', 1)) + self.path = "%s/%s" % (script_name.rstrip("/"), path_info.replace("/", "", 1)) self.META = environ - self.META['PATH_INFO'] = path_info - self.META['SCRIPT_NAME'] = script_name - self.method = environ['REQUEST_METHOD'].upper() + self.META["PATH_INFO"] = path_info + self.META["SCRIPT_NAME"] = script_name + self.method = environ["REQUEST_METHOD"].upper() # Set content_type, content_params, and encoding. self._set_content_type_params(environ) try: - content_length = int(environ.get('CONTENT_LENGTH')) + content_length = int(environ.get("CONTENT_LENGTH")) except (ValueError, TypeError): content_length = 0 - self._stream = LimitedStream(self.environ['wsgi.input'], content_length) + self._stream = LimitedStream(self.environ["wsgi.input"], content_length) self._read_started = False self.resolver_match = None def _get_scheme(self): - return self.environ.get('wsgi.url_scheme') + return self.environ.get("wsgi.url_scheme") @cached_property def GET(self): # The WSGI spec says 'QUERY_STRING' may be absent. - raw_query_string = get_bytes_from_wsgi(self.environ, 'QUERY_STRING', '') + raw_query_string = get_bytes_from_wsgi(self.environ, "QUERY_STRING", "") return QueryDict(raw_query_string, encoding=self._encoding) def _get_post(self): - if not hasattr(self, '_post'): + if not hasattr(self, "_post"): self._load_post_and_files() return self._post @@ -106,12 +105,12 @@ class WSGIRequest(HttpRequest): @cached_property def COOKIES(self): - raw_cookie = get_str_from_wsgi(self.environ, 'HTTP_COOKIE', '') + raw_cookie = get_str_from_wsgi(self.environ, "HTTP_COOKIE", "") return parse_cookie(raw_cookie) @property def FILES(self): - if not hasattr(self, '_files'): + if not hasattr(self, "_files"): self._load_post_and_files() return self._files @@ -133,24 +132,28 @@ class WSGIHandler(base.BaseHandler): response._handler_class = self.__class__ - status = '%d %s' % (response.status_code, response.reason_phrase) + status = "%d %s" % (response.status_code, response.reason_phrase) response_headers = [ *response.items(), - *(('Set-Cookie', c.output(header='')) for c in response.cookies.values()), + *(("Set-Cookie", c.output(header="")) for c in response.cookies.values()), ] start_response(status, response_headers) - if getattr(response, 'file_to_stream', None) is not None and environ.get('wsgi.file_wrapper'): + if getattr(response, "file_to_stream", None) is not None and environ.get( + "wsgi.file_wrapper" + ): # If `wsgi.file_wrapper` is used the WSGI server does not call # .close on the response, but on the file wrapper. Patch it to use # response.close instead which takes care of closing all files. response.file_to_stream.close = response.close - response = environ['wsgi.file_wrapper'](response.file_to_stream, response.block_size) + response = environ["wsgi.file_wrapper"]( + response.file_to_stream, response.block_size + ) return response def get_path_info(environ): """Return the HTTP request's PATH_INFO as a string.""" - path_info = get_bytes_from_wsgi(environ, 'PATH_INFO', '/') + path_info = get_bytes_from_wsgi(environ, "PATH_INFO", "/") return repercent_broken_unicode(path_info).decode() @@ -171,17 +174,19 @@ def get_script_name(environ): # rewrites. Unfortunately not every web server (lighttpd!) passes this # information through all the time, so FORCE_SCRIPT_NAME, above, is still # needed. - script_url = get_bytes_from_wsgi(environ, 'SCRIPT_URL', '') or get_bytes_from_wsgi(environ, 'REDIRECT_URL', '') + script_url = get_bytes_from_wsgi(environ, "SCRIPT_URL", "") or get_bytes_from_wsgi( + environ, "REDIRECT_URL", "" + ) if script_url: - if b'//' in script_url: + if b"//" in script_url: # mod_wsgi squashes multiple successive slashes in PATH_INFO, # do the same with script_url before manipulating paths (#17133). - script_url = _slashes_re.sub(b'/', script_url) - path_info = get_bytes_from_wsgi(environ, 'PATH_INFO', '') - script_name = script_url[:-len(path_info)] if path_info else script_url + script_url = _slashes_re.sub(b"/", script_url) + path_info = get_bytes_from_wsgi(environ, "PATH_INFO", "") + script_name = script_url[: -len(path_info)] if path_info else script_url else: - script_name = get_bytes_from_wsgi(environ, 'SCRIPT_NAME', '') + script_name = get_bytes_from_wsgi(environ, "SCRIPT_NAME", "") return script_name.decode() @@ -196,7 +201,7 @@ def get_bytes_from_wsgi(environ, key, default): # Non-ASCII values in the WSGI environ are arbitrarily decoded with # ISO-8859-1. This is wrong for Django websites where UTF-8 is the default. # Re-encode to recover the original bytestring. - return value.encode('iso-8859-1') + return value.encode("iso-8859-1") def get_str_from_wsgi(environ, key, default): @@ -206,4 +211,4 @@ def get_str_from_wsgi(environ, key, default): key and default should be str objects. """ value = get_bytes_from_wsgi(environ, key, default) - return value.decode(errors='replace') + return value.decode(errors="replace") diff --git a/django/core/mail/__init__.py b/django/core/mail/__init__.py index f49cd07dce..dc63e8702c 100644 --- a/django/core/mail/__init__.py +++ b/django/core/mail/__init__.py @@ -2,24 +2,40 @@ Tools for sending email. """ from django.conf import settings + # Imported for backwards compatibility and for the sake # of a cleaner namespace. These symbols used to be in # django/core/mail.py before the introduction of email # backends and the subsequent reorganization (See #10355) from django.core.mail.message import ( - DEFAULT_ATTACHMENT_MIME_TYPE, BadHeaderError, EmailMessage, - EmailMultiAlternatives, SafeMIMEMultipart, SafeMIMEText, - forbid_multi_line_headers, make_msgid, + DEFAULT_ATTACHMENT_MIME_TYPE, + BadHeaderError, + EmailMessage, + EmailMultiAlternatives, + SafeMIMEMultipart, + SafeMIMEText, + forbid_multi_line_headers, + make_msgid, ) from django.core.mail.utils import DNS_NAME, CachedDnsName from django.utils.module_loading import import_string __all__ = [ - 'CachedDnsName', 'DNS_NAME', 'EmailMessage', 'EmailMultiAlternatives', - 'SafeMIMEText', 'SafeMIMEMultipart', 'DEFAULT_ATTACHMENT_MIME_TYPE', - 'make_msgid', 'BadHeaderError', 'forbid_multi_line_headers', - 'get_connection', 'send_mail', 'send_mass_mail', 'mail_admins', - 'mail_managers', + "CachedDnsName", + "DNS_NAME", + "EmailMessage", + "EmailMultiAlternatives", + "SafeMIMEText", + "SafeMIMEMultipart", + "DEFAULT_ATTACHMENT_MIME_TYPE", + "make_msgid", + "BadHeaderError", + "forbid_multi_line_headers", + "get_connection", + "send_mail", + "send_mass_mail", + "mail_admins", + "mail_managers", ] @@ -35,9 +51,17 @@ def get_connection(backend=None, fail_silently=False, **kwds): return klass(fail_silently=fail_silently, **kwds) -def send_mail(subject, message, from_email, recipient_list, - fail_silently=False, auth_user=None, auth_password=None, - connection=None, html_message=None): +def send_mail( + subject, + message, + from_email, + recipient_list, + fail_silently=False, + auth_user=None, + auth_password=None, + connection=None, + html_message=None, +): """ Easy wrapper for sending a single message to a recipient list. All members of the recipient list will see the other recipients in the 'To' field. @@ -54,15 +78,18 @@ def send_mail(subject, message, from_email, recipient_list, password=auth_password, fail_silently=fail_silently, ) - mail = EmailMultiAlternatives(subject, message, from_email, recipient_list, connection=connection) + mail = EmailMultiAlternatives( + subject, message, from_email, recipient_list, connection=connection + ) if html_message: - mail.attach_alternative(html_message, 'text/html') + mail.attach_alternative(html_message, "text/html") return mail.send() -def send_mass_mail(datatuple, fail_silently=False, auth_user=None, - auth_password=None, connection=None): +def send_mass_mail( + datatuple, fail_silently=False, auth_user=None, auth_password=None, connection=None +): """ Given a datatuple of (subject, message, from_email, recipient_list), send each message to each recipient list. Return the number of emails sent. @@ -87,35 +114,41 @@ def send_mass_mail(datatuple, fail_silently=False, auth_user=None, return connection.send_messages(messages) -def mail_admins(subject, message, fail_silently=False, connection=None, - html_message=None): +def mail_admins( + subject, message, fail_silently=False, connection=None, html_message=None +): """Send a message to the admins, as defined by the ADMINS setting.""" if not settings.ADMINS: return if not all(isinstance(a, (list, tuple)) and len(a) == 2 for a in settings.ADMINS): - raise ValueError('The ADMINS setting must be a list of 2-tuples.') + raise ValueError("The ADMINS setting must be a list of 2-tuples.") mail = EmailMultiAlternatives( - '%s%s' % (settings.EMAIL_SUBJECT_PREFIX, subject), message, - settings.SERVER_EMAIL, [a[1] for a in settings.ADMINS], + "%s%s" % (settings.EMAIL_SUBJECT_PREFIX, subject), + message, + settings.SERVER_EMAIL, + [a[1] for a in settings.ADMINS], connection=connection, ) if html_message: - mail.attach_alternative(html_message, 'text/html') + mail.attach_alternative(html_message, "text/html") mail.send(fail_silently=fail_silently) -def mail_managers(subject, message, fail_silently=False, connection=None, - html_message=None): +def mail_managers( + subject, message, fail_silently=False, connection=None, html_message=None +): """Send a message to the managers, as defined by the MANAGERS setting.""" if not settings.MANAGERS: return if not all(isinstance(a, (list, tuple)) and len(a) == 2 for a in settings.MANAGERS): - raise ValueError('The MANAGERS setting must be a list of 2-tuples.') + raise ValueError("The MANAGERS setting must be a list of 2-tuples.") mail = EmailMultiAlternatives( - '%s%s' % (settings.EMAIL_SUBJECT_PREFIX, subject), message, - settings.SERVER_EMAIL, [a[1] for a in settings.MANAGERS], + "%s%s" % (settings.EMAIL_SUBJECT_PREFIX, subject), + message, + settings.SERVER_EMAIL, + [a[1] for a in settings.MANAGERS], connection=connection, ) if html_message: - mail.attach_alternative(html_message, 'text/html') + mail.attach_alternative(html_message, "text/html") mail.send(fail_silently=fail_silently) diff --git a/django/core/mail/backends/base.py b/django/core/mail/backends/base.py index d687703332..b35b964cb1 100644 --- a/django/core/mail/backends/base.py +++ b/django/core/mail/backends/base.py @@ -14,6 +14,7 @@ class BaseEmailBackend: # do something with connection pass """ + def __init__(self, fail_silently=False, **kwargs): self.fail_silently = fail_silently @@ -56,4 +57,6 @@ class BaseEmailBackend: Send one or more EmailMessage objects and return the number of email messages sent. """ - raise NotImplementedError('subclasses of BaseEmailBackend must override send_messages() method') + raise NotImplementedError( + "subclasses of BaseEmailBackend must override send_messages() method" + ) diff --git a/django/core/mail/backends/console.py b/django/core/mail/backends/console.py index a8bdcbd2c0..ee5dd28504 100644 --- a/django/core/mail/backends/console.py +++ b/django/core/mail/backends/console.py @@ -9,18 +9,20 @@ from django.core.mail.backends.base import BaseEmailBackend class EmailBackend(BaseEmailBackend): def __init__(self, *args, **kwargs): - self.stream = kwargs.pop('stream', sys.stdout) + self.stream = kwargs.pop("stream", sys.stdout) self._lock = threading.RLock() super().__init__(*args, **kwargs) def write_message(self, message): msg = message.message() msg_data = msg.as_bytes() - charset = msg.get_charset().get_output_charset() if msg.get_charset() else 'utf-8' + charset = ( + msg.get_charset().get_output_charset() if msg.get_charset() else "utf-8" + ) msg_data = msg_data.decode(charset) - self.stream.write('%s\n' % msg_data) - self.stream.write('-' * 79) - self.stream.write('\n') + self.stream.write("%s\n" % msg_data) + self.stream.write("-" * 79) + self.stream.write("\n") def send_messages(self, email_messages): """Write all messages to the stream in a thread-safe way.""" diff --git a/django/core/mail/backends/filebased.py b/django/core/mail/backends/filebased.py index 498d86fba8..3b2b037150 100644 --- a/django/core/mail/backends/filebased.py +++ b/django/core/mail/backends/filebased.py @@ -5,9 +5,7 @@ import os from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.core.mail.backends.console import ( - EmailBackend as ConsoleEmailBackend, -) +from django.core.mail.backends.console import EmailBackend as ConsoleEmailBackend class EmailBackend(ConsoleEmailBackend): @@ -16,31 +14,35 @@ class EmailBackend(ConsoleEmailBackend): if file_path is not None: self.file_path = file_path else: - self.file_path = getattr(settings, 'EMAIL_FILE_PATH', None) + self.file_path = getattr(settings, "EMAIL_FILE_PATH", None) self.file_path = os.path.abspath(self.file_path) try: os.makedirs(self.file_path, exist_ok=True) except FileExistsError: raise ImproperlyConfigured( - 'Path for saving email messages exists, but is not a directory: %s' % self.file_path + "Path for saving email messages exists, but is not a directory: %s" + % self.file_path ) except OSError as err: raise ImproperlyConfigured( - 'Could not create directory for saving email messages: %s (%s)' % (self.file_path, err) + "Could not create directory for saving email messages: %s (%s)" + % (self.file_path, err) ) # Make sure that self.file_path is writable. if not os.access(self.file_path, os.W_OK): - raise ImproperlyConfigured('Could not write to directory: %s' % self.file_path) + raise ImproperlyConfigured( + "Could not write to directory: %s" % self.file_path + ) # Finally, call super(). # Since we're using the console-based backend as a base, # force the stream to be None, so we don't default to stdout - kwargs['stream'] = None + kwargs["stream"] = None super().__init__(*args, **kwargs) def write_message(self, message): - self.stream.write(message.message().as_bytes() + b'\n') - self.stream.write(b'-' * 79) - self.stream.write(b'\n') + self.stream.write(message.message().as_bytes() + b"\n") + self.stream.write(b"-" * 79) + self.stream.write(b"\n") def _get_filename(self): """Return a unique file name.""" @@ -52,7 +54,7 @@ class EmailBackend(ConsoleEmailBackend): def open(self): if self.stream is None: - self.stream = open(self._get_filename(), 'ab') + self.stream = open(self._get_filename(), "ab") return True return False diff --git a/django/core/mail/backends/locmem.py b/django/core/mail/backends/locmem.py index 84732e997b..76676973a4 100644 --- a/django/core/mail/backends/locmem.py +++ b/django/core/mail/backends/locmem.py @@ -15,9 +15,10 @@ class EmailBackend(BaseEmailBackend): The dummy outbox is accessible through the outbox instance attribute. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not hasattr(mail, 'outbox'): + if not hasattr(mail, "outbox"): mail.outbox = [] def send_messages(self, messages): diff --git a/django/core/mail/backends/smtp.py b/django/core/mail/backends/smtp.py index 13ed4a2798..5df7c20ae0 100644 --- a/django/core/mail/backends/smtp.py +++ b/django/core/mail/backends/smtp.py @@ -13,10 +13,21 @@ class EmailBackend(BaseEmailBackend): """ A wrapper that manages the SMTP network connection. """ - def __init__(self, host=None, port=None, username=None, password=None, - use_tls=None, fail_silently=False, use_ssl=None, timeout=None, - ssl_keyfile=None, ssl_certfile=None, - **kwargs): + + def __init__( + self, + host=None, + port=None, + username=None, + password=None, + use_tls=None, + fail_silently=False, + use_ssl=None, + timeout=None, + ssl_keyfile=None, + ssl_certfile=None, + **kwargs, + ): super().__init__(fail_silently=fail_silently) self.host = host or settings.EMAIL_HOST self.port = port or settings.EMAIL_PORT @@ -25,12 +36,17 @@ class EmailBackend(BaseEmailBackend): self.use_tls = settings.EMAIL_USE_TLS if use_tls is None else use_tls self.use_ssl = settings.EMAIL_USE_SSL if use_ssl is None else use_ssl self.timeout = settings.EMAIL_TIMEOUT if timeout is None else timeout - self.ssl_keyfile = settings.EMAIL_SSL_KEYFILE if ssl_keyfile is None else ssl_keyfile - self.ssl_certfile = settings.EMAIL_SSL_CERTFILE if ssl_certfile is None else ssl_certfile + self.ssl_keyfile = ( + settings.EMAIL_SSL_KEYFILE if ssl_keyfile is None else ssl_keyfile + ) + self.ssl_certfile = ( + settings.EMAIL_SSL_CERTFILE if ssl_certfile is None else ssl_certfile + ) if self.use_ssl and self.use_tls: raise ValueError( "EMAIL_USE_TLS/EMAIL_USE_SSL are mutually exclusive, so only set " - "one of those settings to True.") + "one of those settings to True." + ) self.connection = None self._lock = threading.RLock() @@ -50,21 +66,27 @@ class EmailBackend(BaseEmailBackend): # If local_hostname is not specified, socket.getfqdn() gets used. # For performance, we use the cached FQDN for local_hostname. - connection_params = {'local_hostname': DNS_NAME.get_fqdn()} + connection_params = {"local_hostname": DNS_NAME.get_fqdn()} if self.timeout is not None: - connection_params['timeout'] = self.timeout + connection_params["timeout"] = self.timeout if self.use_ssl: - connection_params.update({ - 'keyfile': self.ssl_keyfile, - 'certfile': self.ssl_certfile, - }) + connection_params.update( + { + "keyfile": self.ssl_keyfile, + "certfile": self.ssl_certfile, + } + ) try: - self.connection = self.connection_class(self.host, self.port, **connection_params) + self.connection = self.connection_class( + self.host, self.port, **connection_params + ) # TLS/SSL are mutually exclusive, so only attempt TLS over # non-secure connections. if not self.use_ssl and self.use_tls: - self.connection.starttls(keyfile=self.ssl_keyfile, certfile=self.ssl_certfile) + self.connection.starttls( + keyfile=self.ssl_keyfile, certfile=self.ssl_certfile + ) if self.username and self.password: self.connection.login(self.username, self.password) return True @@ -119,10 +141,14 @@ class EmailBackend(BaseEmailBackend): return False encoding = email_message.encoding or settings.DEFAULT_CHARSET from_email = sanitize_address(email_message.from_email, encoding) - recipients = [sanitize_address(addr, encoding) for addr in email_message.recipients()] + recipients = [ + sanitize_address(addr, encoding) for addr in email_message.recipients() + ] message = email_message.message() try: - self.connection.sendmail(from_email, recipients, message.as_bytes(linesep='\r\n')) + self.connection.sendmail( + from_email, recipients, message.as_bytes(linesep="\r\n") + ) except smtplib.SMTPException: if not self.fail_silently: raise diff --git a/django/core/mail/message.py b/django/core/mail/message.py index ccc8a769ea..cd5b71ad51 100644 --- a/django/core/mail/message.py +++ b/django/core/mail/message.py @@ -1,7 +1,7 @@ import mimetypes -from email import ( - charset as Charset, encoders as Encoders, generator, message_from_string, -) +from email import charset as Charset +from email import encoders as Encoders +from email import generator, message_from_string from email.errors import HeaderParseError from email.header import Header from email.headerregistry import Address, parser @@ -20,14 +20,14 @@ from django.utils.encoding import force_str, punycode # Don't BASE64-encode UTF-8 messages so that we avoid unwanted attention from # some spam filters. -utf8_charset = Charset.Charset('utf-8') +utf8_charset = Charset.Charset("utf-8") utf8_charset.body_encoding = None # Python defaults to BASE64 -utf8_charset_qp = Charset.Charset('utf-8') +utf8_charset_qp = Charset.Charset("utf-8") utf8_charset_qp.body_encoding = Charset.QP # Default MIME type to use on attachments (if it is not explicitly given # and cannot be guessed). -DEFAULT_ATTACHMENT_MIME_TYPE = 'application/octet-stream' +DEFAULT_ATTACHMENT_MIME_TYPE = "application/octet-stream" RFC5322_EMAIL_LINE_LENGTH_LIMIT = 998 @@ -38,17 +38,17 @@ class BadHeaderError(ValueError): # Header names that contain structured address data (RFC #5322) ADDRESS_HEADERS = { - 'from', - 'sender', - 'reply-to', - 'to', - 'cc', - 'bcc', - 'resent-from', - 'resent-sender', - 'resent-to', - 'resent-cc', - 'resent-bcc', + "from", + "sender", + "reply-to", + "to", + "cc", + "bcc", + "resent-from", + "resent-sender", + "resent-to", + "resent-cc", + "resent-bcc", } @@ -56,17 +56,21 @@ def forbid_multi_line_headers(name, val, encoding): """Forbid multi-line headers to prevent header injection.""" encoding = encoding or settings.DEFAULT_CHARSET val = str(val) # val may be lazy - if '\n' in val or '\r' in val: - raise BadHeaderError("Header values can't contain newlines (got %r for header %r)" % (val, name)) + if "\n" in val or "\r" in val: + raise BadHeaderError( + "Header values can't contain newlines (got %r for header %r)" % (val, name) + ) try: - val.encode('ascii') + val.encode("ascii") except UnicodeEncodeError: if name.lower() in ADDRESS_HEADERS: - val = ', '.join(sanitize_address(addr, encoding) for addr in getaddresses((val,))) + val = ", ".join( + sanitize_address(addr, encoding) for addr in getaddresses((val,)) + ) else: val = Header(val, encoding).encode() else: - if name.lower() == 'subject': + if name.lower() == "subject": val = Header(val).encode() return name, val @@ -86,28 +90,27 @@ def sanitize_address(addr, encoding): if rest: # The entire email address must be parsed. raise ValueError( - 'Invalid address; only %s could be parsed from "%s"' - % (token, addr) + 'Invalid address; only %s could be parsed from "%s"' % (token, addr) ) - nm = token.display_name or '' + nm = token.display_name or "" localpart = token.local_part - domain = token.domain or '' + domain = token.domain or "" else: nm, address = addr - localpart, domain = address.rsplit('@', 1) + localpart, domain = address.rsplit("@", 1) address_parts = nm + localpart + domain - if '\n' in address_parts or '\r' in address_parts: - raise ValueError('Invalid address; address parts cannot contain newlines.') + if "\n" in address_parts or "\r" in address_parts: + raise ValueError("Invalid address; address parts cannot contain newlines.") # Avoid UTF-8 encode, if it's possible. try: - nm.encode('ascii') + nm.encode("ascii") nm = Header(nm).encode() except UnicodeEncodeError: nm = Header(nm, encoding).encode() try: - localpart.encode('ascii') + localpart.encode("ascii") except UnicodeEncodeError: localpart = Header(localpart, encoding).encode() domain = punycode(domain) @@ -117,7 +120,7 @@ def sanitize_address(addr, encoding): class MIMEMixin: - def as_string(self, unixfrom=False, linesep='\n'): + def as_string(self, unixfrom=False, linesep="\n"): """Return the entire formatted message as a string. Optional `unixfrom' when True, means include the Unix From_ envelope header. @@ -130,7 +133,7 @@ class MIMEMixin: g.flatten(self, unixfrom=unixfrom, linesep=linesep) return fp.getvalue() - def as_bytes(self, unixfrom=False, linesep='\n'): + def as_bytes(self, unixfrom=False, linesep="\n"): """Return the entire formatted message as bytes. Optional `unixfrom' when True, means include the Unix From_ envelope header. @@ -145,16 +148,14 @@ class MIMEMixin: class SafeMIMEMessage(MIMEMixin, MIMEMessage): - def __setitem__(self, name, val): # message/rfc822 attachments must be ASCII - name, val = forbid_multi_line_headers(name, val, 'ascii') + name, val = forbid_multi_line_headers(name, val, "ascii") MIMEMessage.__setitem__(self, name, val) class SafeMIMEText(MIMEMixin, MIMEText): - - def __init__(self, _text, _subtype='plain', _charset=None): + def __init__(self, _text, _subtype="plain", _charset=None): self.encoding = _charset MIMEText.__init__(self, _text, _subtype=_subtype, _charset=_charset) @@ -163,7 +164,7 @@ class SafeMIMEText(MIMEMixin, MIMEText): MIMEText.__setitem__(self, name, val) def set_payload(self, payload, charset=None): - if charset == 'utf-8' and not isinstance(charset, Charset.Charset): + if charset == "utf-8" and not isinstance(charset, Charset.Charset): has_long_lines = any( len(line.encode()) > RFC5322_EMAIL_LINE_LENGTH_LIMIT for line in payload.splitlines() @@ -175,8 +176,9 @@ class SafeMIMEText(MIMEMixin, MIMEText): class SafeMIMEMultipart(MIMEMixin, MIMEMultipart): - - def __init__(self, _subtype='mixed', boundary=None, _subparts=None, encoding=None, **_params): + def __init__( + self, _subtype="mixed", boundary=None, _subparts=None, encoding=None, **_params + ): self.encoding = encoding MIMEMultipart.__init__(self, _subtype, boundary, _subparts, **_params) @@ -187,13 +189,24 @@ class SafeMIMEMultipart(MIMEMixin, MIMEMultipart): class EmailMessage: """A container for email information.""" - content_subtype = 'plain' - mixed_subtype = 'mixed' - encoding = None # None => use settings default - def __init__(self, subject='', body='', from_email=None, to=None, bcc=None, - connection=None, attachments=None, headers=None, cc=None, - reply_to=None): + content_subtype = "plain" + mixed_subtype = "mixed" + encoding = None # None => use settings default + + def __init__( + self, + subject="", + body="", + from_email=None, + to=None, + bcc=None, + connection=None, + attachments=None, + headers=None, + cc=None, + reply_to=None, + ): """ Initialize a single email message (which can be sent to multiple recipients). @@ -224,7 +237,7 @@ class EmailMessage: self.reply_to = [] self.from_email = from_email or settings.DEFAULT_FROM_EMAIL self.subject = subject - self.body = body or '' + self.body = body or "" self.attachments = [] if attachments: for attachment in attachments: @@ -237,6 +250,7 @@ class EmailMessage: def get_connection(self, fail_silently=False): from django.core.mail import get_connection + if not self.connection: self.connection = get_connection(fail_silently=fail_silently) return self.connection @@ -245,26 +259,26 @@ class EmailMessage: encoding = self.encoding or settings.DEFAULT_CHARSET msg = SafeMIMEText(self.body, self.content_subtype, encoding) msg = self._create_message(msg) - msg['Subject'] = self.subject - msg['From'] = self.extra_headers.get('From', self.from_email) - self._set_list_header_if_not_empty(msg, 'To', self.to) - self._set_list_header_if_not_empty(msg, 'Cc', self.cc) - self._set_list_header_if_not_empty(msg, 'Reply-To', self.reply_to) + msg["Subject"] = self.subject + msg["From"] = self.extra_headers.get("From", self.from_email) + self._set_list_header_if_not_empty(msg, "To", self.to) + self._set_list_header_if_not_empty(msg, "Cc", self.cc) + self._set_list_header_if_not_empty(msg, "Reply-To", self.reply_to) # Email header names are case-insensitive (RFC 2045), so we have to # accommodate that when doing comparisons. header_names = [key.lower() for key in self.extra_headers] - if 'date' not in header_names: + if "date" not in header_names: # formatdate() uses stdlib methods to format the date, which use # the stdlib/OS concept of a timezone, however, Django sets the # TZ environment variable based on the TIME_ZONE setting which # will get picked up by formatdate(). - msg['Date'] = formatdate(localtime=settings.EMAIL_USE_LOCALTIME) - if 'message-id' not in header_names: + msg["Date"] = formatdate(localtime=settings.EMAIL_USE_LOCALTIME) + if "message-id" not in header_names: # Use cached DNS_NAME for performance - msg['Message-ID'] = make_msgid(domain=DNS_NAME) + msg["Message-ID"] = make_msgid(domain=DNS_NAME) for name, value in self.extra_headers.items(): - if name.lower() != 'from': # From is already handled + if name.lower() != "from": # From is already handled msg[name] = value return msg @@ -298,17 +312,21 @@ class EmailMessage: if isinstance(filename, MIMEBase): if content is not None or mimetype is not None: raise ValueError( - 'content and mimetype must not be given when a MIMEBase ' - 'instance is provided.' + "content and mimetype must not be given when a MIMEBase " + "instance is provided." ) self.attachments.append(filename) elif content is None: - raise ValueError('content must be provided.') + raise ValueError("content must be provided.") else: - mimetype = mimetype or mimetypes.guess_type(filename)[0] or DEFAULT_ATTACHMENT_MIME_TYPE - basetype, subtype = mimetype.split('/', 1) + mimetype = ( + mimetype + or mimetypes.guess_type(filename)[0] + or DEFAULT_ATTACHMENT_MIME_TYPE + ) + basetype, subtype = mimetype.split("/", 1) - if basetype == 'text': + if basetype == "text": if isinstance(content, bytes): try: content = content.decode() @@ -331,7 +349,7 @@ class EmailMessage: DEFAULT_ATTACHMENT_MIME_TYPE and don't decode the content. """ path = Path(path) - with path.open('rb') as file: + with path.open("rb") as file: content = file.read() self.attach(path.name, content, mimetype) @@ -359,11 +377,11 @@ class EmailMessage: If the mimetype is message/rfc822, content may be an email.Message or EmailMessage object, as well as a str. """ - basetype, subtype = mimetype.split('/', 1) - if basetype == 'text': + basetype, subtype = mimetype.split("/", 1) + if basetype == "text": encoding = self.encoding or settings.DEFAULT_CHARSET attachment = SafeMIMEText(content, subtype, encoding) - elif basetype == 'message' and subtype == 'rfc822': + elif basetype == "message" and subtype == "rfc822": # Bug #18967: per RFC2046 s5.2.1, message/rfc822 attachments # must not be base64 encoded. if isinstance(content, EmailMessage): @@ -390,10 +408,12 @@ class EmailMessage: attachment = self._create_mime_attachment(content, mimetype) if filename: try: - filename.encode('ascii') + filename.encode("ascii") except UnicodeEncodeError: - filename = ('utf-8', '', filename) - attachment.add_header('Content-Disposition', 'attachment', filename=filename) + filename = ("utf-8", "", filename) + attachment.add_header( + "Content-Disposition", "attachment", filename=filename + ) return attachment def _set_list_header_if_not_empty(self, msg, header, values): @@ -405,7 +425,7 @@ class EmailMessage: try: value = self.extra_headers[header] except KeyError: - value = ', '.join(str(v) for v in values) + value = ", ".join(str(v) for v in values) msg[header] = value @@ -415,25 +435,45 @@ class EmailMultiAlternatives(EmailMessage): messages. For example, including text and HTML versions of the text is made easier. """ - alternative_subtype = 'alternative' - def __init__(self, subject='', body='', from_email=None, to=None, bcc=None, - connection=None, attachments=None, headers=None, alternatives=None, - cc=None, reply_to=None): + alternative_subtype = "alternative" + + def __init__( + self, + subject="", + body="", + from_email=None, + to=None, + bcc=None, + connection=None, + attachments=None, + headers=None, + alternatives=None, + cc=None, + reply_to=None, + ): """ Initialize a single email message (which can be sent to multiple recipients). """ super().__init__( - subject, body, from_email, to, bcc, connection, attachments, - headers, cc, reply_to, + subject, + body, + from_email, + to, + bcc, + connection, + attachments, + headers, + cc, + reply_to, ) self.alternatives = alternatives or [] def attach_alternative(self, content, mimetype): """Attach an alternative content representation.""" if content is None or mimetype is None: - raise ValueError('Both content and mimetype must be provided.') + raise ValueError("Both content and mimetype must be provided.") self.alternatives.append((content, mimetype)) def _create_message(self, msg): @@ -443,7 +483,9 @@ class EmailMultiAlternatives(EmailMessage): encoding = self.encoding or settings.DEFAULT_CHARSET if self.alternatives: body_msg = msg - msg = SafeMIMEMultipart(_subtype=self.alternative_subtype, encoding=encoding) + msg = SafeMIMEMultipart( + _subtype=self.alternative_subtype, encoding=encoding + ) if self.body: msg.attach(body_msg) for alternative in self.alternatives: diff --git a/django/core/mail/utils.py b/django/core/mail/utils.py index 1e48faa366..8143c236d5 100644 --- a/django/core/mail/utils.py +++ b/django/core/mail/utils.py @@ -14,7 +14,7 @@ class CachedDnsName: return self.get_fqdn() def get_fqdn(self): - if not hasattr(self, '_fqdn'): + if not hasattr(self, "_fqdn"): self._fqdn = punycode(socket.getfqdn()) return self._fqdn diff --git a/django/core/management/__init__.py b/django/core/management/__init__.py index 6133e71c50..7049e06474 100644 --- a/django/core/management/__init__.py +++ b/django/core/management/__init__.py @@ -3,7 +3,10 @@ import os import pkgutil import sys from argparse import ( - _AppendConstAction, _CountAction, _StoreConstAction, _SubParsersAction, + _AppendConstAction, + _CountAction, + _StoreConstAction, + _SubParsersAction, ) from collections import defaultdict from difflib import get_close_matches @@ -14,7 +17,10 @@ from django.apps import apps from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.management.base import ( - BaseCommand, CommandError, CommandParser, handle_default_options, + BaseCommand, + CommandError, + CommandParser, + handle_default_options, ) from django.core.management.color import color_style from django.utils import autoreload @@ -25,9 +31,12 @@ def find_commands(management_dir): Given a path to a management directory, return a list of all the command names that are available. """ - command_dir = os.path.join(management_dir, 'commands') - return [name for _, name, is_pkg in pkgutil.iter_modules([command_dir]) - if not is_pkg and not name.startswith('_')] + command_dir = os.path.join(management_dir, "commands") + return [ + name + for _, name, is_pkg in pkgutil.iter_modules([command_dir]) + if not is_pkg and not name.startswith("_") + ] def load_command_class(app_name, name): @@ -36,7 +45,7 @@ def load_command_class(app_name, name): class instance. Allow all errors raised by the import process (ImportError, AttributeError) to propagate. """ - module = import_module('%s.management.commands.%s' % (app_name, name)) + module = import_module("%s.management.commands.%s" % (app_name, name)) return module.Command() @@ -63,13 +72,13 @@ def get_commands(): The dictionary is cached on the first call and reused on subsequent calls. """ - commands = {name: 'django.core' for name in find_commands(__path__[0])} + commands = {name: "django.core" for name in find_commands(__path__[0])} if not settings.configured: return commands for app_config in reversed(apps.get_app_configs()): - path = os.path.join(app_config.path, 'management') + path = os.path.join(app_config.path, "management") commands.update({name: app_config.name for name in find_commands(path)}) return commands @@ -98,7 +107,7 @@ def call_command(command_name, *args, **options): if isinstance(command_name, BaseCommand): # Command object passed in. command = command_name - command_name = command.__class__.__module__.split('.')[-1] + command_name = command.__class__.__module__.split(".")[-1] else: # Load the command object by name. try: @@ -113,11 +122,12 @@ def call_command(command_name, *args, **options): command = load_command_class(app_name, command_name) # Simulate argument parsing to get the option defaults (see #10080 for details). - parser = command.create_parser('', command_name) + parser = command.create_parser("", command_name) # Use the `dest` option name from the parser option opt_mapping = { - min(s_opt.option_strings).lstrip('-').replace('-', '_'): s_opt.dest - for s_opt in parser._actions if s_opt.option_strings + min(s_opt.option_strings).lstrip("-").replace("-", "_"): s_opt.dest + for s_opt in parser._actions + if s_opt.option_strings } arg_options = {opt_mapping.get(key, key): value for key, value in options.items()} parse_args = [] @@ -140,20 +150,20 @@ def call_command(command_name, *args, **options): mutually_exclusive_required_options = { opt for group in parser._mutually_exclusive_groups - for opt in group._group_actions if group.required + for opt in group._group_actions + if group.required } # Any required arguments which are passed in via **options must be passed # to parse_args(). for opt in parser_actions: - if ( - opt.dest in options and - (opt.required or opt in mutually_exclusive_required_options) + if opt.dest in options and ( + opt.required or opt in mutually_exclusive_required_options ): opt_dest_count = sum(v == opt.dest for v in opt_mapping.values()) if opt_dest_count > 1: raise TypeError( - f'Cannot pass the dest {opt.dest!r} that matches multiple ' - f'arguments via **options.' + f"Cannot pass the dest {opt.dest!r} that matches multiple " + f"arguments via **options." ) parse_args.append(min(opt.option_strings)) if isinstance(opt, (_AppendConstAction, _CountAction, _StoreConstAction)): @@ -173,16 +183,17 @@ def call_command(command_name, *args, **options): if unknown_options: raise TypeError( "Unknown option(s) for %s command: %s. " - "Valid options are: %s." % ( + "Valid options are: %s." + % ( command_name, - ', '.join(sorted(unknown_options)), - ', '.join(sorted(valid_options)), + ", ".join(sorted(unknown_options)), + ", ".join(sorted(valid_options)), ) ) # Move positional args out of options to mimic legacy optparse - args = defaults.pop('args', ()) - if 'skip_checks' not in options: - defaults['skip_checks'] = True + args = defaults.pop("args", ()) + if "skip_checks" not in options: + defaults["skip_checks"] = True return command.execute(*args, **defaults) @@ -191,11 +202,12 @@ class ManagementUtility: """ Encapsulate the logic of the django-admin and manage.py utilities. """ + def __init__(self, argv=None): self.argv = argv or sys.argv[:] self.prog_name = os.path.basename(self.argv[0]) - if self.prog_name == '__main__.py': - self.prog_name = 'python -m django' + if self.prog_name == "__main__.py": + self.prog_name = "python -m django" self.settings_exception = None def main_help_text(self, commands_only=False): @@ -205,16 +217,17 @@ class ManagementUtility: else: usage = [ "", - "Type '%s help <subcommand>' for help on a specific subcommand." % self.prog_name, + "Type '%s help <subcommand>' for help on a specific subcommand." + % self.prog_name, "", "Available subcommands:", ] commands_dict = defaultdict(lambda: []) for name, app in get_commands().items(): - if app == 'django.core': - app = 'django' + if app == "django.core": + app = "django" else: - app = app.rpartition('.')[-1] + app = app.rpartition(".")[-1] commands_dict[app].append(name) style = color_style() for app in sorted(commands_dict): @@ -224,12 +237,15 @@ class ManagementUtility: usage.append(" %s" % name) # Output an extra note if settings are not properly configured if self.settings_exception is not None: - usage.append(style.NOTICE( - "Note that only Django core commands are listed " - "as settings are not properly configured (error: %s)." - % self.settings_exception)) + usage.append( + style.NOTICE( + "Note that only Django core commands are listed " + "as settings are not properly configured (error: %s)." + % self.settings_exception + ) + ) - return '\n'.join(usage) + return "\n".join(usage) def fetch_command(self, subcommand): """ @@ -242,7 +258,7 @@ class ManagementUtility: try: app_name = commands[subcommand] except KeyError: - if os.environ.get('DJANGO_SETTINGS_MODULE'): + if os.environ.get("DJANGO_SETTINGS_MODULE"): # If `subcommand` is missing due to misconfigured settings, the # following line will retrigger an ImproperlyConfigured exception # (get_commands() swallows the original one) so the user is @@ -251,9 +267,9 @@ class ManagementUtility: elif not settings.configured: sys.stderr.write("No Django settings specified.\n") possible_matches = get_close_matches(subcommand, commands) - sys.stderr.write('Unknown command: %r' % subcommand) + sys.stderr.write("Unknown command: %r" % subcommand) if possible_matches: - sys.stderr.write('. Did you mean %s?' % possible_matches[0]) + sys.stderr.write(". Did you mean %s?" % possible_matches[0]) sys.stderr.write("\nType '%s help' for usage.\n" % self.prog_name) sys.exit(1) if isinstance(app_name, BaseCommand): @@ -285,29 +301,29 @@ class ManagementUtility: and formatted as potential completion suggestions. """ # Don't complete if user hasn't sourced bash_completion file. - if 'DJANGO_AUTO_COMPLETE' not in os.environ: + if "DJANGO_AUTO_COMPLETE" not in os.environ: return - cwords = os.environ['COMP_WORDS'].split()[1:] - cword = int(os.environ['COMP_CWORD']) + cwords = os.environ["COMP_WORDS"].split()[1:] + cword = int(os.environ["COMP_CWORD"]) try: curr = cwords[cword - 1] except IndexError: - curr = '' + curr = "" - subcommands = [*get_commands(), 'help'] - options = [('--help', False)] + subcommands = [*get_commands(), "help"] + options = [("--help", False)] # subcommand if cword == 1: - print(' '.join(sorted(filter(lambda x: x.startswith(curr), subcommands)))) + print(" ".join(sorted(filter(lambda x: x.startswith(curr), subcommands)))) # subcommand options # special case: the 'help' subcommand has no options - elif cwords[0] in subcommands and cwords[0] != 'help': + elif cwords[0] in subcommands and cwords[0] != "help": subcommand_cls = self.fetch_command(cwords[0]) # special case: add the names of installed apps to options - if cwords[0] in ('dumpdata', 'sqlmigrate', 'sqlsequencereset', 'test'): + if cwords[0] in ("dumpdata", "sqlmigrate", "sqlsequencereset", "test"): try: app_configs = apps.get_app_configs() # Get the last part of the dotted path as the app name. @@ -316,13 +332,14 @@ class ManagementUtility: # Fail silently if DJANGO_SETTINGS_MODULE isn't set. The # user will find out once they execute the command. pass - parser = subcommand_cls.create_parser('', cwords[0]) + parser = subcommand_cls.create_parser("", cwords[0]) options.extend( (min(s_opt.option_strings), s_opt.nargs != 0) - for s_opt in parser._actions if s_opt.option_strings + for s_opt in parser._actions + if s_opt.option_strings ) # filter out previously specified options from available options - prev_opts = {x.split('=')[0] for x in cwords[1:cword - 1]} + prev_opts = {x.split("=")[0] for x in cwords[1 : cword - 1]} options = (opt for opt in options if opt[0] not in prev_opts) # filter options by current input @@ -330,7 +347,7 @@ class ManagementUtility: for opt_label, require_arg in options: # append '=' to options which require args if require_arg: - opt_label += '=' + opt_label += "=" print(opt_label) # Exit code of the bash completion function is never passed back to # the user, so it's safe to always exit with 0. @@ -345,20 +362,20 @@ class ManagementUtility: try: subcommand = self.argv[1] except IndexError: - subcommand = 'help' # Display help if no arguments were given. + subcommand = "help" # Display help if no arguments were given. # Preprocess options to extract --settings and --pythonpath. # These options could affect the commands that are available, so they # must be processed early. parser = CommandParser( prog=self.prog_name, - usage='%(prog)s subcommand [options] [args]', + usage="%(prog)s subcommand [options] [args]", add_help=False, allow_abbrev=False, ) - parser.add_argument('--settings') - parser.add_argument('--pythonpath') - parser.add_argument('args', nargs='*') # catch-all + parser.add_argument("--settings") + parser.add_argument("--pythonpath") + parser.add_argument("args", nargs="*") # catch-all try: options, args = parser.parse_known_args(self.argv[2:]) handle_default_options(options) @@ -376,7 +393,7 @@ class ManagementUtility: # Start the auto-reloading dev server even if the code is broken. # The hardcoded condition is a code smell but we can't rely on a # flag on the command class because we haven't located it yet. - if subcommand == 'runserver' and '--noreload' not in self.argv: + if subcommand == "runserver" and "--noreload" not in self.argv: try: autoreload.check_errors(django.setup)() except Exception: @@ -391,7 +408,9 @@ class ManagementUtility: # (e.g. options for the contrib.staticfiles' runserver). # Changes here require manually testing as described in # #27522. - _parser = self.fetch_command('runserver').create_parser('django', 'runserver') + _parser = self.fetch_command("runserver").create_parser( + "django", "runserver" + ) _options, _args = _parser.parse_known_args(self.argv[2:]) for _arg in _args: self.argv.remove(_arg) @@ -402,19 +421,21 @@ class ManagementUtility: self.autocomplete() - if subcommand == 'help': - if '--commands' in args: - sys.stdout.write(self.main_help_text(commands_only=True) + '\n') + if subcommand == "help": + if "--commands" in args: + sys.stdout.write(self.main_help_text(commands_only=True) + "\n") elif not options.args: - sys.stdout.write(self.main_help_text() + '\n') + sys.stdout.write(self.main_help_text() + "\n") else: - self.fetch_command(options.args[0]).print_help(self.prog_name, options.args[0]) + self.fetch_command(options.args[0]).print_help( + self.prog_name, options.args[0] + ) # Special-cases: We want 'django-admin --version' and # 'django-admin --help' to work, for backwards compatibility. - elif subcommand == 'version' or self.argv[1:] == ['--version']: - sys.stdout.write(django.get_version() + '\n') - elif self.argv[1:] in (['--help'], ['-h']): - sys.stdout.write(self.main_help_text() + '\n') + elif subcommand == "version" or self.argv[1:] == ["--version"]: + sys.stdout.write(django.get_version() + "\n") + elif self.argv[1:] in (["--help"], ["-h"]): + sys.stdout.write(self.main_help_text() + "\n") else: self.fetch_command(subcommand).run_from_argv(self.argv) diff --git a/django/core/management/base.py b/django/core/management/base.py index 197230fc14..ab1f7e9f70 100644 --- a/django/core/management/base.py +++ b/django/core/management/base.py @@ -14,7 +14,7 @@ from django.core.exceptions import ImproperlyConfigured from django.core.management.color import color_style, no_style from django.db import DEFAULT_DB_ALIAS, connections -ALL_CHECKS = '__all__' +ALL_CHECKS = "__all__" class CommandError(Exception): @@ -29,6 +29,7 @@ class CommandError(Exception): error) is the preferred way to indicate that something has gone wrong in the execution of a command. """ + def __init__(self, *args, returncode=1, **kwargs): self.returncode = returncode super().__init__(*args, **kwargs) @@ -38,6 +39,7 @@ class SystemCheckError(CommandError): """ The system check framework detected unrecoverable errors. """ + pass @@ -47,15 +49,19 @@ class CommandParser(ArgumentParser): SystemExit in several occasions, as SystemExit is unacceptable when a command is called programmatically. """ - def __init__(self, *, missing_args_message=None, called_from_command_line=None, **kwargs): + + def __init__( + self, *, missing_args_message=None, called_from_command_line=None, **kwargs + ): self.missing_args_message = missing_args_message self.called_from_command_line = called_from_command_line super().__init__(**kwargs) def parse_args(self, args=None, namespace=None): # Catch missing argument for a better error message - if (self.missing_args_message and - not (args or any(not arg.startswith('-') for arg in args))): + if self.missing_args_message and not ( + args or any(not arg.startswith("-") for arg in args) + ): self.error(self.missing_args_message) return super().parse_args(args, namespace) @@ -73,15 +79,17 @@ def handle_default_options(options): user commands. """ if options.settings: - os.environ['DJANGO_SETTINGS_MODULE'] = options.settings + os.environ["DJANGO_SETTINGS_MODULE"] = options.settings if options.pythonpath: sys.path.insert(0, options.pythonpath) def no_translations(handle_func): """Decorator that forces a command to run with translations deactivated.""" + def wrapped(*args, **kwargs): from django.utils import translation + saved_locale = translation.get_language() translation.deactivate_all() try: @@ -90,6 +98,7 @@ def no_translations(handle_func): if saved_locale is not None: translation.activate(saved_locale) return res + return wrapped @@ -98,15 +107,21 @@ class DjangoHelpFormatter(HelpFormatter): Customized formatter so that command-specific arguments appear in the --help output before arguments common to all commands. """ + show_last = { - '--version', '--verbosity', '--traceback', '--settings', '--pythonpath', - '--no-color', '--force-color', '--skip-checks', + "--version", + "--verbosity", + "--traceback", + "--settings", + "--pythonpath", + "--no-color", + "--force-color", + "--skip-checks", } def _reordered_actions(self, actions): return sorted( - actions, - key=lambda a: set(a.option_strings) & self.show_last != set() + actions, key=lambda a: set(a.option_strings) & self.show_last != set() ) def add_usage(self, usage, actions, *args, **kwargs): @@ -120,6 +135,7 @@ class OutputWrapper(TextIOBase): """ Wrapper around stdout/stderr """ + @property def style_func(self): return self._style_func @@ -131,7 +147,7 @@ class OutputWrapper(TextIOBase): else: self._style_func = lambda x: x - def __init__(self, out, ending='\n'): + def __init__(self, out, ending="\n"): self._out = out self.style_func = None self.ending = ending @@ -140,13 +156,13 @@ class OutputWrapper(TextIOBase): return getattr(self._out, name) def flush(self): - if hasattr(self._out, 'flush'): + if hasattr(self._out, "flush"): self._out.flush() def isatty(self): - return hasattr(self._out, 'isatty') and self._out.isatty() + return hasattr(self._out, "isatty") and self._out.isatty() - def write(self, msg='', style_func=None, ending=None): + def write(self, msg="", style_func=None, ending=None): ending = self.ending if ending is None else ending if ending and not msg.endswith(ending): msg += ending @@ -225,17 +241,18 @@ class BaseCommand: A tuple of any options the command uses which aren't defined by the argument parser. """ + # Metadata about this command. - help = '' + help = "" # Configuration shortcuts that alter various logic. _called_from_command_line = False output_transaction = False # Whether to wrap the output in a "BEGIN; COMMIT;" requires_migrations_checks = False - requires_system_checks = '__all__' + requires_system_checks = "__all__" # Arguments, common to all commands, which aren't defined by the argument # parser. - base_stealth_options = ('stderr', 'stdout') + base_stealth_options = ("stderr", "stdout") # Command-specific options not defined by the argument parser. stealth_options = () suppressed_base_arguments = set() @@ -251,10 +268,10 @@ class BaseCommand: self.style = color_style(force_color) self.stderr.style_func = self.style.ERROR if ( - not isinstance(self.requires_system_checks, (list, tuple)) and - self.requires_system_checks != ALL_CHECKS + not isinstance(self.requires_system_checks, (list, tuple)) + and self.requires_system_checks != ALL_CHECKS ): - raise TypeError('requires_system_checks must be a list or tuple.') + raise TypeError("requires_system_checks must be a list or tuple.") def get_version(self): """ @@ -270,50 +287,66 @@ class BaseCommand: parse the arguments to this command. """ parser = CommandParser( - prog='%s %s' % (os.path.basename(prog_name), subcommand), + prog="%s %s" % (os.path.basename(prog_name), subcommand), description=self.help or None, formatter_class=DjangoHelpFormatter, - missing_args_message=getattr(self, 'missing_args_message', None), - called_from_command_line=getattr(self, '_called_from_command_line', None), - **kwargs + missing_args_message=getattr(self, "missing_args_message", None), + called_from_command_line=getattr(self, "_called_from_command_line", None), + **kwargs, ) self.add_base_argument( - parser, '--version', action='version', version=self.get_version(), + parser, + "--version", + action="version", + version=self.get_version(), help="Show program's version number and exit.", ) self.add_base_argument( - parser, '-v', '--verbosity', default=1, - type=int, choices=[0, 1, 2, 3], - help='Verbosity level; 0=minimal output, 1=normal output, 2=verbose output, 3=very verbose output', + parser, + "-v", + "--verbosity", + default=1, + type=int, + choices=[0, 1, 2, 3], + help="Verbosity level; 0=minimal output, 1=normal output, 2=verbose output, 3=very verbose output", ) self.add_base_argument( - parser, '--settings', + parser, + "--settings", help=( - 'The Python path to a settings module, e.g. ' + "The Python path to a settings module, e.g. " '"myproject.settings.main". If this isn\'t provided, the ' - 'DJANGO_SETTINGS_MODULE environment variable will be used.' + "DJANGO_SETTINGS_MODULE environment variable will be used." ), ) self.add_base_argument( - parser, '--pythonpath', + parser, + "--pythonpath", help='A directory to add to the Python path, e.g. "/home/djangoprojects/myproject".', ) self.add_base_argument( - parser, '--traceback', action='store_true', - help='Raise on CommandError exceptions.', + parser, + "--traceback", + action="store_true", + help="Raise on CommandError exceptions.", ) self.add_base_argument( - parser, '--no-color', action='store_true', + parser, + "--no-color", + action="store_true", help="Don't colorize the command output.", ) self.add_base_argument( - parser, '--force-color', action='store_true', - help='Force colorization of the command output.', + parser, + "--force-color", + action="store_true", + help="Force colorization of the command output.", ) if self.requires_system_checks: parser.add_argument( - '--skip-checks', action='store_true', - help='Skip system checks.', + "--skip-checks", + action="store_true", + help="Skip system checks.", ) self.add_arguments(parser) return parser @@ -331,7 +364,7 @@ class BaseCommand: """ for arg in args: if arg in self.suppressed_base_arguments: - kwargs['help'] = argparse.SUPPRESS + kwargs["help"] = argparse.SUPPRESS break parser.add_argument(*args, **kwargs) @@ -357,7 +390,7 @@ class BaseCommand: options = parser.parse_args(argv[2:]) cmd_options = vars(options) # Move positional args out of options to mimic legacy optparse - args = cmd_options.pop('args', ()) + args = cmd_options.pop("args", ()) handle_default_options(options) try: self.execute(*args, **cmd_options) @@ -369,7 +402,7 @@ class BaseCommand: if isinstance(e, SystemCheckError): self.stderr.write(str(e), lambda x: x) else: - self.stderr.write('%s: %s' % (e.__class__.__name__, e)) + self.stderr.write("%s: %s" % (e.__class__.__name__, e)) sys.exit(e.returncode) finally: try: @@ -385,19 +418,21 @@ class BaseCommand: controlled by the ``requires_system_checks`` attribute, except if force-skipped). """ - if options['force_color'] and options['no_color']: - raise CommandError("The --no-color and --force-color options can't be used together.") - if options['force_color']: + if options["force_color"] and options["no_color"]: + raise CommandError( + "The --no-color and --force-color options can't be used together." + ) + if options["force_color"]: self.style = color_style(force_color=True) - elif options['no_color']: + elif options["no_color"]: self.style = no_style() self.stderr.style_func = None - if options.get('stdout'): - self.stdout = OutputWrapper(options['stdout']) - if options.get('stderr'): - self.stderr = OutputWrapper(options['stderr']) + if options.get("stdout"): + self.stdout = OutputWrapper(options["stdout"]) + if options.get("stderr"): + self.stderr = OutputWrapper(options["stderr"]) - if self.requires_system_checks and not options['skip_checks']: + if self.requires_system_checks and not options["skip_checks"]: if self.requires_system_checks == ALL_CHECKS: self.check() else: @@ -407,8 +442,8 @@ class BaseCommand: output = self.handle(*args, **options) if output: if self.output_transaction: - connection = connections[options.get('database', DEFAULT_DB_ALIAS)] - output = '%s\n%s\n%s' % ( + connection = connections[options.get("database", DEFAULT_DB_ALIAS)] + output = "%s\n%s\n%s" % ( self.style.SQL_KEYWORD(connection.ops.start_transaction_sql()), output, self.style.SQL_KEYWORD(connection.ops.end_transaction_sql()), @@ -416,9 +451,15 @@ class BaseCommand: self.stdout.write(output) return output - def check(self, app_configs=None, tags=None, display_num_errors=False, - include_deployment_checks=False, fail_level=checks.ERROR, - databases=None): + def check( + self, + app_configs=None, + tags=None, + display_num_errors=False, + include_deployment_checks=False, + fail_level=checks.ERROR, + databases=None, + ): """ Use the system check framework to validate entire Django project. Raise CommandError for any serious message (error or critical errors). @@ -436,17 +477,35 @@ class BaseCommand: visible_issue_count = 0 # excludes silenced warnings if all_issues: - debugs = [e for e in all_issues if e.level < checks.INFO and not e.is_silenced()] - infos = [e for e in all_issues if checks.INFO <= e.level < checks.WARNING and not e.is_silenced()] - warnings = [e for e in all_issues if checks.WARNING <= e.level < checks.ERROR and not e.is_silenced()] - errors = [e for e in all_issues if checks.ERROR <= e.level < checks.CRITICAL and not e.is_silenced()] - criticals = [e for e in all_issues if checks.CRITICAL <= e.level and not e.is_silenced()] + debugs = [ + e for e in all_issues if e.level < checks.INFO and not e.is_silenced() + ] + infos = [ + e + for e in all_issues + if checks.INFO <= e.level < checks.WARNING and not e.is_silenced() + ] + warnings = [ + e + for e in all_issues + if checks.WARNING <= e.level < checks.ERROR and not e.is_silenced() + ] + errors = [ + e + for e in all_issues + if checks.ERROR <= e.level < checks.CRITICAL and not e.is_silenced() + ] + criticals = [ + e + for e in all_issues + if checks.CRITICAL <= e.level and not e.is_silenced() + ] sorted_issues = [ - (criticals, 'CRITICALS'), - (errors, 'ERRORS'), - (warnings, 'WARNINGS'), - (infos, 'INFOS'), - (debugs, 'DEBUGS'), + (criticals, "CRITICALS"), + (errors, "ERRORS"), + (warnings, "WARNINGS"), + (infos, "INFOS"), + (debugs, "DEBUGS"), ] for issues, group_name in sorted_issues: @@ -456,20 +515,23 @@ class BaseCommand: self.style.ERROR(str(e)) if e.is_serious() else self.style.WARNING(str(e)) - for e in issues) + for e in issues + ) formatted = "\n".join(sorted(formatted)) - body += '\n%s:\n%s\n' % (group_name, formatted) + body += "\n%s:\n%s\n" % (group_name, formatted) if visible_issue_count: header = "System check identified some issues:\n" if display_num_errors: if visible_issue_count: - footer += '\n' + footer += "\n" footer += "System check identified %s (%s silenced)." % ( - "no issues" if visible_issue_count == 0 else - "1 issue" if visible_issue_count == 1 else - "%s issues" % visible_issue_count, + "no issues" + if visible_issue_count == 0 + else "1 issue" + if visible_issue_count == 1 + else "%s issues" % visible_issue_count, len(all_issues) - visible_issue_count, ) @@ -491,6 +553,7 @@ class BaseCommand: migrations in the database. """ from django.db.migrations.executor import MigrationExecutor + try: executor = MigrationExecutor(connections[DEFAULT_DB_ALIAS]) except ImproperlyConfigured: @@ -499,25 +562,32 @@ class BaseCommand: plan = executor.migration_plan(executor.loader.graph.leaf_nodes()) if plan: - apps_waiting_migration = sorted({migration.app_label for migration, backwards in plan}) + apps_waiting_migration = sorted( + {migration.app_label for migration, backwards in plan} + ) self.stdout.write( self.style.NOTICE( "\nYou have %(unapplied_migration_count)s unapplied migration(s). " "Your project may not work properly until you apply the " - "migrations for app(s): %(apps_waiting_migration)s." % { + "migrations for app(s): %(apps_waiting_migration)s." + % { "unapplied_migration_count": len(plan), "apps_waiting_migration": ", ".join(apps_waiting_migration), } ) ) - self.stdout.write(self.style.NOTICE("Run 'python manage.py migrate' to apply them.")) + self.stdout.write( + self.style.NOTICE("Run 'python manage.py migrate' to apply them.") + ) def handle(self, *args, **options): """ The actual logic of the command. Subclasses must implement this method. """ - raise NotImplementedError('subclasses of BaseCommand must provide a handle() method') + raise NotImplementedError( + "subclasses of BaseCommand must provide a handle() method" + ) class AppCommand(BaseCommand): @@ -528,23 +598,32 @@ class AppCommand(BaseCommand): Rather than implementing ``handle()``, subclasses must implement ``handle_app_config()``, which will be called once for each application. """ + missing_args_message = "Enter at least one application label." def add_arguments(self, parser): - parser.add_argument('args', metavar='app_label', nargs='+', help='One or more application label.') + parser.add_argument( + "args", + metavar="app_label", + nargs="+", + help="One or more application label.", + ) def handle(self, *app_labels, **options): from django.apps import apps + try: app_configs = [apps.get_app_config(app_label) for app_label in app_labels] except (LookupError, ImportError) as e: - raise CommandError("%s. Are you sure your INSTALLED_APPS setting is correct?" % e) + raise CommandError( + "%s. Are you sure your INSTALLED_APPS setting is correct?" % e + ) output = [] for app_config in app_configs: app_output = self.handle_app_config(app_config, **options) if app_output: output.append(app_output) - return '\n'.join(output) + return "\n".join(output) def handle_app_config(self, app_config, **options): """ @@ -568,11 +647,12 @@ class LabelCommand(BaseCommand): If the arguments should be names of installed applications, use ``AppCommand`` instead. """ - label = 'label' + + label = "label" missing_args_message = "Enter at least one %s." % label def add_arguments(self, parser): - parser.add_argument('args', metavar=self.label, nargs='+') + parser.add_argument("args", metavar=self.label, nargs="+") def handle(self, *labels, **options): output = [] @@ -580,11 +660,13 @@ class LabelCommand(BaseCommand): label_output = self.handle_label(label, **options) if label_output: output.append(label_output) - return '\n'.join(output) + return "\n".join(output) def handle_label(self, label, **options): """ Perform the command's actions for ``label``, which will be the string as given on the command line. """ - raise NotImplementedError('subclasses of LabelCommand must provide a handle_label() method') + raise NotImplementedError( + "subclasses of LabelCommand must provide a handle_label() method" + ) diff --git a/django/core/management/color.py b/django/core/management/color.py index be8c31bb95..d2255d2282 100644 --- a/django/core/management/color.py +++ b/django/core/management/color.py @@ -10,6 +10,7 @@ from django.utils import termcolors try: import colorama + colorama.init() except (ImportError, OSError): HAS_COLORAMA = False @@ -22,6 +23,7 @@ def supports_color(): Return True if the running system's terminal supports color, and False otherwise. """ + def vt_codes_enabled_in_windows_registry(): """ Check the Windows Registry to see if VT code handling has been enabled @@ -33,26 +35,28 @@ def supports_color(): except ImportError: return False else: - reg_key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, 'Console') + reg_key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, "Console") try: - reg_key_value, _ = winreg.QueryValueEx(reg_key, 'VirtualTerminalLevel') + reg_key_value, _ = winreg.QueryValueEx(reg_key, "VirtualTerminalLevel") except FileNotFoundError: return False else: return reg_key_value == 1 # isatty is not always implemented, #6223. - is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() + is_a_tty = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() return is_a_tty and ( - sys.platform != 'win32' or - HAS_COLORAMA or - 'ANSICON' in os.environ or + sys.platform != "win32" + or HAS_COLORAMA + or "ANSICON" in os.environ + or # Windows Terminal supports VT codes. - 'WT_SESSION' in os.environ or + "WT_SESSION" in os.environ + or # Microsoft Visual Studio Code's built-in terminal supports colors. - os.environ.get('TERM_PROGRAM') == 'vscode' or - vt_codes_enabled_in_windows_registry() + os.environ.get("TERM_PROGRAM") == "vscode" + or vt_codes_enabled_in_windows_registry() ) @@ -60,7 +64,7 @@ class Style: pass -def make_style(config_string=''): +def make_style(config_string=""): """ Create a Style object from the given config_string. @@ -79,8 +83,10 @@ def make_style(config_string=''): format = color_settings.get(role, {}) style_func = termcolors.make_style(**format) else: + def style_func(x): return x + setattr(style, role, style_func) # For backwards compatibility, @@ -95,7 +101,7 @@ def no_style(): """ Return a Style object with no color scheme. """ - return make_style('nocolor') + return make_style("nocolor") def color_style(force_color=False): @@ -104,4 +110,4 @@ def color_style(force_color=False): """ if not force_color and not supports_color(): return no_style() - return make_style(os.environ.get('DJANGO_COLORS', '')) + return make_style(os.environ.get("DJANGO_COLORS", "")) diff --git a/django/core/management/commands/check.py b/django/core/management/commands/check.py index a92563641f..7624b85390 100644 --- a/django/core/management/commands/check.py +++ b/django/core/management/commands/check.py @@ -10,37 +10,46 @@ class Command(BaseCommand): requires_system_checks = [] def add_arguments(self, parser): - parser.add_argument('args', metavar='app_label', nargs='*') + parser.add_argument("args", metavar="app_label", nargs="*") parser.add_argument( - '--tag', '-t', action='append', dest='tags', - help='Run only checks labeled with given tag.', + "--tag", + "-t", + action="append", + dest="tags", + help="Run only checks labeled with given tag.", ) parser.add_argument( - '--list-tags', action='store_true', - help='List available tags.', + "--list-tags", + action="store_true", + help="List available tags.", ) parser.add_argument( - '--deploy', action='store_true', - help='Check deployment settings.', + "--deploy", + action="store_true", + help="Check deployment settings.", ) parser.add_argument( - '--fail-level', - default='ERROR', - choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], + "--fail-level", + default="ERROR", + choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], help=( - 'Message level that will cause the command to exit with a ' - 'non-zero status. Default is ERROR.' + "Message level that will cause the command to exit with a " + "non-zero status. Default is ERROR." ), ) parser.add_argument( - '--database', action='append', dest='databases', - help='Run database related checks against these aliases.', + "--database", + action="append", + dest="databases", + help="Run database related checks against these aliases.", ) def handle(self, *app_labels, **options): - include_deployment_checks = options['deploy'] - if options['list_tags']: - self.stdout.write('\n'.join(sorted(registry.tags_available(include_deployment_checks)))) + include_deployment_checks = options["deploy"] + if options["list_tags"]: + self.stdout.write( + "\n".join(sorted(registry.tags_available(include_deployment_checks))) + ) return if app_labels: @@ -48,23 +57,27 @@ class Command(BaseCommand): else: app_configs = None - tags = options['tags'] + tags = options["tags"] if tags: try: invalid_tag = next( - tag for tag in tags if not checks.tag_exists(tag, include_deployment_checks) + tag + for tag in tags + if not checks.tag_exists(tag, include_deployment_checks) ) except StopIteration: # no invalid tags pass else: - raise CommandError('There is no system check with the "%s" tag.' % invalid_tag) + raise CommandError( + 'There is no system check with the "%s" tag.' % invalid_tag + ) self.check( app_configs=app_configs, tags=tags, display_num_errors=True, include_deployment_checks=include_deployment_checks, - fail_level=getattr(checks, options['fail_level']), - databases=options['databases'], + fail_level=getattr(checks, options["fail_level"]), + databases=options["databases"], ) diff --git a/django/core/management/commands/compilemessages.py b/django/core/management/commands/compilemessages.py index 308fa8831b..bd055d087f 100644 --- a/django/core/management/commands/compilemessages.py +++ b/django/core/management/commands/compilemessages.py @@ -5,22 +5,22 @@ import os from pathlib import Path from django.core.management.base import BaseCommand, CommandError -from django.core.management.utils import ( - find_command, is_ignored_path, popen_wrapper, -) +from django.core.management.utils import find_command, is_ignored_path, popen_wrapper def has_bom(fn): - with fn.open('rb') as f: + with fn.open("rb") as f: sample = f.read(4) - return sample.startswith((codecs.BOM_UTF8, codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)) + return sample.startswith( + (codecs.BOM_UTF8, codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE) + ) def is_writable(path): # Known side effect: updating file access/modified time to current time if # it is writable. try: - with open(path, 'a'): + with open(path, "a"): os.utime(path, None) except OSError: return False @@ -28,71 +28,91 @@ def is_writable(path): class Command(BaseCommand): - help = 'Compiles .po files to .mo files for use with builtin gettext support.' + help = "Compiles .po files to .mo files for use with builtin gettext support." requires_system_checks = [] - program = 'msgfmt' - program_options = ['--check-format'] + program = "msgfmt" + program_options = ["--check-format"] def add_arguments(self, parser): parser.add_argument( - '--locale', '-l', action='append', default=[], - help='Locale(s) to process (e.g. de_AT). Default is to process all. ' - 'Can be used multiple times.', + "--locale", + "-l", + action="append", + default=[], + help="Locale(s) to process (e.g. de_AT). Default is to process all. " + "Can be used multiple times.", ) parser.add_argument( - '--exclude', '-x', action='append', default=[], - help='Locales to exclude. Default is none. Can be used multiple times.', + "--exclude", + "-x", + action="append", + default=[], + help="Locales to exclude. Default is none. Can be used multiple times.", ) parser.add_argument( - '--use-fuzzy', '-f', dest='fuzzy', action='store_true', - help='Use fuzzy translations.', + "--use-fuzzy", + "-f", + dest="fuzzy", + action="store_true", + help="Use fuzzy translations.", ) parser.add_argument( - '--ignore', '-i', action='append', dest='ignore_patterns', - default=[], metavar='PATTERN', - help='Ignore directories matching this glob-style pattern. ' - 'Use multiple times to ignore more.', + "--ignore", + "-i", + action="append", + dest="ignore_patterns", + default=[], + metavar="PATTERN", + help="Ignore directories matching this glob-style pattern. " + "Use multiple times to ignore more.", ) def handle(self, **options): - locale = options['locale'] - exclude = options['exclude'] - ignore_patterns = set(options['ignore_patterns']) - self.verbosity = options['verbosity'] - if options['fuzzy']: - self.program_options = self.program_options + ['-f'] + locale = options["locale"] + exclude = options["exclude"] + ignore_patterns = set(options["ignore_patterns"]) + self.verbosity = options["verbosity"] + if options["fuzzy"]: + self.program_options = self.program_options + ["-f"] if find_command(self.program) is None: - raise CommandError("Can't find %s. Make sure you have GNU gettext " - "tools 0.15 or newer installed." % self.program) + raise CommandError( + "Can't find %s. Make sure you have GNU gettext " + "tools 0.15 or newer installed." % self.program + ) - basedirs = [os.path.join('conf', 'locale'), 'locale'] - if os.environ.get('DJANGO_SETTINGS_MODULE'): + basedirs = [os.path.join("conf", "locale"), "locale"] + if os.environ.get("DJANGO_SETTINGS_MODULE"): from django.conf import settings + basedirs.extend(settings.LOCALE_PATHS) # Walk entire tree, looking for locale directories - for dirpath, dirnames, filenames in os.walk('.', topdown=True): + for dirpath, dirnames, filenames in os.walk(".", topdown=True): for dirname in dirnames: - if is_ignored_path(os.path.normpath(os.path.join(dirpath, dirname)), ignore_patterns): + if is_ignored_path( + os.path.normpath(os.path.join(dirpath, dirname)), ignore_patterns + ): dirnames.remove(dirname) - elif dirname == 'locale': + elif dirname == "locale": basedirs.append(os.path.join(dirpath, dirname)) # Gather existing directories. basedirs = set(map(os.path.abspath, filter(os.path.isdir, basedirs))) if not basedirs: - raise CommandError("This script should be run from the Django Git " - "checkout or your project or app tree, or with " - "the settings module specified.") + raise CommandError( + "This script should be run from the Django Git " + "checkout or your project or app tree, or with " + "the settings module specified." + ) # Build locale list all_locales = [] for basedir in basedirs: - locale_dirs = filter(os.path.isdir, glob.glob('%s/*' % basedir)) + locale_dirs = filter(os.path.isdir, glob.glob("%s/*" % basedir)) all_locales.extend(map(os.path.basename, locale_dirs)) # Account for excluded locales @@ -102,18 +122,22 @@ class Command(BaseCommand): self.has_errors = False for basedir in basedirs: if locales: - dirs = [os.path.join(basedir, locale, 'LC_MESSAGES') for locale in locales] + dirs = [ + os.path.join(basedir, locale, "LC_MESSAGES") for locale in locales + ] else: dirs = [basedir] locations = [] for ldir in dirs: for dirpath, dirnames, filenames in os.walk(ldir): - locations.extend((dirpath, f) for f in filenames if f.endswith('.po')) + locations.extend( + (dirpath, f) for f in filenames if f.endswith(".po") + ) if locations: self.compile_messages(locations) if self.has_errors: - raise CommandError('compilemessages generated one or more errors.') + raise CommandError("compilemessages generated one or more errors.") def compile_messages(self, locations): """ @@ -123,24 +147,25 @@ class Command(BaseCommand): futures = [] for i, (dirpath, f) in enumerate(locations): po_path = Path(dirpath) / f - mo_path = po_path.with_suffix('.mo') + mo_path = po_path.with_suffix(".mo") try: if mo_path.stat().st_mtime >= po_path.stat().st_mtime: if self.verbosity > 0: self.stdout.write( - 'File “%s” is already compiled and up to date.' + "File “%s” is already compiled and up to date." % po_path ) continue except FileNotFoundError: pass if self.verbosity > 0: - self.stdout.write('processing file %s in %s' % (f, dirpath)) + self.stdout.write("processing file %s in %s" % (f, dirpath)) if has_bom(po_path): self.stderr.write( - 'The %s file has a BOM (Byte Order Mark). Django only ' - 'supports .po files encoded in UTF-8 and without any BOM.' % po_path + "The %s file has a BOM (Byte Order Mark). Django only " + "supports .po files encoded in UTF-8 and without any BOM." + % po_path ) self.has_errors = True continue @@ -148,13 +173,13 @@ class Command(BaseCommand): # Check writability on first location if i == 0 and not is_writable(mo_path): self.stderr.write( - 'The po files under %s are in a seemingly not writable location. ' - 'mo files will not be updated/created.' % dirpath + "The po files under %s are in a seemingly not writable location. " + "mo files will not be updated/created." % dirpath ) self.has_errors = True return - args = [self.program, *self.program_options, '-o', mo_path, po_path] + args = [self.program, *self.program_options, "-o", mo_path, po_path] futures.append(executor.submit(popen_wrapper, args)) for future in concurrent.futures.as_completed(futures): @@ -162,7 +187,9 @@ class Command(BaseCommand): if status: if self.verbosity > 0: if errors: - self.stderr.write("Execution of %s failed: %s" % (self.program, errors)) + self.stderr.write( + "Execution of %s failed: %s" % (self.program, errors) + ) else: self.stderr.write("Execution of %s failed" % self.program) self.has_errors = True diff --git a/django/core/management/commands/createcachetable.py b/django/core/management/commands/createcachetable.py index 84f61049cd..99dc3da040 100644 --- a/django/core/management/commands/createcachetable.py +++ b/django/core/management/commands/createcachetable.py @@ -3,7 +3,12 @@ from django.core.cache import caches from django.core.cache.backends.db import BaseDatabaseCache from django.core.management.base import BaseCommand, CommandError from django.db import ( - DEFAULT_DB_ALIAS, DatabaseError, connections, models, router, transaction, + DEFAULT_DB_ALIAS, + DatabaseError, + connections, + models, + router, + transaction, ) @@ -14,24 +19,27 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'args', metavar='table_name', nargs='*', - help='Optional table names. Otherwise, settings.CACHES is used to find cache tables.', + "args", + metavar="table_name", + nargs="*", + help="Optional table names. Otherwise, settings.CACHES is used to find cache tables.", ) parser.add_argument( - '--database', + "--database", default=DEFAULT_DB_ALIAS, - help='Nominates a database onto which the cache tables will be ' - 'installed. Defaults to the "default" database.', + help="Nominates a database onto which the cache tables will be " + 'installed. Defaults to the "default" database.', ) parser.add_argument( - '--dry-run', action='store_true', - help='Does not create the table, just prints the SQL that would be run.', + "--dry-run", + action="store_true", + help="Does not create the table, just prints the SQL that would be run.", ) def handle(self, *tablenames, **options): - db = options['database'] - self.verbosity = options['verbosity'] - dry_run = options['dry_run'] + db = options["database"] + self.verbosity = options["verbosity"] + dry_run = options["dry_run"] if tablenames: # Legacy behavior, tablename specified as argument for tablename in tablenames: @@ -55,9 +63,11 @@ class Command(BaseCommand): fields = ( # "key" is a reserved word in MySQL, so use "cache_key" instead. - models.CharField(name='cache_key', max_length=255, unique=True, primary_key=True), - models.TextField(name='value'), - models.DateTimeField(name='expires', db_index=True), + models.CharField( + name="cache_key", max_length=255, unique=True, primary_key=True + ), + models.TextField(name="value"), + models.DateTimeField(name="expires", db_index=True), ) table_output = [] index_output = [] @@ -66,7 +76,7 @@ class Command(BaseCommand): field_output = [ qn(f.name), f.db_type(connection=connection), - '%sNULL' % ('NOT ' if not f.null else ''), + "%sNULL" % ("NOT " if not f.null else ""), ] if f.primary_key: field_output.append("PRIMARY KEY") @@ -75,14 +85,21 @@ class Command(BaseCommand): if f.db_index: unique = "UNIQUE " if f.unique else "" index_output.append( - "CREATE %sINDEX %s ON %s (%s);" % - (unique, qn('%s_%s' % (tablename, f.name)), qn(tablename), qn(f.name)) + "CREATE %sINDEX %s ON %s (%s);" + % ( + unique, + qn("%s_%s" % (tablename, f.name)), + qn(tablename), + qn(f.name), + ) ) table_output.append(" ".join(field_output)) full_statement = ["CREATE TABLE %s (" % qn(tablename)] for i, line in enumerate(table_output): - full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else '')) - full_statement.append(');') + full_statement.append( + " %s%s" % (line, "," if i < len(table_output) - 1 else "") + ) + full_statement.append(");") full_statement = "\n".join(full_statement) @@ -92,14 +109,17 @@ class Command(BaseCommand): self.stdout.write(statement) return - with transaction.atomic(using=database, savepoint=connection.features.can_rollback_ddl): + with transaction.atomic( + using=database, savepoint=connection.features.can_rollback_ddl + ): with connection.cursor() as curs: try: curs.execute(full_statement) except DatabaseError as e: raise CommandError( - "Cache table '%s' could not be created.\nThe error was: %s." % - (tablename, e)) + "Cache table '%s' could not be created.\nThe error was: %s." + % (tablename, e) + ) for statement in index_output: curs.execute(statement) diff --git a/django/core/management/commands/dbshell.py b/django/core/management/commands/dbshell.py index cd94787f3d..9cdd64f190 100644 --- a/django/core/management/commands/dbshell.py +++ b/django/core/management/commands/dbshell.py @@ -14,29 +14,31 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database onto which to open a shell. Defaults to the "default" database.', ) - parameters = parser.add_argument_group('parameters', prefix_chars='--') - parameters.add_argument('parameters', nargs='*') + parameters = parser.add_argument_group("parameters", prefix_chars="--") + parameters.add_argument("parameters", nargs="*") def handle(self, **options): - connection = connections[options['database']] + connection = connections[options["database"]] try: - connection.client.runshell(options['parameters']) + connection.client.runshell(options["parameters"]) except FileNotFoundError: # Note that we're assuming the FileNotFoundError relates to the # command missing. It could be raised for some other reason, in # which case this error message would be inaccurate. Still, this # message catches the common case. raise CommandError( - 'You appear not to have the %r program installed or on your path.' % - connection.client.executable_name + "You appear not to have the %r program installed or on your path." + % connection.client.executable_name ) except subprocess.CalledProcessError as e: raise CommandError( - '"%s" returned non-zero exit status %s.' % ( - ' '.join(e.cmd), + '"%s" returned non-zero exit status %s.' + % ( + " ".join(e.cmd), e.returncode, ), returncode=e.returncode, diff --git a/django/core/management/commands/diffsettings.py b/django/core/management/commands/diffsettings.py index 5adf35eb66..27cd575294 100644 --- a/django/core/management/commands/diffsettings.py +++ b/django/core/management/commands/diffsettings.py @@ -1,7 +1,7 @@ from django.core.management.base import BaseCommand -def module_to_dict(module, omittable=lambda k: k.startswith('_') or not k.isupper()): +def module_to_dict(module, omittable=lambda k: k.startswith("_") or not k.isupper()): """Convert a module namespace to a Python dictionary.""" return {k: repr(getattr(module, k)) for k in dir(module) if not omittable(k)} @@ -14,21 +14,25 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - '--all', action='store_true', + "--all", + action="store_true", help=( 'Display all settings, regardless of their value. In "hash" ' 'mode, default values are prefixed by "###".' ), ) parser.add_argument( - '--default', metavar='MODULE', + "--default", + metavar="MODULE", help=( "The settings module to compare the current settings against. Leave empty to " "compare against Django's default settings." ), ) parser.add_argument( - '--output', default='hash', choices=('hash', 'unified'), + "--output", + default="hash", + choices=("hash", "unified"), help=( "Selects the output format. 'hash' mode displays each changed " "setting, with the settings that don't appear in the defaults " @@ -46,13 +50,15 @@ class Command(BaseCommand): settings._setup() user_settings = module_to_dict(settings._wrapped) - default = options['default'] - default_settings = module_to_dict(Settings(default) if default else global_settings) + default = options["default"] + default_settings = module_to_dict( + Settings(default) if default else global_settings + ) output_func = { - 'hash': self.output_hash, - 'unified': self.output_unified, - }[options['output']] - return '\n'.join(output_func(user_settings, default_settings, **options)) + "hash": self.output_hash, + "unified": self.output_unified, + }[options["output"]] + return "\n".join(output_func(user_settings, default_settings, **options)) def output_hash(self, user_settings, default_settings, **options): # Inspired by Postfix's "postconf -n". @@ -62,7 +68,7 @@ class Command(BaseCommand): output.append("%s = %s ###" % (key, user_settings[key])) elif user_settings[key] != default_settings[key]: output.append("%s = %s" % (key, user_settings[key])) - elif options['all']: + elif options["all"]: output.append("### %s = %s" % (key, user_settings[key])) return output @@ -70,10 +76,16 @@ class Command(BaseCommand): output = [] for key in sorted(user_settings): if key not in default_settings: - output.append(self.style.SUCCESS("+ %s = %s" % (key, user_settings[key]))) + output.append( + self.style.SUCCESS("+ %s = %s" % (key, user_settings[key])) + ) elif user_settings[key] != default_settings[key]: - output.append(self.style.ERROR("- %s = %s" % (key, default_settings[key]))) - output.append(self.style.SUCCESS("+ %s = %s" % (key, user_settings[key]))) - elif options['all']: + output.append( + self.style.ERROR("- %s = %s" % (key, default_settings[key])) + ) + output.append( + self.style.SUCCESS("+ %s = %s" % (key, user_settings[key])) + ) + elif options["all"]: output.append(" %s = %s" % (key, user_settings[key])) return output diff --git a/django/core/management/commands/dumpdata.py b/django/core/management/commands/dumpdata.py index 925a23a56d..74a5b2d22a 100644 --- a/django/core/management/commands/dumpdata.py +++ b/django/core/management/commands/dumpdata.py @@ -10,12 +10,14 @@ from django.db import DEFAULT_DB_ALIAS, router try: import bz2 + has_bz2 = True except ImportError: has_bz2 = False try: import lzma + has_lzma = True except ImportError: has_lzma = False @@ -33,65 +35,79 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'args', metavar='app_label[.ModelName]', nargs='*', - help='Restricts dumped data to the specified app_label or app_label.ModelName.', + "args", + metavar="app_label[.ModelName]", + nargs="*", + help="Restricts dumped data to the specified app_label or app_label.ModelName.", ) parser.add_argument( - '--format', default='json', - help='Specifies the output serialization format for fixtures.', + "--format", + default="json", + help="Specifies the output serialization format for fixtures.", ) parser.add_argument( - '--indent', type=int, - help='Specifies the indent level to use when pretty-printing output.', + "--indent", + type=int, + help="Specifies the indent level to use when pretty-printing output.", ) parser.add_argument( - '--database', + "--database", default=DEFAULT_DB_ALIAS, - help='Nominates a specific database to dump fixtures from. ' - 'Defaults to the "default" database.', + help="Nominates a specific database to dump fixtures from. " + 'Defaults to the "default" database.', ) parser.add_argument( - '-e', '--exclude', action='append', default=[], - help='An app_label or app_label.ModelName to exclude ' - '(use multiple --exclude to exclude multiple apps/models).', + "-e", + "--exclude", + action="append", + default=[], + help="An app_label or app_label.ModelName to exclude " + "(use multiple --exclude to exclude multiple apps/models).", ) parser.add_argument( - '--natural-foreign', action='store_true', dest='use_natural_foreign_keys', - help='Use natural foreign keys if they are available.', + "--natural-foreign", + action="store_true", + dest="use_natural_foreign_keys", + help="Use natural foreign keys if they are available.", ) parser.add_argument( - '--natural-primary', action='store_true', dest='use_natural_primary_keys', - help='Use natural primary keys if they are available.', + "--natural-primary", + action="store_true", + dest="use_natural_primary_keys", + help="Use natural primary keys if they are available.", ) parser.add_argument( - '-a', '--all', action='store_true', dest='use_base_manager', + "-a", + "--all", + action="store_true", + dest="use_base_manager", help="Use Django's base manager to dump all models stored in the database, " - "including those that would otherwise be filtered or modified by a custom manager.", + "including those that would otherwise be filtered or modified by a custom manager.", ) parser.add_argument( - '--pks', dest='primary_keys', + "--pks", + dest="primary_keys", help="Only dump objects with given primary keys. Accepts a comma-separated " - "list of keys. This option only works when you specify one model.", + "list of keys. This option only works when you specify one model.", ) parser.add_argument( - '-o', '--output', - help='Specifies file to which the output is written.' + "-o", "--output", help="Specifies file to which the output is written." ) def handle(self, *app_labels, **options): - format = options['format'] - indent = options['indent'] - using = options['database'] - excludes = options['exclude'] - output = options['output'] - show_traceback = options['traceback'] - use_natural_foreign_keys = options['use_natural_foreign_keys'] - use_natural_primary_keys = options['use_natural_primary_keys'] - use_base_manager = options['use_base_manager'] - pks = options['primary_keys'] + format = options["format"] + indent = options["indent"] + using = options["database"] + excludes = options["exclude"] + output = options["output"] + show_traceback = options["traceback"] + use_natural_foreign_keys = options["use_natural_foreign_keys"] + use_natural_primary_keys = options["use_natural_primary_keys"] + use_base_manager = options["use_base_manager"] + pks = options["primary_keys"] if pks: - primary_keys = [pk.strip() for pk in pks.split(',')] + primary_keys = [pk.strip() for pk in pks.split(",")] else: primary_keys = [] @@ -101,8 +117,10 @@ class Command(BaseCommand): if primary_keys: raise CommandError("You can only use --pks option with one model") app_list = dict.fromkeys( - app_config for app_config in apps.get_app_configs() - if app_config.models_module is not None and app_config not in excluded_apps + app_config + for app_config in apps.get_app_configs() + if app_config.models_module is not None + and app_config not in excluded_apps ) else: if len(app_labels) > 1 and primary_keys: @@ -110,7 +128,7 @@ class Command(BaseCommand): app_list = {} for label in app_labels: try: - app_label, model_label = label.split('.') + app_label, model_label = label.split(".") try: app_config = apps.get_app_config(app_label) except LookupError as e: @@ -120,7 +138,9 @@ class Command(BaseCommand): try: model = app_config.get_model(model_label) except LookupError: - raise CommandError("Unknown model: %s.%s" % (app_label, model_label)) + raise CommandError( + "Unknown model: %s.%s" % (app_label, model_label) + ) app_list_value = app_list.setdefault(app_config, []) @@ -131,7 +151,9 @@ class Command(BaseCommand): app_list_value.append(model) except ValueError: if primary_keys: - raise CommandError("You can only use --pks option with one model") + raise CommandError( + "You can only use --pks option with one model" + ) # This is just an app - no model qualifier app_label = label try: @@ -158,7 +180,9 @@ class Command(BaseCommand): count the number of objects to be serialized. """ if use_natural_foreign_keys: - models = serializers.sort_dependencies(app_list.items(), allow_cycles=True) + models = serializers.sort_dependencies( + app_list.items(), allow_cycles=True + ) else: # There is no need to sort dependencies when natural foreign # keys are not used. @@ -173,7 +197,8 @@ class Command(BaseCommand): continue if model._meta.proxy and model._meta.proxy_for_model not in models: warnings.warn( - "%s is a proxy model and won't be serialized." % model._meta.label, + "%s is a proxy model and won't be serialized." + % model._meta.label, category=ProxyModelWarning, ) if not model._meta.proxy and router.allow_migrate_model(using, model): @@ -195,25 +220,27 @@ class Command(BaseCommand): progress_output = None object_count = 0 # If dumpdata is outputting to stdout, there is no way to display progress - if output and self.stdout.isatty() and options['verbosity'] > 0: + if output and self.stdout.isatty() and options["verbosity"] > 0: progress_output = self.stdout object_count = sum(get_objects(count_only=True)) if output: file_root, file_ext = os.path.splitext(output) compression_formats = { - '.bz2': (open, {}, file_root), - '.gz': (gzip.open, {}, output), - '.lzma': (open, {}, file_root), - '.xz': (open, {}, file_root), - '.zip': (open, {}, file_root), + ".bz2": (open, {}, file_root), + ".gz": (gzip.open, {}, output), + ".lzma": (open, {}, file_root), + ".xz": (open, {}, file_root), + ".zip": (open, {}, file_root), } if has_bz2: - compression_formats['.bz2'] = (bz2.open, {}, output) + compression_formats[".bz2"] = (bz2.open, {}, output) if has_lzma: - compression_formats['.lzma'] = ( - lzma.open, {'format': lzma.FORMAT_ALONE}, output + compression_formats[".lzma"] = ( + lzma.open, + {"format": lzma.FORMAT_ALONE}, + output, ) - compression_formats['.xz'] = (lzma.open, {}, output) + compression_formats[".xz"] = (lzma.open, {}, output) try: open_method, kwargs, file_path = compression_formats[file_ext] except KeyError: @@ -225,15 +252,18 @@ class Command(BaseCommand): f"Fixtures saved in '{file_name}'.", RuntimeWarning, ) - stream = open_method(file_path, 'wt', **kwargs) + stream = open_method(file_path, "wt", **kwargs) else: stream = None try: serializers.serialize( - format, get_objects(), indent=indent, + format, + get_objects(), + indent=indent, use_natural_foreign_keys=use_natural_foreign_keys, use_natural_primary_keys=use_natural_primary_keys, - stream=stream or self.stdout, progress_output=progress_output, + stream=stream or self.stdout, + progress_output=progress_output, object_count=object_count, ) finally: diff --git a/django/core/management/commands/flush.py b/django/core/management/commands/flush.py index 6737b9be40..8bad63d41d 100644 --- a/django/core/management/commands/flush.py +++ b/django/core/management/commands/flush.py @@ -9,30 +9,34 @@ from django.db import DEFAULT_DB_ALIAS, connections class Command(BaseCommand): help = ( - 'Removes ALL DATA from the database, including data added during ' + "Removes ALL DATA from the database, including data added during " 'migrations. Does not achieve a "fresh install" state.' ) - stealth_options = ('reset_sequences', 'allow_cascade', 'inhibit_post_migrate') + stealth_options = ("reset_sequences", "allow_cascade", "inhibit_post_migrate") def add_arguments(self, parser): parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database to flush. Defaults to the "default" database.', ) def handle(self, **options): - database = options['database'] + database = options["database"] connection = connections[database] - verbosity = options['verbosity'] - interactive = options['interactive'] + verbosity = options["verbosity"] + interactive = options["interactive"] # The following are stealth options used by Django's internals. - reset_sequences = options.get('reset_sequences', True) - allow_cascade = options.get('allow_cascade', False) - inhibit_post_migrate = options.get('inhibit_post_migrate', False) + reset_sequences = options.get("reset_sequences", True) + allow_cascade = options.get("allow_cascade", False) + inhibit_post_migrate = options.get("inhibit_post_migrate", False) self.style = no_style() @@ -40,25 +44,31 @@ class Command(BaseCommand): # dispatcher events. for app_config in apps.get_app_configs(): try: - import_module('.management', app_config.name) + import_module(".management", app_config.name) except ImportError: pass - sql_list = sql_flush(self.style, connection, - reset_sequences=reset_sequences, - allow_cascade=allow_cascade) + sql_list = sql_flush( + self.style, + connection, + reset_sequences=reset_sequences, + allow_cascade=allow_cascade, + ) if interactive: - confirm = input("""You have requested a flush of the database. + confirm = input( + """You have requested a flush of the database. This will IRREVERSIBLY DESTROY all data currently in the "%s" database, and return each table to an empty state. Are you sure you want to do this? - Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['NAME']) + Type 'yes' to continue, or 'no' to cancel: """ + % connection.settings_dict["NAME"] + ) else: - confirm = 'yes' + confirm = "yes" - if confirm == 'yes': + if confirm == "yes": try: connection.ops.execute_sql_flush(sql_list) except Exception as exc: @@ -68,9 +78,8 @@ Are you sure you want to do this? " * At least one of the expected database tables doesn't exist.\n" " * The SQL was invalid.\n" "Hint: Look at the output of 'django-admin sqlflush'. " - "That's the SQL this command wasn't able to run." % ( - connection.settings_dict['NAME'], - ) + "That's the SQL this command wasn't able to run." + % (connection.settings_dict["NAME"],) ) from exc # Empty sql_list may signify an empty database and post_migrate would then crash @@ -79,4 +88,4 @@ Are you sure you want to do this? # respond as if the database had been migrated from scratch. emit_post_migrate_signal(verbosity, interactive, database) else: - self.stdout.write('Flush cancelled.') + self.stdout.write("Flush cancelled.") diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index d64725ca73..753290d574 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -9,23 +9,30 @@ from django.db.models.constants import LOOKUP_SEP class Command(BaseCommand): help = "Introspects the database tables in the given database and outputs a Django model module." requires_system_checks = [] - stealth_options = ('table_name_filter',) - db_module = 'django.db' + stealth_options = ("table_name_filter",) + db_module = "django.db" def add_arguments(self, parser): parser.add_argument( - 'table', nargs='*', type=str, - help='Selects what tables or views should be introspected.', + "table", + nargs="*", + type=str, + help="Selects what tables or views should be introspected.", ) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database to introspect. Defaults to using the "default" database.', ) parser.add_argument( - '--include-partitions', action='store_true', help='Also output models for partition tables.', + "--include-partitions", + action="store_true", + help="Also output models for partition tables.", ) parser.add_argument( - '--include-views', action='store_true', help='Also output models for database views.', + "--include-views", + action="store_true", + help="Also output models for database views.", ) def handle(self, **options): @@ -33,15 +40,17 @@ class Command(BaseCommand): for line in self.handle_inspection(options): self.stdout.write(line) except NotImplementedError: - raise CommandError("Database inspection isn't supported for the currently selected database backend.") + raise CommandError( + "Database inspection isn't supported for the currently selected database backend." + ) def handle_inspection(self, options): - connection = connections[options['database']] + connection = connections[options["database"]] # 'table_name_filter' is a stealth option - table_name_filter = options.get('table_name_filter') + table_name_filter = options.get("table_name_filter") def table2model(table_name): - return re.sub(r'[^a-zA-Z0-9]', '', table_name.title()) + return re.sub(r"[^a-zA-Z0-9]", "", table_name.title()) with connection.cursor() as cursor: yield "# This is an auto-generated Django model module." @@ -54,55 +63,71 @@ class Command(BaseCommand): "Django to create, modify, and delete the table" ) yield "# Feel free to rename the models, but don't rename db_table values or field names." - yield 'from %s import models' % self.db_module + yield "from %s import models" % self.db_module known_models = [] table_info = connection.introspection.get_table_list(cursor) # Determine types of tables and/or views to be introspected. - types = {'t'} - if options['include_partitions']: - types.add('p') - if options['include_views']: - types.add('v') + types = {"t"} + if options["include_partitions"]: + types.add("p") + if options["include_views"]: + types.add("v") - for table_name in (options['table'] or sorted(info.name for info in table_info if info.type in types)): + for table_name in options["table"] or sorted( + info.name for info in table_info if info.type in types + ): if table_name_filter is not None and callable(table_name_filter): if not table_name_filter(table_name): continue try: try: - relations = connection.introspection.get_relations(cursor, table_name) + relations = connection.introspection.get_relations( + cursor, table_name + ) except NotImplementedError: relations = {} try: - constraints = connection.introspection.get_constraints(cursor, table_name) + constraints = connection.introspection.get_constraints( + cursor, table_name + ) except NotImplementedError: constraints = {} - primary_key_column = connection.introspection.get_primary_key_column(cursor, table_name) + primary_key_column = ( + connection.introspection.get_primary_key_column( + cursor, table_name + ) + ) unique_columns = [ - c['columns'][0] for c in constraints.values() - if c['unique'] and len(c['columns']) == 1 + c["columns"][0] + for c in constraints.values() + if c["unique"] and len(c["columns"]) == 1 ] - table_description = connection.introspection.get_table_description(cursor, table_name) + table_description = connection.introspection.get_table_description( + cursor, table_name + ) except Exception as e: yield "# Unable to inspect table '%s'" % table_name yield "# The error was: %s" % e continue - yield '' - yield '' - yield 'class %s(models.Model):' % table2model(table_name) + yield "" + yield "" + yield "class %s(models.Model):" % table2model(table_name) known_models.append(table2model(table_name)) used_column_names = [] # Holds column names used in the table so far column_to_field_name = {} # Maps column names to names of model fields for row in table_description: - comment_notes = [] # Holds Field notes, to be displayed in a Python comment. + comment_notes = ( + [] + ) # Holds Field notes, to be displayed in a Python comment. extra_params = {} # Holds Field parameters such as 'db_column'. column_name = row.name is_relation = column_name in relations att_name, params, notes = self.normalize_col_name( - column_name, used_column_names, is_relation) + column_name, used_column_names, is_relation + ) extra_params.update(params) comment_notes.extend(notes) @@ -111,70 +136,91 @@ class Command(BaseCommand): # Add primary_key and unique, if necessary. if column_name == primary_key_column: - extra_params['primary_key'] = True + extra_params["primary_key"] = True elif column_name in unique_columns: - extra_params['unique'] = True + extra_params["unique"] = True if is_relation: ref_db_column, ref_db_table = relations[column_name] - if extra_params.pop('unique', False) or extra_params.get('primary_key'): - rel_type = 'OneToOneField' + if extra_params.pop("unique", False) or extra_params.get( + "primary_key" + ): + rel_type = "OneToOneField" else: - rel_type = 'ForeignKey' - ref_pk_column = connection.introspection.get_primary_key_column(cursor, ref_db_table) + rel_type = "ForeignKey" + ref_pk_column = ( + connection.introspection.get_primary_key_column( + cursor, ref_db_table + ) + ) if ref_pk_column and ref_pk_column != ref_db_column: - extra_params['to_field'] = ref_db_column + extra_params["to_field"] = ref_db_column rel_to = ( - 'self' if ref_db_table == table_name + "self" + if ref_db_table == table_name else table2model(ref_db_table) ) if rel_to in known_models: - field_type = '%s(%s' % (rel_type, rel_to) + field_type = "%s(%s" % (rel_type, rel_to) else: field_type = "%s('%s'" % (rel_type, rel_to) else: # Calling `get_field_type` to get the field type string and any # additional parameters and notes. - field_type, field_params, field_notes = self.get_field_type(connection, table_name, row) + field_type, field_params, field_notes = self.get_field_type( + connection, table_name, row + ) extra_params.update(field_params) comment_notes.extend(field_notes) - field_type += '(' + field_type += "(" # Don't output 'id = meta.AutoField(primary_key=True)', because # that's assumed if it doesn't exist. - if att_name == 'id' and extra_params == {'primary_key': True}: - if field_type == 'AutoField(': + if att_name == "id" and extra_params == {"primary_key": True}: + if field_type == "AutoField(": continue - elif field_type == connection.features.introspected_field_types['AutoField'] + '(': - comment_notes.append('AutoField?') + elif ( + field_type + == connection.features.introspected_field_types["AutoField"] + + "(" + ): + comment_notes.append("AutoField?") # Add 'null' and 'blank', if the 'null_ok' flag was present in the # table description. if row.null_ok: # If it's NULL... - extra_params['blank'] = True - extra_params['null'] = True + extra_params["blank"] = True + extra_params["null"] = True - field_desc = '%s = %s%s' % ( + field_desc = "%s = %s%s" % ( att_name, # Custom fields will have a dotted path - '' if '.' in field_type else 'models.', + "" if "." in field_type else "models.", field_type, ) - if field_type.startswith(('ForeignKey(', 'OneToOneField(')): - field_desc += ', models.DO_NOTHING' + if field_type.startswith(("ForeignKey(", "OneToOneField(")): + field_desc += ", models.DO_NOTHING" if extra_params: - if not field_desc.endswith('('): - field_desc += ', ' - field_desc += ', '.join('%s=%r' % (k, v) for k, v in extra_params.items()) - field_desc += ')' + if not field_desc.endswith("("): + field_desc += ", " + field_desc += ", ".join( + "%s=%r" % (k, v) for k, v in extra_params.items() + ) + field_desc += ")" if comment_notes: - field_desc += ' # ' + ' '.join(comment_notes) - yield ' %s' % field_desc - is_view = any(info.name == table_name and info.type == 'v' for info in table_info) - is_partition = any(info.name == table_name and info.type == 'p' for info in table_info) - yield from self.get_meta(table_name, constraints, column_to_field_name, is_view, is_partition) + field_desc += " # " + " ".join(comment_notes) + yield " %s" % field_desc + is_view = any( + info.name == table_name and info.type == "v" for info in table_info + ) + is_partition = any( + info.name == table_name and info.type == "p" for info in table_info + ) + yield from self.get_meta( + table_name, constraints, column_to_field_name, is_view, is_partition + ) def normalize_col_name(self, col_name, used_column_names, is_relation): """ @@ -185,50 +231,54 @@ class Command(BaseCommand): new_name = col_name.lower() if new_name != col_name: - field_notes.append('Field name made lowercase.') + field_notes.append("Field name made lowercase.") if is_relation: - if new_name.endswith('_id'): + if new_name.endswith("_id"): new_name = new_name[:-3] else: - field_params['db_column'] = col_name + field_params["db_column"] = col_name - new_name, num_repl = re.subn(r'\W', '_', new_name) + new_name, num_repl = re.subn(r"\W", "_", new_name) if num_repl > 0: - field_notes.append('Field renamed to remove unsuitable characters.') + field_notes.append("Field renamed to remove unsuitable characters.") if new_name.find(LOOKUP_SEP) >= 0: while new_name.find(LOOKUP_SEP) >= 0: - new_name = new_name.replace(LOOKUP_SEP, '_') + new_name = new_name.replace(LOOKUP_SEP, "_") if col_name.lower().find(LOOKUP_SEP) >= 0: # Only add the comment if the double underscore was in the original name - field_notes.append("Field renamed because it contained more than one '_' in a row.") + field_notes.append( + "Field renamed because it contained more than one '_' in a row." + ) - if new_name.startswith('_'): - new_name = 'field%s' % new_name + if new_name.startswith("_"): + new_name = "field%s" % new_name field_notes.append("Field renamed because it started with '_'.") - if new_name.endswith('_'): - new_name = '%sfield' % new_name + if new_name.endswith("_"): + new_name = "%sfield" % new_name field_notes.append("Field renamed because it ended with '_'.") if keyword.iskeyword(new_name): - new_name += '_field' - field_notes.append('Field renamed because it was a Python reserved word.') + new_name += "_field" + field_notes.append("Field renamed because it was a Python reserved word.") if new_name[0].isdigit(): - new_name = 'number_%s' % new_name - field_notes.append("Field renamed because it wasn't a valid Python identifier.") + new_name = "number_%s" % new_name + field_notes.append( + "Field renamed because it wasn't a valid Python identifier." + ) if new_name in used_column_names: num = 0 - while '%s_%d' % (new_name, num) in used_column_names: + while "%s_%d" % (new_name, num) in used_column_names: num += 1 - new_name = '%s_%d' % (new_name, num) - field_notes.append('Field renamed because of name conflict.') + new_name = "%s_%d" % (new_name, num) + field_notes.append("Field renamed because of name conflict.") if col_name != new_name and field_notes: - field_params['db_column'] = col_name + field_params["db_column"] = col_name return new_name, field_params, field_notes @@ -244,30 +294,37 @@ class Command(BaseCommand): try: field_type = connection.introspection.get_field_type(row.type_code, row) except KeyError: - field_type = 'TextField' - field_notes.append('This field type is a guess.') + field_type = "TextField" + field_notes.append("This field type is a guess.") # Add max_length for all CharFields. - if field_type == 'CharField' and row.internal_size: - field_params['max_length'] = int(row.internal_size) + if field_type == "CharField" and row.internal_size: + field_params["max_length"] = int(row.internal_size) - if field_type in {'CharField', 'TextField'} and row.collation: - field_params['db_collation'] = row.collation + if field_type in {"CharField", "TextField"} and row.collation: + field_params["db_collation"] = row.collation - if field_type == 'DecimalField': + if field_type == "DecimalField": if row.precision is None or row.scale is None: field_notes.append( - 'max_digits and decimal_places have been guessed, as this ' - 'database handles decimal fields as float') - field_params['max_digits'] = row.precision if row.precision is not None else 10 - field_params['decimal_places'] = row.scale if row.scale is not None else 5 + "max_digits and decimal_places have been guessed, as this " + "database handles decimal fields as float" + ) + field_params["max_digits"] = ( + row.precision if row.precision is not None else 10 + ) + field_params["decimal_places"] = ( + row.scale if row.scale is not None else 5 + ) else: - field_params['max_digits'] = row.precision - field_params['decimal_places'] = row.scale + field_params["max_digits"] = row.precision + field_params["decimal_places"] = row.scale return field_type, field_params, field_notes - def get_meta(self, table_name, constraints, column_to_field_name, is_view, is_partition): + def get_meta( + self, table_name, constraints, column_to_field_name, is_view, is_partition + ): """ Return a sequence comprising the lines of code necessary to construct the inner Meta class for the model corresponding @@ -276,28 +333,30 @@ class Command(BaseCommand): unique_together = [] has_unsupported_constraint = False for params in constraints.values(): - if params['unique']: - columns = params['columns'] + if params["unique"]: + columns = params["columns"] if None in columns: has_unsupported_constraint = True columns = [x for x in columns if x is not None] if len(columns) > 1: - unique_together.append(str(tuple(column_to_field_name[c] for c in columns))) + unique_together.append( + str(tuple(column_to_field_name[c] for c in columns)) + ) if is_view: managed_comment = " # Created from a view. Don't remove." elif is_partition: managed_comment = " # Created from a partition. Don't remove." else: - managed_comment = '' - meta = [''] + managed_comment = "" + meta = [""] if has_unsupported_constraint: - meta.append(' # A unique constraint could not be introspected.') + meta.append(" # A unique constraint could not be introspected.") meta += [ - ' class Meta:', - ' managed = False%s' % managed_comment, - ' db_table = %r' % table_name + " class Meta:", + " managed = False%s" % managed_comment, + " db_table = %r" % table_name, ] if unique_together: - tup = '(' + ', '.join(unique_together) + ',)' + tup = "(" + ", ".join(unique_together) + ",)" meta += [" unique_together = %s" % tup] return meta diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 20428f9f10..38a2818d5c 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -15,64 +15,82 @@ from django.core.management.base import BaseCommand, CommandError from django.core.management.color import no_style from django.core.management.utils import parse_apps_and_model_labels from django.db import ( - DEFAULT_DB_ALIAS, DatabaseError, IntegrityError, connections, router, + DEFAULT_DB_ALIAS, + DatabaseError, + IntegrityError, + connections, + router, transaction, ) from django.utils.functional import cached_property try: import bz2 + has_bz2 = True except ImportError: has_bz2 = False try: import lzma + has_lzma = True except ImportError: has_lzma = False -READ_STDIN = '-' +READ_STDIN = "-" class Command(BaseCommand): - help = 'Installs the named fixture(s) in the database.' + help = "Installs the named fixture(s) in the database." missing_args_message = ( "No database fixture specified. Please provide the path of at least " "one fixture in the command line." ) def add_arguments(self, parser): - parser.add_argument('args', metavar='fixture', nargs='+', help='Fixture labels.') parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "args", metavar="fixture", nargs="+", help="Fixture labels." + ) + parser.add_argument( + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a specific database to load fixtures into. Defaults to the "default" database.', ) parser.add_argument( - '--app', dest='app_label', - help='Only look for fixtures in the specified app.', + "--app", + dest="app_label", + help="Only look for fixtures in the specified app.", ) parser.add_argument( - '--ignorenonexistent', '-i', action='store_true', dest='ignore', - help='Ignores entries in the serialized data for fields that do not ' - 'currently exist on the model.', + "--ignorenonexistent", + "-i", + action="store_true", + dest="ignore", + help="Ignores entries in the serialized data for fields that do not " + "currently exist on the model.", ) parser.add_argument( - '-e', '--exclude', action='append', default=[], - help='An app_label or app_label.ModelName to exclude. Can be used multiple times.', + "-e", + "--exclude", + action="append", + default=[], + help="An app_label or app_label.ModelName to exclude. Can be used multiple times.", ) parser.add_argument( - '--format', - help='Format of serialized data when reading from stdin.', + "--format", + help="Format of serialized data when reading from stdin.", ) def handle(self, *fixture_labels, **options): - self.ignore = options['ignore'] - self.using = options['database'] - self.app_label = options['app_label'] - self.verbosity = options['verbosity'] - self.excluded_models, self.excluded_apps = parse_apps_and_model_labels(options['exclude']) - self.format = options['format'] + self.ignore = options["ignore"] + self.using = options["database"] + self.app_label = options["app_label"] + self.verbosity = options["verbosity"] + self.excluded_models, self.excluded_apps = parse_apps_and_model_labels( + options["exclude"] + ) + self.format = options["format"] with transaction.atomic(using=self.using): self.loaddata(fixture_labels) @@ -89,16 +107,16 @@ class Command(BaseCommand): """A dict mapping format names to (open function, mode arg) tuples.""" # Forcing binary mode may be revisited after dropping Python 2 support (see #22399) compression_formats = { - None: (open, 'rb'), - 'gz': (gzip.GzipFile, 'rb'), - 'zip': (SingleZipReader, 'r'), - 'stdin': (lambda *args: sys.stdin, None), + None: (open, "rb"), + "gz": (gzip.GzipFile, "rb"), + "zip": (SingleZipReader, "r"), + "stdin": (lambda *args: sys.stdin, None), } if has_bz2: - compression_formats['bz2'] = (bz2.BZ2File, 'r') + compression_formats["bz2"] = (bz2.BZ2File, "r") if has_lzma: - compression_formats['lzma'] = (lzma.LZMAFile, 'r') - compression_formats['xz'] = (lzma.LZMAFile, 'r') + compression_formats["lzma"] = (lzma.LZMAFile, "r") + compression_formats["xz"] = (lzma.LZMAFile, "r") return compression_formats def reset_sequences(self, connection, models): @@ -106,7 +124,7 @@ class Command(BaseCommand): sequence_sql = connection.ops.sequence_reset_sql(no_style(), models) if sequence_sql: if self.verbosity >= 2: - self.stdout.write('Resetting sequences') + self.stdout.write("Resetting sequences") with connection.cursor() as cursor: for line in sequence_sql: cursor.execute(line) @@ -162,14 +180,18 @@ class Command(BaseCommand): else: self.stdout.write( "Installed %d object(s) (of %d) from %d fixture(s)" - % (self.loaded_object_count, self.fixture_object_count, self.fixture_count) + % ( + self.loaded_object_count, + self.fixture_object_count, + self.fixture_count, + ) ) def save_obj(self, obj): """Save an object if permitted.""" if ( - obj.object._meta.app_config in self.excluded_apps or - type(obj.object) in self.excluded_models + obj.object._meta.app_config in self.excluded_apps + or type(obj.object) in self.excluded_models ): return False saved = False @@ -180,11 +202,14 @@ class Command(BaseCommand): obj.save(using=self.using) # psycopg2 raises ValueError if data contains NUL chars. except (DatabaseError, IntegrityError, ValueError) as e: - e.args = ('Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s' % { - 'object_label': obj.object._meta.label, - 'pk': obj.object.pk, - 'error_msg': e, - },) + e.args = ( + "Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s" + % { + "object_label": obj.object._meta.label, + "pk": obj.object.pk, + "error_msg": e, + }, + ) raise if obj.deferred_fields: self.objs_with_deferred_fields.append(obj) @@ -193,7 +218,9 @@ class Command(BaseCommand): def load_label(self, fixture_label): """Load fixtures files for a given label.""" show_progress = self.verbosity >= 3 - for fixture_file, fixture_dir, fixture_name in self.find_fixtures(fixture_label): + for fixture_file, fixture_dir, fixture_name in self.find_fixtures( + fixture_label + ): _, ser_fmt, cmp_fmt = self.parse_name(os.path.basename(fixture_file)) open_method, mode = self.compression_formats[cmp_fmt] fixture = open_method(fixture_file, mode) @@ -207,7 +234,10 @@ class Command(BaseCommand): ) try: objects = serializers.deserialize( - ser_fmt, fixture, using=self.using, ignorenonexistent=self.ignore, + ser_fmt, + fixture, + using=self.using, + ignorenonexistent=self.ignore, handle_forward_references=True, ) @@ -217,12 +247,14 @@ class Command(BaseCommand): loaded_objects_in_fixture += 1 if show_progress: self.stdout.write( - '\rProcessed %i object(s).' % loaded_objects_in_fixture, - ending='' + "\rProcessed %i object(s)." % loaded_objects_in_fixture, + ending="", ) except Exception as e: if not isinstance(e, CommandError): - e.args = ("Problem installing fixture '%s': %s" % (fixture_file, e),) + e.args = ( + "Problem installing fixture '%s': %s" % (fixture_file, e), + ) raise finally: fixture.close() @@ -236,7 +268,7 @@ class Command(BaseCommand): warnings.warn( "No fixture data found for '%s'. (File format may be " "invalid.)" % fixture_name, - RuntimeWarning + RuntimeWarning, ) def get_fixture_name_and_dirs(self, fixture_name): @@ -254,16 +286,18 @@ class Command(BaseCommand): cmp_fmts = self.compression_formats if cmp_fmt is None else [cmp_fmt] ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt] return { - '%s.%s' % ( + "%s.%s" + % ( fixture_name, - '.'.join([ext for ext in combo if ext]), - ) for combo in product(databases, ser_fmts, cmp_fmts) + ".".join([ext for ext in combo if ext]), + ) + for combo in product(databases, ser_fmts, cmp_fmts) } def find_fixture_files_in_dir(self, fixture_dir, fixture_name, targets): fixture_files_in_dir = [] path = os.path.join(fixture_dir, fixture_name) - for candidate in glob.iglob(glob.escape(path) + '*'): + for candidate in glob.iglob(glob.escape(path) + "*"): if os.path.basename(candidate) in targets: # Save the fixture_dir and fixture_name for future error # messages. @@ -287,18 +321,22 @@ class Command(BaseCommand): if self.verbosity >= 2: self.stdout.write("Checking %s for fixtures..." % humanize(fixture_dir)) fixture_files_in_dir = self.find_fixture_files_in_dir( - fixture_dir, fixture_name, targets, + fixture_dir, + fixture_name, + targets, ) if self.verbosity >= 2 and not fixture_files_in_dir: - self.stdout.write("No fixture '%s' in %s." % - (fixture_name, humanize(fixture_dir))) + self.stdout.write( + "No fixture '%s' in %s." % (fixture_name, humanize(fixture_dir)) + ) # Check kept for backwards-compatibility; it isn't clear why # duplicates are only allowed in different directories. if len(fixture_files_in_dir) > 1: raise CommandError( - "Multiple fixtures named '%s' in %s. Aborting." % - (fixture_name, humanize(fixture_dir))) + "Multiple fixtures named '%s' in %s. Aborting." + % (fixture_name, humanize(fixture_dir)) + ) fixture_files.extend(fixture_files_in_dir) if not fixture_files: @@ -321,11 +359,12 @@ class Command(BaseCommand): raise ImproperlyConfigured("settings.FIXTURE_DIRS contains duplicates.") for app_config in apps.get_app_configs(): app_label = app_config.label - app_dir = os.path.join(app_config.path, 'fixtures') + app_dir = os.path.join(app_config.path, "fixtures") if app_dir in fixture_dirs: raise ImproperlyConfigured( "'%s' is a default fixture directory for the '%s' app " - "and cannot be listed in settings.FIXTURE_DIRS." % (app_dir, app_label) + "and cannot be listed in settings.FIXTURE_DIRS." + % (app_dir, app_label) ) if self.app_label and app_label != self.app_label: @@ -333,7 +372,7 @@ class Command(BaseCommand): if os.path.isdir(app_dir): dirs.append(app_dir) dirs.extend(fixture_dirs) - dirs.append('') + dirs.append("") return [os.path.realpath(d) for d in dirs] def parse_name(self, fixture_name): @@ -342,10 +381,12 @@ class Command(BaseCommand): """ if fixture_name == READ_STDIN: if not self.format: - raise CommandError('--format must be specified when reading from stdin.') - return READ_STDIN, self.format, 'stdin' + raise CommandError( + "--format must be specified when reading from stdin." + ) + return READ_STDIN, self.format, "stdin" - parts = fixture_name.rsplit('.', 2) + parts = fixture_name.rsplit(".", 2) if len(parts) > 1 and parts[-1] in self.compression_formats: cmp_fmt = parts[-1] @@ -360,17 +401,17 @@ class Command(BaseCommand): else: raise CommandError( "Problem installing fixture '%s': %s is not a known " - "serialization format." % ('.'.join(parts[:-1]), parts[-1])) + "serialization format." % (".".join(parts[:-1]), parts[-1]) + ) else: ser_fmt = None - name = '.'.join(parts) + name = ".".join(parts) return name, ser_fmt, cmp_fmt class SingleZipReader(zipfile.ZipFile): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if len(self.namelist()) != 1: @@ -381,4 +422,4 @@ class SingleZipReader(zipfile.ZipFile): def humanize(dirname): - return "'%s'" % dirname if dirname else 'absolute path' + return "'%s'" % dirname if dirname else "absolute path" diff --git a/django/core/management/commands/makemessages.py b/django/core/management/commands/makemessages.py index 0070342181..af1b43b7c7 100644 --- a/django/core/management/commands/makemessages.py +++ b/django/core/management/commands/makemessages.py @@ -12,7 +12,10 @@ from django.core.exceptions import ImproperlyConfigured from django.core.files.temp import NamedTemporaryFile from django.core.management.base import BaseCommand, CommandError from django.core.management.utils import ( - find_command, handle_extensions, is_ignored_path, popen_wrapper, + find_command, + handle_extensions, + is_ignored_path, + popen_wrapper, ) from django.utils.encoding import DEFAULT_LOCALE_ENCODING from django.utils.functional import cached_property @@ -21,7 +24,9 @@ from django.utils.regex_helper import _lazy_re_compile from django.utils.text import get_text_list from django.utils.translation import templatize -plural_forms_re = _lazy_re_compile(r'^(?P<value>"Plural-Forms.+?\\n")\s*$', re.MULTILINE | re.DOTALL) +plural_forms_re = _lazy_re_compile( + r'^(?P<value>"Plural-Forms.+?\\n")\s*$', re.MULTILINE | re.DOTALL +) STATUS_OK = 0 NO_LOCALE_DIR = object() @@ -63,6 +68,7 @@ class BuildFile: """ Represent the state of a translatable file during the build process. """ + def __init__(self, command, domain, translatable): self.command = command self.domain = domain @@ -70,11 +76,11 @@ class BuildFile: @cached_property def is_templatized(self): - if self.domain == 'djangojs': + if self.domain == "djangojs": return self.command.gettext_version < (0, 18, 3) - elif self.domain == 'django': + elif self.domain == "django": file_ext = os.path.splitext(self.translatable.file)[1] - return file_ext != '.py' + return file_ext != ".py" return False @cached_property @@ -90,10 +96,10 @@ class BuildFile: if not self.is_templatized: return self.path extension = { - 'djangojs': 'c', - 'django': 'py', + "djangojs": "c", + "django": "py", }.get(self.domain) - filename = '%s.%s' % (self.translatable.file, extension) + filename = "%s.%s" % (self.translatable.file, extension) return os.path.join(self.translatable.dirpath, filename) def preprocess(self): @@ -104,15 +110,15 @@ class BuildFile: if not self.is_templatized: return - with open(self.path, encoding='utf-8') as fp: + with open(self.path, encoding="utf-8") as fp: src_data = fp.read() - if self.domain == 'djangojs': + if self.domain == "djangojs": content = prepare_js_for_gettext(src_data) - elif self.domain == 'django': + elif self.domain == "django": content = templatize(src_data, origin=self.path[2:]) - with open(self.work_path, 'w', encoding='utf-8') as fp: + with open(self.work_path, "w", encoding="utf-8") as fp: fp.write(content) def postprocess_messages(self, msgs): @@ -126,7 +132,7 @@ class BuildFile: return msgs # Remove '.py' suffix - if os.name == 'nt': + if os.name == "nt": # Preserve '.\' prefix on Windows to respect gettext behavior old_path = self.work_path new_path = self.path @@ -135,10 +141,10 @@ class BuildFile: new_path = self.path[2:] return re.sub( - r'^(#: .*)(' + re.escape(old_path) + r')', + r"^(#: .*)(" + re.escape(old_path) + r")", lambda match: match[0].replace(old_path, new_path), msgs, - flags=re.MULTILINE + flags=re.MULTILINE, ) def cleanup(self): @@ -164,8 +170,8 @@ def normalize_eols(raw_contents): lines_list = raw_contents.splitlines() # Ensure last line has its EOL if lines_list and lines_list[-1]: - lines_list.append('') - return '\n'.join(lines_list) + lines_list.append("") + return "\n".join(lines_list) def write_pot_file(potfile, msgs): @@ -182,16 +188,16 @@ def write_pot_file(potfile, msgs): found, header_read = False, False for line in pot_lines: if not found and not header_read: - if 'charset=CHARSET' in line: + if "charset=CHARSET" in line: found = True - line = line.replace('charset=CHARSET', 'charset=UTF-8') + line = line.replace("charset=CHARSET", "charset=UTF-8") if not line and not found: header_read = True lines.append(line) - msgs = '\n'.join(lines) + msgs = "\n".join(lines) # Force newlines of POT files to '\n' to work around # https://savannah.gnu.org/bugs/index.php?52395 - with open(potfile, 'a', encoding='utf-8', newline='\n') as fp: + with open(potfile, "a", encoding="utf-8", newline="\n") as fp: fp.write(msgs) @@ -209,61 +215,86 @@ class Command(BaseCommand): requires_system_checks = [] - msgmerge_options = ['-q', '--backup=none', '--previous', '--update'] - msguniq_options = ['--to-code=utf-8'] - msgattrib_options = ['--no-obsolete'] - xgettext_options = ['--from-code=UTF-8', '--add-comments=Translators'] + msgmerge_options = ["-q", "--backup=none", "--previous", "--update"] + msguniq_options = ["--to-code=utf-8"] + msgattrib_options = ["--no-obsolete"] + xgettext_options = ["--from-code=UTF-8", "--add-comments=Translators"] def add_arguments(self, parser): parser.add_argument( - '--locale', '-l', default=[], action='append', - help='Creates or updates the message files for the given locale(s) (e.g. pt_BR). ' - 'Can be used multiple times.', + "--locale", + "-l", + default=[], + action="append", + help="Creates or updates the message files for the given locale(s) (e.g. pt_BR). " + "Can be used multiple times.", ) parser.add_argument( - '--exclude', '-x', default=[], action='append', - help='Locales to exclude. Default is none. Can be used multiple times.', + "--exclude", + "-x", + default=[], + action="append", + help="Locales to exclude. Default is none. Can be used multiple times.", ) parser.add_argument( - '--domain', '-d', default='django', + "--domain", + "-d", + default="django", help='The domain of the message files (default: "django").', ) parser.add_argument( - '--all', '-a', action='store_true', - help='Updates the message files for all existing locales.', + "--all", + "-a", + action="store_true", + help="Updates the message files for all existing locales.", ) parser.add_argument( - '--extension', '-e', dest='extensions', action='append', + "--extension", + "-e", + dest="extensions", + action="append", help='The file extension(s) to examine (default: "html,txt,py", or "js" ' - 'if the domain is "djangojs"). Separate multiple extensions with ' - 'commas, or use -e multiple times.', + 'if the domain is "djangojs"). Separate multiple extensions with ' + "commas, or use -e multiple times.", ) parser.add_argument( - '--symlinks', '-s', action='store_true', - help='Follows symlinks to directories when examining source code ' - 'and templates for translation strings.', + "--symlinks", + "-s", + action="store_true", + help="Follows symlinks to directories when examining source code " + "and templates for translation strings.", ) parser.add_argument( - '--ignore', '-i', action='append', dest='ignore_patterns', - default=[], metavar='PATTERN', - help='Ignore files or directories matching this glob-style pattern. ' - 'Use multiple times to ignore more.', + "--ignore", + "-i", + action="append", + dest="ignore_patterns", + default=[], + metavar="PATTERN", + help="Ignore files or directories matching this glob-style pattern. " + "Use multiple times to ignore more.", ) parser.add_argument( - '--no-default-ignore', action='store_false', dest='use_default_ignore_patterns', + "--no-default-ignore", + action="store_false", + dest="use_default_ignore_patterns", help="Don't ignore the common glob-style patterns 'CVS', '.*', '*~' and '*.pyc'.", ) parser.add_argument( - '--no-wrap', action='store_true', + "--no-wrap", + action="store_true", help="Don't break long message lines into several lines.", ) parser.add_argument( - '--no-location', action='store_true', + "--no-location", + action="store_true", help="Don't write '#: filename:line' lines.", ) parser.add_argument( - '--add-location', - choices=('full', 'file', 'never'), const='full', nargs='?', + "--add-location", + choices=("full", "file", "never"), + const="full", + nargs="?", help=( "Controls '#: filename:line' lines. If the option is 'full' " "(the default if not given), the lines include both file name " @@ -273,61 +304,65 @@ class Command(BaseCommand): ), ) parser.add_argument( - '--no-obsolete', action='store_true', + "--no-obsolete", + action="store_true", help="Remove obsolete message strings.", ) parser.add_argument( - '--keep-pot', action='store_true', + "--keep-pot", + action="store_true", help="Keep .pot file after making messages. Useful when debugging.", ) def handle(self, *args, **options): - locale = options['locale'] - exclude = options['exclude'] - self.domain = options['domain'] - self.verbosity = options['verbosity'] - process_all = options['all'] - extensions = options['extensions'] - self.symlinks = options['symlinks'] + locale = options["locale"] + exclude = options["exclude"] + self.domain = options["domain"] + self.verbosity = options["verbosity"] + process_all = options["all"] + extensions = options["extensions"] + self.symlinks = options["symlinks"] - ignore_patterns = options['ignore_patterns'] - if options['use_default_ignore_patterns']: - ignore_patterns += ['CVS', '.*', '*~', '*.pyc'] + ignore_patterns = options["ignore_patterns"] + if options["use_default_ignore_patterns"]: + ignore_patterns += ["CVS", ".*", "*~", "*.pyc"] self.ignore_patterns = list(set(ignore_patterns)) # Avoid messing with mutable class variables - if options['no_wrap']: - self.msgmerge_options = self.msgmerge_options[:] + ['--no-wrap'] - self.msguniq_options = self.msguniq_options[:] + ['--no-wrap'] - self.msgattrib_options = self.msgattrib_options[:] + ['--no-wrap'] - self.xgettext_options = self.xgettext_options[:] + ['--no-wrap'] - if options['no_location']: - self.msgmerge_options = self.msgmerge_options[:] + ['--no-location'] - self.msguniq_options = self.msguniq_options[:] + ['--no-location'] - self.msgattrib_options = self.msgattrib_options[:] + ['--no-location'] - self.xgettext_options = self.xgettext_options[:] + ['--no-location'] - if options['add_location']: + if options["no_wrap"]: + self.msgmerge_options = self.msgmerge_options[:] + ["--no-wrap"] + self.msguniq_options = self.msguniq_options[:] + ["--no-wrap"] + self.msgattrib_options = self.msgattrib_options[:] + ["--no-wrap"] + self.xgettext_options = self.xgettext_options[:] + ["--no-wrap"] + if options["no_location"]: + self.msgmerge_options = self.msgmerge_options[:] + ["--no-location"] + self.msguniq_options = self.msguniq_options[:] + ["--no-location"] + self.msgattrib_options = self.msgattrib_options[:] + ["--no-location"] + self.xgettext_options = self.xgettext_options[:] + ["--no-location"] + if options["add_location"]: if self.gettext_version < (0, 19): raise CommandError( "The --add-location option requires gettext 0.19 or later. " - "You have %s." % '.'.join(str(x) for x in self.gettext_version) + "You have %s." % ".".join(str(x) for x in self.gettext_version) ) - arg_add_location = "--add-location=%s" % options['add_location'] + arg_add_location = "--add-location=%s" % options["add_location"] self.msgmerge_options = self.msgmerge_options[:] + [arg_add_location] self.msguniq_options = self.msguniq_options[:] + [arg_add_location] self.msgattrib_options = self.msgattrib_options[:] + [arg_add_location] self.xgettext_options = self.xgettext_options[:] + [arg_add_location] - self.no_obsolete = options['no_obsolete'] - self.keep_pot = options['keep_pot'] + self.no_obsolete = options["no_obsolete"] + self.keep_pot = options["keep_pot"] - if self.domain not in ('django', 'djangojs'): - raise CommandError("currently makemessages only supports domains " - "'django' and 'djangojs'") - if self.domain == 'djangojs': - exts = extensions or ['js'] + if self.domain not in ("django", "djangojs"): + raise CommandError( + "currently makemessages only supports domains " + "'django' and 'djangojs'" + ) + if self.domain == "djangojs": + exts = extensions or ["js"] else: - exts = extensions or ['html', 'txt', 'py'] + exts = extensions or ["html", "txt", "py"] self.extensions = handle_extensions(exts) if (not locale and not exclude and not process_all) or self.domain is None: @@ -338,32 +373,35 @@ class Command(BaseCommand): if self.verbosity > 1: self.stdout.write( - 'examining files with the extensions: %s' - % get_text_list(list(self.extensions), 'and') + "examining files with the extensions: %s" + % get_text_list(list(self.extensions), "and") ) self.invoked_for_django = False self.locale_paths = [] self.default_locale_path = None - if os.path.isdir(os.path.join('conf', 'locale')): - self.locale_paths = [os.path.abspath(os.path.join('conf', 'locale'))] + if os.path.isdir(os.path.join("conf", "locale")): + self.locale_paths = [os.path.abspath(os.path.join("conf", "locale"))] self.default_locale_path = self.locale_paths[0] self.invoked_for_django = True else: if self.settings_available: self.locale_paths.extend(settings.LOCALE_PATHS) # Allow to run makemessages inside an app dir - if os.path.isdir('locale'): - self.locale_paths.append(os.path.abspath('locale')) + if os.path.isdir("locale"): + self.locale_paths.append(os.path.abspath("locale")) if self.locale_paths: self.default_locale_path = self.locale_paths[0] os.makedirs(self.default_locale_path, exist_ok=True) # Build locale list - looks_like_locale = re.compile(r'[a-z]{2}') - locale_dirs = filter(os.path.isdir, glob.glob('%s/*' % self.default_locale_path)) + looks_like_locale = re.compile(r"[a-z]{2}") + locale_dirs = filter( + os.path.isdir, glob.glob("%s/*" % self.default_locale_path) + ) all_locales = [ - lang_code for lang_code in map(os.path.basename, locale_dirs) + lang_code + for lang_code in map(os.path.basename, locale_dirs) if looks_like_locale.match(lang_code) ] @@ -375,25 +413,26 @@ class Command(BaseCommand): locales = set(locales).difference(exclude) if locales: - check_programs('msguniq', 'msgmerge', 'msgattrib') + check_programs("msguniq", "msgmerge", "msgattrib") - check_programs('xgettext') + check_programs("xgettext") try: potfiles = self.build_potfiles() # Build po files for each selected locale for locale in locales: - if '-' in locale: + if "-" in locale: self.stdout.write( - 'invalid locale %s, did you mean %s?' % ( + "invalid locale %s, did you mean %s?" + % ( locale, - locale.replace('-', '_'), + locale.replace("-", "_"), ), ) continue if self.verbosity > 0: - self.stdout.write('processing locale %s' % locale) + self.stdout.write("processing locale %s" % locale) for potfile in potfiles: self.write_po_file(potfile, locale) finally: @@ -405,10 +444,10 @@ class Command(BaseCommand): # Gettext tools will output system-encoded bytestrings instead of UTF-8, # when looking up the version. It's especially a problem on Windows. out, err, status = popen_wrapper( - ['xgettext', '--version'], + ["xgettext", "--version"], stdout_encoding=DEFAULT_LOCALE_ENCODING, ) - m = re.search(r'(\d+)\.(\d+)\.?(\d+)?', out) + m = re.search(r"(\d+)\.(\d+)\.?(\d+)?", out) if m: return tuple(int(d) for d in m.groups() if d is not None) else: @@ -433,26 +472,27 @@ class Command(BaseCommand): self.process_files(file_list) potfiles = [] for path in self.locale_paths: - potfile = os.path.join(path, '%s.pot' % self.domain) + potfile = os.path.join(path, "%s.pot" % self.domain) if not os.path.exists(potfile): continue - args = ['msguniq'] + self.msguniq_options + [potfile] + args = ["msguniq"] + self.msguniq_options + [potfile] msgs, errors, status = popen_wrapper(args) if errors: if status != STATUS_OK: raise CommandError( - "errors happened while running msguniq\n%s" % errors) + "errors happened while running msguniq\n%s" % errors + ) elif self.verbosity > 0: self.stdout.write(errors) msgs = normalize_eols(msgs) - with open(potfile, 'w', encoding='utf-8') as fp: + with open(potfile, "w", encoding="utf-8") as fp: fp.write(msgs) potfiles.append(potfile) return potfiles def remove_potfiles(self): for path in self.locale_paths: - pot_path = os.path.join(path, '%s.pot' % self.domain) + pot_path = os.path.join(path, "%s.pot" % self.domain) if os.path.exists(pot_path): os.unlink(pot_path) @@ -464,23 +504,40 @@ class Command(BaseCommand): all_files = [] ignored_roots = [] if self.settings_available: - ignored_roots = [os.path.normpath(p) for p in (settings.MEDIA_ROOT, settings.STATIC_ROOT) if p] - for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=self.symlinks): + ignored_roots = [ + os.path.normpath(p) + for p in (settings.MEDIA_ROOT, settings.STATIC_ROOT) + if p + ] + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=self.symlinks + ): for dirname in dirnames[:]: - if (is_ignored_path(os.path.normpath(os.path.join(dirpath, dirname)), self.ignore_patterns) or - os.path.join(os.path.abspath(dirpath), dirname) in ignored_roots): + if ( + is_ignored_path( + os.path.normpath(os.path.join(dirpath, dirname)), + self.ignore_patterns, + ) + or os.path.join(os.path.abspath(dirpath), dirname) in ignored_roots + ): dirnames.remove(dirname) if self.verbosity > 1: - self.stdout.write('ignoring directory %s' % dirname) - elif dirname == 'locale': + self.stdout.write("ignoring directory %s" % dirname) + elif dirname == "locale": dirnames.remove(dirname) - self.locale_paths.insert(0, os.path.join(os.path.abspath(dirpath), dirname)) + self.locale_paths.insert( + 0, os.path.join(os.path.abspath(dirpath), dirname) + ) for filename in filenames: file_path = os.path.normpath(os.path.join(dirpath, filename)) file_ext = os.path.splitext(filename)[1] - if file_ext not in self.extensions or is_ignored_path(file_path, self.ignore_patterns): + if file_ext not in self.extensions or is_ignored_path( + file_path, self.ignore_patterns + ): if self.verbosity > 1: - self.stdout.write('ignoring file %s in %s' % (filename, dirpath)) + self.stdout.write( + "ignoring file %s in %s" % (filename, dirpath) + ) else: locale_dir = None for path in self.locale_paths: @@ -488,7 +545,9 @@ class Command(BaseCommand): locale_dir = path break locale_dir = locale_dir or self.default_locale_path or NO_LOCALE_DIR - all_files.append(self.translatable_file_class(dirpath, filename, locale_dir)) + all_files.append( + self.translatable_file_class(dirpath, filename, locale_dir) + ) return sorted(all_files) def process_files(self, file_list): @@ -513,18 +572,22 @@ class Command(BaseCommand): build_files = [] for translatable in files: if self.verbosity > 1: - self.stdout.write('processing file %s in %s' % ( - translatable.file, translatable.dirpath - )) - if self.domain not in ('djangojs', 'django'): + self.stdout.write( + "processing file %s in %s" + % (translatable.file, translatable.dirpath) + ) + if self.domain not in ("djangojs", "django"): continue build_file = self.build_file_class(self, self.domain, translatable) try: build_file.preprocess() except UnicodeDecodeError as e: self.stdout.write( - 'UnicodeDecodeError: skipped file %s in %s (reason: %s)' % ( - translatable.file, translatable.dirpath, e, + "UnicodeDecodeError: skipped file %s in %s (reason: %s)" + % ( + translatable.file, + translatable.dirpath, + e, ) ) continue @@ -535,41 +598,43 @@ class Command(BaseCommand): raise build_files.append(build_file) - if self.domain == 'djangojs': + if self.domain == "djangojs": is_templatized = build_file.is_templatized args = [ - 'xgettext', - '-d', self.domain, - '--language=%s' % ('C' if is_templatized else 'JavaScript',), - '--keyword=gettext_noop', - '--keyword=gettext_lazy', - '--keyword=ngettext_lazy:1,2', - '--keyword=pgettext:1c,2', - '--keyword=npgettext:1c,2,3', - '--output=-', + "xgettext", + "-d", + self.domain, + "--language=%s" % ("C" if is_templatized else "JavaScript",), + "--keyword=gettext_noop", + "--keyword=gettext_lazy", + "--keyword=ngettext_lazy:1,2", + "--keyword=pgettext:1c,2", + "--keyword=npgettext:1c,2,3", + "--output=-", ] - elif self.domain == 'django': + elif self.domain == "django": args = [ - 'xgettext', - '-d', self.domain, - '--language=Python', - '--keyword=gettext_noop', - '--keyword=gettext_lazy', - '--keyword=ngettext_lazy:1,2', - '--keyword=pgettext:1c,2', - '--keyword=npgettext:1c,2,3', - '--keyword=pgettext_lazy:1c,2', - '--keyword=npgettext_lazy:1c,2,3', - '--output=-', + "xgettext", + "-d", + self.domain, + "--language=Python", + "--keyword=gettext_noop", + "--keyword=gettext_lazy", + "--keyword=ngettext_lazy:1,2", + "--keyword=pgettext:1c,2", + "--keyword=npgettext:1c,2,3", + "--keyword=pgettext_lazy:1c,2", + "--keyword=npgettext_lazy:1c,2,3", + "--output=-", ] else: return input_files = [bf.work_path for bf in build_files] - with NamedTemporaryFile(mode='w+') as input_files_list: - input_files_list.write('\n'.join(input_files)) + with NamedTemporaryFile(mode="w+") as input_files_list: + input_files_list.write("\n".join(input_files)) input_files_list.flush() - args.extend(['--files-from', input_files_list.name]) + args.extend(["--files-from", input_files_list.name]) args.extend(self.xgettext_options) msgs, errors, status = popen_wrapper(args) @@ -578,8 +643,8 @@ class Command(BaseCommand): for build_file in build_files: build_file.cleanup() raise CommandError( - 'errors happened while running xgettext on %s\n%s' % - ('\n'.join(input_files), errors) + "errors happened while running xgettext on %s\n%s" + % ("\n".join(input_files), errors) ) elif self.verbosity > 0: # Print warnings @@ -597,7 +662,7 @@ class Command(BaseCommand): ) for build_file in build_files: msgs = build_file.postprocess_messages(msgs) - potfile = os.path.join(locale_dir, '%s.pot' % self.domain) + potfile = os.path.join(locale_dir, "%s.pot" % self.domain) write_pot_file(potfile, msgs) for build_file in build_files: @@ -610,38 +675,41 @@ class Command(BaseCommand): Use msgmerge and msgattrib GNU gettext utilities. """ - basedir = os.path.join(os.path.dirname(potfile), locale, 'LC_MESSAGES') + basedir = os.path.join(os.path.dirname(potfile), locale, "LC_MESSAGES") os.makedirs(basedir, exist_ok=True) - pofile = os.path.join(basedir, '%s.po' % self.domain) + pofile = os.path.join(basedir, "%s.po" % self.domain) if os.path.exists(pofile): - args = ['msgmerge'] + self.msgmerge_options + [pofile, potfile] + args = ["msgmerge"] + self.msgmerge_options + [pofile, potfile] _, errors, status = popen_wrapper(args) if errors: if status != STATUS_OK: raise CommandError( - "errors happened while running msgmerge\n%s" % errors) + "errors happened while running msgmerge\n%s" % errors + ) elif self.verbosity > 0: self.stdout.write(errors) - msgs = Path(pofile).read_text(encoding='utf-8') + msgs = Path(pofile).read_text(encoding="utf-8") else: - with open(potfile, encoding='utf-8') as fp: + with open(potfile, encoding="utf-8") as fp: msgs = fp.read() if not self.invoked_for_django: msgs = self.copy_plural_forms(msgs, locale) msgs = normalize_eols(msgs) msgs = msgs.replace( - "#. #-#-#-#-# %s.pot (PACKAGE VERSION) #-#-#-#-#\n" % self.domain, "") - with open(pofile, 'w', encoding='utf-8') as fp: + "#. #-#-#-#-# %s.pot (PACKAGE VERSION) #-#-#-#-#\n" % self.domain, "" + ) + with open(pofile, "w", encoding="utf-8") as fp: fp.write(msgs) if self.no_obsolete: - args = ['msgattrib'] + self.msgattrib_options + ['-o', pofile, pofile] + args = ["msgattrib"] + self.msgattrib_options + ["-o", pofile, pofile] msgs, errors, status = popen_wrapper(args) if errors: if status != STATUS_OK: raise CommandError( - "errors happened while running msgattrib\n%s" % errors) + "errors happened while running msgattrib\n%s" % errors + ) elif self.verbosity > 0: self.stdout.write(errors) @@ -652,19 +720,21 @@ class Command(BaseCommand): contents of a newly created .po file. """ django_dir = os.path.normpath(os.path.join(os.path.dirname(django.__file__))) - if self.domain == 'djangojs': - domains = ('djangojs', 'django') + if self.domain == "djangojs": + domains = ("djangojs", "django") else: - domains = ('django',) + domains = ("django",) for domain in domains: - django_po = os.path.join(django_dir, 'conf', 'locale', locale, 'LC_MESSAGES', '%s.po' % domain) + django_po = os.path.join( + django_dir, "conf", "locale", locale, "LC_MESSAGES", "%s.po" % domain + ) if os.path.exists(django_po): - with open(django_po, encoding='utf-8') as fp: + with open(django_po, encoding="utf-8") as fp: m = plural_forms_re.search(fp.read()) if m: - plural_form_line = m['value'] + plural_form_line = m["value"] if self.verbosity > 1: - self.stdout.write('copying plural forms: %s' % plural_form_line) + self.stdout.write("copying plural forms: %s" % plural_form_line) lines = [] found = False for line in msgs.splitlines(): @@ -672,6 +742,6 @@ class Command(BaseCommand): line = plural_form_line found = True lines.append(line) - msgs = '\n'.join(lines) + msgs = "\n".join(lines) break return msgs diff --git a/django/core/management/commands/makemigrations.py b/django/core/management/commands/makemigrations.py index 4349f33a61..325848d8b2 100644 --- a/django/core/management/commands/makemigrations.py +++ b/django/core/management/commands/makemigrations.py @@ -5,15 +5,14 @@ from itertools import takewhile from django.apps import apps from django.conf import settings -from django.core.management.base import ( - BaseCommand, CommandError, no_translations, -) +from django.core.management.base import BaseCommand, CommandError, no_translations from django.db import DEFAULT_DB_ALIAS, OperationalError, connections, router from django.db.migrations import Migration from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.loader import MigrationLoader from django.db.migrations.questioner import ( - InteractiveMigrationQuestioner, MigrationQuestioner, + InteractiveMigrationQuestioner, + MigrationQuestioner, NonInteractiveMigrationQuestioner, ) from django.db.migrations.state import ProjectState @@ -26,42 +25,57 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'args', metavar='app_label', nargs='*', - help='Specify the app label(s) to create migrations for.', + "args", + metavar="app_label", + nargs="*", + help="Specify the app label(s) to create migrations for.", ) parser.add_argument( - '--dry-run', action='store_true', + "--dry-run", + action="store_true", help="Just show what migrations would be made; don't actually write them.", ) parser.add_argument( - '--merge', action='store_true', + "--merge", + action="store_true", help="Enable fixing of migration conflicts.", ) parser.add_argument( - '--empty', action='store_true', + "--empty", + action="store_true", help="Create an empty migration.", ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '-n', '--name', + "-n", + "--name", help="Use this name for migration file(s).", ) parser.add_argument( - '--no-header', action='store_false', dest='include_header', - help='Do not add header comments to new migration file(s).', + "--no-header", + action="store_false", + dest="include_header", + help="Do not add header comments to new migration file(s).", ) parser.add_argument( - '--check', action='store_true', dest='check_changes', - help='Exit with a non-zero status if model changes are missing migrations.', + "--check", + action="store_true", + dest="check_changes", + help="Exit with a non-zero status if model changes are missing migrations.", ) parser.add_argument( - '--scriptable', action='store_true', dest='scriptable', + "--scriptable", + action="store_true", + dest="scriptable", help=( - 'Divert log output and input prompts to stderr, writing only ' - 'paths of generated migration files to stdout.' + "Divert log output and input prompts to stderr, writing only " + "paths of generated migration files to stdout." ), ) @@ -74,17 +88,17 @@ class Command(BaseCommand): @no_translations def handle(self, *app_labels, **options): - self.verbosity = options['verbosity'] - self.interactive = options['interactive'] - self.dry_run = options['dry_run'] - self.merge = options['merge'] - self.empty = options['empty'] - self.migration_name = options['name'] + self.verbosity = options["verbosity"] + self.interactive = options["interactive"] + self.dry_run = options["dry_run"] + self.merge = options["merge"] + self.empty = options["empty"] + self.migration_name = options["name"] if self.migration_name and not self.migration_name.isidentifier(): - raise CommandError('The migration name must be a valid Python identifier.') - self.include_header = options['include_header'] - check_changes = options['check_changes'] - self.scriptable = options['scriptable'] + raise CommandError("The migration name must be a valid Python identifier.") + self.include_header = options["include_header"] + check_changes = options["check_changes"] + self.scriptable = options["scriptable"] # If logs and prompts are diverted to stderr, remove the ERROR style. if self.scriptable: self.stderr.style_func = None @@ -108,22 +122,25 @@ class Command(BaseCommand): # Raise an error if any migrations are applied before their dependencies. consistency_check_labels = {config.label for config in apps.get_app_configs()} # Non-default databases are only checked if database routers used. - aliases_to_check = connections if settings.DATABASE_ROUTERS else [DEFAULT_DB_ALIAS] + aliases_to_check = ( + connections if settings.DATABASE_ROUTERS else [DEFAULT_DB_ALIAS] + ) for alias in sorted(aliases_to_check): connection = connections[alias] - if (connection.settings_dict['ENGINE'] != 'django.db.backends.dummy' and any( - # At least one model must be migrated to the database. - router.allow_migrate(connection.alias, app_label, model_name=model._meta.object_name) - for app_label in consistency_check_labels - for model in apps.get_app_config(app_label).get_models() - )): + if connection.settings_dict["ENGINE"] != "django.db.backends.dummy" and any( + # At least one model must be migrated to the database. + router.allow_migrate( + connection.alias, app_label, model_name=model._meta.object_name + ) + for app_label in consistency_check_labels + for model in apps.get_app_config(app_label).get_models() + ): try: loader.check_consistent_history(connection) except OperationalError as error: warnings.warn( "Got an error checking a consistent migration history " - "performed for database connection '%s': %s" - % (alias, error), + "performed for database connection '%s': %s" % (alias, error), RuntimeWarning, ) # Before anything else, see if there's conflicting apps and drop out @@ -133,14 +150,14 @@ class Command(BaseCommand): # If app_labels is specified, filter out conflicting migrations for unspecified apps if app_labels: conflicts = { - app_label: conflict for app_label, conflict in conflicts.items() + app_label: conflict + for app_label, conflict in conflicts.items() if app_label in app_labels } if conflicts and not self.merge: name_str = "; ".join( - "%s in %s" % (", ".join(names), app) - for app, names in conflicts.items() + "%s in %s" % (", ".join(names), app) for app, names in conflicts.items() ) raise CommandError( "Conflicting migrations detected; multiple leaf nodes in the " @@ -150,7 +167,7 @@ class Command(BaseCommand): # If they want to merge and there's nothing to merge, then politely exit if self.merge and not conflicts: - self.log('No conflicts detected to merge.') + self.log("No conflicts detected to merge.") return # If they want to merge and there is something to merge, then @@ -181,12 +198,11 @@ class Command(BaseCommand): # If they want to make an empty migration, make one for each app if self.empty: if not app_labels: - raise CommandError("You must supply at least one app label when using --empty.") + raise CommandError( + "You must supply at least one app label when using --empty." + ) # Make a fake changes() result we can pass to arrange_for_graph - changes = { - app: [Migration("custom", app)] - for app in app_labels - } + changes = {app: [Migration("custom", app)] for app in app_labels} changes = autodetector.arrange_for_graph( changes=changes, graph=loader.graph, @@ -210,9 +226,12 @@ class Command(BaseCommand): if len(app_labels) == 1: self.log("No changes detected in app '%s'" % app_labels.pop()) else: - self.log("No changes detected in apps '%s'" % ("', '".join(app_labels))) + self.log( + "No changes detected in apps '%s'" + % ("', '".join(app_labels)) + ) else: - self.log('No changes detected') + self.log("No changes detected") else: self.write_migration_files(changes) if check_changes: @@ -236,11 +255,11 @@ class Command(BaseCommand): migration_string = os.path.relpath(writer.path) except ValueError: migration_string = writer.path - if migration_string.startswith('..'): + if migration_string.startswith(".."): migration_string = writer.path - self.log(' %s\n' % self.style.MIGRATE_LABEL(migration_string)) + self.log(" %s\n" % self.style.MIGRATE_LABEL(migration_string)) for operation in migration.operations: - self.log(' - %s' % operation.describe()) + self.log(" - %s" % operation.describe()) if self.scriptable: self.stdout.write(migration_string) if not self.dry_run: @@ -254,15 +273,17 @@ class Command(BaseCommand): # We just do this once per app directory_created[app_label] = True migration_string = writer.as_string() - with open(writer.path, "w", encoding='utf-8') as fh: + with open(writer.path, "w", encoding="utf-8") as fh: fh.write(migration_string) elif self.verbosity == 3: # Alternatively, makemigrations --dry-run --verbosity 3 # will log the migrations rather than saving the file to # the disk. - self.log(self.style.MIGRATE_HEADING( - "Full migrations file '%s':" % writer.filename - )) + self.log( + self.style.MIGRATE_HEADING( + "Full migrations file '%s':" % writer.filename + ) + ) self.log(writer.as_string()) def handle_merge(self, loader, conflicts): @@ -273,7 +294,7 @@ class Command(BaseCommand): if self.interactive: questioner = InteractiveMigrationQuestioner(prompt_output=self.log_output) else: - questioner = MigrationQuestioner(defaults={'ask_merge': True}) + questioner = MigrationQuestioner(defaults={"ask_merge": True}) for app_label, migration_names in conflicts.items(): # Grab out the migrations in question, and work out their @@ -282,7 +303,8 @@ class Command(BaseCommand): for migration_name in migration_names: migration = loader.get_migration(app_label, migration_name) migration.ancestry = [ - mig for mig in loader.graph.forwards_plan((app_label, migration_name)) + mig + for mig in loader.graph.forwards_plan((app_label, migration_name)) if mig[0] == migration.app_label ] merge_migrations.append(migration) @@ -291,25 +313,33 @@ class Command(BaseCommand): return all(item == seq[0] for item in seq[1:]) merge_migrations_generations = zip(*(m.ancestry for m in merge_migrations)) - common_ancestor_count = sum(1 for common_ancestor_generation - in takewhile(all_items_equal, merge_migrations_generations)) + common_ancestor_count = sum( + 1 + for common_ancestor_generation in takewhile( + all_items_equal, merge_migrations_generations + ) + ) if not common_ancestor_count: - raise ValueError("Could not find common ancestor of %s" % migration_names) + raise ValueError( + "Could not find common ancestor of %s" % migration_names + ) # Now work out the operations along each divergent branch for migration in merge_migrations: migration.branch = migration.ancestry[common_ancestor_count:] - migrations_ops = (loader.get_migration(node_app, node_name).operations - for node_app, node_name in migration.branch) + migrations_ops = ( + loader.get_migration(node_app, node_name).operations + for node_app, node_name in migration.branch + ) migration.merged_operations = sum(migrations_ops, []) # In future, this could use some of the Optimizer code # (can_optimize_through) to automatically see if they're # mergeable. For now, we always just prompt the user. if self.verbosity > 0: - self.log(self.style.MIGRATE_HEADING('Merging %s' % app_label)) + self.log(self.style.MIGRATE_HEADING("Merging %s" % app_label)) for migration in merge_migrations: - self.log(self.style.MIGRATE_LABEL(' Branch %s' % migration.name)) + self.log(self.style.MIGRATE_LABEL(" Branch %s" % migration.name)) for operation in migration.merged_operations: - self.log(' - %s' % operation.describe()) + self.log(" - %s" % operation.describe()) if questioner.ask_merge(app_label): # If they still want to merge it, then write out an empty # file depending on the migrations needing merging. @@ -321,36 +351,47 @@ class Command(BaseCommand): biggest_number = max(x for x in numbers if x is not None) except ValueError: biggest_number = 1 - subclass = type("Migration", (Migration,), { - "dependencies": [(app_label, migration.name) for migration in merge_migrations], - }) - parts = ['%04i' % (biggest_number + 1)] + subclass = type( + "Migration", + (Migration,), + { + "dependencies": [ + (app_label, migration.name) + for migration in merge_migrations + ], + }, + ) + parts = ["%04i" % (biggest_number + 1)] if self.migration_name: parts.append(self.migration_name) else: - parts.append('merge') - leaf_names = '_'.join(sorted(migration.name for migration in merge_migrations)) + parts.append("merge") + leaf_names = "_".join( + sorted(migration.name for migration in merge_migrations) + ) if len(leaf_names) > 47: parts.append(get_migration_name_timestamp()) else: parts.append(leaf_names) - migration_name = '_'.join(parts) + migration_name = "_".join(parts) new_migration = subclass(migration_name, app_label) writer = MigrationWriter(new_migration, self.include_header) if not self.dry_run: # Write the merge migrations file to the disk - with open(writer.path, "w", encoding='utf-8') as fh: + with open(writer.path, "w", encoding="utf-8") as fh: fh.write(writer.as_string()) if self.verbosity > 0: - self.log('\nCreated new merge migration %s' % writer.path) + self.log("\nCreated new merge migration %s" % writer.path) if self.scriptable: self.stdout.write(writer.path) elif self.verbosity == 3: # Alternatively, makemigrations --merge --dry-run --verbosity 3 # will log the merge migrations rather than saving the file # to the disk. - self.log(self.style.MIGRATE_HEADING( - "Full merge migrations file '%s':" % writer.filename - )) + self.log( + self.style.MIGRATE_HEADING( + "Full merge migrations file '%s':" % writer.filename + ) + ) self.log(writer.as_string()) diff --git a/django/core/management/commands/migrate.py b/django/core/management/commands/migrate.py index a4ad1f3e20..59fd1f0d55 100644 --- a/django/core/management/commands/migrate.py +++ b/django/core/management/commands/migrate.py @@ -3,12 +3,8 @@ import time from importlib import import_module from django.apps import apps -from django.core.management.base import ( - BaseCommand, CommandError, no_translations, -) -from django.core.management.sql import ( - emit_post_migrate_signal, emit_pre_migrate_signal, -) +from django.core.management.base import BaseCommand, CommandError, no_translations +from django.core.management.sql import emit_post_migrate_signal, emit_pre_migrate_signal from django.db import DEFAULT_DB_ALIAS, connections, router from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.executor import MigrationExecutor @@ -19,73 +15,89 @@ from django.utils.text import Truncator class Command(BaseCommand): - help = "Updates database schema. Manages both apps with migrations and those without." + help = ( + "Updates database schema. Manages both apps with migrations and those without." + ) requires_system_checks = [] def add_arguments(self, parser): parser.add_argument( - '--skip-checks', action='store_true', - help='Skip system checks.', + "--skip-checks", + action="store_true", + help="Skip system checks.", ) parser.add_argument( - 'app_label', nargs='?', - help='App label of an application to synchronize the state.', + "app_label", + nargs="?", + help="App label of an application to synchronize the state.", ) parser.add_argument( - 'migration_name', nargs='?', - help='Database state will be brought to the state after that ' - 'migration. Use the name "zero" to unapply all migrations.', + "migration_name", + nargs="?", + help="Database state will be brought to the state after that " + 'migration. Use the name "zero" to unapply all migrations.', ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--database', + "--database", default=DEFAULT_DB_ALIAS, help='Nominates a database to synchronize. Defaults to the "default" database.', ) parser.add_argument( - '--fake', action='store_true', - help='Mark migrations as run without actually running them.', + "--fake", + action="store_true", + help="Mark migrations as run without actually running them.", ) parser.add_argument( - '--fake-initial', action='store_true', - help='Detect if tables already exist and fake-apply initial migrations if so. Make sure ' - 'that the current database schema matches your initial migration before using this ' - 'flag. Django will only check for an existing table name.', + "--fake-initial", + action="store_true", + help="Detect if tables already exist and fake-apply initial migrations if so. Make sure " + "that the current database schema matches your initial migration before using this " + "flag. Django will only check for an existing table name.", ) parser.add_argument( - '--plan', action='store_true', - help='Shows a list of the migration actions that will be performed.', + "--plan", + action="store_true", + help="Shows a list of the migration actions that will be performed.", ) parser.add_argument( - '--run-syncdb', action='store_true', - help='Creates tables for apps without migrations.', + "--run-syncdb", + action="store_true", + help="Creates tables for apps without migrations.", ) parser.add_argument( - '--check', action='store_true', dest='check_unapplied', - help='Exits with a non-zero status if unapplied migrations exist.', + "--check", + action="store_true", + dest="check_unapplied", + help="Exits with a non-zero status if unapplied migrations exist.", ) parser.add_argument( - '--prune', action='store_true', dest='prune', - help='Delete nonexistent migrations from the django_migrations table.', + "--prune", + action="store_true", + dest="prune", + help="Delete nonexistent migrations from the django_migrations table.", ) @no_translations def handle(self, *args, **options): - database = options['database'] - if not options['skip_checks']: + database = options["database"] + if not options["skip_checks"]: self.check(databases=[database]) - self.verbosity = options['verbosity'] - self.interactive = options['interactive'] + self.verbosity = options["verbosity"] + self.interactive = options["interactive"] # Import the 'management' module within each installed app, to register # dispatcher events. for app_config in apps.get_app_configs(): if module_has_submodule(app_config.module, "management"): - import_module('.management', app_config.name) + import_module(".management", app_config.name) # Get the database we're operating from connection = connections[database] @@ -103,8 +115,7 @@ class Command(BaseCommand): conflicts = executor.loader.detect_conflicts() if conflicts: name_str = "; ".join( - "%s in %s" % (", ".join(names), app) - for app, names in conflicts.items() + "%s in %s" % (", ".join(names), app) for app, names in conflicts.items() ) raise CommandError( "Conflicting migrations detected; multiple leaf nodes in the " @@ -113,163 +124,185 @@ class Command(BaseCommand): ) # If they supplied command line arguments, work out what they mean. - run_syncdb = options['run_syncdb'] + run_syncdb = options["run_syncdb"] target_app_labels_only = True - if options['app_label']: + if options["app_label"]: # Validate app_label. - app_label = options['app_label'] + app_label = options["app_label"] try: apps.get_app_config(app_label) except LookupError as err: raise CommandError(str(err)) if run_syncdb: if app_label in executor.loader.migrated_apps: - raise CommandError("Can't use run_syncdb with app '%s' as it has migrations." % app_label) + raise CommandError( + "Can't use run_syncdb with app '%s' as it has migrations." + % app_label + ) elif app_label not in executor.loader.migrated_apps: raise CommandError("App '%s' does not have migrations." % app_label) - if options['app_label'] and options['migration_name']: - migration_name = options['migration_name'] + if options["app_label"] and options["migration_name"]: + migration_name = options["migration_name"] if migration_name == "zero": targets = [(app_label, None)] else: try: - migration = executor.loader.get_migration_by_prefix(app_label, migration_name) + migration = executor.loader.get_migration_by_prefix( + app_label, migration_name + ) except AmbiguityError: raise CommandError( "More than one migration matches '%s' in app '%s'. " - "Please be more specific." % - (migration_name, app_label) + "Please be more specific." % (migration_name, app_label) ) except KeyError: - raise CommandError("Cannot find a migration matching '%s' from app '%s'." % ( - migration_name, app_label)) + raise CommandError( + "Cannot find a migration matching '%s' from app '%s'." + % (migration_name, app_label) + ) target = (app_label, migration.name) # Partially applied squashed migrations are not included in the # graph, use the last replacement instead. if ( - target not in executor.loader.graph.nodes and - target in executor.loader.replacements + target not in executor.loader.graph.nodes + and target in executor.loader.replacements ): incomplete_migration = executor.loader.replacements[target] target = incomplete_migration.replaces[-1] targets = [target] target_app_labels_only = False - elif options['app_label']: - targets = [key for key in executor.loader.graph.leaf_nodes() if key[0] == app_label] + elif options["app_label"]: + targets = [ + key for key in executor.loader.graph.leaf_nodes() if key[0] == app_label + ] else: targets = executor.loader.graph.leaf_nodes() - if options['prune']: - if not options['app_label']: + if options["prune"]: + if not options["app_label"]: raise CommandError( - 'Migrations can be pruned only when an app is specified.' + "Migrations can be pruned only when an app is specified." ) if self.verbosity > 0: - self.stdout.write('Pruning migrations:', self.style.MIGRATE_HEADING) - to_prune = set(executor.loader.applied_migrations) - set(executor.loader.disk_migrations) + self.stdout.write("Pruning migrations:", self.style.MIGRATE_HEADING) + to_prune = set(executor.loader.applied_migrations) - set( + executor.loader.disk_migrations + ) squashed_migrations_with_deleted_replaced_migrations = [ migration_key for migration_key, migration_obj in executor.loader.replacements.items() if any(replaced in to_prune for replaced in migration_obj.replaces) ] if squashed_migrations_with_deleted_replaced_migrations: - self.stdout.write(self.style.NOTICE( - " Cannot use --prune because the following squashed " - "migrations have their 'replaces' attributes and may not " - "be recorded as applied:" - )) + self.stdout.write( + self.style.NOTICE( + " Cannot use --prune because the following squashed " + "migrations have their 'replaces' attributes and may not " + "be recorded as applied:" + ) + ) for migration in squashed_migrations_with_deleted_replaced_migrations: app, name = migration - self.stdout.write(f' {app}.{name}') - self.stdout.write(self.style.NOTICE( - " Re-run 'manage.py migrate' if they are not marked as " - "applied, and remove 'replaces' attributes in their " - "Migration classes." - )) + self.stdout.write(f" {app}.{name}") + self.stdout.write( + self.style.NOTICE( + " Re-run 'manage.py migrate' if they are not marked as " + "applied, and remove 'replaces' attributes in their " + "Migration classes." + ) + ) else: to_prune = sorted( - migration - for migration in to_prune - if migration[0] == app_label + migration for migration in to_prune if migration[0] == app_label ) if to_prune: for migration in to_prune: app, name = migration if self.verbosity > 0: - self.stdout.write(self.style.MIGRATE_LABEL( - f' Pruning {app}.{name}' - ), ending='') + self.stdout.write( + self.style.MIGRATE_LABEL(f" Pruning {app}.{name}"), + ending="", + ) executor.recorder.record_unapplied(app, name) if self.verbosity > 0: - self.stdout.write(self.style.SUCCESS(' OK')) + self.stdout.write(self.style.SUCCESS(" OK")) elif self.verbosity > 0: - self.stdout.write(' No migrations to prune.') + self.stdout.write(" No migrations to prune.") plan = executor.migration_plan(targets) - exit_dry = plan and options['check_unapplied'] + exit_dry = plan and options["check_unapplied"] - if options['plan']: - self.stdout.write('Planned operations:', self.style.MIGRATE_LABEL) + if options["plan"]: + self.stdout.write("Planned operations:", self.style.MIGRATE_LABEL) if not plan: - self.stdout.write(' No planned migration operations.') + self.stdout.write(" No planned migration operations.") for migration, backwards in plan: self.stdout.write(str(migration), self.style.MIGRATE_HEADING) for operation in migration.operations: message, is_error = self.describe_operation(operation, backwards) style = self.style.WARNING if is_error else None - self.stdout.write(' ' + message, style) + self.stdout.write(" " + message, style) if exit_dry: sys.exit(1) return if exit_dry: sys.exit(1) - if options['prune']: + if options["prune"]: return # At this point, ignore run_syncdb if there aren't any apps to sync. - run_syncdb = options['run_syncdb'] and executor.loader.unmigrated_apps + run_syncdb = options["run_syncdb"] and executor.loader.unmigrated_apps # Print some useful info if self.verbosity >= 1: self.stdout.write(self.style.MIGRATE_HEADING("Operations to perform:")) if run_syncdb: - if options['app_label']: + if options["app_label"]: self.stdout.write( - self.style.MIGRATE_LABEL(" Synchronize unmigrated app: %s" % app_label) + self.style.MIGRATE_LABEL( + " Synchronize unmigrated app: %s" % app_label + ) ) else: self.stdout.write( - self.style.MIGRATE_LABEL(" Synchronize unmigrated apps: ") + - (", ".join(sorted(executor.loader.unmigrated_apps))) + self.style.MIGRATE_LABEL(" Synchronize unmigrated apps: ") + + (", ".join(sorted(executor.loader.unmigrated_apps))) ) if target_app_labels_only: self.stdout.write( - self.style.MIGRATE_LABEL(" Apply all migrations: ") + - (", ".join(sorted({a for a, n in targets})) or "(none)") + self.style.MIGRATE_LABEL(" Apply all migrations: ") + + (", ".join(sorted({a for a, n in targets})) or "(none)") ) else: if targets[0][1] is None: self.stdout.write( - self.style.MIGRATE_LABEL(' Unapply all migrations: ') + - str(targets[0][0]) + self.style.MIGRATE_LABEL(" Unapply all migrations: ") + + str(targets[0][0]) ) else: - self.stdout.write(self.style.MIGRATE_LABEL( - " Target specific migration: ") + "%s, from %s" - % (targets[0][1], targets[0][0]) + self.stdout.write( + self.style.MIGRATE_LABEL(" Target specific migration: ") + + "%s, from %s" % (targets[0][1], targets[0][0]) ) pre_migrate_state = executor._create_project_state(with_applied_migrations=True) pre_migrate_apps = pre_migrate_state.apps emit_pre_migrate_signal( - self.verbosity, self.interactive, connection.alias, stdout=self.stdout, apps=pre_migrate_apps, plan=plan, + self.verbosity, + self.interactive, + connection.alias, + stdout=self.stdout, + apps=pre_migrate_apps, + plan=plan, ) # Run the syncdb phase. if run_syncdb: if self.verbosity >= 1: - self.stdout.write(self.style.MIGRATE_HEADING("Synchronizing apps without migrations:")) - if options['app_label']: + self.stdout.write( + self.style.MIGRATE_HEADING("Synchronizing apps without migrations:") + ) + if options["app_label"]: self.sync_apps(connection, [app_label]) else: self.sync_apps(connection, executor.loader.unmigrated_apps) @@ -287,23 +320,30 @@ class Command(BaseCommand): ) changes = autodetector.changes(graph=executor.loader.graph) if changes: - self.stdout.write(self.style.NOTICE( - " Your models in app(s): %s have changes that are not " - "yet reflected in a migration, and so won't be " - "applied." % ", ".join(repr(app) for app in sorted(changes)) - )) - self.stdout.write(self.style.NOTICE( - " Run 'manage.py makemigrations' to make new " - "migrations, and then re-run 'manage.py migrate' to " - "apply them." - )) + self.stdout.write( + self.style.NOTICE( + " Your models in app(s): %s have changes that are not " + "yet reflected in a migration, and so won't be " + "applied." % ", ".join(repr(app) for app in sorted(changes)) + ) + ) + self.stdout.write( + self.style.NOTICE( + " Run 'manage.py makemigrations' to make new " + "migrations, and then re-run 'manage.py migrate' to " + "apply them." + ) + ) fake = False fake_initial = False else: - fake = options['fake'] - fake_initial = options['fake_initial'] + fake = options["fake"] + fake_initial = options["fake_initial"] post_migrate_state = executor.migrate( - targets, plan=plan, state=pre_migrate_state.clone(), fake=fake, + targets, + plan=plan, + state=pre_migrate_state.clone(), + fake=fake, fake_initial=fake_initial, ) # post_migrate signals have access to all models. Ensure that all models @@ -320,14 +360,19 @@ class Command(BaseCommand): model_key = model_state.app_label, model_state.name_lower model_keys.append(model_key) post_migrate_apps.unregister_model(*model_key) - post_migrate_apps.render_multiple([ - ModelState.from_model(apps.get_model(*model)) for model in model_keys - ]) + post_migrate_apps.render_multiple( + [ModelState.from_model(apps.get_model(*model)) for model in model_keys] + ) # Send the post_migrate signal, so individual apps can do whatever they need # to do at this point. emit_post_migrate_signal( - self.verbosity, self.interactive, connection.alias, stdout=self.stdout, apps=post_migrate_apps, plan=plan, + self.verbosity, + self.interactive, + connection.alias, + stdout=self.stdout, + apps=post_migrate_apps, + plan=plan, ) def migration_progress_callback(self, action, migration=None, fake=False): @@ -339,7 +384,9 @@ class Command(BaseCommand): self.stdout.write(" Applying %s..." % migration, ending="") self.stdout.flush() elif action == "apply_success": - elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + elapsed = ( + " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + ) if fake: self.stdout.write(self.style.SUCCESS(" FAKED" + elapsed)) else: @@ -350,7 +397,9 @@ class Command(BaseCommand): self.stdout.write(" Unapplying %s..." % migration, ending="") self.stdout.flush() elif action == "unapply_success": - elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + elapsed = ( + " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + ) if fake: self.stdout.write(self.style.SUCCESS(" FAKED" + elapsed)) else: @@ -361,7 +410,9 @@ class Command(BaseCommand): self.stdout.write(" Rendering model states...", ending="") self.stdout.flush() elif action == "render_success": - elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + elapsed = ( + " (%.3fs)" % (time.monotonic() - self.start) if compute_time else "" + ) self.stdout.write(self.style.SUCCESS(" DONE" + elapsed)) def sync_apps(self, connection, app_labels): @@ -373,7 +424,9 @@ class Command(BaseCommand): all_models = [ ( app_config.label, - router.get_migratable_models(app_config, connection.alias, include_auto_created=False), + router.get_migratable_models( + app_config, connection.alias, include_auto_created=False + ), ) for app_config in apps.get_app_configs() if app_config.models_module is not None and app_config.label in app_labels @@ -383,8 +436,11 @@ class Command(BaseCommand): opts = model._meta converter = connection.introspection.identifier_converter return not ( - (converter(opts.db_table) in tables) or - (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables) + (converter(opts.db_table) in tables) + or ( + opts.auto_created + and converter(opts.auto_created._meta.db_table) in tables + ) ) manifest = { @@ -394,7 +450,7 @@ class Command(BaseCommand): # Create the tables for each model if self.verbosity >= 1: - self.stdout.write(' Creating tables...') + self.stdout.write(" Creating tables...") with connection.schema_editor() as editor: for app_name, model_list in manifest.items(): for model in model_list: @@ -403,36 +459,39 @@ class Command(BaseCommand): continue if self.verbosity >= 3: self.stdout.write( - ' Processing %s.%s model' % (app_name, model._meta.object_name) + " Processing %s.%s model" + % (app_name, model._meta.object_name) ) if self.verbosity >= 1: - self.stdout.write(' Creating table %s' % model._meta.db_table) + self.stdout.write( + " Creating table %s" % model._meta.db_table + ) editor.create_model(model) # Deferred SQL is executed when exiting the editor's context. if self.verbosity >= 1: - self.stdout.write(' Running deferred SQL...') + self.stdout.write(" Running deferred SQL...") @staticmethod def describe_operation(operation, backwards): """Return a string that describes a migration operation for --plan.""" - prefix = '' + prefix = "" is_error = False - if hasattr(operation, 'code'): + if hasattr(operation, "code"): code = operation.reverse_code if backwards else operation.code - action = (code.__doc__ or '') if code else None - elif hasattr(operation, 'sql'): + action = (code.__doc__ or "") if code else None + elif hasattr(operation, "sql"): action = operation.reverse_sql if backwards else operation.sql else: - action = '' + action = "" if backwards: - prefix = 'Undo ' + prefix = "Undo " if action is not None: - action = str(action).replace('\n', '') + action = str(action).replace("\n", "") elif backwards: - action = 'IRREVERSIBLE' + action = "IRREVERSIBLE" is_error = True if action: - action = ' -> ' + action + action = " -> " + action truncated = Truncator(action) return prefix + operation.describe() + truncated.chars(40), is_error diff --git a/django/core/management/commands/runserver.py b/django/core/management/commands/runserver.py index 473fde0de0..3c39f57e4d 100644 --- a/django/core/management/commands/runserver.py +++ b/django/core/management/commands/runserver.py @@ -7,18 +7,19 @@ from datetime import datetime from django.conf import settings from django.core.management.base import BaseCommand, CommandError -from django.core.servers.basehttp import ( - WSGIServer, get_internal_wsgi_application, run, -) +from django.core.servers.basehttp import WSGIServer, get_internal_wsgi_application, run from django.utils import autoreload from django.utils.regex_helper import _lazy_re_compile -naiveip_re = _lazy_re_compile(r"""^(?: +naiveip_re = _lazy_re_compile( + r"""^(?: (?P<addr> (?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address (?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address (?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN -):)?(?P<port>\d+)$""", re.X) +):)?(?P<port>\d+)$""", + re.X, +) class Command(BaseCommand): @@ -26,39 +27,46 @@ class Command(BaseCommand): # Validation is called explicitly each time the server is reloaded. requires_system_checks = [] - stealth_options = ('shutdown_message',) - suppressed_base_arguments = {'--verbosity', '--traceback'} + stealth_options = ("shutdown_message",) + suppressed_base_arguments = {"--verbosity", "--traceback"} - default_addr = '127.0.0.1' - default_addr_ipv6 = '::1' - default_port = '8000' - protocol = 'http' + default_addr = "127.0.0.1" + default_addr_ipv6 = "::1" + default_port = "8000" + protocol = "http" server_cls = WSGIServer def add_arguments(self, parser): parser.add_argument( - 'addrport', nargs='?', - help='Optional port number, or ipaddr:port' + "addrport", nargs="?", help="Optional port number, or ipaddr:port" ) parser.add_argument( - '--ipv6', '-6', action='store_true', dest='use_ipv6', - help='Tells Django to use an IPv6 address.', + "--ipv6", + "-6", + action="store_true", + dest="use_ipv6", + help="Tells Django to use an IPv6 address.", ) parser.add_argument( - '--nothreading', action='store_false', dest='use_threading', - help='Tells Django to NOT use threading.', + "--nothreading", + action="store_false", + dest="use_threading", + help="Tells Django to NOT use threading.", ) parser.add_argument( - '--noreload', action='store_false', dest='use_reloader', - help='Tells Django to NOT use the auto-reloader.', + "--noreload", + action="store_false", + dest="use_reloader", + help="Tells Django to NOT use the auto-reloader.", ) parser.add_argument( - '--skip-checks', action='store_true', - help='Skip system checks.', + "--skip-checks", + action="store_true", + help="Skip system checks.", ) def execute(self, *args, **options): - if options['no_color']: + if options["no_color"]: # We rely on the environment because it's currently the only # way to reach WSGIRequestHandler. This seems an acceptable # compromise considering `runserver` runs indefinitely. @@ -71,20 +79,22 @@ class Command(BaseCommand): def handle(self, *args, **options): if not settings.DEBUG and not settings.ALLOWED_HOSTS: - raise CommandError('You must set settings.ALLOWED_HOSTS if DEBUG is False.') + raise CommandError("You must set settings.ALLOWED_HOSTS if DEBUG is False.") - self.use_ipv6 = options['use_ipv6'] + self.use_ipv6 = options["use_ipv6"] if self.use_ipv6 and not socket.has_ipv6: - raise CommandError('Your Python does not support IPv6.') + raise CommandError("Your Python does not support IPv6.") self._raw_ipv6 = False - if not options['addrport']: - self.addr = '' + if not options["addrport"]: + self.addr = "" self.port = self.default_port else: - m = re.match(naiveip_re, options['addrport']) + m = re.match(naiveip_re, options["addrport"]) if m is None: - raise CommandError('"%s" is not a valid port number ' - 'or address:port pair.' % options['addrport']) + raise CommandError( + '"%s" is not a valid port number ' + "or address:port pair." % options["addrport"] + ) self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups() if not self.port.isdigit(): raise CommandError("%r is not a valid port number." % self.port) @@ -102,7 +112,7 @@ class Command(BaseCommand): def run(self, **options): """Run the server, using the autoreloader if needed.""" - use_reloader = options['use_reloader'] + use_reloader = options["use_reloader"] if use_reloader: autoreload.run_with_reloader(self.inner_run, **options) @@ -114,36 +124,45 @@ class Command(BaseCommand): # to be raised in the child process, raise it now. autoreload.raise_last_exception() - threading = options['use_threading'] + threading = options["use_threading"] # 'shutdown_message' is a stealth option. - shutdown_message = options.get('shutdown_message', '') - quit_command = 'CTRL-BREAK' if sys.platform == 'win32' else 'CONTROL-C' + shutdown_message = options.get("shutdown_message", "") + quit_command = "CTRL-BREAK" if sys.platform == "win32" else "CONTROL-C" - if not options['skip_checks']: - self.stdout.write('Performing system checks...\n\n') + if not options["skip_checks"]: + self.stdout.write("Performing system checks...\n\n") self.check(display_num_errors=True) # Need to check migrations here, so can't use the # requires_migrations_check attribute. self.check_migrations() - now = datetime.now().strftime('%B %d, %Y - %X') + now = datetime.now().strftime("%B %d, %Y - %X") self.stdout.write(now) - self.stdout.write(( - "Django version %(version)s, using settings %(settings)r\n" - "Starting development server at %(protocol)s://%(addr)s:%(port)s/\n" - "Quit the server with %(quit_command)s." - ) % { - "version": self.get_version(), - "settings": settings.SETTINGS_MODULE, - "protocol": self.protocol, - "addr": '[%s]' % self.addr if self._raw_ipv6 else self.addr, - "port": self.port, - "quit_command": quit_command, - }) + self.stdout.write( + ( + "Django version %(version)s, using settings %(settings)r\n" + "Starting development server at %(protocol)s://%(addr)s:%(port)s/\n" + "Quit the server with %(quit_command)s." + ) + % { + "version": self.get_version(), + "settings": settings.SETTINGS_MODULE, + "protocol": self.protocol, + "addr": "[%s]" % self.addr if self._raw_ipv6 else self.addr, + "port": self.port, + "quit_command": quit_command, + } + ) try: handler = self.get_handler(*args, **options) - run(self.addr, int(self.port), handler, - ipv6=self.use_ipv6, threading=threading, server_cls=self.server_cls) + run( + self.addr, + int(self.port), + handler, + ipv6=self.use_ipv6, + threading=threading, + server_cls=self.server_cls, + ) except OSError as e: # Use helpful error messages instead of ugly tracebacks. ERRORS = { diff --git a/django/core/management/commands/sendtestemail.py b/django/core/management/commands/sendtestemail.py index 9ed1e9600f..6a69849300 100644 --- a/django/core/management/commands/sendtestemail.py +++ b/django/core/management/commands/sendtestemail.py @@ -11,30 +11,33 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'email', nargs='*', - help='One or more email addresses to send a test email to.', + "email", + nargs="*", + help="One or more email addresses to send a test email to.", ) parser.add_argument( - '--managers', action='store_true', - help='Send a test email to the addresses specified in settings.MANAGERS.', + "--managers", + action="store_true", + help="Send a test email to the addresses specified in settings.MANAGERS.", ) parser.add_argument( - '--admins', action='store_true', - help='Send a test email to the addresses specified in settings.ADMINS.', + "--admins", + action="store_true", + help="Send a test email to the addresses specified in settings.ADMINS.", ) def handle(self, *args, **kwargs): - subject = 'Test email from %s on %s' % (socket.gethostname(), timezone.now()) + subject = "Test email from %s on %s" % (socket.gethostname(), timezone.now()) send_mail( subject=subject, - message="If you\'re reading this, it was successful.", + message="If you're reading this, it was successful.", from_email=None, - recipient_list=kwargs['email'], + recipient_list=kwargs["email"], ) - if kwargs['managers']: + if kwargs["managers"]: mail_managers(subject, "This email was sent to the site managers.") - if kwargs['admins']: + if kwargs["admins"]: mail_admins(subject, "This email was sent to the site admins.") diff --git a/django/core/management/commands/shell.py b/django/core/management/commands/shell.py index cbd4c86620..52ab27cb21 100644 --- a/django/core/management/commands/shell.py +++ b/django/core/management/commands/shell.py @@ -15,28 +15,34 @@ class Command(BaseCommand): ) requires_system_checks = [] - shells = ['ipython', 'bpython', 'python'] + shells = ["ipython", "bpython", "python"] def add_arguments(self, parser): parser.add_argument( - '--no-startup', action='store_true', - help='When using plain Python, ignore the PYTHONSTARTUP environment variable and ~/.pythonrc.py script.', + "--no-startup", + action="store_true", + help="When using plain Python, ignore the PYTHONSTARTUP environment variable and ~/.pythonrc.py script.", ) parser.add_argument( - '-i', '--interface', choices=self.shells, + "-i", + "--interface", + choices=self.shells, help='Specify an interactive interpreter interface. Available options: "ipython", "bpython", and "python"', ) parser.add_argument( - '-c', '--command', - help='Instead of opening an interactive shell, run a command as Django and exit.', + "-c", + "--command", + help="Instead of opening an interactive shell, run a command as Django and exit.", ) def ipython(self, options): from IPython import start_ipython + start_ipython(argv=[]) def bpython(self, options): import bpython + bpython.embed() def python(self, options): @@ -47,8 +53,10 @@ class Command(BaseCommand): # We want to honor both $PYTHONSTARTUP and .pythonrc.py, so follow system # conventions and get $PYTHONSTARTUP first then .pythonrc.py. - if not options['no_startup']: - for pythonrc in OrderedSet([os.environ.get("PYTHONSTARTUP"), os.path.expanduser('~/.pythonrc.py')]): + if not options["no_startup"]: + for pythonrc in OrderedSet( + [os.environ.get("PYTHONSTARTUP"), os.path.expanduser("~/.pythonrc.py")] + ): if not pythonrc: continue if not os.path.isfile(pythonrc): @@ -58,7 +66,7 @@ class Command(BaseCommand): # Match the behavior of the cpython shell where an error in # PYTHONSTARTUP prints an exception and continues. try: - exec(compile(pythonrc_code, pythonrc, 'exec'), imported_objects) + exec(compile(pythonrc_code, pythonrc, "exec"), imported_objects) except Exception: traceback.print_exc() @@ -78,7 +86,7 @@ class Command(BaseCommand): # Match the behavior of the cpython shell where an error in # sys.__interactivehook__ prints a warning and the exception # and continues. - print('Failed calling sys.__interactivehook__') + print("Failed calling sys.__interactivehook__") traceback.print_exc() # Set up tab completion for objects imported by $PYTHONSTARTUP or @@ -86,6 +94,7 @@ class Command(BaseCommand): try: import readline import rlcompleter + readline.set_completer(rlcompleter.Completer(imported_objects).complete) except ImportError: pass @@ -95,17 +104,23 @@ class Command(BaseCommand): def handle(self, **options): # Execute the command and exit. - if options['command']: - exec(options['command'], globals()) + if options["command"]: + exec(options["command"], globals()) return # Execute stdin if it has anything to read and exit. # Not supported on Windows due to select.select() limitations. - if sys.platform != 'win32' and not sys.stdin.isatty() and select.select([sys.stdin], [], [], 0)[0]: + if ( + sys.platform != "win32" + and not sys.stdin.isatty() + and select.select([sys.stdin], [], [], 0)[0] + ): exec(sys.stdin.read(), globals()) return - available_shells = [options['interface']] if options['interface'] else self.shells + available_shells = ( + [options["interface"]] if options["interface"] else self.shells + ) for shell in available_shells: try: diff --git a/django/core/management/commands/showmigrations.py b/django/core/management/commands/showmigrations.py index e3227457ce..1f3a64f9b1 100644 --- a/django/core/management/commands/showmigrations.py +++ b/django/core/management/commands/showmigrations.py @@ -12,48 +12,58 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'app_label', nargs='*', - help='App labels of applications to limit the output to.', + "app_label", + nargs="*", + help="App labels of applications to limit the output to.", ) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help=( - 'Nominates a database to show migrations for. Defaults to the ' + "Nominates a database to show migrations for. Defaults to the " '"default" database.' ), ) formats = parser.add_mutually_exclusive_group() formats.add_argument( - '--list', '-l', action='store_const', dest='format', const='list', + "--list", + "-l", + action="store_const", + dest="format", + const="list", help=( - 'Shows a list of all migrations and which are applied. ' - 'With a verbosity level of 2 or above, the applied datetimes ' - 'will be included.' + "Shows a list of all migrations and which are applied. " + "With a verbosity level of 2 or above, the applied datetimes " + "will be included." ), ) formats.add_argument( - '--plan', '-p', action='store_const', dest='format', const='plan', + "--plan", + "-p", + action="store_const", + dest="format", + const="plan", help=( - 'Shows all migrations in the order they will be applied. ' - 'With a verbosity level of 2 or above all direct migration dependencies ' - 'and reverse dependencies (run_before) will be included.' - ) + "Shows all migrations in the order they will be applied. " + "With a verbosity level of 2 or above all direct migration dependencies " + "and reverse dependencies (run_before) will be included." + ), ) - parser.set_defaults(format='list') + parser.set_defaults(format="list") def handle(self, *args, **options): - self.verbosity = options['verbosity'] + self.verbosity = options["verbosity"] # Get the database we're operating from - db = options['database'] + db = options["database"] connection = connections[db] - if options['format'] == "plan": - return self.show_plan(connection, options['app_label']) + if options["format"] == "plan": + return self.show_plan(connection, options["app_label"]) else: - return self.show_list(connection, options['app_label']) + return self.show_list(connection, options["app_label"]) def _validate_app_names(self, loader, app_names): has_bad_names = False @@ -93,17 +103,26 @@ class Command(BaseCommand): # Give it a nice title if it's a squashed one title = plan_node[1] if graph.nodes[plan_node].replaces: - title += " (%s squashed migrations)" % len(graph.nodes[plan_node].replaces) + title += " (%s squashed migrations)" % len( + graph.nodes[plan_node].replaces + ) applied_migration = loader.applied_migrations.get(plan_node) # Mark it as applied/unapplied if applied_migration: if plan_node in recorded_migrations: - output = ' [X] %s' % title + output = " [X] %s" % title else: title += " Run 'manage.py migrate' to finish recording." - output = ' [-] %s' % title - if self.verbosity >= 2 and hasattr(applied_migration, 'applied'): - output += ' (applied at %s)' % applied_migration.applied.strftime('%Y-%m-%d %H:%M:%S') + output = " [-] %s" % title + if self.verbosity >= 2 and hasattr( + applied_migration, "applied" + ): + output += ( + " (applied at %s)" + % applied_migration.applied.strftime( + "%Y-%m-%d %H:%M:%S" + ) + ) self.stdout.write(output) else: self.stdout.write(" [ ] %s" % title) @@ -154,4 +173,4 @@ class Command(BaseCommand): else: self.stdout.write("[ ] %s.%s%s" % (node.key[0], node.key[1], deps)) if not plan: - self.stdout.write('(no migrations)', self.style.ERROR) + self.stdout.write("(no migrations)", self.style.ERROR) diff --git a/django/core/management/commands/sqlflush.py b/django/core/management/commands/sqlflush.py index 29782607bb..e6701349d6 100644 --- a/django/core/management/commands/sqlflush.py +++ b/django/core/management/commands/sqlflush.py @@ -14,12 +14,13 @@ class Command(BaseCommand): def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database to print the SQL for. Defaults to the "default" database.', ) def handle(self, **options): - sql_statements = sql_flush(self.style, connections[options['database']]) - if not sql_statements and options['verbosity'] >= 1: - self.stderr.write('No tables found.') - return '\n'.join(sql_statements) + sql_statements = sql_flush(self.style, connections[options["database"]]) + if not sql_statements and options["verbosity"] >= 1: + self.stderr.write("No tables found.") + return "\n".join(sql_statements) diff --git a/django/core/management/commands/sqlmigrate.py b/django/core/management/commands/sqlmigrate.py index f687360fb4..880eb11d9b 100644 --- a/django/core/management/commands/sqlmigrate.py +++ b/django/core/management/commands/sqlmigrate.py @@ -10,34 +10,40 @@ class Command(BaseCommand): output_transaction = True def add_arguments(self, parser): - parser.add_argument('app_label', help='App label of the application containing the migration.') - parser.add_argument('migration_name', help='Migration name to print the SQL for.') parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "app_label", help="App label of the application containing the migration." + ) + parser.add_argument( + "migration_name", help="Migration name to print the SQL for." + ) + parser.add_argument( + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database to create SQL for. Defaults to the "default" database.', ) parser.add_argument( - '--backwards', action='store_true', - help='Creates SQL to unapply the migration, rather than to apply it', + "--backwards", + action="store_true", + help="Creates SQL to unapply the migration, rather than to apply it", ) def execute(self, *args, **options): # sqlmigrate doesn't support coloring its output but we need to force # no_color=True so that the BEGIN/COMMIT statements added by # output_transaction don't get colored either. - options['no_color'] = True + options["no_color"] = True return super().execute(*args, **options) def handle(self, *args, **options): # Get the database we're operating from - connection = connections[options['database']] + connection = connections[options["database"]] # Load up a loader to get all the migration data, but don't replace # migrations. loader = MigrationLoader(connection, replace_migrations=False) # Resolve command-line arguments into a migration - app_label, migration_name = options['app_label'], options['migration_name'] + app_label, migration_name = options["app_label"], options["migration_name"] # Validate app_label try: apps.get_app_config(app_label) @@ -48,21 +54,27 @@ class Command(BaseCommand): try: migration = loader.get_migration_by_prefix(app_label, migration_name) except AmbiguityError: - raise CommandError("More than one migration matches '%s' in app '%s'. Please be more specific." % ( - migration_name, app_label)) + raise CommandError( + "More than one migration matches '%s' in app '%s'. Please be more specific." + % (migration_name, app_label) + ) except KeyError: - raise CommandError("Cannot find a migration matching '%s' from app '%s'. Is it in INSTALLED_APPS?" % ( - migration_name, app_label)) + raise CommandError( + "Cannot find a migration matching '%s' from app '%s'. Is it in INSTALLED_APPS?" + % (migration_name, app_label) + ) target = (app_label, migration.name) # Show begin/end around output for atomic migrations, if the database # supports transactional DDL. - self.output_transaction = migration.atomic and connection.features.can_rollback_ddl + self.output_transaction = ( + migration.atomic and connection.features.can_rollback_ddl + ) # Make a plan that represents just the requested migrations and show SQL # for it - plan = [(loader.graph.nodes[target], options['backwards'])] + plan = [(loader.graph.nodes[target], options["backwards"])] sql_statements = loader.collect_sql(plan) - if not sql_statements and options['verbosity'] >= 1: - self.stderr.write('No operations found.') - return '\n'.join(sql_statements) + if not sql_statements and options["verbosity"] >= 1: + self.stderr.write("No operations found.") + return "\n".join(sql_statements) diff --git a/django/core/management/commands/sqlsequencereset.py b/django/core/management/commands/sqlsequencereset.py index 1d74ed9f55..454a2ab10c 100644 --- a/django/core/management/commands/sqlsequencereset.py +++ b/django/core/management/commands/sqlsequencereset.py @@ -3,23 +3,26 @@ from django.db import DEFAULT_DB_ALIAS, connections class Command(AppCommand): - help = 'Prints the SQL statements for resetting sequences for the given app name(s).' + help = ( + "Prints the SQL statements for resetting sequences for the given app name(s)." + ) output_transaction = True def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument( - '--database', default=DEFAULT_DB_ALIAS, + "--database", + default=DEFAULT_DB_ALIAS, help='Nominates a database to print the SQL for. Defaults to the "default" database.', ) def handle_app_config(self, app_config, **options): if app_config.models_module is None: return - connection = connections[options['database']] + connection = connections[options["database"]] models = app_config.get_models(include_auto_created=True) statements = connection.ops.sequence_reset_sql(self.style, models) - if not statements and options['verbosity'] >= 1: - self.stderr.write('No sequences found.') - return '\n'.join(statements) + if not statements and options["verbosity"] >= 1: + self.stderr.write("No sequences found.") + return "\n".join(statements) diff --git a/django/core/management/commands/squashmigrations.py b/django/core/management/commands/squashmigrations.py index 80fdc0cbc1..1592e792f8 100644 --- a/django/core/management/commands/squashmigrations.py +++ b/django/core/management/commands/squashmigrations.py @@ -16,44 +16,51 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'app_label', - help='App label of the application to squash migrations for.', + "app_label", + help="App label of the application to squash migrations for.", ) parser.add_argument( - 'start_migration_name', nargs='?', - help='Migrations will be squashed starting from and including this migration.', + "start_migration_name", + nargs="?", + help="Migrations will be squashed starting from and including this migration.", ) parser.add_argument( - 'migration_name', - help='Migrations will be squashed until and including this migration.', + "migration_name", + help="Migrations will be squashed until and including this migration.", ) parser.add_argument( - '--no-optimize', action='store_true', - help='Do not try to optimize the squashed operations.', + "--no-optimize", + action="store_true", + help="Do not try to optimize the squashed operations.", ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--squashed-name', - help='Sets the name of the new squashed migration.', + "--squashed-name", + help="Sets the name of the new squashed migration.", ) parser.add_argument( - '--no-header', action='store_false', dest='include_header', - help='Do not add a header comment to the new squashed migration.', + "--no-header", + action="store_false", + dest="include_header", + help="Do not add a header comment to the new squashed migration.", ) def handle(self, **options): - self.verbosity = options['verbosity'] - self.interactive = options['interactive'] - app_label = options['app_label'] - start_migration_name = options['start_migration_name'] - migration_name = options['migration_name'] - no_optimize = options['no_optimize'] - squashed_name = options['squashed_name'] - include_header = options['include_header'] + self.verbosity = options["verbosity"] + self.interactive = options["interactive"] + app_label = options["app_label"] + start_migration_name = options["start_migration_name"] + migration_name = options["migration_name"] + no_optimize = options["no_optimize"] + squashed_name = options["squashed_name"] + include_header = options["include_header"] # Validate app_label. try: apps.get_app_config(app_label) @@ -72,13 +79,19 @@ class Command(BaseCommand): # Work out the list of predecessor migrations migrations_to_squash = [ loader.get_migration(al, mn) - for al, mn in loader.graph.forwards_plan((migration.app_label, migration.name)) + for al, mn in loader.graph.forwards_plan( + (migration.app_label, migration.name) + ) if al == migration.app_label ] if start_migration_name: - start_migration = self.find_migration(loader, app_label, start_migration_name) - start = loader.get_migration(start_migration.app_label, start_migration.name) + start_migration = self.find_migration( + loader, app_label, start_migration_name + ) + start = loader.get_migration( + start_migration.app_label, start_migration.name + ) try: start_index = migrations_to_squash.index(start) migrations_to_squash = migrations_to_squash[start_index:] @@ -93,7 +106,9 @@ class Command(BaseCommand): # Tell them what we're doing and optionally ask if we should proceed if self.verbosity > 0 or self.interactive: - self.stdout.write(self.style.MIGRATE_HEADING("Will squash the following migrations:")) + self.stdout.write( + self.style.MIGRATE_HEADING("Will squash the following migrations:") + ) for migration in migrations_to_squash: self.stdout.write(" - %s" % migration.name) @@ -122,7 +137,8 @@ class Command(BaseCommand): raise CommandError( "You cannot squash squashed migrations! Please transition " "it to a normal migration first: " - "https://docs.djangoproject.com/en/%s/topics/migrations/#squashing-migrations" % get_docs_version() + "https://docs.djangoproject.com/en/%s/topics/migrations/#squashing-migrations" + % get_docs_version() ) operations.extend(smigration.operations) for dependency in smigration.dependencies: @@ -137,7 +153,9 @@ class Command(BaseCommand): if no_optimize: if self.verbosity > 0: - self.stdout.write(self.style.MIGRATE_HEADING("(Skipping optimization.)")) + self.stdout.write( + self.style.MIGRATE_HEADING("(Skipping optimization.)") + ) new_operations = operations else: if self.verbosity > 0: @@ -151,8 +169,8 @@ class Command(BaseCommand): self.stdout.write(" No optimizations possible.") else: self.stdout.write( - " Optimized from %s operations to %s operations." % - (len(operations), len(new_operations)) + " Optimized from %s operations to %s operations." + % (len(operations), len(new_operations)) ) # Work out the value of replaces (any squashed ones we're re-squashing) @@ -165,22 +183,26 @@ class Command(BaseCommand): replaces.append((migration.app_label, migration.name)) # Make a new migration with those operations - subclass = type("Migration", (migrations.Migration,), { - "dependencies": dependencies, - "operations": new_operations, - "replaces": replaces, - }) + subclass = type( + "Migration", + (migrations.Migration,), + { + "dependencies": dependencies, + "operations": new_operations, + "replaces": replaces, + }, + ) if start_migration_name: if squashed_name: # Use the name from --squashed-name. - prefix, _ = start_migration.name.split('_', 1) - name = '%s_%s' % (prefix, squashed_name) + prefix, _ = start_migration.name.split("_", 1) + name = "%s_%s" % (prefix, squashed_name) else: # Generate a name. - name = '%s_squashed_%s' % (start_migration.name, migration.name) + name = "%s_squashed_%s" % (start_migration.name, migration.name) new_migration = subclass(name, app_label) else: - name = '0001_%s' % (squashed_name or 'squashed_%s' % migration.name) + name = "0001_%s" % (squashed_name or "squashed_%s" % migration.name) new_migration = subclass(name, app_label) new_migration.initial = True @@ -188,25 +210,28 @@ class Command(BaseCommand): writer = MigrationWriter(new_migration, include_header) if os.path.exists(writer.path): raise CommandError( - f'Migration {new_migration.name} already exists. Use a different name.' + f"Migration {new_migration.name} already exists. Use a different name." ) - with open(writer.path, "w", encoding='utf-8') as fh: + with open(writer.path, "w", encoding="utf-8") as fh: fh.write(writer.as_string()) if self.verbosity > 0: self.stdout.write( - self.style.MIGRATE_HEADING('Created new squashed migration %s' % writer.path) + '\n' - ' You should commit this migration but leave the old ones in place;\n' - ' the new migration will be used for new installs. Once you are sure\n' - ' all instances of the codebase have applied the migrations you squashed,\n' - ' you can delete them.' + self.style.MIGRATE_HEADING( + "Created new squashed migration %s" % writer.path + ) + + "\n" + " You should commit this migration but leave the old ones in place;\n" + " the new migration will be used for new installs. Once you are sure\n" + " all instances of the codebase have applied the migrations you squashed,\n" + " you can delete them." ) if writer.needs_manual_porting: self.stdout.write( - self.style.MIGRATE_HEADING('Manual porting required') + '\n' - ' Your migrations contained functions that must be manually copied over,\n' - ' as we could not safely copy their implementation.\n' - ' See the comment at the top of the squashed migration for details.' + self.style.MIGRATE_HEADING("Manual porting required") + "\n" + " Your migrations contained functions that must be manually copied over,\n" + " as we could not safely copy their implementation.\n" + " See the comment at the top of the squashed migration for details." ) def find_migration(self, loader, app_label, name): @@ -219,6 +244,6 @@ class Command(BaseCommand): ) except KeyError: raise CommandError( - "Cannot find a migration matching '%s' from app '%s'." % - (name, app_label) + "Cannot find a migration matching '%s' from app '%s'." + % (name, app_label) ) diff --git a/django/core/management/commands/startapp.py b/django/core/management/commands/startapp.py index bba9f3dee0..e85833b9a8 100644 --- a/django/core/management/commands/startapp.py +++ b/django/core/management/commands/startapp.py @@ -9,6 +9,6 @@ class Command(TemplateCommand): missing_args_message = "You must provide an application name." def handle(self, **options): - app_name = options.pop('name') - target = options.pop('directory') - super().handle('app', app_name, target, **options) + app_name = options.pop("name") + target = options.pop("directory") + super().handle("app", app_name, target, **options) diff --git a/django/core/management/commands/startproject.py b/django/core/management/commands/startproject.py index 164ccdffb5..ca17fa54cd 100644 --- a/django/core/management/commands/startproject.py +++ b/django/core/management/commands/startproject.py @@ -12,10 +12,10 @@ class Command(TemplateCommand): missing_args_message = "You must provide a project name." def handle(self, **options): - project_name = options.pop('name') - target = options.pop('directory') + project_name = options.pop("name") + target = options.pop("directory") # Create a random SECRET_KEY to put it in the main settings. - options['secret_key'] = SECRET_KEY_INSECURE_PREFIX + get_random_secret_key() + options["secret_key"] = SECRET_KEY_INSECURE_PREFIX + get_random_secret_key() - super().handle('project', project_name, target, **options) + super().handle("project", project_name, target, **options) diff --git a/django/core/management/commands/test.py b/django/core/management/commands/test.py index 7a76033424..e5660955cd 100644 --- a/django/core/management/commands/test.py +++ b/django/core/management/commands/test.py @@ -8,7 +8,7 @@ from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner class Command(BaseCommand): - help = 'Discover and run tests in the specified modules or the current directory.' + help = "Discover and run tests in the specified modules or the current directory." # DiscoverRunner runs the checks after databases are set up. requires_system_checks = [] @@ -20,42 +20,48 @@ class Command(BaseCommand): option. This allows a test runner to define additional command line arguments. """ - self.test_runner = get_command_line_option(argv, '--testrunner') + self.test_runner = get_command_line_option(argv, "--testrunner") super().run_from_argv(argv) def add_arguments(self, parser): parser.add_argument( - 'args', metavar='test_label', nargs='*', - help='Module paths to test; can be modulename, modulename.TestCase or modulename.TestCase.test_method' + "args", + metavar="test_label", + nargs="*", + help="Module paths to test; can be modulename, modulename.TestCase or modulename.TestCase.test_method", ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--failfast', action='store_true', - help='Tells Django to stop running the test suite after first failed test.', + "--failfast", + action="store_true", + help="Tells Django to stop running the test suite after first failed test.", ) parser.add_argument( - '--testrunner', - help='Tells Django to use specified test runner class instead of ' - 'the one specified by the TEST_RUNNER setting.', + "--testrunner", + help="Tells Django to use specified test runner class instead of " + "the one specified by the TEST_RUNNER setting.", ) test_runner_class = get_runner(settings, self.test_runner) - if hasattr(test_runner_class, 'add_arguments'): + if hasattr(test_runner_class, "add_arguments"): test_runner_class.add_arguments(parser) def handle(self, *test_labels, **options): - TestRunner = get_runner(settings, options['testrunner']) + TestRunner = get_runner(settings, options["testrunner"]) - time_keeper = TimeKeeper() if options.get('timing', False) else NullTimeKeeper() - parallel = options.get('parallel') - if parallel == 'auto': - options['parallel'] = get_max_test_processes() + time_keeper = TimeKeeper() if options.get("timing", False) else NullTimeKeeper() + parallel = options.get("parallel") + if parallel == "auto": + options["parallel"] = get_max_test_processes() test_runner = TestRunner(**options) - with time_keeper.timed('Total run'): + with time_keeper.timed("Total run"): failures = test_runner.run_tests(test_labels) time_keeper.print_results() if failures: diff --git a/django/core/management/commands/testserver.py b/django/core/management/commands/testserver.py index ee8709af8b..caff6c65cd 100644 --- a/django/core/management/commands/testserver.py +++ b/django/core/management/commands/testserver.py @@ -4,51 +4,62 @@ from django.db import connection class Command(BaseCommand): - help = 'Runs a development server with data from the given fixture(s).' + help = "Runs a development server with data from the given fixture(s)." requires_system_checks = [] def add_arguments(self, parser): parser.add_argument( - 'args', metavar='fixture', nargs='*', - help='Path(s) to fixtures to load before running the server.', + "args", + metavar="fixture", + nargs="*", + help="Path(s) to fixtures to load before running the server.", ) parser.add_argument( - '--noinput', '--no-input', action='store_false', dest='interactive', - help='Tells Django to NOT prompt the user for input of any kind.', + "--noinput", + "--no-input", + action="store_false", + dest="interactive", + help="Tells Django to NOT prompt the user for input of any kind.", ) parser.add_argument( - '--addrport', default='', - help='Port number or ipaddr:port to run the server on.', + "--addrport", + default="", + help="Port number or ipaddr:port to run the server on.", ) parser.add_argument( - '--ipv6', '-6', action='store_true', dest='use_ipv6', - help='Tells Django to use an IPv6 address.', + "--ipv6", + "-6", + action="store_true", + dest="use_ipv6", + help="Tells Django to use an IPv6 address.", ) def handle(self, *fixture_labels, **options): - verbosity = options['verbosity'] - interactive = options['interactive'] + verbosity = options["verbosity"] + interactive = options["interactive"] # Create a test database. - db_name = connection.creation.create_test_db(verbosity=verbosity, autoclobber=not interactive, serialize=False) + db_name = connection.creation.create_test_db( + verbosity=verbosity, autoclobber=not interactive, serialize=False + ) # Import the fixture data into the test database. - call_command('loaddata', *fixture_labels, **{'verbosity': verbosity}) + call_command("loaddata", *fixture_labels, **{"verbosity": verbosity}) # Run the development server. Turn off auto-reloading because it causes # a strange error -- it causes this handle() method to be called # multiple times. shutdown_message = ( - '\nServer stopped.\nNote that the test database, %r, has not been ' - 'deleted. You can explore it on your own.' % db_name + "\nServer stopped.\nNote that the test database, %r, has not been " + "deleted. You can explore it on your own." % db_name ) use_threading = connection.features.test_db_allows_multiple_connections call_command( - 'runserver', - addrport=options['addrport'], + "runserver", + addrport=options["addrport"], shutdown_message=shutdown_message, use_reloader=False, - use_ipv6=options['use_ipv6'], - use_threading=use_threading + use_ipv6=options["use_ipv6"], + use_threading=use_threading, ) diff --git a/django/core/management/sql.py b/django/core/management/sql.py index a7e122a15f..2375cc23ab 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -8,7 +8,9 @@ def sql_flush(style, connection, reset_sequences=True, allow_cascade=False): """ Return a list of the SQL statements used to flush the database. """ - tables = connection.introspection.django_table_names(only_existing=True, include_views=False) + tables = connection.introspection.django_table_names( + only_existing=True, include_views=False + ) return connection.ops.sql_flush( style, tables, @@ -23,15 +25,17 @@ def emit_pre_migrate_signal(verbosity, interactive, db, **kwargs): if app_config.models_module is None: continue if verbosity >= 2: - stdout = kwargs.get('stdout', sys.stdout) - stdout.write('Running pre-migrate handlers for application %s' % app_config.label) + stdout = kwargs.get("stdout", sys.stdout) + stdout.write( + "Running pre-migrate handlers for application %s" % app_config.label + ) models.signals.pre_migrate.send( sender=app_config, app_config=app_config, verbosity=verbosity, interactive=interactive, using=db, - **kwargs + **kwargs, ) @@ -41,13 +45,15 @@ def emit_post_migrate_signal(verbosity, interactive, db, **kwargs): if app_config.models_module is None: continue if verbosity >= 2: - stdout = kwargs.get('stdout', sys.stdout) - stdout.write('Running post-migrate handlers for application %s' % app_config.label) + stdout = kwargs.get("stdout", sys.stdout) + stdout.write( + "Running post-migrate handlers for application %s" % app_config.label + ) models.signals.post_migrate.send( sender=app_config, app_config=app_config, verbosity=verbosity, interactive=interactive, using=db, - **kwargs + **kwargs, ) diff --git a/django/core/management/templates.py b/django/core/management/templates.py index cfcaff7c0f..58005c23ed 100644 --- a/django/core/management/templates.py +++ b/django/core/management/templates.py @@ -29,46 +29,61 @@ class TemplateCommand(BaseCommand): :param directory: The directory to which the template should be copied. :param options: The additional variables passed to project or app templates """ + requires_system_checks = [] # The supported URL schemes - url_schemes = ['http', 'https', 'ftp'] + url_schemes = ["http", "https", "ftp"] # Rewrite the following suffixes when determining the target filename. rewrite_template_suffixes = ( # Allow shipping invalid .py files without byte-compilation. - ('.py-tpl', '.py'), + (".py-tpl", ".py"), ) def add_arguments(self, parser): - parser.add_argument('name', help='Name of the application or project.') - parser.add_argument('directory', nargs='?', help='Optional destination directory') - parser.add_argument('--template', help='The path or URL to load the template from.') + parser.add_argument("name", help="Name of the application or project.") parser.add_argument( - '--extension', '-e', dest='extensions', - action='append', default=['py'], + "directory", nargs="?", help="Optional destination directory" + ) + parser.add_argument( + "--template", help="The path or URL to load the template from." + ) + parser.add_argument( + "--extension", + "-e", + dest="extensions", + action="append", + default=["py"], help='The file extension(s) to render (default: "py"). ' - 'Separate multiple extensions with commas, or use ' - '-e multiple times.' + "Separate multiple extensions with commas, or use " + "-e multiple times.", ) parser.add_argument( - '--name', '-n', dest='files', - action='append', default=[], - help='The file name(s) to render. Separate multiple file names ' - 'with commas, or use -n multiple times.' + "--name", + "-n", + dest="files", + action="append", + default=[], + help="The file name(s) to render. Separate multiple file names " + "with commas, or use -n multiple times.", ) parser.add_argument( - '--exclude', '-x', - action='append', default=argparse.SUPPRESS, nargs='?', const='', + "--exclude", + "-x", + action="append", + default=argparse.SUPPRESS, + nargs="?", + const="", help=( - 'The directory name(s) to exclude, in addition to .git and ' - '__pycache__. Can be used multiple times.' + "The directory name(s) to exclude, in addition to .git and " + "__pycache__. Can be used multiple times." ), ) def handle(self, app_or_project, name, target=None, **options): self.app_or_project = app_or_project - self.a_or_an = 'an' if app_or_project == 'app' else 'a' + self.a_or_an = "an" if app_or_project == "app" else "a" self.paths_to_remove = [] - self.verbosity = options['verbosity'] + self.verbosity = options["verbosity"] self.validate_name(name) @@ -83,51 +98,55 @@ class TemplateCommand(BaseCommand): raise CommandError(e) else: top_dir = os.path.abspath(os.path.expanduser(target)) - if app_or_project == 'app': - self.validate_name(os.path.basename(top_dir), 'directory') + if app_or_project == "app": + self.validate_name(os.path.basename(top_dir), "directory") if not os.path.exists(top_dir): - raise CommandError("Destination directory '%s' does not " - "exist, please create it first." % top_dir) + raise CommandError( + "Destination directory '%s' does not " + "exist, please create it first." % top_dir + ) - extensions = tuple(handle_extensions(options['extensions'])) + extensions = tuple(handle_extensions(options["extensions"])) extra_files = [] - excluded_directories = ['.git', '__pycache__'] - for file in options['files']: - extra_files.extend(map(lambda x: x.strip(), file.split(','))) - if exclude := options.get('exclude'): + excluded_directories = [".git", "__pycache__"] + for file in options["files"]: + extra_files.extend(map(lambda x: x.strip(), file.split(","))) + if exclude := options.get("exclude"): for directory in exclude: excluded_directories.append(directory.strip()) if self.verbosity >= 2: self.stdout.write( - 'Rendering %s template files with extensions: %s' - % (app_or_project, ', '.join(extensions)) + "Rendering %s template files with extensions: %s" + % (app_or_project, ", ".join(extensions)) ) self.stdout.write( - 'Rendering %s template files with filenames: %s' - % (app_or_project, ', '.join(extra_files)) + "Rendering %s template files with filenames: %s" + % (app_or_project, ", ".join(extra_files)) ) - base_name = '%s_name' % app_or_project - base_subdir = '%s_template' % app_or_project - base_directory = '%s_directory' % app_or_project - camel_case_name = 'camel_case_%s_name' % app_or_project - camel_case_value = ''.join(x for x in name.title() if x != '_') + base_name = "%s_name" % app_or_project + base_subdir = "%s_template" % app_or_project + base_directory = "%s_directory" % app_or_project + camel_case_name = "camel_case_%s_name" % app_or_project + camel_case_value = "".join(x for x in name.title() if x != "_") - context = Context({ - **options, - base_name: name, - base_directory: top_dir, - camel_case_name: camel_case_value, - 'docs_version': get_docs_version(), - 'django_version': django.__version__, - }, autoescape=False) + context = Context( + { + **options, + base_name: name, + base_directory: top_dir, + camel_case_name: camel_case_value, + "docs_version": get_docs_version(), + "django_version": django.__version__, + }, + autoescape=False, + ) # Setup a stub settings environment for template rendering if not settings.configured: settings.configure() django.setup() - template_dir = self.handle_template(options['template'], - base_subdir) + template_dir = self.handle_template(options["template"], base_subdir) prefix_length = len(template_dir) + 1 for root, dirs, files in os.walk(template_dir): @@ -139,14 +158,14 @@ class TemplateCommand(BaseCommand): os.makedirs(target_dir, exist_ok=True) for dirname in dirs[:]: - if 'exclude' not in options: - if dirname.startswith('.') or dirname == '__pycache__': + if "exclude" not in options: + if dirname.startswith(".") or dirname == "__pycache__": dirs.remove(dirname) elif dirname in excluded_directories: dirs.remove(dirname) for filename in files: - if filename.endswith(('.pyo', '.pyc', '.py.class')): + if filename.endswith((".pyo", ".pyc", ".py.class")): # Ignore some files as they cause various breakages. continue old_path = os.path.join(root, filename) @@ -155,31 +174,34 @@ class TemplateCommand(BaseCommand): ) for old_suffix, new_suffix in self.rewrite_template_suffixes: if new_path.endswith(old_suffix): - new_path = new_path[:-len(old_suffix)] + new_suffix + new_path = new_path[: -len(old_suffix)] + new_suffix break # Only rewrite once if os.path.exists(new_path): raise CommandError( "%s already exists. Overlaying %s %s into an existing " - "directory won't replace conflicting files." % ( - new_path, self.a_or_an, app_or_project, + "directory won't replace conflicting files." + % ( + new_path, + self.a_or_an, + app_or_project, ) ) # Only render the Python files, as we don't want to # accidentally render Django templates files if new_path.endswith(extensions) or filename in extra_files: - with open(old_path, encoding='utf-8') as template_file: + with open(old_path, encoding="utf-8") as template_file: content = template_file.read() template = Engine().from_string(content) content = template.render(context) - with open(new_path, 'w', encoding='utf-8') as new_file: + with open(new_path, "w", encoding="utf-8") as new_file: new_file.write(content) else: shutil.copyfile(old_path, new_path) if self.verbosity >= 2: - self.stdout.write('Creating %s' % new_path) + self.stdout.write("Creating %s" % new_path) try: self.apply_umask(old_path, new_path) self.make_writeable(new_path) @@ -187,11 +209,13 @@ class TemplateCommand(BaseCommand): self.stderr.write( "Notice: Couldn't set permission bits on %s. You're " "probably using an uncommon filesystem setup. No " - "problem." % new_path, self.style.NOTICE) + "problem." % new_path, + self.style.NOTICE, + ) if self.paths_to_remove: if self.verbosity >= 2: - self.stdout.write('Cleaning up temporary files.') + self.stdout.write("Cleaning up temporary files.") for path_to_remove in self.paths_to_remove: if os.path.isfile(path_to_remove): os.remove(path_to_remove) @@ -205,9 +229,9 @@ class TemplateCommand(BaseCommand): directory isn't known. """ if template is None: - return os.path.join(django.__path__[0], 'conf', subdir) + return os.path.join(django.__path__[0], "conf", subdir) else: - if template.startswith('file://'): + if template.startswith("file://"): template = template[7:] expanded_template = os.path.expanduser(template) expanded_template = os.path.normpath(expanded_template) @@ -221,15 +245,18 @@ class TemplateCommand(BaseCommand): if os.path.exists(absolute_path): return self.extract(absolute_path) - raise CommandError("couldn't handle %s template %s." % - (self.app_or_project, template)) + raise CommandError( + "couldn't handle %s template %s." % (self.app_or_project, template) + ) - def validate_name(self, name, name_or_dir='name'): + def validate_name(self, name, name_or_dir="name"): if name is None: - raise CommandError('you must provide {an} {app} name'.format( - an=self.a_or_an, - app=self.app_or_project, - )) + raise CommandError( + "you must provide {an} {app} name".format( + an=self.a_or_an, + app=self.app_or_project, + ) + ) # Check it's a valid directory name. if not name.isidentifier(): raise CommandError( @@ -261,47 +288,49 @@ class TemplateCommand(BaseCommand): """ Download the given URL and return the file name. """ + def cleanup_url(url): - tmp = url.rstrip('/') - filename = tmp.split('/')[-1] - if url.endswith('/'): - display_url = tmp + '/' + tmp = url.rstrip("/") + filename = tmp.split("/")[-1] + if url.endswith("/"): + display_url = tmp + "/" else: display_url = url return filename, display_url - prefix = 'django_%s_template_' % self.app_or_project - tempdir = tempfile.mkdtemp(prefix=prefix, suffix='_download') + prefix = "django_%s_template_" % self.app_or_project + tempdir = tempfile.mkdtemp(prefix=prefix, suffix="_download") self.paths_to_remove.append(tempdir) filename, display_url = cleanup_url(url) if self.verbosity >= 2: - self.stdout.write('Downloading %s' % display_url) + self.stdout.write("Downloading %s" % display_url) the_path = os.path.join(tempdir, filename) opener = build_opener() - opener.addheaders = [('User-Agent', f'Django/{django.__version__}')] + opener.addheaders = [("User-Agent", f"Django/{django.__version__}")] try: - with opener.open(url) as source, open(the_path, 'wb') as target: + with opener.open(url) as source, open(the_path, "wb") as target: headers = source.info() target.write(source.read()) except OSError as e: - raise CommandError("couldn't download URL %s to %s: %s" % - (url, filename, e)) + raise CommandError( + "couldn't download URL %s to %s: %s" % (url, filename, e) + ) - used_name = the_path.split('/')[-1] + used_name = the_path.split("/")[-1] # Trying to get better name from response headers - content_disposition = headers['content-disposition'] + content_disposition = headers["content-disposition"] if content_disposition: _, params = cgi.parse_header(content_disposition) - guessed_filename = params.get('filename') or used_name + guessed_filename = params.get("filename") or used_name else: guessed_filename = used_name # Falling back to content type guessing ext = self.splitext(guessed_filename)[1] - content_type = headers['content-type'] + content_type = headers["content-type"] if not ext and content_type: ext = mimetypes.guess_extension(content_type) if ext: @@ -322,7 +351,7 @@ class TemplateCommand(BaseCommand): Like os.path.splitext, but takes off .tar, too """ base, ext = posixpath.splitext(the_path) - if base.lower().endswith('.tar'): + if base.lower().endswith(".tar"): ext = base[-4:] + ext base = base[:-4] return base, ext @@ -332,23 +361,24 @@ class TemplateCommand(BaseCommand): Extract the given file to a temporary directory and return the path of the directory with the extracted content. """ - prefix = 'django_%s_template_' % self.app_or_project - tempdir = tempfile.mkdtemp(prefix=prefix, suffix='_extract') + prefix = "django_%s_template_" % self.app_or_project + tempdir = tempfile.mkdtemp(prefix=prefix, suffix="_extract") self.paths_to_remove.append(tempdir) if self.verbosity >= 2: - self.stdout.write('Extracting %s' % filename) + self.stdout.write("Extracting %s" % filename) try: archive.extract(filename, tempdir) return tempdir except (archive.ArchiveException, OSError) as e: - raise CommandError("couldn't extract file %s to %s: %s" % - (filename, tempdir, e)) + raise CommandError( + "couldn't extract file %s to %s: %s" % (filename, tempdir, e) + ) def is_url(self, template): """Return True if the name looks like a URL.""" - if ':' not in template: + if ":" not in template: return False - scheme = template.split(':', 1)[0].lower() + scheme = template.split(":", 1)[0].lower() return scheme in self.url_schemes def apply_umask(self, old_path, new_path): diff --git a/django/core/management/utils.py b/django/core/management/utils.py index c6901aa3d5..c12d90f6ae 100644 --- a/django/core/management/utils.py +++ b/django/core/management/utils.py @@ -10,20 +10,20 @@ from django.utils.encoding import DEFAULT_LOCALE_ENCODING from .base import CommandError, CommandParser -def popen_wrapper(args, stdout_encoding='utf-8'): +def popen_wrapper(args, stdout_encoding="utf-8"): """ Friendly wrapper around Popen. Return stdout output, stderr output, and OS status code. """ try: - p = run(args, capture_output=True, close_fds=os.name != 'nt') + p = run(args, capture_output=True, close_fds=os.name != "nt") except OSError as err: - raise CommandError('Error executing %s' % args[0]) from err + raise CommandError("Error executing %s" % args[0]) from err return ( p.stdout.decode(stdout_encoding), - p.stderr.decode(DEFAULT_LOCALE_ENCODING, errors='replace'), - p.returncode + p.stderr.decode(DEFAULT_LOCALE_ENCODING, errors="replace"), + p.returncode, ) @@ -42,25 +42,25 @@ def handle_extensions(extensions): """ ext_list = [] for ext in extensions: - ext_list.extend(ext.replace(' ', '').split(',')) + ext_list.extend(ext.replace(" ", "").split(",")) for i, ext in enumerate(ext_list): - if not ext.startswith('.'): - ext_list[i] = '.%s' % ext_list[i] + if not ext.startswith("."): + ext_list[i] = ".%s" % ext_list[i] return set(ext_list) def find_command(cmd, path=None, pathext=None): if path is None: - path = os.environ.get('PATH', '').split(os.pathsep) + path = os.environ.get("PATH", "").split(os.pathsep) if isinstance(path, str): path = [path] # check if there are funny path extensions for executables, e.g. Windows if pathext is None: - pathext = os.environ.get('PATHEXT', '.COM;.EXE;.BAT;.CMD').split(os.pathsep) + pathext = os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD").split(os.pathsep) # don't use extensions if the command ends with one of them for ext in pathext: if cmd.endswith(ext): - pathext = [''] + pathext = [""] break # check if we find the command on PATH for p in path: @@ -78,7 +78,7 @@ def get_random_secret_key(): """ Return a 50 character random string usable as a SECRET_KEY setting value. """ - chars = 'abcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*(-_=+)' + chars = "abcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*(-_=+)" return get_random_string(50, chars) @@ -93,11 +93,11 @@ def parse_apps_and_model_labels(labels): models = set() for label in labels: - if '.' in label: + if "." in label: try: model = installed_apps.get_model(label) except LookupError: - raise CommandError('Unknown model: %s' % label) + raise CommandError("Unknown model: %s" % label) models.add(model) else: try: @@ -116,7 +116,7 @@ def get_command_line_option(argv, option): option wasn't passed or if the argument list couldn't be parsed. """ parser = CommandParser(add_help=False, allow_abbrev=False) - parser.add_argument(option, dest='value') + parser.add_argument(option, dest="value") try: options, _ = parser.parse_known_args(argv[2:]) except CommandError: @@ -128,12 +128,12 @@ def get_command_line_option(argv, option): def normalize_path_patterns(patterns): """Normalize an iterable of glob style patterns based on OS.""" patterns = [os.path.normcase(p) for p in patterns] - dir_suffixes = {'%s*' % path_sep for path_sep in {'/', os.sep}} + dir_suffixes = {"%s*" % path_sep for path_sep in {"/", os.sep}} norm_patterns = [] for pattern in patterns: for dir_suffix in dir_suffixes: if pattern.endswith(dir_suffix): - norm_patterns.append(pattern[:-len(dir_suffix)]) + norm_patterns.append(pattern[: -len(dir_suffix)]) break else: norm_patterns.append(pattern) @@ -148,6 +148,8 @@ def is_ignored_path(path, ignore_patterns): path = Path(path) def ignore(pattern): - return fnmatch.fnmatchcase(path.name, pattern) or fnmatch.fnmatchcase(str(path), pattern) + return fnmatch.fnmatchcase(path.name, pattern) or fnmatch.fnmatchcase( + str(path), pattern + ) return any(ignore(pattern) for pattern in normalize_path_patterns(ignore_patterns)) diff --git a/django/core/paginator.py b/django/core/paginator.py index 7db64913d9..568445607e 100644 --- a/django/core/paginator.py +++ b/django/core/paginator.py @@ -27,10 +27,9 @@ class EmptyPage(InvalidPage): class Paginator: # Translators: String used to replace omitted page numbers in elided page # range generated by paginators, e.g. [1, 2, '…', 5, 6, 7, '…', 9, 10]. - ELLIPSIS = _('…') + ELLIPSIS = _("…") - def __init__(self, object_list, per_page, orphans=0, - allow_empty_first_page=True): + def __init__(self, object_list, per_page, orphans=0, allow_empty_first_page=True): self.object_list = object_list self._check_object_list_is_ordered() self.per_page = int(per_page) @@ -48,14 +47,14 @@ class Paginator: raise ValueError number = int(number) except (TypeError, ValueError): - raise PageNotAnInteger(_('That page number is not an integer')) + raise PageNotAnInteger(_("That page number is not an integer")) if number < 1: - raise EmptyPage(_('That page number is less than 1')) + raise EmptyPage(_("That page number is less than 1")) if number > self.num_pages: if number == 1 and self.allow_empty_first_page: pass else: - raise EmptyPage(_('That page contains no results')) + raise EmptyPage(_("That page contains no results")) return number def get_page(self, number): @@ -92,7 +91,7 @@ class Paginator: @cached_property def count(self): """Return the total number of objects, across all pages.""" - c = getattr(self.object_list, 'count', None) + c = getattr(self.object_list, "count", None) if callable(c) and not inspect.isbuiltin(c) and method_has_no_args(c): return c() return len(self.object_list) @@ -117,18 +116,20 @@ class Paginator: """ Warn if self.object_list is unordered (typically a QuerySet). """ - ordered = getattr(self.object_list, 'ordered', None) + ordered = getattr(self.object_list, "ordered", None) if ordered is not None and not ordered: obj_list_repr = ( - '{} {}'.format(self.object_list.model, self.object_list.__class__.__name__) - if hasattr(self.object_list, 'model') - else '{!r}'.format(self.object_list) + "{} {}".format( + self.object_list.model, self.object_list.__class__.__name__ + ) + if hasattr(self.object_list, "model") + else "{!r}".format(self.object_list) ) warnings.warn( - 'Pagination may yield inconsistent results with an unordered ' - 'object_list: {}.'.format(obj_list_repr), + "Pagination may yield inconsistent results with an unordered " + "object_list: {}.".format(obj_list_repr), UnorderedObjectListWarning, - stacklevel=3 + stacklevel=3, ) def get_elided_page_range(self, number=1, *, on_each_side=3, on_ends=2): @@ -164,14 +165,13 @@ class Paginator: class Page(collections.abc.Sequence): - def __init__(self, object_list, number, paginator): self.object_list = object_list self.number = number self.paginator = paginator def __repr__(self): - return '<Page %s of %s>' % (self.number, self.paginator.num_pages) + return "<Page %s of %s>" % (self.number, self.paginator.num_pages) def __len__(self): return len(self.object_list) @@ -179,7 +179,7 @@ class Page(collections.abc.Sequence): def __getitem__(self, index): if not isinstance(index, (int, slice)): raise TypeError( - 'Page indices must be integers or slices, not %s.' + "Page indices must be integers or slices, not %s." % type(index).__name__ ) # The object_list is converted to a list so that if it was a QuerySet diff --git a/django/core/serializers/__init__.py b/django/core/serializers/__init__.py index 793f6dc2bd..480c54b79b 100644 --- a/django/core/serializers/__init__.py +++ b/django/core/serializers/__init__.py @@ -42,6 +42,7 @@ class BadSerializer: is an error raised in the process of creating a serializer it will be raised and passed along to the caller when the serializer is used. """ + internal_use_only = False def __init__(self, exception): @@ -72,10 +73,14 @@ def register_serializer(format, serializer_module, serializers=None): except ImportError as exc: bad_serializer = BadSerializer(exc) - module = type('BadSerializerModule', (), { - 'Deserializer': bad_serializer, - 'Serializer': bad_serializer, - }) + module = type( + "BadSerializerModule", + (), + { + "Deserializer": bad_serializer, + "Serializer": bad_serializer, + }, + ) if serializers is None: _serializers[format] = module @@ -153,7 +158,9 @@ def _load_serializers(): register_serializer(format, BUILTIN_SERIALIZERS[format], serializers) if hasattr(settings, "SERIALIZATION_MODULES"): for format in settings.SERIALIZATION_MODULES: - register_serializer(format, settings.SERIALIZATION_MODULES[format], serializers) + register_serializer( + format, settings.SERIALIZATION_MODULES[format], serializers + ) _serializers = serializers @@ -177,8 +184,8 @@ def sort_dependencies(app_list, allow_cycles=False): for model in model_list: models.add(model) # Add any explicitly defined dependencies - if hasattr(model, 'natural_key'): - deps = getattr(model.natural_key, 'dependencies', []) + if hasattr(model, "natural_key"): + deps = getattr(model.natural_key, "dependencies", []) if deps: deps = [apps.get_model(dep) for dep in deps] else: @@ -189,7 +196,7 @@ def sort_dependencies(app_list, allow_cycles=False): for field in model._meta.fields: if field.remote_field: rel_model = field.remote_field.model - if hasattr(rel_model, 'natural_key') and rel_model != model: + if hasattr(rel_model, "natural_key") and rel_model != model: deps.append(rel_model) # Also add a dependency for any simple M2M relation with a model # that defines a natural key. M2M relations with explicit through @@ -197,7 +204,7 @@ def sort_dependencies(app_list, allow_cycles=False): for field in model._meta.many_to_many: if field.remote_field.through._meta.auto_created: rel_model = field.remote_field.model - if hasattr(rel_model, 'natural_key') and rel_model != model: + if hasattr(rel_model, "natural_key") and rel_model != model: deps.append(rel_model) model_dependencies.append((model, deps)) @@ -235,9 +242,11 @@ def sort_dependencies(app_list, allow_cycles=False): else: raise RuntimeError( "Can't resolve dependencies for %s in serialized app list." - % ', '.join( + % ", ".join( model._meta.label - for model, deps in sorted(skipped, key=lambda obj: obj[0].__name__) + for model, deps in sorted( + skipped, key=lambda obj: obj[0].__name__ + ) ), ) model_dependencies = skipped diff --git a/django/core/serializers/base.py b/django/core/serializers/base.py index 45c43a77d6..da85cb4b92 100644 --- a/django/core/serializers/base.py +++ b/django/core/serializers/base.py @@ -17,10 +17,11 @@ class PickleSerializer: Simple wrapper around pickle to be used in signing.dumps()/loads() and cache backends. """ + def __init__(self, protocol=None): warnings.warn( - 'PickleSerializer is deprecated due to its security risk. Use ' - 'JSONSerializer instead.', + "PickleSerializer is deprecated due to its security risk. Use " + "JSONSerializer instead.", RemovedInDjango50Warning, ) self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol @@ -34,11 +35,13 @@ class PickleSerializer: class SerializerDoesNotExist(KeyError): """The requested serializer was not found.""" + pass class SerializationError(Exception): """Something bad happened during serialization.""" + pass @@ -51,11 +54,15 @@ class DeserializationError(Exception): Factory method for creating a deserialization error which has a more explanatory message. """ - return cls("%s: (%s:pk=%s) field_value was '%s'" % (original_exc, model, fk, field_value)) + return cls( + "%s: (%s:pk=%s) field_value was '%s'" + % (original_exc, model, fk, field_value) + ) class M2MDeserializationError(Exception): """Something bad happened during deserialization of a ManyToManyField.""" + def __init__(self, original_exc, pk): self.original_exc = original_exc self.pk = pk @@ -77,10 +84,12 @@ class ProgressBar: if self.prev_done >= done: return self.prev_done = done - cr = '' if self.total_count == 1 else '\r' - self.output.write(cr + '[' + '.' * done + ' ' * (self.progress_width - done) + ']') + cr = "" if self.total_count == 1 else "\r" + self.output.write( + cr + "[" + "." * done + " " * (self.progress_width - done) + "]" + ) if done == self.progress_width: - self.output.write('\n') + self.output.write("\n") self.output.flush() @@ -95,8 +104,18 @@ class Serializer: progress_class = ProgressBar stream_class = StringIO - def serialize(self, queryset, *, stream=None, fields=None, use_natural_foreign_keys=False, - use_natural_primary_keys=False, progress_output=None, object_count=0, **options): + def serialize( + self, + queryset, + *, + stream=None, + fields=None, + use_natural_foreign_keys=False, + use_natural_primary_keys=False, + progress_output=None, + object_count=0, + **options, + ): """ Serialize a queryset. """ @@ -120,20 +139,31 @@ class Serializer: # be serialized, otherwise deserialization isn't possible. if self.use_natural_primary_keys: pk = concrete_model._meta.pk - pk_parent = pk if pk.remote_field and pk.remote_field.parent_link else None + pk_parent = ( + pk if pk.remote_field and pk.remote_field.parent_link else None + ) else: pk_parent = None for field in concrete_model._meta.local_fields: if field.serialize or field is pk_parent: if field.remote_field is None: - if self.selected_fields is None or field.attname in self.selected_fields: + if ( + self.selected_fields is None + or field.attname in self.selected_fields + ): self.handle_field(obj, field) else: - if self.selected_fields is None or field.attname[:-3] in self.selected_fields: + if ( + self.selected_fields is None + or field.attname[:-3] in self.selected_fields + ): self.handle_fk_field(obj, field) for field in concrete_model._meta.local_many_to_many: if field.serialize: - if self.selected_fields is None or field.attname in self.selected_fields: + if ( + self.selected_fields is None + or field.attname in self.selected_fields + ): self.handle_m2m_field(obj, field) self.end_object(obj) progress_bar.update(count) @@ -145,7 +175,9 @@ class Serializer: """ Called when serializing of the queryset starts. """ - raise NotImplementedError('subclasses of Serializer must provide a start_serialization() method') + raise NotImplementedError( + "subclasses of Serializer must provide a start_serialization() method" + ) def end_serialization(self): """ @@ -157,7 +189,9 @@ class Serializer: """ Called when serializing of an object starts. """ - raise NotImplementedError('subclasses of Serializer must provide a start_object() method') + raise NotImplementedError( + "subclasses of Serializer must provide a start_object() method" + ) def end_object(self, obj): """ @@ -169,26 +203,32 @@ class Serializer: """ Called to handle each individual (non-relational) field on an object. """ - raise NotImplementedError('subclasses of Serializer must provide a handle_field() method') + raise NotImplementedError( + "subclasses of Serializer must provide a handle_field() method" + ) def handle_fk_field(self, obj, field): """ Called to handle a ForeignKey field. """ - raise NotImplementedError('subclasses of Serializer must provide a handle_fk_field() method') + raise NotImplementedError( + "subclasses of Serializer must provide a handle_fk_field() method" + ) def handle_m2m_field(self, obj, field): """ Called to handle a ManyToManyField. """ - raise NotImplementedError('subclasses of Serializer must provide a handle_m2m_field() method') + raise NotImplementedError( + "subclasses of Serializer must provide a handle_m2m_field() method" + ) def getvalue(self): """ Return the fully serialized queryset (or None if the output stream is not seekable). """ - if callable(getattr(self.stream, 'getvalue', None)): + if callable(getattr(self.stream, "getvalue", None)): return self.stream.getvalue() @@ -212,7 +252,9 @@ class Deserializer: def __next__(self): """Iteration interface -- return the next item in the stream""" - raise NotImplementedError('subclasses of Deserializer must provide a __next__() method') + raise NotImplementedError( + "subclasses of Deserializer must provide a __next__() method" + ) class DeserializedObject: @@ -256,18 +298,26 @@ class DeserializedObject: self.m2m_data = {} for field, field_value in self.deferred_fields.items(): opts = self.object._meta - label = opts.app_label + '.' + opts.model_name + label = opts.app_label + "." + opts.model_name if isinstance(field.remote_field, models.ManyToManyRel): try: - values = deserialize_m2m_values(field, field_value, using, handle_forward_references=False) + values = deserialize_m2m_values( + field, field_value, using, handle_forward_references=False + ) except M2MDeserializationError as e: - raise DeserializationError.WithData(e.original_exc, label, self.object.pk, e.pk) + raise DeserializationError.WithData( + e.original_exc, label, self.object.pk, e.pk + ) self.m2m_data[field.name] = values elif isinstance(field.remote_field, models.ManyToOneRel): try: - value = deserialize_fk_value(field, field_value, using, handle_forward_references=False) + value = deserialize_fk_value( + field, field_value, using, handle_forward_references=False + ) except Exception as e: - raise DeserializationError.WithData(e, label, self.object.pk, field_value) + raise DeserializationError.WithData( + e, label, self.object.pk, field_value + ) setattr(self.object, field.attname, value) self.save() @@ -281,8 +331,11 @@ def build_instance(Model, data, db): """ default_manager = Model._meta.default_manager pk = data.get(Model._meta.pk.attname) - if (pk is None and hasattr(default_manager, 'get_by_natural_key') and - hasattr(Model, 'natural_key')): + if ( + pk is None + and hasattr(default_manager, "get_by_natural_key") + and hasattr(Model, "natural_key") + ): natural_key = Model(**data).natural_key() try: data[Model._meta.pk.attname] = Model._meta.pk.to_python( @@ -295,13 +348,20 @@ def build_instance(Model, data, db): def deserialize_m2m_values(field, field_value, using, handle_forward_references): model = field.remote_field.model - if hasattr(model._default_manager, 'get_by_natural_key'): + if hasattr(model._default_manager, "get_by_natural_key"): + def m2m_convert(value): - if hasattr(value, '__iter__') and not isinstance(value, str): - return model._default_manager.db_manager(using).get_by_natural_key(*value).pk + if hasattr(value, "__iter__") and not isinstance(value, str): + return ( + model._default_manager.db_manager(using) + .get_by_natural_key(*value) + .pk + ) else: return model._meta.pk.to_python(value) + else: + def m2m_convert(v): return model._meta.pk.to_python(v) @@ -327,8 +387,11 @@ def deserialize_fk_value(field, field_value, using, handle_forward_references): model = field.remote_field.model default_manager = model._default_manager field_name = field.remote_field.field_name - if (hasattr(default_manager, 'get_by_natural_key') and - hasattr(field_value, '__iter__') and not isinstance(field_value, str)): + if ( + hasattr(default_manager, "get_by_natural_key") + and hasattr(field_value, "__iter__") + and not isinstance(field_value, str) + ): try: obj = default_manager.db_manager(using).get_by_natural_key(*field_value) except ObjectDoesNotExist: diff --git a/django/core/serializers/json.py b/django/core/serializers/json.py index 886e8f894c..59d7318409 100644 --- a/django/core/serializers/json.py +++ b/django/core/serializers/json.py @@ -8,9 +8,8 @@ import json import uuid from django.core.serializers.base import DeserializationError -from django.core.serializers.python import ( - Deserializer as PythonDeserializer, Serializer as PythonSerializer, -) +from django.core.serializers.python import Deserializer as PythonDeserializer +from django.core.serializers.python import Serializer as PythonSerializer from django.utils.duration import duration_iso_string from django.utils.functional import Promise from django.utils.timezone import is_aware @@ -18,18 +17,19 @@ from django.utils.timezone import is_aware class Serializer(PythonSerializer): """Convert a queryset to JSON.""" + internal_use_only = False def _init_options(self): self._current = None self.json_kwargs = self.options.copy() - self.json_kwargs.pop('stream', None) - self.json_kwargs.pop('fields', None) - if self.options.get('indent'): + self.json_kwargs.pop("stream", None) + self.json_kwargs.pop("fields", None) + if self.options.get("indent"): # Prevent trailing spaces - self.json_kwargs['separators'] = (',', ': ') - self.json_kwargs.setdefault('cls', DjangoJSONEncoder) - self.json_kwargs.setdefault('ensure_ascii', False) + self.json_kwargs["separators"] = (",", ": ") + self.json_kwargs.setdefault("cls", DjangoJSONEncoder) + self.json_kwargs.setdefault("ensure_ascii", False) def start_serialization(self): self._init_options() @@ -79,14 +79,15 @@ class DjangoJSONEncoder(json.JSONEncoder): JSONEncoder subclass that knows how to encode date/time, decimal types, and UUIDs. """ + def default(self, o): # See "Date Time String Format" in the ECMA-262 specification. if isinstance(o, datetime.datetime): r = o.isoformat() if o.microsecond: r = r[:23] + r[26:] - if r.endswith('+00:00'): - r = r[:-6] + 'Z' + if r.endswith("+00:00"): + r = r[:-6] + "Z" return r elif isinstance(o, datetime.date): return o.isoformat() diff --git a/django/core/serializers/jsonl.py b/django/core/serializers/jsonl.py index 4b3e46ed8e..c264c2ccaf 100644 --- a/django/core/serializers/jsonl.py +++ b/django/core/serializers/jsonl.py @@ -6,24 +6,24 @@ import json from django.core.serializers.base import DeserializationError from django.core.serializers.json import DjangoJSONEncoder -from django.core.serializers.python import ( - Deserializer as PythonDeserializer, Serializer as PythonSerializer, -) +from django.core.serializers.python import Deserializer as PythonDeserializer +from django.core.serializers.python import Serializer as PythonSerializer class Serializer(PythonSerializer): """Convert a queryset to JSON Lines.""" + internal_use_only = False def _init_options(self): self._current = None self.json_kwargs = self.options.copy() - self.json_kwargs.pop('stream', None) - self.json_kwargs.pop('fields', None) - self.json_kwargs.pop('indent', None) - self.json_kwargs['separators'] = (',', ': ') - self.json_kwargs.setdefault('cls', DjangoJSONEncoder) - self.json_kwargs.setdefault('ensure_ascii', False) + self.json_kwargs.pop("stream", None) + self.json_kwargs.pop("fields", None) + self.json_kwargs.pop("indent", None) + self.json_kwargs["separators"] = (",", ": ") + self.json_kwargs.setdefault("cls", DjangoJSONEncoder) + self.json_kwargs.setdefault("ensure_ascii", False) def start_serialization(self): self._init_options() diff --git a/django/core/serializers/python.py b/django/core/serializers/python.py index 0ceb676e90..a3918bf9d2 100644 --- a/django/core/serializers/python.py +++ b/django/core/serializers/python.py @@ -32,10 +32,10 @@ class Serializer(base.Serializer): self._current = None def get_dump_object(self, obj): - data = {'model': str(obj._meta)} - if not self.use_natural_primary_keys or not hasattr(obj, 'natural_key'): + data = {"model": str(obj._meta)} + if not self.use_natural_primary_keys or not hasattr(obj, "natural_key"): data["pk"] = self._value_from_field(obj, obj._meta.pk) - data['fields'] = self._current + data["fields"] = self._current return data def _value_from_field(self, obj, field): @@ -49,7 +49,9 @@ class Serializer(base.Serializer): self._current[field.name] = self._value_from_field(obj, field) def handle_fk_field(self, obj, field): - if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'): + if self.use_natural_foreign_keys and hasattr( + field.remote_field.model, "natural_key" + ): related = getattr(obj, field.name) if related: value = related.natural_key() @@ -61,13 +63,19 @@ class Serializer(base.Serializer): def handle_m2m_field(self, obj, field): if field.remote_field.through._meta.auto_created: - if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'): + if self.use_natural_foreign_keys and hasattr( + field.remote_field.model, "natural_key" + ): + def m2m_value(value): return value.natural_key() + else: + def m2m_value(value): return self._value_from_field(value, value._meta.pk) - m2m_iter = getattr(obj, '_prefetched_objects_cache', {}).get( + + m2m_iter = getattr(obj, "_prefetched_objects_cache", {}).get( field.name, getattr(obj, field.name).iterator(), ) @@ -77,14 +85,16 @@ class Serializer(base.Serializer): return self.objects -def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options): +def Deserializer( + object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options +): """ Deserialize simple Python objects back into Django ORM instances. It's expected that you pass the Python objects themselves (instead of a stream or a string) to the constructor """ - handle_forward_references = options.pop('handle_forward_references', False) + handle_forward_references = options.pop("handle_forward_references", False) field_names_cache = {} # Model: <list of field_names> for d in object_list: @@ -97,11 +107,13 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False else: raise data = {} - if 'pk' in d: + if "pk" in d: try: - data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get('pk')) + data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get("pk")) except Exception as e: - raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None) + raise base.DeserializationError.WithData( + e, d["model"], d.get("pk"), None + ) m2m_data = {} deferred_fields = {} @@ -119,21 +131,33 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False field = Model._meta.get_field(field_name) # Handle M2M relations - if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel): + if field.remote_field and isinstance( + field.remote_field, models.ManyToManyRel + ): try: - values = base.deserialize_m2m_values(field, field_value, using, handle_forward_references) + values = base.deserialize_m2m_values( + field, field_value, using, handle_forward_references + ) except base.M2MDeserializationError as e: - raise base.DeserializationError.WithData(e.original_exc, d['model'], d.get('pk'), e.pk) + raise base.DeserializationError.WithData( + e.original_exc, d["model"], d.get("pk"), e.pk + ) if values == base.DEFER_FIELD: deferred_fields[field] = field_value else: m2m_data[field.name] = values # Handle FK fields - elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel): + elif field.remote_field and isinstance( + field.remote_field, models.ManyToOneRel + ): try: - value = base.deserialize_fk_value(field, field_value, using, handle_forward_references) + value = base.deserialize_fk_value( + field, field_value, using, handle_forward_references + ) except Exception as e: - raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value) + raise base.DeserializationError.WithData( + e, d["model"], d.get("pk"), field_value + ) if value == base.DEFER_FIELD: deferred_fields[field] = field_value else: @@ -143,7 +167,9 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False try: data[field.name] = field.to_python(field_value) except Exception as e: - raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value) + raise base.DeserializationError.WithData( + e, d["model"], d.get("pk"), field_value + ) obj = base.build_instance(Model, data, using) yield base.DeserializedObject(obj, m2m_data, deferred_fields) @@ -154,4 +180,6 @@ def _get_model(model_identifier): try: return apps.get_model(model_identifier) except (LookupError, TypeError): - raise base.DeserializationError("Invalid model identifier: '%s'" % model_identifier) + raise base.DeserializationError( + "Invalid model identifier: '%s'" % model_identifier + ) diff --git a/django/core/serializers/pyyaml.py b/django/core/serializers/pyyaml.py index 9719f6e1b4..9a20b6658f 100644 --- a/django/core/serializers/pyyaml.py +++ b/django/core/serializers/pyyaml.py @@ -11,28 +11,30 @@ from io import StringIO import yaml from django.core.serializers.base import DeserializationError -from django.core.serializers.python import ( - Deserializer as PythonDeserializer, Serializer as PythonSerializer, -) +from django.core.serializers.python import Deserializer as PythonDeserializer +from django.core.serializers.python import Serializer as PythonSerializer from django.db import models # Use the C (faster) implementation if possible try: - from yaml import CSafeDumper as SafeDumper, CSafeLoader as SafeLoader + from yaml import CSafeDumper as SafeDumper + from yaml import CSafeLoader as SafeLoader except ImportError: from yaml import SafeDumper, SafeLoader class DjangoSafeDumper(SafeDumper): def represent_decimal(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', str(data)) + return self.represent_scalar("tag:yaml.org,2002:str", str(data)) def represent_ordered_dict(self, data): - return self.represent_mapping('tag:yaml.org,2002:map', data.items()) + return self.represent_mapping("tag:yaml.org,2002:map", data.items()) DjangoSafeDumper.add_representer(decimal.Decimal, DjangoSafeDumper.represent_decimal) -DjangoSafeDumper.add_representer(collections.OrderedDict, DjangoSafeDumper.represent_ordered_dict) +DjangoSafeDumper.add_representer( + collections.OrderedDict, DjangoSafeDumper.represent_ordered_dict +) # Workaround to represent dictionaries in insertion order. # See https://github.com/yaml/pyyaml/pull/143. DjangoSafeDumper.add_representer(dict, DjangoSafeDumper.represent_ordered_dict) @@ -56,7 +58,7 @@ class Serializer(PythonSerializer): super().handle_field(obj, field) def end_serialization(self): - self.options.setdefault('allow_unicode', True) + self.options.setdefault("allow_unicode", True) yaml.dump(self.objects, self.stream, Dumper=DjangoSafeDumper, **self.options) def getvalue(self): diff --git a/django/core/serializers/xml_serializer.py b/django/core/serializers/xml_serializer.py index 88bfa59032..8d3918cfaa 100644 --- a/django/core/serializers/xml_serializer.py +++ b/django/core/serializers/xml_serializer.py @@ -11,23 +11,25 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.core.serializers import base from django.db import DEFAULT_DB_ALIAS, models -from django.utils.xmlutils import ( - SimplerXMLGenerator, UnserializableContentError, -) +from django.utils.xmlutils import SimplerXMLGenerator, UnserializableContentError class Serializer(base.Serializer): """Serialize a QuerySet to XML.""" def indent(self, level): - if self.options.get('indent') is not None: - self.xml.ignorableWhitespace('\n' + ' ' * self.options.get('indent') * level) + if self.options.get("indent") is not None: + self.xml.ignorableWhitespace( + "\n" + " " * self.options.get("indent") * level + ) def start_serialization(self): """ Start serialization -- open the XML document and the root element. """ - self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET)) + self.xml = SimplerXMLGenerator( + self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET) + ) self.xml.startDocument() self.xml.startElement("django-objects", {"version": "1.0"}) @@ -44,14 +46,16 @@ class Serializer(base.Serializer): Called as each object is handled. """ if not hasattr(obj, "_meta"): - raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj)) + raise base.SerializationError( + "Non-model object (%s) encountered during serialization" % type(obj) + ) self.indent(1) - attrs = {'model': str(obj._meta)} - if not self.use_natural_primary_keys or not hasattr(obj, 'natural_key'): + attrs = {"model": str(obj._meta)} + if not self.use_natural_primary_keys or not hasattr(obj, "natural_key"): obj_pk = obj.pk if obj_pk is not None: - attrs['pk'] = str(obj_pk) + attrs["pk"] = str(obj_pk) self.xml.startElement("object", attrs) @@ -68,23 +72,28 @@ class Serializer(base.Serializer): ManyToManyFields). """ self.indent(2) - self.xml.startElement('field', { - 'name': field.name, - 'type': field.get_internal_type(), - }) + self.xml.startElement( + "field", + { + "name": field.name, + "type": field.get_internal_type(), + }, + ) # Get a "string version" of the object's data. if getattr(obj, field.name) is not None: value = field.value_to_string(obj) - if field.get_internal_type() == 'JSONField': + if field.get_internal_type() == "JSONField": # Dump value since JSONField.value_to_string() doesn't output # strings. value = json.dumps(value, cls=field.encoder) try: self.xml.characters(value) except UnserializableContentError: - raise ValueError("%s.%s (pk:%s) contains unserializable characters" % ( - obj.__class__.__name__, field.name, obj.pk)) + raise ValueError( + "%s.%s (pk:%s) contains unserializable characters" + % (obj.__class__.__name__, field.name, obj.pk) + ) else: self.xml.addQuickElement("None") @@ -98,7 +107,9 @@ class Serializer(base.Serializer): self._start_relational_field(field) related_att = getattr(obj, field.get_attname()) if related_att is not None: - if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'): + if self.use_natural_foreign_keys and hasattr( + field.remote_field.model, "natural_key" + ): related = getattr(obj, field.name) # If related object has a natural key, use it related = related.natural_key() @@ -121,7 +132,9 @@ class Serializer(base.Serializer): """ if field.remote_field.through._meta.auto_created: self._start_relational_field(field) - if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'): + if self.use_natural_foreign_keys and hasattr( + field.remote_field.model, "natural_key" + ): # If the objects in the m2m have a natural key, use it def handle_m2m(value): natural = value.natural_key() @@ -132,12 +145,13 @@ class Serializer(base.Serializer): self.xml.characters(str(key_value)) self.xml.endElement("natural") self.xml.endElement("object") + else: + def handle_m2m(value): - self.xml.addQuickElement("object", attrs={ - 'pk': str(value.pk) - }) - m2m_iter = getattr(obj, '_prefetched_objects_cache', {}).get( + self.xml.addQuickElement("object", attrs={"pk": str(value.pk)}) + + m2m_iter = getattr(obj, "_prefetched_objects_cache", {}).get( field.name, getattr(obj, field.name).iterator(), ) @@ -149,19 +163,29 @@ class Serializer(base.Serializer): def _start_relational_field(self, field): """Output the <field> element for relational fields.""" self.indent(2) - self.xml.startElement('field', { - 'name': field.name, - 'rel': field.remote_field.__class__.__name__, - 'to': str(field.remote_field.model._meta), - }) + self.xml.startElement( + "field", + { + "name": field.name, + "rel": field.remote_field.__class__.__name__, + "to": str(field.remote_field.model._meta), + }, + ) class Deserializer(base.Deserializer): """Deserialize XML.""" - def __init__(self, stream_or_string, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options): + def __init__( + self, + stream_or_string, + *, + using=DEFAULT_DB_ALIAS, + ignorenonexistent=False, + **options, + ): super().__init__(stream_or_string, **options) - self.handle_forward_references = options.pop('handle_forward_references', False) + self.handle_forward_references = options.pop("handle_forward_references", False) self.event_stream = pulldom.parse(self.stream, self._make_parser()) self.db = using self.ignore = ignorenonexistent @@ -185,9 +209,10 @@ class Deserializer(base.Deserializer): # Start building a data dictionary from the object. data = {} - if node.hasAttribute('pk'): + if node.hasAttribute("pk"): data[Model._meta.pk.attname] = Model._meta.pk.to_python( - node.getAttribute('pk')) + node.getAttribute("pk") + ) # Also start building a dict of m2m data (this is saved as # {m2m_accessor_attribute : [list_of_related_objects]}) @@ -201,7 +226,9 @@ class Deserializer(base.Deserializer): # sensing a pattern here?) field_name = field_node.getAttribute("name") if not field_name: - raise base.DeserializationError("<field> node is missing the 'name' attribute") + raise base.DeserializationError( + "<field> node is missing the 'name' attribute" + ) # Get the field from the Model. This will raise a # FieldDoesNotExist if, well, the field doesn't exist, which will @@ -211,34 +238,38 @@ class Deserializer(base.Deserializer): field = Model._meta.get_field(field_name) # As is usually the case, relation fields get the special treatment. - if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel): + if field.remote_field and isinstance( + field.remote_field, models.ManyToManyRel + ): value = self._handle_m2m_field_node(field_node, field) if value == base.DEFER_FIELD: deferred_fields[field] = [ [ getInnerText(nat_node).strip() - for nat_node in obj_node.getElementsByTagName('natural') + for nat_node in obj_node.getElementsByTagName("natural") ] - for obj_node in field_node.getElementsByTagName('object') + for obj_node in field_node.getElementsByTagName("object") ] else: m2m_data[field.name] = value - elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel): + elif field.remote_field and isinstance( + field.remote_field, models.ManyToOneRel + ): value = self._handle_fk_field_node(field_node, field) if value == base.DEFER_FIELD: deferred_fields[field] = [ getInnerText(k).strip() - for k in field_node.getElementsByTagName('natural') + for k in field_node.getElementsByTagName("natural") ] else: data[field.attname] = value else: - if field_node.getElementsByTagName('None'): + if field_node.getElementsByTagName("None"): value = None else: value = field.to_python(getInnerText(field_node).strip()) # Load value since JSONField.to_python() outputs strings. - if field.get_internal_type() == 'JSONField': + if field.get_internal_type() == "JSONField": value = json.loads(value, cls=field.decoder) data[field.name] = value @@ -252,17 +283,19 @@ class Deserializer(base.Deserializer): Handle a <field> node for a ForeignKey """ # Check if there is a child node named 'None', returning None if so. - if node.getElementsByTagName('None'): + if node.getElementsByTagName("None"): return None else: model = field.remote_field.model - if hasattr(model._default_manager, 'get_by_natural_key'): - keys = node.getElementsByTagName('natural') + if hasattr(model._default_manager, "get_by_natural_key"): + keys = node.getElementsByTagName("natural") if keys: # If there are 'natural' subelements, it must be a natural key field_value = [getInnerText(k).strip() for k in keys] try: - obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value) + obj = model._default_manager.db_manager( + self.db + ).get_by_natural_key(*field_value) except ObjectDoesNotExist: if self.handle_forward_references: return base.DEFER_FIELD @@ -276,11 +309,15 @@ class Deserializer(base.Deserializer): else: # Otherwise, treat like a normal PK field_value = getInnerText(node).strip() - obj_pk = model._meta.get_field(field.remote_field.field_name).to_python(field_value) + obj_pk = model._meta.get_field( + field.remote_field.field_name + ).to_python(field_value) return obj_pk else: field_value = getInnerText(node).strip() - return model._meta.get_field(field.remote_field.field_name).to_python(field_value) + return model._meta.get_field(field.remote_field.field_name).to_python( + field_value + ) def _handle_m2m_field_node(self, node, field): """ @@ -288,23 +325,31 @@ class Deserializer(base.Deserializer): """ model = field.remote_field.model default_manager = model._default_manager - if hasattr(default_manager, 'get_by_natural_key'): + if hasattr(default_manager, "get_by_natural_key"): + def m2m_convert(n): - keys = n.getElementsByTagName('natural') + keys = n.getElementsByTagName("natural") if keys: # If there are 'natural' subelements, it must be a natural key field_value = [getInnerText(k).strip() for k in keys] - obj_pk = default_manager.db_manager(self.db).get_by_natural_key(*field_value).pk + obj_pk = ( + default_manager.db_manager(self.db) + .get_by_natural_key(*field_value) + .pk + ) else: # Otherwise, treat like a normal PK value. - obj_pk = model._meta.pk.to_python(n.getAttribute('pk')) + obj_pk = model._meta.pk.to_python(n.getAttribute("pk")) return obj_pk + else: + def m2m_convert(n): - return model._meta.pk.to_python(n.getAttribute('pk')) + return model._meta.pk.to_python(n.getAttribute("pk")) + values = [] try: - for c in node.getElementsByTagName('object'): + for c in node.getElementsByTagName("object"): values.append(m2m_convert(c)) except Exception as e: if isinstance(e, ObjectDoesNotExist) and self.handle_forward_references: @@ -323,13 +368,15 @@ class Deserializer(base.Deserializer): if not model_identifier: raise base.DeserializationError( "<%s> node is missing the required '%s' attribute" - % (node.nodeName, attr)) + % (node.nodeName, attr) + ) try: return apps.get_model(model_identifier) except (LookupError, TypeError): raise base.DeserializationError( "<%s> node has invalid model identifier: '%s'" - % (node.nodeName, model_identifier)) + % (node.nodeName, model_identifier) + ) def getInnerText(node): @@ -337,7 +384,10 @@ def getInnerText(node): # inspired by https://mail.python.org/pipermail/xml-sig/2005-March/011022.html inner_text = [] for child in node.childNodes: - if child.nodeType == child.TEXT_NODE or child.nodeType == child.CDATA_SECTION_NODE: + if ( + child.nodeType == child.TEXT_NODE + or child.nodeType == child.CDATA_SECTION_NODE + ): inner_text.append(child.data) elif child.nodeType == child.ELEMENT_NODE: inner_text.extend(getInnerText(child)) @@ -355,6 +405,7 @@ class DefusedExpatParser(_ExpatParser): Forbid DTDs, external entity references """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setFeature(handler.feature_external_ges, False) @@ -363,8 +414,9 @@ class DefusedExpatParser(_ExpatParser): def start_doctype_decl(self, name, sysid, pubid, has_internal_subset): raise DTDForbidden(name, sysid, pubid) - def entity_decl(self, name, is_parameter_entity, value, base, - sysid, pubid, notation_name): + def entity_decl( + self, name, is_parameter_entity, value, base, sysid, pubid, notation_name + ): raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) def unparsed_entity_decl(self, name, base, sysid, pubid, notation_name): @@ -385,12 +437,14 @@ class DefusedExpatParser(_ExpatParser): class DefusedXmlException(ValueError): """Base exception.""" + def __repr__(self): return str(self) class DTDForbidden(DefusedXmlException): """Document type definition is forbidden.""" + def __init__(self, name, sysid, pubid): super().__init__() self.name = name @@ -404,6 +458,7 @@ class DTDForbidden(DefusedXmlException): class EntitiesForbidden(DefusedXmlException): """Entity definition is forbidden.""" + def __init__(self, name, value, base, sysid, pubid, notation_name): super().__init__() self.name = name @@ -420,6 +475,7 @@ class EntitiesForbidden(DefusedXmlException): class ExternalReferenceForbidden(DefusedXmlException): """Resolving an external reference is forbidden.""" + def __init__(self, context, base, sysid, pubid): super().__init__() self.context = context diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py index 6cc8a46778..440e7bc9cb 100644 --- a/django/core/servers/basehttp.py +++ b/django/core/servers/basehttp.py @@ -19,9 +19,9 @@ from django.core.wsgi import get_wsgi_application from django.db import connections from django.utils.module_loading import import_string -__all__ = ('WSGIServer', 'WSGIRequestHandler') +__all__ = ("WSGIServer", "WSGIRequestHandler") -logger = logging.getLogger('django.server') +logger = logging.getLogger("django.server") def get_internal_wsgi_application(): @@ -38,7 +38,8 @@ def get_internal_wsgi_application(): whatever ``django.core.wsgi.get_wsgi_application`` returns. """ from django.conf import settings - app_path = getattr(settings, 'WSGI_APPLICATION') + + app_path = getattr(settings, "WSGI_APPLICATION") if app_path is None: return get_wsgi_application() @@ -53,11 +54,14 @@ def get_internal_wsgi_application(): def is_broken_pipe_error(): exc_type, _, _ = sys.exc_info() - return issubclass(exc_type, ( - BrokenPipeError, - ConnectionAbortedError, - ConnectionResetError, - )) + return issubclass( + exc_type, + ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + ), + ) class WSGIServer(simple_server.WSGIServer): @@ -80,6 +84,7 @@ class WSGIServer(simple_server.WSGIServer): class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer): """A threaded version of the WSGIServer""" + daemon_threads = True def __init__(self, *args, connections_override=None, **kwargs): @@ -106,7 +111,7 @@ class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer): class ServerHandler(simple_server.ServerHandler): - http_version = '1.1' + http_version = "1.1" def __init__(self, stdin, stdout, stderr, environ, **kwargs): """ @@ -116,24 +121,26 @@ class ServerHandler(simple_server.ServerHandler): This fix applies only for testserver/runserver. """ try: - content_length = int(environ.get('CONTENT_LENGTH')) + content_length = int(environ.get("CONTENT_LENGTH")) except (ValueError, TypeError): content_length = 0 - super().__init__(LimitedStream(stdin, content_length), stdout, stderr, environ, **kwargs) + super().__init__( + LimitedStream(stdin, content_length), stdout, stderr, environ, **kwargs + ) def cleanup_headers(self): super().cleanup_headers() # HTTP/1.1 requires support for persistent connections. Send 'close' if # the content length is unknown to prevent clients from reusing the # connection. - if 'Content-Length' not in self.headers: - self.headers['Connection'] = 'close' + if "Content-Length" not in self.headers: + self.headers["Connection"] = "close" # Persistent connections require threading server. elif not isinstance(self.request_handler.server, socketserver.ThreadingMixIn): - self.headers['Connection'] = 'close' + self.headers["Connection"] = "close" # Mark the connection for closing if it's set as such above or if the # application sent the header. - if self.headers.get('Connection') == 'close': + if self.headers.get("Connection") == "close": self.request_handler.close_connection = True def close(self): @@ -142,7 +149,7 @@ class ServerHandler(simple_server.ServerHandler): class WSGIRequestHandler(simple_server.WSGIRequestHandler): - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def address_string(self): # Short-circuit parent method to not call socket.getfqdn @@ -150,22 +157,23 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): def log_message(self, format, *args): extra = { - 'request': self.request, - 'server_time': self.log_date_time_string(), + "request": self.request, + "server_time": self.log_date_time_string(), } - if args[1][0] == '4': + if args[1][0] == "4": # 0x16 = Handshake, 0x03 = SSL 3.0 or TLS 1.x - if args[0].startswith('\x16\x03'): - extra['status_code'] = 500 + if args[0].startswith("\x16\x03"): + extra["status_code"] = 500 logger.error( "You're accessing the development server over HTTPS, but " - "it only supports HTTP.\n", extra=extra, + "it only supports HTTP.\n", + extra=extra, ) return if args[1].isdigit() and len(args[1]) == 3: status_code = int(args[1]) - extra['status_code'] = status_code + extra["status_code"] = status_code if status_code >= 500: level = logger.error @@ -184,7 +192,7 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): # between underscores and dashes both normalized to underscores in WSGI # env vars. Nginx and Apache 2.4+ both do this as well. for k in self.headers: - if '_' in k: + if "_" in k: del self.headers[k] return super().get_environ() @@ -203,9 +211,9 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): """Copy of WSGIRequestHandler.handle() but with different ServerHandler""" self.raw_requestline = self.rfile.readline(65537) if len(self.raw_requestline) > 65536: - self.requestline = '' - self.request_version = '' - self.command = '' + self.requestline = "" + self.request_version = "" + self.command = "" self.send_error(414) return @@ -215,14 +223,14 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): handler = ServerHandler( self.rfile, self.wfile, self.get_stderr(), self.get_environ() ) - handler.request_handler = self # backpointer for logging & connection closing + handler.request_handler = self # backpointer for logging & connection closing handler.run(self.server.get_app()) def run(addr, port, wsgi_handler, ipv6=False, threading=False, server_cls=WSGIServer): server_address = (addr, port) if threading: - httpd_cls = type('WSGIServer', (socketserver.ThreadingMixIn, server_cls), {}) + httpd_cls = type("WSGIServer", (socketserver.ThreadingMixIn, server_cls), {}) else: httpd_cls = server_cls httpd = httpd_cls(server_address, WSGIRequestHandler, ipv6=ipv6) diff --git a/django/core/signing.py b/django/core/signing.py index cd86fdfab6..916885abb3 100644 --- a/django/core/signing.py +++ b/django/core/signing.py @@ -45,26 +45,28 @@ from django.utils.encoding import force_bytes from django.utils.module_loading import import_string from django.utils.regex_helper import _lazy_re_compile -_SEP_UNSAFE = _lazy_re_compile(r'^[A-z0-9-_=]*$') -BASE62_ALPHABET = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_SEP_UNSAFE = _lazy_re_compile(r"^[A-z0-9-_=]*$") +BASE62_ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" class BadSignature(Exception): """Signature does not match.""" + pass class SignatureExpired(BadSignature): """Signature timestamp is older than required max_age.""" + pass def b62_encode(s): if s == 0: - return '0' - sign = '-' if s < 0 else '' + return "0" + sign = "-" if s < 0 else "" s = abs(s) - encoded = '' + encoded = "" while s > 0: s, remainder = divmod(s, 62) encoded = BASE62_ALPHABET[remainder] + encoded @@ -72,10 +74,10 @@ def b62_encode(s): def b62_decode(s): - if s == '0': + if s == "0": return 0 sign = 1 - if s[0] == '-': + if s[0] == "-": s = s[1:] sign = -1 decoded = 0 @@ -85,24 +87,26 @@ def b62_decode(s): def b64_encode(s): - return base64.urlsafe_b64encode(s).strip(b'=') + return base64.urlsafe_b64encode(s).strip(b"=") def b64_decode(s): - pad = b'=' * (-len(s) % 4) + pad = b"=" * (-len(s) % 4) return base64.urlsafe_b64decode(s + pad) -def base64_hmac(salt, value, key, algorithm='sha1'): - return b64_encode(salted_hmac(salt, value, key, algorithm=algorithm).digest()).decode() +def base64_hmac(salt, value, key, algorithm="sha1"): + return b64_encode( + salted_hmac(salt, value, key, algorithm=algorithm).digest() + ).decode() def _cookie_signer_key(key): # SECRET_KEYS items may be str or bytes. - return b'django.http.cookies' + force_bytes(key) + return b"django.http.cookies" + force_bytes(key) -def get_cookie_signer(salt='django.core.signing.get_cookie_signer'): +def get_cookie_signer(salt="django.core.signing.get_cookie_signer"): Signer = import_string(settings.SIGNING_BACKEND) return Signer( key=_cookie_signer_key(settings.SECRET_KEY), @@ -116,14 +120,17 @@ class JSONSerializer: Simple wrapper around json to be used in signing.dumps and signing.loads. """ + def dumps(self, obj): - return json.dumps(obj, separators=(',', ':')).encode('latin-1') + return json.dumps(obj, separators=(",", ":")).encode("latin-1") def loads(self, data): - return json.loads(data.decode('latin-1')) + return json.loads(data.decode("latin-1")) -def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer, compress=False): +def dumps( + obj, key=None, salt="django.core.signing", serializer=JSONSerializer, compress=False +): """ Return URL-safe, hmac signed base64 compressed JSON string. If key is None, use settings.SECRET_KEY instead. The hmac algorithm is the default @@ -140,13 +147,15 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer, The serializer is expected to return a bytestring. """ - return TimestampSigner(key, salt=salt).sign_object(obj, serializer=serializer, compress=compress) + return TimestampSigner(key, salt=salt).sign_object( + obj, serializer=serializer, compress=compress + ) def loads( s, key=None, - salt='django.core.signing', + salt="django.core.signing", serializer=JSONSerializer, max_age=None, fallback_keys=None, @@ -167,7 +176,7 @@ class Signer: def __init__( self, key=None, - sep=':', + sep=":", salt=None, algorithm=None, fallback_keys=None, @@ -181,18 +190,21 @@ class Signer: self.sep = sep if _SEP_UNSAFE.match(self.sep): raise ValueError( - 'Unsafe Signer separator: %r (cannot be empty or consist of ' - 'only A-z0-9-_=)' % sep, + "Unsafe Signer separator: %r (cannot be empty or consist of " + "only A-z0-9-_=)" % sep, ) - self.salt = salt or '%s.%s' % (self.__class__.__module__, self.__class__.__name__) - self.algorithm = algorithm or 'sha256' + self.salt = salt or "%s.%s" % ( + self.__class__.__module__, + self.__class__.__name__, + ) + self.algorithm = algorithm or "sha256" def signature(self, value, key=None): key = key or self.key - return base64_hmac(self.salt + 'signer', value, key, algorithm=self.algorithm) + return base64_hmac(self.salt + "signer", value, key, algorithm=self.algorithm) def sign(self, value): - return '%s%s%s' % (value, self.sep, self.signature(value)) + return "%s%s%s" % (value, self.sep, self.signature(value)) def unsign(self, signed_value): if self.sep not in signed_value: @@ -225,14 +237,14 @@ class Signer: is_compressed = True base64d = b64_encode(data).decode() if is_compressed: - base64d = '.' + base64d + base64d = "." + base64d return self.sign(base64d) def unsign_object(self, signed_obj, serializer=JSONSerializer, **kwargs): # Signer.unsign() returns str but base64 and zlib compression operate # on bytes. base64d = self.unsign(signed_obj, **kwargs).encode() - decompress = base64d[:1] == b'.' + decompress = base64d[:1] == b"." if decompress: # It's compressed; uncompress it first. base64d = base64d[1:] @@ -243,12 +255,11 @@ class Signer: class TimestampSigner(Signer): - def timestamp(self): return b62_encode(int(time.time())) def sign(self, value): - value = '%s%s%s' % (value, self.sep, self.timestamp()) + value = "%s%s%s" % (value, self.sep, self.timestamp()) return super().sign(value) def unsign(self, value, max_age=None): @@ -265,6 +276,5 @@ class TimestampSigner(Signer): # Check timestamp is not older than max_age age = time.time() - timestamp if age > max_age: - raise SignatureExpired( - 'Signature age %s > %s seconds' % (age, max_age)) + raise SignatureExpired("Signature age %s > %s seconds" % (age, max_age)) return value diff --git a/django/core/validators.py b/django/core/validators.py index 9ad90f006f..5272258a77 100644 --- a/django/core/validators.py +++ b/django/core/validators.py @@ -8,21 +8,24 @@ from django.utils.deconstruct import deconstructible from django.utils.encoding import punycode from django.utils.ipv6 import is_valid_ipv6_address from django.utils.regex_helper import _lazy_re_compile -from django.utils.translation import gettext_lazy as _, ngettext_lazy +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import ngettext_lazy # These values, if given to validate(), will trigger the self.required check. -EMPTY_VALUES = (None, '', [], (), {}) +EMPTY_VALUES = (None, "", [], (), {}) @deconstructible class RegexValidator: - regex = '' - message = _('Enter a valid value.') - code = 'invalid' + regex = "" + message = _("Enter a valid value.") + code = "invalid" inverse_match = False flags = 0 - def __init__(self, regex=None, message=None, code=None, inverse_match=None, flags=None): + def __init__( + self, regex=None, message=None, code=None, inverse_match=None, flags=None + ): if regex is not None: self.regex = regex if message is not None: @@ -34,7 +37,9 @@ class RegexValidator: if flags is not None: self.flags = flags if self.flags and not isinstance(self.regex, str): - raise TypeError("If the flags are set, regex must be a regular expression string.") + raise TypeError( + "If the flags are set, regex must be a regular expression string." + ) self.regex = _lazy_re_compile(self.regex, self.flags) @@ -46,54 +51,58 @@ class RegexValidator: regex_matches = self.regex.search(str(value)) invalid_input = regex_matches if self.inverse_match else not regex_matches if invalid_input: - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) def __eq__(self, other): return ( - isinstance(other, RegexValidator) and - self.regex.pattern == other.regex.pattern and - self.regex.flags == other.regex.flags and - (self.message == other.message) and - (self.code == other.code) and - (self.inverse_match == other.inverse_match) + isinstance(other, RegexValidator) + and self.regex.pattern == other.regex.pattern + and self.regex.flags == other.regex.flags + and (self.message == other.message) + and (self.code == other.code) + and (self.inverse_match == other.inverse_match) ) @deconstructible class URLValidator(RegexValidator): - ul = '\u00a1-\uffff' # Unicode letters range (must not be a raw string). + ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string). # IP patterns ipv4_re = ( - r'(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)' - r'(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}' + r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)" + r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}" ) - ipv6_re = r'\[[0-9a-f:.]+\]' # (simple regex, validated later) + ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later) # Host patterns - hostname_re = r'[a-z' + ul + r'0-9](?:[a-z' + ul + r'0-9-]{0,61}[a-z' + ul + r'0-9])?' - # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1 - domain_re = r'(?:\.(?!-)[a-z' + ul + r'0-9-]{1,63}(?<!-))*' - tld_re = ( - r'\.' # dot - r'(?!-)' # can't start with a dash - r'(?:[a-z' + ul + '-]{2,63}' # domain label - r'|xn--[a-z0-9]{1,59})' # or punycode label - r'(?<!-)' # can't end with a dash - r'\.?' # may have a trailing dot + hostname_re = ( + r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?" ) - host_re = '(' + hostname_re + domain_re + tld_re + '|localhost)' + # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1 + domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*" + tld_re = ( + r"\." # dot + r"(?!-)" # can't start with a dash + r"(?:[a-z" + ul + "-]{2,63}" # domain label + r"|xn--[a-z0-9]{1,59})" # or punycode label + r"(?<!-)" # can't end with a dash + r"\.?" # may have a trailing dot + ) + host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)" regex = _lazy_re_compile( - r'^(?:[a-z0-9.+-]*)://' # scheme is validated separately - r'(?:[^\s:@/]+(?::[^\s:@/]*)?@)?' # user:pass authentication - r'(?:' + ipv4_re + '|' + ipv6_re + '|' + host_re + ')' - r'(?::[0-9]{1,5})?' # port - r'(?:[/?#][^\s]*)?' # resource path - r'\Z', re.IGNORECASE) - message = _('Enter a valid URL.') - schemes = ['http', 'https', 'ftp', 'ftps'] - unsafe_chars = frozenset('\t\r\n') + r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately + r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication + r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")" + r"(?::[0-9]{1,5})?" # port + r"(?:[/?#][^\s]*)?" # resource path + r"\Z", + re.IGNORECASE, + ) + message = _("Enter a valid URL.") + schemes = ["http", "https", "ftp", "ftps"] + unsafe_chars = frozenset("\t\r\n") def __init__(self, schemes=None, **kwargs): super().__init__(**kwargs) @@ -102,19 +111,19 @@ class URLValidator(RegexValidator): def __call__(self, value): if not isinstance(value, str): - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) if self.unsafe_chars.intersection(value): - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) # Check if the scheme is valid. - scheme = value.split('://')[0].lower() + scheme = value.split("://")[0].lower() if scheme not in self.schemes: - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) # Then check full URL try: splitted_url = urlsplit(value) except ValueError: - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) try: super().__call__(value) except ValidationError as e: @@ -131,26 +140,28 @@ class URLValidator(RegexValidator): raise else: # Now verify IPv6 in the netloc part - host_match = re.search(r'^\[(.+)\](?::[0-9]{1,5})?$', splitted_url.netloc) + host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc) if host_match: potential_ip = host_match[1] try: validate_ipv6_address(potential_ip) except ValidationError: - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError( + self.message, code=self.code, params={"value": value} + ) # The maximum length of a full host name is 253 characters per RFC 1034 # section 3.1. It's defined to be 255 bytes or less, but this includes # one byte for the length of the name and one byte for the trailing dot # that's used to indicate absolute names in DNS. if splitted_url.hostname is None or len(splitted_url.hostname) > 253: - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) integer_validator = RegexValidator( - _lazy_re_compile(r'^-?\d+\Z'), - message=_('Enter a valid integer.'), - code='invalid', + _lazy_re_compile(r"^-?\d+\Z"), + message=_("Enter a valid integer."), + code="invalid", ) @@ -160,21 +171,24 @@ def validate_integer(value): @deconstructible class EmailValidator: - message = _('Enter a valid email address.') - code = 'invalid' + message = _("Enter a valid email address.") + code = "invalid" user_regex = _lazy_re_compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z" # dot-atom r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)', # quoted-string - re.IGNORECASE) + re.IGNORECASE, + ) domain_regex = _lazy_re_compile( # max length for domain name labels is 63 characters per RFC 1034 - r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z', - re.IGNORECASE) + r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z", + re.IGNORECASE, + ) literal_regex = _lazy_re_compile( # literal form, ipv4 or ipv6 address (SMTP 4.1.3) - r'\[([A-F0-9:.]+)\]\Z', - re.IGNORECASE) - domain_allowlist = ['localhost'] + r"\[([A-F0-9:.]+)\]\Z", + re.IGNORECASE, + ) + domain_allowlist = ["localhost"] def __init__(self, message=None, code=None, allowlist=None): if message is not None: @@ -185,16 +199,17 @@ class EmailValidator: self.domain_allowlist = allowlist def __call__(self, value): - if not value or '@' not in value: - raise ValidationError(self.message, code=self.code, params={'value': value}) + if not value or "@" not in value: + raise ValidationError(self.message, code=self.code, params={"value": value}) - user_part, domain_part = value.rsplit('@', 1) + user_part, domain_part = value.rsplit("@", 1) if not self.user_regex.match(user_part): - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) - if (domain_part not in self.domain_allowlist and - not self.validate_domain_part(domain_part)): + if domain_part not in self.domain_allowlist and not self.validate_domain_part( + domain_part + ): # Try for possible IDN domain-part try: domain_part = punycode(domain_part) @@ -203,7 +218,7 @@ class EmailValidator: else: if self.validate_domain_part(domain_part): return - raise ValidationError(self.message, code=self.code, params={'value': value}) + raise ValidationError(self.message, code=self.code, params={"value": value}) def validate_domain_part(self, domain_part): if self.domain_regex.match(domain_part): @@ -221,28 +236,30 @@ class EmailValidator: def __eq__(self, other): return ( - isinstance(other, EmailValidator) and - (self.domain_allowlist == other.domain_allowlist) and - (self.message == other.message) and - (self.code == other.code) + isinstance(other, EmailValidator) + and (self.domain_allowlist == other.domain_allowlist) + and (self.message == other.message) + and (self.code == other.code) ) validate_email = EmailValidator() -slug_re = _lazy_re_compile(r'^[-a-zA-Z0-9_]+\Z') +slug_re = _lazy_re_compile(r"^[-a-zA-Z0-9_]+\Z") validate_slug = RegexValidator( slug_re, # Translators: "letters" means latin letters: a-z and A-Z. - _('Enter a valid “slug” consisting of letters, numbers, underscores or hyphens.'), - 'invalid' + _("Enter a valid “slug” consisting of letters, numbers, underscores or hyphens."), + "invalid", ) -slug_unicode_re = _lazy_re_compile(r'^[-\w]+\Z') +slug_unicode_re = _lazy_re_compile(r"^[-\w]+\Z") validate_unicode_slug = RegexValidator( slug_unicode_re, - _('Enter a valid “slug” consisting of Unicode letters, numbers, underscores, or hyphens.'), - 'invalid' + _( + "Enter a valid “slug” consisting of Unicode letters, numbers, underscores, or hyphens." + ), + "invalid", ) @@ -250,25 +267,26 @@ def validate_ipv4_address(value): try: ipaddress.IPv4Address(value) except ValueError: - raise ValidationError(_('Enter a valid IPv4 address.'), code='invalid', params={'value': value}) + raise ValidationError( + _("Enter a valid IPv4 address."), code="invalid", params={"value": value} + ) else: # Leading zeros are forbidden to avoid ambiguity with the octal # notation. This restriction is included in Python 3.9.5+. # TODO: Remove when dropping support for PY39. - if any( - octet != '0' and octet[0] == '0' - for octet in value.split('.') - ): + if any(octet != "0" and octet[0] == "0" for octet in value.split(".")): raise ValidationError( - _('Enter a valid IPv4 address.'), - code='invalid', - params={'value': value}, + _("Enter a valid IPv4 address."), + code="invalid", + params={"value": value}, ) def validate_ipv6_address(value): if not is_valid_ipv6_address(value): - raise ValidationError(_('Enter a valid IPv6 address.'), code='invalid', params={'value': value}) + raise ValidationError( + _("Enter a valid IPv6 address."), code="invalid", params={"value": value} + ) def validate_ipv46_address(value): @@ -278,13 +296,17 @@ def validate_ipv46_address(value): try: validate_ipv6_address(value) except ValidationError: - raise ValidationError(_('Enter a valid IPv4 or IPv6 address.'), code='invalid', params={'value': value}) + raise ValidationError( + _("Enter a valid IPv4 or IPv6 address."), + code="invalid", + params={"value": value}, + ) ip_address_validator_map = { - 'both': ([validate_ipv46_address], _('Enter a valid IPv4 or IPv6 address.')), - 'ipv4': ([validate_ipv4_address], _('Enter a valid IPv4 address.')), - 'ipv6': ([validate_ipv6_address], _('Enter a valid IPv6 address.')), + "both": ([validate_ipv46_address], _("Enter a valid IPv4 or IPv6 address.")), + "ipv4": ([validate_ipv4_address], _("Enter a valid IPv4 address.")), + "ipv6": ([validate_ipv6_address], _("Enter a valid IPv6 address.")), } @@ -293,33 +315,39 @@ def ip_address_validators(protocol, unpack_ipv4): Depending on the given parameters, return the appropriate validators for the GenericIPAddressField. """ - if protocol != 'both' and unpack_ipv4: + if protocol != "both" and unpack_ipv4: raise ValueError( - "You can only use `unpack_ipv4` if `protocol` is set to 'both'") + "You can only use `unpack_ipv4` if `protocol` is set to 'both'" + ) try: return ip_address_validator_map[protocol.lower()] except KeyError: - raise ValueError("The protocol '%s' is unknown. Supported: %s" - % (protocol, list(ip_address_validator_map))) + raise ValueError( + "The protocol '%s' is unknown. Supported: %s" + % (protocol, list(ip_address_validator_map)) + ) -def int_list_validator(sep=',', message=None, code='invalid', allow_negative=False): - regexp = _lazy_re_compile(r'^%(neg)s\d+(?:%(sep)s%(neg)s\d+)*\Z' % { - 'neg': '(-)?' if allow_negative else '', - 'sep': re.escape(sep), - }) +def int_list_validator(sep=",", message=None, code="invalid", allow_negative=False): + regexp = _lazy_re_compile( + r"^%(neg)s\d+(?:%(sep)s%(neg)s\d+)*\Z" + % { + "neg": "(-)?" if allow_negative else "", + "sep": re.escape(sep), + } + ) return RegexValidator(regexp, message=message, code=code) validate_comma_separated_integer_list = int_list_validator( - message=_('Enter only digits separated by commas.'), + message=_("Enter only digits separated by commas."), ) @deconstructible class BaseValidator: - message = _('Ensure this value is %(limit_value)s (it is %(show_value)s).') - code = 'limit_value' + message = _("Ensure this value is %(limit_value)s (it is %(show_value)s).") + code = "limit_value" def __init__(self, limit_value, message=None): self.limit_value = limit_value @@ -328,8 +356,10 @@ class BaseValidator: def __call__(self, value): cleaned = self.clean(value) - limit_value = self.limit_value() if callable(self.limit_value) else self.limit_value - params = {'limit_value': limit_value, 'show_value': cleaned, 'value': value} + limit_value = ( + self.limit_value() if callable(self.limit_value) else self.limit_value + ) + params = {"limit_value": limit_value, "show_value": cleaned, "value": value} if self.compare(cleaned, limit_value): raise ValidationError(self.message, code=self.code, params=params) @@ -337,9 +367,9 @@ class BaseValidator: if not isinstance(other, self.__class__): return NotImplemented return ( - self.limit_value == other.limit_value and - self.message == other.message and - self.code == other.code + self.limit_value == other.limit_value + and self.message == other.message + and self.code == other.code ) def compare(self, a, b): @@ -351,8 +381,8 @@ class BaseValidator: @deconstructible class MaxValueValidator(BaseValidator): - message = _('Ensure this value is less than or equal to %(limit_value)s.') - code = 'max_value' + message = _("Ensure this value is less than or equal to %(limit_value)s.") + code = "max_value" def compare(self, a, b): return a > b @@ -360,8 +390,8 @@ class MaxValueValidator(BaseValidator): @deconstructible class MinValueValidator(BaseValidator): - message = _('Ensure this value is greater than or equal to %(limit_value)s.') - code = 'min_value' + message = _("Ensure this value is greater than or equal to %(limit_value)s.") + code = "min_value" def compare(self, a, b): return a < b @@ -370,10 +400,11 @@ class MinValueValidator(BaseValidator): @deconstructible class MinLengthValidator(BaseValidator): message = ngettext_lazy( - 'Ensure this value has at least %(limit_value)d character (it has %(show_value)d).', - 'Ensure this value has at least %(limit_value)d characters (it has %(show_value)d).', - 'limit_value') - code = 'min_length' + "Ensure this value has at least %(limit_value)d character (it has %(show_value)d).", + "Ensure this value has at least %(limit_value)d characters (it has %(show_value)d).", + "limit_value", + ) + code = "min_length" def compare(self, a, b): return a < b @@ -385,10 +416,11 @@ class MinLengthValidator(BaseValidator): @deconstructible class MaxLengthValidator(BaseValidator): message = ngettext_lazy( - 'Ensure this value has at most %(limit_value)d character (it has %(show_value)d).', - 'Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).', - 'limit_value') - code = 'max_length' + "Ensure this value has at most %(limit_value)d character (it has %(show_value)d).", + "Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).", + "limit_value", + ) + code = "max_length" def compare(self, a, b): return a > b @@ -403,22 +435,23 @@ class DecimalValidator: Validate that the input does not exceed the maximum number of digits expected, otherwise raise ValidationError. """ + messages = { - 'invalid': _('Enter a number.'), - 'max_digits': ngettext_lazy( - 'Ensure that there are no more than %(max)s digit in total.', - 'Ensure that there are no more than %(max)s digits in total.', - 'max' + "invalid": _("Enter a number."), + "max_digits": ngettext_lazy( + "Ensure that there are no more than %(max)s digit in total.", + "Ensure that there are no more than %(max)s digits in total.", + "max", ), - 'max_decimal_places': ngettext_lazy( - 'Ensure that there are no more than %(max)s decimal place.', - 'Ensure that there are no more than %(max)s decimal places.', - 'max' + "max_decimal_places": ngettext_lazy( + "Ensure that there are no more than %(max)s decimal place.", + "Ensure that there are no more than %(max)s decimal places.", + "max", ), - 'max_whole_digits': ngettext_lazy( - 'Ensure that there are no more than %(max)s digit before the decimal point.', - 'Ensure that there are no more than %(max)s digits before the decimal point.', - 'max' + "max_whole_digits": ngettext_lazy( + "Ensure that there are no more than %(max)s digit before the decimal point.", + "Ensure that there are no more than %(max)s digits before the decimal point.", + "max", ), } @@ -428,8 +461,10 @@ class DecimalValidator: def __call__(self, value): digit_tuple, exponent = value.as_tuple()[1:] - if exponent in {'F', 'n', 'N'}: - raise ValidationError(self.messages['invalid'], code='invalid', params={'value': value}) + if exponent in {"F", "n", "N"}: + raise ValidationError( + self.messages["invalid"], code="invalid", params={"value": value} + ) if exponent >= 0: # A positive exponent adds that many trailing zeros. digits = len(digit_tuple) + exponent @@ -449,43 +484,48 @@ class DecimalValidator: if self.max_digits is not None and digits > self.max_digits: raise ValidationError( - self.messages['max_digits'], - code='max_digits', - params={'max': self.max_digits, 'value': value}, + self.messages["max_digits"], + code="max_digits", + params={"max": self.max_digits, "value": value}, ) if self.decimal_places is not None and decimals > self.decimal_places: raise ValidationError( - self.messages['max_decimal_places'], - code='max_decimal_places', - params={'max': self.decimal_places, 'value': value}, + self.messages["max_decimal_places"], + code="max_decimal_places", + params={"max": self.decimal_places, "value": value}, ) - if (self.max_digits is not None and self.decimal_places is not None and - whole_digits > (self.max_digits - self.decimal_places)): + if ( + self.max_digits is not None + and self.decimal_places is not None + and whole_digits > (self.max_digits - self.decimal_places) + ): raise ValidationError( - self.messages['max_whole_digits'], - code='max_whole_digits', - params={'max': (self.max_digits - self.decimal_places), 'value': value}, + self.messages["max_whole_digits"], + code="max_whole_digits", + params={"max": (self.max_digits - self.decimal_places), "value": value}, ) def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self.max_digits == other.max_digits and - self.decimal_places == other.decimal_places + isinstance(other, self.__class__) + and self.max_digits == other.max_digits + and self.decimal_places == other.decimal_places ) @deconstructible class FileExtensionValidator: message = _( - 'File extension “%(extension)s” is not allowed. ' - 'Allowed extensions are: %(allowed_extensions)s.' + "File extension “%(extension)s” is not allowed. " + "Allowed extensions are: %(allowed_extensions)s." ) - code = 'invalid_extension' + code = "invalid_extension" def __init__(self, allowed_extensions=None, message=None, code=None): if allowed_extensions is not None: - allowed_extensions = [allowed_extension.lower() for allowed_extension in allowed_extensions] + allowed_extensions = [ + allowed_extension.lower() for allowed_extension in allowed_extensions + ] self.allowed_extensions = allowed_extensions if message is not None: self.message = message @@ -494,23 +534,26 @@ class FileExtensionValidator: def __call__(self, value): extension = Path(value.name).suffix[1:].lower() - if self.allowed_extensions is not None and extension not in self.allowed_extensions: + if ( + self.allowed_extensions is not None + and extension not in self.allowed_extensions + ): raise ValidationError( self.message, code=self.code, params={ - 'extension': extension, - 'allowed_extensions': ', '.join(self.allowed_extensions), - 'value': value, - } + "extension": extension, + "allowed_extensions": ", ".join(self.allowed_extensions), + "value": value, + }, ) def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self.allowed_extensions == other.allowed_extensions and - self.message == other.message and - self.code == other.code + isinstance(other, self.__class__) + and self.allowed_extensions == other.allowed_extensions + and self.message == other.message + and self.code == other.code ) @@ -525,14 +568,17 @@ def get_available_image_extensions(): def validate_image_file_extension(value): - return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(value) + return FileExtensionValidator(allowed_extensions=get_available_image_extensions())( + value + ) @deconstructible class ProhibitNullCharactersValidator: """Validate that the string doesn't contain the null character.""" - message = _('Null characters are not allowed.') - code = 'null_characters_not_allowed' + + message = _("Null characters are not allowed.") + code = "null_characters_not_allowed" def __init__(self, message=None, code=None): if message is not None: @@ -541,12 +587,12 @@ class ProhibitNullCharactersValidator: self.code = code def __call__(self, value): - if '\x00' in str(value): - raise ValidationError(self.message, code=self.code, params={'value': value}) + if "\x00" in str(value): + raise ValidationError(self.message, code=self.code, params={"value": value}) def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self.message == other.message and - self.code == other.code + isinstance(other, self.__class__) + and self.message == other.message + and self.code == other.code ) diff --git a/django/db/__init__.py b/django/db/__init__.py index 26127860ed..b0cae97e01 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -1,17 +1,36 @@ from django.core import signals from django.db.utils import ( - DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, ConnectionHandler, - ConnectionRouter, DatabaseError, DataError, Error, IntegrityError, - InterfaceError, InternalError, NotSupportedError, OperationalError, + DEFAULT_DB_ALIAS, + DJANGO_VERSION_PICKLE_KEY, + ConnectionHandler, + ConnectionRouter, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, ProgrammingError, ) from django.utils.connection import ConnectionProxy __all__ = [ - 'connection', 'connections', 'router', 'DatabaseError', 'IntegrityError', - 'InternalError', 'ProgrammingError', 'DataError', 'NotSupportedError', - 'Error', 'InterfaceError', 'OperationalError', 'DEFAULT_DB_ALIAS', - 'DJANGO_VERSION_PICKLE_KEY', + "connection", + "connections", + "router", + "DatabaseError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "DataError", + "NotSupportedError", + "Error", + "InterfaceError", + "OperationalError", + "DEFAULT_DB_ALIAS", + "DJANGO_VERSION_PICKLE_KEY", ] connections = ConnectionHandler() diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index 58dd6d43bd..1aee03848b 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -23,19 +23,21 @@ from django.utils import timezone from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property -NO_DB_ALIAS = '__no_db__' +NO_DB_ALIAS = "__no_db__" # RemovedInDjango50Warning def timezone_constructor(tzname): if settings.USE_DEPRECATED_PYTZ: import pytz + return pytz.timezone(tzname) return zoneinfo.ZoneInfo(tzname) class BaseDatabaseWrapper: """Represent a database connection.""" + # Mapping of Field objects to their column types. data_types = {} # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT. @@ -43,8 +45,8 @@ class BaseDatabaseWrapper: # Mapping of Field objects to their SQL for CHECK constraints. data_type_check_constraints = {} ops = None - vendor = 'unknown' - display_name = 'unknown' + vendor = "unknown" + display_name = "unknown" SchemaEditorClass = None # Classes instantiated in __init__(). client_class = None @@ -124,8 +126,8 @@ class BaseDatabaseWrapper: def __repr__(self): return ( - f'<{self.__class__.__qualname__} ' - f'vendor={self.vendor!r} alias={self.alias!r}>' + f"<{self.__class__.__qualname__} " + f"vendor={self.vendor!r} alias={self.alias!r}>" ) def ensure_timezone(self): @@ -153,10 +155,10 @@ class BaseDatabaseWrapper: """ if not settings.USE_TZ: return None - elif self.settings_dict['TIME_ZONE'] is None: + elif self.settings_dict["TIME_ZONE"] is None: return timezone.utc else: - return timezone_constructor(self.settings_dict['TIME_ZONE']) + return timezone_constructor(self.settings_dict["TIME_ZONE"]) @cached_property def timezone_name(self): @@ -165,10 +167,10 @@ class BaseDatabaseWrapper: """ if not settings.USE_TZ: return settings.TIME_ZONE - elif self.settings_dict['TIME_ZONE'] is None: - return 'UTC' + elif self.settings_dict["TIME_ZONE"] is None: + return "UTC" else: - return self.settings_dict['TIME_ZONE'] + return self.settings_dict["TIME_ZONE"] @property def queries_logged(self): @@ -179,26 +181,35 @@ class BaseDatabaseWrapper: if len(self.queries_log) == self.queries_log.maxlen: warnings.warn( "Limit for query logging exceeded, only the last {} queries " - "will be returned.".format(self.queries_log.maxlen)) + "will be returned.".format(self.queries_log.maxlen) + ) return list(self.queries_log) # ##### Backend-specific methods for creating connections and cursors ##### def get_connection_params(self): """Return a dict of parameters suitable for get_new_connection.""" - raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method') + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require a get_connection_params() method" + ) def get_new_connection(self, conn_params): """Open a connection to the database.""" - raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method') + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require a get_new_connection() method" + ) def init_connection_state(self): """Initialize the database connection settings.""" - raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method') + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require an init_connection_state() method" + ) def create_cursor(self, name=None): """Create a cursor. Assume that a connection is established.""" - raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method') + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require a create_cursor() method" + ) # ##### Backend-specific methods for creating connections ##### @@ -213,8 +224,8 @@ class BaseDatabaseWrapper: self.atomic_blocks = [] self.needs_rollback = False # Reset parameters defining when to close/health-check the connection. - self.health_check_enabled = self.settings_dict['CONN_HEALTH_CHECKS'] - max_age = self.settings_dict['CONN_MAX_AGE'] + self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"] + max_age = self.settings_dict["CONN_MAX_AGE"] self.close_at = None if max_age is None else time.monotonic() + max_age self.closed_in_transaction = False self.errors_occurred = False @@ -223,14 +234,14 @@ class BaseDatabaseWrapper: # Establish the connection conn_params = self.get_connection_params() self.connection = self.get_new_connection(conn_params) - self.set_autocommit(self.settings_dict['AUTOCOMMIT']) + self.set_autocommit(self.settings_dict["AUTOCOMMIT"]) self.init_connection_state() connection_created.send(sender=self.__class__, connection=self) self.run_on_commit = [] def check_settings(self): - if self.settings_dict['TIME_ZONE'] is not None and not settings.USE_TZ: + if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ: raise ImproperlyConfigured( "Connection '%s' cannot set TIME_ZONE because USE_TZ is False." % self.alias @@ -356,7 +367,7 @@ class BaseDatabaseWrapper: return thread_ident = _thread.get_ident() - tid = str(thread_ident).replace('-', '') + tid = str(thread_ident).replace("-", "") self.savepoint_state += 1 sid = "s%s_x%d" % (tid, self.savepoint_state) @@ -406,7 +417,9 @@ class BaseDatabaseWrapper: """ Backend-specific implementation to enable or disable autocommit. """ - raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method') + raise NotImplementedError( + "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method" + ) # ##### Generic transaction management methods ##### @@ -415,7 +428,9 @@ class BaseDatabaseWrapper: self.ensure_connection() return self.autocommit - def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): + def set_autocommit( + self, autocommit, force_begin_transaction_with_broken_autocommit=False + ): """ Enable or disable autocommit. @@ -432,8 +447,9 @@ class BaseDatabaseWrapper: self.ensure_connection() start_transaction_under_autocommit = ( - force_begin_transaction_with_broken_autocommit and not autocommit and - hasattr(self, '_start_transaction_under_autocommit') + force_begin_transaction_with_broken_autocommit + and not autocommit + and hasattr(self, "_start_transaction_under_autocommit") ) if start_transaction_under_autocommit: @@ -451,7 +467,8 @@ class BaseDatabaseWrapper: """Get the "needs rollback" flag -- for *advanced use* only.""" if not self.in_atomic_block: raise TransactionManagementError( - "The rollback flag doesn't work outside of an 'atomic' block.") + "The rollback flag doesn't work outside of an 'atomic' block." + ) return self.needs_rollback def set_rollback(self, rollback): @@ -460,20 +477,23 @@ class BaseDatabaseWrapper: """ if not self.in_atomic_block: raise TransactionManagementError( - "The rollback flag doesn't work outside of an 'atomic' block.") + "The rollback flag doesn't work outside of an 'atomic' block." + ) self.needs_rollback = rollback def validate_no_atomic_block(self): """Raise an error if an atomic block is active.""" if self.in_atomic_block: raise TransactionManagementError( - "This is forbidden when an 'atomic' block is active.") + "This is forbidden when an 'atomic' block is active." + ) def validate_no_broken_transaction(self): if self.needs_rollback: raise TransactionManagementError( "An error occurred in the current transaction. You can't " - "execute queries until the end of the 'atomic' block.") + "execute queries until the end of the 'atomic' block." + ) # ##### Foreign key constraints checks handling ##### @@ -524,14 +544,15 @@ class BaseDatabaseWrapper: as that may prevent Django from recycling unusable connections. """ raise NotImplementedError( - "subclasses of BaseDatabaseWrapper may require an is_usable() method") + "subclasses of BaseDatabaseWrapper may require an is_usable() method" + ) def close_if_health_check_failed(self): """Close existing connection if it fails a health check.""" if ( - self.connection is None or - not self.health_check_enabled or - self.health_check_done + self.connection is None + or not self.health_check_enabled + or self.health_check_done ): return @@ -548,7 +569,7 @@ class BaseDatabaseWrapper: self.health_check_done = False # If the application didn't restore the original autocommit setting, # don't take chances, drop the connection. - if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']: + if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]: self.close() return @@ -580,7 +601,9 @@ class BaseDatabaseWrapper: def dec_thread_sharing(self): with self._thread_sharing_lock: if self._thread_sharing_count <= 0: - raise RuntimeError('Cannot decrement the thread sharing count below zero.') + raise RuntimeError( + "Cannot decrement the thread sharing count below zero." + ) self._thread_sharing_count -= 1 def validate_thread_sharing(self): @@ -595,8 +618,7 @@ class BaseDatabaseWrapper: "DatabaseWrapper objects created in a " "thread can only be used in that same thread. The object " "with alias '%s' was created in thread id %s and this is " - "thread id %s." - % (self.alias, self._thread_ident, _thread.get_ident()) + "thread id %s." % (self.alias, self._thread_ident, _thread.get_ident()) ) # ##### Miscellaneous ##### @@ -657,7 +679,7 @@ class BaseDatabaseWrapper: being exposed to potential child threads while (or after) the test database is destroyed. Refs #10868, #17786, #16969. """ - conn = self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS) + conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS) try: with conn.cursor() as cursor: yield cursor @@ -670,7 +692,8 @@ class BaseDatabaseWrapper: """ if self.SchemaEditorClass is None: raise NotImplementedError( - 'The SchemaEditorClass attribute of this database wrapper is still None') + "The SchemaEditorClass attribute of this database wrapper is still None" + ) return self.SchemaEditorClass(self, *args, **kwargs) def on_commit(self, func): @@ -680,7 +703,9 @@ class BaseDatabaseWrapper: # Transaction in progress; save for execution on commit. self.run_on_commit.append((set(self.savepoint_ids), func)) elif not self.get_autocommit(): - raise TransactionManagementError('on_commit() cannot be used in manual transaction management') + raise TransactionManagementError( + "on_commit() cannot be used in manual transaction management" + ) else: # No transaction in progress and in autocommit mode; execute # immediately. diff --git a/django/db/backends/base/client.py b/django/db/backends/base/client.py index 8aca821fd2..031056372d 100644 --- a/django/db/backends/base/client.py +++ b/django/db/backends/base/client.py @@ -4,6 +4,7 @@ import subprocess class BaseDatabaseClient: """Encapsulate backend-specific methods for opening a client shell.""" + # This should be a string representing the name of the executable # (e.g., "psql"). Subclasses must override this. executable_name = None @@ -15,11 +16,13 @@ class BaseDatabaseClient: @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): raise NotImplementedError( - 'subclasses of BaseDatabaseClient must provide a ' - 'settings_to_cmd_args_env() method or override a runshell().' + "subclasses of BaseDatabaseClient must provide a " + "settings_to_cmd_args_env() method or override a runshell()." ) def runshell(self, parameters): - args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters) + args, env = self.settings_to_cmd_args_env( + self.connection.settings_dict, parameters + ) env = {**os.environ, **env} if env else None subprocess.run(args, env=env, check=True) diff --git a/django/db/backends/base/creation.py b/django/db/backends/base/creation.py index d1c0e1ac96..78480fc0f8 100644 --- a/django/db/backends/base/creation.py +++ b/django/db/backends/base/creation.py @@ -11,7 +11,7 @@ from django.utils.module_loading import import_string # The prefix to put on the default database name when creating # the test database. -TEST_DATABASE_PREFIX = 'test_' +TEST_DATABASE_PREFIX = "test_" class BaseDatabaseCreation: @@ -19,6 +19,7 @@ class BaseDatabaseCreation: Encapsulate backend-specific differences pertaining to creation and destruction of the test database. """ + def __init__(self, connection): self.connection = connection @@ -28,7 +29,9 @@ class BaseDatabaseCreation: def log(self, msg): sys.stderr.write(msg + os.linesep) - def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False): + def create_test_db( + self, verbosity=1, autoclobber=False, serialize=True, keepdb=False + ): """ Create a test database, prompting the user for confirmation if the database already exists. Return the name of the test database created. @@ -39,14 +42,17 @@ class BaseDatabaseCreation: test_database_name = self._get_test_db_name() if verbosity >= 1: - action = 'Creating' + action = "Creating" if keepdb: action = "Using existing" - self.log('%s test database for alias %s...' % ( - action, - self._get_database_display_str(verbosity, test_database_name), - )) + self.log( + "%s test database for alias %s..." + % ( + action, + self._get_database_display_str(verbosity, test_database_name), + ) + ) # We could skip this call if keepdb is True, but we instead # give it the keepdb param. This is to handle the case @@ -60,25 +66,24 @@ class BaseDatabaseCreation: self.connection.settings_dict["NAME"] = test_database_name try: - if self.connection.settings_dict['TEST']['MIGRATE'] is False: + if self.connection.settings_dict["TEST"]["MIGRATE"] is False: # Disable migrations for all apps. old_migration_modules = settings.MIGRATION_MODULES settings.MIGRATION_MODULES = { - app.label: None - for app in apps.get_app_configs() + app.label: None for app in apps.get_app_configs() } # We report migrate messages at one level lower than that # requested. This ensures we don't get flooded with messages during # testing (unless you really ask to be flooded). call_command( - 'migrate', + "migrate", verbosity=max(verbosity - 1, 0), interactive=False, database=self.connection.alias, run_syncdb=True, ) finally: - if self.connection.settings_dict['TEST']['MIGRATE'] is False: + if self.connection.settings_dict["TEST"]["MIGRATE"] is False: settings.MIGRATION_MODULES = old_migration_modules # We then serialize the current state of the database into a string @@ -88,12 +93,12 @@ class BaseDatabaseCreation: if serialize: self.connection._test_serialized_contents = self.serialize_db_to_string() - call_command('createcachetable', database=self.connection.alias) + call_command("createcachetable", database=self.connection.alias) # Ensure a connection for the side effect of initializing the test database. self.connection.ensure_connection() - if os.environ.get('RUNNING_DJANGOS_TEST_SUITE') == 'true': + if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": self.mark_expected_failures_and_skips() return test_database_name @@ -103,7 +108,7 @@ class BaseDatabaseCreation: Set this database up to be used in testing as a mirror of a primary database whose settings are given. """ - self.connection.settings_dict['NAME'] = primary_settings_dict['NAME'] + self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"] def serialize_db_to_string(self): """ @@ -114,22 +119,23 @@ class BaseDatabaseCreation: # Iteratively return every object for all models to serialize. def get_objects(): from django.db.migrations.loader import MigrationLoader + loader = MigrationLoader(self.connection) for app_config in apps.get_app_configs(): if ( - app_config.models_module is not None and - app_config.label in loader.migrated_apps and - app_config.name not in settings.TEST_NON_SERIALIZED_APPS + app_config.models_module is not None + and app_config.label in loader.migrated_apps + and app_config.name not in settings.TEST_NON_SERIALIZED_APPS ): for model in app_config.get_models(): - if ( - model._meta.can_migrate(self.connection) and - router.allow_migrate_model(self.connection.alias, model) - ): + if model._meta.can_migrate( + self.connection + ) and router.allow_migrate_model(self.connection.alias, model): queryset = model._base_manager.using( self.connection.alias, ).order_by(model._meta.pk.name) yield from queryset.iterator() + # Serialize to a string out = StringIO() serializers.serialize("json", get_objects(), indent=None, stream=out) @@ -147,7 +153,9 @@ class BaseDatabaseCreation: # Disable constraint checks, because some databases (MySQL) doesn't # support deferred checks. with self.connection.constraint_checks_disabled(): - for obj in serializers.deserialize('json', data, using=self.connection.alias): + for obj in serializers.deserialize( + "json", data, using=self.connection.alias + ): obj.save() table_names.add(obj.object.__class__._meta.db_table) # Manually check for any invalid keys that might have been added, @@ -160,7 +168,7 @@ class BaseDatabaseCreation: """ return "'%s'%s" % ( self.connection.alias, - (" ('%s')" % database_name) if verbosity >= 2 else '', + (" ('%s')" % database_name) if verbosity >= 2 else "", ) def _get_test_db_name(self): @@ -170,12 +178,12 @@ class BaseDatabaseCreation: _create_test_db() and when no external munging is done with the 'NAME' settings. """ - if self.connection.settings_dict['TEST']['NAME']: - return self.connection.settings_dict['TEST']['NAME'] - return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME'] + if self.connection.settings_dict["TEST"]["NAME"]: + return self.connection.settings_dict["TEST"]["NAME"] + return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"] def _execute_create_test_db(self, cursor, parameters, keepdb=False): - cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters) + cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters) def _create_test_db(self, verbosity, autoclobber, keepdb=False): """ @@ -183,8 +191,8 @@ class BaseDatabaseCreation: """ test_database_name = self._get_test_db_name() test_db_params = { - 'dbname': self.connection.ops.quote_name(test_database_name), - 'suffix': self.sql_table_creation_suffix(), + "dbname": self.connection.ops.quote_name(test_database_name), + "suffix": self.sql_table_creation_suffix(), } # Create the test database and connect to it. with self._nodb_cursor() as cursor: @@ -196,24 +204,30 @@ class BaseDatabaseCreation: if keepdb: return test_database_name - self.log('Got an error creating the test database: %s' % e) + self.log("Got an error creating the test database: %s" % e) if not autoclobber: confirm = input( "Type 'yes' if you would like to try deleting the test " - "database '%s', or 'no' to cancel: " % test_database_name) - if autoclobber or confirm == 'yes': + "database '%s', or 'no' to cancel: " % test_database_name + ) + if autoclobber or confirm == "yes": try: if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, test_database_name), - )) - cursor.execute('DROP DATABASE %(dbname)s' % test_db_params) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, test_database_name + ), + ) + ) + cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: - self.log('Got an error recreating the test database: %s' % e) + self.log("Got an error recreating the test database: %s" % e) sys.exit(2) else: - self.log('Tests cancelled.') + self.log("Tests cancelled.") sys.exit(1) return test_database_name @@ -222,16 +236,19 @@ class BaseDatabaseCreation: """ Clone a test database. """ - source_database_name = self.connection.settings_dict['NAME'] + source_database_name = self.connection.settings_dict["NAME"] if verbosity >= 1: - action = 'Cloning test database' + action = "Cloning test database" if keepdb: - action = 'Using existing clone' - self.log('%s for alias %s...' % ( - action, - self._get_database_display_str(verbosity, source_database_name), - )) + action = "Using existing clone" + self.log( + "%s for alias %s..." + % ( + action, + self._get_database_display_str(verbosity, source_database_name), + ) + ) # We could skip this call if keepdb is True, but we instead # give it the keepdb param. See create_test_db for details. @@ -245,7 +262,10 @@ class BaseDatabaseCreation: # already and its name has been copied to settings_dict['NAME'] so # we don't need to call _get_test_db_name. orig_settings_dict = self.connection.settings_dict - return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)} + return { + **orig_settings_dict, + "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix), + } def _clone_test_db(self, suffix, verbosity, keepdb=False): """ @@ -253,27 +273,33 @@ class BaseDatabaseCreation: """ raise NotImplementedError( "The database backend doesn't support cloning databases. " - "Disable the option to run tests in parallel processes.") + "Disable the option to run tests in parallel processes." + ) - def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None): + def destroy_test_db( + self, old_database_name=None, verbosity=1, keepdb=False, suffix=None + ): """ Destroy a test database, prompting the user for confirmation if the database already exists. """ self.connection.close() if suffix is None: - test_database_name = self.connection.settings_dict['NAME'] + test_database_name = self.connection.settings_dict["NAME"] else: - test_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + test_database_name = self.get_test_db_clone_settings(suffix)["NAME"] if verbosity >= 1: - action = 'Destroying' + action = "Destroying" if keepdb: - action = 'Preserving' - self.log('%s test database for alias %s...' % ( - action, - self._get_database_display_str(verbosity, test_database_name), - )) + action = "Preserving" + self.log( + "%s test database for alias %s..." + % ( + action, + self._get_database_display_str(verbosity, test_database_name), + ) + ) # if we want to preserve the database # skip the actual destroying piece. @@ -294,8 +320,9 @@ class BaseDatabaseCreation: # to do so, because it's not allowed to delete a database while being # connected to it. with self._nodb_cursor() as cursor: - cursor.execute("DROP DATABASE %s" - % self.connection.ops.quote_name(test_database_name)) + cursor.execute( + "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name) + ) def mark_expected_failures_and_skips(self): """ @@ -304,9 +331,10 @@ class BaseDatabaseCreation: """ # Only load unittest if we're actually testing. from unittest import expectedFailure, skip + for test_name in self.connection.features.django_test_expected_failures: - test_case_name, _, test_method_name = test_name.rpartition('.') - test_app = test_name.split('.')[0] + test_case_name, _, test_method_name = test_name.rpartition(".") + test_app = test_name.split(".")[0] # Importing a test app that isn't installed raises RuntimeError. if test_app in settings.INSTALLED_APPS: test_case = import_string(test_case_name) @@ -314,8 +342,8 @@ class BaseDatabaseCreation: setattr(test_case, test_method_name, expectedFailure(test_method)) for reason, tests in self.connection.features.django_test_skips.items(): for test_name in tests: - test_case_name, _, test_method_name = test_name.rpartition('.') - test_app = test_name.split('.')[0] + test_case_name, _, test_method_name = test_name.rpartition(".") + test_app = test_name.split(".")[0] # Importing a test app that isn't installed raises RuntimeError. if test_app in settings.INSTALLED_APPS: test_case = import_string(test_case_name) @@ -326,7 +354,7 @@ class BaseDatabaseCreation: """ SQL to append to the end of the test table creation statements. """ - return '' + return "" def test_db_signature(self): """ @@ -336,8 +364,8 @@ class BaseDatabaseCreation: """ settings_dict = self.connection.settings_dict return ( - settings_dict['HOST'], - settings_dict['PORT'], - settings_dict['ENGINE'], + settings_dict["HOST"], + settings_dict["PORT"], + settings_dict["ENGINE"], self._get_test_db_name(), ) diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 20d5c8f772..42399b769a 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -130,21 +130,21 @@ class BaseDatabaseFeatures: # Map fields which some backends may not be able to differentiate to the # field it's introspected as. introspected_field_types = { - 'AutoField': 'AutoField', - 'BigAutoField': 'BigAutoField', - 'BigIntegerField': 'BigIntegerField', - 'BinaryField': 'BinaryField', - 'BooleanField': 'BooleanField', - 'CharField': 'CharField', - 'DurationField': 'DurationField', - 'GenericIPAddressField': 'GenericIPAddressField', - 'IntegerField': 'IntegerField', - 'PositiveBigIntegerField': 'PositiveBigIntegerField', - 'PositiveIntegerField': 'PositiveIntegerField', - 'PositiveSmallIntegerField': 'PositiveSmallIntegerField', - 'SmallAutoField': 'SmallAutoField', - 'SmallIntegerField': 'SmallIntegerField', - 'TimeField': 'TimeField', + "AutoField": "AutoField", + "BigAutoField": "BigAutoField", + "BigIntegerField": "BigIntegerField", + "BinaryField": "BinaryField", + "BooleanField": "BooleanField", + "CharField": "CharField", + "DurationField": "DurationField", + "GenericIPAddressField": "GenericIPAddressField", + "IntegerField": "IntegerField", + "PositiveBigIntegerField": "PositiveBigIntegerField", + "PositiveIntegerField": "PositiveIntegerField", + "PositiveSmallIntegerField": "PositiveSmallIntegerField", + "SmallAutoField": "SmallAutoField", + "SmallIntegerField": "SmallIntegerField", + "TimeField": "TimeField", } # Can the backend introspect the column order (ASC/DESC) for indexes? @@ -201,7 +201,7 @@ class BaseDatabaseFeatures: has_case_insensitive_like = False # Suffix for backends that don't support "SELECT xxx;" queries. - bare_select_suffix = '' + bare_select_suffix = "" # If NULL is implied on columns without needing to be explicitly specified implied_column_null = False @@ -325,10 +325,10 @@ class BaseDatabaseFeatures: # Collation names for use by the Django test suite. test_collations = { - 'ci': None, # Case-insensitive. - 'cs': None, # Case-sensitive. - 'non_default': None, # Non-default. - 'swedish_ci': None # Swedish case-insensitive. + "ci": None, # Case-insensitive. + "cs": None, # Case-sensitive. + "non_default": None, # Non-default. + "swedish_ci": None, # Swedish case-insensitive. } # SQL template override for tests.aggregation.tests.NowUTC test_now_utc_template = None @@ -352,14 +352,14 @@ class BaseDatabaseFeatures: def supports_transactions(self): """Confirm support for transactions.""" with self.connection.cursor() as cursor: - cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') + cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)") self.connection.set_autocommit(False) - cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') + cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)") self.connection.rollback() self.connection.set_autocommit(True) - cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') - count, = cursor.fetchone() - cursor.execute('DROP TABLE ROLLBACK_TEST') + cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST") + (count,) = cursor.fetchone() + cursor.execute("DROP TABLE ROLLBACK_TEST") return count == 0 def allows_group_by_selected_pks_on_model(self, model): diff --git a/django/db/backends/base/introspection.py b/django/db/backends/base/introspection.py index 079c1835b0..c8036ef1e9 100644 --- a/django/db/backends/base/introspection.py +++ b/django/db/backends/base/introspection.py @@ -1,18 +1,19 @@ from collections import namedtuple # Structure returned by DatabaseIntrospection.get_table_list() -TableInfo = namedtuple('TableInfo', ['name', 'type']) +TableInfo = namedtuple("TableInfo", ["name", "type"]) # Structure returned by the DB-API cursor.description interface (PEP 249) FieldInfo = namedtuple( - 'FieldInfo', - 'name type_code display_size internal_size precision scale null_ok ' - 'default collation' + "FieldInfo", + "name type_code display_size internal_size precision scale null_ok " + "default collation", ) class BaseDatabaseIntrospection: """Encapsulate backend-specific introspection utilities.""" + data_types_reverse = {} def __init__(self, connection): @@ -43,9 +44,14 @@ class BaseDatabaseIntrospection: the database's ORDER BY here to avoid subtle differences in sorting order between databases. """ + def get_names(cursor): - return sorted(ti.name for ti in self.get_table_list(cursor) - if include_views or ti.type == 't') + return sorted( + ti.name + for ti in self.get_table_list(cursor) + if include_views or ti.type == "t" + ) + if cursor is None: with self.connection.cursor() as cursor: return get_names(cursor) @@ -56,7 +62,9 @@ class BaseDatabaseIntrospection: Return an unsorted list of TableInfo named tuples of all tables and views that exist in the database. """ - raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method') + raise NotImplementedError( + "subclasses of BaseDatabaseIntrospection may require a get_table_list() method" + ) def get_table_description(self, cursor, table_name): """ @@ -64,13 +72,14 @@ class BaseDatabaseIntrospection: interface. """ raise NotImplementedError( - 'subclasses of BaseDatabaseIntrospection may require a ' - 'get_table_description() method.' + "subclasses of BaseDatabaseIntrospection may require a " + "get_table_description() method." ) def get_migratable_models(self): from django.apps import apps from django.db import router + return ( model for app_config in apps.get_app_configs() @@ -91,16 +100,15 @@ class BaseDatabaseIntrospection: continue tables.add(model._meta.db_table) tables.update( - f.m2m_db_table() for f in model._meta.local_many_to_many + f.m2m_db_table() + for f in model._meta.local_many_to_many if f.remote_field.through._meta.managed ) tables = list(tables) if only_existing: existing_tables = set(self.table_names(include_views=include_views)) tables = [ - t - for t in tables - if self.identifier_converter(t) in existing_tables + t for t in tables if self.identifier_converter(t) in existing_tables ] return tables @@ -111,7 +119,8 @@ class BaseDatabaseIntrospection: """ tables = set(map(self.identifier_converter, tables)) return { - m for m in self.get_migratable_models() + m + for m in self.get_migratable_models() if self.identifier_converter(m._meta.db_table) in tables } @@ -127,13 +136,19 @@ class BaseDatabaseIntrospection: continue if model._meta.swapped: continue - sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields)) + sequence_list.extend( + self.get_sequences( + cursor, model._meta.db_table, model._meta.local_fields + ) + ) for f in model._meta.local_many_to_many: # If this is an m2m using an intermediate table, # we don't need to reset the sequence. if f.remote_field.through._meta.auto_created: sequence = self.get_sequences(cursor, f.m2m_db_table()) - sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}]) + sequence_list.extend( + sequence or [{"table": f.m2m_db_table(), "column": None}] + ) return sequence_list def get_sequences(self, cursor, table_name, table_fields=()): @@ -142,7 +157,9 @@ class BaseDatabaseIntrospection: is a dict: {'table': <table_name>, 'column': <column_name>}. An optional 'name' key can be added if the backend supports named sequences. """ - raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method') + raise NotImplementedError( + "subclasses of BaseDatabaseIntrospection may require a get_sequences() method" + ) def get_relations(self, cursor, table_name): """ @@ -150,8 +167,8 @@ class BaseDatabaseIntrospection: representing all foreign keys in the given table. """ raise NotImplementedError( - 'subclasses of BaseDatabaseIntrospection may require a ' - 'get_relations() method.' + "subclasses of BaseDatabaseIntrospection may require a " + "get_relations() method." ) def get_primary_key_column(self, cursor, table_name): @@ -159,8 +176,8 @@ class BaseDatabaseIntrospection: Return the name of the primary key column for the given table. """ for constraint in self.get_constraints(cursor, table_name).values(): - if constraint['primary_key']: - return constraint['columns'][0] + if constraint["primary_key"]: + return constraint["columns"][0] return None def get_constraints(self, cursor, table_name): @@ -182,4 +199,6 @@ class BaseDatabaseIntrospection: Some backends may return special constraint names that don't exist if they don't name constraints of a certain type (e.g. SQLite) """ - raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method') + raise NotImplementedError( + "subclasses of BaseDatabaseIntrospection may require a get_constraints() method" + ) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 7422137304..5201e53af6 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -16,25 +16,26 @@ class BaseDatabaseOperations: Encapsulate backend-specific differences, such as the way a backend performs ordering or calculates the ID of a recently-inserted row. """ + compiler_module = "django.db.models.sql.compiler" # Integer field safe ranges by `internal_type` as documented # in docs/ref/models/fields.txt. integer_field_ranges = { - 'SmallIntegerField': (-32768, 32767), - 'IntegerField': (-2147483648, 2147483647), - 'BigIntegerField': (-9223372036854775808, 9223372036854775807), - 'PositiveBigIntegerField': (0, 9223372036854775807), - 'PositiveSmallIntegerField': (0, 32767), - 'PositiveIntegerField': (0, 2147483647), - 'SmallAutoField': (-32768, 32767), - 'AutoField': (-2147483648, 2147483647), - 'BigAutoField': (-9223372036854775808, 9223372036854775807), + "SmallIntegerField": (-32768, 32767), + "IntegerField": (-2147483648, 2147483647), + "BigIntegerField": (-9223372036854775808, 9223372036854775807), + "PositiveBigIntegerField": (0, 9223372036854775807), + "PositiveSmallIntegerField": (0, 32767), + "PositiveIntegerField": (0, 2147483647), + "SmallAutoField": (-32768, 32767), + "AutoField": (-2147483648, 2147483647), + "BigAutoField": (-9223372036854775808, 9223372036854775807), } set_operators = { - 'union': 'UNION', - 'intersection': 'INTERSECT', - 'difference': 'EXCEPT', + "union": "UNION", + "intersection": "INTERSECT", + "difference": "EXCEPT", } # Mapping of Field.get_internal_type() (typically the model field's class # name) to the data type to use for the Cast() function, if different from @@ -44,11 +45,11 @@ class BaseDatabaseOperations: cast_char_field_without_max_length = None # Start and end points for window expressions. - PRECEDING = 'PRECEDING' - FOLLOWING = 'FOLLOWING' - UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING - UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING - CURRENT_ROW = 'CURRENT ROW' + PRECEDING = "PRECEDING" + FOLLOWING = "FOLLOWING" + UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING + UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING + CURRENT_ROW = "CURRENT ROW" # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported. explain_prefix = None @@ -76,8 +77,8 @@ class BaseDatabaseOperations: def format_for_duration_arithmetic(self, sql): raise NotImplementedError( - 'subclasses of BaseDatabaseOperations may require a ' - 'format_for_duration_arithmetic() method.' + "subclasses of BaseDatabaseOperations may require a " + "format_for_duration_arithmetic() method." ) def cache_key_culling_sql(self): @@ -88,8 +89,8 @@ class BaseDatabaseOperations: This is used by the 'db' cache backend to determine where to start culling. """ - cache_key = self.quote_name('cache_key') - return f'SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s' + cache_key = self.quote_name("cache_key") + return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s" def unification_cast_sql(self, output_field): """ @@ -97,14 +98,16 @@ class BaseDatabaseOperations: to that type. The resulting string should contain a '%s' placeholder for the expression being cast. """ - return '%s' + return "%s" def date_extract_sql(self, lookup_type, field_name): """ Given a lookup_type of 'year', 'month', or 'day', return the SQL that extracts a value from the given date field field_name. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a date_extract_sql() method" + ) def date_trunc_sql(self, lookup_type, field_name, tzname=None): """ @@ -115,22 +118,26 @@ class BaseDatabaseOperations: If `tzname` is provided, the given value is truncated in a specific timezone. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a date_trunc_sql() method." + ) def datetime_cast_date_sql(self, field_name, tzname): """ Return the SQL to cast a datetime value to date value. """ raise NotImplementedError( - 'subclasses of BaseDatabaseOperations may require a ' - 'datetime_cast_date_sql() method.' + "subclasses of BaseDatabaseOperations may require a " + "datetime_cast_date_sql() method." ) def datetime_cast_time_sql(self, field_name, tzname): """ Return the SQL to cast a datetime value to time value. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method" + ) def datetime_extract_sql(self, lookup_type, field_name, tzname): """ @@ -138,7 +145,9 @@ class BaseDatabaseOperations: 'second', return the SQL that extracts a value from the given datetime field field_name. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method" + ) def datetime_trunc_sql(self, lookup_type, field_name, tzname): """ @@ -146,7 +155,9 @@ class BaseDatabaseOperations: 'second', return the SQL that truncates the given datetime field field_name to a datetime object with only the given specificity. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method" + ) def time_trunc_sql(self, lookup_type, field_name, tzname=None): """ @@ -157,7 +168,9 @@ class BaseDatabaseOperations: If `tzname` is provided, the given value is truncated in a specific timezone. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method" + ) def time_extract_sql(self, lookup_type, field_name): """ @@ -171,7 +184,7 @@ class BaseDatabaseOperations: Return the SQL to make a constraint "initially deferred" during a CREATE TABLE statement. """ - return '' + return "" def distinct_sql(self, fields, params): """ @@ -180,9 +193,11 @@ class BaseDatabaseOperations: duplicates. """ if fields: - raise NotSupportedError('DISTINCT ON fields is not supported by this database backend') + raise NotSupportedError( + "DISTINCT ON fields is not supported by this database backend" + ) else: - return ['DISTINCT'], [] + return ["DISTINCT"], [] def fetch_returned_insert_columns(self, cursor, returning_params): """ @@ -198,7 +213,7 @@ class BaseDatabaseOperations: it in a WHERE statement. The resulting string should contain a '%s' placeholder for the column being searched against. """ - return '%s' + return "%s" def force_no_ordering(self): """ @@ -211,11 +226,11 @@ class BaseDatabaseOperations: """ Return the FOR UPDATE SQL clause to lock rows for an update operation. """ - return 'FOR%s UPDATE%s%s%s' % ( - ' NO KEY' if no_key else '', - ' OF %s' % ', '.join(of) if of else '', - ' NOWAIT' if nowait else '', - ' SKIP LOCKED' if skip_locked else '', + return "FOR%s UPDATE%s%s%s" % ( + " NO KEY" if no_key else "", + " OF %s" % ", ".join(of) if of else "", + " NOWAIT" if nowait else "", + " SKIP LOCKED" if skip_locked else "", ) def _get_limit_offset_params(self, low_mark, high_mark): @@ -229,10 +244,14 @@ class BaseDatabaseOperations: def limit_offset_sql(self, low_mark, high_mark): """Return LIMIT/OFFSET SQL clause.""" limit, offset = self._get_limit_offset_params(low_mark, high_mark) - return ' '.join(sql for sql in ( - ('LIMIT %d' % limit) if limit else None, - ('OFFSET %d' % offset) if offset else None, - ) if sql) + return " ".join( + sql + for sql in ( + ("LIMIT %d" % limit) if limit else None, + ("OFFSET %d" % offset) if offset else None, + ) + if sql + ) def last_executed_query(self, cursor, sql, params): """ @@ -246,7 +265,8 @@ class BaseDatabaseOperations: """ # Convert params to contain string values. def to_string(s): - return force_str(s, strings_only=True, errors='replace') + return force_str(s, strings_only=True, errors="replace") + if isinstance(params, (list, tuple)): u_params = tuple(to_string(val) for val in params) elif params is None: @@ -292,14 +312,16 @@ class BaseDatabaseOperations: Return the value to use for the LIMIT when we are wanting "LIMIT infinity". Return None if the limit clause can be omitted in this case. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a no_limit_value() method" + ) def pk_default_value(self): """ Return the value to use during an INSERT statement to specify that the field should use its default value. """ - return 'DEFAULT' + return "DEFAULT" def prepare_sql_script(self, sql): """ @@ -312,7 +334,8 @@ class BaseDatabaseOperations: """ return [ sqlparse.format(statement, strip_comments=True) - for statement in sqlparse.split(sql) if statement + for statement in sqlparse.split(sql) + if statement ] def process_clob(self, value): @@ -345,7 +368,9 @@ class BaseDatabaseOperations: Return a quoted version of the given table, index, or column name. Do not quote the given name if it's already been quoted. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a quote_name() method" + ) def regex_lookup(self, lookup_type): """ @@ -356,7 +381,9 @@ class BaseDatabaseOperations: If the feature is not supported (or part of it is not supported), raise NotImplementedError. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a regex_lookup() method" + ) def savepoint_create_sql(self, sid): """ @@ -384,7 +411,7 @@ class BaseDatabaseOperations: Return '' if the backend doesn't support time zones. """ - return '' + return "" def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): """ @@ -402,7 +429,9 @@ class BaseDatabaseOperations: to tables with foreign keys pointing the tables being truncated. PostgreSQL requires a cascade even if these tables are empty. """ - raise NotImplementedError('subclasses of BaseDatabaseOperations must provide an sql_flush() method') + raise NotImplementedError( + "subclasses of BaseDatabaseOperations must provide an sql_flush() method" + ) def execute_sql_flush(self, sql_list): """Execute a list of SQL statements to flush the database.""" @@ -453,7 +482,7 @@ class BaseDatabaseOperations: If `inline` is True, append the SQL to a row; otherwise append it to the entire CREATE TABLE or CREATE INDEX statement. """ - return '' + return "" def prep_for_like_query(self, x): """Prepare a value for use in a LIKE query.""" @@ -479,7 +508,7 @@ class BaseDatabaseOperations: cases where the target type isn't known, such as .raw() SQL queries. As a consequence it may not work perfectly in all circumstances. """ - if isinstance(value, datetime.datetime): # must be before date + if isinstance(value, datetime.datetime): # must be before date return self.adapt_datetimefield_value(value) elif isinstance(value, datetime.date): return self.adapt_datefield_value(value) @@ -507,7 +536,7 @@ class BaseDatabaseOperations: if value is None: return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value return str(value) @@ -520,7 +549,7 @@ class BaseDatabaseOperations: if value is None: return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value if timezone.is_aware(value): @@ -552,10 +581,9 @@ class BaseDatabaseOperations: """ if iso_year: first = datetime.date.fromisocalendar(value, 1, 1) - second = ( - datetime.date.fromisocalendar(value + 1, 1, 1) - - datetime.timedelta(days=1) - ) + second = datetime.date.fromisocalendar( + value + 1, 1, 1 + ) - datetime.timedelta(days=1) else: first = datetime.date(value, 1, 1) second = datetime.date(value, 12, 31) @@ -574,10 +602,9 @@ class BaseDatabaseOperations: """ if iso_year: first = datetime.datetime.fromisocalendar(value, 1, 1) - second = ( - datetime.datetime.fromisocalendar(value + 1, 1, 1) - - datetime.timedelta(microseconds=1) - ) + second = datetime.datetime.fromisocalendar( + value + 1, 1, 1 + ) - datetime.timedelta(microseconds=1) else: first = datetime.datetime(value, 1, 1) second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999) @@ -627,7 +654,7 @@ class BaseDatabaseOperations: can vary between backends (e.g., Oracle with %% and &) and between subexpression types (e.g., date expressions). """ - conn = ' %s ' % connector + conn = " %s " % connector return conn.join(sub_expressions) def combine_duration_expression(self, connector, sub_expressions): @@ -638,7 +665,7 @@ class BaseDatabaseOperations: Some backends require special syntax to insert binary content (MySQL for example uses '_binary %s'). """ - return '%s' + return "%s" def modify_insert_params(self, placeholder, params): """ @@ -659,66 +686,76 @@ class BaseDatabaseOperations: if self.connection.features.supports_temporal_subtraction: lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs - return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params) - raise NotSupportedError("This backend does not support %s subtraction." % internal_type) + return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params) + raise NotSupportedError( + "This backend does not support %s subtraction." % internal_type + ) def window_frame_start(self, start): if isinstance(start, int): if start < 0: - return '%d %s' % (abs(start), self.PRECEDING) + return "%d %s" % (abs(start), self.PRECEDING) elif start == 0: return self.CURRENT_ROW elif start is None: return self.UNBOUNDED_PRECEDING - raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start) + raise ValueError( + "start argument must be a negative integer, zero, or None, but got '%s'." + % start + ) def window_frame_end(self, end): if isinstance(end, int): if end == 0: return self.CURRENT_ROW elif end > 0: - return '%d %s' % (end, self.FOLLOWING) + return "%d %s" % (end, self.FOLLOWING) elif end is None: return self.UNBOUNDED_FOLLOWING - raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end) + raise ValueError( + "end argument must be a positive integer, zero, or None, but got '%s'." + % end + ) def window_frame_rows_start_end(self, start=None, end=None): """ Return SQL for start and end points in an OVER clause window frame. """ if not self.connection.features.supports_over_clause: - raise NotSupportedError('This backend does not support window expressions.') + raise NotSupportedError("This backend does not support window expressions.") return self.window_frame_start(start), self.window_frame_end(end) def window_frame_range_start_end(self, start=None, end=None): start_, end_ = self.window_frame_rows_start_end(start, end) if ( - self.connection.features.only_supports_unbounded_with_preceding_and_following and - ((start and start < 0) or (end and end > 0)) + self.connection.features.only_supports_unbounded_with_preceding_and_following + and ((start and start < 0) or (end and end > 0)) ): raise NotSupportedError( - '%s only supports UNBOUNDED together with PRECEDING and ' - 'FOLLOWING.' % self.connection.display_name + "%s only supports UNBOUNDED together with PRECEDING and " + "FOLLOWING." % self.connection.display_name ) return start_, end_ def explain_query_prefix(self, format=None, **options): if not self.connection.features.supports_explaining_query_execution: - raise NotSupportedError('This backend does not support explaining query execution.') + raise NotSupportedError( + "This backend does not support explaining query execution." + ) if format: supported_formats = self.connection.features.supported_explain_formats normalized_format = format.upper() if normalized_format not in supported_formats: - msg = '%s is not a recognized format.' % normalized_format + msg = "%s is not a recognized format." % normalized_format if supported_formats: - msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats)) + msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats)) raise ValueError(msg) if options: - raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys()))) + raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys()))) return self.explain_prefix def insert_statement(self, on_conflict=None): - return 'INSERT INTO' + return "INSERT INTO" def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): - return '' + return "" diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 4cd4567cbc..ea98e86b77 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -2,7 +2,12 @@ import logging from datetime import datetime from django.db.backends.ddl_references import ( - Columns, Expressions, ForeignKeyName, IndexName, Statement, Table, + Columns, + Expressions, + ForeignKeyName, + IndexName, + Statement, + Table, ) from django.db.backends.utils import names_digest, split_identifier from django.db.models import Deferrable, Index @@ -10,7 +15,7 @@ from django.db.models.sql import Query from django.db.transaction import TransactionManagementError, atomic from django.utils import timezone -logger = logging.getLogger('django.db.backends.schema') +logger = logging.getLogger("django.db.backends.schema") def _is_relevant_relation(relation, altered_field): @@ -31,7 +36,10 @@ def _is_relevant_relation(relation, altered_field): def _all_related_fields(model): return model._meta._get_fields( - forward=False, reverse=True, include_hidden=True, include_parents=False, + forward=False, + reverse=True, + include_hidden=True, + include_parents=False, ) @@ -39,8 +47,16 @@ def _related_non_m2m_objects(old_field, new_field): # Filter out m2m objects from reverse relations. # Return (old_relation, new_relation) tuples. related_fields = zip( - (obj for obj in _all_related_fields(old_field.model) if _is_relevant_relation(obj, old_field)), - (obj for obj in _all_related_fields(new_field.model) if _is_relevant_relation(obj, new_field)), + ( + obj + for obj in _all_related_fields(old_field.model) + if _is_relevant_relation(obj, old_field) + ), + ( + obj + for obj in _all_related_fields(new_field.model) + if _is_relevant_relation(obj, new_field) + ), ) for old_rel, new_rel in related_fields: yield old_rel, new_rel @@ -73,8 +89,12 @@ class BaseDatabaseSchemaEditor: sql_alter_column_no_default_null = sql_alter_column_no_default sql_alter_column_collate = "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s" sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" - sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" - sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" + sql_rename_column = ( + "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" + ) + sql_update_with_default = ( + "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" + ) sql_unique_constraint = "UNIQUE (%(columns)s)%(deferrable)s" sql_check_constraint = "CHECK (%(check)s)" @@ -99,10 +119,12 @@ class BaseDatabaseSchemaEditor: sql_create_unique_index = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)%(include)s%(condition)s" sql_delete_index = "DROP INDEX %(name)s" - sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" + sql_create_pk = ( + "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" + ) sql_delete_pk = sql_delete_constraint - sql_delete_procedure = 'DROP PROCEDURE %(procedure)s' + sql_delete_procedure = "DROP PROCEDURE %(procedure)s" def __init__(self, connection, collect_sql=False, atomic=True): self.connection = connection @@ -133,7 +155,11 @@ class BaseDatabaseSchemaEditor: """Execute the given SQL statement, with optional parameters.""" # Don't perform the transactional DDL check if SQL is being collected # as it's not going to be executed anyway. - if not self.collect_sql and self.connection.in_atomic_block and not self.connection.features.can_rollback_ddl: + if ( + not self.collect_sql + and self.connection.in_atomic_block + and not self.connection.features.can_rollback_ddl + ): raise TransactionManagementError( "Executing DDL statements while in a transaction on databases " "that can't perform a rollback is prohibited." @@ -141,11 +167,15 @@ class BaseDatabaseSchemaEditor: # Account for non-string statement objects. sql = str(sql) # Log the command we're running, then run it - logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql}) + logger.debug( + "%s; (params %r)", sql, params, extra={"params": params, "sql": sql} + ) if self.collect_sql: ending = "" if sql.rstrip().endswith(";") else ";" if params is not None: - self.collected_sql.append((sql % tuple(map(self.quote_value, params))) + ending) + self.collected_sql.append( + (sql % tuple(map(self.quote_value, params))) + ending + ) else: self.collected_sql.append(sql + ending) else: @@ -172,59 +202,82 @@ class BaseDatabaseSchemaEditor: continue # Check constraints can go on the column SQL here. db_params = field.db_parameters(connection=self.connection) - if db_params['check']: - definition += ' ' + self.sql_check_constraint % db_params + if db_params["check"]: + definition += " " + self.sql_check_constraint % db_params # Autoincrement SQL (for backends with inline variant). col_type_suffix = field.db_type_suffix(connection=self.connection) if col_type_suffix: - definition += ' %s' % col_type_suffix + definition += " %s" % col_type_suffix params.extend(extra_params) # FK. if field.remote_field and field.db_constraint: to_table = field.remote_field.model._meta.db_table - to_column = field.remote_field.model._meta.get_field(field.remote_field.field_name).column + to_column = field.remote_field.model._meta.get_field( + field.remote_field.field_name + ).column if self.sql_create_inline_fk: - definition += ' ' + self.sql_create_inline_fk % { - 'to_table': self.quote_name(to_table), - 'to_column': self.quote_name(to_column), + definition += " " + self.sql_create_inline_fk % { + "to_table": self.quote_name(to_table), + "to_column": self.quote_name(to_column), } elif self.connection.features.supports_foreign_keys: - self.deferred_sql.append(self._create_fk_sql(model, field, '_fk_%(to_table)s_%(to_column)s')) + self.deferred_sql.append( + self._create_fk_sql( + model, field, "_fk_%(to_table)s_%(to_column)s" + ) + ) # Add the SQL to our big list. - column_sqls.append('%s %s' % ( - self.quote_name(field.column), - definition, - )) + column_sqls.append( + "%s %s" + % ( + self.quote_name(field.column), + definition, + ) + ) # Autoincrement SQL (for backends with post table definition # variant). - if field.get_internal_type() in ('AutoField', 'BigAutoField', 'SmallAutoField'): - autoinc_sql = self.connection.ops.autoinc_sql(model._meta.db_table, field.column) + if field.get_internal_type() in ( + "AutoField", + "BigAutoField", + "SmallAutoField", + ): + autoinc_sql = self.connection.ops.autoinc_sql( + model._meta.db_table, field.column + ) if autoinc_sql: self.deferred_sql.extend(autoinc_sql) - constraints = [constraint.constraint_sql(model, self) for constraint in model._meta.constraints] + constraints = [ + constraint.constraint_sql(model, self) + for constraint in model._meta.constraints + ] sql = self.sql_create_table % { - 'table': self.quote_name(model._meta.db_table), - 'definition': ', '.join(constraint for constraint in (*column_sqls, *constraints) if constraint), + "table": self.quote_name(model._meta.db_table), + "definition": ", ".join( + constraint for constraint in (*column_sqls, *constraints) if constraint + ), } if model._meta.db_tablespace: - tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace) + tablespace_sql = self.connection.ops.tablespace_sql( + model._meta.db_tablespace + ) if tablespace_sql: - sql += ' ' + tablespace_sql + sql += " " + tablespace_sql return sql, params # Field <-> database mapping functions def _iter_column_sql(self, column_db_type, params, model, field, include_default): yield column_db_type - collation = getattr(field, 'db_collation', None) + collation = getattr(field, "db_collation", None) if collation: yield self._collate_sql(collation) # Work out nullability. null = field.null # Include a default value, if requested. include_default = ( - include_default and - not self.skip_default(field) and + include_default + and not self.skip_default(field) + and # Don't include a default value if it's a nullable field and the # default cannot be dropped in the ALTER COLUMN statement (e.g. # MySQL longtext and longblob). @@ -233,7 +286,7 @@ class BaseDatabaseSchemaEditor: if include_default: default_value = self.effective_default(field) if default_value is not None: - column_default = 'DEFAULT ' + self._column_default_sql(field) + column_default = "DEFAULT " + self._column_default_sql(field) if self.connection.features.requires_literal_defaults: # Some databases can't take defaults as a parameter (Oracle). # If this is the case, the individual schema backend should @@ -244,20 +297,27 @@ class BaseDatabaseSchemaEditor: params.append(default_value) # Oracle treats the empty string ('') as null, so coerce the null # option whenever '' is a possible value. - if (field.empty_strings_allowed and not field.primary_key and - self.connection.features.interprets_empty_strings_as_nulls): + if ( + field.empty_strings_allowed + and not field.primary_key + and self.connection.features.interprets_empty_strings_as_nulls + ): null = True if not null: - yield 'NOT NULL' + yield "NOT NULL" elif not self.connection.features.implied_column_null: - yield 'NULL' + yield "NULL" if field.primary_key: - yield 'PRIMARY KEY' + yield "PRIMARY KEY" elif field.unique: - yield 'UNIQUE' + yield "UNIQUE" # Optionally add the tablespace if it's an implicitly indexed column. tablespace = field.db_tablespace or model._meta.db_tablespace - if tablespace and self.connection.features.supports_tablespaces and field.unique: + if ( + tablespace + and self.connection.features.supports_tablespaces + and field.unique + ): yield self.connection.ops.tablespace_sql(tablespace, inline=True) def column_sql(self, model, field, include_default=False): @@ -267,15 +327,20 @@ class BaseDatabaseSchemaEditor: """ # Get the column's type and use that as the basis of the SQL. db_params = field.db_parameters(connection=self.connection) - column_db_type = db_params['type'] + column_db_type = db_params["type"] # Check for fields that aren't actually columns (e.g. M2M). if column_db_type is None: return None, None params = [] - return ' '.join( - # This appends to the params being returned. - self._iter_column_sql(column_db_type, params, model, field, include_default) - ), params + return ( + " ".join( + # This appends to the params being returned. + self._iter_column_sql( + column_db_type, params, model, field, include_default + ) + ), + params, + ) def skip_default(self, field): """ @@ -296,8 +361,8 @@ class BaseDatabaseSchemaEditor: Only used for backends which have requires_literal_defaults feature """ raise NotImplementedError( - 'subclasses of BaseDatabaseSchemaEditor for backends which have ' - 'requires_literal_defaults must provide a prepare_default() method' + "subclasses of BaseDatabaseSchemaEditor for backends which have " + "requires_literal_defaults must provide a prepare_default() method" ) def _column_default_sql(self, field): @@ -305,7 +370,7 @@ class BaseDatabaseSchemaEditor: Return the SQL to use in a DEFAULT clause. The resulting string should contain a '%s' placeholder for a default value. """ - return '%s' + return "%s" @staticmethod def _effective_default(field): @@ -314,18 +379,18 @@ class BaseDatabaseSchemaEditor: default = field.get_default() elif not field.null and field.blank and field.empty_strings_allowed: if field.get_internal_type() == "BinaryField": - default = b'' + default = b"" else: - default = '' - elif getattr(field, 'auto_now', False) or getattr(field, 'auto_now_add', False): + default = "" + elif getattr(field, "auto_now", False) or getattr(field, "auto_now_add", False): internal_type = field.get_internal_type() - if internal_type == 'DateTimeField': + if internal_type == "DateTimeField": default = timezone.now() else: default = datetime.now() - if internal_type == 'DateField': + if internal_type == "DateField": default = default.date() - elif internal_type == 'TimeField': + elif internal_type == "TimeField": default = default.time() else: default = None @@ -372,19 +437,24 @@ class BaseDatabaseSchemaEditor: self.delete_model(field.remote_field.through) # Delete the table - self.execute(self.sql_delete_table % { - "table": self.quote_name(model._meta.db_table), - }) + self.execute( + self.sql_delete_table + % { + "table": self.quote_name(model._meta.db_table), + } + ) # Remove all deferred statements referencing the deleted table. for sql in list(self.deferred_sql): - if isinstance(sql, Statement) and sql.references_table(model._meta.db_table): + if isinstance(sql, Statement) and sql.references_table( + model._meta.db_table + ): self.deferred_sql.remove(sql) def add_index(self, model, index): """Add an index on a model.""" if ( - index.contains_expressions and - not self.connection.features.supports_expression_indexes + index.contains_expressions + and not self.connection.features.supports_expression_indexes ): return None # Index.create_sql returns interpolated SQL which makes params=None a @@ -394,8 +464,8 @@ class BaseDatabaseSchemaEditor: def remove_index(self, model, index): """Remove an index from a model.""" if ( - index.contains_expressions and - not self.connection.features.supports_expression_indexes + index.contains_expressions + and not self.connection.features.supports_expression_indexes ): return None self.execute(index.remove_sql(model, self)) @@ -424,7 +494,9 @@ class BaseDatabaseSchemaEditor: news = {tuple(fields) for fields in new_unique_together} # Deleted uniques for fields in olds.difference(news): - self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_unique) + self._delete_composed_index( + model, fields, {"unique": True}, self.sql_delete_unique + ) # Created uniques for field_names in news.difference(olds): fields = [model._meta.get_field(field) for field in field_names] @@ -443,40 +515,51 @@ class BaseDatabaseSchemaEditor: self._delete_composed_index( model, fields, - {'index': True, 'unique': False}, + {"index": True, "unique": False}, self.sql_delete_index, ) # Created indexes for field_names in news.difference(olds): fields = [model._meta.get_field(field) for field in field_names] - self.execute(self._create_index_sql(model, fields=fields, suffix='_idx')) + self.execute(self._create_index_sql(model, fields=fields, suffix="_idx")) def _delete_composed_index(self, model, fields, constraint_kwargs, sql): - meta_constraint_names = {constraint.name for constraint in model._meta.constraints} + meta_constraint_names = { + constraint.name for constraint in model._meta.constraints + } meta_index_names = {constraint.name for constraint in model._meta.indexes} columns = [model._meta.get_field(field).column for field in fields] constraint_names = self._constraint_names( - model, columns, exclude=meta_constraint_names | meta_index_names, - **constraint_kwargs + model, + columns, + exclude=meta_constraint_names | meta_index_names, + **constraint_kwargs, ) if len(constraint_names) != 1: - raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % ( - len(constraint_names), - model._meta.db_table, - ", ".join(columns), - )) + raise ValueError( + "Found wrong number (%s) of constraints for %s(%s)" + % ( + len(constraint_names), + model._meta.db_table, + ", ".join(columns), + ) + ) self.execute(self._delete_constraint_sql(sql, model, constraint_names[0])) def alter_db_table(self, model, old_db_table, new_db_table): """Rename the table a model points to.""" - if (old_db_table == new_db_table or - (self.connection.features.ignores_table_name_case and - old_db_table.lower() == new_db_table.lower())): + if old_db_table == new_db_table or ( + self.connection.features.ignores_table_name_case + and old_db_table.lower() == new_db_table.lower() + ): return - self.execute(self.sql_rename_table % { - "old_table": self.quote_name(old_db_table), - "new_table": self.quote_name(new_db_table), - }) + self.execute( + self.sql_rename_table + % { + "old_table": self.quote_name(old_db_table), + "new_table": self.quote_name(new_db_table), + } + ) # Rename all references to the old table name. for sql in self.deferred_sql: if isinstance(sql, Statement): @@ -484,11 +567,14 @@ class BaseDatabaseSchemaEditor: def alter_db_tablespace(self, model, old_db_tablespace, new_db_tablespace): """Move a model's table between tablespaces.""" - self.execute(self.sql_retablespace_table % { - "table": self.quote_name(model._meta.db_table), - "old_tablespace": self.quote_name(old_db_tablespace), - "new_tablespace": self.quote_name(new_db_tablespace), - }) + self.execute( + self.sql_retablespace_table + % { + "table": self.quote_name(model._meta.db_table), + "old_tablespace": self.quote_name(old_db_tablespace), + "new_tablespace": self.quote_name(new_db_tablespace), + } + ) def add_field(self, model, field): """ @@ -505,26 +591,36 @@ class BaseDatabaseSchemaEditor: return # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) - if db_params['check']: + if db_params["check"]: definition += " " + self.sql_check_constraint % db_params - if field.remote_field and self.connection.features.supports_foreign_keys and field.db_constraint: - constraint_suffix = '_fk_%(to_table)s_%(to_column)s' + if ( + field.remote_field + and self.connection.features.supports_foreign_keys + and field.db_constraint + ): + constraint_suffix = "_fk_%(to_table)s_%(to_column)s" # Add FK constraint inline, if supported. if self.sql_create_column_inline_fk: to_table = field.remote_field.model._meta.db_table - to_column = field.remote_field.model._meta.get_field(field.remote_field.field_name).column + to_column = field.remote_field.model._meta.get_field( + field.remote_field.field_name + ).column namespace, _ = split_identifier(model._meta.db_table) definition += " " + self.sql_create_column_inline_fk % { - 'name': self._fk_constraint_name(model, field, constraint_suffix), - 'namespace': '%s.' % self.quote_name(namespace) if namespace else '', - 'column': self.quote_name(field.column), - 'to_table': self.quote_name(to_table), - 'to_column': self.quote_name(to_column), - 'deferrable': self.connection.ops.deferrable_sql() + "name": self._fk_constraint_name(model, field, constraint_suffix), + "namespace": "%s." % self.quote_name(namespace) + if namespace + else "", + "column": self.quote_name(field.column), + "to_table": self.quote_name(to_table), + "to_column": self.quote_name(to_column), + "deferrable": self.connection.ops.deferrable_sql(), } # Otherwise, add FK constraints later. else: - self.deferred_sql.append(self._create_fk_sql(model, field, constraint_suffix)) + self.deferred_sql.append( + self._create_fk_sql(model, field, constraint_suffix) + ) # Build the SQL and run it sql = self.sql_create_column % { "table": self.quote_name(model._meta.db_table), @@ -534,8 +630,13 @@ class BaseDatabaseSchemaEditor: self.execute(sql, params) # Drop the default if we need to # (Django usually does not use in-database defaults) - if not self.skip_default_on_alter(field) and self.effective_default(field) is not None: - changes_sql, params = self._alter_column_default_sql(model, None, field, drop=True) + if ( + not self.skip_default_on_alter(field) + and self.effective_default(field) is not None + ): + changes_sql, params = self._alter_column_default_sql( + model, None, field, drop=True + ) sql = self.sql_alter_column % { "table": self.quote_name(model._meta.db_table), "changes": changes_sql, @@ -556,7 +657,7 @@ class BaseDatabaseSchemaEditor: if field.many_to_many and field.remote_field.through._meta.auto_created: return self.delete_model(field.remote_field.through) # It might not actually have a column behind it - if field.db_parameters(connection=self.connection)['type'] is None: + if field.db_parameters(connection=self.connection)["type"] is None: return # Drop any FK constraints, MySQL requires explicit deletion if field.remote_field: @@ -574,7 +675,9 @@ class BaseDatabaseSchemaEditor: self.connection.close() # Remove all deferred statements referencing the deleted column. for sql in list(self.deferred_sql): - if isinstance(sql, Statement) and sql.references_column(model._meta.db_table, field.column): + if isinstance(sql, Statement) and sql.references_column( + model._meta.db_table, field.column + ): self.deferred_sql.remove(sql) def alter_field(self, model, old_field, new_field, strict=False): @@ -589,25 +692,38 @@ class BaseDatabaseSchemaEditor: return # Ensure this field is even column-based old_db_params = old_field.db_parameters(connection=self.connection) - old_type = old_db_params['type'] + old_type = old_db_params["type"] new_db_params = new_field.db_parameters(connection=self.connection) - new_type = new_db_params['type'] - if ((old_type is None and old_field.remote_field is None) or - (new_type is None and new_field.remote_field is None)): + new_type = new_db_params["type"] + if (old_type is None and old_field.remote_field is None) or ( + new_type is None and new_field.remote_field is None + ): raise ValueError( "Cannot alter field %s into %s - they do not properly define " - "db_type (are you using a badly-written custom field?)" % - (old_field, new_field), + "db_type (are you using a badly-written custom field?)" + % (old_field, new_field), ) - elif old_type is None and new_type is None and ( - old_field.remote_field.through and new_field.remote_field.through and - old_field.remote_field.through._meta.auto_created and - new_field.remote_field.through._meta.auto_created): + elif ( + old_type is None + and new_type is None + and ( + old_field.remote_field.through + and new_field.remote_field.through + and old_field.remote_field.through._meta.auto_created + and new_field.remote_field.through._meta.auto_created + ) + ): return self._alter_many_to_many(model, old_field, new_field, strict) - elif old_type is None and new_type is None and ( - old_field.remote_field.through and new_field.remote_field.through and - not old_field.remote_field.through._meta.auto_created and - not new_field.remote_field.through._meta.auto_created): + elif ( + old_type is None + and new_type is None + and ( + old_field.remote_field.through + and new_field.remote_field.through + and not old_field.remote_field.through._meta.auto_created + and not new_field.remote_field.through._meta.auto_created + ) + ): # Both sides have through models; this is a no-op. return elif old_type is None or new_type is None: @@ -617,52 +733,86 @@ class BaseDatabaseSchemaEditor: "through= on M2M fields)" % (old_field, new_field) ) - self._alter_field(model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict) + self._alter_field( + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict, + ) - def _alter_field(self, model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict=False): + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): """Perform a "physical" (non-ManyToMany) field update.""" # Drop any FK constraints, we'll remake them later fks_dropped = set() if ( - self.connection.features.supports_foreign_keys and - old_field.remote_field and - old_field.db_constraint + self.connection.features.supports_foreign_keys + and old_field.remote_field + and old_field.db_constraint ): - fk_names = self._constraint_names(model, [old_field.column], foreign_key=True) + fk_names = self._constraint_names( + model, [old_field.column], foreign_key=True + ) if strict and len(fk_names) != 1: - raise ValueError("Found wrong number (%s) of foreign key constraints for %s.%s" % ( - len(fk_names), - model._meta.db_table, - old_field.column, - )) + raise ValueError( + "Found wrong number (%s) of foreign key constraints for %s.%s" + % ( + len(fk_names), + model._meta.db_table, + old_field.column, + ) + ) for fk_name in fk_names: fks_dropped.add((old_field.column,)) self.execute(self._delete_fk_sql(model, fk_name)) # Has unique been removed? - if old_field.unique and (not new_field.unique or self._field_became_primary_key(old_field, new_field)): + if old_field.unique and ( + not new_field.unique or self._field_became_primary_key(old_field, new_field) + ): # Find the unique constraint for this field - meta_constraint_names = {constraint.name for constraint in model._meta.constraints} + meta_constraint_names = { + constraint.name for constraint in model._meta.constraints + } constraint_names = self._constraint_names( - model, [old_field.column], unique=True, primary_key=False, + model, + [old_field.column], + unique=True, + primary_key=False, exclude=meta_constraint_names, ) if strict and len(constraint_names) != 1: - raise ValueError("Found wrong number (%s) of unique constraints for %s.%s" % ( - len(constraint_names), - model._meta.db_table, - old_field.column, - )) + raise ValueError( + "Found wrong number (%s) of unique constraints for %s.%s" + % ( + len(constraint_names), + model._meta.db_table, + old_field.column, + ) + ) for constraint_name in constraint_names: self.execute(self._delete_unique_sql(model, constraint_name)) # Drop incoming FK constraints if the field is a primary key or unique, # which might be a to_field target, and things are going to change. drop_foreign_keys = ( - self.connection.features.supports_foreign_keys and ( - (old_field.primary_key and new_field.primary_key) or - (old_field.unique and new_field.unique) - ) and old_type != new_type + self.connection.features.supports_foreign_keys + and ( + (old_field.primary_key and new_field.primary_key) + or (old_field.unique and new_field.unique) + ) + and old_type != new_type ) if drop_foreign_keys: # '_meta.related_field' also contains M2M reverse fields, these @@ -683,13 +833,20 @@ class BaseDatabaseSchemaEditor: # True | False | False | False # True | False | False | True # True | False | True | True - if old_field.db_index and not old_field.unique and (not new_field.db_index or new_field.unique): + if ( + old_field.db_index + and not old_field.unique + and (not new_field.db_index or new_field.unique) + ): # Find the index for this field meta_index_names = {index.name for index in model._meta.indexes} # Retrieve only BTREE indexes since this is what's created with # db_index=True. index_names = self._constraint_names( - model, [old_field.column], index=True, type_=Index.suffix, + model, + [old_field.column], + index=True, + type_=Index.suffix, exclude=meta_index_names, ) for index_name in index_names: @@ -698,41 +855,58 @@ class BaseDatabaseSchemaEditor: # is to look at its name (refs #28053). self.execute(self._delete_index_sql(model, index_name)) # Change check constraints? - if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: - meta_constraint_names = {constraint.name for constraint in model._meta.constraints} + if old_db_params["check"] != new_db_params["check"] and old_db_params["check"]: + meta_constraint_names = { + constraint.name for constraint in model._meta.constraints + } constraint_names = self._constraint_names( - model, [old_field.column], check=True, + model, + [old_field.column], + check=True, exclude=meta_constraint_names, ) if strict and len(constraint_names) != 1: - raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % ( - len(constraint_names), - model._meta.db_table, - old_field.column, - )) + raise ValueError( + "Found wrong number (%s) of check constraints for %s.%s" + % ( + len(constraint_names), + model._meta.db_table, + old_field.column, + ) + ) for constraint_name in constraint_names: self.execute(self._delete_check_sql(model, constraint_name)) # Have they renamed the column? if old_field.column != new_field.column: - self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) + self.execute( + self._rename_field_sql( + model._meta.db_table, old_field, new_field, new_type + ) + ) # Rename all references to the renamed column. for sql in self.deferred_sql: if isinstance(sql, Statement): - sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column) + sql.rename_column_references( + model._meta.db_table, old_field.column, new_field.column + ) # Next, start accumulating actions to do actions = [] null_actions = [] post_actions = [] # Collation change? - old_collation = getattr(old_field, 'db_collation', None) - new_collation = getattr(new_field, 'db_collation', None) + old_collation = getattr(old_field, "db_collation", None) + new_collation = getattr(new_field, "db_collation", None) if old_collation != new_collation: # Collation change handles also a type change. - fragment = self._alter_column_collation_sql(model, new_field, new_type, new_collation) + fragment = self._alter_column_collation_sql( + model, new_field, new_type, new_collation + ) actions.append(fragment) # Type change? elif old_type != new_type: - fragment, other_actions = self._alter_column_type_sql(model, old_field, new_field, new_type) + fragment, other_actions = self._alter_column_type_sql( + model, old_field, new_field, new_type + ) actions.append(fragment) post_actions.extend(other_actions) # When changing a column NULL constraint to NOT NULL with a given @@ -747,21 +921,22 @@ class BaseDatabaseSchemaEditor: old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) if ( - not self.skip_default_on_alter(new_field) and - old_default != new_default and - new_default is not None + not self.skip_default_on_alter(new_field) + and old_default != new_default + and new_default is not None ): needs_database_default = True - actions.append(self._alter_column_default_sql(model, old_field, new_field)) + actions.append( + self._alter_column_default_sql(model, old_field, new_field) + ) # Nullability change? if old_field.null != new_field.null: fragment = self._alter_column_null_sql(model, old_field, new_field) if fragment: null_actions.append(fragment) # Only if we have a default and there is a change from NULL to NOT NULL - four_way_default_alteration = ( - new_field.has_default() and - (old_field.null and not new_field.null) + four_way_default_alteration = new_field.has_default() and ( + old_field.null and not new_field.null ) if actions or null_actions: if not four_way_default_alteration: @@ -775,7 +950,8 @@ class BaseDatabaseSchemaEditor: # Apply those actions for sql, params in actions: self.execute( - self.sql_alter_column % { + self.sql_alter_column + % { "table": self.quote_name(model._meta.db_table), "changes": sql, }, @@ -784,7 +960,8 @@ class BaseDatabaseSchemaEditor: if four_way_default_alteration: # Update existing rows with default value self.execute( - self.sql_update_with_default % { + self.sql_update_with_default + % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(new_field.column), "default": "%s", @@ -795,7 +972,8 @@ class BaseDatabaseSchemaEditor: # now for sql, params in null_actions: self.execute( - self.sql_alter_column % { + self.sql_alter_column + % { "table": self.quote_name(model._meta.db_table), "changes": sql, }, @@ -819,7 +997,11 @@ class BaseDatabaseSchemaEditor: # False | False | True | False # False | True | True | False # True | True | True | False - if (not old_field.db_index or old_field.unique) and new_field.db_index and not new_field.unique: + if ( + (not old_field.db_index or old_field.unique) + and new_field.db_index + and not new_field.unique + ): self.execute(self._create_index_sql(model, fields=[new_field])) # Type alteration on primary key? Then we need to alter the column # referring to us. @@ -835,12 +1017,13 @@ class BaseDatabaseSchemaEditor: # Handle our type alters on the other end of rels from the PK stuff above for old_rel, new_rel in rels_to_update: rel_db_params = new_rel.field.db_parameters(connection=self.connection) - rel_type = rel_db_params['type'] + rel_type = rel_db_params["type"] fragment, other_actions = self._alter_column_type_sql( new_rel.related_model, old_rel.field, new_rel.field, rel_type ) self.execute( - self.sql_alter_column % { + self.sql_alter_column + % { "table": self.quote_name(new_rel.related_model._meta.db_table), "changes": fragment[0], }, @@ -849,23 +1032,38 @@ class BaseDatabaseSchemaEditor: for sql, params in other_actions: self.execute(sql, params) # Does it have a foreign key? - if (self.connection.features.supports_foreign_keys and new_field.remote_field and - (fks_dropped or not old_field.remote_field or not old_field.db_constraint) and - new_field.db_constraint): - self.execute(self._create_fk_sql(model, new_field, "_fk_%(to_table)s_%(to_column)s")) + if ( + self.connection.features.supports_foreign_keys + and new_field.remote_field + and ( + fks_dropped or not old_field.remote_field or not old_field.db_constraint + ) + and new_field.db_constraint + ): + self.execute( + self._create_fk_sql(model, new_field, "_fk_%(to_table)s_%(to_column)s") + ) # Rebuild FKs that pointed to us if we previously had to drop them if drop_foreign_keys: for _, rel in rels_to_update: if rel.field.db_constraint: - self.execute(self._create_fk_sql(rel.related_model, rel.field, "_fk")) + self.execute( + self._create_fk_sql(rel.related_model, rel.field, "_fk") + ) # Does it have check constraints we need to add? - if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: - constraint_name = self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check') - self.execute(self._create_check_sql(model, constraint_name, new_db_params['check'])) + if old_db_params["check"] != new_db_params["check"] and new_db_params["check"]: + constraint_name = self._create_index_name( + model._meta.db_table, [new_field.column], suffix="_check" + ) + self.execute( + self._create_check_sql(model, constraint_name, new_db_params["check"]) + ) # Drop the default if we need to # (Django usually does not use in-database defaults) if needs_database_default: - changes_sql, params = self._alter_column_default_sql(model, old_field, new_field, drop=True) + changes_sql, params = self._alter_column_default_sql( + model, old_field, new_field, drop=True + ) sql = self.sql_alter_column % { "table": self.quote_name(model._meta.db_table), "changes": changes_sql, @@ -883,18 +1081,23 @@ class BaseDatabaseSchemaEditor: as required by new_field, or None if no changes are required. """ if ( - self.connection.features.interprets_empty_strings_as_nulls and - new_field.empty_strings_allowed + self.connection.features.interprets_empty_strings_as_nulls + and new_field.empty_strings_allowed ): # The field is nullable in the database anyway, leave it alone. return else: new_db_params = new_field.db_parameters(connection=self.connection) - sql = self.sql_alter_column_null if new_field.null else self.sql_alter_column_not_null + sql = ( + self.sql_alter_column_null + if new_field.null + else self.sql_alter_column_not_null + ) return ( - sql % { - 'column': self.quote_name(new_field.column), - 'type': new_db_params['type'], + sql + % { + "column": self.quote_name(new_field.column), + "type": new_db_params["type"], }, [], ) @@ -928,10 +1131,11 @@ class BaseDatabaseSchemaEditor: else: sql = self.sql_alter_column_default return ( - sql % { - 'column': self.quote_name(new_field.column), - 'type': new_db_params['type'], - 'default': default, + sql + % { + "column": self.quote_name(new_field.column), + "type": new_db_params["type"], + "default": default, }, params, ) @@ -948,7 +1152,8 @@ class BaseDatabaseSchemaEditor: """ return ( ( - self.sql_alter_column_type % { + self.sql_alter_column_type + % { "column": self.quote_name(new_field.column), "type": new_type, }, @@ -959,10 +1164,13 @@ class BaseDatabaseSchemaEditor: def _alter_column_collation_sql(self, model, new_field, new_type, new_collation): return ( - self.sql_alter_column_collate % { - 'column': self.quote_name(new_field.column), - 'type': new_type, - 'collation': ' ' + self._collate_sql(new_collation) if new_collation else '', + self.sql_alter_column_collate + % { + "column": self.quote_name(new_field.column), + "type": new_type, + "collation": " " + self._collate_sql(new_collation) + if new_collation + else "", }, [], ) @@ -970,16 +1178,26 @@ class BaseDatabaseSchemaEditor: def _alter_many_to_many(self, model, old_field, new_field, strict): """Alter M2Ms to repoint their to= endpoints.""" # Rename the through table - if old_field.remote_field.through._meta.db_table != new_field.remote_field.through._meta.db_table: - self.alter_db_table(old_field.remote_field.through, old_field.remote_field.through._meta.db_table, - new_field.remote_field.through._meta.db_table) + if ( + old_field.remote_field.through._meta.db_table + != new_field.remote_field.through._meta.db_table + ): + self.alter_db_table( + old_field.remote_field.through, + old_field.remote_field.through._meta.db_table, + new_field.remote_field.through._meta.db_table, + ) # Repoint the FK to the other side self.alter_field( new_field.remote_field.through, # We need the field that points to the target model, so we can tell alter_field to change it - # this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model) - old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()), - new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()), + old_field.remote_field.through._meta.get_field( + old_field.m2m_reverse_field_name() + ), + new_field.remote_field.through._meta.get_field( + new_field.m2m_reverse_field_name() + ), ) self.alter_field( new_field.remote_field.through, @@ -996,19 +1214,22 @@ class BaseDatabaseSchemaEditor: and a unique digest and suffix. """ _, table_name = split_identifier(table_name) - hash_suffix_part = '%s%s' % (names_digest(table_name, *column_names, length=8), suffix) + hash_suffix_part = "%s%s" % ( + names_digest(table_name, *column_names, length=8), + suffix, + ) max_length = self.connection.ops.max_name_length() or 200 # If everything fits into max_length, use that name. - index_name = '%s_%s_%s' % (table_name, '_'.join(column_names), hash_suffix_part) + index_name = "%s_%s_%s" % (table_name, "_".join(column_names), hash_suffix_part) if len(index_name) <= max_length: return index_name # Shorten a long suffix. if len(hash_suffix_part) > max_length / 3: - hash_suffix_part = hash_suffix_part[:max_length // 3] + hash_suffix_part = hash_suffix_part[: max_length // 3] other_length = (max_length - len(hash_suffix_part)) // 2 - 1 - index_name = '%s_%s_%s' % ( + index_name = "%s_%s_%s" % ( table_name[:other_length], - '_'.join(column_names)[:other_length], + "_".join(column_names)[:other_length], hash_suffix_part, ) # Prepend D if needed to prevent the name from starting with an @@ -1024,25 +1245,38 @@ class BaseDatabaseSchemaEditor: elif model._meta.db_tablespace: db_tablespace = model._meta.db_tablespace if db_tablespace is not None: - return ' ' + self.connection.ops.tablespace_sql(db_tablespace) - return '' + return " " + self.connection.ops.tablespace_sql(db_tablespace) + return "" def _index_condition_sql(self, condition): if condition: - return ' WHERE ' + condition - return '' + return " WHERE " + condition + return "" def _index_include_sql(self, model, columns): if not columns or not self.connection.features.supports_covering_indexes: - return '' + return "" return Statement( - ' INCLUDE (%(columns)s)', + " INCLUDE (%(columns)s)", columns=Columns(model._meta.db_table, columns, self.quote_name), ) - def _create_index_sql(self, model, *, fields=None, name=None, suffix='', using='', - db_tablespace=None, col_suffixes=(), sql=None, opclasses=(), - condition=None, include=None, expressions=None): + def _create_index_sql( + self, + model, + *, + fields=None, + name=None, + suffix="", + using="", + db_tablespace=None, + col_suffixes=(), + sql=None, + opclasses=(), + condition=None, + include=None, + expressions=None, + ): """ Return the SQL statement to create the index for one or several fields or expressions. `sql` can be specified if the syntax differs from the @@ -1053,7 +1287,9 @@ class BaseDatabaseSchemaEditor: compiler = Query(model, alias_cols=False).get_compiler( connection=self.connection, ) - tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace) + tablespace_sql = self._get_index_tablespace_sql( + model, fields, db_tablespace=db_tablespace + ) columns = [field.column for field in fields] sql_create_index = sql or self.sql_create_index table = model._meta.db_table @@ -1102,12 +1338,12 @@ class BaseDatabaseSchemaEditor: for field_names in model._meta.index_together: fields = [model._meta.get_field(field) for field in field_names] - output.append(self._create_index_sql(model, fields=fields, suffix='_idx')) + output.append(self._create_index_sql(model, fields=fields, suffix="_idx")) for index in model._meta.indexes: if ( - not index.contains_expressions or - self.connection.features.supports_expression_indexes + not index.contains_expressions + or self.connection.features.supports_expression_indexes ): output.append(index.create_sql(model, self)) return output @@ -1129,26 +1365,25 @@ class BaseDatabaseSchemaEditor: # - changing an attribute that doesn't affect the schema # - adding only a db_column and the column name is not changed non_database_attrs = [ - 'blank', - 'db_column', - 'editable', - 'error_messages', - 'help_text', - 'limit_choices_to', + "blank", + "db_column", + "editable", + "error_messages", + "help_text", + "limit_choices_to", # Database-level options are not supported, see #21961. - 'on_delete', - 'related_name', - 'related_query_name', - 'validators', - 'verbose_name', + "on_delete", + "related_name", + "related_query_name", + "validators", + "verbose_name", ] for attr in non_database_attrs: old_kwargs.pop(attr, None) new_kwargs.pop(attr, None) - return ( - self.quote_name(old_field.column) != self.quote_name(new_field.column) or - (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs) - ) + return self.quote_name(old_field.column) != self.quote_name( + new_field.column + ) or (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs) def _field_should_be_indexed(self, model, field): return field.db_index and not field.unique @@ -1158,9 +1393,9 @@ class BaseDatabaseSchemaEditor: def _unique_should_be_added(self, old_field, new_field): return ( - not new_field.primary_key and - new_field.unique and - (not old_field.unique or old_field.primary_key) + not new_field.primary_key + and new_field.unique + and (not old_field.unique or old_field.primary_key) ) def _rename_field_sql(self, table, old_field, new_field, new_type): @@ -1176,7 +1411,11 @@ class BaseDatabaseSchemaEditor: name = self._fk_constraint_name(model, field, suffix) column = Columns(model._meta.db_table, [field.column], self.quote_name) to_table = Table(field.target_field.model._meta.db_table, self.quote_name) - to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name) + to_column = Columns( + field.target_field.model._meta.db_table, + [field.target_field.column], + self.quote_name, + ) deferrable = self.connection.ops.deferrable_sql() return Statement( self.sql_create_fk, @@ -1206,19 +1445,26 @@ class BaseDatabaseSchemaEditor: def _deferrable_constraint_sql(self, deferrable): if deferrable is None: - return '' + return "" if deferrable == Deferrable.DEFERRED: - return ' DEFERRABLE INITIALLY DEFERRED' + return " DEFERRABLE INITIALLY DEFERRED" if deferrable == Deferrable.IMMEDIATE: - return ' DEFERRABLE INITIALLY IMMEDIATE' + return " DEFERRABLE INITIALLY IMMEDIATE" def _unique_sql( - self, model, fields, name, condition=None, deferrable=None, - include=None, opclasses=None, expressions=None, + self, + model, + fields, + name, + condition=None, + deferrable=None, + include=None, + opclasses=None, + expressions=None, ): if ( - deferrable and - not self.connection.features.supports_deferrable_unique_constraints + deferrable + and not self.connection.features.supports_deferrable_unique_constraints ): return None if condition or include or opclasses or expressions: @@ -1237,37 +1483,48 @@ class BaseDatabaseSchemaEditor: self.deferred_sql.append(sql) return None constraint = self.sql_unique_constraint % { - 'columns': ', '.join([self.quote_name(field.column) for field in fields]), - 'deferrable': self._deferrable_constraint_sql(deferrable), + "columns": ", ".join([self.quote_name(field.column) for field in fields]), + "deferrable": self._deferrable_constraint_sql(deferrable), } return self.sql_constraint % { - 'name': self.quote_name(name), - 'constraint': constraint, + "name": self.quote_name(name), + "constraint": constraint, } def _create_unique_sql( - self, model, fields, name=None, condition=None, deferrable=None, - include=None, opclasses=None, expressions=None, + self, + model, + fields, + name=None, + condition=None, + deferrable=None, + include=None, + opclasses=None, + expressions=None, ): if ( ( - deferrable and - not self.connection.features.supports_deferrable_unique_constraints - ) or - (condition and not self.connection.features.supports_partial_indexes) or - (include and not self.connection.features.supports_covering_indexes) or - (expressions and not self.connection.features.supports_expression_indexes) + deferrable + and not self.connection.features.supports_deferrable_unique_constraints + ) + or (condition and not self.connection.features.supports_partial_indexes) + or (include and not self.connection.features.supports_covering_indexes) + or ( + expressions and not self.connection.features.supports_expression_indexes + ) ): return None def create_unique_name(*args, **kwargs): return self.quote_name(self._create_index_name(*args, **kwargs)) - compiler = Query(model, alias_cols=False).get_compiler(connection=self.connection) + compiler = Query(model, alias_cols=False).get_compiler( + connection=self.connection + ) table = model._meta.db_table columns = [field.column for field in fields] if name is None: - name = IndexName(table, columns, '_uniq', create_unique_name) + name = IndexName(table, columns, "_uniq", create_unique_name) else: name = self.quote_name(name) if condition or include or opclasses or expressions: @@ -1275,7 +1532,9 @@ class BaseDatabaseSchemaEditor: else: sql = self.sql_create_unique if columns: - columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses) + columns = self._index_columns( + table, columns, col_suffixes=(), opclasses=opclasses + ) else: columns = Expressions(table, expressions, compiler, self.quote_value) return Statement( @@ -1289,18 +1548,25 @@ class BaseDatabaseSchemaEditor: ) def _delete_unique_sql( - self, model, name, condition=None, deferrable=None, include=None, - opclasses=None, expressions=None, + self, + model, + name, + condition=None, + deferrable=None, + include=None, + opclasses=None, + expressions=None, ): if ( ( - deferrable and - not self.connection.features.supports_deferrable_unique_constraints - ) or - (condition and not self.connection.features.supports_partial_indexes) or - (include and not self.connection.features.supports_covering_indexes) or - (expressions and not self.connection.features.supports_expression_indexes) - + deferrable + and not self.connection.features.supports_deferrable_unique_constraints + ) + or (condition and not self.connection.features.supports_partial_indexes) + or (include and not self.connection.features.supports_covering_indexes) + or ( + expressions and not self.connection.features.supports_expression_indexes + ) ): return None if condition or include or opclasses or expressions: @@ -1311,8 +1577,8 @@ class BaseDatabaseSchemaEditor: def _check_sql(self, name, check): return self.sql_constraint % { - 'name': self.quote_name(name), - 'constraint': self.sql_check_constraint % {'check': check}, + "name": self.quote_name(name), + "constraint": self.sql_check_constraint % {"check": check}, } def _create_check_sql(self, model, name, check): @@ -1333,9 +1599,18 @@ class BaseDatabaseSchemaEditor: name=self.quote_name(name), ) - def _constraint_names(self, model, column_names=None, unique=None, - primary_key=None, index=None, foreign_key=None, - check=None, type_=None, exclude=None): + def _constraint_names( + self, + model, + column_names=None, + unique=None, + primary_key=None, + index=None, + foreign_key=None, + check=None, + type_=None, + exclude=None, + ): """Return all constraint names matching the columns and conditions.""" if column_names is not None: column_names = [ @@ -1343,21 +1618,23 @@ class BaseDatabaseSchemaEditor: for name in column_names ] with self.connection.cursor() as cursor: - constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table) + constraints = self.connection.introspection.get_constraints( + cursor, model._meta.db_table + ) result = [] for name, infodict in constraints.items(): - if column_names is None or column_names == infodict['columns']: - if unique is not None and infodict['unique'] != unique: + if column_names is None or column_names == infodict["columns"]: + if unique is not None and infodict["unique"] != unique: continue - if primary_key is not None and infodict['primary_key'] != primary_key: + if primary_key is not None and infodict["primary_key"] != primary_key: continue - if index is not None and infodict['index'] != index: + if index is not None and infodict["index"] != index: continue - if check is not None and infodict['check'] != check: + if check is not None and infodict["check"] != check: continue - if foreign_key is not None and not infodict['foreign_key']: + if foreign_key is not None and not infodict["foreign_key"]: continue - if type_ is not None and infodict['type'] != type_: + if type_ is not None and infodict["type"] != type_: continue if not exclude or name not in exclude: result.append(name) @@ -1366,10 +1643,13 @@ class BaseDatabaseSchemaEditor: def _delete_primary_key(self, model, strict=False): constraint_names = self._constraint_names(model, primary_key=True) if strict and len(constraint_names) != 1: - raise ValueError('Found wrong number (%s) of PK constraints for %s' % ( - len(constraint_names), - model._meta.db_table, - )) + raise ValueError( + "Found wrong number (%s) of PK constraints for %s" + % ( + len(constraint_names), + model._meta.db_table, + ) + ) for constraint_name in constraint_names: self.execute(self._delete_primary_key_sql(model, constraint_name)) @@ -1378,7 +1658,9 @@ class BaseDatabaseSchemaEditor: self.sql_create_pk, table=Table(model._meta.db_table, self.quote_name), name=self.quote_name( - self._create_index_name(model._meta.db_table, [field.column], suffix="_pk") + self._create_index_name( + model._meta.db_table, [field.column], suffix="_pk" + ) ), columns=Columns(model._meta.db_table, [field.column], self.quote_name), ) @@ -1387,11 +1669,11 @@ class BaseDatabaseSchemaEditor: return self._delete_constraint_sql(self.sql_delete_pk, model, name) def _collate_sql(self, collation): - return 'COLLATE ' + self.quote_name(collation) + return "COLLATE " + self.quote_name(collation) def remove_procedure(self, procedure_name, param_types=()): sql = self.sql_delete_procedure % { - 'procedure': self.quote_name(procedure_name), - 'param_types': ','.join(param_types), + "procedure": self.quote_name(procedure_name), + "param_types": ",".join(param_types), } self.execute(sql) diff --git a/django/db/backends/base/validation.py b/django/db/backends/base/validation.py index a02780a694..d0e3e2157d 100644 --- a/django/db/backends/base/validation.py +++ b/django/db/backends/base/validation.py @@ -1,5 +1,6 @@ class BaseDatabaseValidation: """Encapsulate backend-specific validation.""" + def __init__(self, connection): self.connection = connection @@ -9,9 +10,12 @@ class BaseDatabaseValidation: def check_field(self, field, **kwargs): errors = [] # Backends may implement a check_field_type() method. - if (hasattr(self, 'check_field_type') and - # Ignore any related fields. - not getattr(field, 'remote_field', None)): + if ( + hasattr(self, "check_field_type") + and + # Ignore any related fields. + not getattr(field, "remote_field", None) + ): # Ignore fields with unsupported features. db_supports_all_required_features = all( getattr(self.connection.features, feature, False) diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py index f798fd648b..412d07a993 100644 --- a/django/db/backends/ddl_references.py +++ b/django/db/backends/ddl_references.py @@ -33,10 +33,12 @@ class Reference: pass def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, str(self)) + return "<%s %r>" % (self.__class__.__name__, str(self)) def __str__(self): - raise NotImplementedError('Subclasses must define how they should be converted to string.') + raise NotImplementedError( + "Subclasses must define how they should be converted to string." + ) class Table(Reference): @@ -88,12 +90,14 @@ class Columns(TableColumns): try: suffix = self.col_suffixes[idx] if suffix: - col = '{} {}'.format(col, suffix) + col = "{} {}".format(col, suffix) except IndexError: pass return col - return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns)) + return ", ".join( + col_str(column, idx) for idx, column in enumerate(self.columns) + ) class IndexName(TableColumns): @@ -117,35 +121,49 @@ class IndexColumns(Columns): def col_str(column, idx): # Index.__init__() guarantees that self.opclasses is the same # length as self.columns. - col = '{} {}'.format(self.quote_name(column), self.opclasses[idx]) + col = "{} {}".format(self.quote_name(column), self.opclasses[idx]) try: suffix = self.col_suffixes[idx] if suffix: - col = '{} {}'.format(col, suffix) + col = "{} {}".format(col, suffix) except IndexError: pass return col - return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns)) + return ", ".join( + col_str(column, idx) for idx, column in enumerate(self.columns) + ) class ForeignKeyName(TableColumns): """Hold a reference to a foreign key name.""" - def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name): + def __init__( + self, + from_table, + from_columns, + to_table, + to_columns, + suffix_template, + create_fk_name, + ): self.to_reference = TableColumns(to_table, to_columns) self.suffix_template = suffix_template self.create_fk_name = create_fk_name - super().__init__(from_table, from_columns,) + super().__init__( + from_table, + from_columns, + ) def references_table(self, table): - return super().references_table(table) or self.to_reference.references_table(table) + return super().references_table(table) or self.to_reference.references_table( + table + ) def references_column(self, table, column): - return ( - super().references_column(table, column) or - self.to_reference.references_column(table, column) - ) + return super().references_column( + table, column + ) or self.to_reference.references_column(table, column) def rename_table_references(self, old_table, new_table): super().rename_table_references(old_table, new_table) @@ -157,8 +175,8 @@ class ForeignKeyName(TableColumns): def __str__(self): suffix = self.suffix_template % { - 'to_table': self.to_reference.table, - 'to_column': self.to_reference.columns[0], + "to_table": self.to_reference.table, + "to_column": self.to_reference.columns[0], } return self.create_fk_name(self.table, self.columns, suffix) @@ -171,30 +189,31 @@ class Statement(Reference): that might have to be adjusted if they're referencing a table or column that is removed """ + def __init__(self, template, **parts): self.template = template self.parts = parts def references_table(self, table): return any( - hasattr(part, 'references_table') and part.references_table(table) + hasattr(part, "references_table") and part.references_table(table) for part in self.parts.values() ) def references_column(self, table, column): return any( - hasattr(part, 'references_column') and part.references_column(table, column) + hasattr(part, "references_column") and part.references_column(table, column) for part in self.parts.values() ) def rename_table_references(self, old_table, new_table): for part in self.parts.values(): - if hasattr(part, 'rename_table_references'): + if hasattr(part, "rename_table_references"): part.rename_table_references(old_table, new_table) def rename_column_references(self, table, old_column, new_column): for part in self.parts.values(): - if hasattr(part, 'rename_column_references'): + if hasattr(part, "rename_column_references"): part.rename_column_references(table, old_column, new_column) def __str__(self): @@ -206,7 +225,10 @@ class Expressions(TableColumns): self.compiler = compiler self.expressions = expressions self.quote_value = quote_value - columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])] + columns = [ + col.target.column + for col in self.compiler.query._gen_cols([self.expressions]) + ] super().__init__(table, columns) def rename_table_references(self, old_table, new_table): diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index 06a25f2276..36c6480a78 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -17,9 +17,11 @@ from django.db.backends.dummy.features import DummyDatabaseFeatures def complain(*args, **kwargs): - raise ImproperlyConfigured("settings.DATABASES is improperly configured. " - "Please supply the ENGINE value. Check " - "settings documentation for more details.") + raise ImproperlyConfigured( + "settings.DATABASES is improperly configured. " + "Please supply the ENGINE value. Check " + "settings documentation for more details." + ) def ignore(*args, **kwargs): diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index ef2cd3a5ae..b689040f7f 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -15,7 +15,7 @@ try: import MySQLdb as Database except ImportError as err: raise ImproperlyConfigured( - 'Error loading MySQLdb module.\nDid you install mysqlclient?' + "Error loading MySQLdb module.\nDid you install mysqlclient?" ) from err from MySQLdb.constants import CLIENT, FIELD_TYPE @@ -32,7 +32,9 @@ from .validation import DatabaseValidation version = Database.version_info if version < (1, 4, 0): - raise ImproperlyConfigured('mysqlclient 1.4.0 or newer is required; you have %s.' % Database.__version__) + raise ImproperlyConfigured( + "mysqlclient 1.4.0 or newer is required; you have %s." % Database.__version__ + ) # MySQLdb returns TIME columns as timedelta -- they are more like timedelta in @@ -45,7 +47,7 @@ django_conversions = { # This should match the numerical portion of the version numbers (we can treat # versions like 5.0.24 and 5.0.24a as the same). -server_version_re = _lazy_re_compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') +server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})") class CursorWrapper: @@ -56,6 +58,7 @@ class CursorWrapper: Implemented as a wrapper, rather than a subclass, so that it isn't stuck to the particular underlying representation returned by Connection.cursor(). """ + codes_for_integrityerror = ( 1048, # Column cannot be null 1690, # BIGINT UNSIGNED value is out of range @@ -95,39 +98,39 @@ class CursorWrapper: class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'mysql' + vendor = "mysql" # This dictionary maps Field objects to their associated MySQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { - 'AutoField': 'integer AUTO_INCREMENT', - 'BigAutoField': 'bigint AUTO_INCREMENT', - 'BinaryField': 'longblob', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime(6)', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'DurationField': 'bigint', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'char(15)', - 'GenericIPAddressField': 'char(39)', - 'JSONField': 'json', - 'OneToOneField': 'integer', - 'PositiveBigIntegerField': 'bigint UNSIGNED', - 'PositiveIntegerField': 'integer UNSIGNED', - 'PositiveSmallIntegerField': 'smallint UNSIGNED', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallAutoField': 'smallint AUTO_INCREMENT', - 'SmallIntegerField': 'smallint', - 'TextField': 'longtext', - 'TimeField': 'time(6)', - 'UUIDField': 'char(32)', + "AutoField": "integer AUTO_INCREMENT", + "BigAutoField": "bigint AUTO_INCREMENT", + "BinaryField": "longblob", + "BooleanField": "bool", + "CharField": "varchar(%(max_length)s)", + "DateField": "date", + "DateTimeField": "datetime(6)", + "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", + "DurationField": "bigint", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "double precision", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "char(15)", + "GenericIPAddressField": "char(39)", + "JSONField": "json", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint UNSIGNED", + "PositiveIntegerField": "integer UNSIGNED", + "PositiveSmallIntegerField": "smallint UNSIGNED", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "smallint AUTO_INCREMENT", + "SmallIntegerField": "smallint", + "TextField": "longtext", + "TimeField": "time(6)", + "UUIDField": "char(32)", } # For these data types: @@ -136,23 +139,30 @@ class DatabaseWrapper(BaseDatabaseWrapper): # - all versions of MySQL and MariaDB don't support full width database # indexes _limited_data_types = ( - 'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text', - 'mediumtext', 'longtext', 'json', + "tinyblob", + "blob", + "mediumblob", + "longblob", + "tinytext", + "text", + "mediumtext", + "longtext", + "json", ) operators = { - 'exact': '= %s', - 'iexact': 'LIKE %s', - 'contains': 'LIKE BINARY %s', - 'icontains': 'LIKE %s', - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': 'LIKE BINARY %s', - 'endswith': 'LIKE BINARY %s', - 'istartswith': 'LIKE %s', - 'iendswith': 'LIKE %s', + "exact": "= %s", + "iexact": "LIKE %s", + "contains": "LIKE BINARY %s", + "icontains": "LIKE %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE BINARY %s", + "endswith": "LIKE BINARY %s", + "istartswith": "LIKE %s", + "iendswith": "LIKE %s", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -165,19 +175,19 @@ class DatabaseWrapper(BaseDatabaseWrapper): # the LIKE operator. pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')" pattern_ops = { - 'contains': "LIKE BINARY CONCAT('%%', {}, '%%')", - 'icontains': "LIKE CONCAT('%%', {}, '%%')", - 'startswith': "LIKE BINARY CONCAT({}, '%%')", - 'istartswith': "LIKE CONCAT({}, '%%')", - 'endswith': "LIKE BINARY CONCAT('%%', {})", - 'iendswith': "LIKE CONCAT('%%', {})", + "contains": "LIKE BINARY CONCAT('%%', {}, '%%')", + "icontains": "LIKE CONCAT('%%', {}, '%%')", + "startswith": "LIKE BINARY CONCAT({}, '%%')", + "istartswith": "LIKE CONCAT({}, '%%')", + "endswith": "LIKE BINARY CONCAT('%%', {})", + "iendswith": "LIKE CONCAT('%%', {})", } isolation_levels = { - 'read uncommitted', - 'read committed', - 'repeatable read', - 'serializable', + "read uncommitted", + "read committed", + "repeatable read", + "serializable", } Database = Database @@ -192,37 +202,39 @@ class DatabaseWrapper(BaseDatabaseWrapper): def get_connection_params(self): kwargs = { - 'conv': django_conversions, - 'charset': 'utf8', + "conv": django_conversions, + "charset": "utf8", } settings_dict = self.settings_dict - if settings_dict['USER']: - kwargs['user'] = settings_dict['USER'] - if settings_dict['NAME']: - kwargs['database'] = settings_dict['NAME'] - if settings_dict['PASSWORD']: - kwargs['password'] = settings_dict['PASSWORD'] - if settings_dict['HOST'].startswith('/'): - kwargs['unix_socket'] = settings_dict['HOST'] - elif settings_dict['HOST']: - kwargs['host'] = settings_dict['HOST'] - if settings_dict['PORT']: - kwargs['port'] = int(settings_dict['PORT']) + if settings_dict["USER"]: + kwargs["user"] = settings_dict["USER"] + if settings_dict["NAME"]: + kwargs["database"] = settings_dict["NAME"] + if settings_dict["PASSWORD"]: + kwargs["password"] = settings_dict["PASSWORD"] + if settings_dict["HOST"].startswith("/"): + kwargs["unix_socket"] = settings_dict["HOST"] + elif settings_dict["HOST"]: + kwargs["host"] = settings_dict["HOST"] + if settings_dict["PORT"]: + kwargs["port"] = int(settings_dict["PORT"]) # We need the number of potentially affected rows after an # "UPDATE", not the number of changed rows. - kwargs['client_flag'] = CLIENT.FOUND_ROWS + kwargs["client_flag"] = CLIENT.FOUND_ROWS # Validate the transaction isolation level, if specified. - options = settings_dict['OPTIONS'].copy() - isolation_level = options.pop('isolation_level', 'read committed') + options = settings_dict["OPTIONS"].copy() + isolation_level = options.pop("isolation_level", "read committed") if isolation_level: isolation_level = isolation_level.lower() if isolation_level not in self.isolation_levels: raise ImproperlyConfigured( "Invalid transaction isolation level '%s' specified.\n" - "Use one of %s, or None." % ( + "Use one of %s, or None." + % ( isolation_level, - ', '.join("'%s'" % s for s in sorted(self.isolation_levels)) - )) + ", ".join("'%s'" % s for s in sorted(self.isolation_levels)), + ) + ) self.isolation_level = isolation_level kwargs.update(options) return kwargs @@ -245,14 +257,17 @@ class DatabaseWrapper(BaseDatabaseWrapper): # a recently inserted row will return when the field is tested # for NULL. Disabling this brings this aspect of MySQL in line # with SQL standards. - assignments.append('SET SQL_AUTO_IS_NULL = 0') + assignments.append("SET SQL_AUTO_IS_NULL = 0") if self.isolation_level: - assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper()) + assignments.append( + "SET SESSION TRANSACTION ISOLATION LEVEL %s" + % self.isolation_level.upper() + ) if assignments: with self.cursor() as cursor: - cursor.execute('; '.join(assignments)) + cursor.execute("; ".join(assignments)) @async_unsafe def create_cursor(self, name=None): @@ -276,7 +291,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): need to be re-enabled. """ with self.cursor() as cursor: - cursor.execute('SET foreign_key_checks=0') + cursor.execute("SET foreign_key_checks=0") return True def enable_constraint_checking(self): @@ -288,7 +303,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.needs_rollback, needs_rollback = False, self.needs_rollback try: with self.cursor() as cursor: - cursor.execute('SET foreign_key_checks=1') + cursor.execute("SET foreign_key_checks=1") finally: self.needs_rollback = needs_rollback @@ -304,21 +319,32 @@ class DatabaseWrapper(BaseDatabaseWrapper): if table_names is None: table_names = self.introspection.table_names(cursor) for table_name in table_names: - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + primary_key_column_name = self.introspection.get_primary_key_column( + cursor, table_name + ) if not primary_key_column_name: continue relations = self.introspection.get_relations(cursor, table_name) - for column_name, (referenced_column_name, referenced_table_name) in relations.items(): + for column_name, ( + referenced_column_name, + referenced_table_name, + ) in relations.items(): cursor.execute( """ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING LEFT JOIN `%s` as REFERRED ON (REFERRING.`%s` = REFERRED.`%s`) WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL - """ % ( - primary_key_column_name, column_name, table_name, - referenced_table_name, column_name, referenced_column_name, - column_name, referenced_column_name, + """ + % ( + primary_key_column_name, + column_name, + table_name, + referenced_table_name, + column_name, + referenced_column_name, + column_name, + referenced_column_name, ) ) for bad_row in cursor.fetchall(): @@ -327,8 +353,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): "foreign key: %s.%s contains a value '%s' that does not " "have a corresponding value in %s.%s." % ( - table_name, bad_row[0], table_name, column_name, - bad_row[1], referenced_table_name, referenced_column_name, + table_name, + bad_row[0], + table_name, + column_name, + bad_row[1], + referenced_table_name, + referenced_column_name, ) ) @@ -342,20 +373,20 @@ class DatabaseWrapper(BaseDatabaseWrapper): @cached_property def display_name(self): - return 'MariaDB' if self.mysql_is_mariadb else 'MySQL' + return "MariaDB" if self.mysql_is_mariadb else "MySQL" @cached_property def data_type_check_constraints(self): if self.features.supports_column_check_constraints: check_constraints = { - 'PositiveBigIntegerField': '`%(column)s` >= 0', - 'PositiveIntegerField': '`%(column)s` >= 0', - 'PositiveSmallIntegerField': '`%(column)s` >= 0', + "PositiveBigIntegerField": "`%(column)s` >= 0", + "PositiveIntegerField": "`%(column)s` >= 0", + "PositiveSmallIntegerField": "`%(column)s` >= 0", } if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3): # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as # a check constraint. - check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)' + check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)" return check_constraints return {} @@ -365,40 +396,45 @@ class DatabaseWrapper(BaseDatabaseWrapper): # Select some server variables and test if the time zone # definitions are installed. CONVERT_TZ returns NULL if 'UTC' # timezone isn't loaded into the mysql.time_zone table. - cursor.execute(""" + cursor.execute( + """ SELECT VERSION(), @@sql_mode, @@default_storage_engine, @@sql_auto_is_null, @@lower_case_table_names, CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL - """) + """ + ) row = cursor.fetchone() return { - 'version': row[0], - 'sql_mode': row[1], - 'default_storage_engine': row[2], - 'sql_auto_is_null': bool(row[3]), - 'lower_case_table_names': bool(row[4]), - 'has_zoneinfo_database': bool(row[5]), + "version": row[0], + "sql_mode": row[1], + "default_storage_engine": row[2], + "sql_auto_is_null": bool(row[3]), + "lower_case_table_names": bool(row[4]), + "has_zoneinfo_database": bool(row[5]), } @cached_property def mysql_server_info(self): - return self.mysql_server_data['version'] + return self.mysql_server_data["version"] @cached_property def mysql_version(self): match = server_version_re.match(self.mysql_server_info) if not match: - raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info) + raise Exception( + "Unable to determine MySQL version from version string %r" + % self.mysql_server_info + ) return tuple(int(x) for x in match.groups()) @cached_property def mysql_is_mariadb(self): - return 'mariadb' in self.mysql_server_info.lower() + return "mariadb" in self.mysql_server_info.lower() @cached_property def sql_mode(self): - sql_mode = self.mysql_server_data['sql_mode'] - return set(sql_mode.split(',') if sql_mode else ()) + sql_mode = self.mysql_server_data["sql_mode"] + return set(sql_mode.split(",") if sql_mode else ()) diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index 7cbe314afe..0c09a2ca1e 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -2,28 +2,28 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'mysql' + executable_name = "mysql" @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): args = [cls.executable_name] env = None - database = settings_dict['OPTIONS'].get( - 'database', - settings_dict['OPTIONS'].get('db', settings_dict['NAME']), + database = settings_dict["OPTIONS"].get( + "database", + settings_dict["OPTIONS"].get("db", settings_dict["NAME"]), ) - user = settings_dict['OPTIONS'].get('user', settings_dict['USER']) - password = settings_dict['OPTIONS'].get( - 'password', - settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD']) + user = settings_dict["OPTIONS"].get("user", settings_dict["USER"]) + password = settings_dict["OPTIONS"].get( + "password", + settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]), ) - host = settings_dict['OPTIONS'].get('host', settings_dict['HOST']) - port = settings_dict['OPTIONS'].get('port', settings_dict['PORT']) - server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca') - client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert') - client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key') - defaults_file = settings_dict['OPTIONS'].get('read_default_file') - charset = settings_dict['OPTIONS'].get('charset') + host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"]) + port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"]) + server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca") + client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert") + client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key") + defaults_file = settings_dict["OPTIONS"].get("read_default_file") + charset = settings_dict["OPTIONS"].get("charset") # Seems to be no good way to set sql_mode with CLI. if defaults_file: @@ -38,9 +38,9 @@ class DatabaseClient(BaseDatabaseClient): # prevents password exposure if the subprocess.run(check=True) call # raises a CalledProcessError since the string representation of # the latter includes all of the provided `args`. - env = {'MYSQL_PWD': password} + env = {"MYSQL_PWD": password} if host: - if '/' in host: + if "/" in host: args += ["--socket=%s" % host] else: args += ["--host=%s" % host] @@ -53,7 +53,7 @@ class DatabaseClient(BaseDatabaseClient): if client_key: args += ["--ssl-key=%s" % client_key] if charset: - args += ['--default-character-set=%s' % charset] + args += ["--default-character-set=%s" % charset] if database: args += [database] args.extend(parameters) diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index 49b47961a1..a8ab03a55e 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -8,7 +8,14 @@ class SQLCompiler(compiler.SQLCompiler): qn = compiler.quote_name_unless_alias qn2 = self.connection.ops.quote_name sql, params = self.as_sql() - return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params + return ( + "(%s) IN (%s)" + % ( + ", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns), + sql, + ), + params, + ) class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): @@ -27,16 +34,15 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): # since it doesn't allow for GROUP BY and HAVING clauses. return super().as_sql() result = [ - 'DELETE %s FROM' % self.quote_name_unless_alias( - self.query.get_initial_alias() - ) + "DELETE %s FROM" + % self.quote_name_unless_alias(self.query.get_initial_alias()) ] from_sql, from_params = self.get_from_clause() result.extend(from_sql) where_sql, where_params = self.compile(where) if where_sql: - result.append('WHERE %s' % where_sql) - return ' '.join(result), tuple(from_params) + tuple(where_params) + result.append("WHERE %s" % where_sql) + return " ".join(result), tuple(from_params) + tuple(where_params) class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): @@ -50,15 +56,15 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): try: for resolved, (sql, params, _) in self.get_order_by(): if ( - isinstance(resolved.expression, Col) and - resolved.expression.alias != db_table + isinstance(resolved.expression, Col) + and resolved.expression.alias != db_table ): # Ignore ordering if it contains joined fields, because # they cannot be used in the ORDER BY clause. raise FieldError order_by_sql.append(sql) order_by_params.extend(params) - update_query += ' ORDER BY ' + ', '.join(order_by_sql) + update_query += " ORDER BY " + ", ".join(order_by_sql) update_params += tuple(order_by_params) except FieldError: # Ignore ordering if it contains annotations, because they're diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py index 1f0261b667..a060f41d18 100644 --- a/django/db/backends/mysql/creation.py +++ b/django/db/backends/mysql/creation.py @@ -8,15 +8,14 @@ from .client import DatabaseClient class DatabaseCreation(BaseDatabaseCreation): - def sql_table_creation_suffix(self): suffix = [] - test_settings = self.connection.settings_dict['TEST'] - if test_settings['CHARSET']: - suffix.append('CHARACTER SET %s' % test_settings['CHARSET']) - if test_settings['COLLATION']: - suffix.append('COLLATE %s' % test_settings['COLLATION']) - return ' '.join(suffix) + test_settings = self.connection.settings_dict["TEST"] + if test_settings["CHARSET"]: + suffix.append("CHARACTER SET %s" % test_settings["CHARSET"]) + if test_settings["COLLATION"]: + suffix.append("COLLATE %s" % test_settings["COLLATION"]) + return " ".join(suffix) def _execute_create_test_db(self, cursor, parameters, keepdb=False): try: @@ -24,17 +23,17 @@ class DatabaseCreation(BaseDatabaseCreation): except Exception as e: if len(e.args) < 1 or e.args[0] != 1007: # All errors except "database exists" (1007) cancel tests. - self.log('Got an error creating the test database: %s' % e) + self.log("Got an error creating the test database: %s" % e) sys.exit(2) else: raise def _clone_test_db(self, suffix, verbosity, keepdb=False): - source_database_name = self.connection.settings_dict['NAME'] - target_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] test_db_params = { - 'dbname': self.connection.ops.quote_name(target_database_name), - 'suffix': self.sql_table_creation_suffix(), + "dbname": self.connection.ops.quote_name(target_database_name), + "suffix": self.sql_table_creation_suffix(), } with self._nodb_cursor() as cursor: try: @@ -45,24 +44,44 @@ class DatabaseCreation(BaseDatabaseCreation): return try: if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, target_database_name), - )) - cursor.execute('DROP DATABASE %(dbname)s' % test_db_params) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) + cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: - self.log('Got an error recreating the test database: %s' % e) + self.log("Got an error recreating the test database: %s" % e) sys.exit(2) self._clone_db(source_database_name, target_database_name) def _clone_db(self, source_database_name, target_database_name): - cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, []) - dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name] + cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env( + self.connection.settings_dict, [] + ) + dump_cmd = [ + "mysqldump", + *cmd_args[1:-1], + "--routines", + "--events", + source_database_name, + ] dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None load_cmd = cmd_args load_cmd[-1] = target_database_name - with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc: - with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env): + with subprocess.Popen( + dump_cmd, stdout=subprocess.PIPE, env=dump_env + ) as dump_proc: + with subprocess.Popen( + load_cmd, + stdin=dump_proc.stdout, + stdout=subprocess.DEVNULL, + env=load_env, + ): # Allow dump_proc to receive a SIGPIPE if the load process exits. dump_proc.stdout.close() diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 5d6c4afde0..d485d40d60 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -50,87 +50,104 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def test_collations(self): - charset = 'utf8' - if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 6): + charset = "utf8" + if self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( + 10, + 6, + ): # utf8 is an alias for utf8mb3 in MariaDB 10.6+. - charset = 'utf8mb3' + charset = "utf8mb3" return { - 'ci': f'{charset}_general_ci', - 'non_default': f'{charset}_esperanto_ci', - 'swedish_ci': f'{charset}_swedish_ci', + "ci": f"{charset}_general_ci", + "non_default": f"{charset}_esperanto_ci", + "swedish_ci": f"{charset}_swedish_ci", } - test_now_utc_template = 'UTC_TIMESTAMP' + test_now_utc_template = "UTC_TIMESTAMP" @cached_property def django_test_skips(self): skips = { "This doesn't work on MySQL.": { - 'db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround', - 'db_functions.comparison.test_least.LeastTests.test_coalesce_workaround', + "db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround", + "db_functions.comparison.test_least.LeastTests.test_coalesce_workaround", }, - 'Running on MySQL requires utf8mb4 encoding (#18392).': { - 'model_fields.test_textfield.TextFieldTests.test_emoji', - 'model_fields.test_charfield.TestCharField.test_emoji', + "Running on MySQL requires utf8mb4 encoding (#18392).": { + "model_fields.test_textfield.TextFieldTests.test_emoji", + "model_fields.test_charfield.TestCharField.test_emoji", }, "MySQL doesn't support functional indexes on a function that " "returns JSON": { - 'schema.tests.SchemaTests.test_func_index_json_key_transform', + "schema.tests.SchemaTests.test_func_index_json_key_transform", }, "MySQL supports multiplying and dividing DurationFields by a " "scalar value but it's not implemented (#25287).": { - 'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide', + "expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide", }, } - if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode: - skips.update({ - 'GROUP BY optimization does not work properly when ' - 'ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.': { - 'aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued', - 'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o', - }, - }) - if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,): - skips.update({ - 'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': { - 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python', - 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python', - }, - 'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': { - 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database', - 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database', - }, - }) - if ( - self.connection.mysql_is_mariadb and - (10, 4, 3) < self.connection.mysql_version < (10, 5, 2) + if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode: + skips.update( + { + "GROUP BY optimization does not work properly when " + "ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.": { + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o", + }, + } + ) + if not self.connection.mysql_is_mariadb and self.connection.mysql_version < ( + 8, ): - skips.update({ - 'https://jira.mariadb.org/browse/MDEV-19598': { - 'schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key', - }, - }) - if ( - self.connection.mysql_is_mariadb and - (10, 4, 12) < self.connection.mysql_version < (10, 5) - ): - skips.update({ - 'https://jira.mariadb.org/browse/MDEV-22775': { - 'schema.tests.SchemaTests.test_alter_pk_with_self_referential_field', - }, - }) + skips.update( + { + "Casting to datetime/time is not supported by MySQL < 8.0. (#30224)": { + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python", + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python", + }, + "MySQL < 8.0 returns string type instead of datetime/time. (#30224)": { + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database", + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database", + }, + } + ) + if self.connection.mysql_is_mariadb and ( + 10, + 4, + 3, + ) < self.connection.mysql_version < (10, 5, 2): + skips.update( + { + "https://jira.mariadb.org/browse/MDEV-19598": { + "schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key", + }, + } + ) + if self.connection.mysql_is_mariadb and ( + 10, + 4, + 12, + ) < self.connection.mysql_version < (10, 5): + skips.update( + { + "https://jira.mariadb.org/browse/MDEV-22775": { + "schema.tests.SchemaTests.test_alter_pk_with_self_referential_field", + }, + } + ) if not self.supports_explain_analyze: - skips.update({ - 'MariaDB and MySQL >= 8.0.18 specific.': { - 'queries.test_explain.ExplainTests.test_mysql_analyze', - }, - }) + skips.update( + { + "MariaDB and MySQL >= 8.0.18 specific.": { + "queries.test_explain.ExplainTests.test_mysql_analyze", + }, + } + ) return skips @cached_property def _mysql_storage_engine(self): "Internal method used in Django tests. Don't rely on this from your code" - return self.connection.mysql_server_data['default_storage_engine'] + return self.connection.mysql_server_data["default_storage_engine"] @cached_property def allows_auto_pk_0(self): @@ -138,40 +155,50 @@ class DatabaseFeatures(BaseDatabaseFeatures): Autoincrement primary key can be set to 0 if it doesn't generate new autoincrement values. """ - return 'NO_AUTO_VALUE_ON_ZERO' in self.connection.sql_mode + return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode @cached_property def update_can_self_select(self): - return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 3, 2) + return self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( + 10, + 3, + 2, + ) @cached_property def can_introspect_foreign_keys(self): "Confirm support for introspected foreign keys" - return self._mysql_storage_engine != 'MyISAM' + return self._mysql_storage_engine != "MyISAM" @cached_property def introspected_field_types(self): return { **super().introspected_field_types, - 'BinaryField': 'TextField', - 'BooleanField': 'IntegerField', - 'DurationField': 'BigIntegerField', - 'GenericIPAddressField': 'CharField', + "BinaryField": "TextField", + "BooleanField": "IntegerField", + "DurationField": "BigIntegerField", + "GenericIPAddressField": "CharField", } @cached_property def can_return_columns_from_insert(self): - return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 5, 0) + return self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( + 10, + 5, + 0, + ) - can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert')) + can_return_rows_from_bulk_insert = property( + operator.attrgetter("can_return_columns_from_insert") + ) @cached_property def has_zoneinfo_database(self): - return self.connection.mysql_server_data['has_zoneinfo_database'] + return self.connection.mysql_server_data["has_zoneinfo_database"] @cached_property def is_sql_auto_is_null_enabled(self): - return self.connection.mysql_server_data['sql_auto_is_null'] + return self.connection.mysql_server_data["sql_auto_is_null"] @cached_property def supports_over_clause(self): @@ -179,7 +206,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): return True return self.connection.mysql_version >= (8, 0, 2) - supports_frame_range_fixed_distance = property(operator.attrgetter('supports_over_clause')) + supports_frame_range_fixed_distance = property( + operator.attrgetter("supports_over_clause") + ) @cached_property def supports_column_check_constraints(self): @@ -187,7 +216,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): return True return self.connection.mysql_version >= (8, 0, 16) - supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints')) + supports_table_check_constraints = property( + operator.attrgetter("supports_column_check_constraints") + ) @cached_property def can_introspect_check_constraints(self): @@ -210,19 +241,30 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def has_select_for_update_of(self): - return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1) + return ( + not self.connection.mysql_is_mariadb + and self.connection.mysql_version >= (8, 0, 1) + ) @cached_property def supports_explain_analyze(self): - return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (8, 0, 18) + return self.connection.mysql_is_mariadb or self.connection.mysql_version >= ( + 8, + 0, + 18, + ) @cached_property def supported_explain_formats(self): # Alias MySQL's TRADITIONAL to TEXT for consistency with other # backends. - formats = {'JSON', 'TEXT', 'TRADITIONAL'} - if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 16): - formats.add('TREE') + formats = {"JSON", "TEXT", "TRADITIONAL"} + if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= ( + 8, + 0, + 16, + ): + formats.add("TREE") return formats @cached_property @@ -230,11 +272,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): """ All storage engines except MyISAM support transactions. """ - return self._mysql_storage_engine != 'MyISAM' + return self._mysql_storage_engine != "MyISAM" @cached_property def ignores_table_name_case(self): - return self.connection.mysql_server_data['lower_case_table_names'] + return self.connection.mysql_server_data["lower_case_table_names"] @cached_property def supports_default_in_lead_lag(self): @@ -256,13 +298,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def supports_index_column_ordering(self): return ( - not self.connection.mysql_is_mariadb and - self.connection.mysql_version >= (8, 0, 1) + not self.connection.mysql_is_mariadb + and self.connection.mysql_version >= (8, 0, 1) ) @cached_property def supports_expression_indexes(self): return ( - not self.connection.mysql_is_mariadb and - self.connection.mysql_version >= (8, 0, 13) + not self.connection.mysql_is_mariadb + and self.connection.mysql_version >= (8, 0, 13) ) diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 3a76168227..3cf56dffce 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -3,72 +3,76 @@ from collections import namedtuple import sqlparse from MySQLdb.constants import FIELD_TYPE -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo from django.db.models import Index from django.utils.datastructures import OrderedSet -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint')) +FieldInfo = namedtuple( + "FieldInfo", BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint") +) InfoLine = namedtuple( - 'InfoLine', - 'col_name data_type max_len num_prec num_scale extra column_default ' - 'collation is_unsigned' + "InfoLine", + "col_name data_type max_len num_prec num_scale extra column_default " + "collation is_unsigned", ) class DatabaseIntrospection(BaseDatabaseIntrospection): data_types_reverse = { - FIELD_TYPE.BLOB: 'TextField', - FIELD_TYPE.CHAR: 'CharField', - FIELD_TYPE.DECIMAL: 'DecimalField', - FIELD_TYPE.NEWDECIMAL: 'DecimalField', - FIELD_TYPE.DATE: 'DateField', - FIELD_TYPE.DATETIME: 'DateTimeField', - FIELD_TYPE.DOUBLE: 'FloatField', - FIELD_TYPE.FLOAT: 'FloatField', - FIELD_TYPE.INT24: 'IntegerField', - FIELD_TYPE.JSON: 'JSONField', - FIELD_TYPE.LONG: 'IntegerField', - FIELD_TYPE.LONGLONG: 'BigIntegerField', - FIELD_TYPE.SHORT: 'SmallIntegerField', - FIELD_TYPE.STRING: 'CharField', - FIELD_TYPE.TIME: 'TimeField', - FIELD_TYPE.TIMESTAMP: 'DateTimeField', - FIELD_TYPE.TINY: 'IntegerField', - FIELD_TYPE.TINY_BLOB: 'TextField', - FIELD_TYPE.MEDIUM_BLOB: 'TextField', - FIELD_TYPE.LONG_BLOB: 'TextField', - FIELD_TYPE.VAR_STRING: 'CharField', + FIELD_TYPE.BLOB: "TextField", + FIELD_TYPE.CHAR: "CharField", + FIELD_TYPE.DECIMAL: "DecimalField", + FIELD_TYPE.NEWDECIMAL: "DecimalField", + FIELD_TYPE.DATE: "DateField", + FIELD_TYPE.DATETIME: "DateTimeField", + FIELD_TYPE.DOUBLE: "FloatField", + FIELD_TYPE.FLOAT: "FloatField", + FIELD_TYPE.INT24: "IntegerField", + FIELD_TYPE.JSON: "JSONField", + FIELD_TYPE.LONG: "IntegerField", + FIELD_TYPE.LONGLONG: "BigIntegerField", + FIELD_TYPE.SHORT: "SmallIntegerField", + FIELD_TYPE.STRING: "CharField", + FIELD_TYPE.TIME: "TimeField", + FIELD_TYPE.TIMESTAMP: "DateTimeField", + FIELD_TYPE.TINY: "IntegerField", + FIELD_TYPE.TINY_BLOB: "TextField", + FIELD_TYPE.MEDIUM_BLOB: "TextField", + FIELD_TYPE.LONG_BLOB: "TextField", + FIELD_TYPE.VAR_STRING: "CharField", } def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if 'auto_increment' in description.extra: - if field_type == 'IntegerField': - return 'AutoField' - elif field_type == 'BigIntegerField': - return 'BigAutoField' - elif field_type == 'SmallIntegerField': - return 'SmallAutoField' + if "auto_increment" in description.extra: + if field_type == "IntegerField": + return "AutoField" + elif field_type == "BigIntegerField": + return "BigAutoField" + elif field_type == "SmallIntegerField": + return "SmallAutoField" if description.is_unsigned: - if field_type == 'BigIntegerField': - return 'PositiveBigIntegerField' - elif field_type == 'IntegerField': - return 'PositiveIntegerField' - elif field_type == 'SmallIntegerField': - return 'PositiveSmallIntegerField' + if field_type == "BigIntegerField": + return "PositiveBigIntegerField" + elif field_type == "IntegerField": + return "PositiveIntegerField" + elif field_type == "SmallIntegerField": + return "PositiveSmallIntegerField" # JSON data type is an alias for LONGTEXT in MariaDB, use check # constraints clauses to introspect JSONField. if description.has_json_constraint: - return 'JSONField' + return "JSONField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" cursor.execute("SHOW FULL TABLES") - return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1])) - for row in cursor.fetchall()] + return [ + TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1])) + for row in cursor.fetchall() + ] def get_table_description(self, cursor, table_name): """ @@ -76,7 +80,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): interface." """ json_constraints = {} - if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field: + if ( + self.connection.mysql_is_mariadb + and self.connection.features.can_introspect_json_field + ): # JSON data type is an alias for LONGTEXT in MariaDB, select # JSON_VALID() constraints to introspect JSONField. cursor.execute( @@ -102,7 +109,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): [table_name], ) row = cursor.fetchone() - default_column_collation = row[0] if row else '' + default_column_collation = row[0] if row else "" # information_schema database gives more accurate results for some figures: # - varchar length returned by cursor.description is an internal length, # not visible length (#5725) @@ -128,7 +135,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ) field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()} - cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) + ) def to_int(i): return int(i) if i is not None else i @@ -136,25 +145,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): fields = [] for line in cursor.description: info = field_info[line[0]] - fields.append(FieldInfo( - *line[:3], - to_int(info.max_len) or line[3], - to_int(info.num_prec) or line[4], - to_int(info.num_scale) or line[5], - line[6], - info.column_default, - info.collation, - info.extra, - info.is_unsigned, - line[0] in json_constraints, - )) + fields.append( + FieldInfo( + *line[:3], + to_int(info.max_len) or line[3], + to_int(info.num_prec) or line[4], + to_int(info.num_scale) or line[5], + line[6], + info.column_default, + info.collation, + info.extra, + info.is_unsigned, + line[0] in json_constraints, + ) + ) return fields def get_sequences(self, cursor, table_name, table_fields=()): for field_info in self.get_table_description(cursor, table_name): - if 'auto_increment' in field_info.extra: + if "auto_increment" in field_info.extra: # MySQL allows only one auto-increment column per table. - return [{'table': table_name, 'column': field_info.name}] + return [{"table": table_name, "column": field_info.name}] return [] def get_relations(self, cursor, table_name): @@ -204,9 +215,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): tokens = (token for token in statement.flatten() if not token.is_whitespace) for token in tokens: if ( - token.ttype == sqlparse.tokens.Name and - self.connection.ops.quote_name(token.value) == token.value and - token.value[1:-1] in columns + token.ttype == sqlparse.tokens.Name + and self.connection.ops.quote_name(token.value) == token.value + and token.value[1:-1] in columns ): check_columns.add(token.value[1:-1]) return check_columns @@ -237,20 +248,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): for constraint, column, ref_table, ref_column, kind in cursor.fetchall(): if constraint not in constraints: constraints[constraint] = { - 'columns': OrderedSet(), - 'primary_key': kind == 'PRIMARY KEY', - 'unique': kind in {'PRIMARY KEY', 'UNIQUE'}, - 'index': False, - 'check': False, - 'foreign_key': (ref_table, ref_column) if ref_column else None, + "columns": OrderedSet(), + "primary_key": kind == "PRIMARY KEY", + "unique": kind in {"PRIMARY KEY", "UNIQUE"}, + "index": False, + "check": False, + "foreign_key": (ref_table, ref_column) if ref_column else None, } if self.connection.features.supports_index_column_ordering: - constraints[constraint]['orders'] = [] - constraints[constraint]['columns'].add(column) + constraints[constraint]["orders"] = [] + constraints[constraint]["columns"].add(column) # Add check constraints. if self.connection.features.can_introspect_check_constraints: unnamed_constraints_index = 0 - columns = {info.name for info in self.get_table_description(cursor, table_name)} + columns = { + info.name for info in self.get_table_description(cursor, table_name) + } if self.connection.mysql_is_mariadb: type_query = """ SELECT c.constraint_name, c.check_clause @@ -274,42 +287,48 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): """ cursor.execute(type_query, [table_name]) for constraint, check_clause in cursor.fetchall(): - constraint_columns = self._parse_constraint_columns(check_clause, columns) + constraint_columns = self._parse_constraint_columns( + check_clause, columns + ) # Ensure uniqueness of unnamed constraints. Unnamed unique # and check columns constraints have the same name as # a column. if set(constraint_columns) == {constraint}: unnamed_constraints_index += 1 - constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index + constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index constraints[constraint] = { - 'columns': constraint_columns, - 'primary_key': False, - 'unique': False, - 'index': False, - 'check': True, - 'foreign_key': None, + "columns": constraint_columns, + "primary_key": False, + "unique": False, + "index": False, + "check": True, + "foreign_key": None, } # Now add in the indexes - cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name) + ) for table, non_unique, index, colseq, column, order, type_ in [ x[:6] + (x[10],) for x in cursor.fetchall() ]: if index not in constraints: constraints[index] = { - 'columns': OrderedSet(), - 'primary_key': False, - 'unique': not non_unique, - 'check': False, - 'foreign_key': None, + "columns": OrderedSet(), + "primary_key": False, + "unique": not non_unique, + "check": False, + "foreign_key": None, } if self.connection.features.supports_index_column_ordering: - constraints[index]['orders'] = [] - constraints[index]['index'] = True - constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower() - constraints[index]['columns'].add(column) + constraints[index]["orders"] = [] + constraints[index]["index"] = True + constraints[index]["type"] = ( + Index.suffix if type_ == "BTREE" else type_.lower() + ) + constraints[index]["columns"].add(column) if self.connection.features.supports_index_column_ordering: - constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC') + constraints[index]["orders"].append("DESC" if order == "D" else "ASC") # Convert the sorted sets to lists for constraint in constraints.values(): - constraint['columns'] = list(constraint['columns']) + constraint["columns"] = list(constraint["columns"]) return constraints diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 7f1994e657..5bcc03a67b 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -15,42 +15,42 @@ class DatabaseOperations(BaseDatabaseOperations): # MySQL stores positive fields as UNSIGNED ints. integer_field_ranges = { **BaseDatabaseOperations.integer_field_ranges, - 'PositiveSmallIntegerField': (0, 65535), - 'PositiveIntegerField': (0, 4294967295), - 'PositiveBigIntegerField': (0, 18446744073709551615), + "PositiveSmallIntegerField": (0, 65535), + "PositiveIntegerField": (0, 4294967295), + "PositiveBigIntegerField": (0, 18446744073709551615), } cast_data_types = { - 'AutoField': 'signed integer', - 'BigAutoField': 'signed integer', - 'SmallAutoField': 'signed integer', - 'CharField': 'char(%(max_length)s)', - 'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)', - 'TextField': 'char', - 'IntegerField': 'signed integer', - 'BigIntegerField': 'signed integer', - 'SmallIntegerField': 'signed integer', - 'PositiveBigIntegerField': 'unsigned integer', - 'PositiveIntegerField': 'unsigned integer', - 'PositiveSmallIntegerField': 'unsigned integer', - 'DurationField': 'signed integer', + "AutoField": "signed integer", + "BigAutoField": "signed integer", + "SmallAutoField": "signed integer", + "CharField": "char(%(max_length)s)", + "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)", + "TextField": "char", + "IntegerField": "signed integer", + "BigIntegerField": "signed integer", + "SmallIntegerField": "signed integer", + "PositiveBigIntegerField": "unsigned integer", + "PositiveIntegerField": "unsigned integer", + "PositiveSmallIntegerField": "unsigned integer", + "DurationField": "signed integer", } - cast_char_field_without_max_length = 'char' - explain_prefix = 'EXPLAIN' + cast_char_field_without_max_length = "char" + explain_prefix = "EXPLAIN" def date_extract_sql(self, lookup_type, field_name): # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html - if lookup_type == 'week_day': + if lookup_type == "week_day": # DAYOFWEEK() returns an integer, 1-7, Sunday=1. return "DAYOFWEEK(%s)" % field_name - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": # WEEKDAY() returns an integer, 0-6, Monday=0. return "WEEKDAY(%s) + 1" % field_name - elif lookup_type == 'week': + elif lookup_type == "week": # Override the value of default_week_format for consistency with # other database backends. # Mode 3: Monday, 1-53, with 4 or more days this year. return "WEEK(%s, 3)" % field_name - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": # Get the year part from the YEARWEEK function, which returns a # number as year * 100 + week. return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name @@ -61,26 +61,25 @@ class DatabaseOperations(BaseDatabaseOperations): def date_trunc_sql(self, lookup_type, field_name, tzname=None): field_name = self._convert_field_to_tz(field_name, tzname) fields = { - 'year': '%%Y-01-01', - 'month': '%%Y-%%m-01', + "year": "%%Y-01-01", + "month": "%%Y-%%m-01", } # Use double percents to escape. if lookup_type in fields: format_str = fields[lookup_type] return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str) - elif lookup_type == 'quarter': - return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % ( - field_name, field_name - ) - elif lookup_type == 'week': - return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % ( - field_name, field_name + elif lookup_type == "quarter": + return ( + "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" + % (field_name, field_name) ) + elif lookup_type == "week": + return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name) else: return "DATE(%s)" % (field_name) def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) - return f'{sign}{offset}' if offset else tzname + return f"{sign}{offset}" if offset else tzname def _convert_field_to_tz(self, field_name, tzname): if tzname and settings.USE_TZ and self.connection.timezone_name != tzname: @@ -105,16 +104,23 @@ class DatabaseOperations(BaseDatabaseOperations): def datetime_trunc_sql(self, lookup_type, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - fields = ['year', 'month', 'day', 'hour', 'minute', 'second'] - format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape. - format_def = ('0000-', '01', '-01', ' 00:', '00', ':00') - if lookup_type == 'quarter': + fields = ["year", "month", "day", "hour", "minute", "second"] + format = ( + "%%Y-", + "%%m", + "-%%d", + " %%H:", + "%%i", + ":%%s", + ) # Use double percents to escape. + format_def = ("0000-", "01", "-01", " 00:", "00", ":00") + if lookup_type == "quarter": return ( "CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + " - "INTERVAL QUARTER({field_name}) QUARTER - " + - "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)" + "INTERVAL QUARTER({field_name}) QUARTER - " + + "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)" ).format(field_name=field_name) - if lookup_type == 'week': + if lookup_type == "week": return ( "CAST(DATE_FORMAT(DATE_SUB({field_name}, " "INTERVAL WEEKDAY({field_name}) DAY), " @@ -125,16 +131,16 @@ class DatabaseOperations(BaseDatabaseOperations): except ValueError: sql = field_name else: - format_str = ''.join(format[:i] + format_def[i:]) + format_str = "".join(format[:i] + format_def[i:]) sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) return sql def time_trunc_sql(self, lookup_type, field_name, tzname=None): field_name = self._convert_field_to_tz(field_name, tzname) fields = { - 'hour': '%%H:00:00', - 'minute': '%%H:%%i:00', - 'second': '%%H:%%i:%%s', + "hour": "%%H:00:00", + "minute": "%%H:%%i:00", + "second": "%%H:%%i:%%s", } # Use double percents to escape. if lookup_type in fields: format_str = fields[lookup_type] @@ -150,7 +156,7 @@ class DatabaseOperations(BaseDatabaseOperations): return cursor.fetchall() def format_for_duration_arithmetic(self, sql): - return 'INTERVAL %s MICROSECOND' % sql + return "INTERVAL %s MICROSECOND" % sql def force_no_ordering(self): """ @@ -168,7 +174,7 @@ class DatabaseOperations(BaseDatabaseOperations): # attribute where the exact query sent to the database is saved. # See MySQLdb/cursors.py in the source distribution. # MySQLdb returns string, PyMySQL bytes. - return force_str(getattr(cursor, '_executed', None), errors='replace') + return force_str(getattr(cursor, "_executed", None), errors="replace") def no_limit_value(self): # 2**64 - 1, as recommended by the MySQL documentation @@ -183,50 +189,58 @@ class DatabaseOperations(BaseDatabaseOperations): # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING # statement. if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False): if not tables: return [] - sql = ['SET FOREIGN_KEY_CHECKS = 0;'] + sql = ["SET FOREIGN_KEY_CHECKS = 0;"] if reset_sequences: # It's faster to TRUNCATE tables that require a sequence reset # since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE. sql.extend( - '%s %s;' % ( - style.SQL_KEYWORD('TRUNCATE'), + "%s %s;" + % ( + style.SQL_KEYWORD("TRUNCATE"), style.SQL_FIELD(self.quote_name(table_name)), - ) for table_name in tables + ) + for table_name in tables ) else: # Otherwise issue a simple DELETE since it's faster than TRUNCATE # and preserves sequences. sql.extend( - '%s %s %s;' % ( - style.SQL_KEYWORD('DELETE'), - style.SQL_KEYWORD('FROM'), + "%s %s %s;" + % ( + style.SQL_KEYWORD("DELETE"), + style.SQL_KEYWORD("FROM"), style.SQL_FIELD(self.quote_name(table_name)), - ) for table_name in tables + ) + for table_name in tables ) - sql.append('SET FOREIGN_KEY_CHECKS = 1;') + sql.append("SET FOREIGN_KEY_CHECKS = 1;") return sql def sequence_reset_by_name_sql(self, style, sequences): return [ - '%s %s %s %s = 1;' % ( - style.SQL_KEYWORD('ALTER'), - style.SQL_KEYWORD('TABLE'), - style.SQL_FIELD(self.quote_name(sequence_info['table'])), - style.SQL_FIELD('AUTO_INCREMENT'), - ) for sequence_info in sequences + "%s %s %s %s = 1;" + % ( + style.SQL_KEYWORD("ALTER"), + style.SQL_KEYWORD("TABLE"), + style.SQL_FIELD(self.quote_name(sequence_info["table"])), + style.SQL_FIELD("AUTO_INCREMENT"), + ) + for sequence_info in sequences ] def validate_autopk_value(self, value): @@ -234,7 +248,7 @@ class DatabaseOperations(BaseDatabaseOperations): # NO_AUTO_VALUE_ON_ZERO SQL mode. if value == 0 and not self.connection.features.allows_auto_pk_0: raise ValueError( - 'The database backend does not accept 0 as a value for AutoField.' + "The database backend does not accept 0 as a value for AutoField." ) return value @@ -243,7 +257,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # MySQL doesn't support tz-aware datetimes @@ -251,7 +265,9 @@ class DatabaseOperations(BaseDatabaseOperations): if settings.USE_TZ: value = timezone.make_naive(value, self.connection.timezone) else: - raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.") + raise ValueError( + "MySQL backend does not support timezone-aware datetimes when USE_TZ is False." + ) return str(value) def adapt_timefield_value(self, value): @@ -259,20 +275,20 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # MySQL doesn't support tz-aware times if timezone.is_aware(value): raise ValueError("MySQL backend does not support timezone-aware times.") - return value.isoformat(timespec='microseconds') + return value.isoformat(timespec="microseconds") def max_name_length(self): return 64 def pk_default_value(self): - return 'NULL' + return "NULL" def bulk_insert_sql(self, fields, placeholder_rows): placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) @@ -280,27 +296,27 @@ class DatabaseOperations(BaseDatabaseOperations): return "VALUES " + values_sql def combine_expression(self, connector, sub_expressions): - if connector == '^': - return 'POW(%s)' % ','.join(sub_expressions) + if connector == "^": + return "POW(%s)" % ",".join(sub_expressions) # Convert the result to a signed integer since MySQL's binary operators # return an unsigned integer. - elif connector in ('&', '|', '<<', '#'): - connector = '^' if connector == '#' else connector - return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions) - elif connector == '>>': + elif connector in ("&", "|", "<<", "#"): + connector = "^" if connector == "#" else connector + return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions) + elif connector == ">>": lhs, rhs = sub_expressions - return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs} + return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} return super().combine_expression(connector, sub_expressions) def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == 'BooleanField': + if internal_type == "BooleanField": converters.append(self.convert_booleanfield_value) - elif internal_type == 'DateTimeField': + elif internal_type == "DateTimeField": if settings.USE_TZ: converters.append(self.convert_datetimefield_value) - elif internal_type == 'UUIDField': + elif internal_type == "UUIDField": converters.append(self.convert_uuidfield_value) return converters @@ -320,66 +336,88 @@ class DatabaseOperations(BaseDatabaseOperations): return value def binary_placeholder_sql(self, value): - return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s' + return ( + "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s" + ) def subtract_temporals(self, internal_type, lhs, rhs): lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs - if internal_type == 'TimeField': + if internal_type == "TimeField": if self.connection.mysql_is_mariadb: # MariaDB includes the microsecond component in TIME_TO_SEC as # a decimal. MySQL returns an integer without microseconds. - return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % { - 'lhs': lhs_sql, 'rhs': rhs_sql - }, (*lhs_params, *rhs_params) + return "CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)" % { + "lhs": lhs_sql, + "rhs": rhs_sql, + }, ( + *lhs_params, + *rhs_params, + ) return ( "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -" " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))" - ) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2 + ) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple( + rhs_params + ) * 2 params = (*rhs_params, *lhs_params) return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params def explain_query_prefix(self, format=None, **options): # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends. - if format and format.upper() == 'TEXT': - format = 'TRADITIONAL' - elif not format and 'TREE' in self.connection.features.supported_explain_formats: + if format and format.upper() == "TEXT": + format = "TRADITIONAL" + elif ( + not format and "TREE" in self.connection.features.supported_explain_formats + ): # Use TREE by default (if supported) as it's more informative. - format = 'TREE' - analyze = options.pop('analyze', False) + format = "TREE" + analyze = options.pop("analyze", False) prefix = super().explain_query_prefix(format, **options) if analyze and self.connection.features.supports_explain_analyze: # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE. - prefix = 'ANALYZE' if self.connection.mysql_is_mariadb else prefix + ' ANALYZE' + prefix = ( + "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE" + ) if format and not (analyze and not self.connection.mysql_is_mariadb): # Only MariaDB supports the analyze option with formats. - prefix += ' FORMAT=%s' % format + prefix += " FORMAT=%s" % format return prefix def regex_lookup(self, lookup_type): # REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE # doesn't exist in MySQL 5.x or in MariaDB. - if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb: - if lookup_type == 'regex': - return '%s REGEXP BINARY %s' - return '%s REGEXP %s' + if ( + self.connection.mysql_version < (8, 0, 0) + or self.connection.mysql_is_mariadb + ): + if lookup_type == "regex": + return "%s REGEXP BINARY %s" + return "%s REGEXP %s" - match_option = 'c' if lookup_type == 'regex' else 'i' + match_option = "c" if lookup_type == "regex" else "i" return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option def insert_statement(self, on_conflict=None): if on_conflict == OnConflict.IGNORE: - return 'INSERT IGNORE INTO' + return "INSERT IGNORE INTO" return super().insert_statement(on_conflict=on_conflict) def lookup_cast(self, lookup_type, internal_type=None): - lookup = '%s' - if internal_type == 'JSONField': + lookup = "%s" + if internal_type == "JSONField": if self.connection.mysql_is_mariadb or lookup_type in ( - 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', - 'endswith', 'iendswith', 'regex', 'iregex', + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", ): - lookup = 'JSON_UNQUOTE(%s)' + lookup = "JSON_UNQUOTE(%s)" return lookup def conditional_expression_supported_in_where_clause(self, expression): @@ -388,31 +426,38 @@ class DatabaseOperations(BaseDatabaseOperations): if isinstance(expression, (Exists, Lookup)): return True if isinstance(expression, ExpressionWrapper) and expression.conditional: - return self.conditional_expression_supported_in_where_clause(expression.expression) - if getattr(expression, 'conditional', False): + return self.conditional_expression_supported_in_where_clause( + expression.expression + ) + if getattr(expression, "conditional", False): return False return super().conditional_expression_supported_in_where_clause(expression) def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if on_conflict == OnConflict.UPDATE: - conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s' - field_sql = '%(field)s = VALUES(%(field)s)' + conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s" + field_sql = "%(field)s = VALUES(%(field)s)" # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use # aliases for the new row and its columns available in MySQL # 8.0.19+. if not self.connection.mysql_is_mariadb: if self.connection.mysql_version >= (8, 0, 19): - conflict_suffix_sql = f'AS new {conflict_suffix_sql}' - field_sql = '%(field)s = new.%(field)s' + conflict_suffix_sql = f"AS new {conflict_suffix_sql}" + field_sql = "%(field)s = new.%(field)s" # VALUES() was renamed to VALUE() in MariaDB 10.3.3+. elif self.connection.mysql_version >= (10, 3, 3): - field_sql = '%(field)s = VALUE(%(field)s)' + field_sql = "%(field)s = VALUE(%(field)s)" - fields = ', '.join([ - field_sql % {'field': field} - for field in map(self.quote_name, update_fields) - ]) - return conflict_suffix_sql % {'fields': fields} + fields = ", ".join( + [ + field_sql % {"field": field} + for field in map(self.quote_name, update_fields) + ] + ) + return conflict_suffix_sql % {"fields": fields} return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) diff --git a/django/db/backends/mysql/schema.py b/django/db/backends/mysql/schema.py index 17827c2195..562b209eef 100644 --- a/django/db/backends/mysql/schema.py +++ b/django/db/backends/mysql/schema.py @@ -10,24 +10,26 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL" sql_alter_column_type = "MODIFY %(column)s %(type)s" sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s" - sql_alter_column_no_default_null = 'ALTER COLUMN %(column)s SET DEFAULT NULL' + sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL" # No 'CASCADE' which works as a no-op in MySQL but is undocumented sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s" sql_create_column_inline_fk = ( - ', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) ' - 'REFERENCES %(to_table)s(%(to_column)s)' + ", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) " + "REFERENCES %(to_table)s(%(to_column)s)" ) sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s" sql_delete_index = "DROP INDEX %(name)s ON %(table)s" - sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" + sql_create_pk = ( + "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" + ) sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY" - sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s' + sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s" @property def sql_delete_check(self): @@ -35,8 +37,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # The name of the column check constraint is the same as the field # name on MariaDB. Adding IF EXISTS clause prevents migrations # crash. Constraint is removed during a "MODIFY" column statement. - return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s' - return 'ALTER TABLE %(table)s DROP CHECK %(name)s' + return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s" + return "ALTER TABLE %(table)s DROP CHECK %(name)s" @property def sql_rename_column(self): @@ -47,21 +49,26 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): return super().sql_rename_column elif self.connection.mysql_version >= (8, 0, 4): return super().sql_rename_column - return 'ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s' + return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s" def quote_value(self, value): self.connection.ensure_connection() if isinstance(value, str): - value = value.replace('%', '%%') + value = value.replace("%", "%%") # MySQLdb escapes to string, PyMySQL to bytes. - quoted = self.connection.connection.escape(value, self.connection.connection.encoders) + quoted = self.connection.connection.escape( + value, self.connection.connection.encoders + ) if isinstance(value, str) and isinstance(quoted, bytes): quoted = quoted.decode() return quoted def _is_limited_data_type(self, field): db_type = field.db_type(self.connection) - return db_type is not None and db_type.lower() in self.connection._limited_data_types + return ( + db_type is not None + and db_type.lower() in self.connection._limited_data_types + ) def skip_default(self, field): if not self._supports_limited_data_type_defaults: @@ -84,13 +91,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _column_default_sql(self, field): if ( - not self.connection.mysql_is_mariadb and - self._supports_limited_data_type_defaults and - self._is_limited_data_type(field) + not self.connection.mysql_is_mariadb + and self._supports_limited_data_type_defaults + and self._is_limited_data_type(field) ): # MySQL supports defaults for BLOB and TEXT columns only if the # default value is written as an expression i.e. in parentheses. - return '(%s)' + return "(%s)" return super()._column_default_sql(field) def add_field(self, model, field): @@ -100,10 +107,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # field.default may be unhashable, so a set isn't used for "in" check. if self.skip_default(field) and field.default not in (None, NOT_PROVIDED): effective_default = self.effective_default(field) - self.execute('UPDATE %(table)s SET %(column)s = %%s' % { - 'table': self.quote_name(model._meta.db_table), - 'column': self.quote_name(field.column), - }, [effective_default]) + self.execute( + "UPDATE %(table)s SET %(column)s = %%s" + % { + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(field.column), + }, + [effective_default], + ) def _field_should_be_indexed(self, model, field): if not super()._field_should_be_indexed(model, field): @@ -115,9 +126,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # No need to create an index for ForeignKey fields except if # db_constraint=False because the index from that constraint won't be # created. - if (storage == "InnoDB" and - field.get_internal_type() == 'ForeignKey' and - field.db_constraint): + if ( + storage == "InnoDB" + and field.get_internal_type() == "ForeignKey" + and field.db_constraint + ): return False return not self._is_limited_data_type(field) @@ -131,11 +144,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): recreate a FK index. """ first_field = model._meta.get_field(fields[0]) - if first_field.get_internal_type() == 'ForeignKey': - constraint_names = self._constraint_names(model, [first_field.column], index=True) + if first_field.get_internal_type() == "ForeignKey": + constraint_names = self._constraint_names( + model, [first_field.column], index=True + ) if not constraint_names: self.execute( - self._create_index_sql(model, fields=[first_field], suffix='') + self._create_index_sql(model, fields=[first_field], suffix="") ) return super()._delete_composed_index(model, fields, *args) diff --git a/django/db/backends/mysql/validation.py b/django/db/backends/mysql/validation.py index 41e600a856..5d61b4865a 100644 --- a/django/db/backends/mysql/validation.py +++ b/django/db/backends/mysql/validation.py @@ -10,24 +10,28 @@ class DatabaseValidation(BaseDatabaseValidation): return issues def _check_sql_mode(self, **kwargs): - if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}): - return [checks.Warning( - "%s Strict Mode is not set for database connection '%s'" - % (self.connection.display_name, self.connection.alias), - hint=( - "%s's Strict Mode fixes many data integrity problems in " - "%s, such as data truncation upon insertion, by " - "escalating warnings into errors. It is strongly " - "recommended you activate it. See: " - "https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode" - % ( - self.connection.display_name, - self.connection.display_name, - get_docs_version(), + if not ( + self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"} + ): + return [ + checks.Warning( + "%s Strict Mode is not set for database connection '%s'" + % (self.connection.display_name, self.connection.alias), + hint=( + "%s's Strict Mode fixes many data integrity problems in " + "%s, such as data truncation upon insertion, by " + "escalating warnings into errors. It is strongly " + "recommended you activate it. See: " + "https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode" + % ( + self.connection.display_name, + self.connection.display_name, + get_docs_version(), + ), ), - ), - id='mysql.W002', - )] + id="mysql.W002", + ) + ] return [] def check_field_type(self, field, field_type): @@ -38,32 +42,35 @@ class DatabaseValidation(BaseDatabaseValidation): MySQL doesn't support a database index on some data types. """ errors = [] - if (field_type.startswith('varchar') and field.unique and - (field.max_length is None or int(field.max_length) > 255)): + if ( + field_type.startswith("varchar") + and field.unique + and (field.max_length is None or int(field.max_length) > 255) + ): errors.append( checks.Warning( - '%s may not allow unique CharFields to have a max_length ' - '> 255.' % self.connection.display_name, + "%s may not allow unique CharFields to have a max_length " + "> 255." % self.connection.display_name, obj=field, hint=( - 'See: https://docs.djangoproject.com/en/%s/ref/' - 'databases/#mysql-character-fields' % get_docs_version() + "See: https://docs.djangoproject.com/en/%s/ref/" + "databases/#mysql-character-fields" % get_docs_version() ), - id='mysql.W003', + id="mysql.W003", ) ) if field.db_index and field_type.lower() in self.connection._limited_data_types: errors.append( checks.Warning( - '%s does not support a database index on %s columns.' + "%s does not support a database index on %s columns." % (self.connection.display_name, field_type), hint=( "An index won't be created. Silence this warning if " "you don't care about it." ), obj=field, - id='fields.W162', + id="fields.W162", ) ) return errors diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 966eb4b6f4..b13c5f8bb2 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -21,27 +21,31 @@ from django.utils.functional import cached_property def _setup_environment(environ): # Cygwin requires some special voodoo to set the environment variables # properly so that Oracle will see them. - if platform.system().upper().startswith('CYGWIN'): + if platform.system().upper().startswith("CYGWIN"): try: import ctypes except ImportError as e: - raise ImproperlyConfigured("Error loading ctypes: %s; " - "the Oracle backend requires ctypes to " - "operate correctly under Cygwin." % e) - kernel32 = ctypes.CDLL('kernel32') + raise ImproperlyConfigured( + "Error loading ctypes: %s; " + "the Oracle backend requires ctypes to " + "operate correctly under Cygwin." % e + ) + kernel32 = ctypes.CDLL("kernel32") for name, value in environ: kernel32.SetEnvironmentVariableA(name, value) else: os.environ.update(environ) -_setup_environment([ - # Oracle takes client-side character set encoding from the environment. - ('NLS_LANG', '.AL32UTF8'), - # This prevents Unicode from getting mangled by getting encoded into the - # potentially non-Unicode database character set. - ('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'), -]) +_setup_environment( + [ + # Oracle takes client-side character set encoding from the environment. + ("NLS_LANG", ".AL32UTF8"), + # This prevents Unicode from getting mangled by getting encoded into the + # potentially non-Unicode database character set. + ("ORA_NCHAR_LITERAL_REPLACE", "TRUE"), + ] +) try: @@ -77,17 +81,16 @@ def wrap_oracle_errors(): # Convert that case to Django's IntegrityError exception. x = e.args[0] if ( - hasattr(x, 'code') and - hasattr(x, 'message') and - x.code == 2091 and - ('ORA-02291' in x.message or 'ORA-00001' in x.message) + hasattr(x, "code") + and hasattr(x, "message") + and x.code == 2091 + and ("ORA-02291" in x.message or "ORA-00001" in x.message) ): raise IntegrityError(*tuple(e.args)) raise class _UninitializedOperatorsDescriptor: - def __get__(self, instance, cls=None): # If connection.operators is looked up before a connection has been # created, transparently initialize connection.operators to avert an @@ -96,12 +99,12 @@ class _UninitializedOperatorsDescriptor: raise AttributeError("operators not available as class attribute") # Creating a cursor will initialize the operators. instance.cursor().close() - return instance.__dict__['operators'] + return instance.__dict__["operators"] class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'oracle' - display_name = 'Oracle' + vendor = "oracle" + display_name = "Oracle" # This dictionary maps Field objects to their associated Oracle column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. @@ -110,71 +113,71 @@ class DatabaseWrapper(BaseDatabaseWrapper): # Any format strings starting with "qn_" are quoted before being used in the # output (the "qn_" prefix is stripped before the lookup is performed. data_types = { - 'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY', - 'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY', - 'BinaryField': 'BLOB', - 'BooleanField': 'NUMBER(1)', - 'CharField': 'NVARCHAR2(%(max_length)s)', - 'DateField': 'DATE', - 'DateTimeField': 'TIMESTAMP', - 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)', - 'DurationField': 'INTERVAL DAY(9) TO SECOND(6)', - 'FileField': 'NVARCHAR2(%(max_length)s)', - 'FilePathField': 'NVARCHAR2(%(max_length)s)', - 'FloatField': 'DOUBLE PRECISION', - 'IntegerField': 'NUMBER(11)', - 'JSONField': 'NCLOB', - 'BigIntegerField': 'NUMBER(19)', - 'IPAddressField': 'VARCHAR2(15)', - 'GenericIPAddressField': 'VARCHAR2(39)', - 'OneToOneField': 'NUMBER(11)', - 'PositiveBigIntegerField': 'NUMBER(19)', - 'PositiveIntegerField': 'NUMBER(11)', - 'PositiveSmallIntegerField': 'NUMBER(11)', - 'SlugField': 'NVARCHAR2(%(max_length)s)', - 'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY', - 'SmallIntegerField': 'NUMBER(11)', - 'TextField': 'NCLOB', - 'TimeField': 'TIMESTAMP', - 'URLField': 'VARCHAR2(%(max_length)s)', - 'UUIDField': 'VARCHAR2(32)', + "AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY", + "BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY", + "BinaryField": "BLOB", + "BooleanField": "NUMBER(1)", + "CharField": "NVARCHAR2(%(max_length)s)", + "DateField": "DATE", + "DateTimeField": "TIMESTAMP", + "DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)", + "DurationField": "INTERVAL DAY(9) TO SECOND(6)", + "FileField": "NVARCHAR2(%(max_length)s)", + "FilePathField": "NVARCHAR2(%(max_length)s)", + "FloatField": "DOUBLE PRECISION", + "IntegerField": "NUMBER(11)", + "JSONField": "NCLOB", + "BigIntegerField": "NUMBER(19)", + "IPAddressField": "VARCHAR2(15)", + "GenericIPAddressField": "VARCHAR2(39)", + "OneToOneField": "NUMBER(11)", + "PositiveBigIntegerField": "NUMBER(19)", + "PositiveIntegerField": "NUMBER(11)", + "PositiveSmallIntegerField": "NUMBER(11)", + "SlugField": "NVARCHAR2(%(max_length)s)", + "SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY", + "SmallIntegerField": "NUMBER(11)", + "TextField": "NCLOB", + "TimeField": "TIMESTAMP", + "URLField": "VARCHAR2(%(max_length)s)", + "UUIDField": "VARCHAR2(32)", } data_type_check_constraints = { - 'BooleanField': '%(qn_column)s IN (0,1)', - 'JSONField': '%(qn_column)s IS JSON', - 'PositiveBigIntegerField': '%(qn_column)s >= 0', - 'PositiveIntegerField': '%(qn_column)s >= 0', - 'PositiveSmallIntegerField': '%(qn_column)s >= 0', + "BooleanField": "%(qn_column)s IN (0,1)", + "JSONField": "%(qn_column)s IS JSON", + "PositiveBigIntegerField": "%(qn_column)s >= 0", + "PositiveIntegerField": "%(qn_column)s >= 0", + "PositiveSmallIntegerField": "%(qn_column)s >= 0", } # Oracle doesn't support a database index on these columns. - _limited_data_types = ('clob', 'nclob', 'blob') + _limited_data_types = ("clob", "nclob", "blob") operators = _UninitializedOperatorsDescriptor() _standard_operators = { - 'exact': '= %s', - 'iexact': '= UPPER(%s)', - 'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", - 'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", - 'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", - 'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", - 'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "exact": "= %s", + "iexact": "= UPPER(%s)", + "contains": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "icontains": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "endswith": "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "istartswith": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + "iendswith": "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", } _likec_operators = { **_standard_operators, - 'contains': "LIKEC %s ESCAPE '\\'", - 'icontains': "LIKEC UPPER(%s) ESCAPE '\\'", - 'startswith': "LIKEC %s ESCAPE '\\'", - 'endswith': "LIKEC %s ESCAPE '\\'", - 'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'", - 'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'", + "contains": "LIKEC %s ESCAPE '\\'", + "icontains": "LIKEC UPPER(%s) ESCAPE '\\'", + "startswith": "LIKEC %s ESCAPE '\\'", + "endswith": "LIKEC %s ESCAPE '\\'", + "istartswith": "LIKEC UPPER(%s) ESCAPE '\\'", + "iendswith": "LIKEC UPPER(%s) ESCAPE '\\'", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -187,19 +190,22 @@ class DatabaseWrapper(BaseDatabaseWrapper): # the LIKE operator. pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')" _pattern_ops = { - 'contains': "'%%' || {} || '%%'", - 'icontains': "'%%' || UPPER({}) || '%%'", - 'startswith': "{} || '%%'", - 'istartswith': "UPPER({}) || '%%'", - 'endswith': "'%%' || {}", - 'iendswith': "'%%' || UPPER({})", + "contains": "'%%' || {} || '%%'", + "icontains": "'%%' || UPPER({}) || '%%'", + "startswith": "{} || '%%'", + "istartswith": "UPPER({}) || '%%'", + "endswith": "'%%' || {}", + "iendswith": "'%%' || UPPER({})", } - _standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)" - " ESCAPE TRANSLATE('\\' USING NCHAR_CS)" - for k, v in _pattern_ops.items()} - _likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'" - for k, v in _pattern_ops.items()} + _standard_pattern_ops = { + k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)" + " ESCAPE TRANSLATE('\\' USING NCHAR_CS)" + for k, v in _pattern_ops.items() + } + _likec_pattern_ops = { + k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items() + } Database = Database SchemaEditorClass = DatabaseSchemaEditor @@ -213,20 +219,22 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) + use_returning_into = self.settings_dict["OPTIONS"].get( + "use_returning_into", True + ) self.features.can_return_columns_from_insert = use_returning_into def get_connection_params(self): - conn_params = self.settings_dict['OPTIONS'].copy() - if 'use_returning_into' in conn_params: - del conn_params['use_returning_into'] + conn_params = self.settings_dict["OPTIONS"].copy() + if "use_returning_into" in conn_params: + del conn_params["use_returning_into"] return conn_params @async_unsafe def get_new_connection(self, conn_params): return Database.connect( - user=self.settings_dict['USER'], - password=self.settings_dict['PASSWORD'], + user=self.settings_dict["USER"], + password=self.settings_dict["PASSWORD"], dsn=dsn(self.settings_dict), **conn_params, ) @@ -244,11 +252,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): # TO_CHAR(). cursor.execute( "ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'" - " NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" + - (" TIME_ZONE = 'UTC'" if settings.USE_TZ else '') + " NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" + + (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "") ) cursor.close() - if 'operators' not in self.__dict__: + if "operators" not in self.__dict__: # Ticket #14149: Check whether our LIKE implementation will # work for this connection or we need to fall back on LIKEC. # This check is performed only once per DatabaseWrapper @@ -256,9 +264,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): # the same settings. cursor = self.create_cursor() try: - cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s" - % self._standard_operators['contains'], - ['X']) + cursor.execute( + "SELECT 1 FROM DUAL WHERE DUMMY %s" + % self._standard_operators["contains"], + ["X"], + ) except Database.DatabaseError: self.operators = self._likec_operators self.pattern_ops = self._likec_pattern_ops @@ -284,10 +294,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): # logging is enabled to keep query counts consistent with other backends. def _savepoint_commit(self, sid): if self.queries_logged: - self.queries_log.append({ - 'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid), - 'time': '0.000', - }) + self.queries_log.append( + { + "sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid), + "time": "0.000", + } + ) def _set_autocommit(self, autocommit): with self.wrap_database_errors: @@ -299,8 +311,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): afterward. """ with self.cursor() as cursor: - cursor.execute('SET CONSTRAINTS ALL IMMEDIATE') - cursor.execute('SET CONSTRAINTS ALL DEFERRED') + cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") + cursor.execute("SET CONSTRAINTS ALL DEFERRED") def is_usable(self): try: @@ -312,12 +324,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): @cached_property def cx_oracle_version(self): - return tuple(int(x) for x in Database.version.split('.')) + return tuple(int(x) for x in Database.version.split(".")) @cached_property def oracle_version(self): with self.temporary_connection(): - return tuple(int(x) for x in self.connection.version.split('.')) + return tuple(int(x) for x in self.connection.version.split(".")) class OracleParam: @@ -333,8 +345,10 @@ class OracleParam: def __init__(self, param, cursor, strings_only=False): # With raw SQL queries, datetimes can reach this function # without being converted by DateTimeField.get_db_prep_value. - if settings.USE_TZ and (isinstance(param, datetime.datetime) and - not isinstance(param, Oracle_datetime)): + if settings.USE_TZ and ( + isinstance(param, datetime.datetime) + and not isinstance(param, Oracle_datetime) + ): param = Oracle_datetime.from_datetime(param) string_size = 0 @@ -343,7 +357,7 @@ class OracleParam: param = 1 elif param is False: param = 0 - if hasattr(param, 'bind_parameter'): + if hasattr(param, "bind_parameter"): self.force_bytes = param.bind_parameter(cursor) elif isinstance(param, (Database.Binary, datetime.timedelta)): self.force_bytes = param @@ -354,7 +368,7 @@ class OracleParam: if isinstance(self.force_bytes, str): # We could optimize by only converting up to 4000 bytes here string_size = len(force_bytes(param, cursor.charset, strings_only)) - if hasattr(param, 'input_size'): + if hasattr(param, "input_size"): # If parameter has `input_size` attribute, use that. self.input_size = param.input_size elif string_size > 4000: @@ -384,7 +398,7 @@ class VariableWrapper: return getattr(self.var, key) def __setattr__(self, key, value): - if key == 'var': + if key == "var": self.__dict__[key] = value else: setattr(self.var, key, value) @@ -396,7 +410,8 @@ class FormatStylePlaceholderCursor: style. This fixes it -- but note that if you want to use a literal "%s" in a query, you'll need to use "%%s". """ - charset = 'utf-8' + + charset = "utf-8" def __init__(self, connection): self.cursor = connection.cursor() @@ -404,7 +419,7 @@ class FormatStylePlaceholderCursor: @staticmethod def _output_number_converter(value): - return decimal.Decimal(value) if '.' in value else int(value) + return decimal.Decimal(value) if "." in value else int(value) @staticmethod def _get_decimal_converter(precision, scale): @@ -434,7 +449,9 @@ class FormatStylePlaceholderCursor: elif precision > 0: # NUMBER(p,s) column: decimal-precision fixed point. # This comes from IntegerField and DecimalField columns. - outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale) + outconverter = FormatStylePlaceholderCursor._get_decimal_converter( + precision, scale + ) else: # No type information. This normally comes from a # mathematical expression in the SELECT list. Guess int @@ -455,7 +472,7 @@ class FormatStylePlaceholderCursor: def _guess_input_sizes(self, params_list): # Try dict handling; if that fails, treat as sequence - if hasattr(params_list[0], 'keys'): + if hasattr(params_list[0], "keys"): sizes = {} for params in params_list: for k, value in params.items(): @@ -475,7 +492,7 @@ class FormatStylePlaceholderCursor: def _param_generator(self, params): # Try dict handling; if that fails, treat as sequence - if hasattr(params, 'items'): + if hasattr(params, "items"): return {k: v.force_bytes for k, v in params.items()} else: return [p.force_bytes for p in params] @@ -485,11 +502,11 @@ class FormatStylePlaceholderCursor: # it does want a trailing ';' but not a trailing '/'. However, these # characters must be included in the original query in case the query # is being passed to SQL*Plus. - if query.endswith(';') or query.endswith('/'): + if query.endswith(";") or query.endswith("/"): query = query[:-1] if params is None: params = [] - elif hasattr(params, 'keys'): + elif hasattr(params, "keys"): # Handle params as dict args = {k: ":%s" % k for k in params} query = query % args @@ -502,15 +519,14 @@ class FormatStylePlaceholderCursor: # args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0'] # params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'} params_dict = { - param: ':arg%d' % i - for i, param in enumerate(dict.fromkeys(params)) + param: ":arg%d" % i for i, param in enumerate(dict.fromkeys(params)) } args = [params_dict[param] for param in params] params = {value: key for key, value in params_dict.items()} query = query % tuple(args) else: # Handle params as sequence - args = [(':arg%d' % i) for i in range(len(params))] + args = [(":arg%d" % i) for i in range(len(params))] query = query % tuple(args) return query, self._format_params(params) @@ -532,7 +548,9 @@ class FormatStylePlaceholderCursor: formatted = [firstparams] + [self._format_params(p) for p in params_iter] self._guess_input_sizes(formatted) with wrap_oracle_errors(): - return self.cursor.executemany(query, [self._param_generator(p) for p in formatted]) + return self.cursor.executemany( + query, [self._param_generator(p) for p in formatted] + ) def close(self): try: diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py index 9920f4ca67..365b116046 100644 --- a/django/db/backends/oracle/client.py +++ b/django/db/backends/oracle/client.py @@ -4,22 +4,22 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'sqlplus' - wrapper_name = 'rlwrap' + executable_name = "sqlplus" + wrapper_name = "rlwrap" @staticmethod def connect_string(settings_dict): from django.db.backends.oracle.utils import dsn return '%s/"%s"@%s' % ( - settings_dict['USER'], - settings_dict['PASSWORD'], + settings_dict["USER"], + settings_dict["PASSWORD"], dsn(settings_dict), ) @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): - args = [cls.executable_name, '-L', cls.connect_string(settings_dict)] + args = [cls.executable_name, "-L", cls.connect_string(settings_dict)] wrapper_path = shutil.which(cls.wrapper_name) if wrapper_path: args = [wrapper_path, *args] diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py index 3ca3754e15..bdde162aa8 100644 --- a/django/db/backends/oracle/creation.py +++ b/django/db/backends/oracle/creation.py @@ -6,11 +6,10 @@ from django.db.backends.base.creation import BaseDatabaseCreation from django.utils.crypto import get_random_string from django.utils.functional import cached_property -TEST_DATABASE_PREFIX = 'test_' +TEST_DATABASE_PREFIX = "test_" class DatabaseCreation(BaseDatabaseCreation): - @cached_property def _maindb_connection(self): """ @@ -21,9 +20,9 @@ class DatabaseCreation(BaseDatabaseCreation): is the main (non-test) connection. """ settings_dict = settings.DATABASES[self.connection.alias] - user = settings_dict.get('SAVED_USER') or settings_dict['USER'] - password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD'] - settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password} + user = settings_dict.get("SAVED_USER") or settings_dict["USER"] + password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"] + settings_dict = {**settings_dict, "USER": user, "PASSWORD": password} DatabaseWrapper = type(self.connection) return DatabaseWrapper(settings_dict, alias=self.connection.alias) @@ -32,72 +31,95 @@ class DatabaseCreation(BaseDatabaseCreation): with self._maindb_connection.cursor() as cursor: if self._test_database_create(): try: - self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) + self._execute_test_db_creation( + cursor, parameters, verbosity, keepdb + ) except Exception as e: - if 'ORA-01543' not in str(e): + if "ORA-01543" not in str(e): # All errors except "tablespace already exists" cancel tests - self.log('Got an error creating the test database: %s' % e) + self.log("Got an error creating the test database: %s" % e) sys.exit(2) if not autoclobber: confirm = input( "It appears the test database, %s, already exists. " - "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user']) - if autoclobber or confirm == 'yes': + "Type 'yes' to delete it, or 'no' to cancel: " + % parameters["user"] + ) + if autoclobber or confirm == "yes": if verbosity >= 1: - self.log("Destroying old test database for alias '%s'..." % self.connection.alias) + self.log( + "Destroying old test database for alias '%s'..." + % self.connection.alias + ) try: - self._execute_test_db_destruction(cursor, parameters, verbosity) + self._execute_test_db_destruction( + cursor, parameters, verbosity + ) except DatabaseError as e: - if 'ORA-29857' in str(e): - self._handle_objects_preventing_db_destruction(cursor, parameters, - verbosity, autoclobber) + if "ORA-29857" in str(e): + self._handle_objects_preventing_db_destruction( + cursor, parameters, verbosity, autoclobber + ) else: # Ran into a database error that isn't about leftover objects in the tablespace - self.log('Got an error destroying the old test database: %s' % e) + self.log( + "Got an error destroying the old test database: %s" + % e + ) sys.exit(2) except Exception as e: - self.log('Got an error destroying the old test database: %s' % e) + self.log( + "Got an error destroying the old test database: %s" % e + ) sys.exit(2) try: - self._execute_test_db_creation(cursor, parameters, verbosity, keepdb) + self._execute_test_db_creation( + cursor, parameters, verbosity, keepdb + ) except Exception as e: - self.log('Got an error recreating the test database: %s' % e) + self.log( + "Got an error recreating the test database: %s" % e + ) sys.exit(2) else: - self.log('Tests cancelled.') + self.log("Tests cancelled.") sys.exit(1) if self._test_user_create(): if verbosity >= 1: - self.log('Creating test user...') + self.log("Creating test user...") try: self._create_test_user(cursor, parameters, verbosity, keepdb) except Exception as e: - if 'ORA-01920' not in str(e): + if "ORA-01920" not in str(e): # All errors except "user already exists" cancel tests - self.log('Got an error creating the test user: %s' % e) + self.log("Got an error creating the test user: %s" % e) sys.exit(2) if not autoclobber: confirm = input( "It appears the test user, %s, already exists. Type " - "'yes' to delete it, or 'no' to cancel: " % parameters['user']) - if autoclobber or confirm == 'yes': + "'yes' to delete it, or 'no' to cancel: " + % parameters["user"] + ) + if autoclobber or confirm == "yes": try: if verbosity >= 1: - self.log('Destroying old test user...') + self.log("Destroying old test user...") self._destroy_test_user(cursor, parameters, verbosity) if verbosity >= 1: - self.log('Creating test user...') - self._create_test_user(cursor, parameters, verbosity, keepdb) + self.log("Creating test user...") + self._create_test_user( + cursor, parameters, verbosity, keepdb + ) except Exception as e: - self.log('Got an error recreating the test user: %s' % e) + self.log("Got an error recreating the test user: %s" % e) sys.exit(2) else: - self.log('Tests cancelled.') + self.log("Tests cancelled.") sys.exit(1) self._maindb_connection.close() # done with main user -- test user and tablespaces created self._switch_to_test_user(parameters) - return self.connection.settings_dict['NAME'] + return self.connection.settings_dict["NAME"] def _switch_to_test_user(self, parameters): """ @@ -109,59 +131,71 @@ class DatabaseCreation(BaseDatabaseCreation): credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict. """ real_settings = settings.DATABASES[self.connection.alias] - real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \ - self.connection.settings_dict['USER'] - real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \ - self.connection.settings_dict['PASSWORD'] - real_test_settings = real_settings['TEST'] - test_settings = self.connection.settings_dict['TEST'] - real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \ - self.connection.settings_dict['USER'] = parameters['user'] - real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password'] + real_settings["SAVED_USER"] = self.connection.settings_dict[ + "SAVED_USER" + ] = self.connection.settings_dict["USER"] + real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[ + "SAVED_PASSWORD" + ] = self.connection.settings_dict["PASSWORD"] + real_test_settings = real_settings["TEST"] + test_settings = self.connection.settings_dict["TEST"] + real_test_settings["USER"] = real_settings["USER"] = test_settings[ + "USER" + ] = self.connection.settings_dict["USER"] = parameters["user"] + real_settings["PASSWORD"] = self.connection.settings_dict[ + "PASSWORD" + ] = parameters["password"] def set_as_test_mirror(self, primary_settings_dict): """ Set this database up to be used in testing as a mirror of a primary database whose settings are given. """ - self.connection.settings_dict['USER'] = primary_settings_dict['USER'] - self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD'] + self.connection.settings_dict["USER"] = primary_settings_dict["USER"] + self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"] - def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber): + def _handle_objects_preventing_db_destruction( + self, cursor, parameters, verbosity, autoclobber + ): # There are objects in the test tablespace which prevent dropping it # The easy fix is to drop the test user -- but are we allowed to do so? self.log( - 'There are objects in the old test database which prevent its destruction.\n' - 'If they belong to the test user, deleting the user will allow the test ' - 'database to be recreated.\n' - 'Otherwise, you will need to find and remove each of these objects, ' - 'or use a different tablespace.\n' + "There are objects in the old test database which prevent its destruction.\n" + "If they belong to the test user, deleting the user will allow the test " + "database to be recreated.\n" + "Otherwise, you will need to find and remove each of these objects, " + "or use a different tablespace.\n" ) if self._test_user_create(): if not autoclobber: - confirm = input("Type 'yes' to delete user %s: " % parameters['user']) - if autoclobber or confirm == 'yes': + confirm = input("Type 'yes' to delete user %s: " % parameters["user"]) + if autoclobber or confirm == "yes": try: if verbosity >= 1: - self.log('Destroying old test user...') + self.log("Destroying old test user...") self._destroy_test_user(cursor, parameters, verbosity) except Exception as e: - self.log('Got an error destroying the test user: %s' % e) + self.log("Got an error destroying the test user: %s" % e) sys.exit(2) try: if verbosity >= 1: - self.log("Destroying old test database for alias '%s'..." % self.connection.alias) + self.log( + "Destroying old test database for alias '%s'..." + % self.connection.alias + ) self._execute_test_db_destruction(cursor, parameters, verbosity) except Exception as e: - self.log('Got an error destroying the test database: %s' % e) + self.log("Got an error destroying the test database: %s" % e) sys.exit(2) else: - self.log('Tests cancelled -- test database cannot be recreated.') + self.log("Tests cancelled -- test database cannot be recreated.") sys.exit(1) else: - self.log("Django is configured to use pre-existing test user '%s'," - " and will not attempt to delete it." % parameters['user']) - self.log('Tests cancelled -- test database cannot be recreated.') + self.log( + "Django is configured to use pre-existing test user '%s'," + " and will not attempt to delete it." % parameters["user"] + ) + self.log("Tests cancelled -- test database cannot be recreated.") sys.exit(1) def _destroy_test_db(self, test_database_name, verbosity=1): @@ -169,24 +203,28 @@ class DatabaseCreation(BaseDatabaseCreation): Destroy a test database, prompting the user for confirmation if the database already exists. Return the name of the test database created. """ - self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER'] - self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] + self.connection.settings_dict["USER"] = self.connection.settings_dict[ + "SAVED_USER" + ] + self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[ + "SAVED_PASSWORD" + ] self.connection.close() parameters = self._get_test_db_params() with self._maindb_connection.cursor() as cursor: if self._test_user_create(): if verbosity >= 1: - self.log('Destroying test user...') + self.log("Destroying test user...") self._destroy_test_user(cursor, parameters, verbosity) if self._test_database_create(): if verbosity >= 1: - self.log('Destroying test database tables...') + self.log("Destroying test database tables...") self._execute_test_db_destruction(cursor, parameters, verbosity) self._maindb_connection.close() def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False): if verbosity >= 2: - self.log('_create_test_db(): dbname = %s' % parameters['user']) + self.log("_create_test_db(): dbname = %s" % parameters["user"]) if self._test_database_oracle_managed_files(): statements = [ """ @@ -214,12 +252,14 @@ class DatabaseCreation(BaseDatabaseCreation): """, ] # Ignore "tablespace already exists" error when keepdb is on. - acceptable_ora_err = 'ORA-01543' if keepdb else None - self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err) + acceptable_ora_err = "ORA-01543" if keepdb else None + self._execute_allow_fail_statements( + cursor, statements, parameters, verbosity, acceptable_ora_err + ) def _create_test_user(self, cursor, parameters, verbosity, keepdb=False): if verbosity >= 2: - self.log('_create_test_user(): username = %s' % parameters['user']) + self.log("_create_test_user(): username = %s" % parameters["user"]) statements = [ """CREATE USER %(user)s IDENTIFIED BY "%(password)s" @@ -235,40 +275,49 @@ class DatabaseCreation(BaseDatabaseCreation): TO %(user)s""", ] # Ignore "user already exists" error when keepdb is on - acceptable_ora_err = 'ORA-01920' if keepdb else None - success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err) + acceptable_ora_err = "ORA-01920" if keepdb else None + success = self._execute_allow_fail_statements( + cursor, statements, parameters, verbosity, acceptable_ora_err + ) # If the password was randomly generated, change the user accordingly. - if not success and self._test_settings_get('PASSWORD') is None: + if not success and self._test_settings_get("PASSWORD") is None: set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"' self._execute_statements(cursor, [set_password], parameters, verbosity) # Most test suites can be run without "create view" and # "create materialized view" privileges. But some need it. - for object_type in ('VIEW', 'MATERIALIZED VIEW'): - extra = 'GRANT CREATE %(object_type)s TO %(user)s' - parameters['object_type'] = object_type - success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031') + for object_type in ("VIEW", "MATERIALIZED VIEW"): + extra = "GRANT CREATE %(object_type)s TO %(user)s" + parameters["object_type"] = object_type + success = self._execute_allow_fail_statements( + cursor, [extra], parameters, verbosity, "ORA-01031" + ) if not success and verbosity >= 2: - self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type) + self.log( + "Failed to grant CREATE %s permission to test user. This may be ok." + % object_type + ) def _execute_test_db_destruction(self, cursor, parameters, verbosity): if verbosity >= 2: - self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user']) + self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"]) statements = [ - 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', - 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', + "DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS", + "DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS", ] self._execute_statements(cursor, statements, parameters, verbosity) def _destroy_test_user(self, cursor, parameters, verbosity): if verbosity >= 2: - self.log('_destroy_test_user(): user=%s' % parameters['user']) - self.log('Be patient. This can take some time...') + self.log("_destroy_test_user(): user=%s" % parameters["user"]) + self.log("Be patient. This can take some time...") statements = [ - 'DROP USER %(user)s CASCADE', + "DROP USER %(user)s CASCADE", ] self._execute_statements(cursor, statements, parameters, verbosity) - def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False): + def _execute_statements( + self, cursor, statements, parameters, verbosity, allow_quiet_fail=False + ): for template in statements: stmt = template % parameters if verbosity >= 2: @@ -277,10 +326,12 @@ class DatabaseCreation(BaseDatabaseCreation): cursor.execute(stmt) except Exception as err: if (not allow_quiet_fail) or verbosity >= 2: - self.log('Failed (%s)' % (err)) + self.log("Failed (%s)" % (err)) raise - def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err): + def _execute_allow_fail_statements( + self, cursor, statements, parameters, verbosity, acceptable_ora_err + ): """ Execute statements which are allowed to fail silently if the Oracle error code given by `acceptable_ora_err` is raised. Return True if the @@ -288,8 +339,16 @@ class DatabaseCreation(BaseDatabaseCreation): """ try: # Statement can fail when acceptable_ora_err is not None - allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0 - self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail) + allow_quiet_fail = ( + acceptable_ora_err is not None and len(acceptable_ora_err) > 0 + ) + self._execute_statements( + cursor, + statements, + parameters, + verbosity, + allow_quiet_fail=allow_quiet_fail, + ) return True except DatabaseError as err: description = str(err) @@ -299,19 +358,19 @@ class DatabaseCreation(BaseDatabaseCreation): def _get_test_db_params(self): return { - 'dbname': self._test_database_name(), - 'user': self._test_database_user(), - 'password': self._test_database_passwd(), - 'tblspace': self._test_database_tblspace(), - 'tblspace_temp': self._test_database_tblspace_tmp(), - 'datafile': self._test_database_tblspace_datafile(), - 'datafile_tmp': self._test_database_tblspace_tmp_datafile(), - 'maxsize': self._test_database_tblspace_maxsize(), - 'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(), - 'size': self._test_database_tblspace_size(), - 'size_tmp': self._test_database_tblspace_tmp_size(), - 'extsize': self._test_database_tblspace_extsize(), - 'extsize_tmp': self._test_database_tblspace_tmp_extsize(), + "dbname": self._test_database_name(), + "user": self._test_database_user(), + "password": self._test_database_passwd(), + "tblspace": self._test_database_tblspace(), + "tblspace_temp": self._test_database_tblspace_tmp(), + "datafile": self._test_database_tblspace_datafile(), + "datafile_tmp": self._test_database_tblspace_tmp_datafile(), + "maxsize": self._test_database_tblspace_maxsize(), + "maxsize_tmp": self._test_database_tblspace_tmp_maxsize(), + "size": self._test_database_tblspace_size(), + "size_tmp": self._test_database_tblspace_tmp_size(), + "extsize": self._test_database_tblspace_extsize(), + "extsize_tmp": self._test_database_tblspace_tmp_extsize(), } def _test_settings_get(self, key, default=None, prefixed=None): @@ -320,66 +379,67 @@ class DatabaseCreation(BaseDatabaseCreation): prefixed entry from the main settings dict. """ settings_dict = self.connection.settings_dict - val = settings_dict['TEST'].get(key, default) + val = settings_dict["TEST"].get(key, default) if val is None and prefixed: val = TEST_DATABASE_PREFIX + settings_dict[prefixed] return val def _test_database_name(self): - return self._test_settings_get('NAME', prefixed='NAME') + return self._test_settings_get("NAME", prefixed="NAME") def _test_database_create(self): - return self._test_settings_get('CREATE_DB', default=True) + return self._test_settings_get("CREATE_DB", default=True) def _test_user_create(self): - return self._test_settings_get('CREATE_USER', default=True) + return self._test_settings_get("CREATE_USER", default=True) def _test_database_user(self): - return self._test_settings_get('USER', prefixed='USER') + return self._test_settings_get("USER", prefixed="USER") def _test_database_passwd(self): - password = self._test_settings_get('PASSWORD') + password = self._test_settings_get("PASSWORD") if password is None and self._test_user_create(): # Oracle passwords are limited to 30 chars and can't contain symbols. password = get_random_string(30) return password def _test_database_tblspace(self): - return self._test_settings_get('TBLSPACE', prefixed='USER') + return self._test_settings_get("TBLSPACE", prefixed="USER") def _test_database_tblspace_tmp(self): settings_dict = self.connection.settings_dict - return settings_dict['TEST'].get('TBLSPACE_TMP', - TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp') + return settings_dict["TEST"].get( + "TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp" + ) def _test_database_tblspace_datafile(self): - tblspace = '%s.dbf' % self._test_database_tblspace() - return self._test_settings_get('DATAFILE', default=tblspace) + tblspace = "%s.dbf" % self._test_database_tblspace() + return self._test_settings_get("DATAFILE", default=tblspace) def _test_database_tblspace_tmp_datafile(self): - tblspace = '%s.dbf' % self._test_database_tblspace_tmp() - return self._test_settings_get('DATAFILE_TMP', default=tblspace) + tblspace = "%s.dbf" % self._test_database_tblspace_tmp() + return self._test_settings_get("DATAFILE_TMP", default=tblspace) def _test_database_tblspace_maxsize(self): - return self._test_settings_get('DATAFILE_MAXSIZE', default='500M') + return self._test_settings_get("DATAFILE_MAXSIZE", default="500M") def _test_database_tblspace_tmp_maxsize(self): - return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M') + return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M") def _test_database_tblspace_size(self): - return self._test_settings_get('DATAFILE_SIZE', default='50M') + return self._test_settings_get("DATAFILE_SIZE", default="50M") def _test_database_tblspace_tmp_size(self): - return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M') + return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M") def _test_database_tblspace_extsize(self): - return self._test_settings_get('DATAFILE_EXTSIZE', default='25M') + return self._test_settings_get("DATAFILE_EXTSIZE", default="25M") def _test_database_tblspace_tmp_extsize(self): - return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M') + return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M") def _test_database_oracle_managed_files(self): - return self._test_settings_get('ORACLE_MANAGED_FILES', default=False) + return self._test_settings_get("ORACLE_MANAGED_FILES", default=False) def _get_test_db_name(self): """ @@ -387,14 +447,14 @@ class DatabaseCreation(BaseDatabaseCreation): to work. This isn't a great deal in this case because DB names as handled by Django don't have real counterparts in Oracle. """ - return self.connection.settings_dict['NAME'] + return self.connection.settings_dict["NAME"] def test_db_signature(self): settings_dict = self.connection.settings_dict return ( - settings_dict['HOST'], - settings_dict['PORT'], - settings_dict['ENGINE'], - settings_dict['NAME'], + settings_dict["HOST"], + settings_dict["PORT"], + settings_dict["ENGINE"], + settings_dict["NAME"], self._test_database_user(), ) diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 898a82e5d5..6a3c9dab79 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -65,51 +65,51 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_json_field_contains = False supports_collation_on_textfield = False test_collations = { - 'ci': 'BINARY_CI', - 'cs': 'BINARY', - 'non_default': 'SWEDISH_CI', - 'swedish_ci': 'SWEDISH_CI', + "ci": "BINARY_CI", + "cs": "BINARY", + "non_default": "SWEDISH_CI", + "swedish_ci": "SWEDISH_CI", } test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'" django_test_skips = { "Oracle doesn't support SHA224.": { - 'db_functions.text.test_sha224.SHA224Tests.test_basic', - 'db_functions.text.test_sha224.SHA224Tests.test_transform', + "db_functions.text.test_sha224.SHA224Tests.test_basic", + "db_functions.text.test_sha224.SHA224Tests.test_transform", }, "Oracle doesn't correctly calculate ISO 8601 week numbering before " "1583 (the Gregorian calendar was introduced in 1582).": { - 'db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_week_before_1000', - 'db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_week_before_1000', + "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_week_before_1000", + "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_week_before_1000", }, "Oracle doesn't support bitwise XOR.": { - 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor', - 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null', - 'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null', + "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor", + "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null", + "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", }, "Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": { - 'expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering', + "expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering", }, - 'Raises ORA-00600: internal error code.': { - 'model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery', + "Raises ORA-00600: internal error code.": { + "model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery", }, } django_test_expected_failures = { # A bug in Django/cx_Oracle with respect to string handling (#23843). - 'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions', - 'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions', + "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions", + "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions", } @cached_property def introspected_field_types(self): return { **super().introspected_field_types, - 'GenericIPAddressField': 'CharField', - 'PositiveBigIntegerField': 'BigIntegerField', - 'PositiveIntegerField': 'IntegerField', - 'PositiveSmallIntegerField': 'IntegerField', - 'SmallIntegerField': 'IntegerField', - 'TimeField': 'DateTimeField', + "GenericIPAddressField": "CharField", + "PositiveBigIntegerField": "BigIntegerField", + "PositiveIntegerField": "IntegerField", + "PositiveSmallIntegerField": "IntegerField", + "SmallIntegerField": "IntegerField", + "TimeField": "DateTimeField", } @cached_property diff --git a/django/db/backends/oracle/functions.py b/django/db/backends/oracle/functions.py index 1aeb4597e3..936cc9e73f 100644 --- a/django/db/backends/oracle/functions.py +++ b/django/db/backends/oracle/functions.py @@ -2,7 +2,7 @@ from django.db.models import DecimalField, DurationField, Func class IntervalToSeconds(Func): - function = '' + function = "" template = """ EXTRACT(day from %(expressions)s) * 86400 + EXTRACT(hour from %(expressions)s) * 3600 + @@ -11,12 +11,16 @@ class IntervalToSeconds(Func): """ def __init__(self, expression, *, output_field=None, **extra): - super().__init__(expression, output_field=output_field or DecimalField(), **extra) + super().__init__( + expression, output_field=output_field or DecimalField(), **extra + ) class SecondsToInterval(Func): - function = 'NUMTODSINTERVAL' + function = "NUMTODSINTERVAL" template = "%(function)s(%(expressions)s, 'SECOND')" def __init__(self, expression, *, output_field=None, **extra): - super().__init__(expression, output_field=output_field or DurationField(), **extra) + super().__init__( + expression, output_field=output_field or DurationField(), **extra + ) diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index b8882e3cd8..17ffd3a99d 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -3,12 +3,12 @@ from collections import namedtuple import cx_Oracle from django.db import models -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo from django.utils.functional import cached_property -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json')) +FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json")) class DatabaseIntrospection(BaseDatabaseIntrospection): @@ -19,33 +19,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def data_types_reverse(self): if self.connection.cx_oracle_version < (8,): return { - cx_Oracle.BLOB: 'BinaryField', - cx_Oracle.CLOB: 'TextField', - cx_Oracle.DATETIME: 'DateField', - cx_Oracle.FIXED_CHAR: 'CharField', - cx_Oracle.FIXED_NCHAR: 'CharField', - cx_Oracle.INTERVAL: 'DurationField', - cx_Oracle.NATIVE_FLOAT: 'FloatField', - cx_Oracle.NCHAR: 'CharField', - cx_Oracle.NCLOB: 'TextField', - cx_Oracle.NUMBER: 'DecimalField', - cx_Oracle.STRING: 'CharField', - cx_Oracle.TIMESTAMP: 'DateTimeField', + cx_Oracle.BLOB: "BinaryField", + cx_Oracle.CLOB: "TextField", + cx_Oracle.DATETIME: "DateField", + cx_Oracle.FIXED_CHAR: "CharField", + cx_Oracle.FIXED_NCHAR: "CharField", + cx_Oracle.INTERVAL: "DurationField", + cx_Oracle.NATIVE_FLOAT: "FloatField", + cx_Oracle.NCHAR: "CharField", + cx_Oracle.NCLOB: "TextField", + cx_Oracle.NUMBER: "DecimalField", + cx_Oracle.STRING: "CharField", + cx_Oracle.TIMESTAMP: "DateTimeField", } else: return { - cx_Oracle.DB_TYPE_DATE: 'DateField', - cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'FloatField', - cx_Oracle.DB_TYPE_BLOB: 'BinaryField', - cx_Oracle.DB_TYPE_CHAR: 'CharField', - cx_Oracle.DB_TYPE_CLOB: 'TextField', - cx_Oracle.DB_TYPE_INTERVAL_DS: 'DurationField', - cx_Oracle.DB_TYPE_NCHAR: 'CharField', - cx_Oracle.DB_TYPE_NCLOB: 'TextField', - cx_Oracle.DB_TYPE_NVARCHAR: 'CharField', - cx_Oracle.DB_TYPE_NUMBER: 'DecimalField', - cx_Oracle.DB_TYPE_TIMESTAMP: 'DateTimeField', - cx_Oracle.DB_TYPE_VARCHAR: 'CharField', + cx_Oracle.DB_TYPE_DATE: "DateField", + cx_Oracle.DB_TYPE_BINARY_DOUBLE: "FloatField", + cx_Oracle.DB_TYPE_BLOB: "BinaryField", + cx_Oracle.DB_TYPE_CHAR: "CharField", + cx_Oracle.DB_TYPE_CLOB: "TextField", + cx_Oracle.DB_TYPE_INTERVAL_DS: "DurationField", + cx_Oracle.DB_TYPE_NCHAR: "CharField", + cx_Oracle.DB_TYPE_NCLOB: "TextField", + cx_Oracle.DB_TYPE_NVARCHAR: "CharField", + cx_Oracle.DB_TYPE_NUMBER: "DecimalField", + cx_Oracle.DB_TYPE_TIMESTAMP: "DateTimeField", + cx_Oracle.DB_TYPE_VARCHAR: "CharField", } def get_field_type(self, data_type, description): @@ -53,25 +53,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): precision, scale = description[4:6] if scale == 0: if precision > 11: - return 'BigAutoField' if description.is_autofield else 'BigIntegerField' + return ( + "BigAutoField" + if description.is_autofield + else "BigIntegerField" + ) elif 1 < precision < 6 and description.is_autofield: - return 'SmallAutoField' + return "SmallAutoField" elif precision == 1: - return 'BooleanField' + return "BooleanField" elif description.is_autofield: - return 'AutoField' + return "AutoField" else: - return 'IntegerField' + return "IntegerField" elif scale == -127: - return 'FloatField' + return "FloatField" elif data_type == cx_Oracle.NCLOB and description.is_json: - return 'JSONField' + return "JSONField" return super().get_field_type(data_type, description) def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" - cursor.execute(""" + cursor.execute( + """ SELECT table_name, 't' FROM user_tables WHERE @@ -84,8 +89,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): SELECT view_name, 'v' FROM user_views UNION ALL SELECT mview_name, 'v' FROM user_mviews - """) - return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()] + """ + ) + return [ + TableInfo(self.identifier_converter(row[0]), row[1]) + for row in cursor.fetchall() + ] def get_table_description(self, cursor, table_name): """ @@ -131,22 +140,40 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): [table_name], ) field_map = { - column: (internal_size, default if default != 'NULL' else None, collation, is_autofield, is_json) + column: ( + internal_size, + default if default != "NULL" else None, + collation, + is_autofield, + is_json, + ) for column, default, collation, internal_size, is_autofield, is_json in cursor.fetchall() } self.cache_bust_counter += 1 - cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format( - self.connection.ops.quote_name(table_name), - self.cache_bust_counter)) + cursor.execute( + "SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format( + self.connection.ops.quote_name(table_name), self.cache_bust_counter + ) + ) description = [] for desc in cursor.description: name = desc[0] internal_size, default, collation, is_autofield, is_json = field_map[name] name = name % {} # cx_Oracle, for some reason, doubles percent signs. - description.append(FieldInfo( - self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0, - desc[5] or 0, *desc[6:], default, collation, is_autofield, is_json, - )) + description.append( + FieldInfo( + self.identifier_converter(name), + *desc[1:3], + internal_size, + desc[4] or 0, + desc[5] or 0, + *desc[6:], + default, + collation, + is_autofield, + is_json, + ) + ) return description def identifier_converter(self, name): @@ -175,16 +202,18 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Oracle allows only one identity column per table. row = cursor.fetchone() if row: - return [{ - 'name': self.identifier_converter(row[0]), - 'table': self.identifier_converter(table_name), - 'column': self.identifier_converter(row[1]), - }] + return [ + { + "name": self.identifier_converter(row[0]), + "table": self.identifier_converter(table_name), + "column": self.identifier_converter(row[1]), + } + ] # To keep backward compatibility for AutoFields that aren't Oracle # identity columns. for f in table_fields: if isinstance(f, models.AutoField): - return [{'table': table_name, 'column': f.column}] + return [{"table": table_name, "column": f.column}] return [] def get_relations(self, cursor, table_name): @@ -193,19 +222,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): representing all foreign keys in the given table. """ table_name = table_name.upper() - cursor.execute(""" + cursor.execute( + """ SELECT ca.column_name, cb.table_name, cb.column_name FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb WHERE user_constraints.table_name = %s AND user_constraints.constraint_name = ca.constraint_name AND user_constraints.r_constraint_name = cb.constraint_name AND - ca.position = cb.position""", [table_name]) + ca.position = cb.position""", + [table_name], + ) return { self.identifier_converter(field_name): ( self.identifier_converter(rel_field_name), self.identifier_converter(rel_table_name), - ) for field_name, rel_table_name, rel_field_name in cursor.fetchall() + ) + for field_name, rel_table_name, rel_field_name in cursor.fetchall() } def get_primary_key_column(self, cursor, table_name): @@ -265,12 +298,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): for constraint, columns, pk, unique, check in cursor.fetchall(): constraint = self.identifier_converter(constraint) constraints[constraint] = { - 'columns': columns.split(','), - 'primary_key': pk, - 'unique': unique, - 'foreign_key': None, - 'check': check, - 'index': unique, # All uniques come with an index + "columns": columns.split(","), + "primary_key": pk, + "unique": unique, + "foreign_key": None, + "check": check, + "index": unique, # All uniques come with an index } # Foreign key constraints cursor.execute( @@ -296,12 +329,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): for constraint, columns, other_table, other_column in cursor.fetchall(): constraint = self.identifier_converter(constraint) constraints[constraint] = { - 'primary_key': False, - 'unique': False, - 'foreign_key': (other_table, other_column), - 'check': False, - 'index': False, - 'columns': columns.split(','), + "primary_key": False, + "unique": False, + "foreign_key": (other_table, other_column), + "check": False, + "index": False, + "columns": columns.split(","), } # Now get indexes cursor.execute( @@ -328,13 +361,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): for constraint, type_, unique, columns, orders in cursor.fetchall(): constraint = self.identifier_converter(constraint) constraints[constraint] = { - 'primary_key': False, - 'unique': unique == 'unique', - 'foreign_key': None, - 'check': False, - 'index': True, - 'type': 'idx' if type_ == 'normal' else type_, - 'columns': columns.split(','), - 'orders': orders.split(','), + "primary_key": False, + "unique": unique == "unique", + "foreign_key": None, + "check": False, + "index": True, + "type": "idx" if type_ == "normal" else type_, + "columns": columns.split(","), + "orders": orders.split(","), } return constraints diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index d53942b919..63f9714333 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -5,9 +5,7 @@ from functools import lru_cache from django.conf import settings from django.db import DatabaseError, NotSupportedError from django.db.backends.base.operations import BaseDatabaseOperations -from django.db.backends.utils import ( - split_tzname_delta, strip_quotes, truncate_name, -) +from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup from django.db.models.expressions import RawSQL from django.db.models.sql.where import WhereNode @@ -25,17 +23,17 @@ class DatabaseOperations(BaseDatabaseOperations): # SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by # SmallAutoField, to preserve backward compatibility. integer_field_ranges = { - 'SmallIntegerField': (-99999999999, 99999999999), - 'IntegerField': (-99999999999, 99999999999), - 'BigIntegerField': (-9999999999999999999, 9999999999999999999), - 'PositiveBigIntegerField': (0, 9999999999999999999), - 'PositiveSmallIntegerField': (0, 99999999999), - 'PositiveIntegerField': (0, 99999999999), - 'SmallAutoField': (-99999, 99999), - 'AutoField': (-99999999999, 99999999999), - 'BigAutoField': (-9999999999999999999, 9999999999999999999), + "SmallIntegerField": (-99999999999, 99999999999), + "IntegerField": (-99999999999, 99999999999), + "BigIntegerField": (-9999999999999999999, 9999999999999999999), + "PositiveBigIntegerField": (0, 9999999999999999999), + "PositiveSmallIntegerField": (0, 99999999999), + "PositiveIntegerField": (0, 99999999999), + "SmallAutoField": (-99999, 99999), + "AutoField": (-99999999999, 99999999999), + "BigAutoField": (-9999999999999999999, 9999999999999999999), } - set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'} + set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"} # TODO: colorize this SQL code with style.SQL_KEYWORD(), etc. _sequence_reset_sql = """ @@ -63,34 +61,34 @@ END; /""" # Oracle doesn't support string without precision; use the max string size. - cast_char_field_without_max_length = 'NVARCHAR2(2000)' + cast_char_field_without_max_length = "NVARCHAR2(2000)" cast_data_types = { - 'AutoField': 'NUMBER(11)', - 'BigAutoField': 'NUMBER(19)', - 'SmallAutoField': 'NUMBER(5)', - 'TextField': cast_char_field_without_max_length, + "AutoField": "NUMBER(11)", + "BigAutoField": "NUMBER(19)", + "SmallAutoField": "NUMBER(5)", + "TextField": cast_char_field_without_max_length, } def cache_key_culling_sql(self): - cache_key = self.quote_name('cache_key') + cache_key = self.quote_name("cache_key") return ( - f'SELECT {cache_key} ' - f'FROM %s ' - f'ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY' + f"SELECT {cache_key} " + f"FROM %s " + f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY" ) def date_extract_sql(self, lookup_type, field_name): - if lookup_type == 'week_day': + if lookup_type == "week_day": # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday. return "TO_CHAR(%s, 'D')" % field_name - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return "TO_CHAR(%s - 1, 'D')" % field_name - elif lookup_type == 'week': + elif lookup_type == "week": # IW = ISO week number return "TO_CHAR(%s, 'IW')" % field_name - elif lookup_type == 'quarter': + elif lookup_type == "quarter": return "TO_CHAR(%s, 'Q')" % field_name - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return "TO_CHAR(%s, 'IYYY')" % field_name else: # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html @@ -99,11 +97,11 @@ END; def date_trunc_sql(self, lookup_type, field_name, tzname=None): field_name = self._convert_field_to_tz(field_name, tzname) # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html - if lookup_type in ('year', 'month'): + if lookup_type in ("year", "month"): return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) - elif lookup_type == 'quarter': + elif lookup_type == "quarter": return "TRUNC(%s, 'Q')" % field_name - elif lookup_type == 'week': + elif lookup_type == "week": return "TRUNC(%s, 'IW')" % field_name else: return "TRUNC(%s)" % field_name @@ -112,11 +110,11 @@ END; # if the time zone name is passed in parameter. Use interpolation instead. # https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ # This regexp matches all time zone names from the zoneinfo database. - _tzname_re = _lazy_re_compile(r'^[\w/:+-]+$') + _tzname_re = _lazy_re_compile(r"^[\w/:+-]+$") def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) - return f'{sign}{offset}' if offset else tzname + return f"{sign}{offset}" if offset else tzname def _convert_field_to_tz(self, field_name, tzname): if not (settings.USE_TZ and tzname): @@ -136,7 +134,7 @@ END; def datetime_cast_date_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return 'TRUNC(%s)' % field_name + return "TRUNC(%s)" % field_name def datetime_cast_time_sql(self, field_name, tzname): # Since `TimeField` values are stored as TIMESTAMP change to the @@ -146,7 +144,8 @@ END; "'YYYY-MM-DD HH24:MI:SS.FF')" ) % self._convert_field_to_tz(field_name, tzname) return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % ( - field_name, convert_datetime_sql, + field_name, + convert_datetime_sql, ) def datetime_extract_sql(self, lookup_type, field_name, tzname): @@ -156,20 +155,22 @@ END; def datetime_trunc_sql(self, lookup_type, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html - if lookup_type in ('year', 'month'): + if lookup_type in ("year", "month"): sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) - elif lookup_type == 'quarter': + elif lookup_type == "quarter": sql = "TRUNC(%s, 'Q')" % field_name - elif lookup_type == 'week': + elif lookup_type == "week": sql = "TRUNC(%s, 'IW')" % field_name - elif lookup_type == 'day': + elif lookup_type == "day": sql = "TRUNC(%s)" % field_name - elif lookup_type == 'hour': + elif lookup_type == "hour": sql = "TRUNC(%s, 'HH24')" % field_name - elif lookup_type == 'minute': + elif lookup_type == "minute": sql = "TRUNC(%s, 'MI')" % field_name else: - sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision. + sql = ( + "CAST(%s AS DATE)" % field_name + ) # Cast to DATE removes sub-second precision. return sql def time_trunc_sql(self, lookup_type, field_name, tzname=None): @@ -177,31 +178,33 @@ END; # `DateTimeField` and `TimeField` are stored as TIMESTAMP where # the date part of the later is ignored. field_name = self._convert_field_to_tz(field_name, tzname) - if lookup_type == 'hour': + if lookup_type == "hour": sql = "TRUNC(%s, 'HH24')" % field_name - elif lookup_type == 'minute': + elif lookup_type == "minute": sql = "TRUNC(%s, 'MI')" % field_name - elif lookup_type == 'second': - sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision. + elif lookup_type == "second": + sql = ( + "CAST(%s AS DATE)" % field_name + ) # Cast to DATE removes sub-second precision. return sql def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type in ['JSONField', 'TextField']: + if internal_type in ["JSONField", "TextField"]: converters.append(self.convert_textfield_value) - elif internal_type == 'BinaryField': + elif internal_type == "BinaryField": converters.append(self.convert_binaryfield_value) - elif internal_type == 'BooleanField': + elif internal_type == "BooleanField": converters.append(self.convert_booleanfield_value) - elif internal_type == 'DateTimeField': + elif internal_type == "DateTimeField": if settings.USE_TZ: converters.append(self.convert_datetimefield_value) - elif internal_type == 'DateField': + elif internal_type == "DateField": converters.append(self.convert_datefield_value) - elif internal_type == 'TimeField': + elif internal_type == "TimeField": converters.append(self.convert_timefield_value) - elif internal_type == 'UUIDField': + elif internal_type == "UUIDField": converters.append(self.convert_uuidfield_value) # Oracle stores empty strings as null. If the field accepts the empty # string, undo this to adhere to the Django convention of using @@ -209,8 +212,8 @@ END; if expression.output_field.empty_strings_allowed: converters.append( self.convert_empty_bytes - if internal_type == 'BinaryField' else - self.convert_empty_string + if internal_type == "BinaryField" + else self.convert_empty_string ) return converters @@ -255,11 +258,11 @@ END; @staticmethod def convert_empty_string(value, expression, connection): - return '' if value is None else value + return "" if value is None else value @staticmethod def convert_empty_bytes(value, expression, connection): - return b'' if value is None else value + return b"" if value is None else value def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" @@ -270,16 +273,16 @@ END; value = param.get_value() if value == []: raise DatabaseError( - 'The database did not return a new row id. Probably ' + "The database did not return a new row id. Probably " '"ORA-1403: no data found" was raised internally but was ' - 'hidden by the Oracle OCI library (see ' - 'https://code.djangoproject.com/ticket/28859).' + "hidden by the Oracle OCI library (see " + "https://code.djangoproject.com/ticket/28859)." ) columns.append(value[0]) return tuple(columns) def field_cast_sql(self, db_type, internal_type): - if db_type and db_type.endswith('LOB') and internal_type != 'JSONField': + if db_type and db_type.endswith("LOB") and internal_type != "JSONField": return "DBMS_LOB.SUBSTR(%s)" else: return "%s" @@ -289,10 +292,14 @@ END; def limit_offset_sql(self, low_mark, high_mark): fetch, offset = self._get_limit_offset_params(low_mark, high_mark) - return ' '.join(sql for sql in ( - ('OFFSET %d ROWS' % offset) if offset else None, - ('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None, - ) if sql) + return " ".join( + sql + for sql in ( + ("OFFSET %d ROWS" % offset) if offset else None, + ("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None, + ) + if sql + ) def last_executed_query(self, cursor, sql, params): # https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.statement @@ -303,10 +310,14 @@ END; # parameters manually. if isinstance(params, (tuple, list)): for i, param in enumerate(params): - statement = statement.replace(':arg%d' % i, force_str(param, errors='replace')) + statement = statement.replace( + ":arg%d" % i, force_str(param, errors="replace") + ) elif isinstance(params, dict): for key, param in params.items(): - statement = statement.replace(':%s' % key, force_str(param, errors='replace')) + statement = statement.replace( + ":%s" % key, force_str(param, errors="replace") + ) return statement def last_insert_id(self, cursor, table_name, pk_name): @@ -315,10 +326,10 @@ END; return cursor.fetchone()[0] def lookup_cast(self, lookup_type, internal_type=None): - if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): + if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): return "UPPER(%s)" - if internal_type == 'JSONField' and lookup_type == 'exact': - return 'DBMS_LOB.SUBSTR(%s)' + if internal_type == "JSONField" and lookup_type == "exact": + return "DBMS_LOB.SUBSTR(%s)" return "%s" def max_in_list_size(self): @@ -335,7 +346,7 @@ END; def process_clob(self, value): if value is None: - return '' + return "" return value.read() def quote_name(self, name): @@ -348,30 +359,33 @@ END; # Oracle puts the query text into a (query % args) construct, so % signs # in names need to be escaped. The '%%' will be collapsed back to '%' at # that stage so we aren't really making the name longer here. - name = name.replace('%', '%%') + name = name.replace("%", "%%") return name.upper() def regex_lookup(self, lookup_type): - if lookup_type == 'regex': + if lookup_type == "regex": match_option = "'c'" else: match_option = "'i'" - return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option + return "REGEXP_LIKE(%%s, %%s, %s)" % match_option def return_insert_columns(self, fields): if not fields: - return '', () + return "", () field_names = [] params = [] for field in fields: - field_names.append('%s.%s' % ( - self.quote_name(field.model._meta.db_table), - self.quote_name(field.column), - )) + field_names.append( + "%s.%s" + % ( + self.quote_name(field.model._meta.db_table), + self.quote_name(field.column), + ) + ) params.append(InsertVar(field)) - return 'RETURNING %s INTO %s' % ( - ', '.join(field_names), - ', '.join(['%s'] * len(params)), + return "RETURNING %s INTO %s" % ( + ", ".join(field_names), + ", ".join(["%s"] * len(params)), ), tuple(params) def __foreign_key_constraints(self, table_name, recursive): @@ -430,42 +444,54 @@ END; # which truncates all dependent tables by manually retrieving all # foreign key constraints and resolving dependencies. for table in tables: - for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade): + for foreign_table, constraint in self._foreign_key_constraints( + table, recursive=allow_cascade + ): if allow_cascade: truncated_tables.add(foreign_table) constraints.add((foreign_table, constraint)) - sql = [ - '%s %s %s %s %s %s %s %s;' % ( - style.SQL_KEYWORD('ALTER'), - style.SQL_KEYWORD('TABLE'), - style.SQL_FIELD(self.quote_name(table)), - style.SQL_KEYWORD('DISABLE'), - style.SQL_KEYWORD('CONSTRAINT'), - style.SQL_FIELD(self.quote_name(constraint)), - style.SQL_KEYWORD('KEEP'), - style.SQL_KEYWORD('INDEX'), - ) for table, constraint in constraints - ] + [ - '%s %s %s;' % ( - style.SQL_KEYWORD('TRUNCATE'), - style.SQL_KEYWORD('TABLE'), - style.SQL_FIELD(self.quote_name(table)), - ) for table in truncated_tables - ] + [ - '%s %s %s %s %s %s;' % ( - style.SQL_KEYWORD('ALTER'), - style.SQL_KEYWORD('TABLE'), - style.SQL_FIELD(self.quote_name(table)), - style.SQL_KEYWORD('ENABLE'), - style.SQL_KEYWORD('CONSTRAINT'), - style.SQL_FIELD(self.quote_name(constraint)), - ) for table, constraint in constraints - ] + sql = ( + [ + "%s %s %s %s %s %s %s %s;" + % ( + style.SQL_KEYWORD("ALTER"), + style.SQL_KEYWORD("TABLE"), + style.SQL_FIELD(self.quote_name(table)), + style.SQL_KEYWORD("DISABLE"), + style.SQL_KEYWORD("CONSTRAINT"), + style.SQL_FIELD(self.quote_name(constraint)), + style.SQL_KEYWORD("KEEP"), + style.SQL_KEYWORD("INDEX"), + ) + for table, constraint in constraints + ] + + [ + "%s %s %s;" + % ( + style.SQL_KEYWORD("TRUNCATE"), + style.SQL_KEYWORD("TABLE"), + style.SQL_FIELD(self.quote_name(table)), + ) + for table in truncated_tables + ] + + [ + "%s %s %s %s %s %s;" + % ( + style.SQL_KEYWORD("ALTER"), + style.SQL_KEYWORD("TABLE"), + style.SQL_FIELD(self.quote_name(table)), + style.SQL_KEYWORD("ENABLE"), + style.SQL_KEYWORD("CONSTRAINT"), + style.SQL_FIELD(self.quote_name(constraint)), + ) + for table, constraint in constraints + ] + ) if reset_sequences: sequences = [ sequence for sequence in self.connection.introspection.sequence_list() - if sequence['table'].upper() in truncated_tables + if sequence["table"].upper() in truncated_tables ] # Since we've just deleted all the rows, running our sequence ALTER # code will reset the sequence to 0. @@ -475,15 +501,17 @@ END; def sequence_reset_by_name_sql(self, style, sequences): sql = [] for sequence_info in sequences: - no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table']) - table = self.quote_name(sequence_info['table']) - column = self.quote_name(sequence_info['column'] or 'id') + no_autofield_sequence_name = self._get_no_autofield_sequence_name( + sequence_info["table"] + ) + table = self.quote_name(sequence_info["table"]) + column = self.quote_name(sequence_info["column"] or "id") query = self._sequence_reset_sql % { - 'no_autofield_sequence_name': no_autofield_sequence_name, - 'table': table, - 'column': column, - 'table_name': strip_quotes(table), - 'column_name': strip_quotes(column), + "no_autofield_sequence_name": no_autofield_sequence_name, + "table": table, + "column": column, + "table_name": strip_quotes(table), + "column_name": strip_quotes(column), } sql.append(query) return sql @@ -494,23 +522,28 @@ END; for model in model_list: for f in model._meta.local_fields: if isinstance(f, AutoField): - no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table) + no_autofield_sequence_name = self._get_no_autofield_sequence_name( + model._meta.db_table + ) table = self.quote_name(model._meta.db_table) column = self.quote_name(f.column) - output.append(query % { - 'no_autofield_sequence_name': no_autofield_sequence_name, - 'table': table, - 'column': column, - 'table_name': strip_quotes(table), - 'column_name': strip_quotes(column), - }) + output.append( + query + % { + "no_autofield_sequence_name": no_autofield_sequence_name, + "table": table, + "column": column, + "table_name": strip_quotes(table), + "column_name": strip_quotes(column), + } + ) # Only one AutoField is allowed per model, so don't # continue to loop break return output def start_transaction_sql(self): - return '' + return "" def tablespace_sql(self, tablespace, inline=False): if inline: @@ -541,7 +574,7 @@ END; return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # cx_Oracle doesn't support tz-aware datetimes @@ -549,7 +582,9 @@ END; if settings.USE_TZ: value = timezone.make_naive(value, self.connection.timezone) else: - raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.") + raise ValueError( + "Oracle backend does not support timezone-aware datetimes when USE_TZ is False." + ) return Oracle_datetime.from_datetime(value) @@ -558,38 +593,39 @@ END; return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value if isinstance(value, str): - return datetime.datetime.strptime(value, '%H:%M:%S') + return datetime.datetime.strptime(value, "%H:%M:%S") # Oracle doesn't support tz-aware times if timezone.is_aware(value): raise ValueError("Oracle backend does not support timezone-aware times.") - return Oracle_datetime(1900, 1, 1, value.hour, value.minute, - value.second, value.microsecond) + return Oracle_datetime( + 1900, 1, 1, value.hour, value.minute, value.second, value.microsecond + ) def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): return value def combine_expression(self, connector, sub_expressions): lhs, rhs = sub_expressions - if connector == '%%': - return 'MOD(%s)' % ','.join(sub_expressions) - elif connector == '&': - return 'BITAND(%s)' % ','.join(sub_expressions) - elif connector == '|': - return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs} - elif connector == '<<': - return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs} - elif connector == '>>': - return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs} - elif connector == '^': - return 'POWER(%s)' % ','.join(sub_expressions) - elif connector == '#': - raise NotSupportedError('Bitwise XOR is not supported in Oracle.') + if connector == "%%": + return "MOD(%s)" % ",".join(sub_expressions) + elif connector == "&": + return "BITAND(%s)" % ",".join(sub_expressions) + elif connector == "|": + return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs} + elif connector == "<<": + return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} + elif connector == ">>": + return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} + elif connector == "^": + return "POWER(%s)" % ",".join(sub_expressions) + elif connector == "#": + raise NotSupportedError("Bitwise XOR is not supported in Oracle.") return super().combine_expression(connector, sub_expressions) def _get_no_autofield_sequence_name(self, table): @@ -598,14 +634,17 @@ END; AutoFields that aren't Oracle identity columns. """ name_length = self.max_name_length() - 3 - return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper() + return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper() def _get_sequence_name(self, cursor, table, pk_name): - cursor.execute(""" + cursor.execute( + """ SELECT sequence_name FROM user_tab_identity_cols WHERE table_name = UPPER(%s) - AND column_name = UPPER(%s)""", [table, pk_name]) + AND column_name = UPPER(%s)""", + [table, pk_name], + ) row = cursor.fetchone() return self._get_no_autofield_sequence_name(table) if row is None else row[0] @@ -616,26 +655,33 @@ END; for i, placeholder in enumerate(row): # A model without any fields has fields=[None]. if fields[i]: - internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type() - placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder + internal_type = getattr( + fields[i], "target_field", fields[i] + ).get_internal_type() + placeholder = ( + BulkInsertMapper.types.get(internal_type, "%s") % placeholder + ) # Add columns aliases to the first select to avoid "ORA-00918: # column ambiguously defined" when two or more columns in the # first select have the same value. if not query: - placeholder = '%s col_%s' % (placeholder, i) + placeholder = "%s col_%s" % (placeholder, i) select.append(placeholder) - query.append('SELECT %s FROM DUAL' % ', '.join(select)) + query.append("SELECT %s FROM DUAL" % ", ".join(select)) # Bulk insert to tables with Oracle identity columns causes Oracle to # add sequence.nextval to it. Sequence.nextval cannot be used with the # UNION operator. To prevent incorrect SQL, move UNION to a subquery. - return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query) + return "SELECT * FROM (%s)" % " UNION ALL ".join(query) def subtract_temporals(self, internal_type, lhs, rhs): - if internal_type == 'DateField': + if internal_type == "DateField": lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) - return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params + return ( + "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), + params, + ) return super().subtract_temporals(internal_type, lhs, rhs) def bulk_batch_size(self, fields, objs): @@ -652,7 +698,9 @@ END; if isinstance(expression, (Exists, Lookup, WhereNode)): return True if isinstance(expression, ExpressionWrapper) and expression.conditional: - return self.conditional_expression_supported_in_where_clause(expression.expression) + return self.conditional_expression_supported_in_where_clause( + expression.expression + ) if isinstance(expression, RawSQL) and expression.conditional: return True return False diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py index 98e49413c9..2b1027d6b5 100644 --- a/django/db/backends/oracle/schema.py +++ b/django/db/backends/oracle/schema.py @@ -4,7 +4,8 @@ import re from django.db import DatabaseError from django.db.backends.base.schema import ( - BaseDatabaseSchemaEditor, _related_non_m2m_objects, + BaseDatabaseSchemaEditor, + _related_non_m2m_objects, ) from django.utils.duration import duration_iso_string @@ -21,7 +22,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s" sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" - sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s' + sql_create_column_inline_fk = ( + "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" + ) sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS" sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s" @@ -31,7 +34,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): elif isinstance(value, datetime.timedelta): return "'%s'" % duration_iso_string(value) elif isinstance(value, str): - return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%') + return "'%s'" % value.replace("'", "''").replace("%", "%%") elif isinstance(value, (bytes, bytearray, memoryview)): return "'%s'" % value.hex() elif isinstance(value, bool): @@ -50,7 +53,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Run superclass action super().delete_model(model) # Clean up manually created sequence. - self.execute(""" + self.execute( + """ DECLARE i INTEGER; BEGIN @@ -60,7 +64,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"'; END IF; END; - /""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)}) + /""" + % { + "sq_name": self.connection.ops._get_no_autofield_sequence_name( + model._meta.db_table + ) + } + ) def alter_field(self, model, old_field, new_field, strict=False): try: @@ -69,16 +79,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): description = str(e) # If we're changing type to an unsupported type we need a # SQLite-ish workaround - if 'ORA-22858' in description or 'ORA-22859' in description: + if "ORA-22858" in description or "ORA-22859" in description: self._alter_field_type_workaround(model, old_field, new_field) # If an identity column is changing to a non-numeric type, drop the # identity first. - elif 'ORA-30675' in description: + elif "ORA-30675" in description: self._drop_identity(model._meta.db_table, old_field.column) self.alter_field(model, old_field, new_field, strict) # If a primary key column is changing to an identity column, drop # the primary key first. - elif 'ORA-30673' in description and old_field.primary_key: + elif "ORA-30673" in description and old_field.primary_key: self._delete_primary_key(model, strict=True) self._alter_field_type_workaround(model, old_field, new_field) else: @@ -98,7 +108,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Make a new field that's like the new one but with a temporary # column name. new_temp_field = copy.deepcopy(new_field) - new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField')) + new_temp_field.null = new_field.get_internal_type() not in ( + "AutoField", + "BigAutoField", + "SmallAutoField", + ) new_temp_field.column = self._generate_temp_name(new_field.column) # Add it self.add_field(model, new_temp_field) @@ -107,24 +121,30 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD new_value = self.quote_name(old_field.column) old_type = old_field.db_type(self.connection) - if re.match('^N?CLOB', old_type): + if re.match("^N?CLOB", old_type): new_value = "TO_CHAR(%s)" % new_value - old_type = 'VARCHAR2' - if re.match('^N?VARCHAR2', old_type): + old_type = "VARCHAR2" + if re.match("^N?VARCHAR2", old_type): new_internal_type = new_field.get_internal_type() - if new_internal_type == 'DateField': + if new_internal_type == "DateField": new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value - elif new_internal_type == 'DateTimeField': + elif new_internal_type == "DateTimeField": new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value - elif new_internal_type == 'TimeField': + elif new_internal_type == "TimeField": # TimeField are stored as TIMESTAMP with a 1900-01-01 date part. - new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value + new_value = ( + "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" + % new_value + ) # Transfer values across - self.execute("UPDATE %s set %s=%s" % ( - self.quote_name(model._meta.db_table), - self.quote_name(new_temp_field.column), - new_value, - )) + self.execute( + "UPDATE %s set %s=%s" + % ( + self.quote_name(model._meta.db_table), + self.quote_name(new_temp_field.column), + new_value, + ) + ) # Drop the old field self.remove_field(model, old_field) # Rename and possibly make the new field NOT NULL @@ -134,20 +154,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # new_field always match. new_type = new_field.db_type(self.connection) if ( - (old_field.primary_key and new_field.primary_key) or - (old_field.unique and new_field.unique) + (old_field.primary_key and new_field.primary_key) + or (old_field.unique and new_field.unique) ) and old_type != new_type: for _, rel in _related_non_m2m_objects(new_temp_field, new_field): if rel.field.db_constraint: - self.execute(self._create_fk_sql(rel.related_model, rel.field, '_fk')) + self.execute( + self._create_fk_sql(rel.related_model, rel.field, "_fk") + ) def _alter_column_type_sql(self, model, old_field, new_field, new_type): - auto_field_types = {'AutoField', 'BigAutoField', 'SmallAutoField'} + auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"} # Drop the identity if migrating away from AutoField. if ( - old_field.get_internal_type() in auto_field_types and - new_field.get_internal_type() not in auto_field_types and - self._is_identity_column(model._meta.db_table, new_field.column) + old_field.get_internal_type() in auto_field_types + and new_field.get_internal_type() not in auto_field_types + and self._is_identity_column(model._meta.db_table, new_field.column) ): self._drop_identity(model._meta.db_table, new_field.column) return super()._alter_column_type_sql(model, old_field, new_field, new_type) @@ -173,7 +195,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _field_should_be_indexed(self, model, field): create_index = super()._field_should_be_indexed(model, field) db_type = field.db_type(self.connection) - if db_type is not None and db_type.lower() in self.connection._limited_data_types: + if ( + db_type is not None + and db_type.lower() in self.connection._limited_data_types + ): return False return create_index @@ -193,10 +218,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): return row[0] if row else False def _drop_identity(self, table_name, column_name): - self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % { - 'table': self.quote_name(table_name), - 'column': self.quote_name(column_name), - }) + self.execute( + "ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY" + % { + "table": self.quote_name(table_name), + "column": self.quote_name(column_name), + } + ) def _get_default_collation(self, table_name): with self.connection.cursor() as cursor: @@ -211,4 +239,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _alter_column_collation_sql(self, model, new_field, new_type, new_collation): if new_collation is None: new_collation = self._get_default_collation(model._meta.db_table) - return super()._alter_column_collation_sql(model, new_field, new_type, new_collation) + return super()._alter_column_collation_sql( + model, new_field, new_type, new_collation + ) diff --git a/django/db/backends/oracle/utils.py b/django/db/backends/oracle/utils.py index e3786541af..8941a85967 100644 --- a/django/db/backends/oracle/utils.py +++ b/django/db/backends/oracle/utils.py @@ -9,24 +9,25 @@ class InsertVar: as a parameter, in order to receive the id of the row created by an insert statement. """ + types = { - 'AutoField': int, - 'BigAutoField': int, - 'SmallAutoField': int, - 'IntegerField': int, - 'BigIntegerField': int, - 'SmallIntegerField': int, - 'PositiveBigIntegerField': int, - 'PositiveSmallIntegerField': int, - 'PositiveIntegerField': int, - 'FloatField': Database.NATIVE_FLOAT, - 'DateTimeField': Database.TIMESTAMP, - 'DateField': Database.Date, - 'DecimalField': Database.NUMBER, + "AutoField": int, + "BigAutoField": int, + "SmallAutoField": int, + "IntegerField": int, + "BigIntegerField": int, + "SmallIntegerField": int, + "PositiveBigIntegerField": int, + "PositiveSmallIntegerField": int, + "PositiveIntegerField": int, + "FloatField": Database.NATIVE_FLOAT, + "DateTimeField": Database.TIMESTAMP, + "DateField": Database.Date, + "DecimalField": Database.NUMBER, } def __init__(self, field): - internal_type = getattr(field, 'target_field', field).get_internal_type() + internal_type = getattr(field, "target_field", field).get_internal_type() self.db_type = self.types.get(internal_type, str) self.bound_param = None @@ -43,48 +44,54 @@ class Oracle_datetime(datetime.datetime): A datetime object, with an additional class attribute to tell cx_Oracle to save the microseconds too. """ + input_size = Database.TIMESTAMP @classmethod def from_datetime(cls, dt): return Oracle_datetime( - dt.year, dt.month, dt.day, - dt.hour, dt.minute, dt.second, dt.microsecond, + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.microsecond, ) class BulkInsertMapper: - BLOB = 'TO_BLOB(%s)' - DATE = 'TO_DATE(%s)' - INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))' - NCLOB = 'TO_NCLOB(%s)' - NUMBER = 'TO_NUMBER(%s)' - TIMESTAMP = 'TO_TIMESTAMP(%s)' + BLOB = "TO_BLOB(%s)" + DATE = "TO_DATE(%s)" + INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))" + NCLOB = "TO_NCLOB(%s)" + NUMBER = "TO_NUMBER(%s)" + TIMESTAMP = "TO_TIMESTAMP(%s)" types = { - 'AutoField': NUMBER, - 'BigAutoField': NUMBER, - 'BigIntegerField': NUMBER, - 'BinaryField': BLOB, - 'BooleanField': NUMBER, - 'DateField': DATE, - 'DateTimeField': TIMESTAMP, - 'DecimalField': NUMBER, - 'DurationField': INTERVAL, - 'FloatField': NUMBER, - 'IntegerField': NUMBER, - 'PositiveBigIntegerField': NUMBER, - 'PositiveIntegerField': NUMBER, - 'PositiveSmallIntegerField': NUMBER, - 'SmallAutoField': NUMBER, - 'SmallIntegerField': NUMBER, - 'TextField': NCLOB, - 'TimeField': TIMESTAMP, + "AutoField": NUMBER, + "BigAutoField": NUMBER, + "BigIntegerField": NUMBER, + "BinaryField": BLOB, + "BooleanField": NUMBER, + "DateField": DATE, + "DateTimeField": TIMESTAMP, + "DecimalField": NUMBER, + "DurationField": INTERVAL, + "FloatField": NUMBER, + "IntegerField": NUMBER, + "PositiveBigIntegerField": NUMBER, + "PositiveIntegerField": NUMBER, + "PositiveSmallIntegerField": NUMBER, + "SmallAutoField": NUMBER, + "SmallIntegerField": NUMBER, + "TextField": NCLOB, + "TimeField": TIMESTAMP, } def dsn(settings_dict): - if settings_dict['PORT']: - host = settings_dict['HOST'].strip() or 'localhost' - return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME']) - return settings_dict['NAME'] + if settings_dict["PORT"]: + host = settings_dict["HOST"].strip() or "localhost" + return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"]) + return settings_dict["NAME"] diff --git a/django/db/backends/oracle/validation.py b/django/db/backends/oracle/validation.py index e5a35fd3ca..4035b12085 100644 --- a/django/db/backends/oracle/validation.py +++ b/django/db/backends/oracle/validation.py @@ -9,14 +9,14 @@ class DatabaseValidation(BaseDatabaseValidation): if field.db_index and field_type.lower() in self.connection._limited_data_types: errors.append( checks.Warning( - 'Oracle does not support a database index on %s columns.' + "Oracle does not support a database index on %s columns." % field_type, hint=( "An index won't be created. Silence this warning if " "you don't care about it." ), obj=field, - id='fields.W162', + id="fields.W162", ) ) return errors diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index e49d453f94..92f393227e 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -11,11 +11,10 @@ from contextlib import contextmanager from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.db import DatabaseError as WrappedDatabaseError, connections +from django.db import DatabaseError as WrappedDatabaseError +from django.db import connections from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.utils import ( - CursorDebugWrapper as BaseCursorDebugWrapper, -) +from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from django.utils.safestring import SafeString @@ -30,14 +29,17 @@ except ImportError as e: def psycopg2_version(): - version = psycopg2.__version__.split(' ', 1)[0] + version = psycopg2.__version__.split(" ", 1)[0] return get_version_tuple(version) PSYCOPG2_VERSION = psycopg2_version() if PSYCOPG2_VERSION < (2, 8, 4): - raise ImproperlyConfigured("psycopg2 version 2.8.4 or newer is required; you have %s" % psycopg2.__version__) + raise ImproperlyConfigured( + "psycopg2 version 2.8.4 or newer is required; you have %s" + % psycopg2.__version__ + ) # Some of these import psycopg2, so import them after checking if it's installed. @@ -56,68 +58,68 @@ psycopg2.extras.register_uuid() INETARRAY_OID = 1041 INETARRAY = psycopg2.extensions.new_array_type( (INETARRAY_OID,), - 'INETARRAY', + "INETARRAY", psycopg2.extensions.UNICODE, ) psycopg2.extensions.register_type(INETARRAY) class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'postgresql' - display_name = 'PostgreSQL' + vendor = "postgresql" + display_name = "PostgreSQL" # This dictionary maps Field objects to their associated PostgreSQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { - 'AutoField': 'serial', - 'BigAutoField': 'bigserial', - 'BinaryField': 'bytea', - 'BooleanField': 'boolean', - 'CharField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'timestamp with time zone', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'DurationField': 'interval', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'inet', - 'GenericIPAddressField': 'inet', - 'JSONField': 'jsonb', - 'OneToOneField': 'integer', - 'PositiveBigIntegerField': 'bigint', - 'PositiveIntegerField': 'integer', - 'PositiveSmallIntegerField': 'smallint', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallAutoField': 'smallserial', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'UUIDField': 'uuid', + "AutoField": "serial", + "BigAutoField": "bigserial", + "BinaryField": "bytea", + "BooleanField": "boolean", + "CharField": "varchar(%(max_length)s)", + "DateField": "date", + "DateTimeField": "timestamp with time zone", + "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)", + "DurationField": "interval", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "double precision", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "inet", + "GenericIPAddressField": "inet", + "JSONField": "jsonb", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint", + "PositiveIntegerField": "integer", + "PositiveSmallIntegerField": "smallint", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "smallserial", + "SmallIntegerField": "smallint", + "TextField": "text", + "TimeField": "time", + "UUIDField": "uuid", } data_type_check_constraints = { - 'PositiveBigIntegerField': '"%(column)s" >= 0', - 'PositiveIntegerField': '"%(column)s" >= 0', - 'PositiveSmallIntegerField': '"%(column)s" >= 0', + "PositiveBigIntegerField": '"%(column)s" >= 0', + "PositiveIntegerField": '"%(column)s" >= 0', + "PositiveSmallIntegerField": '"%(column)s" >= 0', } operators = { - 'exact': '= %s', - 'iexact': '= UPPER(%s)', - 'contains': 'LIKE %s', - 'icontains': 'LIKE UPPER(%s)', - 'regex': '~ %s', - 'iregex': '~* %s', - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': 'LIKE %s', - 'endswith': 'LIKE %s', - 'istartswith': 'LIKE UPPER(%s)', - 'iendswith': 'LIKE UPPER(%s)', + "exact": "= %s", + "iexact": "= UPPER(%s)", + "contains": "LIKE %s", + "icontains": "LIKE UPPER(%s)", + "regex": "~ %s", + "iregex": "~* %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s", + "endswith": "LIKE %s", + "istartswith": "LIKE UPPER(%s)", + "iendswith": "LIKE UPPER(%s)", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -128,14 +130,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): # # Note: we use str.format() here for readability as '%' is used as a wildcard for # the LIKE operator. - pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" + pattern_esc = ( + r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')" + ) pattern_ops = { - 'contains': "LIKE '%%' || {} || '%%'", - 'icontains': "LIKE '%%' || UPPER({}) || '%%'", - 'startswith': "LIKE {} || '%%'", - 'istartswith': "LIKE UPPER({}) || '%%'", - 'endswith': "LIKE '%%' || {}", - 'iendswith': "LIKE '%%' || UPPER({})", + "contains": "LIKE '%%' || {} || '%%'", + "icontains": "LIKE '%%' || UPPER({}) || '%%'", + "startswith": "LIKE {} || '%%'", + "istartswith": "LIKE UPPER({}) || '%%'", + "endswith": "LIKE '%%' || {}", + "iendswith": "LIKE '%%' || UPPER({})", } Database = Database @@ -152,46 +156,46 @@ class DatabaseWrapper(BaseDatabaseWrapper): def get_connection_params(self): settings_dict = self.settings_dict # None may be used to connect to the default 'postgres' db - if ( - settings_dict['NAME'] == '' and - not settings_dict.get('OPTIONS', {}).get('service') + if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get( + "service" ): raise ImproperlyConfigured( "settings.DATABASES is improperly configured. " "Please supply the NAME or OPTIONS['service'] value." ) - if len(settings_dict['NAME'] or '') > self.ops.max_name_length(): + if len(settings_dict["NAME"] or "") > self.ops.max_name_length(): raise ImproperlyConfigured( "The database name '%s' (%d characters) is longer than " "PostgreSQL's limit of %d characters. Supply a shorter NAME " - "in settings.DATABASES." % ( - settings_dict['NAME'], - len(settings_dict['NAME']), + "in settings.DATABASES." + % ( + settings_dict["NAME"], + len(settings_dict["NAME"]), self.ops.max_name_length(), ) ) conn_params = {} - if settings_dict['NAME']: + if settings_dict["NAME"]: conn_params = { - 'database': settings_dict['NAME'], - **settings_dict['OPTIONS'], + "database": settings_dict["NAME"], + **settings_dict["OPTIONS"], } - elif settings_dict['NAME'] is None: + elif settings_dict["NAME"] is None: # Connect to the default 'postgres' db. - settings_dict.get('OPTIONS', {}).pop('service', None) - conn_params = {'database': 'postgres', **settings_dict['OPTIONS']} + settings_dict.get("OPTIONS", {}).pop("service", None) + conn_params = {"database": "postgres", **settings_dict["OPTIONS"]} else: - conn_params = {**settings_dict['OPTIONS']} + conn_params = {**settings_dict["OPTIONS"]} - conn_params.pop('isolation_level', None) - if settings_dict['USER']: - conn_params['user'] = settings_dict['USER'] - if settings_dict['PASSWORD']: - conn_params['password'] = settings_dict['PASSWORD'] - if settings_dict['HOST']: - conn_params['host'] = settings_dict['HOST'] - if settings_dict['PORT']: - conn_params['port'] = settings_dict['PORT'] + conn_params.pop("isolation_level", None) + if settings_dict["USER"]: + conn_params["user"] = settings_dict["USER"] + if settings_dict["PASSWORD"]: + conn_params["password"] = settings_dict["PASSWORD"] + if settings_dict["HOST"]: + conn_params["host"] = settings_dict["HOST"] + if settings_dict["PORT"]: + conn_params["port"] = settings_dict["PORT"] return conn_params @async_unsafe @@ -203,9 +207,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): # default when no value is explicitly specified in options. # - before calling _set_autocommit() because if autocommit is on, that # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT. - options = self.settings_dict['OPTIONS'] + options = self.settings_dict["OPTIONS"] try: - self.isolation_level = options['isolation_level'] + self.isolation_level = options["isolation_level"] except KeyError: self.isolation_level = connection.isolation_level else: @@ -215,13 +219,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): # Register dummy loads() to avoid a round trip from psycopg2's decode # to json.dumps() to json.loads(), when using a custom decoder in # JSONField. - psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x) + psycopg2.extras.register_default_jsonb( + conn_or_curs=connection, loads=lambda x: x + ) return connection def ensure_timezone(self): if self.connection is None: return False - conn_timezone_name = self.connection.get_parameter_status('TimeZone') + conn_timezone_name = self.connection.get_parameter_status("TimeZone") timezone_name = self.timezone_name if timezone_name and conn_timezone_name != timezone_name: with self.connection.cursor() as cursor: @@ -230,7 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return False def init_connection_state(self): - self.connection.set_client_encoding('UTF8') + self.connection.set_client_encoding("UTF8") timezone_changed = self.ensure_timezone() if timezone_changed: @@ -243,7 +249,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): if name: # In autocommit mode, the cursor will be used outside of a # transaction, hence use a holdable cursor. - cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit) + cursor = self.connection.cursor( + name, scrollable=False, withhold=self.connection.autocommit + ) else: cursor = self.connection.cursor() cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None @@ -268,10 +276,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): if current_task: task_ident = str(id(current_task)) else: - task_ident = 'sync' + task_ident = "sync" # Use that and the thread ident to get a unique name return self._cursor( - name='_django_curs_%d_%s_%d' % ( + name="_django_curs_%d_%s_%d" + % ( # Avoid reusing name in other threads / tasks threading.current_thread().ident, task_ident, @@ -289,14 +298,14 @@ class DatabaseWrapper(BaseDatabaseWrapper): afterward. """ with self.cursor() as cursor: - cursor.execute('SET CONSTRAINTS ALL IMMEDIATE') - cursor.execute('SET CONSTRAINTS ALL DEFERRED') + cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") + cursor.execute("SET CONSTRAINTS ALL DEFERRED") def is_usable(self): try: # Use a psycopg cursor directly, bypassing Django's utilities. with self.connection.cursor() as cursor: - cursor.execute('SELECT 1') + cursor.execute("SELECT 1") except Database.Error: return False else: @@ -317,12 +326,18 @@ class DatabaseWrapper(BaseDatabaseWrapper): "database when it's not needed (for example, when running tests). " "Django was unable to create a connection to the 'postgres' database " "and will use the first PostgreSQL database instead.", - RuntimeWarning + RuntimeWarning, ) for connection in connections.all(): - if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres': + if ( + connection.vendor == "postgresql" + and connection.settings_dict["NAME"] != "postgres" + ): conn = self.__class__( - {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, + { + **self.settings_dict, + "NAME": connection.settings_dict["NAME"], + }, alias=self.alias, ) try: @@ -349,5 +364,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper): return self.cursor.copy_expert(sql, file, *args) def copy_to(self, file, table, *args, **kwargs): - with self.debug_sql(sql='COPY %s TO STDOUT' % table): + with self.debug_sql(sql="COPY %s TO STDOUT" % table): return self.cursor.copy_to(file, table, *args, **kwargs) diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 0effcc44e6..4c9bd63546 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -4,53 +4,53 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'psql' + executable_name = "psql" @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): args = [cls.executable_name] - options = settings_dict.get('OPTIONS', {}) + options = settings_dict.get("OPTIONS", {}) - host = settings_dict.get('HOST') - port = settings_dict.get('PORT') - dbname = settings_dict.get('NAME') - user = settings_dict.get('USER') - passwd = settings_dict.get('PASSWORD') - passfile = options.get('passfile') - service = options.get('service') - sslmode = options.get('sslmode') - sslrootcert = options.get('sslrootcert') - sslcert = options.get('sslcert') - sslkey = options.get('sslkey') + host = settings_dict.get("HOST") + port = settings_dict.get("PORT") + dbname = settings_dict.get("NAME") + user = settings_dict.get("USER") + passwd = settings_dict.get("PASSWORD") + passfile = options.get("passfile") + service = options.get("service") + sslmode = options.get("sslmode") + sslrootcert = options.get("sslrootcert") + sslcert = options.get("sslcert") + sslkey = options.get("sslkey") if not dbname and not service: # Connect to the default 'postgres' db. - dbname = 'postgres' + dbname = "postgres" if user: - args += ['-U', user] + args += ["-U", user] if host: - args += ['-h', host] + args += ["-h", host] if port: - args += ['-p', str(port)] + args += ["-p", str(port)] if dbname: args += [dbname] args.extend(parameters) env = {} if passwd: - env['PGPASSWORD'] = str(passwd) + env["PGPASSWORD"] = str(passwd) if service: - env['PGSERVICE'] = str(service) + env["PGSERVICE"] = str(service) if sslmode: - env['PGSSLMODE'] = str(sslmode) + env["PGSSLMODE"] = str(sslmode) if sslrootcert: - env['PGSSLROOTCERT'] = str(sslrootcert) + env["PGSSLROOTCERT"] = str(sslrootcert) if sslcert: - env['PGSSLCERT'] = str(sslcert) + env["PGSSLCERT"] = str(sslcert) if sslkey: - env['PGSSLKEY'] = str(sslkey) + env["PGSSLKEY"] = str(sslkey) if passfile: - env['PGPASSFILE'] = str(passfile) + env["PGPASSFILE"] = str(passfile) return args, (env or None) def runshell(self, parameters): diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py index eb8ac3bcf5..70c3eda566 100644 --- a/django/db/backends/postgresql/creation.py +++ b/django/db/backends/postgresql/creation.py @@ -8,7 +8,6 @@ from django.db.backends.utils import strip_quotes class DatabaseCreation(BaseDatabaseCreation): - def _quote_name(self, name): return self.connection.ops.quote_name(name) @@ -21,32 +20,35 @@ class DatabaseCreation(BaseDatabaseCreation): return suffix and "WITH" + suffix def sql_table_creation_suffix(self): - test_settings = self.connection.settings_dict['TEST'] - if test_settings.get('COLLATION') is not None: + test_settings = self.connection.settings_dict["TEST"] + if test_settings.get("COLLATION") is not None: raise ImproperlyConfigured( - 'PostgreSQL does not support collation setting at database ' - 'creation time.' + "PostgreSQL does not support collation setting at database " + "creation time." ) return self._get_database_create_suffix( - encoding=test_settings['CHARSET'], - template=test_settings.get('TEMPLATE'), + encoding=test_settings["CHARSET"], + template=test_settings.get("TEMPLATE"), ) def _database_exists(self, cursor, database_name): - cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)]) + cursor.execute( + "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s", + [strip_quotes(database_name)], + ) return cursor.fetchone() is not None def _execute_create_test_db(self, cursor, parameters, keepdb=False): try: - if keepdb and self._database_exists(cursor, parameters['dbname']): + if keepdb and self._database_exists(cursor, parameters["dbname"]): # If the database should be kept and it already exists, don't # try to create a new one. return super()._execute_create_test_db(cursor, parameters, keepdb) except Exception as e: - if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE: + if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE: # All errors except "database already exists" cancel tests. - self.log('Got an error creating the test database: %s' % e) + self.log("Got an error creating the test database: %s" % e) sys.exit(2) elif not keepdb: # If the database should be kept, ignore "database already @@ -58,11 +60,11 @@ class DatabaseCreation(BaseDatabaseCreation): # to the template database. self.connection.close() - source_database_name = self.connection.settings_dict['NAME'] - target_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] test_db_params = { - 'dbname': self._quote_name(target_database_name), - 'suffix': self._get_database_create_suffix(template=source_database_name), + "dbname": self._quote_name(target_database_name), + "suffix": self._get_database_create_suffix(template=source_database_name), } with self._nodb_cursor() as cursor: try: @@ -70,11 +72,16 @@ class DatabaseCreation(BaseDatabaseCreation): except Exception: try: if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, target_database_name), - )) - cursor.execute('DROP DATABASE %(dbname)s' % test_db_params) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) + cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: - self.log('Got an error cloning the test database: %s' % e) + self.log("Got an error cloning the test database: %s" % e) sys.exit(2) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 1ce73fb0a8..182c230c75 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -52,7 +52,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_over_clause = True only_supports_unbounded_with_preceding_and_following = True supports_aggregate_filter_clause = True - supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'} + supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"} validates_explain_options = False # A query will error on invalid options. supports_deferrable_unique_constraints = True has_json_operators = True @@ -60,14 +60,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_update_conflicts = True supports_update_conflicts_with_target = True test_collations = { - 'non_default': 'sv-x-icu', - 'swedish_ci': 'sv-x-icu', + "non_default": "sv-x-icu", + "swedish_ci": "sv-x-icu", } test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'" django_test_skips = { - 'opclasses are PostgreSQL only.': { - 'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses', + "opclasses are PostgreSQL only.": { + "indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses", }, } @@ -75,9 +75,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): def introspected_field_types(self): return { **super().introspected_field_types, - 'PositiveBigIntegerField': 'BigIntegerField', - 'PositiveIntegerField': 'IntegerField', - 'PositiveSmallIntegerField': 'SmallIntegerField', + "PositiveBigIntegerField": "BigIntegerField", + "PositiveIntegerField": "IntegerField", + "PositiveSmallIntegerField": "SmallIntegerField", } @cached_property @@ -96,9 +96,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): def is_postgresql_14(self): return self.connection.pg_version >= 140000 - has_bit_xor = property(operator.attrgetter('is_postgresql_14')) - has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11')) - supports_covering_indexes = property(operator.attrgetter('is_postgresql_11')) - supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12')) - supports_covering_spgist_indexes = property(operator.attrgetter('is_postgresql_14')) - supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12')) + has_bit_xor = property(operator.attrgetter("is_postgresql_14")) + has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11")) + supports_covering_indexes = property(operator.attrgetter("is_postgresql_11")) + supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12")) + supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14")) + supports_non_deterministic_collations = property( + operator.attrgetter("is_postgresql_12") + ) diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index f31d906a2f..a7e9a13d61 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -1,5 +1,7 @@ from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo, TableInfo, + BaseDatabaseIntrospection, + FieldInfo, + TableInfo, ) from django.db.models import Index @@ -7,46 +9,47 @@ from django.db.models import Index class DatabaseIntrospection(BaseDatabaseIntrospection): # Maps type codes to Django Field types. data_types_reverse = { - 16: 'BooleanField', - 17: 'BinaryField', - 20: 'BigIntegerField', - 21: 'SmallIntegerField', - 23: 'IntegerField', - 25: 'TextField', - 700: 'FloatField', - 701: 'FloatField', - 869: 'GenericIPAddressField', - 1042: 'CharField', # blank-padded - 1043: 'CharField', - 1082: 'DateField', - 1083: 'TimeField', - 1114: 'DateTimeField', - 1184: 'DateTimeField', - 1186: 'DurationField', - 1266: 'TimeField', - 1700: 'DecimalField', - 2950: 'UUIDField', - 3802: 'JSONField', + 16: "BooleanField", + 17: "BinaryField", + 20: "BigIntegerField", + 21: "SmallIntegerField", + 23: "IntegerField", + 25: "TextField", + 700: "FloatField", + 701: "FloatField", + 869: "GenericIPAddressField", + 1042: "CharField", # blank-padded + 1043: "CharField", + 1082: "DateField", + 1083: "TimeField", + 1114: "DateTimeField", + 1184: "DateTimeField", + 1186: "DurationField", + 1266: "TimeField", + 1700: "DecimalField", + 2950: "UUIDField", + 3802: "JSONField", } # A hook for subclasses. - index_default_access_method = 'btree' + index_default_access_method = "btree" ignored_tables = [] def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if description.default and 'nextval' in description.default: - if field_type == 'IntegerField': - return 'AutoField' - elif field_type == 'BigIntegerField': - return 'BigAutoField' - elif field_type == 'SmallIntegerField': - return 'SmallAutoField' + if description.default and "nextval" in description.default: + if field_type == "IntegerField": + return "AutoField" + elif field_type == "BigIntegerField": + return "BigAutoField" + elif field_type == "SmallIntegerField": + return "SmallAutoField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" - cursor.execute(""" + cursor.execute( + """ SELECT c.relname, CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END FROM pg_catalog.pg_class c @@ -54,8 +57,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v') AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND pg_catalog.pg_table_is_visible(c.oid) - """) - return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables] + """ + ) + return [ + TableInfo(*row) + for row in cursor.fetchall() + if row[0] not in self.ignored_tables + ] def get_table_description(self, cursor, table_name): """ @@ -65,7 +73,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Query the pg_catalog tables as cursor.description does not reliably # return the nullable property and information_schema.columns does not # contain details of materialized views. - cursor.execute(""" + cursor.execute( + """ SELECT a.attname AS column_name, NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable, @@ -81,9 +90,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): AND c.relname = %s AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND pg_catalog.pg_table_is_visible(c.oid) - """, [table_name]) + """, + [table_name], + ) field_map = {line[0]: line[1:] for line in cursor.fetchall()} - cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name) + ) return [ FieldInfo( line.name, @@ -98,7 +111,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ] def get_sequences(self, cursor, table_name, table_fields=()): - cursor.execute(""" + cursor.execute( + """ SELECT s.relname as sequence_name, col.attname FROM pg_class s JOIN pg_namespace sn ON sn.oid = s.relnamespace @@ -110,9 +124,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): AND d.deptype in ('a', 'n') AND pg_catalog.pg_table_is_visible(tbl.oid) AND tbl.relname = %s - """, [table_name]) + """, + [table_name], + ) return [ - {'name': row[0], 'table': table_name, 'column': row[1]} + {"name": row[0], "table": table_name, "column": row[1]} for row in cursor.fetchall() ] @@ -121,7 +137,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a dictionary of {field_name: (field_name_other_table, other_table)} representing all foreign keys in the given table. """ - cursor.execute(""" + cursor.execute( + """ SELECT a1.attname, c2.relname, a2.attname FROM pg_constraint con LEFT JOIN pg_class c1 ON con.conrelid = c1.oid @@ -133,7 +150,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): con.contype = 'f' AND c1.relnamespace = c2.relnamespace AND pg_catalog.pg_table_is_visible(c1.oid) - """, [table_name]) + """, + [table_name], + ) return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} def get_constraints(self, cursor, table_name): @@ -146,7 +165,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Loop over the key table, collecting things as constraints. The column # array must return column names in the same order in which they were # created. - cursor.execute(""" + cursor.execute( + """ SELECT c.conname, array( @@ -165,7 +185,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): FROM pg_constraint AS c JOIN pg_class AS cl ON c.conrelid = cl.oid WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid) - """, [table_name]) + """, + [table_name], + ) for constraint, columns, kind, used_cols, options in cursor.fetchall(): constraints[constraint] = { "columns": columns, @@ -178,7 +200,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "options": options, } # Now get indexes - cursor.execute(""" + cursor.execute( + """ SELECT indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary, array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions @@ -207,14 +230,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid) ) s2 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions; - """, [self.index_default_access_method, table_name]) - for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall(): + """, + [self.index_default_access_method, table_name], + ) + for ( + index, + columns, + unique, + primary, + orders, + type_, + definition, + options, + ) in cursor.fetchall(): if index not in constraints: basic_index = ( - type_ == self.index_default_access_method and + type_ == self.index_default_access_method + and # '_btree' references # django.contrib.postgres.indexes.BTreeIndex.suffix. - not index.endswith('_btree') and options is None + not index.endswith("_btree") + and options is None ) constraints[index] = { "columns": columns if columns != [None] else [], diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 762cd8d23e..68448157ec 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -7,17 +7,22 @@ from django.db.models.constants import OnConflict class DatabaseOperations(BaseDatabaseOperations): - cast_char_field_without_max_length = 'varchar' - explain_prefix = 'EXPLAIN' + cast_char_field_without_max_length = "varchar" + explain_prefix = "EXPLAIN" cast_data_types = { - 'AutoField': 'integer', - 'BigAutoField': 'bigint', - 'SmallAutoField': 'smallint', + "AutoField": "integer", + "BigAutoField": "bigint", + "SmallAutoField": "smallint", } def unification_cast_sql(self, output_field): internal_type = output_field.get_internal_type() - if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"): + if internal_type in ( + "GenericIPAddressField", + "IPAddressField", + "TimeField", + "UUIDField", + ): # PostgreSQL will resolve a union as type 'text' if input types are # 'unknown'. # https://www.postgresql.org/docs/current/typeconv-union-case.html @@ -25,17 +30,19 @@ class DatabaseOperations(BaseDatabaseOperations): # PostgreSQL configuration so we need to explicitly cast them. # We must also remove components of the type within brackets: # varchar(255) -> varchar. - return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0] - return '%s' + return ( + "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0] + ) + return "%s" def date_extract_sql(self, lookup_type, field_name): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT - if lookup_type == 'week_day': + if lookup_type == "week_day": # For consistency across backends, we return Sunday=1, Saturday=7. return "EXTRACT('dow' FROM %s) + 1" % field_name - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return "EXTRACT('isodow' FROM %s)" % field_name - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return "EXTRACT('isoyear' FROM %s)" % field_name else: return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) @@ -48,22 +55,25 @@ class DatabaseOperations(BaseDatabaseOperations): def _prepare_tzname_delta(self, tzname): tzname, sign, offset = split_tzname_delta(tzname) if offset: - sign = '-' if sign == '+' else '+' - return f'{tzname}{sign}{offset}' + sign = "-" if sign == "+" else "+" + return f"{tzname}{sign}{offset}" return tzname def _convert_field_to_tz(self, field_name, tzname): if tzname and settings.USE_TZ: - field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname)) + field_name = "%s AT TIME ZONE '%s'" % ( + field_name, + self._prepare_tzname_delta(tzname), + ) return field_name def datetime_cast_date_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::date' % field_name + return "(%s)::date" % field_name def datetime_cast_time_sql(self, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) - return '(%s)::time' % field_name + return "(%s)::time" % field_name def datetime_extract_sql(self, lookup_type, field_name, tzname): field_name = self._convert_field_to_tz(field_name, tzname) @@ -89,21 +99,30 @@ class DatabaseOperations(BaseDatabaseOperations): return cursor.fetchall() def lookup_cast(self, lookup_type, internal_type=None): - lookup = '%s' + lookup = "%s" # Cast text lookups to text to allow things like filter(x__contains=4) - if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): - if internal_type in ('IPAddressField', 'GenericIPAddressField'): + if lookup_type in ( + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", + ): + if internal_type in ("IPAddressField", "GenericIPAddressField"): lookup = "HOST(%s)" - elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'): - lookup = '%s::citext' + elif internal_type in ("CICharField", "CIEmailField", "CITextField"): + lookup = "%s::citext" else: lookup = "%s::text" # Use UPPER(x) for case-insensitive lookups; it's faster. - if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): - lookup = 'UPPER(%s)' % lookup + if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"): + lookup = "UPPER(%s)" % lookup return lookup @@ -128,29 +147,32 @@ class DatabaseOperations(BaseDatabaseOperations): # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us # to truncate tables referenced by a foreign key in any other table. sql_parts = [ - style.SQL_KEYWORD('TRUNCATE'), - ', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables), + style.SQL_KEYWORD("TRUNCATE"), + ", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables), ] if reset_sequences: - sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY')) + sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY")) if allow_cascade: - sql_parts.append(style.SQL_KEYWORD('CASCADE')) - return ['%s;' % ' '.join(sql_parts)] + sql_parts.append(style.SQL_KEYWORD("CASCADE")) + return ["%s;" % " ".join(sql_parts)] def sequence_reset_by_name_sql(self, style, sequences): # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements # to reset sequence indices sql = [] for sequence_info in sequences: - table_name = sequence_info['table'] + table_name = sequence_info["table"] # 'id' will be the case if it's an m2m using an autogenerated # intermediate table (see BaseDatabaseIntrospection.sequence_list). - column_name = sequence_info['column'] or 'id' - sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % ( - style.SQL_KEYWORD('SELECT'), - style.SQL_TABLE(self.quote_name(table_name)), - style.SQL_FIELD(column_name), - )) + column_name = sequence_info["column"] or "id" + sql.append( + "%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" + % ( + style.SQL_KEYWORD("SELECT"), + style.SQL_TABLE(self.quote_name(table_name)), + style.SQL_FIELD(column_name), + ) + ) return sql def tablespace_sql(self, tablespace, inline=False): @@ -161,6 +183,7 @@ class DatabaseOperations(BaseDatabaseOperations): def sequence_reset_sql(self, style, model_list): from django.db import models + output = [] qn = self.quote_name for model in model_list: @@ -174,14 +197,15 @@ class DatabaseOperations(BaseDatabaseOperations): if isinstance(f, models.AutoField): output.append( "%s setval(pg_get_serial_sequence('%s','%s'), " - "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % ( - style.SQL_KEYWORD('SELECT'), + "coalesce(max(%s), 1), max(%s) %s null) %s %s;" + % ( + style.SQL_KEYWORD("SELECT"), style.SQL_TABLE(qn(model._meta.db_table)), style.SQL_FIELD(f.column), style.SQL_FIELD(qn(f.column)), style.SQL_FIELD(qn(f.column)), - style.SQL_KEYWORD('IS NOT'), - style.SQL_KEYWORD('FROM'), + style.SQL_KEYWORD("IS NOT"), + style.SQL_KEYWORD("FROM"), style.SQL_TABLE(qn(model._meta.db_table)), ) ) @@ -207,9 +231,9 @@ class DatabaseOperations(BaseDatabaseOperations): def distinct_sql(self, fields, params): if fields: params = [param for param_list in params for param in param_list] - return (['DISTINCT ON (%s)' % ', '.join(fields)], params) + return (["DISTINCT ON (%s)" % ", ".join(fields)], params) else: - return ['DISTINCT'], [] + return ["DISTINCT"], [] def last_executed_query(self, cursor, sql, params): # https://www.psycopg.org/docs/cursor.html#cursor.query @@ -220,14 +244,16 @@ class DatabaseOperations(BaseDatabaseOperations): def return_insert_columns(self, fields): if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def bulk_insert_sql(self, fields, placeholder_rows): placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) @@ -252,7 +278,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None def subtract_temporals(self, internal_type, lhs, rhs): - if internal_type == 'DateField': + if internal_type == "DateField": lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) @@ -263,27 +289,34 @@ class DatabaseOperations(BaseDatabaseOperations): prefix = super().explain_query_prefix(format) extra = {} if format: - extra['FORMAT'] = format + extra["FORMAT"] = format if options: - extra.update({ - name.upper(): 'true' if value else 'false' - for name, value in options.items() - }) + extra.update( + { + name.upper(): "true" if value else "false" + for name, value in options.items() + } + ) if extra: - prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) + prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items()) return prefix def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if on_conflict == OnConflict.IGNORE: - return 'ON CONFLICT DO NOTHING' + return "ON CONFLICT DO NOTHING" if on_conflict == OnConflict.UPDATE: - return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( - ', '.join(map(self.quote_name, unique_fields)), - ', '.join([ - f'{field} = EXCLUDED.{field}' - for field in map(self.quote_name, update_fields) - ]), + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), ) return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py index f3b5baecbe..47e9a6a8f3 100644 --- a/django/db/backends/postgresql/schema.py +++ b/django/db/backends/postgresql/schema.py @@ -9,16 +9,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_create_sequence = "CREATE SEQUENCE %(sequence)s" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" - sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" - sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s' + sql_set_sequence_max = ( + "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" + ) + sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s" sql_create_index = ( - 'CREATE INDEX %(name)s ON %(table)s%(using)s ' - '(%(columns)s)%(include)s%(extra)s%(condition)s' + "CREATE INDEX %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_create_index_concurrently = ( - 'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s ' - '(%(columns)s)%(include)s%(extra)s%(condition)s' + "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s " + "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_delete_index = "DROP INDEX IF EXISTS %(name)s" sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s" @@ -26,21 +28,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Setting the constraint to IMMEDIATE to allow changing data in the same # transaction. sql_create_column_inline_fk = ( - 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s' - '; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE' + "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" + "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE" ) # Setting the constraint to IMMEDIATE runs any deferred checks to allow # dropping it in the same transaction. sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" - sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)' + sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" def quote_value(self, value): if isinstance(value, str): - value = value.replace('%', '%%') + value = value.replace("%", "%%") adapted = psycopg2.extensions.adapt(value) - if hasattr(adapted, 'encoding'): - adapted.encoding = 'utf8' + if hasattr(adapted, "encoding"): + adapted.encoding = "utf8" # getquoted() returns a quoted bytestring of the adapted value. return adapted.getquoted().decode() @@ -61,7 +63,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _field_base_data_types(self, field): # Yield base data types for array fields. - if field.base_field.get_internal_type() == 'ArrayField': + if field.base_field.get_internal_type() == "ArrayField": yield from self._field_base_data_types(field.base_field) else: yield self._field_data_type(field.base_field) @@ -80,45 +82,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # # The same doesn't apply to array fields such as varchar[size] # and text[size], so skip them. - if '[' in db_type: + if "[" in db_type: return None - if db_type.startswith('varchar'): + if db_type.startswith("varchar"): return self._create_index_sql( model, fields=[field], - suffix='_like', - opclasses=['varchar_pattern_ops'], + suffix="_like", + opclasses=["varchar_pattern_ops"], ) - elif db_type.startswith('text'): + elif db_type.startswith("text"): return self._create_index_sql( model, fields=[field], - suffix='_like', - opclasses=['text_pattern_ops'], + suffix="_like", + opclasses=["text_pattern_ops"], ) return None def _alter_column_type_sql(self, model, old_field, new_field, new_type): - self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s' + self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s" # Cast when data type changed. - using_sql = ' USING %(column)s::%(type)s' + using_sql = " USING %(column)s::%(type)s" new_internal_type = new_field.get_internal_type() old_internal_type = old_field.get_internal_type() - if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type: + if new_internal_type == "ArrayField" and new_internal_type == old_internal_type: # Compare base data types for array fields. - if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)): + if list(self._field_base_data_types(old_field)) != list( + self._field_base_data_types(new_field) + ): self.sql_alter_column_type += using_sql elif self._field_data_type(old_field) != self._field_data_type(new_field): self.sql_alter_column_type += using_sql # Make ALTER TYPE with SERIAL make sense. table = strip_quotes(model._meta.db_table) - serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'} + serial_fields_map = { + "bigserial": "bigint", + "serial": "integer", + "smallserial": "smallint", + } if new_type.lower() in serial_fields_map: column = strip_quotes(new_field.column) sequence_name = "%s_%s_seq" % (table, column) return ( ( - self.sql_alter_column_type % { + self.sql_alter_column_type + % { "column": self.quote_name(column), "type": serial_fields_map[new_type.lower()], }, @@ -126,29 +135,35 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): ), [ ( - self.sql_delete_sequence % { + self.sql_delete_sequence + % { "sequence": self.quote_name(sequence_name), }, [], ), ( - self.sql_create_sequence % { + self.sql_create_sequence + % { "sequence": self.quote_name(sequence_name), }, [], ), ( - self.sql_alter_column % { + self.sql_alter_column + % { "table": self.quote_name(table), - "changes": self.sql_alter_column_default % { + "changes": self.sql_alter_column_default + % { "column": self.quote_name(column), - "default": "nextval('%s')" % self.quote_name(sequence_name), - } + "default": "nextval('%s')" + % self.quote_name(sequence_name), + }, }, [], ), ( - self.sql_set_sequence_max % { + self.sql_set_sequence_max + % { "table": self.quote_name(table), "column": self.quote_name(column), "sequence": self.quote_name(sequence_name), @@ -156,24 +171,31 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): [], ), ( - self.sql_set_sequence_owner % { - 'table': self.quote_name(table), - 'column': self.quote_name(column), - 'sequence': self.quote_name(sequence_name), + self.sql_set_sequence_owner + % { + "table": self.quote_name(table), + "column": self.quote_name(column), + "sequence": self.quote_name(sequence_name), }, [], ), ], ) - elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map: + elif ( + old_field.db_parameters(connection=self.connection)["type"] + in serial_fields_map + ): # Drop the sequence if migrating away from AutoField. column = strip_quotes(new_field.column) - sequence_name = '%s_%s_seq' % (table, column) - fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type) + sequence_name = "%s_%s_seq" % (table, column) + fragment, _ = super()._alter_column_type_sql( + model, old_field, new_field, new_type + ) return fragment, [ ( - self.sql_delete_sequence % { - 'sequence': self.quote_name(sequence_name), + self.sql_delete_sequence + % { + "sequence": self.quote_name(sequence_name), }, [], ), @@ -181,58 +203,114 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): else: return super()._alter_column_type_sql(model, old_field, new_field, new_type) - def _alter_field(self, model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict=False): + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): # Drop indexes on varchar/text/citext columns that are changing to a # different type. if (old_field.db_index or old_field.unique) and ( - (old_type.startswith('varchar') and not new_type.startswith('varchar')) or - (old_type.startswith('text') and not new_type.startswith('text')) or - (old_type.startswith('citext') and not new_type.startswith('citext')) + (old_type.startswith("varchar") and not new_type.startswith("varchar")) + or (old_type.startswith("text") and not new_type.startswith("text")) + or (old_type.startswith("citext") and not new_type.startswith("citext")) ): - index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') + index_name = self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ) self.execute(self._delete_index_sql(model, index_name)) super()._alter_field( - model, old_field, new_field, old_type, new_type, old_db_params, - new_db_params, strict, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict, ) # Added an index? Create any PostgreSQL-specific indexes. - if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or - (not old_field.unique and new_field.unique)): + if (not (old_field.db_index or old_field.unique) and new_field.db_index) or ( + not old_field.unique and new_field.unique + ): like_index_statement = self._create_like_index_sql(model, new_field) if like_index_statement is not None: self.execute(like_index_statement) # Removed an index? Drop any PostgreSQL-specific indexes. if old_field.unique and not (new_field.db_index or new_field.unique): - index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') + index_to_remove = self._create_index_name( + model._meta.db_table, [old_field.column], suffix="_like" + ) self.execute(self._delete_index_sql(model, index_to_remove)) def _index_columns(self, table, columns, col_suffixes, opclasses): if opclasses: - return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses) + return IndexColumns( + table, + columns, + self.quote_name, + col_suffixes=col_suffixes, + opclasses=opclasses, + ) return super()._index_columns(table, columns, col_suffixes, opclasses) def add_index(self, model, index, concurrently=False): - self.execute(index.create_sql(model, self, concurrently=concurrently), params=None) + self.execute( + index.create_sql(model, self, concurrently=concurrently), params=None + ) def remove_index(self, model, index, concurrently=False): self.execute(index.remove_sql(model, self, concurrently=concurrently)) def _delete_index_sql(self, model, name, sql=None, concurrently=False): - sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index + sql = ( + self.sql_delete_index_concurrently + if concurrently + else self.sql_delete_index + ) return super()._delete_index_sql(model, name, sql) def _create_index_sql( - self, model, *, fields=None, name=None, suffix='', using='', - db_tablespace=None, col_suffixes=(), sql=None, opclasses=(), - condition=None, concurrently=False, include=None, expressions=None, + self, + model, + *, + fields=None, + name=None, + suffix="", + using="", + db_tablespace=None, + col_suffixes=(), + sql=None, + opclasses=(), + condition=None, + concurrently=False, + include=None, + expressions=None, ): - sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently + sql = ( + self.sql_create_index + if not concurrently + else self.sql_create_index_concurrently + ) return super()._create_index_sql( - model, fields=fields, name=name, suffix=suffix, using=using, - db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql, - opclasses=opclasses, condition=condition, include=include, + model, + fields=fields, + name=name, + suffix=suffix, + using=using, + db_tablespace=db_tablespace, + col_suffixes=col_suffixes, + sql=sql, + opclasses=opclasses, + condition=condition, + include=include, expressions=expressions, ) diff --git a/django/db/backends/sqlite3/_functions.py b/django/db/backends/sqlite3/_functions.py index 3529a99dd6..86684c1907 100644 --- a/django/db/backends/sqlite3/_functions.py +++ b/django/db/backends/sqlite3/_functions.py @@ -7,14 +7,30 @@ import statistics from datetime import timedelta from hashlib import sha1, sha224, sha256, sha384, sha512 from math import ( - acos, asin, atan, atan2, ceil, cos, degrees, exp, floor, fmod, log, pi, - radians, sin, sqrt, tan, + acos, + asin, + atan, + atan2, + ceil, + cos, + degrees, + exp, + floor, + fmod, + log, + pi, + radians, + sin, + sqrt, + tan, ) from re import search as re_search from django.db.backends.base.base import timezone_constructor from django.db.backends.utils import ( - split_tzname_delta, typecast_time, typecast_timestamp, + split_tzname_delta, + typecast_time, + typecast_timestamp, ) from django.utils import timezone from django.utils.crypto import md5 @@ -26,56 +42,62 @@ def register(connection): connection.create_function, deterministic=True, ) - create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract) - create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc) - create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date) - create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time) - create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract) - create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc) - create_deterministic_function('django_time_extract', 2, _sqlite_time_extract) - create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc) - create_deterministic_function('django_time_diff', 2, _sqlite_time_diff) - create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff) - create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta) - create_deterministic_function('regexp', 2, _sqlite_regexp) - create_deterministic_function('ACOS', 1, _sqlite_acos) - create_deterministic_function('ASIN', 1, _sqlite_asin) - create_deterministic_function('ATAN', 1, _sqlite_atan) - create_deterministic_function('ATAN2', 2, _sqlite_atan2) - create_deterministic_function('BITXOR', 2, _sqlite_bitxor) - create_deterministic_function('CEILING', 1, _sqlite_ceiling) - create_deterministic_function('COS', 1, _sqlite_cos) - create_deterministic_function('COT', 1, _sqlite_cot) - create_deterministic_function('DEGREES', 1, _sqlite_degrees) - create_deterministic_function('EXP', 1, _sqlite_exp) - create_deterministic_function('FLOOR', 1, _sqlite_floor) - create_deterministic_function('LN', 1, _sqlite_ln) - create_deterministic_function('LOG', 2, _sqlite_log) - create_deterministic_function('LPAD', 3, _sqlite_lpad) - create_deterministic_function('MD5', 1, _sqlite_md5) - create_deterministic_function('MOD', 2, _sqlite_mod) - create_deterministic_function('PI', 0, _sqlite_pi) - create_deterministic_function('POWER', 2, _sqlite_power) - create_deterministic_function('RADIANS', 1, _sqlite_radians) - create_deterministic_function('REPEAT', 2, _sqlite_repeat) - create_deterministic_function('REVERSE', 1, _sqlite_reverse) - create_deterministic_function('RPAD', 3, _sqlite_rpad) - create_deterministic_function('SHA1', 1, _sqlite_sha1) - create_deterministic_function('SHA224', 1, _sqlite_sha224) - create_deterministic_function('SHA256', 1, _sqlite_sha256) - create_deterministic_function('SHA384', 1, _sqlite_sha384) - create_deterministic_function('SHA512', 1, _sqlite_sha512) - create_deterministic_function('SIGN', 1, _sqlite_sign) - create_deterministic_function('SIN', 1, _sqlite_sin) - create_deterministic_function('SQRT', 1, _sqlite_sqrt) - create_deterministic_function('TAN', 1, _sqlite_tan) + create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract) + create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc) + create_deterministic_function( + "django_datetime_cast_date", 3, _sqlite_datetime_cast_date + ) + create_deterministic_function( + "django_datetime_cast_time", 3, _sqlite_datetime_cast_time + ) + create_deterministic_function( + "django_datetime_extract", 4, _sqlite_datetime_extract + ) + create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc) + create_deterministic_function("django_time_extract", 2, _sqlite_time_extract) + create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc) + create_deterministic_function("django_time_diff", 2, _sqlite_time_diff) + create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff) + create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) + create_deterministic_function("regexp", 2, _sqlite_regexp) + create_deterministic_function("ACOS", 1, _sqlite_acos) + create_deterministic_function("ASIN", 1, _sqlite_asin) + create_deterministic_function("ATAN", 1, _sqlite_atan) + create_deterministic_function("ATAN2", 2, _sqlite_atan2) + create_deterministic_function("BITXOR", 2, _sqlite_bitxor) + create_deterministic_function("CEILING", 1, _sqlite_ceiling) + create_deterministic_function("COS", 1, _sqlite_cos) + create_deterministic_function("COT", 1, _sqlite_cot) + create_deterministic_function("DEGREES", 1, _sqlite_degrees) + create_deterministic_function("EXP", 1, _sqlite_exp) + create_deterministic_function("FLOOR", 1, _sqlite_floor) + create_deterministic_function("LN", 1, _sqlite_ln) + create_deterministic_function("LOG", 2, _sqlite_log) + create_deterministic_function("LPAD", 3, _sqlite_lpad) + create_deterministic_function("MD5", 1, _sqlite_md5) + create_deterministic_function("MOD", 2, _sqlite_mod) + create_deterministic_function("PI", 0, _sqlite_pi) + create_deterministic_function("POWER", 2, _sqlite_power) + create_deterministic_function("RADIANS", 1, _sqlite_radians) + create_deterministic_function("REPEAT", 2, _sqlite_repeat) + create_deterministic_function("REVERSE", 1, _sqlite_reverse) + create_deterministic_function("RPAD", 3, _sqlite_rpad) + create_deterministic_function("SHA1", 1, _sqlite_sha1) + create_deterministic_function("SHA224", 1, _sqlite_sha224) + create_deterministic_function("SHA256", 1, _sqlite_sha256) + create_deterministic_function("SHA384", 1, _sqlite_sha384) + create_deterministic_function("SHA512", 1, _sqlite_sha512) + create_deterministic_function("SIGN", 1, _sqlite_sign) + create_deterministic_function("SIN", 1, _sqlite_sin) + create_deterministic_function("SQRT", 1, _sqlite_sqrt) + create_deterministic_function("TAN", 1, _sqlite_tan) # Don't use the built-in RANDOM() function because it returns a value # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1). - connection.create_function('RAND', 0, random.random) - connection.create_aggregate('STDDEV_POP', 1, StdDevPop) - connection.create_aggregate('STDDEV_SAMP', 1, StdDevSamp) - connection.create_aggregate('VAR_POP', 1, VarPop) - connection.create_aggregate('VAR_SAMP', 1, VarSamp) + connection.create_function("RAND", 0, random.random) + connection.create_aggregate("STDDEV_POP", 1, StdDevPop) + connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp) + connection.create_aggregate("VAR_POP", 1, VarPop) + connection.create_aggregate("VAR_SAMP", 1, VarSamp) def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): @@ -90,9 +112,9 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): if tzname is not None and tzname != conn_tzname: tzname, sign, offset = split_tzname_delta(tzname) if offset: - hours, minutes = offset.split(':') + hours, minutes = offset.split(":") offset_delta = timedelta(hours=int(hours), minutes=int(minutes)) - dt += offset_delta if sign == '+' else -offset_delta + dt += offset_delta if sign == "+" else -offset_delta dt = timezone.localtime(dt, timezone_constructor(tzname)) return dt @@ -101,19 +123,19 @@ def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01' - elif lookup_type == 'quarter': + if lookup_type == "year": + return f"{dt.year:04d}-01-01" + elif lookup_type == "quarter": month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01' - elif lookup_type == 'week': + return f"{dt.year:04d}-{month_in_quarter:02d}-01" + elif lookup_type == "month": + return f"{dt.year:04d}-{dt.month:02d}-01" + elif lookup_type == "week": dt = dt - timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" + elif lookup_type == "day": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): @@ -127,13 +149,13 @@ def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): return None else: dt = dt_parsed - if lookup_type == 'hour': - return f'{dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + if lookup_type == "hour": + return f"{dt.hour:02d}:00:00" + elif lookup_type == "minute": + return f"{dt.hour:02d}:{dt.minute:02d}:00" + elif lookup_type == "second": + return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_datetime_cast_date(dt, tzname, conn_tzname): @@ -154,15 +176,15 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'week_day': + if lookup_type == "week_day": return (dt.isoweekday() % 7) + 1 - elif lookup_type == 'iso_week_day': + elif lookup_type == "iso_week_day": return dt.isoweekday() - elif lookup_type == 'week': + elif lookup_type == "week": return dt.isocalendar()[1] - elif lookup_type == 'quarter': + elif lookup_type == "quarter": return ceil(dt.month / 3) - elif lookup_type == 'iso_year': + elif lookup_type == "iso_year": return dt.isocalendar()[0] else: return getattr(dt, lookup_type) @@ -172,25 +194,25 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname): dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None - if lookup_type == 'year': - return f'{dt.year:04d}-01-01 00:00:00' - elif lookup_type == 'quarter': + if lookup_type == "year": + return f"{dt.year:04d}-01-01 00:00:00" + elif lookup_type == "quarter": month_in_quarter = dt.month - (dt.month - 1) % 3 - return f'{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00' - elif lookup_type == 'month': - return f'{dt.year:04d}-{dt.month:02d}-01 00:00:00' - elif lookup_type == 'week': + return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00" + elif lookup_type == "month": + return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00" + elif lookup_type == "week": dt = dt - timedelta(days=dt.weekday()) - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'day': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00' - elif lookup_type == 'hour': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00' - elif lookup_type == 'minute': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00' - elif lookup_type == 'second': - return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}' - raise ValueError(f'Unsupported lookup type: {lookup_type!r}') + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" + elif lookup_type == "day": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00" + elif lookup_type == "hour": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00" + elif lookup_type == "minute": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00" + elif lookup_type == "second": + return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}" + raise ValueError(f"Unsupported lookup type: {lookup_type!r}") def _sqlite_time_extract(lookup_type, dt): @@ -204,7 +226,7 @@ def _sqlite_time_extract(lookup_type, dt): def _sqlite_prepare_dtdelta_param(conn, param): - if conn in ['+', '-']: + if conn in ["+", "-"]: if isinstance(param, int): return timedelta(0, 0, param) else: @@ -227,13 +249,13 @@ def _sqlite_format_dtdelta(connector, lhs, rhs): real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs) except (ValueError, TypeError): return None - if connector == '+': + if connector == "+": # typecast_timestamp() returns a date or a datetime without timezone. # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" out = str(real_lhs + real_rhs) - elif connector == '-': + elif connector == "-": out = str(real_lhs - real_rhs) - elif connector == '*': + elif connector == "*": out = real_lhs * real_rhs else: out = real_lhs / real_rhs @@ -246,14 +268,14 @@ def _sqlite_time_diff(lhs, rhs): left = typecast_time(lhs) right = typecast_time(rhs) return ( - (left.hour * 60 * 60 * 1000000) + - (left.minute * 60 * 1000000) + - (left.second * 1000000) + - (left.microsecond) - - (right.hour * 60 * 60 * 1000000) - - (right.minute * 60 * 1000000) - - (right.second * 1000000) - - (right.microsecond) + (left.hour * 60 * 60 * 1000000) + + (left.minute * 60 * 1000000) + + (left.second * 1000000) + + (left.microsecond) + - (right.hour * 60 * 60 * 1000000) + - (right.minute * 60 * 1000000) + - (right.second * 1000000) + - (right.microsecond) ) @@ -380,7 +402,7 @@ def _sqlite_pi(): def _sqlite_power(x, y): if x is None or y is None: return None - return x ** y + return x**y def _sqlite_radians(x): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 4343ea180e..5bcd61eb96 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -32,13 +32,13 @@ def decoder(conv_func): def check_sqlite_version(): if Database.sqlite_version_info < (3, 9, 0): raise ImproperlyConfigured( - 'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version + "SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version ) check_sqlite_version() -Database.register_converter("bool", b'1'.__eq__) +Database.register_converter("bool", b"1".__eq__) Database.register_converter("time", decoder(parse_time)) Database.register_converter("datetime", decoder(parse_datetime)) Database.register_converter("timestamp", decoder(parse_datetime)) @@ -47,69 +47,69 @@ Database.register_adapter(decimal.Decimal, str) class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'sqlite' - display_name = 'SQLite' + vendor = "sqlite" + display_name = "SQLite" # SQLite doesn't actually support most of these types, but it "does the right # thing" given more verbose field definitions, so leave them as is so that # schema inspection is more useful. data_types = { - 'AutoField': 'integer', - 'BigAutoField': 'integer', - 'BinaryField': 'BLOB', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'decimal', - 'DurationField': 'bigint', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'real', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'char(15)', - 'GenericIPAddressField': 'char(39)', - 'JSONField': 'text', - 'OneToOneField': 'integer', - 'PositiveBigIntegerField': 'bigint unsigned', - 'PositiveIntegerField': 'integer unsigned', - 'PositiveSmallIntegerField': 'smallint unsigned', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallAutoField': 'integer', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'UUIDField': 'char(32)', + "AutoField": "integer", + "BigAutoField": "integer", + "BinaryField": "BLOB", + "BooleanField": "bool", + "CharField": "varchar(%(max_length)s)", + "DateField": "date", + "DateTimeField": "datetime", + "DecimalField": "decimal", + "DurationField": "bigint", + "FileField": "varchar(%(max_length)s)", + "FilePathField": "varchar(%(max_length)s)", + "FloatField": "real", + "IntegerField": "integer", + "BigIntegerField": "bigint", + "IPAddressField": "char(15)", + "GenericIPAddressField": "char(39)", + "JSONField": "text", + "OneToOneField": "integer", + "PositiveBigIntegerField": "bigint unsigned", + "PositiveIntegerField": "integer unsigned", + "PositiveSmallIntegerField": "smallint unsigned", + "SlugField": "varchar(%(max_length)s)", + "SmallAutoField": "integer", + "SmallIntegerField": "smallint", + "TextField": "text", + "TimeField": "time", + "UUIDField": "char(32)", } data_type_check_constraints = { - 'PositiveBigIntegerField': '"%(column)s" >= 0', - 'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', - 'PositiveIntegerField': '"%(column)s" >= 0', - 'PositiveSmallIntegerField': '"%(column)s" >= 0', + "PositiveBigIntegerField": '"%(column)s" >= 0', + "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', + "PositiveIntegerField": '"%(column)s" >= 0', + "PositiveSmallIntegerField": '"%(column)s" >= 0', } data_types_suffix = { - 'AutoField': 'AUTOINCREMENT', - 'BigAutoField': 'AUTOINCREMENT', - 'SmallAutoField': 'AUTOINCREMENT', + "AutoField": "AUTOINCREMENT", + "BigAutoField": "AUTOINCREMENT", + "SmallAutoField": "AUTOINCREMENT", } # SQLite requires LIKE statements to include an ESCAPE clause if the value # being escaped has a percent or underscore in it. # See https://www.sqlite.org/lang_expr.html for an explanation. operators = { - 'exact': '= %s', - 'iexact': "LIKE %s ESCAPE '\\'", - 'contains': "LIKE %s ESCAPE '\\'", - 'icontains': "LIKE %s ESCAPE '\\'", - 'regex': 'REGEXP %s', - 'iregex': "REGEXP '(?i)' || %s", - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', - 'startswith': "LIKE %s ESCAPE '\\'", - 'endswith': "LIKE %s ESCAPE '\\'", - 'istartswith': "LIKE %s ESCAPE '\\'", - 'iendswith': "LIKE %s ESCAPE '\\'", + "exact": "= %s", + "iexact": "LIKE %s ESCAPE '\\'", + "contains": "LIKE %s ESCAPE '\\'", + "icontains": "LIKE %s ESCAPE '\\'", + "regex": "REGEXP %s", + "iregex": "REGEXP '(?i)' || %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s ESCAPE '\\'", + "endswith": "LIKE %s ESCAPE '\\'", + "istartswith": "LIKE %s ESCAPE '\\'", + "iendswith": "LIKE %s ESCAPE '\\'", } # The patterns below are used to generate SQL pattern lookup clauses when @@ -122,12 +122,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): # the LIKE operator. pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')" pattern_ops = { - 'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'", - 'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'", - 'startswith': r"LIKE {} || '%%' ESCAPE '\'", - 'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'", - 'endswith': r"LIKE '%%' || {} ESCAPE '\'", - 'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'", + "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'", + "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'", + "startswith": r"LIKE {} || '%%' ESCAPE '\'", + "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'", + "endswith": r"LIKE '%%' || {} ESCAPE '\'", + "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'", } Database = Database @@ -141,14 +141,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): def get_connection_params(self): settings_dict = self.settings_dict - if not settings_dict['NAME']: + if not settings_dict["NAME"]: raise ImproperlyConfigured( "settings.DATABASES is improperly configured. " - "Please supply the NAME value.") + "Please supply the NAME value." + ) kwargs = { - 'database': settings_dict['NAME'], - 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, - **settings_dict['OPTIONS'], + "database": settings_dict["NAME"], + "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, + **settings_dict["OPTIONS"], } # Always allow the underlying SQLite connection to be shareable # between multiple threads. The safe-guarding will be handled at a @@ -156,15 +157,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): # property. This is necessary as the shareability is disabled by # default in pysqlite and it cannot be changed once a connection is # opened. - if 'check_same_thread' in kwargs and kwargs['check_same_thread']: + if "check_same_thread" in kwargs and kwargs["check_same_thread"]: warnings.warn( - 'The `check_same_thread` option was provided and set to ' - 'True. It will be overridden with False. Use the ' - '`DatabaseWrapper.allow_thread_sharing` property instead ' - 'for controlling thread shareability.', - RuntimeWarning + "The `check_same_thread` option was provided and set to " + "True. It will be overridden with False. Use the " + "`DatabaseWrapper.allow_thread_sharing` property instead " + "for controlling thread shareability.", + RuntimeWarning, ) - kwargs.update({'check_same_thread': False, 'uri': True}) + kwargs.update({"check_same_thread": False, "uri": True}) return kwargs @async_unsafe @@ -172,10 +173,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): conn = Database.connect(**conn_params) register_functions(conn) - conn.execute('PRAGMA foreign_keys = ON') + conn.execute("PRAGMA foreign_keys = ON") # The macOS bundled SQLite defaults legacy_alter_table ON, which # prevents atomic table renames (feature supports_atomic_references_rename) - conn.execute('PRAGMA legacy_alter_table = OFF') + conn.execute("PRAGMA legacy_alter_table = OFF") return conn def init_connection_state(self): @@ -207,7 +208,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): else: # sqlite3's internal default is ''. It's different from None. # See Modules/_sqlite/connection.c. - level = '' + level = "" # 'isolation_level' is a misleading API. # SQLite always runs at the SERIALIZABLE isolation level. with self.wrap_database_errors: @@ -215,16 +216,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): def disable_constraint_checking(self): with self.cursor() as cursor: - cursor.execute('PRAGMA foreign_keys = OFF') + cursor.execute("PRAGMA foreign_keys = OFF") # Foreign key constraints cannot be turned off while in a multi- # statement transaction. Fetch the current state of the pragma # to determine if constraints are effectively disabled. - enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0] + enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0] return not bool(enabled) def enable_constraint_checking(self): with self.cursor() as cursor: - cursor.execute('PRAGMA foreign_keys = ON') + cursor.execute("PRAGMA foreign_keys = ON") def check_constraints(self, table_names=None): """ @@ -237,24 +238,32 @@ class DatabaseWrapper(BaseDatabaseWrapper): if self.features.supports_pragma_foreign_key_check: with self.cursor() as cursor: if table_names is None: - violations = cursor.execute('PRAGMA foreign_key_check').fetchall() + violations = cursor.execute("PRAGMA foreign_key_check").fetchall() else: violations = chain.from_iterable( cursor.execute( - 'PRAGMA foreign_key_check(%s)' + "PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name) ).fetchall() for table_name in table_names ) # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check - for table_name, rowid, referenced_table_name, foreign_key_index in violations: + for ( + table_name, + rowid, + referenced_table_name, + foreign_key_index, + ) in violations: foreign_key = cursor.execute( - 'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name) + "PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name) ).fetchall()[foreign_key_index] column_name, referenced_column_name = foreign_key[3:5] - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + primary_key_column_name = self.introspection.get_primary_key_column( + cursor, table_name + ) primary_key_value, bad_value = cursor.execute( - 'SELECT %s, %s FROM %s WHERE rowid = %%s' % ( + "SELECT %s, %s FROM %s WHERE rowid = %%s" + % ( self.ops.quote_name(primary_key_column_name), self.ops.quote_name(column_name), self.ops.quote_name(table_name), @@ -264,9 +273,15 @@ class DatabaseWrapper(BaseDatabaseWrapper): raise IntegrityError( "The row in table '%s' with primary key '%s' has an " "invalid foreign key: %s.%s contains a value '%s' that " - "does not have a corresponding value in %s.%s." % ( - table_name, primary_key_value, table_name, column_name, - bad_value, referenced_table_name, referenced_column_name + "does not have a corresponding value in %s.%s." + % ( + table_name, + primary_key_value, + table_name, + column_name, + bad_value, + referenced_table_name, + referenced_column_name, ) ) else: @@ -274,11 +289,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): if table_names is None: table_names = self.introspection.table_names(cursor) for table_name in table_names: - primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) + primary_key_column_name = self.introspection.get_primary_key_column( + cursor, table_name + ) if not primary_key_column_name: continue relations = self.introspection.get_relations(cursor, table_name) - for column_name, (referenced_column_name, referenced_table_name) in relations: + for column_name, ( + referenced_column_name, + referenced_table_name, + ) in relations: cursor.execute( """ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING @@ -287,18 +307,29 @@ class DatabaseWrapper(BaseDatabaseWrapper): WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL """ % ( - primary_key_column_name, column_name, table_name, - referenced_table_name, column_name, referenced_column_name, - column_name, referenced_column_name, + primary_key_column_name, + column_name, + table_name, + referenced_table_name, + column_name, + referenced_column_name, + column_name, + referenced_column_name, ) ) for bad_row in cursor.fetchall(): raise IntegrityError( "The row in table '%s' with primary key '%s' has an " "invalid foreign key: %s.%s contains a value '%s' that " - "does not have a corresponding value in %s.%s." % ( - table_name, bad_row[0], table_name, column_name, - bad_row[1], referenced_table_name, referenced_column_name, + "does not have a corresponding value in %s.%s." + % ( + table_name, + bad_row[0], + table_name, + column_name, + bad_row[1], + referenced_table_name, + referenced_column_name, ) ) @@ -315,10 +346,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.cursor().execute("BEGIN") def is_in_memory_db(self): - return self.creation.is_in_memory_db(self.settings_dict['NAME']) + return self.creation.is_in_memory_db(self.settings_dict["NAME"]) -FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s') +FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s") class SQLiteCursorWrapper(Database.Cursor): @@ -327,6 +358,7 @@ class SQLiteCursorWrapper(Database.Cursor): This fixes it -- but note that if you want to use a literal "%s" in a query, you'll need to use "%%s". """ + def execute(self, query, params=None): if params is None: return Database.Cursor.execute(self, query) @@ -338,4 +370,4 @@ class SQLiteCursorWrapper(Database.Cursor): return Database.Cursor.executemany(self, query, param_list) def convert_query(self, query): - return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') + return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%") diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index 69b9568db3..7cee35dc81 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -2,9 +2,9 @@ from django.db.backends.base.client import BaseDatabaseClient class DatabaseClient(BaseDatabaseClient): - executable_name = 'sqlite3' + executable_name = "sqlite3" @classmethod def settings_to_cmd_args_env(cls, settings_dict, parameters): - args = [cls.executable_name, settings_dict['NAME'], *parameters] + args = [cls.executable_name, settings_dict["NAME"], *parameters] return args, None diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index 4a4046c670..9d8d4a63ad 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -7,17 +7,16 @@ from django.db.backends.base.creation import BaseDatabaseCreation class DatabaseCreation(BaseDatabaseCreation): - @staticmethod def is_in_memory_db(database_name): return not isinstance(database_name, Path) and ( - database_name == ':memory:' or 'mode=memory' in database_name + database_name == ":memory:" or "mode=memory" in database_name ) def _get_test_db_name(self): - test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:' - if test_database_name == ':memory:': - return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias + test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:" + if test_database_name == ":memory:": + return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias return test_database_name def _create_test_db(self, verbosity, autoclobber, keepdb=False): @@ -28,38 +27,39 @@ class DatabaseCreation(BaseDatabaseCreation): if not self.is_in_memory_db(test_database_name): # Erase the old test database if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, test_database_name), - )) + self.log( + "Destroying old test database for alias %s..." + % (self._get_database_display_str(verbosity, test_database_name),) + ) if os.access(test_database_name, os.F_OK): if not autoclobber: confirm = input( "Type 'yes' if you would like to try deleting the test " "database '%s', or 'no' to cancel: " % test_database_name ) - if autoclobber or confirm == 'yes': + if autoclobber or confirm == "yes": try: os.remove(test_database_name) except Exception as e: - self.log('Got an error deleting the old test database: %s' % e) + self.log("Got an error deleting the old test database: %s" % e) sys.exit(2) else: - self.log('Tests cancelled.') + self.log("Tests cancelled.") sys.exit(1) return test_database_name def get_test_db_clone_settings(self, suffix): orig_settings_dict = self.connection.settings_dict - source_database_name = orig_settings_dict['NAME'] + source_database_name = orig_settings_dict["NAME"] if self.is_in_memory_db(source_database_name): return orig_settings_dict else: - root, ext = os.path.splitext(orig_settings_dict['NAME']) - return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)} + root, ext = os.path.splitext(orig_settings_dict["NAME"]) + return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)} def _clone_test_db(self, suffix, verbosity, keepdb=False): - source_database_name = self.connection.settings_dict['NAME'] - target_database_name = self.get_test_db_clone_settings(suffix)['NAME'] + source_database_name = self.connection.settings_dict["NAME"] + target_database_name = self.get_test_db_clone_settings(suffix)["NAME"] # Forking automatically makes a copy of an in-memory database. if not self.is_in_memory_db(source_database_name): # Erase the old test database @@ -67,18 +67,23 @@ class DatabaseCreation(BaseDatabaseCreation): if keepdb: return if verbosity >= 1: - self.log('Destroying old test database for alias %s...' % ( - self._get_database_display_str(verbosity, target_database_name), - )) + self.log( + "Destroying old test database for alias %s..." + % ( + self._get_database_display_str( + verbosity, target_database_name + ), + ) + ) try: os.remove(target_database_name) except Exception as e: - self.log('Got an error deleting the old test database: %s' % e) + self.log("Got an error deleting the old test database: %s" % e) sys.exit(2) try: shutil.copy(source_database_name, target_database_name) except Exception as e: - self.log('Got an error cloning the test database: %s' % e) + self.log("Got an error cloning the test database: %s" % e) sys.exit(2) def _destroy_test_db(self, test_database_name, verbosity): @@ -95,7 +100,7 @@ class DatabaseCreation(BaseDatabaseCreation): TEST NAME. See https://www.sqlite.org/inmemorydb.html """ test_database_name = self._get_test_db_name() - sig = [self.connection.settings_dict['NAME']] + sig = [self.connection.settings_dict["NAME"]] if self.is_in_memory_db(test_database_name): sig.append(self.connection.alias) else: diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 153ce8d1d1..c076f0121e 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -43,49 +43,53 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0) supports_update_conflicts_with_target = supports_update_conflicts test_collations = { - 'ci': 'nocase', - 'cs': 'binary', - 'non_default': 'nocase', + "ci": "nocase", + "cs": "binary", + "non_default": "nocase", } django_test_expected_failures = { # The django_format_dtdelta() function doesn't properly handle mixed # Date/DateTime fields and timedeltas. - 'expressions.tests.FTimeDeltaTests.test_mixed_comparisons1', + "expressions.tests.FTimeDeltaTests.test_mixed_comparisons1", } @cached_property def django_test_skips(self): skips = { - 'SQLite stores values rounded to 15 significant digits.': { - 'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding', + "SQLite stores values rounded to 15 significant digits.": { + "model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding", }, - 'SQLite naively remakes the table on field alteration.': { - 'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops', - 'schema.tests.SchemaTests.test_unique_and_reverse_m2m', - 'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries', - 'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references', + "SQLite naively remakes the table on field alteration.": { + "schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops", + "schema.tests.SchemaTests.test_unique_and_reverse_m2m", + "schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries", + "schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references", }, "SQLite doesn't support negative precision for ROUND().": { - 'db_functions.math.test_round.RoundTests.test_null_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_float_with_negative_precision', - 'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision', + "db_functions.math.test_round.RoundTests.test_null_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_float_with_negative_precision", + "db_functions.math.test_round.RoundTests.test_integer_with_negative_precision", }, } if Database.sqlite_version_info < (3, 27): - skips.update({ - 'Nondeterministic failure on SQLite < 3.27.': { - 'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank', - }, - }) + skips.update( + { + "Nondeterministic failure on SQLite < 3.27.": { + "expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank", + }, + } + ) if self.connection.is_in_memory_db(): - skips.update({ - "the sqlite backend's close() method is a no-op when using an " - "in-memory database": { - 'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections', - 'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections', - }, - }) + skips.update( + { + "the sqlite backend's close() method is a no-op when using an " + "in-memory database": { + "servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections", + "servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections", + }, + } + ) return skips @cached_property @@ -94,12 +98,12 @@ class DatabaseFeatures(BaseDatabaseFeatures): @cached_property def introspected_field_types(self): - return{ + return { **super().introspected_field_types, - 'BigAutoField': 'AutoField', - 'DurationField': 'BigIntegerField', - 'GenericIPAddressField': 'CharField', - 'SmallAutoField': 'AutoField', + "BigAutoField": "AutoField", + "DurationField": "BigIntegerField", + "GenericIPAddressField": "CharField", + "SmallAutoField": "AutoField", } @cached_property @@ -112,11 +116,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): return False return True - can_introspect_json_field = property(operator.attrgetter('supports_json_field')) - has_json_object_function = property(operator.attrgetter('supports_json_field')) + can_introspect_json_field = property(operator.attrgetter("supports_json_field")) + has_json_object_function = property(operator.attrgetter("supports_json_field")) @cached_property def can_return_columns_from_insert(self): return Database.sqlite_version_info >= (3, 35) - can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert')) + can_return_rows_from_bulk_insert = property( + operator.attrgetter("can_return_columns_from_insert") + ) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 81884a7951..f5a5e81e9d 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -3,19 +3,21 @@ from collections import namedtuple import sqlparse from django.db import DatabaseError -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo from django.db.models import Index from django.utils.regex_helper import _lazy_re_compile -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint')) +FieldInfo = namedtuple( + "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint") +) -field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') +field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$") def get_field_size(name): - """ Extract the size number from a "varchar(11)" type name """ + """Extract the size number from a "varchar(11)" type name""" m = field_size_re.search(name) return int(m[1]) if m else None @@ -28,29 +30,29 @@ class FlexibleFieldLookupDict: # entries here because SQLite allows for anything and doesn't normalize the # field type; it uses whatever was given. base_data_types_reverse = { - 'bool': 'BooleanField', - 'boolean': 'BooleanField', - 'smallint': 'SmallIntegerField', - 'smallint unsigned': 'PositiveSmallIntegerField', - 'smallinteger': 'SmallIntegerField', - 'int': 'IntegerField', - 'integer': 'IntegerField', - 'bigint': 'BigIntegerField', - 'integer unsigned': 'PositiveIntegerField', - 'bigint unsigned': 'PositiveBigIntegerField', - 'decimal': 'DecimalField', - 'real': 'FloatField', - 'text': 'TextField', - 'char': 'CharField', - 'varchar': 'CharField', - 'blob': 'BinaryField', - 'date': 'DateField', - 'datetime': 'DateTimeField', - 'time': 'TimeField', + "bool": "BooleanField", + "boolean": "BooleanField", + "smallint": "SmallIntegerField", + "smallint unsigned": "PositiveSmallIntegerField", + "smallinteger": "SmallIntegerField", + "int": "IntegerField", + "integer": "IntegerField", + "bigint": "BigIntegerField", + "integer unsigned": "PositiveIntegerField", + "bigint unsigned": "PositiveBigIntegerField", + "decimal": "DecimalField", + "real": "FloatField", + "text": "TextField", + "char": "CharField", + "varchar": "CharField", + "blob": "BinaryField", + "date": "DateField", + "datetime": "DateTimeField", + "time": "TimeField", } def __getitem__(self, key): - key = key.lower().split('(', 1)[0].strip() + key = key.lower().split("(", 1)[0].strip() return self.base_data_types_reverse[key] @@ -59,22 +61,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_field_type(self, data_type, description): field_type = super().get_field_type(data_type, description) - if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}: + if description.pk and field_type in { + "BigIntegerField", + "IntegerField", + "SmallIntegerField", + }: # No support for BigAutoField or SmallAutoField as SQLite treats # all integer primary keys as signed 64-bit integers. - return 'AutoField' + return "AutoField" if description.has_json_constraint: - return 'JSONField' + return "JSONField" return field_type def get_table_list(self, cursor): """Return a list of table and view names in the current database.""" # Skip the sqlite_sequence system table used for autoincrement key # generation. - cursor.execute(""" + cursor.execute( + """ SELECT name, type FROM sqlite_master WHERE type in ('table', 'view') AND NOT name='sqlite_sequence' - ORDER BY name""") + ORDER BY name""" + ) return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()] def get_table_description(self, cursor, table_name): @@ -82,37 +90,51 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a description of the table with the DB-API cursor.description interface. """ - cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) + ) table_info = cursor.fetchall() if not table_info: - raise DatabaseError(f'Table {table_name} does not exist (empty pragma).') + raise DatabaseError(f"Table {table_name} does not exist (empty pragma).") collations = self._get_column_collations(cursor, table_name) json_columns = set() if self.connection.features.can_introspect_json_field: for line in table_info: column = line[1] json_constraint_sql = '%%json_valid("%s")%%' % column - has_json_constraint = cursor.execute(""" + has_json_constraint = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s AND sql LIKE %s - """, [table_name, json_constraint_sql]).fetchone() + """, + [table_name, json_constraint_sql], + ).fetchone() if has_json_constraint: json_columns.add(column) return [ FieldInfo( - name, data_type, None, get_field_size(data_type), None, None, - not notnull, default, collations.get(name), pk == 1, name in json_columns + name, + data_type, + None, + get_field_size(data_type), + None, + None, + not notnull, + default, + collations.get(name), + pk == 1, + name in json_columns, ) for cid, name, data_type, notnull, default, pk in table_info ] def get_sequences(self, cursor, table_name, table_fields=()): pk_col = self.get_primary_key_column(cursor, table_name) - return [{'table': table_name, 'column': pk_col}] + return [{"table": table_name, "column": pk_col}] def get_relations(self, cursor, table_name): """ @@ -120,7 +142,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): representing all foreign keys in the given table. """ cursor.execute( - 'PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name) ) return { column_name: (ref_column_name, ref_table_name) @@ -130,7 +152,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): def get_primary_key_column(self, cursor, table_name): """Return the column name of the primary key for the given table.""" cursor.execute( - 'PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name) + "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) ) for _, name, *_, pk in cursor.fetchall(): if pk: @@ -148,19 +170,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): check_columns = [] braces_deep = 0 for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): braces_deep += 1 - elif token.match(sqlparse.tokens.Punctuation, ')'): + elif token.match(sqlparse.tokens.Punctuation, ")"): braces_deep -= 1 if braces_deep < 0: # End of columns and constraints for table definition. break - elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','): + elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","): # End of current column or constraint definition. break # Detect column or constraint definition by first token. if is_constraint_definition is None: - is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT') + is_constraint_definition = token.match( + sqlparse.tokens.Keyword, "CONSTRAINT" + ) if is_constraint_definition: continue if is_constraint_definition: @@ -171,7 +195,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: constraint_name = token.value[1:-1] # Start constraint columns parsing after UNIQUE keyword. - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique = True unique_braces_deep = braces_deep elif unique: @@ -191,10 +215,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): field_name = token.value elif token.ttype == sqlparse.tokens.Literal.String.Symbol: field_name = token.value[1:-1] - if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + if token.match(sqlparse.tokens.Keyword, "UNIQUE"): unique_columns = [field_name] # Start constraint columns parsing after CHECK keyword. - if token.match(sqlparse.tokens.Keyword, 'CHECK'): + if token.match(sqlparse.tokens.Keyword, "CHECK"): check = True check_braces_deep = braces_deep elif check: @@ -209,22 +233,30 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): elif token.ttype == sqlparse.tokens.Literal.String.Symbol: if token.value[1:-1] in columns: check_columns.append(token.value[1:-1]) - unique_constraint = { - 'unique': True, - 'columns': unique_columns, - 'primary_key': False, - 'foreign_key': None, - 'check': False, - 'index': False, - } if unique_columns else None - check_constraint = { - 'check': True, - 'columns': check_columns, - 'primary_key': False, - 'unique': False, - 'foreign_key': None, - 'index': False, - } if check_columns else None + unique_constraint = ( + { + "unique": True, + "columns": unique_columns, + "primary_key": False, + "foreign_key": None, + "check": False, + "index": False, + } + if unique_columns + else None + ) + check_constraint = ( + { + "check": True, + "columns": check_columns, + "primary_key": False, + "unique": False, + "foreign_key": None, + "index": False, + } + if check_columns + else None + ) return constraint_name, unique_constraint, check_constraint, token def _parse_table_constraints(self, sql, columns): @@ -236,24 +268,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): tokens = (token for token in statement.flatten() if not token.is_whitespace) # Go to columns and constraint definition for token in tokens: - if token.match(sqlparse.tokens.Punctuation, '('): + if token.match(sqlparse.tokens.Punctuation, "("): break # Parse columns and constraint definition while True: - constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns) + ( + constraint_name, + unique, + check, + end_token, + ) = self._parse_column_or_constraint_definition(tokens, columns) if unique: if constraint_name: constraints[constraint_name] = unique else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = unique if check: if constraint_name: constraints[constraint_name] = check else: unnamed_constrains_index += 1 - constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check - if end_token.match(sqlparse.tokens.Punctuation, ')'): + constraints[ + "__unnamed_constraint_%s__" % unnamed_constrains_index + ] = check + if end_token.match(sqlparse.tokens.Punctuation, ")"): break return constraints @@ -266,19 +307,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Find inline check constraints. try: table_schema = cursor.execute( - "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % ( - self.connection.ops.quote_name(table_name), - ) + "SELECT sql FROM sqlite_master WHERE type='table' and name=%s" + % (self.connection.ops.quote_name(table_name),) ).fetchone()[0] except TypeError: # table_name is a view. pass else: - columns = {info.name for info in self.get_table_description(cursor, table_name)} + columns = { + info.name for info in self.get_table_description(cursor, table_name) + } constraints.update(self._parse_table_constraints(table_schema, columns)) # Get the index info - cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) + cursor.execute( + "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name) + ) for row in cursor.fetchall(): # SQLite 3.8.9+ has 5 columns, however older versions only give 3 # columns. Discard last 2 columns if there. @@ -288,7 +332,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) ) # There's at most one row. - sql, = cursor.fetchone() or (None,) + (sql,) = cursor.fetchone() or (None,) # Inline constraints are already detected in # _parse_table_constraints(). The reasons to avoid fetching inline # constraints from `PRAGMA index_list` are: @@ -299,7 +343,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # An inline constraint continue # Get the index info for that index - cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) + cursor.execute( + "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index) + ) for index_rank, column_rank, column in cursor.fetchall(): if index not in constraints: constraints[index] = { @@ -310,14 +356,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "check": False, "index": True, } - constraints[index]['columns'].append(column) + constraints[index]["columns"].append(column) # Add type and column orders for indexes - if constraints[index]['index']: + if constraints[index]["index"]: # SQLite doesn't support any index type other than b-tree - constraints[index]['type'] = Index.suffix + constraints[index]["type"] = Index.suffix orders = self._get_index_columns_orders(sql) if orders is not None: - constraints[index]['orders'] = orders + constraints[index]["orders"] = orders # Get the PK pk_column = self.get_primary_key_column(cursor, table_name) if pk_column: @@ -334,44 +380,49 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): "index": False, } relations = enumerate(self.get_relations(cursor, table_name).items()) - constraints.update({ - f'fk_{index}': { - 'columns': [column_name], - 'primary_key': False, - 'unique': False, - 'foreign_key': (ref_table_name, ref_column_name), - 'check': False, - 'index': False, + constraints.update( + { + f"fk_{index}": { + "columns": [column_name], + "primary_key": False, + "unique": False, + "foreign_key": (ref_table_name, ref_column_name), + "check": False, + "index": False, + } + for index, (column_name, (ref_column_name, ref_table_name)) in relations } - for index, (column_name, (ref_column_name, ref_table_name)) in relations - }) + ) return constraints def _get_index_columns_orders(self, sql): tokens = sqlparse.parse(sql)[0] for token in tokens: if isinstance(token, sqlparse.sql.Parenthesis): - columns = str(token).strip('()').split(', ') - return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns] + columns = str(token).strip("()").split(", ") + return ["DESC" if info.endswith("DESC") else "ASC" for info in columns] return None def _get_column_collations(self, cursor, table_name): - row = cursor.execute(""" + row = cursor.execute( + """ SELECT sql FROM sqlite_master WHERE type = 'table' AND name = %s - """, [table_name]).fetchone() + """, + [table_name], + ).fetchone() if not row: return {} sql = row[0] - columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ') + columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ") collations = {} for column in columns: tokens = column[1:].split() column_name = tokens[0].strip('"') for index, token in enumerate(tokens): - if token == 'COLLATE': + if token == "COLLATE": collation = tokens[index + 1] break else: diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index c1a6da4e5d..ef8b91c0f0 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -16,15 +16,15 @@ from django.utils.functional import cached_property class DatabaseOperations(BaseDatabaseOperations): - cast_char_field_without_max_length = 'text' + cast_char_field_without_max_length = "text" cast_data_types = { - 'DateField': 'TEXT', - 'DateTimeField': 'TEXT', + "DateField": "TEXT", + "DateTimeField": "TEXT", } - explain_prefix = 'EXPLAIN QUERY PLAN' + explain_prefix = "EXPLAIN QUERY PLAN" # List of datatypes to that cannot be extracted with JSON_EXTRACT() on # SQLite. Use JSON_TYPE() instead. - jsonfield_datatype_values = frozenset(['null', 'false', 'true']) + jsonfield_datatype_values = frozenset(["null", "false", "true"]) def bulk_batch_size(self, fields, objs): """ @@ -55,14 +55,14 @@ class DatabaseOperations(BaseDatabaseOperations): else: if isinstance(output_field, bad_fields): raise NotSupportedError( - 'You cannot use Sum, Avg, StdDev, and Variance ' - 'aggregations on date/time fields in sqlite3 ' - 'since date/time is saved as text.' + "You cannot use Sum, Avg, StdDev, and Variance " + "aggregations on date/time fields in sqlite3 " + "since date/time is saved as text." ) if ( - isinstance(expression, models.Aggregate) and - expression.distinct and - len(expression.source_expressions) > 1 + isinstance(expression, models.Aggregate) + and expression.distinct + and len(expression.source_expressions) > 1 ): raise NotSupportedError( "SQLite doesn't support DISTINCT on aggregate functions " @@ -105,26 +105,32 @@ class DatabaseOperations(BaseDatabaseOperations): def _convert_tznames_to_sql(self, tzname): if tzname and settings.USE_TZ: return "'%s'" % tzname, "'%s'" % self.connection.timezone_name - return 'NULL', 'NULL' + return "NULL", "NULL" def datetime_cast_date_sql(self, field_name, tzname): - return 'django_datetime_cast_date(%s, %s, %s)' % ( - field_name, *self._convert_tznames_to_sql(tzname), + return "django_datetime_cast_date(%s, %s, %s)" % ( + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_cast_time_sql(self, field_name, tzname): - return 'django_datetime_cast_time(%s, %s, %s)' % ( - field_name, *self._convert_tznames_to_sql(tzname), + return "django_datetime_cast_time(%s, %s, %s)" % ( + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_extract_sql(self, lookup_type, field_name, tzname): return "django_datetime_extract('%s', %s, %s, %s)" % ( - lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname), + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), ) def datetime_trunc_sql(self, lookup_type, field_name, tzname): return "django_datetime_trunc('%s', %s, %s, %s)" % ( - lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname), + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), ) def time_extract_sql(self, lookup_type, field_name): @@ -146,11 +152,11 @@ class DatabaseOperations(BaseDatabaseOperations): if len(params) > BATCH_SIZE: results = () for index in range(0, len(params), BATCH_SIZE): - chunk = params[index:index + BATCH_SIZE] + chunk = params[index : index + BATCH_SIZE] results += self._quote_params_for_last_executed_query(chunk) return results - sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params)) + sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params)) # Bypass Django's wrappers and use the underlying sqlite3 connection # to avoid logging this query - it would trigger infinite recursion. cursor = self.connection.connection.cursor() @@ -215,14 +221,20 @@ class DatabaseOperations(BaseDatabaseOperations): if tables and allow_cascade: # Simulate TRUNCATE CASCADE by recursively collecting the tables # referencing the tables to be flushed. - tables = set(chain.from_iterable(self._references_graph(table) for table in tables)) - sql = ['%s %s %s;' % ( - style.SQL_KEYWORD('DELETE'), - style.SQL_KEYWORD('FROM'), - style.SQL_FIELD(self.quote_name(table)) - ) for table in tables] + tables = set( + chain.from_iterable(self._references_graph(table) for table in tables) + ) + sql = [ + "%s %s %s;" + % ( + style.SQL_KEYWORD("DELETE"), + style.SQL_KEYWORD("FROM"), + style.SQL_FIELD(self.quote_name(table)), + ) + for table in tables + ] if reset_sequences: - sequences = [{'table': table} for table in tables] + sequences = [{"table": table} for table in tables] sql.extend(self.sequence_reset_by_name_sql(style, sequences)) return sql @@ -230,17 +242,18 @@ class DatabaseOperations(BaseDatabaseOperations): if not sequences: return [] return [ - '%s %s %s %s = 0 %s %s %s (%s);' % ( - style.SQL_KEYWORD('UPDATE'), - style.SQL_TABLE(self.quote_name('sqlite_sequence')), - style.SQL_KEYWORD('SET'), - style.SQL_FIELD(self.quote_name('seq')), - style.SQL_KEYWORD('WHERE'), - style.SQL_FIELD(self.quote_name('name')), - style.SQL_KEYWORD('IN'), - ', '.join([ - "'%s'" % sequence_info['table'] for sequence_info in sequences - ]), + "%s %s %s %s = 0 %s %s %s (%s);" + % ( + style.SQL_KEYWORD("UPDATE"), + style.SQL_TABLE(self.quote_name("sqlite_sequence")), + style.SQL_KEYWORD("SET"), + style.SQL_FIELD(self.quote_name("seq")), + style.SQL_KEYWORD("WHERE"), + style.SQL_FIELD(self.quote_name("name")), + style.SQL_KEYWORD("IN"), + ", ".join( + ["'%s'" % sequence_info["table"] for sequence_info in sequences] + ), ), ] @@ -249,7 +262,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # SQLite doesn't support tz-aware datetimes @@ -257,7 +270,9 @@ class DatabaseOperations(BaseDatabaseOperations): if settings.USE_TZ: value = timezone.make_naive(value, self.connection.timezone) else: - raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.") + raise ValueError( + "SQLite backend does not support timezone-aware datetimes when USE_TZ is False." + ) return str(value) @@ -266,7 +281,7 @@ class DatabaseOperations(BaseDatabaseOperations): return None # Expression values are adapted by the database. - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): return value # SQLite doesn't support tz-aware datetimes @@ -278,17 +293,17 @@ class DatabaseOperations(BaseDatabaseOperations): def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == 'DateTimeField': + if internal_type == "DateTimeField": converters.append(self.convert_datetimefield_value) - elif internal_type == 'DateField': + elif internal_type == "DateField": converters.append(self.convert_datefield_value) - elif internal_type == 'TimeField': + elif internal_type == "TimeField": converters.append(self.convert_timefield_value) - elif internal_type == 'DecimalField': + elif internal_type == "DecimalField": converters.append(self.get_decimalfield_converter(expression)) - elif internal_type == 'UUIDField': + elif internal_type == "UUIDField": converters.append(self.convert_uuidfield_value) - elif internal_type == 'BooleanField': + elif internal_type == "BooleanField": converters.append(self.convert_booleanfield_value) return converters @@ -317,15 +332,22 @@ class DatabaseOperations(BaseDatabaseOperations): # float inaccuracy must be removed. create_decimal = decimal.Context(prec=15).create_decimal_from_float if isinstance(expression, Col): - quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places) + quantize_value = decimal.Decimal(1).scaleb( + -expression.output_field.decimal_places + ) def converter(value, expression, connection): if value is not None: - return create_decimal(value).quantize(quantize_value, context=expression.output_field.context) + return create_decimal(value).quantize( + quantize_value, context=expression.output_field.context + ) + else: + def converter(value, expression, connection): if value is not None: return create_decimal(value) + return converter def convert_uuidfield_value(self, value, expression, connection): @@ -337,26 +359,26 @@ class DatabaseOperations(BaseDatabaseOperations): return bool(value) if value in (1, 0) else value def bulk_insert_sql(self, fields, placeholder_rows): - placeholder_rows_sql = (', '.join(row) for row in placeholder_rows) - values_sql = ', '.join(f'({sql})' for sql in placeholder_rows_sql) - return f'VALUES {values_sql}' + placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) + values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql) + return f"VALUES {values_sql}" def combine_expression(self, connector, sub_expressions): # SQLite doesn't have a ^ operator, so use the user-defined POWER # function that's registered in connect(). - if connector == '^': - return 'POWER(%s)' % ','.join(sub_expressions) - elif connector == '#': - return 'BITXOR(%s)' % ','.join(sub_expressions) + if connector == "^": + return "POWER(%s)" % ",".join(sub_expressions) + elif connector == "#": + return "BITXOR(%s)" % ",".join(sub_expressions) return super().combine_expression(connector, sub_expressions) def combine_duration_expression(self, connector, sub_expressions): - if connector not in ['+', '-', '*', '/']: - raise DatabaseError('Invalid connector for timedelta: %s.' % connector) + if connector not in ["+", "-", "*", "/"]: + raise DatabaseError("Invalid connector for timedelta: %s." % connector) fn_params = ["'%s'" % connector] + sub_expressions if len(fn_params) > 3: - raise ValueError('Too many params for timedelta operations.') - return "django_format_dtdelta(%s)" % ', '.join(fn_params) + raise ValueError("Too many params for timedelta operations.") + return "django_format_dtdelta(%s)" % ", ".join(fn_params) def integer_field_range(self, internal_type): # SQLite doesn't enforce any integer constraints @@ -366,39 +388,46 @@ class DatabaseOperations(BaseDatabaseOperations): lhs_sql, lhs_params = lhs rhs_sql, rhs_params = rhs params = (*lhs_params, *rhs_params) - if internal_type == 'TimeField': - return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params - return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params + if internal_type == "TimeField": + return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params + return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params def insert_statement(self, on_conflict=None): if on_conflict == OnConflict.IGNORE: - return 'INSERT OR IGNORE INTO' + return "INSERT OR IGNORE INTO" return super().insert_statement(on_conflict=on_conflict) def return_insert_columns(self, fields): # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. if not fields: - return '', () + return "", () columns = [ - '%s.%s' % ( + "%s.%s" + % ( self.quote_name(field.model._meta.db_table), self.quote_name(field.column), - ) for field in fields + ) + for field in fields ] - return 'RETURNING %s' % ', '.join(columns), () + return "RETURNING %s" % ", ".join(columns), () def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): if ( - on_conflict == OnConflict.UPDATE and - self.connection.features.supports_update_conflicts_with_target + on_conflict == OnConflict.UPDATE + and self.connection.features.supports_update_conflicts_with_target ): - return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( - ', '.join(map(self.quote_name, unique_fields)), - ', '.join([ - f'{field} = EXCLUDED.{field}' - for field in map(self.quote_name, update_fields) - ]), + return "ON CONFLICT(%s) DO UPDATE SET %s" % ( + ", ".join(map(self.quote_name, unique_fields)), + ", ".join( + [ + f"{field} = EXCLUDED.{field}" + for field in map(self.quote_name, update_fields) + ] + ), ) return super().on_conflict_suffix_sql( - fields, on_conflict, update_fields, unique_fields, + fields, + on_conflict, + update_fields, + unique_fields, ) diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 3ff0a3f7db..c9af8088e5 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -14,7 +14,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_table = "DROP TABLE %(table)s" sql_create_fk = None - sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" + sql_create_inline_fk = ( + "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" + ) sql_create_column_inline_fk = sql_create_inline_fk sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" sql_delete_unique = "DROP INDEX %(name)s" @@ -24,11 +26,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # disabled. Enforce it here for the duration of the schema edition. if not self.connection.disable_constraint_checking(): raise NotSupportedError( - 'SQLite schema editor cannot be used while foreign key ' - 'constraint checks are enabled. Make sure to disable them ' - 'before entering a transaction.atomic() context because ' - 'SQLite does not support disabling them in the middle of ' - 'a multi-statement transaction.' + "SQLite schema editor cannot be used while foreign key " + "constraint checks are enabled. Make sure to disable them " + "before entering a transaction.atomic() context because " + "SQLite does not support disabling them in the middle of " + "a multi-statement transaction." ) return super().__enter__() @@ -43,6 +45,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # security hardening). try: import sqlite3 + value = sqlite3.adapt(value) except ImportError: pass @@ -54,7 +57,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): elif isinstance(value, (Decimal, float, int)): return str(value) elif isinstance(value, str): - return "'%s'" % value.replace("\'", "\'\'") + return "'%s'" % value.replace("'", "''") elif value is None: return "NULL" elif isinstance(value, (bytes, bytearray, memoryview)): @@ -63,12 +66,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # character. return "X'%s'" % value.hex() else: - raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value))) + raise ValueError( + "Cannot quote parameter value %r of type %s" % (value, type(value)) + ) def prepare_default(self, value): return self.quote_value(value) - def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False): + def _is_referenced_by_fk_constraint( + self, table_name, column_name=None, ignore_self=False + ): """ Return whether or not the provided table name is referenced by another one. If `column_name` is specified, only references pointing to that @@ -79,22 +86,33 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): for other_table in self.connection.introspection.get_table_list(cursor): if ignore_self and other_table.name == table_name: continue - relations = self.connection.introspection.get_relations(cursor, other_table.name) + relations = self.connection.introspection.get_relations( + cursor, other_table.name + ) for constraint_column, constraint_table in relations.values(): - if (constraint_table == table_name and - (column_name is None or constraint_column == column_name)): + if constraint_table == table_name and ( + column_name is None or constraint_column == column_name + ): return True return False - def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True): - if (not self.connection.features.supports_atomic_references_rename and - disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)): + def alter_db_table( + self, model, old_db_table, new_db_table, disable_constraints=True + ): + if ( + not self.connection.features.supports_atomic_references_rename + and disable_constraints + and self._is_referenced_by_fk_constraint(old_db_table) + ): if self.connection.in_atomic_block: - raise NotSupportedError(( - 'Renaming the %r table while in a transaction is not ' - 'supported on SQLite < 3.26 because it would break referential ' - 'integrity. Try adding `atomic = False` to the Migration class.' - ) % old_db_table) + raise NotSupportedError( + ( + "Renaming the %r table while in a transaction is not " + "supported on SQLite < 3.26 because it would break referential " + "integrity. Try adding `atomic = False` to the Migration class." + ) + % old_db_table + ) self.connection.enable_constraint_checking() super().alter_db_table(model, old_db_table, new_db_table) self.connection.disable_constraint_checking() @@ -107,42 +125,56 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): old_field_name = old_field.name table_name = model._meta.db_table _, old_column_name = old_field.get_attname_column() - if (new_field.name != old_field_name and - not self.connection.features.supports_atomic_references_rename and - self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)): + if ( + new_field.name != old_field_name + and not self.connection.features.supports_atomic_references_rename + and self._is_referenced_by_fk_constraint( + table_name, old_column_name, ignore_self=True + ) + ): if self.connection.in_atomic_block: - raise NotSupportedError(( - 'Renaming the %r.%r column while in a transaction is not ' - 'supported on SQLite < 3.26 because it would break referential ' - 'integrity. Try adding `atomic = False` to the Migration class.' - ) % (model._meta.db_table, old_field_name)) + raise NotSupportedError( + ( + "Renaming the %r.%r column while in a transaction is not " + "supported on SQLite < 3.26 because it would break referential " + "integrity. Try adding `atomic = False` to the Migration class." + ) + % (model._meta.db_table, old_field_name) + ) with atomic(self.connection.alias): super().alter_field(model, old_field, new_field, strict=strict) # Follow SQLite's documented procedure for performing changes # that don't affect the on-disk content. # https://sqlite.org/lang_altertable.html#otheralter with self.connection.cursor() as cursor: - schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0] - cursor.execute('PRAGMA writable_schema = 1') + schema_version = cursor.execute("PRAGMA schema_version").fetchone()[ + 0 + ] + cursor.execute("PRAGMA writable_schema = 1") references_template = ' REFERENCES "%s" ("%%s") ' % table_name new_column_name = new_field.get_attname_column()[1] search = references_template % old_column_name replacement = references_template % new_column_name - cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement)) - cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1)) - cursor.execute('PRAGMA writable_schema = 0') + cursor.execute( + "UPDATE sqlite_master SET sql = replace(sql, %s, %s)", + (search, replacement), + ) + cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1)) + cursor.execute("PRAGMA writable_schema = 0") # The integrity check will raise an exception and rollback # the transaction if the sqlite_master updates corrupt the # database. - cursor.execute('PRAGMA integrity_check') + cursor.execute("PRAGMA integrity_check") # Perform a VACUUM to refresh the database representation from # the sqlite_master table. with self.connection.cursor() as cursor: - cursor.execute('VACUUM') + cursor.execute("VACUUM") else: super().alter_field(model, old_field, new_field, strict=strict) - def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None): + def _remake_table( + self, model, create_field=None, delete_field=None, alter_field=None + ): """ Shortcut to transform a model from old_model into new_model @@ -163,6 +195,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # to an altered field. def is_self_referential(f): return f.is_relation and f.remote_field.model is model + # Work out the new fields dict / mapping body = { f.name: f.clone() if is_self_referential(f) else f @@ -170,14 +203,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): } # Since mapping might mix column names and default values, # its values must be already quoted. - mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields} + mapping = { + f.column: self.quote_name(f.column) + for f in model._meta.local_concrete_fields + } # This maps field names (not columns) for things like unique_together rename_mapping = {} # If any of the new or altered fields is introducing a new PK, # remove the old one restore_pk_field = None - if getattr(create_field, 'primary_key', False) or ( - alter_field and getattr(alter_field[1], 'primary_key', False)): + if getattr(create_field, "primary_key", False) or ( + alter_field and getattr(alter_field[1], "primary_key", False) + ): for name, field in list(body.items()): if field.primary_key: field.primary_key = False @@ -201,8 +238,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): body[new_field.name] = new_field if old_field.null and not new_field.null: case_sql = "coalesce(%(col)s, %(default)s)" % { - 'col': self.quote_name(old_field.column), - 'default': self.prepare_default(self.effective_default(new_field)), + "col": self.quote_name(old_field.column), + "default": self.prepare_default(self.effective_default(new_field)), } mapping[new_field.column] = case_sql else: @@ -213,7 +250,10 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): del body[delete_field.name] del mapping[delete_field.column] # Remove any implicit M2M tables - if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created: + if ( + delete_field.many_to_many + and delete_field.remote_field.through._meta.auto_created + ): return self.delete_model(delete_field.remote_field.through) # Work inside a new app registry apps = Apps() @@ -235,8 +275,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): indexes = model._meta.indexes if delete_field: indexes = [ - index for index in indexes - if delete_field.name not in index.fields + index for index in indexes if delete_field.name not in index.fields ] constraints = list(model._meta.constraints) @@ -252,52 +291,57 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # This wouldn't be required if the schema editor was operating on model # states instead of rendered models. meta_contents = { - 'app_label': model._meta.app_label, - 'db_table': model._meta.db_table, - 'unique_together': unique_together, - 'index_together': index_together, - 'indexes': indexes, - 'constraints': constraints, - 'apps': apps, + "app_label": model._meta.app_label, + "db_table": model._meta.db_table, + "unique_together": unique_together, + "index_together": index_together, + "indexes": indexes, + "constraints": constraints, + "apps": apps, } meta = type("Meta", (), meta_contents) - body_copy['Meta'] = meta - body_copy['__module__'] = model.__module__ + body_copy["Meta"] = meta + body_copy["__module__"] = model.__module__ type(model._meta.object_name, model.__bases__, body_copy) # Construct a model with a renamed table name. body_copy = copy.deepcopy(body) meta_contents = { - 'app_label': model._meta.app_label, - 'db_table': 'new__%s' % strip_quotes(model._meta.db_table), - 'unique_together': unique_together, - 'index_together': index_together, - 'indexes': indexes, - 'constraints': constraints, - 'apps': apps, + "app_label": model._meta.app_label, + "db_table": "new__%s" % strip_quotes(model._meta.db_table), + "unique_together": unique_together, + "index_together": index_together, + "indexes": indexes, + "constraints": constraints, + "apps": apps, } meta = type("Meta", (), meta_contents) - body_copy['Meta'] = meta - body_copy['__module__'] = model.__module__ - new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy) + body_copy["Meta"] = meta + body_copy["__module__"] = model.__module__ + new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy) # Create a new table with the updated schema. self.create_model(new_model) # Copy data from the old table into the new table - self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( - self.quote_name(new_model._meta.db_table), - ', '.join(self.quote_name(x) for x in mapping), - ', '.join(mapping.values()), - self.quote_name(model._meta.db_table), - )) + self.execute( + "INSERT INTO %s (%s) SELECT %s FROM %s" + % ( + self.quote_name(new_model._meta.db_table), + ", ".join(self.quote_name(x) for x in mapping), + ", ".join(mapping.values()), + self.quote_name(model._meta.db_table), + ) + ) # Delete the old table to make way for the new self.delete_model(model, handle_autom2m=False) # Rename the new table to take way for the old self.alter_db_table( - new_model, new_model._meta.db_table, model._meta.db_table, + new_model, + new_model._meta.db_table, + model._meta.db_table, disable_constraints=False, ) @@ -314,12 +358,17 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): super().delete_model(model) else: # Delete the table (and only that) - self.execute(self.sql_delete_table % { - "table": self.quote_name(model._meta.db_table), - }) + self.execute( + self.sql_delete_table + % { + "table": self.quote_name(model._meta.db_table), + } + ) # Remove all deferred statements referencing the deleted table. for sql in list(self.deferred_sql): - if isinstance(sql, Statement) and sql.references_table(model._meta.db_table): + if isinstance(sql, Statement) and sql.references_table( + model._meta.db_table + ): self.deferred_sql.remove(sql) def add_field(self, model, field): @@ -327,11 +376,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if ( # Primary keys and unique fields are not supported in ALTER TABLE # ADD COLUMN. - field.primary_key or field.unique or + field.primary_key + or field.unique + or # Fields with default values cannot by handled by ALTER TABLE ADD # COLUMN statement because DROP DEFAULT is not supported in # ALTER TABLE. - not field.null or self.effective_default(field) is not None + not field.null + or self.effective_default(field) is not None ): self._remake_table(model, create_field=field) else: @@ -351,21 +403,40 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # For everything else, remake. else: # It might not actually have a column behind it - if field.db_parameters(connection=self.connection)['type'] is None: + if field.db_parameters(connection=self.connection)["type"] is None: return self._remake_table(model, delete_field=field) - def _alter_field(self, model, old_field, new_field, old_type, new_type, - old_db_params, new_db_params, strict=False): + def _alter_field( + self, + model, + old_field, + new_field, + old_type, + new_type, + old_db_params, + new_db_params, + strict=False, + ): """Perform a "physical" (non-ManyToMany) field update.""" # Use "ALTER TABLE ... RENAME COLUMN" if only the column name # changed and there aren't any constraints. - if (self.connection.features.can_alter_table_rename_column and - old_field.column != new_field.column and - self.column_sql(model, old_field) == self.column_sql(model, new_field) and - not (old_field.remote_field and old_field.db_constraint or - new_field.remote_field and new_field.db_constraint)): - return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) + if ( + self.connection.features.can_alter_table_rename_column + and old_field.column != new_field.column + and self.column_sql(model, old_field) == self.column_sql(model, new_field) + and not ( + old_field.remote_field + and old_field.db_constraint + or new_field.remote_field + and new_field.db_constraint + ) + ): + return self.execute( + self._rename_field_sql( + model._meta.db_table, old_field, new_field, new_type + ) + ) # Alter by remaking table self._remake_table(model, alter_field=(old_field, new_field)) # Rebuild tables with FKs pointing to this field. @@ -393,15 +464,22 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def _alter_many_to_many(self, model, old_field, new_field, strict): """Alter M2Ms to repoint their to= endpoints.""" - if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table: + if ( + old_field.remote_field.through._meta.db_table + == new_field.remote_field.through._meta.db_table + ): # The field name didn't change, but some options did; we have to propagate this altering. self._remake_table( old_field.remote_field.through, alter_field=( # We need the field that points to the target model, so we can tell alter_field to change it - # this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model) - old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()), - new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()), + old_field.remote_field.through._meta.get_field( + old_field.m2m_reverse_field_name() + ), + new_field.remote_field.through._meta.get_field( + new_field.m2m_reverse_field_name() + ), ), ) return @@ -409,29 +487,36 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Make a new through table self.create_model(new_field.remote_field.through) # Copy the data across - self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( - self.quote_name(new_field.remote_field.through._meta.db_table), - ', '.join([ - "id", - new_field.m2m_column_name(), - new_field.m2m_reverse_name(), - ]), - ', '.join([ - "id", - old_field.m2m_column_name(), - old_field.m2m_reverse_name(), - ]), - self.quote_name(old_field.remote_field.through._meta.db_table), - )) + self.execute( + "INSERT INTO %s (%s) SELECT %s FROM %s" + % ( + self.quote_name(new_field.remote_field.through._meta.db_table), + ", ".join( + [ + "id", + new_field.m2m_column_name(), + new_field.m2m_reverse_name(), + ] + ), + ", ".join( + [ + "id", + old_field.m2m_column_name(), + old_field.m2m_reverse_name(), + ] + ), + self.quote_name(old_field.remote_field.through._meta.db_table), + ) + ) # Delete the old through table self.delete_model(old_field.remote_field.through) def add_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and ( - constraint.condition or - constraint.contains_expressions or - constraint.include or - constraint.deferrable + constraint.condition + or constraint.contains_expressions + or constraint.include + or constraint.deferrable ): super().add_constraint(model, constraint) else: @@ -439,14 +524,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and ( - constraint.condition or - constraint.contains_expressions or - constraint.include or - constraint.deferrable + constraint.condition + or constraint.contains_expressions + or constraint.include + or constraint.deferrable ): super().remove_constraint(model, constraint) else: self._remake_table(model) def _collate_sql(self, collation): - return 'COLLATE ' + collation + return "COLLATE " + collation diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index b7318bae62..d505cd7904 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -9,7 +9,7 @@ from django.db import NotSupportedError from django.utils.crypto import md5 from django.utils.dateparse import parse_time -logger = logging.getLogger('django.db.backends') +logger = logging.getLogger("django.db.backends") class CursorWrapper: @@ -17,7 +17,7 @@ class CursorWrapper: self.cursor = cursor self.db = db - WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset']) + WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"]) def __getattr__(self, attr): cursor_attr = getattr(self.cursor, attr) @@ -50,8 +50,8 @@ class CursorWrapper: # database driver may support them (e.g. cx_Oracle). if kparams is not None and not self.db.features.supports_callproc_kwargs: raise NotSupportedError( - 'Keyword parameters for callproc are not supported on this ' - 'database backend.' + "Keyword parameters for callproc are not supported on this " + "database backend." ) self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: @@ -64,13 +64,17 @@ class CursorWrapper: return self.cursor.callproc(procname, params, kparams) def execute(self, sql, params=None): - return self._execute_with_wrappers(sql, params, many=False, executor=self._execute) + return self._execute_with_wrappers( + sql, params, many=False, executor=self._execute + ) def executemany(self, sql, param_list): - return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany) + return self._execute_with_wrappers( + sql, param_list, many=True, executor=self._executemany + ) def _execute_with_wrappers(self, sql, params, many, executor): - context = {'connection': self.db, 'cursor': self} + context = {"connection": self.db, "cursor": self} for wrapper in reversed(self.db.execute_wrappers): executor = functools.partial(wrapper, executor) return executor(sql, params, many, context) @@ -103,7 +107,9 @@ class CursorDebugWrapper(CursorWrapper): return super().executemany(sql, param_list) @contextmanager - def debug_sql(self, sql=None, params=None, use_last_executed_query=False, many=False): + def debug_sql( + self, sql=None, params=None, use_last_executed_query=False, many=False + ): start = time.monotonic() try: yield @@ -113,21 +119,28 @@ class CursorDebugWrapper(CursorWrapper): if use_last_executed_query: sql = self.db.ops.last_executed_query(self.cursor, sql, params) try: - times = len(params) if many else '' + times = len(params) if many else "" except TypeError: # params could be an iterator. - times = '?' - self.db.queries_log.append({ - 'sql': '%s times: %s' % (times, sql) if many else sql, - 'time': '%.3f' % duration, - }) + times = "?" + self.db.queries_log.append( + { + "sql": "%s times: %s" % (times, sql) if many else sql, + "time": "%.3f" % duration, + } + ) logger.debug( - '(%.3f) %s; args=%s; alias=%s', + "(%.3f) %s; args=%s; alias=%s", duration, sql, params, self.db.alias, - extra={'duration': duration, 'sql': sql, 'params': params, 'alias': self.db.alias}, + extra={ + "duration": duration, + "sql": sql, + "params": params, + "alias": self.db.alias, + }, ) @@ -135,7 +148,7 @@ def split_tzname_delta(tzname): """ Split a time zone name into a 3-tuple of (name, sign, offset). """ - for sign in ['+', '-']: + for sign in ["+", "-"]: if sign in tzname: name, offset = tzname.rsplit(sign, 1) if offset and parse_time(offset): @@ -147,19 +160,24 @@ def split_tzname_delta(tzname): # Converters from database (string) to Python # ############################################### + def typecast_date(s): - return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null + return ( + datetime.date(*map(int, s.split("-"))) if s else None + ) # return None if s is null def typecast_time(s): # does NOT store time zone information if not s: return None - hour, minutes, seconds = s.split(':') - if '.' in seconds: # check whether seconds have a fractional part - seconds, microseconds = seconds.split('.') + hour, minutes, seconds = s.split(":") + if "." in seconds: # check whether seconds have a fractional part + seconds, microseconds = seconds.split(".") else: - microseconds = '0' - return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6])) + microseconds = "0" + return datetime.time( + int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6]) + ) def typecast_timestamp(s): # does NOT store time zone information @@ -167,25 +185,29 @@ def typecast_timestamp(s): # does NOT store time zone information # "2005-07-29 09:56:00-05" if not s: return None - if ' ' not in s: + if " " not in s: return typecast_date(s) d, t = s.split() # Remove timezone information. - if '-' in t: - t, _ = t.split('-', 1) - elif '+' in t: - t, _ = t.split('+', 1) - dates = d.split('-') - times = t.split(':') + if "-" in t: + t, _ = t.split("-", 1) + elif "+" in t: + t, _ = t.split("+", 1) + dates = d.split("-") + times = t.split(":") seconds = times[2] - if '.' in seconds: # check whether seconds have a fractional part - seconds, microseconds = seconds.split('.') + if "." in seconds: # check whether seconds have a fractional part + seconds, microseconds = seconds.split(".") else: - microseconds = '0' + microseconds = "0" return datetime.datetime( - int(dates[0]), int(dates[1]), int(dates[2]), - int(times[0]), int(times[1]), int(seconds), - int((microseconds + '000000')[:6]) + int(dates[0]), + int(dates[1]), + int(dates[2]), + int(times[0]), + int(times[1]), + int(seconds), + int((microseconds + "000000")[:6]), ) @@ -193,6 +215,7 @@ def typecast_timestamp(s): # does NOT store time zone information # Converters from Python to database (string) # ############################################### + def split_identifier(identifier): """ Split an SQL identifier into a two element tuple of (namespace, name). @@ -203,7 +226,7 @@ def split_identifier(identifier): try: namespace, name = identifier.split('"."') except ValueError: - namespace, name = '', identifier + namespace, name = "", identifier return namespace.strip('"'), name.strip('"') @@ -221,7 +244,11 @@ def truncate_name(identifier, length=None, hash_len=4): return identifier digest = names_digest(name, length=hash_len) - return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest) + return "%s%s%s" % ( + '%s"."' % namespace if namespace else "", + name[: length - hash_len], + digest, + ) def names_digest(*args, length): @@ -246,7 +273,9 @@ def format_number(value, max_digits, decimal_places): if max_digits is not None: context.prec = max_digits if decimal_places is not None: - value = value.quantize(decimal.Decimal(1).scaleb(-decimal_places), context=context) + value = value.quantize( + decimal.Decimal(1).scaleb(-decimal_places), context=context + ) else: context.traps[decimal.Rounded] = 1 value = context.create_decimal(value) diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index f1238a3504..f8140f1845 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -10,7 +10,9 @@ from django.db.migrations.operations.models import AlterModelOptions from django.db.migrations.optimizer import MigrationOptimizer from django.db.migrations.questioner import MigrationQuestioner from django.db.migrations.utils import ( - COMPILED_REGEX_TYPE, RegexObject, resolve_relation, + COMPILED_REGEX_TYPE, + RegexObject, + resolve_relation, ) from django.utils.topological_sort import stable_topological_sort @@ -57,19 +59,20 @@ class MigrationAutodetector: elif isinstance(obj, tuple): return tuple(self.deep_deconstruct(value) for value in obj) elif isinstance(obj, dict): - return { - key: self.deep_deconstruct(value) - for key, value in obj.items() - } + return {key: self.deep_deconstruct(value) for key, value in obj.items()} elif isinstance(obj, functools.partial): - return (obj.func, self.deep_deconstruct(obj.args), self.deep_deconstruct(obj.keywords)) + return ( + obj.func, + self.deep_deconstruct(obj.args), + self.deep_deconstruct(obj.keywords), + ) elif isinstance(obj, COMPILED_REGEX_TYPE): return RegexObject(obj) elif isinstance(obj, type): # If this is a type that implements 'deconstruct' as an instance method, # avoid treating this as being deconstructible itself - see #22951 return obj - elif hasattr(obj, 'deconstruct'): + elif hasattr(obj, "deconstruct"): deconstructed = obj.deconstruct() if isinstance(obj, models.Field): # we have a field which also returns a name @@ -78,10 +81,7 @@ class MigrationAutodetector: return ( path, [self.deep_deconstruct(value) for value in args], - { - key: self.deep_deconstruct(value) - for key, value in kwargs.items() - }, + {key: self.deep_deconstruct(value) for key, value in kwargs.items()}, ) else: return obj @@ -96,7 +96,7 @@ class MigrationAutodetector: for name, field in sorted(fields.items()): deconstruction = self.deep_deconstruct(field) if field.remote_field and field.remote_field.model: - deconstruction[2].pop('to', None) + deconstruction[2].pop("to", None) fields_def.append(deconstruction) return fields_def @@ -132,22 +132,21 @@ class MigrationAutodetector: self.new_proxy_keys = set() self.new_unmanaged_keys = set() for (app_label, model_name), model_state in self.from_state.models.items(): - if not model_state.options.get('managed', True): + if not model_state.options.get("managed", True): self.old_unmanaged_keys.add((app_label, model_name)) elif app_label not in self.from_state.real_apps: - if model_state.options.get('proxy'): + if model_state.options.get("proxy"): self.old_proxy_keys.add((app_label, model_name)) else: self.old_model_keys.add((app_label, model_name)) for (app_label, model_name), model_state in self.to_state.models.items(): - if not model_state.options.get('managed', True): + if not model_state.options.get("managed", True): self.new_unmanaged_keys.add((app_label, model_name)) - elif ( - app_label not in self.from_state.real_apps or - (convert_apps and app_label in convert_apps) + elif app_label not in self.from_state.real_apps or ( + convert_apps and app_label in convert_apps ): - if model_state.options.get('proxy'): + if model_state.options.get("proxy"): self.new_proxy_keys.add((app_label, model_name)) else: self.new_model_keys.add((app_label, model_name)) @@ -214,8 +213,7 @@ class MigrationAutodetector: (app_label, model_name, field_name) for app_label, model_name in self.kept_model_keys for field_name in self.from_state.models[ - app_label, - self.renamed_models.get((app_label, model_name), model_name) + app_label, self.renamed_models.get((app_label, model_name), model_name) ].fields } self.new_field_keys = { @@ -227,12 +225,22 @@ class MigrationAutodetector: def _generate_through_model_map(self): """Through model map generation.""" for app_label, model_name in sorted(self.old_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] for field_name, field in old_model_state.fields.items(): - if hasattr(field, 'remote_field') and getattr(field.remote_field, 'through', None): - through_key = resolve_relation(field.remote_field.through, app_label, model_name) - self.through_users[through_key] = (app_label, old_model_name, field_name) + if hasattr(field, "remote_field") and getattr( + field.remote_field, "through", None + ): + through_key = resolve_relation( + field.remote_field.through, app_label, model_name + ) + self.through_users[through_key] = ( + app_label, + old_model_name, + field_name, + ) @staticmethod def _resolve_dependency(dependency): @@ -240,9 +248,11 @@ class MigrationAutodetector: Return the resolved dependency and a boolean denoting whether or not it was swappable. """ - if dependency[0] != '__setting__': + if dependency[0] != "__setting__": return dependency, False - resolved_app_label, resolved_object_name = getattr(settings, dependency[1]).split('.') + resolved_app_label, resolved_object_name = getattr( + settings, dependency[1] + ).split(".") return (resolved_app_label, resolved_object_name.lower()) + dependency[2:], True def _build_migration_list(self, graph=None): @@ -282,7 +292,9 @@ class MigrationAutodetector: if dep[0] != app_label: # External app dependency. See if it's not yet # satisfied. - for other_operation in self.generated_operations.get(dep[0], []): + for other_operation in self.generated_operations.get( + dep[0], [] + ): if self.check_dependency(other_operation, dep): deps_satisfied = False break @@ -290,9 +302,13 @@ class MigrationAutodetector: break else: if is_swappable_dep: - operation_dependencies.add((original_dep[0], original_dep[1])) + operation_dependencies.add( + (original_dep[0], original_dep[1]) + ) elif dep[0] in self.migrations: - operation_dependencies.add((dep[0], self.migrations[dep[0]][-1].name)) + operation_dependencies.add( + (dep[0], self.migrations[dep[0]][-1].name) + ) else: # If we can't find the other app, we add a first/last dependency, # but only if we've already been through once and checked everything @@ -301,9 +317,13 @@ class MigrationAutodetector: # as we don't know which migration contains the target field. # If it's not yet migrated or has no migrations, we use __first__ if graph and graph.leaf_nodes(dep[0]): - operation_dependencies.add(graph.leaf_nodes(dep[0])[0]) + operation_dependencies.add( + graph.leaf_nodes(dep[0])[0] + ) else: - operation_dependencies.add((dep[0], "__first__")) + operation_dependencies.add( + (dep[0], "__first__") + ) else: deps_satisfied = False if deps_satisfied: @@ -315,21 +335,33 @@ class MigrationAutodetector: # Make a migration! Well, only if there's stuff to put in it if dependencies or chopped: if not self.generated_operations[app_label] or chop_mode: - subclass = type("Migration", (Migration,), {"operations": [], "dependencies": []}) - instance = subclass("auto_%i" % (len(self.migrations.get(app_label, [])) + 1), app_label) + subclass = type( + "Migration", + (Migration,), + {"operations": [], "dependencies": []}, + ) + instance = subclass( + "auto_%i" % (len(self.migrations.get(app_label, [])) + 1), + app_label, + ) instance.dependencies = list(dependencies) instance.operations = chopped instance.initial = app_label not in self.existing_apps self.migrations.setdefault(app_label, []).append(instance) chop_mode = False else: - self.generated_operations[app_label] = chopped + self.generated_operations[app_label] + self.generated_operations[app_label] = ( + chopped + self.generated_operations[app_label] + ) new_num_ops = sum(len(x) for x in self.generated_operations.values()) if new_num_ops == num_ops: if not chop_mode: chop_mode = True else: - raise ValueError("Cannot resolve operation dependencies: %r" % self.generated_operations) + raise ValueError( + "Cannot resolve operation dependencies: %r" + % self.generated_operations + ) num_ops = new_num_ops def _sort_migrations(self): @@ -351,7 +383,9 @@ class MigrationAutodetector: dependency_graph[op].add(op2) # we use a stable sort for deterministic tests & general behavior - self.generated_operations[app_label] = stable_topological_sort(ops, dependency_graph) + self.generated_operations[app_label] = stable_topological_sort( + ops, dependency_graph + ) def _optimize_migrations(self): # Add in internal dependencies among the migrations @@ -367,7 +401,9 @@ class MigrationAutodetector: # Optimize migrations for app_label, migrations in self.migrations.items(): for migration in migrations: - migration.operations = MigrationOptimizer().optimize(migration.operations, app_label) + migration.operations = MigrationOptimizer().optimize( + migration.operations, app_label + ) def check_dependency(self, operation, dependency): """ @@ -377,56 +413,56 @@ class MigrationAutodetector: # Created model if dependency[2] is None and dependency[3] is True: return ( - isinstance(operation, operations.CreateModel) and - operation.name_lower == dependency[1].lower() + isinstance(operation, operations.CreateModel) + and operation.name_lower == dependency[1].lower() ) # Created field elif dependency[2] is not None and dependency[3] is True: return ( - ( - isinstance(operation, operations.CreateModel) and - operation.name_lower == dependency[1].lower() and - any(dependency[2] == x for x, y in operation.fields) - ) or - ( - isinstance(operation, operations.AddField) and - operation.model_name_lower == dependency[1].lower() and - operation.name_lower == dependency[2].lower() - ) + isinstance(operation, operations.CreateModel) + and operation.name_lower == dependency[1].lower() + and any(dependency[2] == x for x, y in operation.fields) + ) or ( + isinstance(operation, operations.AddField) + and operation.model_name_lower == dependency[1].lower() + and operation.name_lower == dependency[2].lower() ) # Removed field elif dependency[2] is not None and dependency[3] is False: return ( - isinstance(operation, operations.RemoveField) and - operation.model_name_lower == dependency[1].lower() and - operation.name_lower == dependency[2].lower() + isinstance(operation, operations.RemoveField) + and operation.model_name_lower == dependency[1].lower() + and operation.name_lower == dependency[2].lower() ) # Removed model elif dependency[2] is None and dependency[3] is False: return ( - isinstance(operation, operations.DeleteModel) and - operation.name_lower == dependency[1].lower() + isinstance(operation, operations.DeleteModel) + and operation.name_lower == dependency[1].lower() ) # Field being altered elif dependency[2] is not None and dependency[3] == "alter": return ( - isinstance(operation, operations.AlterField) and - operation.model_name_lower == dependency[1].lower() and - operation.name_lower == dependency[2].lower() + isinstance(operation, operations.AlterField) + and operation.model_name_lower == dependency[1].lower() + and operation.name_lower == dependency[2].lower() ) # order_with_respect_to being unset for a field elif dependency[2] is not None and dependency[3] == "order_wrt_unset": return ( - isinstance(operation, operations.AlterOrderWithRespectTo) and - operation.name_lower == dependency[1].lower() and - (operation.order_with_respect_to or "").lower() != dependency[2].lower() + isinstance(operation, operations.AlterOrderWithRespectTo) + and operation.name_lower == dependency[1].lower() + and (operation.order_with_respect_to or "").lower() + != dependency[2].lower() ) # Field is removed and part of an index/unique_together elif dependency[2] is not None and dependency[3] == "foo_together_change": return ( - isinstance(operation, (operations.AlterUniqueTogether, - operations.AlterIndexTogether)) and - operation.name_lower == dependency[1].lower() + isinstance( + operation, + (operations.AlterUniqueTogether, operations.AlterIndexTogether), + ) + and operation.name_lower == dependency[1].lower() ) # Unknown dependency. Raise an error. else: @@ -453,10 +489,10 @@ class MigrationAutodetector: } string_version = "%s.%s" % (item[0], item[1]) if ( - model_state.options.get('swappable') or - "AbstractUser" in base_names or - "AbstractBaseUser" in base_names or - settings.AUTH_USER_MODEL.lower() == string_version.lower() + model_state.options.get("swappable") + or "AbstractUser" in base_names + or "AbstractBaseUser" in base_names + or settings.AUTH_USER_MODEL.lower() == string_version.lower() ): return ("___" + item[0], "___" + item[1]) except LookupError: @@ -479,21 +515,32 @@ class MigrationAutodetector: removed_models = self.old_model_keys - self.new_model_keys for rem_app_label, rem_model_name in removed_models: if rem_app_label == app_label: - rem_model_state = self.from_state.models[rem_app_label, rem_model_name] - rem_model_fields_def = self.only_relation_agnostic_fields(rem_model_state.fields) + rem_model_state = self.from_state.models[ + rem_app_label, rem_model_name + ] + rem_model_fields_def = self.only_relation_agnostic_fields( + rem_model_state.fields + ) if model_fields_def == rem_model_fields_def: - if self.questioner.ask_rename_model(rem_model_state, model_state): + if self.questioner.ask_rename_model( + rem_model_state, model_state + ): dependencies = [] fields = list(model_state.fields.values()) + [ field.remote_field - for relations in self.to_state.relations[app_label, model_name].values() + for relations in self.to_state.relations[ + app_label, model_name + ].values() for field in relations.values() ] for field in fields: if field.is_relation: dependencies.extend( self._get_dependencies_for_foreign_key( - app_label, model_name, field, self.to_state, + app_label, + model_name, + field, + self.to_state, ) ) self.add_operation( @@ -505,11 +552,13 @@ class MigrationAutodetector: dependencies=dependencies, ) self.renamed_models[app_label, model_name] = rem_model_name - renamed_models_rel_key = '%s.%s' % ( + renamed_models_rel_key = "%s.%s" % ( rem_model_state.app_label, rem_model_state.name_lower, ) - self.renamed_models_rel[renamed_models_rel_key] = '%s.%s' % ( + self.renamed_models_rel[ + renamed_models_rel_key + ] = "%s.%s" % ( model_state.app_label, model_state.name_lower, ) @@ -532,7 +581,7 @@ class MigrationAutodetector: added_unmanaged_models = self.new_unmanaged_keys - old_keys all_added_models = chain( sorted(added_models, key=self.swappable_first_key, reverse=True), - sorted(added_unmanaged_models, key=self.swappable_first_key, reverse=True) + sorted(added_unmanaged_models, key=self.swappable_first_key, reverse=True), ) for app_label, model_name in all_added_models: model_state = self.to_state.models[app_label, model_name] @@ -546,15 +595,17 @@ class MigrationAutodetector: primary_key_rel = field.remote_field.model elif not field.remote_field.parent_link: related_fields[field_name] = field - if getattr(field.remote_field, 'through', None): + if getattr(field.remote_field, "through", None): related_fields[field_name] = field # Are there indexes/unique|index_together to defer? - indexes = model_state.options.pop('indexes') - constraints = model_state.options.pop('constraints') - unique_together = model_state.options.pop('unique_together', None) - index_together = model_state.options.pop('index_together', None) - order_with_respect_to = model_state.options.pop('order_with_respect_to', None) + indexes = model_state.options.pop("indexes") + constraints = model_state.options.pop("constraints") + unique_together = model_state.options.pop("unique_together", None) + index_together = model_state.options.pop("index_together", None) + order_with_respect_to = model_state.options.pop( + "order_with_respect_to", None + ) # Depend on the deletion of any possible proxy version of us dependencies = [ (app_label, model_name, None, False), @@ -566,27 +617,44 @@ class MigrationAutodetector: dependencies.append((base_app_label, base_name, None, True)) # Depend on the removal of base fields if the new model has # a field with the same name. - old_base_model_state = self.from_state.models.get((base_app_label, base_name)) - new_base_model_state = self.to_state.models.get((base_app_label, base_name)) + old_base_model_state = self.from_state.models.get( + (base_app_label, base_name) + ) + new_base_model_state = self.to_state.models.get( + (base_app_label, base_name) + ) if old_base_model_state and new_base_model_state: - removed_base_fields = set(old_base_model_state.fields).difference( - new_base_model_state.fields, - ).intersection(model_state.fields) + removed_base_fields = ( + set(old_base_model_state.fields) + .difference( + new_base_model_state.fields, + ) + .intersection(model_state.fields) + ) for removed_base_field in removed_base_fields: - dependencies.append((base_app_label, base_name, removed_base_field, False)) + dependencies.append( + (base_app_label, base_name, removed_base_field, False) + ) # Depend on the other end of the primary key if it's a relation if primary_key_rel: dependencies.append( resolve_relation( - primary_key_rel, app_label, model_name, - ) + (None, True) + primary_key_rel, + app_label, + model_name, + ) + + (None, True) ) # Generate creation operation self.add_operation( app_label, operations.CreateModel( name=model_state.name, - fields=[d for d in model_state.fields.items() if d[0] not in related_fields], + fields=[ + d + for d in model_state.fields.items() + if d[0] not in related_fields + ], options=model_state.options, bases=model_state.bases, managers=model_state.managers, @@ -596,13 +664,16 @@ class MigrationAutodetector: ) # Don't add operations which modify the database for unmanaged models - if not model_state.options.get('managed', True): + if not model_state.options.get("managed", True): continue # Generate operations for each related field for name, field in sorted(related_fields.items()): dependencies = self._get_dependencies_for_foreign_key( - app_label, model_name, field, self.to_state, + app_label, + model_name, + field, + self.to_state, ) # Depend on our own model being created dependencies.append((app_label, model_name, None, True)) @@ -627,11 +698,10 @@ class MigrationAutodetector: dependencies=[ (app_label, model_name, order_with_respect_to, True), (app_label, model_name, None, True), - ] + ], ) related_dependencies = [ - (app_label, model_name, name, True) - for name in sorted(related_fields) + (app_label, model_name, name, True) for name in sorted(related_fields) ] related_dependencies.append((app_label, model_name, None, True)) for index in indexes: @@ -659,7 +729,7 @@ class MigrationAutodetector: name=model_name, unique_together=unique_together, ), - dependencies=related_dependencies + dependencies=related_dependencies, ) if index_together: self.add_operation( @@ -668,13 +738,15 @@ class MigrationAutodetector: name=model_name, index_together=index_together, ), - dependencies=related_dependencies + dependencies=related_dependencies, ) # Fix relationships if the model changed from a proxy model to a # concrete model. relations = self.to_state.relations if (app_label, model_name) in self.old_proxy_keys: - for related_model_key, related_fields in relations[app_label, model_name].items(): + for related_model_key, related_fields in relations[ + app_label, model_name + ].items(): related_model_state = self.to_state.models[related_model_key] for related_field_name, related_field in related_fields.items(): self.add_operation( @@ -733,7 +805,9 @@ class MigrationAutodetector: new_keys = self.new_model_keys | self.new_unmanaged_keys deleted_models = self.old_model_keys - new_keys deleted_unmanaged_models = self.old_unmanaged_keys - new_keys - all_deleted_models = chain(sorted(deleted_models), sorted(deleted_unmanaged_models)) + all_deleted_models = chain( + sorted(deleted_models), sorted(deleted_unmanaged_models) + ) for app_label, model_name in all_deleted_models: model_state = self.from_state.models[app_label, model_name] # Gather related fields @@ -742,18 +816,18 @@ class MigrationAutodetector: if field.remote_field: if field.remote_field.model: related_fields[field_name] = field - if getattr(field.remote_field, 'through', None): + if getattr(field.remote_field, "through", None): related_fields[field_name] = field # Generate option removal first - unique_together = model_state.options.pop('unique_together', None) - index_together = model_state.options.pop('index_together', None) + unique_together = model_state.options.pop("unique_together", None) + index_together = model_state.options.pop("index_together", None) if unique_together: self.add_operation( app_label, operations.AlterUniqueTogether( name=model_name, unique_together=None, - ) + ), ) if index_together: self.add_operation( @@ -761,7 +835,7 @@ class MigrationAutodetector: operations.AlterIndexTogether( name=model_name, index_together=None, - ) + ), ) # Then remove each related field for name in sorted(related_fields): @@ -770,7 +844,7 @@ class MigrationAutodetector: operations.RemoveField( model_name=model_name, name=name, - ) + ), ) # Finally, remove the model. # This depends on both the removal/alteration of all incoming fields @@ -778,16 +852,22 @@ class MigrationAutodetector: # a through model the field that references it. dependencies = [] relations = self.from_state.relations - for (related_object_app_label, object_name), relation_related_fields in ( - relations[app_label, model_name].items() - ): + for ( + related_object_app_label, + object_name, + ), relation_related_fields in relations[app_label, model_name].items(): for field_name, field in relation_related_fields.items(): dependencies.append( (related_object_app_label, object_name, field_name, False), ) if not field.many_to_many: dependencies.append( - (related_object_app_label, object_name, field_name, 'alter'), + ( + related_object_app_label, + object_name, + field_name, + "alter", + ), ) for name in sorted(related_fields): @@ -795,7 +875,9 @@ class MigrationAutodetector: # We're referenced in another field's through= through_user = self.through_users.get((app_label, model_state.name_lower)) if through_user: - dependencies.append((through_user[0], through_user[1], through_user[2], False)) + dependencies.append( + (through_user[0], through_user[1], through_user[2], False) + ) # Finally, make the operation, deduping any dependencies self.add_operation( app_label, @@ -821,29 +903,43 @@ class MigrationAutodetector: def generate_renamed_fields(self): """Work out renamed fields.""" self.renamed_fields = {} - for app_label, model_name, field_name in sorted(self.new_field_keys - self.old_field_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + for app_label, model_name, field_name in sorted( + self.new_field_keys - self.old_field_keys + ): + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] field = new_model_state.get_field(field_name) # Scan to see if this is actually a rename! field_dec = self.deep_deconstruct(field) - for rem_app_label, rem_model_name, rem_field_name in sorted(self.old_field_keys - self.new_field_keys): + for rem_app_label, rem_model_name, rem_field_name in sorted( + self.old_field_keys - self.new_field_keys + ): if rem_app_label == app_label and rem_model_name == model_name: old_field = old_model_state.get_field(rem_field_name) old_field_dec = self.deep_deconstruct(old_field) - if field.remote_field and field.remote_field.model and 'to' in old_field_dec[2]: - old_rel_to = old_field_dec[2]['to'] + if ( + field.remote_field + and field.remote_field.model + and "to" in old_field_dec[2] + ): + old_rel_to = old_field_dec[2]["to"] if old_rel_to in self.renamed_models_rel: - old_field_dec[2]['to'] = self.renamed_models_rel[old_rel_to] + old_field_dec[2]["to"] = self.renamed_models_rel[old_rel_to] old_field.set_attributes_from_name(rem_field_name) old_db_column = old_field.get_attname_column()[1] - if (old_field_dec == field_dec or ( - # Was the field renamed and db_column equal to the - # old field's column added? - old_field_dec[0:2] == field_dec[0:2] and - dict(old_field_dec[2], db_column=old_db_column) == field_dec[2])): - if self.questioner.ask_rename(model_name, rem_field_name, field_name, field): + if old_field_dec == field_dec or ( + # Was the field renamed and db_column equal to the + # old field's column added? + old_field_dec[0:2] == field_dec[0:2] + and dict(old_field_dec[2], db_column=old_db_column) + == field_dec[2] + ): + if self.questioner.ask_rename( + model_name, rem_field_name, field_name, field + ): # A db_column mismatch requires a prior noop # AlterField for the subsequent RenameField to be a # noop on attempts at preserving the old name. @@ -864,16 +960,22 @@ class MigrationAutodetector: model_name=model_name, old_name=rem_field_name, new_name=field_name, - ) + ), + ) + self.old_field_keys.remove( + (rem_app_label, rem_model_name, rem_field_name) ) - self.old_field_keys.remove((rem_app_label, rem_model_name, rem_field_name)) self.old_field_keys.add((app_label, model_name, field_name)) - self.renamed_fields[app_label, model_name, field_name] = rem_field_name + self.renamed_fields[ + app_label, model_name, field_name + ] = rem_field_name break def generate_added_fields(self): """Make AddField operations.""" - for app_label, model_name, field_name in sorted(self.new_field_keys - self.old_field_keys): + for app_label, model_name, field_name in sorted( + self.new_field_keys - self.old_field_keys + ): self._generate_added_field(app_label, model_name, field_name) def _generate_added_field(self, app_label, model_name, field_name): @@ -881,27 +983,38 @@ class MigrationAutodetector: # Fields that are foreignkeys/m2ms depend on stuff dependencies = [] if field.remote_field and field.remote_field.model: - dependencies.extend(self._get_dependencies_for_foreign_key( - app_label, model_name, field, self.to_state, - )) + dependencies.extend( + self._get_dependencies_for_foreign_key( + app_label, + model_name, + field, + self.to_state, + ) + ) # You can't just add NOT NULL fields with no default or fields # which don't allow empty strings as default. time_fields = (models.DateField, models.DateTimeField, models.TimeField) preserve_default = ( - field.null or field.has_default() or field.many_to_many or - (field.blank and field.empty_strings_allowed) or - (isinstance(field, time_fields) and field.auto_now) + field.null + or field.has_default() + or field.many_to_many + or (field.blank and field.empty_strings_allowed) + or (isinstance(field, time_fields) and field.auto_now) ) if not preserve_default: field = field.clone() if isinstance(field, time_fields) and field.auto_now_add: - field.default = self.questioner.ask_auto_now_add_addition(field_name, model_name) + field.default = self.questioner.ask_auto_now_add_addition( + field_name, model_name + ) else: - field.default = self.questioner.ask_not_null_addition(field_name, model_name) + field.default = self.questioner.ask_not_null_addition( + field_name, model_name + ) if ( - field.unique and - field.default is not models.NOT_PROVIDED and - callable(field.default) + field.unique + and field.default is not models.NOT_PROVIDED + and callable(field.default) ): self.questioner.ask_unique_callable_default_addition(field_name, model_name) self.add_operation( @@ -917,7 +1030,9 @@ class MigrationAutodetector: def generate_removed_fields(self): """Make RemoveField operations.""" - for app_label, model_name, field_name in sorted(self.old_field_keys - self.new_field_keys): + for app_label, model_name, field_name in sorted( + self.old_field_keys - self.new_field_keys + ): self._generate_removed_field(app_label, model_name, field_name) def _generate_removed_field(self, app_label, model_name, field_name): @@ -941,21 +1056,35 @@ class MigrationAutodetector: Make AlterField operations, or possibly RemovedField/AddField if alter isn't possible. """ - for app_label, model_name, field_name in sorted(self.old_field_keys & self.new_field_keys): + for app_label, model_name, field_name in sorted( + self.old_field_keys & self.new_field_keys + ): # Did the field change? - old_model_name = self.renamed_models.get((app_label, model_name), model_name) - old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name) - old_field = self.from_state.models[app_label, old_model_name].get_field(old_field_name) - new_field = self.to_state.models[app_label, model_name].get_field(field_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) + old_field_name = self.renamed_fields.get( + (app_label, model_name, field_name), field_name + ) + old_field = self.from_state.models[app_label, old_model_name].get_field( + old_field_name + ) + new_field = self.to_state.models[app_label, model_name].get_field( + field_name + ) dependencies = [] # Implement any model renames on relations; these are handled by RenameModel # so we need to exclude them from the comparison - if hasattr(new_field, "remote_field") and getattr(new_field.remote_field, "model", None): - rename_key = resolve_relation(new_field.remote_field.model, app_label, model_name) + if hasattr(new_field, "remote_field") and getattr( + new_field.remote_field, "model", None + ): + rename_key = resolve_relation( + new_field.remote_field.model, app_label, model_name + ) if rename_key in self.renamed_models: new_field.remote_field.model = old_field.remote_field.model # Handle ForeignKey which can only have a single to_field. - remote_field_name = getattr(new_field.remote_field, 'field_name', None) + remote_field_name = getattr(new_field.remote_field, "field_name", None) if remote_field_name: to_field_rename_key = rename_key + (remote_field_name,) if to_field_rename_key in self.renamed_fields: @@ -963,27 +1092,41 @@ class MigrationAutodetector: # inclusion in ForeignKey.deconstruct() is based on # both. new_field.remote_field.model = old_field.remote_field.model - new_field.remote_field.field_name = old_field.remote_field.field_name + new_field.remote_field.field_name = ( + old_field.remote_field.field_name + ) # Handle ForeignObjects which can have multiple from_fields/to_fields. - from_fields = getattr(new_field, 'from_fields', None) + from_fields = getattr(new_field, "from_fields", None) if from_fields: from_rename_key = (app_label, model_name) - new_field.from_fields = tuple([ - self.renamed_fields.get(from_rename_key + (from_field,), from_field) - for from_field in from_fields - ]) - new_field.to_fields = tuple([ - self.renamed_fields.get(rename_key + (to_field,), to_field) - for to_field in new_field.to_fields - ]) - dependencies.extend(self._get_dependencies_for_foreign_key( - app_label, model_name, new_field, self.to_state, - )) - if ( - hasattr(new_field, 'remote_field') and - getattr(new_field.remote_field, 'through', None) + new_field.from_fields = tuple( + [ + self.renamed_fields.get( + from_rename_key + (from_field,), from_field + ) + for from_field in from_fields + ] + ) + new_field.to_fields = tuple( + [ + self.renamed_fields.get(rename_key + (to_field,), to_field) + for to_field in new_field.to_fields + ] + ) + dependencies.extend( + self._get_dependencies_for_foreign_key( + app_label, + model_name, + new_field, + self.to_state, + ) + ) + if hasattr(new_field, "remote_field") and getattr( + new_field.remote_field, "through", None ): - rename_key = resolve_relation(new_field.remote_field.through, app_label, model_name) + rename_key = resolve_relation( + new_field.remote_field.through, app_label, model_name + ) if rename_key in self.renamed_models: new_field.remote_field.through = old_field.remote_field.through old_field_dec = self.deep_deconstruct(old_field) @@ -997,10 +1140,16 @@ class MigrationAutodetector: if both_m2m or neither_m2m: # Either both fields are m2m or neither is preserve_default = True - if (old_field.null and not new_field.null and not new_field.has_default() and - not new_field.many_to_many): + if ( + old_field.null + and not new_field.null + and not new_field.has_default() + and not new_field.many_to_many + ): field = new_field.clone() - new_default = self.questioner.ask_not_null_alteration(field_name, model_name) + new_default = self.questioner.ask_not_null_alteration( + field_name, model_name + ) if new_default is not models.NOT_PROVIDED: field.default = new_default preserve_default = False @@ -1024,7 +1173,9 @@ class MigrationAutodetector: def create_altered_indexes(self): option_name = operations.AddIndex.option_name for app_label, model_name in sorted(self.kept_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] @@ -1033,38 +1184,43 @@ class MigrationAutodetector: add_idx = [idx for idx in new_indexes if idx not in old_indexes] rem_idx = [idx for idx in old_indexes if idx not in new_indexes] - self.altered_indexes.update({ - (app_label, model_name): { - 'added_indexes': add_idx, 'removed_indexes': rem_idx, + self.altered_indexes.update( + { + (app_label, model_name): { + "added_indexes": add_idx, + "removed_indexes": rem_idx, + } } - }) + ) def generate_added_indexes(self): for (app_label, model_name), alt_indexes in self.altered_indexes.items(): - for index in alt_indexes['added_indexes']: + for index in alt_indexes["added_indexes"]: self.add_operation( app_label, operations.AddIndex( model_name=model_name, index=index, - ) + ), ) def generate_removed_indexes(self): for (app_label, model_name), alt_indexes in self.altered_indexes.items(): - for index in alt_indexes['removed_indexes']: + for index in alt_indexes["removed_indexes"]: self.add_operation( app_label, operations.RemoveIndex( model_name=model_name, name=index.name, - ) + ), ) def create_altered_constraints(self): option_name = operations.AddConstraint.option_name for app_label, model_name in sorted(self.kept_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] @@ -1073,38 +1229,47 @@ class MigrationAutodetector: add_constraints = [c for c in new_constraints if c not in old_constraints] rem_constraints = [c for c in old_constraints if c not in new_constraints] - self.altered_constraints.update({ - (app_label, model_name): { - 'added_constraints': add_constraints, 'removed_constraints': rem_constraints, + self.altered_constraints.update( + { + (app_label, model_name): { + "added_constraints": add_constraints, + "removed_constraints": rem_constraints, + } } - }) + ) def generate_added_constraints(self): - for (app_label, model_name), alt_constraints in self.altered_constraints.items(): - for constraint in alt_constraints['added_constraints']: + for ( + app_label, + model_name, + ), alt_constraints in self.altered_constraints.items(): + for constraint in alt_constraints["added_constraints"]: self.add_operation( app_label, operations.AddConstraint( model_name=model_name, constraint=constraint, - ) + ), ) def generate_removed_constraints(self): - for (app_label, model_name), alt_constraints in self.altered_constraints.items(): - for constraint in alt_constraints['removed_constraints']: + for ( + app_label, + model_name, + ), alt_constraints in self.altered_constraints.items(): + for constraint in alt_constraints["removed_constraints"]: self.add_operation( app_label, operations.RemoveConstraint( model_name=model_name, name=constraint.name, - ) + ), ) @staticmethod def _get_dependencies_for_foreign_key(app_label, model_name, field, project_state): remote_field_model = None - if hasattr(field.remote_field, 'model'): + if hasattr(field.remote_field, "model"): remote_field_model = field.remote_field.model else: relations = project_state.relations[app_label, model_name] @@ -1113,40 +1278,50 @@ class MigrationAutodetector: field == related_field.remote_field for related_field in fields.values() ): - remote_field_model = f'{remote_app_label}.{remote_model_name}' + remote_field_model = f"{remote_app_label}.{remote_model_name}" break # Account for FKs to swappable models - swappable_setting = getattr(field, 'swappable_setting', None) + swappable_setting = getattr(field, "swappable_setting", None) if swappable_setting is not None: dep_app_label = "__setting__" dep_object_name = swappable_setting else: dep_app_label, dep_object_name = resolve_relation( - remote_field_model, app_label, model_name, + remote_field_model, + app_label, + model_name, ) dependencies = [(dep_app_label, dep_object_name, None, True)] - if getattr(field.remote_field, 'through', None): + if getattr(field.remote_field, "through", None): through_app_label, through_object_name = resolve_relation( - remote_field_model, app_label, model_name, + remote_field_model, + app_label, + model_name, ) dependencies.append((through_app_label, through_object_name, None, True)) return dependencies def _get_altered_foo_together_operations(self, option_name): for app_label, model_name in sorted(self.kept_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] # We run the old version through the field renames to account for those old_value = old_model_state.options.get(option_name) - old_value = { - tuple( - self.renamed_fields.get((app_label, model_name, n), n) - for n in unique - ) - for unique in old_value - } if old_value else set() + old_value = ( + { + tuple( + self.renamed_fields.get((app_label, model_name, n), n) + for n in unique + ) + for unique in old_value + } + if old_value + else set() + ) new_value = new_model_state.options.get(option_name) new_value = set(new_value) if new_value else set() @@ -1157,9 +1332,14 @@ class MigrationAutodetector: for field_name in foo_togethers: field = new_model_state.get_field(field_name) if field.remote_field and field.remote_field.model: - dependencies.extend(self._get_dependencies_for_foreign_key( - app_label, model_name, field, self.to_state, - )) + dependencies.extend( + self._get_dependencies_for_foreign_key( + app_label, + model_name, + field, + self.to_state, + ) + ) yield ( old_value, new_value, @@ -1180,7 +1360,9 @@ class MigrationAutodetector: if removal_value or old_value: self.add_operation( app_label, - operation(name=model_name, **{operation.option_name: removal_value}), + operation( + name=model_name, **{operation.option_name: removal_value} + ), dependencies=dependencies, ) @@ -1213,20 +1395,24 @@ class MigrationAutodetector: self._generate_altered_foo_together(operations.AlterIndexTogether) def generate_altered_db_table(self): - models_to_check = self.kept_model_keys.union(self.kept_proxy_keys, self.kept_unmanaged_keys) + models_to_check = self.kept_model_keys.union( + self.kept_proxy_keys, self.kept_unmanaged_keys + ) for app_label, model_name in sorted(models_to_check): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] - old_db_table_name = old_model_state.options.get('db_table') - new_db_table_name = new_model_state.options.get('db_table') + old_db_table_name = old_model_state.options.get("db_table") + new_db_table_name = new_model_state.options.get("db_table") if old_db_table_name != new_db_table_name: self.add_operation( app_label, operations.AlterModelTable( name=model_name, table=new_db_table_name, - ) + ), ) def generate_altered_options(self): @@ -1245,15 +1431,19 @@ class MigrationAutodetector: ) for app_label, model_name in sorted(models_to_check): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] old_options = { - key: value for key, value in old_model_state.options.items() + key: value + for key, value in old_model_state.options.items() if key in AlterModelOptions.ALTER_OPTION_KEYS } new_options = { - key: value for key, value in new_model_state.options.items() + key: value + for key, value in new_model_state.options.items() if key in AlterModelOptions.ALTER_OPTION_KEYS } if old_options != new_options: @@ -1262,39 +1452,48 @@ class MigrationAutodetector: operations.AlterModelOptions( name=model_name, options=new_options, - ) + ), ) def generate_altered_order_with_respect_to(self): for app_label, model_name in sorted(self.kept_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] - if (old_model_state.options.get("order_with_respect_to") != - new_model_state.options.get("order_with_respect_to")): + if old_model_state.options.get( + "order_with_respect_to" + ) != new_model_state.options.get("order_with_respect_to"): # Make sure it comes second if we're adding # (removal dependency is part of RemoveField) dependencies = [] if new_model_state.options.get("order_with_respect_to"): - dependencies.append(( - app_label, - model_name, - new_model_state.options["order_with_respect_to"], - True, - )) + dependencies.append( + ( + app_label, + model_name, + new_model_state.options["order_with_respect_to"], + True, + ) + ) # Actually generate the operation self.add_operation( app_label, operations.AlterOrderWithRespectTo( name=model_name, - order_with_respect_to=new_model_state.options.get('order_with_respect_to'), + order_with_respect_to=new_model_state.options.get( + "order_with_respect_to" + ), ), dependencies=dependencies, ) def generate_altered_managers(self): for app_label, model_name in sorted(self.kept_model_keys): - old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_name = self.renamed_models.get( + (app_label, model_name), model_name + ) old_model_state = self.from_state.models[app_label, old_model_name] new_model_state = self.to_state.models[app_label, model_name] if old_model_state.managers != new_model_state.managers: @@ -1303,7 +1502,7 @@ class MigrationAutodetector: operations.AlterModelManagers( name=model_name, managers=new_model_state.managers, - ) + ), ) def arrange_for_graph(self, changes, graph, migration_name=None): @@ -1339,21 +1538,23 @@ class MigrationAutodetector: for i, migration in enumerate(migrations): if i == 0 and app_leaf: migration.dependencies.append(app_leaf) - new_name_parts = ['%04i' % next_number] + new_name_parts = ["%04i" % next_number] if migration_name: new_name_parts.append(migration_name) elif i == 0 and not app_leaf: - new_name_parts.append('initial') + new_name_parts.append("initial") else: new_name_parts.append(migration.suggest_name()[:100]) - new_name = '_'.join(new_name_parts) + new_name = "_".join(new_name_parts) name_map[(app_label, migration.name)] = (app_label, new_name) next_number += 1 migration.name = new_name # Now fix dependencies for migrations in changes.values(): for migration in migrations: - migration.dependencies = [name_map.get(d, d) for d in migration.dependencies] + migration.dependencies = [ + name_map.get(d, d) for d in migration.dependencies + ] return changes def _trim_to_apps(self, changes, app_labels): @@ -1374,7 +1575,9 @@ class MigrationAutodetector: old_required_apps = None while old_required_apps != required_apps: old_required_apps = set(required_apps) - required_apps.update(*[app_dependencies.get(app_label, ()) for app_label in required_apps]) + required_apps.update( + *[app_dependencies.get(app_label, ()) for app_label in required_apps] + ) # Remove all migrations that aren't needed for app_label in list(changes): if app_label not in required_apps: @@ -1388,9 +1591,9 @@ class MigrationAutodetector: it. For a squashed migration such as '0001_squashed_0004…', return the second number. If no number is found, return None. """ - if squashed_match := re.search(r'.*_squashed_(\d+)', name): + if squashed_match := re.search(r".*_squashed_(\d+)", name): return int(squashed_match[1]) - match = re.match(r'^\d+', name) + match = re.match(r"^\d+", name) if match: return int(match[0]) return None diff --git a/django/db/migrations/exceptions.py b/django/db/migrations/exceptions.py index 8def99da5b..dd556dacb5 100644 --- a/django/db/migrations/exceptions.py +++ b/django/db/migrations/exceptions.py @@ -3,31 +3,37 @@ from django.db import DatabaseError class AmbiguityError(Exception): """More than one migration matches a name prefix.""" + pass class BadMigrationError(Exception): """There's a bad migration (unreadable/bad format/etc.).""" + pass class CircularDependencyError(Exception): """There's an impossible-to-resolve circular dependency.""" + pass class InconsistentMigrationHistory(Exception): """An applied migration has some of its dependencies not applied.""" + pass class InvalidBasesError(ValueError): """A model's base classes can't be resolved.""" + pass class IrreversibleError(RuntimeError): """An irreversible migration is about to be reversed.""" + pass diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py index 89e9344a68..ef7f9060f1 100644 --- a/django/db/migrations/executor.py +++ b/django/db/migrations/executor.py @@ -43,8 +43,8 @@ class MigrationExecutor: # If the target is missing, it's likely a replaced migration. # Reload the graph without replacements. if ( - self.loader.replace_migrations and - target not in self.loader.graph.node_map + self.loader.replace_migrations + and target not in self.loader.graph.node_map ): self.loader.replace_migrations = False self.loader.build_graph() @@ -54,8 +54,8 @@ class MigrationExecutor: # be rolled back); instead roll back through target's immediate # child(ren) in the same app, and no further. next_in_app = sorted( - n for n in - self.loader.graph.node_map[target].children + n + for n in self.loader.graph.node_map[target].children if n[0] == target[0] ) for node in next_in_app: @@ -78,9 +78,12 @@ class MigrationExecutor: state = ProjectState(real_apps=self.loader.unmigrated_apps) if with_applied_migrations: # Create the forwards plan Django would follow on an empty database - full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True) + full_plan = self.migration_plan( + self.loader.graph.leaf_nodes(), clean_start=True + ) applied_migrations = { - self.loader.graph.nodes[key] for key in self.loader.applied_migrations + self.loader.graph.nodes[key] + for key in self.loader.applied_migrations if key in self.loader.graph.nodes } for migration, _ in full_plan: @@ -106,7 +109,9 @@ class MigrationExecutor: if plan is None: plan = self.migration_plan(targets) # Create the forwards plan Django would follow on an empty database - full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True) + full_plan = self.migration_plan( + self.loader.graph.leaf_nodes(), clean_start=True + ) all_forwards = all(not backwards for mig, backwards in plan) all_backwards = all(backwards for mig, backwards in plan) @@ -121,13 +126,15 @@ class MigrationExecutor: "Migration plans with both forwards and backwards migrations " "are not supported. Please split your migration process into " "separate plans of only forwards OR backwards migrations.", - plan + plan, ) elif all_forwards: if state is None: # The resulting state should still include applied migrations. state = self._create_project_state(with_applied_migrations=True) - state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial) + state = self._migrate_all_forwards( + state, plan, full_plan, fake=fake, fake_initial=fake_initial + ) else: # No need to check for `elif all_backwards` here, as that condition # would always evaluate to true. @@ -151,13 +158,15 @@ class MigrationExecutor: # process. break if migration in migrations_to_run: - if 'apps' not in state.__dict__: + if "apps" not in state.__dict__: if self.progress_callback: self.progress_callback("render_start") state.apps # Render all -- performance critical if self.progress_callback: self.progress_callback("render_success") - state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial) + state = self.apply_migration( + state, migration, fake=fake, fake_initial=fake_initial + ) migrations_to_run.remove(migration) return state @@ -177,7 +186,8 @@ class MigrationExecutor: states = {} state = self._create_project_state() applied_migrations = { - self.loader.graph.nodes[key] for key in self.loader.applied_migrations + self.loader.graph.nodes[key] + for key in self.loader.applied_migrations if key in self.loader.graph.nodes } if self.progress_callback: @@ -190,7 +200,7 @@ class MigrationExecutor: # process. break if migration in migrations_to_run: - if 'apps' not in state.__dict__: + if "apps" not in state.__dict__: state.apps # Render all -- performance critical # The state before this migration states[migration] = state @@ -236,7 +246,9 @@ class MigrationExecutor: fake = True if not fake: # Alright, do it normally - with self.connection.schema_editor(atomic=migration.atomic) as schema_editor: + with self.connection.schema_editor( + atomic=migration.atomic + ) as schema_editor: state = migration.apply(state, schema_editor) if not schema_editor.deferred_sql: self.record_migration(migration) @@ -261,7 +273,9 @@ class MigrationExecutor: if self.progress_callback: self.progress_callback("unapply_start", migration, fake) if not fake: - with self.connection.schema_editor(atomic=migration.atomic) as schema_editor: + with self.connection.schema_editor( + atomic=migration.atomic + ) as schema_editor: state = migration.unapply(state, schema_editor) # For replacement migrations, also record individual statuses. if migration.replaces: @@ -296,15 +310,18 @@ class MigrationExecutor: tables or columns it would create exist. This is intended only for use on initial migrations (as it only looks for CreateModel and AddField). """ + def should_skip_detecting_model(migration, model): """ No need to detect tables for proxy models, unmanaged models, or models that can't be migrated on the current database. """ return ( - model._meta.proxy or not model._meta.managed or not - router.allow_migrate( - self.connection.alias, migration.app_label, + model._meta.proxy + or not model._meta.managed + or not router.allow_migrate( + self.connection.alias, + migration.app_label, model_name=model._meta.model_name, ) ) @@ -318,7 +335,9 @@ class MigrationExecutor: return False, project_state if project_state is None: - after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True) + after_state = self.loader.project_state( + (migration.app_label, migration.name), at_end=True + ) else: after_state = migration.mutate_state(project_state) apps = after_state.apps @@ -326,9 +345,13 @@ class MigrationExecutor: found_add_field_migration = False fold_identifier_case = self.connection.features.ignores_table_name_case with self.connection.cursor() as cursor: - existing_table_names = set(self.connection.introspection.table_names(cursor)) + existing_table_names = set( + self.connection.introspection.table_names(cursor) + ) if fold_identifier_case: - existing_table_names = {name.casefold() for name in existing_table_names} + existing_table_names = { + name.casefold() for name in existing_table_names + } # Make sure all create model and add field operations are done for operation in migration.operations: if isinstance(operation, migrations.CreateModel): @@ -368,7 +391,9 @@ class MigrationExecutor: found_add_field_migration = True continue with self.connection.cursor() as cursor: - columns = self.connection.introspection.get_table_description(cursor, table) + columns = self.connection.introspection.get_table_description( + cursor, table + ) for column in columns: field_column = field.column column_name = column.name diff --git a/django/db/migrations/graph.py b/django/db/migrations/graph.py index 4d66822e17..dd845c13e8 100644 --- a/django/db/migrations/graph.py +++ b/django/db/migrations/graph.py @@ -11,6 +11,7 @@ class Node: A single node in the migration graph. Contains direct links to adjacent nodes in either direction. """ + def __init__(self, key): self.key = key self.children = set() @@ -32,7 +33,7 @@ class Node: return str(self.key) def __repr__(self): - return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1]) + return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1]) def add_child(self, child): self.children.add(child) @@ -49,6 +50,7 @@ class DummyNode(Node): After the migration graph is processed, all dummy nodes should be removed. If there are any left, a nonexistent dependency error is raised. """ + def __init__(self, key, origin, error_message): super().__init__(key) self.origin = origin @@ -133,7 +135,7 @@ class MigrationGraph: raise NodeNotFoundError( "Unable to find replacement node %r. It was either never added" " to the migration graph, or has been removed." % (replacement,), - replacement + replacement, ) from err for replaced_key in replaced: self.nodes.pop(replaced_key, None) @@ -167,8 +169,9 @@ class MigrationGraph: except KeyError as err: raise NodeNotFoundError( "Unable to remove replacement node %r. It was either never added" - " to the migration graph, or has been removed already." % (replacement,), - replacement + " to the migration graph, or has been removed already." + % (replacement,), + replacement, ) from err replaced_nodes = set() replaced_nodes_parents = set() @@ -228,7 +231,10 @@ class MigrationGraph: visited.append(node.key) else: stack.append((node, True)) - stack += [(n, False) for n in sorted(node.parents if forwards else node.children)] + stack += [ + (n, False) + for n in sorted(node.parents if forwards else node.children) + ] return visited def root_nodes(self, app=None): @@ -238,7 +244,9 @@ class MigrationGraph: """ roots = set() for node in self.nodes: - if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]): + if all(key[0] != node[0] for key in self.node_map[node].parents) and ( + not app or app == node[0] + ): roots.add(node) return sorted(roots) @@ -252,7 +260,9 @@ class MigrationGraph: """ leaves = set() for node in self.nodes: - if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]): + if all(key[0] != node[0] for key in self.node_map[node].children) and ( + not app or app == node[0] + ): leaves.add(node) return sorted(leaves) @@ -270,8 +280,10 @@ class MigrationGraph: # hashing. node = child.key if node in stack: - cycle = stack[stack.index(node):] - raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle)) + cycle = stack[stack.index(node) :] + raise CircularDependencyError( + ", ".join("%s.%s" % n for n in cycle) + ) if node in todo: stack.append(node) todo.remove(node) @@ -280,14 +292,16 @@ class MigrationGraph: node = stack.pop() def __str__(self): - return 'Graph: %s nodes, %s edges' % self._nodes_and_edges() + return "Graph: %s nodes, %s edges" % self._nodes_and_edges() def __repr__(self): nodes, edges = self._nodes_and_edges() - return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges) + return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges) def _nodes_and_edges(self): - return len(self.nodes), sum(len(node.parents) for node in self.node_map.values()) + return len(self.nodes), sum( + len(node.parents) for node in self.node_map.values() + ) def _generate_plan(self, nodes, at_end): plan = [] diff --git a/django/db/migrations/loader.py b/django/db/migrations/loader.py index 93fb2c3bd5..81dcd06e04 100644 --- a/django/db/migrations/loader.py +++ b/django/db/migrations/loader.py @@ -8,11 +8,13 @@ from django.db.migrations.graph import MigrationGraph from django.db.migrations.recorder import MigrationRecorder from .exceptions import ( - AmbiguityError, BadMigrationError, InconsistentMigrationHistory, + AmbiguityError, + BadMigrationError, + InconsistentMigrationHistory, NodeNotFoundError, ) -MIGRATIONS_MODULE_NAME = 'migrations' +MIGRATIONS_MODULE_NAME = "migrations" class MigrationLoader: @@ -41,7 +43,10 @@ class MigrationLoader: """ def __init__( - self, connection, load=True, ignore_no_migrations=False, + self, + connection, + load=True, + ignore_no_migrations=False, replace_migrations=True, ): self.connection = connection @@ -63,7 +68,7 @@ class MigrationLoader: return settings.MIGRATION_MODULES[app_label], True else: app_package_name = apps.get_app_config(app_label).name - return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False + return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False def load_disk(self): """Load the migrations from all INSTALLED_APPS from disk.""" @@ -80,24 +85,22 @@ class MigrationLoader: try: module = import_module(module_name) except ModuleNotFoundError as e: - if ( - (explicit and self.ignore_no_migrations) or - (not explicit and MIGRATIONS_MODULE_NAME in e.name.split('.')) + if (explicit and self.ignore_no_migrations) or ( + not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".") ): self.unmigrated_apps.add(app_config.label) continue raise else: # Module is not a package (e.g. migrations.py). - if not hasattr(module, '__path__'): + if not hasattr(module, "__path__"): self.unmigrated_apps.add(app_config.label) continue # Empty directories are namespaces. Namespace packages have no # __file__ and don't use a list for __path__. See # https://docs.python.org/3/reference/import.html#namespace-packages - if ( - getattr(module, '__file__', None) is None and - not isinstance(module.__path__, list) + if getattr(module, "__file__", None) is None and not isinstance( + module.__path__, list ): self.unmigrated_apps.add(app_config.label) continue @@ -106,16 +109,17 @@ class MigrationLoader: reload(module) self.migrated_apps.add(app_config.label) migration_names = { - name for _, name, is_pkg in pkgutil.iter_modules(module.__path__) - if not is_pkg and name[0] not in '_~' + name + for _, name, is_pkg in pkgutil.iter_modules(module.__path__) + if not is_pkg and name[0] not in "_~" } # Load migrations for migration_name in migration_names: - migration_path = '%s.%s' % (module_name, migration_name) + migration_path = "%s.%s" % (module_name, migration_name) try: migration_module = import_module(migration_path) except ImportError as e: - if 'bad magic number' in str(e): + if "bad magic number" in str(e): raise ImportError( "Couldn't import %r as it appears to be a stale " ".pyc file." % migration_path @@ -124,9 +128,12 @@ class MigrationLoader: raise if not hasattr(migration_module, "Migration"): raise BadMigrationError( - "Migration %s in app %s has no Migration class" % (migration_name, app_config.label) + "Migration %s in app %s has no Migration class" + % (migration_name, app_config.label) ) - self.disk_migrations[app_config.label, migration_name] = migration_module.Migration( + self.disk_migrations[ + app_config.label, migration_name + ] = migration_module.Migration( migration_name, app_config.label, ) @@ -142,11 +149,14 @@ class MigrationLoader: # Do the search results = [] for migration_app_label, migration_name in self.disk_migrations: - if migration_app_label == app_label and migration_name.startswith(name_prefix): + if migration_app_label == app_label and migration_name.startswith( + name_prefix + ): results.append((migration_app_label, migration_name)) if len(results) > 1: raise AmbiguityError( - "There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix) + "There is more than one migration for '%s' with the prefix '%s'" + % (app_label, name_prefix) ) elif not results: raise KeyError( @@ -181,7 +191,9 @@ class MigrationLoader: if self.ignore_no_migrations: return None else: - raise ValueError("Dependency on app with no migrations: %s" % key[0]) + raise ValueError( + "Dependency on app with no migrations: %s" % key[0] + ) raise ValueError("Dependency on unknown app: %s" % key[0]) def add_internal_dependencies(self, key, migration): @@ -191,7 +203,7 @@ class MigrationLoader: """ for parent in migration.dependencies: # Ignore __first__ references to the same app. - if parent[0] == key[0] and parent[1] != '__first__': + if parent[0] == key[0] and parent[1] != "__first__": self.graph.add_dependency(migration, key, parent, skip_validation=True) def add_external_dependencies(self, key, migration): @@ -241,7 +253,9 @@ class MigrationLoader: for key, migration in self.replacements.items(): # Get applied status of each of this migration's replacement # targets. - applied_statuses = [(target in self.applied_migrations) for target in migration.replaces] + applied_statuses = [ + (target in self.applied_migrations) for target in migration.replaces + ] # The replacing migration is only marked as applied if all of # its replacement targets are. if all(applied_statuses): @@ -273,9 +287,11 @@ class MigrationLoader: # Try to reraise exception with more detail. if exc.node in reverse_replacements: candidates = reverse_replacements.get(exc.node, set()) - is_replaced = any(candidate in self.graph.nodes for candidate in candidates) + is_replaced = any( + candidate in self.graph.nodes for candidate in candidates + ) if not is_replaced: - tries = ', '.join('%s.%s' % c for c in candidates) + tries = ", ".join("%s.%s" % c for c in candidates) raise NodeNotFoundError( "Migration {0} depends on nonexistent node ('{1}', '{2}'). " "Django tried to replace migration {1}.{2} with any of [{3}] " @@ -283,7 +299,7 @@ class MigrationLoader: "are already applied.".format( exc.origin, exc.node[0], exc.node[1], tries ), - exc.node + exc.node, ) from exc raise self.graph.ensure_not_cyclic() @@ -304,12 +320,17 @@ class MigrationLoader: # Skip unapplied squashed migrations that have all of their # `replaces` applied. if parent in self.replacements: - if all(m in applied for m in self.replacements[parent].replaces): + if all( + m in applied for m in self.replacements[parent].replaces + ): continue raise InconsistentMigrationHistory( "Migration {}.{} is applied before its dependency " "{}.{} on database '{}'.".format( - migration[0], migration[1], parent[0], parent[1], + migration[0], + migration[1], + parent[0], + parent[1], connection.alias, ) ) @@ -326,7 +347,9 @@ class MigrationLoader: if app_label in seen_apps: conflicting_apps.add(app_label) seen_apps.setdefault(app_label, set()).add(migration_name) - return {app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps} + return { + app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps + } def project_state(self, nodes=None, at_end=True): """ @@ -335,7 +358,9 @@ class MigrationLoader: See graph.make_state() for the meaning of "nodes" and "at_end". """ - return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps) + return self.graph.make_state( + nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps + ) def collect_sql(self, plan): """ @@ -345,9 +370,13 @@ class MigrationLoader: statements = [] state = None for migration, backwards in plan: - with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor: + with self.connection.schema_editor( + collect_sql=True, atomic=migration.atomic + ) as schema_editor: if state is None: - state = self.project_state((migration.app_label, migration.name), at_end=False) + state = self.project_state( + (migration.app_label, migration.name), at_end=False + ) if not backwards: state = migration.apply(state, schema_editor, collect_sql=True) else: diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py index 5ee0ae5191..39278d4cc7 100644 --- a/django/db/migrations/migration.py +++ b/django/db/migrations/migration.py @@ -60,9 +60,9 @@ class Migration: def __eq__(self, other): return ( - isinstance(other, Migration) and - self.name == other.name and - self.app_label == other.app_label + isinstance(other, Migration) + and self.name == other.name + and self.app_label == other.app_label ) def __repr__(self): @@ -114,15 +114,21 @@ class Migration: old_state = project_state.clone() operation.state_forwards(self.app_label, project_state) # Run the operation - atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False) + atomic_operation = operation.atomic or ( + self.atomic and operation.atomic is not False + ) if not schema_editor.atomic_migration and atomic_operation: # Force a transaction on a non-transactional-DDL backend or an # atomic operation inside a non-atomic migration. with atomic(schema_editor.connection.alias): - operation.database_forwards(self.app_label, schema_editor, old_state, project_state) + operation.database_forwards( + self.app_label, schema_editor, old_state, project_state + ) else: # Normal behaviour - operation.database_forwards(self.app_label, schema_editor, old_state, project_state) + operation.database_forwards( + self.app_label, schema_editor, old_state, project_state + ) return project_state def unapply(self, project_state, schema_editor, collect_sql=False): @@ -145,7 +151,9 @@ class Migration: for operation in self.operations: # If it's irreversible, error out if not operation.reversible: - raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self)) + raise IrreversibleError( + "Operation %s in %s is not reversible" % (operation, self) + ) # Preserve new state from previous run to not tamper the same state # over all operations new_state = new_state.clone() @@ -165,15 +173,21 @@ class Migration: schema_editor.collected_sql.append("--") if not operation.reduces_to_sql: continue - atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False) + atomic_operation = operation.atomic or ( + self.atomic and operation.atomic is not False + ) if not schema_editor.atomic_migration and atomic_operation: # Force a transaction on a non-transactional-DDL backend or an # atomic operation inside a non-atomic migration. with atomic(schema_editor.connection.alias): - operation.database_backwards(self.app_label, schema_editor, from_state, to_state) + operation.database_backwards( + self.app_label, schema_editor, from_state, to_state + ) else: # Normal behaviour - operation.database_backwards(self.app_label, schema_editor, from_state, to_state) + operation.database_backwards( + self.app_label, schema_editor, from_state, to_state + ) return project_state def suggest_name(self): @@ -183,19 +197,19 @@ class Migration: name to avoid VCS conflicts if possible. """ if self.initial: - return 'initial' + return "initial" raw_fragments = [op.migration_name_fragment for op in self.operations] fragments = [name for name in raw_fragments if name] if not fragments or len(fragments) != len(self.operations): - return 'auto_%s' % get_migration_name_timestamp() + return "auto_%s" % get_migration_name_timestamp() name = fragments[0] for fragment in fragments[1:]: - new_name = f'{name}_{fragment}' + new_name = f"{name}_{fragment}" if len(new_name) > 52: - name = f'{name}_and_more' + name = f"{name}_and_more" break name = new_name return name diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py index 119c955868..793969ed12 100644 --- a/django/db/migrations/operations/__init__.py +++ b/django/db/migrations/operations/__init__.py @@ -1,17 +1,40 @@ from .fields import AddField, AlterField, RemoveField, RenameField from .models import ( - AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers, - AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo, - AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint, - RemoveIndex, RenameModel, + AddConstraint, + AddIndex, + AlterIndexTogether, + AlterModelManagers, + AlterModelOptions, + AlterModelTable, + AlterOrderWithRespectTo, + AlterUniqueTogether, + CreateModel, + DeleteModel, + RemoveConstraint, + RemoveIndex, + RenameModel, ) from .special import RunPython, RunSQL, SeparateDatabaseAndState __all__ = [ - 'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether', - 'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex', - 'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField', - 'AddConstraint', 'RemoveConstraint', - 'SeparateDatabaseAndState', 'RunSQL', 'RunPython', - 'AlterOrderWithRespectTo', 'AlterModelManagers', + "CreateModel", + "DeleteModel", + "AlterModelTable", + "AlterUniqueTogether", + "RenameModel", + "AlterIndexTogether", + "AlterModelOptions", + "AddIndex", + "RemoveIndex", + "AddField", + "RemoveField", + "AlterField", + "RenameField", + "AddConstraint", + "RemoveConstraint", + "SeparateDatabaseAndState", + "RunSQL", + "RunPython", + "AlterOrderWithRespectTo", + "AlterModelManagers", ] diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py index 18935520f8..7d4dff2597 100644 --- a/django/db/migrations/operations/base.py +++ b/django/db/migrations/operations/base.py @@ -56,14 +56,18 @@ class Operation: Take the state from the previous migration, and mutate it so that it matches what this migration would perform. """ - raise NotImplementedError('subclasses of Operation must provide a state_forwards() method') + raise NotImplementedError( + "subclasses of Operation must provide a state_forwards() method" + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): """ Perform the mutation on the database schema in the normal (forwards) direction. """ - raise NotImplementedError('subclasses of Operation must provide a database_forwards() method') + raise NotImplementedError( + "subclasses of Operation must provide a database_forwards() method" + ) def database_backwards(self, app_label, schema_editor, from_state, to_state): """ @@ -71,7 +75,9 @@ class Operation: direction - e.g. if this were CreateModel, it would in fact drop the model's table. """ - raise NotImplementedError('subclasses of Operation must provide a database_backwards() method') + raise NotImplementedError( + "subclasses of Operation must provide a database_backwards() method" + ) def describe(self): """ diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 094c3e3cda..cd3aab43ad 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -23,16 +23,23 @@ class FieldOperation(Operation): return self.model_name_lower == operation.model_name_lower def is_same_field_operation(self, operation): - return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower + return ( + self.is_same_model_operation(operation) + and self.name_lower == operation.name_lower + ) def references_model(self, name, app_label): name_lower = name.lower() if name_lower == self.model_name_lower: return True if self.field: - return bool(field_references( - (app_label, self.model_name_lower), self.field, (app_label, name_lower) - )) + return bool( + field_references( + (app_label, self.model_name_lower), + self.field, + (app_label, name_lower), + ) + ) return False def references_field(self, model_name, name, app_label): @@ -41,22 +48,27 @@ class FieldOperation(Operation): if model_name_lower == self.model_name_lower: if name == self.name: return True - elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: + elif ( + self.field + and hasattr(self.field, "from_fields") + and name in self.field.from_fields + ): return True # Check if this operation remotely references the field. if self.field is None: return False - return bool(field_references( - (app_label, self.model_name_lower), - self.field, - (app_label, model_name_lower), - name, - )) + return bool( + field_references( + (app_label, self.model_name_lower), + self.field, + (app_label, model_name_lower), + name, + ) + ) def reduce(self, operation, app_label): - return ( - super().reduce(operation, app_label) or - not operation.references_field(self.model_name, self.name, app_label) + return super().reduce(operation, app_label) or not operation.references_field( + self.model_name, self.name, app_label ) @@ -69,17 +81,13 @@ class AddField(FieldOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'name': self.name, - 'field': self.field, + "model_name": self.model_name, + "name": self.name, + "field": self.field, } if self.preserve_default is not True: - kwargs['preserve_default'] = self.preserve_default - return ( - self.__class__.__name__, - [], - kwargs - ) + kwargs["preserve_default"] = self.preserve_default + return (self.__class__.__name__, [], kwargs) def state_forwards(self, app_label, state): state.add_field( @@ -107,17 +115,21 @@ class AddField(FieldOperation): def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, from_model): - schema_editor.remove_field(from_model, from_model._meta.get_field(self.name)) + schema_editor.remove_field( + from_model, from_model._meta.get_field(self.name) + ) def describe(self): return "Add field %s to %s" % (self.name, self.model_name) @property def migration_name_fragment(self): - return '%s_%s' % (self.model_name_lower, self.name_lower) + return "%s_%s" % (self.model_name_lower, self.name_lower) def reduce(self, operation, app_label): - if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation): + if isinstance(operation, FieldOperation) and self.is_same_field_operation( + operation + ): if isinstance(operation, AlterField): return [ AddField( @@ -144,14 +156,10 @@ class RemoveField(FieldOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'name': self.name, + "model_name": self.model_name, + "name": self.name, } - return ( - self.__class__.__name__, - [], - kwargs - ) + return (self.__class__.__name__, [], kwargs) def state_forwards(self, app_label, state): state.remove_field(app_label, self.model_name_lower, self.name) @@ -159,7 +167,9 @@ class RemoveField(FieldOperation): def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, from_model): - schema_editor.remove_field(from_model, from_model._meta.get_field(self.name)) + schema_editor.remove_field( + from_model, from_model._meta.get_field(self.name) + ) def database_backwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -172,11 +182,15 @@ class RemoveField(FieldOperation): @property def migration_name_fragment(self): - return 'remove_%s_%s' % (self.model_name_lower, self.name_lower) + return "remove_%s_%s" % (self.model_name_lower, self.name_lower) def reduce(self, operation, app_label): from .models import DeleteModel - if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower: + + if ( + isinstance(operation, DeleteModel) + and operation.name_lower == self.model_name_lower + ): return [operation] return super().reduce(operation, app_label) @@ -193,17 +207,13 @@ class AlterField(FieldOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'name': self.name, - 'field': self.field, + "model_name": self.model_name, + "name": self.name, + "field": self.field, } if self.preserve_default is not True: - kwargs['preserve_default'] = self.preserve_default - return ( - self.__class__.__name__, - [], - kwargs - ) + kwargs["preserve_default"] = self.preserve_default + return (self.__class__.__name__, [], kwargs) def state_forwards(self, app_label, state): state.alter_field( @@ -234,15 +244,17 @@ class AlterField(FieldOperation): @property def migration_name_fragment(self): - return 'alter_%s_%s' % (self.model_name_lower, self.name_lower) + return "alter_%s_%s" % (self.model_name_lower, self.name_lower) def reduce(self, operation, app_label): - if isinstance(operation, RemoveField) and self.is_same_field_operation(operation): + if isinstance(operation, RemoveField) and self.is_same_field_operation( + operation + ): return [operation] elif ( - isinstance(operation, RenameField) and - self.is_same_field_operation(operation) and - self.field.db_column is None + isinstance(operation, RenameField) + and self.is_same_field_operation(operation) + and self.field.db_column is None ): return [ operation, @@ -273,18 +285,16 @@ class RenameField(FieldOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'old_name': self.old_name, - 'new_name': self.new_name, + "model_name": self.model_name, + "old_name": self.old_name, + "new_name": self.new_name, } - return ( - self.__class__.__name__, - [], - kwargs - ) + return (self.__class__.__name__, [], kwargs) def state_forwards(self, app_label, state): - state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name) + state.rename_field( + app_label, self.model_name_lower, self.old_name, self.new_name + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -307,11 +317,15 @@ class RenameField(FieldOperation): ) def describe(self): - return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name) + return "Rename field %s on %s to %s" % ( + self.old_name, + self.model_name, + self.new_name, + ) @property def migration_name_fragment(self): - return 'rename_%s_%s_%s' % ( + return "rename_%s_%s_%s" % ( self.old_name_lower, self.model_name_lower, self.new_name_lower, @@ -319,14 +333,15 @@ class RenameField(FieldOperation): def references_field(self, model_name, name, app_label): return self.references_model(model_name, app_label) and ( - name.lower() == self.old_name_lower or - name.lower() == self.new_name_lower + name.lower() == self.old_name_lower or name.lower() == self.new_name_lower ) def reduce(self, operation, app_label): - if (isinstance(operation, RenameField) and - self.is_same_model_operation(operation) and - self.new_name_lower == operation.old_name_lower): + if ( + isinstance(operation, RenameField) + and self.is_same_model_operation(operation) + and self.new_name_lower == operation.old_name_lower + ): return [ RenameField( self.model_name, @@ -336,10 +351,7 @@ class RenameField(FieldOperation): ] # Skip `FieldOperation.reduce` as we want to run `references_field` # against self.old_name and self.new_name. - return ( - super(FieldOperation, self).reduce(operation, app_label) or - not ( - operation.references_field(self.model_name, self.old_name, app_label) or - operation.references_field(self.model_name, self.new_name, app_label) - ) + return super(FieldOperation, self).reduce(operation, app_label) or not ( + operation.references_field(self.model_name, self.old_name, app_label) + or operation.references_field(self.model_name, self.new_name, app_label) ) diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 01c44a9a26..90fc31bee5 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -5,9 +5,7 @@ from django.db.migrations.utils import field_references, resolve_relation from django.db.models.options import normalize_together from django.utils.functional import cached_property -from .fields import ( - AddField, AlterField, FieldOperation, RemoveField, RenameField, -) +from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField def _check_for_duplicates(arg_name, objs): @@ -32,9 +30,8 @@ class ModelOperation(Operation): return name.lower() == self.name_lower def reduce(self, operation, app_label): - return ( - super().reduce(operation, app_label) or - self.can_reduce_through(operation, app_label) + return super().reduce(operation, app_label) or self.can_reduce_through( + operation, app_label ) def can_reduce_through(self, operation, app_label): @@ -44,7 +41,7 @@ class ModelOperation(Operation): class CreateModel(ModelOperation): """Create a model's table.""" - serialization_expand_args = ['fields', 'options', 'managers'] + serialization_expand_args = ["fields", "options", "managers"] def __init__(self, name, fields, options=None, bases=None, managers=None): self.fields = fields @@ -54,40 +51,44 @@ class CreateModel(ModelOperation): super().__init__(name) # Sanity-check that there are no duplicated field names, bases, or # manager names - _check_for_duplicates('fields', (name for name, _ in self.fields)) - _check_for_duplicates('bases', ( - base._meta.label_lower if hasattr(base, '_meta') else - base.lower() if isinstance(base, str) else base - for base in self.bases - )) - _check_for_duplicates('managers', (name for name, _ in self.managers)) + _check_for_duplicates("fields", (name for name, _ in self.fields)) + _check_for_duplicates( + "bases", + ( + base._meta.label_lower + if hasattr(base, "_meta") + else base.lower() + if isinstance(base, str) + else base + for base in self.bases + ), + ) + _check_for_duplicates("managers", (name for name, _ in self.managers)) def deconstruct(self): kwargs = { - 'name': self.name, - 'fields': self.fields, + "name": self.name, + "fields": self.fields, } if self.options: - kwargs['options'] = self.options + kwargs["options"] = self.options if self.bases and self.bases != (models.Model,): - kwargs['bases'] = self.bases - if self.managers and self.managers != [('objects', models.Manager())]: - kwargs['managers'] = self.managers - return ( - self.__class__.__qualname__, - [], - kwargs - ) + kwargs["bases"] = self.bases + if self.managers and self.managers != [("objects", models.Manager())]: + kwargs["managers"] = self.managers + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): - state.add_model(ModelState( - app_label, - self.name, - list(self.fields), - dict(self.options), - tuple(self.bases), - list(self.managers), - )) + state.add_model( + ModelState( + app_label, + self.name, + list(self.fields), + dict(self.options), + tuple(self.bases), + list(self.managers), + ) + ) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = to_state.apps.get_model(app_label, self.name) @@ -100,7 +101,10 @@ class CreateModel(ModelOperation): schema_editor.delete_model(model) def describe(self): - return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name) + return "Create %smodel %s" % ( + "proxy " if self.options.get("proxy", False) else "", + self.name, + ) @property def migration_name_fragment(self): @@ -114,22 +118,32 @@ class CreateModel(ModelOperation): # Check we didn't inherit from the model reference_model_tuple = (app_label, name_lower) for base in self.bases: - if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and - resolve_relation(base, app_label) == reference_model_tuple): + if ( + base is not models.Model + and isinstance(base, (models.base.ModelBase, str)) + and resolve_relation(base, app_label) == reference_model_tuple + ): return True # Check we have no FKs/M2Ms with it for _name, field in self.fields: - if field_references((app_label, self.name_lower), field, reference_model_tuple): + if field_references( + (app_label, self.name_lower), field, reference_model_tuple + ): return True return False def reduce(self, operation, app_label): - if (isinstance(operation, DeleteModel) and - self.name_lower == operation.name_lower and - not self.options.get("proxy", False)): + if ( + isinstance(operation, DeleteModel) + and self.name_lower == operation.name_lower + and not self.options.get("proxy", False) + ): return [] - elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower: + elif ( + isinstance(operation, RenameModel) + and self.name_lower == operation.old_name_lower + ): return [ CreateModel( operation.new_name, @@ -139,7 +153,10 @@ class CreateModel(ModelOperation): managers=self.managers, ), ] - elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower: + elif ( + isinstance(operation, AlterModelOptions) + and self.name_lower == operation.name_lower + ): options = {**self.options, **operation.options} for key in operation.ALTER_OPTION_KEYS: if key not in operation.options: @@ -153,27 +170,42 @@ class CreateModel(ModelOperation): managers=self.managers, ), ] - elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower: + elif ( + isinstance(operation, AlterTogetherOptionOperation) + and self.name_lower == operation.name_lower + ): return [ CreateModel( self.name, fields=self.fields, - options={**self.options, **{operation.option_name: operation.option_value}}, + options={ + **self.options, + **{operation.option_name: operation.option_value}, + }, bases=self.bases, managers=self.managers, ), ] - elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower: + elif ( + isinstance(operation, AlterOrderWithRespectTo) + and self.name_lower == operation.name_lower + ): return [ CreateModel( self.name, fields=self.fields, - options={**self.options, 'order_with_respect_to': operation.order_with_respect_to}, + options={ + **self.options, + "order_with_respect_to": operation.order_with_respect_to, + }, bases=self.bases, managers=self.managers, ), ] - elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower: + elif ( + isinstance(operation, FieldOperation) + and self.name_lower == operation.model_name_lower + ): if isinstance(operation, AddField): return [ CreateModel( @@ -199,17 +231,25 @@ class CreateModel(ModelOperation): ] elif isinstance(operation, RemoveField): options = self.options.copy() - for option_name in ('unique_together', 'index_together'): + for option_name in ("unique_together", "index_together"): option = options.pop(option_name, None) if option: - option = set(filter(bool, ( - tuple(f for f in fields if f != operation.name_lower) for fields in option - ))) + option = set( + filter( + bool, + ( + tuple( + f for f in fields if f != operation.name_lower + ) + for fields in option + ), + ) + ) if option: options[option_name] = option - order_with_respect_to = options.get('order_with_respect_to') + order_with_respect_to = options.get("order_with_respect_to") if order_with_respect_to == operation.name_lower: - del options['order_with_respect_to'] + del options["order_with_respect_to"] return [ CreateModel( self.name, @@ -225,16 +265,19 @@ class CreateModel(ModelOperation): ] elif isinstance(operation, RenameField): options = self.options.copy() - for option_name in ('unique_together', 'index_together'): + for option_name in ("unique_together", "index_together"): option = options.get(option_name) if option: options[option_name] = { - tuple(operation.new_name if f == operation.old_name else f for f in fields) + tuple( + operation.new_name if f == operation.old_name else f + for f in fields + ) for fields in option } - order_with_respect_to = options.get('order_with_respect_to') + order_with_respect_to = options.get("order_with_respect_to") if order_with_respect_to == operation.old_name: - options['order_with_respect_to'] = operation.new_name + options["order_with_respect_to"] = operation.new_name return [ CreateModel( self.name, @@ -255,13 +298,9 @@ class DeleteModel(ModelOperation): def deconstruct(self): kwargs = { - 'name': self.name, + "name": self.name, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): state.remove_model(app_label, self.name_lower) @@ -286,7 +325,7 @@ class DeleteModel(ModelOperation): @property def migration_name_fragment(self): - return 'delete_%s' % self.name_lower + return "delete_%s" % self.name_lower class RenameModel(ModelOperation): @@ -307,14 +346,10 @@ class RenameModel(ModelOperation): def deconstruct(self): kwargs = { - 'old_name': self.old_name, - 'new_name': self.new_name, + "old_name": self.old_name, + "new_name": self.new_name, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): state.rename_model(app_label, self.old_name, self.new_name) @@ -341,19 +376,24 @@ class RenameModel(ModelOperation): related_object.related_model._meta.app_label, related_object.related_model._meta.model_name, ) - to_field = to_state.apps.get_model( - *related_key - )._meta.get_field(related_object.field.name) + to_field = to_state.apps.get_model(*related_key)._meta.get_field( + related_object.field.name + ) schema_editor.alter_field( model, related_object.field, to_field, ) # Rename M2M fields whose name is based on this model's name. - fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many) + fields = zip( + old_model._meta.local_many_to_many, new_model._meta.local_many_to_many + ) for (old_field, new_field) in fields: # Skip self-referential fields as these are renamed above. - if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created: + if ( + new_field.model == new_field.related_model + or not new_field.remote_field.through._meta.auto_created + ): continue # Rename the M2M table that's based on this model's name. old_m2m_model = old_field.remote_field.through @@ -372,18 +412,23 @@ class RenameModel(ModelOperation): ) def database_backwards(self, app_label, schema_editor, from_state, to_state): - self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower + self.new_name_lower, self.old_name_lower = ( + self.old_name_lower, + self.new_name_lower, + ) self.new_name, self.old_name = self.old_name, self.new_name self.database_forwards(app_label, schema_editor, from_state, to_state) - self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower + self.new_name_lower, self.old_name_lower = ( + self.old_name_lower, + self.new_name_lower, + ) self.new_name, self.old_name = self.old_name, self.new_name def references_model(self, name, app_label): return ( - name.lower() == self.old_name_lower or - name.lower() == self.new_name_lower + name.lower() == self.old_name_lower or name.lower() == self.new_name_lower ) def describe(self): @@ -391,11 +436,13 @@ class RenameModel(ModelOperation): @property def migration_name_fragment(self): - return 'rename_%s_%s' % (self.old_name_lower, self.new_name_lower) + return "rename_%s_%s" % (self.old_name_lower, self.new_name_lower) def reduce(self, operation, app_label): - if (isinstance(operation, RenameModel) and - self.new_name_lower == operation.old_name_lower): + if ( + isinstance(operation, RenameModel) + and self.new_name_lower == operation.old_name_lower + ): return [ RenameModel( self.old_name, @@ -404,15 +451,17 @@ class RenameModel(ModelOperation): ] # Skip `ModelOperation.reduce` as we want to run `references_model` # against self.new_name. - return ( - super(ModelOperation, self).reduce(operation, app_label) or - not operation.references_model(self.new_name, app_label) - ) + return super(ModelOperation, self).reduce( + operation, app_label + ) or not operation.references_model(self.new_name, app_label) class ModelOptionOperation(ModelOperation): def reduce(self, operation, app_label): - if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower: + if ( + isinstance(operation, (self.__class__, DeleteModel)) + and self.name_lower == operation.name_lower + ): return [operation] return super().reduce(operation, app_label) @@ -426,17 +475,13 @@ class AlterModelTable(ModelOptionOperation): def deconstruct(self): kwargs = { - 'name': self.name, - 'table': self.table, + "name": self.name, + "table": self.table, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): - state.alter_model_options(app_label, self.name_lower, {'db_table': self.table}) + state.alter_model_options(app_label, self.name_lower, {"db_table": self.table}) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -448,7 +493,9 @@ class AlterModelTable(ModelOptionOperation): new_model._meta.db_table, ) # Rename M2M fields whose name is based on this model's db_table - for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many): + for (old_field, new_field) in zip( + old_model._meta.local_many_to_many, new_model._meta.local_many_to_many + ): if new_field.remote_field.through._meta.auto_created: schema_editor.alter_db_table( new_field.remote_field.through, @@ -462,12 +509,12 @@ class AlterModelTable(ModelOptionOperation): def describe(self): return "Rename table for %s to %s" % ( self.name, - self.table if self.table is not None else "(default)" + self.table if self.table is not None else "(default)", ) @property def migration_name_fragment(self): - return 'alter_%s_table' % self.name_lower + return "alter_%s_table" % self.name_lower class AlterTogetherOptionOperation(ModelOptionOperation): @@ -485,14 +532,10 @@ class AlterTogetherOptionOperation(ModelOptionOperation): def deconstruct(self): kwargs = { - 'name': self.name, + "name": self.name, self.option_name: self.option_value, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): state.alter_model_options( @@ -505,7 +548,7 @@ class AlterTogetherOptionOperation(ModelOptionOperation): new_model = to_state.apps.get_model(app_label, self.name) if self.allow_migrate_model(schema_editor.connection.alias, new_model): old_model = from_state.apps.get_model(app_label, self.name) - alter_together = getattr(schema_editor, 'alter_%s' % self.option_name) + alter_together = getattr(schema_editor, "alter_%s" % self.option_name) alter_together( new_model, getattr(old_model._meta, self.option_name, set()), @@ -516,27 +559,26 @@ class AlterTogetherOptionOperation(ModelOptionOperation): return self.database_forwards(app_label, schema_editor, from_state, to_state) def references_field(self, model_name, name, app_label): - return ( - self.references_model(model_name, app_label) and - ( - not self.option_value or - any((name in fields) for fields in self.option_value) - ) + return self.references_model(model_name, app_label) and ( + not self.option_value + or any((name in fields) for fields in self.option_value) ) def describe(self): - return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or '')) + return "Alter %s for %s (%s constraint(s))" % ( + self.option_name, + self.name, + len(self.option_value or ""), + ) @property def migration_name_fragment(self): - return 'alter_%s_%s' % (self.name_lower, self.option_name) + return "alter_%s_%s" % (self.name_lower, self.option_name) def can_reduce_through(self, operation, app_label): - return ( - super().can_reduce_through(operation, app_label) or ( - isinstance(operation, AlterTogetherOptionOperation) and - type(operation) is not type(self) - ) + return super().can_reduce_through(operation, app_label) or ( + isinstance(operation, AlterTogetherOptionOperation) + and type(operation) is not type(self) ) @@ -545,7 +587,8 @@ class AlterUniqueTogether(AlterTogetherOptionOperation): Change the value of unique_together to the target one. Input value of unique_together must be a set of tuples. """ - option_name = 'unique_together' + + option_name = "unique_together" def __init__(self, name, unique_together): super().__init__(name, unique_together) @@ -556,6 +599,7 @@ class AlterIndexTogether(AlterTogetherOptionOperation): Change the value of index_together to the target one. Input value of index_together must be a set of tuples. """ + option_name = "index_together" def __init__(self, name, index_together): @@ -565,7 +609,7 @@ class AlterIndexTogether(AlterTogetherOptionOperation): class AlterOrderWithRespectTo(ModelOptionOperation): """Represent a change with the order_with_respect_to option.""" - option_name = 'order_with_respect_to' + option_name = "order_with_respect_to" def __init__(self, name, order_with_respect_to): self.order_with_respect_to = order_with_respect_to @@ -573,14 +617,10 @@ class AlterOrderWithRespectTo(ModelOptionOperation): def deconstruct(self): kwargs = { - 'name': self.name, - 'order_with_respect_to': self.order_with_respect_to, + "name": self.name, + "order_with_respect_to": self.order_with_respect_to, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): state.alter_model_options( @@ -594,11 +634,19 @@ class AlterOrderWithRespectTo(ModelOptionOperation): if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.name) # Remove a field if we need to - if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to: - schema_editor.remove_field(from_model, from_model._meta.get_field("_order")) + if ( + from_model._meta.order_with_respect_to + and not to_model._meta.order_with_respect_to + ): + schema_editor.remove_field( + from_model, from_model._meta.get_field("_order") + ) # Add a field if we need to (altering the column is untouched as # it's likely a rename) - elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to: + elif ( + to_model._meta.order_with_respect_to + and not from_model._meta.order_with_respect_to + ): field = to_model._meta.get_field("_order") if not field.has_default(): field.default = 0 @@ -611,20 +659,19 @@ class AlterOrderWithRespectTo(ModelOptionOperation): self.database_forwards(app_label, schema_editor, from_state, to_state) def references_field(self, model_name, name, app_label): - return ( - self.references_model(model_name, app_label) and - ( - self.order_with_respect_to is None or - name == self.order_with_respect_to - ) + return self.references_model(model_name, app_label) and ( + self.order_with_respect_to is None or name == self.order_with_respect_to ) def describe(self): - return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to) + return "Set order_with_respect_to on %s to %s" % ( + self.name, + self.order_with_respect_to, + ) @property def migration_name_fragment(self): - return 'alter_%s_order_with_respect_to' % self.name_lower + return "alter_%s_order_with_respect_to" % self.name_lower class AlterModelOptions(ModelOptionOperation): @@ -655,14 +702,10 @@ class AlterModelOptions(ModelOptionOperation): def deconstruct(self): kwargs = { - 'name': self.name, - 'options': self.options, + "name": self.name, + "options": self.options, } - return ( - self.__class__.__qualname__, - [], - kwargs - ) + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): state.alter_model_options( @@ -683,24 +726,20 @@ class AlterModelOptions(ModelOptionOperation): @property def migration_name_fragment(self): - return 'alter_%s_options' % self.name_lower + return "alter_%s_options" % self.name_lower class AlterModelManagers(ModelOptionOperation): """Alter the model's managers.""" - serialization_expand_args = ['managers'] + serialization_expand_args = ["managers"] def __init__(self, name, managers): self.managers = managers super().__init__(name) def deconstruct(self): - return ( - self.__class__.__qualname__, - [self.name, self.managers], - {} - ) + return (self.__class__.__qualname__, [self.name, self.managers], {}) def state_forwards(self, app_label, state): state.alter_model_managers(app_label, self.name_lower, self.managers) @@ -716,11 +755,11 @@ class AlterModelManagers(ModelOptionOperation): @property def migration_name_fragment(self): - return 'alter_%s_managers' % self.name_lower + return "alter_%s_managers" % self.name_lower class IndexOperation(Operation): - option_name = 'indexes' + option_name = "indexes" @cached_property def model_name_lower(self): @@ -754,8 +793,8 @@ class AddIndex(IndexOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'index': self.index, + "model_name": self.model_name, + "index": self.index, } return ( self.__class__.__qualname__, @@ -765,20 +804,20 @@ class AddIndex(IndexOperation): def describe(self): if self.index.expressions: - return 'Create index %s on %s on model %s' % ( + return "Create index %s on %s on model %s" % ( self.index.name, - ', '.join([str(expression) for expression in self.index.expressions]), + ", ".join([str(expression) for expression in self.index.expressions]), self.model_name, ) - return 'Create index %s on field(s) %s of model %s' % ( + return "Create index %s on field(s) %s of model %s" % ( self.index.name, - ', '.join(self.index.fields), + ", ".join(self.index.fields), self.model_name, ) @property def migration_name_fragment(self): - return '%s_%s' % (self.model_name_lower, self.index.name.lower()) + return "%s_%s" % (self.model_name_lower, self.index.name.lower()) class RemoveIndex(IndexOperation): @@ -807,8 +846,8 @@ class RemoveIndex(IndexOperation): def deconstruct(self): kwargs = { - 'model_name': self.model_name, - 'name': self.name, + "model_name": self.model_name, + "name": self.name, } return ( self.__class__.__qualname__, @@ -817,15 +856,15 @@ class RemoveIndex(IndexOperation): ) def describe(self): - return 'Remove index %s from %s' % (self.name, self.model_name) + return "Remove index %s from %s" % (self.name, self.model_name) @property def migration_name_fragment(self): - return 'remove_%s_%s' % (self.model_name_lower, self.name.lower()) + return "remove_%s_%s" % (self.model_name_lower, self.name.lower()) class AddConstraint(IndexOperation): - option_name = 'constraints' + option_name = "constraints" def __init__(self, model_name, constraint): self.model_name = model_name @@ -845,21 +884,28 @@ class AddConstraint(IndexOperation): schema_editor.remove_constraint(model, self.constraint) def deconstruct(self): - return self.__class__.__name__, [], { - 'model_name': self.model_name, - 'constraint': self.constraint, - } + return ( + self.__class__.__name__, + [], + { + "model_name": self.model_name, + "constraint": self.constraint, + }, + ) def describe(self): - return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name) + return "Create constraint %s on model %s" % ( + self.constraint.name, + self.model_name, + ) @property def migration_name_fragment(self): - return '%s_%s' % (self.model_name_lower, self.constraint.name.lower()) + return "%s_%s" % (self.model_name_lower, self.constraint.name.lower()) class RemoveConstraint(IndexOperation): - option_name = 'constraints' + option_name = "constraints" def __init__(self, model_name, name): self.model_name = model_name @@ -883,14 +929,18 @@ class RemoveConstraint(IndexOperation): schema_editor.add_constraint(model, constraint) def deconstruct(self): - return self.__class__.__name__, [], { - 'model_name': self.model_name, - 'name': self.name, - } + return ( + self.__class__.__name__, + [], + { + "model_name": self.model_name, + "name": self.name, + }, + ) def describe(self): - return 'Remove constraint %s from model %s' % (self.name, self.model_name) + return "Remove constraint %s from model %s" % (self.name, self.model_name) @property def migration_name_fragment(self): - return 'remove_%s_%s' % (self.model_name_lower, self.name.lower()) + return "remove_%s_%s" % (self.model_name_lower, self.name.lower()) diff --git a/django/db/migrations/operations/special.py b/django/db/migrations/operations/special.py index 5a8510ec02..94a6ec72de 100644 --- a/django/db/migrations/operations/special.py +++ b/django/db/migrations/operations/special.py @@ -11,7 +11,7 @@ class SeparateDatabaseAndState(Operation): that affect the state or not the database, or so on. """ - serialization_expand_args = ['database_operations', 'state_operations'] + serialization_expand_args = ["database_operations", "state_operations"] def __init__(self, database_operations=None, state_operations=None): self.database_operations = database_operations or [] @@ -20,14 +20,10 @@ class SeparateDatabaseAndState(Operation): def deconstruct(self): kwargs = {} if self.database_operations: - kwargs['database_operations'] = self.database_operations + kwargs["database_operations"] = self.database_operations if self.state_operations: - kwargs['state_operations'] = self.state_operations - return ( - self.__class__.__qualname__, - [], - kwargs - ) + kwargs["state_operations"] = self.state_operations + return (self.__class__.__qualname__, [], kwargs) def state_forwards(self, app_label, state): for state_operation in self.state_operations: @@ -38,7 +34,9 @@ class SeparateDatabaseAndState(Operation): for database_operation in self.database_operations: to_state = from_state.clone() database_operation.state_forwards(app_label, to_state) - database_operation.database_forwards(app_label, schema_editor, from_state, to_state) + database_operation.database_forwards( + app_label, schema_editor, from_state, to_state + ) from_state = to_state def database_backwards(self, app_label, schema_editor, from_state, to_state): @@ -54,7 +52,9 @@ class SeparateDatabaseAndState(Operation): for database_operation in reversed(self.database_operations): from_state = to_state to_state = to_states[database_operation] - database_operation.database_backwards(app_label, schema_editor, from_state, to_state) + database_operation.database_backwards( + app_label, schema_editor, from_state, to_state + ) def describe(self): return "Custom state/database change combination" @@ -67,9 +67,12 @@ class RunSQL(Operation): Also accept a list of operations that represent the state change effected by this SQL change, in case it's custom column/table creation/deletion. """ - noop = '' - def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False): + noop = "" + + def __init__( + self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False + ): self.sql = sql self.reverse_sql = reverse_sql self.state_operations = state_operations or [] @@ -78,19 +81,15 @@ class RunSQL(Operation): def deconstruct(self): kwargs = { - 'sql': self.sql, + "sql": self.sql, } if self.reverse_sql is not None: - kwargs['reverse_sql'] = self.reverse_sql + kwargs["reverse_sql"] = self.reverse_sql if self.state_operations: - kwargs['state_operations'] = self.state_operations + kwargs["state_operations"] = self.state_operations if self.hints: - kwargs['hints'] = self.hints - return ( - self.__class__.__qualname__, - [], - kwargs - ) + kwargs["hints"] = self.hints + return (self.__class__.__qualname__, [], kwargs) @property def reversible(self): @@ -101,13 +100,17 @@ class RunSQL(Operation): state_operation.state_forwards(app_label, state) def database_forwards(self, app_label, schema_editor, from_state, to_state): - if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints): + if router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): self._run_sql(schema_editor, self.sql) def database_backwards(self, app_label, schema_editor, from_state, to_state): if self.reverse_sql is None: raise NotImplementedError("You cannot reverse this operation") - if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints): + if router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): self._run_sql(schema_editor, self.reverse_sql) def describe(self): @@ -137,7 +140,9 @@ class RunPython(Operation): reduces_to_sql = False - def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False): + def __init__( + self, code, reverse_code=None, atomic=None, hints=None, elidable=False + ): self.atomic = atomic # Forwards code if not callable(code): @@ -155,19 +160,15 @@ class RunPython(Operation): def deconstruct(self): kwargs = { - 'code': self.code, + "code": self.code, } if self.reverse_code is not None: - kwargs['reverse_code'] = self.reverse_code + kwargs["reverse_code"] = self.reverse_code if self.atomic is not None: - kwargs['atomic'] = self.atomic + kwargs["atomic"] = self.atomic if self.hints: - kwargs['hints'] = self.hints - return ( - self.__class__.__qualname__, - [], - kwargs - ) + kwargs["hints"] = self.hints + return (self.__class__.__qualname__, [], kwargs) @property def reversible(self): @@ -182,7 +183,9 @@ class RunPython(Operation): # RunPython has access to all models. Ensure that all models are # reloaded in case any are delayed. from_state.clear_delayed_apps_cache() - if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints): + if router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): # We now execute the Python code in a context that contains a 'models' # object, representing the versioned models as an app registry. # We could try to override the global cache, but then people will still @@ -192,7 +195,9 @@ class RunPython(Operation): def database_backwards(self, app_label, schema_editor, from_state, to_state): if self.reverse_code is None: raise NotImplementedError("You cannot reverse this operation") - if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints): + if router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): self.reverse_code(from_state.apps, schema_editor) def describe(self): diff --git a/django/db/migrations/optimizer.py b/django/db/migrations/optimizer.py index ee20f62af2..7e5dea2377 100644 --- a/django/db/migrations/optimizer.py +++ b/django/db/migrations/optimizer.py @@ -28,7 +28,7 @@ class MigrationOptimizer: """ # Internal tracking variable for test assertions about # of loops if app_label is None: - raise TypeError('app_label must be a str.') + raise TypeError("app_label must be a str.") self._iterations = 0 while True: result = self.optimize_inner(operations, app_label) @@ -43,10 +43,10 @@ class MigrationOptimizer: for i, operation in enumerate(operations): right = True # Should we reduce on the right or on the left. # Compare it to each operation after it - for j, other in enumerate(operations[i + 1:]): + for j, other in enumerate(operations[i + 1 :]): result = operation.reduce(other, app_label) if isinstance(result, list): - in_between = operations[i + 1:i + j + 1] + in_between = operations[i + 1 : i + j + 1] if right: new_operations.extend(in_between) new_operations.extend(result) @@ -59,7 +59,7 @@ class MigrationOptimizer: # Otherwise keep trying. new_operations.append(operation) break - new_operations.extend(operations[i + j + 2:]) + new_operations.extend(operations[i + j + 2 :]) return new_operations elif not result: # Can't perform a right reduction. diff --git a/django/db/migrations/questioner.py b/django/db/migrations/questioner.py index 3460e2b3ab..e1081ab70a 100644 --- a/django/db/migrations/questioner.py +++ b/django/db/migrations/questioner.py @@ -35,7 +35,7 @@ class MigrationQuestioner: # file check will ensure we skip South ones. try: app_config = apps.get_app_config(app_label) - except LookupError: # It's a fake app. + except LookupError: # It's a fake app. return self.defaults.get("ask_initial", False) migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label) if migrations_import_path is None: @@ -88,25 +88,29 @@ class MigrationQuestioner: class InteractiveMigrationQuestioner(MigrationQuestioner): - def __init__(self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None): - super().__init__(defaults=defaults, specified_apps=specified_apps, dry_run=dry_run) + def __init__( + self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None + ): + super().__init__( + defaults=defaults, specified_apps=specified_apps, dry_run=dry_run + ) self.prompt_output = prompt_output or OutputWrapper(sys.stdout) def _boolean_input(self, question, default=None): - self.prompt_output.write(f'{question} ', ending='') + self.prompt_output.write(f"{question} ", ending="") result = input() if not result and default is not None: return default while not result or result[0].lower() not in "yn": - self.prompt_output.write('Please answer yes or no: ', ending='') + self.prompt_output.write("Please answer yes or no: ", ending="") result = input() return result[0].lower() == "y" def _choice_input(self, question, choices): - self.prompt_output.write(f'{question}') + self.prompt_output.write(f"{question}") for i, choice in enumerate(choices): - self.prompt_output.write(' %s) %s' % (i + 1, choice)) - self.prompt_output.write('Select an option: ', ending='') + self.prompt_output.write(" %s) %s" % (i + 1, choice)) + self.prompt_output.write("Select an option: ", ending="") result = input() while True: try: @@ -116,10 +120,10 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): else: if 0 < value <= len(choices): return value - self.prompt_output.write('Please select a valid option: ', ending='') + self.prompt_output.write("Please select a valid option: ", ending="") result = input() - def _ask_default(self, default=''): + def _ask_default(self, default=""): """ Prompt for a default value. @@ -127,15 +131,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): string) which will be shown to the user and used as the return value if the user doesn't provide any other input. """ - self.prompt_output.write('Please enter the default value as valid Python.') + self.prompt_output.write("Please enter the default value as valid Python.") if default: self.prompt_output.write( f"Accept the default '{default}' by pressing 'Enter' or " f"provide another value." ) self.prompt_output.write( - 'The datetime and django.utils.timezone modules are available, so ' - 'it is possible to provide e.g. timezone.now as a value.' + "The datetime and django.utils.timezone modules are available, so " + "it is possible to provide e.g. timezone.now as a value." ) self.prompt_output.write("Type 'exit' to exit this prompt") while True: @@ -143,19 +147,21 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): prompt = "[default: {}] >>> ".format(default) else: prompt = ">>> " - self.prompt_output.write(prompt, ending='') + self.prompt_output.write(prompt, ending="") code = input() if not code and default: code = default if not code: - self.prompt_output.write("Please enter some code, or 'exit' (without quotes) to exit.") + self.prompt_output.write( + "Please enter some code, or 'exit' (without quotes) to exit." + ) elif code == "exit": sys.exit(1) else: try: - return eval(code, {}, {'datetime': datetime, 'timezone': timezone}) + return eval(code, {}, {"datetime": datetime, "timezone": timezone}) except (SyntaxError, NameError) as e: - self.prompt_output.write('Invalid input: %s' % e) + self.prompt_output.write("Invalid input: %s" % e) def ask_not_null_addition(self, field_name, model_name): """Adding a NOT NULL field to a model.""" @@ -167,10 +173,12 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): f"rows.\n" f"Please select a fix:", [ - ("Provide a one-off default now (will be set on all existing " - "rows with a null value for this column)"), - 'Quit and manually define a default value in models.py.', - ] + ( + "Provide a one-off default now (will be set on all existing " + "rows with a null value for this column)" + ), + "Quit and manually define a default value in models.py.", + ], ) if choice == 2: sys.exit(3) @@ -188,13 +196,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): f"populate existing rows.\n" f"Please select a fix:", [ - ("Provide a one-off default now (will be set on all existing " - "rows with a null value for this column)"), - 'Ignore for now. Existing rows that contain NULL values ' - 'will have to be handled manually, for example with a ' - 'RunPython or RunSQL operation.', - 'Quit and manually define a default value in models.py.', - ] + ( + "Provide a one-off default now (will be set on all existing " + "rows with a null value for this column)" + ), + "Ignore for now. Existing rows that contain NULL values " + "will have to be handled manually, for example with a " + "RunPython or RunSQL operation.", + "Quit and manually define a default value in models.py.", + ], ) if choice == 2: return NOT_PROVIDED @@ -206,21 +216,33 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): def ask_rename(self, model_name, old_name, new_name, field_instance): """Was this field really renamed?""" - msg = 'Was %s.%s renamed to %s.%s (a %s)? [y/N]' - return self._boolean_input(msg % (model_name, old_name, model_name, new_name, - field_instance.__class__.__name__), False) + msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]" + return self._boolean_input( + msg + % ( + model_name, + old_name, + model_name, + new_name, + field_instance.__class__.__name__, + ), + False, + ) def ask_rename_model(self, old_model_state, new_model_state): """Was this model really renamed?""" - msg = 'Was the model %s.%s renamed to %s? [y/N]' - return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name, - new_model_state.name), False) + msg = "Was the model %s.%s renamed to %s? [y/N]" + return self._boolean_input( + msg + % (old_model_state.app_label, old_model_state.name, new_model_state.name), + False, + ) def ask_merge(self, app_label): return self._boolean_input( - "\nMerging will only work if the operations printed above do not conflict\n" + - "with each other (working on different fields or models)\n" + - 'Should these migration branches be merged? [y/N]', + "\nMerging will only work if the operations printed above do not conflict\n" + + "with each other (working on different fields or models)\n" + + "Should these migration branches be merged? [y/N]", False, ) @@ -233,15 +255,15 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): f"default. This is because the database needs something to " f"populate existing rows.\n", [ - 'Provide a one-off default now which will be set on all ' - 'existing rows', - 'Quit and manually define a default value in models.py.', - ] + "Provide a one-off default now which will be set on all " + "existing rows", + "Quit and manually define a default value in models.py.", + ], ) if choice == 2: sys.exit(3) else: - return self._ask_default(default='timezone.now') + return self._ask_default(default="timezone.now") return None def ask_unique_callable_default_addition(self, field_name, model_name): @@ -249,16 +271,16 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): if not self.dry_run: version = get_docs_version() choice = self._choice_input( - f'Callable default on unique field {model_name}.{field_name} ' - f'will not generate unique values upon migrating.\n' - f'Please choose how to proceed:\n', + f"Callable default on unique field {model_name}.{field_name} " + f"will not generate unique values upon migrating.\n" + f"Please choose how to proceed:\n", [ - f'Continue making this migration as the first step in ' - f'writing a manual migration to generate unique values ' - f'described here: ' - f'https://docs.djangoproject.com/en/{version}/howto/' - f'writing-migrations/#migrations-that-add-unique-fields.', - 'Quit and edit field options in models.py.', + f"Continue making this migration as the first step in " + f"writing a manual migration to generate unique values " + f"described here: " + f"https://docs.djangoproject.com/en/{version}/howto/" + f"writing-migrations/#migrations-that-add-unique-fields.", + "Quit and edit field options in models.py.", ], ) if choice == 2: @@ -268,13 +290,19 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): class NonInteractiveMigrationQuestioner(MigrationQuestioner): def __init__( - self, defaults=None, specified_apps=None, dry_run=None, verbosity=1, + self, + defaults=None, + specified_apps=None, + dry_run=None, + verbosity=1, log=None, ): self.verbosity = verbosity self.log = log super().__init__( - defaults=defaults, specified_apps=specified_apps, dry_run=dry_run, + defaults=defaults, + specified_apps=specified_apps, + dry_run=dry_run, ) def log_lack_of_migration(self, field_name, model_name, reason): @@ -289,8 +317,8 @@ class NonInteractiveMigrationQuestioner(MigrationQuestioner): self.log_lack_of_migration( field_name, model_name, - 'it is impossible to add a non-nullable field without specifying ' - 'a default', + "it is impossible to add a non-nullable field without specifying " + "a default", ) sys.exit(3) diff --git a/django/db/migrations/recorder.py b/django/db/migrations/recorder.py index 1a37c6b7d0..50876a9ee3 100644 --- a/django/db/migrations/recorder.py +++ b/django/db/migrations/recorder.py @@ -18,6 +18,7 @@ class MigrationRecorder: If a migration is unapplied its row is removed from the table. Having a row in the table always means a migration is applied. """ + _migration_class = None @classproperty @@ -27,6 +28,7 @@ class MigrationRecorder: MigrationRecorder. """ if cls._migration_class is None: + class Migration(models.Model): app = models.CharField(max_length=255) name = models.CharField(max_length=255) @@ -34,11 +36,11 @@ class MigrationRecorder: class Meta: apps = Apps() - app_label = 'migrations' - db_table = 'django_migrations' + app_label = "migrations" + db_table = "django_migrations" def __str__(self): - return 'Migration %s for %s' % (self.name, self.app) + return "Migration %s for %s" % (self.name, self.app) cls._migration_class = Migration return cls._migration_class @@ -67,7 +69,9 @@ class MigrationRecorder: with self.connection.schema_editor() as editor: editor.create_model(self.Migration) except DatabaseError as exc: - raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc) + raise MigrationSchemaMissing( + "Unable to create the django_migrations table (%s)" % exc + ) def applied_migrations(self): """ @@ -75,7 +79,10 @@ class MigrationRecorder: for all applied migrations. """ if self.has_table(): - return {(migration.app, migration.name): migration for migration in self.migration_qs} + return { + (migration.app, migration.name): migration + for migration in self.migration_qs + } else: # If the django_migrations table doesn't exist, then no migrations # are applied. diff --git a/django/db/migrations/serializer.py b/django/db/migrations/serializer.py index 9c58f38e28..fb4a1964d9 100644 --- a/django/db/migrations/serializer.py +++ b/django/db/migrations/serializer.py @@ -25,12 +25,16 @@ class BaseSerializer: self.value = value def serialize(self): - raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.') + raise NotImplementedError( + "Subclasses of BaseSerializer must implement the serialize() method." + ) class BaseSequenceSerializer(BaseSerializer): def _format(self): - raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.') + raise NotImplementedError( + "Subclasses of BaseSequenceSerializer must implement the _format() method." + ) def serialize(self): imports = set() @@ -55,19 +59,21 @@ class ChoicesSerializer(BaseSerializer): class DateTimeSerializer(BaseSerializer): """For datetime.*, except datetime.datetime.""" + def serialize(self): - return repr(self.value), {'import datetime'} + return repr(self.value), {"import datetime"} class DatetimeDatetimeSerializer(BaseSerializer): """For datetime.datetime.""" + def serialize(self): if self.value.tzinfo is not None and self.value.tzinfo != utc: self.value = self.value.astimezone(utc) imports = ["import datetime"] if self.value.tzinfo is not None: imports.append("from django.utils.timezone import utc") - return repr(self.value).replace('datetime.timezone.utc', 'utc'), set(imports) + return repr(self.value).replace("datetime.timezone.utc", "utc"), set(imports) class DecimalSerializer(BaseSerializer): @@ -123,8 +129,8 @@ class EnumSerializer(BaseSerializer): enum_class = self.value.__class__ module = enum_class.__module__ return ( - '%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name), - {'import %s' % module}, + "%s.%s[%r]" % (module, enum_class.__qualname__, self.value.name), + {"import %s" % module}, ) @@ -142,23 +148,29 @@ class FrozensetSerializer(BaseSequenceSerializer): class FunctionTypeSerializer(BaseSerializer): def serialize(self): - if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type): + if getattr(self.value, "__self__", None) and isinstance( + self.value.__self__, type + ): klass = self.value.__self__ module = klass.__module__ - return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module} + return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), { + "import %s" % module + } # Further error checking - if self.value.__name__ == '<lambda>': + if self.value.__name__ == "<lambda>": raise ValueError("Cannot serialize function: lambda") if self.value.__module__ is None: raise ValueError("Cannot serialize function %r: No module" % self.value) module_name = self.value.__module__ - if '<' not in self.value.__qualname__: # Qualname can include <locals> - return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__} + if "<" not in self.value.__qualname__: # Qualname can include <locals> + return "%s.%s" % (module_name, self.value.__qualname__), { + "import %s" % self.value.__module__ + } raise ValueError( - 'Could not find function %s in %s.\n' % (self.value.__name__, module_name) + "Could not find function %s in %s.\n" % (self.value.__name__, module_name) ) @@ -167,11 +179,14 @@ class FunctoolsPartialSerializer(BaseSerializer): # Serialize functools.partial() arguments func_string, func_imports = serializer_factory(self.value.func).serialize() args_string, args_imports = serializer_factory(self.value.args).serialize() - keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize() + keywords_string, keywords_imports = serializer_factory( + self.value.keywords + ).serialize() # Add any imports needed by arguments - imports = {'import functools', *func_imports, *args_imports, *keywords_imports} + imports = {"import functools", *func_imports, *args_imports, *keywords_imports} return ( - 'functools.%s(%s, *%s, **%s)' % ( + "functools.%s(%s, *%s, **%s)" + % ( self.value.__class__.__name__, func_string, args_string, @@ -214,9 +229,10 @@ class ModelManagerSerializer(DeconstructableSerializer): class OperationSerializer(BaseSerializer): def serialize(self): from django.db.migrations.writer import OperationWriter + string, imports = OperationWriter(self.value, indentation=0).serialize() # Nested operation, trailing comma is handled in upper OperationWriter._write() - return string.rstrip(','), imports + return string.rstrip(","), imports class PathLikeSerializer(BaseSerializer): @@ -228,22 +244,24 @@ class PathSerializer(BaseSerializer): def serialize(self): # Convert concrete paths to pure paths to avoid issues with migrations # generated on one platform being used on a different platform. - prefix = 'Pure' if isinstance(self.value, pathlib.Path) else '' - return 'pathlib.%s%r' % (prefix, self.value), {'import pathlib'} + prefix = "Pure" if isinstance(self.value, pathlib.Path) else "" + return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"} class RegexSerializer(BaseSerializer): def serialize(self): - regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize() + regex_pattern, pattern_imports = serializer_factory( + self.value.pattern + ).serialize() # Turn off default implicit flags (e.g. re.U) because regexes with the # same implicit and explicit flags aren't equal. - flags = self.value.flags ^ re.compile('').flags + flags = self.value.flags ^ re.compile("").flags regex_flags, flag_imports = serializer_factory(flags).serialize() - imports = {'import re', *pattern_imports, *flag_imports} + imports = {"import re", *pattern_imports, *flag_imports} args = [regex_pattern] if flags: args.append(regex_flags) - return "re.compile(%s)" % ', '.join(args), imports + return "re.compile(%s)" % ", ".join(args), imports class SequenceSerializer(BaseSequenceSerializer): @@ -255,12 +273,14 @@ class SetSerializer(BaseSequenceSerializer): def _format(self): # Serialize as a set literal except when value is empty because {} # is an empty dict. - return '{%s}' if self.value else 'set(%s)' + return "{%s}" if self.value else "set(%s)" class SettingsReferenceSerializer(BaseSerializer): def serialize(self): - return "settings.%s" % self.value.setting_name, {"from django.conf import settings"} + return "settings.%s" % self.value.setting_name, { + "from django.conf import settings" + } class TupleSerializer(BaseSequenceSerializer): @@ -273,8 +293,8 @@ class TupleSerializer(BaseSequenceSerializer): class TypeSerializer(BaseSerializer): def serialize(self): special_cases = [ - (models.Model, "models.Model", ['from django.db import models']), - (type(None), 'type(None)', []), + (models.Model, "models.Model", ["from django.db import models"]), + (type(None), "type(None)", []), ] for case, string, imports in special_cases: if case is self.value: @@ -284,7 +304,9 @@ class TypeSerializer(BaseSerializer): if module == builtins.__name__: return self.value.__name__, set() else: - return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module} + return "%s.%s" % (module, self.value.__qualname__), { + "import %s" % module + } class UUIDSerializer(BaseSerializer): @@ -309,7 +331,11 @@ class Serializer: (bool, int, type(None), bytes, str, range): BaseSimpleSerializer, decimal.Decimal: DecimalSerializer, (functools.partial, functools.partialmethod): FunctoolsPartialSerializer, - (types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodType, + ): FunctionTypeSerializer, collections.abc.Iterable: IterableSerializer, (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer, uuid.UUID: UUIDSerializer, @@ -320,7 +346,9 @@ class Serializer: @classmethod def register(cls, type_, serializer): if not issubclass(serializer, BaseSerializer): - raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__) + raise ValueError( + "'%s' must inherit from 'BaseSerializer'." % serializer.__name__ + ) cls._registry[type_] = serializer @classmethod @@ -345,7 +373,7 @@ def serializer_factory(value): if isinstance(value, type): return TypeSerializer(value) # Anything that knows how to deconstruct itself. - if hasattr(value, 'deconstruct'): + if hasattr(value, "deconstruct"): return DeconstructableSerializer(value) for type_, serializer_cls in Serializer._registry.items(): if isinstance(value, type_): diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index dfb51d579c..0c9416ae45 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -4,7 +4,8 @@ from contextlib import contextmanager from functools import partial from django.apps import AppConfig -from django.apps.registry import Apps, apps as global_apps +from django.apps.registry import Apps +from django.apps.registry import apps as global_apps from django.conf import settings from django.core.exceptions import FieldDoesNotExist from django.db import models @@ -21,9 +22,9 @@ from .exceptions import InvalidBasesError from .utils import resolve_relation -def _get_app_label_and_model_name(model, app_label=''): +def _get_app_label_and_model_name(model, app_label=""): if isinstance(model, str): - split = model.split('.', 1) + split = model.split(".", 1) return tuple(split) if len(split) == 2 else (app_label, split[0]) else: return model._meta.app_label, model._meta.model_name @@ -32,12 +33,17 @@ def _get_app_label_and_model_name(model, app_label=''): def _get_related_models(m): """Return all models that have a direct relationship to the given model.""" related_models = [ - subclass for subclass in m.__subclasses__() + subclass + for subclass in m.__subclasses__() if issubclass(subclass, models.Model) ] related_fields_models = set() for f in m._meta.get_fields(include_parents=True, include_hidden=True): - if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str): + if ( + f.is_relation + and f.related_model is not None + and not isinstance(f.related_model, str) + ): related_fields_models.add(f.model) related_models.append(f.related_model) # Reverse accessors of foreign keys to proxy models are attached to their @@ -73,7 +79,10 @@ def get_related_models_recursive(model): seen = set() queue = _get_related_models(model) for rel_mod in queue: - rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name + rel_app_label, rel_model_name = ( + rel_mod._meta.app_label, + rel_mod._meta.model_name, + ) if (rel_app_label, rel_model_name) in seen: continue seen.add((rel_app_label, rel_model_name)) @@ -111,7 +120,7 @@ class ProjectState: self.models[model_key] = model_state if self._relations is not None: self.resolve_model_relations(model_key) - if 'apps' in self.__dict__: # hasattr would cache the property + if "apps" in self.__dict__: # hasattr would cache the property self.reload_model(*model_key) def remove_model(self, app_label, model_name): @@ -124,7 +133,7 @@ class ProjectState: model_relations.pop(model_key, None) if not model_relations: del self._relations[related_model_key] - if 'apps' in self.__dict__: # hasattr would cache the property + if "apps" in self.__dict__: # hasattr would cache the property self.apps.unregister_model(*model_key) # Need to do this explicitly since unregister_model() doesn't clear # the cache automatically (#24513) @@ -139,9 +148,11 @@ class ProjectState: self.models[app_label, new_name_lower] = renamed_model # Repoint all fields pointing to the old model to the new one. old_model_tuple = (app_label, old_name_lower) - new_remote_model = f'{app_label}.{new_name}' + new_remote_model = f"{app_label}.{new_name}" to_reload = set() - for model_state, name, field, reference in get_references(self, old_model_tuple): + for model_state, name, field, reference in get_references( + self, old_model_tuple + ): changed_field = None if reference.to: changed_field = field.clone() @@ -193,16 +204,16 @@ class ProjectState: self.reload_model(app_label, model_name, delay=True) def add_index(self, app_label, model_name, index): - self._append_option(app_label, model_name, 'indexes', index) + self._append_option(app_label, model_name, "indexes", index) def remove_index(self, app_label, model_name, index_name): - self._remove_option(app_label, model_name, 'indexes', index_name) + self._remove_option(app_label, model_name, "indexes", index_name) def add_constraint(self, app_label, model_name, constraint): - self._append_option(app_label, model_name, 'constraints', constraint) + self._append_option(app_label, model_name, "constraints", constraint) def remove_constraint(self, app_label, model_name, constraint_name): - self._remove_option(app_label, model_name, 'constraints', constraint_name) + self._remove_option(app_label, model_name, "constraints", constraint_name) def add_field(self, app_label, model_name, name, field, preserve_default): # If preserve default is off, don't use the default for future state. @@ -250,9 +261,8 @@ class ProjectState: # it's sufficient if the new field is (#27737). # Delay rendering of relationships if it's not a relational field and # not referenced by a foreign key. - delay = ( - not field.is_relation and - not field_is_referenced(self, model_key, (name, field)) + delay = not field.is_relation and not field_is_referenced( + self, model_key, (name, field) ) self.reload_model(*model_key, delay=delay) @@ -270,15 +280,17 @@ class ProjectState: fields[new_name] = found for field in fields.values(): # Fix from_fields to refer to the new field. - from_fields = getattr(field, 'from_fields', None) + from_fields = getattr(field, "from_fields", None) if from_fields: - field.from_fields = tuple([ - new_name if from_field_name == old_name else from_field_name - for from_field_name in from_fields - ]) + field.from_fields = tuple( + [ + new_name if from_field_name == old_name else from_field_name + for from_field_name in from_fields + ] + ) # Fix index/unique_together to refer to the new field. options = model_state.options - for option in ('index_together', 'unique_together'): + for option in ("index_together", "unique_together"): if option in options: options[option] = [ [new_name if n == old_name else n for n in together] @@ -291,13 +303,15 @@ class ProjectState: delay = False if reference.to: remote_field, to_fields = reference.to - if getattr(remote_field, 'field_name', None) == old_name: + if getattr(remote_field, "field_name", None) == old_name: remote_field.field_name = new_name if to_fields: - field.to_fields = tuple([ - new_name if to_field_name == old_name else to_field_name - for to_field_name in to_fields - ]) + field.to_fields = tuple( + [ + new_name if to_field_name == old_name else to_field_name + for to_field_name in to_fields + ] + ) if self._relations is not None: old_name_lower = old_name.lower() new_name_lower = new_name.lower() @@ -335,7 +349,9 @@ class ProjectState: if field.is_relation: if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT: continue - rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label) + rel_app_label, rel_model_name = _get_app_label_and_model_name( + field.related_model, app_label + ) direct_related_models.add((rel_app_label, rel_model_name.lower())) # For all direct related models recursively get all related models. @@ -357,15 +373,17 @@ class ProjectState: return related_models def reload_model(self, app_label, model_name, delay=False): - if 'apps' in self.__dict__: # hasattr would cache the property + if "apps" in self.__dict__: # hasattr would cache the property related_models = self._find_reload_model(app_label, model_name, delay) self._reload(related_models) def reload_models(self, models, delay=True): - if 'apps' in self.__dict__: # hasattr would cache the property + if "apps" in self.__dict__: # hasattr would cache the property related_models = set() for app_label, model_name in models: - related_models.update(self._find_reload_model(app_label, model_name, delay)) + related_models.update( + self._find_reload_model(app_label, model_name, delay) + ) self._reload(related_models) def _reload(self, related_models): @@ -395,7 +413,12 @@ class ProjectState: self.apps.render_multiple(states_to_be_rendered) def update_model_field_relation( - self, model, model_key, field_name, field, concretes, + self, + model, + model_key, + field_name, + field, + concretes, ): remote_model_key = resolve_relation(model, *model_key) if remote_model_key[0] not in self.real_apps and remote_model_key in concretes: @@ -413,7 +436,11 @@ class ProjectState: del relations_to_remote_model[model_key] def resolve_model_field_relations( - self, model_key, field_name, field, concretes=None, + self, + model_key, + field_name, + field, + concretes=None, ): remote_field = field.remote_field if not remote_field: @@ -422,13 +449,19 @@ class ProjectState: concretes, _ = self._get_concrete_models_mapping_and_proxy_models() self.update_model_field_relation( - remote_field.model, model_key, field_name, field, concretes, + remote_field.model, + model_key, + field_name, + field, + concretes, ) - through = getattr(remote_field, 'through', None) + through = getattr(remote_field, "through", None) if not through: return - self.update_model_field_relation(through, model_key, field_name, field, concretes) + self.update_model_field_relation( + through, model_key, field_name, field, concretes + ) def resolve_model_relations(self, model_key, concretes=None): if concretes is None: @@ -455,7 +488,10 @@ class ProjectState: self._relations[model_key] = self._relations[concretes[model_key]] def get_concrete_model_key(self, model): - concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models() + ( + concrete_models_mapping, + _, + ) = self._get_concrete_models_mapping_and_proxy_models() model_key = make_model_tuple(model) return concrete_models_mapping[model_key] @@ -464,11 +500,14 @@ class ProjectState: proxy_models = {} # Split models to proxy and concrete models. for model_key, model_state in self.models.items(): - if model_state.options.get('proxy'): + if model_state.options.get("proxy"): proxy_models[model_key] = model_state # Find a concrete model for the proxy. - concrete_models_mapping[model_key] = self._find_concrete_model_from_proxy( - proxy_models, model_state, + concrete_models_mapping[ + model_key + ] = self._find_concrete_model_from_proxy( + proxy_models, + model_state, ) else: concrete_models_mapping[model_key] = model_key @@ -491,14 +530,14 @@ class ProjectState: models={k: v.clone() for k, v in self.models.items()}, real_apps=self.real_apps, ) - if 'apps' in self.__dict__: + if "apps" in self.__dict__: new_state.apps = self.apps.clone() new_state.is_delayed = self.is_delayed return new_state def clear_delayed_apps_cache(self): - if self.is_delayed and 'apps' in self.__dict__: - del self.__dict__['apps'] + if self.is_delayed and "apps" in self.__dict__: + del self.__dict__["apps"] @cached_property def apps(self): @@ -519,6 +558,7 @@ class ProjectState: class AppConfigStub(AppConfig): """Stub of an AppConfig. Only provides a label and a dict of models.""" + def __init__(self, label): self.apps = None self.models = {} @@ -537,6 +577,7 @@ class StateApps(Apps): Subclass of the global Apps registry class to better handle dynamic model additions and removals. """ + def __init__(self, real_apps, models, ignore_swappable=False): # Any apps in self.real_apps should have all their models included # in the render. We don't use the original model instances as there @@ -550,7 +591,9 @@ class StateApps(Apps): self.real_models.append(ModelState.from_model(model, exclude_rels=True)) # Populate the app registry with a stub for each application. app_labels = {model_state.app_label for model_state in models.values()} - app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])] + app_configs = [ + AppConfigStub(label) for label in sorted([*real_apps, *app_labels]) + ] super().__init__(app_configs) # These locks get in the way of copying as implemented in clone(), @@ -563,7 +606,10 @@ class StateApps(Apps): # There shouldn't be any operations pending at this point. from django.core.checks.model_checks import _check_lazy_references - ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set() + + ignore = ( + {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set() + ) errors = _check_lazy_references(self, ignore=ignore) if errors: raise ValueError("\n".join(error.msg for error in errors)) @@ -646,34 +692,36 @@ class ModelState: assign new ones, as these are not detached during a clone. """ - def __init__(self, app_label, name, fields, options=None, bases=None, managers=None): + def __init__( + self, app_label, name, fields, options=None, bases=None, managers=None + ): self.app_label = app_label self.name = name self.fields = dict(fields) self.options = options or {} - self.options.setdefault('indexes', []) - self.options.setdefault('constraints', []) + self.options.setdefault("indexes", []) + self.options.setdefault("constraints", []) self.bases = bases or (models.Model,) self.managers = managers or [] for name, field in self.fields.items(): # Sanity-check that fields are NOT already bound to a model. - if hasattr(field, 'model'): + if hasattr(field, "model"): raise ValueError( 'ModelState.fields cannot be bound to a model - "%s" is.' % name ) # Sanity-check that relation fields are NOT referring to a model class. - if field.is_relation and hasattr(field.related_model, '_meta'): + if field.is_relation and hasattr(field.related_model, "_meta"): raise ValueError( 'ModelState.fields cannot refer to a model class - "%s.to" does. ' - 'Use a string reference instead.' % name + "Use a string reference instead." % name ) - if field.many_to_many and hasattr(field.remote_field.through, '_meta'): + if field.many_to_many and hasattr(field.remote_field.through, "_meta"): raise ValueError( 'ModelState.fields cannot refer to a model class - "%s.through" does. ' - 'Use a string reference instead.' % name + "Use a string reference instead." % name ) # Sanity-check that indexes have their name set. - for index in self.options['indexes']: + for index in self.options["indexes"]: if not index.name: raise ValueError( "Indexes passed to ModelState require a name attribute. " @@ -685,8 +733,8 @@ class ModelState: return self.name.lower() def get_field(self, field_name): - if field_name == '_order': - field_name = self.options.get('order_with_respect_to', field_name) + if field_name == "_order": + field_name = self.options.get("order_with_respect_to", field_name) return self.fields[field_name] @classmethod @@ -703,22 +751,28 @@ class ModelState: try: fields.append((name, field.clone())) except TypeError as e: - raise TypeError("Couldn't reconstruct field %s on %s: %s" % ( - name, - model._meta.label, - e, - )) + raise TypeError( + "Couldn't reconstruct field %s on %s: %s" + % ( + name, + model._meta.label, + e, + ) + ) if not exclude_rels: for field in model._meta.local_many_to_many: name = field.name try: fields.append((name, field.clone())) except TypeError as e: - raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % ( - name, - model._meta.object_name, - e, - )) + raise TypeError( + "Couldn't reconstruct m2m field %s on %s: %s" + % ( + name, + model._meta.object_name, + e, + ) + ) # Extract the options options = {} for name in DEFAULT_NAMES: @@ -737,9 +791,11 @@ class ModelState: for index in indexes: if not index.name: index.set_name_with_model(model) - options['indexes'] = indexes - elif name == 'constraints': - options['constraints'] = [con.clone() for con in model._meta.constraints] + options["indexes"] = indexes + elif name == "constraints": + options["constraints"] = [ + con.clone() for con in model._meta.constraints + ] else: options[name] = model._meta.original_attrs[name] # If we're ignoring relationships, remove all field-listing model @@ -749,8 +805,10 @@ class ModelState: if key in options: del options[key] # Private fields are ignored, so remove options that refer to them. - elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}: - del options['order_with_respect_to'] + elif options.get("order_with_respect_to") in { + field.name for field in model._meta.private_fields + }: + del options["order_with_respect_to"] def flatten_bases(model): bases = [] @@ -766,19 +824,19 @@ class ModelState: # __bases__ we may end up with duplicates and ordering issues, we # therefore discard any duplicates and reorder the bases according # to their index in the MRO. - flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)) + flattened_bases = sorted( + set(flatten_bases(model)), key=lambda x: model.__mro__.index(x) + ) # Make our record bases = tuple( - ( - base._meta.label_lower - if hasattr(base, "_meta") else - base - ) + (base._meta.label_lower if hasattr(base, "_meta") else base) for base in flattened_bases ) # Ensure at least one base inherits from models.Model - if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases): + if not any( + (isinstance(base, str) or issubclass(base, models.Model)) for base in bases + ): bases = (models.Model,) managers = [] @@ -805,7 +863,7 @@ class ModelState: managers.append((manager.name, new_manager)) # Ignore a shimmed default manager called objects if it's the only one. - if managers == [('objects', default_manager_shim)]: + if managers == [("objects", default_manager_shim)]: managers = [] # Construct the new ModelState @@ -848,7 +906,7 @@ class ModelState: def render(self, apps): """Create a Model object from our current state into the given apps.""" # First, make a Meta object - meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options} + meta_contents = {"app_label": self.app_label, "apps": apps, **self.options} meta = type("Meta", (), meta_contents) # Then, work out our bases try: @@ -857,11 +915,13 @@ class ModelState: for base in self.bases ) except LookupError: - raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,)) + raise InvalidBasesError( + "Cannot resolve one or more bases from %r" % (self.bases,) + ) # Clone fields for the body, add other bits. body = {name: field.clone() for name, field in self.fields.items()} - body['Meta'] = meta - body['__module__'] = "__fake__" + body["Meta"] = meta + body["__module__"] = "__fake__" # Restore managers body.update(self.construct_managers()) @@ -869,33 +929,33 @@ class ModelState: return type(self.name, bases, body) def get_index_by_name(self, name): - for index in self.options['indexes']: + for index in self.options["indexes"]: if index.name == name: return index raise ValueError("No index named %s on model %s" % (name, self.name)) def get_constraint_by_name(self, name): - for constraint in self.options['constraints']: + for constraint in self.options["constraints"]: if constraint.name == name: return constraint - raise ValueError('No constraint named %s on model %s' % (name, self.name)) + raise ValueError("No constraint named %s on model %s" % (name, self.name)) def __repr__(self): return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name) def __eq__(self, other): return ( - (self.app_label == other.app_label) and - (self.name == other.name) and - (len(self.fields) == len(other.fields)) and - all( + (self.app_label == other.app_label) + and (self.name == other.name) + and (len(self.fields) == len(other.fields)) + and all( k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:] for (k1, f1), (k2, f2) in zip( sorted(self.fields.items()), sorted(other.fields.items()), ) - ) and - (self.options == other.options) and - (self.bases == other.bases) and - (self.managers == other.managers) + ) + and (self.options == other.options) + and (self.bases == other.bases) + and (self.managers == other.managers) ) diff --git a/django/db/migrations/utils.py b/django/db/migrations/utils.py index 42a4d90340..2b45a6033b 100644 --- a/django/db/migrations/utils.py +++ b/django/db/migrations/utils.py @@ -4,9 +4,9 @@ from collections import namedtuple from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT -FieldReference = namedtuple('FieldReference', 'to through') +FieldReference = namedtuple("FieldReference", "to through") -COMPILED_REGEX_TYPE = type(re.compile('')) +COMPILED_REGEX_TYPE = type(re.compile("")) class RegexObject: @@ -33,16 +33,16 @@ def resolve_relation(model, app_label=None, model_name=None): if model == RECURSIVE_RELATIONSHIP_CONSTANT: if app_label is None or model_name is None: raise TypeError( - 'app_label and model_name must be provided to resolve ' - 'recursive relationships.' + "app_label and model_name must be provided to resolve " + "recursive relationships." ) return app_label, model_name - if '.' in model: - app_label, model_name = model.split('.', 1) + if "." in model: + app_label, model_name = model.split(".", 1) return app_label, model_name.lower() if app_label is None: raise TypeError( - 'app_label must be provided to resolve unscoped model relationships.' + "app_label must be provided to resolve unscoped model relationships." ) return app_label, model.lower() return model._meta.app_label, model._meta.model_name @@ -70,24 +70,32 @@ def field_references( references_to = None references_through = None if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple: - to_fields = getattr(field, 'to_fields', None) + to_fields = getattr(field, "to_fields", None) if ( - reference_field_name is None or + reference_field_name is None + or # Unspecified to_field(s). - to_fields is None or + to_fields is None + or # Reference to primary key. - (None in to_fields and (reference_field is None or reference_field.primary_key)) or + ( + None in to_fields + and (reference_field is None or reference_field.primary_key) + ) + or # Reference to field. reference_field_name in to_fields ): references_to = (remote_field, to_fields) - through = getattr(remote_field, 'through', None) + through = getattr(remote_field, "through", None) if through and resolve_relation(through, *model_tuple) == reference_model_tuple: through_fields = remote_field.through_fields if ( - reference_field_name is None or + reference_field_name is None + or # Unspecified through_fields. - through_fields is None or + through_fields is None + or # Reference to field. reference_field_name in through_fields ): @@ -107,7 +115,9 @@ def get_references(state, model_tuple, field_tuple=()): """ for state_model_tuple, model_state in state.models.items(): for name, field in model_state.fields.items(): - reference = field_references(state_model_tuple, field, model_tuple, *field_tuple) + reference = field_references( + state_model_tuple, field, model_tuple, *field_tuple + ) if reference: yield model_state, name, field, reference diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 4918261fb0..a59f0c8dcb 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -1,10 +1,10 @@ - import os import re from importlib import import_module from django import get_version from django.apps import apps + # SettingsReference imported for backwards compatibility in Django 2.2. from django.conf import SettingsReference # NOQA from django.db import migrations @@ -22,30 +22,30 @@ class OperationWriter: self.indentation = indentation def serialize(self): - def _write(_arg_name, _arg_value): - if (_arg_name in self.operation.serialization_expand_args and - isinstance(_arg_value, (list, tuple, dict))): + if _arg_name in self.operation.serialization_expand_args and isinstance( + _arg_value, (list, tuple, dict) + ): if isinstance(_arg_value, dict): - self.feed('%s={' % _arg_name) + self.feed("%s={" % _arg_name) self.indent() for key, value in _arg_value.items(): key_string, key_imports = MigrationWriter.serialize(key) arg_string, arg_imports = MigrationWriter.serialize(value) args = arg_string.splitlines() if len(args) > 1: - self.feed('%s: %s' % (key_string, args[0])) + self.feed("%s: %s" % (key_string, args[0])) for arg in args[1:-1]: self.feed(arg) - self.feed('%s,' % args[-1]) + self.feed("%s," % args[-1]) else: - self.feed('%s: %s,' % (key_string, arg_string)) + self.feed("%s: %s," % (key_string, arg_string)) imports.update(key_imports) imports.update(arg_imports) self.unindent() - self.feed('},') + self.feed("},") else: - self.feed('%s=[' % _arg_name) + self.feed("%s=[" % _arg_name) self.indent() for item in _arg_value: arg_string, arg_imports = MigrationWriter.serialize(item) @@ -53,22 +53,22 @@ class OperationWriter: if len(args) > 1: for arg in args[:-1]: self.feed(arg) - self.feed('%s,' % args[-1]) + self.feed("%s," % args[-1]) else: - self.feed('%s,' % arg_string) + self.feed("%s," % arg_string) imports.update(arg_imports) self.unindent() - self.feed('],') + self.feed("],") else: arg_string, arg_imports = MigrationWriter.serialize(_arg_value) args = arg_string.splitlines() if len(args) > 1: - self.feed('%s=%s' % (_arg_name, args[0])) + self.feed("%s=%s" % (_arg_name, args[0])) for arg in args[1:-1]: self.feed(arg) - self.feed('%s,' % args[-1]) + self.feed("%s," % args[-1]) else: - self.feed('%s=%s,' % (_arg_name, arg_string)) + self.feed("%s=%s," % (_arg_name, arg_string)) imports.update(arg_imports) imports = set() @@ -79,10 +79,10 @@ class OperationWriter: # We can just use the fact we already have that imported, # otherwise, we need to add an import for the operation class. if getattr(migrations, name, None) == self.operation.__class__: - self.feed('migrations.%s(' % name) + self.feed("migrations.%s(" % name) else: - imports.add('import %s' % (self.operation.__class__.__module__)) - self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) + imports.add("import %s" % (self.operation.__class__.__module__)) + self.feed("%s.%s(" % (self.operation.__class__.__module__, name)) self.indent() @@ -99,7 +99,7 @@ class OperationWriter: _write(arg_name, arg_value) self.unindent() - self.feed('),') + self.feed("),") return self.render(), imports def indent(self): @@ -109,10 +109,10 @@ class OperationWriter: self.indentation -= 1 def feed(self, line): - self.buff.append(' ' * (self.indentation * 4) + line) + self.buff.append(" " * (self.indentation * 4) + line) def render(self): - return '\n'.join(self.buff) + return "\n".join(self.buff) class MigrationWriter: @@ -147,7 +147,10 @@ class MigrationWriter: dependencies = [] for dependency in self.migration.dependencies: if dependency[0] == "__setting__": - dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1]) + dependencies.append( + " migrations.swappable_dependency(settings.%s)," + % dependency[1] + ) imports.add("from django.conf import settings") else: dependencies.append(" %s," % self.serialize(dependency)[0]) @@ -183,24 +186,28 @@ class MigrationWriter: ) % "\n# ".join(sorted(migration_imports)) # If there's a replaces, make a string for it if self.migration.replaces: - items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0] + items["replaces_str"] = ( + "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0] + ) # Hinting that goes into comment if self.include_header: - items['migration_header'] = MIGRATION_HEADER_TEMPLATE % { - 'version': get_version(), - 'timestamp': now().strftime("%Y-%m-%d %H:%M"), + items["migration_header"] = MIGRATION_HEADER_TEMPLATE % { + "version": get_version(), + "timestamp": now().strftime("%Y-%m-%d %H:%M"), } else: - items['migration_header'] = "" + items["migration_header"] = "" if self.migration.initial: - items['initial_str'] = "\n initial = True\n" + items["initial_str"] = "\n initial = True\n" return MIGRATION_TEMPLATE % items @property def basedir(self): - migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label) + migrations_package_name, _ = MigrationLoader.migrations_module( + self.migration.app_label + ) if migrations_package_name is None: raise ValueError( @@ -222,7 +229,11 @@ class MigrationWriter: # Alright, see if it's a direct submodule of the app app_config = apps.get_app_config(self.migration.app_label) - maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".") + ( + maybe_app_name, + _, + migrations_package_basename, + ) = migrations_package_name.rpartition(".") if app_config.name == maybe_app_name: return os.path.join(app_config.path, migrations_package_basename) @@ -246,8 +257,8 @@ class MigrationWriter: raise ValueError( "Could not locate an appropriate location to create " "migrations package %s. Make sure the toplevel " - "package exists and can be imported." % - migrations_package_name) + "package exists and can be imported." % migrations_package_name + ) final_dir = os.path.join(base_dir, *missing_dirs) os.makedirs(final_dir, exist_ok=True) diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index a583af2aff..ffca81de91 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -5,14 +5,34 @@ from django.db.models.aggregates import __all__ as aggregates_all from django.db.models.constraints import * # NOQA from django.db.models.constraints import __all__ as constraints_all from django.db.models.deletion import ( - CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL, - ProtectedError, RestrictedError, + CASCADE, + DO_NOTHING, + PROTECT, + RESTRICT, + SET, + SET_DEFAULT, + SET_NULL, + ProtectedError, + RestrictedError, ) from django.db.models.enums import * # NOQA from django.db.models.enums import __all__ as enums_all from django.db.models.expressions import ( - Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, - OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, + Case, + Exists, + Expression, + ExpressionList, + ExpressionWrapper, + F, + Func, + OrderBy, + OuterRef, + RowRange, + Subquery, + Value, + ValueRange, + When, + Window, WindowFrame, ) from django.db.models.fields import * # NOQA @@ -30,23 +50,66 @@ from django.db.models.query_utils import FilteredRelation, Q # Imports that would create circular imports if sorted from django.db.models.base import DEFERRED, Model # isort:skip from django.db.models.fields.related import ( # isort:skip - ForeignKey, ForeignObject, OneToOneField, ManyToManyField, - ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel, + ForeignKey, + ForeignObject, + OneToOneField, + ManyToManyField, + ForeignObjectRel, + ManyToOneRel, + ManyToManyRel, + OneToOneRel, ) __all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all __all__ += [ - 'ObjectDoesNotExist', 'signals', - 'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT', - 'SET_NULL', 'ProtectedError', 'RestrictedError', - 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F', - 'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value', - 'ValueRange', 'When', - 'Window', 'WindowFrame', - 'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform', - 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', - 'DEFERRED', 'Model', 'FilteredRelation', - 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', - 'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', + "ObjectDoesNotExist", + "signals", + "CASCADE", + "DO_NOTHING", + "PROTECT", + "RESTRICT", + "SET", + "SET_DEFAULT", + "SET_NULL", + "ProtectedError", + "RestrictedError", + "Case", + "Exists", + "Expression", + "ExpressionList", + "ExpressionWrapper", + "F", + "Func", + "OrderBy", + "OuterRef", + "RowRange", + "Subquery", + "Value", + "ValueRange", + "When", + "Window", + "WindowFrame", + "FileField", + "ImageField", + "JSONField", + "OrderWrt", + "Lookup", + "Transform", + "Manager", + "Prefetch", + "Q", + "QuerySet", + "prefetch_related_objects", + "DEFERRED", + "Model", + "FilteredRelation", + "ForeignKey", + "ForeignObject", + "OneToOneField", + "ManyToManyField", + "ForeignObjectRel", + "ManyToOneRel", + "ManyToManyRel", + "OneToOneRel", ] diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index bc31b48d8d..2ffed7cd2c 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -6,28 +6,38 @@ from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import IntegerField from django.db.models.functions.comparison import Coalesce from django.db.models.functions.mixins import ( - FixDurationInputMixin, NumericOutputFieldMixin, + FixDurationInputMixin, + NumericOutputFieldMixin, ) __all__ = [ - 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', + "Aggregate", + "Avg", + "Count", + "Max", + "Min", + "StdDev", + "Sum", + "Variance", ] class Aggregate(Func): - template = '%(function)s(%(distinct)s%(expressions)s)' + template = "%(function)s(%(distinct)s%(expressions)s)" contains_aggregate = True name = None - filter_template = '%s FILTER (WHERE %%(filter)s)' + filter_template = "%s FILTER (WHERE %%(filter)s)" window_compatible = True allow_distinct = False empty_result_set_value = None - def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra): + def __init__( + self, *expressions, distinct=False, filter=None, default=None, **extra + ): if distinct and not self.allow_distinct: raise TypeError("%s does not allow distinct." % self.__class__.__name__) if default is not None and self.empty_result_set_value is not None: - raise TypeError(f'{self.__class__.__name__} does not allow default.') + raise TypeError(f"{self.__class__.__name__} does not allow default.") self.distinct = distinct self.filter = filter self.default = default @@ -47,10 +57,14 @@ class Aggregate(Func): self.filter = self.filter and exprs.pop() return super().set_source_expressions(exprs) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super().resolve_expression(query, allow_joins, reuse, summarize) - c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize) + c.filter = c.filter and c.filter.resolve_expression( + query, allow_joins, reuse, summarize + ) if not summarize: # Call Aggregate.get_source_expressions() to avoid # returning self.filter and including that in this loop. @@ -58,11 +72,18 @@ class Aggregate(Func): for index, expr in enumerate(expressions): if expr.contains_aggregate: before_resolved = self.get_source_expressions()[index] - name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved) - raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name)) + name = ( + before_resolved.name + if hasattr(before_resolved, "name") + else repr(before_resolved) + ) + raise FieldError( + "Cannot compute %s('%s'): '%s' is an aggregate" + % (c.name, name, name) + ) if (default := c.default) is None: return c - if hasattr(default, 'resolve_expression'): + if hasattr(default, "resolve_expression"): default = default.resolve_expression(query, allow_joins, reuse, summarize) c.default = None # Reset the default argument before wrapping. coalesce = Coalesce(c, default, output_field=c._output_field_or_none) @@ -72,22 +93,27 @@ class Aggregate(Func): @property def default_alias(self): expressions = self.get_source_expressions() - if len(expressions) == 1 and hasattr(expressions[0], 'name'): - return '%s__%s' % (expressions[0].name, self.name.lower()) + if len(expressions) == 1 and hasattr(expressions[0], "name"): + return "%s__%s" % (expressions[0].name, self.name.lower()) raise TypeError("Complex expressions require an alias") def get_group_by_cols(self, alias=None): return [] def as_sql(self, compiler, connection, **extra_context): - extra_context['distinct'] = 'DISTINCT ' if self.distinct else '' + extra_context["distinct"] = "DISTINCT " if self.distinct else "" if self.filter: if connection.features.supports_aggregate_filter_clause: filter_sql, filter_params = self.filter.as_sql(compiler, connection) - template = self.filter_template % extra_context.get('template', self.template) + template = self.filter_template % extra_context.get( + "template", self.template + ) sql, params = super().as_sql( - compiler, connection, template=template, filter=filter_sql, - **extra_context + compiler, + connection, + template=template, + filter=filter_sql, + **extra_context, ) return sql, (*params, *filter_params) else: @@ -96,72 +122,74 @@ class Aggregate(Func): source_expressions = copy.get_source_expressions() condition = When(self.filter, then=source_expressions[0]) copy.set_source_expressions([Case(condition)] + source_expressions[1:]) - return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) + return super(Aggregate, copy).as_sql( + compiler, connection, **extra_context + ) return super().as_sql(compiler, connection, **extra_context) def _get_repr_options(self): options = super()._get_repr_options() if self.distinct: - options['distinct'] = self.distinct + options["distinct"] = self.distinct if self.filter: - options['filter'] = self.filter + options["filter"] = self.filter return options class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): - function = 'AVG' - name = 'Avg' + function = "AVG" + name = "Avg" allow_distinct = True class Count(Aggregate): - function = 'COUNT' - name = 'Count' + function = "COUNT" + name = "Count" output_field = IntegerField() allow_distinct = True empty_result_set_value = 0 def __init__(self, expression, filter=None, **extra): - if expression == '*': + if expression == "*": expression = Star() if isinstance(expression, Star) and filter is not None: - raise ValueError('Star cannot be used with filter. Please specify a field.') + raise ValueError("Star cannot be used with filter. Please specify a field.") super().__init__(expression, filter=filter, **extra) class Max(Aggregate): - function = 'MAX' - name = 'Max' + function = "MAX" + name = "Max" class Min(Aggregate): - function = 'MIN' - name = 'Min' + function = "MIN" + name = "Min" class StdDev(NumericOutputFieldMixin, Aggregate): - name = 'StdDev' + name = "StdDev" def __init__(self, expression, sample=False, **extra): - self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' + self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" super().__init__(expression, **extra) def _get_repr_options(self): - return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'} + return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} class Sum(FixDurationInputMixin, Aggregate): - function = 'SUM' - name = 'Sum' + function = "SUM" + name = "Sum" allow_distinct = True class Variance(NumericOutputFieldMixin, Aggregate): - name = 'Variance' + name = "Variance" def __init__(self, expression, sample=False, **extra): - self.function = 'VAR_SAMP' if sample else 'VAR_POP' + self.function = "VAR_SAMP" if sample else "VAR_POP" super().__init__(expression, **extra) def _get_repr_options(self): - return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'} + return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"} diff --git a/django/db/models/base.py b/django/db/models/base.py index 37f6a3dd58..8127a9895a 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -9,28 +9,42 @@ from django.apps import apps from django.conf import settings from django.core import checks from django.core.exceptions import ( - NON_FIELD_ERRORS, FieldDoesNotExist, FieldError, MultipleObjectsReturned, - ObjectDoesNotExist, ValidationError, + NON_FIELD_ERRORS, + FieldDoesNotExist, + FieldError, + MultipleObjectsReturned, + ObjectDoesNotExist, + ValidationError, ) from django.db import ( - DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, DatabaseError, connection, - connections, router, transaction, -) -from django.db.models import ( - NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, Value, + DEFAULT_DB_ALIAS, + DJANGO_VERSION_PICKLE_KEY, + DatabaseError, + connection, + connections, + router, + transaction, ) +from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max, Value from django.db.models.constants import LOOKUP_SEP from django.db.models.constraints import CheckConstraint, UniqueConstraint from django.db.models.deletion import CASCADE, Collector from django.db.models.fields.related import ( - ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation, + ForeignObjectRel, + OneToOneField, + lazy_related_operation, + resolve_relation, ) from django.db.models.functions import Coalesce from django.db.models.manager import Manager from django.db.models.options import Options from django.db.models.query import F, Q from django.db.models.signals import ( - class_prepared, post_init, post_save, pre_init, pre_save, + class_prepared, + post_init, + post_save, + pre_init, + pre_save, ) from django.db.models.utils import make_model_tuple from django.utils.encoding import force_str @@ -41,10 +55,10 @@ from django.utils.translation import gettext_lazy as _ class Deferred: def __repr__(self): - return '<Deferred field>' + return "<Deferred field>" def __str__(self): - return '<Deferred field>' + return "<Deferred field>" DEFERRED = Deferred() @@ -58,19 +72,24 @@ def subclass_exception(name, bases, module, attached_to): that the returned exception class will be added as an attribute to the 'attached_to' class. """ - return type(name, bases, { - '__module__': module, - '__qualname__': '%s.%s' % (attached_to.__qualname__, name), - }) + return type( + name, + bases, + { + "__module__": module, + "__qualname__": "%s.%s" % (attached_to.__qualname__, name), + }, + ) def _has_contribute_to_class(value): # Only call contribute_to_class() if it's bound. - return not inspect.isclass(value) and hasattr(value, 'contribute_to_class') + return not inspect.isclass(value) and hasattr(value, "contribute_to_class") class ModelBase(type): """Metaclass for all models.""" + def __new__(cls, name, bases, attrs, **kwargs): super_new = super().__new__ @@ -81,12 +100,12 @@ class ModelBase(type): return super_new(cls, name, bases, attrs) # Create the class. - module = attrs.pop('__module__') - new_attrs = {'__module__': module} - classcell = attrs.pop('__classcell__', None) + module = attrs.pop("__module__") + new_attrs = {"__module__": module} + classcell = attrs.pop("__classcell__", None) if classcell is not None: - new_attrs['__classcell__'] = classcell - attr_meta = attrs.pop('Meta', None) + new_attrs["__classcell__"] = classcell + attr_meta = attrs.pop("Meta", None) # Pass all attrs without a (Django-specific) contribute_to_class() # method to type.__new__() so that they're properly initialized # (i.e. __set_name__()). @@ -98,16 +117,16 @@ class ModelBase(type): new_attrs[obj_name] = obj new_class = super_new(cls, name, bases, new_attrs, **kwargs) - abstract = getattr(attr_meta, 'abstract', False) - meta = attr_meta or getattr(new_class, 'Meta', None) - base_meta = getattr(new_class, '_meta', None) + abstract = getattr(attr_meta, "abstract", False) + meta = attr_meta or getattr(new_class, "Meta", None) + base_meta = getattr(new_class, "_meta", None) app_label = None # Look for an application configuration to attach the model to. app_config = apps.get_containing_app_config(module) - if getattr(meta, 'app_label', None) is None: + if getattr(meta, "app_label", None) is None: if app_config is None: if not abstract: raise RuntimeError( @@ -119,33 +138,43 @@ class ModelBase(type): else: app_label = app_config.label - new_class.add_to_class('_meta', Options(meta, app_label)) + new_class.add_to_class("_meta", Options(meta, app_label)) if not abstract: new_class.add_to_class( - 'DoesNotExist', + "DoesNotExist", subclass_exception( - 'DoesNotExist', + "DoesNotExist", tuple( - x.DoesNotExist for x in parents if hasattr(x, '_meta') and not x._meta.abstract - ) or (ObjectDoesNotExist,), + x.DoesNotExist + for x in parents + if hasattr(x, "_meta") and not x._meta.abstract + ) + or (ObjectDoesNotExist,), module, - attached_to=new_class)) + attached_to=new_class, + ), + ) new_class.add_to_class( - 'MultipleObjectsReturned', + "MultipleObjectsReturned", subclass_exception( - 'MultipleObjectsReturned', + "MultipleObjectsReturned", tuple( - x.MultipleObjectsReturned for x in parents if hasattr(x, '_meta') and not x._meta.abstract - ) or (MultipleObjectsReturned,), + x.MultipleObjectsReturned + for x in parents + if hasattr(x, "_meta") and not x._meta.abstract + ) + or (MultipleObjectsReturned,), module, - attached_to=new_class)) + attached_to=new_class, + ), + ) if base_meta and not base_meta.abstract: # Non-abstract child classes inherit some attributes from their # non-abstract parent (unless an ABC comes before it in the # method resolution order). - if not hasattr(meta, 'ordering'): + if not hasattr(meta, "ordering"): new_class._meta.ordering = base_meta.ordering - if not hasattr(meta, 'get_latest_by'): + if not hasattr(meta, "get_latest_by"): new_class._meta.get_latest_by = base_meta.get_latest_by is_proxy = new_class._meta.proxy @@ -153,7 +182,9 @@ class ModelBase(type): # If the model is a proxy, ensure that the base class # hasn't been swapped out. if is_proxy and base_meta and base_meta.swapped: - raise TypeError("%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped)) + raise TypeError( + "%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped) + ) # Add remaining attributes (those with a contribute_to_class() method) # to the class. @@ -164,14 +195,14 @@ class ModelBase(type): new_fields = chain( new_class._meta.local_fields, new_class._meta.local_many_to_many, - new_class._meta.private_fields + new_class._meta.private_fields, ) field_names = {f.name for f in new_fields} # Basic setup for proxy models. if is_proxy: base = None - for parent in [kls for kls in parents if hasattr(kls, '_meta')]: + for parent in [kls for kls in parents if hasattr(kls, "_meta")]: if parent._meta.abstract: if parent._meta.fields: raise TypeError( @@ -183,9 +214,14 @@ class ModelBase(type): if base is None: base = parent elif parent._meta.concrete_model is not base._meta.concrete_model: - raise TypeError("Proxy model '%s' has more than one non-abstract model base class." % name) + raise TypeError( + "Proxy model '%s' has more than one non-abstract model base class." + % name + ) if base is None: - raise TypeError("Proxy model '%s' has no non-abstract model base class." % name) + raise TypeError( + "Proxy model '%s' has no non-abstract model base class." % name + ) new_class._meta.setup_proxy(base) new_class._meta.concrete_model = base._meta.concrete_model else: @@ -195,7 +231,7 @@ class ModelBase(type): parent_links = {} for base in reversed([new_class] + parents): # Conceptually equivalent to `if base is Model`. - if not hasattr(base, '_meta'): + if not hasattr(base, "_meta"): continue # Skip concrete parent classes. if base != new_class and not base._meta.abstract: @@ -210,7 +246,7 @@ class ModelBase(type): inherited_attributes = set() # Do the appropriate setup for any model parents. for base in new_class.mro(): - if base not in parents or not hasattr(base, '_meta'): + if base not in parents or not hasattr(base, "_meta"): # Things without _meta aren't functional models, so they're # uninteresting parents. inherited_attributes.update(base.__dict__) @@ -223,8 +259,9 @@ class ModelBase(type): for field in parent_fields: if field.name in field_names: raise FieldError( - 'Local field %r in class %r clashes with field of ' - 'the same name from base class %r.' % ( + "Local field %r in class %r clashes with field of " + "the same name from base class %r." + % ( field.name, name, base.__name__, @@ -239,7 +276,7 @@ class ModelBase(type): if base_key in parent_links: field = parent_links[base_key] elif not is_proxy: - attr_name = '%s_ptr' % base._meta.model_name + attr_name = "%s_ptr" % base._meta.model_name field = OneToOneField( base, on_delete=CASCADE, @@ -252,7 +289,8 @@ class ModelBase(type): raise FieldError( "Auto-generated field '%s' in class %r for " "parent_link to base class %r clashes with " - "declared field of the same name." % ( + "declared field of the same name." + % ( attr_name, name, base.__name__, @@ -271,9 +309,11 @@ class ModelBase(type): # Add fields from abstract base class if it wasn't overridden. for field in parent_fields: - if (field.name not in field_names and - field.name not in new_class.__dict__ and - field.name not in inherited_attributes): + if ( + field.name not in field_names + and field.name not in new_class.__dict__ + and field.name not in inherited_attributes + ): new_field = copy.deepcopy(field) new_class.add_to_class(field.name, new_field) # Replace parent links defined on this base by the new @@ -292,8 +332,9 @@ class ModelBase(type): if field.name in field_names: if not base._meta.abstract: raise FieldError( - 'Local field %r in class %r clashes with field of ' - 'the same name from base class %r.' % ( + "Local field %r in class %r clashes with field of " + "the same name from base class %r." + % ( field.name, name, base.__name__, @@ -307,7 +348,9 @@ class ModelBase(type): # Copy indexes so that index names are unique when models extend an # abstract model. - new_class._meta.indexes = [copy.deepcopy(idx) for idx in new_class._meta.indexes] + new_class._meta.indexes = [ + copy.deepcopy(idx) for idx in new_class._meta.indexes + ] if abstract: # Abstract base models can't be instantiated and don't appear in @@ -333,8 +376,12 @@ class ModelBase(type): opts._prepare(cls) if opts.order_with_respect_to: - cls.get_next_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=True) - cls.get_previous_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=False) + cls.get_next_in_order = partialmethod( + cls._get_next_or_previous_in_order, is_next=True + ) + cls.get_previous_in_order = partialmethod( + cls._get_next_or_previous_in_order, is_next=False + ) # Defer creating accessors on the foreign class until it has been # created and registered. If remote_field is None, we're ordering @@ -348,21 +395,26 @@ class ModelBase(type): # Give the class a docstring -- its definition. if cls.__doc__ is None: - cls.__doc__ = "%s(%s)" % (cls.__name__, ", ".join(f.name for f in opts.fields)) + cls.__doc__ = "%s(%s)" % ( + cls.__name__, + ", ".join(f.name for f in opts.fields), + ) - get_absolute_url_override = settings.ABSOLUTE_URL_OVERRIDES.get(opts.label_lower) + get_absolute_url_override = settings.ABSOLUTE_URL_OVERRIDES.get( + opts.label_lower + ) if get_absolute_url_override: - setattr(cls, 'get_absolute_url', get_absolute_url_override) + setattr(cls, "get_absolute_url", get_absolute_url_override) if not opts.managers: - if any(f.name == 'objects' for f in opts.fields): + if any(f.name == "objects" for f in opts.fields): raise ValueError( "Model %s must specify a custom Manager, because it has a " "field named 'objects'." % cls.__name__ ) manager = Manager() manager.auto_created = True - cls.add_to_class('objects', manager) + cls.add_to_class("objects", manager) # Set the name of _meta.indexes. This can't be done in # Options.contribute_to_class() because fields haven't been added to @@ -399,6 +451,7 @@ class ModelStateCacheDescriptor: class ModelState: """Store model instance state.""" + db = None # If true, uniqueness validation checks will consider this a new, unsaved # object. Necessary for correct validation of new instances of objects with @@ -410,19 +463,18 @@ class ModelState: def __getstate__(self): state = self.__dict__.copy() - if 'fields_cache' in state: - state['fields_cache'] = self.fields_cache.copy() + if "fields_cache" in state: + state["fields_cache"] = self.fields_cache.copy() # Manager instances stored in related_managers_cache won't necessarily # be deserializable if they were dynamically created via an inner # scope, e.g. create_forward_many_to_many_manager() and # create_generic_related_manager(). - if 'related_managers_cache' in state: - state['related_managers_cache'] = {} + if "related_managers_cache" in state: + state["related_managers_cache"] = {} return state class Model(metaclass=ModelBase): - def __init__(self, *args, **kwargs): # Alias some things as locals to avoid repeat global lookups cls = self.__class__ @@ -430,7 +482,7 @@ class Model(metaclass=ModelBase): _setattr = setattr _DEFERRED = DEFERRED if opts.abstract: - raise TypeError('Abstract models cannot be instantiated.') + raise TypeError("Abstract models cannot be instantiated.") pre_init.send(sender=cls, args=args, kwargs=kwargs) @@ -529,10 +581,10 @@ class Model(metaclass=ModelBase): if value is not _DEFERRED: _setattr(self, prop, value) if unexpected: - unexpected_names = ', '.join(repr(n) for n in unexpected) + unexpected_names = ", ".join(repr(n) for n in unexpected) raise TypeError( - f'{cls.__name__}() got unexpected keyword arguments: ' - f'{unexpected_names}' + f"{cls.__name__}() got unexpected keyword arguments: " + f"{unexpected_names}" ) super().__init__() post_init.send(sender=cls, instance=self) @@ -551,10 +603,10 @@ class Model(metaclass=ModelBase): return new def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self) + return "<%s: %s>" % (self.__class__.__name__, self) def __str__(self): - return '%s object (%s)' % (self.__class__.__name__, self.pk) + return "%s object (%s)" % (self.__class__.__name__, self.pk) def __eq__(self, other): if not isinstance(other, Model): @@ -580,7 +632,7 @@ class Model(metaclass=ModelBase): def __getstate__(self): """Hook to allow choosing the attributes to pickle.""" state = self.__dict__.copy() - state['_state'] = copy.copy(state['_state']) + state["_state"] = copy.copy(state["_state"]) # memoryview cannot be pickled, so cast it to bytes and store # separately. _memoryview_attrs = [] @@ -588,7 +640,7 @@ class Model(metaclass=ModelBase): if isinstance(value, memoryview): _memoryview_attrs.append((attr, bytes(value))) if _memoryview_attrs: - state['_memoryview_attrs'] = _memoryview_attrs + state["_memoryview_attrs"] = _memoryview_attrs for attr, value in _memoryview_attrs: state.pop(attr) return state @@ -610,8 +662,8 @@ class Model(metaclass=ModelBase): RuntimeWarning, stacklevel=2, ) - if '_memoryview_attrs' in state: - for attr, value in state.pop('_memoryview_attrs'): + if "_memoryview_attrs" in state: + for attr, value in state.pop("_memoryview_attrs"): state[attr] = memoryview(value) self.__dict__.update(state) @@ -632,7 +684,8 @@ class Model(metaclass=ModelBase): Return a set containing names of deferred fields for this instance. """ return { - f.attname for f in self._meta.concrete_fields + f.attname + for f in self._meta.concrete_fields if f.attname not in self.__dict__ } @@ -654,7 +707,7 @@ class Model(metaclass=ModelBase): if fields is None: self._prefetched_objects_cache = {} else: - prefetched_objects_cache = getattr(self, '_prefetched_objects_cache', ()) + prefetched_objects_cache = getattr(self, "_prefetched_objects_cache", ()) for field in fields: if field in prefetched_objects_cache: del prefetched_objects_cache[field] @@ -664,10 +717,13 @@ class Model(metaclass=ModelBase): if any(LOOKUP_SEP in f for f in fields): raise ValueError( 'Found "%s" in fields argument. Relations and transforms ' - 'are not allowed in fields.' % LOOKUP_SEP) + "are not allowed in fields." % LOOKUP_SEP + ) - hints = {'instance': self} - db_instance_qs = self.__class__._base_manager.db_manager(using, hints=hints).filter(pk=self.pk) + hints = {"instance": self} + db_instance_qs = self.__class__._base_manager.db_manager( + using, hints=hints + ).filter(pk=self.pk) # Use provided fields, if not set then reload all non-deferred fields. deferred_fields = self.get_deferred_fields() @@ -675,8 +731,11 @@ class Model(metaclass=ModelBase): fields = list(fields) db_instance_qs = db_instance_qs.only(*fields) elif deferred_fields: - fields = [f.attname for f in self._meta.concrete_fields - if f.attname not in deferred_fields] + fields = [ + f.attname + for f in self._meta.concrete_fields + if f.attname not in deferred_fields + ] db_instance_qs = db_instance_qs.only(*fields) db_instance = db_instance_qs.get() @@ -714,8 +773,9 @@ class Model(metaclass=ModelBase): return getattr(self, field_name) return getattr(self, field.attname) - def save(self, force_insert=False, force_update=False, using=None, - update_fields=None): + def save( + self, force_insert=False, force_update=False, using=None, update_fields=None + ): """ Save the current instance. Override this in a subclass if you want to control the saving process. @@ -724,7 +784,7 @@ class Model(metaclass=ModelBase): that the "save" must be an SQL insert or update (or equivalent for non-SQL backends), respectively. Normally, they should not be set. """ - self._prepare_related_fields_for_save(operation_name='save') + self._prepare_related_fields_for_save(operation_name="save") using = using or router.db_for_write(self.__class__, instance=self) if force_insert and (force_update or update_fields): @@ -752,9 +812,9 @@ class Model(metaclass=ModelBase): if non_model_fields: raise ValueError( - 'The following fields do not exist in this model, are m2m ' - 'fields, or are non-concrete fields: %s' - % ', '.join(non_model_fields) + "The following fields do not exist in this model, are m2m " + "fields, or are non-concrete fields: %s" + % ", ".join(non_model_fields) ) # If saving to the same database, and this model is deferred, then @@ -762,18 +822,29 @@ class Model(metaclass=ModelBase): elif not force_insert and deferred_fields and using == self._state.db: field_names = set() for field in self._meta.concrete_fields: - if not field.primary_key and not hasattr(field, 'through'): + if not field.primary_key and not hasattr(field, "through"): field_names.add(field.attname) loaded_fields = field_names.difference(deferred_fields) if loaded_fields: update_fields = frozenset(loaded_fields) - self.save_base(using=using, force_insert=force_insert, - force_update=force_update, update_fields=update_fields) + self.save_base( + using=using, + force_insert=force_insert, + force_update=force_update, + update_fields=update_fields, + ) + save.alters_data = True - def save_base(self, raw=False, force_insert=False, - force_update=False, using=None, update_fields=None): + def save_base( + self, + raw=False, + force_insert=False, + force_update=False, + using=None, + update_fields=None, + ): """ Handle the parts of saving which should be done only once per save, yet need to be done in raw saves, too. This includes some sanity @@ -793,7 +864,10 @@ class Model(metaclass=ModelBase): meta = cls._meta if not meta.auto_created: pre_save.send( - sender=origin, instance=self, raw=raw, using=using, + sender=origin, + instance=self, + raw=raw, + using=using, update_fields=update_fields, ) # A transaction isn't needed if one query is issued. @@ -806,8 +880,12 @@ class Model(metaclass=ModelBase): if not raw: parent_inserted = self._save_parents(cls, using, update_fields) updated = self._save_table( - raw, cls, force_insert or parent_inserted, - force_update, using, update_fields, + raw, + cls, + force_insert or parent_inserted, + force_update, + using, + update_fields, ) # Store the database on which the object was saved self._state.db = using @@ -817,8 +895,12 @@ class Model(metaclass=ModelBase): # Signal that the save is complete if not meta.auto_created: post_save.send( - sender=origin, instance=self, created=(not updated), - update_fields=update_fields, raw=raw, using=using, + sender=origin, + instance=self, + created=(not updated), + update_fields=update_fields, + raw=raw, + using=using, ) save_base.alters_data = True @@ -829,12 +911,19 @@ class Model(metaclass=ModelBase): inserted = False for parent, field in meta.parents.items(): # Make sure the link fields are synced between parent and self. - if (field and getattr(self, parent._meta.pk.attname) is None and - getattr(self, field.attname) is not None): + if ( + field + and getattr(self, parent._meta.pk.attname) is None + and getattr(self, field.attname) is not None + ): setattr(self, parent._meta.pk.attname, getattr(self, field.attname)) - parent_inserted = self._save_parents(cls=parent, using=using, update_fields=update_fields) + parent_inserted = self._save_parents( + cls=parent, using=using, update_fields=update_fields + ) updated = self._save_table( - cls=parent, using=using, update_fields=update_fields, + cls=parent, + using=using, + update_fields=update_fields, force_insert=parent_inserted, ) if not updated: @@ -851,8 +940,15 @@ class Model(metaclass=ModelBase): field.delete_cached_value(self) return inserted - def _save_table(self, raw=False, cls=None, force_insert=False, - force_update=False, using=None, update_fields=None): + def _save_table( + self, + raw=False, + cls=None, + force_insert=False, + force_update=False, + using=None, + update_fields=None, + ): """ Do the heavy-lifting involved in saving. Update or insert the data for a single table. @@ -861,8 +957,11 @@ class Model(metaclass=ModelBase): non_pks = [f for f in meta.local_concrete_fields if not f.primary_key] if update_fields: - non_pks = [f for f in non_pks - if f.name in update_fields or f.attname in update_fields] + non_pks = [ + f + for f in non_pks + if f.name in update_fields or f.attname in update_fields + ] pk_val = self._get_pk_val(meta) if pk_val is None: @@ -874,21 +973,28 @@ class Model(metaclass=ModelBase): updated = False # Skip an UPDATE when adding an instance and primary key has a default. if ( - not raw and - not force_insert and - self._state.adding and - meta.pk.default and - meta.pk.default is not NOT_PROVIDED + not raw + and not force_insert + and self._state.adding + and meta.pk.default + and meta.pk.default is not NOT_PROVIDED ): force_insert = True # If possible, try an UPDATE. If that doesn't update anything, do an INSERT. if pk_set and not force_insert: base_qs = cls._base_manager.using(using) - values = [(f, None, (getattr(self, f.attname) if raw else f.pre_save(self, False))) - for f in non_pks] + values = [ + ( + f, + None, + (getattr(self, f.attname) if raw else f.pre_save(self, False)), + ) + for f in non_pks + ] forced_update = update_fields or force_update - updated = self._do_update(base_qs, using, pk_val, values, update_fields, - forced_update) + updated = self._do_update( + base_qs, using, pk_val, values, update_fields, forced_update + ) if force_update and not updated: raise DatabaseError("Forced update did not affect any rows.") if update_fields and not updated: @@ -899,18 +1005,26 @@ class Model(metaclass=ModelBase): # autopopulate the _order field field = meta.order_with_respect_to filter_args = field.get_filter_kwargs_for_object(self) - self._order = cls._base_manager.using(using).filter(**filter_args).aggregate( - _order__max=Coalesce( - ExpressionWrapper(Max('_order') + Value(1), output_field=IntegerField()), - Value(0), - ), - )['_order__max'] + self._order = ( + cls._base_manager.using(using) + .filter(**filter_args) + .aggregate( + _order__max=Coalesce( + ExpressionWrapper( + Max("_order") + Value(1), output_field=IntegerField() + ), + Value(0), + ), + )["_order__max"] + ) fields = meta.local_concrete_fields if not pk_set: fields = [f for f in fields if f is not meta.auto_field] returning_fields = meta.db_returning_fields - results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw) + results = self._do_insert( + cls._base_manager, using, fields, returning_fields, raw + ) if results: for value, field in zip(results[0], returning_fields): setattr(self, field.attname, value) @@ -931,7 +1045,8 @@ class Model(metaclass=ModelBase): return update_fields is not None or filtered.exists() if self._meta.select_on_save and not forced_update: return ( - filtered.exists() and + filtered.exists() + and # It may happen that the object is deleted from the DB right after # this check, causing the subsequent UPDATE to return zero matching # rows. The same result can occur in some rare cases when the @@ -949,8 +1064,11 @@ class Model(metaclass=ModelBase): return the newly created data for the model. """ return manager._insert( - [self], fields=fields, returning_fields=returning_fields, - using=using, raw=raw, + [self], + fields=fields, + returning_fields=returning_fields, + using=using, + raw=raw, ) def _prepare_related_fields_for_save(self, operation_name, fields=None): @@ -986,7 +1104,9 @@ class Model(metaclass=ModelBase): setattr(self, field.attname, obj.pk) # If the relationship's pk/to_field was changed, clear the # cached relationship. - if getattr(obj, field.target_field.attname) != getattr(self, field.attname): + if getattr(obj, field.target_field.attname) != getattr( + self, field.attname + ): field.delete_cached_value(self) def delete(self, using=None, keep_parents=False): @@ -1006,42 +1126,59 @@ class Model(metaclass=ModelBase): value = getattr(self, field.attname) choices_dict = dict(make_hashable(field.flatchoices)) # force_str() to coerce lazy strings. - return force_str(choices_dict.get(make_hashable(value), value), strings_only=True) + return force_str( + choices_dict.get(make_hashable(value), value), strings_only=True + ) def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs): if not self.pk: raise ValueError("get_next/get_previous cannot be used on unsaved objects.") - op = 'gt' if is_next else 'lt' - order = '' if is_next else '-' + op = "gt" if is_next else "lt" + order = "" if is_next else "-" param = getattr(self, field.attname) - q = Q((field.name, param), (f'pk__{op}', self.pk), _connector=Q.AND) - q = Q(q, (f'{field.name}__{op}', param), _connector=Q.OR) - qs = self.__class__._default_manager.using(self._state.db).filter(**kwargs).filter(q).order_by( - '%s%s' % (order, field.name), '%spk' % order + q = Q((field.name, param), (f"pk__{op}", self.pk), _connector=Q.AND) + q = Q(q, (f"{field.name}__{op}", param), _connector=Q.OR) + qs = ( + self.__class__._default_manager.using(self._state.db) + .filter(**kwargs) + .filter(q) + .order_by("%s%s" % (order, field.name), "%spk" % order) ) try: return qs[0] except IndexError: - raise self.DoesNotExist("%s matching query does not exist." % self.__class__._meta.object_name) + raise self.DoesNotExist( + "%s matching query does not exist." % self.__class__._meta.object_name + ) def _get_next_or_previous_in_order(self, is_next): cachename = "__%s_order_cache" % is_next if not hasattr(self, cachename): - op = 'gt' if is_next else 'lt' - order = '_order' if is_next else '-_order' + op = "gt" if is_next else "lt" + order = "_order" if is_next else "-_order" order_field = self._meta.order_with_respect_to filter_args = order_field.get_filter_kwargs_for_object(self) - obj = self.__class__._default_manager.filter(**filter_args).filter(**{ - '_order__%s' % op: self.__class__._default_manager.values('_order').filter(**{ - self._meta.pk.name: self.pk - }) - }).order_by(order)[:1].get() + obj = ( + self.__class__._default_manager.filter(**filter_args) + .filter( + **{ + "_order__%s" + % op: self.__class__._default_manager.values("_order").filter( + **{self._meta.pk.name: self.pk} + ) + } + ) + .order_by(order)[:1] + .get() + ) setattr(self, cachename, obj) return getattr(self, cachename) def prepare_database_save(self, field): if self.pk is None: - raise ValueError("Unsaved model instance %r cannot be used in an ORM query." % self) + raise ValueError( + "Unsaved model instance %r cannot be used in an ORM query." % self + ) return getattr(self, field.remote_field.get_related_field().attname) def clean(self): @@ -1085,7 +1222,9 @@ class Model(metaclass=ModelBase): constraints = [(self.__class__, self._meta.total_unique_constraints)] for parent_class in self._meta.get_parent_list(): if parent_class._meta.unique_together: - unique_togethers.append((parent_class, parent_class._meta.unique_together)) + unique_togethers.append( + (parent_class, parent_class._meta.unique_together) + ) if parent_class._meta.total_unique_constraints: constraints.append( (parent_class, parent_class._meta.total_unique_constraints) @@ -1120,11 +1259,11 @@ class Model(metaclass=ModelBase): if f.unique: unique_checks.append((model_class, (name,))) if f.unique_for_date and f.unique_for_date not in exclude: - date_checks.append((model_class, 'date', name, f.unique_for_date)) + date_checks.append((model_class, "date", name, f.unique_for_date)) if f.unique_for_year and f.unique_for_year not in exclude: - date_checks.append((model_class, 'year', name, f.unique_for_year)) + date_checks.append((model_class, "year", name, f.unique_for_year)) if f.unique_for_month and f.unique_for_month not in exclude: - date_checks.append((model_class, 'month', name, f.unique_for_month)) + date_checks.append((model_class, "month", name, f.unique_for_month)) return unique_checks, date_checks def _perform_unique_checks(self, unique_checks): @@ -1139,8 +1278,10 @@ class Model(metaclass=ModelBase): f = self._meta.get_field(field_name) lookup_value = getattr(self, f.attname) # TODO: Handle multiple backends with different feature flags. - if (lookup_value is None or - (lookup_value == '' and connection.features.interprets_empty_strings_as_nulls)): + if lookup_value is None or ( + lookup_value == "" + and connection.features.interprets_empty_strings_as_nulls + ): # no value, skip the lookup continue if f.primary_key and not self._state.adding: @@ -1168,7 +1309,9 @@ class Model(metaclass=ModelBase): key = unique_check[0] else: key = NON_FIELD_ERRORS - errors.setdefault(key, []).append(self.unique_error_message(model_class, unique_check)) + errors.setdefault(key, []).append( + self.unique_error_message(model_class, unique_check) + ) return errors @@ -1181,12 +1324,14 @@ class Model(metaclass=ModelBase): date = getattr(self, unique_for) if date is None: continue - if lookup_type == 'date': - lookup_kwargs['%s__day' % unique_for] = date.day - lookup_kwargs['%s__month' % unique_for] = date.month - lookup_kwargs['%s__year' % unique_for] = date.year + if lookup_type == "date": + lookup_kwargs["%s__day" % unique_for] = date.day + lookup_kwargs["%s__month" % unique_for] = date.month + lookup_kwargs["%s__year" % unique_for] = date.year else: - lookup_kwargs['%s__%s' % (unique_for, lookup_type)] = getattr(date, lookup_type) + lookup_kwargs["%s__%s" % (unique_for, lookup_type)] = getattr( + date, lookup_type + ) lookup_kwargs[field] = getattr(self, field) qs = model_class._default_manager.filter(**lookup_kwargs) @@ -1205,46 +1350,48 @@ class Model(metaclass=ModelBase): opts = self._meta field = opts.get_field(field_name) return ValidationError( - message=field.error_messages['unique_for_date'], - code='unique_for_date', + message=field.error_messages["unique_for_date"], + code="unique_for_date", params={ - 'model': self, - 'model_name': capfirst(opts.verbose_name), - 'lookup_type': lookup_type, - 'field': field_name, - 'field_label': capfirst(field.verbose_name), - 'date_field': unique_for, - 'date_field_label': capfirst(opts.get_field(unique_for).verbose_name), - } + "model": self, + "model_name": capfirst(opts.verbose_name), + "lookup_type": lookup_type, + "field": field_name, + "field_label": capfirst(field.verbose_name), + "date_field": unique_for, + "date_field_label": capfirst(opts.get_field(unique_for).verbose_name), + }, ) def unique_error_message(self, model_class, unique_check): opts = model_class._meta params = { - 'model': self, - 'model_class': model_class, - 'model_name': capfirst(opts.verbose_name), - 'unique_check': unique_check, + "model": self, + "model_class": model_class, + "model_name": capfirst(opts.verbose_name), + "unique_check": unique_check, } # A unique field if len(unique_check) == 1: field = opts.get_field(unique_check[0]) - params['field_label'] = capfirst(field.verbose_name) + params["field_label"] = capfirst(field.verbose_name) return ValidationError( - message=field.error_messages['unique'], - code='unique', + message=field.error_messages["unique"], + code="unique", params=params, ) # unique_together else: - field_labels = [capfirst(opts.get_field(f).verbose_name) for f in unique_check] - params['field_labels'] = get_text_list(field_labels, _('and')) + field_labels = [ + capfirst(opts.get_field(f).verbose_name) for f in unique_check + ] + params["field_labels"] = get_text_list(field_labels, _("and")) return ValidationError( message=_("%(model_name)s with this %(field_labels)s already exists."), - code='unique_together', + code="unique_together", params=params, ) @@ -1311,9 +1458,13 @@ class Model(metaclass=ModelBase): @classmethod def check(cls, **kwargs): - errors = [*cls._check_swappable(), *cls._check_model(), *cls._check_managers(**kwargs)] + errors = [ + *cls._check_swappable(), + *cls._check_model(), + *cls._check_managers(**kwargs), + ] if not cls._meta.swapped: - databases = kwargs.get('databases') or [] + databases = kwargs.get("databases") or [] errors += [ *cls._check_fields(**kwargs), *cls._check_m2m_through_same_relationship(), @@ -1345,16 +1496,17 @@ class Model(metaclass=ModelBase): @classmethod def _check_default_pk(cls): if ( - not cls._meta.abstract and - cls._meta.pk.auto_created and + not cls._meta.abstract + and cls._meta.pk.auto_created + and # Inherited PKs are checked in parents models. not ( - isinstance(cls._meta.pk, OneToOneField) and - cls._meta.pk.remote_field.parent_link - ) and - not settings.is_overridden('DEFAULT_AUTO_FIELD') and - cls._meta.app_config and - not cls._meta.app_config._is_default_auto_field_overridden + isinstance(cls._meta.pk, OneToOneField) + and cls._meta.pk.remote_field.parent_link + ) + and not settings.is_overridden("DEFAULT_AUTO_FIELD") + and cls._meta.app_config + and not cls._meta.app_config._is_default_auto_field_overridden ): return [ checks.Warning( @@ -1368,7 +1520,7 @@ class Model(metaclass=ModelBase): f"of AutoField, e.g. 'django.db.models.BigAutoField'." ), obj=cls, - id='models.W042', + id="models.W042", ), ] return [] @@ -1383,19 +1535,19 @@ class Model(metaclass=ModelBase): except ValueError: errors.append( checks.Error( - "'%s' is not of the form 'app_label.app_name'." % cls._meta.swappable, - id='models.E001', + "'%s' is not of the form 'app_label.app_name'." + % cls._meta.swappable, + id="models.E001", ) ) except LookupError: - app_label, model_name = cls._meta.swapped.split('.') + app_label, model_name = cls._meta.swapped.split(".") errors.append( checks.Error( "'%s' references '%s.%s', which has not been " - "installed, or is abstract." % ( - cls._meta.swappable, app_label, model_name - ), - id='models.E002', + "installed, or is abstract." + % (cls._meta.swappable, app_label, model_name), + id="models.E002", ) ) return errors @@ -1408,7 +1560,7 @@ class Model(metaclass=ModelBase): errors.append( checks.Error( "Proxy model '%s' contains model fields." % cls.__name__, - id='models.E017', + id="models.E017", ) ) return errors @@ -1433,8 +1585,7 @@ class Model(metaclass=ModelBase): @classmethod def _check_m2m_through_same_relationship(cls): - """ Check if no relationship model is used by more than one m2m field. - """ + """Check if no relationship model is used by more than one m2m field.""" errors = [] seen_intermediary_signatures = [] @@ -1448,15 +1599,20 @@ class Model(metaclass=ModelBase): fields = (f for f in fields if isinstance(f.remote_field.through, ModelBase)) for f in fields: - signature = (f.remote_field.model, cls, f.remote_field.through, f.remote_field.through_fields) + signature = ( + f.remote_field.model, + cls, + f.remote_field.through, + f.remote_field.through_fields, + ) if signature in seen_intermediary_signatures: errors.append( checks.Error( "The model has two identical many-to-many relations " - "through the intermediate model '%s'." % - f.remote_field.through._meta.label, + "through the intermediate model '%s'." + % f.remote_field.through._meta.label, obj=cls, - id='models.E003', + id="models.E003", ) ) else: @@ -1466,15 +1622,17 @@ class Model(metaclass=ModelBase): @classmethod def _check_id_field(cls): """Check if `id` field is a primary key.""" - fields = [f for f in cls._meta.local_fields if f.name == 'id' and f != cls._meta.pk] + fields = [ + f for f in cls._meta.local_fields if f.name == "id" and f != cls._meta.pk + ] # fields is empty or consists of the invalid "id" field - if fields and not fields[0].primary_key and cls._meta.pk.name == 'id': + if fields and not fields[0].primary_key and cls._meta.pk.name == "id": return [ checks.Error( "'id' can only be used as a field name if the field also " "sets 'primary_key=True'.", obj=cls, - id='models.E004', + id="models.E004", ) ] else: @@ -1495,12 +1653,10 @@ class Model(metaclass=ModelBase): checks.Error( "The field '%s' from parent model " "'%s' clashes with the field '%s' " - "from parent model '%s'." % ( - clash.name, clash.model._meta, - f.name, f.model._meta - ), + "from parent model '%s'." + % (clash.name, clash.model._meta, f.name, f.model._meta), obj=cls, - id='models.E005', + id="models.E005", ) ) used_fields[f.name] = f @@ -1520,16 +1676,16 @@ class Model(metaclass=ModelBase): # field "id" and automatically added unique field "id", both # defined at the same model. This special case is considered in # _check_id_field and here we ignore it. - id_conflict = f.name == "id" and clash and clash.name == "id" and clash.model == cls + id_conflict = ( + f.name == "id" and clash and clash.name == "id" and clash.model == cls + ) if clash and not id_conflict: errors.append( checks.Error( "The field '%s' clashes with the field '%s' " - "from model '%s'." % ( - f.name, clash.name, clash.model._meta - ), + "from model '%s'." % (f.name, clash.name, clash.model._meta), obj=f, - id='models.E006', + id="models.E006", ) ) used_fields[f.name] = f @@ -1554,7 +1710,7 @@ class Model(metaclass=ModelBase): "another field." % (f.name, column_name), hint="Specify a 'db_column' for the field.", obj=cls, - id='models.E007' + id="models.E007", ) ) else: @@ -1566,13 +1722,13 @@ class Model(metaclass=ModelBase): def _check_model_name_db_lookup_clashes(cls): errors = [] model_name = cls.__name__ - if model_name.startswith('_') or model_name.endswith('_'): + if model_name.startswith("_") or model_name.endswith("_"): errors.append( checks.Error( "The model name '%s' cannot start or end with an underscore " "as it collides with the query lookup syntax." % model_name, obj=cls, - id='models.E023' + id="models.E023", ) ) elif LOOKUP_SEP in model_name: @@ -1581,7 +1737,7 @@ class Model(metaclass=ModelBase): "The model name '%s' cannot contain double underscores as " "it collides with the query lookup syntax." % model_name, obj=cls, - id='models.E024' + id="models.E024", ) ) return errors @@ -1591,7 +1747,8 @@ class Model(metaclass=ModelBase): errors = [] property_names = cls._meta._property_names related_field_accessors = ( - f.get_attname() for f in cls._meta._get_fields(reverse=False) + f.get_attname() + for f in cls._meta._get_fields(reverse=False) if f.is_relation and f.related_model is not None ) for accessor in related_field_accessors: @@ -1601,7 +1758,7 @@ class Model(metaclass=ModelBase): "The property '%s' clashes with a related field " "accessor." % accessor, obj=cls, - id='models.E025', + id="models.E025", ) ) return errors @@ -1615,7 +1772,7 @@ class Model(metaclass=ModelBase): "The model cannot have more than one field with " "'primary_key=True'.", obj=cls, - id='models.E026', + id="models.E026", ) ) return errors @@ -1628,16 +1785,18 @@ class Model(metaclass=ModelBase): checks.Error( "'index_together' must be a list or tuple.", obj=cls, - id='models.E008', + id="models.E008", ) ] - elif any(not isinstance(fields, (tuple, list)) for fields in cls._meta.index_together): + elif any( + not isinstance(fields, (tuple, list)) for fields in cls._meta.index_together + ): return [ checks.Error( "All 'index_together' elements must be lists or tuples.", obj=cls, - id='models.E009', + id="models.E009", ) ] @@ -1655,16 +1814,19 @@ class Model(metaclass=ModelBase): checks.Error( "'unique_together' must be a list or tuple.", obj=cls, - id='models.E010', + id="models.E010", ) ] - elif any(not isinstance(fields, (tuple, list)) for fields in cls._meta.unique_together): + elif any( + not isinstance(fields, (tuple, list)) + for fields in cls._meta.unique_together + ): return [ checks.Error( "All 'unique_together' elements must be lists or tuples.", obj=cls, - id='models.E011', + id="models.E011", ) ] @@ -1682,13 +1844,13 @@ class Model(metaclass=ModelBase): for index in cls._meta.indexes: # Index name can't start with an underscore or a number, restricted # for cross-database compatibility with Oracle. - if index.name[0] == '_' or index.name[0].isdigit(): + if index.name[0] == "_" or index.name[0].isdigit(): errors.append( checks.Error( "The index name '%s' cannot start with an underscore " "or a number." % index.name, obj=cls, - id='models.E033', + id="models.E033", ), ) if len(index.name) > index.max_name_length: @@ -1697,7 +1859,7 @@ class Model(metaclass=ModelBase): "The index name '%s' cannot be longer than %d " "characters." % (index.name, index.max_name_length), obj=cls, - id='models.E034', + id="models.E034", ), ) if index.contains_expressions: @@ -1710,57 +1872,59 @@ class Model(metaclass=ModelBase): continue connection = connections[db] if not ( - connection.features.supports_partial_indexes or - 'supports_partial_indexes' in cls._meta.required_db_features + connection.features.supports_partial_indexes + or "supports_partial_indexes" in cls._meta.required_db_features ) and any(index.condition is not None for index in cls._meta.indexes): errors.append( checks.Warning( - '%s does not support indexes with conditions.' + "%s does not support indexes with conditions." % connection.display_name, hint=( "Conditions will be ignored. Silence this warning " "if you don't care about it." ), obj=cls, - id='models.W037', + id="models.W037", ) ) if not ( - connection.features.supports_covering_indexes or - 'supports_covering_indexes' in cls._meta.required_db_features + connection.features.supports_covering_indexes + or "supports_covering_indexes" in cls._meta.required_db_features ) and any(index.include for index in cls._meta.indexes): errors.append( checks.Warning( - '%s does not support indexes with non-key columns.' + "%s does not support indexes with non-key columns." % connection.display_name, hint=( "Non-key columns will be ignored. Silence this " "warning if you don't care about it." ), obj=cls, - id='models.W040', + id="models.W040", ) ) if not ( - connection.features.supports_expression_indexes or - 'supports_expression_indexes' in cls._meta.required_db_features + connection.features.supports_expression_indexes + or "supports_expression_indexes" in cls._meta.required_db_features ) and any(index.contains_expressions for index in cls._meta.indexes): errors.append( checks.Warning( - '%s does not support indexes on expressions.' + "%s does not support indexes on expressions." % connection.display_name, hint=( "An index won't be created. Silence this warning " "if you don't care about it." ), obj=cls, - id='models.W043', + id="models.W043", ) ) - fields = [field for index in cls._meta.indexes for field, _ in index.fields_orders] + fields = [ + field for index in cls._meta.indexes for field, _ in index.fields_orders + ] fields += [include for index in cls._meta.indexes for include in index.include] fields += references - errors.extend(cls._check_local_fields(fields, 'indexes')) + errors.extend(cls._check_local_fields(fields, "indexes")) return errors @classmethod @@ -1772,7 +1936,7 @@ class Model(metaclass=ModelBase): forward_fields_map = {} for field in cls._meta._get_fields(reverse=False): forward_fields_map[field.name] = field - if hasattr(field, 'attname'): + if hasattr(field, "attname"): forward_fields_map[field.attname] = field errors = [] @@ -1782,11 +1946,13 @@ class Model(metaclass=ModelBase): except KeyError: errors.append( checks.Error( - "'%s' refers to the nonexistent field '%s'." % ( - option, field_name, + "'%s' refers to the nonexistent field '%s'." + % ( + option, + field_name, ), obj=cls, - id='models.E012', + id="models.E012", ) ) else: @@ -1794,11 +1960,14 @@ class Model(metaclass=ModelBase): errors.append( checks.Error( "'%s' refers to a ManyToManyField '%s', but " - "ManyToManyFields are not permitted in '%s'." % ( - option, field_name, option, + "ManyToManyFields are not permitted in '%s'." + % ( + option, + field_name, + option, ), obj=cls, - id='models.E013', + id="models.E013", ) ) elif field not in cls._meta.local_fields: @@ -1808,7 +1977,7 @@ class Model(metaclass=ModelBase): % (option, field_name, cls._meta.object_name), hint="This issue may be caused by multi-table inheritance.", obj=cls, - id='models.E016', + id="models.E016", ) ) return errors @@ -1824,7 +1993,7 @@ class Model(metaclass=ModelBase): checks.Error( "'ordering' and 'order_with_respect_to' cannot be used together.", obj=cls, - id='models.E021', + id="models.E021", ), ] @@ -1836,7 +2005,7 @@ class Model(metaclass=ModelBase): checks.Error( "'ordering' must be a tuple or list (even if you want to order by only one field).", obj=cls, - id='models.E014', + id="models.E014", ) ] @@ -1844,10 +2013,10 @@ class Model(metaclass=ModelBase): fields = cls._meta.ordering # Skip expressions and '?' fields. - fields = (f for f in fields if isinstance(f, str) and f != '?') + fields = (f for f in fields if isinstance(f, str) and f != "?") # Convert "-field" to "field". - fields = ((f[1:] if f.startswith('-') else f) for f in fields) + fields = ((f[1:] if f.startswith("-") else f) for f in fields) # Separate related fields and non-related fields. _fields = [] @@ -1866,7 +2035,7 @@ class Model(metaclass=ModelBase): for part in field.split(LOOKUP_SEP): try: # pk is an alias that won't be found by opts.get_field. - if part == 'pk': + if part == "pk": fld = _cls._meta.pk else: fld = _cls._meta.get_field(part) @@ -1883,13 +2052,13 @@ class Model(metaclass=ModelBase): "'ordering' refers to the nonexistent field, " "related field, or lookup '%s'." % field, obj=cls, - id='models.E015', + id="models.E015", ) ) # Skip ordering on pk. This is always a valid order_by field # but is an alias and therefore won't be found by opts.get_field. - fields = {f for f in fields if f != 'pk'} + fields = {f for f in fields if f != "pk"} # Check for invalid or nonexistent fields in ordering. invalid_fields = [] @@ -1897,10 +2066,14 @@ class Model(metaclass=ModelBase): # Any field name that is not present in field_names does not exist. # Also, ordering by m2m fields is not allowed. opts = cls._meta - valid_fields = set(chain.from_iterable( - (f.name, f.attname) if not (f.auto_created and not f.concrete) else (f.field.related_query_name(),) - for f in chain(opts.fields, opts.related_objects) - )) + valid_fields = set( + chain.from_iterable( + (f.name, f.attname) + if not (f.auto_created and not f.concrete) + else (f.field.related_query_name(),) + for f in chain(opts.fields, opts.related_objects) + ) + ) invalid_fields.extend(fields - valid_fields) @@ -1910,7 +2083,7 @@ class Model(metaclass=ModelBase): "'ordering' refers to the nonexistent field, related " "field, or lookup '%s'." % invalid_field, obj=cls, - id='models.E015', + id="models.E015", ) ) return errors @@ -1952,7 +2125,11 @@ class Model(metaclass=ModelBase): # Check if auto-generated name for the field is too long # for the database. - if f.db_column is None and column_name is not None and len(column_name) > allowed_len: + if ( + f.db_column is None + and column_name is not None + and len(column_name) > allowed_len + ): errors.append( checks.Error( 'Autogenerated column name too long for field "%s". ' @@ -1960,7 +2137,7 @@ class Model(metaclass=ModelBase): % (column_name, allowed_len, db_alias), hint="Set the column name manually using 'db_column'.", obj=cls, - id='models.E018', + id="models.E018", ) ) @@ -1973,10 +2150,14 @@ class Model(metaclass=ModelBase): # for the database. for m2m in f.remote_field.through._meta.local_fields: _, rel_name = m2m.get_attname_column() - if m2m.db_column is None and rel_name is not None and len(rel_name) > allowed_len: + if ( + m2m.db_column is None + and rel_name is not None + and len(rel_name) > allowed_len + ): errors.append( checks.Error( - 'Autogenerated column name too long for M2M field ' + "Autogenerated column name too long for M2M field " '"%s". Maximum length is "%s" for database "%s".' % (rel_name, allowed_len, db_alias), hint=( @@ -1984,7 +2165,7 @@ class Model(metaclass=ModelBase): "M2M and then set column_name using 'db_column'." ), obj=cls, - id='models.E019', + id="models.E019", ) ) @@ -2002,7 +2183,7 @@ class Model(metaclass=ModelBase): yield from cls._get_expr_references(child) elif isinstance(expr, F): yield tuple(expr.name.split(LOOKUP_SEP)) - elif hasattr(expr, 'get_source_expressions'): + elif hasattr(expr, "get_source_expressions"): for src_expr in expr.get_source_expressions(): yield from cls._get_expr_references(src_expr) @@ -2014,132 +2195,145 @@ class Model(metaclass=ModelBase): continue connection = connections[db] if not ( - connection.features.supports_table_check_constraints or - 'supports_table_check_constraints' in cls._meta.required_db_features + connection.features.supports_table_check_constraints + or "supports_table_check_constraints" in cls._meta.required_db_features ) and any( isinstance(constraint, CheckConstraint) for constraint in cls._meta.constraints ): errors.append( checks.Warning( - '%s does not support check constraints.' % connection.display_name, - hint=( - "A constraint won't be created. Silence this " - "warning if you don't care about it." - ), - obj=cls, - id='models.W027', - ) - ) - if not ( - connection.features.supports_partial_indexes or - 'supports_partial_indexes' in cls._meta.required_db_features - ) and any( - isinstance(constraint, UniqueConstraint) and constraint.condition is not None - for constraint in cls._meta.constraints - ): - errors.append( - checks.Warning( - '%s does not support unique constraints with ' - 'conditions.' % connection.display_name, - hint=( - "A constraint won't be created. Silence this " - "warning if you don't care about it." - ), - obj=cls, - id='models.W036', - ) - ) - if not ( - connection.features.supports_deferrable_unique_constraints or - 'supports_deferrable_unique_constraints' in cls._meta.required_db_features - ) and any( - isinstance(constraint, UniqueConstraint) and constraint.deferrable is not None - for constraint in cls._meta.constraints - ): - errors.append( - checks.Warning( - '%s does not support deferrable unique constraints.' + "%s does not support check constraints." % connection.display_name, hint=( "A constraint won't be created. Silence this " "warning if you don't care about it." ), obj=cls, - id='models.W038', + id="models.W027", ) ) if not ( - connection.features.supports_covering_indexes or - 'supports_covering_indexes' in cls._meta.required_db_features + connection.features.supports_partial_indexes + or "supports_partial_indexes" in cls._meta.required_db_features + ) and any( + isinstance(constraint, UniqueConstraint) + and constraint.condition is not None + for constraint in cls._meta.constraints + ): + errors.append( + checks.Warning( + "%s does not support unique constraints with " + "conditions." % connection.display_name, + hint=( + "A constraint won't be created. Silence this " + "warning if you don't care about it." + ), + obj=cls, + id="models.W036", + ) + ) + if not ( + connection.features.supports_deferrable_unique_constraints + or "supports_deferrable_unique_constraints" + in cls._meta.required_db_features + ) and any( + isinstance(constraint, UniqueConstraint) + and constraint.deferrable is not None + for constraint in cls._meta.constraints + ): + errors.append( + checks.Warning( + "%s does not support deferrable unique constraints." + % connection.display_name, + hint=( + "A constraint won't be created. Silence this " + "warning if you don't care about it." + ), + obj=cls, + id="models.W038", + ) + ) + if not ( + connection.features.supports_covering_indexes + or "supports_covering_indexes" in cls._meta.required_db_features ) and any( isinstance(constraint, UniqueConstraint) and constraint.include for constraint in cls._meta.constraints ): errors.append( checks.Warning( - '%s does not support unique constraints with non-key ' - 'columns.' % connection.display_name, + "%s does not support unique constraints with non-key " + "columns." % connection.display_name, hint=( "A constraint won't be created. Silence this " "warning if you don't care about it." ), obj=cls, - id='models.W039', + id="models.W039", ) ) if not ( - connection.features.supports_expression_indexes or - 'supports_expression_indexes' in cls._meta.required_db_features + connection.features.supports_expression_indexes + or "supports_expression_indexes" in cls._meta.required_db_features ) and any( - isinstance(constraint, UniqueConstraint) and constraint.contains_expressions + isinstance(constraint, UniqueConstraint) + and constraint.contains_expressions for constraint in cls._meta.constraints ): errors.append( checks.Warning( - '%s does not support unique constraints on ' - 'expressions.' % connection.display_name, + "%s does not support unique constraints on " + "expressions." % connection.display_name, hint=( "A constraint won't be created. Silence this " "warning if you don't care about it." ), obj=cls, - id='models.W044', + id="models.W044", ) ) - fields = set(chain.from_iterable( - (*constraint.fields, *constraint.include) - for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint) - )) + fields = set( + chain.from_iterable( + (*constraint.fields, *constraint.include) + for constraint in cls._meta.constraints + if isinstance(constraint, UniqueConstraint) + ) + ) references = set() for constraint in cls._meta.constraints: if isinstance(constraint, UniqueConstraint): if ( - connection.features.supports_partial_indexes or - 'supports_partial_indexes' not in cls._meta.required_db_features + connection.features.supports_partial_indexes + or "supports_partial_indexes" + not in cls._meta.required_db_features ) and isinstance(constraint.condition, Q): - references.update(cls._get_expr_references(constraint.condition)) + references.update( + cls._get_expr_references(constraint.condition) + ) if ( - connection.features.supports_expression_indexes or - 'supports_expression_indexes' not in cls._meta.required_db_features + connection.features.supports_expression_indexes + or "supports_expression_indexes" + not in cls._meta.required_db_features ) and constraint.contains_expressions: for expression in constraint.expressions: references.update(cls._get_expr_references(expression)) elif isinstance(constraint, CheckConstraint): if ( - connection.features.supports_table_check_constraints or - 'supports_table_check_constraints' not in cls._meta.required_db_features + connection.features.supports_table_check_constraints + or "supports_table_check_constraints" + not in cls._meta.required_db_features ) and isinstance(constraint.check, Q): references.update(cls._get_expr_references(constraint.check)) for field_name, *lookups in references: # pk is an alias that won't be found by opts.get_field. - if field_name != 'pk': + if field_name != "pk": fields.add(field_name) if not lookups: # If it has no lookups it cannot result in a JOIN. continue try: - if field_name == 'pk': + if field_name == "pk": field = cls._meta.pk else: field = cls._meta.get_field(field_name) @@ -2150,20 +2344,20 @@ class Model(metaclass=ModelBase): # JOIN must happen at the first lookup. first_lookup = lookups[0] if ( - hasattr(field, 'get_transform') and - hasattr(field, 'get_lookup') and - field.get_transform(first_lookup) is None and - field.get_lookup(first_lookup) is None + hasattr(field, "get_transform") + and hasattr(field, "get_lookup") + and field.get_transform(first_lookup) is None + and field.get_lookup(first_lookup) is None ): errors.append( checks.Error( "'constraints' refers to the joined field '%s'." % LOOKUP_SEP.join([field_name] + lookups), obj=cls, - id='models.E041', + id="models.E041", ) ) - errors.extend(cls._check_local_fields(fields, 'constraints')) + errors.extend(cls._check_local_fields(fields, "constraints")) return errors @@ -2173,14 +2367,16 @@ class Model(metaclass=ModelBase): # ORDERING METHODS ######################### + def method_set_order(self, ordered_obj, id_list, using=None): if using is None: using = DEFAULT_DB_ALIAS order_wrt = ordered_obj._meta.order_with_respect_to filter_args = order_wrt.get_forward_related_filter(self) - ordered_obj.objects.db_manager(using).filter(**filter_args).bulk_update([ - ordered_obj(pk=pk, _order=order) for order, pk in enumerate(id_list) - ], ['_order']) + ordered_obj.objects.db_manager(using).filter(**filter_args).bulk_update( + [ordered_obj(pk=pk, _order=order) for order, pk in enumerate(id_list)], + ["_order"], + ) def method_get_order(self, ordered_obj): @@ -2193,15 +2389,16 @@ def method_get_order(self, ordered_obj): def make_foreign_order_accessors(model, related_model): setattr( related_model, - 'get_%s_order' % model.__name__.lower(), - partialmethod(method_get_order, model) + "get_%s_order" % model.__name__.lower(), + partialmethod(method_get_order, model), ) setattr( related_model, - 'set_%s_order' % model.__name__.lower(), - partialmethod(method_set_order, model) + "set_%s_order" % model.__name__.lower(), + partialmethod(method_set_order, model), ) + ######## # MISC # ######## diff --git a/django/db/models/constants.py b/django/db/models/constants.py index 95addd2ab0..a0c99c95fc 100644 --- a/django/db/models/constants.py +++ b/django/db/models/constants.py @@ -4,9 +4,9 @@ Constants used across the ORM in general. from enum import Enum # Separator used to split filter strings apart. -LOOKUP_SEP = '__' +LOOKUP_SEP = "__" class OnConflict(Enum): - IGNORE = 'ignore' - UPDATE = 'update' + IGNORE = "ignore" + UPDATE = "update" diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 5abedaf3d1..721e43ae58 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -5,7 +5,7 @@ from django.db.models.indexes import IndexExpression from django.db.models.query_utils import Q from django.db.models.sql.query import Query -__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint'] +__all__ = ["CheckConstraint", "Deferrable", "UniqueConstraint"] class BaseConstraint: @@ -17,18 +17,18 @@ class BaseConstraint: return False def constraint_sql(self, model, schema_editor): - raise NotImplementedError('This method must be implemented by a subclass.') + raise NotImplementedError("This method must be implemented by a subclass.") def create_sql(self, model, schema_editor): - raise NotImplementedError('This method must be implemented by a subclass.') + raise NotImplementedError("This method must be implemented by a subclass.") def remove_sql(self, model, schema_editor): - raise NotImplementedError('This method must be implemented by a subclass.') + raise NotImplementedError("This method must be implemented by a subclass.") def deconstruct(self): - path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) - path = path.replace('django.db.models.constraints', 'django.db.models') - return (path, (), {'name': self.name}) + path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) + path = path.replace("django.db.models.constraints", "django.db.models") + return (path, (), {"name": self.name}) def clone(self): _, args, kwargs = self.deconstruct() @@ -38,9 +38,9 @@ class BaseConstraint: class CheckConstraint(BaseConstraint): def __init__(self, *, check, name): self.check = check - if not getattr(check, 'conditional', False): + if not getattr(check, "conditional", False): raise TypeError( - 'CheckConstraint.check must be a Q instance or boolean expression.' + "CheckConstraint.check must be a Q instance or boolean expression." ) super().__init__(name) @@ -63,7 +63,7 @@ class CheckConstraint(BaseConstraint): return schema_editor._delete_check_sql(model, self.name) def __repr__(self): - return '<%s: check=%s name=%s>' % ( + return "<%s: check=%s name=%s>" % ( self.__class__.__qualname__, self.check, repr(self.name), @@ -76,17 +76,17 @@ class CheckConstraint(BaseConstraint): def deconstruct(self): path, args, kwargs = super().deconstruct() - kwargs['check'] = self.check + kwargs["check"] = self.check return path, args, kwargs class Deferrable(Enum): - DEFERRED = 'deferred' - IMMEDIATE = 'immediate' + DEFERRED = "deferred" + IMMEDIATE = "immediate" # A similar format was proposed for Python 3.10. def __repr__(self): - return f'{self.__class__.__qualname__}.{self._name_}' + return f"{self.__class__.__qualname__}.{self._name_}" class UniqueConstraint(BaseConstraint): @@ -101,51 +101,43 @@ class UniqueConstraint(BaseConstraint): opclasses=(), ): if not name: - raise ValueError('A unique constraint must be named.') + raise ValueError("A unique constraint must be named.") if not expressions and not fields: raise ValueError( - 'At least one field or expression is required to define a ' - 'unique constraint.' + "At least one field or expression is required to define a " + "unique constraint." ) if expressions and fields: raise ValueError( - 'UniqueConstraint.fields and expressions are mutually exclusive.' + "UniqueConstraint.fields and expressions are mutually exclusive." ) if not isinstance(condition, (type(None), Q)): - raise ValueError('UniqueConstraint.condition must be a Q instance.') + raise ValueError("UniqueConstraint.condition must be a Q instance.") if condition and deferrable: - raise ValueError( - 'UniqueConstraint with conditions cannot be deferred.' - ) + raise ValueError("UniqueConstraint with conditions cannot be deferred.") if include and deferrable: - raise ValueError( - 'UniqueConstraint with include fields cannot be deferred.' - ) + raise ValueError("UniqueConstraint with include fields cannot be deferred.") if opclasses and deferrable: - raise ValueError( - 'UniqueConstraint with opclasses cannot be deferred.' - ) + raise ValueError("UniqueConstraint with opclasses cannot be deferred.") if expressions and deferrable: - raise ValueError( - 'UniqueConstraint with expressions cannot be deferred.' - ) + raise ValueError("UniqueConstraint with expressions cannot be deferred.") if expressions and opclasses: raise ValueError( - 'UniqueConstraint.opclasses cannot be used with expressions. ' - 'Use django.contrib.postgres.indexes.OpClass() instead.' + "UniqueConstraint.opclasses cannot be used with expressions. " + "Use django.contrib.postgres.indexes.OpClass() instead." ) if not isinstance(deferrable, (type(None), Deferrable)): raise ValueError( - 'UniqueConstraint.deferrable must be a Deferrable instance.' + "UniqueConstraint.deferrable must be a Deferrable instance." ) if not isinstance(include, (type(None), list, tuple)): - raise ValueError('UniqueConstraint.include must be a list or tuple.') + raise ValueError("UniqueConstraint.include must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): - raise ValueError('UniqueConstraint.opclasses must be a list or tuple.') + raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") if opclasses and len(fields) != len(opclasses): raise ValueError( - 'UniqueConstraint.fields and UniqueConstraint.opclasses must ' - 'have the same number of elements.' + "UniqueConstraint.fields and UniqueConstraint.opclasses must " + "have the same number of elements." ) self.fields = tuple(fields) self.condition = condition @@ -185,70 +177,91 @@ class UniqueConstraint(BaseConstraint): def constraint_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] - include = [model._meta.get_field(field_name).column for field_name in self.include] + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._unique_sql( - model, fields, self.name, condition=condition, - deferrable=self.deferrable, include=include, - opclasses=self.opclasses, expressions=expressions, + model, + fields, + self.name, + condition=condition, + deferrable=self.deferrable, + include=include, + opclasses=self.opclasses, + expressions=expressions, ) def create_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] - include = [model._meta.get_field(field_name).column for field_name in self.include] + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._create_unique_sql( - model, fields, self.name, condition=condition, - deferrable=self.deferrable, include=include, - opclasses=self.opclasses, expressions=expressions, + model, + fields, + self.name, + condition=condition, + deferrable=self.deferrable, + include=include, + opclasses=self.opclasses, + expressions=expressions, ) def remove_sql(self, model, schema_editor): condition = self._get_condition_sql(model, schema_editor) - include = [model._meta.get_field(field_name).column for field_name in self.include] + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] expressions = self._get_index_expressions(model, schema_editor) return schema_editor._delete_unique_sql( - model, self.name, condition=condition, deferrable=self.deferrable, - include=include, opclasses=self.opclasses, expressions=expressions, + model, + self.name, + condition=condition, + deferrable=self.deferrable, + include=include, + opclasses=self.opclasses, + expressions=expressions, ) def __repr__(self): - return '<%s:%s%s%s%s%s%s%s>' % ( + return "<%s:%s%s%s%s%s%s%s>" % ( self.__class__.__qualname__, - '' if not self.fields else ' fields=%s' % repr(self.fields), - '' if not self.expressions else ' expressions=%s' % repr(self.expressions), - ' name=%s' % repr(self.name), - '' if self.condition is None else ' condition=%s' % self.condition, - '' if self.deferrable is None else ' deferrable=%r' % self.deferrable, - '' if not self.include else ' include=%s' % repr(self.include), - '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses), + "" if not self.fields else " fields=%s" % repr(self.fields), + "" if not self.expressions else " expressions=%s" % repr(self.expressions), + " name=%s" % repr(self.name), + "" if self.condition is None else " condition=%s" % self.condition, + "" if self.deferrable is None else " deferrable=%r" % self.deferrable, + "" if not self.include else " include=%s" % repr(self.include), + "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) def __eq__(self, other): if isinstance(other, UniqueConstraint): return ( - self.name == other.name and - self.fields == other.fields and - self.condition == other.condition and - self.deferrable == other.deferrable and - self.include == other.include and - self.opclasses == other.opclasses and - self.expressions == other.expressions + self.name == other.name + and self.fields == other.fields + and self.condition == other.condition + and self.deferrable == other.deferrable + and self.include == other.include + and self.opclasses == other.opclasses + and self.expressions == other.expressions ) return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fields: - kwargs['fields'] = self.fields + kwargs["fields"] = self.fields if self.condition: - kwargs['condition'] = self.condition + kwargs["condition"] = self.condition if self.deferrable: - kwargs['deferrable'] = self.deferrable + kwargs["deferrable"] = self.deferrable if self.include: - kwargs['include'] = self.include + kwargs["include"] = self.include if self.opclasses: - kwargs['opclasses'] = self.opclasses + kwargs["opclasses"] = self.opclasses return path, self.expressions, kwargs diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index b99337a309..6912b49498 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -21,8 +21,11 @@ class RestrictedError(IntegrityError): def CASCADE(collector, field, sub_objs, using): collector.collect( - sub_objs, source=field.remote_field.model, source_attr=field.name, - nullable=field.null, fail_on_restricted=False, + sub_objs, + source=field.remote_field.model, + source_attr=field.name, + nullable=field.null, + fail_on_restricted=False, ) if field.null and not connections[using].features.can_defer_constraint_checks: collector.add_field_update(field, None, sub_objs) @@ -31,10 +34,13 @@ def CASCADE(collector, field, sub_objs, using): def PROTECT(collector, field, sub_objs, using): raise ProtectedError( "Cannot delete some instances of model '%s' because they are " - "referenced through a protected foreign key: '%s.%s'" % ( - field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name + "referenced through a protected foreign key: '%s.%s'" + % ( + field.remote_field.model.__name__, + sub_objs[0].__class__.__name__, + field.name, ), - sub_objs + sub_objs, ) @@ -45,12 +51,16 @@ def RESTRICT(collector, field, sub_objs, using): def SET(value): if callable(value): + def set_on_delete(collector, field, sub_objs, using): collector.add_field_update(field, value(), sub_objs) + else: + def set_on_delete(collector, field, sub_objs, using): collector.add_field_update(field, value, sub_objs) - set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {}) + + set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {}) return set_on_delete @@ -70,7 +80,8 @@ def get_candidate_relations_to_delete(opts): # The candidate relations are the ones that come from N-1 and 1-1 relations. # N-N (i.e., many-to-many) relations aren't candidates for deletion. return ( - f for f in opts.get_fields(include_hidden=True) + f + for f in opts.get_fields(include_hidden=True) if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many) ) @@ -124,7 +135,9 @@ class Collector: def add_dependency(self, model, dependency, reverse_dependency=False): if reverse_dependency: model, dependency = dependency, model - self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model) + self.dependencies[model._meta.concrete_model].add( + dependency._meta.concrete_model + ) self.data.setdefault(dependency, self.data.default_factory()) def add_field_update(self, field, value, objs): @@ -151,17 +164,21 @@ class Collector: def clear_restricted_objects_from_queryset(self, model, qs): if model in self.restricted_objects: - objs = set(qs.filter(pk__in=[ - obj.pk - for objs in self.restricted_objects[model].values() for obj in objs - ])) + objs = set( + qs.filter( + pk__in=[ + obj.pk + for objs in self.restricted_objects[model].values() + for obj in objs + ] + ) + ) self.clear_restricted_objects_from_set(model, objs) def _has_signal_listeners(self, model): - return ( - signals.pre_delete.has_listeners(model) or - signals.post_delete.has_listeners(model) - ) + return signals.pre_delete.has_listeners( + model + ) or signals.post_delete.has_listeners(model) def can_fast_delete(self, objs, from_field=None): """ @@ -176,9 +193,9 @@ class Collector: """ if from_field and from_field.remote_field.on_delete is not CASCADE: return False - if hasattr(objs, '_meta'): + if hasattr(objs, "_meta"): model = objs._meta.model - elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'): + elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"): model = objs.model else: return False @@ -188,14 +205,22 @@ class Collector: # parent when parent delete is cascading to child. opts = model._meta return ( - all(link == from_field for link in opts.concrete_model._meta.parents.values()) and + all( + link == from_field + for link in opts.concrete_model._meta.parents.values() + ) + and # Foreign keys pointing to this model. all( related.field.remote_field.on_delete is DO_NOTHING for related in get_candidate_relations_to_delete(opts) - ) and ( + ) + and ( # Something like generic foreign key. - not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields) + not any( + hasattr(field, "bulk_related_objects") + for field in opts.private_fields + ) ) ) @@ -205,16 +230,27 @@ class Collector: """ field_names = [field.name for field in fields] conn_batch_size = max( - connections[self.using].ops.bulk_batch_size(field_names, objs), 1) + connections[self.using].ops.bulk_batch_size(field_names, objs), 1 + ) if len(objs) > conn_batch_size: - return [objs[i:i + conn_batch_size] - for i in range(0, len(objs), conn_batch_size)] + return [ + objs[i : i + conn_batch_size] + for i in range(0, len(objs), conn_batch_size) + ] else: return [objs] - def collect(self, objs, source=None, nullable=False, collect_related=True, - source_attr=None, reverse_dependency=False, keep_parents=False, - fail_on_restricted=True): + def collect( + self, + objs, + source=None, + nullable=False, + collect_related=True, + source_attr=None, + reverse_dependency=False, + keep_parents=False, + fail_on_restricted=True, + ): """ Add 'objs' to the collection of objects to be deleted as well as all parent instances. 'objs' must be a homogeneous iterable collection of @@ -241,8 +277,9 @@ class Collector: if self.can_fast_delete(objs): self.fast_deletes.append(objs) return - new_objs = self.add(objs, source, nullable, - reverse_dependency=reverse_dependency) + new_objs = self.add( + objs, source, nullable, reverse_dependency=reverse_dependency + ) if not new_objs: return @@ -255,11 +292,14 @@ class Collector: for ptr in concrete_model._meta.parents.values(): if ptr: parent_objs = [getattr(obj, ptr.name) for obj in new_objs] - self.collect(parent_objs, source=model, - source_attr=ptr.remote_field.related_name, - collect_related=False, - reverse_dependency=True, - fail_on_restricted=False) + self.collect( + parent_objs, + source=model, + source_attr=ptr.remote_field.related_name, + collect_related=False, + reverse_dependency=True, + fail_on_restricted=False, + ) if not collect_related: return @@ -287,11 +327,18 @@ class Collector: # relationships are select_related as interactions between both # features are hard to get right. This should only happen in # the rare cases where .related_objects is overridden anyway. - if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)): - referenced_fields = set(chain.from_iterable( - (rf.attname for rf in rel.field.foreign_related_fields) - for rel in get_candidate_relations_to_delete(related_model._meta) - )) + if not ( + sub_objs.query.select_related + or self._has_signal_listeners(related_model) + ): + referenced_fields = set( + chain.from_iterable( + (rf.attname for rf in rel.field.foreign_related_fields) + for rel in get_candidate_relations_to_delete( + related_model._meta + ) + ) + ) sub_objs = sub_objs.only(*tuple(referenced_fields)) if sub_objs: try: @@ -301,10 +348,11 @@ class Collector: protected_objects[key] += error.protected_objects if protected_objects: raise ProtectedError( - 'Cannot delete some instances of model %r because they are ' - 'referenced through protected foreign keys: %s.' % ( + "Cannot delete some instances of model %r because they are " + "referenced through protected foreign keys: %s." + % ( model.__name__, - ', '.join(protected_objects), + ", ".join(protected_objects), ), set(chain.from_iterable(protected_objects.values())), ) @@ -314,10 +362,12 @@ class Collector: sub_objs = self.related_objects(related_model, related_fields, batch) self.fast_deletes.append(sub_objs) for field in model._meta.private_fields: - if hasattr(field, 'bulk_related_objects'): + if hasattr(field, "bulk_related_objects"): # It's something like generic foreign key. sub_objs = field.bulk_related_objects(new_objs, self.using) - self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False) + self.collect( + sub_objs, source=model, nullable=True, fail_on_restricted=False + ) if fail_on_restricted: # Raise an error if collected restricted objects (RESTRICT) aren't @@ -335,11 +385,12 @@ class Collector: restricted_objects[key] += objs if restricted_objects: raise RestrictedError( - 'Cannot delete some instances of model %r because ' - 'they are referenced through restricted foreign keys: ' - '%s.' % ( + "Cannot delete some instances of model %r because " + "they are referenced through restricted foreign keys: " + "%s." + % ( model.__name__, - ', '.join(restricted_objects), + ", ".join(restricted_objects), ), set(chain.from_iterable(restricted_objects.values())), ) @@ -349,10 +400,7 @@ class Collector: Get a QuerySet of the related model to objs via related fields. """ predicate = query_utils.Q( - *( - (f'{related_field.name}__in', objs) - for related_field in related_fields - ), + *((f"{related_field.name}__in", objs) for related_field in related_fields), _connector=query_utils.Q.OR, ) return related_model._base_manager.using(self.using).filter(predicate) @@ -397,7 +445,9 @@ class Collector: instance = list(instances)[0] if self.can_fast_delete(instance): with transaction.mark_for_rollback_on_error(self.using): - count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using) + count = sql.DeleteQuery(model).delete_batch( + [instance.pk], self.using + ) setattr(instance, model._meta.pk.attname, None) return count, {model._meta.label: count} @@ -406,7 +456,9 @@ class Collector: for model, obj in self.instances_with_model(): if not model._meta.auto_created: signals.pre_delete.send( - sender=model, instance=obj, using=self.using, + sender=model, + instance=obj, + using=self.using, origin=self.origin, ) @@ -420,8 +472,9 @@ class Collector: for model, instances_for_fieldvalues in self.field_updates.items(): for (field, value), instances in instances_for_fieldvalues.items(): query = sql.UpdateQuery(model) - query.update_batch([obj.pk for obj in instances], - {field.name: value}, self.using) + query.update_batch( + [obj.pk for obj in instances], {field.name: value}, self.using + ) # reverse instance collections for instances in self.data.values(): @@ -438,7 +491,9 @@ class Collector: if not model._meta.auto_created: for obj in instances: signals.post_delete.send( - sender=model, instance=obj, using=self.using, + sender=model, + instance=obj, + using=self.using, origin=self.origin, ) diff --git a/django/db/models/enums.py b/django/db/models/enums.py index 8474c87c94..9a7a2bb70f 100644 --- a/django/db/models/enums.py +++ b/django/db/models/enums.py @@ -3,7 +3,7 @@ from types import DynamicClassAttribute from django.utils.functional import Promise -__all__ = ['Choices', 'IntegerChoices', 'TextChoices'] +__all__ = ["Choices", "IntegerChoices", "TextChoices"] class ChoicesMeta(enum.EnumMeta): @@ -14,14 +14,14 @@ class ChoicesMeta(enum.EnumMeta): for key in classdict._member_names: value = classdict[key] if ( - isinstance(value, (list, tuple)) and - len(value) > 1 and - isinstance(value[-1], (Promise, str)) + isinstance(value, (list, tuple)) + and len(value) > 1 + and isinstance(value[-1], (Promise, str)) ): *value, label = value value = tuple(value) else: - label = key.replace('_', ' ').title() + label = key.replace("_", " ").title() labels.append(label) # Use dict.__setitem__() to suppress defenses against double # assignment in enum's classdict. @@ -39,12 +39,12 @@ class ChoicesMeta(enum.EnumMeta): @property def names(cls): - empty = ['__empty__'] if hasattr(cls, '__empty__') else [] + empty = ["__empty__"] if hasattr(cls, "__empty__") else [] return empty + [member.name for member in cls] @property def choices(cls): - empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else [] + empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else [] return empty + [(member.value, member.label) for member in cls] @property @@ -76,11 +76,12 @@ class Choices(enum.Enum, metaclass=ChoicesMeta): # A similar format was proposed for Python 3.10. def __repr__(self): - return f'{self.__class__.__qualname__}.{self._name_}' + return f"{self.__class__.__qualname__}.{self._name_}" class IntegerChoices(int, Choices): """Class for creating enumerated integer choices.""" + pass diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index f31ff4d3df..a2da1f6e38 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -20,11 +20,12 @@ class SQLiteNumericMixin: Some expressions with output_field=DecimalField() must be cast to numeric to be properly filtered. """ + def as_sqlite(self, compiler, connection, **extra_context): sql, params = self.as_sql(compiler, connection, **extra_context) try: - if self.output_field.get_internal_type() == 'DecimalField': - sql = 'CAST(%s AS NUMERIC)' % sql + if self.output_field.get_internal_type() == "DecimalField": + sql = "CAST(%s AS NUMERIC)" % sql except FieldError: pass return sql, params @@ -37,26 +38,26 @@ class Combinable: """ # Arithmetic connectors - ADD = '+' - SUB = '-' - MUL = '*' - DIV = '/' - POW = '^' + ADD = "+" + SUB = "-" + MUL = "*" + DIV = "/" + POW = "^" # The following is a quoted % operator - it is quoted because it can be # used in strings that also have parameter substitution. - MOD = '%%' + MOD = "%%" # Bitwise operators - note that these are generated by .bitand() # and .bitor(), the '&' and '|' are reserved for boolean operator # usage. - BITAND = '&' - BITOR = '|' - BITLEFTSHIFT = '<<' - BITRIGHTSHIFT = '>>' - BITXOR = '#' + BITAND = "&" + BITOR = "|" + BITLEFTSHIFT = "<<" + BITRIGHTSHIFT = ">>" + BITXOR = "#" def _combine(self, other, connector, reversed): - if not hasattr(other, 'resolve_expression'): + if not hasattr(other, "resolve_expression"): # everything must be resolvable to an expression other = Value(other) @@ -90,7 +91,7 @@ class Combinable: return self._combine(other, self.POW, False) def __and__(self, other): - if getattr(self, 'conditional', False) and getattr(other, 'conditional', False): + if getattr(self, "conditional", False) and getattr(other, "conditional", False): return Q(self) & Q(other) raise NotImplementedError( "Use .bitand() and .bitor() for bitwise logical operations." @@ -109,7 +110,7 @@ class Combinable: return self._combine(other, self.BITXOR, False) def __or__(self, other): - if getattr(self, 'conditional', False) and getattr(other, 'conditional', False): + if getattr(self, "conditional", False) and getattr(other, "conditional", False): return Q(self) | Q(other) raise NotImplementedError( "Use .bitand() and .bitor() for bitwise logical operations." @@ -165,14 +166,14 @@ class BaseExpression: def __getstate__(self): state = self.__dict__.copy() - state.pop('convert_value', None) + state.pop("convert_value", None) return state def get_db_converters(self, connection): return ( [] - if self.convert_value is self._convert_value_noop else - [self.convert_value] + if self.convert_value is self._convert_value_noop + else [self.convert_value] ) + self.output_field.get_db_converters(connection) def get_source_expressions(self): @@ -183,9 +184,10 @@ class BaseExpression: def _parse_expressions(self, *expressions): return [ - arg if hasattr(arg, 'resolve_expression') else ( - F(arg) if isinstance(arg, str) else Value(arg) - ) for arg in expressions + arg + if hasattr(arg, "resolve_expression") + else (F(arg) if isinstance(arg, str) else Value(arg)) + for arg in expressions ] def as_sql(self, compiler, connection): @@ -218,17 +220,26 @@ class BaseExpression: @cached_property def contains_aggregate(self): - return any(expr and expr.contains_aggregate for expr in self.get_source_expressions()) + return any( + expr and expr.contains_aggregate for expr in self.get_source_expressions() + ) @cached_property def contains_over_clause(self): - return any(expr and expr.contains_over_clause for expr in self.get_source_expressions()) + return any( + expr and expr.contains_over_clause for expr in self.get_source_expressions() + ) @cached_property def contains_column_references(self): - return any(expr and expr.contains_column_references for expr in self.get_source_expressions()) + return any( + expr and expr.contains_column_references + for expr in self.get_source_expressions() + ) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): """ Provide the chance to do any preprocessing or validation before being added to the query. @@ -245,11 +256,14 @@ class BaseExpression: """ c = self.copy() c.is_summary = summarize - c.set_source_expressions([ - expr.resolve_expression(query, allow_joins, reuse, summarize) - if expr else None - for expr in c.get_source_expressions() - ]) + c.set_source_expressions( + [ + expr.resolve_expression(query, allow_joins, reuse, summarize) + if expr + else None + for expr in c.get_source_expressions() + ] + ) return c @property @@ -266,7 +280,7 @@ class BaseExpression: output_field = self._resolve_output_field() if output_field is None: self._output_field_resolved_to_none = True - raise FieldError('Cannot resolve expression type, unknown output_field') + raise FieldError("Cannot resolve expression type, unknown output_field") return output_field @cached_property @@ -295,13 +309,16 @@ class BaseExpression: If all sources are None, then an error is raised higher up the stack in the output_field property. """ - sources_iter = (source for source in self.get_source_fields() if source is not None) + sources_iter = ( + source for source in self.get_source_fields() if source is not None + ) for output_field in sources_iter: for source in sources_iter: if not isinstance(output_field, source.__class__): raise FieldError( - 'Expression contains mixed types: %s, %s. You must ' - 'set output_field.' % ( + "Expression contains mixed types: %s, %s. You must " + "set output_field." + % ( output_field.__class__.__name__, source.__class__.__name__, ) @@ -321,12 +338,24 @@ class BaseExpression: """ field = self.output_field internal_type = field.get_internal_type() - if internal_type == 'FloatField': - return lambda value, expression, connection: None if value is None else float(value) - elif internal_type.endswith('IntegerField'): - return lambda value, expression, connection: None if value is None else int(value) - elif internal_type == 'DecimalField': - return lambda value, expression, connection: None if value is None else Decimal(value) + if internal_type == "FloatField": + return ( + lambda value, expression, connection: None + if value is None + else float(value) + ) + elif internal_type.endswith("IntegerField"): + return ( + lambda value, expression, connection: None + if value is None + else int(value) + ) + elif internal_type == "DecimalField": + return ( + lambda value, expression, connection: None + if value is None + else Decimal(value) + ) return self._convert_value_noop def get_lookup(self, lookup): @@ -337,10 +366,12 @@ class BaseExpression: def relabeled_clone(self, change_map): clone = self.copy() - clone.set_source_expressions([ - e.relabeled_clone(change_map) if e is not None else None - for e in self.get_source_expressions() - ]) + clone.set_source_expressions( + [ + e.relabeled_clone(change_map) if e is not None else None + for e in self.get_source_expressions() + ] + ) return clone def copy(self): @@ -375,7 +406,7 @@ class BaseExpression: yield self for expr in self.get_source_expressions(): if expr: - if hasattr(expr, 'flatten'): + if hasattr(expr, "flatten"): yield from expr.flatten() else: yield expr @@ -385,7 +416,7 @@ class BaseExpression: Custom format for select clauses. For example, EXISTS expressions need to be wrapped in CASE WHEN on Oracle. """ - if hasattr(self.output_field, 'select_format'): + if hasattr(self.output_field, "select_format"): return self.output_field.select_format(compiler, sql, params) return sql, params @@ -438,12 +469,13 @@ _connector_combinators = { def _resolve_combined_type(connector, lhs_type, rhs_type): combinators = _connector_combinators.get(connector, ()) for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: - if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type): + if issubclass(lhs_type, combinator_lhs_type) and issubclass( + rhs_type, combinator_rhs_type + ): return combined_type class CombinedExpression(SQLiteNumericMixin, Expression): - def __init__(self, lhs, connector, rhs, output_field=None): super().__init__(output_field=output_field) self.connector = connector @@ -485,13 +517,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression): expressions.append(sql) expression_params.extend(params) # order of precedence - expression_wrapper = '(%s)' + expression_wrapper = "(%s)" sql = connection.ops.combine_expression(self.connector, expressions) return expression_wrapper % sql, expression_params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) - rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + lhs = self.lhs.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + rhs = self.rhs.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) if not isinstance(self, (DurationExpression, TemporalSubtraction)): try: lhs_type = lhs.output_field.get_internal_type() @@ -501,14 +539,28 @@ class CombinedExpression(SQLiteNumericMixin, Expression): rhs_type = rhs.output_field.get_internal_type() except (AttributeError, FieldError): rhs_type = None - if 'DurationField' in {lhs_type, rhs_type} and lhs_type != rhs_type: - return DurationExpression(self.lhs, self.connector, self.rhs).resolve_expression( - query, allow_joins, reuse, summarize, for_save, + if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: + return DurationExpression( + self.lhs, self.connector, self.rhs + ).resolve_expression( + query, + allow_joins, + reuse, + summarize, + for_save, ) - datetime_fields = {'DateField', 'DateTimeField', 'TimeField'} - if self.connector == self.SUB and lhs_type in datetime_fields and lhs_type == rhs_type: + datetime_fields = {"DateField", "DateTimeField", "TimeField"} + if ( + self.connector == self.SUB + and lhs_type in datetime_fields + and lhs_type == rhs_type + ): return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( - query, allow_joins, reuse, summarize, for_save, + query, + allow_joins, + reuse, + summarize, + for_save, ) c = self.copy() c.is_summary = summarize @@ -524,7 +576,7 @@ class DurationExpression(CombinedExpression): except FieldError: pass else: - if output.get_internal_type() == 'DurationField': + if output.get_internal_type() == "DurationField": sql, params = compiler.compile(side) return connection.ops.format_for_duration_arithmetic(sql), params return compiler.compile(side) @@ -542,7 +594,7 @@ class DurationExpression(CombinedExpression): expressions.append(sql) expression_params.extend(params) # order of precedence - expression_wrapper = '(%s)' + expression_wrapper = "(%s)" sql = connection.ops.combine_duration_expression(self.connector, expressions) return expression_wrapper % sql, expression_params @@ -556,11 +608,14 @@ class DurationExpression(CombinedExpression): pass else: allowed_fields = { - 'DecimalField', 'DurationField', 'FloatField', 'IntegerField', + "DecimalField", + "DurationField", + "FloatField", + "IntegerField", } if lhs_type not in allowed_fields or rhs_type not in allowed_fields: raise DatabaseError( - f'Invalid arguments for operator {self.connector}.' + f"Invalid arguments for operator {self.connector}." ) return sql, params @@ -575,10 +630,12 @@ class TemporalSubtraction(CombinedExpression): connection.ops.check_expression_support(self) lhs = compiler.compile(self.lhs) rhs = compiler.compile(self.rhs) - return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs) + return connection.ops.subtract_temporals( + self.lhs.output_field.get_internal_type(), lhs, rhs + ) -@deconstructible(path='django.db.models.F') +@deconstructible(path="django.db.models.F") class F(Combinable): """An object capable of resolving references to existing query objects.""" @@ -592,8 +649,9 @@ class F(Combinable): def __repr__(self): return "{}({})".format(self.__class__.__name__, self.name) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, - summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): return query.resolve_ref(self.name, allow_joins, reuse, summarize) def asc(self, **kwargs): @@ -616,12 +674,13 @@ class ResolvedOuterRef(F): In this case, the reference to the outer query has been resolved because the inner query has been used as a subquery. """ + contains_aggregate = False def as_sql(self, *args, **kwargs): raise ValueError( - 'This queryset contains a reference to an outer query and may ' - 'only be used in a subquery.' + "This queryset contains a reference to an outer query and may " + "only be used in a subquery." ) def resolve_expression(self, *args, **kwargs): @@ -651,18 +710,20 @@ class OuterRef(F): return self -@deconstructible(path='django.db.models.Func') +@deconstructible(path="django.db.models.Func") class Func(SQLiteNumericMixin, Expression): """An SQL function call.""" + function = None - template = '%(function)s(%(expressions)s)' - arg_joiner = ', ' + template = "%(function)s(%(expressions)s)" + arg_joiner = ", " arity = None # The number of arguments the function accepts. def __init__(self, *expressions, output_field=None, **extra): if self.arity is not None and len(expressions) != self.arity: raise TypeError( - "'%s' takes exactly %s %s (%s given)" % ( + "'%s' takes exactly %s %s (%s given)" + % ( self.__class__.__name__, self.arity, "argument" if self.arity == 1 else "arguments", @@ -677,7 +738,9 @@ class Func(SQLiteNumericMixin, Expression): args = self.arg_joiner.join(str(arg) for arg in self.source_expressions) extra = {**self.extra, **self._get_repr_options()} if extra: - extra = ', '.join(str(key) + '=' + str(val) for key, val in sorted(extra.items())) + extra = ", ".join( + str(key) + "=" + str(val) for key, val in sorted(extra.items()) + ) return "{}({}, {})".format(self.__class__.__name__, args, extra) return "{}({})".format(self.__class__.__name__, args) @@ -691,14 +754,26 @@ class Func(SQLiteNumericMixin, Expression): def set_source_expressions(self, exprs): self.source_expressions = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = self.copy() c.is_summary = summarize for pos, arg in enumerate(c.source_expressions): - c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.source_expressions[pos] = arg.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return c - def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context): + def as_sql( + self, + compiler, + connection, + function=None, + template=None, + arg_joiner=None, + **extra_context, + ): connection.ops.check_expression_support(self) sql_parts = [] params = [] @@ -706,7 +781,9 @@ class Func(SQLiteNumericMixin, Expression): try: arg_sql, arg_params = compiler.compile(arg) except EmptyResultSet: - empty_result_set_value = getattr(arg, 'empty_result_set_value', NotImplemented) + empty_result_set_value = getattr( + arg, "empty_result_set_value", NotImplemented + ) if empty_result_set_value is NotImplemented: raise arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) @@ -717,12 +794,12 @@ class Func(SQLiteNumericMixin, Expression): # method, a value supplied in __init__()'s **extra (the value in # `data`), or the value defined on the class. if function is not None: - data['function'] = function + data["function"] = function else: - data.setdefault('function', self.function) - template = template or data.get('template', self.template) - arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner) - data['expressions'] = data['field'] = arg_joiner.join(sql_parts) + data.setdefault("function", self.function) + template = template or data.get("template", self.template) + arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) + data["expressions"] = data["field"] = arg_joiner.join(sql_parts) return template % data, params def copy(self): @@ -732,9 +809,10 @@ class Func(SQLiteNumericMixin, Expression): return copy -@deconstructible(path='django.db.models.Value') +@deconstructible(path="django.db.models.Value") class Value(SQLiteNumericMixin, Expression): """Represent a wrapped value as a node within an expression.""" + # Provide a default value for `for_save` in order to allow unresolved # instances to be compiled until a decision is taken in #25425. for_save = False @@ -752,7 +830,7 @@ class Value(SQLiteNumericMixin, Expression): self.value = value def __repr__(self): - return f'{self.__class__.__name__}({self.value!r})' + return f"{self.__class__.__name__}({self.value!r})" def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) @@ -763,16 +841,18 @@ class Value(SQLiteNumericMixin, Expression): val = output_field.get_db_prep_save(val, connection=connection) else: val = output_field.get_db_prep_value(val, connection=connection) - if hasattr(output_field, 'get_placeholder'): + if hasattr(output_field, "get_placeholder"): return output_field.get_placeholder(val, compiler, connection), [val] if val is None: # cx_Oracle does not always convert None to the appropriate # NULL type (like in case expressions using numbers), so we # use a literal SQL NULL - return 'NULL', [] - return '%s', [val] + return "NULL", [] + return "%s", [val] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) c.for_save = for_save return c @@ -820,12 +900,14 @@ class RawSQL(Expression): return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params) def as_sql(self, compiler, connection): - return '(%s)' % self.sql, self.params + return "(%s)" % self.sql, self.params def get_group_by_cols(self, alias=None): return [self] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): # Resolve parents fields used in raw SQL. for parent in query.model._meta.get_parent_list(): for parent_field in parent._meta.local_fields: @@ -833,7 +915,9 @@ class RawSQL(Expression): if column_name.lower() in self.sql.lower(): query.resolve_ref(parent_field.name, allow_joins, reuse, summarize) break - return super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + return super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) class Star(Expression): @@ -841,7 +925,7 @@ class Star(Expression): return "'*'" def as_sql(self, compiler, connection): - return '*', [] + return "*", [] class Col(Expression): @@ -858,18 +942,20 @@ class Col(Expression): def __repr__(self): alias, target = self.alias, self.target identifiers = (alias, str(target)) if alias else (str(target),) - return '{}({})'.format(self.__class__.__name__, ', '.join(identifiers)) + return "{}({})".format(self.__class__.__name__, ", ".join(identifiers)) def as_sql(self, compiler, connection): alias, column = self.alias, self.target.column identifiers = (alias, column) if alias else (column,) - sql = '.'.join(map(compiler.quote_name_unless_alias, identifiers)) + sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) return sql, [] def relabeled_clone(self, relabels): if self.alias is None: return self - return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field) + return self.__class__( + relabels.get(self.alias, self.alias), self.target, self.output_field + ) def get_group_by_cols(self, alias=None): return [self] @@ -877,8 +963,9 @@ class Col(Expression): def get_db_converters(self, connection): if self.target == self.output_field: return self.output_field.get_db_converters(connection) - return (self.output_field.get_db_converters(connection) + - self.target.get_db_converters(connection)) + return self.output_field.get_db_converters( + connection + ) + self.target.get_db_converters(connection) class Ref(Expression): @@ -886,6 +973,7 @@ class Ref(Expression): Reference to column alias of the query. For example, Ref('sum_cost') in qs.annotate(sum_cost=Sum('cost')) query. """ + def __init__(self, refs, source): super().__init__() self.refs, self.source = refs, source @@ -897,9 +985,11 @@ class Ref(Expression): return [self.source] def set_source_expressions(self, exprs): - self.source, = exprs + (self.source,) = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): # The sub-expression `source` has already been resolved, as this is # just a reference to the name of `source`. return self @@ -920,11 +1010,14 @@ class ExpressionList(Func): list of expressions as an argument to another expression, like a partition clause. """ - template = '%(expressions)s' + + template = "%(expressions)s" def __init__(self, *expressions, **extra): if not expressions: - raise ValueError('%s requires at least one expression.' % self.__class__.__name__) + raise ValueError( + "%s requires at least one expression." % self.__class__.__name__ + ) super().__init__(*expressions, **extra) def __str__(self): @@ -936,13 +1029,13 @@ class ExpressionList(Func): class OrderByList(Func): - template = 'ORDER BY %(expressions)s' + template = "ORDER BY %(expressions)s" def __init__(self, *expressions, **extra): expressions = ( ( OrderBy(F(expr[1:]), descending=True) - if isinstance(expr, str) and expr[0] == '-' + if isinstance(expr, str) and expr[0] == "-" else expr ) for expr in expressions @@ -951,11 +1044,11 @@ class OrderByList(Func): def as_sql(self, *args, **kwargs): if not self.source_expressions: - return '', () + return "", () return super().as_sql(*args, **kwargs) -@deconstructible(path='django.db.models.ExpressionWrapper') +@deconstructible(path="django.db.models.ExpressionWrapper") class ExpressionWrapper(SQLiteNumericMixin, Expression): """ An expression that can wrap another expression so that it can provide @@ -988,9 +1081,9 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression): return "{}({})".format(self.__class__.__name__, self.expression) -@deconstructible(path='django.db.models.When') +@deconstructible(path="django.db.models.When") class When(Expression): - template = 'WHEN %(condition)s THEN %(result)s' + template = "WHEN %(condition)s THEN %(result)s" # This isn't a complete conditional expression, must be used in Case(). conditional = False @@ -998,12 +1091,12 @@ class When(Expression): if lookups: if condition is None: condition, lookups = Q(**lookups), None - elif getattr(condition, 'conditional', False): + elif getattr(condition, "conditional", False): condition, lookups = Q(condition, **lookups), None - if condition is None or not getattr(condition, 'conditional', False) or lookups: + if condition is None or not getattr(condition, "conditional", False) or lookups: raise TypeError( - 'When() supports a Q object, a boolean expression, or lookups ' - 'as a condition.' + "When() supports a Q object, a boolean expression, or lookups " + "as a condition." ) if isinstance(condition, Q) and not condition: raise ValueError("An empty Q() can't be used as a When() condition.") @@ -1027,12 +1120,18 @@ class When(Expression): # We're only interested in the fields of the result expressions. return [self.result._output_field_or_none] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = self.copy() c.is_summary = summarize - if hasattr(c.condition, 'resolve_expression'): - c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False) - c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save) + if hasattr(c.condition, "resolve_expression"): + c.condition = c.condition.resolve_expression( + query, allow_joins, reuse, summarize, False + ) + c.result = c.result.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return c def as_sql(self, compiler, connection, template=None, **extra_context): @@ -1040,10 +1139,10 @@ class When(Expression): template_params = extra_context sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) - template_params['condition'] = condition_sql + template_params["condition"] = condition_sql sql_params.extend(condition_params) result_sql, result_params = compiler.compile(self.result) - template_params['result'] = result_sql + template_params["result"] = result_sql sql_params.extend(result_params) template = template or self.template return template % template_params, sql_params @@ -1056,7 +1155,7 @@ class When(Expression): return cols -@deconstructible(path='django.db.models.Case') +@deconstructible(path="django.db.models.Case") class Case(SQLiteNumericMixin, Expression): """ An SQL searched CASE expression: @@ -1069,8 +1168,9 @@ class Case(SQLiteNumericMixin, Expression): ELSE 'zero' END """ - template = 'CASE %(cases)s ELSE %(default)s END' - case_joiner = ' ' + + template = "CASE %(cases)s ELSE %(default)s END" + case_joiner = " " def __init__(self, *cases, default=None, output_field=None, **extra): if not all(isinstance(case, When) for case in cases): @@ -1081,7 +1181,10 @@ class Case(SQLiteNumericMixin, Expression): self.extra = extra def __str__(self): - return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) + return "CASE %s, ELSE %r" % ( + ", ".join(str(c) for c in self.cases), + self.default, + ) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self) @@ -1092,12 +1195,18 @@ class Case(SQLiteNumericMixin, Expression): def set_source_expressions(self, exprs): *self.cases, self.default = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = self.copy() c.is_summary = summarize for pos, case in enumerate(c.cases): - c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save) - c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.cases[pos] = case.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + c.default = c.default.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return c def copy(self): @@ -1105,7 +1214,9 @@ class Case(SQLiteNumericMixin, Expression): c.cases = c.cases[:] return c - def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context): + def as_sql( + self, compiler, connection, template=None, case_joiner=None, **extra_context + ): connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) @@ -1123,10 +1234,10 @@ class Case(SQLiteNumericMixin, Expression): if not case_parts: return default_sql, default_params case_joiner = case_joiner or self.case_joiner - template_params['cases'] = case_joiner.join(case_parts) - template_params['default'] = default_sql + template_params["cases"] = case_joiner.join(case_parts) + template_params["default"] = default_sql sql_params.extend(default_params) - template = template or template_params.get('template', self.template) + template = template or template_params.get("template", self.template) sql = template % template_params if self._output_field_or_none is not None: sql = connection.ops.unification_cast_sql(self.output_field) % sql @@ -1143,13 +1254,14 @@ class Subquery(BaseExpression, Combinable): An explicit subquery. It may contain OuterRef() references to the outer query which will be resolved when it is applied to that query. """ - template = '(%(subquery)s)' + + template = "(%(subquery)s)" contains_aggregate = False empty_result_set_value = None def __init__(self, queryset, output_field=None, **extra): # Allow the usage of both QuerySet and sql.Query objects. - self.query = getattr(queryset, 'query', queryset).clone() + self.query = getattr(queryset, "query", queryset).clone() self.query.subquery = True self.extra = extra super().__init__(output_field) @@ -1180,9 +1292,9 @@ class Subquery(BaseExpression, Combinable): template_params = {**self.extra, **extra_context} query = query or self.query subquery_sql, sql_params = query.as_sql(compiler, connection) - template_params['subquery'] = subquery_sql[1:-1] + template_params["subquery"] = subquery_sql[1:-1] - template = template or template_params.get('template', self.template) + template = template or template_params.get("template", self.template) sql = template % template_params return sql, sql_params @@ -1197,7 +1309,7 @@ class Subquery(BaseExpression, Combinable): class Exists(Subquery): - template = 'EXISTS(%(subquery)s)' + template = "EXISTS(%(subquery)s)" output_field = fields.BooleanField() def __init__(self, queryset, negated=False, **kwargs): @@ -1227,7 +1339,7 @@ class Exists(Subquery): return compiler.compile(Value(True)) raise if self.negated: - sql = 'NOT {}'.format(sql) + sql = "NOT {}".format(sql) return sql, params def select_format(self, compiler, sql, params): @@ -1235,28 +1347,31 @@ class Exists(Subquery): # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP # BY list. if not compiler.connection.features.supports_boolean_expr_in_select_clause: - sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql) + sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql) return sql, params -@deconstructible(path='django.db.models.OrderBy') +@deconstructible(path="django.db.models.OrderBy") class OrderBy(Expression): - template = '%(expression)s %(ordering)s' + template = "%(expression)s %(ordering)s" conditional = False - def __init__(self, expression, descending=False, nulls_first=False, nulls_last=False): + def __init__( + self, expression, descending=False, nulls_first=False, nulls_last=False + ): if nulls_first and nulls_last: - raise ValueError('nulls_first and nulls_last are mutually exclusive') + raise ValueError("nulls_first and nulls_last are mutually exclusive") self.nulls_first = nulls_first self.nulls_last = nulls_last self.descending = descending - if not hasattr(expression, 'resolve_expression'): - raise ValueError('expression must be an expression type') + if not hasattr(expression, "resolve_expression"): + raise ValueError("expression must be an expression type") self.expression = expression def __repr__(self): return "{}({}, descending={})".format( - self.__class__.__name__, self.expression, self.descending) + self.__class__.__name__, self.expression, self.descending + ) def set_source_expressions(self, exprs): self.expression = exprs[0] @@ -1268,32 +1383,34 @@ class OrderBy(Expression): template = template or self.template if connection.features.supports_order_by_nulls_modifier: if self.nulls_last: - template = '%s NULLS LAST' % template + template = "%s NULLS LAST" % template elif self.nulls_first: - template = '%s NULLS FIRST' % template + template = "%s NULLS FIRST" % template else: if self.nulls_last and not ( self.descending and connection.features.order_by_nulls_first ): - template = '%%(expression)s IS NULL, %s' % template + template = "%%(expression)s IS NULL, %s" % template elif self.nulls_first and not ( not self.descending and connection.features.order_by_nulls_first ): - template = '%%(expression)s IS NOT NULL, %s' % template + template = "%%(expression)s IS NOT NULL, %s" % template connection.ops.check_expression_support(self) expression_sql, params = compiler.compile(self.expression) placeholders = { - 'expression': expression_sql, - 'ordering': 'DESC' if self.descending else 'ASC', + "expression": expression_sql, + "ordering": "DESC" if self.descending else "ASC", **extra_context, } - params *= template.count('%(expression)s') + params *= template.count("%(expression)s") return (template % placeholders).rstrip(), params def as_oracle(self, compiler, connection): # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped # in a CASE WHEN. - if connection.ops.conditional_expression_supported_in_where_clause(self.expression): + if connection.ops.conditional_expression_supported_in_where_clause( + self.expression + ): copy = self.copy() copy.expression = Case( When(self.expression, then=True), @@ -1323,7 +1440,7 @@ class OrderBy(Expression): class Window(SQLiteNumericMixin, Expression): - template = '%(expression)s OVER (%(window)s)' + template = "%(expression)s OVER (%(window)s)" # Although the main expression may either be an aggregate or an # expression with an aggregate function, the GROUP BY that will # be introduced in the query as a result is not desired. @@ -1331,15 +1448,22 @@ class Window(SQLiteNumericMixin, Expression): contains_over_clause = True filterable = False - def __init__(self, expression, partition_by=None, order_by=None, frame=None, output_field=None): + def __init__( + self, + expression, + partition_by=None, + order_by=None, + frame=None, + output_field=None, + ): self.partition_by = partition_by self.order_by = order_by self.frame = frame - if not getattr(expression, 'window_compatible', False): + if not getattr(expression, "window_compatible", False): raise ValueError( - "Expression '%s' isn't compatible with OVER clauses." % - expression.__class__.__name__ + "Expression '%s' isn't compatible with OVER clauses." + % expression.__class__.__name__ ) if self.partition_by is not None: @@ -1354,8 +1478,8 @@ class Window(SQLiteNumericMixin, Expression): self.order_by = OrderByList(self.order_by) else: raise ValueError( - 'Window.order_by must be either a string reference to a ' - 'field, an expression, or a list or tuple of them.' + "Window.order_by must be either a string reference to a " + "field, an expression, or a list or tuple of them." ) super().__init__(output_field=output_field) self.source_expression = self._parse_expressions(expression)[0] @@ -1372,14 +1496,15 @@ class Window(SQLiteNumericMixin, Expression): def as_sql(self, compiler, connection, template=None): connection.ops.check_expression_support(self) if not connection.features.supports_over_clause: - raise NotSupportedError('This backend does not support window expressions.') + raise NotSupportedError("This backend does not support window expressions.") expr_sql, params = compiler.compile(self.source_expression) window_sql, window_params = [], [] if self.partition_by is not None: sql_expr, sql_params = self.partition_by.as_sql( - compiler=compiler, connection=connection, - template='PARTITION BY %(expressions)s', + compiler=compiler, + connection=connection, + template="PARTITION BY %(expressions)s", ) window_sql.append(sql_expr) window_params.extend(sql_params) @@ -1397,10 +1522,10 @@ class Window(SQLiteNumericMixin, Expression): params.extend(window_params) template = template or self.template - return template % { - 'expression': expr_sql, - 'window': ' '.join(window_sql).strip() - }, params + return ( + template % {"expression": expr_sql, "window": " ".join(window_sql).strip()}, + params, + ) def as_sqlite(self, compiler, connection): if isinstance(self.output_field, fields.DecimalField): @@ -1413,15 +1538,15 @@ class Window(SQLiteNumericMixin, Expression): return self.as_sql(compiler, connection) def __str__(self): - return '{} OVER ({}{}{})'.format( + return "{} OVER ({}{}{})".format( str(self.source_expression), - 'PARTITION BY ' + str(self.partition_by) if self.partition_by else '', - str(self.order_by or ''), - str(self.frame or ''), + "PARTITION BY " + str(self.partition_by) if self.partition_by else "", + str(self.order_by or ""), + str(self.frame or ""), ) def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self) + return "<%s: %s>" % (self.__class__.__name__, self) def get_group_by_cols(self, alias=None): return [] @@ -1435,7 +1560,8 @@ class WindowFrame(Expression): frame is optional (the default is UNBOUNDED FOLLOWING, which is the last row in the frame). """ - template = '%(frame_type)s BETWEEN %(start)s AND %(end)s' + + template = "%(frame_type)s BETWEEN %(start)s AND %(end)s" def __init__(self, start=None, end=None): self.start = Value(start) @@ -1449,52 +1575,58 @@ class WindowFrame(Expression): def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) - start, end = self.window_frame_start_end(connection, self.start.value, self.end.value) - return self.template % { - 'frame_type': self.frame_type, - 'start': start, - 'end': end, - }, [] + start, end = self.window_frame_start_end( + connection, self.start.value, self.end.value + ) + return ( + self.template + % { + "frame_type": self.frame_type, + "start": start, + "end": end, + }, + [], + ) def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self) + return "<%s: %s>" % (self.__class__.__name__, self) def get_group_by_cols(self, alias=None): return [] def __str__(self): if self.start.value is not None and self.start.value < 0: - start = '%d %s' % (abs(self.start.value), connection.ops.PRECEDING) + start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) elif self.start.value is not None and self.start.value == 0: start = connection.ops.CURRENT_ROW else: start = connection.ops.UNBOUNDED_PRECEDING if self.end.value is not None and self.end.value > 0: - end = '%d %s' % (self.end.value, connection.ops.FOLLOWING) + end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) elif self.end.value is not None and self.end.value == 0: end = connection.ops.CURRENT_ROW else: end = connection.ops.UNBOUNDED_FOLLOWING return self.template % { - 'frame_type': self.frame_type, - 'start': start, - 'end': end, + "frame_type": self.frame_type, + "start": start, + "end": end, } def window_frame_start_end(self, connection, start, end): - raise NotImplementedError('Subclasses must implement window_frame_start_end().') + raise NotImplementedError("Subclasses must implement window_frame_start_end().") class RowRange(WindowFrame): - frame_type = 'ROWS' + frame_type = "ROWS" def window_frame_start_end(self, connection, start, end): return connection.ops.window_frame_rows_start_end(start, end) class ValueRange(WindowFrame): - frame_type = 'RANGE' + frame_type = "RANGE" def window_frame_start_end(self, connection, start, end): return connection.ops.window_frame_range_start_end(start, end) diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 6d6d10a483..313e31b5f5 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -19,7 +19,10 @@ from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin from django.utils import timezone from django.utils.datastructures import DictWrapper from django.utils.dateparse import ( - parse_date, parse_datetime, parse_duration, parse_time, + parse_date, + parse_datetime, + parse_duration, + parse_time, ) from django.utils.duration import duration_microseconds, duration_string from django.utils.functional import Promise, cached_property @@ -29,14 +32,38 @@ from django.utils.text import capfirst from django.utils.translation import gettext_lazy as _ __all__ = [ - 'AutoField', 'BLANK_CHOICE_DASH', 'BigAutoField', 'BigIntegerField', - 'BinaryField', 'BooleanField', 'CharField', 'CommaSeparatedIntegerField', - 'DateField', 'DateTimeField', 'DecimalField', 'DurationField', - 'EmailField', 'Empty', 'Field', 'FilePathField', 'FloatField', - 'GenericIPAddressField', 'IPAddressField', 'IntegerField', 'NOT_PROVIDED', - 'NullBooleanField', 'PositiveBigIntegerField', 'PositiveIntegerField', - 'PositiveSmallIntegerField', 'SlugField', 'SmallAutoField', - 'SmallIntegerField', 'TextField', 'TimeField', 'URLField', 'UUIDField', + "AutoField", + "BLANK_CHOICE_DASH", + "BigAutoField", + "BigIntegerField", + "BinaryField", + "BooleanField", + "CharField", + "CommaSeparatedIntegerField", + "DateField", + "DateTimeField", + "DecimalField", + "DurationField", + "EmailField", + "Empty", + "Field", + "FilePathField", + "FloatField", + "GenericIPAddressField", + "IPAddressField", + "IntegerField", + "NOT_PROVIDED", + "NullBooleanField", + "PositiveBigIntegerField", + "PositiveIntegerField", + "PositiveSmallIntegerField", + "SlugField", + "SmallAutoField", + "SmallIntegerField", + "TextField", + "TimeField", + "URLField", + "UUIDField", ] @@ -72,6 +99,7 @@ def _load_field(app_label, model_name, field_name): # # getattr(obj, opts.pk.attname) + def _empty(of_cls): new = Empty() new.__class__ = of_cls @@ -98,14 +126,16 @@ class Field(RegisterLookupMixin): auto_creation_counter = -1 default_validators = [] # Default set of validators default_error_messages = { - 'invalid_choice': _('Value %(value)r is not a valid choice.'), - 'null': _('This field cannot be null.'), - 'blank': _('This field cannot be blank.'), - 'unique': _('%(model_name)s with this %(field_label)s already exists.'), + "invalid_choice": _("Value %(value)r is not a valid choice."), + "null": _("This field cannot be null."), + "blank": _("This field cannot be blank."), + "unique": _("%(model_name)s with this %(field_label)s already exists."), # Translators: The 'lookup_type' is one of 'date', 'year' or 'month'. # Eg: "Title must be unique for pub_date year" - 'unique_for_date': _("%(field_label)s must be unique for " - "%(date_field_label)s %(lookup_type)s."), + "unique_for_date": _( + "%(field_label)s must be unique for " + "%(date_field_label)s %(lookup_type)s." + ), } system_check_deprecated_details = None system_check_removed_details = None @@ -123,18 +153,37 @@ class Field(RegisterLookupMixin): # Generic field type description, usually overridden by subclasses def _description(self): - return _('Field of type: %(field_type)s') % { - 'field_type': self.__class__.__name__ + return _("Field of type: %(field_type)s") % { + "field_type": self.__class__.__name__ } + description = property(_description) - def __init__(self, verbose_name=None, name=None, primary_key=False, - max_length=None, unique=False, blank=False, null=False, - db_index=False, rel=None, default=NOT_PROVIDED, editable=True, - serialize=True, unique_for_date=None, unique_for_month=None, - unique_for_year=None, choices=None, help_text='', db_column=None, - db_tablespace=None, auto_created=False, validators=(), - error_messages=None): + def __init__( + self, + verbose_name=None, + name=None, + primary_key=False, + max_length=None, + unique=False, + blank=False, + null=False, + db_index=False, + rel=None, + default=NOT_PROVIDED, + editable=True, + serialize=True, + unique_for_date=None, + unique_for_month=None, + unique_for_year=None, + choices=None, + help_text="", + db_column=None, + db_tablespace=None, + auto_created=False, + validators=(), + error_messages=None, + ): self.name = name self.verbose_name = verbose_name # May be set by set_attributes_from_name self._verbose_name = verbose_name # Store original for deconstruction @@ -170,7 +219,7 @@ class Field(RegisterLookupMixin): messages = {} for c in reversed(self.__class__.__mro__): - messages.update(getattr(c, 'default_error_messages', {})) + messages.update(getattr(c, "default_error_messages", {})) messages.update(error_messages or {}) self._error_messages = error_messages # Store for deconstruction later self.error_messages = messages @@ -180,18 +229,18 @@ class Field(RegisterLookupMixin): Return "app_label.model_label.field_name" for fields attached to models. """ - if not hasattr(self, 'model'): + if not hasattr(self, "model"): return super().__str__() model = self.model - return '%s.%s' % (model._meta.label, self.name) + return "%s.%s" % (model._meta.label, self.name) def __repr__(self): """Display the module, class, and name of the field.""" - path = '%s.%s' % (self.__class__.__module__, self.__class__.__qualname__) - name = getattr(self, 'name', None) + path = "%s.%s" % (self.__class__.__module__, self.__class__.__qualname__) + name = getattr(self, "name", None) if name is not None: - return '<%s: %s>' % (path, name) - return '<%s>' % path + return "<%s: %s>" % (path, name) + return "<%s>" % path def check(self, **kwargs): return [ @@ -209,12 +258,12 @@ class Field(RegisterLookupMixin): Check if field name is valid, i.e. 1) does not end with an underscore, 2) does not contain "__" and 3) is not "pk". """ - if self.name.endswith('_'): + if self.name.endswith("_"): return [ checks.Error( - 'Field names must not end with an underscore.', + "Field names must not end with an underscore.", obj=self, - id='fields.E001', + id="fields.E001", ) ] elif LOOKUP_SEP in self.name: @@ -222,15 +271,15 @@ class Field(RegisterLookupMixin): checks.Error( 'Field names must not contain "%s".' % LOOKUP_SEP, obj=self, - id='fields.E002', + id="fields.E002", ) ] - elif self.name == 'pk': + elif self.name == "pk": return [ checks.Error( "'pk' is a reserved word that cannot be used as a field name.", obj=self, - id='fields.E003', + id="fields.E003", ) ] else: @@ -249,7 +298,7 @@ class Field(RegisterLookupMixin): checks.Error( "'choices' must be an iterable (e.g., a list or tuple).", obj=self, - id='fields.E004', + id="fields.E004", ) ] @@ -268,14 +317,22 @@ class Field(RegisterLookupMixin): ): break if self.max_length is not None and group_choices: - choice_max_length = max([ - choice_max_length, - *(len(value) for value, _ in group_choices if isinstance(value, str)), - ]) + choice_max_length = max( + [ + choice_max_length, + *( + len(value) + for value, _ in group_choices + if isinstance(value, str) + ), + ] + ) except (TypeError, ValueError): # No groups, choices in the form [value, display] value, human_name = group_name, group_choices - if not self._choices_is_value(value) or not self._choices_is_value(human_name): + if not self._choices_is_value(value) or not self._choices_is_value( + human_name + ): break if self.max_length is not None and isinstance(value, str): choice_max_length = max(choice_max_length, len(value)) @@ -290,7 +347,7 @@ class Field(RegisterLookupMixin): "'max_length' is too small to fit the longest value " "in 'choices' (%d characters)." % choice_max_length, obj=self, - id='fields.E009', + id="fields.E009", ), ] return [] @@ -300,7 +357,7 @@ class Field(RegisterLookupMixin): "'choices' must be an iterable containing " "(actual value, human readable name) tuples.", obj=self, - id='fields.E005', + id="fields.E005", ) ] @@ -310,25 +367,30 @@ class Field(RegisterLookupMixin): checks.Error( "'db_index' must be None, True or False.", obj=self, - id='fields.E006', + id="fields.E006", ) ] else: return [] def _check_null_allowed_for_primary_keys(self): - if (self.primary_key and self.null and - not connection.features.interprets_empty_strings_as_nulls): + if ( + self.primary_key + and self.null + and not connection.features.interprets_empty_strings_as_nulls + ): # We cannot reliably check this for backends like Oracle which # consider NULL and '' to be equal (and thus set up # character-based fields a little differently). return [ checks.Error( - 'Primary keys must not have null=True.', - hint=('Set null=False on the field, or ' - 'remove primary_key=True argument.'), + "Primary keys must not have null=True.", + hint=( + "Set null=False on the field, or " + "remove primary_key=True argument." + ), obj=self, - id='fields.E007', + id="fields.E007", ) ] else: @@ -340,7 +402,9 @@ class Field(RegisterLookupMixin): app_label = self.model._meta.app_label errors = [] for alias in databases: - if router.allow_migrate(alias, app_label, model_name=self.model._meta.model_name): + if router.allow_migrate( + alias, app_label, model_name=self.model._meta.model_name + ): errors.extend(connections[alias].validation.check_field(self, **kwargs)) return errors @@ -354,11 +418,12 @@ class Field(RegisterLookupMixin): hint=( "validators[{i}] ({repr}) isn't a function or " "instance of a validator class.".format( - i=i, repr=repr(validator), + i=i, + repr=repr(validator), ) ), obj=self, - id='fields.E008', + id="fields.E008", ) ) return errors @@ -368,41 +433,41 @@ class Field(RegisterLookupMixin): return [ checks.Error( self.system_check_removed_details.get( - 'msg', - '%s has been removed except for support in historical ' - 'migrations.' % self.__class__.__name__ + "msg", + "%s has been removed except for support in historical " + "migrations." % self.__class__.__name__, ), - hint=self.system_check_removed_details.get('hint'), + hint=self.system_check_removed_details.get("hint"), obj=self, - id=self.system_check_removed_details.get('id', 'fields.EXXX'), + id=self.system_check_removed_details.get("id", "fields.EXXX"), ) ] elif self.system_check_deprecated_details is not None: return [ checks.Warning( self.system_check_deprecated_details.get( - 'msg', - '%s has been deprecated.' % self.__class__.__name__ + "msg", "%s has been deprecated." % self.__class__.__name__ ), - hint=self.system_check_deprecated_details.get('hint'), + hint=self.system_check_deprecated_details.get("hint"), obj=self, - id=self.system_check_deprecated_details.get('id', 'fields.WXXX'), + id=self.system_check_deprecated_details.get("id", "fields.WXXX"), ) ] return [] def get_col(self, alias, output_field=None): - if ( - alias == self.model._meta.db_table and - (output_field is None or output_field == self) + if alias == self.model._meta.db_table and ( + output_field is None or output_field == self ): return self.cached_col from django.db.models.expressions import Col + return Col(alias, self, output_field) @cached_property def cached_col(self): from django.db.models.expressions import Col + return Col(self.model._meta.db_table, self) def select_format(self, compiler, sql, params): @@ -462,7 +527,7 @@ class Field(RegisterLookupMixin): "unique_for_month": None, "unique_for_year": None, "choices": None, - "help_text": '', + "help_text": "", "db_column": None, "db_tablespace": None, "auto_created": False, @@ -495,8 +560,8 @@ class Field(RegisterLookupMixin): path = path.replace("django.db.models.fields.related", "django.db.models") elif path.startswith("django.db.models.fields.files"): path = path.replace("django.db.models.fields.files", "django.db.models") - elif path.startswith('django.db.models.fields.json'): - path = path.replace('django.db.models.fields.json', 'django.db.models') + elif path.startswith("django.db.models.fields.json"): + path = path.replace("django.db.models.fields.json", "django.db.models") elif path.startswith("django.db.models.fields.proxy"): path = path.replace("django.db.models.fields.proxy", "django.db.models") elif path.startswith("django.db.models.fields"): @@ -515,10 +580,9 @@ class Field(RegisterLookupMixin): def __eq__(self, other): # Needed for @total_ordering if isinstance(other, Field): - return ( - self.creation_counter == other.creation_counter and - getattr(self, 'model', None) == getattr(other, 'model', None) - ) + return self.creation_counter == other.creation_counter and getattr( + self, "model", None + ) == getattr(other, "model", None) return NotImplemented def __lt__(self, other): @@ -526,17 +590,18 @@ class Field(RegisterLookupMixin): # Order by creation_counter first for backward compatibility. if isinstance(other, Field): if ( - self.creation_counter != other.creation_counter or - not hasattr(self, 'model') and not hasattr(other, 'model') + self.creation_counter != other.creation_counter + or not hasattr(self, "model") + and not hasattr(other, "model") ): return self.creation_counter < other.creation_counter - elif hasattr(self, 'model') != hasattr(other, 'model'): - return not hasattr(self, 'model') # Order no-model fields first + elif hasattr(self, "model") != hasattr(other, "model"): + return not hasattr(self, "model") # Order no-model fields first else: # creation_counter's are equal, compare only models. - return ( - (self.model._meta.app_label, self.model._meta.model_name) < - (other.model._meta.app_label, other.model._meta.model_name) + return (self.model._meta.app_label, self.model._meta.model_name) < ( + other.model._meta.app_label, + other.model._meta.model_name, ) return NotImplemented @@ -549,7 +614,7 @@ class Field(RegisterLookupMixin): obj = copy.copy(self) if self.remote_field: obj.remote_field = copy.copy(self.remote_field) - if hasattr(self.remote_field, 'field') and self.remote_field.field is self: + if hasattr(self.remote_field, "field") and self.remote_field.field is self: obj.remote_field.field = obj memodict[id(self)] = obj return obj @@ -568,7 +633,7 @@ class Field(RegisterLookupMixin): not a new copy of that field. So, use the app registry to load the model and then the field back. """ - if not hasattr(self, 'model'): + if not hasattr(self, "model"): # Fields are sometimes used without attaching them to models (for # example in aggregation). In this case give back a plain field # instance. The code below will create a new empty instance of @@ -577,10 +642,13 @@ class Field(RegisterLookupMixin): state = self.__dict__.copy() # The _get_default cached_property can't be pickled due to lambda # usage. - state.pop('_get_default', None) + state.pop("_get_default", None) return _empty, (self.__class__,), state - return _load_field, (self.model._meta.app_label, self.model._meta.object_name, - self.name) + return _load_field, ( + self.model._meta.app_label, + self.model._meta.object_name, + self.name, + ) def get_pk_value_on_save(self, instance): """ @@ -618,7 +686,7 @@ class Field(RegisterLookupMixin): try: v(value) except exceptions.ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: + if hasattr(e, "code") and e.code in self.error_messages: e.message = self.error_messages[e.code] errors.extend(e.error_list) @@ -645,16 +713,16 @@ class Field(RegisterLookupMixin): elif value == option_key: return raise exceptions.ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': value}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": value}, ) if value is None and not self.null: - raise exceptions.ValidationError(self.error_messages['null'], code='null') + raise exceptions.ValidationError(self.error_messages["null"], code="null") if not self.blank and value in self.empty_values: - raise exceptions.ValidationError(self.error_messages['blank'], code='blank') + raise exceptions.ValidationError(self.error_messages["blank"], code="blank") def clean(self, value, model_instance): """ @@ -668,7 +736,7 @@ class Field(RegisterLookupMixin): return value def db_type_parameters(self, connection): - return DictWrapper(self.__dict__, connection.ops.quote_name, 'qn_') + return DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") def db_check(self, connection): """ @@ -678,7 +746,9 @@ class Field(RegisterLookupMixin): """ data = self.db_type_parameters(connection) try: - return connection.data_type_check_constraints[self.get_internal_type()] % data + return ( + connection.data_type_check_constraints[self.get_internal_type()] % data + ) except KeyError: return None @@ -740,7 +810,7 @@ class Field(RegisterLookupMixin): return connection.data_types_suffix.get(self.get_internal_type()) def get_db_converters(self, connection): - if hasattr(self, 'from_db_value'): + if hasattr(self, "from_db_value"): return [self.from_db_value] return [] @@ -765,7 +835,7 @@ class Field(RegisterLookupMixin): self.attname, self.column = self.get_attname_column() self.concrete = self.column is not None if self.verbose_name is None and self.name: - self.verbose_name = self.name.replace('_', ' ') + self.verbose_name = self.name.replace("_", " ") def contribute_to_class(self, cls, name, private_only=False): """ @@ -784,10 +854,10 @@ class Field(RegisterLookupMixin): # this class, but don't check methods derived from inheritance, to # allow overriding inherited choices. For more complex inheritance # structures users should override contribute_to_class(). - if 'get_%s_display' % self.name not in cls.__dict__: + if "get_%s_display" % self.name not in cls.__dict__: setattr( cls, - 'get_%s_display' % self.name, + "get_%s_display" % self.name, partialmethod(cls._get_FIELD_display, field=self), ) @@ -848,11 +918,21 @@ class Field(RegisterLookupMixin): return self.default return lambda: self.default - if not self.empty_strings_allowed or self.null and not connection.features.interprets_empty_strings_as_nulls: + if ( + not self.empty_strings_allowed + or self.null + and not connection.features.interprets_empty_strings_as_nulls + ): return return_None return str # return empty string - def get_choices(self, include_blank=True, blank_choice=BLANK_CHOICE_DASH, limit_choices_to=None, ordering=()): + def get_choices( + self, + include_blank=True, + blank_choice=BLANK_CHOICE_DASH, + limit_choices_to=None, + ordering=(), + ): """ Return choices with a default blank choices included, for use as <select> choices for this field. @@ -860,7 +940,9 @@ class Field(RegisterLookupMixin): if self.choices is not None: choices = list(self.choices) if include_blank: - blank_defined = any(choice in ('', None) for choice, _ in self.flatchoices) + blank_defined = any( + choice in ("", None) for choice, _ in self.flatchoices + ) if not blank_defined: choices = blank_choice + choices return choices @@ -868,8 +950,8 @@ class Field(RegisterLookupMixin): limit_choices_to = limit_choices_to or self.get_limit_choices_to() choice_func = operator.attrgetter( self.remote_field.get_related_field().attname - if hasattr(self.remote_field, 'get_related_field') - else 'pk' + if hasattr(self.remote_field, "get_related_field") + else "pk" ) qs = rel_model._default_manager.complex_filter(limit_choices_to) if ordering: @@ -896,6 +978,7 @@ class Field(RegisterLookupMixin): else: flat.append((choice, value)) return flat + flatchoices = property(_get_flatchoices) def save_form_data(self, instance, data): @@ -904,24 +987,25 @@ class Field(RegisterLookupMixin): def formfield(self, form_class=None, choices_form_class=None, **kwargs): """Return a django.forms.Field instance for this field.""" defaults = { - 'required': not self.blank, - 'label': capfirst(self.verbose_name), - 'help_text': self.help_text, + "required": not self.blank, + "label": capfirst(self.verbose_name), + "help_text": self.help_text, } if self.has_default(): if callable(self.default): - defaults['initial'] = self.default - defaults['show_hidden_initial'] = True + defaults["initial"] = self.default + defaults["show_hidden_initial"] = True else: - defaults['initial'] = self.get_default() + defaults["initial"] = self.get_default() if self.choices is not None: # Fields with choices get special treatment. - include_blank = (self.blank or - not (self.has_default() or 'initial' in kwargs)) - defaults['choices'] = self.get_choices(include_blank=include_blank) - defaults['coerce'] = self.to_python + include_blank = self.blank or not ( + self.has_default() or "initial" in kwargs + ) + defaults["choices"] = self.get_choices(include_blank=include_blank) + defaults["coerce"] = self.to_python if self.null: - defaults['empty_value'] = None + defaults["empty_value"] = None if choices_form_class is not None: form_class = choices_form_class else: @@ -930,9 +1014,19 @@ class Field(RegisterLookupMixin): # max_value) don't apply for choice fields, so be sure to only pass # the values that TypedChoiceField will understand. for k in list(kwargs): - if k not in ('coerce', 'empty_value', 'choices', 'required', - 'widget', 'label', 'initial', 'help_text', - 'error_messages', 'show_hidden_initial', 'disabled'): + if k not in ( + "coerce", + "empty_value", + "choices", + "required", + "widget", + "label", + "initial", + "help_text", + "error_messages", + "show_hidden_initial", + "disabled", + ): del kwargs[k] defaults.update(kwargs) if form_class is None: @@ -947,8 +1041,8 @@ class Field(RegisterLookupMixin): class BooleanField(Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value must be either True or False.'), - 'invalid_nullable': _('“%(value)s” value must be either True, False, or None.'), + "invalid": _("“%(value)s” value must be either True or False."), + "invalid_nullable": _("“%(value)s” value must be either True, False, or None."), } description = _("Boolean (Either True or False)") @@ -961,14 +1055,14 @@ class BooleanField(Field): if value in (True, False): # 1/0 are equal to True/False. bool() converts former to latter. return bool(value) - if value in ('t', 'True', '1'): + if value in ("t", "True", "1"): return True - if value in ('f', 'False', '0'): + if value in ("f", "False", "0"): return False raise exceptions.ValidationError( - self.error_messages['invalid_nullable' if self.null else 'invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid_nullable" if self.null else "invalid"], + code="invalid", + params={"value": value}, ) def get_prep_value(self, value): @@ -979,14 +1073,14 @@ class BooleanField(Field): def formfield(self, **kwargs): if self.choices is not None: - include_blank = not (self.has_default() or 'initial' in kwargs) - defaults = {'choices': self.get_choices(include_blank=include_blank)} + include_blank = not (self.has_default() or "initial" in kwargs) + defaults = {"choices": self.get_choices(include_blank=include_blank)} else: form_class = forms.NullBooleanField if self.null else forms.BooleanField # In HTML checkboxes, 'required' means "must be checked" which is # different from the choices case ("must select some value"). # required=False allows unchecked checkboxes. - defaults = {'form_class': form_class, 'required': False} + defaults = {"form_class": form_class, "required": False} return super().formfield(**{**defaults, **kwargs}) def select_format(self, compiler, sql, params): @@ -994,8 +1088,8 @@ class BooleanField(Field): # Filters that match everything are handled as empty strings in the # WHERE clause, but in SELECT or GROUP BY list they must use a # predicate that's always True. - if sql == '': - sql = '1' + if sql == "": + sql = "1" return sql, params @@ -1009,7 +1103,7 @@ class CharField(Field): self.validators.append(validators.MaxLengthValidator(self.max_length)) def check(self, **kwargs): - databases = kwargs.get('databases') or [] + databases = kwargs.get("databases") or [] return [ *super().check(**kwargs), *self._check_db_collation(databases), @@ -1022,16 +1116,19 @@ class CharField(Field): checks.Error( "CharFields must define a 'max_length' attribute.", obj=self, - id='fields.E120', + id="fields.E120", ) ] - elif (not isinstance(self.max_length, int) or isinstance(self.max_length, bool) or - self.max_length <= 0): + elif ( + not isinstance(self.max_length, int) + or isinstance(self.max_length, bool) + or self.max_length <= 0 + ): return [ checks.Error( "'max_length' must be a positive integer.", obj=self, - id='fields.E121', + id="fields.E121", ) ] else: @@ -1044,16 +1141,17 @@ class CharField(Field): continue connection = connections[db] if not ( - self.db_collation is None or - 'supports_collation_on_charfield' in self.model._meta.required_db_features or - connection.features.supports_collation_on_charfield + self.db_collation is None + or "supports_collation_on_charfield" + in self.model._meta.required_db_features + or connection.features.supports_collation_on_charfield ): errors.append( checks.Error( - '%s does not support a database collation on ' - 'CharFields.' % connection.display_name, + "%s does not support a database collation on " + "CharFields." % connection.display_name, obj=self, - id='fields.E190', + id="fields.E190", ), ) return errors @@ -1079,17 +1177,17 @@ class CharField(Field): # Passing max_length to forms.CharField means that the value's length # will be validated twice. This is considered acceptable since we want # the value in the form field (to pass into widget for example). - defaults = {'max_length': self.max_length} + defaults = {"max_length": self.max_length} # TODO: Handle multiple backends with different feature flags. if self.null and not connection.features.interprets_empty_strings_as_nulls: - defaults['empty_value'] = None + defaults["empty_value"] = None defaults.update(kwargs) return super().formfield(**defaults) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.db_collation: - kwargs['db_collation'] = self.db_collation + kwargs["db_collation"] = self.db_collation return name, path, args, kwargs @@ -1097,15 +1195,15 @@ class CommaSeparatedIntegerField(CharField): default_validators = [validators.validate_comma_separated_integer_list] description = _("Comma-separated integers") system_check_removed_details = { - 'msg': ( - 'CommaSeparatedIntegerField is removed except for support in ' - 'historical migrations.' + "msg": ( + "CommaSeparatedIntegerField is removed except for support in " + "historical migrations." ), - 'hint': ( - 'Use CharField(validators=[validate_comma_separated_integer_list]) ' - 'instead.' + "hint": ( + "Use CharField(validators=[validate_comma_separated_integer_list]) " + "instead." ), - 'id': 'fields.E901', + "id": "fields.E901", } @@ -1120,7 +1218,6 @@ def _get_naive_now(): class DateTimeCheckMixin: - def check(self, **kwargs): return [ *super().check(**kwargs), @@ -1132,8 +1229,14 @@ class DateTimeCheckMixin: # auto_now, auto_now_add, and default are mutually exclusive # options. The use of more than one of these options together # will trigger an Error - mutually_exclusive_options = [self.auto_now_add, self.auto_now, self.has_default()] - enabled_options = [option not in (None, False) for option in mutually_exclusive_options].count(True) + mutually_exclusive_options = [ + self.auto_now_add, + self.auto_now, + self.has_default(), + ] + enabled_options = [ + option not in (None, False) for option in mutually_exclusive_options + ].count(True) if enabled_options > 1: return [ checks.Error( @@ -1141,7 +1244,7 @@ class DateTimeCheckMixin: "are mutually exclusive. Only one of these options " "may be present.", obj=self, - id='fields.E160', + id="fields.E160", ) ] else: @@ -1173,15 +1276,15 @@ class DateTimeCheckMixin: if lower <= value <= upper: return [ checks.Warning( - 'Fixed default value provided.', + "Fixed default value provided.", hint=( - 'It seems you set a fixed date / time / datetime ' - 'value as default for this field. This may not be ' - 'what you want. If you want to have the current date ' - 'as default, use `django.utils.timezone.now`' + "It seems you set a fixed date / time / datetime " + "value as default for this field. This may not be " + "what you want. If you want to have the current date " + "as default, use `django.utils.timezone.now`" ), obj=self, - id='fields.W161', + id="fields.W161", ) ] return [] @@ -1190,19 +1293,24 @@ class DateTimeCheckMixin: class DateField(DateTimeCheckMixin, Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value has an invalid date format. It must be ' - 'in YYYY-MM-DD format.'), - 'invalid_date': _('“%(value)s” value has the correct format (YYYY-MM-DD) ' - 'but it is an invalid date.'), + "invalid": _( + "“%(value)s” value has an invalid date format. It must be " + "in YYYY-MM-DD format." + ), + "invalid_date": _( + "“%(value)s” value has the correct format (YYYY-MM-DD) " + "but it is an invalid date." + ), } description = _("Date (without time)") - def __init__(self, verbose_name=None, name=None, auto_now=False, - auto_now_add=False, **kwargs): + def __init__( + self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs + ): self.auto_now, self.auto_now_add = auto_now, auto_now_add if auto_now or auto_now_add: - kwargs['editable'] = False - kwargs['blank'] = True + kwargs["editable"] = False + kwargs["blank"] = True super().__init__(verbose_name, name, **kwargs) def _check_fix_default_value(self): @@ -1227,12 +1335,12 @@ class DateField(DateTimeCheckMixin, Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.auto_now: - kwargs['auto_now'] = True + kwargs["auto_now"] = True if self.auto_now_add: - kwargs['auto_now_add'] = True + kwargs["auto_now_add"] = True if self.auto_now or self.auto_now_add: - del kwargs['editable'] - del kwargs['blank'] + del kwargs["editable"] + del kwargs["blank"] return name, path, args, kwargs def get_internal_type(self): @@ -1257,15 +1365,15 @@ class DateField(DateTimeCheckMixin, Field): return parsed except ValueError: raise exceptions.ValidationError( - self.error_messages['invalid_date'], - code='invalid_date', - params={'value': value}, + self.error_messages["invalid_date"], + code="invalid_date", + params={"value": value}, ) raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def pre_save(self, model_instance, add): @@ -1280,12 +1388,18 @@ class DateField(DateTimeCheckMixin, Field): super().contribute_to_class(cls, name, **kwargs) if not self.null: setattr( - cls, 'get_next_by_%s' % self.name, - partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=True) + cls, + "get_next_by_%s" % self.name, + partialmethod( + cls._get_next_or_previous_by_FIELD, field=self, is_next=True + ), ) setattr( - cls, 'get_previous_by_%s' % self.name, - partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=False) + cls, + "get_previous_by_%s" % self.name, + partialmethod( + cls._get_next_or_previous_by_FIELD, field=self, is_next=False + ), ) def get_prep_value(self, value): @@ -1300,25 +1414,33 @@ class DateField(DateTimeCheckMixin, Field): def value_to_string(self, obj): val = self.value_from_object(obj) - return '' if val is None else val.isoformat() + return "" if val is None else val.isoformat() def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.DateField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.DateField, + **kwargs, + } + ) class DateTimeField(DateField): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value has an invalid format. It must be in ' - 'YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format.'), - 'invalid_date': _("“%(value)s” value has the correct format " - "(YYYY-MM-DD) but it is an invalid date."), - 'invalid_datetime': _('“%(value)s” value has the correct format ' - '(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) ' - 'but it is an invalid date/time.'), + "invalid": _( + "“%(value)s” value has an invalid format. It must be in " + "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format." + ), + "invalid_date": _( + "“%(value)s” value has the correct format " + "(YYYY-MM-DD) but it is an invalid date." + ), + "invalid_datetime": _( + "“%(value)s” value has the correct format " + "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " + "but it is an invalid date/time." + ), } description = _("Date (with time)") @@ -1353,10 +1475,12 @@ class DateTimeField(DateField): # local time. This won't work during DST change, but we can't # do much about it, so we let the exceptions percolate up the # call stack. - warnings.warn("DateTimeField %s.%s received a naive datetime " - "(%s) while time zone support is active." % - (self.model.__name__, self.name, value), - RuntimeWarning) + warnings.warn( + "DateTimeField %s.%s received a naive datetime " + "(%s) while time zone support is active." + % (self.model.__name__, self.name, value), + RuntimeWarning, + ) default_timezone = timezone.get_default_timezone() value = timezone.make_aware(value, default_timezone) return value @@ -1367,9 +1491,9 @@ class DateTimeField(DateField): return parsed except ValueError: raise exceptions.ValidationError( - self.error_messages['invalid_datetime'], - code='invalid_datetime', - params={'value': value}, + self.error_messages["invalid_datetime"], + code="invalid_datetime", + params={"value": value}, ) try: @@ -1378,15 +1502,15 @@ class DateTimeField(DateField): return datetime.datetime(parsed.year, parsed.month, parsed.day) except ValueError: raise exceptions.ValidationError( - self.error_messages['invalid_date'], - code='invalid_date', - params={'value': value}, + self.error_messages["invalid_date"], + code="invalid_date", + params={"value": value}, ) raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def pre_save(self, model_instance, add): @@ -1408,13 +1532,14 @@ class DateTimeField(DateField): # time. This won't work during DST change, but we can't do much # about it, so we let the exceptions percolate up the call stack. try: - name = '%s.%s' % (self.model.__name__, self.name) + name = "%s.%s" % (self.model.__name__, self.name) except AttributeError: - name = '(unbound)' - warnings.warn("DateTimeField %s received a naive datetime (%s)" - " while time zone support is active." % - (name, value), - RuntimeWarning) + name = "(unbound)" + warnings.warn( + "DateTimeField %s received a naive datetime (%s)" + " while time zone support is active." % (name, value), + RuntimeWarning, + ) default_timezone = timezone.get_default_timezone() value = timezone.make_aware(value, default_timezone) return value @@ -1427,24 +1552,32 @@ class DateTimeField(DateField): def value_to_string(self, obj): val = self.value_from_object(obj) - return '' if val is None else val.isoformat() + return "" if val is None else val.isoformat() def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.DateTimeField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.DateTimeField, + **kwargs, + } + ) class DecimalField(Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value must be a decimal number.'), + "invalid": _("“%(value)s” value must be a decimal number."), } description = _("Decimal number") - def __init__(self, verbose_name=None, name=None, max_digits=None, - decimal_places=None, **kwargs): + def __init__( + self, + verbose_name=None, + name=None, + max_digits=None, + decimal_places=None, + **kwargs, + ): self.max_digits, self.decimal_places = max_digits, decimal_places super().__init__(verbose_name, name, **kwargs) @@ -1471,7 +1604,7 @@ class DecimalField(Field): checks.Error( "DecimalFields must define a 'decimal_places' attribute.", obj=self, - id='fields.E130', + id="fields.E130", ) ] except ValueError: @@ -1479,7 +1612,7 @@ class DecimalField(Field): checks.Error( "'decimal_places' must be a non-negative integer.", obj=self, - id='fields.E131', + id="fields.E131", ) ] else: @@ -1495,7 +1628,7 @@ class DecimalField(Field): checks.Error( "DecimalFields must define a 'max_digits' attribute.", obj=self, - id='fields.E132', + id="fields.E132", ) ] except ValueError: @@ -1503,7 +1636,7 @@ class DecimalField(Field): checks.Error( "'max_digits' must be a positive integer.", obj=self, - id='fields.E133', + id="fields.E133", ) ] else: @@ -1515,7 +1648,7 @@ class DecimalField(Field): checks.Error( "'max_digits' must be greater or equal to 'decimal_places'.", obj=self, - id='fields.E134', + id="fields.E134", ) ] return [] @@ -1533,9 +1666,9 @@ class DecimalField(Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.max_digits is not None: - kwargs['max_digits'] = self.max_digits + kwargs["max_digits"] = self.max_digits if self.decimal_places is not None: - kwargs['decimal_places'] = self.decimal_places + kwargs["decimal_places"] = self.decimal_places return name, path, args, kwargs def get_internal_type(self): @@ -1547,34 +1680,38 @@ class DecimalField(Field): if isinstance(value, float): if math.isnan(value): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) return self.context.create_decimal_from_float(value) try: return decimal.Decimal(value) except (decimal.InvalidOperation, TypeError, ValueError): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def get_db_prep_save(self, value, connection): - return connection.ops.adapt_decimalfield_value(self.to_python(value), self.max_digits, self.decimal_places) + return connection.ops.adapt_decimalfield_value( + self.to_python(value), self.max_digits, self.decimal_places + ) def get_prep_value(self, value): value = super().get_prep_value(value) return self.to_python(value) def formfield(self, **kwargs): - return super().formfield(**{ - 'max_digits': self.max_digits, - 'decimal_places': self.decimal_places, - 'form_class': forms.DecimalField, - **kwargs, - }) + return super().formfield( + **{ + "max_digits": self.max_digits, + "decimal_places": self.decimal_places, + "form_class": forms.DecimalField, + **kwargs, + } + ) class DurationField(Field): @@ -1584,10 +1721,13 @@ class DurationField(Field): Use interval on PostgreSQL, INTERVAL DAY TO SECOND on Oracle, and bigint of microseconds on other databases. """ + empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value has an invalid format. It must be in ' - '[DD] [[HH:]MM:]ss[.uuuuuu] format.') + "invalid": _( + "“%(value)s” value has an invalid format. It must be in " + "[DD] [[HH:]MM:]ss[.uuuuuu] format." + ) } description = _("Duration") @@ -1608,9 +1748,9 @@ class DurationField(Field): return parsed raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def get_db_prep_value(self, value, connection, prepared=False): @@ -1628,13 +1768,15 @@ class DurationField(Field): def value_to_string(self, obj): val = self.value_from_object(obj) - return '' if val is None else duration_string(val) + return "" if val is None else duration_string(val) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.DurationField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.DurationField, + **kwargs, + } + ) class EmailField(CharField): @@ -1643,7 +1785,7 @@ class EmailField(CharField): def __init__(self, *args, **kwargs): # max_length=254 to be compliant with RFCs 3696 and 5321 - kwargs.setdefault('max_length', 254) + kwargs.setdefault("max_length", 254) super().__init__(*args, **kwargs) def deconstruct(self): @@ -1655,20 +1797,31 @@ class EmailField(CharField): def formfield(self, **kwargs): # As with CharField, this will cause email validation to be performed # twice. - return super().formfield(**{ - 'form_class': forms.EmailField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.EmailField, + **kwargs, + } + ) class FilePathField(Field): description = _("File path") - def __init__(self, verbose_name=None, name=None, path='', match=None, - recursive=False, allow_files=True, allow_folders=False, **kwargs): + def __init__( + self, + verbose_name=None, + name=None, + path="", + match=None, + recursive=False, + allow_files=True, + allow_folders=False, + **kwargs, + ): self.path, self.match, self.recursive = path, match, recursive self.allow_files, self.allow_folders = allow_files, allow_folders - kwargs.setdefault('max_length', 100) + kwargs.setdefault("max_length", 100) super().__init__(verbose_name, name, **kwargs) def check(self, **kwargs): @@ -1683,23 +1836,23 @@ class FilePathField(Field): checks.Error( "FilePathFields must have either 'allow_files' or 'allow_folders' set to True.", obj=self, - id='fields.E140', + id="fields.E140", ) ] return [] def deconstruct(self): name, path, args, kwargs = super().deconstruct() - if self.path != '': - kwargs['path'] = self.path + if self.path != "": + kwargs["path"] = self.path if self.match is not None: - kwargs['match'] = self.match + kwargs["match"] = self.match if self.recursive is not False: - kwargs['recursive'] = self.recursive + kwargs["recursive"] = self.recursive if self.allow_files is not True: - kwargs['allow_files'] = self.allow_files + kwargs["allow_files"] = self.allow_files if self.allow_folders is not False: - kwargs['allow_folders'] = self.allow_folders + kwargs["allow_folders"] = self.allow_folders if kwargs.get("max_length") == 100: del kwargs["max_length"] return name, path, args, kwargs @@ -1711,15 +1864,17 @@ class FilePathField(Field): return str(value) def formfield(self, **kwargs): - return super().formfield(**{ - 'path': self.path() if callable(self.path) else self.path, - 'match': self.match, - 'recursive': self.recursive, - 'form_class': forms.FilePathField, - 'allow_files': self.allow_files, - 'allow_folders': self.allow_folders, - **kwargs, - }) + return super().formfield( + **{ + "path": self.path() if callable(self.path) else self.path, + "match": self.match, + "recursive": self.recursive, + "form_class": forms.FilePathField, + "allow_files": self.allow_files, + "allow_folders": self.allow_folders, + **kwargs, + } + ) def get_internal_type(self): return "FilePathField" @@ -1728,7 +1883,7 @@ class FilePathField(Field): class FloatField(Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value must be a float.'), + "invalid": _("“%(value)s” value must be a float."), } description = _("Floating point number") @@ -1753,22 +1908,24 @@ class FloatField(Field): return float(value) except (TypeError, ValueError): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.FloatField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.FloatField, + **kwargs, + } + ) class IntegerField(Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value must be an integer.'), + "invalid": _("“%(value)s” value must be an integer."), } description = _("Integer") @@ -1782,10 +1939,11 @@ class IntegerField(Field): if self.max_length is not None: return [ checks.Warning( - "'max_length' is ignored when used with %s." % self.__class__.__name__, + "'max_length' is ignored when used with %s." + % self.__class__.__name__, hint="Remove 'max_length' from field", obj=self, - id='fields.W122', + id="fields.W122", ) ] return [] @@ -1799,22 +1957,28 @@ class IntegerField(Field): min_value, max_value = connection.ops.integer_field_range(internal_type) if min_value is not None and not any( ( - isinstance(validator, validators.MinValueValidator) and ( + isinstance(validator, validators.MinValueValidator) + and ( validator.limit_value() if callable(validator.limit_value) else validator.limit_value - ) >= min_value - ) for validator in validators_ + ) + >= min_value + ) + for validator in validators_ ): validators_.append(validators.MinValueValidator(min_value)) if max_value is not None and not any( ( - isinstance(validator, validators.MaxValueValidator) and ( + isinstance(validator, validators.MaxValueValidator) + and ( validator.limit_value() if callable(validator.limit_value) else validator.limit_value - ) <= max_value - ) for validator in validators_ + ) + <= max_value + ) + for validator in validators_ ): validators_.append(validators.MaxValueValidator(max_value)) return validators_ @@ -1840,16 +2004,18 @@ class IntegerField(Field): return int(value) except (TypeError, ValueError): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.IntegerField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.IntegerField, + **kwargs, + } + ) class BigIntegerField(IntegerField): @@ -1860,39 +2026,41 @@ class BigIntegerField(IntegerField): return "BigIntegerField" def formfield(self, **kwargs): - return super().formfield(**{ - 'min_value': -BigIntegerField.MAX_BIGINT - 1, - 'max_value': BigIntegerField.MAX_BIGINT, - **kwargs, - }) + return super().formfield( + **{ + "min_value": -BigIntegerField.MAX_BIGINT - 1, + "max_value": BigIntegerField.MAX_BIGINT, + **kwargs, + } + ) class SmallIntegerField(IntegerField): - description = _('Small integer') + description = _("Small integer") def get_internal_type(self): - return 'SmallIntegerField' + return "SmallIntegerField" class IPAddressField(Field): empty_strings_allowed = False description = _("IPv4 address") system_check_removed_details = { - 'msg': ( - 'IPAddressField has been removed except for support in ' - 'historical migrations.' + "msg": ( + "IPAddressField has been removed except for support in " + "historical migrations." ), - 'hint': 'Use GenericIPAddressField instead.', - 'id': 'fields.E900', + "hint": "Use GenericIPAddressField instead.", + "id": "fields.E900", } def __init__(self, *args, **kwargs): - kwargs['max_length'] = 15 + kwargs["max_length"] = 15 super().__init__(*args, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs['max_length'] + del kwargs["max_length"] return name, path, args, kwargs def get_prep_value(self, value): @@ -1910,14 +2078,23 @@ class GenericIPAddressField(Field): description = _("IP address") default_error_messages = {} - def __init__(self, verbose_name=None, name=None, protocol='both', - unpack_ipv4=False, *args, **kwargs): + def __init__( + self, + verbose_name=None, + name=None, + protocol="both", + unpack_ipv4=False, + *args, + **kwargs, + ): self.unpack_ipv4 = unpack_ipv4 self.protocol = protocol - self.default_validators, invalid_error_message = \ - validators.ip_address_validators(protocol, unpack_ipv4) - self.default_error_messages['invalid'] = invalid_error_message - kwargs['max_length'] = 39 + ( + self.default_validators, + invalid_error_message, + ) = validators.ip_address_validators(protocol, unpack_ipv4) + self.default_error_messages["invalid"] = invalid_error_message + kwargs["max_length"] = 39 super().__init__(verbose_name, name, *args, **kwargs) def check(self, **kwargs): @@ -1927,13 +2104,13 @@ class GenericIPAddressField(Field): ] def _check_blank_and_null_values(self, **kwargs): - if not getattr(self, 'null', False) and getattr(self, 'blank', False): + if not getattr(self, "null", False) and getattr(self, "blank", False): return [ checks.Error( - 'GenericIPAddressFields cannot have blank=True if null=False, ' - 'as blank values are stored as nulls.', + "GenericIPAddressFields cannot have blank=True if null=False, " + "as blank values are stored as nulls.", obj=self, - id='fields.E150', + id="fields.E150", ) ] return [] @@ -1941,11 +2118,11 @@ class GenericIPAddressField(Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.unpack_ipv4 is not False: - kwargs['unpack_ipv4'] = self.unpack_ipv4 + kwargs["unpack_ipv4"] = self.unpack_ipv4 if self.protocol != "both": - kwargs['protocol'] = self.protocol + kwargs["protocol"] = self.protocol if kwargs.get("max_length") == 39: - del kwargs['max_length'] + del kwargs["max_length"] return name, path, args, kwargs def get_internal_type(self): @@ -1957,8 +2134,10 @@ class GenericIPAddressField(Field): if not isinstance(value, str): value = str(value) value = value.strip() - if ':' in value: - return clean_ipv6_address(value, self.unpack_ipv4, self.error_messages['invalid']) + if ":" in value: + return clean_ipv6_address( + value, self.unpack_ipv4, self.error_messages["invalid"] + ) return value def get_db_prep_value(self, value, connection, prepared=False): @@ -1970,7 +2149,7 @@ class GenericIPAddressField(Field): value = super().get_prep_value(value) if value is None: return None - if value and ':' in value: + if value and ":" in value: try: return clean_ipv6_address(value, self.unpack_ipv4) except exceptions.ValidationError: @@ -1978,44 +2157,46 @@ class GenericIPAddressField(Field): return str(value) def formfield(self, **kwargs): - return super().formfield(**{ - 'protocol': self.protocol, - 'form_class': forms.GenericIPAddressField, - **kwargs, - }) + return super().formfield( + **{ + "protocol": self.protocol, + "form_class": forms.GenericIPAddressField, + **kwargs, + } + ) class NullBooleanField(BooleanField): default_error_messages = { - 'invalid': _('“%(value)s” value must be either None, True or False.'), - 'invalid_nullable': _('“%(value)s” value must be either None, True or False.'), + "invalid": _("“%(value)s” value must be either None, True or False."), + "invalid_nullable": _("“%(value)s” value must be either None, True or False."), } description = _("Boolean (Either True, False or None)") system_check_removed_details = { - 'msg': ( - 'NullBooleanField is removed except for support in historical ' - 'migrations.' + "msg": ( + "NullBooleanField is removed except for support in historical " + "migrations." ), - 'hint': 'Use BooleanField(null=True) instead.', - 'id': 'fields.E903', + "hint": "Use BooleanField(null=True) instead.", + "id": "fields.E903", } def __init__(self, *args, **kwargs): - kwargs['null'] = True - kwargs['blank'] = True + kwargs["null"] = True + kwargs["blank"] = True super().__init__(*args, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs['null'] - del kwargs['blank'] + del kwargs["null"] + del kwargs["blank"] return name, path, args, kwargs class PositiveIntegerRelDbTypeMixin: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if not hasattr(cls, 'integer_field_class'): + if not hasattr(cls, "integer_field_class"): cls.integer_field_class = next( ( parent @@ -2041,16 +2222,18 @@ class PositiveIntegerRelDbTypeMixin: class PositiveBigIntegerField(PositiveIntegerRelDbTypeMixin, BigIntegerField): - description = _('Positive big integer') + description = _("Positive big integer") def get_internal_type(self): - return 'PositiveBigIntegerField' + return "PositiveBigIntegerField" def formfield(self, **kwargs): - return super().formfield(**{ - 'min_value': 0, - **kwargs, - }) + return super().formfield( + **{ + "min_value": 0, + **kwargs, + } + ) class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): @@ -2060,10 +2243,12 @@ class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): return "PositiveIntegerField" def formfield(self, **kwargs): - return super().formfield(**{ - 'min_value': 0, - **kwargs, - }) + return super().formfield( + **{ + "min_value": 0, + **kwargs, + } + ) class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField): @@ -2073,17 +2258,21 @@ class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, SmallIntegerField return "PositiveSmallIntegerField" def formfield(self, **kwargs): - return super().formfield(**{ - 'min_value': 0, - **kwargs, - }) + return super().formfield( + **{ + "min_value": 0, + **kwargs, + } + ) class SlugField(CharField): default_validators = [validators.validate_slug] description = _("Slug (up to %(max_length)s)") - def __init__(self, *args, max_length=50, db_index=True, allow_unicode=False, **kwargs): + def __init__( + self, *args, max_length=50, db_index=True, allow_unicode=False, **kwargs + ): self.allow_unicode = allow_unicode if self.allow_unicode: self.default_validators = [validators.validate_unicode_slug] @@ -2092,24 +2281,26 @@ class SlugField(CharField): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if kwargs.get("max_length") == 50: - del kwargs['max_length'] + del kwargs["max_length"] if self.db_index is False: - kwargs['db_index'] = False + kwargs["db_index"] = False else: - del kwargs['db_index'] + del kwargs["db_index"] if self.allow_unicode is not False: - kwargs['allow_unicode'] = self.allow_unicode + kwargs["allow_unicode"] = self.allow_unicode return name, path, args, kwargs def get_internal_type(self): return "SlugField" def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.SlugField, - 'allow_unicode': self.allow_unicode, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.SlugField, + "allow_unicode": self.allow_unicode, + **kwargs, + } + ) class TextField(Field): @@ -2120,7 +2311,7 @@ class TextField(Field): self.db_collation = db_collation def check(self, **kwargs): - databases = kwargs.get('databases') or [] + databases = kwargs.get("databases") or [] return [ *super().check(**kwargs), *self._check_db_collation(databases), @@ -2133,16 +2324,17 @@ class TextField(Field): continue connection = connections[db] if not ( - self.db_collation is None or - 'supports_collation_on_textfield' in self.model._meta.required_db_features or - connection.features.supports_collation_on_textfield + self.db_collation is None + or "supports_collation_on_textfield" + in self.model._meta.required_db_features + or connection.features.supports_collation_on_textfield ): errors.append( checks.Error( - '%s does not support a database collation on ' - 'TextFields.' % connection.display_name, + "%s does not support a database collation on " + "TextFields." % connection.display_name, obj=self, - id='fields.E190', + id="fields.E190", ), ) return errors @@ -2163,35 +2355,42 @@ class TextField(Field): # Passing max_length to forms.CharField means that the value's length # will be validated twice. This is considered acceptable since we want # the value in the form field (to pass into widget for example). - return super().formfield(**{ - 'max_length': self.max_length, - **({} if self.choices is not None else {'widget': forms.Textarea}), - **kwargs, - }) + return super().formfield( + **{ + "max_length": self.max_length, + **({} if self.choices is not None else {"widget": forms.Textarea}), + **kwargs, + } + ) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.db_collation: - kwargs['db_collation'] = self.db_collation + kwargs["db_collation"] = self.db_collation return name, path, args, kwargs class TimeField(DateTimeCheckMixin, Field): empty_strings_allowed = False default_error_messages = { - 'invalid': _('“%(value)s” value has an invalid format. It must be in ' - 'HH:MM[:ss[.uuuuuu]] format.'), - 'invalid_time': _('“%(value)s” value has the correct format ' - '(HH:MM[:ss[.uuuuuu]]) but it is an invalid time.'), + "invalid": _( + "“%(value)s” value has an invalid format. It must be in " + "HH:MM[:ss[.uuuuuu]] format." + ), + "invalid_time": _( + "“%(value)s” value has the correct format " + "(HH:MM[:ss[.uuuuuu]]) but it is an invalid time." + ), } description = _("Time") - def __init__(self, verbose_name=None, name=None, auto_now=False, - auto_now_add=False, **kwargs): + def __init__( + self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs + ): self.auto_now, self.auto_now_add = auto_now, auto_now_add if auto_now or auto_now_add: - kwargs['editable'] = False - kwargs['blank'] = True + kwargs["editable"] = False + kwargs["blank"] = True super().__init__(verbose_name, name, **kwargs) def _check_fix_default_value(self): @@ -2223,8 +2422,8 @@ class TimeField(DateTimeCheckMixin, Field): if self.auto_now_add is not False: kwargs["auto_now_add"] = self.auto_now_add if self.auto_now or self.auto_now_add: - del kwargs['blank'] - del kwargs['editable'] + del kwargs["blank"] + del kwargs["editable"] return name, path, args, kwargs def get_internal_type(self): @@ -2247,15 +2446,15 @@ class TimeField(DateTimeCheckMixin, Field): return parsed except ValueError: raise exceptions.ValidationError( - self.error_messages['invalid_time'], - code='invalid_time', - params={'value': value}, + self.error_messages["invalid_time"], + code="invalid_time", + params={"value": value}, ) raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def pre_save(self, model_instance, add): @@ -2278,13 +2477,15 @@ class TimeField(DateTimeCheckMixin, Field): def value_to_string(self, obj): val = self.value_from_object(obj) - return '' if val is None else val.isoformat() + return "" if val is None else val.isoformat() def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.TimeField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.TimeField, + **kwargs, + } + ) class URLField(CharField): @@ -2292,30 +2493,32 @@ class URLField(CharField): description = _("URL") def __init__(self, verbose_name=None, name=None, **kwargs): - kwargs.setdefault('max_length', 200) + kwargs.setdefault("max_length", 200) super().__init__(verbose_name, name, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if kwargs.get("max_length") == 200: - del kwargs['max_length'] + del kwargs["max_length"] return name, path, args, kwargs def formfield(self, **kwargs): # As with CharField, this will cause URL validation to be performed # twice. - return super().formfield(**{ - 'form_class': forms.URLField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.URLField, + **kwargs, + } + ) class BinaryField(Field): description = _("Raw binary data") - empty_values = [None, b''] + empty_values = [None, b""] def __init__(self, *args, **kwargs): - kwargs.setdefault('editable', False) + kwargs.setdefault("editable", False) super().__init__(*args, **kwargs) if self.max_length is not None: self.validators.append(validators.MaxLengthValidator(self.max_length)) @@ -2330,7 +2533,7 @@ class BinaryField(Field): "BinaryField's default cannot be a string. Use bytes " "content instead.", obj=self, - id='fields.E170', + id="fields.E170", ) ] return [] @@ -2338,9 +2541,9 @@ class BinaryField(Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.editable: - kwargs['editable'] = True + kwargs["editable"] = True else: - del kwargs['editable'] + del kwargs["editable"] return name, path, args, kwargs def get_internal_type(self): @@ -2353,8 +2556,8 @@ class BinaryField(Field): if self.has_default() and not callable(self.default): return self.default default = super().get_default() - if default == '': - return b'' + if default == "": + return b"" return default def get_db_prep_value(self, value, connection, prepared=False): @@ -2365,29 +2568,29 @@ class BinaryField(Field): def value_to_string(self, obj): """Binary data is serialized as base64""" - return b64encode(self.value_from_object(obj)).decode('ascii') + return b64encode(self.value_from_object(obj)).decode("ascii") def to_python(self, value): # If it's a string, it should be base64-encoded data if isinstance(value, str): - return memoryview(b64decode(value.encode('ascii'))) + return memoryview(b64decode(value.encode("ascii"))) return value class UUIDField(Field): default_error_messages = { - 'invalid': _('“%(value)s” is not a valid UUID.'), + "invalid": _("“%(value)s” is not a valid UUID."), } - description = _('Universally unique identifier') + description = _("Universally unique identifier") empty_strings_allowed = False def __init__(self, verbose_name=None, **kwargs): - kwargs['max_length'] = 32 + kwargs["max_length"] = 32 super().__init__(verbose_name, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs['max_length'] + del kwargs["max_length"] return name, path, args, kwargs def get_internal_type(self): @@ -2409,29 +2612,31 @@ class UUIDField(Field): def to_python(self, value): if value is not None and not isinstance(value, uuid.UUID): - input_form = 'int' if isinstance(value, int) else 'hex' + input_form = "int" if isinstance(value, int) else "hex" try: return uuid.UUID(**{input_form: value}) except (AttributeError, ValueError): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) return value def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.UUIDField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.UUIDField, + **kwargs, + } + ) class AutoFieldMixin: db_returning = True def __init__(self, *args, **kwargs): - kwargs['blank'] = True + kwargs["blank"] = True super().__init__(*args, **kwargs) def check(self, **kwargs): @@ -2444,9 +2649,9 @@ class AutoFieldMixin: if not self.primary_key: return [ checks.Error( - 'AutoFields must set primary_key=True.', + "AutoFields must set primary_key=True.", obj=self, - id='fields.E100', + id="fields.E100", ), ] else: @@ -2454,8 +2659,8 @@ class AutoFieldMixin: def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs['blank'] - kwargs['primary_key'] = True + del kwargs["blank"] + kwargs["primary_key"] = True return name, path, args, kwargs def validate(self, value, model_instance): @@ -2502,34 +2707,35 @@ class AutoFieldMeta(type): return (BigAutoField, SmallAutoField) def __instancecheck__(self, instance): - return isinstance(instance, self._subclasses) or super().__instancecheck__(instance) + return isinstance(instance, self._subclasses) or super().__instancecheck__( + instance + ) def __subclasscheck__(self, subclass): - return issubclass(subclass, self._subclasses) or super().__subclasscheck__(subclass) + return issubclass(subclass, self._subclasses) or super().__subclasscheck__( + subclass + ) class AutoField(AutoFieldMixin, IntegerField, metaclass=AutoFieldMeta): - def get_internal_type(self): - return 'AutoField' + return "AutoField" def rel_db_type(self, connection): return IntegerField().db_type(connection=connection) class BigAutoField(AutoFieldMixin, BigIntegerField): - def get_internal_type(self): - return 'BigAutoField' + return "BigAutoField" def rel_db_type(self, connection): return BigIntegerField().db_type(connection=connection) class SmallAutoField(AutoFieldMixin, SmallIntegerField): - def get_internal_type(self): - return 'SmallAutoField' + return "SmallAutoField" def rel_db_type(self, connection): return SmallIntegerField().db_type(connection=connection) diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 18900f7b85..33a1176ed6 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -24,7 +24,7 @@ class FieldFile(File): def __eq__(self, other): # Older code may be expecting FileField values to be simple strings. # By overriding the == operator, it can remain backwards compatibility. - if hasattr(other, 'name'): + if hasattr(other, "name"): return self.name == other.name return self.name == other @@ -37,12 +37,14 @@ class FieldFile(File): def _require_file(self): if not self: - raise ValueError("The '%s' attribute has no file associated with it." % self.field.name) + raise ValueError( + "The '%s' attribute has no file associated with it." % self.field.name + ) def _get_file(self): self._require_file() - if getattr(self, '_file', None) is None: - self._file = self.storage.open(self.name, 'rb') + if getattr(self, "_file", None) is None: + self._file = self.storage.open(self.name, "rb") return self._file def _set_file(self, file): @@ -70,13 +72,14 @@ class FieldFile(File): return self.file.size return self.storage.size(self.name) - def open(self, mode='rb'): + def open(self, mode="rb"): self._require_file() - if getattr(self, '_file', None) is None: + if getattr(self, "_file", None) is None: self.file = self.storage.open(self.name, mode) else: self.file.open(mode) return self + # open() doesn't alter the file's contents, but it does reset the pointer open.alters_data = True @@ -93,6 +96,7 @@ class FieldFile(File): # Save the object because it has changed, unless save is False if save: self.instance.save() + save.alters_data = True def delete(self, save=True): @@ -100,7 +104,7 @@ class FieldFile(File): return # Only close the file if it's already open, which we know by the # presence of self._file - if hasattr(self, '_file'): + if hasattr(self, "_file"): self.close() del self.file @@ -112,15 +116,16 @@ class FieldFile(File): if save: self.instance.save() + delete.alters_data = True @property def closed(self): - file = getattr(self, '_file', None) + file = getattr(self, "_file", None) return file is None or file.closed def close(self): - file = getattr(self, '_file', None) + file = getattr(self, "_file", None) if file is not None: file.close() @@ -129,12 +134,12 @@ class FieldFile(File): # the file's name. Everything else will be restored later, by # FileDescriptor below. return { - 'name': self.name, - 'closed': False, - '_committed': True, - '_file': None, - 'instance': self.instance, - 'field': self.field, + "name": self.name, + "closed": False, + "_committed": True, + "_file": None, + "instance": self.instance, + "field": self.field, } def __setstate__(self, state): @@ -156,6 +161,7 @@ class FileDescriptor(DeferredAttribute): >>> with open('/path/to/hello.world') as f: ... instance.file = File(f) """ + def __get__(self, instance, cls=None): if instance is None: return self @@ -198,7 +204,7 @@ class FileDescriptor(DeferredAttribute): # Finally, because of the (some would say boneheaded) way pickle works, # the underlying FieldFile might not actually itself have an associated # file. So we need to reset the details of the FieldFile in those cases. - elif isinstance(file, FieldFile) and not hasattr(file, 'field'): + elif isinstance(file, FieldFile) and not hasattr(file, "field"): file.instance = instance file.field = self.field file.storage = self.field.storage @@ -225,8 +231,10 @@ class FileField(Field): description = _("File") - def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs): - self._primary_key_set_explicitly = 'primary_key' in kwargs + def __init__( + self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs + ): + self._primary_key_set_explicitly = "primary_key" in kwargs self.storage = storage or default_storage if callable(self.storage): @@ -236,11 +244,15 @@ class FileField(Field): if not isinstance(self.storage, Storage): raise TypeError( "%s.storage must be a subclass/instance of %s.%s" - % (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__) + % ( + self.__class__.__qualname__, + Storage.__module__, + Storage.__qualname__, + ) ) self.upload_to = upload_to - kwargs.setdefault('max_length', 100) + kwargs.setdefault("max_length", 100) super().__init__(verbose_name, name, **kwargs) def check(self, **kwargs): @@ -254,23 +266,24 @@ class FileField(Field): if self._primary_key_set_explicitly: return [ checks.Error( - "'primary_key' is not a valid argument for a %s." % self.__class__.__name__, + "'primary_key' is not a valid argument for a %s." + % self.__class__.__name__, obj=self, - id='fields.E201', + id="fields.E201", ) ] else: return [] def _check_upload_to(self): - if isinstance(self.upload_to, str) and self.upload_to.startswith('/'): + if isinstance(self.upload_to, str) and self.upload_to.startswith("/"): return [ checks.Error( "%s's 'upload_to' argument must be a relative path, not an " "absolute path." % self.__class__.__name__, obj=self, - id='fields.E202', - hint='Remove the leading slash.', + id="fields.E202", + hint="Remove the leading slash.", ) ] else: @@ -280,9 +293,9 @@ class FileField(Field): name, path, args, kwargs = super().deconstruct() if kwargs.get("max_length") == 100: del kwargs["max_length"] - kwargs['upload_to'] = self.upload_to + kwargs["upload_to"] = self.upload_to if self.storage is not default_storage: - kwargs['storage'] = getattr(self, '_storage_callable', self.storage) + kwargs["storage"] = getattr(self, "_storage_callable", self.storage) return name, path, args, kwargs def get_internal_type(self): @@ -329,14 +342,16 @@ class FileField(Field): if data is not None: # This value will be converted to str and stored in the # database, so leaving False as-is is not acceptable. - setattr(instance, self.name, data or '') + setattr(instance, self.name, data or "") def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.FileField, - 'max_length': self.max_length, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.FileField, + "max_length": self.max_length, + **kwargs, + } + ) class ImageFileDescriptor(FileDescriptor): @@ -344,6 +359,7 @@ class ImageFileDescriptor(FileDescriptor): Just like the FileDescriptor, but for ImageFields. The only difference is assigning the width/height to the width_field/height_field, if appropriate. """ + def __set__(self, instance, value): previous_file = instance.__dict__.get(self.field.attname) super().__set__(instance, value) @@ -364,7 +380,7 @@ class ImageFileDescriptor(FileDescriptor): class ImageFieldFile(ImageFile, FieldFile): def delete(self, save=True): # Clear the image dimensions cache - if hasattr(self, '_dimensions_cache'): + if hasattr(self, "_dimensions_cache"): del self._dimensions_cache super().delete(save) @@ -374,7 +390,14 @@ class ImageField(FileField): descriptor_class = ImageFileDescriptor description = _("Image") - def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs): + def __init__( + self, + verbose_name=None, + name=None, + width_field=None, + height_field=None, + **kwargs, + ): self.width_field, self.height_field = width_field, height_field super().__init__(verbose_name, name, **kwargs) @@ -390,11 +413,13 @@ class ImageField(FileField): except ImportError: return [ checks.Error( - 'Cannot use ImageField because Pillow is not installed.', - hint=('Get Pillow at https://pypi.org/project/Pillow/ ' - 'or run command "python -m pip install Pillow".'), + "Cannot use ImageField because Pillow is not installed.", + hint=( + "Get Pillow at https://pypi.org/project/Pillow/ " + 'or run command "python -m pip install Pillow".' + ), obj=self, - id='fields.E210', + id="fields.E210", ) ] else: @@ -403,9 +428,9 @@ class ImageField(FileField): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.width_field: - kwargs['width_field'] = self.width_field + kwargs["width_field"] = self.width_field if self.height_field: - kwargs['height_field'] = self.height_field + kwargs["height_field"] = self.height_field return name, path, args, kwargs def contribute_to_class(self, cls, name, **kwargs): @@ -445,9 +470,9 @@ class ImageField(FileField): if not file and not force: return - dimension_fields_filled = not( - (self.width_field and not getattr(instance, self.width_field)) or - (self.height_field and not getattr(instance, self.height_field)) + dimension_fields_filled = not ( + (self.width_field and not getattr(instance, self.width_field)) + or (self.height_field and not getattr(instance, self.height_field)) ) # When both dimension fields have values, we are most likely loading # data from the database or updating an image field that already had @@ -475,7 +500,9 @@ class ImageField(FileField): setattr(instance, self.height_field, height) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.ImageField, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.ImageField, + **kwargs, + } + ) diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index efb4e2f6ed..fdca700c9d 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -10,32 +10,36 @@ from django.utils.translation import gettext_lazy as _ from . import Field from .mixins import CheckFieldDefaultMixin -__all__ = ['JSONField'] +__all__ = ["JSONField"] class JSONField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False - description = _('A JSON object') + description = _("A JSON object") default_error_messages = { - 'invalid': _('Value must be valid JSON.'), + "invalid": _("Value must be valid JSON."), } - _default_hint = ('dict', '{}') + _default_hint = ("dict", "{}") def __init__( - self, verbose_name=None, name=None, encoder=None, decoder=None, + self, + verbose_name=None, + name=None, + encoder=None, + decoder=None, **kwargs, ): if encoder and not callable(encoder): - raise ValueError('The encoder parameter must be a callable object.') + raise ValueError("The encoder parameter must be a callable object.") if decoder and not callable(decoder): - raise ValueError('The decoder parameter must be a callable object.') + raise ValueError("The decoder parameter must be a callable object.") self.encoder = encoder self.decoder = decoder super().__init__(verbose_name, name, **kwargs) def check(self, **kwargs): errors = super().check(**kwargs) - databases = kwargs.get('databases') or [] + databases = kwargs.get("databases") or [] errors.extend(self._check_supported(databases)) return errors @@ -46,20 +50,19 @@ class JSONField(CheckFieldDefaultMixin, Field): continue connection = connections[db] if ( - self.model._meta.required_db_vendor and - self.model._meta.required_db_vendor != connection.vendor + self.model._meta.required_db_vendor + and self.model._meta.required_db_vendor != connection.vendor ): continue if not ( - 'supports_json_field' in self.model._meta.required_db_features or - connection.features.supports_json_field + "supports_json_field" in self.model._meta.required_db_features + or connection.features.supports_json_field ): errors.append( checks.Error( - '%s does not support JSONFields.' - % connection.display_name, + "%s does not support JSONFields." % connection.display_name, obj=self.model, - id='fields.E180', + id="fields.E180", ) ) return errors @@ -67,9 +70,9 @@ class JSONField(CheckFieldDefaultMixin, Field): def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.encoder is not None: - kwargs['encoder'] = self.encoder + kwargs["encoder"] = self.encoder if self.decoder is not None: - kwargs['decoder'] = self.decoder + kwargs["decoder"] = self.decoder return name, path, args, kwargs def from_db_value(self, value, expression, connection): @@ -85,7 +88,7 @@ class JSONField(CheckFieldDefaultMixin, Field): return value def get_internal_type(self): - return 'JSONField' + return "JSONField" def get_prep_value(self, value): if value is None: @@ -104,64 +107,66 @@ class JSONField(CheckFieldDefaultMixin, Field): json.dumps(value, cls=self.encoder) except TypeError: raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def value_to_string(self, obj): return self.value_from_object(obj) def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.JSONField, - 'encoder': self.encoder, - 'decoder': self.decoder, - **kwargs, - }) + return super().formfield( + **{ + "form_class": forms.JSONField, + "encoder": self.encoder, + "decoder": self.decoder, + **kwargs, + } + ) def compile_json_path(key_transforms, include_root=True): - path = ['$'] if include_root else [] + path = ["$"] if include_root else [] for key_transform in key_transforms: try: num = int(key_transform) except ValueError: # non-integer - path.append('.') + path.append(".") path.append(json.dumps(key_transform)) else: - path.append('[%s]' % num) - return ''.join(path) + path.append("[%s]" % num) + return "".join(path) class DataContains(PostgresOperatorLookup): - lookup_name = 'contains' - postgres_operator = '@>' + lookup_name = "contains" + postgres_operator = "@>" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( - 'contains lookup is not supported on this database backend.' + "contains lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) - return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params + return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params class ContainedBy(PostgresOperatorLookup): - lookup_name = 'contained_by' - postgres_operator = '<@' + lookup_name = "contained_by" + postgres_operator = "<@" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( - 'contained_by lookup is not supported on this database backend.' + "contained_by lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(rhs_params) + tuple(lhs_params) - return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params + return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params class HasKeyLookup(PostgresOperatorLookup): @@ -170,11 +175,13 @@ class HasKeyLookup(PostgresOperatorLookup): def as_sql(self, compiler, connection, template=None): # Process JSON path from the left-hand side. if isinstance(self.lhs, KeyTransform): - lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection) + lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( + compiler, connection + ) lhs_json_path = compile_json_path(lhs_key_transforms) else: lhs, lhs_params = self.process_lhs(compiler, connection) - lhs_json_path = '$' + lhs_json_path = "$" sql = template % lhs # Process JSON path from the right-hand side. rhs = self.rhs @@ -186,20 +193,27 @@ class HasKeyLookup(PostgresOperatorLookup): *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) else: rhs_key_transforms = [key] - rhs_params.append('%s%s' % ( - lhs_json_path, - compile_json_path(rhs_key_transforms, include_root=False), - )) + rhs_params.append( + "%s%s" + % ( + lhs_json_path, + compile_json_path(rhs_key_transforms, include_root=False), + ) + ) # Add condition for each key. if self.logical_operator: - sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params)) + sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params)) return sql, tuple(lhs_params) + tuple(rhs_params) def as_mysql(self, compiler, connection): - return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)") + return self.as_sql( + compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" + ) def as_oracle(self, compiler, connection): - sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')") + sql, params = self.as_sql( + compiler, connection, template="JSON_EXISTS(%s, '%%s')" + ) # Add paths directly into SQL because path expressions cannot be passed # as bind variables on Oracle. return sql % tuple(params), [] @@ -213,28 +227,30 @@ class HasKeyLookup(PostgresOperatorLookup): return super().as_postgresql(compiler, connection) def as_sqlite(self, compiler, connection): - return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL') + return self.as_sql( + compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" + ) class HasKey(HasKeyLookup): - lookup_name = 'has_key' - postgres_operator = '?' + lookup_name = "has_key" + postgres_operator = "?" prepare_rhs = False class HasKeys(HasKeyLookup): - lookup_name = 'has_keys' - postgres_operator = '?&' - logical_operator = ' AND ' + lookup_name = "has_keys" + postgres_operator = "?&" + logical_operator = " AND " def get_prep_lookup(self): return [str(item) for item in self.rhs] class HasAnyKeys(HasKeys): - lookup_name = 'has_any_keys' - postgres_operator = '?|' - logical_operator = ' OR ' + lookup_name = "has_any_keys" + postgres_operator = "?|" + logical_operator = " OR " class CaseInsensitiveMixin: @@ -244,16 +260,17 @@ class CaseInsensitiveMixin: Because utf8mb4_bin is a binary collation, comparison of JSON values is case-sensitive. """ + def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) - if connection.vendor == 'mysql': - return 'LOWER(%s)' % lhs, lhs_params + if connection.vendor == "mysql": + return "LOWER(%s)" % lhs, lhs_params return lhs, lhs_params def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) - if connection.vendor == 'mysql': - return 'LOWER(%s)' % rhs, rhs_params + if connection.vendor == "mysql": + return "LOWER(%s)" % rhs, rhs_params return rhs, rhs_params @@ -263,9 +280,9 @@ class JSONExact(lookups.Exact): def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Treat None lookup values as null. - if rhs == '%s' and rhs_params == [None]: - rhs_params = ['null'] - if connection.vendor == 'mysql': + if rhs == "%s" and rhs_params == [None]: + rhs_params = ["null"] + if connection.vendor == "mysql": func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) rhs = rhs % tuple(func) return rhs, rhs_params @@ -285,8 +302,8 @@ JSONField.register_lookup(JSONIContains) class KeyTransform(Transform): - postgres_operator = '->' - postgres_nested_operator = '#>' + postgres_operator = "->" + postgres_nested_operator = "#>" def __init__(self, key_name, *args, **kwargs): super().__init__(*args, **kwargs) @@ -299,41 +316,41 @@ class KeyTransform(Transform): key_transforms.insert(0, previous.key_name) previous = previous.lhs lhs, params = compiler.compile(previous) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": # Escape string-formatting. - key_transforms = [key.replace('%', '%%') for key in key_transforms] + key_transforms = [key.replace("%", "%%") for key in key_transforms] return lhs, params, key_transforms def as_mysql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) - return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,) def as_oracle(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) return ( - "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" % - ((lhs, json_path) * 2) + "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" + % ((lhs, json_path) * 2) ), tuple(params) * 2 def as_postgresql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) if len(key_transforms) > 1: - sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator) + sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator) return sql, tuple(params) + (key_transforms,) try: lookup = int(self.key_name) except ValueError: lookup = self.key_name - return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,) + return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,) def as_sqlite(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) - datatype_values = ','.join([ - repr(datatype) for datatype in connection.ops.jsonfield_datatype_values - ]) + datatype_values = ",".join( + [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] + ) return ( "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" @@ -341,8 +358,8 @@ class KeyTransform(Transform): class KeyTextTransform(KeyTransform): - postgres_operator = '->>' - postgres_nested_operator = '#>>' + postgres_operator = "->>" + postgres_nested_operator = "#>>" class KeyTransformTextLookupMixin: @@ -352,14 +369,16 @@ class KeyTransformTextLookupMixin: key values to text and performing the lookup on the resulting representation. """ + def __init__(self, key_transform, *args, **kwargs): if not isinstance(key_transform, KeyTransform): raise TypeError( - 'Transform should be an instance of KeyTransform in order to ' - 'use this lookup.' + "Transform should be an instance of KeyTransform in order to " + "use this lookup." ) key_text_transform = KeyTextTransform( - key_transform.key_name, *key_transform.source_expressions, + key_transform.key_name, + *key_transform.source_expressions, **key_transform.extra, ) super().__init__(key_text_transform, *args, **kwargs) @@ -376,12 +395,12 @@ class KeyTransformIsNull(lookups.IsNull): return sql, params # Column doesn't have a key or IS NULL. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection) - return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params) + return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params) def as_sqlite(self, compiler, connection): - template = 'JSON_TYPE(%s, %%s) IS NULL' + template = "JSON_TYPE(%s, %%s) IS NULL" if not self.rhs: - template = 'JSON_TYPE(%s, %%s) IS NOT NULL' + template = "JSON_TYPE(%s, %%s) IS NOT NULL" return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql( compiler, connection, @@ -392,26 +411,29 @@ class KeyTransformIsNull(lookups.IsNull): class KeyTransformIn(lookups.In): def resolve_expression_parameter(self, compiler, connection, sql, param): sql, params = super().resolve_expression_parameter( - compiler, connection, sql, param, + compiler, + connection, + sql, + param, ) if ( - not hasattr(param, 'as_sql') and - not connection.features.has_native_json_field + not hasattr(param, "as_sql") + and not connection.features.has_native_json_field ): - if connection.vendor == 'oracle': + if connection.vendor == "oracle": value = json.loads(param) sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" if isinstance(value, (list, dict)): - sql = sql % 'JSON_QUERY' + sql = sql % "JSON_QUERY" else: - sql = sql % 'JSON_VALUE' - elif connection.vendor == 'mysql' or ( - connection.vendor == 'sqlite' and - params[0] not in connection.ops.jsonfield_datatype_values + sql = sql % "JSON_VALUE" + elif connection.vendor == "mysql" or ( + connection.vendor == "sqlite" + and params[0] not in connection.ops.jsonfield_datatype_values ): sql = "JSON_EXTRACT(%s, '$')" - if connection.vendor == 'mysql' and connection.mysql_is_mariadb: - sql = 'JSON_UNQUOTE(%s)' % sql + if connection.vendor == "mysql" and connection.mysql_is_mariadb: + sql = "JSON_UNQUOTE(%s)" % sql return sql, params @@ -420,21 +442,21 @@ class KeyTransformExact(JSONExact): if isinstance(self.rhs, KeyTransform): return super(lookups.Exact, self).process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection) - if connection.vendor == 'oracle': + if connection.vendor == "oracle": func = [] sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" for value in rhs_params: value = json.loads(value) if isinstance(value, (list, dict)): - func.append(sql % 'JSON_QUERY') + func.append(sql % "JSON_QUERY") else: - func.append(sql % 'JSON_VALUE') + func.append(sql % "JSON_VALUE") rhs = rhs % tuple(func) - elif connection.vendor == 'sqlite': + elif connection.vendor == "sqlite": func = [] for value in rhs_params: if value in connection.ops.jsonfield_datatype_values: - func.append('%s') + func.append("%s") else: func.append("JSON_EXTRACT(%s, '$')") rhs = rhs % tuple(func) @@ -442,24 +464,28 @@ class KeyTransformExact(JSONExact): def as_oracle(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) - if rhs_params == ['null']: + if rhs_params == ["null"]: # Field has key and it's NULL. has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name) has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) - is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True) + is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True) is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) return ( - '%s AND %s' % (has_key_sql, is_null_sql), + "%s AND %s" % (has_key_sql, is_null_sql), tuple(has_key_params) + tuple(is_null_params), ) return super().as_sql(compiler, connection) -class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact): +class KeyTransformIExact( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact +): pass -class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains): +class KeyTransformIContains( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains +): pass @@ -467,7 +493,9 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): pass -class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith): +class KeyTransformIStartsWith( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith +): pass @@ -475,7 +503,9 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): pass -class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith): +class KeyTransformIEndsWith( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith +): pass @@ -483,7 +513,9 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): pass -class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex): +class KeyTransformIRegex( + CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex +): pass @@ -530,7 +562,6 @@ KeyTransform.register_lookup(KeyTransformGte) class KeyTransformFactory: - def __init__(self, key_name): self.key_name = key_name diff --git a/django/db/models/fields/mixins.py b/django/db/models/fields/mixins.py index 3afa8d9304..e7f282210e 100644 --- a/django/db/models/fields/mixins.py +++ b/django/db/models/fields/mixins.py @@ -29,22 +29,25 @@ class FieldCacheMixin: class CheckFieldDefaultMixin: - _default_hint = ('<valid default>', '<invalid default>') + _default_hint = ("<valid default>", "<invalid default>") def _check_default(self): - if self.has_default() and self.default is not None and not callable(self.default): + if ( + self.has_default() + and self.default is not None + and not callable(self.default) + ): return [ checks.Warning( "%s default should be a callable instead of an instance " - "so that it's not shared between all field instances." % ( - self.__class__.__name__, - ), + "so that it's not shared between all field instances." + % (self.__class__.__name__,), hint=( - 'Use a callable instead, e.g., use `%s` instead of ' - '`%s`.' % self._default_hint + "Use a callable instead, e.g., use `%s` instead of " + "`%s`." % self._default_hint ), obj=self, - id='fields.E010', + id="fields.E010", ) ] else: diff --git a/django/db/models/fields/proxy.py b/django/db/models/fields/proxy.py index 0ecf04a333..ac02e47a25 100644 --- a/django/db/models/fields/proxy.py +++ b/django/db/models/fields/proxy.py @@ -13,6 +13,6 @@ class OrderWrt(fields.IntegerField): """ def __init__(self, *args, **kwargs): - kwargs['name'] = '_order' - kwargs['editable'] = False + kwargs["name"] = "_order" + kwargs["editable"] = False super().__init__(*args, **kwargs) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 11407ac902..1cf447c6d4 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -19,19 +19,25 @@ from django.utils.translation import gettext_lazy as _ from . import Field from .mixins import FieldCacheMixin from .related_descriptors import ( - ForeignKeyDeferredAttribute, ForwardManyToOneDescriptor, - ForwardOneToOneDescriptor, ManyToManyDescriptor, - ReverseManyToOneDescriptor, ReverseOneToOneDescriptor, + ForeignKeyDeferredAttribute, + ForwardManyToOneDescriptor, + ForwardOneToOneDescriptor, + ManyToManyDescriptor, + ReverseManyToOneDescriptor, + ReverseOneToOneDescriptor, ) from .related_lookups import ( - RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn, - RelatedIsNull, RelatedLessThan, RelatedLessThanOrEqual, -) -from .reverse_related import ( - ForeignObjectRel, ManyToManyRel, ManyToOneRel, OneToOneRel, + RelatedExact, + RelatedGreaterThan, + RelatedGreaterThanOrEqual, + RelatedIn, + RelatedIsNull, + RelatedLessThan, + RelatedLessThanOrEqual, ) +from .reverse_related import ForeignObjectRel, ManyToManyRel, ManyToOneRel, OneToOneRel -RECURSIVE_RELATIONSHIP_CONSTANT = 'self' +RECURSIVE_RELATIONSHIP_CONSTANT = "self" def resolve_relation(scope_model, relation): @@ -119,19 +125,25 @@ class RelatedField(FieldCacheMixin, Field): def _check_related_name_is_valid(self): import keyword + related_name = self.remote_field.related_name if related_name is None: return [] - is_valid_id = not keyword.iskeyword(related_name) and related_name.isidentifier() - if not (is_valid_id or related_name.endswith('+')): + is_valid_id = ( + not keyword.iskeyword(related_name) and related_name.isidentifier() + ) + if not (is_valid_id or related_name.endswith("+")): return [ checks.Error( - "The name '%s' is invalid related_name for field %s.%s" % - (self.remote_field.related_name, self.model._meta.object_name, - self.name), + "The name '%s' is invalid related_name for field %s.%s" + % ( + self.remote_field.related_name, + self.model._meta.object_name, + self.name, + ), hint="Related name must be a valid Python identifier or end with a '+'", obj=self, - id='fields.E306', + id="fields.E306", ) ] return [] @@ -141,15 +153,17 @@ class RelatedField(FieldCacheMixin, Field): return [] rel_query_name = self.related_query_name() errors = [] - if rel_query_name.endswith('_'): + if rel_query_name.endswith("_"): errors.append( checks.Error( "Reverse query name '%s' must not end with an underscore." % rel_query_name, - hint=("Add or change a related_name or related_query_name " - "argument for this field."), + hint=( + "Add or change a related_name or related_query_name " + "argument for this field." + ), obj=self, - id='fields.E308', + id="fields.E308", ) ) if LOOKUP_SEP in rel_query_name: @@ -157,10 +171,12 @@ class RelatedField(FieldCacheMixin, Field): checks.Error( "Reverse query name '%s' must not contain '%s'." % (rel_query_name, LOOKUP_SEP), - hint=("Add or change a related_name or related_query_name " - "argument for this field."), + hint=( + "Add or change a related_name or related_query_name " + "argument for this field." + ), obj=self, - id='fields.E309', + id="fields.E309", ) ) return errors @@ -168,29 +184,38 @@ class RelatedField(FieldCacheMixin, Field): def _check_relation_model_exists(self): rel_is_missing = self.remote_field.model not in self.opts.apps.get_models() rel_is_string = isinstance(self.remote_field.model, str) - model_name = self.remote_field.model if rel_is_string else self.remote_field.model._meta.object_name - if rel_is_missing and (rel_is_string or not self.remote_field.model._meta.swapped): + model_name = ( + self.remote_field.model + if rel_is_string + else self.remote_field.model._meta.object_name + ) + if rel_is_missing and ( + rel_is_string or not self.remote_field.model._meta.swapped + ): return [ checks.Error( "Field defines a relation with model '%s', which is either " "not installed, or is abstract." % model_name, obj=self, - id='fields.E300', + id="fields.E300", ) ] return [] def _check_referencing_to_swapped_model(self): - if (self.remote_field.model not in self.opts.apps.get_models() and - not isinstance(self.remote_field.model, str) and - self.remote_field.model._meta.swapped): + if ( + self.remote_field.model not in self.opts.apps.get_models() + and not isinstance(self.remote_field.model, str) + and self.remote_field.model._meta.swapped + ): return [ checks.Error( "Field defines a relation with the model '%s', which has " "been swapped out." % self.remote_field.model._meta.label, - hint="Update the relation to point at 'settings.%s'." % self.remote_field.model._meta.swappable, + hint="Update the relation to point at 'settings.%s'." + % self.remote_field.model._meta.swappable, obj=self, - id='fields.E301', + id="fields.E301", ) ] return [] @@ -227,7 +252,7 @@ class RelatedField(FieldCacheMixin, Field): rel_name = self.remote_field.get_accessor_name() # i. e. "model_set" rel_query_name = self.related_query_name() # i. e. "model" # i.e. "app_label.Model.field". - field_name = '%s.%s' % (opts.label, self.name) + field_name = "%s.%s" % (opts.label, self.name) # Check clashes between accessor or reverse query name of `field` # and any other field name -- i.e. accessor for Model.foreign is @@ -235,28 +260,35 @@ class RelatedField(FieldCacheMixin, Field): potential_clashes = rel_opts.fields + rel_opts.many_to_many for clash_field in potential_clashes: # i.e. "app_label.Target.model_set". - clash_name = '%s.%s' % (rel_opts.label, clash_field.name) + clash_name = "%s.%s" % (rel_opts.label, clash_field.name) if not rel_is_hidden and clash_field.name == rel_name: errors.append( checks.Error( f"Reverse accessor '{rel_opts.object_name}.{rel_name}' " f"for '{field_name}' clashes with field name " f"'{clash_name}'.", - hint=("Rename field '%s', or add/change a related_name " - "argument to the definition for field '%s'.") % (clash_name, field_name), + hint=( + "Rename field '%s', or add/change a related_name " + "argument to the definition for field '%s'." + ) + % (clash_name, field_name), obj=self, - id='fields.E302', + id="fields.E302", ) ) if clash_field.name == rel_query_name: errors.append( checks.Error( - "Reverse query name for '%s' clashes with field name '%s'." % (field_name, clash_name), - hint=("Rename field '%s', or add/change a related_name " - "argument to the definition for field '%s'.") % (clash_name, field_name), + "Reverse query name for '%s' clashes with field name '%s'." + % (field_name, clash_name), + hint=( + "Rename field '%s', or add/change a related_name " + "argument to the definition for field '%s'." + ) + % (clash_name, field_name), obj=self, - id='fields.E303', + id="fields.E303", ) ) @@ -266,7 +298,7 @@ class RelatedField(FieldCacheMixin, Field): potential_clashes = (r for r in rel_opts.related_objects if r.field is not self) for clash_field in potential_clashes: # i.e. "app_label.Model.m2m". - clash_name = '%s.%s' % ( + clash_name = "%s.%s" % ( clash_field.related_model._meta.label, clash_field.field.name, ) @@ -276,10 +308,13 @@ class RelatedField(FieldCacheMixin, Field): f"Reverse accessor '{rel_opts.object_name}.{rel_name}' " f"for '{field_name}' clashes with reverse accessor for " f"'{clash_name}'.", - hint=("Add or change a related_name argument " - "to the definition for '%s' or '%s'.") % (field_name, clash_name), + hint=( + "Add or change a related_name argument " + "to the definition for '%s' or '%s'." + ) + % (field_name, clash_name), obj=self, - id='fields.E304', + id="fields.E304", ) ) @@ -288,10 +323,13 @@ class RelatedField(FieldCacheMixin, Field): checks.Error( "Reverse query name for '%s' clashes with reverse query name for '%s'." % (field_name, clash_name), - hint=("Add or change a related_name argument " - "to the definition for '%s' or '%s'.") % (field_name, clash_name), + hint=( + "Add or change a related_name argument " + "to the definition for '%s' or '%s'." + ) + % (field_name, clash_name), obj=self, - id='fields.E305', + id="fields.E305", ) ) @@ -315,32 +353,35 @@ class RelatedField(FieldCacheMixin, Field): related_name = self.opts.default_related_name if related_name: related_name = related_name % { - 'class': cls.__name__.lower(), - 'model_name': cls._meta.model_name.lower(), - 'app_label': cls._meta.app_label.lower() + "class": cls.__name__.lower(), + "model_name": cls._meta.model_name.lower(), + "app_label": cls._meta.app_label.lower(), } self.remote_field.related_name = related_name if self.remote_field.related_query_name: related_query_name = self.remote_field.related_query_name % { - 'class': cls.__name__.lower(), - 'app_label': cls._meta.app_label.lower(), + "class": cls.__name__.lower(), + "app_label": cls._meta.app_label.lower(), } self.remote_field.related_query_name = related_query_name def resolve_related_class(model, related, field): field.remote_field.model = related field.do_related_class(related, model) - lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self) + + lazy_related_operation( + resolve_related_class, cls, self.remote_field.model, field=self + ) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self._limit_choices_to: - kwargs['limit_choices_to'] = self._limit_choices_to + kwargs["limit_choices_to"] = self._limit_choices_to if self._related_name is not None: - kwargs['related_name'] = self._related_name + kwargs["related_name"] = self._related_name if self._related_query_name is not None: - kwargs['related_query_name'] = self._related_query_name + kwargs["related_query_name"] = self._related_query_name return name, path, args, kwargs def get_forward_related_filter(self, obj): @@ -352,7 +393,7 @@ class RelatedField(FieldCacheMixin, Field): self.related_field.model. """ return { - '%s__%s' % (self.name, rh_field.name): getattr(obj, rh_field.attname) + "%s__%s" % (self.name, rh_field.name): getattr(obj, rh_field.attname) for _, rh_field in self.related_fields } @@ -391,9 +432,10 @@ class RelatedField(FieldCacheMixin, Field): return None def set_attributes_from_rel(self): - self.name = ( - self.name or - (self.remote_field.model._meta.model_name + '_' + self.remote_field.model._meta.pk.name) + self.name = self.name or ( + self.remote_field.model._meta.model_name + + "_" + + self.remote_field.model._meta.pk.name ) if self.verbose_name is None: self.verbose_name = self.remote_field.model._meta.verbose_name @@ -423,14 +465,16 @@ class RelatedField(FieldCacheMixin, Field): being constructed. """ defaults = {} - if hasattr(self.remote_field, 'get_related_field'): + if hasattr(self.remote_field, "get_related_field"): # If this is a callable, do not invoke it here. Just pass # it in the defaults for when the form class will later be # instantiated. limit_choices_to = self.remote_field.limit_choices_to - defaults.update({ - 'limit_choices_to': limit_choices_to, - }) + defaults.update( + { + "limit_choices_to": limit_choices_to, + } + ) defaults.update(kwargs) return super().formfield(**defaults) @@ -439,7 +483,11 @@ class RelatedField(FieldCacheMixin, Field): Define the name that can be used to identify this related object in a table-spanning query. """ - return self.remote_field.related_query_name or self.remote_field.related_name or self.opts.model_name + return ( + self.remote_field.related_query_name + or self.remote_field.related_name + or self.opts.model_name + ) @property def target_field(self): @@ -450,7 +498,8 @@ class RelatedField(FieldCacheMixin, Field): target_fields = self.path_infos[-1].target_fields if len(target_fields) > 1: raise exceptions.FieldError( - "The relation has multiple target fields, but only single target field was asked for") + "The relation has multiple target fields, but only single target field was asked for" + ) return target_fields[0] def get_cache_name(self): @@ -473,13 +522,25 @@ class ForeignObject(RelatedField): forward_related_accessor_class = ForwardManyToOneDescriptor rel_class = ForeignObjectRel - def __init__(self, to, on_delete, from_fields, to_fields, rel=None, related_name=None, - related_query_name=None, limit_choices_to=None, parent_link=False, - swappable=True, **kwargs): + def __init__( + self, + to, + on_delete, + from_fields, + to_fields, + rel=None, + related_name=None, + related_query_name=None, + limit_choices_to=None, + parent_link=False, + swappable=True, + **kwargs, + ): if rel is None: rel = self.rel_class( - self, to, + self, + to, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, @@ -502,8 +563,8 @@ class ForeignObject(RelatedField): def __copy__(self): obj = super().__copy__() # Remove any cached PathInfo values. - obj.__dict__.pop('path_infos', None) - obj.__dict__.pop('reverse_path_infos', None) + obj.__dict__.pop("path_infos", None) + obj.__dict__.pop("reverse_path_infos", None) return obj def check(self, **kwargs): @@ -530,7 +591,7 @@ class ForeignObject(RelatedField): "model '%s'." % (to_field, self.remote_field.model._meta.label), obj=self, - id='fields.E312', + id="fields.E312", ) ) return errors @@ -551,21 +612,22 @@ class ForeignObject(RelatedField): unique_foreign_fields = { frozenset([f.name]) for f in self.remote_field.model._meta.get_fields() - if getattr(f, 'unique', False) + if getattr(f, "unique", False) } - unique_foreign_fields.update({ - frozenset(ut) - for ut in self.remote_field.model._meta.unique_together - }) - unique_foreign_fields.update({ - frozenset(uc.fields) - for uc in self.remote_field.model._meta.total_unique_constraints - }) + unique_foreign_fields.update( + {frozenset(ut) for ut in self.remote_field.model._meta.unique_together} + ) + unique_foreign_fields.update( + { + frozenset(uc.fields) + for uc in self.remote_field.model._meta.total_unique_constraints + } + ) foreign_fields = {f.name for f in self.foreign_related_fields} has_unique_constraint = any(u <= foreign_fields for u in unique_foreign_fields) if not has_unique_constraint and len(self.foreign_related_fields) > 1: - field_combination = ', '.join( + field_combination = ", ".join( "'%s'" % rel_field.name for rel_field in self.foreign_related_fields ) model_name = self.remote_field.model.__name__ @@ -574,13 +636,13 @@ class ForeignObject(RelatedField): "No subset of the fields %s on model '%s' is unique." % (field_combination, model_name), hint=( - 'Mark a single field as unique=True or add a set of ' - 'fields to a unique constraint (via unique_together ' - 'or a UniqueConstraint (without condition) in the ' - 'model Meta.constraints).' + "Mark a single field as unique=True or add a set of " + "fields to a unique constraint (via unique_together " + "or a UniqueConstraint (without condition) in the " + "model Meta.constraints)." ), obj=self, - id='fields.E310', + id="fields.E310", ) ] elif not has_unique_constraint: @@ -591,12 +653,12 @@ class ForeignObject(RelatedField): "'%s.%s' must be unique because it is referenced by " "a foreign key." % (model_name, field_name), hint=( - 'Add unique=True to this field or add a ' - 'UniqueConstraint (without condition) in the model ' - 'Meta.constraints.' + "Add unique=True to this field or add a " + "UniqueConstraint (without condition) in the model " + "Meta.constraints." ), obj=self, - id='fields.E311', + id="fields.E311", ) ] else: @@ -604,44 +666,48 @@ class ForeignObject(RelatedField): def deconstruct(self): name, path, args, kwargs = super().deconstruct() - kwargs['on_delete'] = self.remote_field.on_delete - kwargs['from_fields'] = self.from_fields - kwargs['to_fields'] = self.to_fields + kwargs["on_delete"] = self.remote_field.on_delete + kwargs["from_fields"] = self.from_fields + kwargs["to_fields"] = self.to_fields if self.remote_field.parent_link: - kwargs['parent_link'] = self.remote_field.parent_link + kwargs["parent_link"] = self.remote_field.parent_link if isinstance(self.remote_field.model, str): - if '.' in self.remote_field.model: - app_label, model_name = self.remote_field.model.split('.') - kwargs['to'] = '%s.%s' % (app_label, model_name.lower()) + if "." in self.remote_field.model: + app_label, model_name = self.remote_field.model.split(".") + kwargs["to"] = "%s.%s" % (app_label, model_name.lower()) else: - kwargs['to'] = self.remote_field.model.lower() + kwargs["to"] = self.remote_field.model.lower() else: - kwargs['to'] = self.remote_field.model._meta.label_lower + kwargs["to"] = self.remote_field.model._meta.label_lower # If swappable is True, then see if we're actually pointing to the target # of a swap. swappable_setting = self.swappable_setting if swappable_setting is not None: # If it's already a settings reference, error - if hasattr(kwargs['to'], "setting_name"): - if kwargs['to'].setting_name != swappable_setting: + if hasattr(kwargs["to"], "setting_name"): + if kwargs["to"].setting_name != swappable_setting: raise ValueError( "Cannot deconstruct a ForeignKey pointing to a model " "that is swapped in place of more than one model (%s and %s)" - % (kwargs['to'].setting_name, swappable_setting) + % (kwargs["to"].setting_name, swappable_setting) ) # Set it - kwargs['to'] = SettingsReference( - kwargs['to'], + kwargs["to"] = SettingsReference( + kwargs["to"], swappable_setting, ) return name, path, args, kwargs def resolve_related_fields(self): if not self.from_fields or len(self.from_fields) != len(self.to_fields): - raise ValueError('Foreign Object from and to fields must be the same non-zero length') + raise ValueError( + "Foreign Object from and to fields must be the same non-zero length" + ) if isinstance(self.remote_field.model, str): - raise ValueError('Related model %r cannot be resolved' % self.remote_field.model) + raise ValueError( + "Related model %r cannot be resolved" % self.remote_field.model + ) related_fields = [] for index in range(len(self.from_fields)): from_field_name = self.from_fields[index] @@ -651,8 +717,11 @@ class ForeignObject(RelatedField): if from_field_name == RECURSIVE_RELATIONSHIP_CONSTANT else self.opts.get_field(from_field_name) ) - to_field = (self.remote_field.model._meta.pk if to_field_name is None - else self.remote_field.model._meta.get_field(to_field_name)) + to_field = ( + self.remote_field.model._meta.pk + if to_field_name is None + else self.remote_field.model._meta.get_field(to_field_name) + ) related_fields.append((from_field, to_field)) return related_fields @@ -670,7 +739,9 @@ class ForeignObject(RelatedField): @cached_property def foreign_related_fields(self): - return tuple(rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field) + return tuple( + rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field + ) def get_local_related_value(self, instance): return self.get_instance_value_for_fields(instance, self.local_related_fields) @@ -688,9 +759,11 @@ class ForeignObject(RelatedField): # instance.pk (that is, parent_ptr_id) when asked for instance.id. if field.primary_key: possible_parent_link = opts.get_ancestor_link(field.model) - if (not possible_parent_link or - possible_parent_link.primary_key or - possible_parent_link.model._meta.abstract): + if ( + not possible_parent_link + or possible_parent_link.primary_key + or possible_parent_link.model._meta.abstract + ): ret.append(instance.pk) continue ret.append(getattr(instance, field.attname)) @@ -702,7 +775,9 @@ class ForeignObject(RelatedField): def get_joining_columns(self, reverse_join=False): source = self.reverse_related_fields if reverse_join else self.related_fields - return tuple((lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source) + return tuple( + (lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source + ) def get_reverse_joining_columns(self): return self.get_joining_columns(reverse_join=True) @@ -740,15 +815,17 @@ class ForeignObject(RelatedField): """Get path from this field to the related model.""" opts = self.remote_field.model._meta from_opts = self.model._meta - return [PathInfo( - from_opts=from_opts, - to_opts=opts, - target_fields=self.foreign_related_fields, - join_field=self, - m2m=False, - direct=True, - filtered_relation=filtered_relation, - )] + return [ + PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=self.foreign_related_fields, + join_field=self, + m2m=False, + direct=True, + filtered_relation=filtered_relation, + ) + ] @cached_property def path_infos(self): @@ -758,15 +835,17 @@ class ForeignObject(RelatedField): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - return [PathInfo( - from_opts=from_opts, - to_opts=opts, - target_fields=(opts.pk,), - join_field=self.remote_field, - m2m=not self.unique, - direct=False, - filtered_relation=filtered_relation, - )] + return [ + PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + ) + ] @cached_property def reverse_path_infos(self): @@ -776,8 +855,8 @@ class ForeignObject(RelatedField): @functools.lru_cache(maxsize=None) def get_lookups(cls): bases = inspect.getmro(cls) - bases = bases[:bases.index(ForeignObject) + 1] - class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in bases] + bases = bases[: bases.index(ForeignObject) + 1] + class_lookups = [parent.__dict__.get("class_lookups", {}) for parent in bases] return cls.merge_dicts(class_lookups) def contribute_to_class(self, cls, name, private_only=False, **kwargs): @@ -787,13 +866,22 @@ class ForeignObject(RelatedField): def contribute_to_related_class(self, cls, related): # Internal FK's - i.e., those with a related name ending with '+' - # and swapped models don't get a related descriptor. - if not self.remote_field.is_hidden() and not related.related_model._meta.swapped: - setattr(cls._meta.concrete_model, related.get_accessor_name(), self.related_accessor_class(related)) + if ( + not self.remote_field.is_hidden() + and not related.related_model._meta.swapped + ): + setattr( + cls._meta.concrete_model, + related.get_accessor_name(), + self.related_accessor_class(related), + ) # While 'limit_choices_to' might be a callable, simply pass # it along for later - this is too early because it's still # model load time. if self.remote_field.limit_choices_to: - cls._meta.related_fkey_lookups.append(self.remote_field.limit_choices_to) + cls._meta.related_fkey_lookups.append( + self.remote_field.limit_choices_to + ) ForeignObject.register_lookup(RelatedIn) @@ -813,6 +901,7 @@ class ForeignKey(ForeignObject): By default ForeignKey will target the pk of the remote model but this behavior can be changed by using the ``to_field`` argument. """ + descriptor_class = ForeignKeyDeferredAttribute # Field flags many_to_many = False @@ -824,21 +913,33 @@ class ForeignKey(ForeignObject): empty_strings_allowed = False default_error_messages = { - 'invalid': _('%(model)s instance with %(field)s %(value)r does not exist.') + "invalid": _("%(model)s instance with %(field)s %(value)r does not exist.") } description = _("Foreign Key (type determined by related field)") - def __init__(self, to, on_delete, related_name=None, related_query_name=None, - limit_choices_to=None, parent_link=False, to_field=None, - db_constraint=True, **kwargs): + def __init__( + self, + to, + on_delete, + related_name=None, + related_query_name=None, + limit_choices_to=None, + parent_link=False, + to_field=None, + db_constraint=True, + **kwargs, + ): try: to._meta.model_name except AttributeError: if not isinstance(to, str): raise TypeError( - '%s(%r) is invalid. First parameter to ForeignKey must be ' - 'either a model, a model name, or the string %r' % ( - self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT, + "%s(%r) is invalid. First parameter to ForeignKey must be " + "either a model, a model name, or the string %r" + % ( + self.__class__.__name__, + to, + RECURSIVE_RELATIONSHIP_CONSTANT, ) ) else: @@ -847,17 +948,19 @@ class ForeignKey(ForeignObject): # be correct until contribute_to_class is called. Refs #12190. to_field = to_field or (to._meta.pk and to._meta.pk.name) if not callable(on_delete): - raise TypeError('on_delete must be callable.') + raise TypeError("on_delete must be callable.") - kwargs['rel'] = self.rel_class( - self, to, to_field, + kwargs["rel"] = self.rel_class( + self, + to, + to_field, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, parent_link=parent_link, on_delete=on_delete, ) - kwargs.setdefault('db_index', True) + kwargs.setdefault("db_index", True) super().__init__( to, @@ -879,54 +982,60 @@ class ForeignKey(ForeignObject): ] def _check_on_delete(self): - on_delete = getattr(self.remote_field, 'on_delete', None) + on_delete = getattr(self.remote_field, "on_delete", None) if on_delete == SET_NULL and not self.null: return [ checks.Error( - 'Field specifies on_delete=SET_NULL, but cannot be null.', - hint='Set null=True argument on the field, or change the on_delete rule.', + "Field specifies on_delete=SET_NULL, but cannot be null.", + hint="Set null=True argument on the field, or change the on_delete rule.", obj=self, - id='fields.E320', + id="fields.E320", ) ] elif on_delete == SET_DEFAULT and not self.has_default(): return [ checks.Error( - 'Field specifies on_delete=SET_DEFAULT, but has no default value.', - hint='Set a default value, or change the on_delete rule.', + "Field specifies on_delete=SET_DEFAULT, but has no default value.", + hint="Set a default value, or change the on_delete rule.", obj=self, - id='fields.E321', + id="fields.E321", ) ] else: return [] def _check_unique(self, **kwargs): - return [ - checks.Warning( - 'Setting unique=True on a ForeignKey has the same effect as using a OneToOneField.', - hint='ForeignKey(unique=True) is usually better served by a OneToOneField.', - obj=self, - id='fields.W342', - ) - ] if self.unique else [] + return ( + [ + checks.Warning( + "Setting unique=True on a ForeignKey has the same effect as using a OneToOneField.", + hint="ForeignKey(unique=True) is usually better served by a OneToOneField.", + obj=self, + id="fields.W342", + ) + ] + if self.unique + else [] + ) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - del kwargs['to_fields'] - del kwargs['from_fields'] + del kwargs["to_fields"] + del kwargs["from_fields"] # Handle the simpler arguments if self.db_index: - del kwargs['db_index'] + del kwargs["db_index"] else: - kwargs['db_index'] = False + kwargs["db_index"] = False if self.db_constraint is not True: - kwargs['db_constraint'] = self.db_constraint + kwargs["db_constraint"] = self.db_constraint # Rel needs more work. to_meta = getattr(self.remote_field.model, "_meta", None) if self.remote_field.field_name and ( - not to_meta or (to_meta.pk and self.remote_field.field_name != to_meta.pk.name)): - kwargs['to_field'] = self.remote_field.field_name + not to_meta + or (to_meta.pk and self.remote_field.field_name != to_meta.pk.name) + ): + kwargs["to_field"] = self.remote_field.field_name return name, path, args, kwargs def to_python(self, value): @@ -940,15 +1049,17 @@ class ForeignKey(ForeignObject): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - return [PathInfo( - from_opts=from_opts, - to_opts=opts, - target_fields=(opts.pk,), - join_field=self.remote_field, - m2m=not self.unique, - direct=False, - filtered_relation=filtered_relation, - )] + return [ + PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + ) + ] def validate(self, value, model_instance): if self.remote_field.parent_link: @@ -964,21 +1075,27 @@ class ForeignKey(ForeignObject): qs = qs.complex_filter(self.get_limit_choices_to()) if not qs.exists(): raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', + self.error_messages["invalid"], + code="invalid", params={ - 'model': self.remote_field.model._meta.verbose_name, 'pk': value, - 'field': self.remote_field.field_name, 'value': value, + "model": self.remote_field.model._meta.verbose_name, + "pk": value, + "field": self.remote_field.field_name, + "value": value, }, # 'pk' is included for backwards compatibility ) def resolve_related_fields(self): related_fields = super().resolve_related_fields() for from_field, to_field in related_fields: - if to_field and to_field.model != self.remote_field.model._meta.concrete_model: + if ( + to_field + and to_field.model != self.remote_field.model._meta.concrete_model + ): raise exceptions.FieldError( "'%s.%s' refers to field '%s' which is not local to model " - "'%s'." % ( + "'%s'." + % ( self.model._meta.label, self.name, to_field.name, @@ -988,7 +1105,7 @@ class ForeignKey(ForeignObject): return related_fields def get_attname(self): - return '%s_id' % self.name + return "%s_id" % self.name def get_attname_column(self): attname = self.get_attname() @@ -1003,9 +1120,13 @@ class ForeignKey(ForeignObject): return field_default def get_db_prep_save(self, value, connection): - if value is None or (value == '' and - (not self.target_field.empty_strings_allowed or - connection.features.interprets_empty_strings_as_nulls)): + if value is None or ( + value == "" + and ( + not self.target_field.empty_strings_allowed + or connection.features.interprets_empty_strings_as_nulls + ) + ): return None else: return self.target_field.get_db_prep_save(value, connection=connection) @@ -1023,16 +1144,20 @@ class ForeignKey(ForeignObject): def formfield(self, *, using=None, **kwargs): if isinstance(self.remote_field.model, str): - raise ValueError("Cannot create form field for %r yet, because " - "its related model %r has not been loaded yet" % - (self.name, self.remote_field.model)) - return super().formfield(**{ - 'form_class': forms.ModelChoiceField, - 'queryset': self.remote_field.model._default_manager.using(using), - 'to_field_name': self.remote_field.field_name, - **kwargs, - 'blank': self.blank, - }) + raise ValueError( + "Cannot create form field for %r yet, because " + "its related model %r has not been loaded yet" + % (self.name, self.remote_field.model) + ) + return super().formfield( + **{ + "form_class": forms.ModelChoiceField, + "queryset": self.remote_field.model._default_manager.using(using), + "to_field_name": self.remote_field.field_name, + **kwargs, + "blank": self.blank, + } + ) def db_check(self, connection): return None @@ -1060,7 +1185,7 @@ class ForeignKey(ForeignObject): while isinstance(output_field, ForeignKey): output_field = output_field.target_field if output_field is self: - raise ValueError('Cannot resolve output_field.') + raise ValueError("Cannot resolve output_field.") return super().get_col(alias, output_field) @@ -1085,13 +1210,13 @@ class OneToOneField(ForeignKey): description = _("One-to-one relationship") def __init__(self, to, on_delete, to_field=None, **kwargs): - kwargs['unique'] = True + kwargs["unique"] = True super().__init__(to, on_delete, to_field=to_field, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if "unique" in kwargs: - del kwargs['unique'] + del kwargs["unique"] return name, path, args, kwargs def formfield(self, **kwargs): @@ -1121,44 +1246,54 @@ def create_many_to_many_intermediary_model(field, klass): through._meta.managed = model._meta.managed or related._meta.managed to_model = resolve_relation(klass, field.remote_field.model) - name = '%s_%s' % (klass._meta.object_name, field.name) + name = "%s_%s" % (klass._meta.object_name, field.name) lazy_related_operation(set_managed, klass, to_model, name) to = make_model_tuple(to_model)[1] from_ = klass._meta.model_name if to == from_: - to = 'to_%s' % to - from_ = 'from_%s' % from_ + to = "to_%s" % to + from_ = "from_%s" % from_ - meta = type('Meta', (), { - 'db_table': field._get_m2m_db_table(klass._meta), - 'auto_created': klass, - 'app_label': klass._meta.app_label, - 'db_tablespace': klass._meta.db_tablespace, - 'unique_together': (from_, to), - 'verbose_name': _('%(from)s-%(to)s relationship') % {'from': from_, 'to': to}, - 'verbose_name_plural': _('%(from)s-%(to)s relationships') % {'from': from_, 'to': to}, - 'apps': field.model._meta.apps, - }) + meta = type( + "Meta", + (), + { + "db_table": field._get_m2m_db_table(klass._meta), + "auto_created": klass, + "app_label": klass._meta.app_label, + "db_tablespace": klass._meta.db_tablespace, + "unique_together": (from_, to), + "verbose_name": _("%(from)s-%(to)s relationship") + % {"from": from_, "to": to}, + "verbose_name_plural": _("%(from)s-%(to)s relationships") + % {"from": from_, "to": to}, + "apps": field.model._meta.apps, + }, + ) # Construct and return the new class. - return type(name, (models.Model,), { - 'Meta': meta, - '__module__': klass.__module__, - from_: models.ForeignKey( - klass, - related_name='%s+' % name, - db_tablespace=field.db_tablespace, - db_constraint=field.remote_field.db_constraint, - on_delete=CASCADE, - ), - to: models.ForeignKey( - to_model, - related_name='%s+' % name, - db_tablespace=field.db_tablespace, - db_constraint=field.remote_field.db_constraint, - on_delete=CASCADE, - ) - }) + return type( + name, + (models.Model,), + { + "Meta": meta, + "__module__": klass.__module__, + from_: models.ForeignKey( + klass, + related_name="%s+" % name, + db_tablespace=field.db_tablespace, + db_constraint=field.remote_field.db_constraint, + on_delete=CASCADE, + ), + to: models.ForeignKey( + to_model, + related_name="%s+" % name, + db_tablespace=field.db_tablespace, + db_constraint=field.remote_field.db_constraint, + on_delete=CASCADE, + ), + }, + ) class ManyToManyField(RelatedField): @@ -1181,31 +1316,45 @@ class ManyToManyField(RelatedField): description = _("Many-to-many relationship") - def __init__(self, to, related_name=None, related_query_name=None, - limit_choices_to=None, symmetrical=None, through=None, - through_fields=None, db_constraint=True, db_table=None, - swappable=True, **kwargs): + def __init__( + self, + to, + related_name=None, + related_query_name=None, + limit_choices_to=None, + symmetrical=None, + through=None, + through_fields=None, + db_constraint=True, + db_table=None, + swappable=True, + **kwargs, + ): try: to._meta except AttributeError: if not isinstance(to, str): raise TypeError( - '%s(%r) is invalid. First parameter to ManyToManyField ' - 'must be either a model, a model name, or the string %r' % ( - self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT, + "%s(%r) is invalid. First parameter to ManyToManyField " + "must be either a model, a model name, or the string %r" + % ( + self.__class__.__name__, + to, + RECURSIVE_RELATIONSHIP_CONSTANT, ) ) if symmetrical is None: - symmetrical = (to == RECURSIVE_RELATIONSHIP_CONSTANT) + symmetrical = to == RECURSIVE_RELATIONSHIP_CONSTANT if through is not None and db_table is not None: raise ValueError( - 'Cannot specify a db_table if an intermediary model is used.' + "Cannot specify a db_table if an intermediary model is used." ) - kwargs['rel'] = self.rel_class( - self, to, + kwargs["rel"] = self.rel_class( + self, + to, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, @@ -1214,7 +1363,7 @@ class ManyToManyField(RelatedField): through_fields=through_fields, db_constraint=db_constraint, ) - self.has_null_arg = 'null' in kwargs + self.has_null_arg = "null" in kwargs super().__init__( related_name=related_name, @@ -1239,9 +1388,9 @@ class ManyToManyField(RelatedField): if self.unique: return [ checks.Error( - 'ManyToManyFields cannot be unique.', + "ManyToManyFields cannot be unique.", obj=self, - id='fields.E330', + id="fields.E330", ) ] return [] @@ -1252,49 +1401,53 @@ class ManyToManyField(RelatedField): if self.has_null_arg: warnings.append( checks.Warning( - 'null has no effect on ManyToManyField.', + "null has no effect on ManyToManyField.", obj=self, - id='fields.W340', + id="fields.W340", ) ) if self._validators: warnings.append( checks.Warning( - 'ManyToManyField does not support validators.', + "ManyToManyField does not support validators.", obj=self, - id='fields.W341', + id="fields.W341", ) ) if self.remote_field.symmetrical and self._related_name: warnings.append( checks.Warning( - 'related_name has no effect on ManyToManyField ' + "related_name has no effect on ManyToManyField " 'with a symmetrical relationship, e.g. to "self".', obj=self, - id='fields.W345', + id="fields.W345", ) ) return warnings def _check_relationship_model(self, from_model=None, **kwargs): - if hasattr(self.remote_field.through, '_meta'): + if hasattr(self.remote_field.through, "_meta"): qualified_model_name = "%s.%s" % ( - self.remote_field.through._meta.app_label, self.remote_field.through.__name__) + self.remote_field.through._meta.app_label, + self.remote_field.through.__name__, + ) else: qualified_model_name = self.remote_field.through errors = [] - if self.remote_field.through not in self.opts.apps.get_models(include_auto_created=True): + if self.remote_field.through not in self.opts.apps.get_models( + include_auto_created=True + ): # The relationship model is not installed. errors.append( checks.Error( "Field specifies a many-to-many relation through model " "'%s', which has not been installed." % qualified_model_name, obj=self, - id='fields.E331', + id="fields.E331", ) ) @@ -1316,7 +1469,7 @@ class ManyToManyField(RelatedField): # Count foreign keys in intermediate model if self_referential: seen_self = sum( - from_model == getattr(field.remote_field, 'model', None) + from_model == getattr(field.remote_field, "model", None) for field in self.remote_field.through._meta.fields ) @@ -1327,41 +1480,46 @@ class ManyToManyField(RelatedField): "'%s', but it has more than two foreign keys " "to '%s', which is ambiguous. You must specify " "which two foreign keys Django should use via the " - "through_fields keyword argument." % (self, from_model_name), + "through_fields keyword argument." + % (self, from_model_name), hint="Use through_fields to specify which two foreign keys Django should use.", obj=self.remote_field.through, - id='fields.E333', + id="fields.E333", ) ) else: # Count foreign keys in relationship model seen_from = sum( - from_model == getattr(field.remote_field, 'model', None) + from_model == getattr(field.remote_field, "model", None) for field in self.remote_field.through._meta.fields ) seen_to = sum( - to_model == getattr(field.remote_field, 'model', None) + to_model == getattr(field.remote_field, "model", None) for field in self.remote_field.through._meta.fields ) if seen_from > 1 and not self.remote_field.through_fields: errors.append( checks.Error( - ("The model is used as an intermediate model by " - "'%s', but it has more than one foreign key " - "from '%s', which is ambiguous. You must specify " - "which foreign key Django should use via the " - "through_fields keyword argument.") % (self, from_model_name), + ( + "The model is used as an intermediate model by " + "'%s', but it has more than one foreign key " + "from '%s', which is ambiguous. You must specify " + "which foreign key Django should use via the " + "through_fields keyword argument." + ) + % (self, from_model_name), hint=( - 'If you want to create a recursive relationship, ' + "If you want to create a recursive relationship, " 'use ManyToManyField("%s", through="%s").' - ) % ( + ) + % ( RECURSIVE_RELATIONSHIP_CONSTANT, relationship_model_name, ), obj=self, - id='fields.E334', + id="fields.E334", ) ) @@ -1374,14 +1532,15 @@ class ManyToManyField(RelatedField): "which foreign key Django should use via the " "through_fields keyword argument." % (self, to_model_name), hint=( - 'If you want to create a recursive relationship, ' + "If you want to create a recursive relationship, " 'use ManyToManyField("%s", through="%s").' - ) % ( + ) + % ( RECURSIVE_RELATIONSHIP_CONSTANT, relationship_model_name, ), obj=self, - id='fields.E335', + id="fields.E335", ) ) @@ -1389,11 +1548,10 @@ class ManyToManyField(RelatedField): errors.append( checks.Error( "The model is used as an intermediate model by " - "'%s', but it does not have a foreign key to '%s' or '%s'." % ( - self, from_model_name, to_model_name - ), + "'%s', but it does not have a foreign key to '%s' or '%s'." + % (self, from_model_name, to_model_name), obj=self.remote_field.through, - id='fields.E336', + id="fields.E336", ) ) @@ -1401,8 +1559,11 @@ class ManyToManyField(RelatedField): if self.remote_field.through_fields is not None: # Validate that we're given an iterable of at least two items # and that none of them is "falsy". - if not (len(self.remote_field.through_fields) >= 2 and - self.remote_field.through_fields[0] and self.remote_field.through_fields[1]): + if not ( + len(self.remote_field.through_fields) >= 2 + and self.remote_field.through_fields[0] + and self.remote_field.through_fields[1] + ): errors.append( checks.Error( "Field specifies 'through_fields' but does not provide " @@ -1410,7 +1571,7 @@ class ManyToManyField(RelatedField): "for the relation through model '%s'." % qualified_model_name, hint="Make sure you specify 'through_fields' as through_fields=('field1', 'field2')", obj=self, - id='fields.E337', + id="fields.E337", ) ) @@ -1424,20 +1585,34 @@ class ManyToManyField(RelatedField): "where the field is attached to." ) - source, through, target = from_model, self.remote_field.through, self.remote_field.model - source_field_name, target_field_name = self.remote_field.through_fields[:2] + source, through, target = ( + from_model, + self.remote_field.through, + self.remote_field.model, + ) + source_field_name, target_field_name = self.remote_field.through_fields[ + :2 + ] - for field_name, related_model in ((source_field_name, source), - (target_field_name, target)): + for field_name, related_model in ( + (source_field_name, source), + (target_field_name, target), + ): possible_field_names = [] for f in through._meta.fields: - if hasattr(f, 'remote_field') and getattr(f.remote_field, 'model', None) == related_model: + if ( + hasattr(f, "remote_field") + and getattr(f.remote_field, "model", None) == related_model + ): possible_field_names.append(f.name) if possible_field_names: - hint = "Did you mean one of the following foreign keys to '%s': %s?" % ( - related_model._meta.object_name, - ', '.join(possible_field_names), + hint = ( + "Did you mean one of the following foreign keys to '%s': %s?" + % ( + related_model._meta.object_name, + ", ".join(possible_field_names), + ) ) else: hint = None @@ -1451,28 +1626,36 @@ class ManyToManyField(RelatedField): % (qualified_model_name, field_name), hint=hint, obj=self, - id='fields.E338', + id="fields.E338", ) ) else: - if not (hasattr(field, 'remote_field') and - getattr(field.remote_field, 'model', None) == related_model): + if not ( + hasattr(field, "remote_field") + and getattr(field.remote_field, "model", None) + == related_model + ): errors.append( checks.Error( - "'%s.%s' is not a foreign key to '%s'." % ( - through._meta.object_name, field_name, + "'%s.%s' is not a foreign key to '%s'." + % ( + through._meta.object_name, + field_name, related_model._meta.object_name, ), hint=hint, obj=self, - id='fields.E339', + id="fields.E339", ) ) return errors def _check_table_uniqueness(self, **kwargs): - if isinstance(self.remote_field.through, str) or not self.remote_field.through._meta.managed: + if ( + isinstance(self.remote_field.through, str) + or not self.remote_field.through._meta.managed + ): return [] registered_tables = { model._meta.db_table: model @@ -1483,25 +1666,31 @@ class ManyToManyField(RelatedField): model = registered_tables.get(m2m_db_table) # The second condition allows multiple m2m relations on a model if # some point to a through model that proxies another through model. - if model and model._meta.concrete_model != self.remote_field.through._meta.concrete_model: + if ( + model + and model._meta.concrete_model + != self.remote_field.through._meta.concrete_model + ): if model._meta.auto_created: + def _get_field_name(model): for field in model._meta.auto_created._meta.many_to_many: if field.remote_field.through is model: return field.name + opts = model._meta.auto_created._meta - clashing_obj = '%s.%s' % (opts.label, _get_field_name(model)) + clashing_obj = "%s.%s" % (opts.label, _get_field_name(model)) else: clashing_obj = model._meta.label if settings.DATABASE_ROUTERS: - error_class, error_id = checks.Warning, 'fields.W344' + error_class, error_id = checks.Warning, "fields.W344" error_hint = ( - 'You have configured settings.DATABASE_ROUTERS. Verify ' - 'that the table of %r is correctly routed to a separate ' - 'database.' % clashing_obj + "You have configured settings.DATABASE_ROUTERS. Verify " + "that the table of %r is correctly routed to a separate " + "database." % clashing_obj ) else: - error_class, error_id = checks.Error, 'fields.E340' + error_class, error_id = checks.Error, "fields.E340" error_hint = None return [ error_class( @@ -1518,34 +1707,34 @@ class ManyToManyField(RelatedField): name, path, args, kwargs = super().deconstruct() # Handle the simpler arguments. if self.db_table is not None: - kwargs['db_table'] = self.db_table + kwargs["db_table"] = self.db_table if self.remote_field.db_constraint is not True: - kwargs['db_constraint'] = self.remote_field.db_constraint + kwargs["db_constraint"] = self.remote_field.db_constraint # Rel needs more work. if isinstance(self.remote_field.model, str): - kwargs['to'] = self.remote_field.model + kwargs["to"] = self.remote_field.model else: - kwargs['to'] = self.remote_field.model._meta.label - if getattr(self.remote_field, 'through', None) is not None: + kwargs["to"] = self.remote_field.model._meta.label + if getattr(self.remote_field, "through", None) is not None: if isinstance(self.remote_field.through, str): - kwargs['through'] = self.remote_field.through + kwargs["through"] = self.remote_field.through elif not self.remote_field.through._meta.auto_created: - kwargs['through'] = self.remote_field.through._meta.label + kwargs["through"] = self.remote_field.through._meta.label # If swappable is True, then see if we're actually pointing to the target # of a swap. swappable_setting = self.swappable_setting if swappable_setting is not None: # If it's already a settings reference, error. - if hasattr(kwargs['to'], "setting_name"): - if kwargs['to'].setting_name != swappable_setting: + if hasattr(kwargs["to"], "setting_name"): + if kwargs["to"].setting_name != swappable_setting: raise ValueError( "Cannot deconstruct a ManyToManyField pointing to a " "model that is swapped in place of more than one model " - "(%s and %s)" % (kwargs['to'].setting_name, swappable_setting) + "(%s and %s)" % (kwargs["to"].setting_name, swappable_setting) ) - kwargs['to'] = SettingsReference( - kwargs['to'], + kwargs["to"] = SettingsReference( + kwargs["to"], swappable_setting, ) return name, path, args, kwargs @@ -1605,7 +1794,7 @@ class ManyToManyField(RelatedField): elif self.db_table: return self.db_table else: - m2m_table_name = '%s_%s' % (utils.strip_quotes(opts.db_table), self.name) + m2m_table_name = "%s_%s" % (utils.strip_quotes(opts.db_table), self.name) return utils.truncate_name(m2m_table_name, connection.ops.max_name_length()) def _get_m2m_attr(self, related, attr): @@ -1613,7 +1802,7 @@ class ManyToManyField(RelatedField): Function that can be curried to provide the source accessor or DB column name for the m2m table. """ - cache_attr = '_m2m_%s_cache' % attr + cache_attr = "_m2m_%s_cache" % attr if hasattr(self, cache_attr): return getattr(self, cache_attr) if self.remote_field.through_fields is not None: @@ -1621,8 +1810,11 @@ class ManyToManyField(RelatedField): else: link_field_name = None for f in self.remote_field.through._meta.fields: - if (f.is_relation and f.remote_field.model == related.related_model and - (link_field_name is None or link_field_name == f.name)): + if ( + f.is_relation + and f.remote_field.model == related.related_model + and (link_field_name is None or link_field_name == f.name) + ): setattr(self, cache_attr, getattr(f, attr)) return getattr(self, cache_attr) @@ -1631,7 +1823,7 @@ class ManyToManyField(RelatedField): Function that can be curried to provide the related accessor or DB column name for the m2m table. """ - cache_attr = '_m2m_reverse_%s_cache' % attr + cache_attr = "_m2m_reverse_%s_cache" % attr if hasattr(self, cache_attr): return getattr(self, cache_attr) found = False @@ -1664,8 +1856,8 @@ class ManyToManyField(RelatedField): # automatically. The funky name reduces the chance of an accidental # clash. if self.remote_field.symmetrical and ( - self.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT or - self.remote_field.model == cls._meta.object_name + self.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT + or self.remote_field.model == cls._meta.object_name ): self.remote_field.related_name = "%s_rel_+" % name elif self.remote_field.is_hidden(): @@ -1673,7 +1865,7 @@ class ManyToManyField(RelatedField): # related_name with one generated from the m2m field name. Django # still uses backwards relations internally and we need to avoid # clashes between multiple m2m fields with related_name == '+'. - self.remote_field.related_name = '_%s_%s_%s_+' % ( + self.remote_field.related_name = "_%s_%s_%s_+" % ( cls._meta.app_label, cls.__name__.lower(), name, @@ -1687,11 +1879,17 @@ class ManyToManyField(RelatedField): # 3) The class owning the m2m field has been swapped out. if not cls._meta.abstract: if self.remote_field.through: + def resolve_through_model(_, model, field): field.remote_field.through = model - lazy_related_operation(resolve_through_model, cls, self.remote_field.through, field=self) + + lazy_related_operation( + resolve_through_model, cls, self.remote_field.through, field=self + ) elif not cls._meta.swapped: - self.remote_field.through = create_many_to_many_intermediary_model(self, cls) + self.remote_field.through = create_many_to_many_intermediary_model( + self, cls + ) # Add the descriptor for the m2m relation. setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False)) @@ -1702,19 +1900,30 @@ class ManyToManyField(RelatedField): def contribute_to_related_class(self, cls, related): # Internal M2Ms (i.e., those with a related name ending with '+') # and swapped models don't get a related descriptor. - if not self.remote_field.is_hidden() and not related.related_model._meta.swapped: - setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True)) + if ( + not self.remote_field.is_hidden() + and not related.related_model._meta.swapped + ): + setattr( + cls, + related.get_accessor_name(), + ManyToManyDescriptor(self.remote_field, reverse=True), + ) # Set up the accessors for the column names on the m2m table. - self.m2m_column_name = partial(self._get_m2m_attr, related, 'column') - self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, 'column') + self.m2m_column_name = partial(self._get_m2m_attr, related, "column") + self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, "column") - self.m2m_field_name = partial(self._get_m2m_attr, related, 'name') - self.m2m_reverse_field_name = partial(self._get_m2m_reverse_attr, related, 'name') + self.m2m_field_name = partial(self._get_m2m_attr, related, "name") + self.m2m_reverse_field_name = partial( + self._get_m2m_reverse_attr, related, "name" + ) - get_m2m_rel = partial(self._get_m2m_attr, related, 'remote_field') + get_m2m_rel = partial(self._get_m2m_attr, related, "remote_field") self.m2m_target_field_name = lambda: get_m2m_rel().field_name - get_m2m_reverse_rel = partial(self._get_m2m_reverse_attr, related, 'remote_field') + get_m2m_reverse_rel = partial( + self._get_m2m_reverse_attr, related, "remote_field" + ) self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name def set_attributes_from_rel(self): @@ -1728,17 +1937,17 @@ class ManyToManyField(RelatedField): def formfield(self, *, using=None, **kwargs): defaults = { - 'form_class': forms.ModelMultipleChoiceField, - 'queryset': self.remote_field.model._default_manager.using(using), + "form_class": forms.ModelMultipleChoiceField, + "queryset": self.remote_field.model._default_manager.using(using), **kwargs, } # If initial is passed in, it's a list of related objects, but the # MultipleChoiceField takes a list of IDs. - if defaults.get('initial') is not None: - initial = defaults['initial'] + if defaults.get("initial") is not None: + initial = defaults["initial"] if callable(initial): initial = initial() - defaults['initial'] = [i.pk for i in initial] + defaults["initial"] = [i.pk for i in initial] return super().formfield(**defaults) def db_check(self, connection): diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 9c50ef16ce..3f67ed8166 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -74,7 +74,9 @@ from django.utils.functional import cached_property class ForeignKeyDeferredAttribute(DeferredAttribute): def __set__(self, instance, value): - if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(instance): + if instance.__dict__.get(self.field.attname) != value and self.field.is_cached( + instance + ): self.field.delete_cached_value(instance) instance.__dict__[self.field.attname] = value @@ -101,14 +103,16 @@ class ForwardManyToOneDescriptor: # related model might not be resolved yet; `self.field.model` might # still be a string model reference. return type( - 'RelatedObjectDoesNotExist', - (self.field.remote_field.model.DoesNotExist, AttributeError), { - '__module__': self.field.model.__module__, - '__qualname__': '%s.%s.RelatedObjectDoesNotExist' % ( + "RelatedObjectDoesNotExist", + (self.field.remote_field.model.DoesNotExist, AttributeError), + { + "__module__": self.field.model.__module__, + "__qualname__": "%s.%s.RelatedObjectDoesNotExist" + % ( self.field.model.__qualname__, self.field.name, ), - } + }, ) def is_cached(self, instance): @@ -135,9 +139,12 @@ class ForwardManyToOneDescriptor: # The check for len(...) == 1 is a special case that allows the query # to be join-less and smaller. Refs #21760. if remote_field.is_hidden() or len(self.field.foreign_related_fields) == 1: - query = {'%s__in' % related_field.name: {instance_attr(inst)[0] for inst in instances}} + query = { + "%s__in" + % related_field.name: {instance_attr(inst)[0] for inst in instances} + } else: - query = {'%s__in' % self.field.related_query_name(): instances} + query = {"%s__in" % self.field.related_query_name(): instances} queryset = queryset.filter(**query) # Since we're going to assign directly in the cache, @@ -146,7 +153,14 @@ class ForwardManyToOneDescriptor: for rel_obj in queryset: instance = instances_dict[rel_obj_attr(rel_obj)] remote_field.set_cached_value(rel_obj, instance) - return queryset, rel_obj_attr, instance_attr, True, self.field.get_cache_name(), False + return ( + queryset, + rel_obj_attr, + instance_attr, + True, + self.field.get_cache_name(), + False, + ) def get_object(self, instance): qs = self.get_queryset(instance=instance) @@ -173,7 +187,11 @@ class ForwardManyToOneDescriptor: rel_obj = self.field.get_cached_value(instance) except KeyError: has_value = None not in self.field.get_local_related_value(instance) - ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None + ancestor_link = ( + instance._meta.get_ancestor_link(self.field.model) + if has_value + else None + ) if ancestor_link and ancestor_link.is_cached(instance): # An ancestor link will exist if this field is defined on a # multi-table inheritance parent of the instance's class. @@ -211,9 +229,12 @@ class ForwardManyToOneDescriptor: - ``value`` is the ``parent`` instance on the right of the equal sign """ # An object must be an instance of the related class. - if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model): + if value is not None and not isinstance( + value, self.field.remote_field.model._meta.concrete_model + ): raise ValueError( - 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % ( + 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' + % ( value, instance._meta.object_name, self.field.name, @@ -222,11 +243,18 @@ class ForwardManyToOneDescriptor: ) elif value is not None: if instance._state.db is None: - instance._state.db = router.db_for_write(instance.__class__, instance=value) + instance._state.db = router.db_for_write( + instance.__class__, instance=value + ) if value._state.db is None: - value._state.db = router.db_for_write(value.__class__, instance=instance) + value._state.db = router.db_for_write( + value.__class__, instance=instance + ) if not router.allow_relation(value, instance): - raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value) + raise ValueError( + 'Cannot assign "%r": the current database router prevents this relation.' + % value + ) remote_field = self.field.remote_field # If we're setting the value of a OneToOneField to None, we need to clear @@ -314,12 +342,15 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor): opts = instance._meta # Inherited primary key fields from this object's base classes. inherited_pk_fields = [ - field for field in opts.concrete_fields + field + for field in opts.concrete_fields if field.primary_key and field.remote_field ] for field in inherited_pk_fields: rel_model_pk_name = field.remote_field.model._meta.pk.attname - raw_value = getattr(value, rel_model_pk_name) if value is not None else None + raw_value = ( + getattr(value, rel_model_pk_name) if value is not None else None + ) setattr(instance, rel_model_pk_name, raw_value) @@ -346,13 +377,15 @@ class ReverseOneToOneDescriptor: # The exception isn't created at initialization time for the sake of # consistency with `ForwardManyToOneDescriptor`. return type( - 'RelatedObjectDoesNotExist', - (self.related.related_model.DoesNotExist, AttributeError), { - '__module__': self.related.model.__module__, - '__qualname__': '%s.%s.RelatedObjectDoesNotExist' % ( + "RelatedObjectDoesNotExist", + (self.related.related_model.DoesNotExist, AttributeError), + { + "__module__": self.related.model.__module__, + "__qualname__": "%s.%s.RelatedObjectDoesNotExist" + % ( self.related.model.__qualname__, self.related.name, - ) + ), }, ) @@ -370,7 +403,7 @@ class ReverseOneToOneDescriptor: rel_obj_attr = self.related.field.get_local_related_value instance_attr = self.related.field.get_foreign_related_value instances_dict = {instance_attr(inst): inst for inst in instances} - query = {'%s__in' % self.related.field.name: instances} + query = {"%s__in" % self.related.field.name: instances} queryset = queryset.filter(**query) # Since we're going to assign directly in the cache, @@ -378,7 +411,14 @@ class ReverseOneToOneDescriptor: for rel_obj in queryset: instance = instances_dict[rel_obj_attr(rel_obj)] self.related.field.set_cached_value(rel_obj, instance) - return queryset, rel_obj_attr, instance_attr, True, self.related.get_cache_name(), False + return ( + queryset, + rel_obj_attr, + instance_attr, + True, + self.related.get_cache_name(), + False, + ) def __get__(self, instance, cls=None): """ @@ -419,10 +459,8 @@ class ReverseOneToOneDescriptor: if rel_obj is None: raise self.RelatedObjectDoesNotExist( - "%s has no %s." % ( - instance.__class__.__name__, - self.related.get_accessor_name() - ) + "%s has no %s." + % (instance.__class__.__name__, self.related.get_accessor_name()) ) else: return rel_obj @@ -458,7 +496,8 @@ class ReverseOneToOneDescriptor: elif not isinstance(value, self.related.related_model): # An object must be an instance of the related class. raise ValueError( - 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % ( + 'Cannot assign "%r": "%s.%s" must be a "%s" instance.' + % ( value, instance._meta.object_name, self.related.get_accessor_name(), @@ -467,13 +506,23 @@ class ReverseOneToOneDescriptor: ) else: if instance._state.db is None: - instance._state.db = router.db_for_write(instance.__class__, instance=value) + instance._state.db = router.db_for_write( + instance.__class__, instance=value + ) if value._state.db is None: - value._state.db = router.db_for_write(value.__class__, instance=instance) + value._state.db = router.db_for_write( + value.__class__, instance=instance + ) if not router.allow_relation(value, instance): - raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value) + raise ValueError( + 'Cannot assign "%r": the current database router prevents this relation.' + % value + ) - related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields) + related_pk = tuple( + getattr(instance, field.attname) + for field in self.related.field.foreign_related_fields + ) # Set the value of the related field to the value of the related object's related field for index, field in enumerate(self.related.field.local_related_fields): setattr(value, field.attname, related_pk[index]) @@ -548,13 +597,13 @@ class ReverseManyToOneDescriptor: def _get_set_deprecation_msg_params(self): return ( - 'reverse side of a related set', + "reverse side of a related set", self.rel.get_accessor_name(), ) def __set__(self, instance, value): raise TypeError( - 'Direct assignment to the %s is prohibited. Use %s.set() instead.' + "Direct assignment to the %s is prohibited. Use %s.set() instead." % self._get_set_deprecation_msg_params(), ) @@ -581,6 +630,7 @@ def create_reverse_many_to_one_manager(superclass, rel): manager = getattr(self.model, manager) manager_class = create_reverse_many_to_one_manager(manager.__class__, rel) return manager_class(self.instance) + do_not_call_in_templates = True def _apply_rel_filters(self, queryset): @@ -588,7 +638,9 @@ def create_reverse_many_to_one_manager(superclass, rel): Filter the queryset for the instance this manager is bound to. """ db = self._db or router.db_for_read(self.model, instance=self.instance) - empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls + empty_strings_as_null = connections[ + db + ].features.interprets_empty_strings_as_nulls queryset._add_hints(instance=self.instance) if self._db: queryset = queryset.using(self._db) @@ -596,7 +648,7 @@ def create_reverse_many_to_one_manager(superclass, rel): queryset = queryset.filter(**self.core_filters) for field in self.field.foreign_related_fields: val = getattr(self.instance, field.attname) - if val is None or (val == '' and empty_strings_as_null): + if val is None or (val == "" and empty_strings_as_null): return queryset.none() if self.field.many_to_one: # Guard against field-like objects such as GenericRelation @@ -608,24 +660,32 @@ def create_reverse_many_to_one_manager(superclass, rel): except FieldError: # The relationship has multiple target fields. Use a tuple # for related object id. - rel_obj_id = tuple([ - getattr(self.instance, target_field.attname) - for target_field in self.field.path_infos[-1].target_fields - ]) + rel_obj_id = tuple( + [ + getattr(self.instance, target_field.attname) + for target_field in self.field.path_infos[-1].target_fields + ] + ) else: rel_obj_id = getattr(self.instance, target_field.attname) - queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}} + queryset._known_related_objects = { + self.field: {rel_obj_id: self.instance} + } return queryset def _remove_prefetched_objects(self): try: - self.instance._prefetched_objects_cache.pop(self.field.remote_field.get_cache_name()) + self.instance._prefetched_objects_cache.pop( + self.field.remote_field.get_cache_name() + ) except (AttributeError, KeyError): pass # nothing to clear from cache def get_queryset(self): try: - return self.instance._prefetched_objects_cache[self.field.remote_field.get_cache_name()] + return self.instance._prefetched_objects_cache[ + self.field.remote_field.get_cache_name() + ] except (AttributeError, KeyError): queryset = super().get_queryset() return self._apply_rel_filters(queryset) @@ -640,7 +700,7 @@ def create_reverse_many_to_one_manager(superclass, rel): rel_obj_attr = self.field.get_local_related_value instance_attr = self.field.get_foreign_related_value instances_dict = {instance_attr(inst): inst for inst in instances} - query = {'%s__in' % self.field.name: instances} + query = {"%s__in" % self.field.name: instances} queryset = queryset.filter(**query) # Since we just bypassed this class' get_queryset(), we must manage @@ -658,9 +718,13 @@ def create_reverse_many_to_one_manager(superclass, rel): def check_and_update_obj(obj): if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected, got %r" % ( - self.model._meta.object_name, obj, - )) + raise TypeError( + "'%s' instance expected, got %r" + % ( + self.model._meta.object_name, + obj, + ) + ) setattr(obj, self.field.name, self.instance) if bulk: @@ -673,36 +737,43 @@ def create_reverse_many_to_one_manager(superclass, rel): "the object first." % obj ) pks.append(obj.pk) - self.model._base_manager.using(db).filter(pk__in=pks).update(**{ - self.field.name: self.instance, - }) + self.model._base_manager.using(db).filter(pk__in=pks).update( + **{ + self.field.name: self.instance, + } + ) else: with transaction.atomic(using=db, savepoint=False): for obj in objs: check_and_update_obj(obj) obj.save() + add.alters_data = True def create(self, **kwargs): kwargs[self.field.name] = self.instance db = router.db_for_write(self.model, instance=self.instance) return super(RelatedManager, self.db_manager(db)).create(**kwargs) + create.alters_data = True def get_or_create(self, **kwargs): kwargs[self.field.name] = self.instance db = router.db_for_write(self.model, instance=self.instance) return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) + get_or_create.alters_data = True def update_or_create(self, **kwargs): kwargs[self.field.name] = self.instance db = router.db_for_write(self.model, instance=self.instance) return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs) + update_or_create.alters_data = True # remove() and clear() are only provided if the ForeignKey can have a value of null. if rel.field.null: + def remove(self, *objs, bulk=True): if not objs: return @@ -710,9 +781,13 @@ def create_reverse_many_to_one_manager(superclass, rel): old_ids = set() for obj in objs: if not isinstance(obj, self.model): - raise TypeError("'%s' instance expected, got %r" % ( - self.model._meta.object_name, obj, - )) + raise TypeError( + "'%s' instance expected, got %r" + % ( + self.model._meta.object_name, + obj, + ) + ) # Is obj actually part of this descriptor set? if self.field.get_local_related_value(obj) == val: old_ids.add(obj.pk) @@ -721,10 +796,12 @@ def create_reverse_many_to_one_manager(superclass, rel): "%r is not related to %r." % (obj, self.instance) ) self._clear(self.filter(pk__in=old_ids), bulk) + remove.alters_data = True def clear(self, *, bulk=True): self._clear(self, bulk) + clear.alters_data = True def _clear(self, queryset, bulk): @@ -739,6 +816,7 @@ def create_reverse_many_to_one_manager(superclass, rel): for obj in queryset: setattr(obj, self.field.name, None) obj.save(update_fields=[self.field.name]) + _clear.alters_data = True def set(self, objs, *, bulk=True, clear=False): @@ -765,6 +843,7 @@ def create_reverse_many_to_one_manager(superclass, rel): self.add(*new_objs, bulk=bulk) else: self.add(*objs, bulk=bulk) + set.alters_data = True return RelatedManager @@ -822,7 +901,8 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor): def _get_set_deprecation_msg_params(self): return ( - '%s side of a many-to-many set' % ('reverse' if self.reverse else 'forward'), + "%s side of a many-to-many set" + % ("reverse" if self.reverse else "forward"), self.rel.get_accessor_name() if self.reverse else self.field.name, ) @@ -865,41 +945,51 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): self.core_filters = {} self.pk_field_names = {} for lh_field, rh_field in self.source_field.related_fields: - core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name) + core_filter_key = "%s__%s" % (self.query_field_name, rh_field.name) self.core_filters[core_filter_key] = getattr(instance, rh_field.attname) self.pk_field_names[lh_field.name] = rh_field.name self.related_val = self.source_field.get_foreign_related_value(instance) if None in self.related_val: - raise ValueError('"%r" needs to have a value for field "%s" before ' - 'this many-to-many relationship can be used.' % - (instance, self.pk_field_names[self.source_field_name])) + raise ValueError( + '"%r" needs to have a value for field "%s" before ' + "this many-to-many relationship can be used." + % (instance, self.pk_field_names[self.source_field_name]) + ) # Even if this relation is not to pk, we require still pk value. # The wish is that the instance has been already saved to DB, # although having a pk value isn't a guarantee of that. if instance.pk is None: - raise ValueError("%r instance needs to have a primary key value before " - "a many-to-many relationship can be used." % - instance.__class__.__name__) + raise ValueError( + "%r instance needs to have a primary key value before " + "a many-to-many relationship can be used." + % instance.__class__.__name__ + ) def __call__(self, *, manager): manager = getattr(self.model, manager) - manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse) + manager_class = create_forward_many_to_many_manager( + manager.__class__, rel, reverse + ) return manager_class(instance=self.instance) + do_not_call_in_templates = True def _build_remove_filters(self, removed_vals): filters = Q((self.source_field_name, self.related_val)) # No need to add a subquery condition if removed_vals is a QuerySet without # filters. - removed_vals_filters = (not isinstance(removed_vals, QuerySet) or - removed_vals._has_filters()) + removed_vals_filters = ( + not isinstance(removed_vals, QuerySet) or removed_vals._has_filters() + ) if removed_vals_filters: - filters &= Q((f'{self.target_field_name}__in', removed_vals)) + filters &= Q((f"{self.target_field_name}__in", removed_vals)) if self.symmetrical: symmetrical_filters = Q((self.target_field_name, self.related_val)) if removed_vals_filters: - symmetrical_filters &= Q((f'{self.source_field_name}__in', removed_vals)) + symmetrical_filters &= Q( + (f"{self.source_field_name}__in", removed_vals) + ) filters |= symmetrical_filters return filters @@ -933,7 +1023,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): queryset._add_hints(instance=instances[0]) queryset = queryset.using(queryset._db or self._db) - query = {'%s__in' % self.query_field_name: instances} + query = {"%s__in" % self.query_field_name: instances} queryset = queryset._next_is_sticky().filter(**query) # M2M: need to annotate the query in order to get the primary model @@ -947,13 +1037,18 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): join_table = fk.model._meta.db_table connection = connections[queryset.db] qn = connection.ops.quote_name - queryset = queryset.extra(select={ - '_prefetch_related_val_%s' % f.attname: - '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields}) + queryset = queryset.extra( + select={ + "_prefetch_related_val_%s" + % f.attname: "%s.%s" + % (qn(join_table), qn(f.column)) + for f in fk.local_related_fields + } + ) return ( queryset, lambda result: tuple( - getattr(result, '_prefetch_related_val_%s' % f.attname) + getattr(result, "_prefetch_related_val_%s" % f.attname) for f in fk.local_related_fields ), lambda inst: tuple( @@ -970,7 +1065,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): self._add_items( - self.source_field_name, self.target_field_name, *objs, + self.source_field_name, + self.target_field_name, + *objs, through_defaults=through_defaults, ) # If this is a symmetrical m2m relation to self, add the mirror @@ -982,30 +1079,41 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): *objs, through_defaults=through_defaults, ) + add.alters_data = True def remove(self, *objs): self._remove_prefetched_objects() self._remove_items(self.source_field_name, self.target_field_name, *objs) + remove.alters_data = True def clear(self): db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): signals.m2m_changed.send( - sender=self.through, action="pre_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db, + sender=self.through, + action="pre_clear", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=None, + using=db, ) self._remove_prefetched_objects() filters = self._build_remove_filters(super().get_queryset().using(db)) self.through._default_manager.using(db).filter(filters).delete() signals.m2m_changed.send( - sender=self.through, action="post_clear", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=None, using=db, + sender=self.through, + action="post_clear", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=None, + using=db, ) + clear.alters_data = True def set(self, objs, *, clear=False, through_defaults=None): @@ -1019,7 +1127,11 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): self.clear() self.add(*objs, through_defaults=through_defaults) else: - old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True)) + old_ids = set( + self.using(db).values_list( + self.target_field.target_field.attname, flat=True + ) + ) new_objs = [] for obj in objs: @@ -1035,6 +1147,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): self.remove(*old_ids) self.add(*new_objs, through_defaults=through_defaults) + set.alters_data = True def create(self, *, through_defaults=None, **kwargs): @@ -1042,26 +1155,33 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs) self.add(new_obj, through_defaults=through_defaults) return new_obj + create.alters_data = True def get_or_create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) - obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs) + obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create( + **kwargs + ) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: self.add(obj, through_defaults=through_defaults) return obj, created + get_or_create.alters_data = True def update_or_create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) - obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs) + obj, created = super( + ManyRelatedManager, self.db_manager(db) + ).update_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: self.add(obj, through_defaults=through_defaults) return obj, created + update_or_create.alters_data = True def _get_target_ids(self, target_field_name, objs): @@ -1069,6 +1189,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): Return the set of ids of `objs` that the target field references. """ from django.db.models import Model + target_ids = set() target_field = self.through._meta.get_field(target_field_name) for obj in objs: @@ -1076,36 +1197,42 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): if not router.allow_relation(obj, self.instance): raise ValueError( 'Cannot add "%r": instance is on database "%s", ' - 'value is on database "%s"' % - (obj, self.instance._state.db, obj._state.db) + 'value is on database "%s"' + % (obj, self.instance._state.db, obj._state.db) ) target_id = target_field.get_foreign_related_value(obj)[0] if target_id is None: raise ValueError( - 'Cannot add "%r": the value for field "%s" is None' % - (obj, target_field_name) + 'Cannot add "%r": the value for field "%s" is None' + % (obj, target_field_name) ) target_ids.add(target_id) elif isinstance(obj, Model): raise TypeError( - "'%s' instance expected, got %r" % - (self.model._meta.object_name, obj) + "'%s' instance expected, got %r" + % (self.model._meta.object_name, obj) ) else: target_ids.add(target_field.get_prep_value(obj)) return target_ids - def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids): + def _get_missing_target_ids( + self, source_field_name, target_field_name, db, target_ids + ): """ Return the subset of ids of `objs` that aren't already assigned to this relationship. """ - vals = self.through._default_manager.using(db).values_list( - target_field_name, flat=True - ).filter(**{ - source_field_name: self.related_val[0], - '%s__in' % target_field_name: target_ids, - }) + vals = ( + self.through._default_manager.using(db) + .values_list(target_field_name, flat=True) + .filter( + **{ + source_field_name: self.related_val[0], + "%s__in" % target_field_name: target_ids, + } + ) + ) return target_ids.difference(vals) def _get_add_plan(self, db, source_field_name): @@ -1123,21 +1250,27 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): # user-defined intermediary models as they could have other fields # causing conflicts which must be surfaced. can_ignore_conflicts = ( - self.through._meta.auto_created is not False and - connections[db].features.supports_ignore_conflicts + self.through._meta.auto_created is not False + and connections[db].features.supports_ignore_conflicts ) # Don't send the signal when inserting duplicate data row # for symmetrical reverse entries. - must_send_signals = (self.reverse or source_field_name == self.source_field_name) and ( - signals.m2m_changed.has_listeners(self.through) - ) + must_send_signals = ( + self.reverse or source_field_name == self.source_field_name + ) and (signals.m2m_changed.has_listeners(self.through)) # Fast addition through bulk insertion can only be performed # if no m2m_changed listeners are connected for self.through # as they require the added set of ids to be provided via # pk_set. - return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals) + return ( + can_ignore_conflicts, + must_send_signals, + (can_ignore_conflicts and not must_send_signals), + ) - def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): + def _add_items( + self, source_field_name, target_field_name, *objs, through_defaults=None + ): # source_field_name: the PK fieldname in join table for the source object # target_field_name: the PK fieldname in join table for the target object # *objs - objects to add. Either object instances, or primary keys of object instances. @@ -1147,15 +1280,22 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): through_defaults = dict(resolve_callables(through_defaults or {})) target_ids = self._get_target_ids(target_field_name, objs) db = router.db_for_write(self.through, instance=self.instance) - can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name) + can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan( + db, source_field_name + ) if can_fast_add: - self.through._default_manager.using(db).bulk_create([ - self.through(**{ - '%s_id' % source_field_name: self.related_val[0], - '%s_id' % target_field_name: target_id, - }) - for target_id in target_ids - ], ignore_conflicts=True) + self.through._default_manager.using(db).bulk_create( + [ + self.through( + **{ + "%s_id" % source_field_name: self.related_val[0], + "%s_id" % target_field_name: target_id, + } + ) + for target_id in target_ids + ], + ignore_conflicts=True, + ) return missing_target_ids = self._get_missing_target_ids( @@ -1164,24 +1304,38 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): with transaction.atomic(using=db, savepoint=False): if must_send_signals: signals.m2m_changed.send( - sender=self.through, action='pre_add', - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=missing_target_ids, using=db, + sender=self.through, + action="pre_add", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=missing_target_ids, + using=db, ) # Add the ones that aren't there already. - self.through._default_manager.using(db).bulk_create([ - self.through(**through_defaults, **{ - '%s_id' % source_field_name: self.related_val[0], - '%s_id' % target_field_name: target_id, - }) - for target_id in missing_target_ids - ], ignore_conflicts=can_ignore_conflicts) + self.through._default_manager.using(db).bulk_create( + [ + self.through( + **through_defaults, + **{ + "%s_id" % source_field_name: self.related_val[0], + "%s_id" % target_field_name: target_id, + }, + ) + for target_id in missing_target_ids + ], + ignore_conflicts=can_ignore_conflicts, + ) if must_send_signals: signals.m2m_changed.send( - sender=self.through, action='post_add', - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=missing_target_ids, using=db, + sender=self.through, + action="post_add", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=missing_target_ids, + using=db, ) def _remove_items(self, source_field_name, target_field_name, *objs): @@ -1205,23 +1359,32 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): with transaction.atomic(using=db, savepoint=False): # Send a signal to the other end if need be. signals.m2m_changed.send( - sender=self.through, action="pre_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db, + sender=self.through, + action="pre_remove", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=old_ids, + using=db, ) target_model_qs = super().get_queryset() if target_model_qs._has_filters(): - old_vals = target_model_qs.using(db).filter(**{ - '%s__in' % self.target_field.target_field.attname: old_ids}) + old_vals = target_model_qs.using(db).filter( + **{"%s__in" % self.target_field.target_field.attname: old_ids} + ) else: old_vals = old_ids filters = self._build_remove_filters(old_vals) self.through._default_manager.using(db).filter(filters).delete() signals.m2m_changed.send( - sender=self.through, action="post_remove", - instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=old_ids, using=db, + sender=self.through, + action="post_remove", + instance=self.instance, + reverse=self.reverse, + model=self.model, + pk_set=old_ids, + using=db, ) return ManyRelatedManager diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index fd97757b14..1bad1cf416 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -1,5 +1,10 @@ from django.db.models.lookups import ( - Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan, + Exact, + GreaterThan, + GreaterThanOrEqual, + In, + IsNull, + LessThan, LessThanOrEqual, ) @@ -8,16 +13,21 @@ class MultiColSource: contains_aggregate = False def __init__(self, alias, targets, sources, field): - self.targets, self.sources, self.field, self.alias = targets, sources, field, alias + self.targets, self.sources, self.field, self.alias = ( + targets, + sources, + field, + alias, + ) self.output_field = self.field def __repr__(self): - return "{}({}, {})".format( - self.__class__.__name__, self.alias, self.field) + return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field) def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), - self.targets, self.sources, self.field) + return self.__class__( + relabels.get(self.alias, self.alias), self.targets, self.sources, self.field + ) def get_lookup(self, lookup): return self.output_field.get_lookup(lookup) @@ -28,12 +38,15 @@ class MultiColSource: def get_normalized_value(value, lhs): from django.db.models import Model + if isinstance(value, Model): value_list = [] sources = lhs.output_field.path_infos[-1].target_fields for source in sources: while not isinstance(value, source.model) and source.remote_field: - source = source.remote_field.model._meta.get_field(source.remote_field.field_name) + source = source.remote_field.model._meta.get_field( + source.remote_field.field_name + ) try: value_list.append(getattr(value, source.attname)) except AttributeError: @@ -56,20 +69,21 @@ class RelatedIn(In): # case ForeignKey to IntegerField given value 'abc'. The # ForeignKey itself doesn't have validation for non-integers, # so we must run validation using the target field. - if hasattr(self.lhs.output_field, 'path_infos'): + if hasattr(self.lhs.output_field, "path_infos"): # Run the target field's get_prep_value. We can safely # assume there is only one as we don't get to the direct # value branch otherwise. - target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] + target_field = self.lhs.output_field.path_infos[-1].target_fields[ + -1 + ] self.rhs = [target_field.get_prep_value(v) for v in self.rhs] - elif ( - not getattr(self.rhs, 'has_select_fields', True) and - not getattr(self.lhs.field.target_field, 'primary_key', False) + elif not getattr(self.rhs, "has_select_fields", True) and not getattr( + self.lhs.field.target_field, "primary_key", False ): self.rhs.clear_select_clause() if ( - getattr(self.lhs.output_field, 'primary_key', False) and - self.lhs.output_field.model == self.rhs.model + getattr(self.lhs.output_field, "primary_key", False) + and self.lhs.output_field.model == self.rhs.model ): # A case like # Restaurant.objects.filter(place__in=restaurant_qs), where @@ -87,7 +101,10 @@ class RelatedIn(In): # This clause is either a SubqueryConstraint (for values that need to be compiled to # SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses. from django.db.models.sql.where import ( - AND, OR, SubqueryConstraint, WhereNode, + AND, + OR, + SubqueryConstraint, + WhereNode, ) root_constraint = WhereNode(connector=OR) @@ -95,31 +112,41 @@ class RelatedIn(In): values = [get_normalized_value(value, self.lhs) for value in self.rhs] for value in values: value_constraint = WhereNode() - for source, target, val in zip(self.lhs.sources, self.lhs.targets, value): - lookup_class = target.get_lookup('exact') - lookup = lookup_class(target.get_col(self.lhs.alias, source), val) + for source, target, val in zip( + self.lhs.sources, self.lhs.targets, value + ): + lookup_class = target.get_lookup("exact") + lookup = lookup_class( + target.get_col(self.lhs.alias, source), val + ) value_constraint.add(lookup, AND) root_constraint.add(value_constraint, OR) else: root_constraint.add( SubqueryConstraint( - self.lhs.alias, [target.column for target in self.lhs.targets], - [source.name for source in self.lhs.sources], self.rhs), - AND) + self.lhs.alias, + [target.column for target in self.lhs.targets], + [source.name for source in self.lhs.sources], + self.rhs, + ), + AND, + ) return root_constraint.as_sql(compiler, connection) return super().as_sql(compiler, connection) class RelatedLookupMixin: def get_prep_lookup(self): - if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'): + if not isinstance(self.lhs, MultiColSource) and not hasattr( + self.rhs, "resolve_expression" + ): # If we get here, we are dealing with single-column relations. self.rhs = get_normalized_value(self.rhs, self.lhs)[0] # We need to run the related field's get_prep_value(). Consider case # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself # doesn't have validation for non-integers, so we must run validation # using the target field. - if self.prepare_rhs and hasattr(self.lhs.output_field, 'path_infos'): + if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"): # Get the target field. We can safely assume there is only one # as we don't get to the direct value branch otherwise. target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] @@ -132,11 +159,15 @@ class RelatedLookupMixin: assert self.rhs_is_direct_value() self.rhs = get_normalized_value(self.rhs, self.lhs) from django.db.models.sql.where import AND, WhereNode + root_constraint = WhereNode() - for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs): + for target, source, val in zip( + self.lhs.targets, self.lhs.sources, self.rhs + ): lookup_class = target.get_lookup(self.lookup_name) root_constraint.add( - lookup_class(target.get_col(self.lhs.alias, source), val), AND) + lookup_class(target.get_col(self.lhs.alias, source), val), AND + ) return root_constraint.as_sql(compiler, connection) return super().as_sql(compiler, connection) diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py index 6f0c788bbd..2ff66f34d0 100644 --- a/django/db/models/fields/reverse_related.py +++ b/django/db/models/fields/reverse_related.py @@ -36,8 +36,16 @@ class ForeignObjectRel(FieldCacheMixin): null = True empty_strings_allowed = False - def __init__(self, field, to, related_name=None, related_query_name=None, - limit_choices_to=None, parent_link=False, on_delete=None): + def __init__( + self, + field, + to, + related_name=None, + related_query_name=None, + limit_choices_to=None, + parent_link=False, + on_delete=None, + ): self.field = field self.model = to self.related_name = related_name @@ -73,14 +81,17 @@ class ForeignObjectRel(FieldCacheMixin): """ target_fields = self.path_infos[-1].target_fields if len(target_fields) > 1: - raise exceptions.FieldError("Can't use target_field for multicolumn relations.") + raise exceptions.FieldError( + "Can't use target_field for multicolumn relations." + ) return target_fields[0] @cached_property def related_model(self): if not self.field.model: raise AttributeError( - "This property can't be accessed before self.field.contribute_to_class has been called.") + "This property can't be accessed before self.field.contribute_to_class has been called." + ) return self.field.model @cached_property @@ -110,7 +121,7 @@ class ForeignObjectRel(FieldCacheMixin): return self.field.db_type def __repr__(self): - return '<%s: %s.%s>' % ( + return "<%s: %s.%s>" % ( type(self).__name__, self.related_model._meta.app_label, self.related_model._meta.model_name, @@ -147,12 +158,15 @@ class ForeignObjectRel(FieldCacheMixin): # created and doesn't exist in the .models module. # This is a reverse relation, so there is no reverse_path_infos to # delete. - state.pop('path_infos', None) + state.pop("path_infos", None) return state def get_choices( - self, include_blank=True, blank_choice=BLANK_CHOICE_DASH, - limit_choices_to=None, ordering=(), + self, + include_blank=True, + blank_choice=BLANK_CHOICE_DASH, + limit_choices_to=None, + ordering=(), ): """ Return choices with a default blank choices included, for use @@ -165,13 +179,11 @@ class ForeignObjectRel(FieldCacheMixin): qs = self.related_model._default_manager.complex_filter(limit_choices_to) if ordering: qs = qs.order_by(*ordering) - return (blank_choice if include_blank else []) + [ - (x.pk, str(x)) for x in qs - ] + return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs] def is_hidden(self): """Should the related object be hidden?""" - return bool(self.related_name) and self.related_name[-1] == '+' + return bool(self.related_name) and self.related_name[-1] == "+" def get_joining_columns(self): return self.field.get_reverse_joining_columns() @@ -204,7 +216,7 @@ class ForeignObjectRel(FieldCacheMixin): return None if self.related_name: return self.related_name - return opts.model_name + ('_set' if self.multiple else '') + return opts.model_name + ("_set" if self.multiple else "") def get_path_info(self, filtered_relation=None): if filtered_relation: @@ -239,10 +251,20 @@ class ManyToOneRel(ForeignObjectRel): reverse relations into actual fields. """ - def __init__(self, field, to, field_name, related_name=None, related_query_name=None, - limit_choices_to=None, parent_link=False, on_delete=None): + def __init__( + self, + field, + to, + field_name, + related_name=None, + related_query_name=None, + limit_choices_to=None, + parent_link=False, + on_delete=None, + ): super().__init__( - field, to, + field, + to, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, @@ -254,7 +276,7 @@ class ManyToOneRel(ForeignObjectRel): def __getstate__(self): state = super().__getstate__() - state.pop('related_model', None) + state.pop("related_model", None) return state @property @@ -267,7 +289,9 @@ class ManyToOneRel(ForeignObjectRel): """ field = self.model._meta.get_field(self.field_name) if not field.concrete: - raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name) + raise exceptions.FieldDoesNotExist( + "No related field named '%s'" % self.field_name + ) return field def set_field_name(self): @@ -282,10 +306,21 @@ class OneToOneRel(ManyToOneRel): flags for the reverse relation. """ - def __init__(self, field, to, field_name, related_name=None, related_query_name=None, - limit_choices_to=None, parent_link=False, on_delete=None): + def __init__( + self, + field, + to, + field_name, + related_name=None, + related_query_name=None, + limit_choices_to=None, + parent_link=False, + on_delete=None, + ): super().__init__( - field, to, field_name, + field, + to, + field_name, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, @@ -304,11 +339,21 @@ class ManyToManyRel(ForeignObjectRel): flags for the reverse relation. """ - def __init__(self, field, to, related_name=None, related_query_name=None, - limit_choices_to=None, symmetrical=True, through=None, - through_fields=None, db_constraint=True): + def __init__( + self, + field, + to, + related_name=None, + related_query_name=None, + limit_choices_to=None, + symmetrical=True, + through=None, + through_fields=None, + db_constraint=True, + ): super().__init__( - field, to, + field, + to, related_name=related_name, related_query_name=related_query_name, limit_choices_to=limit_choices_to, @@ -343,7 +388,7 @@ class ManyToManyRel(ForeignObjectRel): field = opts.get_field(self.through_fields[0]) else: for field in opts.fields: - rel = getattr(field, 'remote_field', None) + rel = getattr(field, "remote_field", None) if rel and rel.model == self.model: break return field.foreign_related_fields[0] diff --git a/django/db/models/functions/__init__.py b/django/db/models/functions/__init__.py index d687af135d..cd7c801894 100644 --- a/django/db/models/functions/__init__.py +++ b/django/db/models/functions/__init__.py @@ -1,46 +1,190 @@ -from .comparison import ( - Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf, -) +from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf from .datetime import ( - Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear, - ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek, - ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour, - TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek, + Extract, + ExtractDay, + ExtractHour, + ExtractIsoWeekDay, + ExtractIsoYear, + ExtractMinute, + ExtractMonth, + ExtractQuarter, + ExtractSecond, + ExtractWeek, + ExtractWeekDay, + ExtractYear, + Now, + Trunc, + TruncDate, + TruncDay, + TruncHour, + TruncMinute, + TruncMonth, + TruncQuarter, + TruncSecond, + TruncTime, + TruncWeek, TruncYear, ) from .math import ( - Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log, - Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan, + Abs, + ACos, + ASin, + ATan, + ATan2, + Ceil, + Cos, + Cot, + Degrees, + Exp, + Floor, + Ln, + Log, + Mod, + Pi, + Power, + Radians, + Random, + Round, + Sign, + Sin, + Sqrt, + Tan, ) from .text import ( - MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left, - Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad, - RTrim, StrIndex, Substr, Trim, Upper, + MD5, + SHA1, + SHA224, + SHA256, + SHA384, + SHA512, + Chr, + Concat, + ConcatPair, + Left, + Length, + Lower, + LPad, + LTrim, + Ord, + Repeat, + Replace, + Reverse, + Right, + RPad, + RTrim, + StrIndex, + Substr, + Trim, + Upper, ) from .window import ( - CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile, - PercentRank, Rank, RowNumber, + CumeDist, + DenseRank, + FirstValue, + Lag, + LastValue, + Lead, + NthValue, + Ntile, + PercentRank, + Rank, + RowNumber, ) __all__ = [ # comparison and conversion - 'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf', + "Cast", + "Coalesce", + "Collate", + "Greatest", + "JSONObject", + "Least", + "NullIf", # datetime - 'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth', - 'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay', - 'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc', - 'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth', - 'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear', + "Extract", + "ExtractDay", + "ExtractHour", + "ExtractMinute", + "ExtractMonth", + "ExtractQuarter", + "ExtractSecond", + "ExtractWeek", + "ExtractIsoWeekDay", + "ExtractWeekDay", + "ExtractIsoYear", + "ExtractYear", + "Now", + "Trunc", + "TruncDate", + "TruncDay", + "TruncHour", + "TruncMinute", + "TruncMonth", + "TruncQuarter", + "TruncSecond", + "TruncTime", + "TruncWeek", + "TruncYear", # math - 'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees', - 'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random', - 'Round', 'Sign', 'Sin', 'Sqrt', 'Tan', + "Abs", + "ACos", + "ASin", + "ATan", + "ATan2", + "Ceil", + "Cos", + "Cot", + "Degrees", + "Exp", + "Floor", + "Ln", + "Log", + "Mod", + "Pi", + "Power", + "Radians", + "Random", + "Round", + "Sign", + "Sin", + "Sqrt", + "Tan", # text - 'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat', - 'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat', - 'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr', - 'Trim', 'Upper', + "MD5", + "SHA1", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "Chr", + "Concat", + "ConcatPair", + "Left", + "Length", + "Lower", + "LPad", + "LTrim", + "Ord", + "Repeat", + "Replace", + "Reverse", + "Right", + "RPad", + "RTrim", + "StrIndex", + "Substr", + "Trim", + "Upper", # window - 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', - 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', + "CumeDist", + "DenseRank", + "FirstValue", + "Lag", + "LastValue", + "Lead", + "NthValue", + "Ntile", + "PercentRank", + "Rank", + "RowNumber", ] diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index e5882de9c2..cc78834f20 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -7,38 +7,43 @@ from django.utils.regex_helper import _lazy_re_compile class Cast(Func): """Coerce an expression to a new field type.""" - function = 'CAST' - template = '%(function)s(%(expressions)s AS %(db_type)s)' + + function = "CAST" + template = "%(function)s(%(expressions)s AS %(db_type)s)" def __init__(self, expression, output_field): super().__init__(expression, output_field=output_field) def as_sql(self, compiler, connection, **extra_context): - extra_context['db_type'] = self.output_field.cast_db_type(connection) + extra_context["db_type"] = self.output_field.cast_db_type(connection) return super().as_sql(compiler, connection, **extra_context) def as_sqlite(self, compiler, connection, **extra_context): db_type = self.output_field.db_type(connection) - if db_type in {'datetime', 'time'}: + if db_type in {"datetime", "time"}: # Use strftime as datetime/time don't keep fractional seconds. - template = 'strftime(%%s, %(expressions)s)' - sql, params = super().as_sql(compiler, connection, template=template, **extra_context) - format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f' + template = "strftime(%%s, %(expressions)s)" + sql, params = super().as_sql( + compiler, connection, template=template, **extra_context + ) + format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" params.insert(0, format_string) return sql, params - elif db_type == 'date': - template = 'date(%(expressions)s)' - return super().as_sql(compiler, connection, template=template, **extra_context) + elif db_type == "date": + template = "date(%(expressions)s)" + return super().as_sql( + compiler, connection, template=template, **extra_context + ) return self.as_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection, **extra_context): template = None output_type = self.output_field.get_internal_type() # MySQL doesn't support explicit cast to float. - if output_type == 'FloatField': - template = '(%(expressions)s + 0.0)' + if output_type == "FloatField": + template = "(%(expressions)s + 0.0)" # MariaDB doesn't support explicit cast to JSON. - elif output_type == 'JSONField' and connection.mysql_is_mariadb: + elif output_type == "JSONField" and connection.mysql_is_mariadb: template = "JSON_EXTRACT(%(expressions)s, '$')" return self.as_sql(compiler, connection, template=template, **extra_context) @@ -46,23 +51,31 @@ class Cast(Func): # CAST would be valid too, but the :: shortcut syntax is more readable. # 'expressions' is wrapped in parentheses in case it's a complex # expression. - return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) + return self.as_sql( + compiler, + connection, + template="(%(expressions)s)::%(db_type)s", + **extra_context, + ) def as_oracle(self, compiler, connection, **extra_context): - if self.output_field.get_internal_type() == 'JSONField': + if self.output_field.get_internal_type() == "JSONField": # Oracle doesn't support explicit cast to JSON. template = "JSON_QUERY(%(expressions)s, '$')" - return super().as_sql(compiler, connection, template=template, **extra_context) + return super().as_sql( + compiler, connection, template=template, **extra_context + ) return self.as_sql(compiler, connection, **extra_context) class Coalesce(Func): """Return, from left to right, the first non-null expression.""" - function = 'COALESCE' + + function = "COALESCE" def __init__(self, *expressions, **extra): if len(expressions) < 2: - raise ValueError('Coalesce must take at least two expressions') + raise ValueError("Coalesce must take at least two expressions") super().__init__(*expressions, **extra) @property @@ -76,29 +89,32 @@ class Coalesce(Func): def as_oracle(self, compiler, connection, **extra_context): # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), # so convert all fields to NCLOB when that type is expected. - if self.output_field.get_internal_type() == 'TextField': + if self.output_field.get_internal_type() == "TextField": clone = self.copy() - clone.set_source_expressions([ - Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions() - ]) + clone.set_source_expressions( + [ + Func(expression, function="TO_NCLOB") + for expression in self.get_source_expressions() + ] + ) return super(Coalesce, clone).as_sql(compiler, connection, **extra_context) return self.as_sql(compiler, connection, **extra_context) class Collate(Func): - function = 'COLLATE' - template = '%(expressions)s %(function)s %(collation)s' + function = "COLLATE" + template = "%(expressions)s %(function)s %(collation)s" # Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS - collation_re = _lazy_re_compile(r'^[\w\-]+$') + collation_re = _lazy_re_compile(r"^[\w\-]+$") def __init__(self, expression, collation): if not (collation and self.collation_re.match(collation)): - raise ValueError('Invalid collation name: %r.' % collation) + raise ValueError("Invalid collation name: %r." % collation) self.collation = collation super().__init__(expression) def as_sql(self, compiler, connection, **extra_context): - extra_context.setdefault('collation', connection.ops.quote_name(self.collation)) + extra_context.setdefault("collation", connection.ops.quote_name(self.collation)) return super().as_sql(compiler, connection, **extra_context) @@ -110,20 +126,21 @@ class Greatest(Func): On PostgreSQL, the maximum not-null expression is returned. On MySQL, Oracle, and SQLite, if any expression is null, null is returned. """ - function = 'GREATEST' + + function = "GREATEST" def __init__(self, *expressions, **extra): if len(expressions) < 2: - raise ValueError('Greatest must take at least two expressions') + raise ValueError("Greatest must take at least two expressions") super().__init__(*expressions, **extra) def as_sqlite(self, compiler, connection, **extra_context): """Use the MAX function on SQLite.""" - return super().as_sqlite(compiler, connection, function='MAX', **extra_context) + return super().as_sqlite(compiler, connection, function="MAX", **extra_context) class JSONObject(Func): - function = 'JSON_OBJECT' + function = "JSON_OBJECT" output_field = JSONField() def __init__(self, **fields): @@ -135,7 +152,7 @@ class JSONObject(Func): def as_sql(self, compiler, connection, **extra_context): if not connection.features.has_json_object_function: raise NotSupportedError( - 'JSONObject() is not supported on this database backend.' + "JSONObject() is not supported on this database backend." ) return super().as_sql(compiler, connection, **extra_context) @@ -143,21 +160,21 @@ class JSONObject(Func): return self.as_sql( compiler, connection, - function='JSONB_BUILD_OBJECT', + function="JSONB_BUILD_OBJECT", **extra_context, ) def as_oracle(self, compiler, connection, **extra_context): class ArgJoiner: def join(self, args): - args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])] - return ', '.join(args) + args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])] + return ", ".join(args) return self.as_sql( compiler, connection, arg_joiner=ArgJoiner(), - template='%(function)s(%(expressions)s RETURNING CLOB)', + template="%(function)s(%(expressions)s RETURNING CLOB)", **extra_context, ) @@ -170,24 +187,25 @@ class Least(Func): On PostgreSQL, return the minimum not-null expression. On MySQL, Oracle, and SQLite, if any expression is null, return null. """ - function = 'LEAST' + + function = "LEAST" def __init__(self, *expressions, **extra): if len(expressions) < 2: - raise ValueError('Least must take at least two expressions') + raise ValueError("Least must take at least two expressions") super().__init__(*expressions, **extra) def as_sqlite(self, compiler, connection, **extra_context): """Use the MIN function on SQLite.""" - return super().as_sqlite(compiler, connection, function='MIN', **extra_context) + return super().as_sqlite(compiler, connection, function="MIN", **extra_context) class NullIf(Func): - function = 'NULLIF' + function = "NULLIF" arity = 2 def as_oracle(self, compiler, connection, **extra_context): expression1 = self.get_source_expressions()[0] if isinstance(expression1, Value) and expression1.value is None: - raise ValueError('Oracle does not allow Value(None) for expression1.') + raise ValueError("Oracle does not allow Value(None) for expression1.") return super().as_sql(compiler, connection, **extra_context) diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index 07f884f78d..2d6ec7089e 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -3,10 +3,20 @@ from datetime import datetime from django.conf import settings from django.db.models.expressions import Func from django.db.models.fields import ( - DateField, DateTimeField, DurationField, Field, IntegerField, TimeField, + DateField, + DateTimeField, + DurationField, + Field, + IntegerField, + TimeField, ) from django.db.models.lookups import ( - Transform, YearExact, YearGt, YearGte, YearLt, YearLte, + Transform, + YearExact, + YearGt, + YearGte, + YearLt, + YearLte, ) from django.utils import timezone @@ -36,7 +46,7 @@ class Extract(TimezoneMixin, Transform): if self.lookup_name is None: self.lookup_name = lookup_name if self.lookup_name is None: - raise ValueError('lookup_name must be provided') + raise ValueError("lookup_name must be provided") self.tzinfo = tzinfo super().__init__(expression, **extra) @@ -47,14 +57,16 @@ class Extract(TimezoneMixin, Transform): tzname = self.get_tzname() sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) elif self.tzinfo is not None: - raise ValueError('tzinfo can only be used with DateTimeField.') + raise ValueError("tzinfo can only be used with DateTimeField.") elif isinstance(lhs_output_field, DateField): sql = connection.ops.date_extract_sql(self.lookup_name, sql) elif isinstance(lhs_output_field, TimeField): sql = connection.ops.time_extract_sql(self.lookup_name, sql) elif isinstance(lhs_output_field, DurationField): if not connection.features.has_native_duration_field: - raise ValueError('Extract requires native DurationField database support.') + raise ValueError( + "Extract requires native DurationField database support." + ) sql = connection.ops.time_extract_sql(self.lookup_name, sql) else: # resolve_expression has already validated the output_field so this @@ -62,24 +74,38 @@ class Extract(TimezoneMixin, Transform): assert False, "Tried to Extract from an invalid type." return sql, params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) - field = getattr(copy.lhs, 'output_field', None) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + copy = super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + field = getattr(copy.lhs, "output_field", None) if field is None: return copy if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)): raise ValueError( - 'Extract input expression must be DateField, DateTimeField, ' - 'TimeField, or DurationField.' + "Extract input expression must be DateField, DateTimeField, " + "TimeField, or DurationField." ) # Passing dates to functions expecting datetimes is most likely a mistake. - if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'): + if type(field) == DateField and copy.lookup_name in ( + "hour", + "minute", + "second", + ): raise ValueError( - "Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name) + "Cannot extract time component '%s' from DateField '%s'." + % (copy.lookup_name, field.name) ) - if ( - isinstance(field, DurationField) and - copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter') + if isinstance(field, DurationField) and copy.lookup_name in ( + "year", + "iso_year", + "month", + "week", + "week_day", + "iso_week_day", + "quarter", ): raise ValueError( "Cannot extract component '%s' from DurationField '%s'." @@ -89,20 +115,21 @@ class Extract(TimezoneMixin, Transform): class ExtractYear(Extract): - lookup_name = 'year' + lookup_name = "year" class ExtractIsoYear(Extract): """Return the ISO-8601 week-numbering year.""" - lookup_name = 'iso_year' + + lookup_name = "iso_year" class ExtractMonth(Extract): - lookup_name = 'month' + lookup_name = "month" class ExtractDay(Extract): - lookup_name = 'day' + lookup_name = "day" class ExtractWeek(Extract): @@ -110,7 +137,8 @@ class ExtractWeek(Extract): Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the week. """ - lookup_name = 'week' + + lookup_name = "week" class ExtractWeekDay(Extract): @@ -119,28 +147,30 @@ class ExtractWeekDay(Extract): To replicate this in Python: (mydatetime.isoweekday() % 7) + 1 """ - lookup_name = 'week_day' + + lookup_name = "week_day" class ExtractIsoWeekDay(Extract): """Return Monday=1 through Sunday=7, based on ISO-8601.""" - lookup_name = 'iso_week_day' + + lookup_name = "iso_week_day" class ExtractQuarter(Extract): - lookup_name = 'quarter' + lookup_name = "quarter" class ExtractHour(Extract): - lookup_name = 'hour' + lookup_name = "hour" class ExtractMinute(Extract): - lookup_name = 'minute' + lookup_name = "minute" class ExtractSecond(Extract): - lookup_name = 'second' + lookup_name = "second" DateField.register_lookup(ExtractYear) @@ -174,14 +204,16 @@ ExtractIsoYear.register_lookup(YearLte) class Now(Func): - template = 'CURRENT_TIMESTAMP' + template = "CURRENT_TIMESTAMP" output_field = DateTimeField() def as_postgresql(self, compiler, connection, **extra_context): # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with # other databases. - return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context) + return self.as_sql( + compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context + ) class TruncBase(TimezoneMixin, Transform): @@ -190,7 +222,14 @@ class TruncBase(TimezoneMixin, Transform): # RemovedInDjango50Warning: when the deprecation ends, remove is_dst # argument. - def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra): + def __init__( + self, + expression, + output_field=None, + tzinfo=None, + is_dst=timezone.NOT_PASSED, + **extra, + ): self.tzinfo = tzinfo self.is_dst = is_dst super().__init__(expression, output_field=output_field, **extra) @@ -201,7 +240,7 @@ class TruncBase(TimezoneMixin, Transform): if isinstance(self.lhs.output_field, DateTimeField): tzname = self.get_tzname() elif self.tzinfo is not None: - raise ValueError('tzinfo can only be used with DateTimeField.') + raise ValueError("tzinfo can only be used with DateTimeField.") if isinstance(self.output_field, DateTimeField): sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) elif isinstance(self.output_field, DateField): @@ -209,11 +248,17 @@ class TruncBase(TimezoneMixin, Transform): elif isinstance(self.output_field, TimeField): sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname) else: - raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.') + raise ValueError( + "Trunc only valid on DateField, TimeField, or DateTimeField." + ) return sql, inner_params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + copy = super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) field = copy.lhs.output_field # DateTimeField is a subclass of DateField so this works for both. if not isinstance(field, (DateField, TimeField)): @@ -223,23 +268,46 @@ class TruncBase(TimezoneMixin, Transform): # If self.output_field was None, then accessing the field will trigger # the resolver to assign it to self.lhs.output_field. if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)): - raise ValueError('output_field must be either DateField, TimeField, or DateTimeField') + raise ValueError( + "output_field must be either DateField, TimeField, or DateTimeField" + ) # Passing dates or times to functions expecting datetimes is most # likely a mistake. - class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None + class_output_field = ( + self.__class__.output_field + if isinstance(self.__class__.output_field, Field) + else None + ) output_field = class_output_field or copy.output_field - has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__ + has_explicit_output_field = ( + class_output_field or field.__class__ is not copy.output_field.__class__ + ) if type(field) == DateField and ( - isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')): - raise ValueError("Cannot truncate DateField '%s' to %s." % ( - field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField' - )) + isinstance(output_field, DateTimeField) + or copy.kind in ("hour", "minute", "second", "time") + ): + raise ValueError( + "Cannot truncate DateField '%s' to %s." + % ( + field.name, + output_field.__class__.__name__ + if has_explicit_output_field + else "DateTimeField", + ) + ) elif isinstance(field, TimeField) and ( - isinstance(output_field, DateTimeField) or - copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')): - raise ValueError("Cannot truncate TimeField '%s' to %s." % ( - field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField' - )) + isinstance(output_field, DateTimeField) + or copy.kind in ("year", "quarter", "month", "week", "day", "date") + ): + raise ValueError( + "Cannot truncate TimeField '%s' to %s." + % ( + field.name, + output_field.__class__.__name__ + if has_explicit_output_field + else "DateTimeField", + ) + ) return copy def convert_value(self, value, expression, connection): @@ -251,8 +319,8 @@ class TruncBase(TimezoneMixin, Transform): value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst) elif not connection.features.has_zoneinfo_database: raise ValueError( - 'Database returned an invalid datetime value. Are time ' - 'zone definitions for your database installed?' + "Database returned an invalid datetime value. Are time " + "zone definitions for your database installed?" ) elif isinstance(value, datetime): if value is None: @@ -268,38 +336,46 @@ class Trunc(TruncBase): # RemovedInDjango50Warning: when the deprecation ends, remove is_dst # argument. - def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra): + def __init__( + self, + expression, + kind, + output_field=None, + tzinfo=None, + is_dst=timezone.NOT_PASSED, + **extra, + ): self.kind = kind super().__init__( - expression, output_field=output_field, tzinfo=tzinfo, - is_dst=is_dst, **extra + expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra ) class TruncYear(TruncBase): - kind = 'year' + kind = "year" class TruncQuarter(TruncBase): - kind = 'quarter' + kind = "quarter" class TruncMonth(TruncBase): - kind = 'month' + kind = "month" class TruncWeek(TruncBase): """Truncate to midnight on the Monday of the week.""" - kind = 'week' + + kind = "week" class TruncDay(TruncBase): - kind = 'day' + kind = "day" class TruncDate(TruncBase): - kind = 'date' - lookup_name = 'date' + kind = "date" + lookup_name = "date" output_field = DateField() def as_sql(self, compiler, connection): @@ -311,8 +387,8 @@ class TruncDate(TruncBase): class TruncTime(TruncBase): - kind = 'time' - lookup_name = 'time' + kind = "time" + lookup_name = "time" output_field = TimeField() def as_sql(self, compiler, connection): @@ -324,15 +400,15 @@ class TruncTime(TruncBase): class TruncHour(TruncBase): - kind = 'hour' + kind = "hour" class TruncMinute(TruncBase): - kind = 'minute' + kind = "minute" class TruncSecond(TruncBase): - kind = 'second' + kind = "second" DateTimeField.register_lookup(TruncDate) diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py index f939885263..8b5fd79c3a 100644 --- a/django/db/models/functions/math.py +++ b/django/db/models/functions/math.py @@ -4,37 +4,40 @@ from django.db.models.expressions import Func, Value from django.db.models.fields import FloatField, IntegerField from django.db.models.functions import Cast from django.db.models.functions.mixins import ( - FixDecimalInputMixin, NumericOutputFieldMixin, + FixDecimalInputMixin, + NumericOutputFieldMixin, ) from django.db.models.lookups import Transform class Abs(Transform): - function = 'ABS' - lookup_name = 'abs' + function = "ABS" + lookup_name = "abs" class ACos(NumericOutputFieldMixin, Transform): - function = 'ACOS' - lookup_name = 'acos' + function = "ACOS" + lookup_name = "acos" class ASin(NumericOutputFieldMixin, Transform): - function = 'ASIN' - lookup_name = 'asin' + function = "ASIN" + lookup_name = "asin" class ATan(NumericOutputFieldMixin, Transform): - function = 'ATAN' - lookup_name = 'atan' + function = "ATAN" + lookup_name = "atan" class ATan2(NumericOutputFieldMixin, Func): - function = 'ATAN2' + function = "ATAN2" arity = 2 def as_sqlite(self, compiler, connection, **extra_context): - if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0): + if not getattr( + connection.ops, "spatialite", False + ) or connection.ops.spatial_version >= (5, 0, 0): return self.as_sql(compiler, connection) # This function is usually ATan2(y, x), returning the inverse tangent # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0. @@ -42,67 +45,74 @@ class ATan2(NumericOutputFieldMixin, Func): # arguments are mixed between integer and float or decimal. # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2 clone = self.copy() - clone.set_source_expressions([ - Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) - else expression for expression in self.get_source_expressions()[::-1] - ]) + clone.set_source_expressions( + [ + Cast(expression, FloatField()) + if isinstance(expression.output_field, IntegerField) + else expression + for expression in self.get_source_expressions()[::-1] + ] + ) return clone.as_sql(compiler, connection, **extra_context) class Ceil(Transform): - function = 'CEILING' - lookup_name = 'ceil' + function = "CEILING" + lookup_name = "ceil" def as_oracle(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='CEIL', **extra_context) + return super().as_sql(compiler, connection, function="CEIL", **extra_context) class Cos(NumericOutputFieldMixin, Transform): - function = 'COS' - lookup_name = 'cos' + function = "COS" + lookup_name = "cos" class Cot(NumericOutputFieldMixin, Transform): - function = 'COT' - lookup_name = 'cot' - - def as_oracle(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) - - -class Degrees(NumericOutputFieldMixin, Transform): - function = 'DEGREES' - lookup_name = 'degrees' + function = "COT" + lookup_name = "cot" def as_oracle(self, compiler, connection, **extra_context): return super().as_sql( - compiler, connection, - template='((%%(expressions)s) * 180 / %s)' % math.pi, - **extra_context + compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context + ) + + +class Degrees(NumericOutputFieldMixin, Transform): + function = "DEGREES" + lookup_name = "degrees" + + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql( + compiler, + connection, + template="((%%(expressions)s) * 180 / %s)" % math.pi, + **extra_context, ) class Exp(NumericOutputFieldMixin, Transform): - function = 'EXP' - lookup_name = 'exp' + function = "EXP" + lookup_name = "exp" class Floor(Transform): - function = 'FLOOR' - lookup_name = 'floor' + function = "FLOOR" + lookup_name = "floor" class Ln(NumericOutputFieldMixin, Transform): - function = 'LN' - lookup_name = 'ln' + function = "LN" + lookup_name = "ln" class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func): - function = 'LOG' + function = "LOG" arity = 2 def as_sqlite(self, compiler, connection, **extra_context): - if not getattr(connection.ops, 'spatialite', False): + if not getattr(connection.ops, "spatialite", False): return self.as_sql(compiler, connection) # This function is usually Log(b, x) returning the logarithm of x to # the base b, but on SpatiaLite it's Log(x, b). @@ -112,55 +122,60 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func): class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func): - function = 'MOD' + function = "MOD" arity = 2 class Pi(NumericOutputFieldMixin, Func): - function = 'PI' + function = "PI" arity = 0 def as_oracle(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) + return super().as_sql( + compiler, connection, template=str(math.pi), **extra_context + ) class Power(NumericOutputFieldMixin, Func): - function = 'POWER' + function = "POWER" arity = 2 class Radians(NumericOutputFieldMixin, Transform): - function = 'RADIANS' - lookup_name = 'radians' + function = "RADIANS" + lookup_name = "radians" def as_oracle(self, compiler, connection, **extra_context): return super().as_sql( - compiler, connection, - template='((%%(expressions)s) * %s / 180)' % math.pi, - **extra_context + compiler, + connection, + template="((%%(expressions)s) * %s / 180)" % math.pi, + **extra_context, ) class Random(NumericOutputFieldMixin, Func): - function = 'RANDOM' + function = "RANDOM" arity = 0 def as_mysql(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='RAND', **extra_context) + return super().as_sql(compiler, connection, function="RAND", **extra_context) def as_oracle(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context) + return super().as_sql( + compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context + ) def as_sqlite(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='RAND', **extra_context) + return super().as_sql(compiler, connection, function="RAND", **extra_context) def get_group_by_cols(self, alias=None): return [] class Round(FixDecimalInputMixin, Transform): - function = 'ROUND' - lookup_name = 'round' + function = "ROUND" + lookup_name = "round" arity = None # Override Transform's arity=1 to enable passing precision. def __init__(self, expression, precision=0, **extra): @@ -169,7 +184,7 @@ class Round(FixDecimalInputMixin, Transform): def as_sqlite(self, compiler, connection, **extra_context): precision = self.get_source_expressions()[1] if isinstance(precision, Value) and precision.value < 0: - raise ValueError('SQLite does not support negative precision.') + raise ValueError("SQLite does not support negative precision.") return super().as_sqlite(compiler, connection, **extra_context) def _resolve_output_field(self): @@ -178,20 +193,20 @@ class Round(FixDecimalInputMixin, Transform): class Sign(Transform): - function = 'SIGN' - lookup_name = 'sign' + function = "SIGN" + lookup_name = "sign" class Sin(NumericOutputFieldMixin, Transform): - function = 'SIN' - lookup_name = 'sin' + function = "SIN" + lookup_name = "sin" class Sqrt(NumericOutputFieldMixin, Transform): - function = 'SQRT' - lookup_name = 'sqrt' + function = "SQRT" + lookup_name = "sqrt" class Tan(NumericOutputFieldMixin, Transform): - function = 'TAN' - lookup_name = 'tan' + function = "TAN" + lookup_name = "tan" diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py index 00cfd1bc01..caf20e131d 100644 --- a/django/db/models/functions/mixins.py +++ b/django/db/models/functions/mixins.py @@ -5,7 +5,6 @@ from django.db.models.functions import Cast class FixDecimalInputMixin: - def as_postgresql(self, compiler, connection, **extra_context): # Cast FloatField to DecimalField as PostgreSQL doesn't support the # following function signatures: @@ -13,36 +12,42 @@ class FixDecimalInputMixin: # - MOD(double, double) output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000) clone = self.copy() - clone.set_source_expressions([ - Cast(expression, output_field) if isinstance(expression.output_field, FloatField) - else expression for expression in self.get_source_expressions() - ]) + clone.set_source_expressions( + [ + Cast(expression, output_field) + if isinstance(expression.output_field, FloatField) + else expression + for expression in self.get_source_expressions() + ] + ) return clone.as_sql(compiler, connection, **extra_context) class FixDurationInputMixin: - def as_mysql(self, compiler, connection, **extra_context): sql, params = super().as_sql(compiler, connection, **extra_context) - if self.output_field.get_internal_type() == 'DurationField': - sql = 'CAST(%s AS SIGNED)' % sql + if self.output_field.get_internal_type() == "DurationField": + sql = "CAST(%s AS SIGNED)" % sql return sql, params def as_oracle(self, compiler, connection, **extra_context): - if self.output_field.get_internal_type() == 'DurationField': + if self.output_field.get_internal_type() == "DurationField": expression = self.get_source_expressions()[0] options = self._get_repr_options() from django.db.backends.oracle.functions import ( - IntervalToSeconds, SecondsToInterval, + IntervalToSeconds, + SecondsToInterval, ) + return compiler.compile( - SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options)) + SecondsToInterval( + self.__class__(IntervalToSeconds(expression), **options) + ) ) return super().as_sql(compiler, connection, **extra_context) class NumericOutputFieldMixin: - def _resolve_output_field(self): source_fields = self.get_source_fields() if any(isinstance(s, DecimalField) for s in source_fields): diff --git a/django/db/models/functions/text.py b/django/db/models/functions/text.py index 4c52222ba1..a54ce8f19b 100644 --- a/django/db/models/functions/text.py +++ b/django/db/models/functions/text.py @@ -10,7 +10,7 @@ class MySQLSHA2Mixin: return super().as_sql( compiler, connection, - template='SHA2(%%(expressions)s, %s)' % self.function[3:], + template="SHA2(%%(expressions)s, %s)" % self.function[3:], **extra_content, ) @@ -40,25 +40,28 @@ class PostgreSQLSHAMixin: class Chr(Transform): - function = 'CHR' - lookup_name = 'chr' + function = "CHR" + lookup_name = "chr" def as_mysql(self, compiler, connection, **extra_context): return super().as_sql( - compiler, connection, function='CHAR', - template='%(function)s(%(expressions)s USING utf16)', - **extra_context + compiler, + connection, + function="CHAR", + template="%(function)s(%(expressions)s USING utf16)", + **extra_context, ) def as_oracle(self, compiler, connection, **extra_context): return super().as_sql( - compiler, connection, - template='%(function)s(%(expressions)s USING NCHAR_CS)', - **extra_context + compiler, + connection, + template="%(function)s(%(expressions)s USING NCHAR_CS)", + **extra_context, ) def as_sqlite(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='CHAR', **extra_context) + return super().as_sql(compiler, connection, function="CHAR", **extra_context) class ConcatPair(Func): @@ -66,29 +69,38 @@ class ConcatPair(Func): Concatenate two arguments together. This is used by `Concat` because not all backend databases support more than two arguments. """ - function = 'CONCAT' + + function = "CONCAT" def as_sqlite(self, compiler, connection, **extra_context): coalesced = self.coalesce() return super(ConcatPair, coalesced).as_sql( - compiler, connection, template='%(expressions)s', arg_joiner=' || ', - **extra_context + compiler, + connection, + template="%(expressions)s", + arg_joiner=" || ", + **extra_context, ) def as_mysql(self, compiler, connection, **extra_context): # Use CONCAT_WS with an empty separator so that NULLs are ignored. return super().as_sql( - compiler, connection, function='CONCAT_WS', + compiler, + connection, + function="CONCAT_WS", template="%(function)s('', %(expressions)s)", - **extra_context + **extra_context, ) def coalesce(self): # null on either side results in null for expression, wrap with coalesce c = self.copy() - c.set_source_expressions([ - Coalesce(expression, Value('')) for expression in c.get_source_expressions() - ]) + c.set_source_expressions( + [ + Coalesce(expression, Value("")) + for expression in c.get_source_expressions() + ] + ) return c @@ -98,12 +110,13 @@ class Concat(Func): null expression when any arguments are null will wrap each argument in coalesce functions to ensure a non-null result. """ + function = None template = "%(expressions)s" def __init__(self, *expressions, **extra): if len(expressions) < 2: - raise ValueError('Concat must take at least two expressions') + raise ValueError("Concat must take at least two expressions") paired = self._paired(expressions) super().__init__(paired, **extra) @@ -117,7 +130,7 @@ class Concat(Func): class Left(Func): - function = 'LEFT' + function = "LEFT" arity = 2 output_field = CharField() @@ -126,7 +139,7 @@ class Left(Func): expression: the name of a field, or an expression returning a string length: the number of characters to return from the start of the string """ - if not hasattr(length, 'resolve_expression'): + if not hasattr(length, "resolve_expression"): if length < 1: raise ValueError("'length' must be greater than 0.") super().__init__(expression, length, **extra) @@ -143,57 +156,68 @@ class Left(Func): class Length(Transform): """Return the number of characters in the expression.""" - function = 'LENGTH' - lookup_name = 'length' + + function = "LENGTH" + lookup_name = "length" output_field = IntegerField() def as_mysql(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context) + return super().as_sql( + compiler, connection, function="CHAR_LENGTH", **extra_context + ) class Lower(Transform): - function = 'LOWER' - lookup_name = 'lower' + function = "LOWER" + lookup_name = "lower" class LPad(Func): - function = 'LPAD' + function = "LPAD" output_field = CharField() - def __init__(self, expression, length, fill_text=Value(' '), **extra): - if not hasattr(length, 'resolve_expression') and length is not None and length < 0: + def __init__(self, expression, length, fill_text=Value(" "), **extra): + if ( + not hasattr(length, "resolve_expression") + and length is not None + and length < 0 + ): raise ValueError("'length' must be greater or equal to 0.") super().__init__(expression, length, fill_text, **extra) class LTrim(Transform): - function = 'LTRIM' - lookup_name = 'ltrim' + function = "LTRIM" + lookup_name = "ltrim" class MD5(OracleHashMixin, Transform): - function = 'MD5' - lookup_name = 'md5' + function = "MD5" + lookup_name = "md5" class Ord(Transform): - function = 'ASCII' - lookup_name = 'ord' + function = "ASCII" + lookup_name = "ord" output_field = IntegerField() def as_mysql(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='ORD', **extra_context) + return super().as_sql(compiler, connection, function="ORD", **extra_context) def as_sqlite(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='UNICODE', **extra_context) + return super().as_sql(compiler, connection, function="UNICODE", **extra_context) class Repeat(Func): - function = 'REPEAT' + function = "REPEAT" output_field = CharField() def __init__(self, expression, number, **extra): - if not hasattr(number, 'resolve_expression') and number is not None and number < 0: + if ( + not hasattr(number, "resolve_expression") + and number is not None + and number < 0 + ): raise ValueError("'number' must be greater or equal to 0.") super().__init__(expression, number, **extra) @@ -205,73 +229,76 @@ class Repeat(Func): class Replace(Func): - function = 'REPLACE' + function = "REPLACE" - def __init__(self, expression, text, replacement=Value(''), **extra): + def __init__(self, expression, text, replacement=Value(""), **extra): super().__init__(expression, text, replacement, **extra) class Reverse(Transform): - function = 'REVERSE' - lookup_name = 'reverse' + function = "REVERSE" + lookup_name = "reverse" def as_oracle(self, compiler, connection, **extra_context): # REVERSE in Oracle is undocumented and doesn't support multi-byte # strings. Use a special subquery instead. return super().as_sql( - compiler, connection, + compiler, + connection, template=( - '(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM ' - '(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s ' - 'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) ' - 'GROUP BY %(expressions)s)' + "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM " + "(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s " + "FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) " + "GROUP BY %(expressions)s)" ), - **extra_context + **extra_context, ) class Right(Left): - function = 'RIGHT' + function = "RIGHT" def get_substr(self): - return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1)) + return Substr( + self.source_expressions[0], self.source_expressions[1] * Value(-1) + ) class RPad(LPad): - function = 'RPAD' + function = "RPAD" class RTrim(Transform): - function = 'RTRIM' - lookup_name = 'rtrim' + function = "RTRIM" + lookup_name = "rtrim" class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform): - function = 'SHA1' - lookup_name = 'sha1' + function = "SHA1" + lookup_name = "sha1" class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): - function = 'SHA224' - lookup_name = 'sha224' + function = "SHA224" + lookup_name = "sha224" def as_oracle(self, compiler, connection, **extra_context): - raise NotSupportedError('SHA224 is not supported on Oracle.') + raise NotSupportedError("SHA224 is not supported on Oracle.") class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): - function = 'SHA256' - lookup_name = 'sha256' + function = "SHA256" + lookup_name = "sha256" class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): - function = 'SHA384' - lookup_name = 'sha384' + function = "SHA384" + lookup_name = "sha384" class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): - function = 'SHA512' - lookup_name = 'sha512' + function = "SHA512" + lookup_name = "sha512" class StrIndex(Func): @@ -280,16 +307,17 @@ class StrIndex(Func): first occurrence of a substring inside another string, or 0 if the substring is not found. """ - function = 'INSTR' + + function = "INSTR" arity = 2 output_field = IntegerField() def as_postgresql(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='STRPOS', **extra_context) + return super().as_sql(compiler, connection, function="STRPOS", **extra_context) class Substr(Func): - function = 'SUBSTRING' + function = "SUBSTRING" output_field = CharField() def __init__(self, expression, pos, length=None, **extra): @@ -298,7 +326,7 @@ class Substr(Func): pos: an integer > 0, or an expression returning an integer length: an optional number of characters to return """ - if not hasattr(pos, 'resolve_expression'): + if not hasattr(pos, "resolve_expression"): if pos < 1: raise ValueError("'pos' must be greater than 0") expressions = [expression, pos] @@ -307,17 +335,17 @@ class Substr(Func): super().__init__(*expressions, **extra) def as_sqlite(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) + return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) def as_oracle(self, compiler, connection, **extra_context): - return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) + return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) class Trim(Transform): - function = 'TRIM' - lookup_name = 'trim' + function = "TRIM" + lookup_name = "trim" class Upper(Transform): - function = 'UPPER' - lookup_name = 'upper' + function = "UPPER" + lookup_name = "upper" diff --git a/django/db/models/functions/window.py b/django/db/models/functions/window.py index 84b2b24ffa..671017aba7 100644 --- a/django/db/models/functions/window.py +++ b/django/db/models/functions/window.py @@ -2,26 +2,35 @@ from django.db.models.expressions import Func from django.db.models.fields import FloatField, IntegerField __all__ = [ - 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', - 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', + "CumeDist", + "DenseRank", + "FirstValue", + "Lag", + "LastValue", + "Lead", + "NthValue", + "Ntile", + "PercentRank", + "Rank", + "RowNumber", ] class CumeDist(Func): - function = 'CUME_DIST' + function = "CUME_DIST" output_field = FloatField() window_compatible = True class DenseRank(Func): - function = 'DENSE_RANK' + function = "DENSE_RANK" output_field = IntegerField() window_compatible = True class FirstValue(Func): arity = 1 - function = 'FIRST_VALUE' + function = "FIRST_VALUE" window_compatible = True @@ -31,13 +40,12 @@ class LagLeadFunction(Func): def __init__(self, expression, offset=1, default=None, **extra): if expression is None: raise ValueError( - '%s requires a non-null source expression.' % - self.__class__.__name__ + "%s requires a non-null source expression." % self.__class__.__name__ ) if offset is None or offset <= 0: raise ValueError( - '%s requires a positive integer for the offset.' % - self.__class__.__name__ + "%s requires a positive integer for the offset." + % self.__class__.__name__ ) args = (expression, offset) if default is not None: @@ -50,28 +58,32 @@ class LagLeadFunction(Func): class Lag(LagLeadFunction): - function = 'LAG' + function = "LAG" class LastValue(Func): arity = 1 - function = 'LAST_VALUE' + function = "LAST_VALUE" window_compatible = True class Lead(LagLeadFunction): - function = 'LEAD' + function = "LEAD" class NthValue(Func): - function = 'NTH_VALUE' + function = "NTH_VALUE" window_compatible = True def __init__(self, expression, nth=1, **extra): if expression is None: - raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__) + raise ValueError( + "%s requires a non-null source expression." % self.__class__.__name__ + ) if nth is None or nth <= 0: - raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__) + raise ValueError( + "%s requires a positive integer as for nth." % self.__class__.__name__ + ) super().__init__(expression, nth, **extra) def _resolve_output_field(self): @@ -80,29 +92,29 @@ class NthValue(Func): class Ntile(Func): - function = 'NTILE' + function = "NTILE" output_field = IntegerField() window_compatible = True def __init__(self, num_buckets=1, **extra): if num_buckets <= 0: - raise ValueError('num_buckets must be greater than 0.') + raise ValueError("num_buckets must be greater than 0.") super().__init__(num_buckets, **extra) class PercentRank(Func): - function = 'PERCENT_RANK' + function = "PERCENT_RANK" output_field = FloatField() window_compatible = True class Rank(Func): - function = 'RANK' + function = "RANK" output_field = IntegerField() window_compatible = True class RowNumber(Func): - function = 'ROW_NUMBER' + function = "ROW_NUMBER" output_field = IntegerField() window_compatible = True diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py index e843f9a8cb..95b71ae5bf 100644 --- a/django/db/models/indexes.py +++ b/django/db/models/indexes.py @@ -5,11 +5,11 @@ from django.db.models.query_utils import Q from django.db.models.sql import Query from django.utils.functional import partition -__all__ = ['Index'] +__all__ = ["Index"] class Index: - suffix = 'idx' + suffix = "idx" # The max length of the name of the index (restricted to 30 for # cross-database compatibility with Oracle) max_name_length = 30 @@ -25,45 +25,47 @@ class Index: include=None, ): if opclasses and not name: - raise ValueError('An index must be named to use opclasses.') + raise ValueError("An index must be named to use opclasses.") if not isinstance(condition, (type(None), Q)): - raise ValueError('Index.condition must be a Q instance.') + raise ValueError("Index.condition must be a Q instance.") if condition and not name: - raise ValueError('An index must be named to use condition.') + raise ValueError("An index must be named to use condition.") if not isinstance(fields, (list, tuple)): - raise ValueError('Index.fields must be a list or tuple.') + raise ValueError("Index.fields must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): - raise ValueError('Index.opclasses must be a list or tuple.') + raise ValueError("Index.opclasses must be a list or tuple.") if not expressions and not fields: raise ValueError( - 'At least one field or expression is required to define an index.' + "At least one field or expression is required to define an index." ) if expressions and fields: raise ValueError( - 'Index.fields and expressions are mutually exclusive.', + "Index.fields and expressions are mutually exclusive.", ) if expressions and not name: - raise ValueError('An index must be named to use expressions.') + raise ValueError("An index must be named to use expressions.") if expressions and opclasses: raise ValueError( - 'Index.opclasses cannot be used with expressions. Use ' - 'django.contrib.postgres.indexes.OpClass() instead.' + "Index.opclasses cannot be used with expressions. Use " + "django.contrib.postgres.indexes.OpClass() instead." ) if opclasses and len(fields) != len(opclasses): - raise ValueError('Index.fields and Index.opclasses must have the same number of elements.') + raise ValueError( + "Index.fields and Index.opclasses must have the same number of elements." + ) if fields and not all(isinstance(field, str) for field in fields): - raise ValueError('Index.fields must contain only strings with field names.') + raise ValueError("Index.fields must contain only strings with field names.") if include and not name: - raise ValueError('A covering index must be named.') + raise ValueError("A covering index must be named.") if not isinstance(include, (type(None), list, tuple)): - raise ValueError('Index.include must be a list or tuple.') + raise ValueError("Index.include must be a list or tuple.") self.fields = list(fields) # A list of 2-tuple with the field name and ordering ('' or 'DESC'). self.fields_orders = [ - (field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '') + (field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "") for field_name in self.fields ] - self.name = name or '' + self.name = name or "" self.db_tablespace = db_tablespace self.opclasses = opclasses self.condition = condition @@ -86,8 +88,10 @@ class Index: sql, params = where.as_sql(compiler, schema_editor.connection) return sql % tuple(schema_editor.quote_value(p) for p in params) - def create_sql(self, model, schema_editor, using='', **kwargs): - include = [model._meta.get_field(field_name).column for field_name in self.include] + def create_sql(self, model, schema_editor, using="", **kwargs): + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] condition = self._get_condition_sql(model, schema_editor) if self.expressions: index_expressions = [] @@ -108,29 +112,36 @@ class Index: col_suffixes = [order[1] for order in self.fields_orders] expressions = None return schema_editor._create_index_sql( - model, fields=fields, name=self.name, using=using, - db_tablespace=self.db_tablespace, col_suffixes=col_suffixes, - opclasses=self.opclasses, condition=condition, include=include, - expressions=expressions, **kwargs, + model, + fields=fields, + name=self.name, + using=using, + db_tablespace=self.db_tablespace, + col_suffixes=col_suffixes, + opclasses=self.opclasses, + condition=condition, + include=include, + expressions=expressions, + **kwargs, ) def remove_sql(self, model, schema_editor, **kwargs): return schema_editor._delete_index_sql(model, self.name, **kwargs) def deconstruct(self): - path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) - path = path.replace('django.db.models.indexes', 'django.db.models') - kwargs = {'name': self.name} + path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) + path = path.replace("django.db.models.indexes", "django.db.models") + kwargs = {"name": self.name} if self.fields: - kwargs['fields'] = self.fields + kwargs["fields"] = self.fields if self.db_tablespace is not None: - kwargs['db_tablespace'] = self.db_tablespace + kwargs["db_tablespace"] = self.db_tablespace if self.opclasses: - kwargs['opclasses'] = self.opclasses + kwargs["opclasses"] = self.opclasses if self.condition: - kwargs['condition'] = self.condition + kwargs["condition"] = self.condition if self.include: - kwargs['include'] = self.include + kwargs["include"] = self.include return (path, self.expressions, kwargs) def clone(self): @@ -147,39 +158,44 @@ class Index: fit its size by truncating the excess length. """ _, table_name = split_identifier(model._meta.db_table) - column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders] + column_names = [ + model._meta.get_field(field_name).column + for field_name, order in self.fields_orders + ] column_names_with_order = [ - (('-%s' if order else '%s') % column_name) - for column_name, (field_name, order) in zip(column_names, self.fields_orders) + (("-%s" if order else "%s") % column_name) + for column_name, (field_name, order) in zip( + column_names, self.fields_orders + ) ] # The length of the parts of the name is based on the default max # length of 30 characters. hash_data = [table_name] + column_names_with_order + [self.suffix] - self.name = '%s_%s_%s' % ( + self.name = "%s_%s_%s" % ( table_name[:11], column_names[0][:7], - '%s_%s' % (names_digest(*hash_data, length=6), self.suffix), + "%s_%s" % (names_digest(*hash_data, length=6), self.suffix), ) if len(self.name) > self.max_name_length: raise ValueError( - 'Index too long for multiple database support. Is self.suffix ' - 'longer than 3 characters?' + "Index too long for multiple database support. Is self.suffix " + "longer than 3 characters?" ) - if self.name[0] == '_' or self.name[0].isdigit(): - self.name = 'D%s' % self.name[1:] + if self.name[0] == "_" or self.name[0].isdigit(): + self.name = "D%s" % self.name[1:] def __repr__(self): - return '<%s:%s%s%s%s%s%s%s>' % ( + return "<%s:%s%s%s%s%s%s%s>" % ( self.__class__.__qualname__, - '' if not self.fields else ' fields=%s' % repr(self.fields), - '' if not self.expressions else ' expressions=%s' % repr(self.expressions), - '' if not self.name else ' name=%s' % repr(self.name), - '' + "" if not self.fields else " fields=%s" % repr(self.fields), + "" if not self.expressions else " expressions=%s" % repr(self.expressions), + "" if not self.name else " name=%s" % repr(self.name), + "" if self.db_tablespace is None - else ' db_tablespace=%s' % repr(self.db_tablespace), - '' if self.condition is None else ' condition=%s' % self.condition, - '' if not self.include else ' include=%s' % repr(self.include), - '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses), + else " db_tablespace=%s" % repr(self.db_tablespace), + "" if self.condition is None else " condition=%s" % self.condition, + "" if not self.include else " include=%s" % repr(self.include), + "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) def __eq__(self, other): @@ -190,17 +206,20 @@ class Index: class IndexExpression(Func): """Order and wrap expressions for CREATE INDEX statements.""" - template = '%(expressions)s' + + template = "%(expressions)s" wrapper_classes = (OrderBy, Collate) def set_wrapper_classes(self, connection=None): # Some databases (e.g. MySQL) treats COLLATE as an indexed expression. if connection and connection.features.collate_as_index_expression: - self.wrapper_classes = tuple([ - wrapper_cls - for wrapper_cls in self.wrapper_classes - if wrapper_cls is not Collate - ]) + self.wrapper_classes = tuple( + [ + wrapper_cls + for wrapper_cls in self.wrapper_classes + if wrapper_cls is not Collate + ] + ) @classmethod def register_wrappers(cls, *wrapper_classes): @@ -224,16 +243,17 @@ class IndexExpression(Func): if len(wrapper_types) != len(set(wrapper_types)): raise ValueError( "Multiple references to %s can't be used in an indexed " - "expression." % ', '.join([ - wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes - ]) + "expression." + % ", ".join( + [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes] + ) ) - if expressions[1:len(wrappers) + 1] != wrappers: + if expressions[1 : len(wrappers) + 1] != wrappers: raise ValueError( - '%s must be topmost expressions in an indexed expression.' - % ', '.join([ - wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes - ]) + "%s must be topmost expressions in an indexed expression." + % ", ".join( + [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes] + ) ) # Wrap expressions in parentheses if they are not column references. root_expression = index_expressions[1] @@ -245,7 +265,7 @@ class IndexExpression(Func): for_save, ) if not isinstance(resolve_root_expression, Col): - root_expression = Func(root_expression, template='(%(expressions)s)') + root_expression = Func(root_expression, template="(%(expressions)s)") if wrappers: # Order wrappers and set their expressions. @@ -262,7 +282,9 @@ class IndexExpression(Func): else: # Use the root expression, if there are no wrappers. self.set_source_expressions([root_expression]) - return super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + return super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) def as_sqlite(self, compiler, connection, **extra_context): # Casting to numeric is unnecessary. diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 24bfb11c06..5db549e6bf 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -4,7 +4,12 @@ import math from django.core.exceptions import EmptyResultSet from django.db.models.expressions import Case, Expression, Func, Value, When from django.db.models.fields import ( - BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField, + BooleanField, + CharField, + DateTimeField, + Field, + IntegerField, + UUIDField, ) from django.db.models.query_utils import RegisterLookupMixin from django.utils.datastructures import OrderedSet @@ -21,18 +26,19 @@ class Lookup(Expression): self.lhs, self.rhs = lhs, rhs self.rhs = self.get_prep_lookup() self.lhs = self.get_prep_lhs() - if hasattr(self.lhs, 'get_bilateral_transforms'): + if hasattr(self.lhs, "get_bilateral_transforms"): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] if bilateral_transforms: # Warn the user as soon as possible if they are trying to apply # a bilateral transformation on a nested QuerySet: that won't work. - from django.db.models.sql.query import ( # avoid circular import - Query, - ) + from django.db.models.sql.query import Query # avoid circular import + if isinstance(rhs, Query): - raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.") + raise NotImplementedError( + "Bilateral transformations on nested querysets are not implemented." + ) self.bilateral_transforms = bilateral_transforms def apply_bilateral_transforms(self, value): @@ -41,7 +47,7 @@ class Lookup(Expression): return value def __repr__(self): - return f'{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})' + return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})" def batch_process_rhs(self, compiler, connection, rhs=None): if rhs is None: @@ -57,7 +63,7 @@ class Lookup(Expression): sqls_params.extend(sql_params) else: _, params = self.get_db_prep_lookup(rhs, connection) - sqls, sqls_params = ['%s'] * len(params), params + sqls, sqls_params = ["%s"] * len(params), params return sqls, sqls_params def get_source_expressions(self): @@ -72,31 +78,31 @@ class Lookup(Expression): self.lhs, self.rhs = new_exprs def get_prep_lookup(self): - if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'): + if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"): return self.rhs - if hasattr(self.lhs, 'output_field'): - if hasattr(self.lhs.output_field, 'get_prep_value'): + if hasattr(self.lhs, "output_field"): + if hasattr(self.lhs.output_field, "get_prep_value"): return self.lhs.output_field.get_prep_value(self.rhs) elif self.rhs_is_direct_value(): return Value(self.rhs) return self.rhs def get_prep_lhs(self): - if hasattr(self.lhs, 'resolve_expression'): + if hasattr(self.lhs, "resolve_expression"): return self.lhs return Value(self.lhs) def get_db_prep_lookup(self, value, connection): - return ('%s', [value]) + return ("%s", [value]) def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs - if hasattr(lhs, 'resolve_expression'): + if hasattr(lhs, "resolve_expression"): lhs = lhs.resolve_expression(compiler.query) sql, params = compiler.compile(lhs) if isinstance(lhs, Lookup): # Wrapped in parentheses to respect operator precedence. - sql = f'({sql})' + sql = f"({sql})" return sql, params def process_rhs(self, compiler, connection): @@ -108,19 +114,19 @@ class Lookup(Expression): value = Value(value, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) value = value.resolve_expression(compiler.query) - if hasattr(value, 'as_sql'): + if hasattr(value, "as_sql"): sql, params = compiler.compile(value) # Ensure expression is wrapped in parentheses to respect operator # precedence but avoid double wrapping as it can be misinterpreted # on some backends (e.g. subqueries on SQLite). - if sql and sql[0] != '(': - sql = '(%s)' % sql + if sql and sql[0] != "(": + sql = "(%s)" % sql return sql, params else: return self.get_db_prep_lookup(value, connection) def rhs_is_direct_value(self): - return not hasattr(self.rhs, 'as_sql') + return not hasattr(self.rhs, "as_sql") def get_group_by_cols(self, alias=None): cols = [] @@ -157,11 +163,17 @@ class Lookup(Expression): def __hash__(self): return hash(make_hashable(self.identity)) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): c = self.copy() c.is_summary = summarize - c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) - c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.lhs = self.lhs.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + c.rhs = self.rhs.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) return c def select_format(self, compiler, sql, params): @@ -169,7 +181,7 @@ class Lookup(Expression): # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP # BY list. if not compiler.connection.features.supports_boolean_expr_in_select_clause: - sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END' + sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" return sql, params @@ -178,6 +190,7 @@ class Transform(RegisterLookupMixin, Func): RegisterLookupMixin() is first so that get_lookup() and get_transform() first examine self and then check output_field. """ + bilateral = False arity = 1 @@ -186,7 +199,7 @@ class Transform(RegisterLookupMixin, Func): return self.get_source_expressions()[0] def get_bilateral_transforms(self): - if hasattr(self.lhs, 'get_bilateral_transforms'): + if hasattr(self.lhs, "get_bilateral_transforms"): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] @@ -200,9 +213,10 @@ class BuiltinLookup(Lookup): lhs_sql, params = super().process_lhs(compiler, connection, lhs) field_internal_type = self.lhs.output_field.get_internal_type() db_type = self.lhs.output_field.db_type(connection=connection) - lhs_sql = connection.ops.field_cast_sql( - db_type, field_internal_type) % lhs_sql - lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql + lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql + lhs_sql = ( + connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql + ) return lhs_sql, list(params) def as_sql(self, compiler, connection): @@ -210,7 +224,7 @@ class BuiltinLookup(Lookup): rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) rhs_sql = self.get_rhs_op(connection, rhs_sql) - return '%s %s' % (lhs_sql, rhs_sql), params + return "%s %s" % (lhs_sql, rhs_sql), params def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs @@ -221,18 +235,22 @@ class FieldGetDbPrepValueMixin: Some lookups require Field.get_db_prep_value() to be called on their inputs. """ + get_db_prep_lookup_value_is_iterable = False def get_db_prep_lookup(self, value, connection): # For relational fields, use the 'target_field' attribute of the # output_field. - field = getattr(self.lhs.output_field, 'target_field', None) - get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value + field = getattr(self.lhs.output_field, "target_field", None) + get_db_prep_value = ( + getattr(field, "get_db_prep_value", None) + or self.lhs.output_field.get_db_prep_value + ) return ( - '%s', + "%s", [get_db_prep_value(v, connection, prepared=True) for v in value] - if self.get_db_prep_lookup_value_is_iterable else - [get_db_prep_value(value, connection, prepared=True)] + if self.get_db_prep_lookup_value_is_iterable + else [get_db_prep_value(value, connection, prepared=True)], ) @@ -241,18 +259,19 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): Some lookups require Field.get_db_prep_value() to be called on each value in an iterable. """ + get_db_prep_lookup_value_is_iterable = True def get_prep_lookup(self): - if hasattr(self.rhs, 'resolve_expression'): + if hasattr(self.rhs, "resolve_expression"): return self.rhs prepared_values = [] for rhs_value in self.rhs: - if hasattr(rhs_value, 'resolve_expression'): + if hasattr(rhs_value, "resolve_expression"): # An expression will be handled by the database but can coexist # alongside real values. pass - elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): + elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"): rhs_value = self.lhs.output_field.get_prep_value(rhs_value) prepared_values.append(rhs_value) return prepared_values @@ -267,9 +286,9 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): def resolve_expression_parameter(self, compiler, connection, sql, param): params = [param] - if hasattr(param, 'resolve_expression'): + if hasattr(param, "resolve_expression"): param = param.resolve_expression(compiler.query) - if hasattr(param, 'as_sql'): + if hasattr(param, "as_sql"): sql, params = compiler.compile(param) return sql, params @@ -279,40 +298,44 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): # sql/param pair. Zip them to get sql and param pairs that refer to the # same argument and attempt to replace them with the result of # compiling the param step. - sql, params = zip(*( - self.resolve_expression_parameter(compiler, connection, sql, param) - for sql, param in zip(*pre_processed) - )) + sql, params = zip( + *( + self.resolve_expression_parameter(compiler, connection, sql, param) + for sql, param in zip(*pre_processed) + ) + ) params = itertools.chain.from_iterable(params) return sql, tuple(params) class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup): """Lookup defined by operators on PostgreSQL.""" + postgres_operator = None def as_postgresql(self, compiler, connection): lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) - return '%s %s %s' % (lhs, self.postgres_operator, rhs), params + return "%s %s %s" % (lhs, self.postgres_operator, rhs), params @Field.register_lookup class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): - lookup_name = 'exact' + lookup_name = "exact" def get_prep_lookup(self): from django.db.models.sql.query import Query # avoid circular import + if isinstance(self.rhs, Query): if self.rhs.has_limit_one(): if not self.rhs.has_select_fields: self.rhs.clear_select_clause() - self.rhs.add_fields(['pk']) + self.rhs.add_fields(["pk"]) else: raise ValueError( - 'The QuerySet value for an exact lookup must be limited to ' - 'one result using slicing.' + "The QuerySet value for an exact lookup must be limited to " + "one result using slicing." ) return super().get_prep_lookup() @@ -321,19 +344,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): # turns "boolfield__exact=True" into "WHERE boolean_field" instead of # "WHERE boolean_field = True" when allowed. if ( - isinstance(self.rhs, bool) and - getattr(self.lhs, 'conditional', False) and - connection.ops.conditional_expression_supported_in_where_clause(self.lhs) + isinstance(self.rhs, bool) + and getattr(self.lhs, "conditional", False) + and connection.ops.conditional_expression_supported_in_where_clause( + self.lhs + ) ): lhs_sql, params = self.process_lhs(compiler, connection) - template = '%s' if self.rhs else 'NOT %s' + template = "%s" if self.rhs else "NOT %s" return template % lhs_sql, params return super().as_sql(compiler, connection) @Field.register_lookup class IExact(BuiltinLookup): - lookup_name = 'iexact' + lookup_name = "iexact" prepare_rhs = False def process_rhs(self, qn, connection): @@ -345,22 +370,22 @@ class IExact(BuiltinLookup): @Field.register_lookup class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup): - lookup_name = 'gt' + lookup_name = "gt" @Field.register_lookup class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): - lookup_name = 'gte' + lookup_name = "gte" @Field.register_lookup class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup): - lookup_name = 'lt' + lookup_name = "lt" @Field.register_lookup class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): - lookup_name = 'lte' + lookup_name = "lte" class IntegerFieldFloatRounding: @@ -368,6 +393,7 @@ class IntegerFieldFloatRounding: Allow floats to work as query values for IntegerField. Without this, the decimal portion of the float would always be discarded. """ + def get_prep_lookup(self): if isinstance(self.rhs, float): self.rhs = math.ceil(self.rhs) @@ -386,19 +412,20 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan): @Field.register_lookup class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): - lookup_name = 'in' + lookup_name = "in" def get_prep_lookup(self): from django.db.models.sql.query import Query # avoid circular import + if isinstance(self.rhs, Query): self.rhs.clear_ordering(clear_default=True) if not self.rhs.has_select_fields: self.rhs.clear_select_clause() - self.rhs.add_fields(['pk']) + self.rhs.add_fields(["pk"]) return super().get_prep_lookup() def process_rhs(self, compiler, connection): - db_rhs = getattr(self.rhs, '_db', None) + db_rhs = getattr(self.rhs, "_db", None) if db_rhs is not None and db_rhs != connection.alias: raise ValueError( "Subqueries aren't allowed across different databases. Force " @@ -419,16 +446,20 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): # rhs should be an iterable; use batch_process_rhs() to # prepare/transform those values. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) - placeholder = '(' + ', '.join(sqls) + ')' + placeholder = "(" + ", ".join(sqls) + ")" return (placeholder, sqls_params) return super().process_rhs(compiler, connection) def get_rhs_op(self, connection, rhs): - return 'IN %s' % rhs + return "IN %s" % rhs def as_sql(self, compiler, connection): max_in_list_size = connection.ops.max_in_list_size() - if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size: + if ( + self.rhs_is_direct_value() + and max_in_list_size + and len(self.rhs) > max_in_list_size + ): return self.split_parameter_list_as_sql(compiler, connection) return super().as_sql(compiler, connection) @@ -438,25 +469,25 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): max_in_list_size = connection.ops.max_in_list_size() lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.batch_process_rhs(compiler, connection) - in_clause_elements = ['('] + in_clause_elements = ["("] params = [] for offset in range(0, len(rhs_params), max_in_list_size): if offset > 0: - in_clause_elements.append(' OR ') - in_clause_elements.append('%s IN (' % lhs) + in_clause_elements.append(" OR ") + in_clause_elements.append("%s IN (" % lhs) params.extend(lhs_params) - sqls = rhs[offset: offset + max_in_list_size] - sqls_params = rhs_params[offset: offset + max_in_list_size] - param_group = ', '.join(sqls) + sqls = rhs[offset : offset + max_in_list_size] + sqls_params = rhs_params[offset : offset + max_in_list_size] + param_group = ", ".join(sqls) in_clause_elements.append(param_group) - in_clause_elements.append(')') + in_clause_elements.append(")") params.extend(sqls_params) - in_clause_elements.append(')') - return ''.join(in_clause_elements), params + in_clause_elements.append(")") + return "".join(in_clause_elements), params class PatternLookup(BuiltinLookup): - param_pattern = '%%%s%%' + param_pattern = "%%%s%%" prepare_rhs = False def get_rhs_op(self, connection, rhs): @@ -469,8 +500,10 @@ class PatternLookup(BuiltinLookup): # So, for Python values we don't need any special pattern, but for # SQL reference values or SQL transformations we need the correct # pattern added. - if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms: - pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc) + if hasattr(self.rhs, "as_sql") or self.bilateral_transforms: + pattern = connection.pattern_ops[self.lookup_name].format( + connection.pattern_esc + ) return pattern.format(rhs) else: return super().get_rhs_op(connection, rhs) @@ -478,45 +511,47 @@ class PatternLookup(BuiltinLookup): def process_rhs(self, qn, connection): rhs, params = super().process_rhs(qn, connection) if self.rhs_is_direct_value() and params and not self.bilateral_transforms: - params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0]) + params[0] = self.param_pattern % connection.ops.prep_for_like_query( + params[0] + ) return rhs, params @Field.register_lookup class Contains(PatternLookup): - lookup_name = 'contains' + lookup_name = "contains" @Field.register_lookup class IContains(Contains): - lookup_name = 'icontains' + lookup_name = "icontains" @Field.register_lookup class StartsWith(PatternLookup): - lookup_name = 'startswith' - param_pattern = '%s%%' + lookup_name = "startswith" + param_pattern = "%s%%" @Field.register_lookup class IStartsWith(StartsWith): - lookup_name = 'istartswith' + lookup_name = "istartswith" @Field.register_lookup class EndsWith(PatternLookup): - lookup_name = 'endswith' - param_pattern = '%%%s' + lookup_name = "endswith" + param_pattern = "%%%s" @Field.register_lookup class IEndsWith(EndsWith): - lookup_name = 'iendswith' + lookup_name = "iendswith" @Field.register_lookup class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): - lookup_name = 'range' + lookup_name = "range" def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) @@ -524,13 +559,13 @@ class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): @Field.register_lookup class IsNull(BuiltinLookup): - lookup_name = 'isnull' + lookup_name = "isnull" prepare_rhs = False def as_sql(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError( - 'The QuerySet value for an isnull lookup must be True or False.' + "The QuerySet value for an isnull lookup must be True or False." ) sql, params = compiler.compile(self.lhs) if self.rhs: @@ -541,7 +576,7 @@ class IsNull(BuiltinLookup): @Field.register_lookup class Regex(BuiltinLookup): - lookup_name = 'regex' + lookup_name = "regex" prepare_rhs = False def as_sql(self, compiler, connection): @@ -556,21 +591,24 @@ class Regex(BuiltinLookup): @Field.register_lookup class IRegex(Regex): - lookup_name = 'iregex' + lookup_name = "iregex" class YearLookup(Lookup): def year_lookup_bounds(self, connection, year): from django.db.models.functions import ExtractIsoYear + iso_year = isinstance(self.lhs, ExtractIsoYear) output_field = self.lhs.lhs.output_field if isinstance(output_field, DateTimeField): bounds = connection.ops.year_lookup_bounds_for_datetime_field( - year, iso_year=iso_year, + year, + iso_year=iso_year, ) else: bounds = connection.ops.year_lookup_bounds_for_date_field( - year, iso_year=iso_year, + year, + iso_year=iso_year, ) return bounds @@ -585,7 +623,7 @@ class YearLookup(Lookup): rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql) start, finish = self.year_lookup_bounds(connection, self.rhs) params.extend(self.get_bound_params(start, finish)) - return '%s %s' % (lhs_sql, rhs_sql), params + return "%s %s" % (lhs_sql, rhs_sql), params return super().as_sql(compiler, connection) def get_direct_rhs_sql(self, connection, rhs): @@ -593,13 +631,13 @@ class YearLookup(Lookup): def get_bound_params(self, start, finish): raise NotImplementedError( - 'subclasses of YearLookup must provide a get_bound_params() method' + "subclasses of YearLookup must provide a get_bound_params() method" ) class YearExact(YearLookup, Exact): def get_direct_rhs_sql(self, connection, rhs): - return 'BETWEEN %s AND %s' + return "BETWEEN %s AND %s" def get_bound_params(self, start, finish): return (start, finish) @@ -630,12 +668,16 @@ class UUIDTextMixin: Strip hyphens from a value when filtering a UUIDField on backends without a native datatype for UUID. """ + def process_rhs(self, qn, connection): if not connection.features.has_native_uuid_field: from django.db.models.functions import Replace + if self.rhs_is_direct_value(): self.rhs = Value(self.rhs) - self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField()) + self.rhs = Replace( + self.rhs, Value("-"), Value(""), output_field=CharField() + ) rhs, params = super().process_rhs(qn, connection) return rhs, params diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 655dfcf8e7..0c0688828e 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -33,7 +33,7 @@ class BaseManager: def __str__(self): """Return "app_label.model_label.manager_name".""" - return '%s.%s' % (self.model._meta.label, self.name) + return "%s.%s" % (self.model._meta.label, self.name) def __class_getitem__(cls, *args, **kwargs): return cls @@ -46,12 +46,12 @@ class BaseManager: Raise a ValueError if the manager is dynamically generated. """ qs_class = self._queryset_class - if getattr(self, '_built_with_as_manager', False): + if getattr(self, "_built_with_as_manager", False): # using MyQuerySet.as_manager() return ( True, # as_manager None, # manager_class - '%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class + "%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class None, # args None, # kwargs ) @@ -69,7 +69,7 @@ class BaseManager: ) return ( False, # as_manager - '%s.%s' % (module_name, name), # manager_class + "%s.%s" % (module_name, name), # manager_class None, # qs_class self._constructor_args[0], # args self._constructor_args[1], # kwargs @@ -83,18 +83,21 @@ class BaseManager: def create_method(name, method): def manager_method(self, *args, **kwargs): return getattr(self.get_queryset(), name)(*args, **kwargs) + manager_method.__name__ = method.__name__ manager_method.__doc__ = method.__doc__ return manager_method new_methods = {} - for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction): + for name, method in inspect.getmembers( + queryset_class, predicate=inspect.isfunction + ): # Only copy missing methods. if hasattr(cls, name): continue # Only copy public methods or methods with the attribute `queryset_only=False`. - queryset_only = getattr(method, 'queryset_only', None) - if queryset_only or (queryset_only is None and name.startswith('_')): + queryset_only = getattr(method, "queryset_only", None) + if queryset_only or (queryset_only is None and name.startswith("_")): continue # Copy the method onto the manager. new_methods[name] = create_method(name, method) @@ -103,11 +106,15 @@ class BaseManager: @classmethod def from_queryset(cls, queryset_class, class_name=None): if class_name is None: - class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__) - return type(class_name, (cls,), { - '_queryset_class': queryset_class, - **cls._get_queryset_methods(queryset_class), - }) + class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__) + return type( + class_name, + (cls,), + { + "_queryset_class": queryset_class, + **cls._get_queryset_methods(queryset_class), + }, + ) def contribute_to_class(self, cls, name): self.name = self.name or name @@ -157,8 +164,8 @@ class BaseManager: def __eq__(self, other): return ( - isinstance(other, self.__class__) and - self._constructor_args == other._constructor_args + isinstance(other, self.__class__) + and self._constructor_args == other._constructor_args ) def __hash__(self): @@ -170,22 +177,24 @@ class Manager(BaseManager.from_queryset(QuerySet)): class ManagerDescriptor: - def __init__(self, manager): self.manager = manager def __get__(self, instance, cls=None): if instance is not None: - raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__) + raise AttributeError( + "Manager isn't accessible via %s instances" % cls.__name__ + ) if cls._meta.abstract: - raise AttributeError("Manager isn't available; %s is abstract" % ( - cls._meta.object_name, - )) + raise AttributeError( + "Manager isn't available; %s is abstract" % (cls._meta.object_name,) + ) if cls._meta.swapped: raise AttributeError( - "Manager isn't available; '%s' has been swapped for '%s'" % ( + "Manager isn't available; '%s' has been swapped for '%s'" + % ( cls._meta.label, cls._meta.swapped, ) diff --git a/django/db/models/options.py b/django/db/models/options.py index 6022099e3e..b95f9871b1 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -25,13 +25,32 @@ IMMUTABLE_WARNING = ( ) DEFAULT_NAMES = ( - 'verbose_name', 'verbose_name_plural', 'db_table', 'ordering', - 'unique_together', 'permissions', 'get_latest_by', 'order_with_respect_to', - 'app_label', 'db_tablespace', 'abstract', 'managed', 'proxy', 'swappable', - 'auto_created', 'index_together', 'apps', 'default_permissions', - 'select_on_save', 'default_related_name', 'required_db_features', - 'required_db_vendor', 'base_manager_name', 'default_manager_name', - 'indexes', 'constraints', + "verbose_name", + "verbose_name_plural", + "db_table", + "ordering", + "unique_together", + "permissions", + "get_latest_by", + "order_with_respect_to", + "app_label", + "db_tablespace", + "abstract", + "managed", + "proxy", + "swappable", + "auto_created", + "index_together", + "apps", + "default_permissions", + "select_on_save", + "default_related_name", + "required_db_features", + "required_db_vendor", + "base_manager_name", + "default_manager_name", + "indexes", + "constraints", ) @@ -63,11 +82,17 @@ def make_immutable_fields_list(name, data): class Options: FORWARD_PROPERTIES = { - 'fields', 'many_to_many', 'concrete_fields', 'local_concrete_fields', - '_forward_fields_map', 'managers', 'managers_map', 'base_manager', - 'default_manager', + "fields", + "many_to_many", + "concrete_fields", + "local_concrete_fields", + "_forward_fields_map", + "managers", + "managers_map", + "base_manager", + "default_manager", } - REVERSE_PROPERTIES = {'related_objects', 'fields_map', '_relation_tree'} + REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"} default_apps = apps @@ -82,7 +107,7 @@ class Options: self.model_name = None self.verbose_name = None self.verbose_name_plural = None - self.db_table = '' + self.db_table = "" self.ordering = [] self._ordering_clash = False self.indexes = [] @@ -90,7 +115,7 @@ class Options: self.unique_together = [] self.index_together = [] self.select_on_save = False - self.default_permissions = ('add', 'change', 'delete', 'view') + self.default_permissions = ("add", "change", "delete", "view") self.permissions = [] self.object_name = None self.app_label = app_label @@ -130,11 +155,11 @@ class Options: @property def label(self): - return '%s.%s' % (self.app_label, self.object_name) + return "%s.%s" % (self.app_label, self.object_name) @property def label_lower(self): - return '%s.%s' % (self.app_label, self.model_name) + return "%s.%s" % (self.app_label, self.model_name) @property def app_config(self): @@ -163,7 +188,7 @@ class Options: # Ignore any private attributes that Django doesn't care about. # NOTE: We can't modify a dictionary's contents while looping # over it, so we loop over the *original* dictionary instead. - if name.startswith('_'): + if name.startswith("_"): del meta_attrs[name] for attr_name in DEFAULT_NAMES: if attr_name in meta_attrs: @@ -177,30 +202,34 @@ class Options: self.index_together = normalize_together(self.index_together) # App label/class name interpolation for names of constraints and # indexes. - if not getattr(cls._meta, 'abstract', False): - for attr_name in {'constraints', 'indexes'}: + if not getattr(cls._meta, "abstract", False): + for attr_name in {"constraints", "indexes"}: objs = getattr(self, attr_name, []) setattr(self, attr_name, self._format_names_with_class(cls, objs)) # verbose_name_plural is a special case because it uses a 's' # by default. if self.verbose_name_plural is None: - self.verbose_name_plural = format_lazy('{}s', self.verbose_name) + self.verbose_name_plural = format_lazy("{}s", self.verbose_name) # order_with_respect_and ordering are mutually exclusive. self._ordering_clash = bool(self.ordering and self.order_with_respect_to) # Any leftover attributes must be invalid. if meta_attrs != {}: - raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs)) + raise TypeError( + "'class Meta' got invalid attribute(s): %s" % ",".join(meta_attrs) + ) else: - self.verbose_name_plural = format_lazy('{}s', self.verbose_name) + self.verbose_name_plural = format_lazy("{}s", self.verbose_name) del self.meta # If the db_table wasn't provided, use the app_label + model_name. if not self.db_table: self.db_table = "%s_%s" % (self.app_label, self.model_name) - self.db_table = truncate_name(self.db_table, connection.ops.max_name_length()) + self.db_table = truncate_name( + self.db_table, connection.ops.max_name_length() + ) def _format_names_with_class(self, cls, objs): """App label/class name interpolation for object names.""" @@ -208,8 +237,8 @@ class Options: for obj in objs: obj = obj.clone() obj.name = obj.name % { - 'app_label': cls._meta.app_label.lower(), - 'class': cls.__name__.lower(), + "app_label": cls._meta.app_label.lower(), + "class": cls.__name__.lower(), } new_objs.append(obj) return new_objs @@ -217,19 +246,19 @@ class Options: def _get_default_pk_class(self): pk_class_path = getattr( self.app_config, - 'default_auto_field', + "default_auto_field", settings.DEFAULT_AUTO_FIELD, ) if self.app_config and self.app_config._is_default_auto_field_overridden: app_config_class = type(self.app_config) source = ( - f'{app_config_class.__module__}.' - f'{app_config_class.__qualname__}.default_auto_field' + f"{app_config_class.__module__}." + f"{app_config_class.__qualname__}.default_auto_field" ) else: - source = 'DEFAULT_AUTO_FIELD' + source = "DEFAULT_AUTO_FIELD" if not pk_class_path: - raise ImproperlyConfigured(f'{source} must not be empty.') + raise ImproperlyConfigured(f"{source} must not be empty.") try: pk_class = import_string(pk_class_path) except ImportError as e: @@ -252,15 +281,20 @@ class Options: query = self.order_with_respect_to try: self.order_with_respect_to = next( - f for f in self._get_fields(reverse=False) + f + for f in self._get_fields(reverse=False) if f.name == query or f.attname == query ) except StopIteration: - raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, query)) + raise FieldDoesNotExist( + "%s has no field named '%s'" % (self.object_name, query) + ) - self.ordering = ('_order',) - if not any(isinstance(field, OrderWrt) for field in model._meta.local_fields): - model.add_to_class('_order', OrderWrt()) + self.ordering = ("_order",) + if not any( + isinstance(field, OrderWrt) for field in model._meta.local_fields + ): + model.add_to_class("_order", OrderWrt()) else: self.order_with_respect_to = None @@ -272,15 +306,17 @@ class Options: # Look for a local field with the same name as the # first parent link. If a local field has already been # created, use it instead of promoting the parent - already_created = [fld for fld in self.local_fields if fld.name == field.name] + already_created = [ + fld for fld in self.local_fields if fld.name == field.name + ] if already_created: field = already_created[0] field.primary_key = True self.setup_pk(field) else: pk_class = self._get_default_pk_class() - auto = pk_class(verbose_name='ID', primary_key=True, auto_created=True) - model.add_to_class('id', auto) + auto = pk_class(verbose_name="ID", primary_key=True, auto_created=True) + model.add_to_class("id", auto) def add_manager(self, manager): self.local_managers.append(manager) @@ -307,7 +343,11 @@ class Options: # ideally, we'd just ask for field.related_model. However, related_model # is a cached property, and all the models haven't been loaded yet, so # we need to make sure we don't cache a string reference. - if field.is_relation and hasattr(field.remote_field, 'model') and field.remote_field.model: + if ( + field.is_relation + and hasattr(field.remote_field, "model") + and field.remote_field.model + ): try: field.remote_field.model._meta._expire_cache(forward=False) except AttributeError: @@ -331,7 +371,7 @@ class Options: self.db_table = target._meta.db_table def __repr__(self): - return '<Options for %s>' % self.object_name + return "<Options for %s>" % self.object_name def __str__(self): return self.label_lower @@ -348,8 +388,10 @@ class Options: if self.required_db_vendor: return self.required_db_vendor == connection.vendor if self.required_db_features: - return all(getattr(connection.features, feat, False) - for feat in self.required_db_features) + return all( + getattr(connection.features, feat, False) + for feat in self.required_db_features + ) return True @property @@ -371,7 +413,7 @@ class Options: swapped_for = getattr(settings, self.swappable, None) if swapped_for: try: - swapped_label, swapped_object = swapped_for.split('.') + swapped_label, swapped_object = swapped_for.split(".") except ValueError: # setting not in the format app_label.model_name # raising ImproperlyConfigured here causes problems with @@ -379,7 +421,10 @@ class Options: # or as part of validation. return swapped_for - if '%s.%s' % (swapped_label, swapped_object.lower()) != self.label_lower: + if ( + "%s.%s" % (swapped_label, swapped_object.lower()) + != self.label_lower + ): return swapped_for return None @@ -387,7 +432,7 @@ class Options: def managers(self): managers = [] seen_managers = set() - bases = (b for b in self.model.mro() if hasattr(b, '_meta')) + bases = (b for b in self.model.mro() if hasattr(b, "_meta")) for depth, base in enumerate(bases): for manager in base._meta.local_managers: if manager.name in seen_managers: @@ -413,8 +458,8 @@ class Options: if not base_manager_name: # Get the first parent's base_manager_name if there's one. for parent in self.model.mro()[1:]: - if hasattr(parent, '_meta'): - if parent._base_manager.name != '_base_manager': + if hasattr(parent, "_meta"): + if parent._base_manager.name != "_base_manager": base_manager_name = parent._base_manager.name break @@ -423,14 +468,15 @@ class Options: return self.managers_map[base_manager_name] except KeyError: raise ValueError( - "%s has no manager named %r" % ( + "%s has no manager named %r" + % ( self.object_name, base_manager_name, ) ) manager = Manager() - manager.name = '_base_manager' + manager.name = "_base_manager" manager.model = self.model manager.auto_created = True return manager @@ -441,7 +487,7 @@ class Options: if not default_manager_name and not self.local_managers: # Get the first parent's default_manager_name if there's one. for parent in self.model.mro()[1:]: - if hasattr(parent, '_meta'): + if hasattr(parent, "_meta"): default_manager_name = parent._meta.default_manager_name break @@ -450,7 +496,8 @@ class Options: return self.managers_map[default_manager_name] except KeyError: raise ValueError( - "%s has no manager named %r" % ( + "%s has no manager named %r" + % ( self.object_name, default_manager_name, ) @@ -484,13 +531,20 @@ class Options: def is_not_a_generic_foreign_key(f): return not ( - f.is_relation and f.many_to_one and not (hasattr(f.remote_field, 'model') and f.remote_field.model) + f.is_relation + and f.many_to_one + and not (hasattr(f.remote_field, "model") and f.remote_field.model) ) return make_immutable_fields_list( "fields", - (f for f in self._get_fields(reverse=False) - if is_not_an_m2m_field(f) and is_not_a_generic_relation(f) and is_not_a_generic_foreign_key(f)) + ( + f + for f in self._get_fields(reverse=False) + if is_not_an_m2m_field(f) + and is_not_a_generic_relation(f) + and is_not_a_generic_foreign_key(f) + ), ) @cached_property @@ -530,7 +584,11 @@ class Options: """ return make_immutable_fields_list( "many_to_many", - (f for f in self._get_fields(reverse=False) if f.is_relation and f.many_to_many) + ( + f + for f in self._get_fields(reverse=False) + if f.is_relation and f.many_to_many + ), ) @cached_property @@ -544,10 +602,16 @@ class Options: combined with filtering of field properties is the public API for obtaining this field list. """ - all_related_fields = self._get_fields(forward=False, reverse=True, include_hidden=True) + all_related_fields = self._get_fields( + forward=False, reverse=True, include_hidden=True + ) return make_immutable_fields_list( "related_objects", - (obj for obj in all_related_fields if not obj.hidden or obj.field.many_to_many) + ( + obj + for obj in all_related_fields + if not obj.hidden or obj.field.many_to_many + ), ) @cached_property @@ -603,7 +667,9 @@ class Options: # field map. return self.fields_map[field_name] except KeyError: - raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, field_name)) + raise FieldDoesNotExist( + "%s has no field named '%s'" % (self.object_name, field_name) + ) def get_base_chain(self, model): """ @@ -672,15 +738,17 @@ class Options: final_field = opts.parents[int_model] targets = (final_field.remote_field.get_related_field(),) opts = int_model._meta - path.append(PathInfo( - from_opts=final_field.model._meta, - to_opts=opts, - target_fields=targets, - join_field=final_field, - m2m=False, - direct=True, - filtered_relation=None, - )) + path.append( + PathInfo( + from_opts=final_field.model._meta, + to_opts=opts, + target_fields=targets, + join_field=final_field, + m2m=False, + direct=True, + filtered_relation=None, + ) + ) return path def get_path_from_parent(self, parent): @@ -722,7 +790,8 @@ class Options: if opts.abstract: continue fields_with_relations = ( - f for f in opts._get_fields(reverse=False, include_parents=False) + f + for f in opts._get_fields(reverse=False, include_parents=False) if f.is_relation and f.related_model is not None ) for f in fields_with_relations: @@ -736,11 +805,13 @@ class Options: # __dict__ takes precedence over a data descriptor (such as # @cached_property). This means that the _meta._relation_tree is # only called if related_objects is not in __dict__. - related_objects = related_objects_graph[model._meta.concrete_model._meta.label] - model._meta.__dict__['_relation_tree'] = related_objects + related_objects = related_objects_graph[ + model._meta.concrete_model._meta.label + ] + model._meta.__dict__["_relation_tree"] = related_objects # It seems it is possible that self is not in all_models, so guard # against that with default for get(). - return self.__dict__.get('_relation_tree', EMPTY_RELATION_TREE) + return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE) @cached_property def _relation_tree(self): @@ -771,10 +842,18 @@ class Options: """ if include_parents is False: include_parents = PROXY_PARENTS - return self._get_fields(include_parents=include_parents, include_hidden=include_hidden) + return self._get_fields( + include_parents=include_parents, include_hidden=include_hidden + ) - def _get_fields(self, forward=True, reverse=True, include_parents=True, include_hidden=False, - seen_models=None): + def _get_fields( + self, + forward=True, + reverse=True, + include_parents=True, + include_hidden=False, + seen_models=None, + ): """ Internal helper function to return fields of the model. * If forward=True, then fields defined on this model are returned. @@ -787,7 +866,9 @@ class Options: parent chain to the model's concrete model. """ if include_parents not in (True, False, PROXY_PARENTS): - raise TypeError("Invalid argument for include_parents: %s" % (include_parents,)) + raise TypeError( + "Invalid argument for include_parents: %s" % (include_parents,) + ) # This helper function is used to allow recursion in ``get_fields()`` # implementation and to provide a fast way for Django's internals to # access specific subsets of fields. @@ -819,13 +900,22 @@ class Options: # fields from the same parent again. if parent in seen_models: continue - if (parent._meta.concrete_model != self.concrete_model and - include_parents == PROXY_PARENTS): + if ( + parent._meta.concrete_model != self.concrete_model + and include_parents == PROXY_PARENTS + ): continue for obj in parent._meta._get_fields( - forward=forward, reverse=reverse, include_parents=include_parents, - include_hidden=include_hidden, seen_models=seen_models): - if not getattr(obj, 'parent_link', False) or obj.model == self.concrete_model: + forward=forward, + reverse=reverse, + include_parents=include_parents, + include_hidden=include_hidden, + seen_models=seen_models, + ): + if ( + not getattr(obj, "parent_link", False) + or obj.model == self.concrete_model + ): fields.append(obj) if reverse and not self.proxy: # Tree is computed once and cached until the app cache is expired. @@ -867,9 +957,9 @@ class Options: constraint for constraint in self.constraints if ( - isinstance(constraint, UniqueConstraint) and - constraint.condition is None and - not constraint.contains_expressions + isinstance(constraint, UniqueConstraint) + and constraint.condition is None + and not constraint.contains_expressions ) ] @@ -890,6 +980,9 @@ class Options: Fields to be returned after a database insert. """ return [ - field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS) - if getattr(field, 'db_returning', False) + field + for field in self._get_fields( + forward=True, reverse=False, include_parents=PROXY_PARENTS + ) + if getattr(field, "db_returning", False) ] diff --git a/django/db/models/query.py b/django/db/models/query.py index 0bc6aec2f3..687fd8b4cd 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -11,8 +11,12 @@ import django from django.conf import settings from django.core import exceptions from django.db import ( - DJANGO_VERSION_PICKLE_KEY, IntegrityError, NotSupportedError, connections, - router, transaction, + DJANGO_VERSION_PICKLE_KEY, + IntegrityError, + NotSupportedError, + connections, + router, + transaction, ) from django.db.models import AutoField, DateField, DateTimeField, sql from django.db.models.constants import LOOKUP_SEP, OnConflict @@ -34,7 +38,9 @@ REPR_OUTPUT_SIZE = 20 class BaseIterable: - def __init__(self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE): + def __init__( + self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE + ): self.queryset = queryset self.chunked_fetch = chunked_fetch self.chunk_size = chunk_size @@ -49,25 +55,40 @@ class ModelIterable(BaseIterable): compiler = queryset.query.get_compiler(using=db) # Execute the query. This will also fill compiler.select, klass_info, # and annotations. - results = compiler.execute_sql(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size) - select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info, - compiler.annotation_col_map) - model_cls = klass_info['model'] - select_fields = klass_info['select_fields'] + results = compiler.execute_sql( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ) + select, klass_info, annotation_col_map = ( + compiler.select, + compiler.klass_info, + compiler.annotation_col_map, + ) + model_cls = klass_info["model"] + select_fields = klass_info["select_fields"] model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 - init_list = [f[0].target.attname - for f in select[model_fields_start:model_fields_end]] + init_list = [ + f[0].target.attname for f in select[model_fields_start:model_fields_end] + ] related_populators = get_related_populators(klass_info, select, db) known_related_objects = [ - (field, related_objs, operator.attrgetter(*[ - field.attname - if from_field == 'self' else - queryset.model._meta.get_field(from_field).attname - for from_field in field.from_fields - ])) for field, related_objs in queryset._known_related_objects.items() + ( + field, + related_objs, + operator.attrgetter( + *[ + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + for from_field in field.from_fields + ] + ), + ) + for field, related_objs in queryset._known_related_objects.items() ] for row in compiler.results_iter(results): - obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end]) + obj = model_cls.from_db( + db, init_list, row[model_fields_start:model_fields_end] + ) for rel_populator in related_populators: rel_populator.populate(row, obj) if annotation_col_map: @@ -107,7 +128,9 @@ class ValuesIterable(BaseIterable): *query.annotation_select, ] indexes = range(len(names)) - for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size): + for row in compiler.results_iter( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ): yield {names[i]: row[i] for i in indexes} @@ -129,16 +152,25 @@ class ValuesListIterable(BaseIterable): *query.values_select, *query.annotation_select, ] - fields = [*queryset._fields, *(f for f in query.annotation_select if f not in queryset._fields)] + fields = [ + *queryset._fields, + *(f for f in query.annotation_select if f not in queryset._fields), + ] if fields != names: # Reorder according to fields. index_map = {name: idx for idx, name in enumerate(names)} rowfactory = operator.itemgetter(*[index_map[f] for f in fields]) return map( rowfactory, - compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size) + compiler.results_iter( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ), ) - return compiler.results_iter(tuple_expected=True, chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size) + return compiler.results_iter( + tuple_expected=True, + chunked_fetch=self.chunked_fetch, + chunk_size=self.chunk_size, + ) class NamedValuesListIterable(ValuesListIterable): @@ -153,7 +185,11 @@ class NamedValuesListIterable(ValuesListIterable): names = queryset._fields else: query = queryset.query - names = [*query.extra_select, *query.values_select, *query.annotation_select] + names = [ + *query.extra_select, + *query.values_select, + *query.annotation_select, + ] tuple_class = create_namedtuple_class(*names) new = tuple.__new__ for row in super().__iter__(): @@ -169,7 +205,9 @@ class FlatValuesListIterable(BaseIterable): def __iter__(self): queryset = self.queryset compiler = queryset.query.get_compiler(queryset.db) - for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size): + for row in compiler.results_iter( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ): yield row[0] @@ -209,9 +247,11 @@ class QuerySet: def as_manager(cls): # Address the circular dependency between `Queryset` and `Manager`. from django.db.models.manager import Manager + manager = Manager.from_queryset(cls)() manager._built_with_as_manager = True return manager + as_manager.queryset_only = True as_manager = classmethod(as_manager) @@ -223,7 +263,7 @@ class QuerySet: """Don't populate the QuerySet's cache.""" obj = self.__class__() for k, v in self.__dict__.items(): - if k == '_result_cache': + if k == "_result_cache": obj.__dict__[k] = None else: obj.__dict__[k] = copy.deepcopy(v, memo) @@ -254,10 +294,10 @@ class QuerySet: self.__dict__.update(state) def __repr__(self): - data = list(self[:REPR_OUTPUT_SIZE + 1]) + data = list(self[: REPR_OUTPUT_SIZE + 1]) if len(data) > REPR_OUTPUT_SIZE: data[-1] = "...(remaining elements truncated)..." - return '<%s %r>' % (self.__class__.__name__, data) + return "<%s %r>" % (self.__class__.__name__, data) def __len__(self): self._fetch_all() @@ -289,17 +329,17 @@ class QuerySet: """Retrieve an item or slice from the set of results.""" if not isinstance(k, (int, slice)): raise TypeError( - 'QuerySet indices must be integers or slices, not %s.' + "QuerySet indices must be integers or slices, not %s." % type(k).__name__ ) - if ( - (isinstance(k, int) and k < 0) or - (isinstance(k, slice) and ( - (k.start is not None and k.start < 0) or - (k.stop is not None and k.stop < 0) - )) + if (isinstance(k, int) and k < 0) or ( + isinstance(k, slice) + and ( + (k.start is not None and k.start < 0) + or (k.stop is not None and k.stop < 0) + ) ): - raise ValueError('Negative indexing is not supported.') + raise ValueError("Negative indexing is not supported.") if self._result_cache is not None: return self._result_cache[k] @@ -315,7 +355,7 @@ class QuerySet: else: stop = None qs.query.set_limits(start, stop) - return list(qs)[::k.step] if k.step else qs + return list(qs)[:: k.step] if k.step else qs qs = self._chain() qs.query.set_limits(k, k + 1) @@ -326,7 +366,7 @@ class QuerySet: return cls def __and__(self, other): - self._check_operator_queryset(other, '&') + self._check_operator_queryset(other, "&") self._merge_sanity_check(other) if isinstance(other, EmptyQuerySet): return other @@ -338,17 +378,21 @@ class QuerySet: return combined def __or__(self, other): - self._check_operator_queryset(other, '|') + self._check_operator_queryset(other, "|") self._merge_sanity_check(other) if isinstance(self, EmptyQuerySet): return other if isinstance(other, EmptyQuerySet): return self - query = self if self.query.can_filter() else self.model._base_manager.filter(pk__in=self.values('pk')) + query = ( + self + if self.query.can_filter() + else self.model._base_manager.filter(pk__in=self.values("pk")) + ) combined = query._chain() combined._merge_known_related_objects(other) if not other.query.can_filter(): - other = other.model._base_manager.filter(pk__in=other.values('pk')) + other = other.model._base_manager.filter(pk__in=other.values("pk")) combined.query.combine(other.query, sql.OR) return combined @@ -385,14 +429,16 @@ class QuerySet: # 'QuerySet.iterator() after prefetch_related().' # ) warnings.warn( - 'Using QuerySet.iterator() after prefetch_related() ' - 'without specifying chunk_size is deprecated.', + "Using QuerySet.iterator() after prefetch_related() " + "without specifying chunk_size is deprecated.", category=RemovedInDjango50Warning, stacklevel=2, ) elif chunk_size <= 0: - raise ValueError('Chunk size must be strictly positive.') - use_chunked_fetch = not connections[self.db].settings_dict.get('DISABLE_SERVER_SIDE_CURSORS') + raise ValueError("Chunk size must be strictly positive.") + use_chunked_fetch = not connections[self.db].settings_dict.get( + "DISABLE_SERVER_SIDE_CURSORS" + ) return self._iterator(use_chunked_fetch, chunk_size) def aggregate(self, *args, **kwargs): @@ -405,7 +451,9 @@ class QuerySet: """ if self.query.distinct_fields: raise NotImplementedError("aggregate() + distinct(fields) not implemented.") - self._validate_values_are_expressions((*args, *kwargs.values()), method_name='aggregate') + self._validate_values_are_expressions( + (*args, *kwargs.values()), method_name="aggregate" + ) for arg in args: # The default_alias property raises TypeError if default_alias # can't be set automatically or AttributeError if it isn't an @@ -423,7 +471,11 @@ class QuerySet: if not annotation.contains_aggregate: raise TypeError("%s is not an aggregate expression" % alias) for expr in annotation.get_source_expressions(): - if expr.contains_aggregate and isinstance(expr, Ref) and expr.refs in kwargs: + if ( + expr.contains_aggregate + and isinstance(expr, Ref) + and expr.refs in kwargs + ): name = expr.refs raise exceptions.FieldError( "Cannot compute %s('%s'): '%s' is an aggregate" @@ -451,14 +503,17 @@ class QuerySet: """ if self.query.combinator and (args or kwargs): raise NotSupportedError( - 'Calling QuerySet.get(...) with filters after %s() is not ' - 'supported.' % self.query.combinator + "Calling QuerySet.get(...) with filters after %s() is not " + "supported." % self.query.combinator ) clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs) if self.query.can_filter() and not self.query.distinct_fields: clone = clone.order_by() limit = None - if not clone.query.select_for_update or connections[clone.db].features.supports_select_for_update_with_limit: + if ( + not clone.query.select_for_update + or connections[clone.db].features.supports_select_for_update_with_limit + ): limit = MAX_GET_RESULTS clone.query.set_limits(high=limit) num = len(clone) @@ -466,13 +521,13 @@ class QuerySet: return clone._result_cache[0] if not num: raise self.model.DoesNotExist( - "%s matching query does not exist." % - self.model._meta.object_name + "%s matching query does not exist." % self.model._meta.object_name ) raise self.model.MultipleObjectsReturned( - 'get() returned more than one %s -- it returned %s!' % ( + "get() returned more than one %s -- it returned %s!" + % ( self.model._meta.object_name, - num if not limit or num < limit else 'more than %s' % (limit - 1), + num if not limit or num < limit else "more than %s" % (limit - 1), ) ) @@ -491,69 +546,77 @@ class QuerySet: if obj.pk is None: # Populate new PK values. obj.pk = obj._meta.pk.get_pk_value_on_save(obj) - obj._prepare_related_fields_for_save(operation_name='bulk_create') + obj._prepare_related_fields_for_save(operation_name="bulk_create") - def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields): + def _check_bulk_create_options( + self, ignore_conflicts, update_conflicts, update_fields, unique_fields + ): if ignore_conflicts and update_conflicts: raise ValueError( - 'ignore_conflicts and update_conflicts are mutually exclusive.' + "ignore_conflicts and update_conflicts are mutually exclusive." ) db_features = connections[self.db].features if ignore_conflicts: if not db_features.supports_ignore_conflicts: raise NotSupportedError( - 'This database backend does not support ignoring conflicts.' + "This database backend does not support ignoring conflicts." ) return OnConflict.IGNORE elif update_conflicts: if not db_features.supports_update_conflicts: raise NotSupportedError( - 'This database backend does not support updating conflicts.' + "This database backend does not support updating conflicts." ) if not update_fields: raise ValueError( - 'Fields that will be updated when a row insertion fails ' - 'on conflicts must be provided.' + "Fields that will be updated when a row insertion fails " + "on conflicts must be provided." ) if unique_fields and not db_features.supports_update_conflicts_with_target: raise NotSupportedError( - 'This database backend does not support updating ' - 'conflicts with specifying unique fields that can trigger ' - 'the upsert.' + "This database backend does not support updating " + "conflicts with specifying unique fields that can trigger " + "the upsert." ) if not unique_fields and db_features.supports_update_conflicts_with_target: raise ValueError( - 'Unique fields that can trigger the upsert must be provided.' + "Unique fields that can trigger the upsert must be provided." ) # Updating primary keys and non-concrete fields is forbidden. update_fields = [self.model._meta.get_field(name) for name in update_fields] if any(not f.concrete or f.many_to_many for f in update_fields): raise ValueError( - 'bulk_create() can only be used with concrete fields in ' - 'update_fields.' + "bulk_create() can only be used with concrete fields in " + "update_fields." ) if any(f.primary_key for f in update_fields): raise ValueError( - 'bulk_create() cannot be used with primary keys in ' - 'update_fields.' + "bulk_create() cannot be used with primary keys in " + "update_fields." ) if unique_fields: # Primary key is allowed in unique_fields. unique_fields = [ self.model._meta.get_field(name) - for name in unique_fields if name != 'pk' + for name in unique_fields + if name != "pk" ] if any(not f.concrete or f.many_to_many for f in unique_fields): raise ValueError( - 'bulk_create() can only be used with concrete fields ' - 'in unique_fields.' + "bulk_create() can only be used with concrete fields " + "in unique_fields." ) return OnConflict.UPDATE return None def bulk_create( - self, objs, batch_size=None, ignore_conflicts=False, - update_conflicts=False, update_fields=None, unique_fields=None, + self, + objs, + batch_size=None, + ignore_conflicts=False, + update_conflicts=False, + update_fields=None, + unique_fields=None, ): """ Insert each of the instances into the database. Do *not* call @@ -575,7 +638,7 @@ class QuerySet: # Oracle as well, but the semantics for extracting the primary keys is # trickier so it's not done yet. if batch_size is not None and batch_size <= 0: - raise ValueError('Batch size must be a positive integer.') + raise ValueError("Batch size must be a positive integer.") # Check that the parents share the same concrete model with the our # model to detect the inheritance pattern ConcreteGrandParent -> # MultiTableParent -> ProxyChild. Simply checking self.model._meta.proxy @@ -625,7 +688,10 @@ class QuerySet: unique_fields=unique_fields, ) connection = connections[self.db] - if connection.features.can_return_rows_from_bulk_insert and on_conflict is None: + if ( + connection.features.can_return_rows_from_bulk_insert + and on_conflict is None + ): assert len(returned_columns) == len(objs_without_pk) for obj_without_pk, results in zip(objs_without_pk, returned_columns): for result, field in zip(results, opts.db_returning_fields): @@ -640,28 +706,30 @@ class QuerySet: Update the given fields in each of the given objects in the database. """ if batch_size is not None and batch_size < 0: - raise ValueError('Batch size must be a positive integer.') + raise ValueError("Batch size must be a positive integer.") if not fields: - raise ValueError('Field names must be given to bulk_update().') + raise ValueError("Field names must be given to bulk_update().") objs = tuple(objs) if any(obj.pk is None for obj in objs): - raise ValueError('All bulk_update() objects must have a primary key set.') + raise ValueError("All bulk_update() objects must have a primary key set.") fields = [self.model._meta.get_field(name) for name in fields] if any(not f.concrete or f.many_to_many for f in fields): - raise ValueError('bulk_update() can only be used with concrete fields.') + raise ValueError("bulk_update() can only be used with concrete fields.") if any(f.primary_key for f in fields): - raise ValueError('bulk_update() cannot be used with primary key fields.') + raise ValueError("bulk_update() cannot be used with primary key fields.") if not objs: return 0 for obj in objs: - obj._prepare_related_fields_for_save(operation_name='bulk_update', fields=fields) + obj._prepare_related_fields_for_save( + operation_name="bulk_update", fields=fields + ) # PK is used twice in the resulting update query, once in the filter # and once in the WHEN. Each field will also have one CAST. connection = connections[self.db] - max_batch_size = connection.ops.bulk_batch_size(['pk', 'pk'] + fields, objs) + max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size requires_casting = connection.features.requires_casted_case_in_updates - batches = (objs[i:i + batch_size] for i in range(0, len(objs), batch_size)) + batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size)) updates = [] for batch_objs in batches: update_kwargs = {} @@ -669,7 +737,7 @@ class QuerySet: when_statements = [] for obj in batch_objs: attr = getattr(obj, field.attname) - if not hasattr(attr, 'resolve_expression'): + if not hasattr(attr, "resolve_expression"): attr = Value(attr, output_field=field) when_statements.append(When(pk=obj.pk, then=attr)) case_statement = Case(*when_statements, output_field=field) @@ -682,6 +750,7 @@ class QuerySet: for pks, update_kwargs in updates: rows_updated += self.filter(pk__in=pks).update(**update_kwargs) return rows_updated + bulk_update.alters_data = True def get_or_create(self, defaults=None, **kwargs): @@ -748,10 +817,12 @@ class QuerySet: invalid_params.append(param) if invalid_params: raise exceptions.FieldError( - "Invalid field name(s) for model %s: '%s'." % ( + "Invalid field name(s) for model %s: '%s'." + % ( self.model._meta.object_name, "', '".join(sorted(invalid_params)), - )) + ) + ) return params def _earliest(self, *fields): @@ -762,7 +833,7 @@ class QuerySet: if fields: order_by = fields else: - order_by = getattr(self.model._meta, 'get_latest_by') + order_by = getattr(self.model._meta, "get_latest_by") if order_by and not isinstance(order_by, (tuple, list)): order_by = (order_by,) if order_by is None: @@ -778,25 +849,25 @@ class QuerySet: def earliest(self, *fields): if self.query.is_sliced: - raise TypeError('Cannot change a query once a slice has been taken.') + raise TypeError("Cannot change a query once a slice has been taken.") return self._earliest(*fields) def latest(self, *fields): if self.query.is_sliced: - raise TypeError('Cannot change a query once a slice has been taken.') + raise TypeError("Cannot change a query once a slice has been taken.") return self.reverse()._earliest(*fields) def first(self): """Return the first object of a query or None if no match is found.""" - for obj in (self if self.ordered else self.order_by('pk'))[:1]: + for obj in (self if self.ordered else self.order_by("pk"))[:1]: return obj def last(self): """Return the last object of a query or None if no match is found.""" - for obj in (self.reverse() if self.ordered else self.order_by('-pk'))[:1]: + for obj in (self.reverse() if self.ordered else self.order_by("-pk"))[:1]: return obj - def in_bulk(self, id_list=None, *, field_name='pk'): + def in_bulk(self, id_list=None, *, field_name="pk"): """ Return a dictionary mapping each of the given IDs to the object with that ID. If `id_list` isn't provided, evaluate the entire QuerySet. @@ -810,16 +881,19 @@ class QuerySet: if len(constraint.fields) == 1 ] if ( - field_name != 'pk' and - not opts.get_field(field_name).unique and - field_name not in unique_fields and - self.query.distinct_fields != (field_name,) + field_name != "pk" + and not opts.get_field(field_name).unique + and field_name not in unique_fields + and self.query.distinct_fields != (field_name,) ): - raise ValueError("in_bulk()'s field_name must be a unique field but %r isn't." % field_name) + raise ValueError( + "in_bulk()'s field_name must be a unique field but %r isn't." + % field_name + ) if id_list is not None: if not id_list: return {} - filter_key = '{}__in'.format(field_name) + filter_key = "{}__in".format(field_name) batch_size = connections[self.db].features.max_query_params id_list = tuple(id_list) # If the database has a limit on the number of query parameters @@ -827,7 +901,7 @@ class QuerySet: if batch_size and batch_size < len(id_list): qs = () for offset in range(0, len(id_list), batch_size): - batch = id_list[offset:offset + batch_size] + batch = id_list[offset : offset + batch_size] qs += tuple(self.filter(**{filter_key: batch}).order_by()) else: qs = self.filter(**{filter_key: id_list}).order_by() @@ -837,11 +911,11 @@ class QuerySet: def delete(self): """Delete the records in the current QuerySet.""" - self._not_support_combined_queries('delete') + self._not_support_combined_queries("delete") if self.query.is_sliced: raise TypeError("Cannot use 'limit' or 'offset' with delete().") if self.query.distinct or self.query.distinct_fields: - raise TypeError('Cannot call delete() after .distinct().') + raise TypeError("Cannot call delete() after .distinct().") if self._fields is not None: raise TypeError("Cannot call delete() after .values() or .values_list()") @@ -880,6 +954,7 @@ class QuerySet: with cursor: return cursor.rowcount return 0 + _raw_delete.alters_data = True def update(self, **kwargs): @@ -887,9 +962,9 @@ class QuerySet: Update all elements in the current QuerySet, setting all the given fields to the appropriate values. """ - self._not_support_combined_queries('update') + self._not_support_combined_queries("update") if self.query.is_sliced: - raise TypeError('Cannot update a query once a slice has been taken.') + raise TypeError("Cannot update a query once a slice has been taken.") self._for_write = True query = self.query.chain(sql.UpdateQuery) query.add_update_values(kwargs) @@ -899,6 +974,7 @@ class QuerySet: rows = query.get_compiler(self.db).execute_sql(CURSOR) self._result_cache = None return rows + update.alters_data = True def _update(self, values): @@ -909,13 +985,14 @@ class QuerySet: useful at that level). """ if self.query.is_sliced: - raise TypeError('Cannot update a query once a slice has been taken.') + raise TypeError("Cannot update a query once a slice has been taken.") query = self.query.chain(sql.UpdateQuery) query.add_update_fields(values) # Clear any annotations so that they won't be present in subqueries. query.annotations = {} self._result_cache = None return query.get_compiler(self.db).execute_sql(CURSOR) + _update.alters_data = True _update.queryset_only = False @@ -926,10 +1003,10 @@ class QuerySet: def contains(self, obj): """Return True if the queryset contains an object.""" - self._not_support_combined_queries('contains') + self._not_support_combined_queries("contains") if self._fields is not None: raise TypeError( - 'Cannot call QuerySet.contains() after .values() or .values_list().' + "Cannot call QuerySet.contains() after .values() or .values_list()." ) try: if obj._meta.concrete_model != self.model._meta.concrete_model: @@ -937,9 +1014,7 @@ class QuerySet: except AttributeError: raise TypeError("'obj' must be a model instance.") if obj.pk is None: - raise ValueError( - 'QuerySet.contains() cannot be used on unsaved objects.' - ) + raise ValueError("QuerySet.contains() cannot be used on unsaved objects.") if self._result_cache is not None: return obj in self._result_cache return self.filter(pk=obj.pk).exists() @@ -959,7 +1034,13 @@ class QuerySet: def raw(self, raw_query, params=(), translations=None, using=None): if using is None: using = self.db - qs = RawQuerySet(raw_query, model=self.model, params=params, translations=translations, using=using) + qs = RawQuerySet( + raw_query, + model=self.model, + params=params, + translations=translations, + using=using, + ) qs._prefetch_related_lookups = self._prefetch_related_lookups[:] return qs @@ -981,15 +1062,19 @@ class QuerySet: if flat and named: raise TypeError("'flat' and 'named' can't be used together.") if flat and len(fields) > 1: - raise TypeError("'flat' is not valid when values_list is called with more than one field.") + raise TypeError( + "'flat' is not valid when values_list is called with more than one field." + ) - field_names = {f for f in fields if not hasattr(f, 'resolve_expression')} + field_names = {f for f in fields if not hasattr(f, "resolve_expression")} _fields = [] expressions = {} counter = 1 for field in fields: - if hasattr(field, 'resolve_expression'): - field_id_prefix = getattr(field, 'default_alias', field.__class__.__name__.lower()) + if hasattr(field, "resolve_expression"): + field_id_prefix = getattr( + field, "default_alias", field.__class__.__name__.lower() + ) while True: field_id = field_id_prefix + str(counter) counter += 1 @@ -1002,59 +1087,71 @@ class QuerySet: clone = self._values(*_fields, **expressions) clone._iterable_class = ( - NamedValuesListIterable if named - else FlatValuesListIterable if flat + NamedValuesListIterable + if named + else FlatValuesListIterable + if flat else ValuesListIterable ) return clone - def dates(self, field_name, kind, order='ASC'): + def dates(self, field_name, kind, order="ASC"): """ Return a list of date objects representing all available dates for the given field_name, scoped to 'kind'. """ - if kind not in ('year', 'month', 'week', 'day'): + if kind not in ("year", "month", "week", "day"): raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.") - if order not in ('ASC', 'DESC'): + if order not in ("ASC", "DESC"): raise ValueError("'order' must be either 'ASC' or 'DESC'.") - return self.annotate( - datefield=Trunc(field_name, kind, output_field=DateField()), - plain_field=F(field_name) - ).values_list( - 'datefield', flat=True - ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datefield') + return ( + self.annotate( + datefield=Trunc(field_name, kind, output_field=DateField()), + plain_field=F(field_name), + ) + .values_list("datefield", flat=True) + .distinct() + .filter(plain_field__isnull=False) + .order_by(("-" if order == "DESC" else "") + "datefield") + ) # RemovedInDjango50Warning: when the deprecation ends, remove is_dst # argument. - def datetimes(self, field_name, kind, order='ASC', tzinfo=None, is_dst=timezone.NOT_PASSED): + def datetimes( + self, field_name, kind, order="ASC", tzinfo=None, is_dst=timezone.NOT_PASSED + ): """ Return a list of datetime objects representing all available datetimes for the given field_name, scoped to 'kind'. """ - if kind not in ('year', 'month', 'week', 'day', 'hour', 'minute', 'second'): + if kind not in ("year", "month", "week", "day", "hour", "minute", "second"): raise ValueError( "'kind' must be one of 'year', 'month', 'week', 'day', " "'hour', 'minute', or 'second'." ) - if order not in ('ASC', 'DESC'): + if order not in ("ASC", "DESC"): raise ValueError("'order' must be either 'ASC' or 'DESC'.") if settings.USE_TZ: if tzinfo is None: tzinfo = timezone.get_current_timezone() else: tzinfo = None - return self.annotate( - datetimefield=Trunc( - field_name, - kind, - output_field=DateTimeField(), - tzinfo=tzinfo, - is_dst=is_dst, - ), - plain_field=F(field_name) - ).values_list( - 'datetimefield', flat=True - ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datetimefield') + return ( + self.annotate( + datetimefield=Trunc( + field_name, + kind, + output_field=DateTimeField(), + tzinfo=tzinfo, + is_dst=is_dst, + ), + plain_field=F(field_name), + ) + .values_list("datetimefield", flat=True) + .distinct() + .filter(plain_field__isnull=False) + .order_by(("-" if order == "DESC" else "") + "datetimefield") + ) def none(self): """Return an empty QuerySet.""" @@ -1078,7 +1175,7 @@ class QuerySet: Return a new QuerySet instance with the args ANDed to the existing set. """ - self._not_support_combined_queries('filter') + self._not_support_combined_queries("filter") return self._filter_or_exclude(False, args, kwargs) def exclude(self, *args, **kwargs): @@ -1086,12 +1183,12 @@ class QuerySet: Return a new QuerySet instance with NOT (args) ANDed to the existing set. """ - self._not_support_combined_queries('exclude') + self._not_support_combined_queries("exclude") return self._filter_or_exclude(True, args, kwargs) def _filter_or_exclude(self, negate, args, kwargs): if (args or kwargs) and self.query.is_sliced: - raise TypeError('Cannot filter a query once a slice has been taken.') + raise TypeError("Cannot filter a query once a slice has been taken.") clone = self._chain() if self._defer_next_filter: self._defer_next_filter = False @@ -1129,7 +1226,9 @@ class QuerySet: # Clear limits and ordering so they can be reapplied clone.query.clear_ordering(force=True) clone.query.clear_limits() - clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs) + clone.query.combined_queries = (self.query,) + tuple( + qs.query for qs in other_qs + ) clone.query.combinator = combinator clone.query.combinator_all = all return clone @@ -1142,8 +1241,8 @@ class QuerySet: return self if len(qs) == 1: return qs[0] - return qs[0]._combinator_query('union', *qs[1:], all=all) - return self._combinator_query('union', *other_qs, all=all) + return qs[0]._combinator_query("union", *qs[1:], all=all) + return self._combinator_query("union", *other_qs, all=all) def intersection(self, *other_qs): # If any query is an EmptyQuerySet, return it. @@ -1152,13 +1251,13 @@ class QuerySet: for other in other_qs: if isinstance(other, EmptyQuerySet): return other - return self._combinator_query('intersection', *other_qs) + return self._combinator_query("intersection", *other_qs) def difference(self, *other_qs): # If the query is an EmptyQuerySet, return it. if isinstance(self, EmptyQuerySet): return self - return self._combinator_query('difference', *other_qs) + return self._combinator_query("difference", *other_qs) def select_for_update(self, nowait=False, skip_locked=False, of=(), no_key=False): """ @@ -1166,7 +1265,7 @@ class QuerySet: FOR UPDATE lock. """ if nowait and skip_locked: - raise ValueError('The nowait option cannot be used with skip_locked.') + raise ValueError("The nowait option cannot be used with skip_locked.") obj = self._chain() obj._for_write = True obj.query.select_for_update = True @@ -1185,9 +1284,11 @@ class QuerySet: If select_related(None) is called, clear the list. """ - self._not_support_combined_queries('select_related') + self._not_support_combined_queries("select_related") if self._fields is not None: - raise TypeError("Cannot call select_related() after .values() or .values_list()") + raise TypeError( + "Cannot call select_related() after .values() or .values_list()" + ) obj = self._chain() if fields == (None,): @@ -1207,7 +1308,7 @@ class QuerySet: When prefetch_related() is called more than once, append to the list of prefetch lookups. If prefetch_related(None) is called, clear the list. """ - self._not_support_combined_queries('prefetch_related') + self._not_support_combined_queries("prefetch_related") clone = self._chain() if lookups == (None,): clone._prefetch_related_lookups = () @@ -1217,7 +1318,9 @@ class QuerySet: lookup = lookup.prefetch_to lookup = lookup.split(LOOKUP_SEP, 1)[0] if lookup in self.query._filtered_relations: - raise ValueError('prefetch_related() is not supported with FilteredRelation.') + raise ValueError( + "prefetch_related() is not supported with FilteredRelation." + ) clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups return clone @@ -1226,26 +1329,29 @@ class QuerySet: Return a query set in which the returned objects have been annotated with extra data or aggregations. """ - self._not_support_combined_queries('annotate') + self._not_support_combined_queries("annotate") return self._annotate(args, kwargs, select=True) def alias(self, *args, **kwargs): """ Return a query set with added aliases for extra data or aggregations. """ - self._not_support_combined_queries('alias') + self._not_support_combined_queries("alias") return self._annotate(args, kwargs, select=False) def _annotate(self, args, kwargs, select=True): - self._validate_values_are_expressions(args + tuple(kwargs.values()), method_name='annotate') + self._validate_values_are_expressions( + args + tuple(kwargs.values()), method_name="annotate" + ) annotations = {} for arg in args: # The default_alias property may raise a TypeError. try: if arg.default_alias in kwargs: - raise ValueError("The named annotation '%s' conflicts with the " - "default name for another annotation." - % arg.default_alias) + raise ValueError( + "The named annotation '%s' conflicts with the " + "default name for another annotation." % arg.default_alias + ) except TypeError: raise TypeError("Complex annotations require an alias") annotations[arg.default_alias] = arg @@ -1254,20 +1360,29 @@ class QuerySet: clone = self._chain() names = self._fields if names is None: - names = set(chain.from_iterable( - (field.name, field.attname) if hasattr(field, 'attname') else (field.name,) - for field in self.model._meta.get_fields() - )) + names = set( + chain.from_iterable( + (field.name, field.attname) + if hasattr(field, "attname") + else (field.name,) + for field in self.model._meta.get_fields() + ) + ) for alias, annotation in annotations.items(): if alias in names: - raise ValueError("The annotation '%s' conflicts with a field on " - "the model." % alias) + raise ValueError( + "The annotation '%s' conflicts with a field on " + "the model." % alias + ) if isinstance(annotation, FilteredRelation): clone.query.add_filtered_relation(annotation, alias) else: clone.query.add_annotation( - annotation, alias, is_summary=False, select=select, + annotation, + alias, + is_summary=False, + select=select, ) for alias, annotation in clone.query.annotations.items(): if alias in annotations and annotation.contains_aggregate: @@ -1282,7 +1397,7 @@ class QuerySet: def order_by(self, *field_names): """Return a new QuerySet instance with the ordering changed.""" if self.query.is_sliced: - raise TypeError('Cannot reorder a query once a slice has been taken.') + raise TypeError("Cannot reorder a query once a slice has been taken.") obj = self._chain() obj.query.clear_ordering(force=True, clear_default=False) obj.query.add_ordering(*field_names) @@ -1292,19 +1407,28 @@ class QuerySet: """ Return a new QuerySet instance that will select only distinct results. """ - self._not_support_combined_queries('distinct') + self._not_support_combined_queries("distinct") if self.query.is_sliced: - raise TypeError('Cannot create distinct fields once a slice has been taken.') + raise TypeError( + "Cannot create distinct fields once a slice has been taken." + ) obj = self._chain() obj.query.add_distinct_fields(*field_names) return obj - def extra(self, select=None, where=None, params=None, tables=None, - order_by=None, select_params=None): + def extra( + self, + select=None, + where=None, + params=None, + tables=None, + order_by=None, + select_params=None, + ): """Add extra SQL fragments to the query.""" - self._not_support_combined_queries('extra') + self._not_support_combined_queries("extra") if self.query.is_sliced: - raise TypeError('Cannot change a query once a slice has been taken.') + raise TypeError("Cannot change a query once a slice has been taken.") clone = self._chain() clone.query.add_extra(select, select_params, where, params, tables, order_by) return clone @@ -1312,7 +1436,7 @@ class QuerySet: def reverse(self): """Reverse the ordering of the QuerySet.""" if self.query.is_sliced: - raise TypeError('Cannot reverse a query once a slice has been taken.') + raise TypeError("Cannot reverse a query once a slice has been taken.") clone = self._chain() clone.query.standard_ordering = not clone.query.standard_ordering return clone @@ -1324,7 +1448,7 @@ class QuerySet: The only exception to this is if None is passed in as the only parameter, in which case removal all deferrals. """ - self._not_support_combined_queries('defer') + self._not_support_combined_queries("defer") if self._fields is not None: raise TypeError("Cannot call defer() after .values() or .values_list()") clone = self._chain() @@ -1340,7 +1464,7 @@ class QuerySet: method and that are not already specified as deferred are loaded immediately when the queryset is evaluated. """ - self._not_support_combined_queries('only') + self._not_support_combined_queries("only") if self._fields is not None: raise TypeError("Cannot call only() after .values() or .values_list()") if fields == (None,): @@ -1350,7 +1474,7 @@ class QuerySet: for field in fields: field = field.split(LOOKUP_SEP, 1)[0] if field in self.query._filtered_relations: - raise ValueError('only() is not supported with FilteredRelation.') + raise ValueError("only() is not supported with FilteredRelation.") clone = self._chain() clone.query.add_immediate_loading(fields) return clone @@ -1376,8 +1500,9 @@ class QuerySet: if self.query.extra_order_by or self.query.order_by: return True elif ( - self.query.default_ordering and - self.query.get_meta().ordering and + self.query.default_ordering + and self.query.get_meta().ordering + and # A default ordering doesn't affect GROUP BY queries. not self.query.group_by ): @@ -1397,8 +1522,15 @@ class QuerySet: ################### def _insert( - self, objs, fields, returning_fields=None, raw=False, using=None, - on_conflict=None, update_fields=None, unique_fields=None, + self, + objs, + fields, + returning_fields=None, + raw=False, + using=None, + on_conflict=None, + update_fields=None, + unique_fields=None, ): """ Insert a new record for the given model. This provides an interface to @@ -1415,11 +1547,17 @@ class QuerySet: ) query.insert_values(fields, objs, raw=raw) return query.get_compiler(using=using).execute_sql(returning_fields) + _insert.alters_data = True _insert.queryset_only = False def _batched_insert( - self, objs, fields, batch_size, on_conflict=None, update_fields=None, + self, + objs, + fields, + batch_size, + on_conflict=None, + update_fields=None, unique_fields=None, ): """ @@ -1431,12 +1569,16 @@ class QuerySet: batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size inserted_rows = [] bulk_return = connection.features.can_return_rows_from_bulk_insert - for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: + for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: if bulk_return and on_conflict is None: - inserted_rows.extend(self._insert( - item, fields=fields, using=self.db, - returning_fields=self.model._meta.db_returning_fields, - )) + inserted_rows.extend( + self._insert( + item, + fields=fields, + using=self.db, + returning_fields=self.model._meta.db_returning_fields, + ) + ) else: self._insert( item, @@ -1464,7 +1606,12 @@ class QuerySet: Return a copy of the current QuerySet. A lightweight alternative to deepcopy(). """ - c = self.__class__(model=self.model, query=self.query.chain(), using=self._db, hints=self._hints) + c = self.__class__( + model=self.model, + query=self.query.chain(), + using=self._db, + hints=self._hints, + ) c._sticky_filter = self._sticky_filter c._for_write = self._for_write c._prefetch_related_lookups = self._prefetch_related_lookups[:] @@ -1496,9 +1643,10 @@ class QuerySet: def _merge_sanity_check(self, other): """Check that two QuerySet classes may be merged.""" if self._fields is not None and ( - set(self.query.values_select) != set(other.query.values_select) or - set(self.query.extra_select) != set(other.query.extra_select) or - set(self.query.annotation_select) != set(other.query.annotation_select)): + set(self.query.values_select) != set(other.query.values_select) + or set(self.query.extra_select) != set(other.query.extra_select) + or set(self.query.annotation_select) != set(other.query.annotation_select) + ): raise TypeError( "Merging '%s' classes must involve the same values in each case." % self.__class__.__name__ @@ -1515,10 +1663,11 @@ class QuerySet: if self._fields and len(self._fields) > 1: # values() queryset can only be used as nested queries # if they are set up to select only a single field. - raise TypeError('Cannot use multi-field values as a filter value.') + raise TypeError("Cannot use multi-field values as a filter value.") query = self.query.resolve_expression(*args, **kwargs) query._db = self._db return query + resolve_expression.queryset_only = True def _add_hints(self, **hints): @@ -1538,25 +1687,28 @@ class QuerySet: @staticmethod def _validate_values_are_expressions(values, method_name): - invalid_args = sorted(str(arg) for arg in values if not hasattr(arg, 'resolve_expression')) + invalid_args = sorted( + str(arg) for arg in values if not hasattr(arg, "resolve_expression") + ) if invalid_args: raise TypeError( - 'QuerySet.%s() received non-expression(s): %s.' % ( + "QuerySet.%s() received non-expression(s): %s." + % ( method_name, - ', '.join(invalid_args), + ", ".join(invalid_args), ) ) def _not_support_combined_queries(self, operation_name): if self.query.combinator: raise NotSupportedError( - 'Calling QuerySet.%s() after %s() is not supported.' + "Calling QuerySet.%s() after %s() is not supported." % (operation_name, self.query.combinator) ) def _check_operator_queryset(self, other, operator_): if self.query.combinator or other.query.combinator: - raise TypeError(f'Cannot use {operator_} operator with combined queryset.') + raise TypeError(f"Cannot use {operator_} operator with combined queryset.") class InstanceCheckMeta(type): @@ -1579,8 +1731,17 @@ class RawQuerySet: Provide an iterator which converts the results of raw SQL queries into annotated model instances. """ - def __init__(self, raw_query, model=None, query=None, params=(), - translations=None, using=None, hints=None): + + def __init__( + self, + raw_query, + model=None, + query=None, + params=(), + translations=None, + using=None, + hints=None, + ): self.raw_query = raw_query self.model = model self._db = using @@ -1595,10 +1756,17 @@ class RawQuerySet: def resolve_model_init_order(self): """Resolve the init field names and value positions.""" converter = connections[self.db].introspection.identifier_converter - model_init_fields = [f for f in self.model._meta.fields if converter(f.column) in self.columns] - annotation_fields = [(column, pos) for pos, column in enumerate(self.columns) - if column not in self.model_fields] - model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields] + model_init_fields = [ + f for f in self.model._meta.fields if converter(f.column) in self.columns + ] + annotation_fields = [ + (column, pos) + for pos, column in enumerate(self.columns) + if column not in self.model_fields + ] + model_init_order = [ + self.columns.index(converter(f.column)) for f in model_init_fields + ] model_init_names = [f.attname for f in model_init_fields] return model_init_names, model_init_order, annotation_fields @@ -1618,8 +1786,13 @@ class RawQuerySet: def _clone(self): """Same as QuerySet._clone()""" c = self.__class__( - self.raw_query, model=self.model, query=self.query, params=self.params, - translations=self.translations, using=self._db, hints=self._hints + self.raw_query, + model=self.model, + query=self.query, + params=self.params, + translations=self.translations, + using=self._db, + hints=self._hints, ) c._prefetch_related_lookups = self._prefetch_related_lookups[:] return c @@ -1646,20 +1819,24 @@ class RawQuerySet: # Cache some things for performance reasons outside the loop. db = self.db connection = connections[db] - compiler = connection.ops.compiler('SQLCompiler')(self.query, connection, db) + compiler = connection.ops.compiler("SQLCompiler")(self.query, connection, db) query = iter(self.query) try: - model_init_names, model_init_pos, annotation_fields = self.resolve_model_init_order() + ( + model_init_names, + model_init_pos, + annotation_fields, + ) = self.resolve_model_init_order() if self.model._meta.pk.attname not in model_init_names: raise exceptions.FieldDoesNotExist( - 'Raw query must include the primary key' + "Raw query must include the primary key" ) model_cls = self.model fields = [self.model_fields.get(c) for c in self.columns] - converters = compiler.get_converters([ - f.get_col(f.model._meta.db_table) if f else None for f in fields - ]) + converters = compiler.get_converters( + [f.get_col(f.model._meta.db_table) if f else None for f in fields] + ) if converters: query = compiler.apply_converters(query, converters) for values in query: @@ -1672,7 +1849,7 @@ class RawQuerySet: yield instance finally: # Done iterating the Query. If it has its own cursor, close it. - if hasattr(self.query, 'cursor') and self.query.cursor: + if hasattr(self.query, "cursor") and self.query.cursor: self.query.cursor.close() def __repr__(self): @@ -1689,9 +1866,11 @@ class RawQuerySet: def using(self, alias): """Select the database this RawQuerySet should execute against.""" return RawQuerySet( - self.raw_query, model=self.model, + self.raw_query, + model=self.model, query=self.query.chain(using=alias), - params=self.params, translations=self.translations, + params=self.params, + translations=self.translations, using=alias, ) @@ -1731,16 +1910,19 @@ class Prefetch: # `prefetch_to` is the path to the attribute that stores the result. self.prefetch_to = lookup if queryset is not None and ( - isinstance(queryset, RawQuerySet) or ( - hasattr(queryset, '_iterable_class') and - not issubclass(queryset._iterable_class, ModelIterable) + isinstance(queryset, RawQuerySet) + or ( + hasattr(queryset, "_iterable_class") + and not issubclass(queryset._iterable_class, ModelIterable) ) ): raise ValueError( - 'Prefetch querysets cannot use raw(), values(), and values_list().' + "Prefetch querysets cannot use raw(), values(), and values_list()." ) if to_attr: - self.prefetch_to = LOOKUP_SEP.join(lookup.split(LOOKUP_SEP)[:-1] + [to_attr]) + self.prefetch_to = LOOKUP_SEP.join( + lookup.split(LOOKUP_SEP)[:-1] + [to_attr] + ) self.queryset = queryset self.to_attr = to_attr @@ -1752,7 +1934,7 @@ class Prefetch: # Prevent the QuerySet from being evaluated queryset._result_cache = [] queryset._prefetch_done = True - obj_dict['queryset'] = queryset + obj_dict["queryset"] = queryset return obj_dict def add_prefix(self, prefix): @@ -1760,7 +1942,7 @@ class Prefetch: self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to def get_current_prefetch_to(self, level): - return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[:level + 1]) + return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1]) def get_current_to_attr(self, level): parts = self.prefetch_to.split(LOOKUP_SEP) @@ -1805,7 +1987,7 @@ def prefetch_related_objects(model_instances, *related_lookups): # We need to be able to dynamically add to the list of prefetch_related # lookups that we look up (see below). So we need some book keeping to # ensure we don't do duplicate work. - done_queries = {} # dictionary of things like 'foo__bar': [results] + done_queries = {} # dictionary of things like 'foo__bar': [results] auto_lookups = set() # we add to this as we go through. followed_descriptors = set() # recursion protection @@ -1815,8 +1997,11 @@ def prefetch_related_objects(model_instances, *related_lookups): lookup = all_lookups.pop() if lookup.prefetch_to in done_queries: if lookup.queryset is not None: - raise ValueError("'%s' lookup was already seen with a different queryset. " - "You may need to adjust the ordering of your lookups." % lookup.prefetch_to) + raise ValueError( + "'%s' lookup was already seen with a different queryset. " + "You may need to adjust the ordering of your lookups." + % lookup.prefetch_to + ) continue @@ -1842,7 +2027,7 @@ def prefetch_related_objects(model_instances, *related_lookups): # Since prefetching can re-use instances, it is possible to have # the same instance multiple times in obj_list, so obj might # already be prepared. - if not hasattr(obj, '_prefetched_objects_cache'): + if not hasattr(obj, "_prefetched_objects_cache"): try: obj._prefetched_objects_cache = {} except (AttributeError, TypeError): @@ -1862,20 +2047,30 @@ def prefetch_related_objects(model_instances, *related_lookups): # of prefetch_related), so what applies to first object applies to all. first_obj = obj_list[0] to_attr = lookup.get_current_to_attr(level)[0] - prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, through_attr, to_attr) + prefetcher, descriptor, attr_found, is_fetched = get_prefetcher( + first_obj, through_attr, to_attr + ) if not attr_found: - raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid " - "parameter to prefetch_related()" % - (through_attr, first_obj.__class__.__name__, lookup.prefetch_through)) + raise AttributeError( + "Cannot find '%s' on %s object, '%s' is an invalid " + "parameter to prefetch_related()" + % ( + through_attr, + first_obj.__class__.__name__, + lookup.prefetch_through, + ) + ) if level == len(through_attrs) - 1 and prefetcher is None: # Last one, this *must* resolve to something that supports # prefetching, otherwise there is no point adding it and the # developer asking for it has made a mistake. - raise ValueError("'%s' does not resolve to an item that supports " - "prefetching - this is an invalid parameter to " - "prefetch_related()." % lookup.prefetch_through) + raise ValueError( + "'%s' does not resolve to an item that supports " + "prefetching - this is an invalid parameter to " + "prefetch_related()." % lookup.prefetch_through + ) obj_to_fetch = None if prefetcher is not None: @@ -1892,9 +2087,15 @@ def prefetch_related_objects(model_instances, *related_lookups): # same relationships to stop infinite recursion. So, if we # are already on an automatically added lookup, don't add # the new lookups from relationships we've seen already. - if not (prefetch_to in done_queries and lookup in auto_lookups and descriptor in followed_descriptors): + if not ( + prefetch_to in done_queries + and lookup in auto_lookups + and descriptor in followed_descriptors + ): done_queries[prefetch_to] = obj_list - new_lookups = normalize_prefetch_lookups(reversed(additional_lookups), prefetch_to) + new_lookups = normalize_prefetch_lookups( + reversed(additional_lookups), prefetch_to + ) auto_lookups.update(new_lookups) all_lookups.extend(new_lookups) followed_descriptors.add(descriptor) @@ -1908,7 +2109,7 @@ def prefetch_related_objects(model_instances, *related_lookups): # that we can continue with nullable or reverse relations. new_obj_list = [] for obj in obj_list: - if through_attr in getattr(obj, '_prefetched_objects_cache', ()): + if through_attr in getattr(obj, "_prefetched_objects_cache", ()): # If related objects have been prefetched, use the # cache rather than the object's through_attr. new_obj = list(obj._prefetched_objects_cache.get(through_attr)) @@ -1940,6 +2141,7 @@ def get_prefetcher(instance, through_attr, to_attr): a function that takes an instance and returns a boolean that is True if the attribute has already been fetched for that instance) """ + def has_to_attr_attribute(instance): return hasattr(instance, to_attr) @@ -1957,7 +2159,7 @@ def get_prefetcher(instance, through_attr, to_attr): if rel_obj_descriptor: # singly related object, descriptor object has the # get_prefetch_queryset() method. - if hasattr(rel_obj_descriptor, 'get_prefetch_queryset'): + if hasattr(rel_obj_descriptor, "get_prefetch_queryset"): prefetcher = rel_obj_descriptor is_fetched = rel_obj_descriptor.is_cached else: @@ -1965,17 +2167,21 @@ def get_prefetcher(instance, through_attr, to_attr): # the attribute on the instance rather than the class to # support many related managers rel_obj = getattr(instance, through_attr) - if hasattr(rel_obj, 'get_prefetch_queryset'): + if hasattr(rel_obj, "get_prefetch_queryset"): prefetcher = rel_obj if through_attr != to_attr: # Special case cached_property instances because hasattr # triggers attribute computation and assignment. - if isinstance(getattr(instance.__class__, to_attr, None), cached_property): + if isinstance( + getattr(instance.__class__, to_attr, None), cached_property + ): + def has_cached_property(instance): return to_attr in instance.__dict__ is_fetched = has_cached_property else: + def in_prefetched_cache(instance): return through_attr in instance._prefetched_objects_cache @@ -2006,8 +2212,14 @@ def prefetch_one_level(instances, prefetcher, lookup, level): # The 'values to be matched' must be hashable as they will be used # in a dictionary. - rel_qs, rel_obj_attr, instance_attr, single, cache_name, is_descriptor = ( - prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))) + ( + rel_qs, + rel_obj_attr, + instance_attr, + single, + cache_name, + is_descriptor, + ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level)) # We have to handle the possibility that the QuerySet we just got back # contains some prefetch_related lookups. We don't want to trigger the # prefetch_related functionality by evaluating the query. Rather, we need @@ -2015,8 +2227,8 @@ def prefetch_one_level(instances, prefetcher, lookup, level): # Copy the lookups in case it is a Prefetch object which could be reused # later (happens in nested prefetch_related). additional_lookups = [ - copy.copy(additional_lookup) for additional_lookup - in getattr(rel_qs, '_prefetch_related_lookups', ()) + copy.copy(additional_lookup) + for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ()) ] if additional_lookups: # Don't need to clone because the manager should have given us a fresh @@ -2042,7 +2254,7 @@ def prefetch_one_level(instances, prefetcher, lookup, level): except exceptions.FieldDoesNotExist: pass else: - msg = 'to_attr={} conflicts with a field on the {} model.' + msg = "to_attr={} conflicts with a field on the {} model." raise ValueError(msg.format(to_attr, model.__name__)) # Whether or not we're prefetching the last part of the lookup. @@ -2098,6 +2310,7 @@ class RelatedPopulator: method gets row and from_obj as input and populates the select_related() model instance. """ + def __init__(self, klass_info, select, db): self.db = db # Pre-compute needed attributes. The attributes are: @@ -2123,32 +2336,40 @@ class RelatedPopulator: # - local_setter, remote_setter: Methods to set cached values on # the object being populated and on the remote object. Usually # these are Field.set_cached_value() methods. - select_fields = klass_info['select_fields'] - from_parent = klass_info['from_parent'] + select_fields = klass_info["select_fields"] + from_parent = klass_info["from_parent"] if not from_parent: self.cols_start = select_fields[0] self.cols_end = select_fields[-1] + 1 self.init_list = [ - f[0].target.attname for f in select[self.cols_start:self.cols_end] + f[0].target.attname for f in select[self.cols_start : self.cols_end] ] self.reorder_for_init = None else: - attname_indexes = {select[idx][0].target.attname: idx for idx in select_fields} - model_init_attnames = (f.attname for f in klass_info['model']._meta.concrete_fields) - self.init_list = [attname for attname in model_init_attnames if attname in attname_indexes] - self.reorder_for_init = operator.itemgetter(*[attname_indexes[attname] for attname in self.init_list]) + attname_indexes = { + select[idx][0].target.attname: idx for idx in select_fields + } + model_init_attnames = ( + f.attname for f in klass_info["model"]._meta.concrete_fields + ) + self.init_list = [ + attname for attname in model_init_attnames if attname in attname_indexes + ] + self.reorder_for_init = operator.itemgetter( + *[attname_indexes[attname] for attname in self.init_list] + ) - self.model_cls = klass_info['model'] + self.model_cls = klass_info["model"] self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) self.related_populators = get_related_populators(klass_info, select, self.db) - self.local_setter = klass_info['local_setter'] - self.remote_setter = klass_info['remote_setter'] + self.local_setter = klass_info["local_setter"] + self.remote_setter = klass_info["remote_setter"] def populate(self, row, from_obj): if self.reorder_for_init: obj_data = self.reorder_for_init(row) else: - obj_data = row[self.cols_start:self.cols_end] + obj_data = row[self.cols_start : self.cols_end] if obj_data[self.pk_idx] is None: obj = None else: @@ -2162,7 +2383,7 @@ class RelatedPopulator: def get_related_populators(klass_info, select, db): iterators = [] - related_klass_infos = klass_info.get('related_klass_infos', []) + related_klass_infos = klass_info.get("related_klass_infos", []) for rel_klass_info in related_klass_infos: rel_cls = RelatedPopulator(rel_klass_info, select, db) iterators.append(rel_cls) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 188b640850..6ea82b6520 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -17,7 +17,10 @@ from django.utils import tree # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. -PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation') +PathInfo = namedtuple( + "PathInfo", + "from_opts to_opts target_fields join_field m2m direct filtered_relation", +) def subclasses(cls): @@ -31,21 +34,26 @@ class Q(tree.Node): Encapsulate filters as objects that can then be combined logically (using `&` and `|`). """ + # Connection types - AND = 'AND' - OR = 'OR' + AND = "AND" + OR = "OR" default = AND conditional = True def __init__(self, *args, _connector=None, _negated=False, **kwargs): - super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated) + super().__init__( + children=[*args, *sorted(kwargs.items())], + connector=_connector, + negated=_negated, + ) def _combine(self, other, conn): - if not(isinstance(other, Q) or getattr(other, 'conditional', False) is True): + if not (isinstance(other, Q) or getattr(other, "conditional", False) is True): raise TypeError(other) if not self: - return other.copy() if hasattr(other, 'copy') else copy.copy(other) + return other.copy() if hasattr(other, "copy") else copy.copy(other) elif isinstance(other, Q) and not other: _, args, kwargs = self.deconstruct() return type(self)(*args, **kwargs) @@ -68,26 +76,31 @@ class Q(tree.Node): obj.negate() return obj - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): # We must promote any new joins to left outer joins so that when Q is # used as an expression, rows aren't filtered due to joins. clause, joins = query._add_q( - self, reuse, allow_joins=allow_joins, split_subq=False, + self, + reuse, + allow_joins=allow_joins, + split_subq=False, check_filterable=False, ) query.promote_joins(joins) return clause def deconstruct(self): - path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) - if path.startswith('django.db.models.query_utils'): - path = path.replace('django.db.models.query_utils', 'django.db.models') + path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) + if path.startswith("django.db.models.query_utils"): + path = path.replace("django.db.models.query_utils", "django.db.models") args = tuple(self.children) kwargs = {} if self.connector != self.default: - kwargs['_connector'] = self.connector + kwargs["_connector"] = self.connector if self.negated: - kwargs['_negated'] = True + kwargs["_negated"] = True return path, args, kwargs @@ -96,6 +109,7 @@ class DeferredAttribute: A wrapper for a deferred-loading field. When the value is read from this object the first time, the query is executed. """ + def __init__(self, field): self.field = field @@ -132,7 +146,6 @@ class DeferredAttribute: class RegisterLookupMixin: - @classmethod def _get_lookup(cls, lookup_name): return cls.get_lookups().get(lookup_name, None) @@ -140,13 +153,16 @@ class RegisterLookupMixin: @classmethod @functools.lru_cache(maxsize=None) def get_lookups(cls): - class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(cls)] + class_lookups = [ + parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls) + ] return cls.merge_dicts(class_lookups) def get_lookup(self, lookup_name): from django.db.models.lookups import Lookup + found = self._get_lookup(lookup_name) - if found is None and hasattr(self, 'output_field'): + if found is None and hasattr(self, "output_field"): return self.output_field.get_lookup(lookup_name) if found is not None and not issubclass(found, Lookup): return None @@ -154,8 +170,9 @@ class RegisterLookupMixin: def get_transform(self, lookup_name): from django.db.models.lookups import Transform + found = self._get_lookup(lookup_name) - if found is None and hasattr(self, 'output_field'): + if found is None and hasattr(self, "output_field"): return self.output_field.get_transform(lookup_name) if found is not None and not issubclass(found, Transform): return None @@ -181,7 +198,7 @@ class RegisterLookupMixin: def register_lookup(cls, lookup, lookup_name=None): if lookup_name is None: lookup_name = lookup.lookup_name - if 'class_lookups' not in cls.__dict__: + if "class_lookups" not in cls.__dict__: cls.class_lookups = {} cls.class_lookups[lookup_name] = lookup cls._clear_cached_lookups() @@ -228,8 +245,8 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa if field.attname not in load_fields: if restricted and field.name in requested: msg = ( - 'Field %s.%s cannot be both deferred and traversed using ' - 'select_related at the same time.' + "Field %s.%s cannot be both deferred and traversed using " + "select_related at the same time." ) % (field.model._meta.object_name, field.name) raise FieldError(msg) return True @@ -255,12 +272,14 @@ def check_rel_lookup_compatibility(model, target_opts, field): 1) model and opts match (where proxy inheritance is removed) 2) model is parent of opts' model or the other way around """ + def check(opts): return ( - model._meta.concrete_model == opts.concrete_model or - opts.concrete_model in model._meta.get_parent_list() or - model in opts.get_parent_list() + model._meta.concrete_model == opts.concrete_model + or opts.concrete_model in model._meta.get_parent_list() + or model in opts.get_parent_list() ) + # If the field is a primary key, then doing a query against the field's # model is ok, too. Consider the case: # class Restaurant(models.Model): @@ -270,9 +289,8 @@ def check_rel_lookup_compatibility(model, target_opts, field): # give Place's opts as the target opts, but Restaurant isn't compatible # with that. This logic applies only to primary keys, as when doing __in=qs, # we are going to turn this into __in=qs.values('pk') later on. - return ( - check(target_opts) or - (getattr(field, 'primary_key', False) and check(field.model._meta)) + return check(target_opts) or ( + getattr(field, "primary_key", False) and check(field.model._meta) ) @@ -281,11 +299,11 @@ class FilteredRelation: def __init__(self, relation_name, *, condition=Q()): if not relation_name: - raise ValueError('relation_name cannot be empty.') + raise ValueError("relation_name cannot be empty.") self.relation_name = relation_name self.alias = None if not isinstance(condition, Q): - raise ValueError('condition argument must be a Q() instance.') + raise ValueError("condition argument must be a Q() instance.") self.condition = condition self.path = [] @@ -293,9 +311,9 @@ class FilteredRelation: if not isinstance(other, self.__class__): return NotImplemented return ( - self.relation_name == other.relation_name and - self.alias == other.alias and - self.condition == other.condition + self.relation_name == other.relation_name + and self.alias == other.alias + and self.condition == other.condition ) def clone(self): @@ -309,7 +327,7 @@ class FilteredRelation: QuerySet.annotate() only accepts expression-like arguments (with a resolve_expression() method). """ - raise NotImplementedError('FilteredRelation.resolve_expression() is unused.') + raise NotImplementedError("FilteredRelation.resolve_expression() is unused.") def as_sql(self, compiler, connection): # Resolve the condition in Join.filtered_relation. diff --git a/django/db/models/signals.py b/django/db/models/signals.py index d14eaaf91d..a0720937af 100644 --- a/django/db/models/signals.py +++ b/django/db/models/signals.py @@ -11,6 +11,7 @@ class ModelSignal(Signal): Signal subclass that allows the sender to be lazily specified as a string of the `app_label.ModelName` form. """ + def _lazy_method(self, method, apps, receiver, sender, **kwargs): from django.db.models.options import Options @@ -24,8 +25,12 @@ class ModelSignal(Signal): def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None): self._lazy_method( - super().connect, apps, receiver, sender, - weak=weak, dispatch_uid=dispatch_uid, + super().connect, + apps, + receiver, + sender, + weak=weak, + dispatch_uid=dispatch_uid, ) def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None): diff --git a/django/db/models/sql/__init__.py b/django/db/models/sql/__init__.py index 5fa52f6a1f..2956e047b1 100644 --- a/django/db/models/sql/__init__.py +++ b/django/db/models/sql/__init__.py @@ -3,4 +3,4 @@ from django.db.models.sql.query import Query from django.db.models.sql.subqueries import * # NOQA from django.db.models.sql.where import AND, OR -__all__ = ['Query', 'AND', 'OR'] +__all__ = ["Query", "AND", "OR"] diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index d405a203ee..13a7ec7263 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -11,7 +11,12 @@ from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value from django.db.models.functions import Cast, Random from django.db.models.query_utils import select_related_descend from django.db.models.sql.constants import ( - CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE, + CURSOR, + GET_ITERATOR_CHUNK_SIZE, + MULTI, + NO_RESULTS, + ORDER_DIR, + SINGLE, ) from django.db.models.sql.query import Query, get_order_dir from django.db.transaction import TransactionManagementError @@ -23,7 +28,7 @@ from django.utils.regex_helper import _lazy_re_compile class SQLCompiler: # Multiline ordering SQL clause may appear from RawSQL. ordering_parts = _lazy_re_compile( - r'^(.*)\s(?:ASC|DESC).*', + r"^(.*)\s(?:ASC|DESC).*", re.MULTILINE | re.DOTALL, ) @@ -34,7 +39,7 @@ class SQLCompiler: # Some queries, e.g. coalesced aggregation, need to be executed even if # they would return an empty result set. self.elide_empty = elide_empty - self.quote_cache = {'*': '*'} + self.quote_cache = {"*": "*"} # The select, klass_info, and annotations are needed by QuerySet.iterator() # these are set as a side-effect of executing the query. Note that we calculate # separately a list of extra select columns needed for grammatical correctness @@ -46,9 +51,9 @@ class SQLCompiler: def __repr__(self): return ( - f'<{self.__class__.__qualname__} ' - f'model={self.query.model.__qualname__} ' - f'connection={self.connection!r} using={self.using!r}>' + f"<{self.__class__.__qualname__} " + f"model={self.query.model.__qualname__} " + f"connection={self.connection!r} using={self.using!r}>" ) def setup_query(self): @@ -118,16 +123,14 @@ class SQLCompiler: # when we have public API way of forcing the GROUP BY clause. # Converts string references to expressions. for expr in self.query.group_by: - if not hasattr(expr, 'as_sql'): + if not hasattr(expr, "as_sql"): expressions.append(self.query.resolve_ref(expr)) else: expressions.append(expr) # Note that even if the group_by is set, it is only the minimal # set to group by. So, we need to add cols in select, order_by, and # having into the select in any case. - ref_sources = { - expr.source for expr in expressions if isinstance(expr, Ref) - } + ref_sources = {expr.source for expr in expressions if isinstance(expr, Ref)} for expr, _, _ in select: # Skip members of the select clause that are already included # by reference. @@ -169,8 +172,10 @@ class SQLCompiler: for expr in expressions: # Is this a reference to query's base table primary key? If the # expression isn't a Col-like, then skip the expression. - if (getattr(expr, 'target', None) == self.query.model._meta.pk and - getattr(expr, 'alias', None) == self.query.base_table): + if ( + getattr(expr, "target", None) == self.query.model._meta.pk + and getattr(expr, "alias", None) == self.query.base_table + ): pk = expr break # If the main model's primary key is in the query, group by that @@ -178,13 +183,17 @@ class SQLCompiler: # that don't have a primary key included in the grouped columns. if pk: pk_aliases = { - expr.alias for expr in expressions - if hasattr(expr, 'target') and expr.target.primary_key + expr.alias + for expr in expressions + if hasattr(expr, "target") and expr.target.primary_key } expressions = [pk] + [ - expr for expr in expressions - if expr in having or ( - getattr(expr, 'alias', None) is not None and expr.alias not in pk_aliases + expr + for expr in expressions + if expr in having + or ( + getattr(expr, "alias", None) is not None + and expr.alias not in pk_aliases ) ] elif self.connection.features.allows_group_by_selected_pks: @@ -195,16 +204,21 @@ class SQLCompiler: # Unmanaged models are excluded because they could be representing # database views on which the optimization might not be allowed. pks = { - expr for expr in expressions + expr + for expr in expressions if ( - hasattr(expr, 'target') and - expr.target.primary_key and - self.connection.features.allows_group_by_selected_pks_on_model(expr.target.model) + hasattr(expr, "target") + and expr.target.primary_key + and self.connection.features.allows_group_by_selected_pks_on_model( + expr.target.model + ) ) } aliases = {expr.alias for expr in pks} expressions = [ - expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases + expr + for expr in expressions + if expr in pks or getattr(expr, "alias", None) not in aliases ] return expressions @@ -248,8 +262,8 @@ class SQLCompiler: select.append((col, None)) select_idx += 1 klass_info = { - 'model': self.query.model, - 'select_fields': select_list, + "model": self.query.model, + "select_fields": select_list, } for alias, annotation in self.query.annotation_select.items(): annotations[alias] = select_idx @@ -258,14 +272,16 @@ class SQLCompiler: if self.query.select_related: related_klass_infos = self.get_related_selections(select) - klass_info['related_klass_infos'] = related_klass_infos + klass_info["related_klass_infos"] = related_klass_infos def get_select_from_parent(klass_info): - for ki in klass_info['related_klass_infos']: - if ki['from_parent']: - ki['select_fields'] = (klass_info['select_fields'] + - ki['select_fields']) + for ki in klass_info["related_klass_infos"]: + if ki["from_parent"]: + ki["select_fields"] = ( + klass_info["select_fields"] + ki["select_fields"] + ) get_select_from_parent(ki) + get_select_from_parent(klass_info) ret = [] @@ -273,10 +289,12 @@ class SQLCompiler: try: sql, params = self.compile(col) except EmptyResultSet: - empty_result_set_value = getattr(col, 'empty_result_set_value', NotImplemented) + empty_result_set_value = getattr( + col, "empty_result_set_value", NotImplemented + ) if empty_result_set_value is NotImplemented: # Select a predicate that's always False. - sql, params = '0', () + sql, params = "0", () else: sql, params = self.compile(Value(empty_result_set_value)) else: @@ -297,12 +315,12 @@ class SQLCompiler: else: ordering = [] if self.query.standard_ordering: - default_order, _ = ORDER_DIR['ASC'] + default_order, _ = ORDER_DIR["ASC"] else: - default_order, _ = ORDER_DIR['DESC'] + default_order, _ = ORDER_DIR["DESC"] for field in ordering: - if hasattr(field, 'resolve_expression'): + if hasattr(field, "resolve_expression"): if isinstance(field, Value): # output_field must be resolved for constants. field = Cast(field, field.output_field) @@ -313,12 +331,12 @@ class SQLCompiler: field.reverse_ordering() yield field, False continue - if field == '?': # random + if field == "?": # random yield OrderBy(Random()), False continue col, order = get_order_dir(field, default_order) - descending = order == 'DESC' + descending = order == "DESC" if col in self.query.annotation_select: # Reference to expression in SELECT clause @@ -345,13 +363,15 @@ class SQLCompiler: yield OrderBy(expr, descending=descending), False continue - if '.' in field: + if "." in field: # This came in through an extra(order_by=...) addition. Pass it # on verbatim. - table, col = col.split('.', 1) + table, col = col.split(".", 1) yield ( OrderBy( - RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []), + RawSQL( + "%s.%s" % (self.quote_name_unless_alias(table), col), [] + ), descending=descending, ), False, @@ -361,7 +381,10 @@ class SQLCompiler: if self.query.extra and col in self.query.extra: if col in self.query.extra_select: yield ( - OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending), + OrderBy( + Ref(col, RawSQL(*self.query.extra[col])), + descending=descending, + ), True, ) else: @@ -378,7 +401,9 @@ class SQLCompiler: # 'col' is of the form 'field' or 'field1__field2' or # '-field1__field2__field', etc. yield from self.find_ordering_name( - field, self.query.get_meta(), default_order=default_order, + field, + self.query.get_meta(), + default_order=default_order, ) def get_order_by(self): @@ -409,19 +434,21 @@ class SQLCompiler: ): continue if src == sel_expr: - resolved.set_source_expressions([RawSQL('%d' % (idx + 1), ())]) + resolved.set_source_expressions([RawSQL("%d" % (idx + 1), ())]) break else: if col_alias: - raise DatabaseError('ORDER BY term does not match any column in the result set.') + raise DatabaseError( + "ORDER BY term does not match any column in the result set." + ) # Add column used in ORDER BY clause to the selected # columns and to each combined query. order_by_idx = len(self.query.select) + 1 - col_name = f'__orderbycol{order_by_idx}' + col_name = f"__orderbycol{order_by_idx}" for q in self.query.combined_queries: q.add_annotation(expr_src, col_name) self.query.add_select_col(resolved, col_name) - resolved.set_source_expressions([RawSQL(f'{order_by_idx}', ())]) + resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())]) sql, params = self.compile(resolved) # Don't add the same column twice, but the order direction is # not taken into account so we strip it. When this entire method @@ -453,9 +480,14 @@ class SQLCompiler: """ if name in self.quote_cache: return self.quote_cache[name] - if ((name in self.query.alias_map and name not in self.query.table_map) or - name in self.query.extra_select or ( - self.query.external_aliases.get(name) and name not in self.query.table_map)): + if ( + (name in self.query.alias_map and name not in self.query.table_map) + or name in self.query.extra_select + or ( + self.query.external_aliases.get(name) + and name not in self.query.table_map + ) + ): self.quote_cache[name] = name return name r = self.connection.ops.quote_name(name) @@ -463,7 +495,7 @@ class SQLCompiler: return r def compile(self, node): - vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) + vendor_impl = getattr(node, "as_" + self.connection.vendor, None) if vendor_impl: sql, params = vendor_impl(self, self.connection) else: @@ -474,14 +506,19 @@ class SQLCompiler: features = self.connection.features compilers = [ query.get_compiler(self.using, self.connection, self.elide_empty) - for query in self.query.combined_queries if not query.is_empty() + for query in self.query.combined_queries + if not query.is_empty() ] if not features.supports_slicing_ordering_in_compound: for query, compiler in zip(self.query.combined_queries, compilers): if query.low_mark or query.high_mark: - raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.') + raise DatabaseError( + "LIMIT/OFFSET not allowed in subqueries of compound statements." + ) if compiler.get_order_by(): - raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.') + raise DatabaseError( + "ORDER BY not allowed in subqueries of compound statements." + ) parts = () for compiler in compilers: try: @@ -490,41 +527,45 @@ class SQLCompiler: # the query on all combined queries, if not already set. if not compiler.query.values_select and self.query.values_select: compiler.query = compiler.query.clone() - compiler.query.set_values(( - *self.query.extra_select, - *self.query.values_select, - *self.query.annotation_select, - )) + compiler.query.set_values( + ( + *self.query.extra_select, + *self.query.values_select, + *self.query.annotation_select, + ) + ) part_sql, part_args = compiler.as_sql() if compiler.query.combinator: # Wrap in a subquery if wrapping in parentheses isn't # supported. if not features.supports_parentheses_in_compound: - part_sql = 'SELECT * FROM ({})'.format(part_sql) + part_sql = "SELECT * FROM ({})".format(part_sql) # Add parentheses when combining with compound query if not # already added for all compound queries. elif ( - self.query.subquery or - not features.supports_slicing_ordering_in_compound + self.query.subquery + or not features.supports_slicing_ordering_in_compound ): - part_sql = '({})'.format(part_sql) + part_sql = "({})".format(part_sql) parts += ((part_sql, part_args),) except EmptyResultSet: # Omit the empty queryset with UNION and with DIFFERENCE if the # first queryset is nonempty. - if combinator == 'union' or (combinator == 'difference' and parts): + if combinator == "union" or (combinator == "difference" and parts): continue raise if not parts: raise EmptyResultSet combinator_sql = self.connection.ops.set_operators[combinator] - if all and combinator == 'union': - combinator_sql += ' ALL' - braces = '{}' + if all and combinator == "union": + combinator_sql += " ALL" + braces = "{}" if not self.query.subquery and features.supports_slicing_ordering_in_compound: - braces = '({})' - sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts)) - result = [' {} '.format(combinator_sql).join(sql_parts)] + braces = "({})" + sql_parts, args_parts = zip( + *((braces.format(sql), args) for sql, args in parts) + ) + result = [" {} ".format(combinator_sql).join(sql_parts)] params = [] for part in args_parts: params.extend(part) @@ -543,27 +584,39 @@ class SQLCompiler: extra_select, order_by, group_by = self.pre_sql_setup() for_update_part = None # Is a LIMIT/OFFSET clause needed? - with_limit_offset = with_limits and (self.query.high_mark is not None or self.query.low_mark) + with_limit_offset = with_limits and ( + self.query.high_mark is not None or self.query.low_mark + ) combinator = self.query.combinator features = self.connection.features if combinator: - if not getattr(features, 'supports_select_{}'.format(combinator)): - raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) - result, params = self.get_combinator_sql(combinator, self.query.combinator_all) + if not getattr(features, "supports_select_{}".format(combinator)): + raise NotSupportedError( + "{} is not supported on this database backend.".format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) else: distinct_fields, distinct_params = self.get_distinct() # This must come after 'select', 'ordering', and 'distinct' # (see docstring of get_from_clause() for details). from_, f_params = self.get_from_clause() try: - where, w_params = self.compile(self.where) if self.where is not None else ('', []) + where, w_params = ( + self.compile(self.where) if self.where is not None else ("", []) + ) except EmptyResultSet: if self.elide_empty: raise # Use a predicate that's always False. - where, w_params = '0 = 1', [] - having, h_params = self.compile(self.having) if self.having is not None else ("", []) - result = ['SELECT'] + where, w_params = "0 = 1", [] + having, h_params = ( + self.compile(self.having) if self.having is not None else ("", []) + ) + result = ["SELECT"] params = [] if self.query.distinct: @@ -578,27 +631,38 @@ class SQLCompiler: col_idx = 1 for _, (s_sql, s_params), alias in self.select + extra_select: if alias: - s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) - elif with_col_aliases: - s_sql = '%s AS %s' % ( + s_sql = "%s AS %s" % ( s_sql, - self.connection.ops.quote_name('col%d' % col_idx), + self.connection.ops.quote_name(alias), + ) + elif with_col_aliases: + s_sql = "%s AS %s" % ( + s_sql, + self.connection.ops.quote_name("col%d" % col_idx), ) col_idx += 1 params.extend(s_params) out_cols.append(s_sql) - result += [', '.join(out_cols), 'FROM', *from_] + result += [", ".join(out_cols), "FROM", *from_] params.extend(f_params) - if self.query.select_for_update and self.connection.features.has_select_for_update: + if ( + self.query.select_for_update + and self.connection.features.has_select_for_update + ): if self.connection.get_autocommit(): - raise TransactionManagementError('select_for_update cannot be used outside of a transaction.') + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) - if with_limit_offset and not self.connection.features.supports_select_for_update_with_limit: + if ( + with_limit_offset + and not self.connection.features.supports_select_for_update_with_limit + ): raise NotSupportedError( - 'LIMIT/OFFSET is not supported with ' - 'select_for_update on this database backend.' + "LIMIT/OFFSET is not supported with " + "select_for_update on this database backend." ) nowait = self.query.select_for_update_nowait skip_locked = self.query.select_for_update_skip_locked @@ -607,16 +671,31 @@ class SQLCompiler: # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the # backend doesn't support it, raise NotSupportedError to # prevent a possible deadlock. - if nowait and not self.connection.features.has_select_for_update_nowait: - raise NotSupportedError('NOWAIT is not supported on this database backend.') - elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: - raise NotSupportedError('SKIP LOCKED is not supported on this database backend.') - elif of and not self.connection.features.has_select_for_update_of: - raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.') - elif no_key and not self.connection.features.has_select_for_no_key_update: + if ( + nowait + and not self.connection.features.has_select_for_update_nowait + ): raise NotSupportedError( - 'FOR NO KEY UPDATE is not supported on this ' - 'database backend.' + "NOWAIT is not supported on this database backend." + ) + elif ( + skip_locked + and not self.connection.features.has_select_for_update_skip_locked + ): + raise NotSupportedError( + "SKIP LOCKED is not supported on this database backend." + ) + elif of and not self.connection.features.has_select_for_update_of: + raise NotSupportedError( + "FOR UPDATE OF is not supported on this database backend." + ) + elif ( + no_key + and not self.connection.features.has_select_for_no_key_update + ): + raise NotSupportedError( + "FOR NO KEY UPDATE is not supported on this " + "database backend." ) for_update_part = self.connection.ops.for_update_sql( nowait=nowait, @@ -629,7 +708,7 @@ class SQLCompiler: result.append(for_update_part) if where: - result.append('WHERE %s' % where) + result.append("WHERE %s" % where) params.extend(w_params) grouping = [] @@ -638,30 +717,39 @@ class SQLCompiler: params.extend(g_params) if grouping: if distinct_fields: - raise NotImplementedError('annotate() + distinct(fields) is not implemented.') + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented." + ) order_by = order_by or self.connection.ops.force_no_ordering() - result.append('GROUP BY %s' % ', '.join(grouping)) + result.append("GROUP BY %s" % ", ".join(grouping)) if self._meta_ordering: order_by = None if having: - result.append('HAVING %s' % having) + result.append("HAVING %s" % having) params.extend(h_params) if self.query.explain_info: - result.insert(0, self.connection.ops.explain_query_prefix( - self.query.explain_info.format, - **self.query.explain_info.options - )) + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_info.format, + **self.query.explain_info.options, + ), + ) if order_by: ordering = [] for _, (o_sql, o_params, _) in order_by: ordering.append(o_sql) params.extend(o_params) - result.append('ORDER BY %s' % ', '.join(ordering)) + result.append("ORDER BY %s" % ", ".join(ordering)) if with_limit_offset: - result.append(self.connection.ops.limit_offset_sql(self.query.low_mark, self.query.high_mark)) + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) if for_update_part and not self.connection.features.for_update_after_from: result.append(for_update_part) @@ -677,23 +765,30 @@ class SQLCompiler: sub_params = [] for index, (select, _, alias) in enumerate(self.select, start=1): if not alias and with_col_aliases: - alias = 'col%d' % index + alias = "col%d" % index if alias: - sub_selects.append("%s.%s" % ( - self.connection.ops.quote_name('subquery'), - self.connection.ops.quote_name(alias), - )) + sub_selects.append( + "%s.%s" + % ( + self.connection.ops.quote_name("subquery"), + self.connection.ops.quote_name(alias), + ) + ) else: - select_clone = select.relabeled_clone({select.alias: 'subquery'}) - subselect, subparams = select_clone.as_sql(self, self.connection) + select_clone = select.relabeled_clone( + {select.alias: "subquery"} + ) + subselect, subparams = select_clone.as_sql( + self, self.connection + ) sub_selects.append(subselect) sub_params.extend(subparams) - return 'SELECT %s FROM (%s) subquery' % ( - ', '.join(sub_selects), - ' '.join(result), + return "SELECT %s FROM (%s) subquery" % ( + ", ".join(sub_selects), + " ".join(result), ), tuple(sub_params + params) - return ' '.join(result), tuple(params) + return " ".join(result), tuple(params) finally: # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) @@ -726,8 +821,13 @@ class SQLCompiler: # will assign None if the field belongs to this model. if model == opts.model: model = None - if from_parent and model is not None and issubclass( - from_parent._meta.concrete_model, model._meta.concrete_model): + if ( + from_parent + and model is not None + and issubclass( + from_parent._meta.concrete_model, model._meta.concrete_model + ) + ): # Avoid loading data for already loaded parents. # We end up here in the case select_related() resolution # proceeds from parent model to child model. In that case the @@ -736,8 +836,7 @@ class SQLCompiler: continue if field.model in only_load and field.attname not in only_load[field.model]: continue - alias = self.query.join_parent_model(opts, model, start_alias, - seen_models) + alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) result.append(column) return result @@ -755,7 +854,9 @@ class SQLCompiler: for name in self.query.distinct_fields: parts = name.split(LOOKUP_SEP) - _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None) + _, targets, alias, joins, path, _, transform_function = self._setup_joins( + parts, opts, None + ) targets, alias, _ = self.query.trim_joins(targets, joins, path) for target in targets: if name in self.query.annotation_select: @@ -766,46 +867,63 @@ class SQLCompiler: params.append(p) return result, params - def find_ordering_name(self, name, opts, alias=None, default_order='ASC', - already_seen=None): + def find_ordering_name( + self, name, opts, alias=None, default_order="ASC", already_seen=None + ): """ Return the table alias (the name might be ambiguous, the alias will not be) and column name for ordering by the given 'name' parameter. The 'name' is of the form 'field1__field2__...__fieldN'. """ name, order = get_order_dir(name, default_order) - descending = order == 'DESC' + descending = order == "DESC" pieces = name.split(LOOKUP_SEP) - field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias) + ( + field, + targets, + alias, + joins, + path, + opts, + transform_function, + ) = self._setup_joins(pieces, opts, alias) # If we get to this point and the field is a relation to another model, # append the default ordering for that model unless it is the pk # shortcut or the attribute name of the field that is specified. if ( - field.is_relation and - opts.ordering and - getattr(field, 'attname', None) != pieces[-1] and - name != 'pk' + field.is_relation + and opts.ordering + and getattr(field, "attname", None) != pieces[-1] + and name != "pk" ): # Firstly, avoid infinite loops. already_seen = already_seen or set() - join_tuple = tuple(getattr(self.query.alias_map[j], 'join_cols', None) for j in joins) + join_tuple = tuple( + getattr(self.query.alias_map[j], "join_cols", None) for j in joins + ) if join_tuple in already_seen: - raise FieldError('Infinite loop caused by ordering.') + raise FieldError("Infinite loop caused by ordering.") already_seen.add(join_tuple) results = [] for item in opts.ordering: - if hasattr(item, 'resolve_expression') and not isinstance(item, OrderBy): + if hasattr(item, "resolve_expression") and not isinstance( + item, OrderBy + ): item = item.desc() if descending else item.asc() if isinstance(item, OrderBy): results.append((item, False)) continue - results.extend(self.find_ordering_name(item, opts, alias, - order, already_seen)) + results.extend( + self.find_ordering_name(item, opts, alias, order, already_seen) + ) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) - return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets] + return [ + (OrderBy(transform_function(t, alias), descending=descending), False) + for t in targets + ] def _setup_joins(self, pieces, opts, alias): """ @@ -816,7 +934,9 @@ class SQLCompiler: match. Executing SQL where this is not true is an error. """ alias = alias or self.query.get_initial_alias() - field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias) + field, targets, opts, joins, path, transform_function = self.query.setup_joins( + pieces, opts, alias + ) alias = joins[-1] return field, targets, alias, joins, path, opts, transform_function @@ -850,25 +970,39 @@ class SQLCompiler: # Only add the alias if it's not already present (the table_alias() # call increments the refcount, so an alias refcount of one means # this is the only reference). - if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1: - result.append(', %s' % self.quote_name_unless_alias(alias)) + if ( + alias not in self.query.alias_map + or self.query.alias_refcount[alias] == 1 + ): + result.append(", %s" % self.quote_name_unless_alias(alias)) return result, params - def get_related_selections(self, select, opts=None, root_alias=None, cur_depth=1, - requested=None, restricted=None): + def get_related_selections( + self, + select, + opts=None, + root_alias=None, + cur_depth=1, + requested=None, + restricted=None, + ): """ Fill in the information needed for a select_related query. The current depth is measured as the number of connections away from the root model (for example, cur_depth=1 means we are looking at models with direct connections to the root model). """ + def _get_field_choices(): direct_choices = (f.name for f in opts.fields if f.is_relation) reverse_choices = ( f.field.related_query_name() - for f in opts.related_objects if f.field.unique + for f in opts.related_objects + if f.field.unique + ) + return chain( + direct_choices, reverse_choices, self.query._filtered_relations ) - return chain(direct_choices, reverse_choices, self.query._filtered_relations) related_klass_infos = [] if not restricted and cur_depth > self.query.max_depth: @@ -889,7 +1023,7 @@ class SQLCompiler: requested = self.query.select_related def get_related_klass_infos(klass_info, related_klass_infos): - klass_info['related_klass_infos'] = related_klass_infos + klass_info["related_klass_infos"] = related_klass_infos for f in opts.fields: field_model = f.model._meta.concrete_model @@ -903,37 +1037,48 @@ class SQLCompiler: if next or f.name in requested: raise FieldError( "Non-relational field given in select_related: '%s'. " - "Choices are: %s" % ( + "Choices are: %s" + % ( f.name, - ", ".join(_get_field_choices()) or '(none)', + ", ".join(_get_field_choices()) or "(none)", ) ) else: next = False - if not select_related_descend(f, restricted, requested, - only_load.get(field_model)): + if not select_related_descend( + f, restricted, requested, only_load.get(field_model) + ): continue klass_info = { - 'model': f.remote_field.model, - 'field': f, - 'reverse': False, - 'local_setter': f.set_cached_value, - 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None, - 'from_parent': False, + "model": f.remote_field.model, + "field": f, + "reverse": False, + "local_setter": f.set_cached_value, + "remote_setter": f.remote_field.set_cached_value + if f.unique + else lambda x, y: None, + "from_parent": False, } related_klass_infos.append(klass_info) select_fields = [] - _, _, _, joins, _, _ = self.query.setup_joins( - [f.name], opts, root_alias) + _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias) alias = joins[-1] - columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta) + columns = self.get_default_columns( + start_alias=alias, opts=f.remote_field.model._meta + ) for col in columns: select_fields.append(len(select)) select.append((col, None)) - klass_info['select_fields'] = select_fields + klass_info["select_fields"] = select_fields next_klass_infos = self.get_related_selections( - select, f.remote_field.model._meta, alias, cur_depth + 1, next, restricted) + select, + f.remote_field.model._meta, + alias, + cur_depth + 1, + next, + restricted, + ) get_related_klass_infos(klass_info, next_klass_infos) if restricted: @@ -943,36 +1088,40 @@ class SQLCompiler: if o.field.unique and not o.many_to_many ] for f, model in related_fields: - if not select_related_descend(f, restricted, requested, - only_load.get(model), reverse=True): + if not select_related_descend( + f, restricted, requested, only_load.get(model), reverse=True + ): continue related_field_name = f.related_query_name() fields_found.add(related_field_name) - join_info = self.query.setup_joins([related_field_name], opts, root_alias) + join_info = self.query.setup_joins( + [related_field_name], opts, root_alias + ) alias = join_info.joins[-1] from_parent = issubclass(model, opts.model) and model is not opts.model klass_info = { - 'model': model, - 'field': f, - 'reverse': True, - 'local_setter': f.remote_field.set_cached_value, - 'remote_setter': f.set_cached_value, - 'from_parent': from_parent, + "model": model, + "field": f, + "reverse": True, + "local_setter": f.remote_field.set_cached_value, + "remote_setter": f.set_cached_value, + "from_parent": from_parent, } related_klass_infos.append(klass_info) select_fields = [] columns = self.get_default_columns( - start_alias=alias, opts=model._meta, from_parent=opts.model) + start_alias=alias, opts=model._meta, from_parent=opts.model + ) for col in columns: select_fields.append(len(select)) select.append((col, None)) - klass_info['select_fields'] = select_fields + klass_info["select_fields"] = select_fields next = requested.get(f.related_query_name(), {}) next_klass_infos = self.get_related_selections( - select, model._meta, alias, cur_depth + 1, - next, restricted) + select, model._meta, alias, cur_depth + 1, next, restricted + ) get_related_klass_infos(klass_info, next_klass_infos) def local_setter(obj, from_obj): @@ -989,32 +1138,40 @@ class SQLCompiler: break if name in self.query._filtered_relations: fields_found.add(name) - f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias) + f, _, join_opts, joins, _, _ = self.query.setup_joins( + [name], opts, root_alias + ) model = join_opts.model alias = joins[-1] - from_parent = issubclass(model, opts.model) and model is not opts.model + from_parent = ( + issubclass(model, opts.model) and model is not opts.model + ) klass_info = { - 'model': model, - 'field': f, - 'reverse': True, - 'local_setter': local_setter, - 'remote_setter': partial(remote_setter, name), - 'from_parent': from_parent, + "model": model, + "field": f, + "reverse": True, + "local_setter": local_setter, + "remote_setter": partial(remote_setter, name), + "from_parent": from_parent, } related_klass_infos.append(klass_info) select_fields = [] columns = self.get_default_columns( - start_alias=alias, opts=model._meta, + start_alias=alias, + opts=model._meta, from_parent=opts.model, ) for col in columns: select_fields.append(len(select)) select.append((col, None)) - klass_info['select_fields'] = select_fields + klass_info["select_fields"] = select_fields next_requested = requested.get(name, {}) next_klass_infos = self.get_related_selections( - select, opts=model._meta, root_alias=alias, - cur_depth=cur_depth + 1, requested=next_requested, + select, + opts=model._meta, + root_alias=alias, + cur_depth=cur_depth + 1, + requested=next_requested, restricted=restricted, ) get_related_klass_infos(klass_info, next_klass_infos) @@ -1022,10 +1179,11 @@ class SQLCompiler: if fields_not_found: invalid_fields = ("'%s'" % s for s in fields_not_found) raise FieldError( - 'Invalid field name(s) given in select_related: %s. ' - 'Choices are: %s' % ( - ', '.join(invalid_fields), - ', '.join(_get_field_choices()) or '(none)', + "Invalid field name(s) given in select_related: %s. " + "Choices are: %s" + % ( + ", ".join(invalid_fields), + ", ".join(_get_field_choices()) or "(none)", ) ) return related_klass_infos @@ -1035,21 +1193,22 @@ class SQLCompiler: Return a quoted list of arguments for the SELECT FOR UPDATE OF part of the query. """ + def _get_parent_klass_info(klass_info): - concrete_model = klass_info['model']._meta.concrete_model + concrete_model = klass_info["model"]._meta.concrete_model for parent_model, parent_link in concrete_model._meta.parents.items(): parent_list = parent_model._meta.get_parent_list() yield { - 'model': parent_model, - 'field': parent_link, - 'reverse': False, - 'select_fields': [ + "model": parent_model, + "field": parent_link, + "reverse": False, + "select_fields": [ select_index - for select_index in klass_info['select_fields'] + for select_index in klass_info["select_fields"] # Selected columns from a model or its parents. if ( - self.select[select_index][0].target.model == parent_model or - self.select[select_index][0].target.model in parent_list + self.select[select_index][0].target.model == parent_model + or self.select[select_index][0].target.model in parent_list ) ], } @@ -1062,8 +1221,8 @@ class SQLCompiler: select_fields is filled recursively, so it also contains fields from the parent models. """ - concrete_model = klass_info['model']._meta.concrete_model - for select_index in klass_info['select_fields']: + concrete_model = klass_info["model"]._meta.concrete_model + for select_index in klass_info["select_fields"]: if self.select[select_index][0].target.model == concrete_model: return self.select[select_index][0] @@ -1074,10 +1233,10 @@ class SQLCompiler: parent_path, klass_info = queue.popleft() if parent_path is None: path = [] - yield 'self' + yield "self" else: - field = klass_info['field'] - if klass_info['reverse']: + field = klass_info["field"] + if klass_info["reverse"]: field = field.remote_field path = parent_path + [field.name] yield LOOKUP_SEP.join(path) @@ -1087,25 +1246,26 @@ class SQLCompiler: ) queue.extend( (path, klass_info) - for klass_info in klass_info.get('related_klass_infos', []) + for klass_info in klass_info.get("related_klass_infos", []) ) + if not self.klass_info: return [] result = [] invalid_names = [] for name in self.query.select_for_update_of: klass_info = self.klass_info - if name == 'self': + if name == "self": col = _get_first_selected_col_from_model(klass_info) else: for part in name.split(LOOKUP_SEP): klass_infos = ( - *klass_info.get('related_klass_infos', []), + *klass_info.get("related_klass_infos", []), *_get_parent_klass_info(klass_info), ) for related_klass_info in klass_infos: - field = related_klass_info['field'] - if related_klass_info['reverse']: + field = related_klass_info["field"] + if related_klass_info["reverse"]: field = field.remote_field if field.name == part: klass_info = related_klass_info @@ -1124,11 +1284,12 @@ class SQLCompiler: result.append(self.quote_name_unless_alias(col.alias)) if invalid_names: raise FieldError( - 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' - 'Only relational fields followed in the query are allowed. ' - 'Choices are: %s.' % ( - ', '.join(invalid_names), - ', '.join(_get_field_choices()), + "Invalid field name(s) given in select_for_update(of=(...)): %s. " + "Only relational fields followed in the query are allowed. " + "Choices are: %s." + % ( + ", ".join(invalid_names), + ", ".join(_get_field_choices()), ) ) return result @@ -1164,12 +1325,19 @@ class SQLCompiler: row[pos] = value yield row - def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False, - chunk_size=GET_ITERATOR_CHUNK_SIZE): + def results_iter( + self, + results=None, + tuple_expected=False, + chunked_fetch=False, + chunk_size=GET_ITERATOR_CHUNK_SIZE, + ): """Return an iterator over the results from executing this query.""" if results is None: - results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size) - fields = [s[0] for s in self.select[0:self.col_count]] + results = self.execute_sql( + MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size + ) + fields = [s[0] for s in self.select[0 : self.col_count]] converters = self.get_converters(fields) rows = chain.from_iterable(results) if converters: @@ -1185,7 +1353,9 @@ class SQLCompiler: """ return bool(self.execute_sql(SINGLE)) - def execute_sql(self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE): + def execute_sql( + self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE + ): """ Run the query against the database and return the result(s). The return value is a single data item if result_type is SINGLE, or an @@ -1226,7 +1396,7 @@ class SQLCompiler: try: val = cursor.fetchone() if val: - return val[0:self.col_count] + return val[0 : self.col_count] return val finally: # done with the cursor @@ -1236,7 +1406,8 @@ class SQLCompiler: return result = cursor_iter( - cursor, self.connection.features.empty_fetchmany_value, + cursor, + self.connection.features.empty_fetchmany_value, self.col_count if self.has_extra_select else None, chunk_size, ) @@ -1254,21 +1425,22 @@ class SQLCompiler: for index, select_col in enumerate(self.query.select): lhs_sql, lhs_params = self.compile(select_col) - rhs = '%s.%s' % (qn(alias), qn2(columns[index])) - self.query.where.add( - RawSQL('%s = %s' % (lhs_sql, rhs), lhs_params), 'AND') + rhs = "%s.%s" % (qn(alias), qn2(columns[index])) + self.query.where.add(RawSQL("%s = %s" % (lhs_sql, rhs), lhs_params), "AND") sql, params = self.as_sql() - return 'EXISTS (%s)' % sql, params + return "EXISTS (%s)" % sql, params def explain_query(self): result = list(self.execute_sql()) # Some backends return 1 item tuples with strings, and others return # tuples with integers and strings. Flatten them out into strings. - output_formatter = json.dumps if self.query.explain_info.format == 'json' else str + output_formatter = ( + json.dumps if self.query.explain_info.format == "json" else str + ) for row in result[0]: if not isinstance(row, str): - yield ' '.join(output_formatter(c) for c in row) + yield " ".join(output_formatter(c) for c in row) else: yield row @@ -1289,16 +1461,16 @@ class SQLInsertCompiler(SQLCompiler): if field is None: # A field value of None means the value is raw. sql, params = val, [] - elif hasattr(val, 'as_sql'): + elif hasattr(val, "as_sql"): # This is an expression, let's compile it. sql, params = self.compile(val) - elif hasattr(field, 'get_placeholder'): + elif hasattr(field, "get_placeholder"): # Some fields (e.g. geo fields) need special munging before # they can be inserted. sql, params = field.get_placeholder(val, self, self.connection), [val] else: # Return the common case for the placeholder - sql, params = '%s', [val] + sql, params = "%s", [val] # The following hook is only used by Oracle Spatial, which sometimes # needs to yield 'NULL' and [] as its placeholder and params instead @@ -1314,24 +1486,26 @@ class SQLInsertCompiler(SQLCompiler): Prepare a value to be used in a query by resolving it if it is an expression and otherwise calling the field's get_db_prep_save(). """ - if hasattr(value, 'resolve_expression'): - value = value.resolve_expression(self.query, allow_joins=False, for_save=True) + if hasattr(value, "resolve_expression"): + value = value.resolve_expression( + self.query, allow_joins=False, for_save=True + ) # Don't allow values containing Col expressions. They refer to # existing columns on a row, but in the case of insert the row # doesn't exist yet. if value.contains_column_references: raise ValueError( 'Failed to insert expression "%s" on %s. F() expressions ' - 'can only be used to update, not to insert.' % (value, field) + "can only be used to update, not to insert." % (value, field) ) if value.contains_aggregate: raise FieldError( - 'Aggregate functions are not allowed in this query ' - '(%s=%r).' % (field.name, value) + "Aggregate functions are not allowed in this query " + "(%s=%r)." % (field.name, value) ) if value.contains_over_clause: raise FieldError( - 'Window expressions are not allowed in this query (%s=%r).' + "Window expressions are not allowed in this query (%s=%r)." % (field.name, value) ) else: @@ -1390,25 +1564,32 @@ class SQLInsertCompiler(SQLCompiler): insert_statement = self.connection.ops.insert_statement( on_conflict=self.query.on_conflict, ) - result = ['%s %s' % (insert_statement, qn(opts.db_table))] + result = ["%s %s" % (insert_statement, qn(opts.db_table))] fields = self.query.fields or [opts.pk] - result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) + result.append("(%s)" % ", ".join(qn(f.column) for f in fields)) if self.query.fields: value_rows = [ - [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields] + [ + self.prepare_value(field, self.pre_save_val(field, obj)) + for field in fields + ] for obj in self.query.objs ] else: # An empty object. - value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs] + value_rows = [ + [self.connection.ops.pk_default_value()] for _ in self.query.objs + ] fields = [None] # Currently the backends just accept values when generating bulk # queries and generate their own placeholders. Doing that isn't # necessary and it should be possible to use placeholders and # expressions in bulk inserts too. - can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert) + can_bulk = ( + not self.returning_fields and self.connection.features.has_bulk_insert + ) placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) @@ -1418,9 +1599,14 @@ class SQLInsertCompiler(SQLCompiler): self.query.update_fields, self.query.unique_fields, ) - if self.returning_fields and self.connection.features.can_return_columns_from_insert: + if ( + self.returning_fields + and self.connection.features.can_return_columns_from_insert + ): if self.connection.features.can_return_rows_from_bulk_insert: - result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) + result.append( + self.connection.ops.bulk_insert_sql(fields, placeholder_rows) + ) params = param_rows else: result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) @@ -1429,7 +1615,9 @@ class SQLInsertCompiler(SQLCompiler): result.append(on_conflict_suffix_sql) # Skip empty r_sql to allow subclasses to customize behavior for # 3rd party backends. Refs #19096. - r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields) + r_sql, self.returning_params = self.connection.ops.return_insert_columns( + self.returning_fields + ) if r_sql: result.append(r_sql) params += [self.returning_params] @@ -1450,8 +1638,9 @@ class SQLInsertCompiler(SQLCompiler): def execute_sql(self, returning_fields=None): assert not ( - returning_fields and len(self.query.objs) != 1 and - not self.connection.features.can_return_rows_from_bulk_insert + returning_fields + and len(self.query.objs) != 1 + and not self.connection.features.can_return_rows_from_bulk_insert ) opts = self.query.get_meta() self.returning_fields = returning_fields @@ -1460,17 +1649,29 @@ class SQLInsertCompiler(SQLCompiler): cursor.execute(sql, params) if not self.returning_fields: return [] - if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: + if ( + self.connection.features.can_return_rows_from_bulk_insert + and len(self.query.objs) > 1 + ): rows = self.connection.ops.fetch_returned_insert_rows(cursor) elif self.connection.features.can_return_columns_from_insert: assert len(self.query.objs) == 1 - rows = [self.connection.ops.fetch_returned_insert_columns( - cursor, self.returning_params, - )] + rows = [ + self.connection.ops.fetch_returned_insert_columns( + cursor, + self.returning_params, + ) + ] else: - rows = [(self.connection.ops.last_insert_id( - cursor, opts.db_table, opts.pk.column, - ),)] + rows = [ + ( + self.connection.ops.last_insert_id( + cursor, + opts.db_table, + opts.pk.column, + ), + ) + ] cols = [field.get_col(opts.db_table) for field in self.returning_fields] converters = self.get_converters(cols) if converters: @@ -1489,7 +1690,7 @@ class SQLDeleteCompiler(SQLCompiler): def _expr_refs_base_model(cls, expr, base_model): if isinstance(expr, Query): return expr.model == base_model - if not hasattr(expr, 'get_source_expressions'): + if not hasattr(expr, "get_source_expressions"): return False return any( cls._expr_refs_base_model(source_expr, base_model) @@ -1500,17 +1701,17 @@ class SQLDeleteCompiler(SQLCompiler): def contains_self_reference_subquery(self): return any( self._expr_refs_base_model(expr, self.query.model) - for expr in chain(self.query.annotations.values(), self.query.where.children) + for expr in chain( + self.query.annotations.values(), self.query.where.children + ) ) def _as_sql(self, query): - result = [ - 'DELETE FROM %s' % self.quote_name_unless_alias(query.base_table) - ] + result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)] where, params = self.compile(query.where) if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(params) + result.append("WHERE %s" % where) + return " ".join(result), tuple(params) def as_sql(self): """ @@ -1523,16 +1724,14 @@ class SQLDeleteCompiler(SQLCompiler): innerq.__class__ = Query innerq.clear_select_clause() pk = self.query.model._meta.pk - innerq.select = [ - pk.get_col(self.query.get_initial_alias()) - ] + innerq.select = [pk.get_col(self.query.get_initial_alias())] outerq = Query(self.query.model) if not self.connection.features.update_can_self_select: # Force the materialization of the inner query to allow reference # to the target table on MySQL. sql, params = innerq.get_compiler(connection=self.connection).as_sql() - innerq = RawSQL('SELECT * FROM (%s) subquery' % sql, params) - outerq.add_filter('pk__in', innerq) + innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params) + outerq.add_filter("pk__in", innerq) return self._as_sql(outerq) @@ -1544,23 +1743,25 @@ class SQLUpdateCompiler(SQLCompiler): """ self.pre_sql_setup() if not self.query.values: - return '', () + return "", () qn = self.quote_name_unless_alias values, update_params = [], [] for field, model, val in self.query.values: - if hasattr(val, 'resolve_expression'): - val = val.resolve_expression(self.query, allow_joins=False, for_save=True) + if hasattr(val, "resolve_expression"): + val = val.resolve_expression( + self.query, allow_joins=False, for_save=True + ) if val.contains_aggregate: raise FieldError( - 'Aggregate functions are not allowed in this query ' - '(%s=%r).' % (field.name, val) + "Aggregate functions are not allowed in this query " + "(%s=%r)." % (field.name, val) ) if val.contains_over_clause: raise FieldError( - 'Window expressions are not allowed in this query ' - '(%s=%r).' % (field.name, val) + "Window expressions are not allowed in this query " + "(%s=%r)." % (field.name, val) ) - elif hasattr(val, 'prepare_database_save'): + elif hasattr(val, "prepare_database_save"): if field.remote_field: val = field.get_db_prep_save( val.prepare_database_save(field), @@ -1576,29 +1777,29 @@ class SQLUpdateCompiler(SQLCompiler): val = field.get_db_prep_save(val, connection=self.connection) # Getting the placeholder for the field. - if hasattr(field, 'get_placeholder'): + if hasattr(field, "get_placeholder"): placeholder = field.get_placeholder(val, self, self.connection) else: - placeholder = '%s' + placeholder = "%s" name = field.column - if hasattr(val, 'as_sql'): + if hasattr(val, "as_sql"): sql, params = self.compile(val) - values.append('%s = %s' % (qn(name), placeholder % sql)) + values.append("%s = %s" % (qn(name), placeholder % sql)) update_params.extend(params) elif val is not None: - values.append('%s = %s' % (qn(name), placeholder)) + values.append("%s = %s" % (qn(name), placeholder)) update_params.append(val) else: - values.append('%s = NULL' % qn(name)) + values.append("%s = NULL" % qn(name)) table = self.query.base_table result = [ - 'UPDATE %s SET' % qn(table), - ', '.join(values), + "UPDATE %s SET" % qn(table), + ", ".join(values), ] where, params = self.compile(self.query.where) if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(update_params + params) + result.append("WHERE %s" % where) + return " ".join(result), tuple(update_params + params) def execute_sql(self, result_type): """ @@ -1644,7 +1845,9 @@ class SQLUpdateCompiler(SQLCompiler): query.add_fields([query.get_meta().pk.name]) super().pre_sql_setup() - must_pre_select = count > 1 and not self.connection.features.update_can_self_select + must_pre_select = ( + count > 1 and not self.connection.features.update_can_self_select + ) # Now we adjust the current query: reset the where clause and get rid # of all the tables we don't need (since they're in the sub-select). @@ -1656,11 +1859,11 @@ class SQLUpdateCompiler(SQLCompiler): idents = [] for rows in query.get_compiler(self.using).execute_sql(MULTI): idents.extend(r[0] for r in rows) - self.query.add_filter('pk__in', idents) + self.query.add_filter("pk__in", idents) self.query.related_ids = idents else: # The fast path. Filters and updates in one query. - self.query.add_filter('pk__in', query) + self.query.add_filter("pk__in", query) self.query.reset_refcounts(refcounts_before) @@ -1677,13 +1880,14 @@ class SQLAggregateCompiler(SQLCompiler): sql.append(ann_sql) params.extend(ann_params) self.col_count = len(self.query.annotation_select) - sql = ', '.join(sql) + sql = ", ".join(sql) params = tuple(params) inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( - self.using, elide_empty=self.elide_empty, + self.using, + elide_empty=self.elide_empty, ).as_sql(with_col_aliases=True) - sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql) + sql = "SELECT %s FROM (%s) subquery" % (sql, inner_query_sql) params = params + inner_query_params return sql, params diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index a1db61b9ff..fdfb2ea891 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -9,16 +9,16 @@ GET_ITERATOR_CHUNK_SIZE = 100 # Namedtuples for sql.* internal use. # How many results to expect from a cursor.execute call -MULTI = 'multi' -SINGLE = 'single' -CURSOR = 'cursor' -NO_RESULTS = 'no results' +MULTI = "multi" +SINGLE = "single" +CURSOR = "cursor" +NO_RESULTS = "no results" ORDER_DIR = { - 'ASC': ('ASC', 'DESC'), - 'DESC': ('DESC', 'ASC'), + "ASC": ("ASC", "DESC"), + "DESC": ("DESC", "ASC"), } # SQL join types. -INNER = 'INNER JOIN' -LOUTER = 'LEFT OUTER JOIN' +INNER = "INNER JOIN" +LOUTER = "LEFT OUTER JOIN" diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index e08b570350..f398074bf7 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -11,6 +11,7 @@ class MultiJoin(Exception): multi-valued join was attempted (if the caller wants to treat that exceptionally). """ + def __init__(self, names_pos, path_with_names): self.level = names_pos # The path travelled, this includes the path to the multijoin. @@ -38,8 +39,17 @@ class Join: - as_sql() - relabeled_clone() """ - def __init__(self, table_name, parent_alias, table_alias, join_type, - join_field, nullable, filtered_relation=None): + + def __init__( + self, + table_name, + parent_alias, + table_alias, + join_type, + join_field, + nullable, + filtered_relation=None, + ): # Join table self.table_name = table_name self.parent_alias = parent_alias @@ -69,35 +79,47 @@ class Join: # Add a join condition for each pair of joining columns. for lhs_col, rhs_col in self.join_cols: - join_conditions.append('%s.%s = %s.%s' % ( - qn(self.parent_alias), - qn2(lhs_col), - qn(self.table_alias), - qn2(rhs_col), - )) + join_conditions.append( + "%s.%s = %s.%s" + % ( + qn(self.parent_alias), + qn2(lhs_col), + qn(self.table_alias), + qn2(rhs_col), + ) + ) # Add a single condition inside parentheses for whatever # get_extra_restriction() returns. - extra_cond = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) + extra_cond = self.join_field.get_extra_restriction( + self.table_alias, self.parent_alias + ) if extra_cond: extra_sql, extra_params = compiler.compile(extra_cond) - join_conditions.append('(%s)' % extra_sql) + join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if self.filtered_relation: extra_sql, extra_params = compiler.compile(self.filtered_relation) if extra_sql: - join_conditions.append('(%s)' % extra_sql) + join_conditions.append("(%s)" % extra_sql) params.extend(extra_params) if not join_conditions: # This might be a rel on the other end of an actual declared field. - declared_field = getattr(self.join_field, 'field', self.join_field) + declared_field = getattr(self.join_field, "field", self.join_field) raise ValueError( "Join generated an empty ON clause. %s did not yield either " "joining columns or extra restrictions." % declared_field.__class__ ) - on_clause_sql = ' AND '.join(join_conditions) - alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) - sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql) + on_clause_sql = " AND ".join(join_conditions) + alias_str = ( + "" if self.table_alias == self.table_name else (" %s" % self.table_alias) + ) + sql = "%s %s%s ON (%s)" % ( + self.join_type, + qn(self.table_name), + alias_str, + on_clause_sql, + ) return sql, params def relabeled_clone(self, change_map): @@ -105,12 +127,19 @@ class Join: new_table_alias = change_map.get(self.table_alias, self.table_alias) if self.filtered_relation is not None: filtered_relation = self.filtered_relation.clone() - filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path] + filtered_relation.path = [ + change_map.get(p, p) for p in self.filtered_relation.path + ] else: filtered_relation = None return self.__class__( - self.table_name, new_parent_alias, new_table_alias, self.join_type, - self.join_field, self.nullable, filtered_relation=filtered_relation, + self.table_name, + new_parent_alias, + new_table_alias, + self.join_type, + self.join_field, + self.nullable, + filtered_relation=filtered_relation, ) @property @@ -153,6 +182,7 @@ class BaseTable: SELECT * FROM "foo" WHERE somecond could be generated by this class. """ + join_type = None parent_alias = None filtered_relation = None @@ -162,12 +192,16 @@ class BaseTable: self.table_alias = alias def as_sql(self, compiler, connection): - alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) + alias_str = ( + "" if self.table_alias == self.table_name else (" %s" % self.table_alias) + ) base_sql = compiler.quote_name_unless_alias(self.table_name) return base_sql + alias_str, [] def relabeled_clone(self, change_map): - return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) + return self.__class__( + self.table_name, change_map.get(self.table_alias, self.table_alias) + ) @property def identity(self): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 1dc770ae3a..242b2a1f3f 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -20,32 +20,37 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import ( - BaseExpression, Col, Exists, F, OuterRef, Ref, ResolvedOuterRef, + BaseExpression, + Col, + Exists, + F, + OuterRef, + Ref, + ResolvedOuterRef, ) from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup from django.db.models.query_utils import ( - Q, check_rel_lookup_compatibility, refs_expression, + Q, + check_rel_lookup_compatibility, + refs_expression, ) from django.db.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE -from django.db.models.sql.datastructures import ( - BaseTable, Empty, Join, MultiJoin, -) -from django.db.models.sql.where import ( - AND, OR, ExtraWhere, NothingNode, WhereNode, -) +from django.db.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin +from django.db.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode from django.utils.functional import cached_property from django.utils.tree import Node -__all__ = ['Query', 'RawQuery'] +__all__ = ["Query", "RawQuery"] def get_field_names_from_opts(opts): - return set(chain.from_iterable( - (f.name, f.attname) if f.concrete else (f.name,) - for f in opts.get_fields() - )) + return set( + chain.from_iterable( + (f.name, f.attname) if f.concrete else (f.name,) for f in opts.get_fields() + ) + ) def get_children_from_q(q): @@ -57,8 +62,8 @@ def get_children_from_q(q): JoinInfo = namedtuple( - 'JoinInfo', - ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function') + "JoinInfo", + ("final_field", "targets", "opts", "joins", "path", "transform_function"), ) @@ -87,8 +92,7 @@ class RawQuery: if self.cursor is None: self._execute_query() converter = connections[self.using].introspection.identifier_converter - return [converter(column_meta[0]) - for column_meta in self.cursor.description] + return [converter(column_meta[0]) for column_meta in self.cursor.description] def __iter__(self): # Always execute a new query for a new iterator. @@ -136,17 +140,17 @@ class RawQuery: self.cursor.execute(self.sql, params) -ExplainInfo = namedtuple('ExplainInfo', ('format', 'options')) +ExplainInfo = namedtuple("ExplainInfo", ("format", "options")) class Query(BaseExpression): """A single SQL query.""" - alias_prefix = 'T' + alias_prefix = "T" empty_result_set_value = None subq_aliases = frozenset([alias_prefix]) - compiler = 'SQLCompiler' + compiler = "SQLCompiler" base_table_class = BaseTable join_class = Join @@ -167,7 +171,7 @@ class Query(BaseExpression): # aliases too. # Map external tables to whether they are aliased. self.external_aliases = {} - self.table_map = {} # Maps table names to list of aliases. + self.table_map = {} # Maps table names to list of aliases. self.default_cols = True self.default_ordering = True self.standard_ordering = True @@ -240,13 +244,15 @@ class Query(BaseExpression): def output_field(self): if len(self.select) == 1: select = self.select[0] - return getattr(select, 'target', None) or select.field + return getattr(select, "target", None) or select.field elif len(self.annotation_select) == 1: return next(iter(self.annotation_select.values())).output_field @property def has_select_fields(self): - return bool(self.select or self.annotation_select_mask or self.extra_select_mask) + return bool( + self.select or self.annotation_select_mask or self.extra_select_mask + ) @cached_property def base_table(self): @@ -282,7 +288,9 @@ class Query(BaseExpression): raise ValueError("Need either using or connection") if using: connection = connections[using] - return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty) + return connection.ops.compiler(self.compiler)( + self, connection, using, elide_empty + ) def get_meta(self): """ @@ -311,9 +319,9 @@ class Query(BaseExpression): if self.annotation_select_mask is not None: obj.annotation_select_mask = self.annotation_select_mask.copy() if self.combined_queries: - obj.combined_queries = tuple([ - query.clone() for query in self.combined_queries - ]) + obj.combined_queries = tuple( + [query.clone() for query in self.combined_queries] + ) # _annotation_select_cache cannot be copied, as doing so breaks the # (necessary) state in which both annotations and # _annotation_select_cache point to the same underlying objects. @@ -329,7 +337,7 @@ class Query(BaseExpression): # Use deepcopy because select_related stores fields in nested # dicts. obj.select_related = copy.deepcopy(obj.select_related) - if 'subq_aliases' in self.__dict__: + if "subq_aliases" in self.__dict__: obj.subq_aliases = self.subq_aliases.copy() obj.used_aliases = self.used_aliases.copy() obj._filtered_relations = self._filtered_relations.copy() @@ -351,7 +359,7 @@ class Query(BaseExpression): if not obj.filter_is_sticky: obj.used_aliases = set() obj.filter_is_sticky = False - if hasattr(obj, '_setup_query'): + if hasattr(obj, "_setup_query"): obj._setup_query() return obj @@ -401,11 +409,13 @@ class Query(BaseExpression): break else: # An expression that is not selected the subquery. - if isinstance(expr, Col) or (expr.contains_aggregate and not expr.is_summary): + if isinstance(expr, Col) or ( + expr.contains_aggregate and not expr.is_summary + ): # Reference column or another aggregate. Select it # under a non-conflicting alias. col_cnt += 1 - col_alias = '__col%d' % col_cnt + col_alias = "__col%d" % col_cnt self.annotations[col_alias] = expr self.append_annotation_mask([col_alias]) new_expr = Ref(col_alias, expr) @@ -424,8 +434,8 @@ class Query(BaseExpression): if not self.annotation_select: return {} existing_annotations = [ - annotation for alias, annotation - in self.annotations.items() + annotation + for alias, annotation in self.annotations.items() if alias not in added_aggregate_names ] # Decide if we need to use a subquery. @@ -439,9 +449,15 @@ class Query(BaseExpression): # those operations must be done in a subquery so that the query # aggregates on the limit and/or distinct results instead of applying # the distinct and limit after the aggregation. - if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or - self.distinct or self.combinator): + if ( + isinstance(self.group_by, tuple) + or self.is_sliced + or existing_annotations + or self.distinct + or self.combinator + ): from django.db.models.sql.subqueries import AggregateQuery + inner_query = self.clone() inner_query.subquery = True outer_query = AggregateQuery(self.model, inner_query) @@ -459,15 +475,18 @@ class Query(BaseExpression): # clearing the select clause can alter results if distinct is # used. has_existing_aggregate_annotations = any( - annotation for annotation in existing_annotations - if getattr(annotation, 'contains_aggregate', True) + annotation + for annotation in existing_annotations + if getattr(annotation, "contains_aggregate", True) ) if inner_query.default_cols and has_existing_aggregate_annotations: - inner_query.group_by = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),) + inner_query.group_by = ( + self.model._meta.pk.get_col(inner_query.get_initial_alias()), + ) inner_query.default_cols = False - relabels = {t: 'subquery' for t in inner_query.alias_map} - relabels[None] = 'subquery' + relabels = {t: "subquery" for t in inner_query.alias_map} + relabels[None] = "subquery" # Remove any aggregates marked for reduction from the subquery # and move them to the outer AggregateQuery. col_cnt = 0 @@ -475,16 +494,24 @@ class Query(BaseExpression): annotation_select_mask = inner_query.annotation_select_mask if expression.is_summary: expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) - outer_query.annotations[alias] = expression.relabeled_clone(relabels) + outer_query.annotations[alias] = expression.relabeled_clone( + relabels + ) del inner_query.annotations[alias] annotation_select_mask.remove(alias) # Make sure the annotation_select wont use cached results. inner_query.set_annotation_mask(inner_query.annotation_select_mask) - if inner_query.select == () and not inner_query.default_cols and not inner_query.annotation_select_mask: + if ( + inner_query.select == () + and not inner_query.default_cols + and not inner_query.annotation_select_mask + ): # In case of Model.objects[0:3].count(), there would be no # field selected in the inner query, yet we must use a subquery. # So, make sure at least one field is selected. - inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),) + inner_query.select = ( + self.model._meta.pk.get_col(inner_query.get_initial_alias()), + ) else: outer_query = self self.select = () @@ -515,8 +542,8 @@ class Query(BaseExpression): Perform a COUNT() query using the current filter constraints. """ obj = self.clone() - obj.add_annotation(Count('*'), alias='__count', is_summary=True) - return obj.get_aggregation(using, ['__count'])['__count'] + obj.add_annotation(Count("*"), alias="__count", is_summary=True) + return obj.get_aggregation(using, ["__count"])["__count"] def has_filters(self): return self.where @@ -525,13 +552,17 @@ class Query(BaseExpression): q = self.clone() if not q.distinct: if q.group_by is True: - q.add_fields((f.attname for f in self.model._meta.concrete_fields), False) + q.add_fields( + (f.attname for f in self.model._meta.concrete_fields), False + ) # Disable GROUP BY aliases to avoid orphaning references to the # SELECT clause which is about to be cleared. q.set_group_by(allow_aliases=False) q.clear_select_clause() - if q.combined_queries and q.combinator == 'union': - limit_combined = connections[using].features.supports_slicing_ordering_in_compound + if q.combined_queries and q.combinator == "union": + limit_combined = connections[ + using + ].features.supports_slicing_ordering_in_compound q.combined_queries = tuple( combined_query.exists(using, limit=limit_combined) for combined_query in q.combined_queries @@ -539,8 +570,8 @@ class Query(BaseExpression): q.clear_ordering(force=True) if limit: q.set_limits(high=1) - q.add_extra({'a': 1}, None, None, None, None, None) - q.set_extra_mask(['a']) + q.add_extra({"a": 1}, None, None, None, None, None) + q.set_extra_mask(["a"]) return q def has_results(self, using): @@ -552,7 +583,7 @@ class Query(BaseExpression): q = self.clone() q.explain_info = ExplainInfo(format, options) compiler = q.get_compiler(using=using) - return '\n'.join(compiler.explain_query()) + return "\n".join(compiler.explain_query()) def combine(self, rhs, connector): """ @@ -564,13 +595,13 @@ class Query(BaseExpression): 'rhs' query. """ if self.model != rhs.model: - raise TypeError('Cannot combine queries on two different base models.') + raise TypeError("Cannot combine queries on two different base models.") if self.is_sliced: - raise TypeError('Cannot combine queries once a slice has been taken.') + raise TypeError("Cannot combine queries once a slice has been taken.") if self.distinct != rhs.distinct: - raise TypeError('Cannot combine a unique query with a non-unique query.') + raise TypeError("Cannot combine a unique query with a non-unique query.") if self.distinct_fields != rhs.distinct_fields: - raise TypeError('Cannot combine queries with different distinct fields.') + raise TypeError("Cannot combine queries with different distinct fields.") # If lhs and rhs shares the same alias prefix, it is possible to have # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up @@ -583,7 +614,7 @@ class Query(BaseExpression): # Work out how to relabel the rhs aliases, if necessary. change_map = {} - conjunction = (connector == AND) + conjunction = connector == AND # Determine which existing joins can be reused. When combining the # query with AND we must recreate all joins for m2m filters. When @@ -600,7 +631,8 @@ class Query(BaseExpression): reuse = set() if conjunction else set(self.alias_map) joinpromoter = JoinPromoter(connector, 2, False) joinpromoter.add_votes( - j for j in self.alias_map if self.alias_map[j].join_type == INNER) + j for j in self.alias_map if self.alias_map[j].join_type == INNER + ) rhs_votes = set() # Now, add the joins from rhs query into the new query (skipping base # table). @@ -649,7 +681,9 @@ class Query(BaseExpression): # really make sense (or return consistent value sets). Not worth # the extra complexity when you can write a real query instead. if self.extra and rhs.extra: - raise ValueError("When merging querysets using 'or', you cannot have extra(select=...) on both sides.") + raise ValueError( + "When merging querysets using 'or', you cannot have extra(select=...) on both sides." + ) self.extra.update(rhs.extra) extra_select_mask = set() if self.extra_select_mask is not None: @@ -767,11 +801,13 @@ class Query(BaseExpression): # Create a new alias for this table. if alias_list: - alias = '%s%d' % (self.alias_prefix, len(self.alias_map) + 1) + alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) alias_list.append(alias) else: # The first occurrence of a table uses the table name directly. - alias = filtered_relation.alias if filtered_relation is not None else table_name + alias = ( + filtered_relation.alias if filtered_relation is not None else table_name + ) self.table_map[table_name] = [alias] self.alias_refcount[alias] = 1 return alias, True @@ -806,16 +842,19 @@ class Query(BaseExpression): # Only the first alias (skipped above) should have None join_type assert self.alias_map[alias].join_type is not None parent_alias = self.alias_map[alias].parent_alias - parent_louter = parent_alias and self.alias_map[parent_alias].join_type == LOUTER + parent_louter = ( + parent_alias and self.alias_map[parent_alias].join_type == LOUTER + ) already_louter = self.alias_map[alias].join_type == LOUTER - if ((self.alias_map[alias].nullable or parent_louter) and - not already_louter): + if (self.alias_map[alias].nullable or parent_louter) and not already_louter: self.alias_map[alias] = self.alias_map[alias].promote() # Join type of 'alias' changed, so re-examine all aliases that # refer to this one. aliases.extend( - join for join in self.alias_map - if self.alias_map[join].parent_alias == alias and join not in aliases + join + for join in self.alias_map + if self.alias_map[join].parent_alias == alias + and join not in aliases ) def demote_joins(self, aliases): @@ -861,10 +900,13 @@ class Query(BaseExpression): # "group by" and "where". self.where.relabel_aliases(change_map) if isinstance(self.group_by, tuple): - self.group_by = tuple([col.relabeled_clone(change_map) for col in self.group_by]) + self.group_by = tuple( + [col.relabeled_clone(change_map) for col in self.group_by] + ) self.select = tuple([col.relabeled_clone(change_map) for col in self.select]) self.annotations = self.annotations and { - key: col.relabeled_clone(change_map) for key, col in self.annotations.items() + key: col.relabeled_clone(change_map) + for key, col in self.annotations.items() } # 2. Rename the alias in the internal table/alias datastructures. @@ -895,6 +937,7 @@ class Query(BaseExpression): conflict. Even tables that previously had no alias will get an alias after this call. To prevent changing aliases use the exclude parameter. """ + def prefix_gen(): """ Generate a sequence of characters in alphabetical order: @@ -908,9 +951,9 @@ class Query(BaseExpression): prefix = chr(ord(self.alias_prefix) + 1) yield prefix for n in count(1): - seq = alphabet[alphabet.index(prefix):] if prefix else alphabet + seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet for s in product(seq, repeat=n): - yield ''.join(s) + yield "".join(s) prefix = None if self.alias_prefix != other_query.alias_prefix: @@ -928,17 +971,19 @@ class Query(BaseExpression): break if pos > local_recursion_limit: raise RecursionError( - 'Maximum recursion depth exceeded: too many subqueries.' + "Maximum recursion depth exceeded: too many subqueries." ) self.subq_aliases = self.subq_aliases.union([self.alias_prefix]) other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases) if exclude is None: exclude = {} - self.change_aliases({ - alias: '%s%d' % (self.alias_prefix, pos) - for pos, alias in enumerate(self.alias_map) - if alias not in exclude - }) + self.change_aliases( + { + alias: "%s%d" % (self.alias_prefix, pos) + for pos, alias in enumerate(self.alias_map) + if alias not in exclude + } + ) def get_initial_alias(self): """ @@ -974,7 +1019,8 @@ class Query(BaseExpression): joins are created as LOUTER if the join is nullable. """ reuse_aliases = [ - a for a, j in self.alias_map.items() + a + for a, j in self.alias_map.items() if (reuse is None or a in reuse) and j.equals(join) ] if reuse_aliases: @@ -988,7 +1034,9 @@ class Query(BaseExpression): return reuse_alias # No reuse is possible, so we need a new alias. - alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation) + alias, _ = self.table_alias( + join.table_name, create=True, filtered_relation=join.filtered_relation + ) if join.join_type: if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: join_type = LOUTER @@ -1034,8 +1082,9 @@ class Query(BaseExpression): def add_annotation(self, annotation, alias, is_summary=False, select=True): """Add a single annotation expression to the Query.""" - annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None, - summarize=is_summary) + annotation = annotation.resolve_expression( + self, allow_joins=True, reuse=None, summarize=is_summary + ) if select: self.append_annotation_mask([alias]) else: @@ -1050,27 +1099,32 @@ class Query(BaseExpression): clone.where.resolve_expression(query, *args, **kwargs) # Resolve combined queries. if clone.combinator: - clone.combined_queries = tuple([ - combined_query.resolve_expression(query, *args, **kwargs) - for combined_query in clone.combined_queries - ]) + clone.combined_queries = tuple( + [ + combined_query.resolve_expression(query, *args, **kwargs) + for combined_query in clone.combined_queries + ] + ) for key, value in clone.annotations.items(): resolved = value.resolve_expression(query, *args, **kwargs) - if hasattr(resolved, 'external_aliases'): + if hasattr(resolved, "external_aliases"): resolved.external_aliases.update(clone.external_aliases) clone.annotations[key] = resolved # Outer query's aliases are considered external. for alias, table in query.alias_map.items(): clone.external_aliases[alias] = ( - (isinstance(table, Join) and table.join_field.related_model._meta.db_table != alias) or - (isinstance(table, BaseTable) and table.table_name != table.table_alias) + isinstance(table, Join) + and table.join_field.related_model._meta.db_table != alias + ) or ( + isinstance(table, BaseTable) and table.table_name != table.table_alias ) return clone def get_external_cols(self): exprs = chain(self.annotations.values(), self.where.children) return [ - col for col in self._gen_cols(exprs, include_external=True) + col + for col in self._gen_cols(exprs, include_external=True) if col.alias in self.external_aliases ] @@ -1086,19 +1140,21 @@ class Query(BaseExpression): # Some backends (e.g. Oracle) raise an error when a subquery contains # unnecessary ORDER BY clause. if ( - self.subquery and - not connection.features.ignores_unnecessary_order_by_in_subqueries + self.subquery + and not connection.features.ignores_unnecessary_order_by_in_subqueries ): self.clear_ordering(force=False) sql, params = self.get_compiler(connection=connection).as_sql() if self.subquery: - sql = '(%s)' % sql + sql = "(%s)" % sql return sql, params def resolve_lookup_value(self, value, can_reuse, allow_joins): - if hasattr(value, 'resolve_expression'): + if hasattr(value, "resolve_expression"): value = value.resolve_expression( - self, reuse=can_reuse, allow_joins=allow_joins, + self, + reuse=can_reuse, + allow_joins=allow_joins, ) elif isinstance(value, (list, tuple)): # The items of the iterable may be expressions and therefore need @@ -1108,7 +1164,7 @@ class Query(BaseExpression): for sub_value in value ) type_ = type(value) - if hasattr(type_, '_make'): # namedtuple + if hasattr(type_, "_make"): # namedtuple return type_(*values) return type_(values) return value @@ -1119,15 +1175,17 @@ class Query(BaseExpression): """ lookup_splitted = lookup.split(LOOKUP_SEP) if self.annotations: - expression, expression_lookups = refs_expression(lookup_splitted, self.annotations) + expression, expression_lookups = refs_expression( + lookup_splitted, self.annotations + ) if expression: return expression_lookups, (), expression _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) - field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] + field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)] if len(lookup_parts) > 1 and not field_parts: raise FieldError( - 'Invalid lookup "%s" for model %s".' % - (lookup, self.get_meta().model.__name__) + 'Invalid lookup "%s" for model %s".' + % (lookup, self.get_meta().model.__name__) ) return lookup_parts, field_parts, False @@ -1136,11 +1194,12 @@ class Query(BaseExpression): Check whether the object passed while querying is of the correct type. If not, raise a ValueError specifying the wrong object. """ - if hasattr(value, '_meta'): + if hasattr(value, "_meta"): if not check_rel_lookup_compatibility(value._meta.model, opts, field): raise ValueError( - 'Cannot query "%s": Must be "%s" instance.' % - (value, opts.object_name)) + 'Cannot query "%s": Must be "%s" instance.' + % (value, opts.object_name) + ) def check_related_objects(self, field, value, opts): """Check the type of object passed to query relations.""" @@ -1150,29 +1209,31 @@ class Query(BaseExpression): # opts would be Author's (from the author field) and value.model # would be Author.objects.all() queryset's .model (Author also). # The field is the related field on the lhs side. - if (isinstance(value, Query) and not value.has_select_fields and - not check_rel_lookup_compatibility(value.model, opts, field)): + if ( + isinstance(value, Query) + and not value.has_select_fields + and not check_rel_lookup_compatibility(value.model, opts, field) + ): raise ValueError( - 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' % - (value.model._meta.object_name, opts.object_name) + 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' + % (value.model._meta.object_name, opts.object_name) ) - elif hasattr(value, '_meta'): + elif hasattr(value, "_meta"): self.check_query_object_type(value, opts, field) - elif hasattr(value, '__iter__'): + elif hasattr(value, "__iter__"): for v in value: self.check_query_object_type(v, opts, field) def check_filterable(self, expression): """Raise an error if expression cannot be used in a WHERE clause.""" - if ( - hasattr(expression, 'resolve_expression') and - not getattr(expression, 'filterable', True) + if hasattr(expression, "resolve_expression") and not getattr( + expression, "filterable", True ): raise NotSupportedError( - expression.__class__.__name__ + ' is disallowed in the filter ' - 'clause.' + expression.__class__.__name__ + " is disallowed in the filter " + "clause." ) - if hasattr(expression, 'get_source_expressions'): + if hasattr(expression, "get_source_expressions"): for expr in expression.get_source_expressions(): self.check_filterable(expr) @@ -1186,7 +1247,7 @@ class Query(BaseExpression): and get_transform(). """ # __exact is the default lookup if one isn't given. - *transforms, lookup_name = lookups or ['exact'] + *transforms, lookup_name = lookups or ["exact"] for name in transforms: lhs = self.try_transform(lhs, name) # First try get_lookup() so that the lookup takes precedence if the lhs @@ -1194,11 +1255,13 @@ class Query(BaseExpression): lookup_class = lhs.get_lookup(lookup_name) if not lookup_class: if lhs.field.is_relation: - raise FieldError('Related Field got invalid lookup: {}'.format(lookup_name)) + raise FieldError( + "Related Field got invalid lookup: {}".format(lookup_name) + ) # A lookup wasn't found. Try to interpret the name as a transform # and do an Exact lookup against it. lhs = self.try_transform(lhs, lookup_name) - lookup_name = 'exact' + lookup_name = "exact" lookup_class = lhs.get_lookup(lookup_name) if not lookup_class: return @@ -1207,20 +1270,20 @@ class Query(BaseExpression): # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value unless the lookup supports it. if lookup.rhs is None and not lookup.can_use_none_as_rhs: - if lookup_name not in ('exact', 'iexact'): + if lookup_name not in ("exact", "iexact"): raise ValueError("Cannot use None as a query value") - return lhs.get_lookup('isnull')(lhs, True) + return lhs.get_lookup("isnull")(lhs, True) # For Oracle '' is equivalent to null. The check must be done at this # stage because join promotion can't be done in the compiler. Using # DEFAULT_DB_ALIAS isn't nice but it's the best that can be done here. # A similar thing is done in is_nullable(), too. if ( - lookup_name == 'exact' and - lookup.rhs == '' and - connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls + lookup_name == "exact" + and lookup.rhs == "" + and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls ): - return lhs.get_lookup('isnull')(lhs, True) + return lhs.get_lookup("isnull")(lhs, True) return lookup @@ -1234,19 +1297,28 @@ class Query(BaseExpression): return transform_class(lhs) else: output_field = lhs.output_field.__class__ - suggested_lookups = difflib.get_close_matches(name, output_field.get_lookups()) + suggested_lookups = difflib.get_close_matches( + name, output_field.get_lookups() + ) if suggested_lookups: - suggestion = ', perhaps you meant %s?' % ' or '.join(suggested_lookups) + suggestion = ", perhaps you meant %s?" % " or ".join(suggested_lookups) else: - suggestion = '.' + suggestion = "." raise FieldError( "Unsupported lookup '%s' for %s or join on the field not " "permitted%s" % (name, output_field.__name__, suggestion) ) - def build_filter(self, filter_expr, branch_negated=False, current_negated=False, - can_reuse=None, allow_joins=True, split_subq=True, - check_filterable=True): + def build_filter( + self, + filter_expr, + branch_negated=False, + current_negated=False, + can_reuse=None, + allow_joins=True, + split_subq=True, + check_filterable=True, + ): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1284,12 +1356,12 @@ class Query(BaseExpression): split_subq=split_subq, check_filterable=check_filterable, ) - if hasattr(filter_expr, 'resolve_expression'): - if not getattr(filter_expr, 'conditional', False): - raise TypeError('Cannot filter against a non-conditional expression.') + if hasattr(filter_expr, "resolve_expression"): + if not getattr(filter_expr, "conditional", False): + raise TypeError("Cannot filter against a non-conditional expression.") condition = filter_expr.resolve_expression(self, allow_joins=allow_joins) if not isinstance(condition, Lookup): - condition = self.build_lookup(['exact'], condition, True) + condition = self.build_lookup(["exact"], condition, True) return WhereNode([condition], connector=AND), [] arg, value = filter_expr if not arg: @@ -1304,7 +1376,9 @@ class Query(BaseExpression): pre_joins = self.alias_refcount.copy() value = self.resolve_lookup_value(value, can_reuse, allow_joins) - used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} + used_joins = { + k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0) + } if check_filterable: self.check_filterable(value) @@ -1319,7 +1393,11 @@ class Query(BaseExpression): try: join_info = self.setup_joins( - parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, + parts, + opts, + alias, + can_reuse=can_reuse, + allow_many=allow_many, ) # Prevent iterator from being consumed by check_related_objects() @@ -1336,7 +1414,9 @@ class Query(BaseExpression): # Update used_joins before trimming since they are reused to determine # which joins could be later promoted to INNER. used_joins.update(join_info.joins) - targets, alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) + targets, alias, join_list = self.trim_joins( + join_info.targets, join_info.joins, join_info.path + ) if can_reuse is not None: can_reuse.update(join_list) @@ -1344,11 +1424,15 @@ class Query(BaseExpression): # No support for transforms for relational fields num_lookups = len(lookups) if num_lookups > 1: - raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) + raise FieldError( + "Related Field got invalid lookup: {}".format(lookups[0]) + ) if len(targets) == 1: col = self._get_col(targets[0], join_info.final_field, alias) else: - col = MultiColSource(alias, targets, join_info.targets, join_info.final_field) + col = MultiColSource( + alias, targets, join_info.targets, join_info.final_field + ) else: col = self._get_col(targets[0], join_info.final_field, alias) @@ -1356,10 +1440,16 @@ class Query(BaseExpression): lookup_type = condition.lookup_name clause = WhereNode([condition], connector=AND) - require_outer = lookup_type == 'isnull' and condition.rhs is True and not current_negated - if current_negated and (lookup_type != 'isnull' or condition.rhs is False) and condition.rhs is not None: + require_outer = ( + lookup_type == "isnull" and condition.rhs is True and not current_negated + ) + if ( + current_negated + and (lookup_type != "isnull" or condition.rhs is False) + and condition.rhs is not None + ): require_outer = True - if lookup_type != 'isnull': + if lookup_type != "isnull": # The condition added here will be SQL like this: # NOT (col IS NOT NULL), where the first NOT is added in # upper layers of code. The reason for addition is that if col @@ -1370,16 +1460,16 @@ class Query(BaseExpression): # <=> # NOT (col IS NOT NULL AND col = someval). if ( - self.is_nullable(targets[0]) or - self.alias_map[join_list[-1]].join_type == LOUTER + self.is_nullable(targets[0]) + or self.alias_map[join_list[-1]].join_type == LOUTER ): - lookup_class = targets[0].get_lookup('isnull') + lookup_class = targets[0].get_lookup("isnull") col = self._get_col(targets[0], join_info.targets[0], alias) clause.add(lookup_class(col, False), AND) # If someval is a nullable column, someval IS NOT NULL is # added. if isinstance(value, Col) and self.is_nullable(value.target): - lookup_class = value.target.get_lookup('isnull') + lookup_class = value.target.get_lookup("isnull") clause.add(lookup_class(value, False), AND) return clause, used_joins if not require_outer else () @@ -1397,7 +1487,9 @@ class Query(BaseExpression): # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if # rel_a doesn't produce any rows, then the whole condition must fail. # So, demotion is OK. - existing_inner = {a for a in self.alias_map if self.alias_map[a].join_type == INNER} + existing_inner = { + a for a in self.alias_map if self.alias_map[a].join_type == INNER + } clause, _ = self._add_q(q_object, self.used_aliases) if clause: self.where.add(clause, AND) @@ -1409,20 +1501,33 @@ class Query(BaseExpression): def clear_where(self): self.where = WhereNode() - def _add_q(self, q_object, used_aliases, branch_negated=False, - current_negated=False, allow_joins=True, split_subq=True, - check_filterable=True): + def _add_q( + self, + q_object, + used_aliases, + branch_negated=False, + current_negated=False, + allow_joins=True, + split_subq=True, + check_filterable=True, + ): """Add a Q-object to the current filter.""" connector = q_object.connector current_negated = current_negated ^ q_object.negated branch_negated = branch_negated or q_object.negated target_clause = WhereNode(connector=connector, negated=q_object.negated) - joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated) + joinpromoter = JoinPromoter( + q_object.connector, len(q_object.children), current_negated + ) for child in q_object.children: child_clause, needed_inner = self.build_filter( - child, can_reuse=used_aliases, branch_negated=branch_negated, - current_negated=current_negated, allow_joins=allow_joins, - split_subq=split_subq, check_filterable=check_filterable, + child, + can_reuse=used_aliases, + branch_negated=branch_negated, + current_negated=current_negated, + allow_joins=allow_joins, + split_subq=split_subq, + check_filterable=check_filterable, ) joinpromoter.add_votes(needed_inner) if child_clause: @@ -1430,7 +1535,9 @@ class Query(BaseExpression): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False): + def build_filtered_relation_q( + self, q_object, reuse, branch_negated=False, current_negated=False + ): """Add a FilteredRelation object to the current filter.""" connector = q_object.connector current_negated ^= q_object.negated @@ -1439,14 +1546,19 @@ class Query(BaseExpression): for child in q_object.children: if isinstance(child, Node): child_clause = self.build_filtered_relation_q( - child, reuse=reuse, branch_negated=branch_negated, + child, + reuse=reuse, + branch_negated=branch_negated, current_negated=current_negated, ) else: child_clause, _ = self.build_filter( - child, can_reuse=reuse, branch_negated=branch_negated, + child, + can_reuse=reuse, + branch_negated=branch_negated, current_negated=current_negated, - allow_joins=True, split_subq=False, + allow_joins=True, + split_subq=False, ) target_clause.add(child_clause, connector) return target_clause @@ -1454,7 +1566,9 @@ class Query(BaseExpression): def add_filtered_relation(self, filtered_relation, alias): filtered_relation.alias = alias lookups = dict(get_children_from_q(filtered_relation.condition)) - relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(filtered_relation.relation_name) + relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type( + filtered_relation.relation_name + ) if relation_lookup_parts: raise ValueError( "FilteredRelation's relation_name cannot contain lookups " @@ -1498,7 +1612,7 @@ class Query(BaseExpression): path, names_with_path = [], [] for pos, name in enumerate(names): cur_names_with_path = (name, []) - if name == 'pk': + if name == "pk": name = opts.pk.name field = None @@ -1513,7 +1627,10 @@ class Query(BaseExpression): if LOOKUP_SEP in filtered_relation.relation_name: parts = filtered_relation.relation_name.split(LOOKUP_SEP) filtered_relation_path, field, _, _ = self.names_to_path( - parts, opts, allow_many, fail_on_missing, + parts, + opts, + allow_many, + fail_on_missing, ) path.extend(filtered_relation_path[:-1]) else: @@ -1540,13 +1657,17 @@ class Query(BaseExpression): # one step. pos -= 1 if pos == -1 or fail_on_missing: - available = sorted([ - *get_field_names_from_opts(opts), - *self.annotation_select, - *self._filtered_relations, - ]) - raise FieldError("Cannot resolve keyword '%s' into field. " - "Choices are: %s" % (name, ", ".join(available))) + available = sorted( + [ + *get_field_names_from_opts(opts), + *self.annotation_select, + *self._filtered_relations, + ] + ) + raise FieldError( + "Cannot resolve keyword '%s' into field. " + "Choices are: %s" % (name, ", ".join(available)) + ) break # Check if we need any joins for concrete inheritance cases (the # field lives in parent, but we are currently in one of its @@ -1557,7 +1678,7 @@ class Query(BaseExpression): path.extend(path_to_parent) cur_names_with_path[1].extend(path_to_parent) opts = path_to_parent[-1].to_opts - if hasattr(field, 'path_infos'): + if hasattr(field, "path_infos"): if filtered_relation: pathinfos = field.get_path_info(filtered_relation) else: @@ -1565,7 +1686,7 @@ class Query(BaseExpression): if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: - cur_names_with_path[1].extend(pathinfos[0:inner_pos + 1]) + cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1]) names_with_path.append(cur_names_with_path) raise MultiJoin(pos + 1, names_with_path) last = pathinfos[-1] @@ -1582,9 +1703,10 @@ class Query(BaseExpression): if fail_on_missing and pos + 1 != len(names): raise FieldError( "Cannot resolve keyword %r into field. Join on '%s'" - " not permitted." % (names[pos + 1], name)) + " not permitted." % (names[pos + 1], name) + ) break - return path, final_field, targets, names[pos + 1:] + return path, final_field, targets, names[pos + 1 :] def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ @@ -1631,7 +1753,10 @@ class Query(BaseExpression): for pivot in range(len(names), 0, -1): try: path, final_field, targets, rest = self.names_to_path( - names[:pivot], opts, allow_many, fail_on_missing=True, + names[:pivot], + opts, + allow_many, + fail_on_missing=True, ) except FieldError as exc: if pivot == 1: @@ -1646,6 +1771,7 @@ class Query(BaseExpression): transforms = names[pivot:] break for name in transforms: + def transform(field, alias, *, name, previous): try: wrapped = previous(field, alias) @@ -1656,7 +1782,10 @@ class Query(BaseExpression): raise last_field_exception else: raise - final_transformer = functools.partial(transform, name=name, previous=final_transformer) + + final_transformer = functools.partial( + transform, name=name, previous=final_transformer + ) # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1673,8 +1802,13 @@ class Query(BaseExpression): else: nullable = True connection = self.join_class( - opts.db_table, alias, table_alias, INNER, join.join_field, - nullable, filtered_relation=filtered_relation, + opts.db_table, + alias, + table_alias, + INNER, + join.join_field, + nullable, + filtered_relation=filtered_relation, ) reuse = can_reuse if join.m2m else None alias = self.join(connection, reuse=reuse) @@ -1706,7 +1840,11 @@ class Query(BaseExpression): cur_targets = {t.column for t in targets} if not cur_targets.issubset(join_targets): break - targets_dict = {r[1].column: r[0] for r in info.join_field.related_fields if r[1].column in cur_targets} + targets_dict = { + r[1].column: r[0] + for r in info.join_field.related_fields + if r[1].column in cur_targets + } targets = tuple(targets_dict[t.column] for t in targets) self.unref_alias(joins.pop()) return targets, joins[-1], joins @@ -1716,9 +1854,11 @@ class Query(BaseExpression): for expr in exprs: if isinstance(expr, Col): yield expr - elif include_external and callable(getattr(expr, 'get_external_cols', None)): + elif include_external and callable( + getattr(expr, "get_external_cols", None) + ): yield from expr.get_external_cols() - elif hasattr(expr, 'get_source_expressions'): + elif hasattr(expr, "get_source_expressions"): yield from cls._gen_cols( expr.get_source_expressions(), include_external=include_external, @@ -1735,7 +1875,7 @@ class Query(BaseExpression): for alias in self._gen_col_aliases([annotation]): if isinstance(self.alias_map[alias], Join): raise FieldError( - 'Joined field references are not permitted in this query' + "Joined field references are not permitted in this query" ) if summarize: # Summarize currently means we are doing an aggregate() query @@ -1757,10 +1897,16 @@ class Query(BaseExpression): for transform in field_list[1:]: annotation = self.try_transform(annotation, transform) return annotation - join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) - targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) + join_info = self.setup_joins( + field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse + ) + targets, final_alias, join_list = self.trim_joins( + join_info.targets, join_info.joins, join_info.path + ) if not allow_joins and len(join_list) > 1: - raise FieldError('Joined field references are not permitted in this query') + raise FieldError( + "Joined field references are not permitted in this query" + ) if len(targets) > 1: raise FieldError( "Referencing multicolumn fields with F() objects isn't supported" @@ -1813,23 +1959,25 @@ class Query(BaseExpression): # Need to add a restriction so that outer query's filters are in effect for # the subquery, too. query.bump_prefix(self) - lookup_class = select_field.get_lookup('exact') + lookup_class = select_field.get_lookup("exact") # Note that the query.select[0].alias is different from alias # due to bump_prefix above. - lookup = lookup_class(pk.get_col(query.select[0].alias), - pk.get_col(alias)) + lookup = lookup_class(pk.get_col(query.select[0].alias), pk.get_col(alias)) query.where.add(lookup, AND) query.external_aliases[alias] = True - lookup_class = select_field.get_lookup('exact') + lookup_class = select_field.get_lookup("exact") lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix)) query.where.add(lookup, AND) condition, needed_inner = self.build_filter(Exists(query)) if contains_louter: or_null_condition, _ = self.build_filter( - ('%s__isnull' % trimmed_prefix, True), - current_negated=True, branch_negated=True, can_reuse=can_reuse) + ("%s__isnull" % trimmed_prefix, True), + current_negated=True, + branch_negated=True, + can_reuse=can_reuse, + ) condition.add(or_null_condition, OR) # Note that the end result will be: # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL. @@ -1907,8 +2055,8 @@ class Query(BaseExpression): self.values_select = () def add_select_col(self, col, name): - self.select += col, - self.values_select += name, + self.select += (col,) + self.values_select += (name,) def set_select(self, cols): self.default_cols = False @@ -1934,7 +2082,9 @@ class Query(BaseExpression): for name in field_names: # Join promotion note - we must not remove any rows here, so # if there is no existing joins, use outer join. - join_info = self.setup_joins(name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) + join_info = self.setup_joins( + name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m + ) targets, final_alias, joins = self.trim_joins( join_info.targets, join_info.joins, @@ -1957,12 +2107,18 @@ class Query(BaseExpression): "it." % name ) else: - names = sorted([ - *get_field_names_from_opts(opts), *self.extra, - *self.annotation_select, *self._filtered_relations - ]) - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(names))) + names = sorted( + [ + *get_field_names_from_opts(opts), + *self.extra, + *self.annotation_select, + *self._filtered_relations, + ] + ) + raise FieldError( + "Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(names)) + ) def add_ordering(self, *ordering): """ @@ -1976,9 +2132,9 @@ class Query(BaseExpression): errors = [] for item in ordering: if isinstance(item, str): - if item == '?': + if item == "?": continue - if item.startswith('-'): + if item.startswith("-"): item = item[1:] if item in self.annotations: continue @@ -1987,15 +2143,15 @@ class Query(BaseExpression): # names_to_path() validates the lookup. A descriptive # FieldError will be raise if it's not. self.names_to_path(item.split(LOOKUP_SEP), self.model._meta) - elif not hasattr(item, 'resolve_expression'): + elif not hasattr(item, "resolve_expression"): errors.append(item) - if getattr(item, 'contains_aggregate', False): + if getattr(item, "contains_aggregate", False): raise FieldError( - 'Using an aggregate in order_by() without also including ' - 'it in annotate() is not allowed: %s' % item + "Using an aggregate in order_by() without also including " + "it in annotate() is not allowed: %s" % item ) if errors: - raise FieldError('Invalid order_by arguments: %s' % errors) + raise FieldError("Invalid order_by arguments: %s" % errors) if ordering: self.order_by += ordering else: @@ -2008,7 +2164,9 @@ class Query(BaseExpression): If 'clear_default' is True, there will be no ordering in the resulting query (not even the model's default). """ - if not force and (self.is_sliced or self.distinct_fields or self.select_for_update): + if not force and ( + self.is_sliced or self.distinct_fields or self.select_for_update + ): return self.order_by = () self.extra_order_by = () @@ -2031,10 +2189,9 @@ class Query(BaseExpression): for join in list(self.alias_map.values())[1:]: # Skip base table. model = join.join_field.related_model if model not in seen_models: - column_names.update({ - field.column - for field in model._meta.local_concrete_fields - }) + column_names.update( + {field.column for field in model._meta.local_concrete_fields} + ) seen_models.add(model) group_by = list(self.select) @@ -2082,7 +2239,7 @@ class Query(BaseExpression): entry_params = [] pos = entry.find("%s") while pos != -1: - if pos == 0 or entry[pos - 1] != '%': + if pos == 0 or entry[pos - 1] != "%": entry_params.append(next(param_iter)) pos = entry.find("%s", pos + 2) select_pairs[name] = (entry, entry_params) @@ -2135,8 +2292,8 @@ class Query(BaseExpression): """ existing, defer = self.deferred_loading field_names = set(field_names) - if 'pk' in field_names: - field_names.remove('pk') + if "pk" in field_names: + field_names.remove("pk") field_names.add(self.get_meta().pk.name) if defer: @@ -2224,7 +2381,9 @@ class Query(BaseExpression): # Selected annotations must be known before setting the GROUP BY # clause. if self.group_by is True: - self.add_fields((f.attname for f in self.model._meta.concrete_fields), False) + self.add_fields( + (f.attname for f in self.model._meta.concrete_fields), False + ) # Disable GROUP BY aliases to avoid orphaning references to the # SELECT clause which is about to be cleared. self.set_group_by(allow_aliases=False) @@ -2254,7 +2413,8 @@ class Query(BaseExpression): return {} elif self.annotation_select_mask is not None: self._annotation_select_cache = { - k: v for k, v in self.annotations.items() + k: v + for k, v in self.annotations.items() if k in self.annotation_select_mask } return self._annotation_select_cache @@ -2269,8 +2429,7 @@ class Query(BaseExpression): return {} elif self.extra_select_mask is not None: self._extra_select_cache = { - k: v for k, v in self.extra.items() - if k in self.extra_select_mask + k: v for k, v in self.extra.items() if k in self.extra_select_mask } return self._extra_select_cache else: @@ -2297,8 +2456,7 @@ class Query(BaseExpression): # the lookup part of the query. That is, avoid trimming # joins generated for F() expressions. lookup_tables = [ - t for t in self.alias_map - if t in self._lookup_joins or t == self.base_table + t for t in self.alias_map if t in self._lookup_joins or t == self.base_table ] for trimmed_paths, path in enumerate(all_paths): if path.m2m: @@ -2317,8 +2475,7 @@ class Query(BaseExpression): break trimmed_prefix.append(name) paths_in_prefix -= len(path) - trimmed_prefix.append( - join_field.foreign_related_fields[0].name) + trimmed_prefix.append(join_field.foreign_related_fields[0].name) trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix) # Lets still see if we can trim the first join from the inner query # (that is, self). We can't do this for: @@ -2331,7 +2488,9 @@ class Query(BaseExpression): select_fields = [r[0] for r in join_field.related_fields] select_alias = lookup_tables[trimmed_paths + 1] self.unref_alias(lookup_tables[trimmed_paths]) - extra_restriction = join_field.get_extra_restriction(None, lookup_tables[trimmed_paths + 1]) + extra_restriction = join_field.get_extra_restriction( + None, lookup_tables[trimmed_paths + 1] + ) if extra_restriction: self.where.add(extra_restriction, AND) else: @@ -2367,12 +2526,12 @@ class Query(BaseExpression): # is_nullable() is needed to the compiler stage, but that is not easy # to do currently. return field.null or ( - field.empty_strings_allowed and - connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls + field.empty_strings_allowed + and connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls ) -def get_order_dir(field, default='ASC'): +def get_order_dir(field, default="ASC"): """ Return the field name and direction for an order specification. For example, '-foo' is returned as ('foo', 'DESC'). @@ -2381,7 +2540,7 @@ def get_order_dir(field, default='ASC'): prefix) should sort. The '-' prefix always sorts the opposite way. """ dirn = ORDER_DIR[default] - if field[0] == '-': + if field[0] == "-": return field[1:], dirn[1] return field, dirn[0] @@ -2428,8 +2587,8 @@ class JoinPromoter: def __repr__(self): return ( - f'{self.__class__.__qualname__}(connector={self.connector!r}, ' - f'num_children={self.num_children!r}, negated={self.negated!r})' + f"{self.__class__.__qualname__}(connector={self.connector!r}, " + f"num_children={self.num_children!r}, negated={self.negated!r})" ) def add_votes(self, votes): @@ -2461,7 +2620,7 @@ class JoinPromoter: # to rel_a would remove a valid match from the query. So, we need # to promote any existing INNER to LOUTER (it is possible this # promotion in turn will be demoted later on). - if self.effective_connector == 'OR' and votes < self.num_children: + if self.effective_connector == "OR" and votes < self.num_children: to_promote.add(table) # If connector is AND and there is a filter that can match only # when there is a joinable row, then use INNER. For example, in @@ -2473,8 +2632,9 @@ class JoinPromoter: # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell) # then if rel_a doesn't produce any rows, the whole condition # can't match. Hence we can safely use INNER join. - if self.effective_connector == 'AND' or ( - self.effective_connector == 'OR' and votes == self.num_children): + if self.effective_connector == "AND" or ( + self.effective_connector == "OR" and votes == self.num_children + ): to_demote.add(table) # Finally, what happens in cases where we have: # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0 diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f6a371a925..04063f73bc 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -3,18 +3,16 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ from django.core.exceptions import FieldError -from django.db.models.sql.constants import ( - CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, -) +from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS from django.db.models.sql.query import Query -__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery'] +__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] class DeleteQuery(Query): """A DELETE SQL query.""" - compiler = 'SQLDeleteCompiler' + compiler = "SQLDeleteCompiler" def do_query(self, table, where, using): self.alias_map = {table: self.alias_map[table]} @@ -38,17 +36,19 @@ class DeleteQuery(Query): for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.clear_where() self.add_filter( - f'{field.attname}__in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE], + f"{field.attname}__in", + pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE], + ) + num_deleted += self.do_query( + self.get_meta().db_table, self.where, using=using ) - num_deleted += self.do_query(self.get_meta().db_table, self.where, using=using) return num_deleted class UpdateQuery(Query): """An UPDATE SQL query.""" - compiler = 'SQLUpdateCompiler' + compiler = "SQLUpdateCompiler" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -72,7 +72,9 @@ class UpdateQuery(Query): self.add_update_values(values) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.clear_where() - self.add_filter('pk__in', pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]) + self.add_filter( + "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE] + ) self.get_compiler(using).execute_sql(NO_RESULTS) def add_update_values(self, values): @@ -84,12 +86,14 @@ class UpdateQuery(Query): values_seq = [] for name, val in values.items(): field = self.get_meta().get_field(name) - direct = not (field.auto_created and not field.concrete) or not field.concrete + direct = ( + not (field.auto_created and not field.concrete) or not field.concrete + ) model = field.model._meta.concrete_model if not direct or (field.is_relation and field.many_to_many): raise FieldError( - 'Cannot update model field %r (only non-relations and ' - 'foreign keys permitted).' % field + "Cannot update model field %r (only non-relations and " + "foreign keys permitted)." % field ) if model is not self.get_meta().concrete_model: self.add_related_update(model, field, val) @@ -104,7 +108,7 @@ class UpdateQuery(Query): called add_update_targets() to hint at the extra information here. """ for field, model, val in values_seq: - if hasattr(val, 'resolve_expression'): + if hasattr(val, "resolve_expression"): # Resolve expressions here so that annotations are no longer needed val = val.resolve_expression(self, allow_joins=False, for_save=True) self.values.append((field, model, val)) @@ -130,15 +134,17 @@ class UpdateQuery(Query): query = UpdateQuery(model) query.values = values if self.related_ids is not None: - query.add_filter('pk__in', self.related_ids) + query.add_filter("pk__in", self.related_ids) result.append(query) return result class InsertQuery(Query): - compiler = 'SQLInsertCompiler' + compiler = "SQLInsertCompiler" - def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs): + def __init__( + self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs + ): super().__init__(*args, **kwargs) self.fields = [] self.objs = [] @@ -158,7 +164,7 @@ class AggregateQuery(Query): elements in the provided list. """ - compiler = 'SQLAggregateCompiler' + compiler = "SQLAggregateCompiler" def __init__(self, model, inner_query): self.inner_query = inner_query diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 50ff13be75..532780fd98 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -7,8 +7,8 @@ from django.utils import tree from django.utils.functional import cached_property # Connection types -AND = 'AND' -OR = 'OR' +AND = "AND" +OR = "OR" class WhereNode(tree.Node): @@ -25,6 +25,7 @@ class WhereNode(tree.Node): relabeled_clone() method or relabel_aliases() and clone() methods and contains_aggregate attribute. """ + default = AND resolved = False conditional = True @@ -40,15 +41,15 @@ class WhereNode(tree.Node): in_negated = negated ^ self.negated # If the effective connector is OR and this node contains an aggregate, # then we need to push the whole branch to HAVING clause. - may_need_split = ( - (in_negated and self.connector == AND) or - (not in_negated and self.connector == OR)) + may_need_split = (in_negated and self.connector == AND) or ( + not in_negated and self.connector == OR + ) if may_need_split and self.contains_aggregate: return None, self where_parts = [] having_parts = [] for c in self.children: - if hasattr(c, 'split_having'): + if hasattr(c, "split_having"): where_part, having_part = c.split_having(in_negated) if where_part is not None: where_parts.append(where_part) @@ -58,8 +59,16 @@ class WhereNode(tree.Node): having_parts.append(c) else: where_parts.append(c) - having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None - where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None + having_node = ( + self.__class__(having_parts, self.connector, self.negated) + if having_parts + else None + ) + where_node = ( + self.__class__(where_parts, self.connector, self.negated) + if where_parts + else None + ) return where_node, having_node def as_sql(self, compiler, connection): @@ -94,24 +103,24 @@ class WhereNode(tree.Node): # counts. if empty_needed == 0: if self.negated: - return '', [] + return "", [] else: raise EmptyResultSet if full_needed == 0: if self.negated: raise EmptyResultSet else: - return '', [] - conn = ' %s ' % self.connector + return "", [] + conn = " %s " % self.connector sql_string = conn.join(result) if sql_string: if self.negated: # Some backends (Oracle at least) need parentheses # around the inner SQL in the negated case, even if the # inner SQL contains just a single expression. - sql_string = 'NOT (%s)' % sql_string + sql_string = "NOT (%s)" % sql_string elif len(result) > 1 or self.resolved: - sql_string = '(%s)' % sql_string + sql_string = "(%s)" % sql_string return sql_string, result_params def get_group_by_cols(self, alias=None): @@ -133,10 +142,10 @@ class WhereNode(tree.Node): mapping old (current) alias values to the new values. """ for pos, child in enumerate(self.children): - if hasattr(child, 'relabel_aliases'): + if hasattr(child, "relabel_aliases"): # For example another WhereNode child.relabel_aliases(change_map) - elif hasattr(child, 'relabeled_clone'): + elif hasattr(child, "relabeled_clone"): self.children[pos] = child.relabeled_clone(change_map) def clone(self): @@ -146,10 +155,12 @@ class WhereNode(tree.Node): value) tuples, or objects supporting .clone(). """ clone = self.__class__._new_instance( - children=None, connector=self.connector, negated=self.negated, + children=None, + connector=self.connector, + negated=self.negated, ) for child in self.children: - if hasattr(child, 'clone'): + if hasattr(child, "clone"): clone.children.append(child.clone()) else: clone.children.append(child) @@ -185,18 +196,18 @@ class WhereNode(tree.Node): @staticmethod def _resolve_leaf(expr, query, *args, **kwargs): - if hasattr(expr, 'resolve_expression'): + if hasattr(expr, "resolve_expression"): expr = expr.resolve_expression(query, *args, **kwargs) return expr @classmethod def _resolve_node(cls, node, query, *args, **kwargs): - if hasattr(node, 'children'): + if hasattr(node, "children"): for child in node.children: cls._resolve_node(child, query, *args, **kwargs) - if hasattr(node, 'lhs'): + if hasattr(node, "lhs"): node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs) - if hasattr(node, 'rhs'): + if hasattr(node, "rhs"): node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs) def resolve_expression(self, *args, **kwargs): @@ -208,6 +219,7 @@ class WhereNode(tree.Node): @cached_property def output_field(self): from django.db.models import BooleanField + return BooleanField() def select_format(self, compiler, sql, params): @@ -215,7 +227,7 @@ class WhereNode(tree.Node): # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP # BY list. if not compiler.connection.features.supports_boolean_expr_in_select_clause: - sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END' + sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" return sql, params def get_db_converters(self, connection): @@ -227,6 +239,7 @@ class WhereNode(tree.Node): class NothingNode: """A node that matches nothing.""" + contains_aggregate = False def as_sql(self, compiler=None, connection=None): diff --git a/django/db/models/utils.py b/django/db/models/utils.py index 949c528469..5521f3cca5 100644 --- a/django/db/models/utils.py +++ b/django/db/models/utils.py @@ -46,7 +46,7 @@ def create_namedtuple_class(*names): return unpickle_named_row, (names, tuple(self)) return type( - 'Row', - (namedtuple('Row', names),), - {'__reduce__': __reduce__, '__slots__': ()}, + "Row", + (namedtuple("Row", names),), + {"__reduce__": __reduce__, "__slots__": ()}, ) diff --git a/django/db/transaction.py b/django/db/transaction.py index b61785754f..b3c7b4bbaa 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -1,12 +1,17 @@ from contextlib import ContextDecorator, contextmanager from django.db import ( - DEFAULT_DB_ALIAS, DatabaseError, Error, ProgrammingError, connections, + DEFAULT_DB_ALIAS, + DatabaseError, + Error, + ProgrammingError, + connections, ) class TransactionManagementError(ProgrammingError): """Transaction management is used improperly.""" + pass @@ -132,6 +137,7 @@ def on_commit(func, using=None): # Decorators / context managers # ################################# + class Atomic(ContextDecorator): """ Guarantee the atomic execution of a given block. @@ -176,13 +182,13 @@ class Atomic(ContextDecorator): connection = get_connection(self.using) if ( - self.durable and - connection.atomic_blocks and - not connection.atomic_blocks[-1]._from_testcase + self.durable + and connection.atomic_blocks + and not connection.atomic_blocks[-1]._from_testcase ): raise RuntimeError( - 'A durable atomic block cannot be nested within another ' - 'atomic block.' + "A durable atomic block cannot be nested within another " + "atomic block." ) if not connection.in_atomic_block: # Reset state when entering an outermost atomic block. @@ -206,7 +212,9 @@ class Atomic(ContextDecorator): else: connection.savepoint_ids.append(None) else: - connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True) + connection.set_autocommit( + False, force_begin_transaction_with_broken_autocommit=True + ) connection.in_atomic_block = True if connection.in_atomic_block: diff --git a/django/db/utils.py b/django/db/utils.py index 82498d1df6..7ef62ae5a2 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -3,14 +3,15 @@ from importlib import import_module from django.conf import settings from django.core.exceptions import ImproperlyConfigured + # For backwards compatibility with Django < 3.2 from django.utils.connection import ConnectionDoesNotExist # NOQA: F401 from django.utils.connection import BaseConnectionHandler from django.utils.functional import cached_property from django.utils.module_loading import import_string -DEFAULT_DB_ALIAS = 'default' -DJANGO_VERSION_PICKLE_KEY = '_django_version' +DEFAULT_DB_ALIAS = "default" +DJANGO_VERSION_PICKLE_KEY = "_django_version" class Error(Exception): @@ -70,15 +71,15 @@ class DatabaseErrorWrapper: if exc_type is None: return for dj_exc_type in ( - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, - DatabaseError, - InterfaceError, - Error, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, + DatabaseError, + InterfaceError, + Error, ): db_exc_type = getattr(self.wrapper.Database, dj_exc_type.__name__) if issubclass(exc_type, db_exc_type): @@ -95,6 +96,7 @@ class DatabaseErrorWrapper: def inner(*args, **kwargs): with self: return func(*args, **kwargs) + return inner @@ -104,20 +106,22 @@ def load_backend(backend_name): backend name, or raise an error if it doesn't exist. """ # This backend was renamed in Django 1.9. - if backend_name == 'django.db.backends.postgresql_psycopg2': - backend_name = 'django.db.backends.postgresql' + if backend_name == "django.db.backends.postgresql_psycopg2": + backend_name = "django.db.backends.postgresql" try: - return import_module('%s.base' % backend_name) + return import_module("%s.base" % backend_name) except ImportError as e_user: # The database backend wasn't found. Display a helpful error message # listing all built-in database backends. import django.db.backends + builtin_backends = [ - name for _, name, ispkg in pkgutil.iter_modules(django.db.backends.__path__) - if ispkg and name not in {'base', 'dummy'} + name + for _, name, ispkg in pkgutil.iter_modules(django.db.backends.__path__) + if ispkg and name not in {"base", "dummy"} ] - if backend_name not in ['django.db.backends.%s' % b for b in builtin_backends]: + if backend_name not in ["django.db.backends.%s" % b for b in builtin_backends]: backend_reprs = map(repr, sorted(builtin_backends)) raise ImproperlyConfigured( "%r isn't an available database backend or couldn't be " @@ -132,7 +136,7 @@ def load_backend(backend_name): class ConnectionHandler(BaseConnectionHandler): - settings_name = 'DATABASES' + settings_name = "DATABASES" # Connections needs to still be an actual thread local, as it's truly # thread-critical. Database backends should use @async_unsafe to protect # their code from async contexts, but this will give those contexts @@ -143,13 +147,13 @@ class ConnectionHandler(BaseConnectionHandler): def configure_settings(self, databases): databases = super().configure_settings(databases) if databases == {}: - databases[DEFAULT_DB_ALIAS] = {'ENGINE': 'django.db.backends.dummy'} + databases[DEFAULT_DB_ALIAS] = {"ENGINE": "django.db.backends.dummy"} elif DEFAULT_DB_ALIAS not in databases: raise ImproperlyConfigured( f"You must define a '{DEFAULT_DB_ALIAS}' database." ) elif databases[DEFAULT_DB_ALIAS] == {}: - databases[DEFAULT_DB_ALIAS]['ENGINE'] = 'django.db.backends.dummy' + databases[DEFAULT_DB_ALIAS]["ENGINE"] = "django.db.backends.dummy" return databases @property @@ -166,17 +170,17 @@ class ConnectionHandler(BaseConnectionHandler): except KeyError: raise self.exception_class(f"The connection '{alias}' doesn't exist.") - conn.setdefault('ATOMIC_REQUESTS', False) - conn.setdefault('AUTOCOMMIT', True) - conn.setdefault('ENGINE', 'django.db.backends.dummy') - if conn['ENGINE'] == 'django.db.backends.' or not conn['ENGINE']: - conn['ENGINE'] = 'django.db.backends.dummy' - conn.setdefault('CONN_MAX_AGE', 0) - conn.setdefault('CONN_HEALTH_CHECKS', False) - conn.setdefault('OPTIONS', {}) - conn.setdefault('TIME_ZONE', None) - for setting in ['NAME', 'USER', 'PASSWORD', 'HOST', 'PORT']: - conn.setdefault(setting, '') + conn.setdefault("ATOMIC_REQUESTS", False) + conn.setdefault("AUTOCOMMIT", True) + conn.setdefault("ENGINE", "django.db.backends.dummy") + if conn["ENGINE"] == "django.db.backends." or not conn["ENGINE"]: + conn["ENGINE"] = "django.db.backends.dummy" + conn.setdefault("CONN_MAX_AGE", 0) + conn.setdefault("CONN_HEALTH_CHECKS", False) + conn.setdefault("OPTIONS", {}) + conn.setdefault("TIME_ZONE", None) + for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]: + conn.setdefault(setting, "") def prepare_test_settings(self, alias): """ @@ -187,13 +191,13 @@ class ConnectionHandler(BaseConnectionHandler): except KeyError: raise self.exception_class(f"The connection '{alias}' doesn't exist.") - test_settings = conn.setdefault('TEST', {}) + test_settings = conn.setdefault("TEST", {}) default_test_settings = [ - ('CHARSET', None), - ('COLLATION', None), - ('MIGRATE', True), - ('MIRROR', None), - ('NAME', None), + ("CHARSET", None), + ("COLLATION", None), + ("MIGRATE", True), + ("MIRROR", None), + ("NAME", None), ] for key, value in default_test_settings: test_settings.setdefault(key, value) @@ -202,7 +206,7 @@ class ConnectionHandler(BaseConnectionHandler): self.ensure_defaults(alias) self.prepare_test_settings(alias) db = self.databases[alias] - backend = load_backend(db['ENGINE']) + backend = load_backend(db["ENGINE"]) return backend.DatabaseWrapper(db, alias) def close_all(self): @@ -247,14 +251,15 @@ class ConnectionRouter: chosen_db = method(model, **hints) if chosen_db: return chosen_db - instance = hints.get('instance') + instance = hints.get("instance") if instance is not None and instance._state.db: return instance._state.db return DEFAULT_DB_ALIAS + return _route_db - db_for_read = _router_func('db_for_read') - db_for_write = _router_func('db_for_write') + db_for_read = _router_func("db_for_read") + db_for_write = _router_func("db_for_write") def allow_relation(self, obj1, obj2, **hints): for router in self.routers: diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py index d3630ae784..86eb1c3b20 100644 --- a/django/dispatch/dispatcher.py +++ b/django/dispatch/dispatcher.py @@ -4,11 +4,11 @@ import weakref from django.utils.inspect import func_accepts_kwargs -logger = logging.getLogger('django.dispatch') +logger = logging.getLogger("django.dispatch") def _make_id(target): - if hasattr(target, '__func__'): + if hasattr(target, "__func__"): return (id(target.__self__), id(target.__func__)) return id(target) @@ -28,6 +28,7 @@ class Signal: receivers { receiverkey (id) : weakref(receiver) } """ + def __init__(self, use_caching=False): """ Create a new signal. @@ -81,10 +82,12 @@ class Signal: # If DEBUG is on, check that we got a good receiver if settings.configured and settings.DEBUG: if not callable(receiver): - raise TypeError('Signal receivers must be callable.') + raise TypeError("Signal receivers must be callable.") # Check for **kwargs if not func_accepts_kwargs(receiver): - raise ValueError("Signal receivers must accept keyword arguments (**kwargs).") + raise ValueError( + "Signal receivers must accept keyword arguments (**kwargs)." + ) if dispatch_uid: lookup_key = (dispatch_uid, _make_id(sender)) @@ -95,7 +98,7 @@ class Signal: ref = weakref.ref receiver_object = receiver # Check for bound methods - if hasattr(receiver, '__self__') and hasattr(receiver, '__func__'): + if hasattr(receiver, "__self__") and hasattr(receiver, "__func__"): ref = weakref.WeakMethod receiver_object = receiver.__self__ receiver = ref(receiver) @@ -164,7 +167,10 @@ class Signal: Return a list of tuple pairs [(receiver, response), ... ]. """ - if not self.receivers or self.sender_receivers_cache.get(sender) is NO_RECEIVERS: + if ( + not self.receivers + or self.sender_receivers_cache.get(sender) is NO_RECEIVERS + ): return [] return [ @@ -191,7 +197,10 @@ class Signal: If any receiver raises an error (specifically any subclass of Exception), return the error instance as the result for that receiver. """ - if not self.receivers or self.sender_receivers_cache.get(sender) is NO_RECEIVERS: + if ( + not self.receivers + or self.sender_receivers_cache.get(sender) is NO_RECEIVERS + ): return [] # Call each receiver with whatever arguments it can accept. @@ -202,7 +211,7 @@ class Signal: response = receiver(signal=self, sender=sender, **named) except Exception as err: logger.error( - 'Error calling %s in Signal.send_robust() (%s)', + "Error calling %s in Signal.send_robust() (%s)", receiver.__qualname__, err, exc_info=err, @@ -217,8 +226,9 @@ class Signal: if self._dead_receivers: self._dead_receivers = False self.receivers = [ - r for r in self.receivers - if not(isinstance(r[1], weakref.ReferenceType) and r[1]() is None) + r + for r in self.receivers + if not (isinstance(r[1], weakref.ReferenceType) and r[1]() is None) ] def _live_receivers(self, sender): @@ -283,6 +293,7 @@ def receiver(signal, **kwargs): def signals_receiver(sender, **kwargs): ... """ + def _decorator(func): if isinstance(signal, (list, tuple)): for s in signal: @@ -290,4 +301,5 @@ def receiver(signal, **kwargs): else: signal.connect(func, **kwargs) return func + return _decorator diff --git a/django/forms/boundfield.py b/django/forms/boundfield.py index 62d1823d57..e83160645e 100644 --- a/django/forms/boundfield.py +++ b/django/forms/boundfield.py @@ -7,12 +7,13 @@ from django.utils.functional import cached_property from django.utils.html import format_html, html_safe from django.utils.translation import gettext_lazy as _ -__all__ = ('BoundField',) +__all__ = ("BoundField",) @html_safe class BoundField: "A Field plus data" + def __init__(self, form, field, name): self.form = form self.field = field @@ -24,7 +25,7 @@ class BoundField: self.label = pretty_name(name) else: self.label = self.field.label - self.help_text = field.help_text or '' + self.help_text = field.help_text or "" def __str__(self): """Render this field as an HTML widget.""" @@ -41,12 +42,14 @@ class BoundField: This property is cached so that only one database query occurs when rendering ModelChoiceFields. """ - id_ = self.field.widget.attrs.get('id') or self.auto_id - attrs = {'id': id_} if id_ else {} + id_ = self.field.widget.attrs.get("id") or self.auto_id + attrs = {"id": id_} if id_ else {} attrs = self.build_widget_attrs(attrs) return [ BoundWidget(self.field.widget, widget, self.form.renderer) - for widget in self.field.widget.subwidgets(self.html_name, self.value(), attrs=attrs) + for widget in self.field.widget.subwidgets( + self.html_name, self.value(), attrs=attrs + ) ] def __bool__(self): @@ -64,7 +67,7 @@ class BoundField: # from templates. if not isinstance(idx, (int, slice)): raise TypeError( - 'BoundField indices must be integers or slices, not %s.' + "BoundField indices must be integers or slices, not %s." % type(idx).__name__ ) return self.subwidgets[idx] @@ -74,7 +77,9 @@ class BoundField: """ Return an ErrorList (empty if there are no errors) for this field. """ - return self.form.errors.get(self.name, self.form.error_class(renderer=self.form.renderer)) + return self.form.errors.get( + self.name, self.form.error_class(renderer=self.form.renderer) + ) def as_widget(self, widget=None, attrs=None, only_initial=False): """ @@ -87,8 +92,10 @@ class BoundField: widget.is_localized = True attrs = attrs or {} attrs = self.build_widget_attrs(attrs, widget) - if self.auto_id and 'id' not in widget.attrs: - attrs.setdefault('id', self.html_initial_id if only_initial else self.auto_id) + if self.auto_id and "id" not in widget.attrs: + attrs.setdefault( + "id", self.html_initial_id if only_initial else self.auto_id + ) return widget.render( name=self.html_initial_name if only_initial else self.html_name, value=self.value(), @@ -134,7 +141,8 @@ class BoundField: if field.show_hidden_initial: hidden_widget = field.hidden_widget() initial_value = self.form._widget_data_value( - hidden_widget, self.html_initial_name, + hidden_widget, + self.html_initial_name, ) try: initial_value = field.to_python(initial_value) @@ -157,31 +165,34 @@ class BoundField: """ contents = contents or self.label if label_suffix is None: - label_suffix = (self.field.label_suffix if self.field.label_suffix is not None - else self.form.label_suffix) + label_suffix = ( + self.field.label_suffix + if self.field.label_suffix is not None + else self.form.label_suffix + ) # Only add the suffix if the label does not end in punctuation. # Translators: If found as last label character, these punctuation # characters will prevent the default label_suffix to be appended to the label - if label_suffix and contents and contents[-1] not in _(':?.!'): - contents = format_html('{}{}', contents, label_suffix) + if label_suffix and contents and contents[-1] not in _(":?.!"): + contents = format_html("{}{}", contents, label_suffix) widget = self.field.widget - id_ = widget.attrs.get('id') or self.auto_id + id_ = widget.attrs.get("id") or self.auto_id if id_: id_for_label = widget.id_for_label(id_) if id_for_label: - attrs = {**(attrs or {}), 'for': id_for_label} - if self.field.required and hasattr(self.form, 'required_css_class'): + attrs = {**(attrs or {}), "for": id_for_label} + if self.field.required and hasattr(self.form, "required_css_class"): attrs = attrs or {} - if 'class' in attrs: - attrs['class'] += ' ' + self.form.required_css_class + if "class" in attrs: + attrs["class"] += " " + self.form.required_css_class else: - attrs['class'] = self.form.required_css_class + attrs["class"] = self.form.required_css_class context = { - 'field': self, - 'label': contents, - 'attrs': attrs, - 'use_tag': bool(id_), - 'tag': tag or 'label', + "field": self, + "label": contents, + "attrs": attrs, + "use_tag": bool(id_), + "tag": tag or "label", } return self.form.render(self.form.template_name_label, context) @@ -195,20 +206,20 @@ class BoundField: label_suffix overrides the form's label_suffix. """ - return self.label_tag(contents, attrs, label_suffix, tag='legend') + return self.label_tag(contents, attrs, label_suffix, tag="legend") def css_classes(self, extra_classes=None): """ Return a string of space-separated CSS classes for this field. """ - if hasattr(extra_classes, 'split'): + if hasattr(extra_classes, "split"): extra_classes = extra_classes.split() extra_classes = set(extra_classes or []) - if self.errors and hasattr(self.form, 'error_css_class'): + if self.errors and hasattr(self.form, "error_css_class"): extra_classes.add(self.form.error_css_class) - if self.field.required and hasattr(self.form, 'required_css_class'): + if self.field.required and hasattr(self.form, "required_css_class"): extra_classes.add(self.form.required_css_class) - return ' '.join(extra_classes) + return " ".join(extra_classes) @property def is_hidden(self): @@ -222,11 +233,11 @@ class BoundField: associated Form has specified auto_id. Return an empty string otherwise. """ auto_id = self.form.auto_id # Boolean or string - if auto_id and '%s' in str(auto_id): + if auto_id and "%s" in str(auto_id): return auto_id % self.html_name elif auto_id: return self.html_name - return '' + return "" @property def id_for_label(self): @@ -236,7 +247,7 @@ class BoundField: it has a single widget or a MultiWidget. """ widget = self.field.widget - id_ = widget.attrs.get('id') or self.auto_id + id_ = widget.attrs.get("id") or self.auto_id return widget.id_for_label(id_) @cached_property @@ -246,25 +257,34 @@ class BoundField: def build_widget_attrs(self, attrs, widget=None): widget = widget or self.field.widget attrs = dict(attrs) # Copy attrs to avoid modifying the argument. - if widget.use_required_attribute(self.initial) and self.field.required and self.form.use_required_attribute: + if ( + widget.use_required_attribute(self.initial) + and self.field.required + and self.form.use_required_attribute + ): # MultiValueField has require_all_fields: if False, fall back # on subfields. if ( - hasattr(self.field, 'require_all_fields') and - not self.field.require_all_fields and - isinstance(self.field.widget, MultiWidget) + hasattr(self.field, "require_all_fields") + and not self.field.require_all_fields + and isinstance(self.field.widget, MultiWidget) ): for subfield, subwidget in zip(self.field.fields, widget.widgets): - subwidget.attrs['required'] = subwidget.use_required_attribute(self.initial) and subfield.required + subwidget.attrs["required"] = ( + subwidget.use_required_attribute(self.initial) + and subfield.required + ) else: - attrs['required'] = True + attrs["required"] = True if self.field.disabled: - attrs['disabled'] = True + attrs["disabled"] = True return attrs @property def widget_type(self): - return re.sub(r'widget$|input$', '', self.field.widget.__class__.__name__.lower()) + return re.sub( + r"widget$|input$", "", self.field.widget.__class__.__name__.lower() + ) @html_safe @@ -281,6 +301,7 @@ class BoundWidget: </label> {% endfor %} """ + def __init__(self, parent_widget, data, renderer): self.parent_widget = parent_widget self.data = data @@ -290,19 +311,19 @@ class BoundWidget: return self.tag(wrap_label=True) def tag(self, wrap_label=False): - context = {'widget': {**self.data, 'wrap_label': wrap_label}} + context = {"widget": {**self.data, "wrap_label": wrap_label}} return self.parent_widget._render(self.template_name, context, self.renderer) @property def template_name(self): - if 'template_name' in self.data: - return self.data['template_name'] + if "template_name" in self.data: + return self.data["template_name"] return self.parent_widget.template_name @property def id_for_label(self): - return self.data['attrs'].get('id') + return self.data["attrs"].get("id") @property def choice_label(self): - return self.data['label'] + return self.data["label"] diff --git a/django/forms/fields.py b/django/forms/fields.py index 65d6a9ec82..a7031936dd 100644 --- a/django/forms/fields.py +++ b/django/forms/fields.py @@ -19,45 +19,94 @@ from django.core.exceptions import ValidationError from django.forms.boundfield import BoundField from django.forms.utils import from_current_timezone, to_current_timezone from django.forms.widgets import ( - FILE_INPUT_CONTRADICTION, CheckboxInput, ClearableFileInput, DateInput, - DateTimeInput, EmailInput, FileInput, HiddenInput, MultipleHiddenInput, - NullBooleanSelect, NumberInput, Select, SelectMultiple, - SplitDateTimeWidget, SplitHiddenDateTimeWidget, Textarea, TextInput, - TimeInput, URLInput, + FILE_INPUT_CONTRADICTION, + CheckboxInput, + ClearableFileInput, + DateInput, + DateTimeInput, + EmailInput, + FileInput, + HiddenInput, + MultipleHiddenInput, + NullBooleanSelect, + NumberInput, + Select, + SelectMultiple, + SplitDateTimeWidget, + SplitHiddenDateTimeWidget, + Textarea, + TextInput, + TimeInput, + URLInput, ) from django.utils import formats from django.utils.dateparse import parse_datetime, parse_duration from django.utils.duration import duration_string from django.utils.ipv6 import clean_ipv6_address from django.utils.regex_helper import _lazy_re_compile -from django.utils.translation import gettext_lazy as _, ngettext_lazy +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import ngettext_lazy __all__ = ( - 'Field', 'CharField', 'IntegerField', - 'DateField', 'TimeField', 'DateTimeField', 'DurationField', - 'RegexField', 'EmailField', 'FileField', 'ImageField', 'URLField', - 'BooleanField', 'NullBooleanField', 'ChoiceField', 'MultipleChoiceField', - 'ComboField', 'MultiValueField', 'FloatField', 'DecimalField', - 'SplitDateTimeField', 'GenericIPAddressField', 'FilePathField', - 'JSONField', 'SlugField', 'TypedChoiceField', 'TypedMultipleChoiceField', - 'UUIDField', + "Field", + "CharField", + "IntegerField", + "DateField", + "TimeField", + "DateTimeField", + "DurationField", + "RegexField", + "EmailField", + "FileField", + "ImageField", + "URLField", + "BooleanField", + "NullBooleanField", + "ChoiceField", + "MultipleChoiceField", + "ComboField", + "MultiValueField", + "FloatField", + "DecimalField", + "SplitDateTimeField", + "GenericIPAddressField", + "FilePathField", + "JSONField", + "SlugField", + "TypedChoiceField", + "TypedMultipleChoiceField", + "UUIDField", ) class Field: widget = TextInput # Default widget to use when rendering this type of Field. - hidden_widget = HiddenInput # Default widget to use when rendering this as "hidden". + hidden_widget = ( + HiddenInput # Default widget to use when rendering this as "hidden". + ) default_validators = [] # Default set of validators # Add an 'invalid' entry to default_error_message if you want a specific # field error message not raised by the field validators. default_error_messages = { - 'required': _('This field is required.'), + "required": _("This field is required."), } empty_values = list(validators.EMPTY_VALUES) - def __init__(self, *, required=True, widget=None, label=None, initial=None, - help_text='', error_messages=None, show_hidden_initial=False, - validators=(), localize=False, disabled=False, label_suffix=None): + def __init__( + self, + *, + required=True, + widget=None, + label=None, + initial=None, + help_text="", + error_messages=None, + show_hidden_initial=False, + validators=(), + localize=False, + disabled=False, + label_suffix=None, + ): # required -- Boolean that specifies whether the field is required. # True by default. # widget -- A Widget class, or instance of a Widget class, that should @@ -109,7 +158,7 @@ class Field: messages = {} for c in reversed(self.__class__.__mro__): - messages.update(getattr(c, 'default_error_messages', {})) + messages.update(getattr(c, "default_error_messages", {})) messages.update(error_messages or {}) self.error_messages = messages @@ -125,7 +174,7 @@ class Field: def validate(self, value): if value in self.empty_values and self.required: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError(self.error_messages["required"], code="required") def run_validators(self, value): if value in self.empty_values: @@ -135,7 +184,7 @@ class Field: try: v(value) except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: + if hasattr(e, "code") and e.code in self.error_messages: e.message = self.error_messages[e.code] errors.extend(e.error_list) if errors: @@ -180,15 +229,15 @@ class Field: return False try: data = self.to_python(data) - if hasattr(self, '_coerce'): + if hasattr(self, "_coerce"): return self._coerce(data) != self._coerce(initial) except ValidationError: return True # For purposes of seeing whether something has changed, None is # the same as an empty string, if the data or initial value we get # is None, replace it with ''. - initial_value = initial if initial is not None else '' - data_value = data if data is not None else '' + initial_value = initial if initial is not None else "" + data_value = data if data is not None else "" return initial_value != data_value def get_bound_field(self, form, field_name): @@ -208,7 +257,9 @@ class Field: class CharField(Field): - def __init__(self, *, max_length=None, min_length=None, strip=True, empty_value='', **kwargs): + def __init__( + self, *, max_length=None, min_length=None, strip=True, empty_value="", **kwargs + ): self.max_length = max_length self.min_length = min_length self.strip = strip @@ -234,25 +285,25 @@ class CharField(Field): attrs = super().widget_attrs(widget) if self.max_length is not None and not widget.is_hidden: # The HTML attribute is maxlength, not max_length. - attrs['maxlength'] = str(self.max_length) + attrs["maxlength"] = str(self.max_length) if self.min_length is not None and not widget.is_hidden: # The HTML attribute is minlength, not min_length. - attrs['minlength'] = str(self.min_length) + attrs["minlength"] = str(self.min_length) return attrs class IntegerField(Field): widget = NumberInput default_error_messages = { - 'invalid': _('Enter a whole number.'), + "invalid": _("Enter a whole number."), } - re_decimal = _lazy_re_compile(r'\.0*\s*$') + re_decimal = _lazy_re_compile(r"\.0*\s*$") def __init__(self, *, max_value=None, min_value=None, **kwargs): self.max_value, self.min_value = max_value, min_value - if kwargs.get('localize') and self.widget == NumberInput: + if kwargs.get("localize") and self.widget == NumberInput: # Localized number input is not well supported on most browsers - kwargs.setdefault('widget', super().widget) + kwargs.setdefault("widget", super().widget) super().__init__(**kwargs) if max_value is not None: @@ -272,24 +323,24 @@ class IntegerField(Field): value = formats.sanitize_separators(value) # Strip trailing decimal and zeros. try: - value = int(self.re_decimal.sub('', str(value))) + value = int(self.re_decimal.sub("", str(value))) except (ValueError, TypeError): - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") return value def widget_attrs(self, widget): attrs = super().widget_attrs(widget) if isinstance(widget, NumberInput): if self.min_value is not None: - attrs['min'] = self.min_value + attrs["min"] = self.min_value if self.max_value is not None: - attrs['max'] = self.max_value + attrs["max"] = self.max_value return attrs class FloatField(IntegerField): default_error_messages = { - 'invalid': _('Enter a number.'), + "invalid": _("Enter a number."), } def to_python(self, value): @@ -305,7 +356,7 @@ class FloatField(IntegerField): try: value = float(value) except (ValueError, TypeError): - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") return value def validate(self, value): @@ -313,21 +364,29 @@ class FloatField(IntegerField): if value in self.empty_values: return if not math.isfinite(value): - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") def widget_attrs(self, widget): attrs = super().widget_attrs(widget) - if isinstance(widget, NumberInput) and 'step' not in widget.attrs: - attrs.setdefault('step', 'any') + if isinstance(widget, NumberInput) and "step" not in widget.attrs: + attrs.setdefault("step", "any") return attrs class DecimalField(IntegerField): default_error_messages = { - 'invalid': _('Enter a number.'), + "invalid": _("Enter a number."), } - def __init__(self, *, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): + def __init__( + self, + *, + max_value=None, + min_value=None, + max_digits=None, + decimal_places=None, + **kwargs, + ): self.max_digits, self.decimal_places = max_digits, decimal_places super().__init__(max_value=max_value, min_value=min_value, **kwargs) self.validators.append(validators.DecimalValidator(max_digits, decimal_places)) @@ -346,7 +405,7 @@ class DecimalField(IntegerField): try: value = Decimal(str(value)) except DecimalException: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") return value def validate(self, value): @@ -355,26 +414,25 @@ class DecimalField(IntegerField): return if not value.is_finite(): raise ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) def widget_attrs(self, widget): attrs = super().widget_attrs(widget) - if isinstance(widget, NumberInput) and 'step' not in widget.attrs: + if isinstance(widget, NumberInput) and "step" not in widget.attrs: if self.decimal_places is not None: # Use exponential notation for small values since they might # be parsed as 0 otherwise. ref #20765 step = str(Decimal(1).scaleb(-self.decimal_places)).lower() else: - step = 'any' - attrs.setdefault('step', step) + step = "any" + attrs.setdefault("step", step) return attrs class BaseTemporalField(Field): - def __init__(self, *, input_formats=None, **kwargs): super().__init__(**kwargs) if input_formats is not None: @@ -388,17 +446,17 @@ class BaseTemporalField(Field): return self.strptime(value, format) except (ValueError, TypeError): continue - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") def strptime(self, value, format): - raise NotImplementedError('Subclasses must define this method.') + raise NotImplementedError("Subclasses must define this method.") class DateField(BaseTemporalField): widget = DateInput - input_formats = formats.get_format_lazy('DATE_INPUT_FORMATS') + input_formats = formats.get_format_lazy("DATE_INPUT_FORMATS") default_error_messages = { - 'invalid': _('Enter a valid date.'), + "invalid": _("Enter a valid date."), } def to_python(self, value): @@ -420,10 +478,8 @@ class DateField(BaseTemporalField): class TimeField(BaseTemporalField): widget = TimeInput - input_formats = formats.get_format_lazy('TIME_INPUT_FORMATS') - default_error_messages = { - 'invalid': _('Enter a valid time.') - } + input_formats = formats.get_format_lazy("TIME_INPUT_FORMATS") + default_error_messages = {"invalid": _("Enter a valid time.")} def to_python(self, value): """ @@ -442,15 +498,15 @@ class TimeField(BaseTemporalField): class DateTimeFormatsIterator: def __iter__(self): - yield from formats.get_format('DATETIME_INPUT_FORMATS') - yield from formats.get_format('DATE_INPUT_FORMATS') + yield from formats.get_format("DATETIME_INPUT_FORMATS") + yield from formats.get_format("DATE_INPUT_FORMATS") class DateTimeField(BaseTemporalField): widget = DateTimeInput input_formats = DateTimeFormatsIterator() default_error_messages = { - 'invalid': _('Enter a valid date/time.'), + "invalid": _("Enter a valid date/time."), } def prepare_value(self, value): @@ -473,7 +529,7 @@ class DateTimeField(BaseTemporalField): try: result = parse_datetime(value.strip()) except ValueError: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") if not result: result = super().to_python(value) return from_current_timezone(result) @@ -484,8 +540,8 @@ class DateTimeField(BaseTemporalField): class DurationField(Field): default_error_messages = { - 'invalid': _('Enter a valid duration.'), - 'overflow': _('The number of days must be between {min_days} and {max_days}.') + "invalid": _("Enter a valid duration."), + "overflow": _("The number of days must be between {min_days} and {max_days}."), } def prepare_value(self, value): @@ -501,12 +557,15 @@ class DurationField(Field): try: value = parse_duration(str(value)) except OverflowError: - raise ValidationError(self.error_messages['overflow'].format( - min_days=datetime.timedelta.min.days, - max_days=datetime.timedelta.max.days, - ), code='overflow') + raise ValidationError( + self.error_messages["overflow"].format( + min_days=datetime.timedelta.min.days, + max_days=datetime.timedelta.max.days, + ), + code="overflow", + ) if value is None: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") return value @@ -515,7 +574,7 @@ class RegexField(CharField): """ regex can be either a string or a compiled regular expression object. """ - kwargs.setdefault('strip', False) + kwargs.setdefault("strip", False) super().__init__(**kwargs) self._set_regex(regex) @@ -526,7 +585,10 @@ class RegexField(CharField): if isinstance(regex, str): regex = re.compile(regex) self._regex = regex - if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: + if ( + hasattr(self, "_regex_validator") + and self._regex_validator in self.validators + ): self.validators.remove(self._regex_validator) self._regex_validator = validators.RegexValidator(regex=regex) self.validators.append(self._regex_validator) @@ -545,14 +607,17 @@ class EmailField(CharField): class FileField(Field): widget = ClearableFileInput default_error_messages = { - 'invalid': _("No file was submitted. Check the encoding type on the form."), - 'missing': _("No file was submitted."), - 'empty': _("The submitted file is empty."), - 'max_length': ngettext_lazy( - 'Ensure this filename has at most %(max)d character (it has %(length)d).', - 'Ensure this filename has at most %(max)d characters (it has %(length)d).', - 'max'), - 'contradiction': _('Please either submit a file or check the clear checkbox, not both.') + "invalid": _("No file was submitted. Check the encoding type on the form."), + "missing": _("No file was submitted."), + "empty": _("The submitted file is empty."), + "max_length": ngettext_lazy( + "Ensure this filename has at most %(max)d character (it has %(length)d).", + "Ensure this filename has at most %(max)d characters (it has %(length)d).", + "max", + ), + "contradiction": _( + "Please either submit a file or check the clear checkbox, not both." + ), } def __init__(self, *, max_length=None, allow_empty_file=False, **kwargs): @@ -569,22 +634,26 @@ class FileField(Field): file_name = data.name file_size = data.size except AttributeError: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") if self.max_length is not None and len(file_name) > self.max_length: - params = {'max': self.max_length, 'length': len(file_name)} - raise ValidationError(self.error_messages['max_length'], code='max_length', params=params) + params = {"max": self.max_length, "length": len(file_name)} + raise ValidationError( + self.error_messages["max_length"], code="max_length", params=params + ) if not file_name: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") if not self.allow_empty_file and not file_size: - raise ValidationError(self.error_messages['empty'], code='empty') + raise ValidationError(self.error_messages["empty"], code="empty") return data def clean(self, data, initial=None): # If the widget got contradictory inputs, we raise a validation error if data is FILE_INPUT_CONTRADICTION: - raise ValidationError(self.error_messages['contradiction'], code='contradiction') + raise ValidationError( + self.error_messages["contradiction"], code="contradiction" + ) # False means the field value should be cleared; further validation is # not needed. if data is False: @@ -612,7 +681,7 @@ class FileField(Field): class ImageField(FileField): default_validators = [validators.validate_image_file_extension] default_error_messages = { - 'invalid_image': _( + "invalid_image": _( "Upload a valid image. The file you uploaded was either not an " "image or a corrupted image." ), @@ -631,13 +700,13 @@ class ImageField(FileField): # We need to get a file object for Pillow. We might have a path or we might # have to read the data into memory. - if hasattr(data, 'temporary_file_path'): + if hasattr(data, "temporary_file_path"): file = data.temporary_file_path() else: - if hasattr(data, 'read'): + if hasattr(data, "read"): file = BytesIO(data.read()) else: - file = BytesIO(data['content']) + file = BytesIO(data["content"]) try: # load() could spot a truncated JPEG, but it loads the entire @@ -654,24 +723,24 @@ class ImageField(FileField): except Exception as exc: # Pillow doesn't recognize it as an image. raise ValidationError( - self.error_messages['invalid_image'], - code='invalid_image', + self.error_messages["invalid_image"], + code="invalid_image", ) from exc - if hasattr(f, 'seek') and callable(f.seek): + if hasattr(f, "seek") and callable(f.seek): f.seek(0) return f def widget_attrs(self, widget): attrs = super().widget_attrs(widget) - if isinstance(widget, FileInput) and 'accept' not in widget.attrs: - attrs.setdefault('accept', 'image/*') + if isinstance(widget, FileInput) and "accept" not in widget.attrs: + attrs.setdefault("accept", "image/*") return attrs class URLField(CharField): widget = URLInput default_error_messages = { - 'invalid': _('Enter a valid URL.'), + "invalid": _("Enter a valid URL."), } default_validators = [validators.URLValidator()] @@ -679,7 +748,6 @@ class URLField(CharField): super().__init__(strip=True, **kwargs) def to_python(self, value): - def split_url(url): """ Return a list of url parts via urlparse.urlsplit(), or raise @@ -690,19 +758,19 @@ class URLField(CharField): except ValueError: # urlparse.urlsplit can raise a ValueError with some # misformatted URLs. - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") value = super().to_python(value) if value: url_fields = split_url(value) if not url_fields[0]: # If no URL scheme given, assume http:// - url_fields[0] = 'http' + url_fields[0] = "http" if not url_fields[1]: # Assume that if no domain is provided, that the path segment # contains the domain. url_fields[1] = url_fields[2] - url_fields[2] = '' + url_fields[2] = "" # Rebuild the url_fields list, since the domain segment may now # contain the path too. url_fields = split_url(urlunsplit(url_fields)) @@ -719,7 +787,7 @@ class BooleanField(Field): # will submit for False. Also check for '0', since this is what # RadioSelect will provide. Because bool("True") == bool('1') == True, # we don't need to handle that explicitly. - if isinstance(value, str) and value.lower() in ('false', '0'): + if isinstance(value, str) and value.lower() in ("false", "0"): value = False else: value = bool(value) @@ -727,7 +795,7 @@ class BooleanField(Field): def validate(self, value): if not value and self.required: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError(self.error_messages["required"], code="required") def has_changed(self, initial, data): if self.disabled: @@ -742,6 +810,7 @@ class NullBooleanField(BooleanField): A field whose valid values are None, True, and False. Clean invalid values to None. """ + widget = NullBooleanSelect def to_python(self, value): @@ -753,9 +822,9 @@ class NullBooleanField(BooleanField): the Booleanfield, this field must check for True because it doesn't use the bool() function. """ - if value in (True, 'True', 'true', '1'): + if value in (True, "True", "true", "1"): return True - elif value in (False, 'False', 'false', '0'): + elif value in (False, "False", "false", "0"): return False else: return None @@ -775,7 +844,9 @@ class CallableChoiceIterator: class ChoiceField(Field): widget = Select default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), + "invalid_choice": _( + "Select a valid choice. %(value)s is not one of the available choices." + ), } def __init__(self, *, choices=(), **kwargs): @@ -806,7 +877,7 @@ class ChoiceField(Field): def to_python(self, value): """Return a string.""" if value in self.empty_values: - return '' + return "" return str(value) def validate(self, value): @@ -814,9 +885,9 @@ class ChoiceField(Field): super().validate(value) if value and not self.valid_value(value): raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': value}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": value}, ) def valid_value(self, value): @@ -835,7 +906,7 @@ class ChoiceField(Field): class TypedChoiceField(ChoiceField): - def __init__(self, *, coerce=lambda val: val, empty_value='', **kwargs): + def __init__(self, *, coerce=lambda val: val, empty_value="", **kwargs): self.coerce = coerce self.empty_value = empty_value super().__init__(**kwargs) @@ -850,9 +921,9 @@ class TypedChoiceField(ChoiceField): value = self.coerce(value) except (ValueError, TypeError, ValidationError): raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': value}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": value}, ) return value @@ -865,28 +936,32 @@ class MultipleChoiceField(ChoiceField): hidden_widget = MultipleHiddenInput widget = SelectMultiple default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), - 'invalid_list': _('Enter a list of values.'), + "invalid_choice": _( + "Select a valid choice. %(value)s is not one of the available choices." + ), + "invalid_list": _("Enter a list of values."), } def to_python(self, value): if not value: return [] elif not isinstance(value, (list, tuple)): - raise ValidationError(self.error_messages['invalid_list'], code='invalid_list') + raise ValidationError( + self.error_messages["invalid_list"], code="invalid_list" + ) return [str(val) for val in value] def validate(self, value): """Validate that the input is a list or tuple.""" if self.required and not value: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError(self.error_messages["required"], code="required") # Validate that each value in the value list is in self.choices. for val in value: if not self.valid_value(val): raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': val}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": val}, ) def has_changed(self, initial, data): @@ -906,7 +981,7 @@ class MultipleChoiceField(ChoiceField): class TypedMultipleChoiceField(MultipleChoiceField): def __init__(self, *, coerce=lambda val: val, **kwargs): self.coerce = coerce - self.empty_value = kwargs.pop('empty_value', []) + self.empty_value = kwargs.pop("empty_value", []) super().__init__(**kwargs) def _coerce(self, value): @@ -922,9 +997,9 @@ class TypedMultipleChoiceField(MultipleChoiceField): new_value.append(self.coerce(choice)) except (ValueError, TypeError, ValidationError): raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': choice}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": choice}, ) return new_value @@ -936,13 +1011,14 @@ class TypedMultipleChoiceField(MultipleChoiceField): if value != self.empty_value: super().validate(value) elif self.required: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError(self.error_messages["required"], code="required") class ComboField(Field): """ A Field whose clean() method calls multiple Field clean() methods. """ + def __init__(self, fields, **kwargs): super().__init__(**kwargs) # Set 'required' to False on the individual fields, because the @@ -980,17 +1056,17 @@ class MultiValueField(Field): You'll probably want to use this with MultiWidget. """ + default_error_messages = { - 'invalid': _('Enter a list of values.'), - 'incomplete': _('Enter a complete value.'), + "invalid": _("Enter a list of values."), + "incomplete": _("Enter a complete value."), } def __init__(self, fields, *, require_all_fields=True, **kwargs): self.require_all_fields = require_all_fields super().__init__(**kwargs) for f in fields: - f.error_messages.setdefault('incomplete', - self.error_messages['incomplete']) + f.error_messages.setdefault("incomplete", self.error_messages["incomplete"]) if self.disabled: f.disabled = True if self.require_all_fields: @@ -1024,11 +1100,13 @@ class MultiValueField(Field): if not value or isinstance(value, (list, tuple)): if not value or not [v for v in value if v not in self.empty_values]: if self.required: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError( + self.error_messages["required"], code="required" + ) else: return self.compress([]) else: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") for i, field in enumerate(self.fields): try: field_value = value[i] @@ -1039,13 +1117,15 @@ class MultiValueField(Field): # Raise a 'required' error if the MultiValueField is # required and any field is empty. if self.required: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError( + self.error_messages["required"], code="required" + ) elif field.required: # Otherwise, add an 'incomplete' error to the list of # collected errors and skip field cleaning, if a required # field is empty. - if field.error_messages['incomplete'] not in errors: - errors.append(field.error_messages['incomplete']) + if field.error_messages["incomplete"] not in errors: + errors.append(field.error_messages["incomplete"]) continue try: clean_data.append(field.clean(field_value)) @@ -1071,13 +1151,13 @@ class MultiValueField(Field): fields=(DateField(), TimeField()), this might return a datetime object created by combining the date and time in data_list. """ - raise NotImplementedError('Subclasses must implement this method.') + raise NotImplementedError("Subclasses must implement this method.") def has_changed(self, initial, data): if self.disabled: return False if initial is None: - initial = ['' for x in range(0, len(data))] + initial = ["" for x in range(0, len(data))] else: if not isinstance(initial, list): initial = self.widget.decompress(initial) @@ -1092,8 +1172,16 @@ class MultiValueField(Field): class FilePathField(ChoiceField): - def __init__(self, path, *, match=None, recursive=False, allow_files=True, - allow_folders=False, **kwargs): + def __init__( + self, + path, + *, + match=None, + recursive=False, + allow_files=True, + allow_folders=False, + **kwargs, + ): self.path, self.match, self.recursive = path, match, recursive self.allow_files, self.allow_folders = allow_files, allow_folders super().__init__(choices=(), **kwargs) @@ -1115,7 +1203,7 @@ class FilePathField(ChoiceField): self.choices.append((f, f.replace(path, "", 1))) if self.allow_folders: for f in sorted(dirs): - if f == '__pycache__': + if f == "__pycache__": continue if self.match is None or self.match_re.search(f): f = os.path.join(root, f) @@ -1124,12 +1212,12 @@ class FilePathField(ChoiceField): choices = [] with os.scandir(self.path) as entries: for f in entries: - if f.name == '__pycache__': + if f.name == "__pycache__": continue - if (( - (self.allow_files and f.is_file()) or - (self.allow_folders and f.is_dir()) - ) and (self.match is None or self.match_re.search(f.name))): + if ( + (self.allow_files and f.is_file()) + or (self.allow_folders and f.is_dir()) + ) and (self.match is None or self.match_re.search(f.name)): choices.append((f.path, f.name)) choices.sort(key=operator.itemgetter(1)) self.choices.extend(choices) @@ -1141,22 +1229,26 @@ class SplitDateTimeField(MultiValueField): widget = SplitDateTimeWidget hidden_widget = SplitHiddenDateTimeWidget default_error_messages = { - 'invalid_date': _('Enter a valid date.'), - 'invalid_time': _('Enter a valid time.'), + "invalid_date": _("Enter a valid date."), + "invalid_time": _("Enter a valid time."), } def __init__(self, *, input_date_formats=None, input_time_formats=None, **kwargs): errors = self.default_error_messages.copy() - if 'error_messages' in kwargs: - errors.update(kwargs['error_messages']) - localize = kwargs.get('localize', False) + if "error_messages" in kwargs: + errors.update(kwargs["error_messages"]) + localize = kwargs.get("localize", False) fields = ( - DateField(input_formats=input_date_formats, - error_messages={'invalid': errors['invalid_date']}, - localize=localize), - TimeField(input_formats=input_time_formats, - error_messages={'invalid': errors['invalid_time']}, - localize=localize), + DateField( + input_formats=input_date_formats, + error_messages={"invalid": errors["invalid_date"]}, + localize=localize, + ), + TimeField( + input_formats=input_time_formats, + error_messages={"invalid": errors["invalid_time"]}, + localize=localize, + ), ) super().__init__(fields, **kwargs) @@ -1165,25 +1257,31 @@ class SplitDateTimeField(MultiValueField): # Raise a validation error if time or date is empty # (possible if SplitDateTimeField has required=False). if data_list[0] in self.empty_values: - raise ValidationError(self.error_messages['invalid_date'], code='invalid_date') + raise ValidationError( + self.error_messages["invalid_date"], code="invalid_date" + ) if data_list[1] in self.empty_values: - raise ValidationError(self.error_messages['invalid_time'], code='invalid_time') + raise ValidationError( + self.error_messages["invalid_time"], code="invalid_time" + ) result = datetime.datetime.combine(*data_list) return from_current_timezone(result) return None class GenericIPAddressField(CharField): - def __init__(self, *, protocol='both', unpack_ipv4=False, **kwargs): + def __init__(self, *, protocol="both", unpack_ipv4=False, **kwargs): self.unpack_ipv4 = unpack_ipv4 - self.default_validators = validators.ip_address_validators(protocol, unpack_ipv4)[0] + self.default_validators = validators.ip_address_validators( + protocol, unpack_ipv4 + )[0] super().__init__(**kwargs) def to_python(self, value): if value in self.empty_values: - return '' + return "" value = value.strip() - if value and ':' in value: + if value and ":" in value: return clean_ipv6_address(value, self.unpack_ipv4) return value @@ -1200,7 +1298,7 @@ class SlugField(CharField): class UUIDField(CharField): default_error_messages = { - 'invalid': _('Enter a valid UUID.'), + "invalid": _("Enter a valid UUID."), } def prepare_value(self, value): @@ -1216,7 +1314,7 @@ class UUIDField(CharField): try: value = uuid.UUID(value) except ValueError: - raise ValidationError(self.error_messages['invalid'], code='invalid') + raise ValidationError(self.error_messages["invalid"], code="invalid") return value @@ -1230,7 +1328,7 @@ class JSONString(str): class JSONField(CharField): default_error_messages = { - 'invalid': _('Enter a valid JSON.'), + "invalid": _("Enter a valid JSON."), } widget = Textarea @@ -1250,9 +1348,9 @@ class JSONField(CharField): converted = json.loads(value, cls=self.decoder) except json.JSONDecodeError: raise ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, + self.error_messages["invalid"], + code="invalid", + params={"value": value}, ) if isinstance(converted, str): return JSONString(converted) @@ -1279,7 +1377,6 @@ class JSONField(CharField): return True # For purposes of seeing whether something has changed, True isn't the # same as 1 and the order of keys doesn't matter. - return ( - json.dumps(initial, sort_keys=True, cls=self.encoder) != - json.dumps(self.to_python(data), sort_keys=True, cls=self.encoder) + return json.dumps(initial, sort_keys=True, cls=self.encoder) != json.dumps( + self.to_python(data), sort_keys=True, cls=self.encoder ) diff --git a/django/forms/forms.py b/django/forms/forms.py index 589b4693fd..952b974130 100644 --- a/django/forms/forms.py +++ b/django/forms/forms.py @@ -19,15 +19,17 @@ from django.utils.translation import gettext as _ from .renderers import get_default_renderer -__all__ = ('BaseForm', 'Form') +__all__ = ("BaseForm", "Form") class DeclarativeFieldsMetaclass(MediaDefiningClass): """Collect Fields declared on the base classes.""" + def __new__(mcs, name, bases, attrs): # Collect fields from current class and remove them from attrs. - attrs['declared_fields'] = { - key: attrs.pop(key) for key, value in list(attrs.items()) + attrs["declared_fields"] = { + key: attrs.pop(key) + for key, value in list(attrs.items()) if isinstance(value, Field) } @@ -37,7 +39,7 @@ class DeclarativeFieldsMetaclass(MediaDefiningClass): declared_fields = {} for base in reversed(new_class.__mro__): # Collect fields from base class. - if hasattr(base, 'declared_fields'): + if hasattr(base, "declared_fields"): declared_fields.update(base.declared_fields) # Field shadowing. @@ -58,20 +60,32 @@ class BaseForm(RenderableFormMixin): improvements to the form API should be made to this class, not to the Form class. """ + default_renderer = None field_order = None prefix = None use_required_attribute = True - template_name = 'django/forms/default.html' - template_name_p = 'django/forms/p.html' - template_name_table = 'django/forms/table.html' - template_name_ul = 'django/forms/ul.html' - template_name_label = 'django/forms/label.html' + template_name = "django/forms/default.html" + template_name_p = "django/forms/p.html" + template_name_table = "django/forms/table.html" + template_name_ul = "django/forms/ul.html" + template_name_label = "django/forms/label.html" - def __init__(self, data=None, files=None, auto_id='id_%s', prefix=None, - initial=None, error_class=ErrorList, label_suffix=None, - empty_permitted=False, field_order=None, use_required_attribute=None, renderer=None): + def __init__( + self, + data=None, + files=None, + auto_id="id_%s", + prefix=None, + initial=None, + error_class=ErrorList, + label_suffix=None, + empty_permitted=False, + field_order=None, + use_required_attribute=None, + renderer=None, + ): self.is_bound = data is not None or files is not None self.data = MultiValueDict() if data is None else data self.files = MultiValueDict() if files is None else files @@ -81,7 +95,7 @@ class BaseForm(RenderableFormMixin): self.initial = initial or {} self.error_class = error_class # Translators: This is the default suffix added to form field labels - self.label_suffix = label_suffix if label_suffix is not None else _(':') + self.label_suffix = label_suffix if label_suffix is not None else _(":") self.empty_permitted = empty_permitted self._errors = None # Stores the errors after clean() has been called. @@ -99,8 +113,8 @@ class BaseForm(RenderableFormMixin): if self.empty_permitted and self.use_required_attribute: raise ValueError( - 'The empty_permitted and use_required_attribute arguments may ' - 'not both be True.' + "The empty_permitted and use_required_attribute arguments may " + "not both be True." ) # Initialize form renderer. Use a global default if not specified @@ -141,11 +155,11 @@ class BaseForm(RenderableFormMixin): is_valid = "Unknown" else: is_valid = self.is_bound and not self._errors - return '<%(cls)s bound=%(bound)s, valid=%(valid)s, fields=(%(fields)s)>' % { - 'cls': self.__class__.__name__, - 'bound': self.is_bound, - 'valid': is_valid, - 'fields': ';'.join(self.fields), + return "<%(cls)s bound=%(bound)s, valid=%(valid)s, fields=(%(fields)s)>" % { + "cls": self.__class__.__name__, + "bound": self.is_bound, + "valid": is_valid, + "fields": ";".join(self.fields), } def _bound_items(self): @@ -168,10 +182,11 @@ class BaseForm(RenderableFormMixin): field = self.fields[name] except KeyError: raise KeyError( - "Key '%s' not found in '%s'. Choices are: %s." % ( + "Key '%s' not found in '%s'. Choices are: %s." + % ( name, self.__class__.__name__, - ', '.join(sorted(self.fields)), + ", ".join(sorted(self.fields)), ) ) bound_field = field.get_bound_field(self, name) @@ -196,11 +211,11 @@ class BaseForm(RenderableFormMixin): Subclasses may wish to override. """ - return '%s-%s' % (self.prefix, field_name) if self.prefix else field_name + return "%s-%s" % (self.prefix, field_name) if self.prefix else field_name def add_initial_prefix(self, field_name): """Add an 'initial' prefix for checking dynamic initial values.""" - return 'initial-%s' % self.add_prefix(field_name) + return "initial-%s" % self.add_prefix(field_name) def _widget_data_value(self, widget, html_name): # value_from_datadict() gets the data from the data dictionaries. @@ -208,11 +223,13 @@ class BaseForm(RenderableFormMixin): # widgets split data over several HTML fields. return widget.value_from_datadict(self.data, self.files, html_name) - def _html_output(self, normal_row, error_row, row_ender, help_text_html, errors_on_separate_row): + def _html_output( + self, normal_row, error_row, row_ender, help_text_html, errors_on_separate_row + ): "Output HTML. Used by as_table(), as_ul(), as_p()." warnings.warn( - 'django.forms.BaseForm._html_output() is deprecated. ' - 'Please use .render() and .get_context() instead.', + "django.forms.BaseForm._html_output() is deprecated. " + "Please use .render() and .get_context() instead.", RemovedInDjango50Warning, stacklevel=2, ) @@ -222,13 +239,17 @@ class BaseForm(RenderableFormMixin): for name, bf in self._bound_items(): field = bf.field - html_class_attr = '' + html_class_attr = "" bf_errors = self.error_class(bf.errors) if bf.is_hidden: if bf_errors: top_errors.extend( - [_('(Hidden field %(name)s) %(error)s') % {'name': name, 'error': str(e)} - for e in bf_errors]) + [ + _("(Hidden field %(name)s) %(error)s") + % {"name": name, "error": str(e)} + for e in bf_errors + ] + ) hidden_fields.append(str(bf)) else: # Create a 'class="..."' attribute if the row should have any @@ -242,30 +263,33 @@ class BaseForm(RenderableFormMixin): if bf.label: label = conditional_escape(bf.label) - label = bf.label_tag(label) or '' + label = bf.label_tag(label) or "" else: - label = '' + label = "" if field.help_text: help_text = help_text_html % field.help_text else: - help_text = '' + help_text = "" - output.append(normal_row % { - 'errors': bf_errors, - 'label': label, - 'field': bf, - 'help_text': help_text, - 'html_class_attr': html_class_attr, - 'css_classes': css_classes, - 'field_name': bf.html_name, - }) + output.append( + normal_row + % { + "errors": bf_errors, + "label": label, + "field": bf, + "help_text": help_text, + "html_class_attr": html_class_attr, + "css_classes": css_classes, + "field_name": bf.html_name, + } + ) if top_errors: output.insert(0, error_row % top_errors) if hidden_fields: # Insert any hidden fields in the last row. - str_hidden = ''.join(hidden_fields) + str_hidden = "".join(hidden_fields) if output: last_row = output[-1] # Chop off the trailing row_ender (e.g. '</td></tr>') and @@ -275,22 +299,22 @@ class BaseForm(RenderableFormMixin): # that users write): if there are only top errors, we may # not be able to conscript the last row for our purposes, # so insert a new, empty row. - last_row = (normal_row % { - 'errors': '', - 'label': '', - 'field': '', - 'help_text': '', - 'html_class_attr': html_class_attr, - 'css_classes': '', - 'field_name': '', - }) + last_row = normal_row % { + "errors": "", + "label": "", + "field": "", + "help_text": "", + "html_class_attr": html_class_attr, + "css_classes": "", + "field_name": "", + } output.append(last_row) - output[-1] = last_row[:-len(row_ender)] + str_hidden + row_ender + output[-1] = last_row[: -len(row_ender)] + str_hidden + row_ender else: # If there aren't any rows in the output, just append the # hidden fields. output.append(str_hidden) - return mark_safe('\n'.join(output)) + return mark_safe("\n".join(output)) def get_context(self): fields = [] @@ -301,7 +325,8 @@ class BaseForm(RenderableFormMixin): if bf.is_hidden: if bf_errors: top_errors += [ - _('(Hidden field %(name)s) %(error)s') % {'name': name, 'error': str(e)} + _("(Hidden field %(name)s) %(error)s") + % {"name": name, "error": str(e)} for e in bf_errors ] hidden_fields.append(bf) @@ -310,18 +335,18 @@ class BaseForm(RenderableFormMixin): # RemovedInDjango50Warning. if not isinstance(errors_str, SafeString): warnings.warn( - f'Returning a plain string from ' - f'{self.error_class.__name__} is deprecated. Please ' - f'customize via the template system instead.', + f"Returning a plain string from " + f"{self.error_class.__name__} is deprecated. Please " + f"customize via the template system instead.", RemovedInDjango50Warning, ) errors_str = mark_safe(errors_str) fields.append((bf, errors_str)) return { - 'form': self, - 'fields': fields, - 'hidden_fields': hidden_fields, - 'errors': top_errors, + "form": self, + "fields": fields, + "hidden_fields": hidden_fields, + "errors": top_errors, } def non_field_errors(self): @@ -332,7 +357,7 @@ class BaseForm(RenderableFormMixin): """ return self.errors.get( NON_FIELD_ERRORS, - self.error_class(error_class='nonfield', renderer=self.renderer), + self.error_class(error_class="nonfield", renderer=self.renderer), ) def add_error(self, field, error): @@ -358,7 +383,7 @@ class BaseForm(RenderableFormMixin): # do the hard work of making sense of the input. error = ValidationError(error) - if hasattr(error, 'error_dict'): + if hasattr(error, "error_dict"): if field is not None: raise TypeError( "The argument `field` must be `None` when the `error` " @@ -373,9 +398,13 @@ class BaseForm(RenderableFormMixin): if field not in self.errors: if field != NON_FIELD_ERRORS and field not in self.fields: raise ValueError( - "'%s' has no field named '%s'." % (self.__class__.__name__, field)) + "'%s' has no field named '%s'." + % (self.__class__.__name__, field) + ) if field == NON_FIELD_ERRORS: - self._errors[field] = self.error_class(error_class='nonfield', renderer=self.renderer) + self._errors[field] = self.error_class( + error_class="nonfield", renderer=self.renderer + ) else: self._errors[field] = self.error_class(renderer=self.renderer) self._errors[field].extend(error_list) @@ -384,8 +413,8 @@ class BaseForm(RenderableFormMixin): def has_error(self, field, code=None): return field in self.errors and ( - code is None or - any(error.code == code for error in self.errors.as_data()[field]) + code is None + or any(error.code == code for error in self.errors.as_data()[field]) ) def full_clean(self): @@ -415,8 +444,8 @@ class BaseForm(RenderableFormMixin): else: value = field.clean(value) self.cleaned_data[name] = value - if hasattr(self, 'clean_%s' % name): - value = getattr(self, 'clean_%s' % name)() + if hasattr(self, "clean_%s" % name): + value = getattr(self, "clean_%s" % name)() self.cleaned_data[name] = value except ValidationError as e: self.add_error(name, e) @@ -493,8 +522,10 @@ class BaseForm(RenderableFormMixin): value = value() # If this is an auto-generated default date, nix the microseconds # for standardized handling. See #22502. - if (isinstance(value, (datetime.datetime, datetime.time)) and - not field.widget.supports_microseconds): + if ( + isinstance(value, (datetime.datetime, datetime.time)) + and not field.widget.supports_microseconds + ): value = value.replace(microsecond=0) return value diff --git a/django/forms/formsets.py b/django/forms/formsets.py index 75b0646512..e5807e8688 100644 --- a/django/forms/formsets.py +++ b/django/forms/formsets.py @@ -5,17 +5,18 @@ from django.forms.renderers import get_default_renderer from django.forms.utils import ErrorList, RenderableFormMixin from django.forms.widgets import CheckboxInput, HiddenInput, NumberInput from django.utils.functional import cached_property -from django.utils.translation import gettext_lazy as _, ngettext +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import ngettext -__all__ = ('BaseFormSet', 'formset_factory', 'all_valid') +__all__ = ("BaseFormSet", "formset_factory", "all_valid") # special field names -TOTAL_FORM_COUNT = 'TOTAL_FORMS' -INITIAL_FORM_COUNT = 'INITIAL_FORMS' -MIN_NUM_FORM_COUNT = 'MIN_NUM_FORMS' -MAX_NUM_FORM_COUNT = 'MAX_NUM_FORMS' -ORDERING_FIELD_NAME = 'ORDER' -DELETION_FIELD_NAME = 'DELETE' +TOTAL_FORM_COUNT = "TOTAL_FORMS" +INITIAL_FORM_COUNT = "INITIAL_FORMS" +MIN_NUM_FORM_COUNT = "MIN_NUM_FORMS" +MAX_NUM_FORM_COUNT = "MAX_NUM_FORMS" +ORDERING_FIELD_NAME = "ORDER" +DELETION_FIELD_NAME = "DELETE" # default minimum number of forms in a formset DEFAULT_MIN_NUM = 0 @@ -30,6 +31,7 @@ class ManagementForm(Form): new forms via JavaScript, you should increment the count field of this form as well. """ + TOTAL_FORMS = IntegerField(widget=HiddenInput) INITIAL_FORMS = IntegerField(widget=HiddenInput) # MIN_NUM_FORM_COUNT and MAX_NUM_FORM_COUNT are output with the rest of the @@ -51,22 +53,31 @@ class BaseFormSet(RenderableFormMixin): """ A collection of instances of the same Form class. """ + deletion_widget = CheckboxInput ordering_widget = NumberInput default_error_messages = { - 'missing_management_form': _( - 'ManagementForm data is missing or has been tampered with. Missing fields: ' - '%(field_names)s. You may need to file a bug report if the issue persists.' + "missing_management_form": _( + "ManagementForm data is missing or has been tampered with. Missing fields: " + "%(field_names)s. You may need to file a bug report if the issue persists." ), } - template_name = 'django/forms/formsets/default.html' - template_name_p = 'django/forms/formsets/p.html' - template_name_table = 'django/forms/formsets/table.html' - template_name_ul = 'django/forms/formsets/ul.html' + template_name = "django/forms/formsets/default.html" + template_name_p = "django/forms/formsets/p.html" + template_name_table = "django/forms/formsets/table.html" + template_name_ul = "django/forms/formsets/ul.html" - def __init__(self, data=None, files=None, auto_id='id_%s', prefix=None, - initial=None, error_class=ErrorList, form_kwargs=None, - error_messages=None): + def __init__( + self, + data=None, + files=None, + auto_id="id_%s", + prefix=None, + initial=None, + error_class=ErrorList, + form_kwargs=None, + error_messages=None, + ): self.is_bound = data is not None or files is not None self.prefix = prefix or self.get_default_prefix() self.auto_id = auto_id @@ -80,7 +91,7 @@ class BaseFormSet(RenderableFormMixin): messages = {} for cls in reversed(type(self).__mro__): - messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(getattr(cls, "default_error_messages", {})) if error_messages is not None: messages.update(error_messages) self.error_messages = messages @@ -105,14 +116,14 @@ class BaseFormSet(RenderableFormMixin): def __repr__(self): if self._errors is None: - is_valid = 'Unknown' + is_valid = "Unknown" else: is_valid = ( - self.is_bound and - not self._non_form_errors and - not any(form_errors for form_errors in self._errors) + self.is_bound + and not self._non_form_errors + and not any(form_errors for form_errors in self._errors) ) - return '<%s: bound=%s valid=%s total_forms=%s>' % ( + return "<%s: bound=%s valid=%s total_forms=%s>" % ( self.__class__.__qualname__, self.is_bound, is_valid, @@ -123,7 +134,12 @@ class BaseFormSet(RenderableFormMixin): def management_form(self): """Return the ManagementForm instance for this FormSet.""" if self.is_bound: - form = ManagementForm(self.data, auto_id=self.auto_id, prefix=self.prefix, renderer=self.renderer) + form = ManagementForm( + self.data, + auto_id=self.auto_id, + prefix=self.prefix, + renderer=self.renderer, + ) form.full_clean() else: form = ManagementForm( @@ -146,7 +162,9 @@ class BaseFormSet(RenderableFormMixin): # count in the data; this is DoS protection to prevent clients # from forcing the server to instantiate arbitrary numbers of # forms - return min(self.management_form.cleaned_data[TOTAL_FORM_COUNT], self.absolute_max) + return min( + self.management_form.cleaned_data[TOTAL_FORM_COUNT], self.absolute_max + ) else: initial_forms = self.initial_form_count() total_forms = max(initial_forms, self.min_num) + self.extra @@ -188,27 +206,27 @@ class BaseFormSet(RenderableFormMixin): def _construct_form(self, i, **kwargs): """Instantiate and return the i-th form instance in a formset.""" defaults = { - 'auto_id': self.auto_id, - 'prefix': self.add_prefix(i), - 'error_class': self.error_class, + "auto_id": self.auto_id, + "prefix": self.add_prefix(i), + "error_class": self.error_class, # Don't render the HTML 'required' attribute as it may cause # incorrect validation for extra, optional, and deleted # forms in the formset. - 'use_required_attribute': False, - 'renderer': self.renderer, + "use_required_attribute": False, + "renderer": self.renderer, } if self.is_bound: - defaults['data'] = self.data - defaults['files'] = self.files - if self.initial and 'initial' not in kwargs: + defaults["data"] = self.data + defaults["files"] = self.files + if self.initial and "initial" not in kwargs: try: - defaults['initial'] = self.initial[i] + defaults["initial"] = self.initial[i] except IndexError: pass # Allow extra forms to be empty, unless they're part of # the minimum forms. if i >= self.initial_form_count() and i >= self.min_num: - defaults['empty_permitted'] = True + defaults["empty_permitted"] = True defaults.update(kwargs) form = self.form(**defaults) self.add_fields(form, i) @@ -217,18 +235,18 @@ class BaseFormSet(RenderableFormMixin): @property def initial_forms(self): """Return a list of all the initial forms in this formset.""" - return self.forms[:self.initial_form_count()] + return self.forms[: self.initial_form_count()] @property def extra_forms(self): """Return a list of all the extra forms in this formset.""" - return self.forms[self.initial_form_count():] + return self.forms[self.initial_form_count() :] @property def empty_form(self): form = self.form( auto_id=self.auto_id, - prefix=self.add_prefix('__prefix__'), + prefix=self.add_prefix("__prefix__"), empty_permitted=True, use_required_attribute=False, **self.get_form_kwargs(None), @@ -243,7 +261,9 @@ class BaseFormSet(RenderableFormMixin): Return a list of form.cleaned_data dicts for every form in self.forms. """ if not self.is_valid(): - raise AttributeError("'%s' object has no attribute 'cleaned_data'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'cleaned_data'" % self.__class__.__name__ + ) return [form.cleaned_data for form in self.forms] @property @@ -253,7 +273,7 @@ class BaseFormSet(RenderableFormMixin): return [] # construct _deleted_form_indexes which is just a list of form indexes # that have had their deletion widget set to True - if not hasattr(self, '_deleted_form_indexes'): + if not hasattr(self, "_deleted_form_indexes"): self._deleted_form_indexes = [] for i, form in enumerate(self.forms): # if this is an extra form and hasn't changed, don't consider it @@ -270,12 +290,14 @@ class BaseFormSet(RenderableFormMixin): Raise an AttributeError if ordering is not allowed. """ if not self.is_valid() or not self.can_order: - raise AttributeError("'%s' object has no attribute 'ordered_forms'" % self.__class__.__name__) + raise AttributeError( + "'%s' object has no attribute 'ordered_forms'" % self.__class__.__name__ + ) # Construct _ordering, which is a list of (form_index, order_field_value) # tuples. After constructing this list, we'll sort it by order_field_value # so we have a way to get to the form indexes in the order specified # by the form data. - if not hasattr(self, '_ordering'): + if not hasattr(self, "_ordering"): self._ordering = [] for i, form in enumerate(self.forms): # if this is an extra form and hasn't changed, don't consider it @@ -295,6 +317,7 @@ class BaseFormSet(RenderableFormMixin): if k[1] is None: return (1, 0) # +infinity, larger than any number return (0, k[1]) + self._ordering.sort(key=compare_ordering_key) # Return a list of form.cleaned_data dicts in the order specified by # the form data. @@ -302,7 +325,7 @@ class BaseFormSet(RenderableFormMixin): @classmethod def get_default_prefix(cls): - return 'form' + return "form" @classmethod def get_deletion_widget(cls): @@ -331,8 +354,9 @@ class BaseFormSet(RenderableFormMixin): def total_error_count(self): """Return the number of errors across all forms in the formset.""" - return len(self.non_form_errors()) +\ - sum(len(form_errors) for form_errors in self.errors) + return len(self.non_form_errors()) + sum( + len(form_errors) for form_errors in self.errors + ) def _should_delete_form(self, form): """Return whether or not the form was marked for deletion.""" @@ -346,10 +370,13 @@ class BaseFormSet(RenderableFormMixin): self.errors # List comprehension ensures is_valid() is called for all forms. # Forms due to be deleted shouldn't cause the formset to be invalid. - forms_valid = all([ - form.is_valid() for form in self.forms - if not (self.can_delete and self._should_delete_form(form)) - ]) + forms_valid = all( + [ + form.is_valid() + for form in self.forms + if not (self.can_delete and self._should_delete_form(form)) + ] + ) return forms_valid and not self.non_form_errors() def full_clean(self): @@ -358,7 +385,9 @@ class BaseFormSet(RenderableFormMixin): self._non_form_errors. """ self._errors = [] - self._non_form_errors = self.error_class(error_class='nonform', renderer=self.renderer) + self._non_form_errors = self.error_class( + error_class="nonform", renderer=self.renderer + ) empty_forms_count = 0 if not self.is_bound: # Stop further processing. @@ -366,14 +395,14 @@ class BaseFormSet(RenderableFormMixin): if not self.management_form.is_valid(): error = ValidationError( - self.error_messages['missing_management_form'], + self.error_messages["missing_management_form"], params={ - 'field_names': ', '.join( + "field_names": ", ".join( self.management_form.add_prefix(field_name) for field_name in self.management_form.errors ), }, - code='missing_management_form', + code="missing_management_form", ) self._non_form_errors.append(error) @@ -388,26 +417,43 @@ class BaseFormSet(RenderableFormMixin): continue self._errors.append(form_errors) try: - if (self.validate_max and - self.total_form_count() - len(self.deleted_forms) > self.max_num) or \ - self.management_form.cleaned_data[TOTAL_FORM_COUNT] > self.absolute_max: - raise ValidationError(ngettext( - "Please submit at most %d form.", - "Please submit at most %d forms.", self.max_num) % self.max_num, - code='too_many_forms', + if ( + self.validate_max + and self.total_form_count() - len(self.deleted_forms) > self.max_num + ) or self.management_form.cleaned_data[ + TOTAL_FORM_COUNT + ] > self.absolute_max: + raise ValidationError( + ngettext( + "Please submit at most %d form.", + "Please submit at most %d forms.", + self.max_num, + ) + % self.max_num, + code="too_many_forms", + ) + if ( + self.validate_min + and self.total_form_count() + - len(self.deleted_forms) + - empty_forms_count + < self.min_num + ): + raise ValidationError( + ngettext( + "Please submit at least %d form.", + "Please submit at least %d forms.", + self.min_num, + ) + % self.min_num, + code="too_few_forms", ) - if (self.validate_min and - self.total_form_count() - len(self.deleted_forms) - empty_forms_count < self.min_num): - raise ValidationError(ngettext( - "Please submit at least %d form.", - "Please submit at least %d forms.", self.min_num) % self.min_num, - code='too_few_forms') # Give self.clean() a chance to do cross-form validation. self.clean() except ValidationError as e: self._non_form_errors = self.error_class( e.error_list, - error_class='nonform', + error_class="nonform", renderer=self.renderer, ) @@ -431,26 +477,26 @@ class BaseFormSet(RenderableFormMixin): # Only pre-fill the ordering field for initial forms. if index is not None and index < initial_form_count: form.fields[ORDERING_FIELD_NAME] = IntegerField( - label=_('Order'), + label=_("Order"), initial=index + 1, required=False, widget=self.get_ordering_widget(), ) else: form.fields[ORDERING_FIELD_NAME] = IntegerField( - label=_('Order'), + label=_("Order"), required=False, widget=self.get_ordering_widget(), ) if self.can_delete and (self.can_delete_extra or index < initial_form_count): form.fields[DELETION_FIELD_NAME] = BooleanField( - label=_('Delete'), + label=_("Delete"), required=False, widget=self.get_deletion_widget(), ) def add_prefix(self, index): - return '%s-%s' % (self.prefix, index) + return "%s-%s" % (self.prefix, index) def is_multipart(self): """ @@ -472,13 +518,23 @@ class BaseFormSet(RenderableFormMixin): return self.empty_form.media def get_context(self): - return {'formset': self} + return {"formset": self} -def formset_factory(form, formset=BaseFormSet, extra=1, can_order=False, - can_delete=False, max_num=None, validate_max=False, - min_num=None, validate_min=False, absolute_max=None, - can_delete_extra=True, renderer=None): +def formset_factory( + form, + formset=BaseFormSet, + extra=1, + can_order=False, + can_delete=False, + max_num=None, + validate_max=False, + min_num=None, + validate_min=False, + absolute_max=None, + can_delete_extra=True, + renderer=None, +): """Return a FormSet for the given form class.""" if min_num is None: min_num = DEFAULT_MIN_NUM @@ -490,23 +546,21 @@ def formset_factory(form, formset=BaseFormSet, extra=1, can_order=False, if absolute_max is None: absolute_max = max_num + DEFAULT_MAX_NUM if max_num > absolute_max: - raise ValueError( - "'absolute_max' must be greater or equal to 'max_num'." - ) + raise ValueError("'absolute_max' must be greater or equal to 'max_num'.") attrs = { - 'form': form, - 'extra': extra, - 'can_order': can_order, - 'can_delete': can_delete, - 'can_delete_extra': can_delete_extra, - 'min_num': min_num, - 'max_num': max_num, - 'absolute_max': absolute_max, - 'validate_min': validate_min, - 'validate_max': validate_max, - 'renderer': renderer or get_default_renderer(), + "form": form, + "extra": extra, + "can_order": can_order, + "can_delete": can_delete, + "can_delete_extra": can_delete_extra, + "min_num": min_num, + "max_num": max_num, + "absolute_max": absolute_max, + "validate_min": validate_min, + "validate_max": validate_max, + "renderer": renderer or get_default_renderer(), } - return type(form.__name__ + 'FormSet', (formset,), attrs) + return type(form.__name__ + "FormSet", (formset,), attrs) def all_valid(formsets): diff --git a/django/forms/models.py b/django/forms/models.py index 19a5cb142a..a55af8eeb6 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -5,26 +5,41 @@ and database field objects. from itertools import chain from django.core.exceptions import ( - NON_FIELD_ERRORS, FieldError, ImproperlyConfigured, ValidationError, + NON_FIELD_ERRORS, + FieldError, + ImproperlyConfigured, + ValidationError, ) from django.forms.fields import ChoiceField, Field from django.forms.forms import BaseForm, DeclarativeFieldsMetaclass from django.forms.formsets import BaseFormSet, formset_factory from django.forms.utils import ErrorList from django.forms.widgets import ( - HiddenInput, MultipleHiddenInput, RadioSelect, SelectMultiple, + HiddenInput, + MultipleHiddenInput, + RadioSelect, + SelectMultiple, ) from django.utils.text import capfirst, get_text_list -from django.utils.translation import gettext, gettext_lazy as _ +from django.utils.translation import gettext +from django.utils.translation import gettext_lazy as _ __all__ = ( - 'ModelForm', 'BaseModelForm', 'model_to_dict', 'fields_for_model', - 'ModelChoiceField', 'ModelMultipleChoiceField', 'ALL_FIELDS', - 'BaseModelFormSet', 'modelformset_factory', 'BaseInlineFormSet', - 'inlineformset_factory', 'modelform_factory', + "ModelForm", + "BaseModelForm", + "model_to_dict", + "fields_for_model", + "ModelChoiceField", + "ModelMultipleChoiceField", + "ALL_FIELDS", + "BaseModelFormSet", + "modelformset_factory", + "BaseInlineFormSet", + "inlineformset_factory", + "modelform_factory", ) -ALL_FIELDS = '__all__' +ALL_FIELDS = "__all__" def construct_instance(form, instance, fields=None, exclude=None): @@ -33,13 +48,17 @@ def construct_instance(form, instance, fields=None, exclude=None): ``cleaned_data``, but do not save the returned instance to the database. """ from django.db import models + opts = instance._meta cleaned_data = form.cleaned_data file_field_list = [] for f in opts.fields: - if not f.editable or isinstance(f, models.AutoField) \ - or f.name not in cleaned_data: + if ( + not f.editable + or isinstance(f, models.AutoField) + or f.name not in cleaned_data + ): continue if fields is not None and f.name not in fields: continue @@ -48,9 +67,11 @@ def construct_instance(form, instance, fields=None, exclude=None): # Leave defaults for fields that aren't in POST data, except for # checkbox inputs because they don't appear in POST data if not checked. if ( - f.has_default() and - form[f.name].field.widget.value_omitted_from_data(form.data, form.files, form.add_prefix(f.name)) and - cleaned_data.get(f.name) in form[f.name].field.empty_values + f.has_default() + and form[f.name].field.widget.value_omitted_from_data( + form.data, form.files, form.add_prefix(f.name) + ) + and cleaned_data.get(f.name) in form[f.name].field.empty_values ): continue # Defer saving file-type fields until after the other fields, so a @@ -68,6 +89,7 @@ def construct_instance(form, instance, fields=None, exclude=None): # ModelForms ################################################################# + def model_to_dict(instance, fields=None, exclude=None): """ Return a dict containing the data in ``instance`` suitable for passing as @@ -83,7 +105,7 @@ def model_to_dict(instance, fields=None, exclude=None): opts = instance._meta data = {} for f in chain(opts.concrete_fields, opts.private_fields, opts.many_to_many): - if not getattr(f, 'editable', False): + if not getattr(f, "editable", False): continue if fields is not None and f.name not in fields: continue @@ -96,23 +118,34 @@ def model_to_dict(instance, fields=None, exclude=None): def apply_limit_choices_to_to_formfield(formfield): """Apply limit_choices_to to the formfield's queryset if needed.""" from django.db.models import Exists, OuterRef, Q - if hasattr(formfield, 'queryset') and hasattr(formfield, 'get_limit_choices_to'): + + if hasattr(formfield, "queryset") and hasattr(formfield, "get_limit_choices_to"): limit_choices_to = formfield.get_limit_choices_to() if limit_choices_to: complex_filter = limit_choices_to if not isinstance(complex_filter, Q): complex_filter = Q(**limit_choices_to) - complex_filter &= Q(pk=OuterRef('pk')) + complex_filter &= Q(pk=OuterRef("pk")) # Use Exists() to avoid potential duplicates. formfield.queryset = formfield.queryset.filter( Exists(formfield.queryset.model._base_manager.filter(complex_filter)), ) -def fields_for_model(model, fields=None, exclude=None, widgets=None, - formfield_callback=None, localized_fields=None, - labels=None, help_texts=None, error_messages=None, - field_classes=None, *, apply_limit_choices_to=True): +def fields_for_model( + model, + fields=None, + exclude=None, + widgets=None, + formfield_callback=None, + localized_fields=None, + labels=None, + help_texts=None, + error_messages=None, + field_classes=None, + *, + apply_limit_choices_to=True, +): """ Return a dictionary containing form fields for the given model. @@ -148,14 +181,22 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, opts = model._meta # Avoid circular import from django.db.models import Field as ModelField - sortable_private_fields = [f for f in opts.private_fields if isinstance(f, ModelField)] - for f in sorted(chain(opts.concrete_fields, sortable_private_fields, opts.many_to_many)): - if not getattr(f, 'editable', False): - if (fields is not None and f.name in fields and - (exclude is None or f.name not in exclude)): + + sortable_private_fields = [ + f for f in opts.private_fields if isinstance(f, ModelField) + ] + for f in sorted( + chain(opts.concrete_fields, sortable_private_fields, opts.many_to_many) + ): + if not getattr(f, "editable", False): + if ( + fields is not None + and f.name in fields + and (exclude is None or f.name not in exclude) + ): raise FieldError( - "'%s' cannot be specified for %s model form as it is a non-editable field" % ( - f.name, model.__name__) + "'%s' cannot be specified for %s model form as it is a non-editable field" + % (f.name, model.__name__) ) continue if fields is not None and f.name not in fields: @@ -165,22 +206,24 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, kwargs = {} if widgets and f.name in widgets: - kwargs['widget'] = widgets[f.name] - if localized_fields == ALL_FIELDS or (localized_fields and f.name in localized_fields): - kwargs['localize'] = True + kwargs["widget"] = widgets[f.name] + if localized_fields == ALL_FIELDS or ( + localized_fields and f.name in localized_fields + ): + kwargs["localize"] = True if labels and f.name in labels: - kwargs['label'] = labels[f.name] + kwargs["label"] = labels[f.name] if help_texts and f.name in help_texts: - kwargs['help_text'] = help_texts[f.name] + kwargs["help_text"] = help_texts[f.name] if error_messages and f.name in error_messages: - kwargs['error_messages'] = error_messages[f.name] + kwargs["error_messages"] = error_messages[f.name] if field_classes and f.name in field_classes: - kwargs['form_class'] = field_classes[f.name] + kwargs["form_class"] = field_classes[f.name] if formfield_callback is None: formfield = f.formfield(**kwargs) elif not callable(formfield_callback): - raise TypeError('formfield_callback must be a function or callable') + raise TypeError("formfield_callback must be a function or callable") else: formfield = formfield_callback(f, **kwargs) @@ -192,7 +235,8 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, ignored.append(f.name) if fields: field_dict = { - f: field_dict.get(f) for f in fields + f: field_dict.get(f) + for f in fields if (not exclude or f not in exclude) and f not in ignored } return field_dict @@ -200,46 +244,49 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, class ModelFormOptions: def __init__(self, options=None): - self.model = getattr(options, 'model', None) - self.fields = getattr(options, 'fields', None) - self.exclude = getattr(options, 'exclude', None) - self.widgets = getattr(options, 'widgets', None) - self.localized_fields = getattr(options, 'localized_fields', None) - self.labels = getattr(options, 'labels', None) - self.help_texts = getattr(options, 'help_texts', None) - self.error_messages = getattr(options, 'error_messages', None) - self.field_classes = getattr(options, 'field_classes', None) + self.model = getattr(options, "model", None) + self.fields = getattr(options, "fields", None) + self.exclude = getattr(options, "exclude", None) + self.widgets = getattr(options, "widgets", None) + self.localized_fields = getattr(options, "localized_fields", None) + self.labels = getattr(options, "labels", None) + self.help_texts = getattr(options, "help_texts", None) + self.error_messages = getattr(options, "error_messages", None) + self.field_classes = getattr(options, "field_classes", None) class ModelFormMetaclass(DeclarativeFieldsMetaclass): def __new__(mcs, name, bases, attrs): base_formfield_callback = None for b in bases: - if hasattr(b, 'Meta') and hasattr(b.Meta, 'formfield_callback'): + if hasattr(b, "Meta") and hasattr(b.Meta, "formfield_callback"): base_formfield_callback = b.Meta.formfield_callback break - formfield_callback = attrs.pop('formfield_callback', base_formfield_callback) + formfield_callback = attrs.pop("formfield_callback", base_formfield_callback) new_class = super().__new__(mcs, name, bases, attrs) if bases == (BaseModelForm,): return new_class - opts = new_class._meta = ModelFormOptions(getattr(new_class, 'Meta', None)) + opts = new_class._meta = ModelFormOptions(getattr(new_class, "Meta", None)) # We check if a string was passed to `fields` or `exclude`, # which is likely to be a mistake where the user typed ('foo') instead # of ('foo',) - for opt in ['fields', 'exclude', 'localized_fields']: + for opt in ["fields", "exclude", "localized_fields"]: value = getattr(opts, opt) if isinstance(value, str) and value != ALL_FIELDS: - msg = ("%(model)s.Meta.%(opt)s cannot be a string. " - "Did you mean to type: ('%(value)s',)?" % { - 'model': new_class.__name__, - 'opt': opt, - 'value': value, - }) + msg = ( + "%(model)s.Meta.%(opt)s cannot be a string. " + "Did you mean to type: ('%(value)s',)?" + % { + "model": new_class.__name__, + "opt": opt, + "value": value, + } + ) raise TypeError(msg) if opts.model: @@ -257,9 +304,16 @@ class ModelFormMetaclass(DeclarativeFieldsMetaclass): opts.fields = None fields = fields_for_model( - opts.model, opts.fields, opts.exclude, opts.widgets, - formfield_callback, opts.localized_fields, opts.labels, - opts.help_texts, opts.error_messages, opts.field_classes, + opts.model, + opts.fields, + opts.exclude, + opts.widgets, + formfield_callback, + opts.localized_fields, + opts.labels, + opts.help_texts, + opts.error_messages, + opts.field_classes, # limit_choices_to will be applied during ModelForm.__init__(). apply_limit_choices_to=False, ) @@ -268,9 +322,8 @@ class ModelFormMetaclass(DeclarativeFieldsMetaclass): none_model_fields = {k for k, v in fields.items() if not v} missing_fields = none_model_fields.difference(new_class.declared_fields) if missing_fields: - message = 'Unknown field(s) (%s) specified for %s' - message = message % (', '.join(missing_fields), - opts.model.__name__) + message = "Unknown field(s) (%s) specified for %s" + message = message % (", ".join(missing_fields), opts.model.__name__) raise FieldError(message) # Override default model fields with any custom declared ones # (plus, include all the other declared fields). @@ -284,13 +337,23 @@ class ModelFormMetaclass(DeclarativeFieldsMetaclass): class BaseModelForm(BaseForm): - def __init__(self, data=None, files=None, auto_id='id_%s', prefix=None, - initial=None, error_class=ErrorList, label_suffix=None, - empty_permitted=False, instance=None, use_required_attribute=None, - renderer=None): + def __init__( + self, + data=None, + files=None, + auto_id="id_%s", + prefix=None, + initial=None, + error_class=ErrorList, + label_suffix=None, + empty_permitted=False, + instance=None, + use_required_attribute=None, + renderer=None, + ): opts = self._meta if opts.model is None: - raise ValueError('ModelForm has no model class specified.') + raise ValueError("ModelForm has no model class specified.") if instance is None: # if we didn't get an instance, instantiate a new one self.instance = opts.model() @@ -306,8 +369,15 @@ class BaseModelForm(BaseForm): # super will stop validate_unique from being called. self._validate_unique = False super().__init__( - data, files, auto_id, prefix, object_data, error_class, - label_suffix, empty_permitted, use_required_attribute=use_required_attribute, + data, + files, + auto_id, + prefix, + object_data, + error_class, + label_suffix, + empty_permitted, + use_required_attribute=use_required_attribute, renderer=renderer, ) for formfield in self.fields.values(): @@ -350,7 +420,11 @@ class BaseModelForm(BaseForm): else: form_field = self.fields[field] field_value = self.cleaned_data.get(field) - if not f.blank and not form_field.required and field_value in form_field.empty_values: + if ( + not f.blank + and not form_field.required + and field_value in form_field.empty_values + ): exclude.append(f.name) return exclude @@ -365,14 +439,17 @@ class BaseModelForm(BaseForm): # Allow the model generated by construct_instance() to raise # ValidationError and have them handled in the same way as others. - if hasattr(errors, 'error_dict'): + if hasattr(errors, "error_dict"): error_dict = errors.error_dict else: error_dict = {NON_FIELD_ERRORS: errors} for field, messages in error_dict.items(): - if (field == NON_FIELD_ERRORS and opts.error_messages and - NON_FIELD_ERRORS in opts.error_messages): + if ( + field == NON_FIELD_ERRORS + and opts.error_messages + and NON_FIELD_ERRORS in opts.error_messages + ): error_messages = opts.error_messages[NON_FIELD_ERRORS] elif field in self.fields: error_messages = self.fields[field].error_messages @@ -380,8 +457,10 @@ class BaseModelForm(BaseForm): continue for message in messages: - if (isinstance(message, ValidationError) and - message.code in error_messages): + if ( + isinstance(message, ValidationError) + and message.code in error_messages + ): message.message = error_messages[message.code] self.add_error(None, errors) @@ -403,7 +482,9 @@ class BaseModelForm(BaseForm): exclude.append(name) try: - self.instance = construct_instance(self, self.instance, opts.fields, opts.exclude) + self.instance = construct_instance( + self, self.instance, opts.fields, opts.exclude + ) except ValidationError as e: self._update_errors(e) @@ -439,7 +520,7 @@ class BaseModelForm(BaseForm): # private_fields here. (GenericRelation was previously a fake # m2m field). for f in chain(opts.many_to_many, opts.private_fields): - if not hasattr(f, 'save_form_data'): + if not hasattr(f, "save_form_data"): continue if fields and f.name not in fields: continue @@ -456,9 +537,10 @@ class BaseModelForm(BaseForm): """ if self.errors: raise ValueError( - "The %s could not be %s because the data didn't validate." % ( + "The %s could not be %s because the data didn't validate." + % ( self.instance._meta.object_name, - 'created' if self.instance._state.adding else 'changed', + "created" if self.instance._state.adding else "changed", ) ) if commit: @@ -478,10 +560,19 @@ class ModelForm(BaseModelForm, metaclass=ModelFormMetaclass): pass -def modelform_factory(model, form=ModelForm, fields=None, exclude=None, - formfield_callback=None, widgets=None, localized_fields=None, - labels=None, help_texts=None, error_messages=None, - field_classes=None): +def modelform_factory( + model, + form=ModelForm, + fields=None, + exclude=None, + formfield_callback=None, + widgets=None, + localized_fields=None, + labels=None, + help_texts=None, + error_messages=None, + field_classes=None, +): """ Return a ModelForm containing form fields for the given model. You can optionally pass a `form` argument to use as a starting point for @@ -517,41 +608,37 @@ def modelform_factory(model, form=ModelForm, fields=None, exclude=None, # inner class. # Build up a list of attributes that the Meta object will have. - attrs = {'model': model} + attrs = {"model": model} if fields is not None: - attrs['fields'] = fields + attrs["fields"] = fields if exclude is not None: - attrs['exclude'] = exclude + attrs["exclude"] = exclude if widgets is not None: - attrs['widgets'] = widgets + attrs["widgets"] = widgets if localized_fields is not None: - attrs['localized_fields'] = localized_fields + attrs["localized_fields"] = localized_fields if labels is not None: - attrs['labels'] = labels + attrs["labels"] = labels if help_texts is not None: - attrs['help_texts'] = help_texts + attrs["help_texts"] = help_texts if error_messages is not None: - attrs['error_messages'] = error_messages + attrs["error_messages"] = error_messages if field_classes is not None: - attrs['field_classes'] = field_classes + attrs["field_classes"] = field_classes # If parent form class already has an inner Meta, the Meta we're # creating needs to inherit from the parent's inner meta. - bases = (form.Meta,) if hasattr(form, 'Meta') else () - Meta = type('Meta', bases, attrs) + bases = (form.Meta,) if hasattr(form, "Meta") else () + Meta = type("Meta", bases, attrs) if formfield_callback: Meta.formfield_callback = staticmethod(formfield_callback) # Give this new form class a reasonable name. - class_name = model.__name__ + 'Form' + class_name = model.__name__ + "Form" # Class attributes for the new form class. - form_class_attrs = { - 'Meta': Meta, - 'formfield_callback': formfield_callback - } + form_class_attrs = {"Meta": Meta, "formfield_callback": formfield_callback} - if (getattr(Meta, 'fields', None) is None and - getattr(Meta, 'exclude', None) is None): + if getattr(Meta, "fields", None) is None and getattr(Meta, "exclude", None) is None: raise ImproperlyConfigured( "Calling modelform_factory without defining 'fields' or " "'exclude' explicitly is prohibited." @@ -563,20 +650,39 @@ def modelform_factory(model, form=ModelForm, fields=None, exclude=None, # ModelFormSets ############################################################## + class BaseModelFormSet(BaseFormSet): """ A ``FormSet`` for editing a queryset and/or adding new objects to it. """ + model = None # Set of fields that must be unique among forms of this set. unique_fields = set() - def __init__(self, data=None, files=None, auto_id='id_%s', prefix=None, - queryset=None, *, initial=None, **kwargs): + def __init__( + self, + data=None, + files=None, + auto_id="id_%s", + prefix=None, + queryset=None, + *, + initial=None, + **kwargs, + ): self.queryset = queryset self.initial_extra = initial - super().__init__(**{'data': data, 'files': files, 'auto_id': auto_id, 'prefix': prefix, **kwargs}) + super().__init__( + **{ + "data": data, + "files": files, + "auto_id": auto_id, + "prefix": prefix, + **kwargs, + } + ) def initial_form_count(self): """Return the number of forms that are required in this FormSet.""" @@ -585,7 +691,7 @@ class BaseModelFormSet(BaseFormSet): return super().initial_form_count() def _existing_object(self, pk): - if not hasattr(self, '_object_dict'): + if not hasattr(self, "_object_dict"): self._object_dict = {o.pk: o for o in self.get_queryset()} return self._object_dict.get(pk) @@ -602,7 +708,7 @@ class BaseModelFormSet(BaseFormSet): pk_required = i < self.initial_form_count() if pk_required: if self.is_bound: - pk_key = '%s-%s' % (self.add_prefix(i), self.model._meta.pk.name) + pk_key = "%s-%s" % (self.add_prefix(i), self.model._meta.pk.name) try: pk = self.data[pk_key] except KeyError: @@ -618,13 +724,13 @@ class BaseModelFormSet(BaseFormSet): # user may have tampered with POST data. pass else: - kwargs['instance'] = self._existing_object(pk) + kwargs["instance"] = self._existing_object(pk) else: - kwargs['instance'] = self.get_queryset()[i] + kwargs["instance"] = self.get_queryset()[i] elif self.initial_extra: # Set initial values for extra forms try: - kwargs['initial'] = self.initial_extra[i - self.initial_form_count()] + kwargs["initial"] = self.initial_extra[i - self.initial_form_count()] except IndexError: pass form = super()._construct_form(i, **kwargs) @@ -633,7 +739,7 @@ class BaseModelFormSet(BaseFormSet): return form def get_queryset(self): - if not hasattr(self, '_queryset'): + if not hasattr(self, "_queryset"): if self.queryset is not None: qs = self.queryset else: @@ -675,6 +781,7 @@ class BaseModelFormSet(BaseFormSet): def save_m2m(): for form in self.saved_forms: form.save_m2m() + self.save_m2m = save_m2m if self.edit_only: return self.save_existing_objects(commit) @@ -691,10 +798,16 @@ class BaseModelFormSet(BaseFormSet): all_unique_checks = set() all_date_checks = set() forms_to_delete = self.deleted_forms - valid_forms = [form for form in self.forms if form.is_valid() and form not in forms_to_delete] + valid_forms = [ + form + for form in self.forms + if form.is_valid() and form not in forms_to_delete + ] for form in valid_forms: exclude = form._get_validation_exclusions() - unique_checks, date_checks = form.instance._get_unique_checks(exclude=exclude) + unique_checks, date_checks = form.instance._get_unique_checks( + exclude=exclude + ) all_unique_checks.update(unique_checks) all_date_checks.update(date_checks) @@ -706,14 +819,15 @@ class BaseModelFormSet(BaseFormSet): # Get the data for the set of fields that must be unique among the forms. row_data = ( field if field in self.unique_fields else form.cleaned_data[field] - for field in unique_check if field in form.cleaned_data + for field in unique_check + if field in form.cleaned_data ) # Reduce Model instances to their primary key values row_data = tuple( - d._get_pk_val() if hasattr(d, '_get_pk_val') + d._get_pk_val() if hasattr(d, "_get_pk_val") # Prevent "unhashable type: list" errors later on. - else tuple(d) if isinstance(d, list) - else d for d in row_data + else tuple(d) if isinstance(d, list) else d + for d in row_data ) if row_data and None not in row_data: # if we've already seen it then we have a uniqueness failure @@ -737,10 +851,13 @@ class BaseModelFormSet(BaseFormSet): uclass, lookup, field, unique_for = date_check for form in valid_forms: # see if we have data for both fields - if (form.cleaned_data and form.cleaned_data[field] is not None and - form.cleaned_data[unique_for] is not None): + if ( + form.cleaned_data + and form.cleaned_data[field] is not None + and form.cleaned_data[unique_for] is not None + ): # if it's a date lookup we need to get the data for all the fields - if lookup == 'date': + if lookup == "date": date = form.cleaned_data[unique_for] date_data = (date.year, date.month, date.day) # otherwise it's just the attribute on the date/datetime @@ -771,7 +888,9 @@ class BaseModelFormSet(BaseFormSet): "field": unique_check[0], } else: - return gettext("Please correct the duplicate data for %(field)s, which must be unique.") % { + return gettext( + "Please correct the duplicate data for %(field)s, which must be unique." + ) % { "field": get_text_list(unique_check, _("and")), } @@ -780,9 +899,9 @@ class BaseModelFormSet(BaseFormSet): "Please correct the duplicate data for %(field_name)s " "which must be unique for the %(lookup)s in %(date_field)s." ) % { - 'field_name': date_check[2], - 'date_field': date_check[3], - 'lookup': str(date_check[1]), + "field_name": date_check[2], + "date_field": date_check[3], + "lookup": str(date_check[1]), } def get_form_error(self): @@ -831,6 +950,7 @@ class BaseModelFormSet(BaseFormSet): def add_fields(self, form, index): """Add a hidden field for the object's primary key.""" from django.db.models import AutoField, ForeignKey, OneToOneField + self._pk_field = pk = self.model._meta.pk # If a pk isn't editable, then it won't be on the form, so we need to # add it here so we can tell which object is which when we get the @@ -840,11 +960,15 @@ class BaseModelFormSet(BaseFormSet): def pk_is_not_editable(pk): return ( - (not pk.editable) or (pk.auto_created or isinstance(pk, AutoField)) or ( - pk.remote_field and pk.remote_field.parent_link and - pk_is_not_editable(pk.remote_field.model._meta.pk) + (not pk.editable) + or (pk.auto_created or isinstance(pk, AutoField)) + or ( + pk.remote_field + and pk.remote_field.parent_link + and pk_is_not_editable(pk.remote_field.model._meta.pk) ) ) + if pk_is_not_editable(pk) or pk.name not in form.fields: if form.is_bound: # If we're adding the related instance, ignore its primary key @@ -868,37 +992,75 @@ class BaseModelFormSet(BaseFormSet): widget = form._meta.widgets.get(self._pk_field.name, HiddenInput) else: widget = HiddenInput - form.fields[self._pk_field.name] = ModelChoiceField(qs, initial=pk_value, required=False, widget=widget) + form.fields[self._pk_field.name] = ModelChoiceField( + qs, initial=pk_value, required=False, widget=widget + ) super().add_fields(form, index) -def modelformset_factory(model, form=ModelForm, formfield_callback=None, - formset=BaseModelFormSet, extra=1, can_delete=False, - can_order=False, max_num=None, fields=None, exclude=None, - widgets=None, validate_max=False, localized_fields=None, - labels=None, help_texts=None, error_messages=None, - min_num=None, validate_min=False, field_classes=None, - absolute_max=None, can_delete_extra=True, renderer=None, - edit_only=False): +def modelformset_factory( + model, + form=ModelForm, + formfield_callback=None, + formset=BaseModelFormSet, + extra=1, + can_delete=False, + can_order=False, + max_num=None, + fields=None, + exclude=None, + widgets=None, + validate_max=False, + localized_fields=None, + labels=None, + help_texts=None, + error_messages=None, + min_num=None, + validate_min=False, + field_classes=None, + absolute_max=None, + can_delete_extra=True, + renderer=None, + edit_only=False, +): """Return a FormSet class for the given Django model class.""" - meta = getattr(form, 'Meta', None) - if (getattr(meta, 'fields', fields) is None and - getattr(meta, 'exclude', exclude) is None): + meta = getattr(form, "Meta", None) + if ( + getattr(meta, "fields", fields) is None + and getattr(meta, "exclude", exclude) is None + ): raise ImproperlyConfigured( "Calling modelformset_factory without defining 'fields' or " "'exclude' explicitly is prohibited." ) - form = modelform_factory(model, form=form, fields=fields, exclude=exclude, - formfield_callback=formfield_callback, - widgets=widgets, localized_fields=localized_fields, - labels=labels, help_texts=help_texts, - error_messages=error_messages, field_classes=field_classes) - FormSet = formset_factory(form, formset, extra=extra, min_num=min_num, max_num=max_num, - can_order=can_order, can_delete=can_delete, - validate_min=validate_min, validate_max=validate_max, - absolute_max=absolute_max, can_delete_extra=can_delete_extra, - renderer=renderer) + form = modelform_factory( + model, + form=form, + fields=fields, + exclude=exclude, + formfield_callback=formfield_callback, + widgets=widgets, + localized_fields=localized_fields, + labels=labels, + help_texts=help_texts, + error_messages=error_messages, + field_classes=field_classes, + ) + FormSet = formset_factory( + form, + formset, + extra=extra, + min_num=min_num, + max_num=max_num, + can_order=can_order, + can_delete=can_delete, + validate_min=validate_min, + validate_max=validate_max, + absolute_max=absolute_max, + can_delete_extra=can_delete_extra, + renderer=renderer, + ) FormSet.model = model FormSet.edit_only = edit_only return FormSet @@ -906,10 +1068,20 @@ def modelformset_factory(model, form=ModelForm, formfield_callback=None, # InlineFormSets ############################################################# + class BaseInlineFormSet(BaseModelFormSet): """A formset for child objects related to a parent.""" - def __init__(self, data=None, files=None, instance=None, - save_as_new=False, prefix=None, queryset=None, **kwargs): + + def __init__( + self, + data=None, + files=None, + instance=None, + save_as_new=False, + prefix=None, + queryset=None, + **kwargs, + ): if instance is None: self.instance = self.fk.remote_field.model() else: @@ -939,7 +1111,7 @@ class BaseInlineFormSet(BaseModelFormSet): def _construct_form(self, i, **kwargs): form = super()._construct_form(i, **kwargs) if self.save_as_new: - mutable = getattr(form.data, '_mutable', None) + mutable = getattr(form.data, "_mutable", None) # Allow modifying an immutable QueryDict. if mutable is not None: form.data._mutable = True @@ -955,13 +1127,13 @@ class BaseInlineFormSet(BaseModelFormSet): fk_value = self.instance.pk if self.fk.remote_field.field_name != self.fk.remote_field.model._meta.pk.name: fk_value = getattr(self.instance, self.fk.remote_field.field_name) - fk_value = getattr(fk_value, 'pk', fk_value) + fk_value = getattr(fk_value, "pk", fk_value) setattr(form.instance, self.fk.get_attname(), fk_value) return form @classmethod def get_default_prefix(cls): - return cls.fk.remote_field.get_accessor_name(model=cls.model).replace('+', '') + return cls.fk.remote_field.get_accessor_name(model=cls.model).replace("+", "") def save_new(self, form, commit=True): # Ensure the latest copy of the related instance is present on each @@ -974,26 +1146,28 @@ class BaseInlineFormSet(BaseModelFormSet): super().add_fields(form, index) if self._pk_field == self.fk: name = self._pk_field.name - kwargs = {'pk_field': True} + kwargs = {"pk_field": True} else: # The foreign key field might not be on the form, so we poke at the # Model field to get the label, since we need that for error messages. name = self.fk.name kwargs = { - 'label': getattr(form.fields.get(name), 'label', capfirst(self.fk.verbose_name)) + "label": getattr( + form.fields.get(name), "label", capfirst(self.fk.verbose_name) + ) } # The InlineForeignKeyField assumes that the foreign key relation is # based on the parent model's pk. If this isn't the case, set to_field # to correctly resolve the initial form value. if self.fk.remote_field.field_name != self.fk.remote_field.model._meta.pk.name: - kwargs['to_field'] = self.fk.remote_field.field_name + kwargs["to_field"] = self.fk.remote_field.field_name # If we're adding a new object, ignore a parent's auto-generated key # as it will be regenerated on the save request. if self.instance._state.adding: - if kwargs.get('to_field') is not None: - to_field = self.instance._meta.get_field(kwargs['to_field']) + if kwargs.get("to_field") is not None: + to_field = self.instance._meta.get_field(kwargs["to_field"]) else: to_field = self.instance._meta.pk if to_field.has_default(): @@ -1016,24 +1190,30 @@ def _get_foreign_key(parent_model, model, fk_name=None, can_fail=False): """ # avoid circular import from django.db.models import ForeignKey + opts = model._meta if fk_name: fks_to_parent = [f for f in opts.fields if f.name == fk_name] if len(fks_to_parent) == 1: fk = fks_to_parent[0] parent_list = parent_model._meta.get_parent_list() - if not isinstance(fk, ForeignKey) or ( - # ForeignKey to proxy models. - fk.remote_field.model._meta.proxy and - fk.remote_field.model._meta.proxy_for_model not in parent_list - ) or ( - # ForeignKey to concrete models. - not fk.remote_field.model._meta.proxy and - fk.remote_field.model != parent_model and - fk.remote_field.model not in parent_list + if ( + not isinstance(fk, ForeignKey) + or ( + # ForeignKey to proxy models. + fk.remote_field.model._meta.proxy + and fk.remote_field.model._meta.proxy_for_model not in parent_list + ) + or ( + # ForeignKey to concrete models. + not fk.remote_field.model._meta.proxy + and fk.remote_field.model != parent_model + and fk.remote_field.model not in parent_list + ) ): raise ValueError( - "fk_name '%s' is not a ForeignKey to '%s'." % (fk_name, parent_model._meta.label) + "fk_name '%s' is not a ForeignKey to '%s'." + % (fk_name, parent_model._meta.label) ) elif not fks_to_parent: raise ValueError( @@ -1043,12 +1223,15 @@ def _get_foreign_key(parent_model, model, fk_name=None, can_fail=False): # Try to discover what the ForeignKey from model to parent_model is parent_list = parent_model._meta.get_parent_list() fks_to_parent = [ - f for f in opts.fields - if isinstance(f, ForeignKey) and ( - f.remote_field.model == parent_model or - f.remote_field.model in parent_list or ( - f.remote_field.model._meta.proxy and - f.remote_field.model._meta.proxy_for_model in parent_list + f + for f in opts.fields + if isinstance(f, ForeignKey) + and ( + f.remote_field.model == parent_model + or f.remote_field.model in parent_list + or ( + f.remote_field.model._meta.proxy + and f.remote_field.model._meta.proxy_for_model in parent_list ) ) ] @@ -1058,7 +1241,8 @@ def _get_foreign_key(parent_model, model, fk_name=None, can_fail=False): if can_fail: return raise ValueError( - "'%s' has no ForeignKey to '%s'." % ( + "'%s' has no ForeignKey to '%s'." + % ( model._meta.label, parent_model._meta.label, ) @@ -1066,7 +1250,8 @@ def _get_foreign_key(parent_model, model, fk_name=None, can_fail=False): else: raise ValueError( "'%s' has more than one ForeignKey to '%s'. You must specify " - "a 'fk_name' attribute." % ( + "a 'fk_name' attribute." + % ( model._meta.label, parent_model._meta.label, ) @@ -1074,15 +1259,33 @@ def _get_foreign_key(parent_model, model, fk_name=None, can_fail=False): return fk -def inlineformset_factory(parent_model, model, form=ModelForm, - formset=BaseInlineFormSet, fk_name=None, - fields=None, exclude=None, extra=3, can_order=False, - can_delete=True, max_num=None, formfield_callback=None, - widgets=None, validate_max=False, localized_fields=None, - labels=None, help_texts=None, error_messages=None, - min_num=None, validate_min=False, field_classes=None, - absolute_max=None, can_delete_extra=True, renderer=None, - edit_only=False): +def inlineformset_factory( + parent_model, + model, + form=ModelForm, + formset=BaseInlineFormSet, + fk_name=None, + fields=None, + exclude=None, + extra=3, + can_order=False, + can_delete=True, + max_num=None, + formfield_callback=None, + widgets=None, + validate_max=False, + localized_fields=None, + labels=None, + help_texts=None, + error_messages=None, + min_num=None, + validate_min=False, + field_classes=None, + absolute_max=None, + can_delete_extra=True, + renderer=None, + edit_only=False, +): """ Return an ``InlineFormSet`` for the given kwargs. @@ -1094,28 +1297,28 @@ def inlineformset_factory(parent_model, model, form=ModelForm, if fk.unique: max_num = 1 kwargs = { - 'form': form, - 'formfield_callback': formfield_callback, - 'formset': formset, - 'extra': extra, - 'can_delete': can_delete, - 'can_order': can_order, - 'fields': fields, - 'exclude': exclude, - 'min_num': min_num, - 'max_num': max_num, - 'widgets': widgets, - 'validate_min': validate_min, - 'validate_max': validate_max, - 'localized_fields': localized_fields, - 'labels': labels, - 'help_texts': help_texts, - 'error_messages': error_messages, - 'field_classes': field_classes, - 'absolute_max': absolute_max, - 'can_delete_extra': can_delete_extra, - 'renderer': renderer, - 'edit_only': edit_only, + "form": form, + "formfield_callback": formfield_callback, + "formset": formset, + "extra": extra, + "can_delete": can_delete, + "can_order": can_order, + "fields": fields, + "exclude": exclude, + "min_num": min_num, + "max_num": max_num, + "widgets": widgets, + "validate_min": validate_min, + "validate_max": validate_max, + "localized_fields": localized_fields, + "labels": labels, + "help_texts": help_texts, + "error_messages": error_messages, + "field_classes": field_classes, + "absolute_max": absolute_max, + "can_delete_extra": can_delete_extra, + "renderer": renderer, + "edit_only": edit_only, } FormSet = modelformset_factory(model, **kwargs) FormSet.fk = fk @@ -1124,14 +1327,16 @@ def inlineformset_factory(parent_model, model, form=ModelForm, # Fields ##################################################################### + class InlineForeignKeyField(Field): """ A basic integer field that deals with validating the given value to a given parent instance in an inline. """ + widget = HiddenInput default_error_messages = { - 'invalid_choice': _('The inline value did not match the parent instance.'), + "invalid_choice": _("The inline value did not match the parent instance."), } def __init__(self, parent_instance, *args, pk_field=False, to_field=None, **kwargs): @@ -1158,7 +1363,9 @@ class InlineForeignKeyField(Field): else: orig = self.parent_instance.pk if str(value) != str(orig): - raise ValidationError(self.error_messages['invalid_choice'], code='invalid_choice') + raise ValidationError( + self.error_messages["invalid_choice"], code="invalid_choice" + ) return self.parent_instance def has_changed(self, initial, data): @@ -1215,34 +1422,50 @@ class ModelChoiceIterator: class ModelChoiceField(ChoiceField): """A ChoiceField whose choices are a model QuerySet.""" + # This class is a subclass of ChoiceField for purity, but it doesn't # actually use any of ChoiceField's implementation. default_error_messages = { - 'invalid_choice': _( - 'Select a valid choice. That choice is not one of the available choices.' + "invalid_choice": _( + "Select a valid choice. That choice is not one of the available choices." ), } iterator = ModelChoiceIterator - def __init__(self, queryset, *, empty_label="---------", - required=True, widget=None, label=None, initial=None, - help_text='', to_field_name=None, limit_choices_to=None, - blank=False, **kwargs): + def __init__( + self, + queryset, + *, + empty_label="---------", + required=True, + widget=None, + label=None, + initial=None, + help_text="", + to_field_name=None, + limit_choices_to=None, + blank=False, + **kwargs, + ): # Call Field instead of ChoiceField __init__() because we don't need # ChoiceField.__init__(). Field.__init__( - self, required=required, widget=widget, label=label, - initial=initial, help_text=help_text, **kwargs + self, + required=required, + widget=widget, + label=label, + initial=initial, + help_text=help_text, + **kwargs, ) - if ( - (required and initial is not None) or - (isinstance(self.widget, RadioSelect) and not blank) + if (required and initial is not None) or ( + isinstance(self.widget, RadioSelect) and not blank ): self.empty_label = None else: self.empty_label = empty_label self.queryset = queryset - self.limit_choices_to = limit_choices_to # limit the queryset later. + self.limit_choices_to = limit_choices_to # limit the queryset later. self.to_field_name = to_field_name def get_limit_choices_to(self): @@ -1284,7 +1507,7 @@ class ModelChoiceField(ChoiceField): def _get_choices(self): # If self._choices is set, then somebody must have manually set # the property self.choices. In this case, just return self._choices. - if hasattr(self, '_choices'): + if hasattr(self, "_choices"): return self._choices # Otherwise, execute the QuerySet in self.queryset to determine the @@ -1299,7 +1522,7 @@ class ModelChoiceField(ChoiceField): choices = property(_get_choices, ChoiceField._set_choices) def prepare_value(self, value): - if hasattr(value, '_meta'): + if hasattr(value, "_meta"): if self.to_field_name: return value.serializable_value(self.to_field_name) else: @@ -1310,15 +1533,15 @@ class ModelChoiceField(ChoiceField): if value in self.empty_values: return None try: - key = self.to_field_name or 'pk' + key = self.to_field_name or "pk" if isinstance(value, self.queryset.model): value = getattr(value, key) value = self.queryset.get(**{key: value}) except (ValueError, TypeError, self.queryset.model.DoesNotExist): raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': value}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": value}, ) return value @@ -1328,21 +1551,22 @@ class ModelChoiceField(ChoiceField): def has_changed(self, initial, data): if self.disabled: return False - initial_value = initial if initial is not None else '' - data_value = data if data is not None else '' + initial_value = initial if initial is not None else "" + data_value = data if data is not None else "" return str(self.prepare_value(initial_value)) != str(data_value) class ModelMultipleChoiceField(ModelChoiceField): """A MultipleChoiceField whose choices are a model QuerySet.""" + widget = SelectMultiple hidden_widget = MultipleHiddenInput default_error_messages = { - 'invalid_list': _('Enter a list of values.'), - 'invalid_choice': _( - 'Select a valid choice. %(value)s is not one of the available choices.' + "invalid_list": _("Enter a list of values."), + "invalid_choice": _( + "Select a valid choice. %(value)s is not one of the available choices." ), - 'invalid_pk_value': _('“%(pk)s” is not a valid value.') + "invalid_pk_value": _("“%(pk)s” is not a valid value."), } def __init__(self, queryset, **kwargs): @@ -1356,13 +1580,13 @@ class ModelMultipleChoiceField(ModelChoiceField): def clean(self, value): value = self.prepare_value(value) if self.required and not value: - raise ValidationError(self.error_messages['required'], code='required') + raise ValidationError(self.error_messages["required"], code="required") elif not self.required and not value: return self.queryset.none() if not isinstance(value, (list, tuple)): raise ValidationError( - self.error_messages['invalid_list'], - code='invalid_list', + self.error_messages["invalid_list"], + code="invalid_list", ) qs = self._check_values(value) # Since this overrides the inherited ModelChoiceField.clean @@ -1376,7 +1600,7 @@ class ModelMultipleChoiceField(ModelChoiceField): corresponding objects. Raise a ValidationError if a given value is invalid (not a valid PK, not in the queryset, etc.) """ - key = self.to_field_name or 'pk' + key = self.to_field_name or "pk" # deduplicate given values to avoid creating many querysets or # requiring the database backend deduplicate efficiently. try: @@ -1384,33 +1608,35 @@ class ModelMultipleChoiceField(ModelChoiceField): except TypeError: # list of lists isn't hashable, for example raise ValidationError( - self.error_messages['invalid_list'], - code='invalid_list', + self.error_messages["invalid_list"], + code="invalid_list", ) for pk in value: try: self.queryset.filter(**{key: pk}) except (ValueError, TypeError): raise ValidationError( - self.error_messages['invalid_pk_value'], - code='invalid_pk_value', - params={'pk': pk}, + self.error_messages["invalid_pk_value"], + code="invalid_pk_value", + params={"pk": pk}, ) - qs = self.queryset.filter(**{'%s__in' % key: value}) + qs = self.queryset.filter(**{"%s__in" % key: value}) pks = {str(getattr(o, key)) for o in qs} for val in value: if str(val) not in pks: raise ValidationError( - self.error_messages['invalid_choice'], - code='invalid_choice', - params={'value': val}, + self.error_messages["invalid_choice"], + code="invalid_choice", + params={"value": val}, ) return qs def prepare_value(self, value): - if (hasattr(value, '__iter__') and - not isinstance(value, str) and - not hasattr(value, '_meta')): + if ( + hasattr(value, "__iter__") + and not isinstance(value, str) + and not hasattr(value, "_meta") + ): prepare_value = super().prepare_value return [prepare_value(v) for v in value] return super().prepare_value(value) @@ -1430,7 +1656,6 @@ class ModelMultipleChoiceField(ModelChoiceField): def modelform_defines_fields(form_class): - return hasattr(form_class, '_meta') and ( - form_class._meta.fields is not None or - form_class._meta.exclude is not None + return hasattr(form_class, "_meta") and ( + form_class._meta.fields is not None or form_class._meta.exclude is not None ) diff --git a/django/forms/renderers.py b/django/forms/renderers.py index ffb61600c2..88cf504653 100644 --- a/django/forms/renderers.py +++ b/django/forms/renderers.py @@ -16,7 +16,7 @@ def get_default_renderer(): class BaseRenderer: def get_template(self, template_name): - raise NotImplementedError('subclasses must implement get_template()') + raise NotImplementedError("subclasses must implement get_template()") def render(self, template_name, context, request=None): template = self.get_template(template_name) @@ -29,12 +29,14 @@ class EngineMixin: @cached_property def engine(self): - return self.backend({ - 'APP_DIRS': True, - 'DIRS': [Path(__file__).parent / self.backend.app_dirname], - 'NAME': 'djangoforms', - 'OPTIONS': {}, - }) + return self.backend( + { + "APP_DIRS": True, + "DIRS": [Path(__file__).parent / self.backend.app_dirname], + "NAME": "djangoforms", + "OPTIONS": {}, + } + ) class DjangoTemplates(EngineMixin, BaseRenderer): @@ -42,6 +44,7 @@ class DjangoTemplates(EngineMixin, BaseRenderer): Load Django templates from the built-in widget templates in django/forms/templates and from apps' 'templates' directory. """ + backend = DjangoTemplates @@ -50,9 +53,11 @@ class Jinja2(EngineMixin, BaseRenderer): Load Jinja2 templates from the built-in widget templates in django/forms/jinja2 and from apps' 'jinja2' directory. """ + @cached_property def backend(self): from django.template.backends.jinja2 import Jinja2 + return Jinja2 @@ -61,5 +66,6 @@ class TemplatesSetting(BaseRenderer): Load templates using template.loader.get_template() which is configured based on settings.TEMPLATES. """ + def get_template(self, template_name): return get_template(template_name) diff --git a/django/forms/utils.py b/django/forms/utils.py index 4af57e8586..7d3bb5ad48 100644 --- a/django/forms/utils.py +++ b/django/forms/utils.py @@ -13,8 +13,8 @@ from django.utils.translation import gettext_lazy as _ def pretty_name(name): """Convert 'first_name' to 'First name'.""" if not name: - return '' - return name.replace('_', ' ').capitalize() + return "" + return name.replace("_", " ").capitalize() def flatatt(attrs): @@ -37,23 +37,24 @@ def flatatt(attrs): elif value is not None: key_value_attrs.append((attr, value)) - return ( - format_html_join('', ' {}="{}"', sorted(key_value_attrs)) + - format_html_join('', ' {}', sorted(boolean_attrs)) + return format_html_join("", ' {}="{}"', sorted(key_value_attrs)) + format_html_join( + "", " {}", sorted(boolean_attrs) ) class RenderableMixin: def get_context(self): raise NotImplementedError( - 'Subclasses of RenderableMixin must provide a get_context() method.' + "Subclasses of RenderableMixin must provide a get_context() method." ) def render(self, template_name=None, context=None, renderer=None): - return mark_safe((renderer or self.renderer).render( - template_name or self.template_name, - context or self.get_context(), - )) + return mark_safe( + (renderer or self.renderer).render( + template_name or self.template_name, + context or self.get_context(), + ) + ) __str__ = render __html__ = render @@ -90,9 +91,10 @@ class ErrorDict(dict, RenderableErrorMixin): The dictionary keys are the field names, and the values are the errors. """ - template_name = 'django/forms/errors/dict/default.html' - template_name_text = 'django/forms/errors/dict/text.txt' - template_name_ul = 'django/forms/errors/dict/ul.html' + + template_name = "django/forms/errors/dict/default.html" + template_name_text = "django/forms/errors/dict/text.txt" + template_name_ul = "django/forms/errors/dict/ul.html" def __init__(self, *args, renderer=None, **kwargs): super().__init__(*args, **kwargs) @@ -106,8 +108,8 @@ class ErrorDict(dict, RenderableErrorMixin): def get_context(self): return { - 'errors': self.items(), - 'error_class': 'errorlist', + "errors": self.items(), + "error_class": "errorlist", } @@ -115,17 +117,18 @@ class ErrorList(UserList, list, RenderableErrorMixin): """ A collection of errors that knows how to display itself in various formats. """ - template_name = 'django/forms/errors/list/default.html' - template_name_text = 'django/forms/errors/list/text.txt' - template_name_ul = 'django/forms/errors/list/ul.html' + + template_name = "django/forms/errors/list/default.html" + template_name_text = "django/forms/errors/list/text.txt" + template_name_ul = "django/forms/errors/list/ul.html" def __init__(self, initlist=None, error_class=None, renderer=None): super().__init__(initlist) if error_class is None: - self.error_class = 'errorlist' + self.error_class = "errorlist" else: - self.error_class = 'errorlist {}'.format(error_class) + self.error_class = "errorlist {}".format(error_class) self.renderer = renderer or get_default_renderer() def as_data(self): @@ -140,16 +143,18 @@ class ErrorList(UserList, list, RenderableErrorMixin): errors = [] for error in self.as_data(): message = next(iter(error)) - errors.append({ - 'message': escape(message) if escape_html else message, - 'code': error.code or '', - }) + errors.append( + { + "message": escape(message) if escape_html else message, + "code": error.code or "", + } + ) return errors def get_context(self): return { - 'errors': self, - 'error_class': self.error_class, + "errors": self, + "error_class": self.error_class, } def __repr__(self): @@ -179,6 +184,7 @@ class ErrorList(UserList, list, RenderableErrorMixin): # Utilities for time zone support in DateTimeField et al. + def from_current_timezone(value): """ When time zone support is enabled, convert naive datetimes @@ -187,19 +193,20 @@ def from_current_timezone(value): if settings.USE_TZ and value is not None and timezone.is_naive(value): current_timezone = timezone.get_current_timezone() try: - if ( - not timezone._is_pytz_zone(current_timezone) and - timezone._datetime_ambiguous_or_imaginary(value, current_timezone) - ): - raise ValueError('Ambiguous or non-existent time.') + if not timezone._is_pytz_zone( + current_timezone + ) and timezone._datetime_ambiguous_or_imaginary(value, current_timezone): + raise ValueError("Ambiguous or non-existent time.") return timezone.make_aware(value, current_timezone) except Exception as exc: raise ValidationError( - _('%(datetime)s couldn’t be interpreted ' - 'in time zone %(current_timezone)s; it ' - 'may be ambiguous or it may not exist.'), - code='ambiguous_timezone', - params={'datetime': value, 'current_timezone': current_timezone} + _( + "%(datetime)s couldn’t be interpreted " + "in time zone %(current_timezone)s; it " + "may be ambiguous or it may not exist." + ), + code="ambiguous_timezone", + params={"datetime": value, "current_timezone": current_timezone}, ) from exc return value diff --git a/django/forms/widgets.py b/django/forms/widgets.py index 05667f8e44..8c5122ad1d 100644 --- a/django/forms/widgets.py +++ b/django/forms/widgets.py @@ -17,24 +17,41 @@ from django.utils.formats import get_format from django.utils.html import format_html, html_safe from django.utils.regex_helper import _lazy_re_compile from django.utils.safestring import mark_safe -from django.utils.topological_sort import ( - CyclicDependencyError, stable_topological_sort, -) +from django.utils.topological_sort import CyclicDependencyError, stable_topological_sort from django.utils.translation import gettext_lazy as _ from .renderers import get_default_renderer __all__ = ( - 'Media', 'MediaDefiningClass', 'Widget', 'TextInput', 'NumberInput', - 'EmailInput', 'URLInput', 'PasswordInput', 'HiddenInput', - 'MultipleHiddenInput', 'FileInput', 'ClearableFileInput', 'Textarea', - 'DateInput', 'DateTimeInput', 'TimeInput', 'CheckboxInput', 'Select', - 'NullBooleanSelect', 'SelectMultiple', 'RadioSelect', - 'CheckboxSelectMultiple', 'MultiWidget', 'SplitDateTimeWidget', - 'SplitHiddenDateTimeWidget', 'SelectDateWidget', + "Media", + "MediaDefiningClass", + "Widget", + "TextInput", + "NumberInput", + "EmailInput", + "URLInput", + "PasswordInput", + "HiddenInput", + "MultipleHiddenInput", + "FileInput", + "ClearableFileInput", + "Textarea", + "DateInput", + "DateTimeInput", + "TimeInput", + "CheckboxInput", + "Select", + "NullBooleanSelect", + "SelectMultiple", + "RadioSelect", + "CheckboxSelectMultiple", + "MultiWidget", + "SplitDateTimeWidget", + "SplitHiddenDateTimeWidget", + "SelectDateWidget", ) -MEDIA_TYPES = ('css', 'js') +MEDIA_TYPES = ("css", "js") class MediaOrderConflictWarning(RuntimeWarning): @@ -45,8 +62,8 @@ class MediaOrderConflictWarning(RuntimeWarning): class Media: def __init__(self, media=None, css=None, js=None): if media is not None: - css = getattr(media, 'css', {}) - js = getattr(media, 'js', []) + css = getattr(media, "css", {}) + js = getattr(media, "js", []) else: if css is None: css = {} @@ -56,7 +73,7 @@ class Media: self._js_lists = [js] def __repr__(self): - return 'Media(css=%r, js=%r)' % (self._css, self._js) + return "Media(css=%r, js=%r)" % (self._css, self._js) def __str__(self): return self.render() @@ -74,26 +91,35 @@ class Media: return self.merge(*self._js_lists) def render(self): - return mark_safe('\n'.join(chain.from_iterable(getattr(self, 'render_' + name)() for name in MEDIA_TYPES))) + return mark_safe( + "\n".join( + chain.from_iterable( + getattr(self, "render_" + name)() for name in MEDIA_TYPES + ) + ) + ) def render_js(self): return [ - format_html( - '<script src="{}"></script>', - self.absolute_path(path) - ) for path in self._js + format_html('<script src="{}"></script>', self.absolute_path(path)) + for path in self._js ] def render_css(self): # To keep rendering order consistent, we can't just iterate over items(). # We need to sort the keys, and iterate over the sorted list. media = sorted(self._css) - return chain.from_iterable([ - format_html( - '<link href="{}" media="{}" rel="stylesheet">', - self.absolute_path(path), medium - ) for path in self._css[medium] - ] for medium in media) + return chain.from_iterable( + [ + format_html( + '<link href="{}" media="{}" rel="stylesheet">', + self.absolute_path(path), + medium, + ) + for path in self._css[medium] + ] + for medium in media + ) def absolute_path(self, path): """ @@ -101,14 +127,14 @@ class Media: path. An absolute path will be returned unchanged while a relative path will be passed to django.templatetags.static.static(). """ - if path.startswith(('http://', 'https://', '/')): + if path.startswith(("http://", "https://", "/")): return path return static(path) def __getitem__(self, name): """Return a Media object that only contains media of the given type.""" if name in MEDIA_TYPES: - return Media(**{str(name): getattr(self, '_' + name)}) + return Media(**{str(name): getattr(self, "_" + name)}) raise KeyError('Unknown media type "%s"' % name) @staticmethod @@ -138,9 +164,10 @@ class Media: return stable_topological_sort(all_items, dependency_graph) except CyclicDependencyError: warnings.warn( - 'Detected duplicate Media files in an opposite order: {}'.format( - ', '.join(repr(list_) for list_ in lists) - ), MediaOrderConflictWarning, + "Detected duplicate Media files in an opposite order: {}".format( + ", ".join(repr(list_) for list_ in lists) + ), + MediaOrderConflictWarning, ) return list(all_items) @@ -167,9 +194,9 @@ def media_property(cls): base = Media() # Get the media definition for this class - definition = getattr(cls, 'Media', None) + definition = getattr(cls, "Media", None) if definition: - extend = getattr(definition, 'extend', True) + extend = getattr(definition, "extend", True) if extend: if extend is True: m = base @@ -180,6 +207,7 @@ def media_property(cls): return m + Media(definition) return Media(definition) return base + return property(_media) @@ -187,10 +215,11 @@ class MediaDefiningClass(type): """ Metaclass for classes that can have media definitions. """ + def __new__(mcs, name, bases, attrs): new_class = super().__new__(mcs, name, bases, attrs) - if 'media' not in attrs: + if "media" not in attrs: new_class.media = media_property(new_class) return new_class @@ -213,17 +242,17 @@ class Widget(metaclass=MediaDefiningClass): @property def is_hidden(self): - return self.input_type == 'hidden' if hasattr(self, 'input_type') else False + return self.input_type == "hidden" if hasattr(self, "input_type") else False def subwidgets(self, name, value, attrs=None): context = self.get_context(name, value, attrs) - yield context['widget'] + yield context["widget"] def format_value(self, value): """ Return a value as it should appear when rendered in a template. """ - if value == '' or value is None: + if value == "" or value is None: return None if self.is_localized: return formats.localize_input(value) @@ -231,13 +260,13 @@ class Widget(metaclass=MediaDefiningClass): def get_context(self, name, value, attrs): return { - 'widget': { - 'name': name, - 'is_hidden': self.is_hidden, - 'required': self.is_required, - 'value': self.format_value(value), - 'attrs': self.build_attrs(self.attrs, attrs), - 'template_name': self.template_name, + "widget": { + "name": name, + "is_hidden": self.is_hidden, + "required": self.is_required, + "value": self.format_value(value), + "attrs": self.build_attrs(self.attrs, attrs), + "template_name": self.template_name, }, } @@ -285,44 +314,45 @@ class Input(Widget): """ Base class for all <input> widgets. """ + input_type = None # Subclasses must define this. - template_name = 'django/forms/widgets/input.html' + template_name = "django/forms/widgets/input.html" def __init__(self, attrs=None): if attrs is not None: attrs = attrs.copy() - self.input_type = attrs.pop('type', self.input_type) + self.input_type = attrs.pop("type", self.input_type) super().__init__(attrs) def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) - context['widget']['type'] = self.input_type + context["widget"]["type"] = self.input_type return context class TextInput(Input): - input_type = 'text' - template_name = 'django/forms/widgets/text.html' + input_type = "text" + template_name = "django/forms/widgets/text.html" class NumberInput(Input): - input_type = 'number' - template_name = 'django/forms/widgets/number.html' + input_type = "number" + template_name = "django/forms/widgets/number.html" class EmailInput(Input): - input_type = 'email' - template_name = 'django/forms/widgets/email.html' + input_type = "email" + template_name = "django/forms/widgets/email.html" class URLInput(Input): - input_type = 'url' - template_name = 'django/forms/widgets/url.html' + input_type = "url" + template_name = "django/forms/widgets/url.html" class PasswordInput(Input): - input_type = 'password' - template_name = 'django/forms/widgets/password.html' + input_type = "password" + template_name = "django/forms/widgets/password.html" def __init__(self, attrs=None, render_value=False): super().__init__(attrs) @@ -335,8 +365,8 @@ class PasswordInput(Input): class HiddenInput(Input): - input_type = 'hidden' - template_name = 'django/forms/widgets/hidden.html' + input_type = "hidden" + template_name = "django/forms/widgets/hidden.html" class MultipleHiddenInput(HiddenInput): @@ -344,25 +374,26 @@ class MultipleHiddenInput(HiddenInput): Handle <input type="hidden"> for fields that have a list of values. """ - template_name = 'django/forms/widgets/multiple_hidden.html' + + template_name = "django/forms/widgets/multiple_hidden.html" def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) - final_attrs = context['widget']['attrs'] - id_ = context['widget']['attrs'].get('id') + final_attrs = context["widget"]["attrs"] + id_ = context["widget"]["attrs"].get("id") subwidgets = [] - for index, value_ in enumerate(context['widget']['value']): + for index, value_ in enumerate(context["widget"]["value"]): widget_attrs = final_attrs.copy() if id_: # An ID attribute was given. Add a numeric index as a suffix # so that the inputs don't all have the same ID attribute. - widget_attrs['id'] = '%s_%s' % (id_, index) + widget_attrs["id"] = "%s_%s" % (id_, index) widget = HiddenInput() widget.is_required = self.is_required - subwidgets.append(widget.get_context(name, value_, widget_attrs)['widget']) + subwidgets.append(widget.get_context(name, value_, widget_attrs)["widget"]) - context['widget']['subwidgets'] = subwidgets + context["widget"]["subwidgets"] = subwidgets return context def value_from_datadict(self, data, files, name): @@ -377,9 +408,9 @@ class MultipleHiddenInput(HiddenInput): class FileInput(Input): - input_type = 'file' + input_type = "file" needs_multipart_form = True - template_name = 'django/forms/widgets/file.html' + template_name = "django/forms/widgets/file.html" def format_value(self, value): """File input never renders a value.""" @@ -400,29 +431,29 @@ FILE_INPUT_CONTRADICTION = object() class ClearableFileInput(FileInput): - clear_checkbox_label = _('Clear') - initial_text = _('Currently') - input_text = _('Change') - template_name = 'django/forms/widgets/clearable_file_input.html' + clear_checkbox_label = _("Clear") + initial_text = _("Currently") + input_text = _("Change") + template_name = "django/forms/widgets/clearable_file_input.html" def clear_checkbox_name(self, name): """ Given the name of the file input, return the name of the clear checkbox input. """ - return name + '-clear' + return name + "-clear" def clear_checkbox_id(self, name): """ Given the name of the clear checkbox input, return the HTML id for it. """ - return name + '_id' + return name + "_id" def is_initial(self, value): """ Return whether value is considered to be initial value. """ - return bool(value and getattr(value, 'url', False)) + return bool(value and getattr(value, "url", False)) def format_value(self, value): """ @@ -435,20 +466,23 @@ class ClearableFileInput(FileInput): context = super().get_context(name, value, attrs) checkbox_name = self.clear_checkbox_name(name) checkbox_id = self.clear_checkbox_id(checkbox_name) - context['widget'].update({ - 'checkbox_name': checkbox_name, - 'checkbox_id': checkbox_id, - 'is_initial': self.is_initial(value), - 'input_text': self.input_text, - 'initial_text': self.initial_text, - 'clear_checkbox_label': self.clear_checkbox_label, - }) + context["widget"].update( + { + "checkbox_name": checkbox_name, + "checkbox_id": checkbox_id, + "is_initial": self.is_initial(value), + "input_text": self.input_text, + "initial_text": self.initial_text, + "clear_checkbox_label": self.clear_checkbox_label, + } + ) return context def value_from_datadict(self, data, files, name): upload = super().value_from_datadict(data, files, name) if not self.is_required and CheckboxInput().value_from_datadict( - data, files, self.clear_checkbox_name(name)): + data, files, self.clear_checkbox_name(name) + ): if upload: # If the user contradicts themselves (uploads a new file AND @@ -461,24 +495,24 @@ class ClearableFileInput(FileInput): def value_omitted_from_data(self, data, files, name): return ( - super().value_omitted_from_data(data, files, name) and - self.clear_checkbox_name(name) not in data + super().value_omitted_from_data(data, files, name) + and self.clear_checkbox_name(name) not in data ) class Textarea(Widget): - template_name = 'django/forms/widgets/textarea.html' + template_name = "django/forms/widgets/textarea.html" def __init__(self, attrs=None): # Use slightly better defaults than HTML's 20x2 box - default_attrs = {'cols': '40', 'rows': '10'} + default_attrs = {"cols": "40", "rows": "10"} if attrs: default_attrs.update(attrs) super().__init__(default_attrs) class DateTimeBaseInput(TextInput): - format_key = '' + format_key = "" supports_microseconds = False def __init__(self, attrs=None, format=None): @@ -486,32 +520,34 @@ class DateTimeBaseInput(TextInput): self.format = format or None def format_value(self, value): - return formats.localize_input(value, self.format or formats.get_format(self.format_key)[0]) + return formats.localize_input( + value, self.format or formats.get_format(self.format_key)[0] + ) class DateInput(DateTimeBaseInput): - format_key = 'DATE_INPUT_FORMATS' - template_name = 'django/forms/widgets/date.html' + format_key = "DATE_INPUT_FORMATS" + template_name = "django/forms/widgets/date.html" class DateTimeInput(DateTimeBaseInput): - format_key = 'DATETIME_INPUT_FORMATS' - template_name = 'django/forms/widgets/datetime.html' + format_key = "DATETIME_INPUT_FORMATS" + template_name = "django/forms/widgets/datetime.html" class TimeInput(DateTimeBaseInput): - format_key = 'TIME_INPUT_FORMATS' - template_name = 'django/forms/widgets/time.html' + format_key = "TIME_INPUT_FORMATS" + template_name = "django/forms/widgets/time.html" # Defined at module level so that CheckboxInput is picklable (#17976) def boolean_check(v): - return not (v is False or v is None or v == '') + return not (v is False or v is None or v == "") class CheckboxInput(Input): - input_type = 'checkbox' - template_name = 'django/forms/widgets/checkbox.html' + input_type = "checkbox" + template_name = "django/forms/widgets/checkbox.html" def __init__(self, attrs=None, check_test=None): super().__init__(attrs) @@ -521,13 +557,13 @@ class CheckboxInput(Input): def format_value(self, value): """Only return the 'value' attribute if value isn't empty.""" - if value is True or value is False or value is None or value == '': + if value is True or value is False or value is None or value == "": return return str(value) def get_context(self, name, value, attrs): if self.check_test(value): - attrs = {**(attrs or {}), 'checked': True} + attrs = {**(attrs or {}), "checked": True} return super().get_context(name, value, attrs) def value_from_datadict(self, data, files, name): @@ -537,7 +573,7 @@ class CheckboxInput(Input): return False value = data.get(name) # Translate true and false strings to boolean values. - values = {'true': True, 'false': False} + values = {"true": True, "false": False} if isinstance(value, str): value = values.get(value.lower(), value) return bool(value) @@ -554,7 +590,7 @@ class ChoiceWidget(Widget): template_name = None option_template_name = None add_id_index = True - checked_attribute = {'checked': True} + checked_attribute = {"checked": True} option_inherits_attrs = True def __init__(self, attrs=None, choices=()): @@ -591,7 +627,7 @@ class ChoiceWidget(Widget): for index, (option_value, option_label) in enumerate(self.choices): if option_value is None: - option_value = '' + option_value = "" subgroup = [] if isinstance(option_label, (list, tuple)): @@ -605,50 +641,62 @@ class ChoiceWidget(Widget): groups.append((group_name, subgroup, index)) for subvalue, sublabel in choices: - selected = ( - (not has_selected or self.allow_multiple_selected) and - str(subvalue) in value - ) + selected = (not has_selected or self.allow_multiple_selected) and str( + subvalue + ) in value has_selected |= selected - subgroup.append(self.create_option( - name, subvalue, sublabel, selected, index, - subindex=subindex, attrs=attrs, - )) + subgroup.append( + self.create_option( + name, + subvalue, + sublabel, + selected, + index, + subindex=subindex, + attrs=attrs, + ) + ) if subindex is not None: subindex += 1 return groups - def create_option(self, name, value, label, selected, index, subindex=None, attrs=None): + def create_option( + self, name, value, label, selected, index, subindex=None, attrs=None + ): index = str(index) if subindex is None else "%s_%s" % (index, subindex) - option_attrs = self.build_attrs(self.attrs, attrs) if self.option_inherits_attrs else {} + option_attrs = ( + self.build_attrs(self.attrs, attrs) if self.option_inherits_attrs else {} + ) if selected: option_attrs.update(self.checked_attribute) - if 'id' in option_attrs: - option_attrs['id'] = self.id_for_label(option_attrs['id'], index) + if "id" in option_attrs: + option_attrs["id"] = self.id_for_label(option_attrs["id"], index) return { - 'name': name, - 'value': value, - 'label': label, - 'selected': selected, - 'index': index, - 'attrs': option_attrs, - 'type': self.input_type, - 'template_name': self.option_template_name, - 'wrap_label': True, + "name": name, + "value": value, + "label": label, + "selected": selected, + "index": index, + "attrs": option_attrs, + "type": self.input_type, + "template_name": self.option_template_name, + "wrap_label": True, } def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) - context['widget']['optgroups'] = self.optgroups(name, context['widget']['value'], attrs) + context["widget"]["optgroups"] = self.optgroups( + name, context["widget"]["value"], attrs + ) return context - def id_for_label(self, id_, index='0'): + def id_for_label(self, id_, index="0"): """ Use an incremented id for each option where the main widget references the zero index. """ if id_ and self.add_id_index: - id_ = '%s_%s' % (id_, index) + id_ = "%s_%s" % (id_, index) return id_ def value_from_datadict(self, data, files, name): @@ -666,28 +714,28 @@ class ChoiceWidget(Widget): return [] if not isinstance(value, (tuple, list)): value = [value] - return [str(v) if v is not None else '' for v in value] + return [str(v) if v is not None else "" for v in value] class Select(ChoiceWidget): - input_type = 'select' - template_name = 'django/forms/widgets/select.html' - option_template_name = 'django/forms/widgets/select_option.html' + input_type = "select" + template_name = "django/forms/widgets/select.html" + option_template_name = "django/forms/widgets/select_option.html" add_id_index = False - checked_attribute = {'selected': True} + checked_attribute = {"selected": True} option_inherits_attrs = False def get_context(self, name, value, attrs): context = super().get_context(name, value, attrs) if self.allow_multiple_selected: - context['widget']['attrs']['multiple'] = True + context["widget"]["attrs"]["multiple"] = True return context @staticmethod def _choice_has_empty_value(choice): """Return True if the choice's value is empty string or None.""" value, _ = choice - return value is None or value == '' + return value is None or value == "" def use_required_attribute(self, initial): """ @@ -700,44 +748,52 @@ class Select(ChoiceWidget): return use_required_attribute first_choice = next(iter(self.choices), None) - return use_required_attribute and first_choice is not None and self._choice_has_empty_value(first_choice) + return ( + use_required_attribute + and first_choice is not None + and self._choice_has_empty_value(first_choice) + ) class NullBooleanSelect(Select): """ A Select Widget intended to be used with NullBooleanField. """ + def __init__(self, attrs=None): choices = ( - ('unknown', _('Unknown')), - ('true', _('Yes')), - ('false', _('No')), + ("unknown", _("Unknown")), + ("true", _("Yes")), + ("false", _("No")), ) super().__init__(attrs, choices) def format_value(self, value): try: return { - True: 'true', False: 'false', - 'true': 'true', 'false': 'false', + True: "true", + False: "false", + "true": "true", + "false": "false", # For backwards compatibility with Django < 2.2. - '2': 'true', '3': 'false', + "2": "true", + "3": "false", }[value] except KeyError: - return 'unknown' + return "unknown" def value_from_datadict(self, data, files, name): value = data.get(name) return { True: True, - 'True': True, - 'False': False, + "True": True, + "False": False, False: False, - 'true': True, - 'false': False, + "true": True, + "false": False, # For backwards compatibility with Django < 2.2. - '2': True, - '3': False, + "2": True, + "3": False, }.get(value) @@ -758,9 +814,9 @@ class SelectMultiple(Select): class RadioSelect(ChoiceWidget): - input_type = 'radio' - template_name = 'django/forms/widgets/radio.html' - option_template_name = 'django/forms/widgets/radio_option.html' + input_type = "radio" + template_name = "django/forms/widgets/radio.html" + option_template_name = "django/forms/widgets/radio_option.html" def id_for_label(self, id_, index=None): """ @@ -769,15 +825,15 @@ class RadioSelect(ChoiceWidget): the first input. """ if index is None: - return '' + return "" return super().id_for_label(id_, index) class CheckboxSelectMultiple(RadioSelect): allow_multiple_selected = True - input_type = 'checkbox' - template_name = 'django/forms/widgets/checkbox_select.html' - option_template_name = 'django/forms/widgets/checkbox_option.html' + input_type = "checkbox" + template_name = "django/forms/widgets/checkbox_select.html" + option_template_name = "django/forms/widgets/checkbox_option.html" def use_required_attribute(self, initial): # Don't use the 'required' attribute because browser validation would @@ -800,16 +856,15 @@ class MultiWidget(Widget): You'll probably want to use this class with MultiValueField. """ - template_name = 'django/forms/widgets/multiwidget.html' + + template_name = "django/forms/widgets/multiwidget.html" def __init__(self, widgets, attrs=None): if isinstance(widgets, dict): - self.widgets_names = [ - ('_%s' % name) if name else '' for name in widgets - ] + self.widgets_names = [("_%s" % name) if name else "" for name in widgets] widgets = widgets.values() else: - self.widgets_names = ['_%s' % i for i in range(len(widgets))] + self.widgets_names = ["_%s" % i for i in range(len(widgets))] self.widgets = [w() if isinstance(w, type) else w for w in widgets] super().__init__(attrs) @@ -827,11 +882,13 @@ class MultiWidget(Widget): if not isinstance(value, list): value = self.decompress(value) - final_attrs = context['widget']['attrs'] - input_type = final_attrs.pop('type', None) - id_ = final_attrs.get('id') + final_attrs = context["widget"]["attrs"] + input_type = final_attrs.pop("type", None) + id_ = final_attrs.get("id") subwidgets = [] - for i, (widget_name, widget) in enumerate(zip(self.widgets_names, self.widgets)): + for i, (widget_name, widget) in enumerate( + zip(self.widgets_names, self.widgets) + ): if input_type is not None: widget.input_type = input_type widget_name = name + widget_name @@ -841,15 +898,17 @@ class MultiWidget(Widget): widget_value = None if id_: widget_attrs = final_attrs.copy() - widget_attrs['id'] = '%s_%s' % (id_, i) + widget_attrs["id"] = "%s_%s" % (id_, i) else: widget_attrs = final_attrs - subwidgets.append(widget.get_context(widget_name, widget_value, widget_attrs)['widget']) - context['widget']['subwidgets'] = subwidgets + subwidgets.append( + widget.get_context(widget_name, widget_value, widget_attrs)["widget"] + ) + context["widget"]["subwidgets"] = subwidgets return context def id_for_label(self, id_): - return '' + return "" def value_from_datadict(self, data, files, name): return [ @@ -869,7 +928,7 @@ class MultiWidget(Widget): The given value can be assumed to be valid, but not necessarily non-empty. """ - raise NotImplementedError('Subclasses must implement this method.') + raise NotImplementedError("Subclasses must implement this method.") def _get_media(self): """ @@ -880,6 +939,7 @@ class MultiWidget(Widget): for w in self.widgets: media = media + w.media return media + media = property(_get_media) def __deepcopy__(self, memo): @@ -896,10 +956,18 @@ class SplitDateTimeWidget(MultiWidget): """ A widget that splits datetime input into two <input type="text"> boxes. """ - supports_microseconds = False - template_name = 'django/forms/widgets/splitdatetime.html' - def __init__(self, attrs=None, date_format=None, time_format=None, date_attrs=None, time_attrs=None): + supports_microseconds = False + template_name = "django/forms/widgets/splitdatetime.html" + + def __init__( + self, + attrs=None, + date_format=None, + time_format=None, + date_attrs=None, + time_attrs=None, + ): widgets = ( DateInput( attrs=attrs if date_attrs is None else date_attrs, @@ -923,12 +991,20 @@ class SplitHiddenDateTimeWidget(SplitDateTimeWidget): """ A widget that splits datetime input into two <input type="hidden"> inputs. """ - template_name = 'django/forms/widgets/splithiddendatetime.html' - def __init__(self, attrs=None, date_format=None, time_format=None, date_attrs=None, time_attrs=None): + template_name = "django/forms/widgets/splithiddendatetime.html" + + def __init__( + self, + attrs=None, + date_format=None, + time_format=None, + date_attrs=None, + time_attrs=None, + ): super().__init__(attrs, date_format, time_format, date_attrs, time_attrs) for widget in self.widgets: - widget.input_type = 'hidden' + widget.input_type = "hidden" class SelectDateWidget(Widget): @@ -938,14 +1014,15 @@ class SelectDateWidget(Widget): This also serves as an example of a Widget that has more than one HTML element and hence implements value_from_datadict. """ - none_value = ('', '---') - month_field = '%s_month' - day_field = '%s_day' - year_field = '%s_year' - template_name = 'django/forms/widgets/select_date.html' - input_type = 'select' + + none_value = ("", "---") + month_field = "%s_month" + day_field = "%s_day" + year_field = "%s_year" + template_name = "django/forms/widgets/select_date.html" + input_type = "select" select_widget = Select - date_re = _lazy_re_compile(r'(\d{4}|0)-(\d\d?)-(\d\d?)$') + date_re = _lazy_re_compile(r"(\d{4}|0)-(\d\d?)-(\d\d?)$") def __init__(self, attrs=None, years=None, months=None, empty_label=None): self.attrs = attrs or {} @@ -966,14 +1043,14 @@ class SelectDateWidget(Widget): # Optional string, list, or tuple to use as empty_label. if isinstance(empty_label, (list, tuple)): if not len(empty_label) == 3: - raise ValueError('empty_label list/tuple must have 3 elements.') + raise ValueError("empty_label list/tuple must have 3 elements.") - self.year_none_value = ('', empty_label[0]) - self.month_none_value = ('', empty_label[1]) - self.day_none_value = ('', empty_label[2]) + self.year_none_value = ("", empty_label[0]) + self.month_none_value = ("", empty_label[1]) + self.day_none_value = ("", empty_label[2]) else: if empty_label is not None: - self.none_value = ('', empty_label) + self.none_value = ("", empty_label) self.year_none_value = self.none_value self.month_none_value = self.none_value @@ -986,33 +1063,40 @@ class SelectDateWidget(Widget): if not self.is_required: year_choices.insert(0, self.year_none_value) year_name = self.year_field % name - date_context['year'] = self.select_widget(attrs, choices=year_choices).get_context( + date_context["year"] = self.select_widget( + attrs, choices=year_choices + ).get_context( name=year_name, - value=context['widget']['value']['year'], - attrs={**context['widget']['attrs'], 'id': 'id_%s' % year_name}, + value=context["widget"]["value"]["year"], + attrs={**context["widget"]["attrs"], "id": "id_%s" % year_name}, ) month_choices = list(self.months.items()) if not self.is_required: month_choices.insert(0, self.month_none_value) month_name = self.month_field % name - date_context['month'] = self.select_widget(attrs, choices=month_choices).get_context( + date_context["month"] = self.select_widget( + attrs, choices=month_choices + ).get_context( name=month_name, - value=context['widget']['value']['month'], - attrs={**context['widget']['attrs'], 'id': 'id_%s' % month_name}, + value=context["widget"]["value"]["month"], + attrs={**context["widget"]["attrs"], "id": "id_%s" % month_name}, ) day_choices = [(i, i) for i in range(1, 32)] if not self.is_required: day_choices.insert(0, self.day_none_value) day_name = self.day_field % name - date_context['day'] = self.select_widget(attrs, choices=day_choices,).get_context( + date_context["day"] = self.select_widget( + attrs, + choices=day_choices, + ).get_context( name=day_name, - value=context['widget']['value']['day'], - attrs={**context['widget']['attrs'], 'id': 'id_%s' % day_name}, + value=context["widget"]["value"]["day"], + attrs={**context["widget"]["attrs"], "id": "id_%s" % day_name}, ) subwidgets = [] for field in self._parse_date_fmt(): - subwidgets.append(date_context[field]['widget']) - context['widget']['subwidgets'] = subwidgets + subwidgets.append(date_context[field]["widget"]) + context["widget"]["subwidgets"] = subwidgets return context def format_value(self, value): @@ -1029,58 +1113,58 @@ class SelectDateWidget(Widget): if match: # Convert any zeros in the date to empty strings to match the # empty option value. - year, month, day = [int(val) or '' for val in match.groups()] + year, month, day = [int(val) or "" for val in match.groups()] else: - input_format = get_format('DATE_INPUT_FORMATS')[0] + input_format = get_format("DATE_INPUT_FORMATS")[0] try: d = datetime.datetime.strptime(value, input_format) except ValueError: pass else: year, month, day = d.year, d.month, d.day - return {'year': year, 'month': month, 'day': day} + return {"year": year, "month": month, "day": day} @staticmethod def _parse_date_fmt(): - fmt = get_format('DATE_FORMAT') + fmt = get_format("DATE_FORMAT") escaped = False for char in fmt: if escaped: escaped = False - elif char == '\\': + elif char == "\\": escaped = True - elif char in 'Yy': - yield 'year' - elif char in 'bEFMmNn': - yield 'month' - elif char in 'dj': - yield 'day' + elif char in "Yy": + yield "year" + elif char in "bEFMmNn": + yield "month" + elif char in "dj": + yield "day" def id_for_label(self, id_): for first_select in self._parse_date_fmt(): - return '%s_%s' % (id_, first_select) - return '%s_month' % id_ + return "%s_%s" % (id_, first_select) + return "%s_month" % id_ def value_from_datadict(self, data, files, name): y = data.get(self.year_field % name) m = data.get(self.month_field % name) d = data.get(self.day_field % name) - if y == m == d == '': + if y == m == d == "": return None if y is not None and m is not None and d is not None: - input_format = get_format('DATE_INPUT_FORMATS')[0] + input_format = get_format("DATE_INPUT_FORMATS")[0] input_format = formats.sanitize_strftime_format(input_format) try: date_value = datetime.date(int(y), int(m), int(d)) except ValueError: # Return pseudo-ISO dates with zeros for any unselected values, # e.g. '2017-0-23'. - return '%s-%s-%s' % (y or 0, m or 0, d or 0) + return "%s-%s-%s" % (y or 0, m or 0, d or 0) return date_value.strftime(input_format) return data.get(name) def value_omitted_from_data(self, data, files, name): return not any( - ('{}_{}'.format(name, interval) in data) - for interval in ('year', 'month', 'day') + ("{}_{}".format(name, interval) in data) + for interval in ("year", "month", "day") ) diff --git a/django/http/__init__.py b/django/http/__init__.py index 491239bf8a..4c997154d9 100644 --- a/django/http/__init__.py +++ b/django/http/__init__.py @@ -1,21 +1,48 @@ from django.http.cookie import SimpleCookie, parse_cookie from django.http.request import ( - HttpRequest, QueryDict, RawPostDataException, UnreadablePostError, + HttpRequest, + QueryDict, + RawPostDataException, + UnreadablePostError, ) from django.http.response import ( - BadHeaderError, FileResponse, Http404, HttpResponse, - HttpResponseBadRequest, HttpResponseForbidden, HttpResponseGone, - HttpResponseNotAllowed, HttpResponseNotFound, HttpResponseNotModified, - HttpResponsePermanentRedirect, HttpResponseRedirect, - HttpResponseServerError, JsonResponse, StreamingHttpResponse, + BadHeaderError, + FileResponse, + Http404, + HttpResponse, + HttpResponseBadRequest, + HttpResponseForbidden, + HttpResponseGone, + HttpResponseNotAllowed, + HttpResponseNotFound, + HttpResponseNotModified, + HttpResponsePermanentRedirect, + HttpResponseRedirect, + HttpResponseServerError, + JsonResponse, + StreamingHttpResponse, ) __all__ = [ - 'SimpleCookie', 'parse_cookie', 'HttpRequest', 'QueryDict', - 'RawPostDataException', 'UnreadablePostError', - 'HttpResponse', 'StreamingHttpResponse', 'HttpResponseRedirect', - 'HttpResponsePermanentRedirect', 'HttpResponseNotModified', - 'HttpResponseBadRequest', 'HttpResponseForbidden', 'HttpResponseNotFound', - 'HttpResponseNotAllowed', 'HttpResponseGone', 'HttpResponseServerError', - 'Http404', 'BadHeaderError', 'JsonResponse', 'FileResponse', + "SimpleCookie", + "parse_cookie", + "HttpRequest", + "QueryDict", + "RawPostDataException", + "UnreadablePostError", + "HttpResponse", + "StreamingHttpResponse", + "HttpResponseRedirect", + "HttpResponsePermanentRedirect", + "HttpResponseNotModified", + "HttpResponseBadRequest", + "HttpResponseForbidden", + "HttpResponseNotFound", + "HttpResponseNotAllowed", + "HttpResponseGone", + "HttpResponseServerError", + "Http404", + "BadHeaderError", + "JsonResponse", + "FileResponse", ] diff --git a/django/http/cookie.py b/django/http/cookie.py index b94d2b0386..dce0cfdf25 100644 --- a/django/http/cookie.py +++ b/django/http/cookie.py @@ -9,13 +9,13 @@ def parse_cookie(cookie): Return a dictionary parsed from a `Cookie:` header string. """ cookiedict = {} - for chunk in cookie.split(';'): - if '=' in chunk: - key, val = chunk.split('=', 1) + for chunk in cookie.split(";"): + if "=" in chunk: + key, val = chunk.split("=", 1) else: # Assume an empty name per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 - key, val = '', chunk + key, val = "", chunk key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. diff --git a/django/http/multipartparser.py b/django/http/multipartparser.py index ef0b339d1b..13ed2fa4cd 100644 --- a/django/http/multipartparser.py +++ b/django/http/multipartparser.py @@ -13,15 +13,15 @@ from urllib.parse import unquote from django.conf import settings from django.core.exceptions import ( - RequestDataTooBig, SuspiciousMultipartForm, TooManyFieldsSent, -) -from django.core.files.uploadhandler import ( - SkipFile, StopFutureHandlers, StopUpload, + RequestDataTooBig, + SuspiciousMultipartForm, + TooManyFieldsSent, ) +from django.core.files.uploadhandler import SkipFile, StopFutureHandlers, StopUpload from django.utils.datastructures import MultiValueDict from django.utils.encoding import force_str -__all__ = ('MultiPartParser', 'MultiPartParserError', 'InputStreamExhausted') +__all__ = ("MultiPartParser", "MultiPartParserError", "InputStreamExhausted") class MultiPartParserError(Exception): @@ -32,6 +32,7 @@ class InputStreamExhausted(Exception): """ No more reads are allowed from this device. """ + pass @@ -47,6 +48,7 @@ class MultiPartParser: ``MultiValueDict.parse()`` reads the input stream in ``chunk_size`` chunks and returns a tuple of ``(MultiValueDict(POST), MultiValueDict(FILES))``. """ + def __init__(self, META, input_data, upload_handlers, encoding=None): """ Initialize the MultiPartParser object. @@ -62,23 +64,28 @@ class MultiPartParser: The encoding with which to treat the incoming data. """ # Content-Type should contain multipart and the boundary information. - content_type = META.get('CONTENT_TYPE', '') - if not content_type.startswith('multipart/'): - raise MultiPartParserError('Invalid Content-Type: %s' % content_type) + content_type = META.get("CONTENT_TYPE", "") + if not content_type.startswith("multipart/"): + raise MultiPartParserError("Invalid Content-Type: %s" % content_type) # Parse the header to get the boundary to split the parts. try: - ctypes, opts = parse_header(content_type.encode('ascii')) + ctypes, opts = parse_header(content_type.encode("ascii")) except UnicodeEncodeError: - raise MultiPartParserError('Invalid non-ASCII Content-Type in multipart: %s' % force_str(content_type)) - boundary = opts.get('boundary') + raise MultiPartParserError( + "Invalid non-ASCII Content-Type in multipart: %s" + % force_str(content_type) + ) + boundary = opts.get("boundary") if not boundary or not cgi.valid_boundary(boundary): - raise MultiPartParserError('Invalid boundary in multipart: %s' % force_str(boundary)) + raise MultiPartParserError( + "Invalid boundary in multipart: %s" % force_str(boundary) + ) # Content-Length should contain the length of the body we are about # to receive. try: - content_length = int(META.get('CONTENT_LENGTH', 0)) + content_length = int(META.get("CONTENT_LENGTH", 0)) except (ValueError, TypeError): content_length = 0 @@ -87,14 +94,14 @@ class MultiPartParser: raise MultiPartParserError("Invalid content length: %r" % content_length) if isinstance(boundary, str): - boundary = boundary.encode('ascii') + boundary = boundary.encode("ascii") self._boundary = boundary self._input_data = input_data # For compatibility with low-level network APIs (with 32-bit integers), # the chunk size should be < 2^31, but still divisible by 4. possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] - self._chunk_size = min([2 ** 31 - 4] + possible_sizes) + self._chunk_size = min([2**31 - 4] + possible_sizes) self._meta = META self._encoding = encoding or settings.DEFAULT_CHARSET @@ -163,32 +170,36 @@ class MultiPartParser: uploaded_file = True try: - disposition = meta_data['content-disposition'][1] - field_name = disposition['name'].strip() + disposition = meta_data["content-disposition"][1] + field_name = disposition["name"].strip() except (KeyError, IndexError, AttributeError): continue - transfer_encoding = meta_data.get('content-transfer-encoding') + transfer_encoding = meta_data.get("content-transfer-encoding") if transfer_encoding is not None: transfer_encoding = transfer_encoding[0].strip() - field_name = force_str(field_name, encoding, errors='replace') + field_name = force_str(field_name, encoding, errors="replace") if item_type == FIELD: # Avoid storing more than DATA_UPLOAD_MAX_NUMBER_FIELDS. num_post_keys += 1 - if (settings.DATA_UPLOAD_MAX_NUMBER_FIELDS is not None and - settings.DATA_UPLOAD_MAX_NUMBER_FIELDS < num_post_keys): + if ( + settings.DATA_UPLOAD_MAX_NUMBER_FIELDS is not None + and settings.DATA_UPLOAD_MAX_NUMBER_FIELDS < num_post_keys + ): raise TooManyFieldsSent( - 'The number of GET/POST parameters exceeded ' - 'settings.DATA_UPLOAD_MAX_NUMBER_FIELDS.' + "The number of GET/POST parameters exceeded " + "settings.DATA_UPLOAD_MAX_NUMBER_FIELDS." ) # Avoid reading more than DATA_UPLOAD_MAX_MEMORY_SIZE. if settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None: - read_size = settings.DATA_UPLOAD_MAX_MEMORY_SIZE - num_bytes_read + read_size = ( + settings.DATA_UPLOAD_MAX_MEMORY_SIZE - num_bytes_read + ) # This is a post field, we can just set it in the post - if transfer_encoding == 'base64': + if transfer_encoding == "base64": raw_data = field_stream.read(size=read_size) num_bytes_read += len(raw_data) try: @@ -202,26 +213,34 @@ class MultiPartParser: # Add two here to make the check consistent with the # x-www-form-urlencoded check that includes '&='. num_bytes_read += len(field_name) + 2 - if (settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None and - num_bytes_read > settings.DATA_UPLOAD_MAX_MEMORY_SIZE): - raise RequestDataTooBig('Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE.') + if ( + settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None + and num_bytes_read > settings.DATA_UPLOAD_MAX_MEMORY_SIZE + ): + raise RequestDataTooBig( + "Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE." + ) - self._post.appendlist(field_name, force_str(data, encoding, errors='replace')) + self._post.appendlist( + field_name, force_str(data, encoding, errors="replace") + ) elif item_type == FILE: # This is a file, use the handler... - file_name = disposition.get('filename') + file_name = disposition.get("filename") if file_name: - file_name = force_str(file_name, encoding, errors='replace') + file_name = force_str(file_name, encoding, errors="replace") file_name = self.sanitize_file_name(file_name) if not file_name: continue - content_type, content_type_extra = meta_data.get('content-type', ('', {})) + content_type, content_type_extra = meta_data.get( + "content-type", ("", {}) + ) content_type = content_type.strip() - charset = content_type_extra.get('charset') + charset = content_type_extra.get("charset") try: - content_length = int(meta_data.get('content-length')[0]) + content_length = int(meta_data.get("content-length")[0]) except (IndexError, TypeError, ValueError): content_length = None @@ -231,14 +250,18 @@ class MultiPartParser: for handler in handlers: try: handler.new_file( - field_name, file_name, content_type, - content_length, charset, content_type_extra, + field_name, + file_name, + content_type, + content_length, + charset, + content_type_extra, ) except StopFutureHandlers: break for chunk in field_stream: - if transfer_encoding == 'base64': + if transfer_encoding == "base64": # We only special-case base64 transfer encoding # We should always decode base64 chunks by multiple of 4, # ignoring whitespace. @@ -257,7 +280,9 @@ class MultiPartParser: chunk = base64.b64decode(stripped_chunk) except Exception as exc: # Since this is only a chunk, any error is an unfixable error. - raise MultiPartParserError("Could not decode base64 data.") from exc + raise MultiPartParserError( + "Could not decode base64 data." + ) from exc for i, handler in enumerate(handlers): chunk_length = len(chunk) @@ -303,7 +328,10 @@ class MultiPartParser: file_obj = handler.file_complete(counters[i]) if file_obj: # If it returns a file object, then set the files dict. - self._files.appendlist(force_str(old_field_name, self._encoding, errors='replace'), file_obj) + self._files.appendlist( + force_str(old_field_name, self._encoding, errors="replace"), + file_obj, + ) break def sanitize_file_name(self, file_name): @@ -320,12 +348,12 @@ class MultiPartParser: resulting filename should still be considered as untrusted user input. """ file_name = html.unescape(file_name) - file_name = file_name.rsplit('/')[-1] - file_name = file_name.rsplit('\\')[-1] + file_name = file_name.rsplit("/")[-1] + file_name = file_name.rsplit("\\")[-1] # Remove non-printable characters. - file_name = ''.join([char for char in file_name if char.isprintable()]) + file_name = "".join([char for char in file_name if char.isprintable()]) - if file_name in {'', '.', '..'}: + if file_name in {"", ".", ".."}: return None return file_name @@ -336,7 +364,7 @@ class MultiPartParser: # FIXME: this currently assumes that upload handlers store the file as 'file' # We should document that... (Maybe add handler.free_file to complement new_file) for handler in self._upload_handlers: - if hasattr(handler, 'file'): + if hasattr(handler, "file"): handler.file.close() @@ -348,6 +376,7 @@ class LazyStream: LazyStream object will support iteration, reading, and keeping a "look-back" variable in case you need to "unget" some bytes. """ + def __init__(self, producer, length=None): """ Every LazyStream must have a producer when instantiated. @@ -357,7 +386,7 @@ class LazyStream: """ self._producer = producer self._empty = False - self._leftover = b'' + self._leftover = b"" self.length = length self.position = 0 self._remaining = length @@ -371,14 +400,14 @@ class LazyStream: remaining = self._remaining if size is None else size # do the whole thing in one shot if no limit was provided. if remaining is None: - yield b''.join(self) + yield b"".join(self) return # otherwise do some bookkeeping to return exactly enough # of the stream and stashing any extra content we get from # the producer while remaining != 0: - assert remaining > 0, 'remaining bytes to read should never go negative' + assert remaining > 0, "remaining bytes to read should never go negative" try: chunk = next(self) @@ -390,7 +419,7 @@ class LazyStream: remaining -= len(emitting) yield emitting - return b''.join(parts()) + return b"".join(parts()) def __next__(self): """ @@ -401,7 +430,7 @@ class LazyStream: """ if self._leftover: output = self._leftover - self._leftover = b'' + self._leftover = b"" else: output = next(self._producer) self._unget_history = [] @@ -442,10 +471,13 @@ class LazyStream: maliciously-malformed MIME request. """ self._unget_history = [num_bytes] + self._unget_history[:49] - number_equal = len([ - current_number for current_number in self._unget_history - if current_number == num_bytes - ]) + number_equal = len( + [ + current_number + for current_number in self._unget_history + if current_number == num_bytes + ] + ) if number_equal > 40: raise SuspiciousMultipartForm( @@ -460,6 +492,7 @@ class ChunkIter: An iterable that will yield chunks of data. Given a file-like object as the constructor, yield chunks of read operations from that object. """ + def __init__(self, flo, chunk_size=64 * 1024): self.flo = flo self.chunk_size = chunk_size @@ -482,6 +515,7 @@ class InterBoundaryIter: """ A Producer that will iterate over boundaries. """ + def __init__(self, stream, boundary): self._stream = stream self._boundary = boundary @@ -548,7 +582,7 @@ class BoundaryIter: if not chunks: raise StopIteration() - chunk = b''.join(chunks) + chunk = b"".join(chunks) boundary = self._find_boundary(chunk) if boundary: @@ -584,10 +618,10 @@ class BoundaryIter: next = index + len(self._boundary) # backup over CRLF last = max(0, end - 1) - if data[last:last + 1] == b'\n': + if data[last : last + 1] == b"\n": end -= 1 last = max(0, end - 1) - if data[last:last + 1] == b'\r': + if data[last : last + 1] == b"\r": end -= 1 return end, next @@ -613,12 +647,12 @@ def parse_boundary_stream(stream, max_header_size): # 'find' returns the top of these four bytes, so we'll # need to munch them later to prevent them from polluting # the payload. - header_end = chunk.find(b'\r\n\r\n') + header_end = chunk.find(b"\r\n\r\n") def _parse_header(line): main_value_pair, params = parse_header(line) try: - name, value = main_value_pair.split(':', 1) + name, value = main_value_pair.split(":", 1) except ValueError: raise ValueError("Invalid header: %r" % line) return name, (value, params) @@ -633,13 +667,13 @@ def parse_boundary_stream(stream, max_header_size): # here we place any excess chunk back onto the stream, as # well as throwing away the CRLFCRLF bytes from above. - stream.unget(chunk[header_end + 4:]) + stream.unget(chunk[header_end + 4 :]) TYPE = RAW outdict = {} # Eliminate blank lines - for line in header.split(b'\r\n'): + for line in header.split(b"\r\n"): # This terminology ("main value" and "dictionary of # parameters") is from the Python docs. try: @@ -647,9 +681,9 @@ def parse_boundary_stream(stream, max_header_size): except ValueError: continue - if name == 'content-disposition': + if name == "content-disposition": TYPE = FIELD - if params.get('filename'): + if params.get("filename"): TYPE = FILE outdict[name] = value, params @@ -663,7 +697,7 @@ def parse_boundary_stream(stream, max_header_size): class Parser: def __init__(self, stream, boundary): self._stream = stream - self._separator = b'--' + boundary + self._separator = b"--" + boundary def __iter__(self): boundarystream = InterBoundaryIter(self._stream, self._separator) @@ -679,24 +713,24 @@ def parse_header(line): Input (line): bytes, output: str for key/name, bytes for values which will be decoded later. """ - plist = _parse_header_params(b';' + line) - key = plist.pop(0).lower().decode('ascii') + plist = _parse_header_params(b";" + line) + key = plist.pop(0).lower().decode("ascii") pdict = {} for p in plist: - i = p.find(b'=') + i = p.find(b"=") if i >= 0: has_encoding = False - name = p[:i].strip().lower().decode('ascii') - if name.endswith('*'): + name = p[:i].strip().lower().decode("ascii") + if name.endswith("*"): # Lang/encoding embedded in the value (like "filename*=UTF-8''file.ext") # https://tools.ietf.org/html/rfc2231#section-4 name = name[:-1] if p.count(b"'") == 2: has_encoding = True - value = p[i + 1:].strip() + value = p[i + 1 :].strip() if len(value) >= 2 and value[:1] == value[-1:] == b'"': value = value[1:-1] - value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') if has_encoding: encoding, lang, value = value.split(b"'") value = unquote(value.decode(), encoding=encoding.decode()) @@ -706,11 +740,11 @@ def parse_header(line): def _parse_header_params(s): plist = [] - while s[:1] == b';': + while s[:1] == b";": s = s[1:] - end = s.find(b';') + end = s.find(b";") while end > 0 and s.count(b'"', 0, end) % 2: - end = s.find(b';', end + 1) + end = s.find(b";", end + 1) if end < 0: end = len(s) f = s[:end] diff --git a/django/http/request.py b/django/http/request.py index 5971203261..d975aadf25 100644 --- a/django/http/request.py +++ b/django/http/request.py @@ -8,12 +8,17 @@ from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlsplit from django.conf import settings from django.core import signing from django.core.exceptions import ( - DisallowedHost, ImproperlyConfigured, RequestDataTooBig, TooManyFieldsSent, + DisallowedHost, + ImproperlyConfigured, + RequestDataTooBig, + TooManyFieldsSent, ) from django.core.files import uploadhandler from django.http.multipartparser import MultiPartParser, MultiPartParserError from django.utils.datastructures import ( - CaseInsensitiveMapping, ImmutableList, MultiValueDict, + CaseInsensitiveMapping, + ImmutableList, + MultiValueDict, ) from django.utils.encoding import escape_uri_path, iri_to_uri from django.utils.functional import cached_property @@ -23,7 +28,9 @@ from django.utils.regex_helper import _lazy_re_compile from .multipartparser import parse_header RAISE_ERROR = object() -host_validation_re = _lazy_re_compile(r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$") +host_validation_re = _lazy_re_compile( + r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9\.:]+\])(:[0-9]+)?$" +) class UnreadablePostError(OSError): @@ -36,6 +43,7 @@ class RawPostDataException(Exception): multipart/* POST data if it has been accessed via POST, FILES, etc.. """ + pass @@ -57,8 +65,8 @@ class HttpRequest: self.META = {} self.FILES = MultiValueDict() - self.path = '' - self.path_info = '' + self.path = "" + self.path_info = "" self.method = None self.resolver_match = None self.content_type = None @@ -66,8 +74,12 @@ class HttpRequest: def __repr__(self): if self.method is None or not self.get_full_path(): - return '<%s>' % self.__class__.__name__ - return '<%s: %s %r>' % (self.__class__.__name__, self.method, self.get_full_path()) + return "<%s>" % self.__class__.__name__ + return "<%s: %s %r>" % ( + self.__class__.__name__, + self.method, + self.get_full_path(), + ) @cached_property def headers(self): @@ -76,24 +88,25 @@ class HttpRequest: @cached_property def accepted_types(self): """Return a list of MediaType instances.""" - return parse_accept_header(self.headers.get('Accept', '*/*')) + return parse_accept_header(self.headers.get("Accept", "*/*")) def accepts(self, media_type): return any( - accepted_type.match(media_type) - for accepted_type in self.accepted_types + accepted_type.match(media_type) for accepted_type in self.accepted_types ) def _set_content_type_params(self, meta): """Set content_type, content_params, and encoding.""" - self.content_type, self.content_params = cgi.parse_header(meta.get('CONTENT_TYPE', '')) - if 'charset' in self.content_params: + self.content_type, self.content_params = cgi.parse_header( + meta.get("CONTENT_TYPE", "") + ) + if "charset" in self.content_params: try: - codecs.lookup(self.content_params['charset']) + codecs.lookup(self.content_params["charset"]) except LookupError: pass else: - self.encoding = self.content_params['charset'] + self.encoding = self.content_params["charset"] def _get_raw_host(self): """ @@ -101,17 +114,16 @@ class HttpRequest: allowed hosts protection, so may return an insecure host. """ # We try three options, in order of decreasing preference. - if settings.USE_X_FORWARDED_HOST and ( - 'HTTP_X_FORWARDED_HOST' in self.META): - host = self.META['HTTP_X_FORWARDED_HOST'] - elif 'HTTP_HOST' in self.META: - host = self.META['HTTP_HOST'] + if settings.USE_X_FORWARDED_HOST and ("HTTP_X_FORWARDED_HOST" in self.META): + host = self.META["HTTP_X_FORWARDED_HOST"] + elif "HTTP_HOST" in self.META: + host = self.META["HTTP_HOST"] else: # Reconstruct the host using the algorithm from PEP 333. - host = self.META['SERVER_NAME'] + host = self.META["SERVER_NAME"] server_port = self.get_port() - if server_port != ('443' if self.is_secure() else '80'): - host = '%s:%s' % (host, server_port) + if server_port != ("443" if self.is_secure() else "80"): + host = "%s:%s" % (host, server_port) return host def get_host(self): @@ -121,7 +133,7 @@ class HttpRequest: # Allow variants of localhost if ALLOWED_HOSTS is empty and DEBUG=True. allowed_hosts = settings.ALLOWED_HOSTS if settings.DEBUG and not allowed_hosts: - allowed_hosts = ['.localhost', '127.0.0.1', '[::1]'] + allowed_hosts = [".localhost", "127.0.0.1", "[::1]"] domain, port = split_domain_port(host) if domain and validate_host(domain, allowed_hosts): @@ -131,15 +143,17 @@ class HttpRequest: if domain: msg += " You may need to add %r to ALLOWED_HOSTS." % domain else: - msg += " The domain name provided is not valid according to RFC 1034/1035." + msg += ( + " The domain name provided is not valid according to RFC 1034/1035." + ) raise DisallowedHost(msg) def get_port(self): """Return the port number for the request as a string.""" - if settings.USE_X_FORWARDED_PORT and 'HTTP_X_FORWARDED_PORT' in self.META: - port = self.META['HTTP_X_FORWARDED_PORT'] + if settings.USE_X_FORWARDED_PORT and "HTTP_X_FORWARDED_PORT" in self.META: + port = self.META["HTTP_X_FORWARDED_PORT"] else: - port = self.META['SERVER_PORT'] + port = self.META["SERVER_PORT"] return str(port) def get_full_path(self, force_append_slash=False): @@ -151,13 +165,15 @@ class HttpRequest: def _get_full_path(self, path, force_append_slash): # RFC 3986 requires query string arguments to be in the ASCII range. # Rather than crash if this doesn't happen, we encode defensively. - return '%s%s%s' % ( + return "%s%s%s" % ( escape_uri_path(path), - '/' if force_append_slash and not path.endswith('/') else '', - ('?' + iri_to_uri(self.META.get('QUERY_STRING', ''))) if self.META.get('QUERY_STRING', '') else '' + "/" if force_append_slash and not path.endswith("/") else "", + ("?" + iri_to_uri(self.META.get("QUERY_STRING", ""))) + if self.META.get("QUERY_STRING", "") + else "", ) - def get_signed_cookie(self, key, default=RAISE_ERROR, salt='', max_age=None): + def get_signed_cookie(self, key, default=RAISE_ERROR, salt="", max_age=None): """ Attempt to return a signed cookie. If the signature fails or the cookie has expired, raise an exception, unless the `default` argument @@ -172,7 +188,8 @@ class HttpRequest: raise try: value = signing.get_cookie_signer(salt=key + salt).unsign( - cookie_value, max_age=max_age) + cookie_value, max_age=max_age + ) except signing.BadSignature: if default is not RAISE_ERROR: return default @@ -192,7 +209,7 @@ class HttpRequest: if location is None: # Make it an absolute url (but schemeless and domainless) for the # edge case that the path starts with '//'. - location = '//%s' % self.get_full_path() + location = "//%s" % self.get_full_path() else: # Coerce lazy locations. location = str(location) @@ -201,12 +218,17 @@ class HttpRequest: # Handle the simple, most common case. If the location is absolute # and a scheme or host (netloc) isn't provided, skip an expensive # urljoin() as long as no path segments are '.' or '..'. - if (bits.path.startswith('/') and not bits.scheme and not bits.netloc and - '/./' not in bits.path and '/../' not in bits.path): + if ( + bits.path.startswith("/") + and not bits.scheme + and not bits.netloc + and "/./" not in bits.path + and "/../" not in bits.path + ): # If location starts with '//' but has no netloc, reuse the # schema and netloc from the current request. Strip the double # slashes and continue as if it wasn't specified. - if location.startswith('//'): + if location.startswith("//"): location = location[2:] location = self._current_scheme_host + location else: @@ -218,14 +240,14 @@ class HttpRequest: @cached_property def _current_scheme_host(self): - return '{}://{}'.format(self.scheme, self.get_host()) + return "{}://{}".format(self.scheme, self.get_host()) def _get_scheme(self): """ Hook for subclasses like WSGIRequest to implement. Return 'http' by default. """ - return 'http' + return "http" @property def scheme(self): @@ -234,15 +256,15 @@ class HttpRequest: header, secure_value = settings.SECURE_PROXY_SSL_HEADER except ValueError: raise ImproperlyConfigured( - 'The SECURE_PROXY_SSL_HEADER setting must be a tuple containing two values.' + "The SECURE_PROXY_SSL_HEADER setting must be a tuple containing two values." ) header_value = self.META.get(header) if header_value is not None: - return 'https' if header_value == secure_value else 'http' + return "https" if header_value == secure_value else "http" return self._get_scheme() def is_secure(self): - return self.scheme == 'https' + return self.scheme == "https" @property def encoding(self): @@ -256,14 +278,16 @@ class HttpRequest: next access (so that it is decoded correctly). """ self._encoding = val - if hasattr(self, 'GET'): + if hasattr(self, "GET"): del self.GET - if hasattr(self, '_post'): + if hasattr(self, "_post"): del self._post def _initialize_handlers(self): - self._upload_handlers = [uploadhandler.load_handler(handler, self) - for handler in settings.FILE_UPLOAD_HANDLERS] + self._upload_handlers = [ + uploadhandler.load_handler(handler, self) + for handler in settings.FILE_UPLOAD_HANDLERS + ] @property def upload_handlers(self): @@ -274,29 +298,38 @@ class HttpRequest: @upload_handlers.setter def upload_handlers(self, upload_handlers): - if hasattr(self, '_files'): - raise AttributeError("You cannot set the upload handlers after the upload has been processed.") + if hasattr(self, "_files"): + raise AttributeError( + "You cannot set the upload handlers after the upload has been processed." + ) self._upload_handlers = upload_handlers def parse_file_upload(self, META, post_data): """Return a tuple of (POST QueryDict, FILES MultiValueDict).""" self.upload_handlers = ImmutableList( self.upload_handlers, - warning="You cannot alter upload handlers after the upload has been processed." + warning="You cannot alter upload handlers after the upload has been processed.", ) parser = MultiPartParser(META, post_data, self.upload_handlers, self.encoding) return parser.parse() @property def body(self): - if not hasattr(self, '_body'): + if not hasattr(self, "_body"): if self._read_started: - raise RawPostDataException("You cannot access body after reading from request's data stream") + raise RawPostDataException( + "You cannot access body after reading from request's data stream" + ) # Limit the maximum request data size that will be handled in-memory. - if (settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None and - int(self.META.get('CONTENT_LENGTH') or 0) > settings.DATA_UPLOAD_MAX_MEMORY_SIZE): - raise RequestDataTooBig('Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE.') + if ( + settings.DATA_UPLOAD_MAX_MEMORY_SIZE is not None + and int(self.META.get("CONTENT_LENGTH") or 0) + > settings.DATA_UPLOAD_MAX_MEMORY_SIZE + ): + raise RequestDataTooBig( + "Request body exceeded settings.DATA_UPLOAD_MAX_MEMORY_SIZE." + ) try: self._body = self.read() @@ -311,15 +344,18 @@ class HttpRequest: def _load_post_and_files(self): """Populate self._post and self._files if the content-type is a form type""" - if self.method != 'POST': - self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict() + if self.method != "POST": + self._post, self._files = ( + QueryDict(encoding=self._encoding), + MultiValueDict(), + ) return - if self._read_started and not hasattr(self, '_body'): + if self._read_started and not hasattr(self, "_body"): self._mark_post_parse_error() return - if self.content_type == 'multipart/form-data': - if hasattr(self, '_body'): + if self.content_type == "multipart/form-data": + if hasattr(self, "_body"): # Use already read data data = BytesIO(self._body) else: @@ -333,13 +369,19 @@ class HttpRequest: # attempts to parse POST data again. self._mark_post_parse_error() raise - elif self.content_type == 'application/x-www-form-urlencoded': - self._post, self._files = QueryDict(self.body, encoding=self._encoding), MultiValueDict() + elif self.content_type == "application/x-www-form-urlencoded": + self._post, self._files = ( + QueryDict(self.body, encoding=self._encoding), + MultiValueDict(), + ) else: - self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict() + self._post, self._files = ( + QueryDict(encoding=self._encoding), + MultiValueDict(), + ) def close(self): - if hasattr(self, '_files'): + if hasattr(self, "_files"): for f in chain.from_iterable(list_[1] for list_ in self._files.lists()): f.close() @@ -366,16 +408,16 @@ class HttpRequest: raise UnreadablePostError(*e.args) from e def __iter__(self): - return iter(self.readline, b'') + return iter(self.readline, b"") def readlines(self): return list(self) class HttpHeaders(CaseInsensitiveMapping): - HTTP_PREFIX = 'HTTP_' + HTTP_PREFIX = "HTTP_" # PEP 333 gives two headers which aren't prepended with HTTP_. - UNPREFIXED_HEADERS = {'CONTENT_TYPE', 'CONTENT_LENGTH'} + UNPREFIXED_HEADERS = {"CONTENT_TYPE", "CONTENT_LENGTH"} def __init__(self, environ): headers = {} @@ -387,15 +429,15 @@ class HttpHeaders(CaseInsensitiveMapping): def __getitem__(self, key): """Allow header lookup using underscores in place of hyphens.""" - return super().__getitem__(key.replace('_', '-')) + return super().__getitem__(key.replace("_", "-")) @classmethod def parse_header_name(cls, header): if header.startswith(cls.HTTP_PREFIX): - header = header[len(cls.HTTP_PREFIX):] + header = header[len(cls.HTTP_PREFIX) :] elif header not in cls.UNPREFIXED_HEADERS: return None - return header.replace('_', '-').title() + return header.replace("_", "-").title() class QueryDict(MultiValueDict): @@ -421,11 +463,11 @@ class QueryDict(MultiValueDict): def __init__(self, query_string=None, mutable=False, encoding=None): super().__init__() self.encoding = encoding or settings.DEFAULT_CHARSET - query_string = query_string or '' + query_string = query_string or "" parse_qsl_kwargs = { - 'keep_blank_values': True, - 'encoding': self.encoding, - 'max_num_fields': settings.DATA_UPLOAD_MAX_NUMBER_FIELDS, + "keep_blank_values": True, + "encoding": self.encoding, + "max_num_fields": settings.DATA_UPLOAD_MAX_NUMBER_FIELDS, } if isinstance(query_string, bytes): # query_string normally contains URL-encoded data, a subset of ASCII. @@ -433,7 +475,7 @@ class QueryDict(MultiValueDict): query_string = query_string.decode(self.encoding) except UnicodeDecodeError: # ... but some user agents are misbehaving :-( - query_string = query_string.decode('iso-8859-1') + query_string = query_string.decode("iso-8859-1") try: for key, value in parse_qsl(query_string, **parse_qsl_kwargs): self.appendlist(key, value) @@ -443,18 +485,18 @@ class QueryDict(MultiValueDict): # the exception was raised by exceeding the value of max_num_fields # instead of fragile checks of exception message strings. raise TooManyFieldsSent( - 'The number of GET/POST parameters exceeded ' - 'settings.DATA_UPLOAD_MAX_NUMBER_FIELDS.' + "The number of GET/POST parameters exceeded " + "settings.DATA_UPLOAD_MAX_NUMBER_FIELDS." ) from e self._mutable = mutable @classmethod - def fromkeys(cls, iterable, value='', mutable=False, encoding=None): + def fromkeys(cls, iterable, value="", mutable=False, encoding=None): """ Return a new QueryDict with keys (may be repeated) from an iterable and values from value. """ - q = cls('', mutable=True, encoding=encoding) + q = cls("", mutable=True, encoding=encoding) for key in iterable: q.appendlist(key, value) if not mutable: @@ -486,13 +528,13 @@ class QueryDict(MultiValueDict): super().__delitem__(key) def __copy__(self): - result = self.__class__('', mutable=True, encoding=self.encoding) + result = self.__class__("", mutable=True, encoding=self.encoding) for key, value in self.lists(): result.setlist(key, value) return result def __deepcopy__(self, memo): - result = self.__class__('', mutable=True, encoding=self.encoding) + result = self.__class__("", mutable=True, encoding=self.encoding) memo[id(self)] = result for key, value in self.lists(): result.setlist(copy.deepcopy(key, memo), copy.deepcopy(value, memo)) @@ -554,48 +596,50 @@ class QueryDict(MultiValueDict): safe = safe.encode(self.encoding) def encode(k, v): - return '%s=%s' % ((quote(k, safe), quote(v, safe))) + return "%s=%s" % ((quote(k, safe), quote(v, safe))) + else: + def encode(k, v): return urlencode({k: v}) + for k, list_ in self.lists(): output.extend( encode(k.encode(self.encoding), str(v).encode(self.encoding)) for v in list_ ) - return '&'.join(output) + return "&".join(output) class MediaType: def __init__(self, media_type_raw_line): full_type, self.params = parse_header( - media_type_raw_line.encode('ascii') if media_type_raw_line else b'' + media_type_raw_line.encode("ascii") if media_type_raw_line else b"" ) - self.main_type, _, self.sub_type = full_type.partition('/') + self.main_type, _, self.sub_type = full_type.partition("/") def __str__(self): - params_str = ''.join( - '; %s=%s' % (k, v.decode('ascii')) - for k, v in self.params.items() + params_str = "".join( + "; %s=%s" % (k, v.decode("ascii")) for k, v in self.params.items() ) - return '%s%s%s' % ( + return "%s%s%s" % ( self.main_type, - ('/%s' % self.sub_type) if self.sub_type else '', + ("/%s" % self.sub_type) if self.sub_type else "", params_str, ) def __repr__(self): - return '<%s: %s>' % (self.__class__.__qualname__, self) + return "<%s: %s>" % (self.__class__.__qualname__, self) @property def is_all_types(self): - return self.main_type == '*' and self.sub_type == '*' + return self.main_type == "*" and self.sub_type == "*" def match(self, other): if self.is_all_types: return True other = MediaType(other) - if self.main_type == other.main_type and self.sub_type in {'*', other.sub_type}: + if self.main_type == other.main_type and self.sub_type in {"*", other.sub_type}: return True return False @@ -612,7 +656,7 @@ def bytes_to_text(s, encoding): Return any non-bytes objects without change. """ if isinstance(s, bytes): - return str(s, encoding, 'replace') + return str(s, encoding, "replace") else: return s @@ -627,15 +671,15 @@ def split_domain_port(host): host = host.lower() if not host_validation_re.match(host): - return '', '' + return "", "" - if host[-1] == ']': + if host[-1] == "]": # It's an IPv6 address without a port. - return host, '' - bits = host.rsplit(':', 1) - domain, port = bits if len(bits) == 2 else (bits[0], '') + return host, "" + bits = host.rsplit(":", 1) + domain, port = bits if len(bits) == 2 else (bits[0], "") # Remove a trailing dot (if present) from the domain. - domain = domain[:-1] if domain.endswith('.') else domain + domain = domain[:-1] if domain.endswith(".") else domain return domain, port @@ -654,8 +698,10 @@ def validate_host(host, allowed_hosts): Return ``True`` for a valid host, ``False`` otherwise. """ - return any(pattern == '*' or is_same_domain(host, pattern) for pattern in allowed_hosts) + return any( + pattern == "*" or is_same_domain(host, pattern) for pattern in allowed_hosts + ) def parse_accept_header(header): - return [MediaType(token) for token in header.split(',') if token.strip()] + return [MediaType(token) for token in header.split(",") if token.strip()] diff --git a/django/http/response.py b/django/http/response.py index 847d824191..e40f2d169b 100644 --- a/django/http/response.py +++ b/django/http/response.py @@ -21,7 +21,9 @@ from django.utils.encoding import iri_to_uri from django.utils.http import http_date from django.utils.regex_helper import _lazy_re_compile -_charset_from_content_type_re = _lazy_re_compile(r';\s*charset=(?P<charset>[^\s;]+)', re.I) +_charset_from_content_type_re = _lazy_re_compile( + r";\s*charset=(?P<charset>[^\s;]+)", re.I +) class ResponseHeaders(CaseInsensitiveMapping): @@ -42,11 +44,12 @@ class ResponseHeaders(CaseInsensitiveMapping): """ if not isinstance(value, (bytes, str)): value = str(value) - if ( - (isinstance(value, bytes) and (b'\n' in value or b'\r' in value)) or - (isinstance(value, str) and ('\n' in value or '\r' in value)) + if (isinstance(value, bytes) and (b"\n" in value or b"\r" in value)) or ( + isinstance(value, str) and ("\n" in value or "\r" in value) ): - raise BadHeaderError("Header values can't contain newlines (got %r)" % value) + raise BadHeaderError( + "Header values can't contain newlines (got %r)" % value + ) try: if isinstance(value, str): # Ensure string is valid in given charset @@ -56,9 +59,9 @@ class ResponseHeaders(CaseInsensitiveMapping): value = value.decode(charset) except UnicodeError as e: if mime_encode: - value = Header(value, 'utf-8', maxlinelen=sys.maxsize).encode() + value = Header(value, "utf-8", maxlinelen=sys.maxsize).encode() else: - e.reason += ', HTTP response headers must be in %s format' % charset + e.reason += ", HTTP response headers must be in %s format" % charset raise return value @@ -66,8 +69,8 @@ class ResponseHeaders(CaseInsensitiveMapping): self.pop(key) def __setitem__(self, key, value): - key = self._convert_to_charset(key, 'ascii') - value = self._convert_to_charset(value, 'latin-1', mime_encode=True) + key = self._convert_to_charset(key, "ascii") + value = self._convert_to_charset(value, "latin-1", mime_encode=True) self._store[key.lower()] = (key, value) def pop(self, key, default=None): @@ -92,18 +95,20 @@ class HttpResponseBase: status_code = 200 - def __init__(self, content_type=None, status=None, reason=None, charset=None, headers=None): + def __init__( + self, content_type=None, status=None, reason=None, charset=None, headers=None + ): self.headers = ResponseHeaders(headers or {}) self._charset = charset - if content_type and 'Content-Type' in self.headers: + if content_type and "Content-Type" in self.headers: raise ValueError( "'headers' must not contain 'Content-Type' when the " "'content_type' parameter is provided." ) - if 'Content-Type' not in self.headers: + if "Content-Type" not in self.headers: if content_type is None: - content_type = 'text/html; charset=%s' % self.charset - self.headers['Content-Type'] = content_type + content_type = "text/html; charset=%s" % self.charset + self.headers["Content-Type"] = content_type self._resource_closers = [] # This parameter is set by the handler. It's necessary to preserve the # historical behavior of request_finished. @@ -114,10 +119,10 @@ class HttpResponseBase: try: self.status_code = int(status) except (ValueError, TypeError): - raise TypeError('HTTP status code must be an integer.') + raise TypeError("HTTP status code must be an integer.") if not 100 <= self.status_code <= 599: - raise ValueError('HTTP status code must be an integer from 100 to 599.') + raise ValueError("HTTP status code must be an integer from 100 to 599.") self._reason_phrase = reason @property @@ -126,7 +131,7 @@ class HttpResponseBase: return self._reason_phrase # Leave self._reason_phrase unset in order to use the default # reason phrase for status code. - return responses.get(self.status_code, 'Unknown Status Code') + return responses.get(self.status_code, "Unknown Status Code") @reason_phrase.setter def reason_phrase(self, value): @@ -136,11 +141,11 @@ class HttpResponseBase: def charset(self): if self._charset is not None: return self._charset - content_type = self.get('Content-Type', '') + content_type = self.get("Content-Type", "") matched = _charset_from_content_type_re.search(content_type) if matched: # Extract the charset and strip its double quotes - return matched['charset'].replace('"', '') + return matched["charset"].replace('"', "") return settings.DEFAULT_CHARSET @charset.setter @@ -149,16 +154,22 @@ class HttpResponseBase: def serialize_headers(self): """HTTP headers as a bytestring.""" - return b'\r\n'.join([ - key.encode('ascii') + b': ' + value.encode('latin-1') - for key, value in self.headers.items() - ]) + return b"\r\n".join( + [ + key.encode("ascii") + b": " + value.encode("latin-1") + for key, value in self.headers.items() + ] + ) __bytes__ = serialize_headers @property def _content_type_for_repr(self): - return ', "%s"' % self.headers['Content-Type'] if 'Content-Type' in self.headers else '' + return ( + ', "%s"' % self.headers["Content-Type"] + if "Content-Type" in self.headers + else "" + ) def __setitem__(self, header, value): self.headers[header] = value @@ -181,8 +192,18 @@ class HttpResponseBase: def get(self, header, alternate=None): return self.headers.get(header, alternate) - def set_cookie(self, key, value='', max_age=None, expires=None, path='/', - domain=None, secure=False, httponly=False, samesite=None): + def set_cookie( + self, + key, + value="", + max_age=None, + expires=None, + path="/", + domain=None, + secure=False, + httponly=False, + samesite=None, + ): """ Set a cookie. @@ -206,47 +227,51 @@ class HttpResponseBase: expires = None max_age = max(0, delta.days * 86400 + delta.seconds) else: - self.cookies[key]['expires'] = expires + self.cookies[key]["expires"] = expires else: - self.cookies[key]['expires'] = '' + self.cookies[key]["expires"] = "" if max_age is not None: - self.cookies[key]['max-age'] = int(max_age) + self.cookies[key]["max-age"] = int(max_age) # IE requires expires, so set it if hasn't been already. if not expires: - self.cookies[key]['expires'] = http_date(time.time() + max_age) + self.cookies[key]["expires"] = http_date(time.time() + max_age) if path is not None: - self.cookies[key]['path'] = path + self.cookies[key]["path"] = path if domain is not None: - self.cookies[key]['domain'] = domain + self.cookies[key]["domain"] = domain if secure: - self.cookies[key]['secure'] = True + self.cookies[key]["secure"] = True if httponly: - self.cookies[key]['httponly'] = True + self.cookies[key]["httponly"] = True if samesite: - if samesite.lower() not in ('lax', 'none', 'strict'): + if samesite.lower() not in ("lax", "none", "strict"): raise ValueError('samesite must be "lax", "none", or "strict".') - self.cookies[key]['samesite'] = samesite + self.cookies[key]["samesite"] = samesite def setdefault(self, key, value): """Set a header unless it has already been set.""" self.headers.setdefault(key, value) - def set_signed_cookie(self, key, value, salt='', **kwargs): + def set_signed_cookie(self, key, value, salt="", **kwargs): value = signing.get_cookie_signer(salt=key + salt).sign(value) return self.set_cookie(key, value, **kwargs) - def delete_cookie(self, key, path='/', domain=None, samesite=None): + def delete_cookie(self, key, path="/", domain=None, samesite=None): # Browsers can ignore the Set-Cookie header if the cookie doesn't use # the secure flag and: # - the cookie name starts with "__Host-" or "__Secure-", or # - the samesite is "none". - secure = ( - key.startswith(('__Secure-', '__Host-')) or - (samesite and samesite.lower() == 'none') + secure = key.startswith(("__Secure-", "__Host-")) or ( + samesite and samesite.lower() == "none" ) self.set_cookie( - key, max_age=0, path=path, domain=domain, secure=secure, - expires='Thu, 01 Jan 1970 00:00:00 GMT', samesite=samesite, + key, + max_age=0, + path=path, + domain=domain, + secure=secure, + expires="Thu, 01 Jan 1970 00:00:00 GMT", + samesite=samesite, ) # Common methods used by subclasses @@ -284,13 +309,15 @@ class HttpResponseBase: signals.request_finished.send(sender=self._handler_class) def write(self, content): - raise OSError('This %s instance is not writable' % self.__class__.__name__) + raise OSError("This %s instance is not writable" % self.__class__.__name__) def flush(self): pass def tell(self): - raise OSError('This %s instance cannot tell its position' % self.__class__.__name__) + raise OSError( + "This %s instance cannot tell its position" % self.__class__.__name__ + ) # These methods partially implement a stream-like object interface. # See https://docs.python.org/library/io.html#io.IOBase @@ -305,7 +332,7 @@ class HttpResponseBase: return False def writelines(self, lines): - raise OSError('This %s instance is not writable' % self.__class__.__name__) + raise OSError("This %s instance is not writable" % self.__class__.__name__) class HttpResponse(HttpResponseBase): @@ -317,37 +344,36 @@ class HttpResponse(HttpResponseBase): streaming = False - def __init__(self, content=b'', *args, **kwargs): + def __init__(self, content=b"", *args, **kwargs): super().__init__(*args, **kwargs) # Content is a bytestring. See the `content` property methods. self.content = content def __repr__(self): - return '<%(cls)s status_code=%(status_code)d%(content_type)s>' % { - 'cls': self.__class__.__name__, - 'status_code': self.status_code, - 'content_type': self._content_type_for_repr, + return "<%(cls)s status_code=%(status_code)d%(content_type)s>" % { + "cls": self.__class__.__name__, + "status_code": self.status_code, + "content_type": self._content_type_for_repr, } def serialize(self): """Full HTTP message, including headers, as a bytestring.""" - return self.serialize_headers() + b'\r\n\r\n' + self.content + return self.serialize_headers() + b"\r\n\r\n" + self.content __bytes__ = serialize @property def content(self): - return b''.join(self._container) + return b"".join(self._container) @content.setter def content(self, value): # Consume iterators upon assignment to allow repeated iteration. - if ( - hasattr(value, '__iter__') and - not isinstance(value, (bytes, memoryview, str)) + if hasattr(value, "__iter__") and not isinstance( + value, (bytes, memoryview, str) ): - content = b''.join(self.make_bytes(chunk) for chunk in value) - if hasattr(value, 'close'): + content = b"".join(self.make_bytes(chunk) for chunk in value) + if hasattr(value, "close"): try: value.close() except Exception: @@ -395,10 +421,10 @@ class StreamingHttpResponse(HttpResponseBase): self.streaming_content = streaming_content def __repr__(self): - return '<%(cls)s status_code=%(status_code)d%(content_type)s>' % { - 'cls': self.__class__.__qualname__, - 'status_code': self.status_code, - 'content_type': self._content_type_for_repr, + return "<%(cls)s status_code=%(status_code)d%(content_type)s>" % { + "cls": self.__class__.__qualname__, + "status_code": self.status_code, + "content_type": self._content_type_for_repr, } @property @@ -419,37 +445,40 @@ class StreamingHttpResponse(HttpResponseBase): def _set_streaming_content(self, value): # Ensure we can never iterate on "value" more than once. self._iterator = iter(value) - if hasattr(value, 'close'): + if hasattr(value, "close"): self._resource_closers.append(value.close) def __iter__(self): return self.streaming_content def getvalue(self): - return b''.join(self.streaming_content) + return b"".join(self.streaming_content) class FileResponse(StreamingHttpResponse): """ A streaming HTTP response class optimized for files. """ + block_size = 4096 - def __init__(self, *args, as_attachment=False, filename='', **kwargs): + def __init__(self, *args, as_attachment=False, filename="", **kwargs): self.as_attachment = as_attachment self.filename = filename - self._no_explicit_content_type = 'content_type' not in kwargs or kwargs['content_type'] is None + self._no_explicit_content_type = ( + "content_type" not in kwargs or kwargs["content_type"] is None + ) super().__init__(*args, **kwargs) def _set_streaming_content(self, value): - if not hasattr(value, 'read'): + if not hasattr(value, "read"): self.file_to_stream = None return super()._set_streaming_content(value) self.file_to_stream = filelike = value - if hasattr(filelike, 'close'): + if hasattr(filelike, "close"): self._resource_closers.append(filelike.close) - value = iter(lambda: filelike.read(self.block_size), b'') + value = iter(lambda: filelike.read(self.block_size), b"") self.set_headers(filelike) super()._set_streaming_content(value) @@ -458,22 +487,30 @@ class FileResponse(StreamingHttpResponse): Set some common response headers (Content-Length, Content-Type, and Content-Disposition) based on the `filelike` response content. """ - filename = getattr(filelike, 'name', '') - filename = filename if isinstance(filename, str) else '' - seekable = hasattr(filelike, 'seek') and (not hasattr(filelike, 'seekable') or filelike.seekable()) - if hasattr(filelike, 'tell'): + filename = getattr(filelike, "name", "") + filename = filename if isinstance(filename, str) else "" + seekable = hasattr(filelike, "seek") and ( + not hasattr(filelike, "seekable") or filelike.seekable() + ) + if hasattr(filelike, "tell"): if seekable: initial_position = filelike.tell() filelike.seek(0, io.SEEK_END) - self.headers['Content-Length'] = filelike.tell() - initial_position + self.headers["Content-Length"] = filelike.tell() - initial_position filelike.seek(initial_position) - elif hasattr(filelike, 'getbuffer'): - self.headers['Content-Length'] = filelike.getbuffer().nbytes - filelike.tell() + elif hasattr(filelike, "getbuffer"): + self.headers["Content-Length"] = ( + filelike.getbuffer().nbytes - filelike.tell() + ) elif os.path.exists(filename): - self.headers['Content-Length'] = os.path.getsize(filename) - filelike.tell() + self.headers["Content-Length"] = ( + os.path.getsize(filename) - filelike.tell() + ) elif seekable: - self.headers['Content-Length'] = sum(iter(lambda: len(filelike.read(self.block_size)), 0)) - filelike.seek(-int(self.headers['Content-Length']), io.SEEK_END) + self.headers["Content-Length"] = sum( + iter(lambda: len(filelike.read(self.block_size)), 0) + ) + filelike.seek(-int(self.headers["Content-Length"]), io.SEEK_END) filename = os.path.basename(self.filename or filename) if self._no_explicit_content_type: @@ -482,45 +519,54 @@ class FileResponse(StreamingHttpResponse): # Encoding isn't set to prevent browsers from automatically # uncompressing files. content_type = { - 'bzip2': 'application/x-bzip', - 'gzip': 'application/gzip', - 'xz': 'application/x-xz', + "bzip2": "application/x-bzip", + "gzip": "application/gzip", + "xz": "application/x-xz", }.get(encoding, content_type) - self.headers['Content-Type'] = content_type or 'application/octet-stream' + self.headers["Content-Type"] = ( + content_type or "application/octet-stream" + ) else: - self.headers['Content-Type'] = 'application/octet-stream' + self.headers["Content-Type"] = "application/octet-stream" if filename: - disposition = 'attachment' if self.as_attachment else 'inline' + disposition = "attachment" if self.as_attachment else "inline" try: - filename.encode('ascii') + filename.encode("ascii") file_expr = 'filename="{}"'.format(filename) except UnicodeEncodeError: file_expr = "filename*=utf-8''{}".format(quote(filename)) - self.headers['Content-Disposition'] = '{}; {}'.format(disposition, file_expr) + self.headers["Content-Disposition"] = "{}; {}".format( + disposition, file_expr + ) elif self.as_attachment: - self.headers['Content-Disposition'] = 'attachment' + self.headers["Content-Disposition"] = "attachment" class HttpResponseRedirectBase(HttpResponse): - allowed_schemes = ['http', 'https', 'ftp'] + allowed_schemes = ["http", "https", "ftp"] def __init__(self, redirect_to, *args, **kwargs): super().__init__(*args, **kwargs) - self['Location'] = iri_to_uri(redirect_to) + self["Location"] = iri_to_uri(redirect_to) parsed = urlparse(str(redirect_to)) if parsed.scheme and parsed.scheme not in self.allowed_schemes: - raise DisallowedRedirect("Unsafe redirect to URL with protocol '%s'" % parsed.scheme) + raise DisallowedRedirect( + "Unsafe redirect to URL with protocol '%s'" % parsed.scheme + ) - url = property(lambda self: self['Location']) + url = property(lambda self: self["Location"]) def __repr__(self): - return '<%(cls)s status_code=%(status_code)d%(content_type)s, url="%(url)s">' % { - 'cls': self.__class__.__name__, - 'status_code': self.status_code, - 'content_type': self._content_type_for_repr, - 'url': self.url, - } + return ( + '<%(cls)s status_code=%(status_code)d%(content_type)s, url="%(url)s">' + % { + "cls": self.__class__.__name__, + "status_code": self.status_code, + "content_type": self._content_type_for_repr, + "url": self.url, + } + ) class HttpResponseRedirect(HttpResponseRedirectBase): @@ -536,12 +582,14 @@ class HttpResponseNotModified(HttpResponse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - del self['content-type'] + del self["content-type"] @HttpResponse.content.setter def content(self, value): if value: - raise AttributeError("You cannot set content to a 304 (Not Modified) response") + raise AttributeError( + "You cannot set content to a 304 (Not Modified) response" + ) self._container = [] @@ -562,14 +610,14 @@ class HttpResponseNotAllowed(HttpResponse): def __init__(self, permitted_methods, *args, **kwargs): super().__init__(*args, **kwargs) - self['Allow'] = ', '.join(permitted_methods) + self["Allow"] = ", ".join(permitted_methods) def __repr__(self): - return '<%(cls)s [%(methods)s] status_code=%(status_code)d%(content_type)s>' % { - 'cls': self.__class__.__name__, - 'status_code': self.status_code, - 'content_type': self._content_type_for_repr, - 'methods': self['Allow'], + return "<%(cls)s [%(methods)s] status_code=%(status_code)d%(content_type)s>" % { + "cls": self.__class__.__name__, + "status_code": self.status_code, + "content_type": self._content_type_for_repr, + "methods": self["Allow"], } @@ -599,15 +647,21 @@ class JsonResponse(HttpResponse): :param json_dumps_params: A dictionary of kwargs passed to json.dumps(). """ - def __init__(self, data, encoder=DjangoJSONEncoder, safe=True, - json_dumps_params=None, **kwargs): + def __init__( + self, + data, + encoder=DjangoJSONEncoder, + safe=True, + json_dumps_params=None, + **kwargs, + ): if safe and not isinstance(data, dict): raise TypeError( - 'In order to allow non-dict objects to be serialized set the ' - 'safe parameter to False.' + "In order to allow non-dict objects to be serialized set the " + "safe parameter to False." ) if json_dumps_params is None: json_dumps_params = {} - kwargs.setdefault('content_type', 'application/json') + kwargs.setdefault("content_type", "application/json") data = json.dumps(data, cls=encoder, **json_dumps_params) super().__init__(content=data, **kwargs) diff --git a/django/middleware/cache.py b/django/middleware/cache.py index c191118183..0fdffe1bbe 100644 --- a/django/middleware/cache.py +++ b/django/middleware/cache.py @@ -46,7 +46,10 @@ More details about how the caching works: from django.conf import settings from django.core.cache import DEFAULT_CACHE_ALIAS, caches from django.utils.cache import ( - get_cache_key, get_max_age, has_vary_header, learn_cache_key, + get_cache_key, + get_max_age, + has_vary_header, + learn_cache_key, patch_response_headers, ) from django.utils.deprecation import MiddlewareMixin @@ -61,6 +64,7 @@ class UpdateCacheMiddleware(MiddlewareMixin): UpdateCacheMiddleware must be the first piece of middleware in MIDDLEWARE so that it'll get called last during the response phase. """ + def __init__(self, get_response): super().__init__(get_response) self.cache_timeout = settings.CACHE_MIDDLEWARE_SECONDS @@ -73,7 +77,7 @@ class UpdateCacheMiddleware(MiddlewareMixin): return caches[self.cache_alias] def _should_update_cache(self, request, response): - return hasattr(request, '_cache_update_cache') and request._cache_update_cache + return hasattr(request, "_cache_update_cache") and request._cache_update_cache def process_response(self, request, response): """Set the cache, if needed.""" @@ -86,11 +90,15 @@ class UpdateCacheMiddleware(MiddlewareMixin): # Don't cache responses that set a user-specific (and maybe security # sensitive) cookie in response to a cookie-less request. - if not request.COOKIES and response.cookies and has_vary_header(response, 'Cookie'): + if ( + not request.COOKIES + and response.cookies + and has_vary_header(response, "Cookie") + ): return response # Don't cache a response with 'Cache-Control: private' - if 'private' in response.get('Cache-Control', ()): + if "private" in response.get("Cache-Control", ()): return response # Page timeout takes precedence over the "max-age" and the default @@ -107,8 +115,10 @@ class UpdateCacheMiddleware(MiddlewareMixin): return response patch_response_headers(response, timeout) if timeout and response.status_code == 200: - cache_key = learn_cache_key(request, response, timeout, self.key_prefix, cache=self.cache) - if hasattr(response, 'render') and callable(response.render): + cache_key = learn_cache_key( + request, response, timeout, self.key_prefix, cache=self.cache + ) + if hasattr(response, "render") and callable(response.render): response.add_post_render_callback( lambda r: self.cache.set(cache_key, r, timeout) ) @@ -125,6 +135,7 @@ class FetchFromCacheMiddleware(MiddlewareMixin): FetchFromCacheMiddleware must be the last piece of middleware in MIDDLEWARE so that it'll get called last during the request phase. """ + def __init__(self, get_response): super().__init__(get_response) self.key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX @@ -139,19 +150,21 @@ class FetchFromCacheMiddleware(MiddlewareMixin): Check whether the page is already cached and return the cached version if available. """ - if request.method not in ('GET', 'HEAD'): + if request.method not in ("GET", "HEAD"): request._cache_update_cache = False return None # Don't bother checking the cache. # try and get the cached GET response - cache_key = get_cache_key(request, self.key_prefix, 'GET', cache=self.cache) + cache_key = get_cache_key(request, self.key_prefix, "GET", cache=self.cache) if cache_key is None: request._cache_update_cache = True return None # No cache information available, need to rebuild. response = self.cache.get(cache_key) # if it wasn't found and we are looking for a HEAD, try looking just for that - if response is None and request.method == 'HEAD': - cache_key = get_cache_key(request, self.key_prefix, 'HEAD', cache=self.cache) + if response is None and request.method == "HEAD": + cache_key = get_cache_key( + request, self.key_prefix, "HEAD", cache=self.cache + ) response = self.cache.get(cache_key) if response is None: @@ -170,6 +183,7 @@ class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware): Also used as the hook point for the cache decorator, which is generated using the decorator-from-middleware utility. """ + def __init__(self, get_response, cache_timeout=None, page_timeout=None, **kwargs): super().__init__(get_response) # We need to differentiate between "provided, but using default value", @@ -178,14 +192,14 @@ class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware): # we need to use middleware defaults. try: - key_prefix = kwargs['key_prefix'] + key_prefix = kwargs["key_prefix"] if key_prefix is None: - key_prefix = '' + key_prefix = "" self.key_prefix = key_prefix except KeyError: pass try: - cache_alias = kwargs['cache_alias'] + cache_alias = kwargs["cache_alias"] if cache_alias is None: cache_alias = DEFAULT_CACHE_ALIAS self.cache_alias = cache_alias diff --git a/django/middleware/clickjacking.py b/django/middleware/clickjacking.py index 0161f8eb8f..072f4148a2 100644 --- a/django/middleware/clickjacking.py +++ b/django/middleware/clickjacking.py @@ -21,16 +21,17 @@ class XFrameOptionsMiddleware(MiddlewareMixin): response from being loaded in a frame in any site, set X_FRAME_OPTIONS in your project's Django settings to 'DENY'. """ + def process_response(self, request, response): # Don't set it if it's already in the response - if response.get('X-Frame-Options') is not None: + if response.get("X-Frame-Options") is not None: return response # Don't set it if they used @xframe_options_exempt - if getattr(response, 'xframe_options_exempt', False): + if getattr(response, "xframe_options_exempt", False): return response - response.headers['X-Frame-Options'] = self.get_xframe_options_value( + response.headers["X-Frame-Options"] = self.get_xframe_options_value( request, response, ) @@ -44,4 +45,4 @@ class XFrameOptionsMiddleware(MiddlewareMixin): This method can be overridden if needed, allowing it to vary based on the request or response. """ - return getattr(settings, 'X_FRAME_OPTIONS', 'DENY').upper() + return getattr(settings, "X_FRAME_OPTIONS", "DENY").upper() diff --git a/django/middleware/common.py b/django/middleware/common.py index e42d05e255..c652374aec 100644 --- a/django/middleware/common.py +++ b/django/middleware/common.py @@ -38,16 +38,16 @@ class CommonMiddleware(MiddlewareMixin): """ # Check for denied User-Agents - user_agent = request.META.get('HTTP_USER_AGENT') + user_agent = request.META.get("HTTP_USER_AGENT") if user_agent is not None: for user_agent_regex in settings.DISALLOWED_USER_AGENTS: if user_agent_regex.search(user_agent): - raise PermissionDenied('Forbidden user agent') + raise PermissionDenied("Forbidden user agent") # Check for a redirect based on settings.PREPEND_WWW host = request.get_host() - must_prepend = settings.PREPEND_WWW and host and not host.startswith('www.') - redirect_url = ('%s://www.%s' % (request.scheme, host)) if must_prepend else '' + must_prepend = settings.PREPEND_WWW and host and not host.startswith("www.") + redirect_url = ("%s://www.%s" % (request.scheme, host)) if must_prepend else "" # Check if a slash should be appended if self.should_redirect_with_slash(request): @@ -65,13 +65,13 @@ class CommonMiddleware(MiddlewareMixin): Return True if settings.APPEND_SLASH is True and appending a slash to the request path turns an invalid path into a valid one. """ - if settings.APPEND_SLASH and not request.path_info.endswith('/'): - urlconf = getattr(request, 'urlconf', None) + if settings.APPEND_SLASH and not request.path_info.endswith("/"): + urlconf = getattr(request, "urlconf", None) if not is_valid_path(request.path_info, urlconf): - match = is_valid_path('%s/' % request.path_info, urlconf) + match = is_valid_path("%s/" % request.path_info, urlconf) if match: view = match.func - return getattr(view, 'should_append_slash', True) + return getattr(view, "should_append_slash", True) return False def get_full_path_with_slash(self, request): @@ -84,15 +84,16 @@ class CommonMiddleware(MiddlewareMixin): new_path = request.get_full_path(force_append_slash=True) # Prevent construction of scheme relative urls. new_path = escape_leading_slashes(new_path) - if settings.DEBUG and request.method in ('POST', 'PUT', 'PATCH'): + if settings.DEBUG and request.method in ("POST", "PUT", "PATCH"): raise RuntimeError( "You called this URL via %(method)s, but the URL doesn't end " "in a slash and you have APPEND_SLASH set. Django can't " "redirect to the slash URL while maintaining %(method)s data. " "Change your form to point to %(url)s (note the trailing " - "slash), or set APPEND_SLASH=False in your Django settings." % { - 'method': request.method, - 'url': request.get_host() + new_path, + "slash), or set APPEND_SLASH=False in your Django settings." + % { + "method": request.method, + "url": request.get_host() + new_path, } ) return new_path @@ -109,28 +110,32 @@ class CommonMiddleware(MiddlewareMixin): # Add the Content-Length header to non-streaming responses if not # already set. - if not response.streaming and not response.has_header('Content-Length'): - response.headers['Content-Length'] = str(len(response.content)) + if not response.streaming and not response.has_header("Content-Length"): + response.headers["Content-Length"] = str(len(response.content)) return response class BrokenLinkEmailsMiddleware(MiddlewareMixin): - def process_response(self, request, response): """Send broken link emails for relevant 404 NOT FOUND responses.""" if response.status_code == 404 and not settings.DEBUG: domain = request.get_host() path = request.get_full_path() - referer = request.META.get('HTTP_REFERER', '') + referer = request.META.get("HTTP_REFERER", "") if not self.is_ignorable_request(request, path, domain, referer): - ua = request.META.get('HTTP_USER_AGENT', '<none>') - ip = request.META.get('REMOTE_ADDR', '<none>') + ua = request.META.get("HTTP_USER_AGENT", "<none>") + ip = request.META.get("REMOTE_ADDR", "<none>") mail_managers( - "Broken %slink on %s" % ( - ('INTERNAL ' if self.is_internal_request(domain, referer) else ''), - domain + "Broken %slink on %s" + % ( + ( + "INTERNAL " + if self.is_internal_request(domain, referer) + else "" + ), + domain, ), "Referrer: %s\nRequested URL: %s\nUser agent: %s\n" "IP address: %s\n" % (referer, path, ua, ip), @@ -158,17 +163,17 @@ class BrokenLinkEmailsMiddleware(MiddlewareMixin): # APPEND_SLASH is enabled and the referer is equal to the current URL # without a trailing slash indicating an internal redirect. - if settings.APPEND_SLASH and uri.endswith('/') and referer == uri[:-1]: + if settings.APPEND_SLASH and uri.endswith("/") and referer == uri[:-1]: return True # A '?' in referer is identified as a search engine source. - if not self.is_internal_request(domain, referer) and '?' in referer: + if not self.is_internal_request(domain, referer) and "?" in referer: return True # The referer is equal to the current URL, ignoring the scheme (assumed # to be a poorly implemented bot). parsed_referer = urlparse(referer) - if parsed_referer.netloc in ['', domain] and parsed_referer.path == uri: + if parsed_referer.netloc in ["", domain] and parsed_referer.path == uri: return True return any(pattern.search(uri) for pattern in settings.IGNORABLE_404_URLS) diff --git a/django/middleware/csrf.py b/django/middleware/csrf.py index 6be68ebd76..94f580fa71 100644 --- a/django/middleware/csrf.py +++ b/django/middleware/csrf.py @@ -22,27 +22,29 @@ from django.utils.http import is_same_domain from django.utils.log import log_response from django.utils.regex_helper import _lazy_re_compile -logger = logging.getLogger('django.security.csrf') +logger = logging.getLogger("django.security.csrf") # This matches if any character is not in CSRF_ALLOWED_CHARS. -invalid_token_chars_re = _lazy_re_compile('[^a-zA-Z0-9]') +invalid_token_chars_re = _lazy_re_compile("[^a-zA-Z0-9]") REASON_BAD_ORIGIN = "Origin checking failed - %s does not match any trusted origins." REASON_NO_REFERER = "Referer checking failed - no Referer." REASON_BAD_REFERER = "Referer checking failed - %s does not match any trusted origins." REASON_NO_CSRF_COOKIE = "CSRF cookie not set." -REASON_CSRF_TOKEN_MISSING = 'CSRF token missing.' +REASON_CSRF_TOKEN_MISSING = "CSRF token missing." REASON_MALFORMED_REFERER = "Referer checking failed - Referer is malformed." -REASON_INSECURE_REFERER = "Referer checking failed - Referer is insecure while host is secure." +REASON_INSECURE_REFERER = ( + "Referer checking failed - Referer is insecure while host is secure." +) # The reason strings below are for passing to InvalidTokenFormat. They are # phrases without a subject because they can be in reference to either the CSRF # cookie or non-cookie token. -REASON_INCORRECT_LENGTH = 'has incorrect length' -REASON_INVALID_CHARACTERS = 'has invalid characters' +REASON_INCORRECT_LENGTH = "has incorrect length" +REASON_INVALID_CHARACTERS = "has invalid characters" CSRF_SECRET_LENGTH = 32 CSRF_TOKEN_LENGTH = 2 * CSRF_SECRET_LENGTH CSRF_ALLOWED_CHARS = string.ascii_letters + string.digits -CSRF_SESSION_KEY = '_csrftoken' +CSRF_SESSION_KEY = "_csrftoken" def _get_failure_view(): @@ -62,7 +64,7 @@ def _mask_cipher_secret(secret): mask = _get_new_csrf_string() chars = CSRF_ALLOWED_CHARS pairs = zip((chars.index(x) for x in secret), (chars.index(x) for x in mask)) - cipher = ''.join(chars[(x + y) % len(chars)] for x, y in pairs) + cipher = "".join(chars[(x + y) % len(chars)] for x, y in pairs) return mask + cipher @@ -76,21 +78,24 @@ def _unmask_cipher_token(token): token = token[CSRF_SECRET_LENGTH:] chars = CSRF_ALLOWED_CHARS pairs = zip((chars.index(x) for x in token), (chars.index(x) for x in mask)) - return ''.join(chars[x - y] for x, y in pairs) # Note negative values are ok + return "".join(chars[x - y] for x, y in pairs) # Note negative values are ok def _add_new_csrf_cookie(request): """Generate a new random CSRF_COOKIE value, and add it to request.META.""" csrf_secret = _get_new_csrf_string() - request.META.update({ - # RemovedInDjango50Warning: when the deprecation ends, replace - # with: 'CSRF_COOKIE': csrf_secret - 'CSRF_COOKIE': ( - _mask_cipher_secret(csrf_secret) - if settings.CSRF_COOKIE_MASKED else csrf_secret - ), - 'CSRF_COOKIE_NEEDS_UPDATE': True, - }) + request.META.update( + { + # RemovedInDjango50Warning: when the deprecation ends, replace + # with: 'CSRF_COOKIE': csrf_secret + "CSRF_COOKIE": ( + _mask_cipher_secret(csrf_secret) + if settings.CSRF_COOKIE_MASKED + else csrf_secret + ), + "CSRF_COOKIE_NEEDS_UPDATE": True, + } + ) return csrf_secret @@ -104,12 +109,12 @@ def get_token(request): header to the outgoing response. For this reason, you may need to use this function lazily, as is done by the csrf context processor. """ - if 'CSRF_COOKIE' in request.META: - csrf_secret = request.META['CSRF_COOKIE'] + if "CSRF_COOKIE" in request.META: + csrf_secret = request.META["CSRF_COOKIE"] # Since the cookie is being used, flag to send the cookie in # process_response() (even if the client already has it) in order to # renew the expiry timer. - request.META['CSRF_COOKIE_NEEDS_UPDATE'] = True + request.META["CSRF_COOKIE_NEEDS_UPDATE"] = True else: csrf_secret = _add_new_csrf_cookie(request) return _mask_cipher_secret(csrf_secret) @@ -171,19 +176,17 @@ class CsrfViewMiddleware(MiddlewareMixin): This middleware should be used in conjunction with the {% csrf_token %} template tag. """ + @cached_property def csrf_trusted_origins_hosts(self): return [ - urlparse(origin).netloc.lstrip('*') + urlparse(origin).netloc.lstrip("*") for origin in settings.CSRF_TRUSTED_ORIGINS ] @cached_property def allowed_origins_exact(self): - return { - origin for origin in settings.CSRF_TRUSTED_ORIGINS - if '*' not in origin - } + return {origin for origin in settings.CSRF_TRUSTED_ORIGINS if "*" not in origin} @cached_property def allowed_origin_subdomains(self): @@ -192,8 +195,12 @@ class CsrfViewMiddleware(MiddlewareMixin): subdomains of the netloc are allowed. """ allowed_origin_subdomains = defaultdict(list) - for parsed in (urlparse(origin) for origin in settings.CSRF_TRUSTED_ORIGINS if '*' in origin): - allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip('*')) + for parsed in ( + urlparse(origin) + for origin in settings.CSRF_TRUSTED_ORIGINS + if "*" in origin + ): + allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*")) return allowed_origin_subdomains # The _accept and _reject methods currently only exist for the sake of the @@ -208,7 +215,9 @@ class CsrfViewMiddleware(MiddlewareMixin): def _reject(self, request, reason): response = _get_failure_view()(request, reason=reason) log_response( - 'Forbidden (%s): %s', reason, request.path, + "Forbidden (%s): %s", + reason, + request.path, response=response, request=request, logger=logger, @@ -228,9 +237,9 @@ class CsrfViewMiddleware(MiddlewareMixin): csrf_secret = request.session.get(CSRF_SESSION_KEY) except AttributeError: raise ImproperlyConfigured( - 'CSRF_USE_SESSIONS is enabled, but request.session is not ' - 'set. SessionMiddleware must appear before CsrfViewMiddleware ' - 'in MIDDLEWARE.' + "CSRF_USE_SESSIONS is enabled, but request.session is not " + "set. SessionMiddleware must appear before CsrfViewMiddleware " + "in MIDDLEWARE." ) else: try: @@ -249,12 +258,12 @@ class CsrfViewMiddleware(MiddlewareMixin): def _set_csrf_cookie(self, request, response): if settings.CSRF_USE_SESSIONS: - if request.session.get(CSRF_SESSION_KEY) != request.META['CSRF_COOKIE']: - request.session[CSRF_SESSION_KEY] = request.META['CSRF_COOKIE'] + if request.session.get(CSRF_SESSION_KEY) != request.META["CSRF_COOKIE"]: + request.session[CSRF_SESSION_KEY] = request.META["CSRF_COOKIE"] else: response.set_cookie( settings.CSRF_COOKIE_NAME, - request.META['CSRF_COOKIE'], + request.META["CSRF_COOKIE"], max_age=settings.CSRF_COOKIE_AGE, domain=settings.CSRF_COOKIE_DOMAIN, path=settings.CSRF_COOKIE_PATH, @@ -263,17 +272,17 @@ class CsrfViewMiddleware(MiddlewareMixin): samesite=settings.CSRF_COOKIE_SAMESITE, ) # Set the Vary header since content varies with the CSRF cookie. - patch_vary_headers(response, ('Cookie',)) + patch_vary_headers(response, ("Cookie",)) def _origin_verified(self, request): - request_origin = request.META['HTTP_ORIGIN'] + request_origin = request.META["HTTP_ORIGIN"] try: good_host = request.get_host() except DisallowedHost: pass else: - good_origin = '%s://%s' % ( - 'https' if request.is_secure() else 'http', + good_origin = "%s://%s" % ( + "https" if request.is_secure() else "http", good_host, ) if request_origin == good_origin: @@ -292,7 +301,7 @@ class CsrfViewMiddleware(MiddlewareMixin): ) def _check_referer(self, request): - referer = request.META.get('HTTP_REFERER') + referer = request.META.get("HTTP_REFERER") if referer is None: raise RejectRequest(REASON_NO_REFERER) @@ -302,11 +311,11 @@ class CsrfViewMiddleware(MiddlewareMixin): raise RejectRequest(REASON_MALFORMED_REFERER) # Make sure we have a valid URL for Referer. - if '' in (referer.scheme, referer.netloc): + if "" in (referer.scheme, referer.netloc): raise RejectRequest(REASON_MALFORMED_REFERER) # Ensure that our Referer is also secure. - if referer.scheme != 'https': + if referer.scheme != "https": raise RejectRequest(REASON_INSECURE_REFERER) if any( @@ -330,18 +339,18 @@ class CsrfViewMiddleware(MiddlewareMixin): raise RejectRequest(REASON_BAD_REFERER % referer.geturl()) else: server_port = request.get_port() - if server_port not in ('443', '80'): - good_referer = '%s:%s' % (good_referer, server_port) + if server_port not in ("443", "80"): + good_referer = "%s:%s" % (good_referer, server_port) if not is_same_domain(referer.netloc, good_referer): raise RejectRequest(REASON_BAD_REFERER % referer.geturl()) def _bad_token_message(self, reason, token_source): - if token_source != 'POST': + if token_source != "POST": # Assume it is a settings.CSRF_HEADER_NAME value. header_name = HttpHeaders.parse_header_name(token_source) - token_source = f'the {header_name!r} HTTP header' - return f'CSRF token from {token_source} {reason}.' + token_source = f"the {header_name!r} HTTP header" + return f"CSRF token from {token_source} {reason}." def _check_token(self, request): # Access csrf_secret via self._get_secret() as rotate_token() may have @@ -350,7 +359,7 @@ class CsrfViewMiddleware(MiddlewareMixin): try: csrf_secret = self._get_secret(request) except InvalidTokenFormat as exc: - raise RejectRequest(f'CSRF cookie {exc.reason}.') + raise RejectRequest(f"CSRF cookie {exc.reason}.") if csrf_secret is None: # No CSRF cookie. For POST requests, we insist on a CSRF cookie, @@ -359,10 +368,10 @@ class CsrfViewMiddleware(MiddlewareMixin): raise RejectRequest(REASON_NO_CSRF_COOKIE) # Check non-cookie token for match. - request_csrf_token = '' - if request.method == 'POST': + request_csrf_token = "" + if request.method == "POST": try: - request_csrf_token = request.POST.get('csrfmiddlewaretoken', '') + request_csrf_token = request.POST.get("csrfmiddlewaretoken", "") except UnreadablePostError: # Handle a broken connection before we've completed reading the # POST data. process_view shouldn't raise any exceptions, so @@ -370,7 +379,7 @@ class CsrfViewMiddleware(MiddlewareMixin): # listening, which they probably aren't because of the error). pass - if request_csrf_token == '': + if request_csrf_token == "": # Fall back to X-CSRFToken, to make things easier for AJAX, and # possible for PUT/DELETE. try: @@ -383,7 +392,7 @@ class CsrfViewMiddleware(MiddlewareMixin): raise RejectRequest(REASON_CSRF_TOKEN_MISSING) token_source = settings.CSRF_HEADER_NAME else: - token_source = 'POST' + token_source = "POST" try: _check_token_format(request_csrf_token) @@ -392,7 +401,7 @@ class CsrfViewMiddleware(MiddlewareMixin): raise RejectRequest(reason) if not _does_token_match(request_csrf_token, csrf_secret): - reason = self._bad_token_message('incorrect', token_source) + reason = self._bad_token_message("incorrect", token_source) raise RejectRequest(reason) def process_request(self, request): @@ -406,22 +415,22 @@ class CsrfViewMiddleware(MiddlewareMixin): # masked, this also causes it to be replaced with the unmasked # form, but only in cases where the secret is already getting # saved anyways. - request.META['CSRF_COOKIE'] = csrf_secret + request.META["CSRF_COOKIE"] = csrf_secret def process_view(self, request, callback, callback_args, callback_kwargs): - if getattr(request, 'csrf_processing_done', False): + if getattr(request, "csrf_processing_done", False): return None # Wait until request.META["CSRF_COOKIE"] has been manipulated before # bailing out, so that get_token still works - if getattr(callback, 'csrf_exempt', False): + if getattr(callback, "csrf_exempt", False): return None # Assume that anything not defined as 'safe' by RFC7231 needs protection - if request.method in ('GET', 'HEAD', 'OPTIONS', 'TRACE'): + if request.method in ("GET", "HEAD", "OPTIONS", "TRACE"): return self._accept(request) - if getattr(request, '_dont_enforce_csrf_checks', False): + if getattr(request, "_dont_enforce_csrf_checks", False): # Mechanism to turn off CSRF checks for test suite. It comes after # the creation of CSRF cookies, so that everything else continues # to work exactly the same (e.g. cookies are sent, etc.), but @@ -430,9 +439,11 @@ class CsrfViewMiddleware(MiddlewareMixin): # Reject the request if the Origin header doesn't match an allowed # value. - if 'HTTP_ORIGIN' in request.META: + if "HTTP_ORIGIN" in request.META: if not self._origin_verified(request): - return self._reject(request, REASON_BAD_ORIGIN % request.META['HTTP_ORIGIN']) + return self._reject( + request, REASON_BAD_ORIGIN % request.META["HTTP_ORIGIN"] + ) elif request.is_secure(): # If the Origin header wasn't provided, reject HTTPS requests if # the Referer header doesn't match an allowed value. @@ -464,7 +475,7 @@ class CsrfViewMiddleware(MiddlewareMixin): return self._accept(request) def process_response(self, request, response): - if request.META.get('CSRF_COOKIE_NEEDS_UPDATE'): + if request.META.get("CSRF_COOKIE_NEEDS_UPDATE"): self._set_csrf_cookie(request, response) # Unset the flag to prevent _set_csrf_cookie() from being # unnecessarily called again in process_response() by other @@ -473,6 +484,6 @@ class CsrfViewMiddleware(MiddlewareMixin): # CSRF_COOKIE_NEEDS_UPDATE is still respected in subsequent calls # e.g. in case rotate_token() is called in process_response() later # by custom middleware but before those subsequent calls. - request.META['CSRF_COOKIE_NEEDS_UPDATE'] = False + request.META["CSRF_COOKIE_NEEDS_UPDATE"] = False return response diff --git a/django/middleware/gzip.py b/django/middleware/gzip.py index 350466151d..6d27c1e335 100644 --- a/django/middleware/gzip.py +++ b/django/middleware/gzip.py @@ -3,7 +3,7 @@ from django.utils.deprecation import MiddlewareMixin from django.utils.regex_helper import _lazy_re_compile from django.utils.text import compress_sequence, compress_string -re_accepts_gzip = _lazy_re_compile(r'\bgzip\b') +re_accepts_gzip = _lazy_re_compile(r"\bgzip\b") class GZipMiddleware(MiddlewareMixin): @@ -12,18 +12,19 @@ class GZipMiddleware(MiddlewareMixin): Set the Vary header accordingly, so that caches will base their storage on the Accept-Encoding header. """ + def process_response(self, request, response): # It's not worth attempting to compress really short responses. if not response.streaming and len(response.content) < 200: return response # Avoid gzipping if we've already got a content-encoding. - if response.has_header('Content-Encoding'): + if response.has_header("Content-Encoding"): return response - patch_vary_headers(response, ('Accept-Encoding',)) + patch_vary_headers(response, ("Accept-Encoding",)) - ae = request.META.get('HTTP_ACCEPT_ENCODING', '') + ae = request.META.get("HTTP_ACCEPT_ENCODING", "") if not re_accepts_gzip.search(ae): return response @@ -31,21 +32,21 @@ class GZipMiddleware(MiddlewareMixin): # Delete the `Content-Length` header for streaming content, because # we won't know the compressed size until we stream it. response.streaming_content = compress_sequence(response.streaming_content) - del response.headers['Content-Length'] + del response.headers["Content-Length"] else: # Return the compressed content only if it's actually shorter. compressed_content = compress_string(response.content) if len(compressed_content) >= len(response.content): return response response.content = compressed_content - response.headers['Content-Length'] = str(len(response.content)) + response.headers["Content-Length"] = str(len(response.content)) # If there is a strong ETag, make it weak to fulfill the requirements # of RFC 7232 section-2.1 while also allowing conditional request # matches on ETags. - etag = response.get('ETag') + etag = response.get("ETag") if etag and etag.startswith('"'): - response.headers['ETag'] = 'W/' + etag - response.headers['Content-Encoding'] = 'gzip' + response.headers["ETag"] = "W/" + etag + response.headers["Content-Encoding"] = "gzip" return response diff --git a/django/middleware/http.py b/django/middleware/http.py index 4fcde85698..84c5466bb6 100644 --- a/django/middleware/http.py +++ b/django/middleware/http.py @@ -1,6 +1,4 @@ -from django.utils.cache import ( - cc_delim_re, get_conditional_response, set_response_etag, -) +from django.utils.cache import cc_delim_re, get_conditional_response, set_response_etag from django.utils.deprecation import MiddlewareMixin from django.utils.http import parse_http_date_safe @@ -11,18 +9,19 @@ class ConditionalGetMiddleware(MiddlewareMixin): Last-Modified header and the request has If-None-Match or If-Modified-Since, replace the response with HttpNotModified. Add an ETag header if needed. """ + def process_response(self, request, response): # It's too late to prevent an unsafe request with a 412 response, and # for a HEAD request, the response body is always empty so computing # an accurate ETag isn't possible. - if request.method != 'GET': + if request.method != "GET": return response - if self.needs_etag(response) and not response.has_header('ETag'): + if self.needs_etag(response) and not response.has_header("ETag"): set_response_etag(response) - etag = response.get('ETag') - last_modified = response.get('Last-Modified') + etag = response.get("ETag") + last_modified = response.get("Last-Modified") last_modified = last_modified and parse_http_date_safe(last_modified) if etag or last_modified: @@ -37,5 +36,5 @@ class ConditionalGetMiddleware(MiddlewareMixin): def needs_etag(self, response): """Return True if an ETag header should be added to response.""" - cache_control_headers = cc_delim_re.split(response.get('Cache-Control', '')) - return all(header.lower() != 'no-store' for header in cache_control_headers) + cache_control_headers = cc_delim_re.split(response.get("Cache-Control", "")) + return all(header.lower() != "no-store" for header in cache_control_headers) diff --git a/django/middleware/locale.py b/django/middleware/locale.py index d90fc84152..71db230da2 100644 --- a/django/middleware/locale.py +++ b/django/middleware/locale.py @@ -13,14 +13,24 @@ class LocaleMiddleware(MiddlewareMixin): current thread context. This allows pages to be dynamically translated to the language the user desires (if the language is available). """ + response_redirect_class = HttpResponseRedirect def process_request(self, request): - urlconf = getattr(request, 'urlconf', settings.ROOT_URLCONF) - i18n_patterns_used, prefixed_default_language = is_language_prefix_patterns_used(urlconf) - language = translation.get_language_from_request(request, check_path=i18n_patterns_used) + urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) + ( + i18n_patterns_used, + prefixed_default_language, + ) = is_language_prefix_patterns_used(urlconf) + language = translation.get_language_from_request( + request, check_path=i18n_patterns_used + ) language_from_path = translation.get_language_from_path(request.path_info) - if not language_from_path and i18n_patterns_used and not prefixed_default_language: + if ( + not language_from_path + and i18n_patterns_used + and not prefixed_default_language + ): language = settings.LANGUAGE_CODE translation.activate(language) request.LANGUAGE_CODE = translation.get_language() @@ -28,39 +38,43 @@ class LocaleMiddleware(MiddlewareMixin): def process_response(self, request, response): language = translation.get_language() language_from_path = translation.get_language_from_path(request.path_info) - urlconf = getattr(request, 'urlconf', settings.ROOT_URLCONF) - i18n_patterns_used, prefixed_default_language = is_language_prefix_patterns_used(urlconf) + urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) + ( + i18n_patterns_used, + prefixed_default_language, + ) = is_language_prefix_patterns_used(urlconf) - if (response.status_code == 404 and not language_from_path and - i18n_patterns_used and prefixed_default_language): + if ( + response.status_code == 404 + and not language_from_path + and i18n_patterns_used + and prefixed_default_language + ): # Maybe the language code is missing in the URL? Try adding the # language prefix and redirecting to that URL. - language_path = '/%s%s' % (language, request.path_info) + language_path = "/%s%s" % (language, request.path_info) path_valid = is_valid_path(language_path, urlconf) - path_needs_slash = ( - not path_valid and ( - settings.APPEND_SLASH and not language_path.endswith('/') and - is_valid_path('%s/' % language_path, urlconf) - ) + path_needs_slash = not path_valid and ( + settings.APPEND_SLASH + and not language_path.endswith("/") + and is_valid_path("%s/" % language_path, urlconf) ) if path_valid or path_needs_slash: script_prefix = get_script_prefix() # Insert language after the script prefix and before the # rest of the URL - language_url = request.get_full_path(force_append_slash=path_needs_slash).replace( - script_prefix, - '%s%s/' % (script_prefix, language), - 1 - ) + language_url = request.get_full_path( + force_append_slash=path_needs_slash + ).replace(script_prefix, "%s%s/" % (script_prefix, language), 1) # Redirect to the language-specific URL as detected by # get_language_from_request(). HTTP caches may cache this # redirect, so add the Vary header. redirect = self.response_redirect_class(language_url) - patch_vary_headers(redirect, ('Accept-Language', 'Cookie')) + patch_vary_headers(redirect, ("Accept-Language", "Cookie")) return redirect if not (i18n_patterns_used and language_from_path): - patch_vary_headers(response, ('Accept-Language',)) - response.headers.setdefault('Content-Language', language) + patch_vary_headers(response, ("Accept-Language",)) + response.headers.setdefault("Content-Language", language) return response diff --git a/django/middleware/security.py b/django/middleware/security.py index d2c2bf2d3f..1dd2204814 100644 --- a/django/middleware/security.py +++ b/django/middleware/security.py @@ -20,38 +20,47 @@ class SecurityMiddleware(MiddlewareMixin): def process_request(self, request): path = request.path.lstrip("/") - if (self.redirect and not request.is_secure() and - not any(pattern.search(path) - for pattern in self.redirect_exempt)): + if ( + self.redirect + and not request.is_secure() + and not any(pattern.search(path) for pattern in self.redirect_exempt) + ): host = self.redirect_host or request.get_host() return HttpResponsePermanentRedirect( "https://%s%s" % (host, request.get_full_path()) ) def process_response(self, request, response): - if (self.sts_seconds and request.is_secure() and - 'Strict-Transport-Security' not in response): + if ( + self.sts_seconds + and request.is_secure() + and "Strict-Transport-Security" not in response + ): sts_header = "max-age=%s" % self.sts_seconds if self.sts_include_subdomains: sts_header = sts_header + "; includeSubDomains" if self.sts_preload: sts_header = sts_header + "; preload" - response.headers['Strict-Transport-Security'] = sts_header + response.headers["Strict-Transport-Security"] = sts_header if self.content_type_nosniff: - response.headers.setdefault('X-Content-Type-Options', 'nosniff') + response.headers.setdefault("X-Content-Type-Options", "nosniff") if self.referrer_policy: # Support a comma-separated string or iterable of values to allow # fallback. - response.headers.setdefault('Referrer-Policy', ','.join( - [v.strip() for v in self.referrer_policy.split(',')] - if isinstance(self.referrer_policy, str) else self.referrer_policy - )) + response.headers.setdefault( + "Referrer-Policy", + ",".join( + [v.strip() for v in self.referrer_policy.split(",")] + if isinstance(self.referrer_policy, str) + else self.referrer_policy + ), + ) if self.cross_origin_opener_policy: response.setdefault( - 'Cross-Origin-Opener-Policy', + "Cross-Origin-Opener-Policy", self.cross_origin_opener_policy, ) return response diff --git a/django/shortcuts.py b/django/shortcuts.py index 9d213bc1ff..90ec1bedc5 100644 --- a/django/shortcuts.py +++ b/django/shortcuts.py @@ -4,14 +4,19 @@ of MVC. In other words, these functions/classes introduce controlled coupling for convenience's sake. """ from django.http import ( - Http404, HttpResponse, HttpResponsePermanentRedirect, HttpResponseRedirect, + Http404, + HttpResponse, + HttpResponsePermanentRedirect, + HttpResponseRedirect, ) from django.template import loader from django.urls import NoReverseMatch, reverse from django.utils.functional import Promise -def render(request, template_name, context=None, content_type=None, status=None, using=None): +def render( + request, template_name, context=None, content_type=None, status=None, using=None +): """ Return an HttpResponse whose content is filled with the result of calling django.template.loader.render_to_string() with the passed arguments. @@ -37,7 +42,9 @@ def redirect(to, *args, permanent=False, **kwargs): Issues a temporary redirect by default; pass permanent=True to issue a permanent redirect. """ - redirect_class = HttpResponsePermanentRedirect if permanent else HttpResponseRedirect + redirect_class = ( + HttpResponsePermanentRedirect if permanent else HttpResponseRedirect + ) return redirect_class(resolve_url(to, *args, **kwargs)) @@ -49,7 +56,7 @@ def _get_queryset(klass): the job. """ # If it is a model class or anything else with ._default_manager - if hasattr(klass, '_default_manager'): + if hasattr(klass, "_default_manager"): return klass._default_manager.all() return klass @@ -66,8 +73,10 @@ def get_object_or_404(klass, *args, **kwargs): one object is found. """ queryset = _get_queryset(klass) - if not hasattr(queryset, 'get'): - klass__name = klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + if not hasattr(queryset, "get"): + klass__name = ( + klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + ) raise ValueError( "First argument to get_object_or_404() must be a Model, Manager, " "or QuerySet, not '%s'." % klass__name @@ -75,7 +84,9 @@ def get_object_or_404(klass, *args, **kwargs): try: return queryset.get(*args, **kwargs) except queryset.model.DoesNotExist: - raise Http404('No %s matches the given query.' % queryset.model._meta.object_name) + raise Http404( + "No %s matches the given query." % queryset.model._meta.object_name + ) def get_list_or_404(klass, *args, **kwargs): @@ -87,15 +98,19 @@ def get_list_or_404(klass, *args, **kwargs): arguments and keyword arguments are used in the filter() query. """ queryset = _get_queryset(klass) - if not hasattr(queryset, 'filter'): - klass__name = klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + if not hasattr(queryset, "filter"): + klass__name = ( + klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + ) raise ValueError( "First argument to get_list_or_404() must be a Model, Manager, or " "QuerySet, not '%s'." % klass__name ) obj_list = list(queryset.filter(*args, **kwargs)) if not obj_list: - raise Http404('No %s matches the given query.' % queryset.model._meta.object_name) + raise Http404( + "No %s matches the given query." % queryset.model._meta.object_name + ) return obj_list @@ -113,7 +128,7 @@ def resolve_url(to, *args, **kwargs): * A URL, which will be returned as-is. """ # If it's a model, use get_absolute_url() - if hasattr(to, 'get_absolute_url'): + if hasattr(to, "get_absolute_url"): return to.get_absolute_url() if isinstance(to, Promise): @@ -122,7 +137,7 @@ def resolve_url(to, *args, **kwargs): to = str(to) # Handle relative URLs - if isinstance(to, str) and to.startswith(('./', '../')): + if isinstance(to, str) and to.startswith(("./", "../")): return to # Next try a reverse URL resolution. @@ -133,7 +148,7 @@ def resolve_url(to, *args, **kwargs): if callable(to): raise # If this doesn't "feel" like a URL, re-raise. - if '/' not in to and '.' not in to: + if "/" not in to and "." not in to: raise # Finally, fall back and assume it's a URL diff --git a/django/template/__init__.py b/django/template/__init__.py index 7414d0fef5..adb431c00d 100644 --- a/django/template/__init__.py +++ b/django/template/__init__.py @@ -46,26 +46,30 @@ from .utils import EngineHandler engines = EngineHandler() -__all__ = ('Engine', 'engines') +__all__ = ("Engine", "engines") # Django Template Language # Public exceptions -from .base import VariableDoesNotExist # NOQA isort:skip -from .context import Context, ContextPopException, RequestContext # NOQA isort:skip -from .exceptions import TemplateDoesNotExist, TemplateSyntaxError # NOQA isort:skip +from .base import VariableDoesNotExist # NOQA isort:skip +from .context import Context, ContextPopException, RequestContext # NOQA isort:skip +from .exceptions import TemplateDoesNotExist, TemplateSyntaxError # NOQA isort:skip # Template parts -from .base import ( # NOQA isort:skip - Node, NodeList, Origin, Template, Variable, +from .base import ( # NOQA isort:skip + Node, + NodeList, + Origin, + Template, + Variable, ) # Library management -from .library import Library # NOQA isort:skip +from .library import Library # NOQA isort:skip # Import the .autoreload module to trigger the registrations of signals. -from . import autoreload # NOQA isort:skip +from . import autoreload # NOQA isort:skip -__all__ += ('Template', 'Context', 'RequestContext') +__all__ += ("Template", "Context", "RequestContext") diff --git a/django/template/autoreload.py b/django/template/autoreload.py index 7242d68f2d..84c8554165 100644 --- a/django/template/autoreload.py +++ b/django/template/autoreload.py @@ -4,9 +4,7 @@ from django.dispatch import receiver from django.template import engines from django.template.backends.django import DjangoTemplates from django.utils._os import to_path -from django.utils.autoreload import ( - autoreload_started, file_changed, is_django_path, -) +from django.utils.autoreload import autoreload_started, file_changed, is_django_path def get_template_directories(): @@ -22,7 +20,7 @@ def get_template_directories(): items.update(cwd / to_path(dir) for dir in backend.engine.dirs) for loader in backend.engine.template_loaders: - if not hasattr(loader, 'get_dirs'): + if not hasattr(loader, "get_dirs"): continue items.update( cwd / to_path(directory) @@ -40,15 +38,15 @@ def reset_loaders(): loader.reset() -@receiver(autoreload_started, dispatch_uid='template_loaders_watch_changes') +@receiver(autoreload_started, dispatch_uid="template_loaders_watch_changes") def watch_for_template_changes(sender, **kwargs): for directory in get_template_directories(): - sender.watch_dir(directory, '**/*') + sender.watch_dir(directory, "**/*") -@receiver(file_changed, dispatch_uid='template_loaders_file_changed') +@receiver(file_changed, dispatch_uid="template_loaders_file_changed") def template_changed(sender, file_path, **kwargs): - if file_path.suffix == '.py': + if file_path.suffix == ".py": return for template_dir in get_template_directories(): if template_dir in file_path.parents: diff --git a/django/template/backends/base.py b/django/template/backends/base.py index f1fa142362..240733e6f4 100644 --- a/django/template/backends/base.py +++ b/django/template/backends/base.py @@ -1,6 +1,4 @@ -from django.core.exceptions import ( - ImproperlyConfigured, SuspiciousFileOperation, -) +from django.core.exceptions import ImproperlyConfigured, SuspiciousFileOperation from django.template.utils import get_app_template_dirs from django.utils._os import safe_join from django.utils.functional import cached_property @@ -18,18 +16,20 @@ class BaseEngine: `params` is a dict of configuration settings. """ params = params.copy() - self.name = params.pop('NAME') - self.dirs = list(params.pop('DIRS')) - self.app_dirs = params.pop('APP_DIRS') + self.name = params.pop("NAME") + self.dirs = list(params.pop("DIRS")) + self.app_dirs = params.pop("APP_DIRS") if params: raise ImproperlyConfigured( - "Unknown parameters: {}".format(", ".join(params))) + "Unknown parameters: {}".format(", ".join(params)) + ) @property def app_dirname(self): raise ImproperlyConfigured( "{} doesn't support loading templates from installed " - "applications.".format(self.__class__.__name__)) + "applications.".format(self.__class__.__name__) + ) def from_string(self, template_code): """ diff --git a/django/template/backends/django.py b/django/template/backends/django.py index 48257e5e4e..218e5e0bc1 100644 --- a/django/template/backends/django.py +++ b/django/template/backends/django.py @@ -13,16 +13,16 @@ from .base import BaseEngine class DjangoTemplates(BaseEngine): - app_dirname = 'templates' + app_dirname = "templates" def __init__(self, params): params = params.copy() - options = params.pop('OPTIONS').copy() - options.setdefault('autoescape', True) - options.setdefault('debug', settings.DEBUG) - options.setdefault('file_charset', 'utf-8') - libraries = options.get('libraries', {}) - options['libraries'] = self.get_templatetag_libraries(libraries) + options = params.pop("OPTIONS").copy() + options.setdefault("autoescape", True) + options.setdefault("debug", settings.DEBUG) + options.setdefault("file_charset", "utf-8") + libraries = options.get("libraries", {}) + options["libraries"] = self.get_templatetag_libraries(libraries) super().__init__(params) self.engine = Engine(self.dirs, self.app_dirs, **options) @@ -46,7 +46,6 @@ class DjangoTemplates(BaseEngine): class Template: - def __init__(self, template, backend): self.template = template self.backend = backend @@ -56,7 +55,9 @@ class Template: return self.template.origin def render(self, context=None, request=None): - context = make_context(context, request, autoescape=self.backend.engine.autoescape) + context = make_context( + context, request, autoescape=self.backend.engine.autoescape + ) try: return self.template.render(context) except TemplateDoesNotExist as exc: @@ -71,7 +72,7 @@ def copy_exception(exc, backend=None): """ backend = backend or exc.backend new = exc.__class__(*exc.args, tried=exc.tried, backend=backend, chain=exc.chain) - if hasattr(exc, 'template_debug'): + if hasattr(exc, "template_debug"): new.template_debug = exc.template_debug return new @@ -89,10 +90,9 @@ def get_template_tag_modules(): Yield (module_name, module_path) pairs for all installed template tag libraries. """ - candidates = ['django.templatetags'] + candidates = ["django.templatetags"] candidates.extend( - f'{app_config.name}.templatetags' - for app_config in apps.get_app_configs() + f"{app_config.name}.templatetags" for app_config in apps.get_app_configs() ) for candidate in candidates: @@ -102,9 +102,9 @@ def get_template_tag_modules(): # No templatetags package defined. This is safe to ignore. continue - if hasattr(pkg, '__path__'): + if hasattr(pkg, "__path__"): for name in get_package_libraries(pkg): - yield name[len(candidate) + 1:], name + yield name[len(candidate) + 1 :], name def get_installed_libraries(): @@ -115,8 +115,7 @@ def get_installed_libraries(): django.templatetags.i18n is stored as i18n. """ return { - module_name: full_name - for module_name, full_name in get_template_tag_modules() + module_name: full_name for module_name, full_name in get_template_tag_modules() } @@ -125,7 +124,7 @@ def get_package_libraries(pkg): Recursively yield template tag libraries defined in submodules of a package. """ - for entry in walk_packages(pkg.__path__, pkg.__name__ + '.'): + for entry in walk_packages(pkg.__path__, pkg.__name__ + "."): try: module = import_module(entry[1]) except ImportError as e: @@ -134,5 +133,5 @@ def get_package_libraries(pkg): "trying to load '%s': %s" % (entry[1], e) ) from e - if hasattr(module, 'register'): + if hasattr(module, "register"): yield entry[1] diff --git a/django/template/backends/dummy.py b/django/template/backends/dummy.py index 6be05ca614..692382b6b1 100644 --- a/django/template/backends/dummy.py +++ b/django/template/backends/dummy.py @@ -10,14 +10,13 @@ from .utils import csrf_input_lazy, csrf_token_lazy class TemplateStrings(BaseEngine): - app_dirname = 'template_strings' + app_dirname = "template_strings" def __init__(self, params): params = params.copy() - options = params.pop('OPTIONS').copy() + options = params.pop("OPTIONS").copy() if options: - raise ImproperlyConfigured( - "Unknown options: {}".format(", ".join(options))) + raise ImproperlyConfigured("Unknown options: {}".format(", ".join(options))) super().__init__(params) def from_string(self, template_code): @@ -27,26 +26,27 @@ class TemplateStrings(BaseEngine): tried = [] for template_file in self.iter_template_filenames(template_name): try: - with open(template_file, encoding='utf-8') as fp: + with open(template_file, encoding="utf-8") as fp: template_code = fp.read() except FileNotFoundError: - tried.append(( - Origin(template_file, template_name, self), - 'Source does not exist', - )) + tried.append( + ( + Origin(template_file, template_name, self), + "Source does not exist", + ) + ) else: return Template(template_code) raise TemplateDoesNotExist(template_name, tried=tried, backend=self) class Template(string.Template): - def render(self, context=None, request=None): if context is None: context = {} else: context = {k: conditional_escape(v) for k, v in context.items()} if request is not None: - context['csrf_input'] = csrf_input_lazy(request) - context['csrf_token'] = csrf_token_lazy(request) + context["csrf_input"] = csrf_input_lazy(request) + context["csrf_token"] = csrf_token_lazy(request) return self.safe_substitute(context) diff --git a/django/template/backends/jinja2.py b/django/template/backends/jinja2.py index f0540d389c..199d62b429 100644 --- a/django/template/backends/jinja2.py +++ b/django/template/backends/jinja2.py @@ -12,24 +12,25 @@ from .base import BaseEngine class Jinja2(BaseEngine): - app_dirname = 'jinja2' + app_dirname = "jinja2" def __init__(self, params): params = params.copy() - options = params.pop('OPTIONS').copy() + options = params.pop("OPTIONS").copy() super().__init__(params) - self.context_processors = options.pop('context_processors', []) + self.context_processors = options.pop("context_processors", []) - environment = options.pop('environment', 'jinja2.Environment') + environment = options.pop("environment", "jinja2.Environment") environment_cls = import_string(environment) - if 'loader' not in options: - options['loader'] = jinja2.FileSystemLoader(self.template_dirs) - options.setdefault('autoescape', True) - options.setdefault('auto_reload', settings.DEBUG) - options.setdefault('undefined', - jinja2.DebugUndefined if settings.DEBUG else jinja2.Undefined) + if "loader" not in options: + options["loader"] = jinja2.FileSystemLoader(self.template_dirs) + options.setdefault("autoescape", True) + options.setdefault("auto_reload", settings.DEBUG) + options.setdefault( + "undefined", jinja2.DebugUndefined if settings.DEBUG else jinja2.Undefined + ) self.env = environment_cls(**options) @@ -52,22 +53,23 @@ class Jinja2(BaseEngine): class Template: - def __init__(self, template, backend): self.template = template self.backend = backend self.origin = Origin( - name=template.filename, template_name=template.name, + name=template.filename, + template_name=template.name, ) def render(self, context=None, request=None): from .utils import csrf_input_lazy, csrf_token_lazy + if context is None: context = {} if request is not None: - context['request'] = request - context['csrf_input'] = csrf_input_lazy(request) - context['csrf_token'] = csrf_token_lazy(request) + context["request"] = request + context["csrf_input"] = csrf_input_lazy(request) + context["csrf_token"] = csrf_token_lazy(request) for context_processor in self.backend.template_context_processors: context.update(context_processor(request)) try: @@ -83,6 +85,7 @@ class Origin: A container to hold debug information as described in the template API documentation. """ + def __init__(self, name, template_name): self.name = name self.template_name = template_name @@ -101,24 +104,24 @@ def get_exception_info(exception): if exception_file.exists(): source = exception_file.read_text() if source is not None: - lines = list(enumerate(source.strip().split('\n'), start=1)) + lines = list(enumerate(source.strip().split("\n"), start=1)) during = lines[lineno - 1][1] total = len(lines) top = max(0, lineno - context_lines - 1) bottom = min(total, lineno + context_lines) else: - during = '' + during = "" lines = [] total = top = bottom = 0 return { - 'name': exception.filename, - 'message': exception.message, - 'source_lines': lines[top:bottom], - 'line': lineno, - 'before': '', - 'during': during, - 'after': '', - 'total': total, - 'top': top, - 'bottom': bottom, + "name": exception.filename, + "message": exception.message, + "source_lines": lines[top:bottom], + "line": lineno, + "before": "", + "during": during, + "after": "", + "total": total, + "top": top, + "bottom": bottom, } diff --git a/django/template/backends/utils.py b/django/template/backends/utils.py index 1396ae7095..880959d6a1 100644 --- a/django/template/backends/utils.py +++ b/django/template/backends/utils.py @@ -7,7 +7,8 @@ from django.utils.safestring import SafeString def csrf_input(request): return format_html( '<input type="hidden" name="csrfmiddlewaretoken" value="{}">', - get_token(request)) + get_token(request), + ) csrf_input_lazy = lazy(csrf_input, SafeString, str) diff --git a/django/template/base.py b/django/template/base.py index caadf89970..00f8283507 100644 --- a/django/template/base.py +++ b/django/template/base.py @@ -60,37 +60,35 @@ from django.utils.formats import localize from django.utils.html import conditional_escape, escape from django.utils.regex_helper import _lazy_re_compile from django.utils.safestring import SafeData, SafeString, mark_safe -from django.utils.text import ( - get_text_list, smart_split, unescape_string_literal, -) +from django.utils.text import get_text_list, smart_split, unescape_string_literal from django.utils.timezone import template_localtime from django.utils.translation import gettext_lazy, pgettext_lazy from .exceptions import TemplateSyntaxError # template syntax constants -FILTER_SEPARATOR = '|' -FILTER_ARGUMENT_SEPARATOR = ':' -VARIABLE_ATTRIBUTE_SEPARATOR = '.' -BLOCK_TAG_START = '{%' -BLOCK_TAG_END = '%}' -VARIABLE_TAG_START = '{{' -VARIABLE_TAG_END = '}}' -COMMENT_TAG_START = '{#' -COMMENT_TAG_END = '#}' -SINGLE_BRACE_START = '{' -SINGLE_BRACE_END = '}' +FILTER_SEPARATOR = "|" +FILTER_ARGUMENT_SEPARATOR = ":" +VARIABLE_ATTRIBUTE_SEPARATOR = "." +BLOCK_TAG_START = "{%" +BLOCK_TAG_END = "%}" +VARIABLE_TAG_START = "{{" +VARIABLE_TAG_END = "}}" +COMMENT_TAG_START = "{#" +COMMENT_TAG_END = "#}" +SINGLE_BRACE_START = "{" +SINGLE_BRACE_END = "}" # what to report as the origin for templates that come from non-loader sources # (e.g. strings) -UNKNOWN_SOURCE = '<unknown source>' +UNKNOWN_SOURCE = "<unknown source>" # Match BLOCK_TAG_*, VARIABLE_TAG_*, and COMMENT_TAG_* tags and capture the # entire tag, including start/end delimiters. Using re.compile() is faster # than instantiating SimpleLazyObject with _lazy_re_compile(). -tag_re = re.compile(r'({%.*?%}|{{.*?}}|{#.*?#})') +tag_re = re.compile(r"({%.*?%}|{{.*?}}|{#.*?#})") -logger = logging.getLogger('django.template') +logger = logging.getLogger("django.template") class TokenType(Enum): @@ -101,7 +99,6 @@ class TokenType(Enum): class VariableDoesNotExist(Exception): - def __init__(self, msg, params=()): self.msg = msg self.params = params @@ -120,20 +117,21 @@ class Origin: return self.name def __repr__(self): - return '<%s name=%r>' % (self.__class__.__qualname__, self.name) + return "<%s name=%r>" % (self.__class__.__qualname__, self.name) def __eq__(self, other): return ( - isinstance(other, Origin) and - self.name == other.name and - self.loader == other.loader + isinstance(other, Origin) + and self.name == other.name + and self.loader == other.loader ) @property def loader_name(self): if self.loader: - return '%s.%s' % ( - self.loader.__module__, self.loader.__class__.__name__, + return "%s.%s" % ( + self.loader.__module__, + self.loader.__class__.__name__, ) @@ -145,6 +143,7 @@ class Template: # e.g. Template('...').render(Context({...})) if engine is None: from .engine import Engine + engine = Engine.get_default() if origin is None: origin = Origin(UNKNOWN_SOURCE) @@ -161,7 +160,7 @@ class Template: def __repr__(self): return '<%s template_string="%s...">' % ( self.__class__.__qualname__, - self.source[:20].replace('\n', ''), + self.source[:20].replace("\n", ""), ) def _render(self, context): @@ -191,7 +190,9 @@ class Template: tokens = lexer.tokenize() parser = Parser( - tokens, self.engine.template_libraries, self.engine.template_builtins, + tokens, + self.engine.template_libraries, + self.engine.template_builtins, self.origin, ) @@ -263,30 +264,30 @@ class Template: try: message = str(exception.args[0]) except (IndexError, UnicodeDecodeError): - message = '(Could not get exception message)' + message = "(Could not get exception message)" return { - 'message': message, - 'source_lines': source_lines[top:bottom], - 'before': before, - 'during': during, - 'after': after, - 'top': top, - 'bottom': bottom, - 'total': total, - 'line': line, - 'name': self.origin.name, - 'start': start, - 'end': end, + "message": message, + "source_lines": source_lines[top:bottom], + "before": before, + "during": during, + "after": after, + "top": top, + "bottom": bottom, + "total": total, + "line": line, + "name": self.origin.name, + "start": start, + "end": end, } def linebreak_iter(template_source): yield 0 - p = template_source.find('\n') + p = template_source.find("\n") while p >= 0: yield p + 1 - p = template_source.find('\n', p + 1) + p = template_source.find("\n", p + 1) yield len(template_source) + 1 @@ -316,8 +317,10 @@ class Token: def __repr__(self): token_name = self.token_type.name.capitalize() - return ('<%s token: "%s...">' % - (token_name, self.contents[:20].replace('\n', ''))) + return '<%s token: "%s...">' % ( + token_name, + self.contents[:20].replace("\n", ""), + ) def split_contents(self): split = [] @@ -325,12 +328,12 @@ class Token: for bit in bits: # Handle translation-marked template pieces if bit.startswith(('_("', "_('")): - sentinel = bit[2] + ')' + sentinel = bit[2] + ")" trans_bit = [bit] while not bit.endswith(sentinel): bit = next(bits) trans_bit.append(bit) - bit = ' '.join(trans_bit) + bit = " ".join(trans_bit) split.append(bit) return split @@ -343,7 +346,7 @@ class Lexer: def __repr__(self): return '<%s template_string="%s...", verbatim=%s>' % ( self.__class__.__qualname__, - self.template_string[:20].replace('\n', ''), + self.template_string[:20].replace("\n", ""), self.verbatim, ) @@ -357,7 +360,7 @@ class Lexer: for token_string in tag_re.split(self.template_string): if token_string: result.append(self.create_token(token_string, None, lineno, in_tag)) - lineno += token_string.count('\n') + lineno += token_string.count("\n") in_tag = not in_tag return result @@ -382,9 +385,9 @@ class Lexer: return Token(TokenType.TEXT, token_string, position, lineno) # Otherwise, the current verbatim block is ending. self.verbatim = False - elif content[:9] in ('verbatim', 'verbatim '): + elif content[:9] in ("verbatim", "verbatim "): # Then a verbatim block is starting. - self.verbatim = 'end%s' % content + self.verbatim = "end%s" % content return Token(TokenType.BLOCK, content, position, lineno) if not self.verbatim: content = token_string[2:-2].strip() @@ -425,7 +428,7 @@ class DebugLexer(Lexer): for token_string, position in self._tag_re_split(): if token_string: result.append(self.create_token(token_string, position, lineno, in_tag)) - lineno += token_string.count('\n') + lineno += token_string.count("\n") in_tag = not in_tag return result @@ -450,7 +453,7 @@ class Parser: self.origin = origin def __repr__(self): - return '<%s tokens=%r>' % (self.__class__.__qualname__, self.tokens) + return "<%s tokens=%r>" % (self.__class__.__qualname__, self.tokens) def parse(self, parse_until=None): """ @@ -472,7 +475,9 @@ class Parser: self.extend_nodelist(nodelist, TextNode(token.contents), token) elif token_type == 1: # TokenType.VAR if not token.contents: - raise self.error(token, 'Empty variable tag on line %d' % token.lineno) + raise self.error( + token, "Empty variable tag on line %d" % token.lineno + ) try: filter_expression = self.compile_filter(token.contents) except TemplateSyntaxError as e: @@ -483,7 +488,7 @@ class Parser: try: command = token.contents.split()[0] except IndexError: - raise self.error(token, 'Empty block tag on line %d' % token.lineno) + raise self.error(token, "Empty block tag on line %d" % token.lineno) if command in parse_until: # A matching token has been reached. Return control to # the caller. Put the token back on the token list so the @@ -524,7 +529,8 @@ class Parser: # Check that non-text nodes don't appear before an extends tag. if node.must_be_first and nodelist.contains_nontext: raise self.error( - token, '%r must be the first tag in the template.' % node, + token, + "%r must be the first tag in the template." % node, ) if not isinstance(node, TextNode): nodelist.contains_nontext = True @@ -543,7 +549,7 @@ class Parser: """ if not isinstance(e, Exception): e = TemplateSyntaxError(e) - if not hasattr(e, 'token'): + if not hasattr(e, "token"): e.token = token return e @@ -552,16 +558,17 @@ class Parser: raise self.error( token, "Invalid block tag on line %d: '%s', expected %s. Did you " - "forget to register or load this tag?" % ( + "forget to register or load this tag?" + % ( token.lineno, command, - get_text_list(["'%s'" % p for p in parse_until], 'or'), + get_text_list(["'%s'" % p for p in parse_until], "or"), ), ) raise self.error( token, "Invalid block tag on line %d: '%s'. Did you forget to register " - "or load this tag?" % (token.lineno, command) + "or load this tag?" % (token.lineno, command), ) def unclosed_block_tag(self, parse_until): @@ -569,7 +576,7 @@ class Parser: msg = "Unclosed tag on line %d: '%s'. Looking for one of: %s." % ( token.lineno, command, - ', '.join(parse_until), + ", ".join(parse_until), ) raise self.error(token, msg) @@ -608,10 +615,10 @@ constant_string = r""" %(strdq)s| %(strsq)s) """ % { - 'strdq': r'"[^"\\]*(?:\\.[^"\\]*)*"', # double-quoted string - 'strsq': r"'[^'\\]*(?:\\.[^'\\]*)*'", # single-quoted string - 'i18n_open': re.escape("_("), - 'i18n_close': re.escape(")"), + "strdq": r'"[^"\\]*(?:\\.[^"\\]*)*"', # double-quoted string + "strsq": r"'[^'\\]*(?:\\.[^'\\]*)*'", # single-quoted string + "i18n_open": re.escape("_("), + "i18n_close": re.escape(")"), } constant_string = constant_string.replace("\n", "") @@ -627,11 +634,11 @@ filter_raw_string = r""" ) )? )""" % { - 'constant': constant_string, - 'num': r'[-+\.]?\d[\d\.e]*', - 'var_chars': r'\w\.', - 'filter_sep': re.escape(FILTER_SEPARATOR), - 'arg_sep': re.escape(FILTER_ARGUMENT_SEPARATOR), + "constant": constant_string, + "num": r"[-+\.]?\d[\d\.e]*", + "var_chars": r"\w\.", + "filter_sep": re.escape(FILTER_SEPARATOR), + "arg_sep": re.escape(FILTER_ARGUMENT_SEPARATOR), } filter_re = _lazy_re_compile(filter_raw_string, re.VERBOSE) @@ -652,7 +659,7 @@ class FilterExpression: <Variable: 'variable'> """ - __slots__ = ('token', 'filters', 'var', 'is_var') + __slots__ = ("token", "filters", "var", "is_var") def __init__(self, token, parser): self.token = token @@ -663,12 +670,12 @@ class FilterExpression: for match in matches: start = match.start() if upto != start: - raise TemplateSyntaxError("Could not parse some characters: " - "%s|%s|%s" % - (token[:upto], token[upto:start], - token[start:])) + raise TemplateSyntaxError( + "Could not parse some characters: " + "%s|%s|%s" % (token[:upto], token[upto:start], token[start:]) + ) if var_obj is None: - var, constant = match['var'], match['constant'] + var, constant = match["var"], match["constant"] if constant: try: var_obj = Variable(constant).resolve({}) @@ -681,9 +688,9 @@ class FilterExpression: else: var_obj = Variable(var) else: - filter_name = match['filter_name'] + filter_name = match["filter_name"] args = [] - constant_arg, var_arg = match['constant_arg'], match['var_arg'] + constant_arg, var_arg = match["constant_arg"], match["var_arg"] if constant_arg: args.append((False, Variable(constant_arg).resolve({}))) elif var_arg: @@ -693,8 +700,10 @@ class FilterExpression: filters.append((filter_func, args)) upto = match.end() if upto != len(token): - raise TemplateSyntaxError("Could not parse the remainder: '%s' " - "from '%s'" % (token[upto:], token)) + raise TemplateSyntaxError( + "Could not parse the remainder: '%s' " + "from '%s'" % (token[upto:], token) + ) self.filters = filters self.var = var_obj @@ -710,7 +719,7 @@ class FilterExpression: else: string_if_invalid = context.template.engine.string_if_invalid if string_if_invalid: - if '%s' in string_if_invalid: + if "%s" in string_if_invalid: return string_if_invalid % self.var else: return string_if_invalid @@ -725,13 +734,13 @@ class FilterExpression: arg_vals.append(mark_safe(arg)) else: arg_vals.append(arg.resolve(context)) - if getattr(func, 'expects_localtime', False): + if getattr(func, "expects_localtime", False): obj = template_localtime(obj, context.use_tz) - if getattr(func, 'needs_autoescape', False): + if getattr(func, "needs_autoescape", False): new_obj = func(obj, autoescape=context.autoescape, *arg_vals) else: new_obj = func(obj, *arg_vals) - if getattr(func, 'is_safe', False) and isinstance(obj, SafeData): + if getattr(func, "is_safe", False) and isinstance(obj, SafeData): obj = mark_safe(new_obj) else: obj = new_obj @@ -749,10 +758,12 @@ class FilterExpression: dlen = len(defaults or []) # Not enough OR Too many if plen < (alen - dlen) or plen > alen: - raise TemplateSyntaxError("%s requires %d arguments, %d provided" % - (name, alen - dlen, plen)) + raise TemplateSyntaxError( + "%s requires %d arguments, %d provided" % (name, alen - dlen, plen) + ) return True + args_check = staticmethod(args_check) def __str__(self): @@ -781,7 +792,7 @@ class Variable: (The example assumes VARIABLE_ATTRIBUTE_SEPARATOR is '.') """ - __slots__ = ('var', 'literal', 'lookups', 'translate', 'message_context') + __slots__ = ("var", "literal", "lookups", "translate", "message_context") def __init__(self, var): self.var = var @@ -791,8 +802,7 @@ class Variable: self.message_context = None if not isinstance(var, str): - raise TypeError( - "Variable must be a string or number, got %s" % type(var)) + raise TypeError("Variable must be a string or number, got %s" % type(var)) try: # First try to treat this variable as a number. # @@ -802,16 +812,16 @@ class Variable: # Try to interpret values containing a period or an 'e'/'E' # (possibly scientific notation) as a float; otherwise, try int. - if '.' in var or 'e' in var.lower(): + if "." in var or "e" in var.lower(): self.literal = float(var) # "2." is invalid - if var[-1] == '.': + if var[-1] == ".": raise ValueError else: self.literal = int(var) except ValueError: # A ValueError means that the variable isn't a number. - if var[0:2] == '_(' and var[-1] == ')': + if var[0:2] == "_(" and var[-1] == ")": # The result of the lookup should be translated at rendering # time. self.translate = True @@ -823,10 +833,11 @@ class Variable: except ValueError: # Otherwise we'll set self.lookups so that resolve() knows we're # dealing with a bonafide variable - if VARIABLE_ATTRIBUTE_SEPARATOR + '_' in var or var[0] == '_': - raise TemplateSyntaxError("Variables and attributes may " - "not begin with underscores: '%s'" % - var) + if VARIABLE_ATTRIBUTE_SEPARATOR + "_" in var or var[0] == "_": + raise TemplateSyntaxError( + "Variables and attributes may " + "not begin with underscores: '%s'" % var + ) self.lookups = tuple(var.split(VARIABLE_ATTRIBUTE_SEPARATOR)) def resolve(self, context): @@ -839,7 +850,7 @@ class Variable: value = self.literal if self.translate: is_safe = isinstance(value, SafeData) - msgid = value.replace('%', '%%') + msgid = value.replace("%", "%%") msgid = mark_safe(msgid) if is_safe else msgid if self.message_context: return pgettext_lazy(self.message_context, msgid) @@ -872,7 +883,9 @@ class Variable: except (TypeError, AttributeError, KeyError, ValueError, IndexError): try: # attribute lookup # Don't return class attributes if the class is the context: - if isinstance(current, BaseContext) and getattr(type(current), bit): + if isinstance(current, BaseContext) and getattr( + type(current), bit + ): raise AttributeError current = getattr(current, bit) except (TypeError, AttributeError): @@ -881,18 +894,20 @@ class Variable: raise try: # list-index lookup current = current[int(bit)] - except (IndexError, # list index out of range - ValueError, # invalid literal for int() - KeyError, # current is a dict without `int(bit)` key - TypeError): # unsubscriptable object + except ( + IndexError, # list index out of range + ValueError, # invalid literal for int() + KeyError, # current is a dict without `int(bit)` key + TypeError, + ): # unsubscriptable object raise VariableDoesNotExist( "Failed lookup for key [%s] in %r", (bit, current), ) # missing attribute if callable(current): - if getattr(current, 'do_not_call_in_templates', False): + if getattr(current, "do_not_call_in_templates", False): pass - elif getattr(current, 'alters_data', False): + elif getattr(current, "alters_data", False): current = context.template.engine.string_if_invalid else: try: # method call (assuming no args required) @@ -902,11 +917,13 @@ class Variable: try: signature.bind() except TypeError: # arguments *were* required - current = context.template.engine.string_if_invalid # invalid method call + current = ( + context.template.engine.string_if_invalid + ) # invalid method call else: raise except Exception as e: - template_name = getattr(context, 'template_name', None) or 'unknown' + template_name = getattr(context, "template_name", None) or "unknown" logger.debug( "Exception while resolving variable '%s' in template '%s'.", bit, @@ -914,7 +931,7 @@ class Variable: exc_info=True, ) - if getattr(e, 'silent_variable_failure', False): + if getattr(e, "silent_variable_failure", False): current = context.template.engine.string_if_invalid else: raise @@ -926,7 +943,7 @@ class Node: # Set this to True for nodes that must be first in the template (although # they can be preceded by text nodes. must_be_first = False - child_nodelists = ('nodelist',) + child_nodelists = ("nodelist",) token = None def render(self, context): @@ -947,14 +964,17 @@ class Node: except Exception as e: if context.template.engine.debug: # Store the actual node that caused the exception. - if not hasattr(e, '_culprit_node'): + if not hasattr(e, "_culprit_node"): e._culprit_node = self if ( - not hasattr(e, 'template_debug') and - context.render_context.template.origin == e._culprit_node.origin + not hasattr(e, "template_debug") + and context.render_context.template.origin == e._culprit_node.origin ): - e.template_debug = context.render_context.template.get_exception_info( - e, e._culprit_node.token, + e.template_debug = ( + context.render_context.template.get_exception_info( + e, + e._culprit_node.token, + ) ) raise @@ -982,9 +1002,7 @@ class NodeList(list): contains_nontext = False def render(self, context): - return SafeString(''.join([ - node.render_annotated(context) for node in self - ])) + return SafeString("".join([node.render_annotated(context) for node in self])) def get_nodes_by_type(self, nodetype): "Return a list of all nodes of the given type" @@ -1048,7 +1066,7 @@ class VariableNode(Node): # Unicode conversion can fail sometimes for reasons out of our # control (e.g. exception rendering). In that case, we fail # quietly. - return '' + return "" return render_value_in_context(output, context) @@ -1079,7 +1097,7 @@ def token_kwargs(bits, parser, support_legacy=False): if not kwarg_format: if not support_legacy: return {} - if len(bits) < 3 or bits[1] != 'as': + if len(bits) < 3 or bits[1] != "as": return {} kwargs = {} @@ -1091,13 +1109,13 @@ def token_kwargs(bits, parser, support_legacy=False): key, value = match.groups() del bits[:1] else: - if len(bits) < 3 or bits[1] != 'as': + if len(bits) < 3 or bits[1] != "as": return kwargs key, value = bits[2], bits[0] del bits[:3] kwargs[key] = parser.compile_filter(value) if bits and not kwarg_format: - if bits[0] != 'and': + if bits[0] != "and": return kwargs del bits[:1] return kwargs diff --git a/django/template/context.py b/django/template/context.py index f0a0cf2a00..ccf0b430dc 100644 --- a/django/template/context.py +++ b/django/template/context.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from copy import copy # Hard-coded processor for easier use of CSRF protection. -_builtin_context_processors = ('django.template.context_processors.csrf',) +_builtin_context_processors = ("django.template.context_processors.csrf",) class ContextPopException(Exception): @@ -29,7 +29,7 @@ class BaseContext: self._reset_dicts(dict_) def _reset_dicts(self, value=None): - builtins = {'True': True, 'False': False, 'None': None} + builtins = {"True": True, "False": False, "None": None} self.dicts = [builtins] if value is not None: self.dicts.append(value) @@ -132,6 +132,7 @@ class BaseContext: class Context(BaseContext): "A stack container for variable context" + def __init__(self, dict_=None, autoescape=True, use_l10n=None, use_tz=None): self.autoescape = autoescape self.use_l10n = use_l10n @@ -160,8 +161,8 @@ class Context(BaseContext): def update(self, other_dict): "Push other_dict to the stack of dictionaries in the Context" - if not hasattr(other_dict, '__getitem__'): - raise TypeError('other_dict must be a mapping (dictionary-like) object.') + if not hasattr(other_dict, "__getitem__"): + raise TypeError("other_dict must be a mapping (dictionary-like) object.") if isinstance(other_dict, BaseContext): other_dict = other_dict.dicts[1:].pop() return ContextDict(self, other_dict) @@ -182,6 +183,7 @@ class RenderContext(BaseContext): rendering of other templates as they would if they were stored in the normal template context. """ + template = None def __iter__(self): @@ -217,7 +219,16 @@ class RequestContext(Context): Additional processors can be specified as a list of callables using the "processors" keyword argument. """ - def __init__(self, request, dict_=None, processors=None, use_l10n=None, use_tz=None, autoescape=True): + + def __init__( + self, + request, + dict_=None, + processors=None, + use_l10n=None, + use_tz=None, + autoescape=True, + ): super().__init__(dict_, use_l10n=use_l10n, use_tz=use_tz, autoescape=autoescape) self.request = request self._processors = () if processors is None else tuple(processors) @@ -237,8 +248,7 @@ class RequestContext(Context): self.template = template # Set context processors according to the template engine's settings. - processors = (template.engine.template_context_processors + - self._processors) + processors = template.engine.template_context_processors + self._processors updates = {} for processor in processors: updates.update(processor(self.request)) @@ -255,7 +265,7 @@ class RequestContext(Context): new_context = super().new(values) # This is for backwards-compatibility: RequestContexts created via # Context.new don't include values from context processors. - if hasattr(new_context, '_processors_index'): + if hasattr(new_context, "_processors_index"): del new_context._processors_index return new_context @@ -265,7 +275,9 @@ def make_context(context, request=None, **kwargs): Create a suitable Context from a plain dict and optionally an HttpRequest. """ if context is not None and not isinstance(context, dict): - raise TypeError('context must be a dict rather than %s.' % context.__class__.__name__) + raise TypeError( + "context must be a dict rather than %s." % context.__class__.__name__ + ) if request is None: context = Context(context, **kwargs) else: diff --git a/django/template/context_processors.py b/django/template/context_processors.py index 25ac1f2661..32753032fc 100644 --- a/django/template/context_processors.py +++ b/django/template/context_processors.py @@ -19,17 +19,18 @@ def csrf(request): Context processor that provides a CSRF token, or the string 'NOTPROVIDED' if it has not been provided by either a view decorator or the middleware """ + def _get_val(): token = get_token(request) if token is None: # In order to be able to provide debugging info in the # case of misconfiguration, we use a sentinel value # instead of returning an empty dict. - return 'NOTPROVIDED' + return "NOTPROVIDED" else: return token - return {'csrf_token': SimpleLazyObject(_get_val)} + return {"csrf_token": SimpleLazyObject(_get_val)} def debug(request): @@ -37,46 +38,52 @@ def debug(request): Return context variables helpful for debugging. """ context_extras = {} - if settings.DEBUG and request.META.get('REMOTE_ADDR') in settings.INTERNAL_IPS: - context_extras['debug'] = True + if settings.DEBUG and request.META.get("REMOTE_ADDR") in settings.INTERNAL_IPS: + context_extras["debug"] = True from django.db import connections # Return a lazy reference that computes connection.queries on access, # to ensure it contains queries triggered after this function runs. - context_extras['sql_queries'] = lazy( - lambda: list(itertools.chain.from_iterable(connections[x].queries for x in connections)), - list + context_extras["sql_queries"] = lazy( + lambda: list( + itertools.chain.from_iterable( + connections[x].queries for x in connections + ) + ), + list, ) return context_extras def i18n(request): from django.utils import translation + return { - 'LANGUAGES': settings.LANGUAGES, - 'LANGUAGE_CODE': translation.get_language(), - 'LANGUAGE_BIDI': translation.get_language_bidi(), + "LANGUAGES": settings.LANGUAGES, + "LANGUAGE_CODE": translation.get_language(), + "LANGUAGE_BIDI": translation.get_language_bidi(), } def tz(request): from django.utils import timezone - return {'TIME_ZONE': timezone.get_current_timezone_name()} + + return {"TIME_ZONE": timezone.get_current_timezone_name()} def static(request): """ Add static-related context variables to the context. """ - return {'STATIC_URL': settings.STATIC_URL} + return {"STATIC_URL": settings.STATIC_URL} def media(request): """ Add media-related context variables to the context. """ - return {'MEDIA_URL': settings.MEDIA_URL} + return {"MEDIA_URL": settings.MEDIA_URL} def request(request): - return {'request': request} + return {"request": request} diff --git a/django/template/defaultfilters.py b/django/template/defaultfilters.py index f78a96e3eb..46334791c6 100644 --- a/django/template/defaultfilters.py +++ b/django/template/defaultfilters.py @@ -12,14 +12,14 @@ from urllib.parse import quote from django.utils import formats from django.utils.dateformat import format, time_format from django.utils.encoding import iri_to_uri -from django.utils.html import ( - avoid_wrapping, conditional_escape, escape, escapejs, - json_script as _json_script, linebreaks, strip_tags, urlize as _urlize, -) +from django.utils.html import avoid_wrapping, conditional_escape, escape, escapejs +from django.utils.html import json_script as _json_script +from django.utils.html import linebreaks, strip_tags +from django.utils.html import urlize as _urlize from django.utils.safestring import SafeData, mark_safe -from django.utils.text import ( - Truncator, normalize_newlines, phone2numeric, slugify as _slugify, wrap, -) +from django.utils.text import Truncator, normalize_newlines, phone2numeric +from django.utils.text import slugify as _slugify +from django.utils.text import wrap from django.utils.timesince import timesince, timeuntil from django.utils.translation import gettext, ngettext @@ -33,16 +33,18 @@ register = Library() # STRING DECORATOR # ####################### + def stringfilter(func): """ Decorator for filters which should only receive strings. The object passed as the first positional argument will be converted to a string. """ + @wraps(func) def _dec(first, *args, **kwargs): first = str(first) result = func(first, *args, **kwargs) - if isinstance(first, SafeData) and getattr(unwrap(func), 'is_safe', False): + if isinstance(first, SafeData) and getattr(unwrap(func), "is_safe", False): result = mark_safe(result) return result @@ -53,6 +55,7 @@ def stringfilter(func): # STRINGS # ################### + @register.filter(is_safe=True) @stringfilter def addslashes(value): @@ -61,7 +64,7 @@ def addslashes(value): example. Less useful for escaping JavaScript; use the ``escapejs`` filter instead. """ - return value.replace('\\', '\\\\').replace('"', '\\"').replace("'", "\\'") + return value.replace("\\", "\\\\").replace('"', '\\"').replace("'", "\\'") @register.filter(is_safe=True) @@ -135,14 +138,14 @@ def floatformat(text, arg=-1): use_l10n = True if isinstance(arg, str): last_char = arg[-1] - if arg[-2:] in {'gu', 'ug'}: + if arg[-2:] in {"gu", "ug"}: force_grouping = True use_l10n = False arg = arg[:-2] or -1 - elif last_char == 'g': + elif last_char == "g": force_grouping = True arg = arg[:-1] or -1 - elif last_char == 'u': + elif last_char == "u": use_l10n = False arg = arg[:-1] or -1 try: @@ -152,7 +155,7 @@ def floatformat(text, arg=-1): try: d = Decimal(str(float(text))) except (ValueError, InvalidOperation, TypeError): - return '' + return "" try: p = int(arg) except ValueError: @@ -164,12 +167,14 @@ def floatformat(text, arg=-1): return input_val if not m and p < 0: - return mark_safe(formats.number_format( - '%d' % (int(d)), - 0, - use_l10n=use_l10n, - force_grouping=force_grouping, - )) + return mark_safe( + formats.number_format( + "%d" % (int(d)), + 0, + use_l10n=use_l10n, + force_grouping=force_grouping, + ) + ) exp = Decimal(1).scaleb(-abs(p)) # Set the precision high enough to avoid an exception (#15789). @@ -184,17 +189,19 @@ def floatformat(text, arg=-1): sign, digits, exponent = rounded_d.as_tuple() digits = [str(digit) for digit in reversed(digits)] while len(digits) <= abs(exponent): - digits.append('0') - digits.insert(-exponent, '.') + digits.append("0") + digits.insert(-exponent, ".") if sign and rounded_d: - digits.append('-') - number = ''.join(reversed(digits)) - return mark_safe(formats.number_format( - number, - abs(p), - use_l10n=use_l10n, - force_grouping=force_grouping, - )) + digits.append("-") + number = "".join(reversed(digits)) + return mark_safe( + formats.number_format( + number, + abs(p), + use_l10n=use_l10n, + force_grouping=force_grouping, + ) + ) @register.filter(is_safe=True) @@ -208,7 +215,7 @@ def iriencode(value): @stringfilter def linenumbers(value, autoescape=True): """Display text with line numbers.""" - lines = value.split('\n') + lines = value.split("\n") # Find the maximum width of the line count, for use with zero padding # string format command width = str(len(str(len(lines)))) @@ -218,7 +225,7 @@ def linenumbers(value, autoescape=True): else: for i, line in enumerate(lines): lines[i] = ("%0" + width + "d. %s") % (i + 1, escape(line)) - return mark_safe('\n'.join(lines)) + return mark_safe("\n".join(lines)) @register.filter(is_safe=True) @@ -275,7 +282,7 @@ def stringformat(value, arg): def title(value): """Convert a string into titlecase.""" t = re.sub("([a-z])'([A-Z])", lambda m: m[0].lower(), value.title()) - return re.sub(r'\d([A-Z])', lambda m: m[0].lower(), t) + return re.sub(r"\d([A-Z])", lambda m: m[0].lower(), t) @register.filter(is_safe=True) @@ -314,7 +321,7 @@ def truncatewords(value, arg): length = int(arg) except ValueError: # Invalid literal for int(). return value # Fail silently. - return Truncator(value).words(length, truncate=' …') + return Truncator(value).words(length, truncate=" …") @register.filter(is_safe=True) @@ -328,7 +335,7 @@ def truncatewords_html(value, arg): length = int(arg) except ValueError: # invalid literal for int() return value # Fail silently. - return Truncator(value).words(length, html=True, truncate=' …') + return Truncator(value).words(length, html=True, truncate=" …") @register.filter(is_safe=False) @@ -351,7 +358,7 @@ def urlencode(value, safe=None): """ kwargs = {} if safe is not None: - kwargs['safe'] = safe + kwargs["safe"] = safe return quote(value, **kwargs) @@ -371,7 +378,9 @@ def urlizetrunc(value, limit, autoescape=True): Argument: Length to truncate URLs to. """ - return mark_safe(_urlize(value, trim_url_limit=int(limit), nofollow=True, autoescape=autoescape)) + return mark_safe( + _urlize(value, trim_url_limit=int(limit), nofollow=True, autoescape=autoescape) + ) @register.filter(is_safe=False) @@ -414,8 +423,8 @@ def center(value, arg): def cut(value, arg): """Remove all values of arg from the given string.""" safe = isinstance(value, SafeData) - value = value.replace(arg, '') - if safe and arg != ';': + value = value.replace(arg, "") + if safe and arg != ";": return mark_safe(value) return value @@ -424,6 +433,7 @@ def cut(value, arg): # HTML STRINGS # ################### + @register.filter("escape", is_safe=True) @stringfilter def escape_filter(value): @@ -465,7 +475,7 @@ def linebreaksbr(value, autoescape=True): value = normalize_newlines(value) if autoescape: value = escape(value) - return mark_safe(value.replace('\n', '<br>')) + return mark_safe(value.replace("\n", "<br>")) @register.filter(is_safe=True) @@ -496,6 +506,7 @@ def striptags(value): # LISTS # ################### + def _property_resolver(arg): """ When arg is convertible to float, behave like operator.itemgetter(arg) @@ -517,8 +528,8 @@ def _property_resolver(arg): try: float(arg) except ValueError: - if VARIABLE_ATTRIBUTE_SEPARATOR + '_' in arg or arg[0] == '_': - raise AttributeError('Access to private variables is forbidden.') + if VARIABLE_ATTRIBUTE_SEPARATOR + "_" in arg or arg[0] == "_": + raise AttributeError("Access to private variables is forbidden.") parts = arg.split(VARIABLE_ATTRIBUTE_SEPARATOR) def resolve(value): @@ -543,7 +554,7 @@ def dictsort(value, arg): try: return sorted(value, key=_property_resolver(arg)) except (AttributeError, TypeError): - return '' + return "" @register.filter(is_safe=False) @@ -555,7 +566,7 @@ def dictsortreversed(value, arg): try: return sorted(value, key=_property_resolver(arg), reverse=True) except (AttributeError, TypeError): - return '' + return "" @register.filter(is_safe=False) @@ -564,7 +575,7 @@ def first(value): try: return value[0] except IndexError: - return '' + return "" @register.filter(is_safe=True, needs_autoescape=True) @@ -585,7 +596,7 @@ def last(value): try: return value[-1] except IndexError: - return '' + return "" @register.filter(is_safe=False) @@ -603,7 +614,7 @@ def length_is(value, arg): try: return len(value) == int(arg) except (ValueError, TypeError): - return '' + return "" @register.filter(is_safe=True) @@ -619,7 +630,7 @@ def slice_filter(value, arg): """ try: bits = [] - for x in str(arg).split(':'): + for x in str(arg).split(":"): if not x: bits.append(None) else: @@ -655,6 +666,7 @@ def unordered_list(value, autoescape=True): if autoescape: escaper = conditional_escape else: + def escaper(x): return x @@ -683,16 +695,19 @@ def unordered_list(value, autoescape=True): pass def list_formatter(item_list, tabs=1): - indent = '\t' * tabs + indent = "\t" * tabs output = [] for item, children in walk_items(item_list): - sublist = '' + sublist = "" if children: - sublist = '\n%s<ul>\n%s\n%s</ul>\n%s' % ( - indent, list_formatter(children, tabs + 1), indent, indent) - output.append('%s<li>%s%s</li>' % ( - indent, escaper(item), sublist)) - return '\n'.join(output) + sublist = "\n%s<ul>\n%s\n%s</ul>\n%s" % ( + indent, + list_formatter(children, tabs + 1), + indent, + indent, + ) + output.append("%s<li>%s%s</li>" % (indent, escaper(item), sublist)) + return "\n".join(output) return mark_safe(list_formatter(value)) @@ -701,6 +716,7 @@ def unordered_list(value, autoescape=True): # INTEGERS # ################### + @register.filter(is_safe=False) def add(value, arg): """Add the arg to the value.""" @@ -710,7 +726,7 @@ def add(value, arg): try: return value + arg except Exception: - return '' + return "" @register.filter(is_safe=False) @@ -738,62 +754,64 @@ def get_digit(value, arg): # DATES # ################### + @register.filter(expects_localtime=True, is_safe=False) def date(value, arg=None): """Format a date according to the given format.""" - if value in (None, ''): - return '' + if value in (None, ""): + return "" try: return formats.date_format(value, arg) except AttributeError: try: return format(value, arg) except AttributeError: - return '' + return "" @register.filter(expects_localtime=True, is_safe=False) def time(value, arg=None): """Format a time according to the given format.""" - if value in (None, ''): - return '' + if value in (None, ""): + return "" try: return formats.time_format(value, arg) except (AttributeError, TypeError): try: return time_format(value, arg) except (AttributeError, TypeError): - return '' + return "" @register.filter("timesince", is_safe=False) def timesince_filter(value, arg=None): """Format a date as the time since that date (i.e. "4 days, 6 hours").""" if not value: - return '' + return "" try: if arg: return timesince(value, arg) return timesince(value) except (ValueError, TypeError): - return '' + return "" @register.filter("timeuntil", is_safe=False) def timeuntil_filter(value, arg=None): """Format a date as the time until that date (i.e. "4 days, 6 hours").""" if not value: - return '' + return "" try: return timeuntil(value, arg) except (ValueError, TypeError): - return '' + return "" ################### # LOGIC # ################### + @register.filter(is_safe=False) def default(value, arg): """If value is unavailable, use given default.""" @@ -832,8 +850,8 @@ def yesno(value, arg=None): """ if arg is None: # Translators: Please do not add spaces around commas. - arg = gettext('yes,no,maybe') - bits = arg.split(',') + arg = gettext("yes,no,maybe") + bits = arg.split(",") if len(bits) < 2: return value # Invalid arg. try: @@ -852,6 +870,7 @@ def yesno(value, arg=None): # MISC # ################### + @register.filter(is_safe=True) def filesizeformat(bytes_): """ @@ -861,7 +880,7 @@ def filesizeformat(bytes_): try: bytes_ = int(bytes_) except (TypeError, ValueError, UnicodeDecodeError): - value = ngettext("%(size)d byte", "%(size)d bytes", 0) % {'size': 0} + value = ngettext("%(size)d byte", "%(size)d bytes", 0) % {"size": 0} return avoid_wrapping(value) def filesize_number_format(value): @@ -878,7 +897,7 @@ def filesizeformat(bytes_): bytes_ = -bytes_ # Allow formatting of negative numbers. if bytes_ < KB: - value = ngettext("%(size)d byte", "%(size)d bytes", bytes_) % {'size': bytes_} + value = ngettext("%(size)d byte", "%(size)d bytes", bytes_) % {"size": bytes_} elif bytes_ < MB: value = gettext("%s KB") % filesize_number_format(bytes_ / KB) elif bytes_ < GB: @@ -896,7 +915,7 @@ def filesizeformat(bytes_): @register.filter(is_safe=False) -def pluralize(value, arg='s'): +def pluralize(value, arg="s"): """ Return a plural suffix if the value is not 1, '1', or an object of length 1. By default, use 's' as the suffix: @@ -918,11 +937,11 @@ def pluralize(value, arg='s'): * If value is 1, cand{{ value|pluralize:"y,ies" }} display "candy". * If value is 2, cand{{ value|pluralize:"y,ies" }} display "candies". """ - if ',' not in arg: - arg = ',' + arg - bits = arg.split(',') + if "," not in arg: + arg = "," + arg + bits = arg.split(",") if len(bits) > 2: - return '' + return "" singular_suffix, plural_suffix = bits[:2] try: @@ -934,7 +953,7 @@ def pluralize(value, arg='s'): return singular_suffix if len(value) == 1 else plural_suffix except TypeError: # len() of unsized object. pass - return '' + return "" @register.filter("phone2numeric", is_safe=True) diff --git a/django/template/defaulttags.py b/django/template/defaulttags.py index 9090b14e40..7762b94723 100644 --- a/django/template/defaulttags.py +++ b/django/template/defaulttags.py @@ -4,7 +4,8 @@ import sys import warnings from collections import namedtuple from datetime import datetime -from itertools import cycle as itertools_cycle, groupby +from itertools import cycle as itertools_cycle +from itertools import groupby from django.conf import settings from django.utils import timezone @@ -13,11 +14,23 @@ from django.utils.lorem_ipsum import paragraphs, words from django.utils.safestring import mark_safe from .base import ( - BLOCK_TAG_END, BLOCK_TAG_START, COMMENT_TAG_END, COMMENT_TAG_START, - FILTER_SEPARATOR, SINGLE_BRACE_END, SINGLE_BRACE_START, - VARIABLE_ATTRIBUTE_SEPARATOR, VARIABLE_TAG_END, VARIABLE_TAG_START, Node, - NodeList, TemplateSyntaxError, VariableDoesNotExist, kwarg_re, - render_value_in_context, token_kwargs, + BLOCK_TAG_END, + BLOCK_TAG_START, + COMMENT_TAG_END, + COMMENT_TAG_START, + FILTER_SEPARATOR, + SINGLE_BRACE_END, + SINGLE_BRACE_START, + VARIABLE_ATTRIBUTE_SEPARATOR, + VARIABLE_TAG_END, + VARIABLE_TAG_START, + Node, + NodeList, + TemplateSyntaxError, + VariableDoesNotExist, + kwarg_re, + render_value_in_context, + token_kwargs, ) from .context import Context from .defaultfilters import date @@ -29,6 +42,7 @@ register = Library() class AutoEscapeControlNode(Node): """Implement the actions of the autoescape tag.""" + def __init__(self, setting, nodelist): self.setting, self.nodelist = setting, nodelist @@ -47,19 +61,22 @@ class CommentNode(Node): child_nodelists = () def render(self, context): - return '' + return "" class CsrfTokenNode(Node): child_nodelists = () def render(self, context): - csrf_token = context.get('csrf_token') + csrf_token = context.get("csrf_token") if csrf_token: - if csrf_token == 'NOTPROVIDED': + if csrf_token == "NOTPROVIDED": return format_html("") else: - return format_html('<input type="hidden" name="csrfmiddlewaretoken" value="{}">', csrf_token) + return format_html( + '<input type="hidden" name="csrfmiddlewaretoken" value="{}">', + csrf_token, + ) else: # It's very probable that the token is missing because of # misconfiguration, so we raise a warning @@ -69,7 +86,7 @@ class CsrfTokenNode(Node): "did not provide the value. This is usually caused by not " "using RequestContext." ) - return '' + return "" class CycleNode(Node): @@ -87,7 +104,7 @@ class CycleNode(Node): if self.variable_name: context.set_upward(self.variable_name, value) if self.silent: - return '' + return "" return render_value_in_context(value, context) def reset(self, context): @@ -100,13 +117,14 @@ class CycleNode(Node): class DebugNode(Node): def render(self, context): if not settings.DEBUG: - return '' + return "" from pprint import pformat + output = [escape(pformat(val)) for val in context] - output.append('\n\n') + output.append("\n\n") output.append(escape(pformat(sys.modules))) - return ''.join(output) + return "".join(output) class FilterNode(Node): @@ -126,7 +144,7 @@ class FirstOfNode(Node): self.asvar = asvar def render(self, context): - first = '' + first = "" for var in self.vars: value = var.resolve(context, ignore_failures=True) if value: @@ -134,14 +152,16 @@ class FirstOfNode(Node): break if self.asvar: context[self.asvar] = first - return '' + return "" return first class ForNode(Node): - child_nodelists = ('nodelist_loop', 'nodelist_empty') + child_nodelists = ("nodelist_loop", "nodelist_empty") - def __init__(self, loopvars, sequence, is_reversed, nodelist_loop, nodelist_empty=None): + def __init__( + self, loopvars, sequence, is_reversed, nodelist_loop, nodelist_empty=None + ): self.loopvars, self.sequence = loopvars, sequence self.is_reversed = is_reversed self.nodelist_loop = nodelist_loop @@ -151,25 +171,25 @@ class ForNode(Node): self.nodelist_empty = nodelist_empty def __repr__(self): - reversed_text = ' reversed' if self.is_reversed else '' - return '<%s: for %s in %s, tail_len: %d%s>' % ( + reversed_text = " reversed" if self.is_reversed else "" + return "<%s: for %s in %s, tail_len: %d%s>" % ( self.__class__.__name__, - ', '.join(self.loopvars), + ", ".join(self.loopvars), self.sequence, len(self.nodelist_loop), reversed_text, ) def render(self, context): - if 'forloop' in context: - parentloop = context['forloop'] + if "forloop" in context: + parentloop = context["forloop"] else: parentloop = {} with context.push(): values = self.sequence.resolve(context, ignore_failures=True) if values is None: values = [] - if not hasattr(values, '__len__'): + if not hasattr(values, "__len__"): values = list(values) len_values = len(values) if len_values < 1: @@ -181,17 +201,17 @@ class ForNode(Node): unpack = num_loopvars > 1 # Create a forloop value in the context. We'll update counters on each # iteration just below. - loop_dict = context['forloop'] = {'parentloop': parentloop} + loop_dict = context["forloop"] = {"parentloop": parentloop} for i, item in enumerate(values): # Shortcuts for current loop iteration number. - loop_dict['counter0'] = i - loop_dict['counter'] = i + 1 + loop_dict["counter0"] = i + loop_dict["counter"] = i + 1 # Reverse counter iteration numbers. - loop_dict['revcounter'] = len_values - i - loop_dict['revcounter0'] = len_values - i - 1 + loop_dict["revcounter"] = len_values - i + loop_dict["revcounter0"] = len_values - i - 1 # Boolean values designating first and last times through loop. - loop_dict['first'] = (i == 0) - loop_dict['last'] = (i == len_values - 1) + loop_dict["first"] = i == 0 + loop_dict["last"] = i == len_values - 1 pop_context = False if unpack: @@ -204,8 +224,9 @@ class ForNode(Node): # Check loop variable count before unpacking if num_loopvars != len_item: raise ValueError( - "Need {} values to unpack in for loop; got {}. " - .format(num_loopvars, len_item), + "Need {} values to unpack in for loop; got {}. ".format( + num_loopvars, len_item + ), ) unpacked_vars = dict(zip(self.loopvars, item)) pop_context = True @@ -221,11 +242,11 @@ class ForNode(Node): # the context ending up in an inconsistent state when other # tags (e.g., include and with) push data to context. context.pop() - return mark_safe(''.join(nodelist)) + return mark_safe("".join(nodelist)) class IfChangedNode(Node): - child_nodelists = ('nodelist_true', 'nodelist_false') + child_nodelists = ("nodelist_true", "nodelist_false") def __init__(self, nodelist_true, nodelist_false, *varlist): self.nodelist_true, self.nodelist_false = nodelist_true, nodelist_false @@ -240,7 +261,9 @@ class IfChangedNode(Node): if self._varlist: # Consider multiple parameters. This behaves like an OR evaluation # of the multiple variables. - compare_to = [var.resolve(context, ignore_failures=True) for var in self._varlist] + compare_to = [ + var.resolve(context, ignore_failures=True) for var in self._varlist + ] else: # The "{% ifchanged %}" syntax (without any variables) compares # the rendered output. @@ -252,28 +275,27 @@ class IfChangedNode(Node): return nodelist_true_output or self.nodelist_true.render(context) elif self.nodelist_false: return self.nodelist_false.render(context) - return '' + return "" def _get_context_stack_frame(self, context): # The Context object behaves like a stack where each template tag can create a new scope. # Find the place where to store the state to detect changes. - if 'forloop' in context: + if "forloop" in context: # Ifchanged is bound to the local for loop. # When there is a loop-in-loop, the state is bound to the inner loop, # so it resets when the outer loop continues. - return context['forloop'] + return context["forloop"] else: # Using ifchanged outside loops. Effectively this is a no-op because the state is associated with 'self'. return context.render_context class IfNode(Node): - def __init__(self, conditions_nodelists): self.conditions_nodelists = conditions_nodelists def __repr__(self): - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def __iter__(self): for _, nodelist in self.conditions_nodelists: @@ -286,18 +308,18 @@ class IfNode(Node): def render(self, context): for condition, nodelist in self.conditions_nodelists: - if condition is not None: # if / elif clause + if condition is not None: # if / elif clause try: match = condition.eval(context) except VariableDoesNotExist: match = None - else: # else clause + else: # else clause match = True if match: return nodelist.render(context) - return '' + return "" class LoremNode(Node): @@ -309,16 +331,16 @@ class LoremNode(Node): count = int(self.count.resolve(context)) except (ValueError, TypeError): count = 1 - if self.method == 'w': + if self.method == "w": return words(count, common=self.common) else: paras = paragraphs(count, common=self.common) - if self.method == 'p': - paras = ['<p>%s</p>' % p for p in paras] - return '\n\n'.join(paras) + if self.method == "p": + paras = ["<p>%s</p>" % p for p in paras] + return "\n\n".join(paras) -GroupedResult = namedtuple('GroupedResult', ['grouper', 'list']) +GroupedResult = namedtuple("GroupedResult", ["grouper", "list"]) class RegroupNode(Node): @@ -337,22 +359,23 @@ class RegroupNode(Node): if obj_list is None: # target variable wasn't found in context; fail silently. context[self.var_name] = [] - return '' + return "" # List of dictionaries in the format: # {'grouper': 'key', 'list': [list of contents]}. context[self.var_name] = [ GroupedResult(grouper=key, list=list(val)) - for key, val in - groupby(obj_list, lambda obj: self.resolve_expression(obj, context)) + for key, val in groupby( + obj_list, lambda obj: self.resolve_expression(obj, context) + ) ] - return '' + return "" class LoadNode(Node): child_nodelists = () def render(self, context): - return '' + return "" class NowNode(Node): @@ -366,7 +389,7 @@ class NowNode(Node): if self.asvar: context[self.asvar] = formatted - return '' + return "" else: return formatted @@ -377,7 +400,7 @@ class ResetCycleNode(Node): def render(self, context): self.node.reset(context) - return '' + return "" class SpacelessNode(Node): @@ -386,26 +409,27 @@ class SpacelessNode(Node): def render(self, context): from django.utils.html import strip_spaces_between_tags + return strip_spaces_between_tags(self.nodelist.render(context).strip()) class TemplateTagNode(Node): mapping = { - 'openblock': BLOCK_TAG_START, - 'closeblock': BLOCK_TAG_END, - 'openvariable': VARIABLE_TAG_START, - 'closevariable': VARIABLE_TAG_END, - 'openbrace': SINGLE_BRACE_START, - 'closebrace': SINGLE_BRACE_END, - 'opencomment': COMMENT_TAG_START, - 'closecomment': COMMENT_TAG_END, + "openblock": BLOCK_TAG_START, + "closeblock": BLOCK_TAG_END, + "openvariable": VARIABLE_TAG_START, + "closevariable": VARIABLE_TAG_END, + "openbrace": SINGLE_BRACE_START, + "closebrace": SINGLE_BRACE_END, + "opencomment": COMMENT_TAG_START, + "closecomment": COMMENT_TAG_END, } def __init__(self, tagtype): self.tagtype = tagtype def render(self, context): - return self.mapping.get(self.tagtype, '') + return self.mapping.get(self.tagtype, "") class URLNode(Node): @@ -428,6 +452,7 @@ class URLNode(Node): def render(self, context): from django.urls import NoReverseMatch, reverse + args = [arg.resolve(context) for arg in self.args] kwargs = {k: v.resolve(context) for k, v in self.kwargs.items()} view_name = self.view_name.resolve(context) @@ -440,7 +465,7 @@ class URLNode(Node): current_app = None # Try to look up the URL. If it fails, raise NoReverseMatch unless the # {% url ... as var %} construct is used, in which case return nothing. - url = '' + url = "" try: url = reverse(view_name, args=args, kwargs=kwargs, current_app=current_app) except NoReverseMatch: @@ -449,7 +474,7 @@ class URLNode(Node): if self.asvar: context[self.asvar] = url - return '' + return "" else: if context.autoescape: url = conditional_escape(url) @@ -477,7 +502,7 @@ class WidthRatioNode(Node): max_value = self.max_expr.resolve(context) max_width = int(self.max_width.resolve(context)) except VariableDoesNotExist: - return '' + return "" except (ValueError, TypeError): raise TemplateSyntaxError("widthratio final argument must be a number") try: @@ -486,13 +511,13 @@ class WidthRatioNode(Node): ratio = (value / max_value) * max_width result = str(round(ratio)) except ZeroDivisionError: - result = '0' + result = "0" except (ValueError, TypeError, OverflowError): - result = '' + result = "" if self.asvar: context[self.asvar] = result - return '' + return "" else: return result @@ -507,7 +532,7 @@ class WithNode(Node): self.extra_context[name] = var def __repr__(self): - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def render(self, context): values = {key: val.resolve(context) for key, val in self.extra_context.items()} @@ -525,11 +550,11 @@ def autoescape(parser, token): if len(args) != 2: raise TemplateSyntaxError("'autoescape' tag requires exactly one argument.") arg = args[1] - if arg not in ('on', 'off'): + if arg not in ("on", "off"): raise TemplateSyntaxError("'autoescape' argument should be 'on' or 'off'") - nodelist = parser.parse(('endautoescape',)) + nodelist = parser.parse(("endautoescape",)) parser.delete_first_token() - return AutoEscapeControlNode((arg == 'on'), nodelist) + return AutoEscapeControlNode((arg == "on"), nodelist) @register.tag @@ -537,7 +562,7 @@ def comment(parser, token): """ Ignore everything between ``{% comment %}`` and ``{% endcomment %}``. """ - parser.skip_past('endcomment') + parser.skip_past("endcomment") return CommentNode() @@ -595,8 +620,10 @@ def cycle(parser, token): if len(args) == 2: # {% cycle foo %} case. name = args[1] - if not hasattr(parser, '_named_cycle_nodes'): - raise TemplateSyntaxError("No named cycles in template. '%s' is not defined" % name) + if not hasattr(parser, "_named_cycle_nodes"): + raise TemplateSyntaxError( + "No named cycles in template. '%s' is not defined" % name + ) if name not in parser._named_cycle_nodes: raise TemplateSyntaxError("Named cycle '%s' does not exist" % name) return parser._named_cycle_nodes[name] @@ -607,7 +634,10 @@ def cycle(parser, token): # {% cycle ... as foo [silent] %} case. if args[-3] == "as": if args[-1] != "silent": - raise TemplateSyntaxError("Only 'silent' flag is allowed after cycle's name, not '%s'." % args[-1]) + raise TemplateSyntaxError( + "Only 'silent' flag is allowed after cycle's name, not '%s'." + % args[-1] + ) as_form = True silent = True args = args[:-1] @@ -619,7 +649,7 @@ def cycle(parser, token): name = args[-1] values = [parser.compile_filter(arg) for arg in args[1:-2]] node = CycleNode(values, name, silent=silent) - if not hasattr(parser, '_named_cycle_nodes'): + if not hasattr(parser, "_named_cycle_nodes"): parser._named_cycle_nodes = {} parser._named_cycle_nodes[name] = node else: @@ -649,7 +679,7 @@ def debug(parser, token): return DebugNode() -@register.tag('filter') +@register.tag("filter") def do_filter(parser, token): """ Filter the contents of the block through variable filters. @@ -671,10 +701,13 @@ def do_filter(parser, token): _, rest = token.contents.split(None, 1) filter_expr = parser.compile_filter("var|%s" % (rest)) for func, unused in filter_expr.filters: - filter_name = getattr(func, '_filter_name', None) - if filter_name in ('escape', 'safe'): - raise TemplateSyntaxError('"filter %s" is not permitted. Use the "autoescape" tag instead.' % filter_name) - nodelist = parser.parse(('endfilter',)) + filter_name = getattr(func, "_filter_name", None) + if filter_name in ("escape", "safe"): + raise TemplateSyntaxError( + '"filter %s" is not permitted. Use the "autoescape" tag instead.' + % filter_name + ) + nodelist = parser.parse(("endfilter",)) parser.delete_first_token() return FilterNode(filter_expr, nodelist) @@ -722,13 +755,13 @@ def firstof(parser, token): if not bits: raise TemplateSyntaxError("'firstof' statement requires at least one argument") - if len(bits) >= 2 and bits[-2] == 'as': + if len(bits) >= 2 and bits[-2] == "as": asvar = bits[-1] bits = bits[:-2] return FirstOfNode([parser.compile_filter(bit) for bit in bits], asvar) -@register.tag('for') +@register.tag("for") def do_for(parser, token): """ Loop over each item in an array. @@ -797,14 +830,16 @@ def do_for(parser, token): "'for' statements should have at least four words: %s" % token.contents ) - is_reversed = bits[-1] == 'reversed' + is_reversed = bits[-1] == "reversed" in_index = -3 if is_reversed else -2 - if bits[in_index] != 'in': - raise TemplateSyntaxError("'for' statements should use the format" - " 'for x in y': %s" % token.contents) + if bits[in_index] != "in": + raise TemplateSyntaxError( + "'for' statements should use the format" + " 'for x in y': %s" % token.contents + ) - invalid_chars = frozenset((' ', '"', "'", FILTER_SEPARATOR)) - loopvars = re.split(r' *, *', ' '.join(bits[1:in_index])) + invalid_chars = frozenset((" ", '"', "'", FILTER_SEPARATOR)) + loopvars = re.split(r" *, *", " ".join(bits[1:in_index])) for var in loopvars: if not var or not invalid_chars.isdisjoint(var): raise TemplateSyntaxError( @@ -812,10 +847,15 @@ def do_for(parser, token): ) sequence = parser.compile_filter(bits[in_index + 1]) - nodelist_loop = parser.parse(('empty', 'endfor',)) + nodelist_loop = parser.parse( + ( + "empty", + "endfor", + ) + ) token = parser.next_token() - if token.contents == 'empty': - nodelist_empty = parser.parse(('endfor',)) + if token.contents == "empty": + nodelist_empty = parser.parse(("endfor",)) parser.delete_first_token() else: nodelist_empty = None @@ -845,7 +885,7 @@ class TemplateIfParser(IfParser): return TemplateLiteral(self.template_parser.compile_filter(value), value) -@register.tag('if') +@register.tag("if") def do_if(parser, token): """ Evaluate a variable, and if that variable is "true" (i.e., exists, is not @@ -907,27 +947,31 @@ def do_if(parser, token): # {% if ... %} bits = token.split_contents()[1:] condition = TemplateIfParser(parser, bits).parse() - nodelist = parser.parse(('elif', 'else', 'endif')) + nodelist = parser.parse(("elif", "else", "endif")) conditions_nodelists = [(condition, nodelist)] token = parser.next_token() # {% elif ... %} (repeatable) - while token.contents.startswith('elif'): + while token.contents.startswith("elif"): bits = token.split_contents()[1:] condition = TemplateIfParser(parser, bits).parse() - nodelist = parser.parse(('elif', 'else', 'endif')) + nodelist = parser.parse(("elif", "else", "endif")) conditions_nodelists.append((condition, nodelist)) token = parser.next_token() # {% else %} (optional) - if token.contents == 'else': - nodelist = parser.parse(('endif',)) + if token.contents == "else": + nodelist = parser.parse(("endif",)) conditions_nodelists.append((None, nodelist)) token = parser.next_token() # {% endif %} - if token.contents != 'endif': - raise TemplateSyntaxError('Malformed template tag at line {}: "{}"'.format(token.lineno, token.contents)) + if token.contents != "endif": + raise TemplateSyntaxError( + 'Malformed template tag at line {}: "{}"'.format( + token.lineno, token.contents + ) + ) return IfNode(conditions_nodelists) @@ -963,10 +1007,10 @@ def ifchanged(parser, token): {% endfor %} """ bits = token.split_contents() - nodelist_true = parser.parse(('else', 'endifchanged')) + nodelist_true = parser.parse(("else", "endifchanged")) token = parser.next_token() - if token.contents == 'else': - nodelist_false = parser.parse(('endifchanged',)) + if token.contents == "else": + nodelist_false = parser.parse(("endifchanged",)) parser.delete_first_token() else: nodelist_false = NodeList() @@ -979,8 +1023,10 @@ def find_library(parser, name): return parser.libraries[name] except KeyError: raise TemplateSyntaxError( - "'%s' is not a registered tag library. Must be one of:\n%s" % ( - name, "\n".join(sorted(parser.libraries)), + "'%s' is not a registered tag library. Must be one of:\n%s" + % ( + name, + "\n".join(sorted(parser.libraries)), ), ) @@ -1000,8 +1046,10 @@ def load_from_library(library, label, names): subset.filters[name] = library.filters[name] if found is False: raise TemplateSyntaxError( - "'%s' is not a valid tag or filter in tag library '%s'" % ( - name, label, + "'%s' is not a valid tag or filter in tag library '%s'" + % ( + name, + label, ), ) return subset @@ -1066,19 +1114,19 @@ def lorem(parser, token): bits = list(token.split_contents()) tagname = bits[0] # Random bit - common = bits[-1] != 'random' + common = bits[-1] != "random" if not common: bits.pop() # Method bit - if bits[-1] in ('w', 'p', 'b'): + if bits[-1] in ("w", "p", "b"): method = bits.pop() else: - method = 'b' + method = "b" # Count bit if len(bits) > 1: count = bits.pop() else: - count = '1' + count = "1" count = parser.compile_filter(count) if len(bits) != 1: raise TemplateSyntaxError("Incorrect format for %r tag" % tagname) @@ -1099,7 +1147,7 @@ def now(parser, token): """ bits = token.split_contents() asvar = None - if len(bits) == 4 and bits[-2] == 'as': + if len(bits) == 4 and bits[-2] == "as": asvar = bits[-1] bits = bits[:-2] if len(bits) != 2: @@ -1159,12 +1207,10 @@ def regroup(parser, token): if len(bits) != 6: raise TemplateSyntaxError("'regroup' tag takes five arguments") target = parser.compile_filter(bits[1]) - if bits[2] != 'by': + if bits[2] != "by": raise TemplateSyntaxError("second argument to 'regroup' tag must be 'by'") - if bits[4] != 'as': - raise TemplateSyntaxError( - "next-to-last argument to 'regroup' tag must be 'as'" - ) + if bits[4] != "as": + raise TemplateSyntaxError("next-to-last argument to 'regroup' tag must be 'as'") var_name = bits[5] # RegroupNode will take each item in 'target', put it in the context under # 'var_name', evaluate 'var_name'.'expression' in the current context, and @@ -1172,9 +1218,9 @@ def regroup(parser, token): # save the final result in the context under 'var_name', thus clearing the # temporary values. This hack is necessary because the template engine # doesn't provide a context-aware equivalent of Python's getattr. - expression = parser.compile_filter(var_name + - VARIABLE_ATTRIBUTE_SEPARATOR + - bits[3]) + expression = parser.compile_filter( + var_name + VARIABLE_ATTRIBUTE_SEPARATOR + bits[3] + ) return RegroupNode(target, expression, var_name) @@ -1230,7 +1276,7 @@ def spaceless(parser, token): </strong> {% endspaceless %} """ - nodelist = parser.parse(('endspaceless',)) + nodelist = parser.parse(("endspaceless",)) parser.delete_first_token() return SpacelessNode(nodelist) @@ -1264,9 +1310,10 @@ def templatetag(parser, token): raise TemplateSyntaxError("'templatetag' statement takes one argument") tag = bits[1] if tag not in TemplateTagNode.mapping: - raise TemplateSyntaxError("Invalid templatetag argument: '%s'." - " Must be one of: %s" % - (tag, list(TemplateTagNode.mapping))) + raise TemplateSyntaxError( + "Invalid templatetag argument: '%s'." + " Must be one of: %s" % (tag, list(TemplateTagNode.mapping)) + ) return TemplateTagNode(tag) @@ -1314,13 +1361,15 @@ def url(parser, token): """ bits = token.split_contents() if len(bits) < 2: - raise TemplateSyntaxError("'%s' takes at least one argument, a URL pattern name." % bits[0]) + raise TemplateSyntaxError( + "'%s' takes at least one argument, a URL pattern name." % bits[0] + ) viewname = parser.compile_filter(bits[1]) args = [] kwargs = {} asvar = None bits = bits[2:] - if len(bits) >= 2 and bits[-2] == 'as': + if len(bits) >= 2 and bits[-2] == "as": asvar = bits[-1] bits = bits[:-2] @@ -1355,7 +1404,7 @@ def verbatim(parser, token): ... {% endverbatim myblock %} """ - nodelist = parser.parse(('endverbatim',)) + nodelist = parser.parse(("endverbatim",)) parser.delete_first_token() return VerbatimNode(nodelist.render(Context())) @@ -1387,18 +1436,22 @@ def widthratio(parser, token): asvar = None elif len(bits) == 6: tag, this_value_expr, max_value_expr, max_width, as_, asvar = bits - if as_ != 'as': - raise TemplateSyntaxError("Invalid syntax in widthratio tag. Expecting 'as' keyword") + if as_ != "as": + raise TemplateSyntaxError( + "Invalid syntax in widthratio tag. Expecting 'as' keyword" + ) else: raise TemplateSyntaxError("widthratio takes at least three arguments") - return WidthRatioNode(parser.compile_filter(this_value_expr), - parser.compile_filter(max_value_expr), - parser.compile_filter(max_width), - asvar=asvar) + return WidthRatioNode( + parser.compile_filter(this_value_expr), + parser.compile_filter(max_value_expr), + parser.compile_filter(max_width), + asvar=asvar, + ) -@register.tag('with') +@register.tag("with") def do_with(parser, token): """ Add one or more values to the context (inside of this block) for caching @@ -1427,8 +1480,9 @@ def do_with(parser, token): "%r expected at least one variable assignment" % bits[0] ) if remaining_bits: - raise TemplateSyntaxError("%r received an invalid token: %r" % - (bits[0], remaining_bits[0])) - nodelist = parser.parse(('endwith',)) + raise TemplateSyntaxError( + "%r received an invalid token: %r" % (bits[0], remaining_bits[0]) + ) + nodelist = parser.parse(("endwith",)) parser.delete_first_token() return WithNode(None, None, nodelist, extra_context=extra_context) diff --git a/django/template/engine.py b/django/template/engine.py index 91e503f709..9e6f1e97da 100644 --- a/django/template/engine.py +++ b/django/template/engine.py @@ -12,28 +12,39 @@ from .library import import_library class Engine: default_builtins = [ - 'django.template.defaulttags', - 'django.template.defaultfilters', - 'django.template.loader_tags', + "django.template.defaulttags", + "django.template.defaultfilters", + "django.template.loader_tags", ] - def __init__(self, dirs=None, app_dirs=False, context_processors=None, - debug=False, loaders=None, string_if_invalid='', - file_charset='utf-8', libraries=None, builtins=None, autoescape=True): + def __init__( + self, + dirs=None, + app_dirs=False, + context_processors=None, + debug=False, + loaders=None, + string_if_invalid="", + file_charset="utf-8", + libraries=None, + builtins=None, + autoescape=True, + ): if dirs is None: dirs = [] if context_processors is None: context_processors = [] if loaders is None: - loaders = ['django.template.loaders.filesystem.Loader'] + loaders = ["django.template.loaders.filesystem.Loader"] if app_dirs: - loaders += ['django.template.loaders.app_directories.Loader'] + loaders += ["django.template.loaders.app_directories.Loader"] if not debug: - loaders = [('django.template.loaders.cached.Loader', loaders)] + loaders = [("django.template.loaders.cached.Loader", loaders)] else: if app_dirs: raise ImproperlyConfigured( - "app_dirs must not be set when loaders is defined.") + "app_dirs must not be set when loaders is defined." + ) if libraries is None: libraries = {} if builtins is None: @@ -54,21 +65,21 @@ class Engine: def __repr__(self): return ( - '<%s:%s app_dirs=%s%s debug=%s loaders=%s string_if_invalid=%s ' - 'file_charset=%s%s%s autoescape=%s>' + "<%s:%s app_dirs=%s%s debug=%s loaders=%s string_if_invalid=%s " + "file_charset=%s%s%s autoescape=%s>" ) % ( self.__class__.__qualname__, - '' if not self.dirs else ' dirs=%s' % repr(self.dirs), + "" if not self.dirs else " dirs=%s" % repr(self.dirs), self.app_dirs, - '' + "" if not self.context_processors - else ' context_processors=%s' % repr(self.context_processors), + else " context_processors=%s" % repr(self.context_processors), self.debug, repr(self.loaders), repr(self.string_if_invalid), repr(self.file_charset), - '' if not self.libraries else ' libraries=%s' % repr(self.libraries), - '' if not self.builtins else ' builtins=%s' % repr(self.builtins), + "" if not self.libraries else " libraries=%s" % repr(self.libraries), + "" if not self.builtins else " builtins=%s" % repr(self.builtins), repr(self.autoescape), ) @@ -93,10 +104,11 @@ class Engine: # local imports are required to avoid import loops. from django.template import engines from django.template.backends.django import DjangoTemplates + for engine in engines.all(): if isinstance(engine, DjangoTemplates): return engine.engine - raise ImproperlyConfigured('No DjangoTemplates backend is configured.') + raise ImproperlyConfigured("No DjangoTemplates backend is configured.") @cached_property def template_context_processors(self): @@ -136,7 +148,8 @@ class Engine: return loader_class(self, *args) else: raise ImproperlyConfigured( - "Invalid value in template loaders configuration: %r" % loader) + "Invalid value in template loaders configuration: %r" % loader + ) def find_template(self, name, dirs=None, skip=None): tried = [] @@ -161,7 +174,7 @@ class Engine: handling template inheritance recursively. """ template, origin = self.find_template(template_name) - if not hasattr(template, 'render'): + if not hasattr(template, "render"): # template needs to be compiled template = Template(template, origin, template_name, engine=self) return template @@ -197,4 +210,4 @@ class Engine: not_found.append(exc.args[0]) continue # If we get here, none of the templates could be loaded - raise TemplateDoesNotExist(', '.join(not_found)) + raise TemplateDoesNotExist(", ".join(not_found)) diff --git a/django/template/exceptions.py b/django/template/exceptions.py index 97edc9eba4..2a9c92f779 100644 --- a/django/template/exceptions.py +++ b/django/template/exceptions.py @@ -24,6 +24,7 @@ class TemplateDoesNotExist(Exception): encapsulate multiple exceptions when loading templates from multiple engines. """ + def __init__(self, msg, tried=None, backend=None, chain=None): self.backend = backend if tried is None: @@ -39,4 +40,5 @@ class TemplateSyntaxError(Exception): """ The exception used for syntax errors during parsing or rendering. """ + pass diff --git a/django/template/library.py b/django/template/library.py index 06ea5f1ad8..fbec9484a1 100644 --- a/django/template/library.py +++ b/django/template/library.py @@ -20,6 +20,7 @@ class Library: The filter, simple_tag, and inclusion_tag methods provide a convenient way to register callables as tags. """ + def __init__(self): self.filters = {} self.tags = {} @@ -36,6 +37,7 @@ class Library: # @register.tag('somename') or @register.tag(name='somename') def dec(func): return self.tag(name, func) + return dec elif name is not None and compile_function is not None: # register.tag('somename', somefunc) @@ -43,8 +45,8 @@ class Library: return compile_function else: raise ValueError( - "Unsupported arguments to Library.tag: (%r, %r)" % - (name, compile_function), + "Unsupported arguments to Library.tag: (%r, %r)" + % (name, compile_function), ) def tag_function(self, func): @@ -63,6 +65,7 @@ class Library: # @register.filter() def dec(func): return self.filter_function(func, **flags) + return dec elif name is not None and filter_func is None: if callable(name): @@ -72,11 +75,12 @@ class Library: # @register.filter('somename') or @register.filter(name='somename') def dec(func): return self.filter(name, func, **flags) + return dec elif name is not None and filter_func is not None: # register.filter('somename', somefunc) self.filters[name] = filter_func - for attr in ('expects_localtime', 'is_safe', 'needs_autoescape'): + for attr in ("expects_localtime", "is_safe", "needs_autoescape"): if attr in flags: value = flags[attr] # set the flag on the filter for FilterExpression.resolve @@ -88,8 +92,8 @@ class Library: return filter_func else: raise ValueError( - "Unsupported arguments to Library.filter: (%r, %r)" % - (name, filter_func), + "Unsupported arguments to Library.filter: (%r, %r)" + % (name, filter_func), ) def filter_function(self, func, **flags): @@ -103,22 +107,40 @@ class Library: def hello(*args, **kwargs): return 'world' """ + def dec(func): - params, varargs, varkw, defaults, kwonly, kwonly_defaults, _ = getfullargspec(unwrap(func)) - function_name = (name or func.__name__) + ( + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + _, + ) = getfullargspec(unwrap(func)) + function_name = name or func.__name__ @functools.wraps(func) def compile_func(parser, token): bits = token.split_contents()[1:] target_var = None - if len(bits) >= 2 and bits[-2] == 'as': + if len(bits) >= 2 and bits[-2] == "as": target_var = bits[-1] bits = bits[:-2] args, kwargs = parse_bits( - parser, bits, params, varargs, varkw, defaults, - kwonly, kwonly_defaults, takes_context, function_name, + parser, + bits, + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + takes_context, + function_name, ) return SimpleNode(func, takes_context, args, kwargs, target_var) + self.tag(function_name, compile_func) return func @@ -140,22 +162,45 @@ class Library: choices = poll.choice_set.all() return {'choices': choices} """ + def dec(func): - params, varargs, varkw, defaults, kwonly, kwonly_defaults, _ = getfullargspec(unwrap(func)) + ( + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + _, + ) = getfullargspec(unwrap(func)) function_name = name or func.__name__ @functools.wraps(func) def compile_func(parser, token): bits = token.split_contents()[1:] args, kwargs = parse_bits( - parser, bits, params, varargs, varkw, defaults, - kwonly, kwonly_defaults, takes_context, function_name, + parser, + bits, + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + takes_context, + function_name, ) return InclusionNode( - func, takes_context, args, kwargs, filename, + func, + takes_context, + args, + kwargs, + filename, ) + self.tag(function_name, compile_func) return func + return dec @@ -165,6 +210,7 @@ class TagHelperNode(Node): Manages the positional and keyword arguments to be passed to the decorated function. """ + def __init__(self, func, takes_context, args, kwargs): self.func = func self.takes_context = takes_context @@ -191,14 +237,13 @@ class SimpleNode(TagHelperNode): output = self.func(*resolved_args, **resolved_kwargs) if self.target_var is not None: context[self.target_var] = output - return '' + return "" if context.autoescape: output = conditional_escape(output) return output class InclusionNode(TagHelperNode): - def __init__(self, func, takes_context, args, kwargs, filename): super().__init__(func, takes_context, args, kwargs) self.filename = filename @@ -216,7 +261,7 @@ class InclusionNode(TagHelperNode): if t is None: if isinstance(self.filename, Template): t = self.filename - elif isinstance(getattr(self.filename, 'template', None), Template): + elif isinstance(getattr(self.filename, "template", None), Template): t = self.filename.template elif not isinstance(self.filename, str) and is_iterable(self.filename): t = context.template.engine.select_template(self.filename) @@ -227,32 +272,42 @@ class InclusionNode(TagHelperNode): # Copy across the CSRF token, if present, because inclusion tags are # often used for forms, and we need instructions for using CSRF # protection to be as simple as possible. - csrf_token = context.get('csrf_token') + csrf_token = context.get("csrf_token") if csrf_token is not None: - new_context['csrf_token'] = csrf_token + new_context["csrf_token"] = csrf_token return t.render(new_context) -def parse_bits(parser, bits, params, varargs, varkw, defaults, - kwonly, kwonly_defaults, takes_context, name): +def parse_bits( + parser, + bits, + params, + varargs, + varkw, + defaults, + kwonly, + kwonly_defaults, + takes_context, + name, +): """ Parse bits for template tag helpers simple_tag and inclusion_tag, in particular by detecting syntax errors and by extracting positional and keyword arguments. """ if takes_context: - if params and params[0] == 'context': + if params and params[0] == "context": params = params[1:] else: raise TemplateSyntaxError( "'%s' is decorated with takes_context=True so it must " - "have a first argument of 'context'" % name) + "have a first argument of 'context'" % name + ) args = [] kwargs = {} unhandled_params = list(params) unhandled_kwargs = [ - kwarg for kwarg in kwonly - if not kwonly_defaults or kwarg not in kwonly_defaults + kwarg for kwarg in kwonly if not kwonly_defaults or kwarg not in kwonly_defaults ] for bit in bits: # First we try to extract a potential kwarg from the bit @@ -263,13 +318,14 @@ def parse_bits(parser, bits, params, varargs, varkw, defaults, if param not in params and param not in kwonly and varkw is None: # An unexpected keyword argument was supplied raise TemplateSyntaxError( - "'%s' received unexpected keyword argument '%s'" % - (name, param)) + "'%s' received unexpected keyword argument '%s'" % (name, param) + ) elif param in kwargs: # The keyword argument has already been supplied once raise TemplateSyntaxError( - "'%s' received multiple values for keyword argument '%s'" % - (name, param)) + "'%s' received multiple values for keyword argument '%s'" + % (name, param) + ) else: # All good, record the keyword argument kwargs[str(param)] = value @@ -284,7 +340,8 @@ def parse_bits(parser, bits, params, varargs, varkw, defaults, if kwargs: raise TemplateSyntaxError( "'%s' received some positional argument(s) after some " - "keyword argument(s)" % name) + "keyword argument(s)" % name + ) else: # Record the positional argument args.append(parser.compile_filter(bit)) @@ -294,17 +351,18 @@ def parse_bits(parser, bits, params, varargs, varkw, defaults, except IndexError: if varargs is None: raise TemplateSyntaxError( - "'%s' received too many positional arguments" % - name) + "'%s' received too many positional arguments" % name + ) if defaults is not None: # Consider the last n params handled, where n is the # number of defaults. - unhandled_params = unhandled_params[:-len(defaults)] + unhandled_params = unhandled_params[: -len(defaults)] if unhandled_params or unhandled_kwargs: # Some positional arguments were not supplied raise TemplateSyntaxError( - "'%s' did not receive value(s) for the argument(s): %s" % - (name, ", ".join("'%s'" % p for p in unhandled_params + unhandled_kwargs))) + "'%s' did not receive value(s) for the argument(s): %s" + % (name, ", ".join("'%s'" % p for p in unhandled_params + unhandled_kwargs)) + ) return args, kwargs diff --git a/django/template/loader.py b/django/template/loader.py index 2492aee760..9b108f3456 100644 --- a/django/template/loader.py +++ b/django/template/loader.py @@ -29,9 +29,9 @@ def select_template(template_name_list, using=None): """ if isinstance(template_name_list, str): raise TypeError( - 'select_template() takes an iterable of template names but got a ' - 'string: %r. Use get_template() if you want to load a single ' - 'template by name.' % template_name_list + "select_template() takes an iterable of template names but got a " + "string: %r. Use get_template() if you want to load a single " + "template by name." % template_name_list ) chain = [] @@ -44,7 +44,7 @@ def select_template(template_name_list, using=None): chain.append(e) if template_name_list: - raise TemplateDoesNotExist(', '.join(template_name_list), chain=chain) + raise TemplateDoesNotExist(", ".join(template_name_list), chain=chain) else: raise TemplateDoesNotExist("No template names provided") diff --git a/django/template/loader_tags.py b/django/template/loader_tags.py index 37cefaf9c7..bf4d4f0d74 100644 --- a/django/template/loader_tags.py +++ b/django/template/loader_tags.py @@ -3,14 +3,12 @@ from collections import defaultdict from django.utils.safestring import mark_safe -from .base import ( - Node, Template, TemplateSyntaxError, TextNode, Variable, token_kwargs, -) +from .base import Node, Template, TemplateSyntaxError, TextNode, Variable, token_kwargs from .library import Library register = Library() -BLOCK_CONTEXT_KEY = 'block_context' +BLOCK_CONTEXT_KEY = "block_context" class BlockContext: @@ -19,7 +17,7 @@ class BlockContext: self.blocks = defaultdict(list) def __repr__(self): - return f'<{self.__class__.__qualname__}: blocks={self.blocks!r}>' + return f"<{self.__class__.__qualname__}: blocks={self.blocks!r}>" def add_blocks(self, blocks): for name, block in blocks.items(): @@ -52,7 +50,7 @@ class BlockNode(Node): block_context = context.render_context.get(BLOCK_CONTEXT_KEY) with context.push(): if block_context is None: - context['block'] = self + context["block"] = self result = self.nodelist.render(context) else: push = block = block_context.pop(self.name) @@ -61,28 +59,30 @@ class BlockNode(Node): # Create new block so we can store context without thread-safety issues. block = type(self)(block.name, block.nodelist) block.context = context - context['block'] = block + context["block"] = block result = block.nodelist.render(context) if push is not None: block_context.push(self.name, push) return result def super(self): - if not hasattr(self, 'context'): + if not hasattr(self, "context"): raise TemplateSyntaxError( "'%s' object has no attribute 'context'. Did you use " "{{ block.super }} in a base template?" % self.__class__.__name__ ) render_context = self.context.render_context - if (BLOCK_CONTEXT_KEY in render_context and - render_context[BLOCK_CONTEXT_KEY].get_block(self.name) is not None): + if ( + BLOCK_CONTEXT_KEY in render_context + and render_context[BLOCK_CONTEXT_KEY].get_block(self.name) is not None + ): return mark_safe(self.render(self.context)) - return '' + return "" class ExtendsNode(Node): must_be_first = True - context_key = 'extends_context' + context_key = "extends_context" def __init__(self, nodelist, parent_name, template_dirs=None): self.nodelist = nodelist @@ -91,7 +91,7 @@ class ExtendsNode(Node): self.blocks = {n.name: n for n in nodelist.get_nodes_by_type(BlockNode)} def __repr__(self): - return '<%s: extends %s>' % (self.__class__.__name__, self.parent_name.token) + return "<%s: extends %s>" % (self.__class__.__name__, self.parent_name.token) def find_template(self, template_name, context): """ @@ -101,10 +101,12 @@ class ExtendsNode(Node): without extending the same template twice. """ history = context.render_context.setdefault( - self.context_key, [self.origin], + self.context_key, + [self.origin], ) template, origin = context.template.engine.find_template( - template_name, skip=history, + template_name, + skip=history, ) history.append(origin) return template @@ -113,15 +115,15 @@ class ExtendsNode(Node): parent = self.parent_name.resolve(context) if not parent: error_msg = "Invalid template name in 'extends' tag: %r." % parent - if self.parent_name.filters or\ - isinstance(self.parent_name.var, Variable): - error_msg += " Got this from the '%s' variable." %\ - self.parent_name.token + if self.parent_name.filters or isinstance(self.parent_name.var, Variable): + error_msg += ( + " Got this from the '%s' variable." % self.parent_name.token + ) raise TemplateSyntaxError(error_msg) if isinstance(parent, Template): # parent is a django.template.Template return parent - if isinstance(getattr(parent, 'template', None), Template): + if isinstance(getattr(parent, "template", None), Template): # parent is a django.template.backends.django.Template return parent.template return self.find_template(parent, context) @@ -142,8 +144,10 @@ class ExtendsNode(Node): # The ExtendsNode has to be the first non-text node. if not isinstance(node, TextNode): if not isinstance(node, ExtendsNode): - blocks = {n.name: n for n in - compiled_parent.nodelist.get_nodes_by_type(BlockNode)} + blocks = { + n.name: n + for n in compiled_parent.nodelist.get_nodes_by_type(BlockNode) + } block_context.add_blocks(blocks) break @@ -154,16 +158,18 @@ class ExtendsNode(Node): class IncludeNode(Node): - context_key = '__include_context' + context_key = "__include_context" - def __init__(self, template, *args, extra_context=None, isolated_context=False, **kwargs): + def __init__( + self, template, *args, extra_context=None, isolated_context=False, **kwargs + ): self.template = template self.extra_context = extra_context or {} self.isolated_context = isolated_context super().__init__(*args, **kwargs) def __repr__(self): - return f'<{self.__class__.__qualname__}: template={self.template!r}>' + return f"<{self.__class__.__qualname__}: template={self.template!r}>" def render(self, context): """ @@ -173,14 +179,16 @@ class IncludeNode(Node): """ template = self.template.resolve(context) # Does this quack like a Template? - if not callable(getattr(template, 'render', None)): + if not callable(getattr(template, "render", None)): # If not, try the cache and select_template(). template_name = template or () if isinstance(template_name, str): - template_name = (construct_relative_path( - self.origin.template_name, - template_name, - ),) + template_name = ( + construct_relative_path( + self.origin.template_name, + template_name, + ), + ) else: template_name = tuple(template_name) cache = context.render_context.dicts[0].setdefault(self, {}) @@ -189,11 +197,10 @@ class IncludeNode(Node): template = context.template.engine.select_template(template_name) cache[template_name] = template # Use the base.Template of a backends.django.Template. - elif hasattr(template, 'template'): + elif hasattr(template, "template"): template = template.template values = { - name: var.resolve(context) - for name, var in self.extra_context.items() + name: var.resolve(context) for name, var in self.extra_context.items() } if self.isolated_context: return template.render(context.new(values)) @@ -201,7 +208,7 @@ class IncludeNode(Node): return template.render(context) -@register.tag('block') +@register.tag("block") def do_block(parser, token): """ Define a block that can be overridden by child templates. @@ -215,17 +222,19 @@ def do_block(parser, token): # check for duplication. try: if block_name in parser.__loaded_blocks: - raise TemplateSyntaxError("'%s' tag with name '%s' appears more than once" % (bits[0], block_name)) + raise TemplateSyntaxError( + "'%s' tag with name '%s' appears more than once" % (bits[0], block_name) + ) parser.__loaded_blocks.append(block_name) except AttributeError: # parser.__loaded_blocks isn't a list yet parser.__loaded_blocks = [block_name] - nodelist = parser.parse(('endblock',)) + nodelist = parser.parse(("endblock",)) # This check is kept for backwards-compatibility. See #3100. endblock = parser.next_token() - acceptable_endblocks = ('endblock', 'endblock %s' % block_name) + acceptable_endblocks = ("endblock", "endblock %s" % block_name) if endblock.contents not in acceptable_endblocks: - parser.invalid_block_tag(endblock, 'endblock', acceptable_endblocks) + parser.invalid_block_tag(endblock, "endblock", acceptable_endblocks) return BlockNode(block_name, nodelist) @@ -235,37 +244,36 @@ def construct_relative_path(current_template_name, relative_name): Convert a relative path (starting with './' or '../') to the full template name based on the current_template_name. """ - new_name = relative_name.strip('\'"') - if not new_name.startswith(('./', '../')): + new_name = relative_name.strip("'\"") + if not new_name.startswith(("./", "../")): # relative_name is a variable or a literal that doesn't contain a # relative path. return relative_name new_name = posixpath.normpath( posixpath.join( - posixpath.dirname(current_template_name.lstrip('/')), + posixpath.dirname(current_template_name.lstrip("/")), new_name, ) ) - if new_name.startswith('../'): + if new_name.startswith("../"): raise TemplateSyntaxError( "The relative path '%s' points outside the file hierarchy that " "template '%s' is in." % (relative_name, current_template_name) ) - if current_template_name.lstrip('/') == new_name: + if current_template_name.lstrip("/") == new_name: raise TemplateSyntaxError( "The relative path '%s' was translated to template name '%s', the " "same template in which the tag appears." % (relative_name, current_template_name) ) has_quotes = ( - relative_name.startswith(('"', "'")) and - relative_name[0] == relative_name[-1] + relative_name.startswith(('"', "'")) and relative_name[0] == relative_name[-1] ) return f'"{new_name}"' if has_quotes else new_name -@register.tag('extends') +@register.tag("extends") def do_extends(parser, token): """ Signal that this template extends a parent template. @@ -283,11 +291,13 @@ def do_extends(parser, token): parent_name = parser.compile_filter(bits[1]) nodelist = parser.parse() if nodelist.get_nodes_by_type(ExtendsNode): - raise TemplateSyntaxError("'%s' cannot appear more than once in the same template" % bits[0]) + raise TemplateSyntaxError( + "'%s' cannot appear more than once in the same template" % bits[0] + ) return ExtendsNode(nodelist, parent_name) -@register.tag('include') +@register.tag("include") def do_include(parser, token): """ Load a template and render it with the current context. You can pass @@ -316,22 +326,26 @@ def do_include(parser, token): option = remaining_bits.pop(0) if option in options: raise TemplateSyntaxError( - 'The %r option was specified more than once.' % option + "The %r option was specified more than once." % option ) - if option == 'with': + if option == "with": value = token_kwargs(remaining_bits, parser, support_legacy=False) if not value: raise TemplateSyntaxError( '"with" in %r tag needs at least one keyword argument.' % bits[0] ) - elif option == 'only': + elif option == "only": value = True else: - raise TemplateSyntaxError('Unknown argument for %r tag: %r.' % - (bits[0], option)) + raise TemplateSyntaxError( + "Unknown argument for %r tag: %r." % (bits[0], option) + ) options[option] = value - isolated_context = options.get('only', False) - namemap = options.get('with', {}) + isolated_context = options.get("only", False) + namemap = options.get("with", {}) bits[1] = construct_relative_path(parser.origin.template_name, bits[1]) - return IncludeNode(parser.compile_filter(bits[1]), extra_context=namemap, - isolated_context=isolated_context) + return IncludeNode( + parser.compile_filter(bits[1]), + extra_context=namemap, + isolated_context=isolated_context, + ) diff --git a/django/template/loaders/app_directories.py b/django/template/loaders/app_directories.py index c9a8adf49c..0bd7dc8e25 100644 --- a/django/template/loaders/app_directories.py +++ b/django/template/loaders/app_directories.py @@ -9,6 +9,5 @@ from .filesystem import Loader as FilesystemLoader class Loader(FilesystemLoader): - def get_dirs(self): - return get_app_template_dirs('templates') + return get_app_template_dirs("templates") diff --git a/django/template/loaders/base.py b/django/template/loaders/base.py index b77ea9eca2..9168c26c54 100644 --- a/django/template/loaders/base.py +++ b/django/template/loaders/base.py @@ -2,7 +2,6 @@ from django.template import Template, TemplateDoesNotExist class Loader: - def __init__(self, engine): self.engine = engine @@ -17,17 +16,20 @@ class Loader: for origin in self.get_template_sources(template_name): if skip is not None and origin in skip: - tried.append((origin, 'Skipped to avoid recursion')) + tried.append((origin, "Skipped to avoid recursion")) continue try: contents = self.get_contents(origin) except TemplateDoesNotExist: - tried.append((origin, 'Source does not exist')) + tried.append((origin, "Source does not exist")) continue else: return Template( - contents, origin, origin.template_name, self.engine, + contents, + origin, + origin.template_name, + self.engine, ) raise TemplateDoesNotExist(template_name, tried=tried) @@ -38,7 +40,7 @@ class Loader: template name. """ raise NotImplementedError( - 'subclasses of Loader must provide a get_template_sources() method' + "subclasses of Loader must provide a get_template_sources() method" ) def reset(self): diff --git a/django/template/loaders/cached.py b/django/template/loaders/cached.py index bb47682d7f..4f40953831 100644 --- a/django/template/loaders/cached.py +++ b/django/template/loaders/cached.py @@ -12,7 +12,6 @@ from .base import Loader as BaseLoader class Loader(BaseLoader): - def __init__(self, engine, loaders): self.get_template_cache = {} self.loaders = engine.get_template_loaders(loaders) @@ -57,7 +56,9 @@ class Loader(BaseLoader): try: template = super().get_template(template_name, skip) except TemplateDoesNotExist as e: - self.get_template_cache[key] = copy_exception(e) if self.engine.debug else TemplateDoesNotExist + self.get_template_cache[key] = ( + copy_exception(e) if self.engine.debug else TemplateDoesNotExist + ) raise else: self.get_template_cache[key] = template @@ -80,17 +81,19 @@ class Loader(BaseLoader): y -> a -> a z -> a -> a """ - skip_prefix = '' + skip_prefix = "" if skip: - matching = [origin.name for origin in skip if origin.template_name == template_name] + matching = [ + origin.name for origin in skip if origin.template_name == template_name + ] if matching: skip_prefix = self.generate_hash(matching) - return '-'.join(s for s in (str(template_name), skip_prefix) if s) + return "-".join(s for s in (str(template_name), skip_prefix) if s) def generate_hash(self, values): - return hashlib.sha1('|'.join(values).encode()).hexdigest() + return hashlib.sha1("|".join(values).encode()).hexdigest() def reset(self): "Empty the template cache." diff --git a/django/template/loaders/filesystem.py b/django/template/loaders/filesystem.py index 2e49e3d6b3..a2474a3fad 100644 --- a/django/template/loaders/filesystem.py +++ b/django/template/loaders/filesystem.py @@ -10,7 +10,6 @@ from .base import Loader as BaseLoader class Loader(BaseLoader): - def __init__(self, engine, dirs=None): super().__init__(engine) self.dirs = dirs diff --git a/django/template/loaders/locmem.py b/django/template/loaders/locmem.py index 25d7672719..432de62b6c 100644 --- a/django/template/loaders/locmem.py +++ b/django/template/loaders/locmem.py @@ -8,7 +8,6 @@ from .base import Loader as BaseLoader class Loader(BaseLoader): - def __init__(self, engine, templates_dict): self.templates_dict = templates_dict super().__init__(engine) diff --git a/django/template/response.py b/django/template/response.py index 63d2f4a577..c38b95e9de 100644 --- a/django/template/response.py +++ b/django/template/response.py @@ -8,10 +8,18 @@ class ContentNotRenderedError(Exception): class SimpleTemplateResponse(HttpResponse): - rendering_attrs = ['template_name', 'context_data', '_post_render_callbacks'] + rendering_attrs = ["template_name", "context_data", "_post_render_callbacks"] - def __init__(self, template, context=None, content_type=None, status=None, - charset=None, using=None, headers=None): + def __init__( + self, + template, + context=None, + content_type=None, + status=None, + charset=None, + using=None, + headers=None, + ): # It would seem obvious to call these next two members 'template' and # 'context', but those names are reserved as part of the test Client # API. To avoid the name collision, we use different names. @@ -33,7 +41,7 @@ class SimpleTemplateResponse(HttpResponse): # content argument doesn't make sense here because it will be replaced # with rendered template so we always pass empty string in order to # prevent errors and provide shorter signature. - super().__init__('', content_type, status, charset=charset, headers=headers) + super().__init__("", content_type, status, charset=charset, headers=headers) # _is_rendered tracks whether the template and context has been baked # into a final response. @@ -50,7 +58,7 @@ class SimpleTemplateResponse(HttpResponse): obj_dict = self.__dict__.copy() if not self._is_rendered: raise ContentNotRenderedError( - 'The response content must be rendered before it can be pickled.' + "The response content must be rendered before it can be pickled." ) for attr in self.rendering_attrs: if attr in obj_dict: @@ -117,7 +125,7 @@ class SimpleTemplateResponse(HttpResponse): def __iter__(self): if not self._is_rendered: raise ContentNotRenderedError( - 'The response content must be rendered before it can be iterated over.' + "The response content must be rendered before it can be iterated over." ) return super().__iter__() @@ -125,7 +133,7 @@ class SimpleTemplateResponse(HttpResponse): def content(self): if not self._is_rendered: raise ContentNotRenderedError( - 'The response content must be rendered before it can be accessed.' + "The response content must be rendered before it can be accessed." ) return super().content @@ -137,9 +145,20 @@ class SimpleTemplateResponse(HttpResponse): class TemplateResponse(SimpleTemplateResponse): - rendering_attrs = SimpleTemplateResponse.rendering_attrs + ['_request'] + rendering_attrs = SimpleTemplateResponse.rendering_attrs + ["_request"] - def __init__(self, request, template, context=None, content_type=None, - status=None, charset=None, using=None, headers=None): - super().__init__(template, context, content_type, status, charset, using, headers=headers) + def __init__( + self, + request, + template, + context=None, + content_type=None, + status=None, + charset=None, + using=None, + headers=None, + ): + super().__init__( + template, context, content_type, status, charset, using, headers=headers + ) self._request = request diff --git a/django/template/smartif.py b/django/template/smartif.py index 96a1af8db0..5b15a5a476 100644 --- a/django/template/smartif.py +++ b/django/template/smartif.py @@ -13,6 +13,7 @@ class TokenBase: Base class for operators and literals, mainly for debugging and for throwing syntax errors. """ + id = None # node/token type name value = None # used by literals first = second = None # used by tree nodes @@ -45,6 +46,7 @@ def infix(bp, func): Create an infix operator, given a binding power and a function that evaluates the node. """ + class Operator(TokenBase): lbp = bp @@ -70,6 +72,7 @@ def prefix(bp, func): Create a prefix operator, given a binding power and a function that evaluates the node. """ + class Operator(TokenBase): lbp = bp @@ -91,19 +94,19 @@ def prefix(bp, func): # We defer variable evaluation to the lambda to ensure that terms are # lazily evaluated using Python's boolean parsing logic. OPERATORS = { - 'or': infix(6, lambda context, x, y: x.eval(context) or y.eval(context)), - 'and': infix(7, lambda context, x, y: x.eval(context) and y.eval(context)), - 'not': prefix(8, lambda context, x: not x.eval(context)), - 'in': infix(9, lambda context, x, y: x.eval(context) in y.eval(context)), - 'not in': infix(9, lambda context, x, y: x.eval(context) not in y.eval(context)), - 'is': infix(10, lambda context, x, y: x.eval(context) is y.eval(context)), - 'is not': infix(10, lambda context, x, y: x.eval(context) is not y.eval(context)), - '==': infix(10, lambda context, x, y: x.eval(context) == y.eval(context)), - '!=': infix(10, lambda context, x, y: x.eval(context) != y.eval(context)), - '>': infix(10, lambda context, x, y: x.eval(context) > y.eval(context)), - '>=': infix(10, lambda context, x, y: x.eval(context) >= y.eval(context)), - '<': infix(10, lambda context, x, y: x.eval(context) < y.eval(context)), - '<=': infix(10, lambda context, x, y: x.eval(context) <= y.eval(context)), + "or": infix(6, lambda context, x, y: x.eval(context) or y.eval(context)), + "and": infix(7, lambda context, x, y: x.eval(context) and y.eval(context)), + "not": prefix(8, lambda context, x: not x.eval(context)), + "in": infix(9, lambda context, x, y: x.eval(context) in y.eval(context)), + "not in": infix(9, lambda context, x, y: x.eval(context) not in y.eval(context)), + "is": infix(10, lambda context, x, y: x.eval(context) is y.eval(context)), + "is not": infix(10, lambda context, x, y: x.eval(context) is not y.eval(context)), + "==": infix(10, lambda context, x, y: x.eval(context) == y.eval(context)), + "!=": infix(10, lambda context, x, y: x.eval(context) != y.eval(context)), + ">": infix(10, lambda context, x, y: x.eval(context) > y.eval(context)), + ">=": infix(10, lambda context, x, y: x.eval(context) >= y.eval(context)), + "<": infix(10, lambda context, x, y: x.eval(context) < y.eval(context)), + "<=": infix(10, lambda context, x, y: x.eval(context) <= y.eval(context)), } # Assign 'id' to each: @@ -115,6 +118,7 @@ class Literal(TokenBase): """ A basic self-resolvable object similar to a Django template variable. """ + # IfParser uses Literal in create_var, but TemplateIfParser overrides # create_var so that a proper implementation that actually resolves # variables, filters etc. is used. @@ -190,8 +194,9 @@ class IfParser: retval = self.expression() # Check that we have exhausted all the tokens if self.current_token is not EndToken: - raise self.error_class("Unused '%s' at end of if expression." % - self.current_token.display()) + raise self.error_class( + "Unused '%s' at end of if expression." % self.current_token.display() + ) return retval def expression(self, rbp=0): diff --git a/django/template/utils.py b/django/template/utils.py index ad7baba2f3..2b118f900e 100644 --- a/django/template/utils.py +++ b/django/template/utils.py @@ -33,31 +33,34 @@ class EngineHandler: try: # This will raise an exception if 'BACKEND' doesn't exist or # isn't a string containing at least one dot. - default_name = tpl['BACKEND'].rsplit('.', 2)[-2] + default_name = tpl["BACKEND"].rsplit(".", 2)[-2] except Exception: - invalid_backend = tpl.get('BACKEND', '<not defined>') + invalid_backend = tpl.get("BACKEND", "<not defined>") raise ImproperlyConfigured( "Invalid BACKEND for a template engine: {}. Check " - "your TEMPLATES setting.".format(invalid_backend)) + "your TEMPLATES setting.".format(invalid_backend) + ) tpl = { - 'NAME': default_name, - 'DIRS': [], - 'APP_DIRS': False, - 'OPTIONS': {}, + "NAME": default_name, + "DIRS": [], + "APP_DIRS": False, + "OPTIONS": {}, **tpl, } - templates[tpl['NAME']] = tpl - backend_names.append(tpl['NAME']) + templates[tpl["NAME"]] = tpl + backend_names.append(tpl["NAME"]) counts = Counter(backend_names) duplicates = [alias for alias, count in counts.most_common() if count > 1] if duplicates: raise ImproperlyConfigured( "Template engine aliases aren't unique, duplicates: {}. " - "Set a unique NAME for each engine in settings.TEMPLATES." - .format(", ".join(duplicates))) + "Set a unique NAME for each engine in settings.TEMPLATES.".format( + ", ".join(duplicates) + ) + ) return templates @@ -70,13 +73,14 @@ class EngineHandler: except KeyError: raise InvalidTemplateEngineError( "Could not find config for '{}' " - "in settings.TEMPLATES".format(alias)) + "in settings.TEMPLATES".format(alias) + ) # If importing or initializing the backend raises an exception, # self._engines[alias] isn't set and this code may get executed # again, so we must preserve the original params. See #24265. params = params.copy() - backend = params.pop('BACKEND') + backend = params.pop("BACKEND") engine_cls = import_string(backend) engine = engine_cls(params) diff --git a/django/templatetags/cache.py b/django/templatetags/cache.py index 9e402a1206..4a60cd8afd 100644 --- a/django/templatetags/cache.py +++ b/django/templatetags/cache.py @@ -1,8 +1,6 @@ from django.core.cache import InvalidCacheBackendError, caches from django.core.cache.utils import make_template_fragment_key -from django.template import ( - Library, Node, TemplateSyntaxError, VariableDoesNotExist, -) +from django.template import Library, Node, TemplateSyntaxError, VariableDoesNotExist register = Library() @@ -19,26 +17,34 @@ class CacheNode(Node): try: expire_time = self.expire_time_var.resolve(context) except VariableDoesNotExist: - raise TemplateSyntaxError('"cache" tag got an unknown variable: %r' % self.expire_time_var.var) + raise TemplateSyntaxError( + '"cache" tag got an unknown variable: %r' % self.expire_time_var.var + ) if expire_time is not None: try: expire_time = int(expire_time) except (ValueError, TypeError): - raise TemplateSyntaxError('"cache" tag got a non-integer timeout value: %r' % expire_time) + raise TemplateSyntaxError( + '"cache" tag got a non-integer timeout value: %r' % expire_time + ) if self.cache_name: try: cache_name = self.cache_name.resolve(context) except VariableDoesNotExist: - raise TemplateSyntaxError('"cache" tag got an unknown variable: %r' % self.cache_name.var) + raise TemplateSyntaxError( + '"cache" tag got an unknown variable: %r' % self.cache_name.var + ) try: fragment_cache = caches[cache_name] except InvalidCacheBackendError: - raise TemplateSyntaxError('Invalid cache name specified for cache tag: %r' % cache_name) + raise TemplateSyntaxError( + "Invalid cache name specified for cache tag: %r" % cache_name + ) else: try: - fragment_cache = caches['template_fragments'] + fragment_cache = caches["template_fragments"] except InvalidCacheBackendError: - fragment_cache = caches['default'] + fragment_cache = caches["default"] vary_on = [var.resolve(context) for var in self.vary_on] cache_key = make_template_fragment_key(self.fragment_name, vary_on) @@ -49,7 +55,7 @@ class CacheNode(Node): return value -@register.tag('cache') +@register.tag("cache") def do_cache(parser, token): """ This will cache the contents of a template fragment for a given amount @@ -75,18 +81,19 @@ def do_cache(parser, token): Each unique set of arguments will result in a unique cache entry. """ - nodelist = parser.parse(('endcache',)) + nodelist = parser.parse(("endcache",)) parser.delete_first_token() tokens = token.split_contents() if len(tokens) < 3: raise TemplateSyntaxError("'%r' tag requires at least 2 arguments." % tokens[0]) - if len(tokens) > 3 and tokens[-1].startswith('using='): - cache_name = parser.compile_filter(tokens[-1][len('using='):]) + if len(tokens) > 3 and tokens[-1].startswith("using="): + cache_name = parser.compile_filter(tokens[-1][len("using=") :]) tokens = tokens[:-1] else: cache_name = None return CacheNode( - nodelist, parser.compile_filter(tokens[1]), + nodelist, + parser.compile_filter(tokens[1]), tokens[2], # fragment_name can't be a variable. [parser.compile_filter(t) for t in tokens[3:]], cache_name, diff --git a/django/templatetags/i18n.py b/django/templatetags/i18n.py index efe5f28941..9724efbd91 100644 --- a/django/templatetags/i18n.py +++ b/django/templatetags/i18n.py @@ -15,8 +15,10 @@ class GetAvailableLanguagesNode(Node): self.variable = variable def render(self, context): - context[self.variable] = [(k, translation.gettext(v)) for k, v in settings.LANGUAGES] - return '' + context[self.variable] = [ + (k, translation.gettext(v)) for k, v in settings.LANGUAGES + ] + return "" class GetLanguageInfoNode(Node): @@ -27,7 +29,7 @@ class GetLanguageInfoNode(Node): def render(self, context): lang_code = self.lang_code.resolve(context) context[self.variable] = translation.get_language_info(lang_code) - return '' + return "" class GetLanguageInfoListNode(Node): @@ -46,7 +48,7 @@ class GetLanguageInfoListNode(Node): def render(self, context): langs = self.languages.resolve(context) context[self.variable] = [self.get_language_info(lang) for lang in langs] - return '' + return "" class GetCurrentLanguageNode(Node): @@ -55,7 +57,7 @@ class GetCurrentLanguageNode(Node): def render(self, context): context[self.variable] = translation.get_language() - return '' + return "" class GetCurrentLanguageBidiNode(Node): @@ -64,47 +66,54 @@ class GetCurrentLanguageBidiNode(Node): def render(self, context): context[self.variable] = translation.get_language_bidi() - return '' + return "" class TranslateNode(Node): child_nodelists = () - def __init__(self, filter_expression, noop, asvar=None, - message_context=None): + def __init__(self, filter_expression, noop, asvar=None, message_context=None): self.noop = noop self.asvar = asvar self.message_context = message_context self.filter_expression = filter_expression if isinstance(self.filter_expression.var, str): self.filter_expression.is_var = True - self.filter_expression.var = Variable("'%s'" % - self.filter_expression.var) + self.filter_expression.var = Variable("'%s'" % self.filter_expression.var) def render(self, context): self.filter_expression.var.translate = not self.noop if self.message_context: - self.filter_expression.var.message_context = ( - self.message_context.resolve(context)) + self.filter_expression.var.message_context = self.message_context.resolve( + context + ) output = self.filter_expression.resolve(context) value = render_value_in_context(output, context) # Restore percent signs. Percent signs in template text are doubled # so they are not interpreted as string format flags. is_safe = isinstance(value, SafeData) - value = value.replace('%%', '%') + value = value.replace("%%", "%") value = mark_safe(value) if is_safe else value if self.asvar: context[self.asvar] = value - return '' + return "" else: return value class BlockTranslateNode(Node): - - def __init__(self, extra_context, singular, plural=None, countervar=None, - counter=None, message_context=None, trimmed=False, asvar=None, - tag_name='blocktranslate'): + def __init__( + self, + extra_context, + singular, + plural=None, + countervar=None, + counter=None, + message_context=None, + trimmed=False, + asvar=None, + tag_name="blocktranslate", + ): self.extra_context = extra_context self.singular = singular self.plural = plural @@ -117,9 +126,9 @@ class BlockTranslateNode(Node): def __repr__(self): return ( - f'<{self.__class__.__qualname__}: ' - f'extra_context={self.extra_context!r} ' - f'singular={self.singular!r} plural={self.plural!r}>' + f"<{self.__class__.__qualname__}: " + f"extra_context={self.extra_context!r} " + f"singular={self.singular!r} plural={self.plural!r}>" ) def render_token_list(self, tokens): @@ -127,11 +136,11 @@ class BlockTranslateNode(Node): vars = [] for token in tokens: if token.token_type == TokenType.TEXT: - result.append(token.contents.replace('%', '%%')) + result.append(token.contents.replace("%", "%%")) elif token.token_type == TokenType.VAR: - result.append('%%(%s)s' % token.contents) + result.append("%%(%s)s" % token.contents) vars.append(token.contents) - msg = ''.join(result) + msg = "".join(result) if self.trimmed: msg = translation.trim_whitespace(msg) return msg, vars @@ -143,7 +152,9 @@ class BlockTranslateNode(Node): message_context = None # Update() works like a push(), so corresponding context.pop() is at # the end of function - context.update({var: val.resolve(context) for var, val in self.extra_context.items()}) + context.update( + {var: val.resolve(context) for var, val in self.extra_context.items()} + ) singular, vars = self.render_token_list(self.singular) if self.plural and self.countervar and self.counter: count = self.counter.resolve(context) @@ -155,8 +166,7 @@ class BlockTranslateNode(Node): context[self.countervar] = count plural, plural_vars = self.render_token_list(self.plural) if message_context: - result = translation.npgettext(message_context, singular, - plural, count) + result = translation.npgettext(message_context, singular, plural, count) else: result = translation.ngettext(singular, plural, count) vars.extend(plural_vars) @@ -171,7 +181,7 @@ class BlockTranslateNode(Node): if key in context: val = context[key] else: - val = default_value % key if '%s' in default_value else default_value + val = default_value % key if "%s" in default_value else default_value return render_value_in_context(val, context) data = {v: render_value(v) for v in vars} @@ -182,14 +192,14 @@ class BlockTranslateNode(Node): if nested: # Either string is malformed, or it's a bug raise TemplateSyntaxError( - '%r is unable to format string returned by gettext: %r ' - 'using %r' % (self.tag_name, result, data) + "%r is unable to format string returned by gettext: %r " + "using %r" % (self.tag_name, result, data) ) with translation.override(None): result = self.render(context, nested=True) if self.asvar: context[self.asvar] = result - return '' + return "" else: return result @@ -221,8 +231,10 @@ def do_get_available_languages(parser, token): """ # token.split_contents() isn't useful here because this tag doesn't accept variable as arguments args = token.contents.split() - if len(args) != 3 or args[1] != 'as': - raise TemplateSyntaxError("'get_available_languages' requires 'as variable' (got %r)" % args) + if len(args) != 3 or args[1] != "as": + raise TemplateSyntaxError( + "'get_available_languages' requires 'as variable' (got %r)" % args + ) return GetAvailableLanguagesNode(args[2]) @@ -242,8 +254,10 @@ def do_get_language_info(parser, token): {{ l.bidi|yesno:"bi-directional,uni-directional" }} """ args = token.split_contents() - if len(args) != 5 or args[1] != 'for' or args[3] != 'as': - raise TemplateSyntaxError("'%s' requires 'for string as variable' (got %r)" % (args[0], args[1:])) + if len(args) != 5 or args[1] != "for" or args[3] != "as": + raise TemplateSyntaxError( + "'%s' requires 'for string as variable' (got %r)" % (args[0], args[1:]) + ) return GetLanguageInfoNode(parser.compile_filter(args[2]), args[4]) @@ -267,30 +281,32 @@ def do_get_language_info_list(parser, token): {% endfor %} """ args = token.split_contents() - if len(args) != 5 or args[1] != 'for' or args[3] != 'as': - raise TemplateSyntaxError("'%s' requires 'for sequence as variable' (got %r)" % (args[0], args[1:])) + if len(args) != 5 or args[1] != "for" or args[3] != "as": + raise TemplateSyntaxError( + "'%s' requires 'for sequence as variable' (got %r)" % (args[0], args[1:]) + ) return GetLanguageInfoListNode(parser.compile_filter(args[2]), args[4]) @register.filter def language_name(lang_code): - return translation.get_language_info(lang_code)['name'] + return translation.get_language_info(lang_code)["name"] @register.filter def language_name_translated(lang_code): - english_name = translation.get_language_info(lang_code)['name'] + english_name = translation.get_language_info(lang_code)["name"] return translation.gettext(english_name) @register.filter def language_name_local(lang_code): - return translation.get_language_info(lang_code)['name_local'] + return translation.get_language_info(lang_code)["name_local"] @register.filter def language_bidi(lang_code): - return translation.get_language_info(lang_code)['bidi'] + return translation.get_language_info(lang_code)["bidi"] @register.tag("get_current_language") @@ -307,8 +323,10 @@ def do_get_current_language(parser, token): """ # token.split_contents() isn't useful here because this tag doesn't accept variable as arguments args = token.contents.split() - if len(args) != 3 or args[1] != 'as': - raise TemplateSyntaxError("'get_current_language' requires 'as variable' (got %r)" % args) + if len(args) != 3 or args[1] != "as": + raise TemplateSyntaxError( + "'get_current_language' requires 'as variable' (got %r)" % args + ) return GetCurrentLanguageNode(args[2]) @@ -327,8 +345,10 @@ def do_get_current_language_bidi(parser, token): """ # token.split_contents() isn't useful here because this tag doesn't accept variable as arguments args = token.contents.split() - if len(args) != 3 or args[1] != 'as': - raise TemplateSyntaxError("'get_current_language_bidi' requires 'as variable' (got %r)" % args) + if len(args) != 3 or args[1] != "as": + raise TemplateSyntaxError( + "'get_current_language_bidi' requires 'as variable' (got %r)" % args + ) return GetCurrentLanguageBidiNode(args[2]) @@ -384,7 +404,7 @@ def do_translate(parser, token): asvar = None message_context = None seen = set() - invalid_context = {'as', 'noop'} + invalid_context = {"as", "noop"} while remaining: option = remaining.pop(0) @@ -392,21 +412,23 @@ def do_translate(parser, token): raise TemplateSyntaxError( "The '%s' option was specified more than once." % option, ) - elif option == 'noop': + elif option == "noop": noop = True - elif option == 'context': + elif option == "context": try: value = remaining.pop(0) except IndexError: raise TemplateSyntaxError( - "No argument provided to the '%s' tag for the context option." % bits[0] + "No argument provided to the '%s' tag for the context option." + % bits[0] ) if value in invalid_context: raise TemplateSyntaxError( - "Invalid argument '%s' provided to the '%s' tag for the context option" % (value, bits[0]), + "Invalid argument '%s' provided to the '%s' tag for the context option" + % (value, bits[0]), ) message_context = parser.compile_filter(value) - elif option == 'as': + elif option == "as": try: value = remaining.pop(0) except IndexError: @@ -417,8 +439,10 @@ def do_translate(parser, token): else: raise TemplateSyntaxError( "Unknown argument for '%s' tag: '%s'. The only options " - "available are 'noop', 'context' \"xxx\", and 'as VAR'." % ( - bits[0], option, + "available are 'noop', 'context' \"xxx\", and 'as VAR'." + % ( + bits[0], + option, ) ) seen.add(option) @@ -478,19 +502,21 @@ def do_block_translate(parser, token): option = remaining_bits.pop(0) if option in options: raise TemplateSyntaxError( - 'The %r option was specified more than once.' % option + "The %r option was specified more than once." % option ) - if option == 'with': + if option == "with": value = token_kwargs(remaining_bits, parser, support_legacy=True) if not value: raise TemplateSyntaxError( '"with" in %r tag needs at least one keyword argument.' % bits[0] ) - elif option == 'count': + elif option == "count": value = token_kwargs(remaining_bits, parser, support_legacy=True) if len(value) != 1: - raise TemplateSyntaxError('"count" in %r tag expected exactly ' - 'one keyword argument.' % bits[0]) + raise TemplateSyntaxError( + '"count" in %r tag expected exactly ' + "one keyword argument." % bits[0] + ) elif option == "context": try: value = remaining_bits.pop(0) @@ -506,23 +532,25 @@ def do_block_translate(parser, token): value = remaining_bits.pop(0) except IndexError: raise TemplateSyntaxError( - "No argument provided to the '%s' tag for the asvar option." % bits[0] + "No argument provided to the '%s' tag for the asvar option." + % bits[0] ) asvar = value else: - raise TemplateSyntaxError('Unknown argument for %r tag: %r.' % - (bits[0], option)) + raise TemplateSyntaxError( + "Unknown argument for %r tag: %r." % (bits[0], option) + ) options[option] = value - if 'count' in options: - countervar, counter = next(iter(options['count'].items())) + if "count" in options: + countervar, counter = next(iter(options["count"].items())) else: countervar, counter = None, None - if 'context' in options: - message_context = options['context'] + if "context" in options: + message_context = options["context"] else: message_context = None - extra_context = options.get('with', {}) + extra_context = options.get("with", {}) trimmed = options.get("trimmed", False) @@ -535,21 +563,34 @@ def do_block_translate(parser, token): else: break if countervar and counter: - if token.contents.strip() != 'plural': - raise TemplateSyntaxError("%r doesn't allow other block tags inside it" % bits[0]) + if token.contents.strip() != "plural": + raise TemplateSyntaxError( + "%r doesn't allow other block tags inside it" % bits[0] + ) while parser.tokens: token = parser.next_token() if token.token_type in (TokenType.VAR, TokenType.TEXT): plural.append(token) else: break - end_tag_name = 'end%s' % bits[0] + end_tag_name = "end%s" % bits[0] if token.contents.strip() != end_tag_name: - raise TemplateSyntaxError("%r doesn't allow other block tags (seen %r) inside it" % (bits[0], token.contents)) + raise TemplateSyntaxError( + "%r doesn't allow other block tags (seen %r) inside it" + % (bits[0], token.contents) + ) - return BlockTranslateNode(extra_context, singular, plural, countervar, - counter, message_context, trimmed=trimmed, - asvar=asvar, tag_name=bits[0]) + return BlockTranslateNode( + extra_context, + singular, + plural, + countervar, + counter, + message_context, + trimmed=trimmed, + asvar=asvar, + tag_name=bits[0], + ) @register.tag @@ -567,6 +608,6 @@ def language(parser, token): if len(bits) != 2: raise TemplateSyntaxError("'%s' takes one argument (language)" % bits[0]) language = parser.compile_filter(bits[1]) - nodelist = parser.parse(('endlanguage',)) + nodelist = parser.parse(("endlanguage",)) parser.delete_first_token() return LanguageNode(nodelist, language) diff --git a/django/templatetags/l10n.py b/django/templatetags/l10n.py index 9212753286..45e9a6ac50 100644 --- a/django/templatetags/l10n.py +++ b/django/templatetags/l10n.py @@ -28,7 +28,7 @@ class LocalizeNode(Node): self.use_l10n = use_l10n def __repr__(self): - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def render(self, context): old_setting = context.use_l10n @@ -38,7 +38,7 @@ class LocalizeNode(Node): return output -@register.tag('localize') +@register.tag("localize") def localize_tag(parser, token): """ Force or prevents localization of values, regardless of the value of @@ -54,10 +54,10 @@ def localize_tag(parser, token): bits = list(token.split_contents()) if len(bits) == 1: use_l10n = True - elif len(bits) > 2 or bits[1] not in ('on', 'off'): + elif len(bits) > 2 or bits[1] not in ("on", "off"): raise TemplateSyntaxError("%r argument should be 'on' or 'off'" % bits[0]) else: - use_l10n = bits[1] == 'on' - nodelist = parser.parse(('endlocalize',)) + use_l10n = bits[1] == "on" + nodelist = parser.parse(("endlocalize",)) parser.delete_first_token() return LocalizeNode(nodelist, use_l10n) diff --git a/django/templatetags/static.py b/django/templatetags/static.py index 4d1a05ee03..7a5147a0dd 100644 --- a/django/templatetags/static.py +++ b/django/templatetags/static.py @@ -9,14 +9,14 @@ register = template.Library() class PrefixNode(template.Node): - def __repr__(self): return "<PrefixNode for %r>" % self.name def __init__(self, varname=None, name=None): if name is None: raise template.TemplateSyntaxError( - "Prefix nodes must be given a name to return.") + "Prefix nodes must be given a name to return." + ) self.varname = varname self.name = name @@ -27,9 +27,10 @@ class PrefixNode(template.Node): """ # token.split_contents() isn't useful here because tags using this method don't accept variable as arguments tokens = token.contents.split() - if len(tokens) > 1 and tokens[1] != 'as': + if len(tokens) > 1 and tokens[1] != "as": raise template.TemplateSyntaxError( - "First argument in '%s' must be 'as'" % tokens[0]) + "First argument in '%s' must be 'as'" % tokens[0] + ) if len(tokens) > 1: varname = tokens[2] else: @@ -41,9 +42,9 @@ class PrefixNode(template.Node): try: from django.conf import settings except ImportError: - prefix = '' + prefix = "" else: - prefix = iri_to_uri(getattr(settings, name, '')) + prefix = iri_to_uri(getattr(settings, name, "")) return prefix def render(self, context): @@ -51,7 +52,7 @@ class PrefixNode(template.Node): if self.varname is None: return prefix context[self.varname] = prefix - return '' + return "" @register.tag @@ -96,13 +97,14 @@ class StaticNode(template.Node): def __init__(self, varname=None, path=None): if path is None: raise template.TemplateSyntaxError( - "Static template nodes must be given a path to return.") + "Static template nodes must be given a path to return." + ) self.path = path self.varname = varname def __repr__(self): return ( - f'{self.__class__.__name__}(varname={self.varname!r}, path={self.path!r})' + f"{self.__class__.__name__}(varname={self.varname!r}, path={self.path!r})" ) def url(self, context): @@ -116,12 +118,13 @@ class StaticNode(template.Node): if self.varname is None: return url context[self.varname] = url - return '' + return "" @classmethod def handle_simple(cls, path): - if apps.is_installed('django.contrib.staticfiles'): + if apps.is_installed("django.contrib.staticfiles"): from django.contrib.staticfiles.storage import staticfiles_storage + return staticfiles_storage.url(path) else: return urljoin(PrefixNode.handle_simple("STATIC_URL"), quote(path)) @@ -135,11 +138,12 @@ class StaticNode(template.Node): if len(bits) < 2: raise template.TemplateSyntaxError( - "'%s' takes at least one argument (path to file)" % bits[0]) + "'%s' takes at least one argument (path to file)" % bits[0] + ) path = parser.compile_filter(bits[1]) - if len(bits) >= 2 and bits[-2] == 'as': + if len(bits) >= 2 and bits[-2] == "as": varname = bits[3] else: varname = None @@ -147,7 +151,7 @@ class StaticNode(template.Node): return cls(varname, path) -@register.tag('static') +@register.tag("static") def do_static(parser, token): """ Join the given path with the STATIC_URL setting. diff --git a/django/templatetags/tz.py b/django/templatetags/tz.py index 489391a267..50810cece6 100644 --- a/django/templatetags/tz.py +++ b/django/templatetags/tz.py @@ -22,6 +22,7 @@ class UnknownTimezoneException(BaseException): def timezone_constructor(tzname): if settings.USE_DEPRECATED_PYTZ: import pytz + try: return pytz.timezone(tzname) except pytz.UnknownTimeZoneError: @@ -40,6 +41,7 @@ class datetimeobject(datetime): # Template filters + @register.filter def localtime(value): """ @@ -58,7 +60,7 @@ def utc(value): return do_timezone(value, timezone.utc) -@register.filter('timezone') +@register.filter("timezone") def do_timezone(value, arg): """ Convert a datetime to local time in a given time zone. @@ -68,7 +70,7 @@ def do_timezone(value, arg): Naive datetimes are assumed to be in local time in the default time zone. """ if not isinstance(value, datetime): - return '' + return "" # Obtain a timezone-aware datetime try: @@ -78,7 +80,7 @@ def do_timezone(value, arg): # Filters must never raise exceptions, and pytz' exceptions inherit # Exception directly, not a specific subclass. So catch everything. except Exception: - return '' + return "" # Obtain a tzinfo instance if isinstance(arg, tzinfo): @@ -87,27 +89,36 @@ def do_timezone(value, arg): try: tz = timezone_constructor(arg) except UnknownTimezoneException: - return '' + return "" else: - return '' + return "" result = timezone.localtime(value, tz) # HACK: the convert_to_local_time flag will prevent # automatic conversion of the value to local time. - result = datetimeobject(result.year, result.month, result.day, - result.hour, result.minute, result.second, - result.microsecond, result.tzinfo) + result = datetimeobject( + result.year, + result.month, + result.day, + result.hour, + result.minute, + result.second, + result.microsecond, + result.tzinfo, + ) result.convert_to_local_time = False return result # Template tags + class LocalTimeNode(Node): """ Template node class used by ``localtime_tag``. """ + def __init__(self, nodelist, use_tz): self.nodelist = nodelist self.use_tz = use_tz @@ -124,6 +135,7 @@ class TimezoneNode(Node): """ Template node class used by ``timezone_tag``. """ + def __init__(self, nodelist, tz): self.nodelist = nodelist self.tz = tz @@ -138,15 +150,16 @@ class GetCurrentTimezoneNode(Node): """ Template node class used by ``get_current_timezone_tag``. """ + def __init__(self, variable): self.variable = variable def render(self, context): context[self.variable] = timezone.get_current_timezone_name() - return '' + return "" -@register.tag('localtime') +@register.tag("localtime") def localtime_tag(parser, token): """ Force or prevent conversion of datetime objects to local time, @@ -159,17 +172,16 @@ def localtime_tag(parser, token): bits = token.split_contents() if len(bits) == 1: use_tz = True - elif len(bits) > 2 or bits[1] not in ('on', 'off'): - raise TemplateSyntaxError("%r argument should be 'on' or 'off'" % - bits[0]) + elif len(bits) > 2 or bits[1] not in ("on", "off"): + raise TemplateSyntaxError("%r argument should be 'on' or 'off'" % bits[0]) else: - use_tz = bits[1] == 'on' - nodelist = parser.parse(('endlocaltime',)) + use_tz = bits[1] == "on" + nodelist = parser.parse(("endlocaltime",)) parser.delete_first_token() return LocalTimeNode(nodelist, use_tz) -@register.tag('timezone') +@register.tag("timezone") def timezone_tag(parser, token): """ Enable a given time zone just for this block. @@ -186,10 +198,9 @@ def timezone_tag(parser, token): """ bits = token.split_contents() if len(bits) != 2: - raise TemplateSyntaxError("'%s' takes one argument (timezone)" % - bits[0]) + raise TemplateSyntaxError("'%s' takes one argument (timezone)" % bits[0]) tz = parser.compile_filter(bits[1]) - nodelist = parser.parse(('endtimezone',)) + nodelist = parser.parse(("endtimezone",)) parser.delete_first_token() return TimezoneNode(nodelist, tz) @@ -208,7 +219,7 @@ def get_current_timezone_tag(parser, token): """ # token.split_contents() isn't useful here because this tag doesn't accept variable as arguments args = token.contents.split() - if len(args) != 3 or args[1] != 'as': + if len(args) != 3 or args[1] != "as": raise TemplateSyntaxError( "'get_current_timezone' requires 'as variable' (got %r)" % args ) diff --git a/django/test/__init__.py b/django/test/__init__.py index d1f953a8dd..485298e8e7 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -1,21 +1,38 @@ """Django Unit Test framework.""" -from django.test.client import ( - AsyncClient, AsyncRequestFactory, Client, RequestFactory, -) +from django.test.client import AsyncClient, AsyncRequestFactory, Client, RequestFactory from django.test.testcases import ( - LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase, - skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature, + LiveServerTestCase, + SimpleTestCase, + TestCase, + TransactionTestCase, + skipIfDBFeature, + skipUnlessAnyDBFeature, + skipUnlessDBFeature, ) from django.test.utils import ( - ignore_warnings, modify_settings, override_settings, - override_system_checks, tag, + ignore_warnings, + modify_settings, + override_settings, + override_system_checks, + tag, ) __all__ = [ - 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory', - 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase', - 'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', - 'ignore_warnings', 'modify_settings', 'override_settings', - 'override_system_checks', 'tag', + "AsyncClient", + "AsyncRequestFactory", + "Client", + "RequestFactory", + "TestCase", + "TransactionTestCase", + "SimpleTestCase", + "LiveServerTestCase", + "skipIfDBFeature", + "skipUnlessAnyDBFeature", + "skipUnlessDBFeature", + "ignore_warnings", + "modify_settings", + "override_settings", + "override_system_checks", + "tag", ] diff --git a/django/test/client.py b/django/test/client.py index af1090a740..a38e7dae13 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -16,9 +16,7 @@ from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.core.serializers.json import DjangoJSONEncoder -from django.core.signals import ( - got_request_exception, request_finished, request_started, -) +from django.core.signals import got_request_exception, request_finished, request_started from django.db import close_old_connections from django.http import HttpRequest, QueryDict, SimpleCookie from django.test import signals @@ -31,20 +29,26 @@ from django.utils.itercompat import is_iterable from django.utils.regex_helper import _lazy_re_compile __all__ = ( - 'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError', - 'RequestFactory', 'encode_file', 'encode_multipart', + "AsyncClient", + "AsyncRequestFactory", + "Client", + "RedirectCycleError", + "RequestFactory", + "encode_file", + "encode_multipart", ) -BOUNDARY = 'BoUnDaRyStRiNg' -MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY -CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w-]+);?') +BOUNDARY = "BoUnDaRyStRiNg" +MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY +CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w-]+);?") # Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8 -JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json') +JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json") class RedirectCycleError(Exception): """The test client has been asked to follow a redirect loop.""" + def __init__(self, message, last_response): super().__init__(message) self.last_response = last_response @@ -58,6 +62,7 @@ class FakePayload: length. This makes sure that views can't do anything under the test client that wouldn't work in real life. """ + def __init__(self, content=None): self.__content = BytesIO() self.__len = 0 @@ -74,7 +79,9 @@ class FakePayload: self.read_started = True if num_bytes is None: num_bytes = self.__len or 0 - assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data." + assert ( + self.__len >= num_bytes + ), "Cannot read more than the available bytes from the HTTP incoming data." content = self.__content.read(num_bytes) self.__len -= num_bytes return content @@ -92,7 +99,7 @@ def closing_iterator_wrapper(iterable, close): yield from iterable finally: request_finished.disconnect(close_old_connections) - close() # will fire request_finished + close() # will fire request_finished request_finished.connect(close_old_connections) @@ -106,12 +113,12 @@ def conditional_content_removal(request, response): if response.streaming: response.streaming_content = [] else: - response.content = b'' - if request.method == 'HEAD': + response.content = b"" + if request.method == "HEAD": if response.streaming: response.streaming_content = [] else: - response.content = b'' + response.content = b"" return response @@ -121,6 +128,7 @@ class ClientHandler(BaseHandler): interface to compose requests, but return the raw HttpResponse object with the originating WSGIRequest attached to its ``wsgi_request`` attribute. """ + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): self.enforce_csrf_checks = enforce_csrf_checks super().__init__(*args, **kwargs) @@ -154,10 +162,11 @@ class ClientHandler(BaseHandler): # Emulate a WSGI server by calling the close method on completion. if response.streaming: response.streaming_content = closing_iterator_wrapper( - response.streaming_content, response.close) + response.streaming_content, response.close + ) else: request_finished.disconnect(close_old_connections) - response.close() # will fire request_finished + response.close() # will fire request_finished request_finished.connect(close_old_connections) return response @@ -165,6 +174,7 @@ class ClientHandler(BaseHandler): class AsyncClientHandler(BaseHandler): """An async version of ClientHandler.""" + def __init__(self, enforce_csrf_checks=True, *args, **kwargs): self.enforce_csrf_checks = enforce_csrf_checks super().__init__(*args, **kwargs) @@ -175,13 +185,15 @@ class AsyncClientHandler(BaseHandler): if self._middleware_chain is None: self.load_middleware(is_async=True) # Extract body file from the scope, if provided. - if '_body_file' in scope: - body_file = scope.pop('_body_file') + if "_body_file" in scope: + body_file = scope.pop("_body_file") else: - body_file = FakePayload('') + body_file = FakePayload("") request_started.disconnect(close_old_connections) - await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope) + await sync_to_async(request_started.send, thread_sensitive=False)( + sender=self.__class__, scope=scope + ) request_started.connect(close_old_connections) request = ASGIRequest(scope, body_file) # Sneaky little hack so that we can easily get round @@ -197,7 +209,9 @@ class AsyncClientHandler(BaseHandler): response.asgi_request = request # Emulate a server by calling the close method on completion. if response.streaming: - response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)( + response.streaming_content = await sync_to_async( + closing_iterator_wrapper, thread_sensitive=False + )( response.streaming_content, response.close, ) @@ -216,10 +230,10 @@ def store_rendered_templates(store, signal, sender, template, context, **kwargs) The context is copied so that it is an accurate representation at the time of rendering. """ - store.setdefault('templates', []).append(template) - if 'context' not in store: - store['context'] = ContextList() - store['context'].append(copy(context)) + store.setdefault("templates", []).append(template) + if "context" not in store: + store["context"] = ContextList() + store["context"].append(copy(context)) def encode_multipart(boundary, data): @@ -255,25 +269,33 @@ def encode_multipart(boundary, data): if is_file(item): lines.extend(encode_file(boundary, key, item)) else: - lines.extend(to_bytes(val) for val in [ - '--%s' % boundary, - 'Content-Disposition: form-data; name="%s"' % key, - '', - item - ]) + lines.extend( + to_bytes(val) + for val in [ + "--%s" % boundary, + 'Content-Disposition: form-data; name="%s"' % key, + "", + item, + ] + ) else: - lines.extend(to_bytes(val) for val in [ - '--%s' % boundary, - 'Content-Disposition: form-data; name="%s"' % key, - '', - value - ]) + lines.extend( + to_bytes(val) + for val in [ + "--%s" % boundary, + 'Content-Disposition: form-data; name="%s"' % key, + "", + value, + ] + ) - lines.extend([ - to_bytes('--%s--' % boundary), - b'', - ]) - return b'\r\n'.join(lines) + lines.extend( + [ + to_bytes("--%s--" % boundary), + b"", + ] + ) + return b"\r\n".join(lines) def encode_file(boundary, key, file): @@ -282,10 +304,10 @@ def encode_file(boundary, key, file): # file.name might not be a string. For example, it's an int for # tempfile.TemporaryFile(). - file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str) - filename = os.path.basename(file.name) if file_has_string_name else '' + file_has_string_name = hasattr(file, "name") and isinstance(file.name, str) + filename = os.path.basename(file.name) if file_has_string_name else "" - if hasattr(file, 'content_type'): + if hasattr(file, "content_type"): content_type = file.content_type elif filename: content_type = mimetypes.guess_type(filename)[0] @@ -293,15 +315,16 @@ def encode_file(boundary, key, file): content_type = None if content_type is None: - content_type = 'application/octet-stream' + content_type = "application/octet-stream" filename = filename or key return [ - to_bytes('--%s' % boundary), - to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"' - % (key, filename)), - to_bytes('Content-Type: %s' % content_type), - b'', - to_bytes(file.read()) + to_bytes("--%s" % boundary), + to_bytes( + 'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename) + ), + to_bytes("Content-Type: %s" % content_type), + b"", + to_bytes(file.read()), ] @@ -318,6 +341,7 @@ class RequestFactory: Once you have a request object you can pass it to any view function, just as if that view had been hooked up using a URLconf. """ + def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults): self.json_encoder = json_encoder self.defaults = defaults @@ -333,24 +357,26 @@ class RequestFactory: # - REMOTE_ADDR: often useful, see #8551. # See https://www.python.org/dev/peps/pep-3333/#environ-variables return { - 'HTTP_COOKIE': '; '.join(sorted( - '%s=%s' % (morsel.key, morsel.coded_value) - for morsel in self.cookies.values() - )), - 'PATH_INFO': '/', - 'REMOTE_ADDR': '127.0.0.1', - 'REQUEST_METHOD': 'GET', - 'SCRIPT_NAME': '', - 'SERVER_NAME': 'testserver', - 'SERVER_PORT': '80', - 'SERVER_PROTOCOL': 'HTTP/1.1', - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': 'http', - 'wsgi.input': FakePayload(b''), - 'wsgi.errors': self.errors, - 'wsgi.multiprocess': True, - 'wsgi.multithread': False, - 'wsgi.run_once': False, + "HTTP_COOKIE": "; ".join( + sorted( + "%s=%s" % (morsel.key, morsel.coded_value) + for morsel in self.cookies.values() + ) + ), + "PATH_INFO": "/", + "REMOTE_ADDR": "127.0.0.1", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "", + "SERVER_NAME": "testserver", + "SERVER_PORT": "80", + "SERVER_PROTOCOL": "HTTP/1.1", + "wsgi.version": (1, 0), + "wsgi.url_scheme": "http", + "wsgi.input": FakePayload(b""), + "wsgi.errors": self.errors, + "wsgi.multiprocess": True, + "wsgi.multithread": False, + "wsgi.run_once": False, **self.defaults, **request, } @@ -376,7 +402,9 @@ class RequestFactory: Return encoded JSON if data is a dict, list, or tuple and content_type is application/json. """ - should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple)) + should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance( + data, (dict, list, tuple) + ) return json.dumps(data, cls=self.json_encoder) if should_encode else data def _get_path(self, parsed): @@ -388,88 +416,128 @@ class RequestFactory: # Replace the behavior where non-ASCII values in the WSGI environ are # arbitrarily decoded with ISO-8859-1. # Refs comment in `get_bytes_from_wsgi()`. - return path.decode('iso-8859-1') + return path.decode("iso-8859-1") def get(self, path, data=None, secure=False, **extra): """Construct a GET request.""" data = {} if data is None else data - return self.generic('GET', path, secure=secure, **{ - 'QUERY_STRING': urlencode(data, doseq=True), - **extra, - }) + return self.generic( + "GET", + path, + secure=secure, + **{ + "QUERY_STRING": urlencode(data, doseq=True), + **extra, + }, + ) - def post(self, path, data=None, content_type=MULTIPART_CONTENT, - secure=False, **extra): + def post( + self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra + ): """Construct a POST request.""" data = self._encode_json({} if data is None else data, content_type) post_data = self._encode_data(data, content_type) - return self.generic('POST', path, post_data, content_type, - secure=secure, **extra) + return self.generic( + "POST", path, post_data, content_type, secure=secure, **extra + ) def head(self, path, data=None, secure=False, **extra): """Construct a HEAD request.""" data = {} if data is None else data - return self.generic('HEAD', path, secure=secure, **{ - 'QUERY_STRING': urlencode(data, doseq=True), - **extra, - }) + return self.generic( + "HEAD", + path, + secure=secure, + **{ + "QUERY_STRING": urlencode(data, doseq=True), + **extra, + }, + ) def trace(self, path, secure=False, **extra): """Construct a TRACE request.""" - return self.generic('TRACE', path, secure=secure, **extra) + return self.generic("TRACE", path, secure=secure, **extra) - def options(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def options( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): "Construct an OPTIONS request." - return self.generic('OPTIONS', path, data, content_type, - secure=secure, **extra) + return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra) - def put(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def put( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a PUT request.""" data = self._encode_json(data, content_type) - return self.generic('PUT', path, data, content_type, - secure=secure, **extra) + return self.generic("PUT", path, data, content_type, secure=secure, **extra) - def patch(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def patch( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a PATCH request.""" data = self._encode_json(data, content_type) - return self.generic('PATCH', path, data, content_type, - secure=secure, **extra) + return self.generic("PATCH", path, data, content_type, secure=secure, **extra) - def delete(self, path, data='', content_type='application/octet-stream', - secure=False, **extra): + def delete( + self, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct a DELETE request.""" data = self._encode_json(data, content_type) - return self.generic('DELETE', path, data, content_type, - secure=secure, **extra) + return self.generic("DELETE", path, data, content_type, secure=secure, **extra) - def generic(self, method, path, data='', - content_type='application/octet-stream', secure=False, - **extra): + def generic( + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, + ): """Construct an arbitrary HTTP request.""" parsed = urlparse(str(path)) # path can be lazy data = force_bytes(data, settings.DEFAULT_CHARSET) r = { - 'PATH_INFO': self._get_path(parsed), - 'REQUEST_METHOD': method, - 'SERVER_PORT': '443' if secure else '80', - 'wsgi.url_scheme': 'https' if secure else 'http', + "PATH_INFO": self._get_path(parsed), + "REQUEST_METHOD": method, + "SERVER_PORT": "443" if secure else "80", + "wsgi.url_scheme": "https" if secure else "http", } if data: - r.update({ - 'CONTENT_LENGTH': str(len(data)), - 'CONTENT_TYPE': content_type, - 'wsgi.input': FakePayload(data), - }) + r.update( + { + "CONTENT_LENGTH": str(len(data)), + "CONTENT_TYPE": content_type, + "wsgi.input": FakePayload(data), + } + ) r.update(extra) # If QUERY_STRING is absent or empty, we want to extract it from the URL. - if not r.get('QUERY_STRING'): + if not r.get("QUERY_STRING"): # WSGI requires latin-1 encoded strings. See get_path_info(). - query_string = parsed[4].encode().decode('iso-8859-1') - r['QUERY_STRING'] = query_string + query_string = parsed[4].encode().decode("iso-8859-1") + r["QUERY_STRING"] = query_string return self.request(**r) @@ -487,30 +555,35 @@ class AsyncRequestFactory(RequestFactory): a) this makes ASGIRequest subclasses, and b) AsyncTestClient can subclass it. """ + def _base_scope(self, **request): """The base scope for a request.""" # This is a minimal valid ASGI scope, plus: # - headers['cookie'] for cookie support, # - 'client' often useful, see #8551. scope = { - 'asgi': {'version': '3.0'}, - 'type': 'http', - 'http_version': '1.1', - 'client': ['127.0.0.1', 0], - 'server': ('testserver', '80'), - 'scheme': 'http', - 'method': 'GET', - 'headers': [], + "asgi": {"version": "3.0"}, + "type": "http", + "http_version": "1.1", + "client": ["127.0.0.1", 0], + "server": ("testserver", "80"), + "scheme": "http", + "method": "GET", + "headers": [], **self.defaults, **request, } - scope['headers'].append(( - b'cookie', - b'; '.join(sorted( - ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii') - for morsel in self.cookies.values() - )), - )) + scope["headers"].append( + ( + b"cookie", + b"; ".join( + sorted( + ("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii") + for morsel in self.cookies.values() + ) + ), + ) + ) return scope def request(self, **request): @@ -518,45 +591,52 @@ class AsyncRequestFactory(RequestFactory): # This is synchronous, which means all methods on this class are. # AsyncClient, however, has an async request function, which makes all # its methods async. - if '_body_file' in request: - body_file = request.pop('_body_file') + if "_body_file" in request: + body_file = request.pop("_body_file") else: - body_file = FakePayload('') + body_file = FakePayload("") return ASGIRequest(self._base_scope(**request), body_file) def generic( - self, method, path, data='', content_type='application/octet-stream', - secure=False, **extra, + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra, ): """Construct an arbitrary HTTP request.""" parsed = urlparse(str(path)) # path can be lazy. data = force_bytes(data, settings.DEFAULT_CHARSET) s = { - 'method': method, - 'path': self._get_path(parsed), - 'server': ('127.0.0.1', '443' if secure else '80'), - 'scheme': 'https' if secure else 'http', - 'headers': [(b'host', b'testserver')], + "method": method, + "path": self._get_path(parsed), + "server": ("127.0.0.1", "443" if secure else "80"), + "scheme": "https" if secure else "http", + "headers": [(b"host", b"testserver")], } if data: - s['headers'].extend([ - (b'content-length', str(len(data)).encode('ascii')), - (b'content-type', content_type.encode('ascii')), - ]) - s['_body_file'] = FakePayload(data) - follow = extra.pop('follow', None) + s["headers"].extend( + [ + (b"content-length", str(len(data)).encode("ascii")), + (b"content-type", content_type.encode("ascii")), + ] + ) + s["_body_file"] = FakePayload(data) + follow = extra.pop("follow", None) if follow is not None: - s['follow'] = follow - if query_string := extra.pop('QUERY_STRING', None): - s['query_string'] = query_string - s['headers'] += [ - (key.lower().encode('ascii'), value.encode('latin1')) + s["follow"] = follow + if query_string := extra.pop("QUERY_STRING", None): + s["query_string"] = query_string + s["headers"] += [ + (key.lower().encode("ascii"), value.encode("latin1")) for key, value in extra.items() ] # If QUERY_STRING is absent or empty, we want to extract it from the # URL. - if not s.get('query_string'): - s['query_string'] = parsed[4] + if not s.get("query_string"): + s["query_string"] = parsed[4] return self.request(**s) @@ -564,6 +644,7 @@ class ClientMixin: """ Mixin with common methods between Client and AsyncClient. """ + def store_exc_info(self, **kwargs): """Store exceptions when they are generated by a view.""" self.exc_info = sys.exc_info() @@ -601,6 +682,7 @@ class ClientMixin: are incorrect. """ from django.contrib.auth import authenticate + user = authenticate(**credentials) if user: self._login(user) @@ -610,9 +692,10 @@ class ClientMixin: def force_login(self, user, backend=None): def get_backend(): from django.contrib.auth import load_backend + for backend_path in settings.AUTHENTICATION_BACKENDS: backend = load_backend(backend_path) - if hasattr(backend, 'get_user'): + if hasattr(backend, "get_user"): return backend_path if backend is None: @@ -637,17 +720,18 @@ class ClientMixin: session_cookie = settings.SESSION_COOKIE_NAME self.cookies[session_cookie] = request.session.session_key cookie_data = { - 'max-age': None, - 'path': '/', - 'domain': settings.SESSION_COOKIE_DOMAIN, - 'secure': settings.SESSION_COOKIE_SECURE or None, - 'expires': None, + "max-age": None, + "path": "/", + "domain": settings.SESSION_COOKIE_DOMAIN, + "secure": settings.SESSION_COOKIE_SECURE or None, + "expires": None, } self.cookies[session_cookie].update(cookie_data) def logout(self): """Log out the user by removing the cookies and session object.""" from django.contrib.auth import get_user, logout + request = HttpRequest() if self.session: request.session = self.session @@ -659,13 +743,15 @@ class ClientMixin: self.cookies = SimpleCookie() def _parse_json(self, response, **extra): - if not hasattr(response, '_json'): - if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')): + if not hasattr(response, "_json"): + if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")): raise ValueError( 'Content-Type header is "%s", not "application/json"' - % response.get('Content-Type') + % response.get("Content-Type") ) - response._json = json.loads(response.content.decode(response.charset), **extra) + response._json = json.loads( + response.content.decode(response.charset), **extra + ) return response._json @@ -687,7 +773,10 @@ class Client(ClientMixin, RequestFactory): contexts and templates produced by a view, rather than the HTML rendered to the end-user. """ - def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + + def __init__( + self, enforce_csrf_checks=False, raise_request_exception=True, **defaults + ): super().__init__(**defaults) self.handler = ClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception @@ -723,13 +812,13 @@ class Client(ClientMixin, RequestFactory): response.client = self response.request = request # Add any rendered template detail to the response. - response.templates = data.get('templates', []) - response.context = data.get('context') + response.templates = data.get("templates", []) + response.context = data.get("context") response.json = partial(self._parse_json, response) # Attach the ResolverMatch instance to the response. - urlconf = getattr(response.wsgi_request, 'urlconf', None) + urlconf = getattr(response.wsgi_request, "urlconf", None) response.resolver_match = SimpleLazyObject( - lambda: resolve(request['PATH_INFO'], urlconf=urlconf), + lambda: resolve(request["PATH_INFO"], urlconf=urlconf), ) # Flatten a single context. Not really necessary anymore thanks to the # __getattr__ flattening in ContextList, but has some edge case @@ -749,13 +838,24 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def post(self, path, data=None, content_type=MULTIPART_CONTENT, - follow=False, secure=False, **extra): + def post( + self, + path, + data=None, + content_type=MULTIPART_CONTENT, + follow=False, + secure=False, + **extra, + ): """Request a response from the server using POST.""" self.extra = extra - response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().post( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response def head(self, path, data=None, follow=False, secure=False, **extra): @@ -766,43 +866,87 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def options(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def options( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Request a response from the server using OPTIONS.""" self.extra = extra - response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().options( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def put(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def put( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a resource to the server using PUT.""" self.extra = extra - response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().put( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def patch(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def patch( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a resource to the server using PATCH.""" self.extra = extra - response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().patch( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def delete(self, path, data='', content_type='application/octet-stream', - follow=False, secure=False, **extra): + def delete( + self, + path, + data="", + content_type="application/octet-stream", + follow=False, + secure=False, + **extra, + ): """Send a DELETE request to the server.""" self.extra = extra - response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) + response = super().delete( + path, data=data, content_type=content_type, secure=secure, **extra + ) if follow: - response = self._handle_redirects(response, data=data, content_type=content_type, **extra) + response = self._handle_redirects( + response, data=data, content_type=content_type, **extra + ) return response - def trace(self, path, data='', follow=False, secure=False, **extra): + def trace(self, path, data="", follow=False, secure=False, **extra): """Send a TRACE request to the server.""" self.extra = extra response = super().trace(path, data=data, secure=secure, **extra) @@ -810,7 +954,7 @@ class Client(ClientMixin, RequestFactory): response = self._handle_redirects(response, data=data, **extra) return response - def _handle_redirects(self, response, data='', content_type='', **extra): + def _handle_redirects(self, response, data="", content_type="", **extra): """ Follow any redirects by requesting responses from the server using GET. """ @@ -829,39 +973,46 @@ class Client(ClientMixin, RequestFactory): url = urlsplit(response_url) if url.scheme: - extra['wsgi.url_scheme'] = url.scheme + extra["wsgi.url_scheme"] = url.scheme if url.hostname: - extra['SERVER_NAME'] = url.hostname + extra["SERVER_NAME"] = url.hostname if url.port: - extra['SERVER_PORT'] = str(url.port) + extra["SERVER_PORT"] = str(url.port) path = url.path # RFC 2616: bare domains without path are treated as the root. if not path and url.netloc: - path = '/' + path = "/" # Prepend the request path to handle relative path redirects - if not path.startswith('/'): - path = urljoin(response.request['PATH_INFO'], path) + if not path.startswith("/"): + path = urljoin(response.request["PATH_INFO"], path) - if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT): + if response.status_code in ( + HTTPStatus.TEMPORARY_REDIRECT, + HTTPStatus.PERMANENT_REDIRECT, + ): # Preserve request method and query string (if needed) # post-redirect for 307/308 responses. - request_method = response.request['REQUEST_METHOD'].lower() - if request_method not in ('get', 'head'): - extra['QUERY_STRING'] = url.query + request_method = response.request["REQUEST_METHOD"].lower() + if request_method not in ("get", "head"): + extra["QUERY_STRING"] = url.query request_method = getattr(self, request_method) else: request_method = self.get data = QueryDict(url.query) content_type = None - response = request_method(path, data=data, content_type=content_type, follow=False, **extra) + response = request_method( + path, data=data, content_type=content_type, follow=False, **extra + ) response.redirect_chain = redirect_chain if redirect_chain[-1] in redirect_chain[:-1]: # Check that we're not redirecting to somewhere we've already # been to, to prevent loops. - raise RedirectCycleError("Redirect loop detected.", last_response=response) + raise RedirectCycleError( + "Redirect loop detected.", last_response=response + ) if len(redirect_chain) > 20: # Such a lengthy chain likely also means a loop, but one with # a growing path, changing view, or changing query argument; @@ -878,7 +1029,10 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): Does not currently support "follow" on its methods. """ - def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults): + + def __init__( + self, enforce_csrf_checks=False, raise_request_exception=True, **defaults + ): super().__init__(**defaults) self.handler = AsyncClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception @@ -892,19 +1046,19 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): query environment, which can be overridden using the arguments to the request. """ - if 'follow' in request: + if "follow" in request: raise NotImplementedError( - 'AsyncClient request methods do not accept the follow parameter.' + "AsyncClient request methods do not accept the follow parameter." ) scope = self._base_scope(**request) # Curry a data dictionary into an instance of the template renderer # callback function. data = {} on_template_render = partial(store_rendered_templates, data) - signal_uid = 'template-render-%s' % id(request) + signal_uid = "template-render-%s" % id(request) signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) # Capture exceptions created by the handler. - exception_uid = 'request-exception-%s' % id(request) + exception_uid = "request-exception-%s" % id(request) got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid) try: response = await self.handler(scope) @@ -917,13 +1071,13 @@ class AsyncClient(ClientMixin, AsyncRequestFactory): response.client = self response.request = request # Add any rendered template detail to the response. - response.templates = data.get('templates', []) - response.context = data.get('context') + response.templates = data.get("templates", []) + response.context = data.get("context") response.json = partial(self._parse_json, response) # Attach the ResolverMatch instance to the response. - urlconf = getattr(response.asgi_request, 'urlconf', None) + urlconf = getattr(response.asgi_request, "urlconf", None) response.resolver_match = SimpleLazyObject( - lambda: resolve(request['path'], urlconf=urlconf), + lambda: resolve(request["path"], urlconf=urlconf), ) # Flatten a single context. Not really necessary anymore thanks to the # __getattr__ flattening in ContextList, but has some edge case diff --git a/django/test/html.py b/django/test/html.py index 07e986439b..87e213d651 100644 --- a/django/test/html.py +++ b/django/test/html.py @@ -7,32 +7,52 @@ from django.utils.regex_helper import _lazy_re_compile # ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020 # SPACE. # https://infra.spec.whatwg.org/#ascii-whitespace -ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+') +ASCII_WHITESPACE = _lazy_re_compile(r"[\t\n\f\r ]+") # https://html.spec.whatwg.org/#attributes-3 BOOLEAN_ATTRIBUTES = { - 'allowfullscreen', 'async', 'autofocus', 'autoplay', 'checked', 'controls', - 'default', 'defer ', 'disabled', 'formnovalidate', 'hidden', 'ismap', - 'itemscope', 'loop', 'multiple', 'muted', 'nomodule', 'novalidate', 'open', - 'playsinline', 'readonly', 'required', 'reversed', 'selected', + "allowfullscreen", + "async", + "autofocus", + "autoplay", + "checked", + "controls", + "default", + "defer ", + "disabled", + "formnovalidate", + "hidden", + "ismap", + "itemscope", + "loop", + "multiple", + "muted", + "nomodule", + "novalidate", + "open", + "playsinline", + "readonly", + "required", + "reversed", + "selected", # Attributes for deprecated tags. - 'truespeed', + "truespeed", } def normalize_whitespace(string): - return ASCII_WHITESPACE.sub(' ', string) + return ASCII_WHITESPACE.sub(" ", string) def normalize_attributes(attributes): normalized = [] for name, value in attributes: - if name == 'class' and value: + if name == "class" and value: # Special case handling of 'class' attribute, so that comparisons # of DOM instances are not sensitive to ordering of classes. - value = ' '.join(sorted( - value for value in ASCII_WHITESPACE.split(value) if value - )) + value = " ".join( + sorted(value for value in ASCII_WHITESPACE.split(value) if value) + ) # Boolean attributes without a value is same as attribute with value # that equals the attributes name. For example: # <input checked> == <input checked="checked"> @@ -40,7 +60,7 @@ def normalize_attributes(attributes): if not value or value == name: value = None elif value is None: - value = '' + value = "" normalized.append((name, value)) return normalized @@ -80,11 +100,11 @@ class Element: for i, child in enumerate(self.children): if isinstance(child, str): self.children[i] = child.strip() - elif hasattr(child, 'finalize'): + elif hasattr(child, "finalize"): child.finalize() def __eq__(self, element): - if not hasattr(element, 'name') or self.name != element.name: + if not hasattr(element, "name") or self.name != element.name: return False if self.attributes != element.attributes: return False @@ -142,21 +162,23 @@ class Element: return self.children[key] def __str__(self): - output = '<%s' % self.name + output = "<%s" % self.name for key, value in self.attributes: if value is not None: output += ' %s="%s"' % (key, value) else: - output += ' %s' % key + output += " %s" % key if self.children: - output += '>\n' - output += ''.join([ - html.escape(c) if isinstance(c, str) else str(c) - for c in self.children - ]) - output += '\n</%s>' % self.name + output += ">\n" + output += "".join( + [ + html.escape(c) if isinstance(c, str) else str(c) + for c in self.children + ] + ) + output += "\n</%s>" % self.name else: - output += '>' + output += ">" return output def __repr__(self): @@ -168,10 +190,9 @@ class RootElement(Element): super().__init__(None, ()) def __str__(self): - return ''.join([ - html.escape(c) if isinstance(c, str) else str(c) - for c in self.children - ]) + return "".join( + [html.escape(c) if isinstance(c, str) else str(c) for c in self.children] + ) class HTMLParseError(Exception): @@ -181,10 +202,23 @@ class HTMLParseError(Exception): class Parser(HTMLParser): # https://html.spec.whatwg.org/#void-elements SELF_CLOSING_TAGS = { - 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta', - 'param', 'source', 'track', 'wbr', + "area", + "base", + "br", + "col", + "embed", + "hr", + "img", + "input", + "link", + "meta", + "param", + "source", + "track", + "wbr", # Deprecated tags - 'frame', 'spacer', + "frame", + "spacer", } def __init__(self): @@ -201,9 +235,9 @@ class Parser(HTMLParser): position = self.element_positions[element] if position is None: position = self.getpos() - if hasattr(position, 'lineno'): + if hasattr(position, "lineno"): position = position.lineno, position.offset - return 'Line %d, Column %d' % position + return "Line %d, Column %d" % position @property def current(self): @@ -227,13 +261,13 @@ class Parser(HTMLParser): def handle_endtag(self, tag): if not self.open_tags: - self.error("Unexpected end tag `%s` (%s)" % ( - tag, self.format_position())) + self.error("Unexpected end tag `%s` (%s)" % (tag, self.format_position())) element = self.open_tags.pop() while element.name != tag: if not self.open_tags: - self.error("Unexpected end tag `%s` (%s)" % ( - tag, self.format_position())) + self.error( + "Unexpected end tag `%s` (%s)" % (tag, self.format_position()) + ) element = self.open_tags.pop() def handle_data(self, data): diff --git a/django/test/runner.py b/django/test/runner.py index 2e36514922..113d5216a6 100644 --- a/django/test/runner.py +++ b/django/test/runner.py @@ -20,11 +20,11 @@ from io import StringIO from django.core.management import call_command from django.db import connections from django.test import SimpleTestCase, TestCase -from django.test.utils import ( - NullTimeKeeper, TimeKeeper, iter_test_cases, - setup_databases as _setup_databases, setup_test_environment, - teardown_databases as _teardown_databases, teardown_test_environment, -) +from django.test.utils import NullTimeKeeper, TimeKeeper, iter_test_cases +from django.test.utils import setup_databases as _setup_databases +from django.test.utils import setup_test_environment +from django.test.utils import teardown_databases as _teardown_databases +from django.test.utils import teardown_test_environment from django.utils.crypto import new_hash from django.utils.datastructures import OrderedSet from django.utils.deprecation import RemovedInDjango50Warning @@ -42,7 +42,7 @@ except ImportError: class DebugSQLTextTestResult(unittest.TextTestResult): def __init__(self, stream, descriptions, verbosity): - self.logger = logging.getLogger('django.db.backends') + self.logger = logging.getLogger("django.db.backends") self.logger.setLevel(logging.DEBUG) self.debug_sql_stream = None super().__init__(stream, descriptions, verbosity) @@ -65,7 +65,7 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().addError(test, err) if self.debug_sql_stream is None: # Error before tests e.g. in setUpTestData(). - sql = '' + sql = "" else: self.debug_sql_stream.seek(0) sql = self.debug_sql_stream.read() @@ -80,7 +80,11 @@ class DebugSQLTextTestResult(unittest.TextTestResult): super().addSubTest(test, subtest, err) if err is not None: self.debug_sql_stream.seek(0) - errors = self.failures if issubclass(err[0], test.failureException) else self.errors + errors = ( + self.failures + if issubclass(err[0], test.failureException) + else self.errors + ) errors[-1] = errors[-1] + (self.debug_sql_stream.read(),) def printErrorList(self, flavour, errors): @@ -124,6 +128,7 @@ class DummyList: """ Dummy list class for faking storage of results in unittest.TestResult. """ + __slots__ = () def append(self, item): @@ -157,10 +162,10 @@ class RemoteTestResult(unittest.TestResult): # attributes. This is possible since they aren't used after unpickling # after being sent to ParallelTestSuite. state = self.__dict__.copy() - state.pop('_stdout_buffer', None) - state.pop('_stderr_buffer', None) - state.pop('_original_stdout', None) - state.pop('_original_stderr', None) + state.pop("_stdout_buffer", None) + state.pop("_stderr_buffer", None) + state.pop("_original_stdout", None) + state.pop("_original_stderr", None) return state @property @@ -176,7 +181,8 @@ class RemoteTestResult(unittest.TestResult): pickle.loads(pickle.dumps(obj)) def _print_unpicklable_subtest(self, test, subtest, pickle_exc): - print(""" + print( + """ Subtest failed: test: {} @@ -189,7 +195,10 @@ test runner cannot handle it cleanly. Here is the pickling error: You should re-run this test with --parallel=1 to reproduce the failure with a cleaner failure message. -""".format(test, subtest, pickle_exc)) +""".format( + test, subtest, pickle_exc + ) + ) def check_picklable(self, test, err): # Ensure that sys.exc_info() tuples are picklable. This displays a @@ -202,11 +211,16 @@ with a cleaner failure message. self._confirm_picklable(err) except Exception as exc: original_exc_txt = repr(err[1]) - original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + original_exc_txt = textwrap.fill( + original_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) pickle_exc_txt = repr(exc) - pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ') + pickle_exc_txt = textwrap.fill( + pickle_exc_txt, 75, initial_indent=" ", subsequent_indent=" " + ) if tblib is None: - print(""" + print( + """ {} failed: @@ -218,9 +232,13 @@ parallel test runner to handle this exception cleanly. In order to see the traceback, you should install tblib: python -m pip install tblib -""".format(test, original_exc_txt)) +""".format( + test, original_exc_txt + ) + ) else: - print(""" + print( + """ {} failed: @@ -235,7 +253,10 @@ Here's the error encountered while trying to pickle the exception: You should re-run this test with the --parallel=1 option to reproduce the failure and get a correct traceback. -""".format(test, original_exc_txt, pickle_exc_txt)) +""".format( + test, original_exc_txt, pickle_exc_txt + ) + ) raise def check_subtest_picklable(self, test, subtest): @@ -247,28 +268,28 @@ failure and get a correct traceback. def startTestRun(self): super().startTestRun() - self.events.append(('startTestRun',)) + self.events.append(("startTestRun",)) def stopTestRun(self): super().stopTestRun() - self.events.append(('stopTestRun',)) + self.events.append(("stopTestRun",)) def startTest(self, test): super().startTest(test) - self.events.append(('startTest', self.test_index)) + self.events.append(("startTest", self.test_index)) def stopTest(self, test): super().stopTest(test) - self.events.append(('stopTest', self.test_index)) + self.events.append(("stopTest", self.test_index)) def addError(self, test, err): self.check_picklable(test, err) - self.events.append(('addError', self.test_index, err)) + self.events.append(("addError", self.test_index, err)) super().addError(test, err) def addFailure(self, test, err): self.check_picklable(test, err) - self.events.append(('addFailure', self.test_index, err)) + self.events.append(("addFailure", self.test_index, err)) super().addFailure(test, err) def addSubTest(self, test, subtest, err): @@ -279,15 +300,15 @@ failure and get a correct traceback. # check_picklable() performs the tblib check. self.check_picklable(test, err) self.check_subtest_picklable(test, subtest) - self.events.append(('addSubTest', self.test_index, subtest, err)) + self.events.append(("addSubTest", self.test_index, subtest, err)) super().addSubTest(test, subtest, err) def addSuccess(self, test): - self.events.append(('addSuccess', self.test_index)) + self.events.append(("addSuccess", self.test_index)) super().addSuccess(test) def addSkip(self, test, reason): - self.events.append(('addSkip', self.test_index, reason)) + self.events.append(("addSkip", self.test_index, reason)) super().addSkip(test, reason) def addExpectedFailure(self, test, err): @@ -298,23 +319,23 @@ failure and get a correct traceback. if tblib is None: err = err[0], err[1], None self.check_picklable(test, err) - self.events.append(('addExpectedFailure', self.test_index, err)) + self.events.append(("addExpectedFailure", self.test_index, err)) super().addExpectedFailure(test, err) def addUnexpectedSuccess(self, test): - self.events.append(('addUnexpectedSuccess', self.test_index)) + self.events.append(("addUnexpectedSuccess", self.test_index)) super().addUnexpectedSuccess(test) def wasSuccessful(self): """Tells whether or not this result was a success.""" - failure_types = {'addError', 'addFailure', 'addSubTest', 'addUnexpectedSuccess'} + failure_types = {"addError", "addFailure", "addSubTest", "addUnexpectedSuccess"} return all(e[0] not in failure_types for e in self.events) def _exc_info_to_string(self, err, test): # Make this method no-op. It only powers the default unittest behavior # for recording errors, but this class pickles errors into 'events' # instead. - return '' + return "" class RemoteTestRunner: @@ -347,17 +368,17 @@ def get_max_test_processes(): """ # The current implementation of the parallel test runner requires # multiprocessing to start subprocesses with fork(). - if multiprocessing.get_start_method() != 'fork': + if multiprocessing.get_start_method() != "fork": return 1 try: - return int(os.environ['DJANGO_TEST_PROCESSES']) + return int(os.environ["DJANGO_TEST_PROCESSES"]) except KeyError: return multiprocessing.cpu_count() def parallel_type(value): """Parse value passed to the --parallel option.""" - if value == 'auto': + if value == "auto": return value try: return int(value) @@ -505,30 +526,30 @@ class Shuffler: """ # This doesn't need to be cryptographically strong, so use what's fastest. - hash_algorithm = 'md5' + hash_algorithm = "md5" @classmethod def _hash_text(cls, text): h = new_hash(cls.hash_algorithm, usedforsecurity=False) - h.update(text.encode('utf-8')) + h.update(text.encode("utf-8")) return h.hexdigest() def __init__(self, seed=None): if seed is None: # Limit seeds to 10 digits for simpler output. seed = random.randint(0, 10**10 - 1) - seed_source = 'generated' + seed_source = "generated" else: - seed_source = 'given' + seed_source = "given" self.seed = seed self.seed_source = seed_source @property def seed_display(self): - return f'{self.seed!r} ({self.seed_source})' + return f"{self.seed!r} ({self.seed_source})" def _hash_item(self, item, key): - text = '{}{}'.format(self.seed, key(item)) + text = "{}{}".format(self.seed, key(item)) return self._hash_text(text) def shuffle(self, items, key): @@ -544,8 +565,10 @@ class Shuffler: for item in items: hashed = self._hash_item(item, key) if hashed in hashes: - msg = 'item {!r} has same hash {!r} as item {!r}'.format( - item, hashed, hashes[hashed], + msg = "item {!r} has same hash {!r} as item {!r}".format( + item, + hashed, + hashes[hashed], ) raise RuntimeError(msg) hashes[hashed] = item @@ -561,12 +584,29 @@ class DiscoverRunner: test_loader = unittest.defaultTestLoader reorder_by = (TestCase, SimpleTestCase) - def __init__(self, pattern=None, top_level=None, verbosity=1, - interactive=True, failfast=False, keepdb=False, - reverse=False, debug_mode=False, debug_sql=False, parallel=0, - tags=None, exclude_tags=None, test_name_patterns=None, - pdb=False, buffer=False, enable_faulthandler=True, - timing=False, shuffle=False, logger=None, **kwargs): + def __init__( + self, + pattern=None, + top_level=None, + verbosity=1, + interactive=True, + failfast=False, + keepdb=False, + reverse=False, + debug_mode=False, + debug_sql=False, + parallel=0, + tags=None, + exclude_tags=None, + test_name_patterns=None, + pdb=False, + buffer=False, + enable_faulthandler=True, + timing=False, + shuffle=False, + logger=None, + **kwargs, + ): self.pattern = pattern self.top_level = top_level @@ -587,7 +627,9 @@ class DiscoverRunner: faulthandler.enable(file=sys.__stderr__.fileno()) self.pdb = pdb if self.pdb and self.parallel > 1: - raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.') + raise ValueError( + "You cannot use --pdb with parallel tests; pass --parallel=1 to use it." + ) self.buffer = buffer self.test_name_patterns = None self.time_keeper = TimeKeeper() if timing else NullTimeKeeper() @@ -595,7 +637,7 @@ class DiscoverRunner: # unittest does not export the _convert_select_pattern function # that converts command-line arguments to patterns. self.test_name_patterns = { - pattern if '*' in pattern else '*%s*' % pattern + pattern if "*" in pattern else "*%s*" % pattern for pattern in test_name_patterns } self.shuffle = shuffle @@ -605,73 +647,99 @@ class DiscoverRunner: @classmethod def add_arguments(cls, parser): parser.add_argument( - '-t', '--top-level-directory', dest='top_level', - help='Top level of project for unittest discovery.', + "-t", + "--top-level-directory", + dest="top_level", + help="Top level of project for unittest discovery.", ) parser.add_argument( - '-p', '--pattern', default="test*.py", - help='The test matching pattern. Defaults to test*.py.', + "-p", + "--pattern", + default="test*.py", + help="The test matching pattern. Defaults to test*.py.", ) parser.add_argument( - '--keepdb', action='store_true', - help='Preserves the test DB between runs.' + "--keepdb", action="store_true", help="Preserves the test DB between runs." ) parser.add_argument( - '--shuffle', nargs='?', default=False, type=int, metavar='SEED', - help='Shuffles test case order.', + "--shuffle", + nargs="?", + default=False, + type=int, + metavar="SEED", + help="Shuffles test case order.", ) parser.add_argument( - '-r', '--reverse', action='store_true', - help='Reverses test case order.', + "-r", + "--reverse", + action="store_true", + help="Reverses test case order.", ) parser.add_argument( - '--debug-mode', action='store_true', - help='Sets settings.DEBUG to True.', + "--debug-mode", + action="store_true", + help="Sets settings.DEBUG to True.", ) parser.add_argument( - '-d', '--debug-sql', action='store_true', - help='Prints logged SQL queries on failure.', + "-d", + "--debug-sql", + action="store_true", + help="Prints logged SQL queries on failure.", ) parser.add_argument( - '--parallel', nargs='?', const='auto', default=0, - type=parallel_type, metavar='N', + "--parallel", + nargs="?", + const="auto", + default=0, + type=parallel_type, + metavar="N", help=( - 'Run tests using up to N parallel processes. Use the value ' + "Run tests using up to N parallel processes. Use the value " '"auto" to run one test process for each processor core.' ), ) parser.add_argument( - '--tag', action='append', dest='tags', - help='Run only tests with the specified tag. Can be used multiple times.', + "--tag", + action="append", + dest="tags", + help="Run only tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--exclude-tag', action='append', dest='exclude_tags', - help='Do not run tests with the specified tag. Can be used multiple times.', + "--exclude-tag", + action="append", + dest="exclude_tags", + help="Do not run tests with the specified tag. Can be used multiple times.", ) parser.add_argument( - '--pdb', action='store_true', - help='Runs a debugger (pdb, or ipdb if installed) on error or failure.' + "--pdb", + action="store_true", + help="Runs a debugger (pdb, or ipdb if installed) on error or failure.", ) parser.add_argument( - '-b', '--buffer', action='store_true', - help='Discard output from passing tests.', + "-b", + "--buffer", + action="store_true", + help="Discard output from passing tests.", ) parser.add_argument( - '--no-faulthandler', action='store_false', dest='enable_faulthandler', - help='Disables the Python faulthandler module during tests.', + "--no-faulthandler", + action="store_false", + dest="enable_faulthandler", + help="Disables the Python faulthandler module during tests.", ) parser.add_argument( - '--timing', action='store_true', + "--timing", + action="store_true", + help=("Output timings, including database set up and total run time."), + ) + parser.add_argument( + "-k", + action="append", + dest="test_name_patterns", help=( - 'Output timings, including database set up and total run time.' - ), - ) - parser.add_argument( - '-k', action='append', dest='test_name_patterns', - help=( - 'Only run test methods and classes that match the pattern ' - 'or substring. Can be used multiple times. Same as ' - 'unittest -k option.' + "Only run test methods and classes that match the pattern " + "or substring. Can be used multiple times. Same as " + "unittest -k option." ), ) @@ -693,9 +761,7 @@ class DiscoverRunner: if level is None: level = logging.INFO if self.logger is None: - if self.verbosity <= 0 or ( - self.verbosity == 1 and level < logging.INFO - ): + if self.verbosity <= 0 or (self.verbosity == 1 and level < logging.INFO): return print(msg) else: @@ -709,7 +775,7 @@ class DiscoverRunner: if self.shuffle is False: return shuffler = Shuffler(seed=self.shuffle) - self.log(f'Using shuffle seed: {shuffler.seed_display}') + self.log(f"Using shuffle seed: {shuffler.seed_display}") self._shuffler = shuffler @contextmanager @@ -741,15 +807,15 @@ class DiscoverRunner: if os.path.exists(label_as_path): assert tests is None raise RuntimeError( - f'One of the test labels is a path to a file: {label!r}, ' - f'which is not supported. Use a dotted module name or ' - f'path to a directory instead.' + f"One of the test labels is a path to a file: {label!r}, " + f"which is not supported. Use a dotted module name or " + f"path to a directory instead." ) return tests kwargs = discover_kwargs.copy() if os.path.isdir(label_as_path) and not self.top_level: - kwargs['top_level_dir'] = find_top_level(label_as_path) + kwargs["top_level_dir"] = find_top_level(label_as_path) with self.load_with_patterns(): tests = self.test_loader.discover(start_dir=label, **kwargs) @@ -762,18 +828,18 @@ class DiscoverRunner: def build_suite(self, test_labels=None, extra_tests=None, **kwargs): if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) - test_labels = test_labels or ['.'] + test_labels = test_labels or ["."] extra_tests = extra_tests or [] discover_kwargs = {} if self.pattern is not None: - discover_kwargs['pattern'] = self.pattern + discover_kwargs["pattern"] = self.pattern if self.top_level is not None: - discover_kwargs['top_level_dir'] = self.top_level + discover_kwargs["top_level_dir"] = self.top_level self.setup_shuffler() all_tests = [] @@ -786,12 +852,12 @@ class DiscoverRunner: if self.tags or self.exclude_tags: if self.tags: self.log( - 'Including test tag(s): %s.' % ', '.join(sorted(self.tags)), + "Including test tag(s): %s." % ", ".join(sorted(self.tags)), level=logging.DEBUG, ) if self.exclude_tags: self.log( - 'Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)), + "Excluding test tag(s): %s." % ", ".join(sorted(self.exclude_tags)), level=logging.DEBUG, ) all_tests = filter_tests_by_tags(all_tests, self.tags, self.exclude_tags) @@ -800,13 +866,15 @@ class DiscoverRunner: # _FailedTest objects include things like test modules that couldn't be # found or that couldn't be loaded due to syntax errors. test_types = (unittest.loader._FailedTest, *self.reorder_by) - all_tests = list(reorder_tests( - all_tests, - test_types, - shuffler=self._shuffler, - reverse=self.reverse, - )) - self.log('Found %d test(s).' % len(all_tests)) + all_tests = list( + reorder_tests( + all_tests, + test_types, + shuffler=self._shuffler, + reverse=self.reverse, + ) + ) + self.log("Found %d test(s)." % len(all_tests)) suite = self.test_suite(all_tests) if self.parallel > 1: @@ -828,8 +896,13 @@ class DiscoverRunner: def setup_databases(self, **kwargs): return _setup_databases( - self.verbosity, self.interactive, time_keeper=self.time_keeper, keepdb=self.keepdb, - debug_sql=self.debug_sql, parallel=self.parallel, **kwargs + self.verbosity, + self.interactive, + time_keeper=self.time_keeper, + keepdb=self.keepdb, + debug_sql=self.debug_sql, + parallel=self.parallel, + **kwargs, ) def get_resultclass(self): @@ -840,16 +913,16 @@ class DiscoverRunner: def get_test_runner_kwargs(self): return { - 'failfast': self.failfast, - 'resultclass': self.get_resultclass(), - 'verbosity': self.verbosity, - 'buffer': self.buffer, + "failfast": self.failfast, + "resultclass": self.get_resultclass(), + "verbosity": self.verbosity, + "buffer": self.buffer, } def run_checks(self, databases): # Checks are run after database creation since some checks require # database access. - call_command('check', verbosity=self.verbosity, databases=databases) + call_command("check", verbosity=self.verbosity, databases=databases) def run_suite(self, suite, **kwargs): kwargs = self.get_test_runner_kwargs() @@ -859,7 +932,7 @@ class DiscoverRunner: finally: if self._shuffler is not None: seed_display = self._shuffler.seed_display - self.log(f'Used shuffle seed: {seed_display}') + self.log(f"Used shuffle seed: {seed_display}") def teardown_databases(self, old_config, **kwargs): """Destroy all the non-mirror databases.""" @@ -875,16 +948,18 @@ class DiscoverRunner: teardown_test_environment() def suite_result(self, suite, result, **kwargs): - return len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses) + return ( + len(result.failures) + len(result.errors) + len(result.unexpectedSuccesses) + ) def _get_databases(self, suite): databases = {} for test in iter_test_cases(suite): - test_databases = getattr(test, 'databases', None) - if test_databases == '__all__': + test_databases = getattr(test, "databases", None) + if test_databases == "__all__": test_databases = connections if test_databases: - serialized_rollback = getattr(test, 'serialized_rollback', False) + serialized_rollback = getattr(test, "serialized_rollback", False) databases.update( (alias, serialized_rollback or databases.get(alias, False)) for alias in test_databases @@ -896,7 +971,8 @@ class DiscoverRunner: unused_databases = [alias for alias in connections if alias not in databases] if unused_databases: self.log( - 'Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)), + "Skipping setup of unused database(s): %s." + % ", ".join(sorted(unused_databases)), level=logging.DEBUG, ) return databases @@ -912,7 +988,7 @@ class DiscoverRunner: """ if extra_tests is not None: warnings.warn( - 'The extra_tests argument is deprecated.', + "The extra_tests argument is deprecated.", RemovedInDjango50Warning, stacklevel=2, ) @@ -920,10 +996,9 @@ class DiscoverRunner: suite = self.build_suite(test_labels, extra_tests) databases = self.get_databases(suite) serialized_aliases = set( - alias - for alias, serialize in databases.items() if serialize + alias for alias, serialize in databases.items() if serialize ) - with self.time_keeper.timed('Total database setup'): + with self.time_keeper.timed("Total database setup"): old_config = self.setup_databases( aliases=databases, serialized_aliases=serialized_aliases, @@ -937,7 +1012,7 @@ class DiscoverRunner: raise finally: try: - with self.time_keeper.timed('Total database teardown'): + with self.time_keeper.timed("Total database teardown"): self.teardown_databases(old_config) self.teardown_test_environment() except Exception: @@ -960,7 +1035,7 @@ def try_importing(label): except (ImportError, TypeError): return (False, False) - return (True, hasattr(mod, '__path__')) + return (True, hasattr(mod, "__path__")) def find_top_level(top_level): @@ -976,7 +1051,7 @@ def find_top_level(top_level): # top-level module or as a directory path, unittest unfortunately prefers # the latter. while True: - init_py = os.path.join(top_level, '__init__.py') + init_py = os.path.join(top_level, "__init__.py") if not os.path.exists(init_py): break try_next = os.path.dirname(top_level) @@ -988,7 +1063,7 @@ def find_top_level(top_level): def _class_shuffle_key(cls): - return f'{cls.__module__}.{cls.__qualname__}' + return f"{cls.__module__}.{cls.__qualname__}" def shuffle_tests(tests, shuffler): @@ -1073,9 +1148,7 @@ def partition_suite_by_case(suite): """Partition a test suite by test case, preserving the order of tests.""" suite_class = type(suite) all_tests = iter_test_cases(suite) - return [ - suite_class(tests) for _, tests in itertools.groupby(all_tests, type) - ] + return [suite_class(tests) for _, tests in itertools.groupby(all_tests, type)] def test_match_tags(test, tags, exclude_tags): @@ -1083,11 +1156,11 @@ def test_match_tags(test, tags, exclude_tags): # Tests that couldn't load always match to prevent tests from falsely # passing due e.g. to syntax errors. return True - test_tags = set(getattr(test, 'tags', [])) - test_fn_name = getattr(test, '_testMethodName', str(test)) + test_tags = set(getattr(test, "tags", [])) + test_fn_name = getattr(test, "_testMethodName", str(test)) if hasattr(test, test_fn_name): test_fn = getattr(test, test_fn_name) - test_fn_tags = list(getattr(test_fn, 'tags', [])) + test_fn_tags = list(getattr(test_fn, "tags", [])) test_tags = test_tags.union(test_fn_tags) if tags and test_tags.isdisjoint(tags): return False diff --git a/django/test/selenium.py b/django/test/selenium.py index 97a7840fea..aa714ad365 100644 --- a/django/test/selenium.py +++ b/django/test/selenium.py @@ -27,7 +27,9 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): """ test_class = super().__new__(cls, name, bases, attrs) # If the test class is either browser-specific or a test base, return it. - if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()): + if test_class.browser or not any( + name.startswith("test") and callable(value) for name, value in attrs.items() + ): return test_class elif test_class.browsers: # Reuse the created test class to make it browser-specific. @@ -37,7 +39,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): first_browser = test_class.browsers[0] test_class.browser = first_browser # Listen on an external interface if using a selenium hub. - host = test_class.host if not test_class.selenium_hub else '0.0.0.0' + host = test_class.host if not test_class.selenium_hub else "0.0.0.0" test_class.host = host test_class.external_host = cls.external_host # Create subclasses for each of the remaining browsers and expose @@ -49,16 +51,16 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): "%s%s" % (capfirst(browser), name), (test_class,), { - 'browser': browser, - 'host': host, - 'external_host': cls.external_host, - '__module__': test_class.__module__, - } + "browser": browser, + "host": host, + "external_host": cls.external_host, + "__module__": test_class.__module__, + }, ) setattr(module, browser_test_class.__name__, browser_test_class) return test_class # If no browsers were specified, skip this class (it'll still be discovered). - return unittest.skip('No browsers specified.')(test_class) + return unittest.skip("No browsers specified.")(test_class) @classmethod def import_webdriver(cls, browser): @@ -66,13 +68,12 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): @classmethod def import_options(cls, browser): - return import_string('selenium.webdriver.%s.options.Options' % browser) + return import_string("selenium.webdriver.%s.options.Options" % browser) @classmethod def get_capability(cls, browser): - from selenium.webdriver.common.desired_capabilities import ( - DesiredCapabilities, - ) + from selenium.webdriver.common.desired_capabilities import DesiredCapabilities + return getattr(DesiredCapabilities, browser.upper()) def create_options(self): @@ -87,6 +88,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): def create_webdriver(self): if self.selenium_hub: from selenium import webdriver + return webdriver.Remote( command_executor=self.selenium_hub, desired_capabilities=self.get_capability(self.browser), @@ -94,14 +96,14 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)): return self.import_webdriver(self.browser)(options=self.create_options()) -@tag('selenium') +@tag("selenium") class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase): implicit_wait = 10 external_host = None @classproperty def live_server_url(cls): - return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port) + return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port) @classproperty def allowed_host(cls): @@ -118,7 +120,7 @@ class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase): # quit() the WebDriver before attempting to terminate and join the # single-threaded LiveServerThread to avoid a dead lock if the browser # kept a connection alive. - if hasattr(cls, 'selenium'): + if hasattr(cls, "selenium"): cls.selenium.quit() super()._tearDownClassInternal() diff --git a/django/test/signals.py b/django/test/signals.py index c82b95013d..c874f220df 100644 --- a/django/test/signals.py +++ b/django/test/signals.py @@ -20,13 +20,14 @@ template_rendered = Signal() # except for cases where the receiver is related to a contrib app. # Settings that may not work well when using 'override_settings' (#19031) -COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'} +COMPLEX_OVERRIDE_SETTINGS = {"DATABASES"} @receiver(setting_changed) def clear_cache_handlers(*, setting, **kwargs): - if setting == 'CACHES': + if setting == "CACHES": from django.core.cache import caches, close_caches + close_caches() caches._settings = caches.settings = caches.configure_settings(None) caches._connections = Local() @@ -34,37 +35,41 @@ def clear_cache_handlers(*, setting, **kwargs): @receiver(setting_changed) def update_installed_apps(*, setting, **kwargs): - if setting == 'INSTALLED_APPS': + if setting == "INSTALLED_APPS": # Rebuild any AppDirectoriesFinder instance. from django.contrib.staticfiles.finders import get_finder + get_finder.cache_clear() # Rebuild management commands cache from django.core.management import get_commands + get_commands.cache_clear() # Rebuild get_app_template_dirs cache. from django.template.utils import get_app_template_dirs + get_app_template_dirs.cache_clear() # Rebuild translations cache. from django.utils.translation import trans_real + trans_real._translations = {} @receiver(setting_changed) def update_connections_time_zone(*, setting, **kwargs): - if setting == 'TIME_ZONE': + if setting == "TIME_ZONE": # Reset process time zone - if hasattr(time, 'tzset'): - if kwargs['value']: - os.environ['TZ'] = kwargs['value'] + if hasattr(time, "tzset"): + if kwargs["value"]: + os.environ["TZ"] = kwargs["value"] else: - os.environ.pop('TZ', None) + os.environ.pop("TZ", None) time.tzset() # Reset local time zone cache timezone.get_default_timezone.cache_clear() # Reset the database connections' time zone - if setting in {'TIME_ZONE', 'USE_TZ'}: + if setting in {"TIME_ZONE", "USE_TZ"}: for conn in connections.all(): try: del conn.timezone @@ -79,18 +84,19 @@ def update_connections_time_zone(*, setting, **kwargs): @receiver(setting_changed) def clear_routers_cache(*, setting, **kwargs): - if setting == 'DATABASE_ROUTERS': + if setting == "DATABASE_ROUTERS": router.routers = ConnectionRouter().routers @receiver(setting_changed) def reset_template_engines(*, setting, **kwargs): if setting in { - 'TEMPLATES', - 'DEBUG', - 'INSTALLED_APPS', + "TEMPLATES", + "DEBUG", + "INSTALLED_APPS", }: from django.template import engines + try: del engines.templates except AttributeError: @@ -98,40 +104,46 @@ def reset_template_engines(*, setting, **kwargs): engines._templates = None engines._engines = {} from django.template.engine import Engine + Engine.get_default.cache_clear() from django.forms.renderers import get_default_renderer + get_default_renderer.cache_clear() @receiver(setting_changed) def clear_serializers_cache(*, setting, **kwargs): - if setting == 'SERIALIZATION_MODULES': + if setting == "SERIALIZATION_MODULES": from django.core import serializers + serializers._serializers = {} @receiver(setting_changed) def language_changed(*, setting, **kwargs): - if setting in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}: + if setting in {"LANGUAGES", "LANGUAGE_CODE", "LOCALE_PATHS"}: from django.utils.translation import trans_real + trans_real._default = None trans_real._active = Local() - if setting in {'LANGUAGES', 'LOCALE_PATHS'}: + if setting in {"LANGUAGES", "LOCALE_PATHS"}: from django.utils.translation import trans_real + trans_real._translations = {} trans_real.check_for_language.cache_clear() @receiver(setting_changed) def localize_settings_changed(*, setting, **kwargs): - if setting in FORMAT_SETTINGS or setting == 'USE_THOUSAND_SEPARATOR': + if setting in FORMAT_SETTINGS or setting == "USE_THOUSAND_SEPARATOR": reset_format_cache() @receiver(setting_changed) def file_storage_changed(*, setting, **kwargs): - if setting == 'DEFAULT_FILE_STORAGE': + if setting == "DEFAULT_FILE_STORAGE": from django.core.files.storage import default_storage + default_storage._wrapped = empty @@ -141,15 +153,16 @@ def complex_setting_changed(*, enter, setting, **kwargs): # Considering the current implementation of the signals framework, # this stacklevel shows the line containing the override_settings call. warnings.warn( - f'Overriding setting {setting} can lead to unexpected behavior.', + f"Overriding setting {setting} can lead to unexpected behavior.", stacklevel=6, ) @receiver(setting_changed) def root_urlconf_changed(*, setting, **kwargs): - if setting == 'ROOT_URLCONF': + if setting == "ROOT_URLCONF": from django.urls import clear_url_caches, set_urlconf + clear_url_caches() set_urlconf(None) @@ -157,55 +170,64 @@ def root_urlconf_changed(*, setting, **kwargs): @receiver(setting_changed) def static_storage_changed(*, setting, **kwargs): if setting in { - 'STATICFILES_STORAGE', - 'STATIC_ROOT', - 'STATIC_URL', + "STATICFILES_STORAGE", + "STATIC_ROOT", + "STATIC_URL", }: from django.contrib.staticfiles.storage import staticfiles_storage + staticfiles_storage._wrapped = empty @receiver(setting_changed) def static_finders_changed(*, setting, **kwargs): if setting in { - 'STATICFILES_DIRS', - 'STATIC_ROOT', + "STATICFILES_DIRS", + "STATIC_ROOT", }: from django.contrib.staticfiles.finders import get_finder + get_finder.cache_clear() @receiver(setting_changed) def auth_password_validators_changed(*, setting, **kwargs): - if setting == 'AUTH_PASSWORD_VALIDATORS': + if setting == "AUTH_PASSWORD_VALIDATORS": from django.contrib.auth.password_validation import ( get_default_password_validators, ) + get_default_password_validators.cache_clear() @receiver(setting_changed) def user_model_swapped(*, setting, **kwargs): - if setting == 'AUTH_USER_MODEL': + if setting == "AUTH_USER_MODEL": apps.clear_cache() try: from django.contrib.auth import get_user_model + UserModel = get_user_model() except ImproperlyConfigured: # Some tests set an invalid AUTH_USER_MODEL. pass else: from django.contrib.auth import backends + backends.UserModel = UserModel from django.contrib.auth import forms + forms.UserModel = UserModel from django.contrib.auth.handlers import modwsgi + modwsgi.UserModel = UserModel from django.contrib.auth.management.commands import changepassword + changepassword.UserModel = UserModel from django.contrib.auth import views + views.UserModel = UserModel diff --git a/django/test/testcases.py b/django/test/testcases.py index d24a065790..d514d06c7a 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -15,7 +15,13 @@ from functools import wraps from unittest.suite import _DebugResult from unittest.util import safe_repr from urllib.parse import ( - parse_qsl, unquote, urlencode, urljoin, urlparse, urlsplit, urlunparse, + parse_qsl, + unquote, + urlencode, + urljoin, + urlparse, + urlsplit, + urlunparse, ) from urllib.request import url2pathname @@ -40,7 +46,10 @@ from django.test.client import AsyncClient, Client from django.test.html import HTMLParseError, parse_html from django.test.signals import template_rendered from django.test.utils import ( - CaptureQueriesContext, ContextList, compare_xml, modify_settings, + CaptureQueriesContext, + ContextList, + compare_xml, + modify_settings, override_settings, ) from django.utils.deprecation import RemovedInDjango50Warning @@ -48,8 +57,13 @@ from django.utils.functional import classproperty from django.utils.version import PY310 from django.views.static import serve -__all__ = ('TestCase', 'TransactionTestCase', - 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature') +__all__ = ( + "TestCase", + "TransactionTestCase", + "SimpleTestCase", + "skipIfDBFeature", + "skipUnlessDBFeature", +) def to_list(value): @@ -63,7 +77,7 @@ def assert_and_parse_html(self, html, user_msg, msg): try: dom = parse_html(html) except HTMLParseError as e: - standardMsg = '%s\n%s' % (msg, e) + standardMsg = "%s\n%s" % (msg, e) self.fail(self._formatMessage(user_msg, standardMsg)) return dom @@ -80,18 +94,22 @@ class _AssertNumQueriesContext(CaptureQueriesContext): return executed = len(self) self.test_case.assertEqual( - executed, self.num, - "%d queries executed, %d expected\nCaptured queries were:\n%s" % ( - executed, self.num, - '\n'.join( - '%d. %s' % (i, query['sql']) for i, query in enumerate(self.captured_queries, start=1) - ) - ) + executed, + self.num, + "%d queries executed, %d expected\nCaptured queries were:\n%s" + % ( + executed, + self.num, + "\n".join( + "%d. %s" % (i, query["sql"]) + for i, query in enumerate(self.captured_queries, start=1) + ), + ), ) class _AssertTemplateUsedContext: - def __init__(self, test_case, template_name, msg_prefix='', count=None): + def __init__(self, test_case, template_name, msg_prefix="", count=None): self.test_case = test_case self.template_name = template_name self.msg_prefix = msg_prefix @@ -108,7 +126,9 @@ class _AssertTemplateUsedContext: def test(self): self.test_case._assert_template_used( - self.template_name, self.rendered_template_names, self.msg_prefix, + self.template_name, + self.rendered_template_names, + self.msg_prefix, self.count, ) @@ -128,7 +148,7 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext): self.test_case.assertFalse( self.template_name in self.rendered_template_names, f"{self.msg_prefix}Template '{self.template_name}' was used " - f"unexpectedly in rendering the response" + f"unexpectedly in rendering the response", ) @@ -156,16 +176,16 @@ class SimpleTestCase(unittest.TestCase): databases = set() _disallowed_database_msg = ( - 'Database %(operation)s to %(alias)r are not allowed in SimpleTestCase ' - 'subclasses. Either subclass TestCase or TransactionTestCase to ensure ' - 'proper test isolation or add %(alias)r to %(test)s.databases to silence ' - 'this failure.' + "Database %(operation)s to %(alias)r are not allowed in SimpleTestCase " + "subclasses. Either subclass TestCase or TransactionTestCase to ensure " + "proper test isolation or add %(alias)r to %(test)s.databases to silence " + "this failure." ) _disallowed_connection_methods = [ - ('connect', 'connections'), - ('temporary_connection', 'connections'), - ('cursor', 'queries'), - ('chunked_cursor', 'queries'), + ("connect", "connections"), + ("temporary_connection", "connections"), + ("cursor", "queries"), + ("chunked_cursor", "queries"), ] @classmethod @@ -184,18 +204,21 @@ class SimpleTestCase(unittest.TestCase): @classmethod def _validate_databases(cls): - if cls.databases == '__all__': + if cls.databases == "__all__": return frozenset(connections) for alias in cls.databases: if alias not in connections: - message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % ( - cls.__module__, - cls.__qualname__, - alias, + message = ( + "%s.%s.databases refers to %r which is not defined in settings.DATABASES." + % ( + cls.__module__, + cls.__qualname__, + alias, + ) ) close_matches = get_close_matches(alias, list(connections)) if close_matches: - message += ' Did you mean %r?' % close_matches[0] + message += " Did you mean %r?" % close_matches[0] raise ImproperlyConfigured(message) return frozenset(cls.databases) @@ -208,9 +231,9 @@ class SimpleTestCase(unittest.TestCase): connection = connections[alias] for name, operation in cls._disallowed_connection_methods: message = cls._disallowed_database_msg % { - 'test': '%s.%s' % (cls.__module__, cls.__qualname__), - 'alias': alias, - 'operation': operation, + "test": "%s.%s" % (cls.__module__, cls.__qualname__), + "alias": alias, + "operation": operation, } method = getattr(connection, name) setattr(connection, name, _DatabaseFailure(method, message)) @@ -247,9 +270,8 @@ class SimpleTestCase(unittest.TestCase): instead of __call__() to run the test. """ testMethod = getattr(self, self._testMethodName) - skipped = ( - getattr(self.__class__, "__unittest_skip__", False) or - getattr(testMethod, "__unittest_skip__", False) + skipped = getattr(self.__class__, "__unittest_skip__", False) or getattr( + testMethod, "__unittest_skip__", False ) # Convert async test methods. @@ -305,9 +327,15 @@ class SimpleTestCase(unittest.TestCase): """ return modify_settings(**kwargs) - def assertRedirects(self, response, expected_url, status_code=302, - target_status_code=200, msg_prefix='', - fetch_redirect_response=True): + def assertRedirects( + self, + response, + expected_url, + status_code=302, + target_status_code=200, + msg_prefix="", + fetch_redirect_response=True, + ): """ Assert that a response redirected to a specific URL and that the redirect URL can be loaded. @@ -319,43 +347,50 @@ class SimpleTestCase(unittest.TestCase): if msg_prefix: msg_prefix += ": " - if hasattr(response, 'redirect_chain'): + if hasattr(response, "redirect_chain"): # The request was a followed redirect self.assertTrue( response.redirect_chain, - msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)" - % (response.status_code, status_code) + msg_prefix + + "Response didn't redirect as expected: Response code was %d (expected %d)" + % (response.status_code, status_code), ) self.assertEqual( - response.redirect_chain[0][1], status_code, - msg_prefix + "Initial response didn't redirect as expected: Response code was %d (expected %d)" - % (response.redirect_chain[0][1], status_code) + response.redirect_chain[0][1], + status_code, + msg_prefix + + "Initial response didn't redirect as expected: Response code was %d (expected %d)" + % (response.redirect_chain[0][1], status_code), ) url, status_code = response.redirect_chain[-1] self.assertEqual( - response.status_code, target_status_code, - msg_prefix + "Response didn't redirect as expected: Final Response code was %d (expected %d)" - % (response.status_code, target_status_code) + response.status_code, + target_status_code, + msg_prefix + + "Response didn't redirect as expected: Final Response code was %d (expected %d)" + % (response.status_code, target_status_code), ) else: # Not a followed redirect self.assertEqual( - response.status_code, status_code, - msg_prefix + "Response didn't redirect as expected: Response code was %d (expected %d)" - % (response.status_code, status_code) + response.status_code, + status_code, + msg_prefix + + "Response didn't redirect as expected: Response code was %d (expected %d)" + % (response.status_code, status_code), ) url = response.url scheme, netloc, path, query, fragment = urlsplit(url) # Prepend the request path to handle relative path redirects. - if not path.startswith('/'): - url = urljoin(response.request['PATH_INFO'], url) - path = urljoin(response.request['PATH_INFO'], path) + if not path.startswith("/"): + url = urljoin(response.request["PATH_INFO"], url) + path = urljoin(response.request["PATH_INFO"], path) if fetch_redirect_response: # netloc might be empty, or in cases where Django tests the @@ -375,21 +410,25 @@ class SimpleTestCase(unittest.TestCase): redirect_response = response.client.get( path, QueryDict(query), - secure=(scheme == 'https'), + secure=(scheme == "https"), **extra, ) self.assertEqual( - redirect_response.status_code, target_status_code, - msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" - % (path, redirect_response.status_code, target_status_code) + redirect_response.status_code, + target_status_code, + msg_prefix + + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" + % (path, redirect_response.status_code, target_status_code), ) self.assertURLEqual( - url, expected_url, - msg_prefix + "Response redirected to '%s', expected '%s'" % (url, expected_url) + url, + expected_url, + msg_prefix + + "Response redirected to '%s', expected '%s'" % (url, expected_url), ) - def assertURLEqual(self, url1, url2, msg_prefix=''): + def assertURLEqual(self, url1, url2, msg_prefix=""): """ Assert that two URLs are the same, ignoring the order of query string parameters except for parameters with the same name. @@ -397,35 +436,44 @@ class SimpleTestCase(unittest.TestCase): For example, /path/?x=1&y=2 is equal to /path/?y=2&x=1, but /path/?a=1&a=2 isn't equal to /path/?a=2&a=1. """ + def normalize(url): """Sort the URL's query string parameters.""" url = str(url) # Coerce reverse_lazy() URLs. scheme, netloc, path, params, query, fragment = urlparse(url) query_parts = sorted(parse_qsl(query)) - return urlunparse((scheme, netloc, path, params, urlencode(query_parts), fragment)) + return urlunparse( + (scheme, netloc, path, params, urlencode(query_parts), fragment) + ) self.assertEqual( - normalize(url1), normalize(url2), - msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2) + normalize(url1), + normalize(url2), + msg_prefix + "Expected '%s' to equal '%s'." % (url1, url2), ) def _assert_contains(self, response, text, status_code, msg_prefix, html): # If the response supports deferred rendering and hasn't been rendered # yet, then ensure that it does get rendered before proceeding further. - if hasattr(response, 'render') and callable(response.render) and not response.is_rendered: + if ( + hasattr(response, "render") + and callable(response.render) + and not response.is_rendered + ): response.render() if msg_prefix: msg_prefix += ": " self.assertEqual( - response.status_code, status_code, + response.status_code, + status_code, msg_prefix + "Couldn't retrieve content: Response code was %d" - " (expected %d)" % (response.status_code, status_code) + " (expected %d)" % (response.status_code, status_code), ) if response.streaming: - content = b''.join(response.streaming_content) + content = b"".join(response.streaming_content) else: content = response.content if not isinstance(text, bytes) or html: @@ -435,12 +483,18 @@ class SimpleTestCase(unittest.TestCase): else: text_repr = repr(text) if html: - content = assert_and_parse_html(self, content, None, "Response's content is not valid HTML:") - text = assert_and_parse_html(self, text, None, "Second argument is not valid HTML:") + content = assert_and_parse_html( + self, content, None, "Response's content is not valid HTML:" + ) + text = assert_and_parse_html( + self, text, None, "Second argument is not valid HTML:" + ) real_count = content.count(text) return (text_repr, real_count, msg_prefix) - def assertContains(self, response, text, count=None, status_code=200, msg_prefix='', html=False): + def assertContains( + self, response, text, count=None, status_code=200, msg_prefix="", html=False + ): """ Assert that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected) and that @@ -449,26 +503,37 @@ class SimpleTestCase(unittest.TestCase): if the text occurs at least once in the response. """ text_repr, real_count, msg_prefix = self._assert_contains( - response, text, status_code, msg_prefix, html) + response, text, status_code, msg_prefix, html + ) if count is not None: self.assertEqual( - real_count, count, - msg_prefix + "Found %d instances of %s in response (expected %d)" % (real_count, text_repr, count) + real_count, + count, + msg_prefix + + "Found %d instances of %s in response (expected %d)" + % (real_count, text_repr, count), ) else: - self.assertTrue(real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr) + self.assertTrue( + real_count != 0, msg_prefix + "Couldn't find %s in response" % text_repr + ) - def assertNotContains(self, response, text, status_code=200, msg_prefix='', html=False): + def assertNotContains( + self, response, text, status_code=200, msg_prefix="", html=False + ): """ Assert that a response indicates that some content was retrieved successfully, (i.e., the HTTP status code was as expected) and that ``text`` doesn't occur in the content of the response. """ text_repr, real_count, msg_prefix = self._assert_contains( - response, text, status_code, msg_prefix, html) + response, text, status_code, msg_prefix, html + ) - self.assertEqual(real_count, 0, msg_prefix + "Response should not contain %s" % text_repr) + self.assertEqual( + real_count, 0, msg_prefix + "Response should not contain %s" % text_repr + ) def _check_test_client_response(self, response, attribute, method_name): """ @@ -481,24 +546,26 @@ class SimpleTestCase(unittest.TestCase): "the Django test Client." ) - def assertFormError(self, response, form, field, errors, msg_prefix=''): + def assertFormError(self, response, form, field, errors, msg_prefix=""): """ Assert that a form used to render the response has a specific field error. """ - self._check_test_client_response(response, 'context', 'assertFormError') + self._check_test_client_response(response, "context", "assertFormError") if msg_prefix: msg_prefix += ": " # Put context(s) into a list to simplify processing. contexts = [] if response.context is None else to_list(response.context) if not contexts: - self.fail(msg_prefix + "Response did not use any contexts to render the response") + self.fail( + msg_prefix + "Response did not use any contexts to render the response" + ) if errors is None: warnings.warn( - 'Passing errors=None to assertFormError() is deprecated, use ' - 'errors=[] instead.', + "Passing errors=None to assertFormError() is deprecated, use " + "errors=[] instead.", RemovedInDjango50Warning, stacklevel=2, ) @@ -520,18 +587,20 @@ class SimpleTestCase(unittest.TestCase): err in field_errors, msg_prefix + "The field '%s' on form '%s' in" " context %d does not contain the error '%s'" - " (actual errors: %s)" % - (field, form, i, err, repr(field_errors)) + " (actual errors: %s)" + % (field, form, i, err, repr(field_errors)), ) elif field in context[form].fields: self.fail( - msg_prefix + "The field '%s' on form '%s' in context %d contains no errors" % - (field, form, i) + msg_prefix + + "The field '%s' on form '%s' in context %d contains no errors" + % (field, form, i) ) else: self.fail( - msg_prefix + "The form '%s' in context %d does not contain the field '%s'" % - (form, i, field) + msg_prefix + + "The form '%s' in context %d does not contain the field '%s'" + % (form, i, field) ) else: non_field_errors = context[form].non_field_errors() @@ -539,14 +608,17 @@ class SimpleTestCase(unittest.TestCase): err in non_field_errors, msg_prefix + "The form '%s' in context %d does not" " contain the non-field error '%s'" - " (actual errors: %s)" % - (form, i, err, non_field_errors or 'none') + " (actual errors: %s)" + % (form, i, err, non_field_errors or "none"), ) if not found_form: - self.fail(msg_prefix + "The form '%s' was not used to render the response" % form) + self.fail( + msg_prefix + "The form '%s' was not used to render the response" % form + ) - def assertFormsetError(self, response, formset, form_index, field, errors, - msg_prefix=''): + def assertFormsetError( + self, response, formset, form_index, field, errors, msg_prefix="" + ): """ Assert that a formset used to render the response has a specific error. @@ -556,7 +628,7 @@ class SimpleTestCase(unittest.TestCase): For non-form errors, specify ``form_index`` as None and the ``field`` as None. """ - self._check_test_client_response(response, 'context', 'assertFormsetError') + self._check_test_client_response(response, "context", "assertFormsetError") # Add punctuation to msg_prefix if msg_prefix: msg_prefix += ": " @@ -564,13 +636,15 @@ class SimpleTestCase(unittest.TestCase): # Put context(s) into a list to simplify processing. contexts = [] if response.context is None else to_list(response.context) if not contexts: - self.fail(msg_prefix + 'Response did not use any contexts to ' - 'render the response') + self.fail( + msg_prefix + "Response did not use any contexts to " + "render the response" + ) if errors is None: warnings.warn( - 'Passing errors=None to assertFormsetError() is deprecated, ' - 'use errors=[] instead.', + "Passing errors=None to assertFormsetError() is deprecated, " + "use errors=[] instead.", RemovedInDjango50Warning, stacklevel=2, ) @@ -581,7 +655,7 @@ class SimpleTestCase(unittest.TestCase): # Search all contexts for the error. found_formset = False for i, context in enumerate(contexts): - if formset not in context or not hasattr(context[formset], 'forms'): + if formset not in context or not hasattr(context[formset], "forms"): continue found_formset = True for err in errors: @@ -592,60 +666,68 @@ class SimpleTestCase(unittest.TestCase): err in field_errors, msg_prefix + "The field '%s' on formset '%s', " "form %d in context %d does not contain the " - "error '%s' (actual errors: %s)" % - (field, formset, form_index, i, err, repr(field_errors)) + "error '%s' (actual errors: %s)" + % (field, formset, form_index, i, err, repr(field_errors)), ) elif field in context[formset].forms[form_index].fields: self.fail( - msg_prefix + "The field '%s' on formset '%s', form %d in context %d contains no errors" + msg_prefix + + "The field '%s' on formset '%s', form %d in context %d contains no errors" % (field, formset, form_index, i) ) else: self.fail( - msg_prefix + "The formset '%s', form %d in context %d does not contain the field '%s'" + msg_prefix + + "The formset '%s', form %d in context %d does not contain the field '%s'" % (formset, form_index, i, field) ) elif form_index is not None: - non_field_errors = context[formset].forms[form_index].non_field_errors() + non_field_errors = ( + context[formset].forms[form_index].non_field_errors() + ) self.assertFalse( not non_field_errors, msg_prefix + "The formset '%s', form %d in context %d " - "does not contain any non-field errors." % (formset, form_index, i) + "does not contain any non-field errors." + % (formset, form_index, i), ) self.assertTrue( err in non_field_errors, msg_prefix + "The formset '%s', form %d in context %d " "does not contain the non-field error '%s' (actual errors: %s)" - % (formset, form_index, i, err, repr(non_field_errors)) + % (formset, form_index, i, err, repr(non_field_errors)), ) else: non_form_errors = context[formset].non_form_errors() self.assertFalse( not non_form_errors, msg_prefix + "The formset '%s' in context %d does not " - "contain any non-form errors." % (formset, i) + "contain any non-form errors." % (formset, i), ) self.assertTrue( err in non_form_errors, msg_prefix + "The formset '%s' in context %d does not " "contain the non-form error '%s' (actual errors: %s)" - % (formset, i, err, repr(non_form_errors)) + % (formset, i, err, repr(non_form_errors)), ) if not found_formset: - self.fail(msg_prefix + "The formset '%s' was not used to render the response" % formset) + self.fail( + msg_prefix + + "The formset '%s' was not used to render the response" % formset + ) def _get_template_used(self, response, template_name, msg_prefix, method_name): if response is None and template_name is None: - raise TypeError('response and/or template_name argument must be provided') + raise TypeError("response and/or template_name argument must be provided") if msg_prefix: msg_prefix += ": " if template_name is not None and response is not None: - self._check_test_client_response(response, 'templates', method_name) + self._check_test_client_response(response, "templates", method_name) - if not hasattr(response, 'templates') or (response is None and template_name): + if not hasattr(response, "templates") or (response is None and template_name): if response: template_name = response response = None @@ -662,38 +744,49 @@ class SimpleTestCase(unittest.TestCase): template_name in template_names, msg_prefix + "Template '%s' was not a template used to render" " the response. Actual template(s) used: %s" - % (template_name, ', '.join(template_names)) + % (template_name, ", ".join(template_names)), ) if count is not None: self.assertEqual( - template_names.count(template_name), count, + template_names.count(template_name), + count, msg_prefix + "Template '%s' was expected to be rendered %d " "time(s) but was actually rendered %d time(s)." - % (template_name, count, template_names.count(template_name)) + % (template_name, count, template_names.count(template_name)), ) - def assertTemplateUsed(self, response=None, template_name=None, msg_prefix='', count=None): + def assertTemplateUsed( + self, response=None, template_name=None, msg_prefix="", count=None + ): """ Assert that the template with the provided name was used in rendering the response. Also usable as context manager. """ context_mgr_template, template_names, msg_prefix = self._get_template_used( - response, template_name, msg_prefix, 'assertTemplateUsed', + response, + template_name, + msg_prefix, + "assertTemplateUsed", ) if context_mgr_template: # Use assertTemplateUsed as context manager. - return _AssertTemplateUsedContext(self, context_mgr_template, msg_prefix, count) + return _AssertTemplateUsedContext( + self, context_mgr_template, msg_prefix, count + ) self._assert_template_used(template_name, template_names, msg_prefix, count) - def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=''): + def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=""): """ Assert that the template with the provided name was NOT used in rendering the response. Also usable as context manager. """ context_mgr_template, template_names, msg_prefix = self._get_template_used( - response, template_name, msg_prefix, 'assertTemplateNotUsed', + response, + template_name, + msg_prefix, + "assertTemplateNotUsed", ) if context_mgr_template: # Use assertTemplateNotUsed as context manager. @@ -701,20 +794,28 @@ class SimpleTestCase(unittest.TestCase): self.assertFalse( template_name in template_names, - msg_prefix + "Template '%s' was used unexpectedly in rendering the response" % template_name + msg_prefix + + "Template '%s' was used unexpectedly in rendering the response" + % template_name, ) @contextmanager - def _assert_raises_or_warns_cm(self, func, cm_attr, expected_exception, expected_message): + def _assert_raises_or_warns_cm( + self, func, cm_attr, expected_exception, expected_message + ): with func(expected_exception) as cm: yield cm self.assertIn(expected_message, str(getattr(cm, cm_attr))) - def _assertFooMessage(self, func, cm_attr, expected_exception, expected_message, *args, **kwargs): + def _assertFooMessage( + self, func, cm_attr, expected_exception, expected_message, *args, **kwargs + ): callable_obj = None if args: callable_obj, *args = args - cm = self._assert_raises_or_warns_cm(func, cm_attr, expected_exception, expected_message) + cm = self._assert_raises_or_warns_cm( + func, cm_attr, expected_exception, expected_message + ) # Assertion used in context manager fashion. if callable_obj is None: return cm @@ -722,7 +823,9 @@ class SimpleTestCase(unittest.TestCase): with cm: callable_obj(*args, **kwargs) - def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs): + def assertRaisesMessage( + self, expected_exception, expected_message, *args, **kwargs + ): """ Assert that expected_message is found in the message of a raised exception. @@ -734,8 +837,12 @@ class SimpleTestCase(unittest.TestCase): kwargs: Extra kwargs. """ return self._assertFooMessage( - self.assertRaises, 'exception', expected_exception, expected_message, - *args, **kwargs + self.assertRaises, + "exception", + expected_exception, + expected_message, + *args, + **kwargs, ) def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs): @@ -744,12 +851,17 @@ class SimpleTestCase(unittest.TestCase): assertRaises(). """ return self._assertFooMessage( - self.assertWarns, 'warning', expected_warning, expected_message, - *args, **kwargs + self.assertWarns, + "warning", + expected_warning, + expected_message, + *args, + **kwargs, ) # A similar method is available in Python 3.10+. if not PY310: + @contextmanager def assertNoLogs(self, logger, level=None): """ @@ -759,20 +871,29 @@ class SimpleTestCase(unittest.TestCase): if isinstance(level, int): level = logging.getLevelName(level) elif level is None: - level = 'INFO' + level = "INFO" try: with self.assertLogs(logger, level) as cm: yield except AssertionError as e: msg = e.args[0] - expected_msg = f'no logs of level {level} or higher triggered on {logger}' + expected_msg = ( + f"no logs of level {level} or higher triggered on {logger}" + ) if msg != expected_msg: raise e else: - self.fail(f'Unexpected logs found: {cm.output!r}') + self.fail(f"Unexpected logs found: {cm.output!r}") - def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None, - field_kwargs=None, empty_value=''): + def assertFieldOutput( + self, + fieldclass, + valid, + invalid, + field_args=None, + field_kwargs=None, + empty_value="", + ): """ Assert that a form field behaves correctly with various inputs. @@ -791,7 +912,7 @@ class SimpleTestCase(unittest.TestCase): if field_kwargs is None: field_kwargs = {} required = fieldclass(*field_args, **field_kwargs) - optional = fieldclass(*field_args, **{**field_kwargs, 'required': False}) + optional = fieldclass(*field_args, **{**field_kwargs, "required": False}) # test valid inputs for input, output in valid.items(): self.assertEqual(required.clean(input), output) @@ -806,7 +927,7 @@ class SimpleTestCase(unittest.TestCase): optional.clean(input) self.assertEqual(context_manager.exception.messages, errors) # test required inputs - error_required = [required.error_messages['required']] + error_required = [required.error_messages["required"]] for e in required.empty_values: with self.assertRaises(ValidationError) as context_manager: required.clean(e) @@ -814,7 +935,7 @@ class SimpleTestCase(unittest.TestCase): self.assertEqual(optional.clean(e), empty_value) # test that max_length and min_length are always accepted if issubclass(fieldclass, CharField): - field_kwargs.update({'min_length': 2, 'max_length': 20}) + field_kwargs.update({"min_length": 2, "max_length": 20}) self.assertIsInstance(fieldclass(*field_args, **field_kwargs), fieldclass) def assertHTMLEqual(self, html1, html2, msg=None): @@ -823,39 +944,57 @@ class SimpleTestCase(unittest.TestCase): Whitespace in most cases is ignored, and attribute ordering is not significant. The arguments must be valid HTML. """ - dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:') - dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:') + dom1 = assert_and_parse_html( + self, html1, msg, "First argument is not valid HTML:" + ) + dom2 = assert_and_parse_html( + self, html2, msg, "Second argument is not valid HTML:" + ) if dom1 != dom2: - standardMsg = '%s != %s' % ( - safe_repr(dom1, True), safe_repr(dom2, True)) - diff = ('\n' + '\n'.join(difflib.ndiff( - str(dom1).splitlines(), str(dom2).splitlines(), - ))) + standardMsg = "%s != %s" % (safe_repr(dom1, True), safe_repr(dom2, True)) + diff = "\n" + "\n".join( + difflib.ndiff( + str(dom1).splitlines(), + str(dom2).splitlines(), + ) + ) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) def assertHTMLNotEqual(self, html1, html2, msg=None): """Assert that two HTML snippets are not semantically equivalent.""" - dom1 = assert_and_parse_html(self, html1, msg, 'First argument is not valid HTML:') - dom2 = assert_and_parse_html(self, html2, msg, 'Second argument is not valid HTML:') + dom1 = assert_and_parse_html( + self, html1, msg, "First argument is not valid HTML:" + ) + dom2 = assert_and_parse_html( + self, html2, msg, "Second argument is not valid HTML:" + ) if dom1 == dom2: - standardMsg = '%s == %s' % ( - safe_repr(dom1, True), safe_repr(dom2, True)) + standardMsg = "%s == %s" % (safe_repr(dom1, True), safe_repr(dom2, True)) self.fail(self._formatMessage(msg, standardMsg)) - def assertInHTML(self, needle, haystack, count=None, msg_prefix=''): - needle = assert_and_parse_html(self, needle, None, 'First argument is not valid HTML:') - haystack = assert_and_parse_html(self, haystack, None, 'Second argument is not valid HTML:') + def assertInHTML(self, needle, haystack, count=None, msg_prefix=""): + needle = assert_and_parse_html( + self, needle, None, "First argument is not valid HTML:" + ) + haystack = assert_and_parse_html( + self, haystack, None, "Second argument is not valid HTML:" + ) real_count = haystack.count(needle) if count is not None: self.assertEqual( - real_count, count, - msg_prefix + "Found %d instances of '%s' in response (expected %d)" % (real_count, needle, count) + real_count, + count, + msg_prefix + + "Found %d instances of '%s' in response (expected %d)" + % (real_count, needle, count), ) else: - self.assertTrue(real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle) + self.assertTrue( + real_count != 0, msg_prefix + "Couldn't find '%s' in response" % needle + ) def assertJSONEqual(self, raw, expected_data, msg=None): """ @@ -900,14 +1039,17 @@ class SimpleTestCase(unittest.TestCase): try: result = compare_xml(xml1, xml2) except Exception as e: - standardMsg = 'First or second argument is not valid XML\n%s' % e + standardMsg = "First or second argument is not valid XML\n%s" % e self.fail(self._formatMessage(msg, standardMsg)) else: if not result: - standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True)) - diff = ('\n' + '\n'.join( + standardMsg = "%s != %s" % ( + safe_repr(xml1, True), + safe_repr(xml2, True), + ) + diff = "\n" + "\n".join( difflib.ndiff(xml1.splitlines(), xml2.splitlines()) - )) + ) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) @@ -920,11 +1062,14 @@ class SimpleTestCase(unittest.TestCase): try: result = compare_xml(xml1, xml2) except Exception as e: - standardMsg = 'First or second argument is not valid XML\n%s' % e + standardMsg = "First or second argument is not valid XML\n%s" % e self.fail(self._formatMessage(msg, standardMsg)) else: if result: - standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True)) + standardMsg = "%s == %s" % ( + safe_repr(xml1, True), + safe_repr(xml2, True), + ) self.fail(self._formatMessage(msg, standardMsg)) @@ -942,9 +1087,9 @@ class TransactionTestCase(SimpleTestCase): databases = {DEFAULT_DB_ALIAS} _disallowed_database_msg = ( - 'Database %(operation)s to %(alias)r are not allowed in this test. ' - 'Add %(alias)r to %(test)s.databases to ensure proper test isolation ' - 'and silence this failure.' + "Database %(operation)s to %(alias)r are not allowed in this test. " + "Add %(alias)r to %(test)s.databases to ensure proper test isolation " + "and silence this failure." ) # If transactions aren't available, Django will serialize the database @@ -966,7 +1111,7 @@ class TransactionTestCase(SimpleTestCase): apps.set_available_apps(self.available_apps) setting_changed.send( sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', + setting="INSTALLED_APPS", value=self.available_apps, enter=True, ) @@ -979,7 +1124,7 @@ class TransactionTestCase(SimpleTestCase): apps.unset_available_apps() setting_changed.send( sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', + setting="INSTALLED_APPS", value=settings.INSTALLED_APPS, enter=False, ) @@ -994,9 +1139,12 @@ class TransactionTestCase(SimpleTestCase): def _databases_names(cls, include_mirrors=True): # Only consider allowed database aliases, including mirrors or not. return [ - alias for alias in connections - if alias in cls.databases and ( - include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR'] + alias + for alias in connections + if alias in cls.databases + and ( + include_mirrors + or not connections[alias].settings_dict["TEST"]["MIRROR"] ) ] @@ -1004,7 +1152,8 @@ class TransactionTestCase(SimpleTestCase): conn = connections[db_name] if conn.features.supports_sequence_reset: sql_list = conn.ops.sequence_reset_by_name_sql( - no_style(), conn.introspection.sequence_list()) + no_style(), conn.introspection.sequence_list() + ) if sql_list: with transaction.atomic(using=db_name): with conn.cursor() as cursor: @@ -1018,7 +1167,9 @@ class TransactionTestCase(SimpleTestCase): self._reset_sequences(db_name) # Provide replica initial data from migrated apps, if needed. - if self.serialized_rollback and hasattr(connections[db_name], "_test_serialized_contents"): + if self.serialized_rollback and hasattr( + connections[db_name], "_test_serialized_contents" + ): if self.available_apps is not None: apps.unset_available_apps() connections[db_name].creation.deserialize_db_from_string( @@ -1030,8 +1181,9 @@ class TransactionTestCase(SimpleTestCase): if self.fixtures: # We have to use this slightly awkward syntax due to the fact # that we're using *args and **kwargs together. - call_command('loaddata', *self.fixtures, - **{'verbosity': 0, 'database': db_name}) + call_command( + "loaddata", *self.fixtures, **{"verbosity": 0, "database": db_name} + ) def _should_reload_connections(self): return True @@ -1058,10 +1210,12 @@ class TransactionTestCase(SimpleTestCase): finally: if self.available_apps is not None: apps.unset_available_apps() - setting_changed.send(sender=settings._wrapped.__class__, - setting='INSTALLED_APPS', - value=settings.INSTALLED_APPS, - enter=False) + setting_changed.send( + sender=settings._wrapped.__class__, + setting="INSTALLED_APPS", + value=settings.INSTALLED_APPS, + enter=False, + ) def _fixture_teardown(self): # Allow TRUNCATE ... CASCADE and don't emit the post_migrate signal @@ -1069,17 +1223,22 @@ class TransactionTestCase(SimpleTestCase): for db_name in self._databases_names(include_mirrors=False): # Flush the database inhibit_post_migrate = ( - self.available_apps is not None or - ( # Inhibit the post_migrate signal when using serialized + self.available_apps is not None + or ( # Inhibit the post_migrate signal when using serialized # rollback to avoid trying to recreate the serialized data. - self.serialized_rollback and - hasattr(connections[db_name], '_test_serialized_contents') + self.serialized_rollback + and hasattr(connections[db_name], "_test_serialized_contents") ) ) - call_command('flush', verbosity=0, interactive=False, - database=db_name, reset_sequences=False, - allow_cascade=self.available_apps is not None, - inhibit_post_migrate=inhibit_post_migrate) + call_command( + "flush", + verbosity=0, + interactive=False, + database=db_name, + reset_sequences=False, + allow_cascade=self.available_apps is not None, + inhibit_post_migrate=inhibit_post_migrate, + ) def assertQuerysetEqual(self, qs, values, transform=None, ordered=True, msg=None): values = list(values) @@ -1090,10 +1249,10 @@ class TransactionTestCase(SimpleTestCase): return self.assertDictEqual(Counter(items), Counter(values), msg=msg) # For example qs.iterator() could be passed as qs, but it does not # have 'ordered' attribute. - if len(values) > 1 and hasattr(qs, 'ordered') and not qs.ordered: + if len(values) > 1 and hasattr(qs, "ordered") and not qs.ordered: raise ValueError( - 'Trying to compare non-ordered queryset against more than one ' - 'ordered value.' + "Trying to compare non-ordered queryset against more than one " + "ordered value." ) return self.assertEqual(list(items), values, msg=msg) @@ -1113,7 +1272,11 @@ def connections_support_transactions(aliases=None): Return whether or not all (or specified) connections support transactions. """ - conns = connections.all() if aliases is None else (connections[alias] for alias in aliases) + conns = ( + connections.all() + if aliases is None + else (connections[alias] for alias in aliases) + ) return all(conn.features.supports_transactions for conn in conns) @@ -1128,7 +1291,8 @@ class TestData: Objects are deep copied using a memo kept on the test case instance in order to maintain their original relationships. """ - memo_attr = '_testdata_memo' + + memo_attr = "_testdata_memo" def __init__(self, name, data): self.name = name @@ -1151,7 +1315,7 @@ class TestData: return data def __repr__(self): - return '<TestData: name=%r, data=%r>' % (self.name, self.data) + return "<TestData: name=%r, data=%r>" % (self.name, self.data) class TestCase(TransactionTestCase): @@ -1167,6 +1331,7 @@ class TestCase(TransactionTestCase): On database backends with no transaction support, TestCase behaves as TransactionTestCase. """ + @classmethod def _enter_atomics(cls): """Open atomic blocks for multiple databases.""" @@ -1199,7 +1364,11 @@ class TestCase(TransactionTestCase): if cls.fixtures: for db_name in cls._databases_names(include_mirrors=False): try: - call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name}) + call_command( + "loaddata", + *cls.fixtures, + **{"verbosity": 0, "database": db_name}, + ) except Exception: cls._rollback_atomics(cls.cls_atomics) raise @@ -1239,7 +1408,7 @@ class TestCase(TransactionTestCase): return super()._fixture_setup() if self.reset_sequences: - raise TypeError('reset_sequences cannot be used on TestCase instances') + raise TypeError("reset_sequences cannot be used on TestCase instances") self.atomics = self._enter_atomics() def _fixture_teardown(self): @@ -1254,8 +1423,9 @@ class TestCase(TransactionTestCase): def _should_check_constraints(self, connection): return ( - connection.features.can_defer_constraint_checks and - not connection.needs_rollback and connection.is_usable() + connection.features.can_defer_constraint_checks + and not connection.needs_rollback + and connection.is_usable() ) @classmethod @@ -1281,6 +1451,7 @@ class TestCase(TransactionTestCase): class CheckCondition: """Descriptor class for deferred condition checking.""" + def __init__(self, *conditions): self.conditions = conditions @@ -1289,7 +1460,7 @@ class CheckCondition: def __get__(self, instance, cls=None): # Trigger access for all bases. - if any(getattr(base, '__unittest_skip__', False) for base in cls.__bases__): + if any(getattr(base, "__unittest_skip__", False) for base in cls.__bases__): return True for condition, reason in self.conditions: if condition(): @@ -1303,15 +1474,21 @@ class CheckCondition: def _deferredSkip(condition, reason, name): def decorator(test_func): nonlocal condition - if not (isinstance(test_func, type) and - issubclass(test_func, unittest.TestCase)): + if not ( + isinstance(test_func, type) and issubclass(test_func, unittest.TestCase) + ): + @wraps(test_func) def skip_wrapper(*args, **kwargs): - if (args and isinstance(args[0], unittest.TestCase) and - connection.alias not in getattr(args[0], 'databases', {})): + if ( + args + and isinstance(args[0], unittest.TestCase) + and connection.alias not in getattr(args[0], "databases", {}) + ): raise ValueError( "%s cannot be used on %s as %s doesn't allow queries " - "against the %r database." % ( + "against the %r database." + % ( name, args[0], args[0].__class__.__qualname__, @@ -1321,55 +1498,67 @@ def _deferredSkip(condition, reason, name): if condition(): raise unittest.SkipTest(reason) return test_func(*args, **kwargs) + test_item = skip_wrapper else: # Assume a class is decorated test_item = test_func - databases = getattr(test_item, 'databases', None) + databases = getattr(test_item, "databases", None) if not databases or connection.alias not in databases: # Defer raising to allow importing test class's module. def condition(): raise ValueError( "%s cannot be used on %s as it doesn't allow queries " - "against the '%s' database." % ( - name, test_item, connection.alias, + "against the '%s' database." + % ( + name, + test_item, + connection.alias, ) ) + # Retrieve the possibly existing value from the class's dict to # avoid triggering the descriptor. - skip = test_func.__dict__.get('__unittest_skip__') + skip = test_func.__dict__.get("__unittest_skip__") if isinstance(skip, CheckCondition): test_item.__unittest_skip__ = skip.add_condition(condition, reason) elif skip is not True: test_item.__unittest_skip__ = CheckCondition((condition, reason)) return test_item + return decorator def skipIfDBFeature(*features): """Skip a test if a database has at least one of the named features.""" return _deferredSkip( - lambda: any(getattr(connection.features, feature, False) for feature in features), + lambda: any( + getattr(connection.features, feature, False) for feature in features + ), "Database has feature(s) %s" % ", ".join(features), - 'skipIfDBFeature', + "skipIfDBFeature", ) def skipUnlessDBFeature(*features): """Skip a test unless a database has all the named features.""" return _deferredSkip( - lambda: not all(getattr(connection.features, feature, False) for feature in features), + lambda: not all( + getattr(connection.features, feature, False) for feature in features + ), "Database doesn't support feature(s): %s" % ", ".join(features), - 'skipUnlessDBFeature', + "skipUnlessDBFeature", ) def skipUnlessAnyDBFeature(*features): """Skip a test unless a database has any of the named features.""" return _deferredSkip( - lambda: not any(getattr(connection.features, feature, False) for feature in features), + lambda: not any( + getattr(connection.features, feature, False) for feature in features + ), "Database doesn't support any of the feature(s): %s" % ", ".join(features), - 'skipUnlessAnyDBFeature', + "skipUnlessAnyDBFeature", ) @@ -1378,6 +1567,7 @@ class QuietWSGIRequestHandler(WSGIRequestHandler): A WSGIRequestHandler that doesn't log to standard output any of the requests received, so as to not clutter the test result output. """ + def log_message(*args): pass @@ -1387,6 +1577,7 @@ class FSFilesHandler(WSGIHandler): WSGI middleware that intercepts calls to a directory, as defined by one of the *_ROOT settings, and serves those files, publishing them under *_URL. """ + def __init__(self, application): self.application = application self.base_url = urlparse(self.get_base_url()) @@ -1402,7 +1593,7 @@ class FSFilesHandler(WSGIHandler): def file_path(self, url): """Return the relative path to the file on disk for the given URL.""" - relative_url = url[len(self.base_url[2]):] + relative_url = url[len(self.base_url[2]) :] return url2pathname(relative_url) def get_response(self, request): @@ -1421,7 +1612,7 @@ class FSFilesHandler(WSGIHandler): # Emulate behavior of django.contrib.staticfiles.views.serve() when it # invokes staticfiles' finders functionality. # TODO: Modify if/when that internal API is refactored - final_rel_path = os_rel_path.replace('\\', '/').lstrip('/') + final_rel_path = os_rel_path.replace("\\", "/").lstrip("/") return serve(request, final_rel_path, document_root=self.get_base_dir()) def __call__(self, environ, start_response): @@ -1435,6 +1626,7 @@ class _StaticFilesHandler(FSFilesHandler): Handler for serving static files. A private class that is meant to be used solely as a convenience by LiveServerThread. """ + def get_base_dir(self): return settings.STATIC_ROOT @@ -1447,6 +1639,7 @@ class _MediaFilesHandler(FSFilesHandler): Handler for serving the media files. A private class that is meant to be used solely as a convenience by LiveServerThread. """ + def get_base_dir(self): return settings.MEDIA_ROOT @@ -1503,7 +1696,7 @@ class LiveServerThread(threading.Thread): ) def terminate(self): - if hasattr(self, 'httpd'): + if hasattr(self, "httpd"): # Stop the WSGI server self.httpd.shutdown() self.httpd.server_close() @@ -1521,14 +1714,15 @@ class LiveServerTestCase(TransactionTestCase): and each thread needs to commit all their transactions so that the other thread can see the changes. """ - host = 'localhost' + + host = "localhost" port = 0 server_thread_class = LiveServerThread static_handler = _StaticFilesHandler @classproperty def live_server_url(cls): - return 'http://%s:%s' % (cls.host, cls.server_thread.port) + return "http://%s:%s" % (cls.host, cls.server_thread.port) @classproperty def allowed_host(cls): @@ -1540,7 +1734,7 @@ class LiveServerTestCase(TransactionTestCase): for conn in connections.all(): # If using in-memory sqlite databases, pass the connections to # the server thread. - if conn.vendor == 'sqlite' and conn.is_in_memory_db(): + if conn.vendor == "sqlite" and conn.is_in_memory_db(): connections_override[conn.alias] = conn return connections_override @@ -1548,7 +1742,7 @@ class LiveServerTestCase(TransactionTestCase): def setUpClass(cls): super().setUpClass() cls._live_server_modified_settings = modify_settings( - ALLOWED_HOSTS={'append': cls.allowed_host}, + ALLOWED_HOSTS={"append": cls.allowed_host}, ) cls._live_server_modified_settings.enable() cls.addClassCleanup(cls._live_server_modified_settings.disable) @@ -1598,6 +1792,7 @@ class SerializeMixin: Place it early in the MRO in order to isolate setUpClass()/tearDownClass(). """ + lockfile = None def __init_subclass__(cls, /, **kwargs): @@ -1605,7 +1800,8 @@ class SerializeMixin: if cls.lockfile is None: raise ValueError( "{}.lockfile isn't set. Set it to a unique value " - "in the base class.".format(cls.__name__)) + "in the base class.".format(cls.__name__) + ) @classmethod def setUpClass(cls): diff --git a/django/test/utils.py b/django/test/utils.py index 6c2f566909..ac0fc34b08 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -35,15 +35,24 @@ except ImportError: __all__ = ( - 'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner', - 'CaptureQueriesContext', - 'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings', - 'override_system_checks', 'tag', - 'requires_tz_support', - 'setup_databases', 'setup_test_environment', 'teardown_test_environment', + "Approximate", + "ContextList", + "isolate_lru_cache", + "get_runner", + "CaptureQueriesContext", + "ignore_warnings", + "isolate_apps", + "modify_settings", + "override_settings", + "override_system_checks", + "tag", + "requires_tz_support", + "setup_databases", + "setup_test_environment", + "teardown_test_environment", ) -TZ_SUPPORT = hasattr(time, 'tzset') +TZ_SUPPORT = hasattr(time, "tzset") class Approximate: @@ -63,6 +72,7 @@ class ContextList(list): A wrapper that provides direct key access to context items contained in a list of context objects. """ + def __getitem__(self, key): if isinstance(key, str): for subcontext in self: @@ -110,7 +120,7 @@ def setup_test_environment(debug=None): Perform global pre-test setup, such as installing the instrumented template renderer and setting the email backend to the locmem email backend. """ - if hasattr(_TestState, 'saved_data'): + if hasattr(_TestState, "saved_data"): # Executing this function twice would overwrite the saved values. raise RuntimeError( "setup_test_environment() was already called and can't be called " @@ -125,13 +135,13 @@ def setup_test_environment(debug=None): saved_data.allowed_hosts = settings.ALLOWED_HOSTS # Add the default host of the test client. - settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver'] + settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"] saved_data.debug = settings.DEBUG settings.DEBUG = debug saved_data.email_backend = settings.EMAIL_BACKEND - settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend' + settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend" saved_data.template_render = Template._render Template._render = instrumented_test_render @@ -191,18 +201,17 @@ def setup_databases( # replace with: # serialize_alias = serialized_aliases is None or alias in serialized_aliases try: - serialize_alias = connection.settings_dict['TEST']['SERIALIZE'] + serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"] except KeyError: serialize_alias = ( - serialized_aliases is None or - alias in serialized_aliases + serialized_aliases is None or alias in serialized_aliases ) else: warnings.warn( - 'The SERIALIZE test database setting is ' - 'deprecated as it can be inferred from the ' - 'TestCase/TransactionTestCase.databases that ' - 'enable the serialized_rollback feature.', + "The SERIALIZE test database setting is " + "deprecated as it can be inferred from the " + "TestCase/TransactionTestCase.databases that " + "enable the serialized_rollback feature.", category=RemovedInDjango50Warning, ) connection.creation.create_test_db( @@ -221,12 +230,15 @@ def setup_databases( ) # Configure all other connections as mirrors of the first one else: - connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict) + connections[alias].creation.set_as_test_mirror( + connections[first_alias].settings_dict + ) # Configure the test mirrors. for alias, mirror_alias in mirrored_aliases.items(): connections[alias].creation.set_as_test_mirror( - connections[mirror_alias].settings_dict) + connections[mirror_alias].settings_dict + ) if debug_sql: for alias in connections: @@ -246,8 +258,8 @@ def iter_test_cases(tests): # Prevent an unfriendly RecursionError that can happen with # strings. raise TypeError( - f'Test {test!r} must be a test case or test suite not string ' - f'(was found in {tests!r}).' + f"Test {test!r} must be a test case or test suite not string " + f"(was found in {tests!r})." ) if isinstance(test, TestCase): yield test @@ -319,18 +331,18 @@ def get_unique_databases_and_mirrors(aliases=None): for alias in connections: connection = connections[alias] - test_settings = connection.settings_dict['TEST'] + test_settings = connection.settings_dict["TEST"] - if test_settings['MIRROR']: + if test_settings["MIRROR"]: # If the database is marked as a test mirror, save the alias. - mirrored_aliases[alias] = test_settings['MIRROR'] + mirrored_aliases[alias] = test_settings["MIRROR"] elif alias in aliases: # Store a tuple with DB parameters that uniquely identify it. # If we have two aliases with the same values for that tuple, # we only need to create the test database once. item = test_databases.setdefault( connection.creation.test_db_signature(), - (connection.settings_dict['NAME'], []), + (connection.settings_dict["NAME"], []), ) # The default database must be the first because data migrations # use the default alias by default. @@ -339,11 +351,16 @@ def get_unique_databases_and_mirrors(aliases=None): else: item[1].append(alias) - if 'DEPENDENCIES' in test_settings: - dependencies[alias] = test_settings['DEPENDENCIES'] + if "DEPENDENCIES" in test_settings: + dependencies[alias] = test_settings["DEPENDENCIES"] else: - if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig: - dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS]) + if ( + alias != DEFAULT_DB_ALIAS + and connection.creation.test_db_signature() != default_sig + ): + dependencies[alias] = test_settings.get( + "DEPENDENCIES", [DEFAULT_DB_ALIAS] + ) test_databases = dict(dependency_ordered(test_databases.items(), dependencies)) return test_databases, mirrored_aliases @@ -365,12 +382,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False): def get_runner(settings, test_runner_class=None): test_runner_class = test_runner_class or settings.TEST_RUNNER - test_path = test_runner_class.split('.') + test_path = test_runner_class.split(".") # Allow for relative paths if len(test_path) > 1: - test_module_name = '.'.join(test_path[:-1]) + test_module_name = ".".join(test_path[:-1]) else: - test_module_name = '.' + test_module_name = "." test_module = __import__(test_module_name, {}, {}, test_path[-1]) return getattr(test_module, test_path[-1]) @@ -387,6 +404,7 @@ class TestContextDecorator: `kwarg_name`: keyword argument passing the return value of enable() if used as a function decorator. """ + def __init__(self, attr_name=None, kwarg_name=None): self.attr_name = attr_name self.kwarg_name = kwarg_name @@ -416,7 +434,7 @@ class TestContextDecorator: cls.setUp = setUp return cls - raise TypeError('Can only decorate subclasses of unittest.TestCase') + raise TypeError("Can only decorate subclasses of unittest.TestCase") def decorate_callable(self, func): if asyncio.iscoroutinefunction(func): @@ -428,13 +446,16 @@ class TestContextDecorator: if self.kwarg_name: kwargs[self.kwarg_name] = context return await func(*args, **kwargs) + else: + @wraps(func) def inner(*args, **kwargs): with self as context: if self.kwarg_name: kwargs[self.kwarg_name] = context return func(*args, **kwargs) + return inner def __call__(self, decorated): @@ -442,7 +463,7 @@ class TestContextDecorator: return self.decorate_class(decorated) elif callable(decorated): return self.decorate_callable(decorated) - raise TypeError('Cannot decorate object of type %s' % type(decorated)) + raise TypeError("Cannot decorate object of type %s" % type(decorated)) class override_settings(TestContextDecorator): @@ -452,6 +473,7 @@ class override_settings(TestContextDecorator): with the ``with`` statement. In either event, entering/exiting are called before and after, respectively, the function/block is executed. """ + enable_exception = None def __init__(self, **kwargs): @@ -461,9 +483,9 @@ class override_settings(TestContextDecorator): def enable(self): # Keep this code at the beginning to leave the settings unchanged # in case it raises an exception because INSTALLED_APPS is invalid. - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: try: - apps.set_installed_apps(self.options['INSTALLED_APPS']) + apps.set_installed_apps(self.options["INSTALLED_APPS"]) except Exception: apps.unset_installed_apps() raise @@ -476,14 +498,16 @@ class override_settings(TestContextDecorator): try: setting_changed.send( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=True, + setting=key, + value=new_value, + enter=True, ) except Exception as exc: self.enable_exception = exc self.disable() def disable(self): - if 'INSTALLED_APPS' in self.options: + if "INSTALLED_APPS" in self.options: apps.unset_installed_apps() settings._wrapped = self.wrapped del self.wrapped @@ -492,7 +516,9 @@ class override_settings(TestContextDecorator): new_value = getattr(settings, key, None) responses_for_setting = setting_changed.send_robust( sender=settings._wrapped.__class__, - setting=key, value=new_value, enter=False, + setting=key, + value=new_value, + enter=False, ) responses.extend(responses_for_setting) if self.enable_exception is not None: @@ -515,10 +541,12 @@ class override_settings(TestContextDecorator): def decorate_class(self, cls): from django.test import SimpleTestCase + if not issubclass(cls, SimpleTestCase): raise ValueError( "Only subclasses of Django SimpleTestCase can be decorated " - "with override_settings") + "with override_settings" + ) self.save_options(cls) return cls @@ -528,6 +556,7 @@ class modify_settings(override_settings): Like override_settings, but makes it possible to append, prepend, or remove items instead of redefining the entire list. """ + def __init__(self, *args, **kwargs): if args: # Hack used when instantiating from SimpleTestCase.setUpClass. @@ -543,8 +572,9 @@ class modify_settings(override_settings): test_func._modified_settings = self.operations else: # Duplicate list to prevent subclasses from altering their parent. - test_func._modified_settings = list( - test_func._modified_settings) + self.operations + test_func._modified_settings = ( + list(test_func._modified_settings) + self.operations + ) def enable(self): self.options = {} @@ -559,11 +589,11 @@ class modify_settings(override_settings): # items my be a single value or an iterable. if isinstance(items, str): items = [items] - if action == 'append': + if action == "append": value = value + [item for item in items if item not in value] - elif action == 'prepend': + elif action == "prepend": value = [item for item in items if item not in value] + value - elif action == 'remove': + elif action == "remove": value = [item for item in value if item not in items] else: raise ValueError("Unsupported action: %s" % action) @@ -577,8 +607,10 @@ class override_system_checks(TestContextDecorator): Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app, you also need to exclude its system checks. """ + def __init__(self, new_checks, deployment_checks=None): from django.core.checks.registry import registry + self.registry = registry self.new_checks = new_checks self.deployment_checks = deployment_checks @@ -588,12 +620,12 @@ class override_system_checks(TestContextDecorator): self.old_checks = self.registry.registered_checks self.registry.registered_checks = set() for check in self.new_checks: - self.registry.register(check, *getattr(check, 'tags', ())) + self.registry.register(check, *getattr(check, "tags", ())) self.old_deployment_checks = self.registry.deployment_checks if self.deployment_checks is not None: self.registry.deployment_checks = set() for check in self.deployment_checks: - self.registry.register(check, *getattr(check, 'tags', ()), deploy=True) + self.registry.register(check, *getattr(check, "tags", ()), deploy=True) def disable(self): self.registry.registered_checks = self.old_checks @@ -609,18 +641,18 @@ def compare_xml(want, got): Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py """ - _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') + _norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+") def norm_whitespace(v): - return _norm_whitespace_re.sub(' ', v) + return _norm_whitespace_re.sub(" ", v) def child_text(element): - return ''.join(c.data for c in element.childNodes - if c.nodeType == Node.TEXT_NODE) + return "".join( + c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE + ) def children(element): - return [c for c in element.childNodes - if c.nodeType == Node.ELEMENT_NODE] + return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE] def norm_child_text(element): return norm_whitespace(child_text(element)) @@ -639,7 +671,9 @@ def compare_xml(want, got): got_children = children(got_element) if len(want_children) != len(got_children): return False - return all(check_element(want, got) for want, got in zip(want_children, got_children)) + return all( + check_element(want, got) for want, got in zip(want_children, got_children) + ) def first_node(document): for node in document.childNodes: @@ -650,13 +684,13 @@ def compare_xml(want, got): ): return node - want = want.strip().replace('\\n', '\n') - got = got.strip().replace('\\n', '\n') + want = want.strip().replace("\\n", "\n") + got = got.strip().replace("\\n", "\n") # If the string is not a complete xml document, we may need to add a # root element. This allow us to compare fragments, like "<foo/><bar/>" - if not want.startswith('<?xml'): - wrapper = '<root>%s</root>' + if not want.startswith("<?xml"): + wrapper = "<root>%s</root>" want = wrapper % want got = wrapper % got @@ -671,6 +705,7 @@ class CaptureQueriesContext: """ Context manager that captures queries executed by the specified connection. """ + def __init__(self, connection): self.connection = connection @@ -685,7 +720,7 @@ class CaptureQueriesContext: @property def captured_queries(self): - return self.connection.queries[self.initial_queries:self.final_queries] + return self.connection.queries[self.initial_queries : self.final_queries] def __enter__(self): self.force_debug_cursor = self.connection.force_debug_cursor @@ -709,7 +744,7 @@ class CaptureQueriesContext: class ignore_warnings(TestContextDecorator): def __init__(self, **kwargs): self.ignore_kwargs = kwargs - if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs: + if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs: self.filter_func = warnings.filterwarnings else: self.filter_func = warnings.simplefilter @@ -718,7 +753,7 @@ class ignore_warnings(TestContextDecorator): def enable(self): self.catch_warnings = warnings.catch_warnings() self.catch_warnings.__enter__() - self.filter_func('ignore', **self.ignore_kwargs) + self.filter_func("ignore", **self.ignore_kwargs) def disable(self): self.catch_warnings.__exit__(*sys.exc_info()) @@ -732,7 +767,7 @@ class ignore_warnings(TestContextDecorator): requires_tz_support = skipUnless( TZ_SUPPORT, "This test relies on the ability to run a program in an arbitrary " - "time zone, but your operating system isn't able to do that." + "time zone, but your operating system isn't able to do that.", ) @@ -775,9 +810,9 @@ def captured_output(stream_name): def captured_stdout(): """Capture the output of sys.stdout: - with captured_stdout() as stdout: - print("hello") - self.assertEqual(stdout.getvalue(), "hello\n") + with captured_stdout() as stdout: + print("hello") + self.assertEqual(stdout.getvalue(), "hello\n") """ return captured_output("stdout") @@ -785,9 +820,9 @@ def captured_stdout(): def captured_stderr(): """Capture the output of sys.stderr: - with captured_stderr() as stderr: - print("hello", file=sys.stderr) - self.assertEqual(stderr.getvalue(), "hello\n") + with captured_stderr() as stderr: + print("hello", file=sys.stderr) + self.assertEqual(stderr.getvalue(), "hello\n") """ return captured_output("stderr") @@ -795,12 +830,12 @@ def captured_stderr(): def captured_stdin(): """Capture the input to sys.stdin: - with captured_stdin() as stdin: - stdin.write('hello\n') - stdin.seek(0) - # call test code that consumes from sys.stdin - captured = input() - self.assertEqual(captured, "hello") + with captured_stdin() as stdin: + stdin.write('hello\n') + stdin.seek(0) + # call test code that consumes from sys.stdin + captured = input() + self.assertEqual(captured, "hello") """ return captured_output("stdin") @@ -828,18 +863,24 @@ def require_jinja2(test_func): Django template engine for a test or skip it if Jinja2 isn't available. """ test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func) - return override_settings(TEMPLATES=[{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'APP_DIRS': True, - }, { - 'BACKEND': 'django.template.backends.jinja2.Jinja2', - 'APP_DIRS': True, - 'OPTIONS': {'keep_trailing_newline': True}, - }])(test_func) + return override_settings( + TEMPLATES=[ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + }, + { + "BACKEND": "django.template.backends.jinja2.Jinja2", + "APP_DIRS": True, + "OPTIONS": {"keep_trailing_newline": True}, + }, + ] + )(test_func) class override_script_prefix(TestContextDecorator): """Decorator or context manager to temporary override the script prefix.""" + def __init__(self, prefix): self.prefix = prefix super().__init__() @@ -857,8 +898,9 @@ class LoggingCaptureMixin: Capture the output from the 'django' logger and store it on the class's logger_output attribute. """ + def setUp(self): - self.logger = logging.getLogger('django') + self.logger = logging.getLogger("django") self.old_stream = self.logger.handlers[0].stream self.logger_output = StringIO() self.logger.handlers[0].stream = self.logger_output @@ -883,6 +925,7 @@ class isolate_apps(TestContextDecorator): `kwarg_name`: keyword argument passing the isolated registry if used as a function decorator. """ + def __init__(self, *installed_apps, **kwargs): self.installed_apps = installed_apps super().__init__(**kwargs) @@ -890,11 +933,11 @@ class isolate_apps(TestContextDecorator): def enable(self): self.old_apps = Options.default_apps apps = Apps(self.installed_apps) - setattr(Options, 'default_apps', apps) + setattr(Options, "default_apps", apps) return apps def disable(self): - setattr(Options, 'default_apps', self.old_apps) + setattr(Options, "default_apps", self.old_apps) class TimeKeeper: @@ -914,7 +957,7 @@ class TimeKeeper: def print_results(self): for name, end_times in self.records.items(): for record_time in end_times: - record = '%s took %.3fs' % (name, record_time) + record = "%s took %.3fs" % (name, record_time) sys.stderr.write(record + os.linesep) @@ -929,12 +972,14 @@ class NullTimeKeeper: def tag(*tags): """Decorator to add tags to a test class or method.""" + def decorator(obj): - if hasattr(obj, 'tags'): + if hasattr(obj, "tags"): obj.tags = obj.tags.union(tags) else: - setattr(obj, 'tags', set(tags)) + setattr(obj, "tags", set(tags)) return obj + return decorator diff --git a/django/urls/__init__.py b/django/urls/__init__.py index e9e32ac5b9..9aaf4814f2 100644 --- a/django/urls/__init__.py +++ b/django/urls/__init__.py @@ -1,23 +1,53 @@ from .base import ( - clear_script_prefix, clear_url_caches, get_script_prefix, get_urlconf, - is_valid_path, resolve, reverse, reverse_lazy, set_script_prefix, - set_urlconf, translate_url, + clear_script_prefix, + clear_url_caches, + get_script_prefix, + get_urlconf, + is_valid_path, + resolve, + reverse, + reverse_lazy, + set_script_prefix, + set_urlconf, + translate_url, ) from .conf import include, path, re_path from .converters import register_converter from .exceptions import NoReverseMatch, Resolver404 from .resolvers import ( - LocalePrefixPattern, ResolverMatch, URLPattern, URLResolver, - get_ns_resolver, get_resolver, + LocalePrefixPattern, + ResolverMatch, + URLPattern, + URLResolver, + get_ns_resolver, + get_resolver, ) from .utils import get_callable, get_mod_func __all__ = [ - 'LocalePrefixPattern', 'NoReverseMatch', 'URLPattern', - 'URLResolver', 'Resolver404', 'ResolverMatch', 'clear_script_prefix', - 'clear_url_caches', 'get_callable', 'get_mod_func', 'get_ns_resolver', - 'get_resolver', 'get_script_prefix', 'get_urlconf', 'include', - 'is_valid_path', 'path', 're_path', 'register_converter', 'resolve', - 'reverse', 'reverse_lazy', 'set_script_prefix', 'set_urlconf', - 'translate_url', + "LocalePrefixPattern", + "NoReverseMatch", + "URLPattern", + "URLResolver", + "Resolver404", + "ResolverMatch", + "clear_script_prefix", + "clear_url_caches", + "get_callable", + "get_mod_func", + "get_ns_resolver", + "get_resolver", + "get_script_prefix", + "get_urlconf", + "include", + "is_valid_path", + "path", + "re_path", + "register_converter", + "resolve", + "reverse", + "reverse_lazy", + "set_script_prefix", + "set_urlconf", + "translate_url", ] diff --git a/django/urls/base.py b/django/urls/base.py index 8c26a3880b..647ef3e2de 100644 --- a/django/urls/base.py +++ b/django/urls/base.py @@ -36,16 +36,16 @@ def reverse(viewname, urlconf=None, args=None, kwargs=None, current_app=None): if not isinstance(viewname, str): view = viewname else: - *path, view = viewname.split(':') + *path, view = viewname.split(":") if current_app: - current_path = current_app.split(':') + current_path = current_app.split(":") current_path.reverse() else: current_path = None resolved_path = [] - ns_pattern = '' + ns_pattern = "" ns_converters = {} for ns in path: current_ns = current_path.pop() if current_path else None @@ -75,13 +75,15 @@ def reverse(viewname, urlconf=None, args=None, kwargs=None, current_app=None): except KeyError as key: if resolved_path: raise NoReverseMatch( - "%s is not a registered namespace inside '%s'" % - (key, ':'.join(resolved_path)) + "%s is not a registered namespace inside '%s'" + % (key, ":".join(resolved_path)) ) else: raise NoReverseMatch("%s is not a registered namespace" % key) if ns_pattern: - resolver = get_ns_resolver(ns_pattern, resolver, tuple(ns_converters.items())) + resolver = get_ns_resolver( + ns_pattern, resolver, tuple(ns_converters.items()) + ) return resolver._reverse_with_prefix(view, prefix, *args, **kwargs) @@ -99,8 +101,8 @@ def set_script_prefix(prefix): """ Set the script prefix for the current thread. """ - if not prefix.endswith('/'): - prefix += '/' + if not prefix.endswith("/"): + prefix += "/" _prefixes.value = prefix @@ -110,7 +112,7 @@ def get_script_prefix(): wishes to construct their own URLs manually (although accessing the request instance is normally going to be a lot cleaner). """ - return getattr(_prefixes, "value", '/') + return getattr(_prefixes, "value", "/") def clear_script_prefix(): @@ -168,12 +170,18 @@ def translate_url(url, lang_code): except Resolver404: pass else: - to_be_reversed = "%s:%s" % (match.namespace, match.url_name) if match.namespace else match.url_name + to_be_reversed = ( + "%s:%s" % (match.namespace, match.url_name) + if match.namespace + else match.url_name + ) with override(lang_code): try: url = reverse(to_be_reversed, args=match.args, kwargs=match.kwargs) except NoReverseMatch: pass else: - url = urlunsplit((parsed.scheme, parsed.netloc, url, parsed.query, parsed.fragment)) + url = urlunsplit( + (parsed.scheme, parsed.netloc, url, parsed.query, parsed.fragment) + ) return url diff --git a/django/urls/conf.py b/django/urls/conf.py index 40990d10f5..40708028a3 100644 --- a/django/urls/conf.py +++ b/django/urls/conf.py @@ -5,7 +5,11 @@ from importlib import import_module from django.core.exceptions import ImproperlyConfigured from .resolvers import ( - LocalePrefixPattern, RegexPattern, RoutePattern, URLPattern, URLResolver, + LocalePrefixPattern, + RegexPattern, + RoutePattern, + URLPattern, + URLResolver, ) @@ -18,13 +22,13 @@ def include(arg, namespace=None): except ValueError: if namespace: raise ImproperlyConfigured( - 'Cannot override the namespace for a dynamic module that ' - 'provides a namespace.' + "Cannot override the namespace for a dynamic module that " + "provides a namespace." ) raise ImproperlyConfigured( - 'Passing a %d-tuple to include() is not supported. Pass a ' - '2-tuple containing the list of patterns and app_name, and ' - 'provide the namespace argument to include() instead.' % len(arg) + "Passing a %d-tuple to include() is not supported. Pass a " + "2-tuple containing the list of patterns and app_name, and " + "provide the namespace argument to include() instead." % len(arg) ) else: # No namespace hint - use manually provided namespace. @@ -32,24 +36,24 @@ def include(arg, namespace=None): if isinstance(urlconf_module, str): urlconf_module = import_module(urlconf_module) - patterns = getattr(urlconf_module, 'urlpatterns', urlconf_module) - app_name = getattr(urlconf_module, 'app_name', app_name) + patterns = getattr(urlconf_module, "urlpatterns", urlconf_module) + app_name = getattr(urlconf_module, "app_name", app_name) if namespace and not app_name: raise ImproperlyConfigured( - 'Specifying a namespace in include() without providing an app_name ' - 'is not supported. Set the app_name attribute in the included ' - 'module, or pass a 2-tuple containing the list of patterns and ' - 'app_name instead.', + "Specifying a namespace in include() without providing an app_name " + "is not supported. Set the app_name attribute in the included " + "module, or pass a 2-tuple containing the list of patterns and " + "app_name instead.", ) namespace = namespace or app_name # Make sure the patterns can be iterated through (without this, some # testcases will break). if isinstance(patterns, (list, tuple)): for url_pattern in patterns: - pattern = getattr(url_pattern, 'pattern', None) + pattern = getattr(url_pattern, "pattern", None) if isinstance(pattern, LocalePrefixPattern): raise ImproperlyConfigured( - 'Using i18n_patterns in an included URLconf is not allowed.' + "Using i18n_patterns in an included URLconf is not allowed." ) return (urlconf_module, app_name, namespace) @@ -59,7 +63,7 @@ def _path(route, view, kwargs=None, name=None, Pattern=None): if kwargs is not None and not isinstance(kwargs, dict): raise TypeError( - f'kwargs argument must be a dict, but got {kwargs.__class__.__name__}.' + f"kwargs argument must be a dict, but got {kwargs.__class__.__name__}." ) if isinstance(view, (list, tuple)): # For include(...) processing. @@ -78,11 +82,13 @@ def _path(route, view, kwargs=None, name=None, Pattern=None): elif isinstance(view, View): view_cls_name = view.__class__.__name__ raise TypeError( - f'view must be a callable, pass {view_cls_name}.as_view(), not ' - f'{view_cls_name}().' + f"view must be a callable, pass {view_cls_name}.as_view(), not " + f"{view_cls_name}()." ) else: - raise TypeError('view must be a callable or a list/tuple in the case of include().') + raise TypeError( + "view must be a callable or a list/tuple in the case of include()." + ) path = partial(_path, Pattern=RoutePattern) diff --git a/django/urls/converters.py b/django/urls/converters.py index bb8478e32f..8af3cbab25 100644 --- a/django/urls/converters.py +++ b/django/urls/converters.py @@ -3,7 +3,7 @@ from functools import lru_cache class IntConverter: - regex = '[0-9]+' + regex = "[0-9]+" def to_python(self, value): return int(value) @@ -13,7 +13,7 @@ class IntConverter: class StringConverter: - regex = '[^/]+' + regex = "[^/]+" def to_python(self, value): return value @@ -23,7 +23,7 @@ class StringConverter: class UUIDConverter: - regex = '[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}' + regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" def to_python(self, value): return uuid.UUID(value) @@ -33,19 +33,19 @@ class UUIDConverter: class SlugConverter(StringConverter): - regex = '[-a-zA-Z0-9_]+' + regex = "[-a-zA-Z0-9_]+" class PathConverter(StringConverter): - regex = '.+' + regex = ".+" DEFAULT_CONVERTERS = { - 'int': IntConverter(), - 'path': PathConverter(), - 'slug': SlugConverter(), - 'str': StringConverter(), - 'uuid': UUIDConverter(), + "int": IntConverter(), + "path": PathConverter(), + "slug": SlugConverter(), + "str": StringConverter(), + "uuid": UUIDConverter(), } diff --git a/django/urls/resolvers.py b/django/urls/resolvers.py index 2ef173c2c8..fee8ae9750 100644 --- a/django/urls/resolvers.py +++ b/django/urls/resolvers.py @@ -31,7 +31,17 @@ from .utils import get_callable class ResolverMatch: - def __init__(self, func, args, kwargs, url_name=None, app_names=None, namespaces=None, route=None, tried=None): + def __init__( + self, + func, + args, + kwargs, + url_name=None, + app_names=None, + namespaces=None, + route=None, + tried=None, + ): self.func = func self.args = args self.kwargs = kwargs @@ -42,21 +52,21 @@ class ResolverMatch: # If a URLRegexResolver doesn't have a namespace or app_name, it passes # in an empty value. self.app_names = [x for x in app_names if x] if app_names else [] - self.app_name = ':'.join(self.app_names) + self.app_name = ":".join(self.app_names) self.namespaces = [x for x in namespaces if x] if namespaces else [] - self.namespace = ':'.join(self.namespaces) + self.namespace = ":".join(self.namespaces) - if hasattr(func, 'view_class'): + if hasattr(func, "view_class"): func = func.view_class - if not hasattr(func, '__name__'): + if not hasattr(func, "__name__"): # A class-based view - self._func_path = func.__class__.__module__ + '.' + func.__class__.__name__ + self._func_path = func.__class__.__module__ + "." + func.__class__.__name__ else: # A function-based view - self._func_path = func.__module__ + '.' + func.__name__ + self._func_path = func.__module__ + "." + func.__name__ view_path = url_name or self._func_path - self.view_name = ':'.join(self.namespaces + [view_path]) + self.view_name = ":".join(self.namespaces + [view_path]) def __getitem__(self, index): return (self.func, self.args, self.kwargs)[index] @@ -67,15 +77,21 @@ class ResolverMatch: else: func = self._func_path return ( - 'ResolverMatch(func=%s, args=%r, kwargs=%r, url_name=%r, ' - 'app_names=%r, namespaces=%r, route=%r)' % ( - func, self.args, self.kwargs, self.url_name, - self.app_names, self.namespaces, self.route, + "ResolverMatch(func=%s, args=%r, kwargs=%r, url_name=%r, " + "app_names=%r, namespaces=%r, route=%r)" + % ( + func, + self.args, + self.kwargs, + self.url_name, + self.app_names, + self.namespaces, + self.route, ) ) def __reduce_ex__(self, protocol): - raise PicklingError(f'Cannot pickle {self.__class__.__qualname__}.') + raise PicklingError(f"Cannot pickle {self.__class__.__qualname__}.") def get_resolver(urlconf=None): @@ -86,7 +102,7 @@ def get_resolver(urlconf=None): @functools.lru_cache(maxsize=None) def _get_cached_resolver(urlconf=None): - return URLResolver(RegexPattern(r'^/'), urlconf) + return URLResolver(RegexPattern(r"^/"), urlconf) @functools.lru_cache(maxsize=None) @@ -97,7 +113,7 @@ def get_ns_resolver(ns_pattern, resolver, converters): pattern = RegexPattern(ns_pattern) pattern.converters = dict(converters) ns_resolver = URLResolver(pattern, resolver.url_patterns) - return URLResolver(RegexPattern(r'^/'), [ns_resolver]) + return URLResolver(RegexPattern(r"^/"), [ns_resolver]) class LocaleRegexDescriptor: @@ -115,8 +131,8 @@ class LocaleRegexDescriptor: # avoid per-language compilation. pattern = getattr(instance, self.attr) if isinstance(pattern, str): - instance.__dict__['regex'] = instance._compile(pattern) - return instance.__dict__['regex'] + instance.__dict__["regex"] = instance._compile(pattern) + return instance.__dict__["regex"] language_code = get_language() if language_code not in instance._regex_dict: instance._regex_dict[language_code] = instance._compile(str(pattern)) @@ -142,7 +158,9 @@ class CheckURLMixin: # Skip check as it can be useful to start a URL pattern with a slash # when APPEND_SLASH=False. return [] - if regex_pattern.startswith(('/', '^/', '^\\/')) and not regex_pattern.endswith('/'): + if regex_pattern.startswith(("/", "^/", "^\\/")) and not regex_pattern.endswith( + "/" + ): warning = Warning( "Your URL pattern {} has a route beginning with a '/'. Remove this " "slash as it is unnecessary. If this pattern is targeted in an " @@ -157,7 +175,7 @@ class CheckURLMixin: class RegexPattern(CheckURLMixin): - regex = LocaleRegexDescriptor('_regex') + regex = LocaleRegexDescriptor("_regex") def __init__(self, regex, name=None, is_endpoint=False): self._regex = regex @@ -169,7 +187,7 @@ class RegexPattern(CheckURLMixin): def match(self, path): match = ( self.regex.fullmatch(path) - if self._is_endpoint and self.regex.pattern.endswith('$') + if self._is_endpoint and self.regex.pattern.endswith("$") else self.regex.search(path) ) if match: @@ -179,7 +197,7 @@ class RegexPattern(CheckURLMixin): kwargs = match.groupdict() args = () if kwargs else match.groups() kwargs = {k: v for k, v in kwargs.items() if v is not None} - return path[match.end():], args, kwargs + return path[match.end() :], args, kwargs return None def check(self): @@ -191,13 +209,15 @@ class RegexPattern(CheckURLMixin): def _check_include_trailing_dollar(self): regex_pattern = self.regex.pattern - if regex_pattern.endswith('$') and not regex_pattern.endswith(r'\$'): - return [Warning( - "Your URL pattern {} uses include with a route ending with a '$'. " - "Remove the dollar from the route to avoid problems including " - "URLs.".format(self.describe()), - id='urls.W001', - )] + if regex_pattern.endswith("$") and not regex_pattern.endswith(r"\$"): + return [ + Warning( + "Your URL pattern {} uses include with a route ending with a '$'. " + "Remove the dollar from the route to avoid problems including " + "URLs.".format(self.describe()), + id="urls.W001", + ) + ] else: return [] @@ -215,7 +235,7 @@ class RegexPattern(CheckURLMixin): _PATH_PARAMETER_COMPONENT_RE = _lazy_re_compile( - r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>[^>]+)>' + r"<(?:(?P<converter>[^>:]+):)?(?P<parameter>[^>]+)>" ) @@ -227,7 +247,7 @@ def _route_to_regex(route, is_endpoint=False): and {'pk': <django.urls.converters.IntConverter>}. """ original_route = route - parts = ['^'] + parts = ["^"] converters = {} while True: match = _PATH_PARAMETER_COMPONENT_RE.search(route) @@ -239,34 +259,34 @@ def _route_to_regex(route, is_endpoint=False): "URL route '%s' cannot contain whitespace in angle brackets " "<…>." % original_route ) - parts.append(re.escape(route[:match.start()])) - route = route[match.end():] - parameter = match['parameter'] + parts.append(re.escape(route[: match.start()])) + route = route[match.end() :] + parameter = match["parameter"] if not parameter.isidentifier(): raise ImproperlyConfigured( "URL route '%s' uses parameter name %r which isn't a valid " "Python identifier." % (original_route, parameter) ) - raw_converter = match['converter'] + raw_converter = match["converter"] if raw_converter is None: # If a converter isn't specified, the default is `str`. - raw_converter = 'str' + raw_converter = "str" try: converter = get_converter(raw_converter) except KeyError as e: raise ImproperlyConfigured( - 'URL route %r uses invalid converter %r.' + "URL route %r uses invalid converter %r." % (original_route, raw_converter) ) from e converters[parameter] = converter - parts.append('(?P<' + parameter + '>' + converter.regex + ')') + parts.append("(?P<" + parameter + ">" + converter.regex + ")") if is_endpoint: - parts.append(r'\Z') - return ''.join(parts), converters + parts.append(r"\Z") + return "".join(parts), converters class RoutePattern(CheckURLMixin): - regex = LocaleRegexDescriptor('_route') + regex = LocaleRegexDescriptor("_route") def __init__(self, route, name=None, is_endpoint=False): self._route = route @@ -286,19 +306,21 @@ class RoutePattern(CheckURLMixin): kwargs[key] = converter.to_python(value) except ValueError: return None - return path[match.end():], (), kwargs + return path[match.end() :], (), kwargs return None def check(self): warnings = self._check_pattern_startswith_slash() route = self._route - if '(?P<' in route or route.startswith('^') or route.endswith('$'): - warnings.append(Warning( - "Your URL pattern {} has a route that contains '(?P<', begins " - "with a '^', or ends with a '$'. This was likely an oversight " - "when migrating to django.urls.path().".format(self.describe()), - id='2_0.W001', - )) + if "(?P<" in route or route.startswith("^") or route.endswith("$"): + warnings.append( + Warning( + "Your URL pattern {} has a route that contains '(?P<', begins " + "with a '^', or ends with a '$'. This was likely an oversight " + "when migrating to django.urls.path().".format(self.describe()), + id="2_0.W001", + ) + ) return warnings def _compile(self, route): @@ -322,14 +344,14 @@ class LocalePrefixPattern: def language_prefix(self): language_code = get_language() or settings.LANGUAGE_CODE if language_code == settings.LANGUAGE_CODE and not self.prefix_default_language: - return '' + return "" else: - return '%s/' % language_code + return "%s/" % language_code def match(self, path): language_prefix = self.language_prefix if path.startswith(language_prefix): - return path[len(language_prefix):], (), {} + return path[len(language_prefix) :], (), {} return None def check(self): @@ -350,7 +372,7 @@ class URLPattern: self.name = name def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.pattern.describe()) + return "<%s %s>" % (self.__class__.__name__, self.pattern.describe()) def check(self): warnings = self._check_pattern_name() @@ -377,15 +399,18 @@ class URLPattern: view = self.callback if inspect.isclass(view) and issubclass(view, View): - return [Error( - 'Your URL pattern %s has an invalid view, pass %s.as_view() ' - 'instead of %s.' % ( - self.pattern.describe(), - view.__name__, - view.__name__, - ), - id='urls.E009', - )] + return [ + Error( + "Your URL pattern %s has an invalid view, pass %s.as_view() " + "instead of %s." + % ( + self.pattern.describe(), + view.__name__, + view.__name__, + ), + id="urls.E009", + ) + ] return [] def resolve(self, path): @@ -394,7 +419,9 @@ class URLPattern: new_path, args, kwargs = match # Pass any extra_kwargs as **kwargs. kwargs.update(self.default_args) - return ResolverMatch(self.callback, args, kwargs, self.pattern.name, route=str(self.pattern)) + return ResolverMatch( + self.callback, args, kwargs, self.pattern.name, route=str(self.pattern) + ) @cached_property def lookup_str(self): @@ -405,15 +432,17 @@ class URLPattern: callback = self.callback if isinstance(callback, functools.partial): callback = callback.func - if hasattr(callback, 'view_class'): + if hasattr(callback, "view_class"): callback = callback.view_class - elif not hasattr(callback, '__name__'): + elif not hasattr(callback, "__name__"): return callback.__module__ + "." + callback.__class__.__name__ return callback.__module__ + "." + callback.__qualname__ class URLResolver: - def __init__(self, pattern, urlconf_name, default_kwargs=None, app_name=None, namespace=None): + def __init__( + self, pattern, urlconf_name, default_kwargs=None, app_name=None, namespace=None + ): self.pattern = pattern # urlconf_name is the dotted Python path to the module defining # urlpatterns. It may also be an object with an urlpatterns attribute @@ -435,12 +464,15 @@ class URLResolver: def __repr__(self): if isinstance(self.urlconf_name, list) and self.urlconf_name: # Don't bother to output the whole list, it can be huge - urlconf_repr = '<%s list>' % self.urlconf_name[0].__class__.__name__ + urlconf_repr = "<%s list>" % self.urlconf_name[0].__class__.__name__ else: urlconf_repr = repr(self.urlconf_name) - return '<%s %s (%s:%s) %s>' % ( - self.__class__.__name__, urlconf_repr, self.app_name, - self.namespace, self.pattern.describe(), + return "<%s %s (%s:%s) %s>" % ( + self.__class__.__name__, + urlconf_repr, + self.app_name, + self.namespace, + self.pattern.describe(), ) def check(self): @@ -458,11 +490,11 @@ class URLResolver: try: handler = self.resolve_error_handler(status_code) except (ImportError, ViewDoesNotExist) as e: - path = getattr(self.urlconf_module, 'handler%s' % status_code) + path = getattr(self.urlconf_module, "handler%s" % status_code) msg = ( "The custom handler{status_code} view '{path}' could not be imported." ).format(status_code=status_code, path=path) - messages.append(Error(msg, hint=str(e), id='urls.E008')) + messages.append(Error(msg, hint=str(e), id="urls.E008")) continue signature = inspect.signature(handler) args = [None] * num_parameters @@ -474,10 +506,10 @@ class URLResolver: "take the correct number of arguments ({args})." ).format( status_code=status_code, - path=handler.__module__ + '.' + handler.__qualname__, - args='request, exception' if num_parameters == 2 else 'request', + path=handler.__module__ + "." + handler.__qualname__, + args="request, exception" if num_parameters == 2 else "request", ) - messages.append(Error(msg, id='urls.E007')) + messages.append(Error(msg, id="urls.E007")) return messages def _populate(self): @@ -485,7 +517,7 @@ class URLResolver: # infinite recursion. Concurrent threads may call this at the same # time and will need to continue, so set 'populating' on a # thread-local variable. - if getattr(self._local, 'populating', False): + if getattr(self._local, "populating", False): return try: self._local.populating = True @@ -495,28 +527,45 @@ class URLResolver: language_code = get_language() for url_pattern in reversed(self.url_patterns): p_pattern = url_pattern.pattern.regex.pattern - if p_pattern.startswith('^'): + if p_pattern.startswith("^"): p_pattern = p_pattern[1:] if isinstance(url_pattern, URLPattern): self._callback_strs.add(url_pattern.lookup_str) bits = normalize(url_pattern.pattern.regex.pattern) lookups.appendlist( url_pattern.callback, - (bits, p_pattern, url_pattern.default_args, url_pattern.pattern.converters) + ( + bits, + p_pattern, + url_pattern.default_args, + url_pattern.pattern.converters, + ), ) if url_pattern.name is not None: lookups.appendlist( url_pattern.name, - (bits, p_pattern, url_pattern.default_args, url_pattern.pattern.converters) + ( + bits, + p_pattern, + url_pattern.default_args, + url_pattern.pattern.converters, + ), ) else: # url_pattern is a URLResolver. url_pattern._populate() if url_pattern.app_name: - apps.setdefault(url_pattern.app_name, []).append(url_pattern.namespace) + apps.setdefault(url_pattern.app_name, []).append( + url_pattern.namespace + ) namespaces[url_pattern.namespace] = (p_pattern, url_pattern) else: for name in url_pattern.reverse_dict: - for matches, pat, defaults, converters in url_pattern.reverse_dict.getlist(name): + for ( + matches, + pat, + defaults, + converters, + ) in url_pattern.reverse_dict.getlist(name): new_matches = normalize(p_pattern + pat) lookups.appendlist( name, @@ -524,10 +573,17 @@ class URLResolver: new_matches, p_pattern + pat, {**defaults, **url_pattern.default_kwargs}, - {**self.pattern.converters, **url_pattern.pattern.converters, **converters} - ) + { + **self.pattern.converters, + **url_pattern.pattern.converters, + **converters, + }, + ), ) - for namespace, (prefix, sub_pattern) in url_pattern.namespace_dict.items(): + for namespace, ( + prefix, + sub_pattern, + ) in url_pattern.namespace_dict.items(): current_converters = url_pattern.pattern.converters sub_pattern.pattern.converters.update(current_converters) namespaces[namespace] = (p_pattern + prefix, sub_pattern) @@ -574,7 +630,7 @@ class URLResolver: """Join two routes, without the starting ^ in the second route.""" if not route1: return route2 - if route2.startswith('^'): + if route2.startswith("^"): route2 = route2[1:] return route1 + route2 @@ -593,7 +649,7 @@ class URLResolver: try: sub_match = pattern.resolve(new_path) except Resolver404 as e: - self._extend_tried(tried, pattern, e.args[0].get('tried')) + self._extend_tried(tried, pattern, e.args[0].get("tried")) else: if sub_match: # Merge captured arguments in match with submatch @@ -605,7 +661,11 @@ class URLResolver: sub_match_args = sub_match.args if not sub_match_dict: sub_match_args = args + sub_match.args - current_route = '' if isinstance(pattern, URLPattern) else str(pattern.pattern) + current_route = ( + "" + if isinstance(pattern, URLPattern) + else str(pattern.pattern) + ) self._extend_tried(tried, pattern, sub_match.tried) return ResolverMatch( sub_match.func, @@ -618,8 +678,8 @@ class URLResolver: tried, ) tried.append([pattern]) - raise Resolver404({'tried': tried, 'path': new_path}) - raise Resolver404({'path': path}) + raise Resolver404({"tried": tried, "path": new_path}) + raise Resolver404({"path": path}) @cached_property def urlconf_module(self): @@ -645,16 +705,17 @@ class URLResolver: return patterns def resolve_error_handler(self, view_type): - callback = getattr(self.urlconf_module, 'handler%s' % view_type, None) + callback = getattr(self.urlconf_module, "handler%s" % view_type, None) if not callback: # No handler specified in file; use lazy import, since # django.conf.urls imports this file. from django.conf import urls - callback = getattr(urls, 'handler%s' % view_type) + + callback = getattr(urls, "handler%s" % view_type) return get_callable(callback) def reverse(self, lookup_view, *args, **kwargs): - return self._reverse_with_prefix(lookup_view, '', *args, **kwargs) + return self._reverse_with_prefix(lookup_view, "", *args, **kwargs) def _reverse_with_prefix(self, lookup_view, _prefix, *args, **kwargs): if args and kwargs: @@ -696,16 +757,22 @@ class URLResolver: # without quoting to build a decoded URL and look for a match. # Then, if we have a match, redo the substitution with quoted # arguments in order to return a properly encoded URL. - candidate_pat = _prefix.replace('%', '%%') + result - if re.search('^%s%s' % (re.escape(_prefix), pattern), candidate_pat % text_candidate_subs): + candidate_pat = _prefix.replace("%", "%%") + result + if re.search( + "^%s%s" % (re.escape(_prefix), pattern), + candidate_pat % text_candidate_subs, + ): # safe characters from `pchar` definition of RFC 3986 - url = quote(candidate_pat % text_candidate_subs, safe=RFC3986_SUBDELIMS + '/~:@') + url = quote( + candidate_pat % text_candidate_subs, + safe=RFC3986_SUBDELIMS + "/~:@", + ) # Don't allow construction of scheme relative urls. return escape_leading_slashes(url) # lookup_view can be URL name or callable, but callables are not # friendly in error messages. - m = getattr(lookup_view, '__module__', None) - n = getattr(lookup_view, '__name__', None) + m = getattr(lookup_view, "__module__", None) + n = getattr(lookup_view, "__name__", None) if m is not None and n is not None: lookup_view_s = "%s.%s" % (m, n) else: @@ -719,13 +786,15 @@ class URLResolver: arg_msg = "keyword arguments '%s'" % kwargs else: arg_msg = "no arguments" - msg = ( - "Reverse for '%s' with %s not found. %d pattern(s) tried: %s" % - (lookup_view_s, arg_msg, len(patterns), patterns) + msg = "Reverse for '%s' with %s not found. %d pattern(s) tried: %s" % ( + lookup_view_s, + arg_msg, + len(patterns), + patterns, ) else: msg = ( "Reverse for '%(view)s' not found. '%(view)s' is not " - "a valid view function or pattern name." % {'view': lookup_view_s} + "a valid view function or pattern name." % {"view": lookup_view_s} ) raise NoReverseMatch(msg) diff --git a/django/urls/utils.py b/django/urls/utils.py index e59ab9fdbd..60b46d9050 100644 --- a/django/urls/utils.py +++ b/django/urls/utils.py @@ -18,11 +18,15 @@ def get_callable(lookup_view): return lookup_view if not isinstance(lookup_view, str): - raise ViewDoesNotExist("'%s' is not a callable or a dot-notation path" % lookup_view) + raise ViewDoesNotExist( + "'%s' is not a callable or a dot-notation path" % lookup_view + ) mod_name, func_name = get_mod_func(lookup_view) if not func_name: # No '.' in lookup_view - raise ImportError("Could not import '%s'. The path must be fully qualified." % lookup_view) + raise ImportError( + "Could not import '%s'. The path must be fully qualified." % lookup_view + ) try: mod = import_module(mod_name) @@ -30,8 +34,8 @@ def get_callable(lookup_view): parentmod, submod = get_mod_func(mod_name) if submod and not module_has_submodule(import_module(parentmod), submod): raise ViewDoesNotExist( - "Could not import '%s'. Parent module %s does not exist." % - (lookup_view, mod_name) + "Could not import '%s'. Parent module %s does not exist." + % (lookup_view, mod_name) ) else: raise @@ -40,14 +44,14 @@ def get_callable(lookup_view): view_func = getattr(mod, func_name) except AttributeError: raise ViewDoesNotExist( - "Could not import '%s'. View does not exist in module %s." % - (lookup_view, mod_name) + "Could not import '%s'. View does not exist in module %s." + % (lookup_view, mod_name) ) else: if not callable(view_func): raise ViewDoesNotExist( - "Could not import '%s.%s'. View is not callable." % - (mod_name, func_name) + "Could not import '%s.%s'. View is not callable." + % (mod_name, func_name) ) return view_func @@ -56,7 +60,7 @@ def get_mod_func(callback): # Convert 'django.views.news.stories.story_detail' to # ['django.views.news.stories', 'story_detail'] try: - dot = callback.rindex('.') + dot = callback.rindex(".") except ValueError: - return callback, '' - return callback[:dot], callback[dot + 1:] + return callback, "" + return callback[:dot], callback[dot + 1 :] diff --git a/django/utils/_os.py b/django/utils/_os.py index b2a2a9d426..f85fe09248 100644 --- a/django/utils/_os.py +++ b/django/utils/_os.py @@ -23,12 +23,15 @@ def safe_join(base, *paths): # safe_join("/dir", "/../d")) # b) The final path must be the same as the base path. # c) The base path must be the most root path (meaning either "/" or "C:\\") - if (not normcase(final_path).startswith(normcase(base_path + sep)) and - normcase(final_path) != normcase(base_path) and - dirname(normcase(base_path)) != normcase(base_path)): + if ( + not normcase(final_path).startswith(normcase(base_path + sep)) + and normcase(final_path) != normcase(base_path) + and dirname(normcase(base_path)) != normcase(base_path) + ): raise SuspiciousFileOperation( - 'The joined path ({}) is located outside of the base path ' - 'component ({})'.format(final_path, base_path)) + "The joined path ({}) is located outside of the base path " + "component ({})".format(final_path, base_path) + ) return final_path @@ -39,8 +42,8 @@ def symlinks_supported(): permissions). """ with tempfile.TemporaryDirectory() as temp_dir: - original_path = os.path.join(temp_dir, 'original') - symlink_path = os.path.join(temp_dir, 'symlink') + original_path = os.path.join(temp_dir, "original") + symlink_path = os.path.join(temp_dir, "symlink") os.makedirs(original_path) try: os.symlink(original_path, symlink_path) @@ -55,5 +58,5 @@ def to_path(value): if isinstance(value, Path): return value elif not isinstance(value, str): - raise TypeError('Invalid path type: %s' % type(value).__name__) + raise TypeError("Invalid path type: %s" % type(value).__name__) return Path(value) diff --git a/django/utils/archive.py b/django/utils/archive.py index d5a0cf0446..71ec2d0015 100644 --- a/django/utils/archive.py +++ b/django/utils/archive.py @@ -55,6 +55,7 @@ class Archive: """ The external API class that encapsulates an archive implementation. """ + def __init__(self, file): self._archive = self._archive_cls(file)(file) @@ -68,7 +69,8 @@ class Archive: filename = file.name except AttributeError: raise UnrecognizedArchiveFormat( - "File object not a recognized archive format.") + "File object not a recognized archive format." + ) base, tail_ext = os.path.splitext(filename.lower()) cls = extension_map.get(tail_ext) if not cls: @@ -76,7 +78,8 @@ class Archive: cls = extension_map.get(ext) if not cls: raise UnrecognizedArchiveFormat( - "Path not a recognized archive format: %s" % filename) + "Path not a recognized archive format: %s" % filename + ) return cls def __enter__(self): @@ -99,6 +102,7 @@ class BaseArchive: """ Base Archive class. Implementations should inherit this class. """ + @staticmethod def _copy_permissions(mode, filename): """ @@ -111,13 +115,15 @@ class BaseArchive: def split_leading_dir(self, path): path = str(path) - path = path.lstrip('/').lstrip('\\') - if '/' in path and (('\\' in path and path.find('/') < path.find('\\')) or '\\' not in path): - return path.split('/', 1) - elif '\\' in path: - return path.split('\\', 1) + path = path.lstrip("/").lstrip("\\") + if "/" in path and ( + ("\\" in path and path.find("/") < path.find("\\")) or "\\" not in path + ): + return path.split("/", 1) + elif "\\" in path: + return path.split("\\", 1) else: - return path, '' + return path, "" def has_leading_dir(self, paths): """ @@ -143,14 +149,17 @@ class BaseArchive: return filename def extract(self): - raise NotImplementedError('subclasses of BaseArchive must provide an extract() method') + raise NotImplementedError( + "subclasses of BaseArchive must provide an extract() method" + ) def list(self): - raise NotImplementedError('subclasses of BaseArchive must provide a list() method') + raise NotImplementedError( + "subclasses of BaseArchive must provide a list() method" + ) class TarArchive(BaseArchive): - def __init__(self, file): self._archive = tarfile.open(file) @@ -174,13 +183,15 @@ class TarArchive(BaseArchive): except (KeyError, AttributeError) as exc: # Some corrupt tar files seem to produce this # (specifically bad symlinks) - print("In the tar file %s the member %s is invalid: %s" % - (name, member.name, exc)) + print( + "In the tar file %s the member %s is invalid: %s" + % (name, member.name, exc) + ) else: dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) - with open(filename, 'wb') as outfile: + with open(filename, "wb") as outfile: shutil.copyfileobj(extracted, outfile) self._copy_permissions(member.mode, filename) finally: @@ -192,7 +203,6 @@ class TarArchive(BaseArchive): class ZipArchive(BaseArchive): - def __init__(self, file): self._archive = zipfile.ZipFile(file) @@ -210,14 +220,14 @@ class ZipArchive(BaseArchive): if not name: continue filename = self.target_filename(to_path, name) - if name.endswith(('/', '\\')): + if name.endswith(("/", "\\")): # A directory os.makedirs(filename, exist_ok=True) else: dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) - with open(filename, 'wb') as outfile: + with open(filename, "wb") as outfile: outfile.write(data) # Convert ZipInfo.external_attr to mode mode = info.external_attr >> 16 @@ -227,11 +237,21 @@ class ZipArchive(BaseArchive): self._archive.close() -extension_map = dict.fromkeys(( - '.tar', - '.tar.bz2', '.tbz2', '.tbz', '.tz2', - '.tar.gz', '.tgz', '.taz', - '.tar.lzma', '.tlz', - '.tar.xz', '.txz', -), TarArchive) -extension_map['.zip'] = ZipArchive +extension_map = dict.fromkeys( + ( + ".tar", + ".tar.bz2", + ".tbz2", + ".tbz", + ".tz2", + ".tar.gz", + ".tgz", + ".taz", + ".tar.lzma", + ".tlz", + ".tar.xz", + ".txz", + ), + TarArchive, +) +extension_map[".zip"] = ZipArchive diff --git a/django/utils/asyncio.py b/django/utils/asyncio.py index b8e14f1f68..7e0b439db2 100644 --- a/django/utils/asyncio.py +++ b/django/utils/asyncio.py @@ -10,6 +10,7 @@ def async_unsafe(message): Decorator to mark functions as async-unsafe. Someone trying to access the function while in an async context will get an error message. """ + def decorator(func): @wraps(func) def inner(*args, **kwargs): @@ -19,15 +20,17 @@ def async_unsafe(message): except RuntimeError: pass else: - if not os.environ.get('DJANGO_ALLOW_ASYNC_UNSAFE'): + if not os.environ.get("DJANGO_ALLOW_ASYNC_UNSAFE"): raise SynchronousOnlyOperation(message) # Pass onward. return func(*args, **kwargs) + return inner + # If the message is actually a function, then be a no-arguments decorator. if callable(message): func = message - message = 'You cannot call this from an async context - use a thread or sync_to_async.' + message = "You cannot call this from an async context - use a thread or sync_to_async." return decorator(func) else: return decorator diff --git a/django/utils/autoreload.py b/django/utils/autoreload.py index 583c2be647..7b9219f4c1 100644 --- a/django/utils/autoreload.py +++ b/django/utils/autoreload.py @@ -24,9 +24,9 @@ from django.utils.version import get_version_tuple autoreload_started = Signal() file_changed = Signal() -DJANGO_AUTORELOAD_ENV = 'RUN_MAIN' +DJANGO_AUTORELOAD_ENV = "RUN_MAIN" -logger = logging.getLogger('django.utils.autoreload') +logger = logging.getLogger("django.utils.autoreload") # If an error is raised while importing a file, it's not placed in sys.modules. # This means that any future modifications aren't caught. Keep a list of these @@ -48,7 +48,7 @@ except ImportError: def is_django_module(module): """Return True if the given module is nested under Django.""" - return module.__name__.startswith('django.') + return module.__name__.startswith("django.") def is_django_path(path): @@ -67,7 +67,7 @@ def check_errors(fn): et, ev, tb = _exception - if getattr(ev, 'filename', None) is None: + if getattr(ev, "filename", None) is None: # get the filename from the last item in the stack filename = traceback.extract_tb(tb)[-1][0] else: @@ -97,7 +97,7 @@ def ensure_echo_on(): attr_list = termios.tcgetattr(sys.stdin) if not attr_list[3] & termios.ECHO: attr_list[3] |= termios.ECHO - if hasattr(signal, 'SIGTTOU'): + if hasattr(signal, "SIGTTOU"): old_handler = signal.signal(signal.SIGTTOU, signal.SIG_IGN) else: old_handler = None @@ -112,7 +112,11 @@ def iter_all_python_module_files(): # This ensures cached results are returned in the usual case that modules # aren't loaded on the fly. keys = sorted(sys.modules) - modules = tuple(m for m in map(sys.modules.__getitem__, keys) if not isinstance(m, weakref.ProxyTypes)) + modules = tuple( + m + for m in map(sys.modules.__getitem__, keys) + if not isinstance(m, weakref.ProxyTypes) + ) return iter_modules_and_files(modules, frozenset(_error_files)) @@ -126,21 +130,25 @@ def iter_modules_and_files(modules, extra_files): # cause issues here. if not isinstance(module, ModuleType): continue - if module.__name__ == '__main__': + if module.__name__ == "__main__": # __main__ (usually manage.py) doesn't always have a __spec__ set. # Handle this by falling back to using __file__, resolved below. # See https://docs.python.org/reference/import.html#main-spec # __file__ may not exists, e.g. when running ipdb debugger. - if hasattr(module, '__file__'): + if hasattr(module, "__file__"): sys_file_paths.append(module.__file__) continue - if getattr(module, '__spec__', None) is None: + if getattr(module, "__spec__", None) is None: continue spec = module.__spec__ # Modules could be loaded from places without a concrete location. If # this is the case, skip them. if spec.has_location: - origin = spec.loader.archive if isinstance(spec.loader, zipimporter) else spec.origin + origin = ( + spec.loader.archive + if isinstance(spec.loader, zipimporter) + else spec.origin + ) sys_file_paths.append(origin) results = set() @@ -217,49 +225,50 @@ def get_child_arguments(): on reloading. """ import __main__ + py_script = Path(sys.argv[0]) - args = [sys.executable] + ['-W%s' % o for o in sys.warnoptions] - if sys.implementation.name == 'cpython': + args = [sys.executable] + ["-W%s" % o for o in sys.warnoptions] + if sys.implementation.name == "cpython": args.extend( - f'-X{key}' if value is True else f'-X{key}={value}' + f"-X{key}" if value is True else f"-X{key}={value}" for key, value in sys._xoptions.items() ) # __spec__ is set when the server was started with the `-m` option, # see https://docs.python.org/3/reference/import.html#main-spec # __spec__ may not exist, e.g. when running in a Conda env. - if getattr(__main__, '__spec__', None) is not None: + if getattr(__main__, "__spec__", None) is not None: spec = __main__.__spec__ - if (spec.name == '__main__' or spec.name.endswith('.__main__')) and spec.parent: + if (spec.name == "__main__" or spec.name.endswith(".__main__")) and spec.parent: name = spec.parent else: name = spec.name - args += ['-m', name] + args += ["-m", name] args += sys.argv[1:] elif not py_script.exists(): # sys.argv[0] may not exist for several reasons on Windows. # It may exist with a .exe extension or have a -script.py suffix. - exe_entrypoint = py_script.with_suffix('.exe') + exe_entrypoint = py_script.with_suffix(".exe") if exe_entrypoint.exists(): # Should be executed directly, ignoring sys.executable. return [exe_entrypoint, *sys.argv[1:]] - script_entrypoint = py_script.with_name('%s-script.py' % py_script.name) + script_entrypoint = py_script.with_name("%s-script.py" % py_script.name) if script_entrypoint.exists(): # Should be executed as usual. return [*args, script_entrypoint, *sys.argv[1:]] - raise RuntimeError('Script %s does not exist.' % py_script) + raise RuntimeError("Script %s does not exist." % py_script) else: args += sys.argv return args def trigger_reload(filename): - logger.info('%s changed, reloading.', filename) + logger.info("%s changed, reloading.", filename) sys.exit(3) def restart_with_reloader(): - new_environ = {**os.environ, DJANGO_AUTORELOAD_ENV: 'true'} + new_environ = {**os.environ, DJANGO_AUTORELOAD_ENV: "true"} args = get_child_arguments() while True: p = subprocess.run(args, env=new_environ, close_fds=False) @@ -279,12 +288,12 @@ class BaseReloader: path = path.absolute() except FileNotFoundError: logger.debug( - 'Unable to watch directory %s as it cannot be resolved.', + "Unable to watch directory %s as it cannot be resolved.", path, exc_info=True, ) return - logger.debug('Watching dir %s with glob %s.', path, glob) + logger.debug("Watching dir %s with glob %s.", path, glob) self.directory_globs[path].add(glob) def watched_files(self, include_globs=True): @@ -314,11 +323,11 @@ class BaseReloader: if app_reg.ready_event.wait(timeout=0.1): return True else: - logger.debug('Main Django thread has terminated before apps are ready.') + logger.debug("Main Django thread has terminated before apps are ready.") return False def run(self, django_main_thread): - logger.debug('Waiting for apps ready_event.') + logger.debug("Waiting for apps ready_event.") self.wait_for_apps_ready(apps, django_main_thread) from django.urls import get_resolver @@ -330,7 +339,7 @@ class BaseReloader: # Loading the urlconf can result in errors during development. # If this occurs then swallow the error and continue. pass - logger.debug('Apps ready_event triggered. Sending autoreload_started signal.') + logger.debug("Apps ready_event triggered. Sending autoreload_started signal.") autoreload_started.send(sender=self) self.run_loop() @@ -351,15 +360,15 @@ class BaseReloader: testability of the reloader implementations by decoupling the work they do from the loop. """ - raise NotImplementedError('subclasses must implement tick().') + raise NotImplementedError("subclasses must implement tick().") @classmethod def check_availability(cls): - raise NotImplementedError('subclasses must implement check_availability().') + raise NotImplementedError("subclasses must implement check_availability().") def notify_file_changed(self, path): results = file_changed.send(sender=self, file_path=path) - logger.debug('%s notified as changed. Signal results: %s.', path, results) + logger.debug("%s notified as changed. Signal results: %s.", path, results) if not any(res[1] for res in results): trigger_reload(path) @@ -382,10 +391,15 @@ class StatReloader(BaseReloader): old_time = mtimes.get(filepath) mtimes[filepath] = mtime if old_time is None: - logger.debug('File %s first seen with mtime %s', filepath, mtime) + logger.debug("File %s first seen with mtime %s", filepath, mtime) continue elif mtime > old_time: - logger.debug('File %s previous mtime: %s, current mtime: %s', filepath, old_time, mtime) + logger.debug( + "File %s previous mtime: %s, current mtime: %s", + filepath, + old_time, + mtime, + ) self.notify_file_changed(filepath) time.sleep(self.SLEEP_TIME) @@ -418,7 +432,7 @@ class WatchmanReloader(BaseReloader): def __init__(self): self.roots = defaultdict(set) self.processed_request = threading.Event() - self.client_timeout = int(os.environ.get('DJANGO_WATCHMAN_TIMEOUT', 5)) + self.client_timeout = int(os.environ.get("DJANGO_WATCHMAN_TIMEOUT", 5)) super().__init__() @cached_property @@ -437,52 +451,63 @@ class WatchmanReloader(BaseReloader): # now, watching its parent, if possible, is sufficient. if not root.exists(): if not root.parent.exists(): - logger.warning('Unable to watch root dir %s as neither it or its parent exist.', root) + logger.warning( + "Unable to watch root dir %s as neither it or its parent exist.", + root, + ) return root = root.parent - result = self.client.query('watch-project', str(root.absolute())) - if 'warning' in result: - logger.warning('Watchman warning: %s', result['warning']) - logger.debug('Watchman watch-project result: %s', result) - return result['watch'], result.get('relative_path') + result = self.client.query("watch-project", str(root.absolute())) + if "warning" in result: + logger.warning("Watchman warning: %s", result["warning"]) + logger.debug("Watchman watch-project result: %s", result) + return result["watch"], result.get("relative_path") @functools.lru_cache def _get_clock(self, root): - return self.client.query('clock', root)['clock'] + return self.client.query("clock", root)["clock"] def _subscribe(self, directory, name, expression): root, rel_path = self._watch_root(directory) # Only receive notifications of files changing, filtering out other types # like special files: https://facebook.github.io/watchman/docs/type only_files_expression = [ - 'allof', - ['anyof', ['type', 'f'], ['type', 'l']], - expression + "allof", + ["anyof", ["type", "f"], ["type", "l"]], + expression, ] query = { - 'expression': only_files_expression, - 'fields': ['name'], - 'since': self._get_clock(root), - 'dedup_results': True, + "expression": only_files_expression, + "fields": ["name"], + "since": self._get_clock(root), + "dedup_results": True, } if rel_path: - query['relative_root'] = rel_path - logger.debug('Issuing watchman subscription %s, for root %s. Query: %s', name, root, query) - self.client.query('subscribe', root, name, query) + query["relative_root"] = rel_path + logger.debug( + "Issuing watchman subscription %s, for root %s. Query: %s", + name, + root, + query, + ) + self.client.query("subscribe", root, name, query) def _subscribe_dir(self, directory, filenames): if not directory.exists(): if not directory.parent.exists(): - logger.warning('Unable to watch directory %s as neither it or its parent exist.', directory) + logger.warning( + "Unable to watch directory %s as neither it or its parent exist.", + directory, + ) return - prefix = 'files-parent-%s' % directory.name - filenames = ['%s/%s' % (directory.name, filename) for filename in filenames] + prefix = "files-parent-%s" % directory.name + filenames = ["%s/%s" % (directory.name, filename) for filename in filenames] directory = directory.parent - expression = ['name', filenames, 'wholename'] + expression = ["name", filenames, "wholename"] else: - prefix = 'files' - expression = ['name', filenames] - self._subscribe(directory, '%s:%s' % (prefix, directory), expression) + prefix = "files" + expression = ["name", filenames] + self._subscribe(directory, "%s:%s" % (prefix, directory), expression) def _watch_glob(self, directory, patterns): """ @@ -493,19 +518,22 @@ class WatchmanReloader(BaseReloader): overwrite the named subscription, so it must include all possible glob expressions. """ - prefix = 'glob' + prefix = "glob" if not directory.exists(): if not directory.parent.exists(): - logger.warning('Unable to watch directory %s as neither it or its parent exist.', directory) + logger.warning( + "Unable to watch directory %s as neither it or its parent exist.", + directory, + ) return - prefix = 'glob-parent-%s' % directory.name - patterns = ['%s/%s' % (directory.name, pattern) for pattern in patterns] + prefix = "glob-parent-%s" % directory.name + patterns = ["%s/%s" % (directory.name, pattern) for pattern in patterns] directory = directory.parent - expression = ['anyof'] + expression = ["anyof"] for pattern in patterns: - expression.append(['match', pattern, 'wholename']) - self._subscribe(directory, '%s:%s' % (prefix, directory), expression) + expression.append(["match", pattern, "wholename"]) + self._subscribe(directory, "%s:%s" % (prefix, directory), expression) def watched_roots(self, watched_files): extra_directories = self.directory_globs.keys() @@ -516,8 +544,8 @@ class WatchmanReloader(BaseReloader): def _update_watches(self): watched_files = list(self.watched_files(include_globs=False)) found_roots = common_roots(self.watched_roots(watched_files)) - logger.debug('Watching %s files', len(watched_files)) - logger.debug('Found common roots: %s', found_roots) + logger.debug("Watching %s files", len(watched_files)) + logger.debug("Found common roots: %s", found_roots) # Setup initial roots for performance, shortest roots first. for root in sorted(found_roots): self._watch_root(root) @@ -527,7 +555,9 @@ class WatchmanReloader(BaseReloader): sorted_files = sorted(watched_files, key=lambda p: p.parent) for directory, group in itertools.groupby(sorted_files, key=lambda p: p.parent): # These paths need to be relative to the parent directory. - self._subscribe_dir(directory, [str(p.relative_to(directory)) for p in group]) + self._subscribe_dir( + directory, [str(p.relative_to(directory)) for p in group] + ) def update_watches(self): try: @@ -541,19 +571,19 @@ class WatchmanReloader(BaseReloader): subscription = self.client.getSubscription(sub) if not subscription: return - logger.debug('Watchman subscription %s has results.', sub) + logger.debug("Watchman subscription %s has results.", sub) for result in subscription: # When using watch-project, it's not simple to get the relative # directory without storing some specific state. Store the full # path to the directory in the subscription name, prefixed by its # type (glob, files). - root_directory = Path(result['subscription'].split(':', 1)[1]) - logger.debug('Found root directory %s', root_directory) - for file in result.get('files', []): + root_directory = Path(result["subscription"].split(":", 1)[1]) + logger.debug("Found root directory %s", root_directory) + for file in result.get("files", []): self.notify_file_changed(root_directory / file) def request_processed(self, **kwargs): - logger.debug('Request processed. Setting update_watches event.') + logger.debug("Request processed. Setting update_watches event.") self.processed_request.set() def tick(self): @@ -568,7 +598,7 @@ class WatchmanReloader(BaseReloader): except pywatchman.SocketTimeout: pass except pywatchman.WatchmanError as ex: - logger.debug('Watchman error: %s, checking server status.', ex) + logger.debug("Watchman error: %s, checking server status.", ex) self.check_server_status(ex) else: for sub in list(self.client.subs.keys()): @@ -584,7 +614,7 @@ class WatchmanReloader(BaseReloader): def check_server_status(self, inner_ex=None): """Return True if the server is available.""" try: - self.client.query('version') + self.client.query("version") except Exception: raise WatchmanUnavailable(str(inner_ex)) from inner_ex return True @@ -592,19 +622,19 @@ class WatchmanReloader(BaseReloader): @classmethod def check_availability(cls): if not pywatchman: - raise WatchmanUnavailable('pywatchman not installed.') + raise WatchmanUnavailable("pywatchman not installed.") client = pywatchman.client(timeout=0.1) try: result = client.capabilityCheck() except Exception: # The service is down? - raise WatchmanUnavailable('Cannot connect to the watchman service.') - version = get_version_tuple(result['version']) + raise WatchmanUnavailable("Cannot connect to the watchman service.") + version = get_version_tuple(result["version"]) # Watchman 4.9 includes multiple improvements to watching project # directories as well as case insensitive filesystems. - logger.debug('Watchman version %s', version) + logger.debug("Watchman version %s", version) if version < (4, 9): - raise WatchmanUnavailable('Watchman 4.9 or later is required.') + raise WatchmanUnavailable("Watchman 4.9 or later is required.") def get_reloader(): @@ -620,7 +650,9 @@ def start_django(reloader, main_func, *args, **kwargs): ensure_echo_on() main_func = check_errors(main_func) - django_main_thread = threading.Thread(target=main_func, args=args, kwargs=kwargs, name='django-main-thread') + django_main_thread = threading.Thread( + target=main_func, args=args, kwargs=kwargs, name="django-main-thread" + ) django_main_thread.daemon = True django_main_thread.start() @@ -631,16 +663,20 @@ def start_django(reloader, main_func, *args, **kwargs): # It's possible that the watchman service shuts down or otherwise # becomes unavailable. In that case, use the StatReloader. reloader = StatReloader() - logger.error('Error connecting to Watchman: %s', ex) - logger.info('Watching for file changes with %s', reloader.__class__.__name__) + logger.error("Error connecting to Watchman: %s", ex) + logger.info( + "Watching for file changes with %s", reloader.__class__.__name__ + ) def run_with_reloader(main_func, *args, **kwargs): signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) try: - if os.environ.get(DJANGO_AUTORELOAD_ENV) == 'true': + if os.environ.get(DJANGO_AUTORELOAD_ENV) == "true": reloader = get_reloader() - logger.info('Watching for file changes with %s', reloader.__class__.__name__) + logger.info( + "Watching for file changes with %s", reloader.__class__.__name__ + ) start_django(reloader, main_func, *args, **kwargs) else: exit_code = restart_with_reloader() diff --git a/django/utils/baseconv.py b/django/utils/baseconv.py index 21f1fb3b91..fcaab23f53 100644 --- a/django/utils/baseconv.py +++ b/django/utils/baseconv.py @@ -42,33 +42,37 @@ import warnings from django.utils.deprecation import RemovedInDjango50Warning warnings.warn( - 'The django.utils.baseconv module is deprecated.', + "The django.utils.baseconv module is deprecated.", category=RemovedInDjango50Warning, stacklevel=2, ) -BASE2_ALPHABET = '01' -BASE16_ALPHABET = '0123456789ABCDEF' -BASE56_ALPHABET = '23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz' -BASE36_ALPHABET = '0123456789abcdefghijklmnopqrstuvwxyz' -BASE62_ALPHABET = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' -BASE64_ALPHABET = BASE62_ALPHABET + '-_' +BASE2_ALPHABET = "01" +BASE16_ALPHABET = "0123456789ABCDEF" +BASE56_ALPHABET = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz" +BASE36_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyz" +BASE62_ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +BASE64_ALPHABET = BASE62_ALPHABET + "-_" class BaseConverter: - decimal_digits = '0123456789' + decimal_digits = "0123456789" - def __init__(self, digits, sign='-'): + def __init__(self, digits, sign="-"): self.sign = sign self.digits = digits if sign in self.digits: - raise ValueError('Sign character found in converter base digits.') + raise ValueError("Sign character found in converter base digits.") def __repr__(self): - return "<%s: base%s (%s)>" % (self.__class__.__name__, len(self.digits), self.digits) + return "<%s: base%s (%s)>" % ( + self.__class__.__name__, + len(self.digits), + self.digits, + ) def encode(self, i): - neg, value = self.convert(i, self.decimal_digits, self.digits, '-') + neg, value = self.convert(i, self.decimal_digits, self.digits, "-") if neg: return self.sign + value return value @@ -76,7 +80,7 @@ class BaseConverter: def decode(self, s): neg, value = self.convert(s, self.digits, self.decimal_digits, self.sign) if neg: - value = '-' + value + value = "-" + value return int(value) def convert(self, number, from_digits, to_digits, sign): @@ -95,7 +99,7 @@ class BaseConverter: if x == 0: res = to_digits[0] else: - res = '' + res = "" while x > 0: digit = x % len(to_digits) res = to_digits[digit] + res @@ -108,4 +112,4 @@ base16 = BaseConverter(BASE16_ALPHABET) base36 = BaseConverter(BASE36_ALPHABET) base56 = BaseConverter(BASE56_ALPHABET) base62 = BaseConverter(BASE62_ALPHABET) -base64 = BaseConverter(BASE64_ALPHABET, sign='$') +base64 = BaseConverter(BASE64_ALPHABET, sign="$") diff --git a/django/utils/cache.py b/django/utils/cache.py index c0e47e0e42..90292ce4da 100644 --- a/django/utils/cache.py +++ b/django/utils/cache.py @@ -23,15 +23,13 @@ from django.conf import settings from django.core.cache import caches from django.http import HttpResponse, HttpResponseNotModified from django.utils.crypto import md5 -from django.utils.http import ( - http_date, parse_etags, parse_http_date_safe, quote_etag, -) +from django.utils.http import http_date, parse_etags, parse_http_date_safe, quote_etag from django.utils.log import log_response from django.utils.regex_helper import _lazy_re_compile from django.utils.timezone import get_current_timezone_name from django.utils.translation import get_language -cc_delim_re = _lazy_re_compile(r'\s*,\s*') +cc_delim_re = _lazy_re_compile(r"\s*,\s*") def patch_cache_control(response, **kwargs): @@ -46,8 +44,9 @@ def patch_cache_control(response, **kwargs): * All other parameters are added with their value, after applying str() to it. """ + def dictitem(s): - t = s.split('=', 1) + t = s.split("=", 1) if len(t) > 1: return (t[0].lower(), t[1]) else: @@ -57,13 +56,13 @@ def patch_cache_control(response, **kwargs): if t[1] is True: return t[0] else: - return '%s=%s' % (t[0], t[1]) + return "%s=%s" % (t[0], t[1]) cc = defaultdict(set) - if response.get('Cache-Control'): - for field in cc_delim_re.split(response.headers['Cache-Control']): + if response.get("Cache-Control"): + for field in cc_delim_re.split(response.headers["Cache-Control"]): directive, value = dictitem(field) - if directive == 'no-cache': + if directive == "no-cache": # no-cache supports multiple field names. cc[directive].add(value) else: @@ -72,18 +71,18 @@ def patch_cache_control(response, **kwargs): # If there's already a max-age header but we're being asked to set a new # max-age, use the minimum of the two ages. In practice this happens when # a decorator and a piece of middleware both operate on a given view. - if 'max-age' in cc and 'max_age' in kwargs: - kwargs['max_age'] = min(int(cc['max-age']), kwargs['max_age']) + if "max-age" in cc and "max_age" in kwargs: + kwargs["max_age"] = min(int(cc["max-age"]), kwargs["max_age"]) # Allow overriding private caching and vice versa - if 'private' in cc and 'public' in kwargs: - del cc['private'] - elif 'public' in cc and 'private' in kwargs: - del cc['public'] + if "private" in cc and "public" in kwargs: + del cc["private"] + elif "public" in cc and "private" in kwargs: + del cc["public"] for (k, v) in kwargs.items(): - directive = k.replace('_', '-') - if directive == 'no-cache': + directive = k.replace("_", "-") + if directive == "no-cache": # no-cache supports multiple field names. cc[directive].add(v) else: @@ -98,8 +97,8 @@ def patch_cache_control(response, **kwargs): directives.extend([dictvalue(directive, value) for value in values]) else: directives.append(dictvalue(directive, values)) - cc = ', '.join(directives) - response.headers['Cache-Control'] = cc + cc = ", ".join(directives) + response.headers["Cache-Control"] = cc def get_max_age(response): @@ -107,18 +106,20 @@ def get_max_age(response): Return the max-age from the response Cache-Control header as an integer, or None if it wasn't found or wasn't an integer. """ - if not response.has_header('Cache-Control'): + if not response.has_header("Cache-Control"): return - cc = dict(_to_tuple(el) for el in cc_delim_re.split(response.headers['Cache-Control'])) + cc = dict( + _to_tuple(el) for el in cc_delim_re.split(response.headers["Cache-Control"]) + ) try: - return int(cc['max-age']) + return int(cc["max-age"]) except (ValueError, TypeError, KeyError): pass def set_response_etag(response): if not response.streaming and response.content: - response.headers['ETag'] = quote_etag( + response.headers["ETag"] = quote_etag( md5(response.content, usedforsecurity=False).hexdigest(), ) return response @@ -127,7 +128,8 @@ def set_response_etag(response): def _precondition_failed(request): response = HttpResponse(status=412) log_response( - 'Precondition Failed: %s', request.path, + "Precondition Failed: %s", + request.path, response=response, request=request, ) @@ -139,7 +141,15 @@ def _not_modified(request, response=None): if response: # Preserve the headers required by Section 4.1 of RFC 7232, as well as # Last-Modified. - for header in ('Cache-Control', 'Content-Location', 'Date', 'ETag', 'Expires', 'Last-Modified', 'Vary'): + for header in ( + "Cache-Control", + "Content-Location", + "Date", + "ETag", + "Expires", + "Last-Modified", + "Vary", + ): if header in response: new_response.headers[header] = response.headers[header] @@ -158,11 +168,13 @@ def get_conditional_response(request, etag=None, last_modified=None, response=No return response # Get HTTP request headers. - if_match_etags = parse_etags(request.META.get('HTTP_IF_MATCH', '')) - if_unmodified_since = request.META.get('HTTP_IF_UNMODIFIED_SINCE') - if_unmodified_since = if_unmodified_since and parse_http_date_safe(if_unmodified_since) - if_none_match_etags = parse_etags(request.META.get('HTTP_IF_NONE_MATCH', '')) - if_modified_since = request.META.get('HTTP_IF_MODIFIED_SINCE') + if_match_etags = parse_etags(request.META.get("HTTP_IF_MATCH", "")) + if_unmodified_since = request.META.get("HTTP_IF_UNMODIFIED_SINCE") + if_unmodified_since = if_unmodified_since and parse_http_date_safe( + if_unmodified_since + ) + if_none_match_etags = parse_etags(request.META.get("HTTP_IF_NONE_MATCH", "")) + if_modified_since = request.META.get("HTTP_IF_MODIFIED_SINCE") if_modified_since = if_modified_since and parse_http_date_safe(if_modified_since) # Step 1 of section 6 of RFC 7232: Test the If-Match precondition. @@ -170,23 +182,26 @@ def get_conditional_response(request, etag=None, last_modified=None, response=No return _precondition_failed(request) # Step 2: Test the If-Unmodified-Since precondition. - if (not if_match_etags and if_unmodified_since and - not _if_unmodified_since_passes(last_modified, if_unmodified_since)): + if ( + not if_match_etags + and if_unmodified_since + and not _if_unmodified_since_passes(last_modified, if_unmodified_since) + ): return _precondition_failed(request) # Step 3: Test the If-None-Match precondition. if if_none_match_etags and not _if_none_match_passes(etag, if_none_match_etags): - if request.method in ('GET', 'HEAD'): + if request.method in ("GET", "HEAD"): return _not_modified(request, response) else: return _precondition_failed(request) # Step 4: Test the If-Modified-Since precondition. if ( - not if_none_match_etags and - if_modified_since and - not _if_modified_since_passes(last_modified, if_modified_since) and - request.method in ('GET', 'HEAD') + not if_none_match_etags + and if_modified_since + and not _if_modified_since_passes(last_modified, if_modified_since) + and request.method in ("GET", "HEAD") ): return _not_modified(request, response) @@ -202,12 +217,12 @@ def _if_match_passes(target_etag, etags): if not target_etag: # If there isn't an ETag, then there can't be a match. return False - elif etags == ['*']: + elif etags == ["*"]: # The existence of an ETag means that there is "a current # representation for the target resource", even if the ETag is weak, # so there is a match to '*'. return True - elif target_etag.startswith('W/'): + elif target_etag.startswith("W/"): # A weak ETag can never strongly match another ETag. return False else: @@ -231,15 +246,15 @@ def _if_none_match_passes(target_etag, etags): if not target_etag: # If there isn't an ETag, then there isn't a match. return True - elif etags == ['*']: + elif etags == ["*"]: # The existence of an ETag means that there is "a current # representation for the target resource", so there is a match to '*'. return False else: # The comparison should be weak, so look for a match after stripping # off any weak indicators. - target_etag = target_etag.strip('W/') - etags = (etag.strip('W/') for etag in etags) + target_etag = target_etag.strip("W/") + etags = (etag.strip("W/") for etag in etags) return target_etag not in etags @@ -264,8 +279,8 @@ def patch_response_headers(response, cache_timeout=None): cache_timeout = settings.CACHE_MIDDLEWARE_SECONDS if cache_timeout < 0: cache_timeout = 0 # Can't have max-age negative - if not response.has_header('Expires'): - response.headers['Expires'] = http_date(time.time() + cache_timeout) + if not response.has_header("Expires"): + response.headers["Expires"] = http_date(time.time() + cache_timeout) patch_cache_control(response, max_age=cache_timeout) @@ -274,7 +289,9 @@ def add_never_cache_headers(response): Add headers to a response to indicate that a page should never be cached. """ patch_response_headers(response, cache_timeout=-1) - patch_cache_control(response, no_cache=True, no_store=True, must_revalidate=True, private=True) + patch_cache_control( + response, no_cache=True, no_store=True, must_revalidate=True, private=True + ) def patch_vary_headers(response, newheaders): @@ -287,28 +304,31 @@ def patch_vary_headers(response, newheaders): # Note that we need to keep the original order intact, because cache # implementations may rely on the order of the Vary contents in, say, # computing an MD5 hash. - if response.has_header('Vary'): - vary_headers = cc_delim_re.split(response.headers['Vary']) + if response.has_header("Vary"): + vary_headers = cc_delim_re.split(response.headers["Vary"]) else: vary_headers = [] # Use .lower() here so we treat headers as case-insensitive. existing_headers = {header.lower() for header in vary_headers} - additional_headers = [newheader for newheader in newheaders - if newheader.lower() not in existing_headers] + additional_headers = [ + newheader + for newheader in newheaders + if newheader.lower() not in existing_headers + ] vary_headers += additional_headers - if '*' in vary_headers: - response.headers['Vary'] = '*' + if "*" in vary_headers: + response.headers["Vary"] = "*" else: - response.headers['Vary'] = ', '.join(vary_headers) + response.headers["Vary"] = ", ".join(vary_headers) def has_vary_header(response, header_query): """ Check to see if the response has a given header name in its Vary header. """ - if not response.has_header('Vary'): + if not response.has_header("Vary"): return False - vary_headers = cc_delim_re.split(response.headers['Vary']) + vary_headers = cc_delim_re.split(response.headers["Vary"]) existing_headers = {header.lower() for header in vary_headers} return header_query.lower() in existing_headers @@ -319,9 +339,9 @@ def _i18n_cache_key_suffix(request, cache_key): # first check if LocaleMiddleware or another middleware added # LANGUAGE_CODE to request, then fall back to the active language # which in turn can also fall back to settings.LANGUAGE_CODE - cache_key += '.%s' % getattr(request, 'LANGUAGE_CODE', get_language()) + cache_key += ".%s" % getattr(request, "LANGUAGE_CODE", get_language()) if settings.USE_TZ: - cache_key += '.%s' % get_current_timezone_name() + cache_key += ".%s" % get_current_timezone_name() return cache_key @@ -332,21 +352,27 @@ def _generate_cache_key(request, method, headerlist, key_prefix): value = request.META.get(header) if value is not None: ctx.update(value.encode()) - url = md5(request.build_absolute_uri().encode('ascii'), usedforsecurity=False) - cache_key = 'views.decorators.cache.cache_page.%s.%s.%s.%s' % ( - key_prefix, method, url.hexdigest(), ctx.hexdigest()) + url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False) + cache_key = "views.decorators.cache.cache_page.%s.%s.%s.%s" % ( + key_prefix, + method, + url.hexdigest(), + ctx.hexdigest(), + ) return _i18n_cache_key_suffix(request, cache_key) def _generate_cache_header_key(key_prefix, request): """Return a cache key for the header cache.""" - url = md5(request.build_absolute_uri().encode('ascii'), usedforsecurity=False) - cache_key = 'views.decorators.cache.cache_header.%s.%s' % ( - key_prefix, url.hexdigest()) + url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False) + cache_key = "views.decorators.cache.cache_header.%s.%s" % ( + key_prefix, + url.hexdigest(), + ) return _i18n_cache_key_suffix(request, cache_key) -def get_cache_key(request, key_prefix=None, method='GET', cache=None): +def get_cache_key(request, key_prefix=None, method="GET", cache=None): """ Return a cache key based on the request URL and query. It can be used in the request phase because it pulls the list of headers to take into @@ -388,17 +414,17 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach cache_key = _generate_cache_header_key(key_prefix, request) if cache is None: cache = caches[settings.CACHE_MIDDLEWARE_ALIAS] - if response.has_header('Vary'): + if response.has_header("Vary"): is_accept_language_redundant = settings.USE_I18N # If i18n is used, the generated cache key will be suffixed with the # current locale. Adding the raw value of Accept-Language is redundant # in that case and would result in storing the same content under # multiple keys in the cache. See #18191 for details. headerlist = [] - for header in cc_delim_re.split(response.headers['Vary']): - header = header.upper().replace('-', '_') - if header != 'ACCEPT_LANGUAGE' or not is_accept_language_redundant: - headerlist.append('HTTP_' + header) + for header in cc_delim_re.split(response.headers["Vary"]): + header = header.upper().replace("-", "_") + if header != "ACCEPT_LANGUAGE" or not is_accept_language_redundant: + headerlist.append("HTTP_" + header) headerlist.sort() cache.set(cache_key, headerlist, cache_timeout) return _generate_cache_key(request, request.method, headerlist, key_prefix) @@ -410,7 +436,7 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach def _to_tuple(s): - t = s.split('=', 1) + t = s.split("=", 1) if len(t) == 2: return t[0].lower(), t[1] return t[0].lower(), True diff --git a/django/utils/connection.py b/django/utils/connection.py index 72a02143fe..1b5895e8d3 100644 --- a/django/utils/connection.py +++ b/django/utils/connection.py @@ -8,8 +8,8 @@ class ConnectionProxy: """Proxy for accessing a connection object's attributes.""" def __init__(self, connections, alias): - self.__dict__['_connections'] = connections - self.__dict__['_alias'] = alias + self.__dict__["_connections"] = connections + self.__dict__["_alias"] = alias def __getattr__(self, item): return getattr(self._connections[self._alias], item) @@ -51,7 +51,7 @@ class BaseConnectionHandler: return settings def create_connection(self, alias): - raise NotImplementedError('Subclasses must implement create_connection().') + raise NotImplementedError("Subclasses must implement create_connection().") def __getitem__(self, alias): try: diff --git a/django/utils/crypto.py b/django/utils/crypto.py index 2af58fda6e..341cb742c1 100644 --- a/django/utils/crypto.py +++ b/django/utils/crypto.py @@ -12,10 +12,11 @@ from django.utils.inspect import func_supports_parameter class InvalidAlgorithm(ValueError): """Algorithm is not supported by hashlib.""" + pass -def salted_hmac(key_salt, value, secret=None, *, algorithm='sha1'): +def salted_hmac(key_salt, value, secret=None, *, algorithm="sha1"): """ Return the HMAC of 'value', using a key generated from key_salt and a secret (which defaults to settings.SECRET_KEY). Default algorithm is SHA1, @@ -32,8 +33,7 @@ def salted_hmac(key_salt, value, secret=None, *, algorithm='sha1'): hasher = getattr(hashlib, algorithm) except AttributeError as e: raise InvalidAlgorithm( - '%r is not an algorithm accepted by the hashlib module.' - % algorithm + "%r is not an algorithm accepted by the hashlib module." % algorithm ) from e # We need to generate a derived key from our base key. We can do this by # passing the key_salt and our base key through a pseudo-random function. @@ -45,7 +45,7 @@ def salted_hmac(key_salt, value, secret=None, *, algorithm='sha1'): return hmac.new(key, msg=force_bytes(value), digestmod=hasher) -RANDOM_STRING_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' +RANDOM_STRING_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" def get_random_string(length, allowed_chars=RANDOM_STRING_CHARS): @@ -59,7 +59,7 @@ def get_random_string(length, allowed_chars=RANDOM_STRING_CHARS): * length: 12, bit length =~ 71 bits * length: 22, bit length =~ 131 bits """ - return ''.join(secrets.choice(allowed_chars) for i in range(length)) + return "".join(secrets.choice(allowed_chars) for i in range(length)) def constant_time_compare(val1, val2): @@ -81,11 +81,12 @@ def pbkdf2(password, salt, iterations, dklen=0, digest=None): # detect whether the usedforsecurity argument is available as this fix may also # have been applied by downstream package maintainers to other versions in # their repositories. -if func_supports_parameter(hashlib.md5, 'usedforsecurity'): +if func_supports_parameter(hashlib.md5, "usedforsecurity"): md5 = hashlib.md5 new_hash = hashlib.new else: - def md5(data=b'', *, usedforsecurity=True): + + def md5(data=b"", *, usedforsecurity=True): return hashlib.md5(data) def new_hash(hash_algorithm, *, usedforsecurity=True): diff --git a/django/utils/datastructures.py b/django/utils/datastructures.py index 18c91aa5e8..b5858b8076 100644 --- a/django/utils/datastructures.py +++ b/django/utils/datastructures.py @@ -38,8 +38,8 @@ class OrderedSet: return len(self.dict) def __repr__(self): - data = repr(list(self.dict)) if self.dict else '' - return f'{self.__class__.__qualname__}({data})' + data = repr(list(self.dict)) if self.dict else "" + return f"{self.__class__.__qualname__}({data})" class MultiValueDictKeyError(KeyError): @@ -68,6 +68,7 @@ class MultiValueDict(dict): which returns a list for every key, even though most web forms submit single name-value pairs. """ + def __init__(self, key_to_list_mapping=()): super().__init__(key_to_list_mapping) @@ -92,24 +93,22 @@ class MultiValueDict(dict): super().__setitem__(key, [value]) def __copy__(self): - return self.__class__([ - (k, v[:]) - for k, v in self.lists() - ]) + return self.__class__([(k, v[:]) for k, v in self.lists()]) def __deepcopy__(self, memo): result = self.__class__() memo[id(self)] = result for key, value in dict.items(self): - dict.__setitem__(result, copy.deepcopy(key, memo), - copy.deepcopy(value, memo)) + dict.__setitem__( + result, copy.deepcopy(key, memo), copy.deepcopy(value, memo) + ) return result def __getstate__(self): - return {**self.__dict__, '_data': {k: self._getlist(k) for k in self}} + return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}} def __setstate__(self, obj_dict): - data = obj_dict.pop('_data', {}) + data = obj_dict.pop("_data", {}) for k, v in data.items(): self.setlist(k, v) self.__dict__.update(obj_dict) @@ -231,7 +230,7 @@ class ImmutableList(tuple): AttributeError: You cannot mutate this. """ - def __new__(cls, *args, warning='ImmutableList object is immutable.', **kwargs): + def __new__(cls, *args, warning="ImmutableList object is immutable.", **kwargs): self = tuple.__new__(cls, *args, **kwargs) self.warning = warning return self @@ -264,6 +263,7 @@ class DictWrapper(dict): Used by the SQL construction code to ensure that values are correctly quoted before being used. """ + def __init__(self, data, func, prefix): super().__init__(data) self.func = func @@ -277,7 +277,7 @@ class DictWrapper(dict): """ use_func = key.startswith(self.prefix) if use_func: - key = key[len(self.prefix):] + key = key[len(self.prefix) :] value = super().__getitem__(key) if use_func: return self.func(value) @@ -314,9 +314,7 @@ class CaseInsensitiveMapping(Mapping): def __eq__(self, other): return isinstance(other, Mapping) and { k.lower(): v for k, v in self.items() - } == { - k.lower(): v for k, v in other.items() - } + } == {k.lower(): v for k, v in other.items()} def __iter__(self): return (original_key for original_key, value in self._store.values()) @@ -335,11 +333,11 @@ class CaseInsensitiveMapping(Mapping): for i, elem in enumerate(data): if len(elem) != 2: raise ValueError( - 'dictionary update sequence element #{} has length {}; ' - '2 is required.'.format(i, len(elem)) + "dictionary update sequence element #{} has length {}; " + "2 is required.".format(i, len(elem)) ) if not isinstance(elem[0], str): raise ValueError( - 'Element key %r invalid, only strings are allowed' % elem[0] + "Element key %r invalid, only strings are allowed" % elem[0] ) yield elem diff --git a/django/utils/dateformat.py b/django/utils/dateformat.py index 70bfb32e01..541b8f8688 100644 --- a/django/utils/dateformat.py +++ b/django/utils/dateformat.py @@ -15,17 +15,24 @@ import datetime from email.utils import format_datetime as format_datetime_rfc5322 from django.utils.dates import ( - MONTHS, MONTHS_3, MONTHS_ALT, MONTHS_AP, WEEKDAYS, WEEKDAYS_ABBR, + MONTHS, + MONTHS_3, + MONTHS_ALT, + MONTHS_AP, + WEEKDAYS, + WEEKDAYS_ABBR, ) from django.utils.regex_helper import _lazy_re_compile from django.utils.timezone import ( - _datetime_ambiguous_or_imaginary, get_default_timezone, is_naive, + _datetime_ambiguous_or_imaginary, + get_default_timezone, + is_naive, make_aware, ) from django.utils.translation import gettext as _ -re_formatchars = _lazy_re_compile(r'(?<!\\)([aAbcdDeEfFgGhHiIjlLmMnNoOPrsStTUuwWyYzZ])') -re_escaped = _lazy_re_compile(r'\\(.)') +re_formatchars = _lazy_re_compile(r"(?<!\\)([aAbcdDeEfFgGhHiIjlLmMnNoOPrsStTUuwWyYzZ])") +re_escaped = _lazy_re_compile(r"\\(.)") class Formatter: @@ -40,12 +47,11 @@ class Formatter: ) pieces.append(str(getattr(self, piece)())) elif piece: - pieces.append(re_escaped.sub(r'\1', piece)) - return ''.join(pieces) + pieces.append(re_escaped.sub(r"\1", piece)) + return "".join(pieces) class TimeFormat(Formatter): - def __init__(self, obj): self.data = obj self.timezone = None @@ -61,22 +67,21 @@ class TimeFormat(Formatter): @property def _no_timezone_or_datetime_is_ambiguous_or_imaginary(self): - return ( - not self.timezone or - _datetime_ambiguous_or_imaginary(self.data, self.timezone) + return not self.timezone or _datetime_ambiguous_or_imaginary( + self.data, self.timezone ) def a(self): "'a.m.' or 'p.m.'" if self.data.hour > 11: - return _('p.m.') - return _('a.m.') + return _("p.m.") + return _("a.m.") def A(self): "'AM' or 'PM'" if self.data.hour > 11: - return _('PM') - return _('AM') + return _("PM") + return _("AM") def e(self): """ @@ -88,8 +93,8 @@ class TimeFormat(Formatter): return "" try: - if hasattr(self.data, 'tzinfo') and self.data.tzinfo: - return self.data.tzname() or '' + if hasattr(self.data, "tzinfo") and self.data.tzinfo: + return self.data.tzname() or "" except NotImplementedError: pass return "" @@ -103,7 +108,7 @@ class TimeFormat(Formatter): """ hour = self.data.hour % 12 or 12 minute = self.data.minute - return '%d:%02d' % (hour, minute) if minute else hour + return "%d:%02d" % (hour, minute) if minute else hour def g(self): "Hour, 12-hour format without leading zeros; i.e. '1' to '12'" @@ -115,15 +120,15 @@ class TimeFormat(Formatter): def h(self): "Hour, 12-hour format; i.e. '01' to '12'" - return '%02d' % (self.data.hour % 12 or 12) + return "%02d" % (self.data.hour % 12 or 12) def H(self): "Hour, 24-hour format; i.e. '00' to '23'" - return '%02d' % self.data.hour + return "%02d" % self.data.hour def i(self): "Minutes; i.e. '00' to '59'" - return '%02d' % self.data.minute + return "%02d" % self.data.minute def O(self): # NOQA: E743, E741 """ @@ -135,7 +140,7 @@ class TimeFormat(Formatter): return "" seconds = self.Z() - sign = '-' if seconds < 0 else '+' + sign = "-" if seconds < 0 else "+" seconds = abs(seconds) return "%s%02d%02d" % (sign, seconds // 3600, (seconds // 60) % 60) @@ -147,14 +152,14 @@ class TimeFormat(Formatter): Proprietary extension. """ if self.data.minute == 0 and self.data.hour == 0: - return _('midnight') + return _("midnight") if self.data.minute == 0 and self.data.hour == 12: - return _('noon') - return '%s %s' % (self.f(), self.a()) + return _("noon") + return "%s %s" % (self.f(), self.a()) def s(self): "Seconds; i.e. '00' to '59'" - return '%02d' % self.data.second + return "%02d" % self.data.second def T(self): """ @@ -169,7 +174,7 @@ class TimeFormat(Formatter): def u(self): "Microseconds; i.e. '000000' to '999999'" - return '%06d' % self.data.microsecond + return "%06d" % self.data.microsecond def Z(self): """ @@ -205,7 +210,7 @@ class DateFormat(TimeFormat): def d(self): "Day of the month, 2 digits with leading zeros; i.e. '01' to '31'" - return '%02d' % self.data.day + return "%02d" % self.data.day def D(self): "Day of the week, textual, 3 letters; e.g. 'Fri'" @@ -222,8 +227,8 @@ class DateFormat(TimeFormat): def I(self): # NOQA: E743, E741 "'1' if daylight saving time, '0' otherwise." if self._no_timezone_or_datetime_is_ambiguous_or_imaginary: - return '' - return '1' if self.timezone.dst(self.data) else '0' + return "" + return "1" if self.timezone.dst(self.data) else "0" def j(self): "Day of the month without leading zeros; i.e. '1' to '31'" @@ -239,7 +244,7 @@ class DateFormat(TimeFormat): def m(self): "Month; i.e. '01' to '12'" - return '%02d' % self.data.month + return "%02d" % self.data.month def M(self): "Month, textual, 3 letters; e.g. 'Jan'" @@ -273,19 +278,19 @@ class DateFormat(TimeFormat): def S(self): "English ordinal suffix for the day of the month, 2 characters; i.e. 'st', 'nd', 'rd' or 'th'" if self.data.day in (11, 12, 13): # Special case - return 'th' + return "th" last = self.data.day % 10 if last == 1: - return 'st' + return "st" if last == 2: - return 'nd' + return "nd" if last == 3: - return 'rd' - return 'th' + return "rd" + return "th" def t(self): "Number of days in the given month; i.e. '28' to '31'" - return '%02d' % calendar.monthrange(self.data.year, self.data.month)[1] + return "%02d" % calendar.monthrange(self.data.year, self.data.month)[1] def U(self): "Seconds since the Unix epoch (January 1 1970 00:00:00 GMT)" @@ -304,11 +309,11 @@ class DateFormat(TimeFormat): def y(self): """Year, 2 digits with leading zeros; e.g. '99'.""" - return '%02d' % (self.data.year % 100) + return "%02d" % (self.data.year % 100) def Y(self): """Year, 4 digits with leading zeros; e.g. '1999'.""" - return '%04d' % self.data.year + return "%04d" % self.data.year def z(self): """Day of the year, i.e. 1 to 366.""" diff --git a/django/utils/dateparse.py b/django/utils/dateparse.py index a137031b3f..2e6a260a4f 100644 --- a/django/utils/dateparse.py +++ b/django/utils/dateparse.py @@ -10,59 +10,57 @@ import datetime from django.utils.regex_helper import _lazy_re_compile from django.utils.timezone import get_fixed_timezone, utc -date_re = _lazy_re_compile( - r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$' -) +date_re = _lazy_re_compile(r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$") time_re = _lazy_re_compile( - r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})' - r'(?::(?P<second>\d{1,2})(?:[\.,](?P<microsecond>\d{1,6})\d{0,6})?)?$' + r"(?P<hour>\d{1,2}):(?P<minute>\d{1,2})" + r"(?::(?P<second>\d{1,2})(?:[\.,](?P<microsecond>\d{1,6})\d{0,6})?)?$" ) datetime_re = _lazy_re_compile( - r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})' - r'[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})' - r'(?::(?P<second>\d{1,2})(?:[\.,](?P<microsecond>\d{1,6})\d{0,6})?)?' - r'\s*(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$' + r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})" + r"[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})" + r"(?::(?P<second>\d{1,2})(?:[\.,](?P<microsecond>\d{1,6})\d{0,6})?)?" + r"\s*(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$" ) standard_duration_re = _lazy_re_compile( - r'^' - r'(?:(?P<days>-?\d+) (days?, )?)?' - r'(?P<sign>-?)' - r'((?:(?P<hours>\d+):)(?=\d+:\d+))?' - r'(?:(?P<minutes>\d+):)?' - r'(?P<seconds>\d+)' - r'(?:[\.,](?P<microseconds>\d{1,6})\d{0,6})?' - r'$' + r"^" + r"(?:(?P<days>-?\d+) (days?, )?)?" + r"(?P<sign>-?)" + r"((?:(?P<hours>\d+):)(?=\d+:\d+))?" + r"(?:(?P<minutes>\d+):)?" + r"(?P<seconds>\d+)" + r"(?:[\.,](?P<microseconds>\d{1,6})\d{0,6})?" + r"$" ) # Support the sections of ISO 8601 date representation that are accepted by # timedelta iso8601_duration_re = _lazy_re_compile( - r'^(?P<sign>[-+]?)' - r'P' - r'(?:(?P<days>\d+([\.,]\d+)?)D)?' - r'(?:T' - r'(?:(?P<hours>\d+([\.,]\d+)?)H)?' - r'(?:(?P<minutes>\d+([\.,]\d+)?)M)?' - r'(?:(?P<seconds>\d+([\.,]\d+)?)S)?' - r')?' - r'$' + r"^(?P<sign>[-+]?)" + r"P" + r"(?:(?P<days>\d+([\.,]\d+)?)D)?" + r"(?:T" + r"(?:(?P<hours>\d+([\.,]\d+)?)H)?" + r"(?:(?P<minutes>\d+([\.,]\d+)?)M)?" + r"(?:(?P<seconds>\d+([\.,]\d+)?)S)?" + r")?" + r"$" ) # Support PostgreSQL's day-time interval format, e.g. "3 days 04:05:06". The # year-month and mixed intervals cannot be converted to a timedelta and thus # aren't accepted. postgres_interval_re = _lazy_re_compile( - r'^' - r'(?:(?P<days>-?\d+) (days? ?))?' - r'(?:(?P<sign>[-+])?' - r'(?P<hours>\d+):' - r'(?P<minutes>\d\d):' - r'(?P<seconds>\d\d)' - r'(?:\.(?P<microseconds>\d{1,6}))?' - r')?$' + r"^" + r"(?:(?P<days>-?\d+) (days? ?))?" + r"(?:(?P<sign>[-+])?" + r"(?P<hours>\d+):" + r"(?P<minutes>\d\d):" + r"(?P<seconds>\d\d)" + r"(?:\.(?P<microseconds>\d{1,6}))?" + r")?$" ) @@ -98,7 +96,7 @@ def parse_time(value): except ValueError: if match := time_re.match(value): kw = match.groupdict() - kw['microsecond'] = kw['microsecond'] and kw['microsecond'].ljust(6, '0') + kw["microsecond"] = kw["microsecond"] and kw["microsecond"].ljust(6, "0") kw = {k: int(v) for k, v in kw.items() if v is not None} return datetime.time(**kw) @@ -117,14 +115,14 @@ def parse_datetime(value): except ValueError: if match := datetime_re.match(value): kw = match.groupdict() - kw['microsecond'] = kw['microsecond'] and kw['microsecond'].ljust(6, '0') - tzinfo = kw.pop('tzinfo') - if tzinfo == 'Z': + kw["microsecond"] = kw["microsecond"] and kw["microsecond"].ljust(6, "0") + tzinfo = kw.pop("tzinfo") + if tzinfo == "Z": tzinfo = utc elif tzinfo is not None: offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0 offset = 60 * int(tzinfo[1:3]) + offset_mins - if tzinfo[0] == '-': + if tzinfo[0] == "-": offset = -offset tzinfo = get_fixed_timezone(offset) kw = {k: int(v) for k, v in kw.items() if v is not None} @@ -140,17 +138,17 @@ def parse_duration(value): format. """ match = ( - standard_duration_re.match(value) or - iso8601_duration_re.match(value) or - postgres_interval_re.match(value) + standard_duration_re.match(value) + or iso8601_duration_re.match(value) + or postgres_interval_re.match(value) ) if match: kw = match.groupdict() - sign = -1 if kw.pop('sign', '+') == '-' else 1 - if kw.get('microseconds'): - kw['microseconds'] = kw['microseconds'].ljust(6, '0') - kw = {k: float(v.replace(',', '.')) for k, v in kw.items() if v is not None} - days = datetime.timedelta(kw.pop('days', .0) or .0) + sign = -1 if kw.pop("sign", "+") == "-" else 1 + if kw.get("microseconds"): + kw["microseconds"] = kw["microseconds"].ljust(6, "0") + kw = {k: float(v.replace(",", ".")) for k, v in kw.items() if v is not None} + days = datetime.timedelta(kw.pop("days", 0.0) or 0.0) if match.re == iso8601_duration_re: days *= sign return days + sign * datetime.timedelta(**kw) diff --git a/django/utils/dates.py b/django/utils/dates.py index d2e769db84..05ac3293df 100644 --- a/django/utils/dates.py +++ b/django/utils/dates.py @@ -1,49 +1,79 @@ "Commonly-used date structures" -from django.utils.translation import gettext_lazy as _, pgettext_lazy +from django.utils.translation import gettext_lazy as _ +from django.utils.translation import pgettext_lazy WEEKDAYS = { - 0: _('Monday'), 1: _('Tuesday'), 2: _('Wednesday'), 3: _('Thursday'), 4: _('Friday'), - 5: _('Saturday'), 6: _('Sunday') + 0: _("Monday"), + 1: _("Tuesday"), + 2: _("Wednesday"), + 3: _("Thursday"), + 4: _("Friday"), + 5: _("Saturday"), + 6: _("Sunday"), } WEEKDAYS_ABBR = { - 0: _('Mon'), 1: _('Tue'), 2: _('Wed'), 3: _('Thu'), 4: _('Fri'), - 5: _('Sat'), 6: _('Sun') + 0: _("Mon"), + 1: _("Tue"), + 2: _("Wed"), + 3: _("Thu"), + 4: _("Fri"), + 5: _("Sat"), + 6: _("Sun"), } MONTHS = { - 1: _('January'), 2: _('February'), 3: _('March'), 4: _('April'), 5: _('May'), 6: _('June'), - 7: _('July'), 8: _('August'), 9: _('September'), 10: _('October'), 11: _('November'), - 12: _('December') + 1: _("January"), + 2: _("February"), + 3: _("March"), + 4: _("April"), + 5: _("May"), + 6: _("June"), + 7: _("July"), + 8: _("August"), + 9: _("September"), + 10: _("October"), + 11: _("November"), + 12: _("December"), } MONTHS_3 = { - 1: _('jan'), 2: _('feb'), 3: _('mar'), 4: _('apr'), 5: _('may'), 6: _('jun'), - 7: _('jul'), 8: _('aug'), 9: _('sep'), 10: _('oct'), 11: _('nov'), 12: _('dec') + 1: _("jan"), + 2: _("feb"), + 3: _("mar"), + 4: _("apr"), + 5: _("may"), + 6: _("jun"), + 7: _("jul"), + 8: _("aug"), + 9: _("sep"), + 10: _("oct"), + 11: _("nov"), + 12: _("dec"), } MONTHS_AP = { # month names in Associated Press style - 1: pgettext_lazy('abbrev. month', 'Jan.'), - 2: pgettext_lazy('abbrev. month', 'Feb.'), - 3: pgettext_lazy('abbrev. month', 'March'), - 4: pgettext_lazy('abbrev. month', 'April'), - 5: pgettext_lazy('abbrev. month', 'May'), - 6: pgettext_lazy('abbrev. month', 'June'), - 7: pgettext_lazy('abbrev. month', 'July'), - 8: pgettext_lazy('abbrev. month', 'Aug.'), - 9: pgettext_lazy('abbrev. month', 'Sept.'), - 10: pgettext_lazy('abbrev. month', 'Oct.'), - 11: pgettext_lazy('abbrev. month', 'Nov.'), - 12: pgettext_lazy('abbrev. month', 'Dec.') + 1: pgettext_lazy("abbrev. month", "Jan."), + 2: pgettext_lazy("abbrev. month", "Feb."), + 3: pgettext_lazy("abbrev. month", "March"), + 4: pgettext_lazy("abbrev. month", "April"), + 5: pgettext_lazy("abbrev. month", "May"), + 6: pgettext_lazy("abbrev. month", "June"), + 7: pgettext_lazy("abbrev. month", "July"), + 8: pgettext_lazy("abbrev. month", "Aug."), + 9: pgettext_lazy("abbrev. month", "Sept."), + 10: pgettext_lazy("abbrev. month", "Oct."), + 11: pgettext_lazy("abbrev. month", "Nov."), + 12: pgettext_lazy("abbrev. month", "Dec."), } MONTHS_ALT = { # required for long date representation by some locales - 1: pgettext_lazy('alt. month', 'January'), - 2: pgettext_lazy('alt. month', 'February'), - 3: pgettext_lazy('alt. month', 'March'), - 4: pgettext_lazy('alt. month', 'April'), - 5: pgettext_lazy('alt. month', 'May'), - 6: pgettext_lazy('alt. month', 'June'), - 7: pgettext_lazy('alt. month', 'July'), - 8: pgettext_lazy('alt. month', 'August'), - 9: pgettext_lazy('alt. month', 'September'), - 10: pgettext_lazy('alt. month', 'October'), - 11: pgettext_lazy('alt. month', 'November'), - 12: pgettext_lazy('alt. month', 'December') + 1: pgettext_lazy("alt. month", "January"), + 2: pgettext_lazy("alt. month", "February"), + 3: pgettext_lazy("alt. month", "March"), + 4: pgettext_lazy("alt. month", "April"), + 5: pgettext_lazy("alt. month", "May"), + 6: pgettext_lazy("alt. month", "June"), + 7: pgettext_lazy("alt. month", "July"), + 8: pgettext_lazy("alt. month", "August"), + 9: pgettext_lazy("alt. month", "September"), + 10: pgettext_lazy("alt. month", "October"), + 11: pgettext_lazy("alt. month", "November"), + 12: pgettext_lazy("alt. month", "December"), } diff --git a/django/utils/datetime_safe.py b/django/utils/datetime_safe.py index e06887b706..817ddcf0fa 100644 --- a/django/utils/datetime_safe.py +++ b/django/utils/datetime_safe.py @@ -9,13 +9,14 @@ import time import warnings -from datetime import date as real_date, datetime as real_datetime +from datetime import date as real_date +from datetime import datetime as real_datetime from django.utils.deprecation import RemovedInDjango50Warning from django.utils.regex_helper import _lazy_re_compile warnings.warn( - 'The django.utils.datetime_safe module is deprecated.', + "The django.utils.datetime_safe module is deprecated.", category=RemovedInDjango50Warning, stacklevel=2, ) @@ -32,9 +33,16 @@ class datetime(real_datetime): @classmethod def combine(cls, date, time): - return cls(date.year, date.month, date.day, - time.hour, time.minute, time.second, - time.microsecond, time.tzinfo) + return cls( + date.year, + date.month, + date.day, + time.hour, + time.minute, + time.second, + time.microsecond, + time.tzinfo, + ) def date(self): return date(self.year, self.month, self.day) @@ -78,7 +86,9 @@ def strftime(dt, fmt): return super(type(dt), dt).strftime(fmt) illegal_formatting = _illegal_formatting.search(fmt) if illegal_formatting: - raise TypeError('strftime of dates before 1000 does not handle ' + illegal_formatting[0]) + raise TypeError( + "strftime of dates before 1000 does not handle " + illegal_formatting[0] + ) year = dt.year # For every non-leap year century, advance by @@ -104,5 +114,5 @@ def strftime(dt, fmt): s = s1 syear = "%04d" % dt.year for site in sites: - s = s[:site] + syear + s[site + 4:] + s = s[:site] + syear + s[site + 4 :] return s diff --git a/django/utils/deconstruct.py b/django/utils/deconstruct.py index 9d073902f3..ca87f7ecc9 100644 --- a/django/utils/deconstruct.py +++ b/django/utils/deconstruct.py @@ -10,6 +10,7 @@ def deconstructible(*args, path=None): The `path` kwarg specifies the import path. """ + def decorator(klass): def __new__(cls, *args, **kwargs): # We capture the arguments to make returning them trivial @@ -24,7 +25,7 @@ def deconstructible(*args, path=None): """ # Fallback version if path and type(obj) is klass: - module_name, _, name = path.rpartition('.') + module_name, _, name = path.rpartition(".") else: module_name = obj.__module__ name = obj.__class__.__name__ @@ -38,11 +39,12 @@ def deconstructible(*args, path=None): "body to use migrations.\n" "For more information, see " "https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values" - % (name, module_name, get_docs_version())) + % (name, module_name, get_docs_version()) + ) return ( path if path and type(obj) is klass - else f'{obj.__class__.__module__}.{name}', + else f"{obj.__class__.__module__}.{name}", obj._constructor_args[0], obj._constructor_args[1], ) diff --git a/django/utils/decorators.py b/django/utils/decorators.py index 69aca10a4d..e412bb15e1 100644 --- a/django/utils/decorators.py +++ b/django/utils/decorators.py @@ -6,7 +6,9 @@ from functools import partial, update_wrapper, wraps class classonlymethod(classmethod): def __get__(self, instance, cls=None): if instance is not None: - raise AttributeError("This method is available only on the class, not on instances.") + raise AttributeError( + "This method is available only on the class, not on instances." + ) return super().__get__(instance, cls) @@ -16,6 +18,7 @@ def _update_method_wrapper(_wrapper, decorator): @decorator def dummy(*args, **kwargs): pass + update_wrapper(_wrapper, dummy) @@ -24,7 +27,7 @@ def _multi_decorate(decorators, method): Decorate `method` with one or more function decorators. `decorators` can be a single decorator or an iterable of decorators. """ - if hasattr(decorators, '__iter__'): + if hasattr(decorators, "__iter__"): # Apply a list/tuple of decorators if 'decorators' is one. Decorator # functions are applied so that the call order is the same as the # order in which they appear in the iterable. @@ -50,7 +53,7 @@ def _multi_decorate(decorators, method): return _wrapper -def method_decorator(decorator, name=''): +def method_decorator(decorator, name=""): """ Convert a function decorator into a method decorator """ @@ -78,11 +81,11 @@ def method_decorator(decorator, name=''): # Don't worry about making _dec look similar to a list/tuple as it's rather # meaningless. - if not hasattr(decorator, '__iter__'): + if not hasattr(decorator, "__iter__"): update_wrapper(_dec, decorator) # Change the name to aid debugging. - obj = decorator if hasattr(decorator, '__name__') else decorator.__class__ - _dec.__name__ = 'method_decorator(%s)' % obj.__name__ + obj = decorator if hasattr(decorator, "__name__") else decorator.__class__ + _dec.__name__ = "method_decorator(%s)" % obj.__name__ return _dec @@ -118,37 +121,44 @@ def make_middleware_decorator(middleware_class): @wraps(view_func) def _wrapped_view(request, *args, **kwargs): - if hasattr(middleware, 'process_request'): + if hasattr(middleware, "process_request"): result = middleware.process_request(request) if result is not None: return result - if hasattr(middleware, 'process_view'): + if hasattr(middleware, "process_view"): result = middleware.process_view(request, view_func, args, kwargs) if result is not None: return result try: response = view_func(request, *args, **kwargs) except Exception as e: - if hasattr(middleware, 'process_exception'): + if hasattr(middleware, "process_exception"): result = middleware.process_exception(request, e) if result is not None: return result raise - if hasattr(response, 'render') and callable(response.render): - if hasattr(middleware, 'process_template_response'): - response = middleware.process_template_response(request, response) + if hasattr(response, "render") and callable(response.render): + if hasattr(middleware, "process_template_response"): + response = middleware.process_template_response( + request, response + ) # Defer running of process_response until after the template # has been rendered: - if hasattr(middleware, 'process_response'): + if hasattr(middleware, "process_response"): + def callback(response): return middleware.process_response(request, response) + response.add_post_render_callback(callback) else: - if hasattr(middleware, 'process_response'): + if hasattr(middleware, "process_response"): return middleware.process_response(request, response) return response + return _wrapped_view + return _decorator + return _make_decorator diff --git a/django/utils/deprecation.py b/django/utils/deprecation.py index 48209bcdf1..528783a5c1 100644 --- a/django/utils/deprecation.py +++ b/django/utils/deprecation.py @@ -14,7 +14,9 @@ class RemovedInDjango50Warning(PendingDeprecationWarning): class warn_about_renamed_method: - def __init__(self, class_name, old_method_name, new_method_name, deprecation_warning): + def __init__( + self, class_name, old_method_name, new_method_name, deprecation_warning + ): self.class_name = class_name self.old_method_name = old_method_name self.new_method_name = new_method_name @@ -23,10 +25,13 @@ class warn_about_renamed_method: def __call__(self, f): def wrapped(*args, **kwargs): warnings.warn( - "`%s.%s` is deprecated, use `%s` instead." % - (self.class_name, self.old_method_name, self.new_method_name), - self.deprecation_warning, 2) + "`%s.%s` is deprecated, use `%s` instead." + % (self.class_name, self.old_method_name, self.new_method_name), + self.deprecation_warning, + 2, + ) return f(*args, **kwargs) + return wrapped @@ -60,9 +65,11 @@ class RenameMethodsBase(type): # Define the new method if missing and complain about it if not new_method and old_method: warnings.warn( - "`%s.%s` method should be renamed `%s`." % - (class_name, old_method_name, new_method_name), - deprecation_warning, 2) + "`%s.%s` method should be renamed `%s`." + % (class_name, old_method_name, new_method_name), + deprecation_warning, + 2, + ) setattr(base, new_method_name, old_method) setattr(base, old_method_name, wrapper(old_method)) @@ -77,7 +84,8 @@ class DeprecationInstanceCheck(type): def __instancecheck__(self, instance): warnings.warn( "`%s` is deprecated, use `%s` instead." % (self.__name__, self.alternative), - self.deprecation_warning, 2 + self.deprecation_warning, + 2, ) return super().__instancecheck__(instance) @@ -88,17 +96,17 @@ class MiddlewareMixin: def __init__(self, get_response): if get_response is None: - raise ValueError('get_response must be provided.') + raise ValueError("get_response must be provided.") self.get_response = get_response self._async_check() super().__init__() def __repr__(self): - return '<%s get_response=%s>' % ( + return "<%s get_response=%s>" % ( self.__class__.__qualname__, getattr( self.get_response, - '__qualname__', + "__qualname__", self.get_response.__class__.__name__, ), ) @@ -120,10 +128,10 @@ class MiddlewareMixin: if self._is_coroutine: return self.__acall__(request) response = None - if hasattr(self, 'process_request'): + if hasattr(self, "process_request"): response = self.process_request(request) response = response or self.get_response(request) - if hasattr(self, 'process_response'): + if hasattr(self, "process_response"): response = self.process_response(request, response) return response @@ -133,13 +141,13 @@ class MiddlewareMixin: is running. """ response = None - if hasattr(self, 'process_request'): + if hasattr(self, "process_request"): response = await sync_to_async( self.process_request, thread_sensitive=True, )(request) response = response or await self.get_response(request) - if hasattr(self, 'process_response'): + if hasattr(self, "process_response"): response = await sync_to_async( self.process_response, thread_sensitive=True, diff --git a/django/utils/duration.py b/django/utils/duration.py index 466603d46c..8495af3fa8 100644 --- a/django/utils/duration.py +++ b/django/utils/duration.py @@ -19,25 +19,27 @@ def duration_string(duration): """Version of str(timedelta) which is not English specific.""" days, hours, minutes, seconds, microseconds = _get_duration_components(duration) - string = '{:02d}:{:02d}:{:02d}'.format(hours, minutes, seconds) + string = "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) if days: - string = '{} '.format(days) + string + string = "{} ".format(days) + string if microseconds: - string += '.{:06d}'.format(microseconds) + string += ".{:06d}".format(microseconds) return string def duration_iso_string(duration): if duration < datetime.timedelta(0): - sign = '-' + sign = "-" duration *= -1 else: - sign = '' + sign = "" days, hours, minutes, seconds, microseconds = _get_duration_components(duration) - ms = '.{:06d}'.format(microseconds) if microseconds else "" - return '{}P{}DT{:02d}H{:02d}M{:02d}{}S'.format(sign, days, hours, minutes, seconds, ms) + ms = ".{:06d}".format(microseconds) if microseconds else "" + return "{}P{}DT{:02d}H{:02d}M{:02d}{}S".format( + sign, days, hours, minutes, seconds, ms + ) def duration_microseconds(delta): diff --git a/django/utils/encoding.py b/django/utils/encoding.py index 19eb150ad7..89eac79dd4 100644 --- a/django/utils/encoding.py +++ b/django/utils/encoding.py @@ -13,10 +13,14 @@ class DjangoUnicodeDecodeError(UnicodeDecodeError): super().__init__(*args) def __str__(self): - return '%s. You passed in %r (%s)' % (super().__str__(), self.obj, type(self.obj)) + return "%s. You passed in %r (%s)" % ( + super().__str__(), + self.obj, + type(self.obj), + ) -def smart_str(s, encoding='utf-8', strings_only=False, errors='strict'): +def smart_str(s, encoding="utf-8", strings_only=False, errors="strict"): """ Return a string representing 's'. Treat bytestrings using the 'encoding' codec. @@ -30,7 +34,13 @@ def smart_str(s, encoding='utf-8', strings_only=False, errors='strict'): _PROTECTED_TYPES = ( - type(None), int, float, Decimal, datetime.datetime, datetime.date, datetime.time, + type(None), + int, + float, + Decimal, + datetime.datetime, + datetime.date, + datetime.time, ) @@ -43,7 +53,7 @@ def is_protected_type(obj): return isinstance(obj, _PROTECTED_TYPES) -def force_str(s, encoding='utf-8', strings_only=False, errors='strict'): +def force_str(s, encoding="utf-8", strings_only=False, errors="strict"): """ Similar to smart_str(), except that lazy instances are resolved to strings, rather than kept as lazy objects. @@ -65,7 +75,7 @@ def force_str(s, encoding='utf-8', strings_only=False, errors='strict'): return s -def smart_bytes(s, encoding='utf-8', strings_only=False, errors='strict'): +def smart_bytes(s, encoding="utf-8", strings_only=False, errors="strict"): """ Return a bytestring version of 's', encoded as specified in 'encoding'. @@ -77,7 +87,7 @@ def smart_bytes(s, encoding='utf-8', strings_only=False, errors='strict'): return force_bytes(s, encoding, strings_only, errors) -def force_bytes(s, encoding='utf-8', strings_only=False, errors='strict'): +def force_bytes(s, encoding="utf-8", strings_only=False, errors="strict"): """ Similar to smart_bytes, except that lazy instances are resolved to strings, rather than kept as lazy objects. @@ -86,10 +96,10 @@ def force_bytes(s, encoding='utf-8', strings_only=False, errors='strict'): """ # Handle the common case first for performance reasons. if isinstance(s, bytes): - if encoding == 'utf-8': + if encoding == "utf-8": return s else: - return s.decode('utf-8', errors).encode(encoding, errors) + return s.decode("utf-8", errors).encode(encoding, errors) if strings_only and is_protected_type(s): return s if isinstance(s, memoryview): @@ -136,15 +146,14 @@ _hextobyte = { (fmt % char).encode(): bytes((char,)) for ascii_range in _ascii_ranges for char in ascii_range - for fmt in ['%02x', '%02X'] + for fmt in ["%02x", "%02X"] } # And then everything above 128, because bytes ≥ 128 are part of multibyte # Unicode characters. -_hexdig = '0123456789ABCDEFabcdef' -_hextobyte.update({ - (a + b).encode(): bytes.fromhex(a + b) - for a in _hexdig[8:] for b in _hexdig -}) +_hexdig = "0123456789ABCDEFabcdef" +_hextobyte.update( + {(a + b).encode(): bytes.fromhex(a + b) for a in _hexdig[8:] for b in _hexdig} +) def uri_to_iri(uri): @@ -164,7 +173,7 @@ def uri_to_iri(uri): # second block, decode the first 2 bytes if they represent a hex code to # decode. The rest of the block is the part after '%AB', not containing # any '%'. Add that to the output without further processing. - bits = uri.split(b'%') + bits = uri.split(b"%") if len(bits) == 1: iri = uri else: @@ -177,9 +186,9 @@ def uri_to_iri(uri): append(hextobyte[item[:2]]) append(item[2:]) else: - append(b'%') + append(b"%") append(item) - iri = b''.join(parts) + iri = b"".join(parts) return repercent_broken_unicode(iri).decode() @@ -202,7 +211,7 @@ def escape_uri_path(path): def punycode(domain): """Return the Punycode of the given domain if it's non-ASCII.""" - return domain.encode('idna').decode('ascii') + return domain.encode("idna").decode("ascii") def repercent_broken_unicode(path): @@ -217,8 +226,8 @@ def repercent_broken_unicode(path): except UnicodeDecodeError as e: # CVE-2019-14235: A recursion shouldn't be used since the exception # handling uses massive amounts of memory - repercent = quote(path[e.start:e.end], safe=b"/#%[]=:;$&()+,!?*@'~") - path = path[:e.start] + repercent.encode() + path[e.end:] + repercent = quote(path[e.start : e.end], safe=b"/#%[]=:;$&()+,!?*@'~") + path = path[: e.start] + repercent.encode() + path[e.end :] else: return path @@ -245,10 +254,10 @@ def get_system_encoding(): #10335 and #5846. """ try: - encoding = locale.getdefaultlocale()[1] or 'ascii' + encoding = locale.getdefaultlocale()[1] or "ascii" codecs.lookup(encoding) except Exception: - encoding = 'ascii' + encoding = "ascii" return encoding diff --git a/django/utils/feedgenerator.py b/django/utils/feedgenerator.py index 857beff13a..a8f6bc5c06 100644 --- a/django/utils/feedgenerator.py +++ b/django/utils/feedgenerator.py @@ -40,7 +40,7 @@ def rfc2822_date(date): def rfc3339_date(date): if not isinstance(date, datetime.datetime): date = datetime.datetime.combine(date, datetime.time()) - return date.isoformat() + ('Z' if date.utcoffset() is None else '') + return date.isoformat() + ("Z" if date.utcoffset() is None else "") def get_tag_uri(url, date): @@ -50,68 +50,103 @@ def get_tag_uri(url, date): See https://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id """ bits = urlparse(url) - d = '' + d = "" if date is not None: - d = ',%s' % date.strftime('%Y-%m-%d') - return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment) + d = ",%s" % date.strftime("%Y-%m-%d") + return "tag:%s%s:%s/%s" % (bits.hostname, d, bits.path, bits.fragment) class SyndicationFeed: "Base class for all syndication feeds. Subclasses should provide write()" - def __init__(self, title, link, description, language=None, author_email=None, - author_name=None, author_link=None, subtitle=None, categories=None, - feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs): + + def __init__( + self, + title, + link, + description, + language=None, + author_email=None, + author_name=None, + author_link=None, + subtitle=None, + categories=None, + feed_url=None, + feed_copyright=None, + feed_guid=None, + ttl=None, + **kwargs, + ): def to_str(s): return str(s) if s is not None else s + categories = categories and [str(c) for c in categories] self.feed = { - 'title': to_str(title), - 'link': iri_to_uri(link), - 'description': to_str(description), - 'language': to_str(language), - 'author_email': to_str(author_email), - 'author_name': to_str(author_name), - 'author_link': iri_to_uri(author_link), - 'subtitle': to_str(subtitle), - 'categories': categories or (), - 'feed_url': iri_to_uri(feed_url), - 'feed_copyright': to_str(feed_copyright), - 'id': feed_guid or link, - 'ttl': to_str(ttl), + "title": to_str(title), + "link": iri_to_uri(link), + "description": to_str(description), + "language": to_str(language), + "author_email": to_str(author_email), + "author_name": to_str(author_name), + "author_link": iri_to_uri(author_link), + "subtitle": to_str(subtitle), + "categories": categories or (), + "feed_url": iri_to_uri(feed_url), + "feed_copyright": to_str(feed_copyright), + "id": feed_guid or link, + "ttl": to_str(ttl), **kwargs, } self.items = [] - def add_item(self, title, link, description, author_email=None, - author_name=None, author_link=None, pubdate=None, comments=None, - unique_id=None, unique_id_is_permalink=None, categories=(), - item_copyright=None, ttl=None, updateddate=None, enclosures=None, **kwargs): + def add_item( + self, + title, + link, + description, + author_email=None, + author_name=None, + author_link=None, + pubdate=None, + comments=None, + unique_id=None, + unique_id_is_permalink=None, + categories=(), + item_copyright=None, + ttl=None, + updateddate=None, + enclosures=None, + **kwargs, + ): """ Add an item to the feed. All args are expected to be strings except pubdate and updateddate, which are datetime.datetime objects, and enclosures, which is an iterable of instances of the Enclosure class. """ + def to_str(s): return str(s) if s is not None else s + categories = categories and [to_str(c) for c in categories] - self.items.append({ - 'title': to_str(title), - 'link': iri_to_uri(link), - 'description': to_str(description), - 'author_email': to_str(author_email), - 'author_name': to_str(author_name), - 'author_link': iri_to_uri(author_link), - 'pubdate': pubdate, - 'updateddate': updateddate, - 'comments': to_str(comments), - 'unique_id': to_str(unique_id), - 'unique_id_is_permalink': unique_id_is_permalink, - 'enclosures': enclosures or (), - 'categories': categories or (), - 'item_copyright': to_str(item_copyright), - 'ttl': to_str(ttl), - **kwargs, - }) + self.items.append( + { + "title": to_str(title), + "link": iri_to_uri(link), + "description": to_str(description), + "author_email": to_str(author_email), + "author_name": to_str(author_name), + "author_link": iri_to_uri(author_link), + "pubdate": pubdate, + "updateddate": updateddate, + "comments": to_str(comments), + "unique_id": to_str(unique_id), + "unique_id_is_permalink": unique_id_is_permalink, + "enclosures": enclosures or (), + "categories": categories or (), + "item_copyright": to_str(item_copyright), + "ttl": to_str(ttl), + **kwargs, + } + ) def num_items(self): return len(self.items) @@ -147,7 +182,9 @@ class SyndicationFeed: Output the feed in the given encoding to outfile, which is a file-like object. Subclasses should override this. """ - raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method') + raise NotImplementedError( + "subclasses of SyndicationFeed must provide a write() method" + ) def writeString(self, encoding): """ @@ -163,7 +200,7 @@ class SyndicationFeed: have either of these attributes this return the current UTC date/time. """ latest_date = None - date_keys = ('updateddate', 'pubdate') + date_keys = ("updateddate", "pubdate") for item in self.items: for date_key in date_keys: @@ -177,6 +214,7 @@ class SyndicationFeed: class Enclosure: """An RSS enclosure""" + def __init__(self, url, length, mime_type): "All args are expected to be strings" self.length, self.mime_type = length, mime_type @@ -184,7 +222,7 @@ class Enclosure: class RssFeed(SyndicationFeed): - content_type = 'application/rss+xml; charset=utf-8' + content_type = "application/rss+xml; charset=utf-8" def write(self, outfile, encoding): handler = SimplerXMLGenerator(outfile, encoding, short_empty_elements=True) @@ -198,31 +236,33 @@ class RssFeed(SyndicationFeed): def rss_attributes(self): return { - 'version': self._version, - 'xmlns:atom': 'http://www.w3.org/2005/Atom', + "version": self._version, + "xmlns:atom": "http://www.w3.org/2005/Atom", } def write_items(self, handler): for item in self.items: - handler.startElement('item', self.item_attributes(item)) + handler.startElement("item", self.item_attributes(item)) self.add_item_elements(handler, item) handler.endElement("item") def add_root_elements(self, handler): - handler.addQuickElement("title", self.feed['title']) - handler.addQuickElement("link", self.feed['link']) - handler.addQuickElement("description", self.feed['description']) - if self.feed['feed_url'] is not None: - handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']}) - if self.feed['language'] is not None: - handler.addQuickElement("language", self.feed['language']) - for cat in self.feed['categories']: + handler.addQuickElement("title", self.feed["title"]) + handler.addQuickElement("link", self.feed["link"]) + handler.addQuickElement("description", self.feed["description"]) + if self.feed["feed_url"] is not None: + handler.addQuickElement( + "atom:link", None, {"rel": "self", "href": self.feed["feed_url"]} + ) + if self.feed["language"] is not None: + handler.addQuickElement("language", self.feed["language"]) + for cat in self.feed["categories"]: handler.addQuickElement("category", cat) - if self.feed['feed_copyright'] is not None: - handler.addQuickElement("copyright", self.feed['feed_copyright']) + if self.feed["feed_copyright"] is not None: + handler.addQuickElement("copyright", self.feed["feed_copyright"]) handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date())) - if self.feed['ttl'] is not None: - handler.addQuickElement("ttl", self.feed['ttl']) + if self.feed["ttl"] is not None: + handler.addQuickElement("ttl", self.feed["ttl"]) def endChannelElement(self, handler): handler.endElement("channel") @@ -232,10 +272,10 @@ class RssUserland091Feed(RssFeed): _version = "0.91" def add_item_elements(self, handler, item): - handler.addQuickElement("title", item['title']) - handler.addQuickElement("link", item['link']) - if item['description'] is not None: - handler.addQuickElement("description", item['description']) + handler.addQuickElement("title", item["title"]) + handler.addQuickElement("link", item["link"]) + if item["description"] is not None: + handler.addQuickElement("description", item["description"]) class Rss201rev2Feed(RssFeed): @@ -243,93 +283,105 @@ class Rss201rev2Feed(RssFeed): _version = "2.0" def add_item_elements(self, handler, item): - handler.addQuickElement("title", item['title']) - handler.addQuickElement("link", item['link']) - if item['description'] is not None: - handler.addQuickElement("description", item['description']) + handler.addQuickElement("title", item["title"]) + handler.addQuickElement("link", item["link"]) + if item["description"] is not None: + handler.addQuickElement("description", item["description"]) # Author information. if item["author_name"] and item["author_email"]: - handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name'])) + handler.addQuickElement( + "author", "%s (%s)" % (item["author_email"], item["author_name"]) + ) elif item["author_email"]: handler.addQuickElement("author", item["author_email"]) elif item["author_name"]: handler.addQuickElement( - "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"} + "dc:creator", + item["author_name"], + {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}, ) - if item['pubdate'] is not None: - handler.addQuickElement("pubDate", rfc2822_date(item['pubdate'])) - if item['comments'] is not None: - handler.addQuickElement("comments", item['comments']) - if item['unique_id'] is not None: + if item["pubdate"] is not None: + handler.addQuickElement("pubDate", rfc2822_date(item["pubdate"])) + if item["comments"] is not None: + handler.addQuickElement("comments", item["comments"]) + if item["unique_id"] is not None: guid_attrs = {} - if isinstance(item.get('unique_id_is_permalink'), bool): - guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower() - handler.addQuickElement("guid", item['unique_id'], guid_attrs) - if item['ttl'] is not None: - handler.addQuickElement("ttl", item['ttl']) + if isinstance(item.get("unique_id_is_permalink"), bool): + guid_attrs["isPermaLink"] = str(item["unique_id_is_permalink"]).lower() + handler.addQuickElement("guid", item["unique_id"], guid_attrs) + if item["ttl"] is not None: + handler.addQuickElement("ttl", item["ttl"]) # Enclosure. - if item['enclosures']: - enclosures = list(item['enclosures']) + if item["enclosures"]: + enclosures = list(item["enclosures"]) if len(enclosures) > 1: raise ValueError( "RSS feed items may only have one enclosure, see " "http://www.rssboard.org/rss-profile#element-channel-item-enclosure" ) enclosure = enclosures[0] - handler.addQuickElement('enclosure', '', { - 'url': enclosure.url, - 'length': enclosure.length, - 'type': enclosure.mime_type, - }) + handler.addQuickElement( + "enclosure", + "", + { + "url": enclosure.url, + "length": enclosure.length, + "type": enclosure.mime_type, + }, + ) # Categories. - for cat in item['categories']: + for cat in item["categories"]: handler.addQuickElement("category", cat) class Atom1Feed(SyndicationFeed): # Spec: https://tools.ietf.org/html/rfc4287 - content_type = 'application/atom+xml; charset=utf-8' + content_type = "application/atom+xml; charset=utf-8" ns = "http://www.w3.org/2005/Atom" def write(self, outfile, encoding): handler = SimplerXMLGenerator(outfile, encoding, short_empty_elements=True) handler.startDocument() - handler.startElement('feed', self.root_attributes()) + handler.startElement("feed", self.root_attributes()) self.add_root_elements(handler) self.write_items(handler) handler.endElement("feed") def root_attributes(self): - if self.feed['language'] is not None: - return {"xmlns": self.ns, "xml:lang": self.feed['language']} + if self.feed["language"] is not None: + return {"xmlns": self.ns, "xml:lang": self.feed["language"]} else: return {"xmlns": self.ns} def add_root_elements(self, handler): - handler.addQuickElement("title", self.feed['title']) - handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']}) - if self.feed['feed_url'] is not None: - handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']}) - handler.addQuickElement("id", self.feed['id']) + handler.addQuickElement("title", self.feed["title"]) + handler.addQuickElement( + "link", "", {"rel": "alternate", "href": self.feed["link"]} + ) + if self.feed["feed_url"] is not None: + handler.addQuickElement( + "link", "", {"rel": "self", "href": self.feed["feed_url"]} + ) + handler.addQuickElement("id", self.feed["id"]) handler.addQuickElement("updated", rfc3339_date(self.latest_post_date())) - if self.feed['author_name'] is not None: + if self.feed["author_name"] is not None: handler.startElement("author", {}) - handler.addQuickElement("name", self.feed['author_name']) - if self.feed['author_email'] is not None: - handler.addQuickElement("email", self.feed['author_email']) - if self.feed['author_link'] is not None: - handler.addQuickElement("uri", self.feed['author_link']) + handler.addQuickElement("name", self.feed["author_name"]) + if self.feed["author_email"] is not None: + handler.addQuickElement("email", self.feed["author_email"]) + if self.feed["author_link"] is not None: + handler.addQuickElement("uri", self.feed["author_link"]) handler.endElement("author") - if self.feed['subtitle'] is not None: - handler.addQuickElement("subtitle", self.feed['subtitle']) - for cat in self.feed['categories']: + if self.feed["subtitle"] is not None: + handler.addQuickElement("subtitle", self.feed["subtitle"]) + for cat in self.feed["categories"]: handler.addQuickElement("category", "", {"term": cat}) - if self.feed['feed_copyright'] is not None: - handler.addQuickElement("rights", self.feed['feed_copyright']) + if self.feed["feed_copyright"] is not None: + handler.addQuickElement("rights", self.feed["feed_copyright"]) def write_items(self, handler): for item in self.items: @@ -338,52 +390,56 @@ class Atom1Feed(SyndicationFeed): handler.endElement("entry") def add_item_elements(self, handler, item): - handler.addQuickElement("title", item['title']) - handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"}) + handler.addQuickElement("title", item["title"]) + handler.addQuickElement("link", "", {"href": item["link"], "rel": "alternate"}) - if item['pubdate'] is not None: - handler.addQuickElement('published', rfc3339_date(item['pubdate'])) + if item["pubdate"] is not None: + handler.addQuickElement("published", rfc3339_date(item["pubdate"])) - if item['updateddate'] is not None: - handler.addQuickElement('updated', rfc3339_date(item['updateddate'])) + if item["updateddate"] is not None: + handler.addQuickElement("updated", rfc3339_date(item["updateddate"])) # Author information. - if item['author_name'] is not None: + if item["author_name"] is not None: handler.startElement("author", {}) - handler.addQuickElement("name", item['author_name']) - if item['author_email'] is not None: - handler.addQuickElement("email", item['author_email']) - if item['author_link'] is not None: - handler.addQuickElement("uri", item['author_link']) + handler.addQuickElement("name", item["author_name"]) + if item["author_email"] is not None: + handler.addQuickElement("email", item["author_email"]) + if item["author_link"] is not None: + handler.addQuickElement("uri", item["author_link"]) handler.endElement("author") # Unique ID. - if item['unique_id'] is not None: - unique_id = item['unique_id'] + if item["unique_id"] is not None: + unique_id = item["unique_id"] else: - unique_id = get_tag_uri(item['link'], item['pubdate']) + unique_id = get_tag_uri(item["link"], item["pubdate"]) handler.addQuickElement("id", unique_id) # Summary. - if item['description'] is not None: - handler.addQuickElement("summary", item['description'], {"type": "html"}) + if item["description"] is not None: + handler.addQuickElement("summary", item["description"], {"type": "html"}) # Enclosures. - for enclosure in item['enclosures']: - handler.addQuickElement('link', '', { - 'rel': 'enclosure', - 'href': enclosure.url, - 'length': enclosure.length, - 'type': enclosure.mime_type, - }) + for enclosure in item["enclosures"]: + handler.addQuickElement( + "link", + "", + { + "rel": "enclosure", + "href": enclosure.url, + "length": enclosure.length, + "type": enclosure.mime_type, + }, + ) # Categories. - for cat in item['categories']: + for cat in item["categories"]: handler.addQuickElement("category", "", {"term": cat}) # Rights. - if item['item_copyright'] is not None: - handler.addQuickElement("rights", item['item_copyright']) + if item["item_copyright"] is not None: + handler.addQuickElement("rights", item["item_copyright"]) # This isolates the decision of what the system default is, so calling code can diff --git a/django/utils/formats.py b/django/utils/formats.py index 3aef3bc23c..50b58445e4 100644 --- a/django/utils/formats.py +++ b/django/utils/formats.py @@ -8,9 +8,7 @@ from importlib import import_module from django.conf import settings from django.utils import dateformat, numberformat from django.utils.functional import lazy -from django.utils.translation import ( - check_for_language, get_language, to_locale, -) +from django.utils.translation import check_for_language, get_language, to_locale # format_cache is a mapping from (format_type, lang) to the format string. # By using the cache, it is possible to avoid running get_format_modules @@ -19,33 +17,35 @@ _format_cache = {} _format_modules_cache = {} ISO_INPUT_FORMATS = { - 'DATE_INPUT_FORMATS': ['%Y-%m-%d'], - 'TIME_INPUT_FORMATS': ['%H:%M:%S', '%H:%M:%S.%f', '%H:%M'], - 'DATETIME_INPUT_FORMATS': [ - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%d %H:%M:%S.%f', - '%Y-%m-%d %H:%M', - '%Y-%m-%d' + "DATE_INPUT_FORMATS": ["%Y-%m-%d"], + "TIME_INPUT_FORMATS": ["%H:%M:%S", "%H:%M:%S.%f", "%H:%M"], + "DATETIME_INPUT_FORMATS": [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M:%S.%f", + "%Y-%m-%d %H:%M", + "%Y-%m-%d", ], } -FORMAT_SETTINGS = frozenset([ - 'DECIMAL_SEPARATOR', - 'THOUSAND_SEPARATOR', - 'NUMBER_GROUPING', - 'FIRST_DAY_OF_WEEK', - 'MONTH_DAY_FORMAT', - 'TIME_FORMAT', - 'DATE_FORMAT', - 'DATETIME_FORMAT', - 'SHORT_DATE_FORMAT', - 'SHORT_DATETIME_FORMAT', - 'YEAR_MONTH_FORMAT', - 'DATE_INPUT_FORMATS', - 'TIME_INPUT_FORMATS', - 'DATETIME_INPUT_FORMATS', -]) +FORMAT_SETTINGS = frozenset( + [ + "DECIMAL_SEPARATOR", + "THOUSAND_SEPARATOR", + "NUMBER_GROUPING", + "FIRST_DAY_OF_WEEK", + "MONTH_DAY_FORMAT", + "TIME_FORMAT", + "DATE_FORMAT", + "DATETIME_FORMAT", + "SHORT_DATE_FORMAT", + "SHORT_DATETIME_FORMAT", + "YEAR_MONTH_FORMAT", + "DATE_INPUT_FORMATS", + "TIME_INPUT_FORMATS", + "DATETIME_INPUT_FORMATS", + ] +) def reset_format_cache(): @@ -72,16 +72,16 @@ def iter_format_modules(lang, format_module_path=None): if isinstance(format_module_path, str): format_module_path = [format_module_path] for path in format_module_path: - format_locations.append(path + '.%s') - format_locations.append('django.conf.locale.%s') + format_locations.append(path + ".%s") + format_locations.append("django.conf.locale.%s") locale = to_locale(lang) locales = [locale] - if '_' in locale: - locales.append(locale.split('_')[0]) + if "_" in locale: + locales.append(locale.split("_")[0]) for location in format_locations: for loc in locales: try: - yield import_module('%s.formats' % (location % loc)) + yield import_module("%s.formats" % (location % loc)) except ImportError: pass @@ -91,7 +91,9 @@ def get_format_modules(lang=None): if lang is None: lang = get_language() if lang not in _format_modules_cache: - _format_modules_cache[lang] = list(iter_format_modules(lang, settings.FORMAT_MODULE_PATH)) + _format_modules_cache[lang] = list( + iter_format_modules(lang, settings.FORMAT_MODULE_PATH) + ) return _format_modules_cache[lang] @@ -104,11 +106,14 @@ def get_format(format_type, lang=None, use_l10n=None): If use_l10n is provided and is not None, it forces the value to be localized (or not), overriding the value of settings.USE_L10N. """ - use_l10n = use_l10n or (use_l10n is None and ( - settings._USE_L10N_INTERNAL - if hasattr(settings, '_USE_L10N_INTERNAL') - else settings.USE_L10N - )) + use_l10n = use_l10n or ( + use_l10n is None + and ( + settings._USE_L10N_INTERNAL + if hasattr(settings, "_USE_L10N_INTERNAL") + else settings.USE_L10N + ) + ) if use_l10n and lang is None: lang = get_language() cache_key = (format_type, lang) @@ -152,7 +157,9 @@ def date_format(value, format=None, use_l10n=None): If use_l10n is provided and is not None, that will force the value to be localized (or not), overriding the value of settings.USE_L10N. """ - return dateformat.format(value, get_format(format or 'DATE_FORMAT', use_l10n=use_l10n)) + return dateformat.format( + value, get_format(format or "DATE_FORMAT", use_l10n=use_l10n) + ) def time_format(value, format=None, use_l10n=None): @@ -162,7 +169,9 @@ def time_format(value, format=None, use_l10n=None): If use_l10n is provided and is not None, it forces the value to be localized (or not), overriding the value of settings.USE_L10N. """ - return dateformat.time_format(value, get_format(format or 'TIME_FORMAT', use_l10n=use_l10n)) + return dateformat.time_format( + value, get_format(format or "TIME_FORMAT", use_l10n=use_l10n) + ) def number_format(value, decimal_pos=None, use_l10n=None, force_grouping=False): @@ -172,18 +181,21 @@ def number_format(value, decimal_pos=None, use_l10n=None, force_grouping=False): If use_l10n is provided and is not None, it forces the value to be localized (or not), overriding the value of settings.USE_L10N. """ - use_l10n = use_l10n or (use_l10n is None and ( - settings._USE_L10N_INTERNAL - if hasattr(settings, '_USE_L10N_INTERNAL') - else settings.USE_L10N - )) + use_l10n = use_l10n or ( + use_l10n is None + and ( + settings._USE_L10N_INTERNAL + if hasattr(settings, "_USE_L10N_INTERNAL") + else settings.USE_L10N + ) + ) lang = get_language() if use_l10n else None return numberformat.format( value, - get_format('DECIMAL_SEPARATOR', lang, use_l10n=use_l10n), + get_format("DECIMAL_SEPARATOR", lang, use_l10n=use_l10n), decimal_pos, - get_format('NUMBER_GROUPING', lang, use_l10n=use_l10n), - get_format('THOUSAND_SEPARATOR', lang, use_l10n=use_l10n), + get_format("NUMBER_GROUPING", lang, use_l10n=use_l10n), + get_format("THOUSAND_SEPARATOR", lang, use_l10n=use_l10n), force_grouping=force_grouping, use_l10n=use_l10n, ) @@ -206,11 +218,11 @@ def localize(value, use_l10n=None): return str(value) return number_format(value, use_l10n=use_l10n) elif isinstance(value, datetime.datetime): - return date_format(value, 'DATETIME_FORMAT', use_l10n=use_l10n) + return date_format(value, "DATETIME_FORMAT", use_l10n=use_l10n) elif isinstance(value, datetime.date): return date_format(value, use_l10n=use_l10n) elif isinstance(value, datetime.time): - return time_format(value, 'TIME_FORMAT', use_l10n=use_l10n) + return time_format(value, "TIME_FORMAT", use_l10n=use_l10n) return value @@ -226,15 +238,15 @@ def localize_input(value, default=None): elif isinstance(value, (decimal.Decimal, float, int)): return number_format(value) elif isinstance(value, datetime.datetime): - format = default or get_format('DATETIME_INPUT_FORMATS')[0] + format = default or get_format("DATETIME_INPUT_FORMATS")[0] format = sanitize_strftime_format(format) return value.strftime(format) elif isinstance(value, datetime.date): - format = default or get_format('DATE_INPUT_FORMATS')[0] + format = default or get_format("DATE_INPUT_FORMATS")[0] format = sanitize_strftime_format(format) return value.strftime(format) elif isinstance(value, datetime.time): - format = default or get_format('TIME_INPUT_FORMATS')[0] + format = default or get_format("TIME_INPUT_FORMATS")[0] return value.strftime(format) return value @@ -262,12 +274,12 @@ def sanitize_strftime_format(fmt): See https://bugs.python.org/issue13305 for more details. """ - if datetime.date(1, 1, 1).strftime('%Y') == '0001': + if datetime.date(1, 1, 1).strftime("%Y") == "0001": return fmt - mapping = {'C': 2, 'F': 10, 'G': 4, 'Y': 4} + mapping = {"C": 2, "F": 10, "G": 4, "Y": 4} return re.sub( - r'((?:^|[^%])(?:%%)*)%([CFGY])', - lambda m: r'%s%%0%s%s' % (m[1], mapping[m[2]], m[2]), + r"((?:^|[^%])(?:%%)*)%([CFGY])", + lambda m: r"%s%%0%s%s" % (m[1], mapping[m[2]], m[2]), fmt, ) @@ -279,19 +291,25 @@ def sanitize_separators(value): """ if isinstance(value, str): parts = [] - decimal_separator = get_format('DECIMAL_SEPARATOR') + decimal_separator = get_format("DECIMAL_SEPARATOR") if decimal_separator in value: value, decimals = value.split(decimal_separator, 1) parts.append(decimals) if settings.USE_THOUSAND_SEPARATOR: - thousand_sep = get_format('THOUSAND_SEPARATOR') - if thousand_sep == '.' and value.count('.') == 1 and len(value.split('.')[-1]) != 3: + thousand_sep = get_format("THOUSAND_SEPARATOR") + if ( + thousand_sep == "." + and value.count(".") == 1 + and len(value.split(".")[-1]) != 3 + ): # Special case where we suspect a dot meant decimal separator (see #22171) pass else: for replacement in { - thousand_sep, unicodedata.normalize('NFKD', thousand_sep)}: - value = value.replace(replacement, '') + thousand_sep, + unicodedata.normalize("NFKD", thousand_sep), + }: + value = value.replace(replacement, "") parts.append(value) - value = '.'.join(reversed(parts)) + value = ".".join(reversed(parts)) return value diff --git a/django/utils/functional.py b/django/utils/functional.py index ea46cff20f..9e1be0fe0f 100644 --- a/django/utils/functional.py +++ b/django/utils/functional.py @@ -13,13 +13,14 @@ class cached_property: A cached property can be made out of an existing method: (e.g. ``url = cached_property(get_absolute_url)``). """ + name = None @staticmethod def func(instance): raise TypeError( - 'Cannot use cached_property instance without calling ' - '__set_name__() on it.' + "Cannot use cached_property instance without calling " + "__set_name__() on it." ) def __init__(self, func, name=None): @@ -33,7 +34,7 @@ class cached_property: stacklevel=2, ) self.real_func = func - self.__doc__ = getattr(func, '__doc__') + self.__doc__ = getattr(func, "__doc__") def __set_name__(self, owner, name): if self.name is None: @@ -62,6 +63,7 @@ class classproperty: Decorator that converts a method with a single cls argument into a property that can be accessed directly from the class. """ + def __init__(self, method=None): self.fget = method @@ -78,6 +80,7 @@ class Promise: Base class for the proxy class created in the closure of the lazy function. It's used to recognize promises in code. """ + pass @@ -96,6 +99,7 @@ def lazy(func, *resultclasses): called on the result of that function. The function is not evaluated until one of the methods on the result is called. """ + __prepared = False def __init__(self, args, kw): @@ -108,7 +112,7 @@ def lazy(func, *resultclasses): def __reduce__(self): return ( _lazy_proxy_unpickle, - (func, self.__args, self.__kw) + resultclasses + (func, self.__args, self.__kw) + resultclasses, ) def __repr__(self): @@ -129,7 +133,7 @@ def lazy(func, *resultclasses): cls._delegate_text = str in resultclasses if cls._delegate_bytes and cls._delegate_text: raise ValueError( - 'Cannot call lazy() with both bytes and text return types.' + "Cannot call lazy() with both bytes and text return types." ) if cls._delegate_text: cls.__str__ = cls.__text_cast @@ -144,6 +148,7 @@ def lazy(func, *resultclasses): # applies the given magic method of the result type. res = func(*self.__args, **self.__kw) return getattr(res, method_name)(*args, **kw) + return __wrapper__ def __text_cast(self): @@ -233,10 +238,15 @@ def keep_lazy(*resultclasses): @wraps(func) def wrapper(*args, **kwargs): - if any(isinstance(arg, Promise) for arg in itertools.chain(args, kwargs.values())): + if any( + isinstance(arg, Promise) + for arg in itertools.chain(args, kwargs.values()) + ): return lazy_func(*args, **kwargs) return func(*args, **kwargs) + return wrapper + return decorator @@ -255,6 +265,7 @@ def new_method_proxy(func): if self._wrapped is empty: self._setup() return func(self._wrapped, *args) + return inner @@ -297,7 +308,9 @@ class LazyObject: """ Must be implemented by subclasses to initialize the wrapped object. """ - raise NotImplementedError('subclasses of LazyObject must provide a _setup() method') + raise NotImplementedError( + "subclasses of LazyObject must provide a _setup() method" + ) # Because we have messed with __class__ below, we confuse pickle as to what # class we are pickling. We're going to have to initialize the wrapped @@ -376,6 +389,7 @@ class SimpleLazyObject(LazyObject): Designed for compound objects of unknown type. For builtins or objects of known type, use django.utils.functional.lazy. """ + def __init__(self, func): """ Pass in a callable that returns the object to be wrapped. @@ -385,7 +399,7 @@ class SimpleLazyObject(LazyObject): callable can be safely run more than once and will return the same value. """ - self.__dict__['_setupfunc'] = func + self.__dict__["_setupfunc"] = func super().__init__() def _setup(self): @@ -398,7 +412,7 @@ class SimpleLazyObject(LazyObject): repr_attr = self._setupfunc else: repr_attr = self._wrapped - return '<%s: %r>' % (type(self).__name__, repr_attr) + return "<%s: %r>" % (type(self).__name__, repr_attr) def __copy__(self): if self._wrapped is empty: diff --git a/django/utils/hashable.py b/django/utils/hashable.py index 7d137ccc2f..042e1a4373 100644 --- a/django/utils/hashable.py +++ b/django/utils/hashable.py @@ -8,10 +8,12 @@ def make_hashable(value): The returned value should generate the same hash for equal values. """ if isinstance(value, dict): - return tuple([ - (key, make_hashable(nested_value)) - for key, nested_value in sorted(value.items()) - ]) + return tuple( + [ + (key, make_hashable(nested_value)) + for key, nested_value in sorted(value.items()) + ] + ) # Try hash to avoid converting a hashable iterable (e.g. string, frozenset) # to a tuple. try: diff --git a/django/utils/html.py b/django/utils/html.py index be9f22312e..d228e4c7bc 100644 --- a/django/utils/html.py +++ b/django/utils/html.py @@ -4,9 +4,7 @@ import html import json import re from html.parser import HTMLParser -from urllib.parse import ( - parse_qsl, quote, unquote, urlencode, urlsplit, urlunsplit, -) +from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit, urlunsplit from django.utils.encoding import punycode from django.utils.functional import Promise, keep_lazy, keep_lazy_text @@ -30,22 +28,22 @@ def escape(text): _js_escapes = { - ord('\\'): '\\u005C', - ord('\''): '\\u0027', - ord('"'): '\\u0022', - ord('>'): '\\u003E', - ord('<'): '\\u003C', - ord('&'): '\\u0026', - ord('='): '\\u003D', - ord('-'): '\\u002D', - ord(';'): '\\u003B', - ord('`'): '\\u0060', - ord('\u2028'): '\\u2028', - ord('\u2029'): '\\u2029' + ord("\\"): "\\u005C", + ord("'"): "\\u0027", + ord('"'): "\\u0022", + ord(">"): "\\u003E", + ord("<"): "\\u003C", + ord("&"): "\\u0026", + ord("="): "\\u003D", + ord("-"): "\\u002D", + ord(";"): "\\u003B", + ord("`"): "\\u0060", + ord("\u2028"): "\\u2028", + ord("\u2029"): "\\u2029", } # Escape every ASCII character with a value less than 32. -_js_escapes.update((ord('%c' % z), '\\u%04X' % z) for z in range(32)) +_js_escapes.update((ord("%c" % z), "\\u%04X" % z) for z in range(32)) @keep_lazy(str, SafeString) @@ -55,9 +53,9 @@ def escapejs(value): _json_script_escapes = { - ord('>'): '\\u003E', - ord('<'): '\\u003C', - ord('&'): '\\u0026', + ord(">"): "\\u003E", + ord("<"): "\\u003C", + ord("&"): "\\u0026", } @@ -68,6 +66,7 @@ def json_script(value, element_id=None): the escaped JSON in a script tag. """ from django.core.serializers.json import DjangoJSONEncoder + json_str = json.dumps(value, cls=DjangoJSONEncoder).translate(_json_script_escapes) if element_id: template = '<script id="{}" type="application/json">{}</script>' @@ -87,7 +86,7 @@ def conditional_escape(text): """ if isinstance(text, Promise): text = str(text) - if hasattr(text, '__html__'): + if hasattr(text, "__html__"): return text.__html__() else: return escape(text) @@ -118,22 +117,23 @@ def format_html_join(sep, format_string, args_generator): format_html_join('\n', "<li>{} {}</li>", ((u.first_name, u.last_name) for u in users)) """ - return mark_safe(conditional_escape(sep).join( - format_html(format_string, *args) - for args in args_generator - )) + return mark_safe( + conditional_escape(sep).join( + format_html(format_string, *args) for args in args_generator + ) + ) @keep_lazy_text def linebreaks(value, autoescape=False): """Convert newlines into <p> and <br>s.""" value = normalize_newlines(value) - paras = re.split('\n{2,}', str(value)) + paras = re.split("\n{2,}", str(value)) if autoescape: - paras = ['<p>%s</p>' % escape(p).replace('\n', '<br>') for p in paras] + paras = ["<p>%s</p>" % escape(p).replace("\n", "<br>") for p in paras] else: - paras = ['<p>%s</p>' % p.replace('\n', '<br>') for p in paras] - return '\n\n'.join(paras) + paras = ["<p>%s</p>" % p.replace("\n", "<br>") for p in paras] + return "\n\n".join(paras) class MLStripper(HTMLParser): @@ -146,13 +146,13 @@ class MLStripper(HTMLParser): self.fed.append(d) def handle_entityref(self, name): - self.fed.append('&%s;' % name) + self.fed.append("&%s;" % name) def handle_charref(self, name): - self.fed.append('&#%s;' % name) + self.fed.append("&#%s;" % name) def get_data(self): - return ''.join(self.fed) + return "".join(self.fed) def _strip_once(value): @@ -171,9 +171,9 @@ def strip_tags(value): # Note: in typical case this loop executes _strip_once once. Loop condition # is redundant, but helps to reduce number of executions of _strip_once. value = str(value) - while '<' in value and '>' in value: + while "<" in value and ">" in value: new_value = _strip_once(value) - if value.count('<') == new_value.count('<'): + if value.count("<") == new_value.count("<"): # _strip_once wasn't able to detect more tags. break value = new_value @@ -183,17 +183,18 @@ def strip_tags(value): @keep_lazy_text def strip_spaces_between_tags(value): """Return the given HTML with spaces between tags removed.""" - return re.sub(r'>\s+<', '><', str(value)) + return re.sub(r">\s+<", "><", str(value)) def smart_urlquote(url): """Quote a URL if it isn't already quoted.""" + def unquote_quote(segment): segment = unquote(segment) # Tilde is part of RFC3986 Unreserved Characters # https://tools.ietf.org/html/rfc3986#section-2.3 # See also https://bugs.python.org/issue16285 - return quote(segment, safe=RFC3986_SUBDELIMS + RFC3986_GENDELIMS + '~') + return quote(segment, safe=RFC3986_SUBDELIMS + RFC3986_GENDELIMS + "~") # Handle IDN before quoting. try: @@ -210,8 +211,10 @@ def smart_urlquote(url): if query: # Separately unquoting key/value, so as to not mix querystring separators # included in query values. See #22267. - query_parts = [(unquote(q[0]), unquote(q[1])) - for q in parse_qsl(query, keep_blank_values=True)] + query_parts = [ + (unquote(q[0]), unquote(q[1])) + for q in parse_qsl(query, keep_blank_values=True) + ] # urlencode will take care of quoting query = urlencode(query_parts) @@ -230,17 +233,17 @@ class Urlizer: Links can have trailing punctuation (periods, commas, close-parens) and leading punctuation (opening parens) and it'll still do the right thing. """ - trailing_punctuation_chars = '.,:;!' - wrapping_punctuation = [('(', ')'), ('[', ']')] - simple_url_re = _lazy_re_compile(r'^https?://\[?\w', re.IGNORECASE) + trailing_punctuation_chars = ".,:;!" + wrapping_punctuation = [("(", ")"), ("[", "]")] + + simple_url_re = _lazy_re_compile(r"^https?://\[?\w", re.IGNORECASE) simple_url_2_re = _lazy_re_compile( - r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)($|/.*)$', - re.IGNORECASE + r"^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)($|/.*)$", re.IGNORECASE ) - word_split_re = _lazy_re_compile(r'''([\s<>"']+)''') + word_split_re = _lazy_re_compile(r"""([\s<>"']+)""") - mailto_template = 'mailto:{local}@{domain}' + mailto_template = "mailto:{local}@{domain}" url_template = '<a href="{href}"{attrs}>{url}</a>' def __call__(self, text, trim_url_limit=None, nofollow=False, autoescape=False): @@ -256,39 +259,48 @@ class Urlizer: safe_input = isinstance(text, SafeData) words = self.word_split_re.split(str(text)) - return ''.join([ - self.handle_word( - word, - safe_input=safe_input, - trim_url_limit=trim_url_limit, - nofollow=nofollow, - autoescape=autoescape, - ) for word in words - ]) + return "".join( + [ + self.handle_word( + word, + safe_input=safe_input, + trim_url_limit=trim_url_limit, + nofollow=nofollow, + autoescape=autoescape, + ) + for word in words + ] + ) def handle_word( - self, word, *, safe_input, trim_url_limit=None, nofollow=False, autoescape=False, + self, + word, + *, + safe_input, + trim_url_limit=None, + nofollow=False, + autoescape=False, ): - if '.' in word or '@' in word or ':' in word: + if "." in word or "@" in word or ":" in word: # lead: Punctuation trimmed from the beginning of the word. # middle: State of the word. # trail: Punctuation trimmed from the end of the word. lead, middle, trail = self.trim_punctuation(word) # Make URL we want to point to. url = None - nofollow_attr = ' rel="nofollow"' if nofollow else '' + nofollow_attr = ' rel="nofollow"' if nofollow else "" if self.simple_url_re.match(middle): url = smart_urlquote(html.unescape(middle)) elif self.simple_url_2_re.match(middle): - url = smart_urlquote('http://%s' % html.unescape(middle)) - elif ':' not in middle and self.is_email_simple(middle): - local, domain = middle.rsplit('@', 1) + url = smart_urlquote("http://%s" % html.unescape(middle)) + elif ":" not in middle and self.is_email_simple(middle): + local, domain = middle.rsplit("@", 1) try: domain = punycode(domain) except UnicodeError: return word url = self.mailto_template.format(local=local, domain=domain) - nofollow_attr = '' + nofollow_attr = "" # Make link. if url: trimmed = self.trim_url(middle, limit=trim_url_limit) @@ -300,7 +312,7 @@ class Urlizer: attrs=nofollow_attr, url=trimmed, ) - return mark_safe(f'{lead}{middle}{trail}') + return mark_safe(f"{lead}{middle}{trail}") else: if safe_input: return mark_safe(word) @@ -315,14 +327,14 @@ class Urlizer: def trim_url(self, x, *, limit): if limit is None or len(x) <= limit: return x - return '%s…' % x[:max(0, limit - 1)] + return "%s…" % x[: max(0, limit - 1)] def trim_punctuation(self, word): """ Trim trailing and wrapping punctuation from `word`. Return the items of the new state. """ - lead, middle, trail = '', word, '' + lead, middle, trail = "", word, "" # Continue trimming until middle remains unchanged. trimmed_something = True while trimmed_something: @@ -330,15 +342,15 @@ class Urlizer: # Trim wrapping punctuation. for opening, closing in self.wrapping_punctuation: if middle.startswith(opening): - middle = middle[len(opening):] + middle = middle[len(opening) :] lead += opening trimmed_something = True # Keep parentheses at the end only if they're balanced. if ( - middle.endswith(closing) and - middle.count(closing) == middle.count(opening) + 1 + middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1 ): - middle = middle[:-len(closing)] + middle = middle[: -len(closing)] trail = closing + trail trimmed_something = True # Trim trailing punctuation (after trimming wrapping punctuation, @@ -357,15 +369,15 @@ class Urlizer: def is_email_simple(value): """Return True if value looks like an email address.""" # An @ must be in the middle of the value. - if '@' not in value or value.startswith('@') or value.endswith('@'): + if "@" not in value or value.startswith("@") or value.endswith("@"): return False try: - p1, p2 = value.split('@') + p1, p2 = value.split("@") except ValueError: # value contains more than one @. return False # Dot must be in p2 (e.g. example.com) - if '.' not in p2 or p2.startswith('.'): + if "." not in p2 or p2.startswith("."): return False return True @@ -375,7 +387,9 @@ urlizer = Urlizer() @keep_lazy_text def urlize(text, trim_url_limit=None, nofollow=False, autoescape=False): - return urlizer(text, trim_url_limit=trim_url_limit, nofollow=nofollow, autoescape=autoescape) + return urlizer( + text, trim_url_limit=trim_url_limit, nofollow=nofollow, autoescape=autoescape + ) def avoid_wrapping(value): @@ -391,12 +405,12 @@ def html_safe(klass): A decorator that defines the __html__ method. This helps non-Django templates to detect classes whose __str__ methods return SafeString. """ - if '__html__' in klass.__dict__: + if "__html__" in klass.__dict__: raise ValueError( "can't apply @html_safe to %s because it defines " "__html__()." % klass.__name__ ) - if '__str__' not in klass.__dict__: + if "__str__" not in klass.__dict__: raise ValueError( "can't apply @html_safe to %s because it doesn't " "define __str__()." % klass.__name__ diff --git a/django/utils/http.py b/django/utils/http.py index ab90f1e377..0292713235 100644 --- a/django/utils/http.py +++ b/django/utils/http.py @@ -5,33 +5,42 @@ import unicodedata from binascii import Error as BinasciiError from email.utils import formatdate from urllib.parse import ( - ParseResult, SplitResult, _coerce_args, _splitnetloc, _splitparams, - scheme_chars, urlencode as original_urlencode, uses_params, + ParseResult, + SplitResult, + _coerce_args, + _splitnetloc, + _splitparams, + scheme_chars, ) +from urllib.parse import urlencode as original_urlencode +from urllib.parse import uses_params from django.utils.datastructures import MultiValueDict from django.utils.regex_helper import _lazy_re_compile # based on RFC 7232, Appendix C -ETAG_MATCH = _lazy_re_compile(r''' +ETAG_MATCH = _lazy_re_compile( + r""" \A( # start of string and capture group (?:W/)? # optional weak indicator " # opening quote [^"]* # any sequence of non-quote characters " # end quote )\Z # end of string and capture group -''', re.X) +""", + re.X, +) -MONTHS = 'jan feb mar apr may jun jul aug sep oct nov dec'.split() -__D = r'(?P<day>[0-9]{2})' -__D2 = r'(?P<day>[ 0-9][0-9])' -__M = r'(?P<mon>\w{3})' -__Y = r'(?P<year>[0-9]{4})' -__Y2 = r'(?P<year>[0-9]{2})' -__T = r'(?P<hour>[0-9]{2}):(?P<min>[0-9]{2}):(?P<sec>[0-9]{2})' -RFC1123_DATE = _lazy_re_compile(r'^\w{3}, %s %s %s %s GMT$' % (__D, __M, __Y, __T)) -RFC850_DATE = _lazy_re_compile(r'^\w{6,9}, %s-%s-%s %s GMT$' % (__D, __M, __Y2, __T)) -ASCTIME_DATE = _lazy_re_compile(r'^\w{3} %s %s %s %s$' % (__M, __D2, __T, __Y)) +MONTHS = "jan feb mar apr may jun jul aug sep oct nov dec".split() +__D = r"(?P<day>[0-9]{2})" +__D2 = r"(?P<day>[ 0-9][0-9])" +__M = r"(?P<mon>\w{3})" +__Y = r"(?P<year>[0-9]{4})" +__Y2 = r"(?P<year>[0-9]{2})" +__T = r"(?P<hour>[0-9]{2}):(?P<min>[0-9]{2}):(?P<sec>[0-9]{2})" +RFC1123_DATE = _lazy_re_compile(r"^\w{3}, %s %s %s %s GMT$" % (__D, __M, __Y, __T)) +RFC850_DATE = _lazy_re_compile(r"^\w{6,9}, %s-%s-%s %s GMT$" % (__D, __M, __Y2, __T)) +ASCTIME_DATE = _lazy_re_compile(r"^\w{3} %s %s %s %s$" % (__M, __D2, __T, __Y)) RFC3986_GENDELIMS = ":/?#[]@" RFC3986_SUBDELIMS = "!$&'()*+,;=" @@ -44,7 +53,7 @@ def urlencode(query, doseq=False): """ if isinstance(query, MultiValueDict): query = query.lists() - elif hasattr(query, 'items'): + elif hasattr(query, "items"): query = query.items() query_params = [] for key, value in query: @@ -112,7 +121,7 @@ def parse_http_date(date): raise ValueError("%r is not in a valid HTTP date format" % date) try: tz = datetime.timezone.utc - year = int(m['year']) + year = int(m["year"]) if year < 100: current_year = datetime.datetime.now(tz=tz).year current_century = current_year - (current_year % 100) @@ -122,11 +131,11 @@ def parse_http_date(date): year += current_century - 100 else: year += current_century - month = MONTHS.index(m['mon'].lower()) + 1 - day = int(m['day']) - hour = int(m['hour']) - min = int(m['min']) - sec = int(m['sec']) + month = MONTHS.index(m["mon"].lower()) + 1 + day = int(m["day"]) + hour = int(m["hour"]) + min = int(m["min"]) + sec = int(m["sec"]) result = datetime.datetime(year, month, day, hour, min, sec, tzinfo=tz) return int(result.timestamp()) except Exception as exc: @@ -145,6 +154,7 @@ def parse_http_date_safe(date): # Base 36 functions: useful for generating compact URLs + def base36_to_int(s): """ Convert a base 36 string to an int. Raise ValueError if the input won't fit @@ -160,12 +170,12 @@ def base36_to_int(s): def int_to_base36(i): """Convert an integer to a base36 string.""" - char_set = '0123456789abcdefghijklmnopqrstuvwxyz' + char_set = "0123456789abcdefghijklmnopqrstuvwxyz" if i < 0: raise ValueError("Negative base36 conversion input.") if i < 36: return char_set[i] - b36 = '' + b36 = "" while i != 0: i, n = divmod(i, 36) b36 = char_set[n] + b36 @@ -177,7 +187,7 @@ def urlsafe_base64_encode(s): Encode a bytestring to a base64 string for use in URLs. Strip any trailing equal signs. """ - return base64.urlsafe_b64encode(s).rstrip(b'\n=').decode('ascii') + return base64.urlsafe_b64encode(s).rstrip(b"\n=").decode("ascii") def urlsafe_base64_decode(s): @@ -187,7 +197,7 @@ def urlsafe_base64_decode(s): """ s = s.encode() try: - return base64.urlsafe_b64decode(s.ljust(len(s) + len(s) % 4, b'=')) + return base64.urlsafe_b64decode(s.ljust(len(s) + len(s) % 4, b"=")) except (LookupError, BinasciiError) as e: raise ValueError(e) @@ -198,11 +208,11 @@ def parse_etags(etag_str): defined by RFC 7232. Return a list of quoted ETags, or ['*'] if all ETags should be matched. """ - if etag_str.strip() == '*': - return ['*'] + if etag_str.strip() == "*": + return ["*"] else: # Parse each ETag individually, and return any that are valid. - etag_matches = (ETAG_MATCH.match(etag.strip()) for etag in etag_str.split(',')) + etag_matches = (ETAG_MATCH.match(etag.strip()) for etag in etag_str.split(",")) return [match[1] for match in etag_matches if match] @@ -231,8 +241,9 @@ def is_same_domain(host, pattern): pattern = pattern.lower() return ( - pattern[0] == '.' and (host.endswith(pattern) or host == pattern[1:]) or - pattern == host + pattern[0] == "." + and (host.endswith(pattern) or host == pattern[1:]) + or pattern == host ) @@ -259,14 +270,15 @@ def url_has_allowed_host_and_scheme(url, allowed_hosts, require_https=False): allowed_hosts = {allowed_hosts} # Chrome treats \ completely as / in paths but it could be part of some # basic auth credentials so we need to check both URLs. - return ( - _url_has_allowed_host_and_scheme(url, allowed_hosts, require_https=require_https) and - _url_has_allowed_host_and_scheme(url.replace('\\', '/'), allowed_hosts, require_https=require_https) + return _url_has_allowed_host_and_scheme( + url, allowed_hosts, require_https=require_https + ) and _url_has_allowed_host_and_scheme( + url.replace("\\", "/"), allowed_hosts, require_https=require_https ) # Copied from urllib.parse.urlparse() but uses fixed urlsplit() function. -def _urlparse(url, scheme='', allow_fragments=True): +def _urlparse(url, scheme="", allow_fragments=True): """Parse a URL into 6 components: <scheme>://<netloc>/<path>;<params>?<query>#<fragment> Return a 6-tuple: (scheme, netloc, path, params, query, fragment). @@ -275,41 +287,42 @@ def _urlparse(url, scheme='', allow_fragments=True): url, scheme, _coerce_result = _coerce_args(url, scheme) splitresult = _urlsplit(url, scheme, allow_fragments) scheme, netloc, url, query, fragment = splitresult - if scheme in uses_params and ';' in url: + if scheme in uses_params and ";" in url: url, params = _splitparams(url) else: - params = '' + params = "" result = ParseResult(scheme, netloc, url, params, query, fragment) return _coerce_result(result) # Copied from urllib.parse.urlsplit() with # https://github.com/python/cpython/pull/661 applied. -def _urlsplit(url, scheme='', allow_fragments=True): +def _urlsplit(url, scheme="", allow_fragments=True): """Parse a URL into 5 components: <scheme>://<netloc>/<path>?<query>#<fragment> Return a 5-tuple: (scheme, netloc, path, query, fragment). Note that we don't break the components up in smaller bits (e.g. netloc is a single string) and we don't expand % escapes.""" url, scheme, _coerce_result = _coerce_args(url, scheme) - netloc = query = fragment = '' - i = url.find(':') + netloc = query = fragment = "" + i = url.find(":") if i > 0: for c in url[:i]: if c not in scheme_chars: break else: - scheme, url = url[:i].lower(), url[i + 1:] + scheme, url = url[:i].lower(), url[i + 1 :] - if url[:2] == '//': + if url[:2] == "//": netloc, url = _splitnetloc(url, 2) - if (('[' in netloc and ']' not in netloc) or - (']' in netloc and '[' not in netloc)): + if ("[" in netloc and "]" not in netloc) or ( + "]" in netloc and "[" not in netloc + ): raise ValueError("Invalid IPv6 URL") - if allow_fragments and '#' in url: - url, fragment = url.split('#', 1) - if '?' in url: - url, query = url.split('?', 1) + if allow_fragments and "#" in url: + url, fragment = url.split("#", 1) + if "?" in url: + url, query = url.split("?", 1) v = SplitResult(scheme, netloc, url, query, fragment) return _coerce_result(v) @@ -317,7 +330,7 @@ def _urlsplit(url, scheme='', allow_fragments=True): def _url_has_allowed_host_and_scheme(url, allowed_hosts, require_https=False): # Chrome considers any URL with more than two slashes to be absolute, but # urlparse is not so flexible. Treat any url with three slashes as unsafe. - if url.startswith('///'): + if url.startswith("///"): return False try: url_info = _urlparse(url) @@ -332,15 +345,16 @@ def _url_has_allowed_host_and_scheme(url, allowed_hosts, require_https=False): # Forbid URLs that start with control characters. Some browsers (like # Chrome) ignore quite a few control characters at the start of a # URL and might consider the URL as scheme relative. - if unicodedata.category(url[0])[0] == 'C': + if unicodedata.category(url[0])[0] == "C": return False scheme = url_info.scheme # Consider URLs without a scheme (e.g. //example.com/p) to be http. if not url_info.scheme and url_info.netloc: - scheme = 'http' - valid_schemes = ['https'] if require_https else ['http', 'https'] - return ((not url_info.netloc or url_info.netloc in allowed_hosts) and - (not scheme or scheme in valid_schemes)) + scheme = "http" + valid_schemes = ["https"] if require_https else ["http", "https"] + return (not url_info.netloc or url_info.netloc in allowed_hosts) and ( + not scheme or scheme in valid_schemes + ) def escape_leading_slashes(url): @@ -349,6 +363,6 @@ def escape_leading_slashes(url): escaped to prevent browsers from handling the path as schemaless and redirecting to another host. """ - if url.startswith('//'): - url = '/%2F{}'.format(url[2:]) + if url.startswith("//"): + url = "/%2F{}".format(url[2:]) return url diff --git a/django/utils/inspect.py b/django/utils/inspect.py index 7e062244e5..28418f7312 100644 --- a/django/utils/inspect.py +++ b/django/utils/inspect.py @@ -19,7 +19,8 @@ def _get_callable_parameters(meth_or_func): def get_func_args(func): params = _get_callable_parameters(func) return [ - param.name for param in params + param.name + for param in params if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ] @@ -35,12 +36,12 @@ def get_func_full_args(func): for param in params: name = param.name # Ignore 'self' - if name == 'self': + if name == "self": continue if param.kind == inspect.Parameter.VAR_POSITIONAL: - name = '*' + name + name = "*" + name elif param.kind == inspect.Parameter.VAR_KEYWORD: - name = '**' + name + name = "**" + name if param.default != inspect.Parameter.empty: args.append((name, param.default)) else: @@ -50,28 +51,21 @@ def get_func_full_args(func): def func_accepts_kwargs(func): """Return True if function 'func' accepts keyword arguments **kwargs.""" - return any( - p for p in _get_callable_parameters(func) - if p.kind == p.VAR_KEYWORD - ) + return any(p for p in _get_callable_parameters(func) if p.kind == p.VAR_KEYWORD) def func_accepts_var_args(func): """ Return True if function 'func' accepts positional arguments *args. """ - return any( - p for p in _get_callable_parameters(func) - if p.kind == p.VAR_POSITIONAL - ) + return any(p for p in _get_callable_parameters(func) if p.kind == p.VAR_POSITIONAL) def method_has_no_args(meth): """Return True if a method only accepts 'self'.""" - count = len([ - p for p in _get_callable_parameters(meth) - if p.kind == p.POSITIONAL_OR_KEYWORD - ]) + count = len( + [p for p in _get_callable_parameters(meth) if p.kind == p.POSITIONAL_OR_KEYWORD] + ) return count == 0 if inspect.ismethod(meth) else count == 1 diff --git a/django/utils/ipv6.py b/django/utils/ipv6.py index ddb8c8091d..88dd6ecb4b 100644 --- a/django/utils/ipv6.py +++ b/django/utils/ipv6.py @@ -4,8 +4,9 @@ from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ -def clean_ipv6_address(ip_str, unpack_ipv4=False, - error_message=_("This is not a valid IPv6 address.")): +def clean_ipv6_address( + ip_str, unpack_ipv4=False, error_message=_("This is not a valid IPv6 address.") +): """ Clean an IPv6 address string. @@ -25,12 +26,12 @@ def clean_ipv6_address(ip_str, unpack_ipv4=False, try: addr = ipaddress.IPv6Address(int(ipaddress.IPv6Address(ip_str))) except ValueError: - raise ValidationError(error_message, code='invalid') + raise ValidationError(error_message, code="invalid") if unpack_ipv4 and addr.ipv4_mapped: return str(addr.ipv4_mapped) elif addr.ipv4_mapped: - return '::ffff:%s' % str(addr.ipv4_mapped) + return "::ffff:%s" % str(addr.ipv4_mapped) return str(addr) diff --git a/django/utils/jslex.py b/django/utils/jslex.py index 8abf5f1126..93a1a2e972 100644 --- a/django/utils/jslex.py +++ b/django/utils/jslex.py @@ -7,6 +7,7 @@ class Tok: """ A specification for a token class. """ + num = 0 def __init__(self, name, regex, next=None): @@ -101,21 +102,34 @@ class JsLexer(Lexer): Tok("comment", r"/\*(.|\n)*?\*/"), Tok("linecomment", r"//.*?$"), Tok("ws", r"\s+"), - Tok("keyword", literals(""" + Tok( + "keyword", + literals( + """ break case catch class const continue debugger default delete do else enum export extends finally for function if import in instanceof new return super switch this throw try typeof var void while with - """, suffix=r"\b"), next='reg'), - Tok("reserved", literals("null true false", suffix=r"\b"), next='div'), - Tok("id", r""" + """, + suffix=r"\b", + ), + next="reg", + ), + Tok("reserved", literals("null true false", suffix=r"\b"), next="div"), + Tok( + "id", + r""" ([a-zA-Z_$ ]|\\u[0-9a-fA-Z]{4}) # first char ([a-zA-Z_$0-9]|\\u[0-9a-fA-F]{4})* # rest chars - """, next='div'), - Tok("hnum", r"0[xX][0-9a-fA-F]+", next='div'), + """, + next="div", + ), + Tok("hnum", r"0[xX][0-9a-fA-F]+", next="div"), Tok("onum", r"0[0-7]+"), - Tok("dnum", r""" + Tok( + "dnum", + r""" ( (0|[1-9][0-9]*) # DecimalIntegerLiteral \. # dot [0-9]* # DecimalDigits-opt @@ -128,15 +142,23 @@ class JsLexer(Lexer): (0|[1-9][0-9]*) # DecimalIntegerLiteral ([eE][-+]?[0-9]+)? # ExponentPart-opt ) - """, next='div'), - Tok("punct", literals(""" + """, + next="div", + ), + Tok( + "punct", + literals( + """ >>>= === !== >>> <<= >>= <= >= == != << >> && || += -= *= %= &= |= ^= - """), next="reg"), - Tok("punct", literals("++ -- ) ]"), next='div'), - Tok("punct", literals("{ } ( [ . ; , < > + - * % & | ^ ! ~ ? : ="), next='reg'), - Tok("string", r'"([^"\\]|(\\(.|\n)))*?"', next='div'), - Tok("string", r"'([^'\\]|(\\(.|\n)))*?'", next='div'), + """ + ), + next="reg", + ), + Tok("punct", literals("++ -- ) ]"), next="div"), + Tok("punct", literals("{ } ( [ . ; , < > + - * % & | ^ ! ~ ? : ="), next="reg"), + Tok("string", r'"([^"\\]|(\\(.|\n)))*?"', next="div"), + Tok("string", r"'([^'\\]|(\\(.|\n)))*?'", next="div"), ] both_after = [ @@ -145,13 +167,16 @@ class JsLexer(Lexer): states = { # slash will mean division - 'div': both_before + [ - Tok("punct", literals("/= /"), next='reg'), - ] + both_after, - + "div": both_before + + [ + Tok("punct", literals("/= /"), next="reg"), + ] + + both_after, # slash will mean regex - 'reg': both_before + [ - Tok("regex", + "reg": both_before + + [ + Tok( + "regex", r""" / # opening slash # First character is.. @@ -174,12 +199,15 @@ class JsLexer(Lexer): )* # many times / # closing slash [a-zA-Z0-9]* # trailing flags - """, next='div'), - ] + both_after, + """, + next="div", + ), + ] + + both_after, } def __init__(self): - super().__init__(self.states, 'reg') + super().__init__(self.states, "reg") def prepare_js_for_gettext(js): @@ -190,31 +218,32 @@ def prepare_js_for_gettext(js): What actually happens is that all the regex literals are replaced with "REGEX". """ + def escape_quotes(m): """Used in a regex to properly escape double quotes.""" s = m[0] if s == '"': - return r'\"' + return r"\"" else: return s lexer = JsLexer() c = [] for name, tok in lexer.lex(js): - if name == 'regex': + if name == "regex": # C doesn't grok regexes, and they aren't needed for gettext, # so just output a string instead. tok = '"REGEX"' - elif name == 'string': + elif name == "string": # C doesn't have single-quoted strings, so make all strings # double-quoted. if tok.startswith("'"): guts = re.sub(r"\\.|.", escape_quotes, tok[1:-1]) tok = '"' + guts + '"' - elif name == 'id': + elif name == "id": # C can't deal with Unicode escapes in identifiers. We don't # need them for gettext anyway, so replace them with something # innocuous tok = tok.replace("\\", "U") c.append(tok) - return ''.join(c) + return "".join(c) diff --git a/django/utils/log.py b/django/utils/log.py index 5a5decd531..fd0cc1bdc1 100644 --- a/django/utils/log.py +++ b/django/utils/log.py @@ -8,7 +8,7 @@ from django.core.mail import get_connection from django.core.management.color import color_style from django.utils.module_loading import import_string -request_logger = logging.getLogger('django.request') +request_logger = logging.getLogger("django.request") # Default logging for Django. This sends an email to the site admins on every # HTTP 500 error. Depending on DEBUG, all other log records are either sent to @@ -16,51 +16,51 @@ request_logger = logging.getLogger('django.request') # require_debug_true filter. This configuration is quoted in # docs/ref/logging.txt; please amend it there if edited here. DEFAULT_LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'filters': { - 'require_debug_false': { - '()': 'django.utils.log.RequireDebugFalse', + "version": 1, + "disable_existing_loggers": False, + "filters": { + "require_debug_false": { + "()": "django.utils.log.RequireDebugFalse", }, - 'require_debug_true': { - '()': 'django.utils.log.RequireDebugTrue', + "require_debug_true": { + "()": "django.utils.log.RequireDebugTrue", }, }, - 'formatters': { - 'django.server': { - '()': 'django.utils.log.ServerFormatter', - 'format': '[{server_time}] {message}', - 'style': '{', + "formatters": { + "django.server": { + "()": "django.utils.log.ServerFormatter", + "format": "[{server_time}] {message}", + "style": "{", } }, - 'handlers': { - 'console': { - 'level': 'INFO', - 'filters': ['require_debug_true'], - 'class': 'logging.StreamHandler', + "handlers": { + "console": { + "level": "INFO", + "filters": ["require_debug_true"], + "class": "logging.StreamHandler", }, - 'django.server': { - 'level': 'INFO', - 'class': 'logging.StreamHandler', - 'formatter': 'django.server', + "django.server": { + "level": "INFO", + "class": "logging.StreamHandler", + "formatter": "django.server", + }, + "mail_admins": { + "level": "ERROR", + "filters": ["require_debug_false"], + "class": "django.utils.log.AdminEmailHandler", }, - 'mail_admins': { - 'level': 'ERROR', - 'filters': ['require_debug_false'], - 'class': 'django.utils.log.AdminEmailHandler' - } }, - 'loggers': { - 'django': { - 'handlers': ['console', 'mail_admins'], - 'level': 'INFO', + "loggers": { + "django": { + "handlers": ["console", "mail_admins"], + "level": "INFO", }, - 'django.server': { - 'handlers': ['django.server'], - 'level': 'INFO', - 'propagate': False, + "django.server": { + "handlers": ["django.server"], + "level": "INFO", + "propagate": False, }, - } + }, } @@ -87,22 +87,24 @@ class AdminEmailHandler(logging.Handler): super().__init__() self.include_html = include_html self.email_backend = email_backend - self.reporter_class = import_string(reporter_class or settings.DEFAULT_EXCEPTION_REPORTER) + self.reporter_class = import_string( + reporter_class or settings.DEFAULT_EXCEPTION_REPORTER + ) def emit(self, record): try: request = record.request - subject = '%s (%s IP): %s' % ( + subject = "%s (%s IP): %s" % ( record.levelname, - ('internal' if request.META.get('REMOTE_ADDR') in settings.INTERNAL_IPS - else 'EXTERNAL'), - record.getMessage() + ( + "internal" + if request.META.get("REMOTE_ADDR") in settings.INTERNAL_IPS + else "EXTERNAL" + ), + record.getMessage(), ) except Exception: - subject = '%s: %s' % ( - record.levelname, - record.getMessage() - ) + subject = "%s: %s" % (record.levelname, record.getMessage()) request = None subject = self.format_subject(subject) @@ -118,12 +120,17 @@ class AdminEmailHandler(logging.Handler): exc_info = (None, record.getMessage(), None) reporter = self.reporter_class(request, is_email=True, *exc_info) - message = "%s\n\n%s" % (self.format(no_exc_record), reporter.get_traceback_text()) + message = "%s\n\n%s" % ( + self.format(no_exc_record), + reporter.get_traceback_text(), + ) html_message = reporter.get_traceback_html() if self.include_html else None self.send_mail(subject, message, fail_silently=True, html_message=html_message) def send_mail(self, subject, message, *args, **kwargs): - mail.mail_admins(subject, message, *args, connection=self.connection(), **kwargs) + mail.mail_admins( + subject, message, *args, connection=self.connection(), **kwargs + ) def connection(self): return get_connection(backend=self.email_backend, fail_silently=True) @@ -132,7 +139,7 @@ class AdminEmailHandler(logging.Handler): """ Escape CR and LF characters. """ - return subject.replace('\n', '\\n').replace('\r', '\\r') + return subject.replace("\n", "\\n").replace("\r", "\\r") class CallbackFilter(logging.Filter): @@ -141,6 +148,7 @@ class CallbackFilter(logging.Filter): takes the record-to-be-logged as its only parameter) to decide whether to log a record. """ + def __init__(self, callback): self.callback = callback @@ -161,7 +169,7 @@ class RequireDebugTrue(logging.Filter): class ServerFormatter(logging.Formatter): - default_time_format = '%d/%b/%Y %H:%M:%S' + default_time_format = "%d/%b/%Y %H:%M:%S" def __init__(self, *args, **kwargs): self.style = color_style() @@ -169,7 +177,7 @@ class ServerFormatter(logging.Formatter): def format(self, record): msg = record.msg - status_code = getattr(record, 'status_code', None) + status_code = getattr(record, "status_code", None) if status_code: if 200 <= status_code < 300: @@ -189,17 +197,25 @@ class ServerFormatter(logging.Formatter): # Any 5XX, or any other status code msg = self.style.HTTP_SERVER_ERROR(msg) - if self.uses_server_time() and not hasattr(record, 'server_time'): + if self.uses_server_time() and not hasattr(record, "server_time"): record.server_time = self.formatTime(record, self.datefmt) record.msg = msg return super().format(record) def uses_server_time(self): - return self._fmt.find('{server_time}') >= 0 + return self._fmt.find("{server_time}") >= 0 -def log_response(message, *args, response=None, request=None, logger=request_logger, level=None, exception=None): +def log_response( + message, + *args, + response=None, + request=None, + logger=request_logger, + level=None, + exception=None, +): """ Log errors based on HttpResponse status. @@ -211,22 +227,23 @@ def log_response(message, *args, response=None, request=None, logger=request_log # the same response can be received in some cases, e.g., when the # response is the result of an exception and is logged when the exception # is caught, to record the exception. - if getattr(response, '_has_been_logged', False): + if getattr(response, "_has_been_logged", False): return if level is None: if response.status_code >= 500: - level = 'error' + level = "error" elif response.status_code >= 400: - level = 'warning' + level = "warning" else: - level = 'info' + level = "info" getattr(logger, level)( - message, *args, + message, + *args, extra={ - 'status_code': response.status_code, - 'request': request, + "status_code": response.status_code, + "request": request, }, exc_info=exception, ) diff --git a/django/utils/lorem_ipsum.py b/django/utils/lorem_ipsum.py index cfa675d70a..5cbc4e5a60 100644 --- a/django/utils/lorem_ipsum.py +++ b/django/utils/lorem_ipsum.py @@ -5,51 +5,220 @@ Utility functions for generating "lorem ipsum" Latin text. import random COMMON_P = ( - 'Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod ' - 'tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim ' - 'veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea ' - 'commodo consequat. Duis aute irure dolor in reprehenderit in voluptate ' - 'velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint ' - 'occaecat cupidatat non proident, sunt in culpa qui officia deserunt ' - 'mollit anim id est laborum.' + "Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod " + "tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim " + "veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea " + "commodo consequat. Duis aute irure dolor in reprehenderit in voluptate " + "velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint " + "occaecat cupidatat non proident, sunt in culpa qui officia deserunt " + "mollit anim id est laborum." ) WORDS = ( - 'exercitationem', 'perferendis', 'perspiciatis', 'laborum', 'eveniet', - 'sunt', 'iure', 'nam', 'nobis', 'eum', 'cum', 'officiis', 'excepturi', - 'odio', 'consectetur', 'quasi', 'aut', 'quisquam', 'vel', 'eligendi', - 'itaque', 'non', 'odit', 'tempore', 'quaerat', 'dignissimos', - 'facilis', 'neque', 'nihil', 'expedita', 'vitae', 'vero', 'ipsum', - 'nisi', 'animi', 'cumque', 'pariatur', 'velit', 'modi', 'natus', - 'iusto', 'eaque', 'sequi', 'illo', 'sed', 'ex', 'et', 'voluptatibus', - 'tempora', 'veritatis', 'ratione', 'assumenda', 'incidunt', 'nostrum', - 'placeat', 'aliquid', 'fuga', 'provident', 'praesentium', 'rem', - 'necessitatibus', 'suscipit', 'adipisci', 'quidem', 'possimus', - 'voluptas', 'debitis', 'sint', 'accusantium', 'unde', 'sapiente', - 'voluptate', 'qui', 'aspernatur', 'laudantium', 'soluta', 'amet', - 'quo', 'aliquam', 'saepe', 'culpa', 'libero', 'ipsa', 'dicta', - 'reiciendis', 'nesciunt', 'doloribus', 'autem', 'impedit', 'minima', - 'maiores', 'repudiandae', 'ipsam', 'obcaecati', 'ullam', 'enim', - 'totam', 'delectus', 'ducimus', 'quis', 'voluptates', 'dolores', - 'molestiae', 'harum', 'dolorem', 'quia', 'voluptatem', 'molestias', - 'magni', 'distinctio', 'omnis', 'illum', 'dolorum', 'voluptatum', 'ea', - 'quas', 'quam', 'corporis', 'quae', 'blanditiis', 'atque', 'deserunt', - 'laboriosam', 'earum', 'consequuntur', 'hic', 'cupiditate', - 'quibusdam', 'accusamus', 'ut', 'rerum', 'error', 'minus', 'eius', - 'ab', 'ad', 'nemo', 'fugit', 'officia', 'at', 'in', 'id', 'quos', - 'reprehenderit', 'numquam', 'iste', 'fugiat', 'sit', 'inventore', - 'beatae', 'repellendus', 'magnam', 'recusandae', 'quod', 'explicabo', - 'doloremque', 'aperiam', 'consequatur', 'asperiores', 'commodi', - 'optio', 'dolor', 'labore', 'temporibus', 'repellat', 'veniam', - 'architecto', 'est', 'esse', 'mollitia', 'nulla', 'a', 'similique', - 'eos', 'alias', 'dolore', 'tenetur', 'deleniti', 'porro', 'facere', - 'maxime', 'corrupti', + "exercitationem", + "perferendis", + "perspiciatis", + "laborum", + "eveniet", + "sunt", + "iure", + "nam", + "nobis", + "eum", + "cum", + "officiis", + "excepturi", + "odio", + "consectetur", + "quasi", + "aut", + "quisquam", + "vel", + "eligendi", + "itaque", + "non", + "odit", + "tempore", + "quaerat", + "dignissimos", + "facilis", + "neque", + "nihil", + "expedita", + "vitae", + "vero", + "ipsum", + "nisi", + "animi", + "cumque", + "pariatur", + "velit", + "modi", + "natus", + "iusto", + "eaque", + "sequi", + "illo", + "sed", + "ex", + "et", + "voluptatibus", + "tempora", + "veritatis", + "ratione", + "assumenda", + "incidunt", + "nostrum", + "placeat", + "aliquid", + "fuga", + "provident", + "praesentium", + "rem", + "necessitatibus", + "suscipit", + "adipisci", + "quidem", + "possimus", + "voluptas", + "debitis", + "sint", + "accusantium", + "unde", + "sapiente", + "voluptate", + "qui", + "aspernatur", + "laudantium", + "soluta", + "amet", + "quo", + "aliquam", + "saepe", + "culpa", + "libero", + "ipsa", + "dicta", + "reiciendis", + "nesciunt", + "doloribus", + "autem", + "impedit", + "minima", + "maiores", + "repudiandae", + "ipsam", + "obcaecati", + "ullam", + "enim", + "totam", + "delectus", + "ducimus", + "quis", + "voluptates", + "dolores", + "molestiae", + "harum", + "dolorem", + "quia", + "voluptatem", + "molestias", + "magni", + "distinctio", + "omnis", + "illum", + "dolorum", + "voluptatum", + "ea", + "quas", + "quam", + "corporis", + "quae", + "blanditiis", + "atque", + "deserunt", + "laboriosam", + "earum", + "consequuntur", + "hic", + "cupiditate", + "quibusdam", + "accusamus", + "ut", + "rerum", + "error", + "minus", + "eius", + "ab", + "ad", + "nemo", + "fugit", + "officia", + "at", + "in", + "id", + "quos", + "reprehenderit", + "numquam", + "iste", + "fugiat", + "sit", + "inventore", + "beatae", + "repellendus", + "magnam", + "recusandae", + "quod", + "explicabo", + "doloremque", + "aperiam", + "consequatur", + "asperiores", + "commodi", + "optio", + "dolor", + "labore", + "temporibus", + "repellat", + "veniam", + "architecto", + "est", + "esse", + "mollitia", + "nulla", + "a", + "similique", + "eos", + "alias", + "dolore", + "tenetur", + "deleniti", + "porro", + "facere", + "maxime", + "corrupti", ) COMMON_WORDS = ( - 'lorem', 'ipsum', 'dolor', 'sit', 'amet', 'consectetur', - 'adipisicing', 'elit', 'sed', 'do', 'eiusmod', 'tempor', 'incididunt', - 'ut', 'labore', 'et', 'dolore', 'magna', 'aliqua', + "lorem", + "ipsum", + "dolor", + "sit", + "amet", + "consectetur", + "adipisicing", + "elit", + "sed", + "do", + "eiusmod", + "tempor", + "incididunt", + "ut", + "labore", + "et", + "dolore", + "magna", + "aliqua", ) @@ -62,10 +231,13 @@ def sentence(): """ # Determine the number of comma-separated sections and number of words in # each section for this sentence. - sections = [' '.join(random.sample(WORDS, random.randint(3, 12))) for i in range(random.randint(1, 5))] - s = ', '.join(sections) + sections = [ + " ".join(random.sample(WORDS, random.randint(3, 12))) + for i in range(random.randint(1, 5)) + ] + s = ", ".join(sections) # Convert to sentence case and add end punctuation. - return '%s%s%s' % (s[0].upper(), s[1:], random.choice('?.')) + return "%s%s%s" % (s[0].upper(), s[1:], random.choice("?.")) def paragraph(): @@ -74,7 +246,7 @@ def paragraph(): The paragraph consists of between 1 and 4 sentences, inclusive. """ - return ' '.join(sentence() for i in range(random.randint(1, 4))) + return " ".join(sentence() for i in range(random.randint(1, 4))) def paragraphs(count, common=True): @@ -111,4 +283,4 @@ def words(count, common=True): word_list += random.sample(WORDS, c) else: word_list = word_list[:count] - return ' '.join(word_list) + return " ".join(word_list) diff --git a/django/utils/module_loading.py b/django/utils/module_loading.py index bf099cba96..cb579e7f8c 100644 --- a/django/utils/module_loading.py +++ b/django/utils/module_loading.py @@ -8,9 +8,9 @@ from importlib.util import find_spec as importlib_find def cached_import(module_path, class_name): # Check whether module is loaded and fully initialized. if not ( - (module := sys.modules.get(module_path)) and - (spec := getattr(module, '__spec__', None)) and - getattr(spec, '_initializing', False) is False + (module := sys.modules.get(module_path)) + and (spec := getattr(module, "__spec__", None)) + and getattr(spec, "_initializing", False) is False ): module = import_module(module_path) return getattr(module, class_name) @@ -22,15 +22,16 @@ def import_string(dotted_path): last name in the path. Raise ImportError if the import failed. """ try: - module_path, class_name = dotted_path.rsplit('.', 1) + module_path, class_name = dotted_path.rsplit(".", 1) except ValueError as err: raise ImportError("%s doesn't look like a module path" % dotted_path) from err try: return cached_import(module_path, class_name) except AttributeError as err: - raise ImportError('Module "%s" does not define a "%s" attribute/class' % ( - module_path, class_name) + raise ImportError( + 'Module "%s" does not define a "%s" attribute/class' + % (module_path, class_name) ) from err @@ -46,7 +47,7 @@ def autodiscover_modules(*args, **kwargs): """ from django.apps import apps - register_to = kwargs.get('register_to') + register_to = kwargs.get("register_to") for app_config in apps.get_app_configs(): for module_to_search in args: # Attempt to import the app's module. @@ -54,7 +55,7 @@ def autodiscover_modules(*args, **kwargs): if register_to: before_import_registry = copy.copy(register_to._registry) - import_module('%s.%s' % (app_config.name, module_to_search)) + import_module("%s.%s" % (app_config.name, module_to_search)) except Exception: # Reset the registry to the state before the last import # as this import will have to reoccur on the next request and @@ -79,7 +80,7 @@ def module_has_submodule(package, module_name): # package isn't a package. return False - full_module_name = package_name + '.' + module_name + full_module_name = package_name + "." + module_name try: return importlib_find(full_module_name, package_path) is not None except ModuleNotFoundError: @@ -96,11 +97,11 @@ def module_dir(module): over several directories. """ # Convert to list because __path__ may not support indexing. - paths = list(getattr(module, '__path__', [])) + paths = list(getattr(module, "__path__", [])) if len(paths) == 1: return paths[0] else: - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename is not None: return os.path.dirname(filename) raise ValueError("Cannot determine directory containing %s" % module) diff --git a/django/utils/numberformat.py b/django/utils/numberformat.py index 3bfdb2ea52..488d6a77cd 100644 --- a/django/utils/numberformat.py +++ b/django/utils/numberformat.py @@ -4,8 +4,15 @@ from django.conf import settings from django.utils.safestring import mark_safe -def format(number, decimal_sep, decimal_pos=None, grouping=0, thousand_sep='', - force_grouping=False, use_l10n=None): +def format( + number, + decimal_sep, + decimal_pos=None, + grouping=0, + thousand_sep="", + force_grouping=False, + use_l10n=None, +): """ Get a number (as a number or string), and return it as a string, using formats defined as arguments: @@ -18,54 +25,61 @@ def format(number, decimal_sep, decimal_pos=None, grouping=0, thousand_sep='', module in locale.localeconv() LC_NUMERIC grouping (e.g. (3, 2, 0)). * thousand_sep: Thousand separator symbol (for example ",") """ - use_grouping = (use_l10n or (use_l10n is None and settings.USE_L10N)) and settings.USE_THOUSAND_SEPARATOR + use_grouping = ( + use_l10n or (use_l10n is None and settings.USE_L10N) + ) and settings.USE_THOUSAND_SEPARATOR use_grouping = use_grouping or force_grouping use_grouping = use_grouping and grouping != 0 # Make the common case fast if isinstance(number, int) and not use_grouping and not decimal_pos: return mark_safe(number) # sign - sign = '' + sign = "" # Treat potentially very large/small floats as Decimals. - if isinstance(number, float) and 'e' in str(number).lower(): + if isinstance(number, float) and "e" in str(number).lower(): number = Decimal(str(number)) if isinstance(number, Decimal): if decimal_pos is not None: # If the provided number is too small to affect any of the visible # decimal places, consider it equal to '0'. - cutoff = Decimal('0.' + '1'.rjust(decimal_pos, '0')) + cutoff = Decimal("0." + "1".rjust(decimal_pos, "0")) if abs(number) < cutoff: - number = Decimal('0') + number = Decimal("0") # Format values with more than 200 digits (an arbitrary cutoff) using # scientific notation to avoid high memory usage in {:f}'.format(). _, digits, exponent = number.as_tuple() if abs(exponent) + len(digits) > 200: - number = '{:e}'.format(number) - coefficient, exponent = number.split('e') + number = "{:e}".format(number) + coefficient, exponent = number.split("e") # Format the coefficient. coefficient = format( - coefficient, decimal_sep, decimal_pos, grouping, - thousand_sep, force_grouping, use_l10n, + coefficient, + decimal_sep, + decimal_pos, + grouping, + thousand_sep, + force_grouping, + use_l10n, ) - return '{}e{}'.format(coefficient, exponent) + return "{}e{}".format(coefficient, exponent) else: - str_number = '{:f}'.format(number) + str_number = "{:f}".format(number) else: str_number = str(number) - if str_number[0] == '-': - sign = '-' + if str_number[0] == "-": + sign = "-" str_number = str_number[1:] # decimal part - if '.' in str_number: - int_part, dec_part = str_number.split('.') + if "." in str_number: + int_part, dec_part = str_number.split(".") if decimal_pos is not None: dec_part = dec_part[:decimal_pos] else: - int_part, dec_part = str_number, '' + int_part, dec_part = str_number, "" if decimal_pos is not None: - dec_part = dec_part + ('0' * (decimal_pos - len(dec_part))) + dec_part = dec_part + ("0" * (decimal_pos - len(dec_part))) dec_part = dec_part and decimal_sep + dec_part # grouping if use_grouping: @@ -76,7 +90,7 @@ def format(number, decimal_sep, decimal_pos=None, grouping=0, thousand_sep='', # grouping is a single value intervals = [grouping, 0] active_interval = intervals.pop(0) - int_part_gd = '' + int_part_gd = "" cnt = 0 for digit in int_part[::-1]: if cnt and cnt == active_interval: diff --git a/django/utils/regex_helper.py b/django/utils/regex_helper.py index 8612475b96..9ee82e1a9b 100644 --- a/django/utils/regex_helper.py +++ b/django/utils/regex_helper.py @@ -73,23 +73,23 @@ def normalize(pattern): try: ch, escaped = next(pattern_iter) except StopIteration: - return [('', [])] + return [("", [])] try: while True: if escaped: result.append(ch) - elif ch == '.': + elif ch == ".": # Replace "any character" with an arbitrary representative. result.append(".") - elif ch == '|': + elif ch == "|": # FIXME: One day we'll should do this, but not in 1.0. - raise NotImplementedError('Awaiting Implementation') + raise NotImplementedError("Awaiting Implementation") elif ch == "^": pass - elif ch == '$': + elif ch == "$": break - elif ch == ')': + elif ch == ")": # This can only be the end of a non-capturing group, since all # other unescaped parentheses are handled by the grouping # section later (and the full group is handled there). @@ -99,17 +99,17 @@ def normalize(pattern): start = non_capturing_groups.pop() inner = NonCapture(result[start:]) result = result[:start] + [inner] - elif ch == '[': + elif ch == "[": # Replace ranges with the first character in the range. ch, escaped = next(pattern_iter) result.append(ch) ch, escaped = next(pattern_iter) - while escaped or ch != ']': + while escaped or ch != "]": ch, escaped = next(pattern_iter) - elif ch == '(': + elif ch == "(": # Some kind of group. ch, escaped = next(pattern_iter) - if ch != '?' or escaped: + if ch != "?" or escaped: # A positional group name = "_%d" % num_args num_args += 1 @@ -117,37 +117,39 @@ def normalize(pattern): walk_to_end(ch, pattern_iter) else: ch, escaped = next(pattern_iter) - if ch in '!=<': + if ch in "!=<": # All of these are ignorable. Walk to the end of the # group. walk_to_end(ch, pattern_iter) - elif ch == ':': + elif ch == ":": # Non-capturing group non_capturing_groups.append(len(result)) - elif ch != 'P': + elif ch != "P": # Anything else, other than a named group, is something # we cannot reverse. raise ValueError("Non-reversible reg-exp portion: '(?%s'" % ch) else: ch, escaped = next(pattern_iter) - if ch not in ('<', '='): - raise ValueError("Non-reversible reg-exp portion: '(?P%s'" % ch) + if ch not in ("<", "="): + raise ValueError( + "Non-reversible reg-exp portion: '(?P%s'" % ch + ) # We are in a named capturing group. Extra the name and # then skip to the end. - if ch == '<': - terminal_char = '>' + if ch == "<": + terminal_char = ">" # We are in a named backreference. else: - terminal_char = ')' + terminal_char = ")" name = [] ch, escaped = next(pattern_iter) while ch != terminal_char: name.append(ch) ch, escaped = next(pattern_iter) - param = ''.join(name) + param = "".join(name) # Named backreferences have already consumed the # parenthesis. - if terminal_char != ')': + if terminal_char != ")": result.append(Group((("%%(%s)s" % param), param))) walk_to_end(ch, pattern_iter) else: @@ -185,7 +187,7 @@ def normalize(pattern): pass except NotImplementedError: # A case of using the disjunctive form. No results for you! - return [('', [])] + return [("", [])] return list(zip(*flatten_result(result))) @@ -201,7 +203,7 @@ def next_char(input_iter): raw (unescaped) character or not. """ for ch in input_iter: - if ch != '\\': + if ch != "\\": yield ch, False continue ch = next(input_iter) @@ -217,16 +219,16 @@ def walk_to_end(ch, input_iter): this group, skipping over any nested groups and handling escaped parentheses correctly. """ - if ch == '(': + if ch == "(": nesting = 1 else: nesting = 0 for ch, escaped in input_iter: if escaped: continue - elif ch == '(': + elif ch == "(": nesting += 1 - elif ch == ')': + elif ch == ")": if not nesting: return nesting -= 1 @@ -241,30 +243,30 @@ def get_quantifier(ch, input_iter): either None or the next character from the input_iter if the next character is not part of the quantifier. """ - if ch in '*?+': + if ch in "*?+": try: ch2, escaped = next(input_iter) except StopIteration: ch2 = None - if ch2 == '?': + if ch2 == "?": ch2 = None - if ch == '+': + if ch == "+": return 1, ch2 return 0, ch2 quant = [] - while ch != '}': + while ch != "}": ch, escaped = next(input_iter) quant.append(ch) quant = quant[:-1] - values = ''.join(quant).split(',') + values = "".join(quant).split(",") # Consume the trailing '?', if necessary. try: ch, escaped = next(input_iter) except StopIteration: ch = None - if ch == '?': + if ch == "?": ch = None return int(values[0]), ch @@ -290,20 +292,20 @@ def flatten_result(source): Each of the two lists will be of the same length. """ if source is None: - return [''], [[]] + return [""], [[]] if isinstance(source, Group): if source[1] is None: params = [] else: params = [source[1]] return [source[0]], [params] - result = [''] + result = [""] result_args = [[]] pos = last = 0 for pos, elt in enumerate(source): if isinstance(elt, str): continue - piece = ''.join(source[last:pos]) + piece = "".join(source[last:pos]) if isinstance(elt, Group): piece += elt[0] param = elt[1] @@ -331,7 +333,7 @@ def flatten_result(source): result = new_result result_args = new_args if pos >= last: - piece = ''.join(source[last:]) + piece = "".join(source[last:]) for i in range(len(result)): result[i] += piece return result, result_args @@ -339,13 +341,13 @@ def flatten_result(source): def _lazy_re_compile(regex, flags=0): """Lazily compile a regex with flags.""" + def _compile(): # Compile the regex if it was not passed pre-compiled. if isinstance(regex, (str, bytes)): return re.compile(regex, flags) else: - assert not flags, ( - 'flags must be empty if regex is passed pre-compiled' - ) + assert not flags, "flags must be empty if regex is passed pre-compiled" return regex + return SimpleLazyObject(_compile) diff --git a/django/utils/safestring.py b/django/utils/safestring.py index c061717889..b7d1adff62 100644 --- a/django/utils/safestring.py +++ b/django/utils/safestring.py @@ -49,6 +49,7 @@ def _safety_decorator(safety_marker, func): @wraps(func) def wrapped(*args, **kwargs): return safety_marker(func(*args, **kwargs)) + return wrapped @@ -61,7 +62,7 @@ def mark_safe(s): Can be called multiple times on a single string. """ - if hasattr(s, '__html__'): + if hasattr(s, "__html__"): return s if callable(s): return _safety_decorator(mark_safe, s) diff --git a/django/utils/termcolors.py b/django/utils/termcolors.py index 089ae63ef0..3d1eee6e41 100644 --- a/django/utils/termcolors.py +++ b/django/utils/termcolors.py @@ -2,15 +2,21 @@ termcolors.py """ -color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') -foreground = {color_names[x]: '3%s' % x for x in range(8)} -background = {color_names[x]: '4%s' % x for x in range(8)} +color_names = ("black", "red", "green", "yellow", "blue", "magenta", "cyan", "white") +foreground = {color_names[x]: "3%s" % x for x in range(8)} +background = {color_names[x]: "4%s" % x for x in range(8)} -RESET = '0' -opt_dict = {'bold': '1', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'} +RESET = "0" +opt_dict = { + "bold": "1", + "underscore": "4", + "blink": "5", + "reverse": "7", + "conceal": "8", +} -def colorize(text='', opts=(), **kwargs): +def colorize(text="", opts=(), **kwargs): """ Return your text, enclosed in ANSI graphics codes. @@ -40,19 +46,19 @@ def colorize(text='', opts=(), **kwargs): print('this should not be red') """ code_list = [] - if text == '' and len(opts) == 1 and opts[0] == 'reset': - return '\x1b[%sm' % RESET + if text == "" and len(opts) == 1 and opts[0] == "reset": + return "\x1b[%sm" % RESET for k, v in kwargs.items(): - if k == 'fg': + if k == "fg": code_list.append(foreground[v]) - elif k == 'bg': + elif k == "bg": code_list.append(background[v]) for o in opts: if o in opt_dict: code_list.append(opt_dict[o]) - if 'noreset' not in opts: - text = '%s\x1b[%sm' % (text or '', RESET) - return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '') + if "noreset" not in opts: + text = "%s\x1b[%sm" % (text or "", RESET) + return "%s%s" % (("\x1b[%sm" % ";".join(code_list)), text or "") def make_style(opts=(), **kwargs): @@ -68,68 +74,68 @@ def make_style(opts=(), **kwargs): return lambda text: colorize(text, opts, **kwargs) -NOCOLOR_PALETTE = 'nocolor' -DARK_PALETTE = 'dark' -LIGHT_PALETTE = 'light' +NOCOLOR_PALETTE = "nocolor" +DARK_PALETTE = "dark" +LIGHT_PALETTE = "light" PALETTES = { NOCOLOR_PALETTE: { - 'ERROR': {}, - 'SUCCESS': {}, - 'WARNING': {}, - 'NOTICE': {}, - 'SQL_FIELD': {}, - 'SQL_COLTYPE': {}, - 'SQL_KEYWORD': {}, - 'SQL_TABLE': {}, - 'HTTP_INFO': {}, - 'HTTP_SUCCESS': {}, - 'HTTP_REDIRECT': {}, - 'HTTP_NOT_MODIFIED': {}, - 'HTTP_BAD_REQUEST': {}, - 'HTTP_NOT_FOUND': {}, - 'HTTP_SERVER_ERROR': {}, - 'MIGRATE_HEADING': {}, - 'MIGRATE_LABEL': {}, + "ERROR": {}, + "SUCCESS": {}, + "WARNING": {}, + "NOTICE": {}, + "SQL_FIELD": {}, + "SQL_COLTYPE": {}, + "SQL_KEYWORD": {}, + "SQL_TABLE": {}, + "HTTP_INFO": {}, + "HTTP_SUCCESS": {}, + "HTTP_REDIRECT": {}, + "HTTP_NOT_MODIFIED": {}, + "HTTP_BAD_REQUEST": {}, + "HTTP_NOT_FOUND": {}, + "HTTP_SERVER_ERROR": {}, + "MIGRATE_HEADING": {}, + "MIGRATE_LABEL": {}, }, DARK_PALETTE: { - 'ERROR': {'fg': 'red', 'opts': ('bold',)}, - 'SUCCESS': {'fg': 'green', 'opts': ('bold',)}, - 'WARNING': {'fg': 'yellow', 'opts': ('bold',)}, - 'NOTICE': {'fg': 'red'}, - 'SQL_FIELD': {'fg': 'green', 'opts': ('bold',)}, - 'SQL_COLTYPE': {'fg': 'green'}, - 'SQL_KEYWORD': {'fg': 'yellow'}, - 'SQL_TABLE': {'opts': ('bold',)}, - 'HTTP_INFO': {'opts': ('bold',)}, - 'HTTP_SUCCESS': {}, - 'HTTP_REDIRECT': {'fg': 'green'}, - 'HTTP_NOT_MODIFIED': {'fg': 'cyan'}, - 'HTTP_BAD_REQUEST': {'fg': 'red', 'opts': ('bold',)}, - 'HTTP_NOT_FOUND': {'fg': 'yellow'}, - 'HTTP_SERVER_ERROR': {'fg': 'magenta', 'opts': ('bold',)}, - 'MIGRATE_HEADING': {'fg': 'cyan', 'opts': ('bold',)}, - 'MIGRATE_LABEL': {'opts': ('bold',)}, + "ERROR": {"fg": "red", "opts": ("bold",)}, + "SUCCESS": {"fg": "green", "opts": ("bold",)}, + "WARNING": {"fg": "yellow", "opts": ("bold",)}, + "NOTICE": {"fg": "red"}, + "SQL_FIELD": {"fg": "green", "opts": ("bold",)}, + "SQL_COLTYPE": {"fg": "green"}, + "SQL_KEYWORD": {"fg": "yellow"}, + "SQL_TABLE": {"opts": ("bold",)}, + "HTTP_INFO": {"opts": ("bold",)}, + "HTTP_SUCCESS": {}, + "HTTP_REDIRECT": {"fg": "green"}, + "HTTP_NOT_MODIFIED": {"fg": "cyan"}, + "HTTP_BAD_REQUEST": {"fg": "red", "opts": ("bold",)}, + "HTTP_NOT_FOUND": {"fg": "yellow"}, + "HTTP_SERVER_ERROR": {"fg": "magenta", "opts": ("bold",)}, + "MIGRATE_HEADING": {"fg": "cyan", "opts": ("bold",)}, + "MIGRATE_LABEL": {"opts": ("bold",)}, }, LIGHT_PALETTE: { - 'ERROR': {'fg': 'red', 'opts': ('bold',)}, - 'SUCCESS': {'fg': 'green', 'opts': ('bold',)}, - 'WARNING': {'fg': 'yellow', 'opts': ('bold',)}, - 'NOTICE': {'fg': 'red'}, - 'SQL_FIELD': {'fg': 'green', 'opts': ('bold',)}, - 'SQL_COLTYPE': {'fg': 'green'}, - 'SQL_KEYWORD': {'fg': 'blue'}, - 'SQL_TABLE': {'opts': ('bold',)}, - 'HTTP_INFO': {'opts': ('bold',)}, - 'HTTP_SUCCESS': {}, - 'HTTP_REDIRECT': {'fg': 'green', 'opts': ('bold',)}, - 'HTTP_NOT_MODIFIED': {'fg': 'green'}, - 'HTTP_BAD_REQUEST': {'fg': 'red', 'opts': ('bold',)}, - 'HTTP_NOT_FOUND': {'fg': 'red'}, - 'HTTP_SERVER_ERROR': {'fg': 'magenta', 'opts': ('bold',)}, - 'MIGRATE_HEADING': {'fg': 'cyan', 'opts': ('bold',)}, - 'MIGRATE_LABEL': {'opts': ('bold',)}, - } + "ERROR": {"fg": "red", "opts": ("bold",)}, + "SUCCESS": {"fg": "green", "opts": ("bold",)}, + "WARNING": {"fg": "yellow", "opts": ("bold",)}, + "NOTICE": {"fg": "red"}, + "SQL_FIELD": {"fg": "green", "opts": ("bold",)}, + "SQL_COLTYPE": {"fg": "green"}, + "SQL_KEYWORD": {"fg": "blue"}, + "SQL_TABLE": {"opts": ("bold",)}, + "HTTP_INFO": {"opts": ("bold",)}, + "HTTP_SUCCESS": {}, + "HTTP_REDIRECT": {"fg": "green", "opts": ("bold",)}, + "HTTP_NOT_MODIFIED": {"fg": "green"}, + "HTTP_BAD_REQUEST": {"fg": "red", "opts": ("bold",)}, + "HTTP_NOT_FOUND": {"fg": "red"}, + "HTTP_SERVER_ERROR": {"fg": "magenta", "opts": ("bold",)}, + "MIGRATE_HEADING": {"fg": "cyan", "opts": ("bold",)}, + "MIGRATE_LABEL": {"opts": ("bold",)}, + }, } DEFAULT_PALETTE = DARK_PALETTE @@ -169,39 +175,39 @@ def parse_color_setting(config_string): return PALETTES[DEFAULT_PALETTE] # Split the color configuration into parts - parts = config_string.lower().split(';') + parts = config_string.lower().split(";") palette = PALETTES[NOCOLOR_PALETTE].copy() for part in parts: if part in PALETTES: # A default palette has been specified palette.update(PALETTES[part]) - elif '=' in part: + elif "=" in part: # Process a palette defining string definition = {} # Break the definition into the role, # plus the list of specific instructions. # The role must be in upper case - role, instructions = part.split('=') + role, instructions = part.split("=") role = role.upper() - styles = instructions.split(',') + styles = instructions.split(",") styles.reverse() # The first instruction can contain a slash # to break apart fg/bg. - colors = styles.pop().split('/') + colors = styles.pop().split("/") colors.reverse() fg = colors.pop() if fg in color_names: - definition['fg'] = fg + definition["fg"] = fg if colors and colors[-1] in color_names: - definition['bg'] = colors[-1] + definition["bg"] = colors[-1] # All remaining instructions are options opts = tuple(s for s in styles if s in opt_dict) if opts: - definition['opts'] = opts + definition["opts"] = opts # The nocolor palette has all available roles. # Use that palette as the basis for determining diff --git a/django/utils/text.py b/django/utils/text.py index 915da50483..dcfe3fba0e 100644 --- a/django/utils/text.py +++ b/django/utils/text.py @@ -1,12 +1,14 @@ import re import unicodedata -from gzip import GzipFile, compress as gzip_compress +from gzip import GzipFile +from gzip import compress as gzip_compress from io import BytesIO from django.core.exceptions import SuspiciousFileOperation from django.utils.functional import SimpleLazyObject, keep_lazy_text, lazy from django.utils.regex_helper import _lazy_re_compile -from django.utils.translation import gettext as _, gettext_lazy, pgettext +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy, pgettext @keep_lazy_text @@ -20,11 +22,11 @@ def capfirst(x): # Set up regular expressions -re_words = _lazy_re_compile(r'<[^>]+?>|([^<>\s]+)', re.S) -re_chars = _lazy_re_compile(r'<[^>]+?>|(.)', re.S) -re_tag = _lazy_re_compile(r'<(/)?(\S+?)(?:(\s*/)|\s.*?)?>', re.S) -re_newlines = _lazy_re_compile(r'\r\n|\r') # Used in normalize_newlines -re_camel_case = _lazy_re_compile(r'(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))') +re_words = _lazy_re_compile(r"<[^>]+?>|([^<>\s]+)", re.S) +re_chars = _lazy_re_compile(r"<[^>]+?>|(.)", re.S) +re_tag = _lazy_re_compile(r"<(/)?(\S+?)(?:(\s*/)|\s.*?)?>", re.S) +re_newlines = _lazy_re_compile(r"\r\n|\r") # Used in normalize_newlines +re_camel_case = _lazy_re_compile(r"(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))") @keep_lazy_text @@ -39,46 +41,49 @@ def wrap(text, width): Don't wrap long words, thus the output text may have lines longer than ``width``. """ + def _generator(): for line in text.splitlines(True): # True keeps trailing linebreaks - max_width = min((line.endswith('\n') and width + 1 or width), width) + max_width = min((line.endswith("\n") and width + 1 or width), width) while len(line) > max_width: - space = line[:max_width + 1].rfind(' ') + 1 + space = line[: max_width + 1].rfind(" ") + 1 if space == 0: - space = line.find(' ') + 1 + space = line.find(" ") + 1 if space == 0: yield line - line = '' + line = "" break - yield '%s\n' % line[:space - 1] + yield "%s\n" % line[: space - 1] line = line[space:] - max_width = min((line.endswith('\n') and width + 1 or width), width) + max_width = min((line.endswith("\n") and width + 1 or width), width) if line: yield line - return ''.join(_generator()) + + return "".join(_generator()) class Truncator(SimpleLazyObject): """ An object used to truncate text, either by characters or words. """ + def __init__(self, text): super().__init__(lambda: str(text)) def add_truncation_text(self, text, truncate=None): if truncate is None: truncate = pgettext( - 'String to return when truncating text', - '%(truncated_text)s…') - if '%(truncated_text)s' in truncate: - return truncate % {'truncated_text': text} + "String to return when truncating text", "%(truncated_text)s…" + ) + if "%(truncated_text)s" in truncate: + return truncate % {"truncated_text": text} # The truncation text didn't contain the %(truncated_text)s string # replacement argument so just append it to the text. if text.endswith(truncate): # But don't append the truncation text if the current text already # ends in this. return text - return '%s%s' % (text, truncate) + return "%s%s" % (text, truncate) def chars(self, num, truncate=None, html=False): """ @@ -90,11 +95,11 @@ class Truncator(SimpleLazyObject): """ self._setup() length = int(num) - text = unicodedata.normalize('NFC', self._wrapped) + text = unicodedata.normalize("NFC", self._wrapped) # Calculate the length to truncate to (max length - end_text length) truncate_len = length - for char in self.add_truncation_text('', truncate): + for char in self.add_truncation_text("", truncate): if not unicodedata.combining(char): truncate_len -= 1 if truncate_len == 0: @@ -117,8 +122,7 @@ class Truncator(SimpleLazyObject): end_index = i if s_len > length: # Return the truncated string - return self.add_truncation_text(text[:end_index or 0], - truncate) + return self.add_truncation_text(text[: end_index or 0], truncate) # Return the original string since no truncation was necessary return text @@ -144,8 +148,8 @@ class Truncator(SimpleLazyObject): words = self._wrapped.split() if len(words) > length: words = words[:length] - return self.add_truncation_text(' '.join(words), truncate) - return ' '.join(words) + return self.add_truncation_text(" ".join(words), truncate) + return " ".join(words) def _truncate_html(self, length, truncate, text, truncate_len, words): """ @@ -156,11 +160,18 @@ class Truncator(SimpleLazyObject): Preserve newlines in the HTML. """ if words and length <= 0: - return '' + return "" html4_singlets = ( - 'br', 'col', 'link', 'base', 'img', - 'param', 'area', 'hr', 'input' + "br", + "col", + "link", + "base", + "img", + "param", + "area", + "hr", + "input", ) # Count non-HTML chars/words and keep note of open tags @@ -202,7 +213,7 @@ class Truncator(SimpleLazyObject): else: # SGML: An end tag closes, back to the matching start tag, # all unclosed intervening start tags with omitted end tags - open_tags = open_tags[i + 1:] + open_tags = open_tags[i + 1 :] else: # Add it to the start of the open tags list open_tags.insert(0, tagname) @@ -210,12 +221,12 @@ class Truncator(SimpleLazyObject): if current_len <= length: return text out = text[:end_text_pos] - truncate_text = self.add_truncation_text('', truncate) + truncate_text = self.add_truncation_text("", truncate) if truncate_text: out += truncate_text # Close any tags still open for tag in open_tags: - out += '</%s>' % tag + out += "</%s>" % tag # Return string return out @@ -230,15 +241,15 @@ def get_valid_filename(name): >>> get_valid_filename("john's portrait in 2004.jpg") 'johns_portrait_in_2004.jpg' """ - s = str(name).strip().replace(' ', '_') - s = re.sub(r'(?u)[^-\w.]', '', s) - if s in {'', '.', '..'}: + s = str(name).strip().replace(" ", "_") + s = re.sub(r"(?u)[^-\w.]", "", s) + if s in {"", ".", ".."}: raise SuspiciousFileOperation("Could not derive file name from '%s'" % name) return s @keep_lazy_text -def get_text_list(list_, last_word=gettext_lazy('or')): +def get_text_list(list_, last_word=gettext_lazy("or")): """ >>> get_text_list(['a', 'b', 'c', 'd']) 'a, b, c or d' @@ -252,31 +263,55 @@ def get_text_list(list_, last_word=gettext_lazy('or')): '' """ if not list_: - return '' + return "" if len(list_) == 1: return str(list_[0]) - return '%s %s %s' % ( + return "%s %s %s" % ( # Translators: This string is used as a separator between list elements - _(', ').join(str(i) for i in list_[:-1]), str(last_word), str(list_[-1]) + _(", ").join(str(i) for i in list_[:-1]), + str(last_word), + str(list_[-1]), ) @keep_lazy_text def normalize_newlines(text): """Normalize CRLF and CR newlines to just LF.""" - return re_newlines.sub('\n', str(text)) + return re_newlines.sub("\n", str(text)) @keep_lazy_text def phone2numeric(phone): """Convert a phone number with letters into its numeric equivalent.""" char2number = { - 'a': '2', 'b': '2', 'c': '2', 'd': '3', 'e': '3', 'f': '3', 'g': '4', - 'h': '4', 'i': '4', 'j': '5', 'k': '5', 'l': '5', 'm': '6', 'n': '6', - 'o': '6', 'p': '7', 'q': '7', 'r': '7', 's': '7', 't': '8', 'u': '8', - 'v': '8', 'w': '9', 'x': '9', 'y': '9', 'z': '9', + "a": "2", + "b": "2", + "c": "2", + "d": "3", + "e": "3", + "f": "3", + "g": "4", + "h": "4", + "i": "4", + "j": "5", + "k": "5", + "l": "5", + "m": "6", + "n": "6", + "o": "6", + "p": "7", + "q": "7", + "r": "7", + "s": "7", + "t": "8", + "u": "8", + "v": "8", + "w": "9", + "x": "9", + "y": "9", + "z": "9", } - return ''.join(char2number.get(c, c) for c in phone.lower()) + return "".join(char2number.get(c, c) for c in phone.lower()) def compress_string(s): @@ -294,7 +329,7 @@ class StreamingBuffer(BytesIO): # Like compress_string, but for iterators of strings. def compress_sequence(sequence): buf = StreamingBuffer() - with GzipFile(mode='wb', compresslevel=6, fileobj=buf, mtime=0) as zfile: + with GzipFile(mode="wb", compresslevel=6, fileobj=buf, mtime=0) as zfile: # Output headers... yield buf.read() for item in sequence: @@ -307,7 +342,8 @@ def compress_sequence(sequence): # Expression to match some_token and some_token="with spaces" (and similarly # for single-quoted strings). -smart_split_re = _lazy_re_compile(r""" +smart_split_re = _lazy_re_compile( + r""" ((?: [^\s'"]* (?: @@ -315,7 +351,9 @@ smart_split_re = _lazy_re_compile(r""" [^\s'"]* )+ ) | \S+) -""", re.VERBOSE) +""", + re.VERBOSE, +) def smart_split(text): @@ -355,7 +393,7 @@ def unescape_string_literal(s): if not s or s[0] not in "\"'" or s[-1] != s[0]: raise ValueError("Not a string literal: %r" % s) quote = s[0] - return s[1:-1].replace(r'\%s' % quote, quote).replace(r'\\', '\\') + return s[1:-1].replace(r"\%s" % quote, quote).replace(r"\\", "\\") @keep_lazy_text @@ -368,18 +406,22 @@ def slugify(value, allow_unicode=False): """ value = str(value) if allow_unicode: - value = unicodedata.normalize('NFKC', value) + value = unicodedata.normalize("NFKC", value) else: - value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') - value = re.sub(r'[^\w\s-]', '', value.lower()) - return re.sub(r'[-\s]+', '-', value).strip('-_') + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") def camel_case_to_spaces(value): """ Split CamelCase and convert to lowercase. Strip surrounding whitespace. """ - return re_camel_case.sub(r' \1', value).strip().lower() + return re_camel_case.sub(r" \1", value).strip().lower() def _format_lazy(format_string, *args, **kwargs): diff --git a/django/utils/timesince.py b/django/utils/timesince.py index 157dcb72c2..8a8ffb8151 100644 --- a/django/utils/timesince.py +++ b/django/utils/timesince.py @@ -6,21 +6,21 @@ from django.utils.timezone import is_aware, utc from django.utils.translation import gettext, ngettext_lazy TIME_STRINGS = { - 'year': ngettext_lazy('%(num)d year', '%(num)d years', 'num'), - 'month': ngettext_lazy('%(num)d month', '%(num)d months', 'num'), - 'week': ngettext_lazy('%(num)d week', '%(num)d weeks', 'num'), - 'day': ngettext_lazy('%(num)d day', '%(num)d days', 'num'), - 'hour': ngettext_lazy('%(num)d hour', '%(num)d hours', 'num'), - 'minute': ngettext_lazy('%(num)d minute', '%(num)d minutes', 'num'), + "year": ngettext_lazy("%(num)d year", "%(num)d years", "num"), + "month": ngettext_lazy("%(num)d month", "%(num)d months", "num"), + "week": ngettext_lazy("%(num)d week", "%(num)d weeks", "num"), + "day": ngettext_lazy("%(num)d day", "%(num)d days", "num"), + "hour": ngettext_lazy("%(num)d hour", "%(num)d hours", "num"), + "minute": ngettext_lazy("%(num)d minute", "%(num)d minutes", "num"), } TIMESINCE_CHUNKS = ( - (60 * 60 * 24 * 365, 'year'), - (60 * 60 * 24 * 30, 'month'), - (60 * 60 * 24 * 7, 'week'), - (60 * 60 * 24, 'day'), - (60 * 60, 'hour'), - (60, 'minute'), + (60 * 60 * 24 * 365, "year"), + (60 * 60 * 24 * 30, "month"), + (60 * 60 * 24 * 7, "week"), + (60 * 60 * 24, "day"), + (60 * 60, "hour"), + (60, "minute"), ) @@ -47,7 +47,7 @@ def timesince(d, now=None, reversed=False, time_strings=None, depth=2): if time_strings is None: time_strings = TIME_STRINGS if depth <= 0: - raise ValueError('depth must be greater than 0.') + raise ValueError("depth must be greater than 0.") # Convert datetime.date to datetime.datetime for comparison. if not isinstance(d, datetime.datetime): d = datetime.datetime(d.year, d.month, d.day) @@ -73,13 +73,13 @@ def timesince(d, now=None, reversed=False, time_strings=None, depth=2): since = delta.days * 24 * 60 * 60 + delta.seconds if since <= 0: # d is in the future compared to now, stop processing. - return avoid_wrapping(time_strings['minute'] % {'num': 0}) + return avoid_wrapping(time_strings["minute"] % {"num": 0}) for i, (seconds, name) in enumerate(TIMESINCE_CHUNKS): count = since // seconds if count != 0: break else: - return avoid_wrapping(time_strings['minute'] % {'num': 0}) + return avoid_wrapping(time_strings["minute"] % {"num": 0}) result = [] current_depth = 0 while i < len(TIMESINCE_CHUNKS) and current_depth < depth: @@ -87,11 +87,11 @@ def timesince(d, now=None, reversed=False, time_strings=None, depth=2): count = since // seconds if count == 0: break - result.append(avoid_wrapping(time_strings[name] % {'num': count})) + result.append(avoid_wrapping(time_strings[name] % {"num": count})) since -= seconds * count current_depth += 1 i += 1 - return gettext(', ').join(result) + return gettext(", ").join(result) def timeuntil(d, now=None, time_strings=None, depth=2): diff --git a/django/utils/timezone.py b/django/utils/timezone.py index 9572c99bac..71b160448e 100644 --- a/django/utils/timezone.py +++ b/django/utils/timezone.py @@ -20,12 +20,21 @@ from django.conf import settings from django.utils.deprecation import RemovedInDjango50Warning __all__ = [ - 'utc', 'get_fixed_timezone', - 'get_default_timezone', 'get_default_timezone_name', - 'get_current_timezone', 'get_current_timezone_name', - 'activate', 'deactivate', 'override', - 'localtime', 'now', - 'is_aware', 'is_naive', 'make_aware', 'make_naive', + "utc", + "get_fixed_timezone", + "get_default_timezone", + "get_default_timezone_name", + "get_current_timezone", + "get_current_timezone_name", + "activate", + "deactivate", + "override", + "localtime", + "now", + "is_aware", + "is_naive", + "make_aware", + "make_naive", ] # RemovedInDjango50Warning: sentinel for deprecation of is_dst parameters. @@ -39,8 +48,8 @@ def get_fixed_timezone(offset): """Return a tzinfo instance with a fixed offset from UTC.""" if isinstance(offset, timedelta): offset = offset.total_seconds() // 60 - sign = '-' if offset < 0 else '+' - hhmm = '%02d%02d' % divmod(abs(offset), 60) + sign = "-" if offset < 0 else "+" + hhmm = "%02d%02d" % divmod(abs(offset), 60) name = sign + hhmm return timezone(timedelta(minutes=offset), name) @@ -56,6 +65,7 @@ def get_default_timezone(): """ if settings.USE_DEPRECATED_PYTZ: import pytz + return pytz.timezone(settings.TIME_ZONE) return zoneinfo.ZoneInfo(settings.TIME_ZONE) @@ -86,6 +96,7 @@ def _get_timezone_name(timezone): """ return timezone.tzname(None) or str(timezone) + # Timezone selection functions. # These functions don't change os.environ['TZ'] and call time.tzset() @@ -104,6 +115,7 @@ def activate(timezone): elif isinstance(timezone, str): if settings.USE_DEPRECATED_PYTZ: import pytz + _active.value = pytz.timezone(timezone) else: _active.value = zoneinfo.ZoneInfo(timezone) @@ -133,11 +145,12 @@ class override(ContextDecorator): time zone name, or ``None``. If it is ``None``, Django enables the default time zone. """ + def __init__(self, timezone): self.timezone = timezone def __enter__(self): - self.old_timezone = getattr(_active, 'value', None) + self.old_timezone = getattr(_active, "value", None) if self.timezone is None: deactivate() else: @@ -152,6 +165,7 @@ class override(ContextDecorator): # Templates + def template_localtime(value, use_tz=None): """ Check if value is a datetime and converts it to local time if necessary. @@ -162,16 +176,17 @@ def template_localtime(value, use_tz=None): This function is designed for use by the template engine. """ should_convert = ( - isinstance(value, datetime) and - (settings.USE_TZ if use_tz is None else use_tz) and - not is_naive(value) and - getattr(value, 'convert_to_local_time', True) + isinstance(value, datetime) + and (settings.USE_TZ if use_tz is None else use_tz) + and not is_naive(value) + and getattr(value, "convert_to_local_time", True) ) return localtime(value) if should_convert else value # Utilities + def localtime(value=None, timezone=None): """ Convert an aware datetime.datetime to local time. @@ -215,6 +230,7 @@ def now(): # By design, these four functions don't perform any checks on their arguments. # The caller should ensure that they don't receive an invalid value like None. + def is_aware(value): """ Determine if a given datetime.datetime is aware. @@ -247,9 +263,9 @@ def make_aware(value, timezone=None, is_dst=NOT_PASSED): is_dst = None else: warnings.warn( - 'The is_dst argument to make_aware(), used by the Trunc() ' - 'database functions and QuerySet.datetimes(), is deprecated as it ' - 'has no effect with zoneinfo time zones.', + "The is_dst argument to make_aware(), used by the Trunc() " + "database functions and QuerySet.datetimes(), is deprecated as it " + "has no effect with zoneinfo time zones.", RemovedInDjango50Warning, ) if timezone is None: @@ -260,8 +276,7 @@ def make_aware(value, timezone=None, is_dst=NOT_PASSED): else: # Check that we won't overwrite the timezone of an aware datetime. if is_aware(value): - raise ValueError( - "make_aware expects a naive datetime, got %s" % value) + raise ValueError("make_aware expects a naive datetime, got %s" % value) # This may be wrong around DST changes! return value.replace(tzinfo=timezone) @@ -315,6 +330,7 @@ def _is_pytz_zone(tz): def _datetime_ambiguous_or_imaginary(dt, tz): if _is_pytz_zone(tz): import pytz + try: tz.utcoffset(dt) except (pytz.AmbiguousTimeError, pytz.NonExistentTimeError): diff --git a/django/utils/topological_sort.py b/django/utils/topological_sort.py index f7ce0e0d1d..66b6866ec8 100644 --- a/django/utils/topological_sort.py +++ b/django/utils/topological_sort.py @@ -17,14 +17,20 @@ def topological_sort_as_sets(dependency_graph): current = {node for node, deps in todo.items() if not deps} if not current: - raise CyclicDependencyError('Cyclic dependency in graph: {}'.format( - ', '.join(repr(x) for x in todo.items()))) + raise CyclicDependencyError( + "Cyclic dependency in graph: {}".format( + ", ".join(repr(x) for x in todo.items()) + ) + ) yield current # remove current from todo's nodes & dependencies - todo = {node: (dependencies - current) for node, dependencies in - todo.items() if node not in current} + todo = { + node: (dependencies - current) + for node, dependencies in todo.items() + if node not in current + } def stable_topological_sort(nodes, dependency_graph): diff --git a/django/utils/translation/__init__.py b/django/utils/translation/__init__.py index 29ac60ad1d..6b8cc73d5c 100644 --- a/django/utils/translation/__init__.py +++ b/django/utils/translation/__init__.py @@ -9,14 +9,27 @@ from django.utils.functional import lazy from django.utils.regex_helper import _lazy_re_compile __all__ = [ - 'activate', 'deactivate', 'override', 'deactivate_all', - 'get_language', 'get_language_from_request', - 'get_language_info', 'get_language_bidi', - 'check_for_language', 'to_language', 'to_locale', 'templatize', - 'gettext', 'gettext_lazy', 'gettext_noop', - 'ngettext', 'ngettext_lazy', - 'pgettext', 'pgettext_lazy', - 'npgettext', 'npgettext_lazy', + "activate", + "deactivate", + "override", + "deactivate_all", + "get_language", + "get_language_from_request", + "get_language_info", + "get_language_bidi", + "check_for_language", + "to_language", + "to_locale", + "templatize", + "gettext", + "gettext_lazy", + "gettext_noop", + "ngettext", + "ngettext_lazy", + "pgettext", + "pgettext_lazy", + "npgettext", + "npgettext_lazy", ] @@ -32,6 +45,7 @@ class TranslatorCommentWarning(SyntaxWarning): # replace the functions with their real counterparts (once we do access the # settings). + class Trans: """ The purpose of this class is to store the actual translation function upon @@ -47,13 +61,20 @@ class Trans: def __getattr__(self, real_name): from django.conf import settings + if settings.USE_I18N: from django.utils.translation import trans_real as trans from django.utils.translation.reloader import ( - translation_file_changed, watch_for_translation_changes, + translation_file_changed, + watch_for_translation_changes, + ) + + autoreload_started.connect( + watch_for_translation_changes, dispatch_uid="translation_file_changed" + ) + file_changed.connect( + translation_file_changed, dispatch_uid="translation_file_changed" ) - autoreload_started.connect(watch_for_translation_changes, dispatch_uid='translation_file_changed') - file_changed.connect(translation_file_changed, dispatch_uid='translation_file_changed') else: from django.utils.translation import trans_null as trans setattr(self, real_name, getattr(trans, real_name)) @@ -92,31 +113,33 @@ pgettext_lazy = lazy(pgettext, str) def lazy_number(func, resultclass, number=None, **kwargs): if isinstance(number, int): - kwargs['number'] = number + kwargs["number"] = number proxy = lazy(func, resultclass)(**kwargs) else: original_kwargs = kwargs.copy() class NumberAwareString(resultclass): def __bool__(self): - return bool(kwargs['singular']) + return bool(kwargs["singular"]) def _get_number_value(self, values): try: return values[number] except KeyError: raise KeyError( - "Your dictionary lacks key '%s\'. Please provide " + "Your dictionary lacks key '%s'. Please provide " "it, because it is required to determine whether " "string is singular or plural." % number ) def _translate(self, number_value): - kwargs['number'] = number_value + kwargs["number"] = number_value return func(**kwargs) def format(self, *args, **kwargs): - number_value = self._get_number_value(kwargs) if kwargs and number else args[0] + number_value = ( + self._get_number_value(kwargs) if kwargs and number else args[0] + ) return self._translate(number_value).format(*args, **kwargs) def __mod__(self, rhs): @@ -133,7 +156,10 @@ def lazy_number(func, resultclass, number=None, **kwargs): return translated proxy = lazy(lambda **kwargs: NumberAwareString(), NumberAwareString)(**kwargs) - proxy.__reduce__ = lambda: (_lazy_number_unpickle, (func, resultclass, number, original_kwargs)) + proxy.__reduce__ = lambda: ( + _lazy_number_unpickle, + (func, resultclass, number, original_kwargs), + ) return proxy @@ -146,7 +172,9 @@ def ngettext_lazy(singular, plural, number=None): def npgettext_lazy(context, singular, plural, number=None): - return lazy_number(npgettext, str, context=context, singular=singular, plural=plural, number=number) + return lazy_number( + npgettext, str, context=context, singular=singular, plural=plural, number=number + ) def activate(language): @@ -192,27 +220,27 @@ def check_for_language(lang_code): def to_language(locale): """Turn a locale name (en_US) into a language name (en-us).""" - p = locale.find('_') + p = locale.find("_") if p >= 0: - return locale[:p].lower() + '-' + locale[p + 1:].lower() + return locale[:p].lower() + "-" + locale[p + 1 :].lower() else: return locale.lower() def to_locale(language): """Turn a language name (en-us) into a locale name (en_US).""" - lang, _, country = language.lower().partition('-') + lang, _, country = language.lower().partition("-") if not country: return language[:3].lower() + language[3:] # A language with > 2 characters after the dash only has its first # character after the dash capitalized; e.g. sr-latn becomes sr_Latn. # A language with 2 characters after the dash has both characters # capitalized; e.g. en-us becomes en_US. - country, _, tail = country.partition('-') + country, _, tail = country.partition("-") country = country.title() if len(country) > 2 else country.upper() if tail: - country += '-' + tail - return lang + '_' + country + country += "-" + tail + return lang + "_" + country def get_language_from_request(request, check_path=False): @@ -229,6 +257,7 @@ def get_supported_language_variant(lang_code, *, strict=False): def templatize(src, **kwargs): from .template import templatize + return templatize(src, **kwargs) @@ -238,32 +267,35 @@ def deactivate_all(): def get_language_info(lang_code): from django.conf.locale import LANG_INFO + try: lang_info = LANG_INFO[lang_code] - if 'fallback' in lang_info and 'name' not in lang_info: - info = get_language_info(lang_info['fallback'][0]) + if "fallback" in lang_info and "name" not in lang_info: + info = get_language_info(lang_info["fallback"][0]) else: info = lang_info except KeyError: - if '-' not in lang_code: + if "-" not in lang_code: raise KeyError("Unknown language code %s." % lang_code) - generic_lang_code = lang_code.split('-')[0] + generic_lang_code = lang_code.split("-")[0] try: info = LANG_INFO[generic_lang_code] except KeyError: - raise KeyError("Unknown language code %s and %s." % (lang_code, generic_lang_code)) + raise KeyError( + "Unknown language code %s and %s." % (lang_code, generic_lang_code) + ) if info: - info['name_translated'] = gettext_lazy(info['name']) + info["name_translated"] = gettext_lazy(info["name"]) return info -trim_whitespace_re = _lazy_re_compile(r'\s*\n\s*') +trim_whitespace_re = _lazy_re_compile(r"\s*\n\s*") def trim_whitespace(s): - return trim_whitespace_re.sub(' ', s.strip()) + return trim_whitespace_re.sub(" ", s.strip()) def round_away_from_one(value): - return int(Decimal(value - 1).quantize(Decimal('0'), rounding=ROUND_UP)) + 1 + return int(Decimal(value - 1).quantize(Decimal("0"), rounding=ROUND_UP)) + 1 diff --git a/django/utils/translation/reloader.py b/django/utils/translation/reloader.py index d8afa89270..be05ccc860 100644 --- a/django/utils/translation/reloader.py +++ b/django/utils/translation/reloader.py @@ -11,23 +11,24 @@ def watch_for_translation_changes(sender, **kwargs): from django.conf import settings if settings.USE_I18N: - directories = [Path('locale')] + directories = [Path("locale")] directories.extend( - Path(config.path) / 'locale' + Path(config.path) / "locale" for config in apps.get_app_configs() if not is_django_module(config.module) ) directories.extend(Path(p) for p in settings.LOCALE_PATHS) for path in directories: - sender.watch_dir(path, '**/*.mo') + sender.watch_dir(path, "**/*.mo") def translation_file_changed(sender, file_path, **kwargs): """Clear the internal translations cache if a .mo file is modified.""" - if file_path.suffix == '.mo': + if file_path.suffix == ".mo": import gettext from django.utils.translation import trans_real + gettext._translations = {} trans_real._translations = {} trans_real._default = None diff --git a/django/utils/translation/template.py b/django/utils/translation/template.py index 588f538cb2..d7353a3028 100644 --- a/django/utils/translation/template.py +++ b/django/utils/translation/template.py @@ -6,9 +6,9 @@ from django.utils.regex_helper import _lazy_re_compile from . import TranslatorCommentWarning, trim_whitespace -TRANSLATOR_COMMENT_MARK = 'Translators' +TRANSLATOR_COMMENT_MARK = "Translators" -dot_re = _lazy_re_compile(r'\S') +dot_re = _lazy_re_compile(r"\S") def blankout(src, char): @@ -28,7 +28,9 @@ inline_re = _lazy_re_compile( # Match the optional context part r"""(\s+.*context\s+((?:"[^"]*?")|(?:'[^']*?')))?\s*""" ) -block_re = _lazy_re_compile(r"""^\s*blocktrans(?:late)?(\s+.*context\s+((?:"[^"]*?")|(?:'[^']*?')))?(?:\s+|$)""") +block_re = _lazy_re_compile( + r"""^\s*blocktrans(?:late)?(\s+.*context\s+((?:"[^"]*?")|(?:'[^']*?')))?(?:\s+|$)""" +) endblock_re = _lazy_re_compile(r"""^\s*endblocktrans(?:late)?$""") plural_re = _lazy_re_compile(r"""^\s*plural$""") constant_re = _lazy_re_compile(r"""_\(((?:".*?")|(?:'.*?'))\)""") @@ -40,7 +42,7 @@ def templatize(src, origin=None): does so by translating the Django translation tags into standard gettext function invocations. """ - out = StringIO('') + out = StringIO("") message_context = None intrans = False inplural = False @@ -52,27 +54,30 @@ def templatize(src, origin=None): lineno_comment_map = {} comment_lineno_cache = None # Adding the u prefix allows gettext to recognize the string (#26093). - raw_prefix = 'u' + raw_prefix = "u" def join_tokens(tokens, trim=False): - message = ''.join(tokens) + message = "".join(tokens) if trim: message = trim_whitespace(message) return message for t in Lexer(src).tokenize(): if incomment: - if t.token_type == TokenType.BLOCK and t.contents == 'endcomment': - content = ''.join(comment) + if t.token_type == TokenType.BLOCK and t.contents == "endcomment": + content = "".join(comment) translators_comment_start = None for lineno, line in enumerate(content.splitlines(True)): if line.lstrip().startswith(TRANSLATOR_COMMENT_MARK): translators_comment_start = lineno for lineno, line in enumerate(content.splitlines(True)): - if translators_comment_start is not None and lineno >= translators_comment_start: - out.write(' # %s' % line) + if ( + translators_comment_start is not None + and lineno >= translators_comment_start + ): + out.write(" # %s" % line) else: - out.write(' #\n') + out.write(" #\n") incomment = False comment = [] else: @@ -84,36 +89,44 @@ def templatize(src, origin=None): if endbmatch: if inplural: if message_context: - out.write(' npgettext({p}{!r}, {p}{!r}, {p}{!r},count) '.format( - message_context, - join_tokens(singular, trimmed), - join_tokens(plural, trimmed), - p=raw_prefix, - )) + out.write( + " npgettext({p}{!r}, {p}{!r}, {p}{!r},count) ".format( + message_context, + join_tokens(singular, trimmed), + join_tokens(plural, trimmed), + p=raw_prefix, + ) + ) else: - out.write(' ngettext({p}{!r}, {p}{!r}, count) '.format( - join_tokens(singular, trimmed), - join_tokens(plural, trimmed), - p=raw_prefix, - )) + out.write( + " ngettext({p}{!r}, {p}{!r}, count) ".format( + join_tokens(singular, trimmed), + join_tokens(plural, trimmed), + p=raw_prefix, + ) + ) for part in singular: - out.write(blankout(part, 'S')) + out.write(blankout(part, "S")) for part in plural: - out.write(blankout(part, 'P')) + out.write(blankout(part, "P")) else: if message_context: - out.write(' pgettext({p}{!r}, {p}{!r}) '.format( - message_context, - join_tokens(singular, trimmed), - p=raw_prefix, - )) + out.write( + " pgettext({p}{!r}, {p}{!r}) ".format( + message_context, + join_tokens(singular, trimmed), + p=raw_prefix, + ) + ) else: - out.write(' gettext({p}{!r}) '.format( - join_tokens(singular, trimmed), - p=raw_prefix, - )) + out.write( + " gettext({p}{!r}) ".format( + join_tokens(singular, trimmed), + p=raw_prefix, + ) + ) for part in singular: - out.write(blankout(part, 'S')) + out.write(blankout(part, "S")) message_context = None intrans = False inplural = False @@ -122,20 +135,20 @@ def templatize(src, origin=None): elif pluralmatch: inplural = True else: - filemsg = '' + filemsg = "" if origin: - filemsg = 'file %s, ' % origin + filemsg = "file %s, " % origin raise SyntaxError( "Translation blocks must not include other block tags: " "%s (%sline %d)" % (t.contents, filemsg, t.lineno) ) elif t.token_type == TokenType.VAR: if inplural: - plural.append('%%(%s)s' % t.contents) + plural.append("%%(%s)s" % t.contents) else: - singular.append('%%(%s)s' % t.contents) + singular.append("%%(%s)s" % t.contents) elif t.token_type == TokenType.TEXT: - contents = t.contents.replace('%', '%%') + contents = t.contents.replace("%", "%%") if inplural: plural.append(contents) else: @@ -144,13 +157,13 @@ def templatize(src, origin=None): # Handle comment tokens (`{# ... #}`) plus other constructs on # the same line: if comment_lineno_cache is not None: - cur_lineno = t.lineno + t.contents.count('\n') + cur_lineno = t.lineno + t.contents.count("\n") if comment_lineno_cache == cur_lineno: if t.token_type != TokenType.COMMENT: for c in lineno_comment_map[comment_lineno_cache]: - filemsg = '' + filemsg = "" if origin: - filemsg = 'file %s, ' % origin + filemsg = "file %s, " % origin warn_msg = ( "The translator-targeted comment '%s' " "(%sline %d) was ignored, because it wasn't " @@ -159,7 +172,9 @@ def templatize(src, origin=None): warnings.warn(warn_msg, TranslatorCommentWarning) lineno_comment_map[comment_lineno_cache] = [] else: - out.write('# %s' % ' | '.join(lineno_comment_map[comment_lineno_cache])) + out.write( + "# %s" % " | ".join(lineno_comment_map[comment_lineno_cache]) + ) comment_lineno_cache = None if t.token_type == TokenType.BLOCK: @@ -172,7 +187,7 @@ def templatize(src, origin=None): g = g.strip('"') elif g[0] == "'": g = g.strip("'") - g = g.replace('%', '%%') + g = g.replace("%", "%%") if imatch[2]: # A context is provided context_match = context_re.match(imatch[2]) @@ -181,15 +196,17 @@ def templatize(src, origin=None): message_context = message_context.strip('"') elif message_context[0] == "'": message_context = message_context.strip("'") - out.write(' pgettext({p}{!r}, {p}{!r}) '.format( - message_context, g, p=raw_prefix - )) + out.write( + " pgettext({p}{!r}, {p}{!r}) ".format( + message_context, g, p=raw_prefix + ) + ) message_context = None else: - out.write(' gettext({p}{!r}) '.format(g, p=raw_prefix)) + out.write(" gettext({p}{!r}) ".format(g, p=raw_prefix)) elif bmatch: for fmatch in constant_re.findall(t.contents): - out.write(' _(%s) ' % fmatch) + out.write(" _(%s) " % fmatch) if bmatch[1]: # A context is provided context_match = context_re.match(bmatch[1]) @@ -200,30 +217,30 @@ def templatize(src, origin=None): message_context = message_context.strip("'") intrans = True inplural = False - trimmed = 'trimmed' in t.split_contents() + trimmed = "trimmed" in t.split_contents() singular = [] plural = [] elif cmatches: for cmatch in cmatches: - out.write(' _(%s) ' % cmatch) - elif t.contents == 'comment': + out.write(" _(%s) " % cmatch) + elif t.contents == "comment": incomment = True else: - out.write(blankout(t.contents, 'B')) + out.write(blankout(t.contents, "B")) elif t.token_type == TokenType.VAR: - parts = t.contents.split('|') + parts = t.contents.split("|") cmatch = constant_re.match(parts[0]) if cmatch: - out.write(' _(%s) ' % cmatch[1]) + out.write(" _(%s) " % cmatch[1]) for p in parts[1:]: - if p.find(':_(') >= 0: - out.write(' %s ' % p.split(':', 1)[1]) + if p.find(":_(") >= 0: + out.write(" %s " % p.split(":", 1)[1]) else: - out.write(blankout(p, 'F')) + out.write(blankout(p, "F")) elif t.token_type == TokenType.COMMENT: if t.contents.lstrip().startswith(TRANSLATOR_COMMENT_MARK): lineno_comment_map.setdefault(t.lineno, []).append(t.contents) comment_lineno_cache = t.lineno else: - out.write(blankout(t.contents, 'X')) + out.write(blankout(t.contents, "X")) return out.getvalue() diff --git a/django/utils/translation/trans_real.py b/django/utils/translation/trans_real.py index 97efd40c6d..52110af83a 100644 --- a/django/utils/translation/trans_real.py +++ b/django/utils/translation/trans_real.py @@ -32,18 +32,20 @@ CONTEXT_SEPARATOR = "\x04" # Format of Accept-Language header values. From RFC 2616, section 14.4 and 3.9 # and RFC 3066, section 2.1 -accept_language_re = _lazy_re_compile(r''' +accept_language_re = _lazy_re_compile( + r""" ([A-Za-z]{1,8}(?:-[A-Za-z0-9]{1,8})*|\*) # "en", "en-au", "x-y-z", "es-419", "*" (?:\s*;\s*q=(0(?:\.[0-9]{,3})?|1(?:\.0{,3})?))? # Optional "q=1.00", "q=0.8" (?:\s*,\s*|$) # Multiple accepts per header. - ''', re.VERBOSE) - -language_code_re = _lazy_re_compile( - r'^[a-z]{1,8}(?:-[a-z0-9]{1,8})*(?:@[a-z0-9]{1,20})?$', - re.IGNORECASE + """, + re.VERBOSE, ) -language_code_prefix_re = _lazy_re_compile(r'^/(\w+([@-]\w+){0,2})(/|$)') +language_code_re = _lazy_re_compile( + r"^[a-z]{1,8}(?:-[a-z0-9]{1,8})*(?:@[a-z0-9]{1,20})?$", re.IGNORECASE +) + +language_code_prefix_re = _lazy_re_compile(r"^/(\w+([@-]\w+){0,2})(/|$)") @receiver(setting_changed) @@ -52,7 +54,7 @@ def reset_cache(*, setting, **kwargs): Reset global state when LANGUAGES setting has been changed, as some languages should no longer be accepted. """ - if setting in ('LANGUAGES', 'LANGUAGE_CODE'): + if setting in ("LANGUAGES", "LANGUAGE_CODE"): check_for_language.cache_clear() get_languages.cache_clear() get_supported_language_variant.cache_clear() @@ -63,6 +65,7 @@ class TranslationCatalog: Simulate a dict for DjangoTranslation._catalog so as multiple catalogs with different plural equations are kept separate. """ + def __init__(self, trans=None): self._catalogs = [trans._catalog.copy()] if trans else [{}] self._plurals = [trans.plural] if trans else [lambda n: int(n != 1)] @@ -124,7 +127,8 @@ class DjangoTranslation(gettext_module.GNUTranslations): requested language and add a fallback to the default language, if it's different from the requested language. """ - domain = 'django' + + domain = "django" def __init__(self, language, domain=None, localedirs=None): """Create a GNUTranslations() using many locale directories""" @@ -140,10 +144,12 @@ class DjangoTranslation(gettext_module.GNUTranslations): # pluralization: anything except one is pluralized. self.plural = lambda n: int(n != 1) - if self.domain == 'django': + if self.domain == "django": if localedirs is not None: # A module-level cache is used for caching 'django' translations - warnings.warn("localedirs is ignored when domain is 'django'.", RuntimeWarning) + warnings.warn( + "localedirs is ignored when domain is 'django'.", RuntimeWarning + ) localedirs = None self._init_translation_catalog() @@ -155,9 +161,16 @@ class DjangoTranslation(gettext_module.GNUTranslations): self._add_installed_apps_translations() self._add_local_translations() - if self.__language == settings.LANGUAGE_CODE and self.domain == 'django' and self._catalog is None: + if ( + self.__language == settings.LANGUAGE_CODE + and self.domain == "django" + and self._catalog is None + ): # default lang should have at least one translation file available. - raise OSError('No translation files found for default language %s.' % settings.LANGUAGE_CODE) + raise OSError( + "No translation files found for default language %s." + % settings.LANGUAGE_CODE + ) self._add_fallback(localedirs) if self._catalog is None: # No catalogs found for this language, set an empty catalog. @@ -184,7 +197,7 @@ class DjangoTranslation(gettext_module.GNUTranslations): def _init_translation_catalog(self): """Create a base catalog using global django translations.""" settingsfile = sys.modules[settings.__module__].__file__ - localedir = os.path.join(os.path.dirname(settingsfile), 'locale') + localedir = os.path.join(os.path.dirname(settingsfile), "locale") translation = self._new_gnu_trans(localedir) self.merge(translation) @@ -196,9 +209,10 @@ class DjangoTranslation(gettext_module.GNUTranslations): raise AppRegistryNotReady( "The translation infrastructure cannot be initialized before the " "apps registry is ready. Check that you don't make non-lazy " - "gettext calls at import time.") + "gettext calls at import time." + ) for app_config in app_configs: - localedir = os.path.join(app_config.path, 'locale') + localedir = os.path.join(app_config.path, "locale") if os.path.exists(localedir): translation = self._new_gnu_trans(localedir) self.merge(translation) @@ -213,9 +227,11 @@ class DjangoTranslation(gettext_module.GNUTranslations): """Set the GNUTranslations() fallback with the default language.""" # Don't set a fallback for the default language or any English variant # (as it's empty, so it'll ALWAYS fall back to the default language) - if self.__language == settings.LANGUAGE_CODE or self.__language.startswith('en'): + if self.__language == settings.LANGUAGE_CODE or self.__language.startswith( + "en" + ): return - if self.domain == 'django': + if self.domain == "django": # Get from cache default_translation = translation(settings.LANGUAGE_CODE) else: @@ -226,7 +242,7 @@ class DjangoTranslation(gettext_module.GNUTranslations): def merge(self, other): """Merge another translation into this catalog.""" - if not getattr(other, '_catalog', None): + if not getattr(other, "_catalog", None): return # NullTranslations() has no _catalog if self._catalog is None: # Take plural and _info from first catalog found (generally Django's). @@ -321,7 +337,7 @@ def get_language_bidi(): if lang is None: return False else: - base_lang = get_language().split('-')[0] + base_lang = get_language().split("-")[0] return base_lang in settings.LANGUAGES_BIDI @@ -349,7 +365,7 @@ def gettext(message): """ global _default - eol_message = message.replace('\r\n', '\n').replace('\r', '\n') + eol_message = message.replace("\r\n", "\n").replace("\r", "\n") if eol_message: _default = _default or translation(settings.LANGUAGE_CODE) @@ -359,7 +375,7 @@ def gettext(message): else: # Return an empty value of the corresponding type if an empty message # is given, instead of metadata, which is the default gettext behavior. - result = type(message)('') + result = type(message)("") if isinstance(message, SafeData): return mark_safe(result) @@ -404,13 +420,15 @@ def ngettext(singular, plural, number): Return a string of the translation of either the singular or plural, based on the number. """ - return do_ntranslate(singular, plural, number, 'ngettext') + return do_ntranslate(singular, plural, number, "ngettext") def npgettext(context, singular, plural, number): - msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, singular), - "%s%s%s" % (context, CONTEXT_SEPARATOR, plural), - number) + msgs_with_ctxt = ( + "%s%s%s" % (context, CONTEXT_SEPARATOR, singular), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural), + number, + ) result = ngettext(*msgs_with_ctxt) if CONTEXT_SEPARATOR in result: # Translation not found @@ -423,10 +441,11 @@ def all_locale_paths(): Return a list of paths to user-provides languages files. """ globalpath = os.path.join( - os.path.dirname(sys.modules[settings.__module__].__file__), 'locale') + os.path.dirname(sys.modules[settings.__module__].__file__), "locale" + ) app_paths = [] for app_config in apps.get_app_configs(): - locale_path = os.path.join(app_config.path, 'locale') + locale_path = os.path.join(app_config.path, "locale") if os.path.exists(locale_path): app_paths.append(locale_path) return [globalpath, *settings.LOCALE_PATHS, *app_paths] @@ -447,7 +466,7 @@ def check_for_language(lang_code): if lang_code is None or not language_code_re.search(lang_code): return False return any( - gettext_module.find('django', path, [to_locale(lang_code)]) is not None + gettext_module.find("django", path, [to_locale(lang_code)]) is not None for path in all_locale_paths() ) @@ -478,11 +497,11 @@ def get_supported_language_variant(lang_code, strict=False): # language codes i.e. 'zh-hant' and 'zh'. possible_lang_codes = [lang_code] try: - possible_lang_codes.extend(LANG_INFO[lang_code]['fallback']) + possible_lang_codes.extend(LANG_INFO[lang_code]["fallback"]) except KeyError: pass i = None - while (i := lang_code.rfind('-', 0, i)) > -1: + while (i := lang_code.rfind("-", 0, i)) > -1: possible_lang_codes.append(lang_code[:i]) generic_lang_code = possible_lang_codes[-1] supported_lang_codes = get_languages() @@ -493,7 +512,7 @@ def get_supported_language_variant(lang_code, strict=False): if not strict: # if fr-fr is not supported, try fr-ca. for supported_code in supported_lang_codes: - if supported_code.startswith(generic_lang_code + '-'): + if supported_code.startswith(generic_lang_code + "-"): return supported_code raise LookupError(lang_code) @@ -531,7 +550,11 @@ def get_language_from_request(request, check_path=False): return lang_code lang_code = request.COOKIES.get(settings.LANGUAGE_COOKIE_NAME) - if lang_code is not None and lang_code in get_languages() and check_for_language(lang_code): + if ( + lang_code is not None + and lang_code in get_languages() + and check_for_language(lang_code) + ): return lang_code try: @@ -539,9 +562,9 @@ def get_language_from_request(request, check_path=False): except LookupError: pass - accept = request.META.get('HTTP_ACCEPT_LANGUAGE', '') + accept = request.META.get("HTTP_ACCEPT_LANGUAGE", "") for accept_lang, unused in parse_accept_lang_header(accept): - if accept_lang == '*': + if accept_lang == "*": break if not language_code_re.search(accept_lang): @@ -571,7 +594,7 @@ def parse_accept_lang_header(lang_string): if pieces[-1]: return () for i in range(0, len(pieces) - 1, 3): - first, lang, priority = pieces[i:i + 3] + first, lang, priority = pieces[i : i + 3] if first: return () if priority: diff --git a/django/utils/tree.py b/django/utils/tree.py index a56442c32d..f67c90eae4 100644 --- a/django/utils/tree.py +++ b/django/utils/tree.py @@ -14,9 +14,10 @@ class Node: connection (the root) with the children being either leaf nodes or other Node instances. """ + # Standard connector type. Clients usually won't use this at all and # subclasses will usually override the value. - default = 'DEFAULT' + default = "DEFAULT" def __init__(self, children=None, connector=None, negated=False): """Construct a new Node. If no connector is given, use the default.""" @@ -41,8 +42,8 @@ class Node: return obj def __str__(self): - template = '(NOT (%s: %s))' if self.negated else '(%s: %s)' - return template % (self.connector, ', '.join(str(c) for c in self.children)) + template = "(NOT (%s: %s))" if self.negated else "(%s: %s)" + return template % (self.connector, ", ".join(str(c) for c in self.children)) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self) @@ -67,14 +68,21 @@ class Node: def __eq__(self, other): return ( - self.__class__ == other.__class__ and - self.connector == other.connector and - self.negated == other.negated and - self.children == other.children + self.__class__ == other.__class__ + and self.connector == other.connector + and self.negated == other.negated + and self.children == other.children ) def __hash__(self): - return hash((self.__class__, self.connector, self.negated, *make_hashable(self.children))) + return hash( + ( + self.__class__, + self.connector, + self.negated, + *make_hashable(self.children), + ) + ) def add(self, data, conn_type): """ @@ -94,9 +102,9 @@ class Node: self.children = [obj, data] return data elif ( - isinstance(data, Node) and - not data.negated and - (data.connector == conn_type or len(data) == 1) + isinstance(data, Node) + and not data.negated + and (data.connector == conn_type or len(data) == 1) ): # We can squash the other node's children directly into this node. # We are just doing (AB)(CD) == (ABCD) here, with the addition that diff --git a/django/utils/version.py b/django/utils/version.py index 7f4e9e3cce..1e20a86563 100644 --- a/django/utils/version.py +++ b/django/utils/version.py @@ -28,14 +28,14 @@ def get_version(version=None): main = get_main_version(version) - sub = '' - if version[3] == 'alpha' and version[4] == 0: + sub = "" + if version[3] == "alpha" and version[4] == 0: git_changeset = get_git_changeset() if git_changeset: - sub = '.dev%s' % git_changeset + sub = ".dev%s" % git_changeset - elif version[3] != 'final': - mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'rc'} + elif version[3] != "final": + mapping = {"alpha": "a", "beta": "b", "rc": "rc"} sub = mapping[version[3]] + str(version[4]) return main + sub @@ -45,7 +45,7 @@ def get_main_version(version=None): """Return main version (X.Y[.Z]) from VERSION.""" version = get_complete_version(version) parts = 2 if version[2] == 0 else 3 - return '.'.join(str(x) for x in version[:parts]) + return ".".join(str(x) for x in version[:parts]) def get_complete_version(version=None): @@ -57,17 +57,17 @@ def get_complete_version(version=None): from django import VERSION as version else: assert len(version) == 5 - assert version[3] in ('alpha', 'beta', 'rc', 'final') + assert version[3] in ("alpha", "beta", "rc", "final") return version def get_docs_version(version=None): version = get_complete_version(version) - if version[3] != 'final': - return 'dev' + if version[3] != "final": + return "dev" else: - return '%d.%d' % version[:2] + return "%d.%d" % version[:2] @functools.lru_cache @@ -80,12 +80,15 @@ def get_git_changeset(): """ # Repository may not be found if __file__ is undefined, e.g. in a frozen # module. - if '__file__' not in globals(): + if "__file__" not in globals(): return None repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) git_log = subprocess.run( - 'git log --pretty=format:%ct --quiet -1 HEAD', - capture_output=True, shell=True, cwd=repo_dir, text=True, + "git log --pretty=format:%ct --quiet -1 HEAD", + capture_output=True, + shell=True, + cwd=repo_dir, + text=True, ) timestamp = git_log.stdout tz = datetime.timezone.utc @@ -93,10 +96,10 @@ def get_git_changeset(): timestamp = datetime.datetime.fromtimestamp(int(timestamp), tz=tz) except ValueError: return None - return timestamp.strftime('%Y%m%d%H%M%S') + return timestamp.strftime("%Y%m%d%H%M%S") -version_component_re = _lazy_re_compile(r'(\d+|[a-z]+|\.)') +version_component_re = _lazy_re_compile(r"(\d+|[a-z]+|\.)") def get_version_tuple(version): @@ -106,7 +109,7 @@ def get_version_tuple(version): """ version_numbers = [] for item in version_component_re.split(version): - if item and item != '.': + if item and item != ".": try: component = int(item) except ValueError: diff --git a/django/utils/xmlutils.py b/django/utils/xmlutils.py index e4607b9865..c3eb3ba6a3 100644 --- a/django/utils/xmlutils.py +++ b/django/utils/xmlutils.py @@ -21,10 +21,12 @@ class SimplerXMLGenerator(XMLGenerator): self.endElement(name) def characters(self, content): - if content and re.search(r'[\x00-\x08\x0B-\x0C\x0E-\x1F]', content): + if content and re.search(r"[\x00-\x08\x0B-\x0C\x0E-\x1F]", content): # Fail loudly when content has control chars (unsupported in XML 1.0) # See https://www.w3.org/International/questions/qa-controls - raise UnserializableContentError("Control characters are not supported in XML 1.0") + raise UnserializableContentError( + "Control characters are not supported in XML 1.0" + ) XMLGenerator.characters(self, content) def startElement(self, name, attrs): diff --git a/django/views/__init__.py b/django/views/__init__.py index 95b0c6b865..1440d43345 100644 --- a/django/views/__init__.py +++ b/django/views/__init__.py @@ -1,3 +1,3 @@ from django.views.generic.base import View -__all__ = ['View'] +__all__ = ["View"] diff --git a/django/views/csrf.py b/django/views/csrf.py index f40f776701..ec3c6d5d91 100644 --- a/django/views/csrf.py +++ b/django/views/csrf.py @@ -106,40 +106,46 @@ def csrf_failure(request, reason="", template_name=CSRF_FAILURE_TEMPLATE_NAME): Default view used when request fails CSRF protection """ from django.middleware.csrf import REASON_NO_CSRF_COOKIE, REASON_NO_REFERER + c = { - 'title': _("Forbidden"), - 'main': _("CSRF verification failed. Request aborted."), - 'reason': reason, - 'no_referer': reason == REASON_NO_REFERER, - 'no_referer1': _( - 'You are seeing this message because this HTTPS site requires a ' - '“Referer header” to be sent by your web browser, but none was ' - 'sent. This header is required for security reasons, to ensure ' - 'that your browser is not being hijacked by third parties.'), - 'no_referer2': _( - 'If you have configured your browser to disable “Referer” headers, ' - 'please re-enable them, at least for this site, or for HTTPS ' - 'connections, or for “same-origin” requests.'), - 'no_referer3': _( + "title": _("Forbidden"), + "main": _("CSRF verification failed. Request aborted."), + "reason": reason, + "no_referer": reason == REASON_NO_REFERER, + "no_referer1": _( + "You are seeing this message because this HTTPS site requires a " + "“Referer header” to be sent by your web browser, but none was " + "sent. This header is required for security reasons, to ensure " + "that your browser is not being hijacked by third parties." + ), + "no_referer2": _( + "If you have configured your browser to disable “Referer” headers, " + "please re-enable them, at least for this site, or for HTTPS " + "connections, or for “same-origin” requests." + ), + "no_referer3": _( 'If you are using the <meta name="referrer" ' - 'content=\"no-referrer\"> tag or including the “Referrer-Policy: ' - 'no-referrer” header, please remove them. The CSRF protection ' - 'requires the “Referer” header to do strict referer checking. If ' - 'you’re concerned about privacy, use alternatives like ' - '<a rel=\"noreferrer\" …> for links to third-party sites.'), - 'no_cookie': reason == REASON_NO_CSRF_COOKIE, - 'no_cookie1': _( + 'content="no-referrer"> tag or including the “Referrer-Policy: ' + "no-referrer” header, please remove them. The CSRF protection " + "requires the “Referer” header to do strict referer checking. If " + "you’re concerned about privacy, use alternatives like " + '<a rel="noreferrer" …> for links to third-party sites.' + ), + "no_cookie": reason == REASON_NO_CSRF_COOKIE, + "no_cookie1": _( "You are seeing this message because this site requires a CSRF " "cookie when submitting forms. This cookie is required for " "security reasons, to ensure that your browser is not being " - "hijacked by third parties."), - 'no_cookie2': _( - 'If you have configured your browser to disable cookies, please ' - 're-enable them, at least for this site, or for “same-origin” ' - 'requests.'), - 'DEBUG': settings.DEBUG, - 'docs_version': get_docs_version(), - 'more': _("More information is available with DEBUG=True."), + "hijacked by third parties." + ), + "no_cookie2": _( + "If you have configured your browser to disable cookies, please " + "re-enable them, at least for this site, or for “same-origin” " + "requests." + ), + "DEBUG": settings.DEBUG, + "docs_version": get_docs_version(), + "more": _("More information is available with DEBUG=True."), } try: t = loader.get_template(template_name) @@ -151,4 +157,4 @@ def csrf_failure(request, reason="", template_name=CSRF_FAILURE_TEMPLATE_NAME): else: # Raise if a developer-specified template doesn't exist. raise - return HttpResponseForbidden(t.render(c), content_type='text/html') + return HttpResponseForbidden(t.render(c), content_type="text/html") diff --git a/django/views/debug.py b/django/views/debug.py index 34bd57b33d..e30fd65095 100644 --- a/django/views/debug.py +++ b/django/views/debug.py @@ -23,7 +23,7 @@ from django.utils.version import get_docs_version # works even if the template loader is broken. DEBUG_ENGINE = Engine( debug=True, - libraries={'i18n': 'django.templatetags.i18n'}, + libraries={"i18n": "django.templatetags.i18n"}, ) @@ -34,7 +34,7 @@ def builtin_template_path(name): Avoid calling this function at the module level or in a class-definition because __file__ may not exist, e.g. in frozen environments. """ - return Path(__file__).parent / 'templates' / name + return Path(__file__).parent / "templates" / name class ExceptionCycleWarning(UserWarning): @@ -48,6 +48,7 @@ class CallableSettingWrapper: * Not to break the debug page if the callable forbidding to set attributes (#23070). """ + def __init__(self, callable_setting): self._wrapped = callable_setting @@ -61,12 +62,14 @@ def technical_500_response(request, exc_type, exc_value, tb, status_code=500): the values returned from sys.exc_info() and friends. """ reporter = get_exception_reporter_class(request)(request, exc_type, exc_value, tb) - if request.accepts('text/html'): + if request.accepts("text/html"): html = reporter.get_traceback_html() - return HttpResponse(html, status=status_code, content_type='text/html') + return HttpResponse(html, status=status_code, content_type="text/html") else: text = reporter.get_traceback_text() - return HttpResponse(text, status=status_code, content_type='text/plain; charset=utf-8') + return HttpResponse( + text, status=status_code, content_type="text/plain; charset=utf-8" + ) @functools.lru_cache @@ -77,12 +80,16 @@ def get_default_exception_reporter_filter(): def get_exception_reporter_filter(request): default_filter = get_default_exception_reporter_filter() - return getattr(request, 'exception_reporter_filter', default_filter) + return getattr(request, "exception_reporter_filter", default_filter) def get_exception_reporter_class(request): - default_exception_reporter_class = import_string(settings.DEFAULT_EXCEPTION_REPORTER) - return getattr(request, 'exception_reporter_class', default_exception_reporter_class) + default_exception_reporter_class = import_string( + settings.DEFAULT_EXCEPTION_REPORTER + ) + return getattr( + request, "exception_reporter_class", default_exception_reporter_class + ) def get_caller(request): @@ -92,7 +99,7 @@ def get_caller(request): resolver_match = resolve(request.path) except Http404: pass - return '' if resolver_match is None else resolver_match._func_path + return "" if resolver_match is None else resolver_match._func_path class SafeExceptionReporterFilter: @@ -100,8 +107,11 @@ class SafeExceptionReporterFilter: Use annotations made by the sensitive_post_parameters and sensitive_variables decorators to filter out sensitive information. """ - cleansed_substitute = '********************' - hidden_settings = _lazy_re_compile('API|TOKEN|KEY|SECRET|PASS|SIGNATURE', flags=re.I) + + cleansed_substitute = "********************" + hidden_settings = _lazy_re_compile( + "API|TOKEN|KEY|SECRET|PASS|SIGNATURE", flags=re.I + ) def cleanse_setting(self, key, value): """ @@ -118,9 +128,9 @@ class SafeExceptionReporterFilter: elif isinstance(value, dict): cleansed = {k: self.cleanse_setting(k, v) for k, v in value.items()} elif isinstance(value, list): - cleansed = [self.cleanse_setting('', v) for v in value] + cleansed = [self.cleanse_setting("", v) for v in value] elif isinstance(value, tuple): - cleansed = tuple([self.cleanse_setting('', v) for v in value]) + cleansed = tuple([self.cleanse_setting("", v) for v in value]) else: cleansed = value @@ -144,7 +154,7 @@ class SafeExceptionReporterFilter: """ Return a dictionary of request.META with sensitive values redacted. """ - if not hasattr(request, 'META'): + if not hasattr(request, "META"): return {} return {k: self.cleanse_setting(k, v) for k, v in request.META.items()} @@ -163,7 +173,7 @@ class SafeExceptionReporterFilter: This mitigates leaking sensitive POST parameters if something like request.POST['nonexistent_key'] throws an exception (#21098). """ - sensitive_post_parameters = getattr(request, 'sensitive_post_parameters', []) + sensitive_post_parameters = getattr(request, "sensitive_post_parameters", []) if self.is_active(request) and sensitive_post_parameters: multivaluedict = multivaluedict.copy() for param in sensitive_post_parameters: @@ -179,10 +189,12 @@ class SafeExceptionReporterFilter: if request is None: return {} else: - sensitive_post_parameters = getattr(request, 'sensitive_post_parameters', []) + sensitive_post_parameters = getattr( + request, "sensitive_post_parameters", [] + ) if self.is_active(request) and sensitive_post_parameters: cleansed = request.POST.copy() - if sensitive_post_parameters == '__ALL__': + if sensitive_post_parameters == "__ALL__": # Cleanse all parameters. for k in cleansed: cleansed[k] = self.cleansed_substitute @@ -203,7 +215,7 @@ class SafeExceptionReporterFilter: # MultiValueDicts will have a return value. is_multivalue_dict = isinstance(value, MultiValueDict) except Exception as e: - return '{!r} while evaluating {!r}'.format(e, value) + return "{!r} while evaluating {!r}".format(e, value) if is_multivalue_dict: # Cleanse MultiValueDicts (request.POST is the one we usually care about) @@ -220,18 +232,20 @@ class SafeExceptionReporterFilter: current_frame = tb_frame.f_back sensitive_variables = None while current_frame is not None: - if (current_frame.f_code.co_name == 'sensitive_variables_wrapper' and - 'sensitive_variables_wrapper' in current_frame.f_locals): + if ( + current_frame.f_code.co_name == "sensitive_variables_wrapper" + and "sensitive_variables_wrapper" in current_frame.f_locals + ): # The sensitive_variables decorator was used, so we take note # of the sensitive variables' names. - wrapper = current_frame.f_locals['sensitive_variables_wrapper'] - sensitive_variables = getattr(wrapper, 'sensitive_variables', None) + wrapper = current_frame.f_locals["sensitive_variables_wrapper"] + sensitive_variables = getattr(wrapper, "sensitive_variables", None) break current_frame = current_frame.f_back cleansed = {} if self.is_active(request) and sensitive_variables: - if sensitive_variables == '__ALL__': + if sensitive_variables == "__ALL__": # Cleanse all variables for name in tb_frame.f_locals: cleansed[name] = self.cleansed_substitute @@ -249,14 +263,16 @@ class SafeExceptionReporterFilter: for name, value in tb_frame.f_locals.items(): cleansed[name] = self.cleanse_special_types(request, value) - if (tb_frame.f_code.co_name == 'sensitive_variables_wrapper' and - 'sensitive_variables_wrapper' in tb_frame.f_locals): + if ( + tb_frame.f_code.co_name == "sensitive_variables_wrapper" + and "sensitive_variables_wrapper" in tb_frame.f_locals + ): # For good measure, obfuscate the decorated function's arguments in # the sensitive_variables decorator's frame, in case the variables # associated with those arguments were meant to be obfuscated from # the decorated function's frame. - cleansed['func_args'] = self.cleansed_substitute - cleansed['func_kwargs'] = self.cleansed_substitute + cleansed["func_args"] = self.cleansed_substitute + cleansed["func_kwargs"] = self.cleansed_substitute return cleansed.items() @@ -266,11 +282,11 @@ class ExceptionReporter: @property def html_template_path(self): - return builtin_template_path('technical_500.html') + return builtin_template_path("technical_500.html") @property def text_template_path(self): - return builtin_template_path('technical_500.txt') + return builtin_template_path("technical_500.txt") def __init__(self, request, exc_type, exc_value, tb, is_email=False): self.request = request @@ -280,7 +296,7 @@ class ExceptionReporter: self.tb = tb self.is_email = is_email - self.template_info = getattr(self.exc_value, 'template_debug', None) + self.template_info = getattr(self.exc_value, "template_debug", None) self.template_does_not_exist = False self.postmortem = None @@ -289,7 +305,7 @@ class ExceptionReporter: Return an absolute URI from variables available in this request. Skip allowed hosts protection, so may return insecure URI. """ - return '{scheme}://{host}{path}'.format( + return "{scheme}://{host}{path}".format( scheme=self.request.scheme, host=self.request._get_raw_host(), path=self.request.get_full_path(), @@ -303,26 +319,27 @@ class ExceptionReporter: frames = self.get_traceback_frames() for i, frame in enumerate(frames): - if 'vars' in frame: + if "vars" in frame: frame_vars = [] - for k, v in frame['vars']: + for k, v in frame["vars"]: v = pprint(v) # Trim large blobs of data if len(v) > 4096: - v = '%s… <trimmed %d bytes string>' % (v[0:4096], len(v)) + v = "%s… <trimmed %d bytes string>" % (v[0:4096], len(v)) frame_vars.append((k, v)) - frame['vars'] = frame_vars + frame["vars"] = frame_vars frames[i] = frame - unicode_hint = '' + unicode_hint = "" if self.exc_type and issubclass(self.exc_type, UnicodeError): - start = getattr(self.exc_value, 'start', None) - end = getattr(self.exc_value, 'end', None) + start = getattr(self.exc_value, "start", None) + end = getattr(self.exc_value, "end", None) if start is not None and end is not None: unicode_str = self.exc_value.args[1] unicode_hint = force_str( - unicode_str[max(start - 5, 0):min(end + 5, len(unicode_str))], - 'ascii', errors='replace' + unicode_str[max(start - 5, 0) : min(end + 5, len(unicode_str))], + "ascii", + errors="replace", ) from django import get_version @@ -334,59 +351,61 @@ class ExceptionReporter: except Exception: # request.user may raise OperationalError if the database is # unavailable, for example. - user_str = '[unable to retrieve the current user]' + user_str = "[unable to retrieve the current user]" c = { - 'is_email': self.is_email, - 'unicode_hint': unicode_hint, - 'frames': frames, - 'request': self.request, - 'request_meta': self.filter.get_safe_request_meta(self.request), - 'user_str': user_str, - 'filtered_POST_items': list(self.filter.get_post_parameters(self.request).items()), - 'settings': self.filter.get_safe_settings(), - 'sys_executable': sys.executable, - 'sys_version_info': '%d.%d.%d' % sys.version_info[0:3], - 'server_time': timezone.now(), - 'django_version_info': get_version(), - 'sys_path': sys.path, - 'template_info': self.template_info, - 'template_does_not_exist': self.template_does_not_exist, - 'postmortem': self.postmortem, + "is_email": self.is_email, + "unicode_hint": unicode_hint, + "frames": frames, + "request": self.request, + "request_meta": self.filter.get_safe_request_meta(self.request), + "user_str": user_str, + "filtered_POST_items": list( + self.filter.get_post_parameters(self.request).items() + ), + "settings": self.filter.get_safe_settings(), + "sys_executable": sys.executable, + "sys_version_info": "%d.%d.%d" % sys.version_info[0:3], + "server_time": timezone.now(), + "django_version_info": get_version(), + "sys_path": sys.path, + "template_info": self.template_info, + "template_does_not_exist": self.template_does_not_exist, + "postmortem": self.postmortem, } if self.request is not None: - c['request_GET_items'] = self.request.GET.items() - c['request_FILES_items'] = self.request.FILES.items() - c['request_COOKIES_items'] = self.request.COOKIES.items() - c['request_insecure_uri'] = self._get_raw_insecure_uri() - c['raising_view_name'] = get_caller(self.request) + c["request_GET_items"] = self.request.GET.items() + c["request_FILES_items"] = self.request.FILES.items() + c["request_COOKIES_items"] = self.request.COOKIES.items() + c["request_insecure_uri"] = self._get_raw_insecure_uri() + c["raising_view_name"] = get_caller(self.request) # Check whether exception info is available if self.exc_type: - c['exception_type'] = self.exc_type.__name__ + c["exception_type"] = self.exc_type.__name__ if self.exc_value: - c['exception_value'] = str(self.exc_value) + c["exception_value"] = str(self.exc_value) if frames: - c['lastframe'] = frames[-1] + c["lastframe"] = frames[-1] return c def get_traceback_html(self): """Return HTML version of debug 500 HTTP error page.""" - with self.html_template_path.open(encoding='utf-8') as fh: + with self.html_template_path.open(encoding="utf-8") as fh: t = DEBUG_ENGINE.from_string(fh.read()) c = Context(self.get_traceback_data(), use_l10n=False) return t.render(c) def get_traceback_text(self): """Return plain text version of debug 500 HTTP error page.""" - with self.text_template_path.open(encoding='utf-8') as fh: + with self.text_template_path.open(encoding="utf-8") as fh: t = DEBUG_ENGINE.from_string(fh.read()) c = Context(self.get_traceback_data(), autoescape=False, use_l10n=False) return t.render(c) def _get_source(self, filename, loader, module_name): source = None - if hasattr(loader, 'get_source'): + if hasattr(loader, "get_source"): try: source = loader.get_source(module_name) except ImportError: @@ -395,13 +414,15 @@ class ExceptionReporter: source = source.splitlines() if source is None: try: - with open(filename, 'rb') as fp: + with open(filename, "rb") as fp: source = fp.read().splitlines() except OSError: pass return source - def _get_lines_from_file(self, filename, lineno, context_lines, loader=None, module_name=None): + def _get_lines_from_file( + self, filename, lineno, context_lines, loader=None, module_name=None + ): """ Return context_lines before and after lineno from file. Return (pre_context_lineno, pre_context, context_line, post_context). @@ -414,15 +435,15 @@ class ExceptionReporter: # apply tokenize.detect_encoding to decode the source into a # string, then we should do that ourselves. if isinstance(source[0], bytes): - encoding = 'ascii' + encoding = "ascii" for line in source[:2]: # File coding may be specified. Match pattern from PEP-263 # (https://www.python.org/dev/peps/pep-0263/) - match = re.search(br'coding[:=]\s*([-\w.]+)', line) + match = re.search(rb"coding[:=]\s*([-\w.]+)", line) if match: - encoding = match[1].decode('ascii') + encoding = match[1].decode("ascii") break - source = [str(sline, encoding, 'replace') for sline in source] + source = [str(sline, encoding, "replace") for sline in source] lower_bound = max(0, lineno - context_lines) upper_bound = lineno + context_lines @@ -430,15 +451,15 @@ class ExceptionReporter: try: pre_context = source[lower_bound:lineno] context_line = source[lineno] - post_context = source[lineno + 1:upper_bound] + post_context = source[lineno + 1 : upper_bound] except IndexError: return None, [], None, [] return lower_bound, pre_context, context_line, post_context def _get_explicit_or_implicit_cause(self, exc_value): - explicit = getattr(exc_value, '__cause__', None) - suppress_context = getattr(exc_value, '__suppress_context__', None) - implicit = getattr(exc_value, '__context__', None) + explicit = getattr(exc_value, "__cause__", None) + suppress_context = getattr(exc_value, "__suppress_context__", None) + implicit = getattr(exc_value, "__context__", None) return explicit or (None if suppress_context else implicit) def get_traceback_frames(self): @@ -476,47 +497,58 @@ class ExceptionReporter: def get_exception_traceback_frames(self, exc_value, tb): exc_cause = self._get_explicit_or_implicit_cause(exc_value) - exc_cause_explicit = getattr(exc_value, '__cause__', True) + exc_cause_explicit = getattr(exc_value, "__cause__", True) if tb is None: yield { - 'exc_cause': exc_cause, - 'exc_cause_explicit': exc_cause_explicit, - 'tb': None, - 'type': 'user', + "exc_cause": exc_cause, + "exc_cause_explicit": exc_cause_explicit, + "tb": None, + "type": "user", } while tb is not None: # Support for __traceback_hide__ which is used by a few libraries # to hide internal frames. - if tb.tb_frame.f_locals.get('__traceback_hide__'): + if tb.tb_frame.f_locals.get("__traceback_hide__"): tb = tb.tb_next continue filename = tb.tb_frame.f_code.co_filename function = tb.tb_frame.f_code.co_name lineno = tb.tb_lineno - 1 - loader = tb.tb_frame.f_globals.get('__loader__') - module_name = tb.tb_frame.f_globals.get('__name__') or '' - pre_context_lineno, pre_context, context_line, post_context = self._get_lines_from_file( - filename, lineno, 7, loader, module_name, + loader = tb.tb_frame.f_globals.get("__loader__") + module_name = tb.tb_frame.f_globals.get("__name__") or "" + ( + pre_context_lineno, + pre_context, + context_line, + post_context, + ) = self._get_lines_from_file( + filename, + lineno, + 7, + loader, + module_name, ) if pre_context_lineno is None: pre_context_lineno = lineno pre_context = [] - context_line = '<source code not available>' + context_line = "<source code not available>" post_context = [] yield { - 'exc_cause': exc_cause, - 'exc_cause_explicit': exc_cause_explicit, - 'tb': tb, - 'type': 'django' if module_name.startswith('django.') else 'user', - 'filename': filename, - 'function': function, - 'lineno': lineno + 1, - 'vars': self.filter.get_traceback_frame_variables(self.request, tb.tb_frame), - 'id': id(tb), - 'pre_context': pre_context, - 'context_line': context_line, - 'post_context': post_context, - 'pre_context_lineno': pre_context_lineno + 1, + "exc_cause": exc_cause, + "exc_cause_explicit": exc_cause_explicit, + "tb": tb, + "type": "django" if module_name.startswith("django.") else "user", + "filename": filename, + "function": function, + "lineno": lineno + 1, + "vars": self.filter.get_traceback_frame_variables( + self.request, tb.tb_frame + ), + "id": id(tb), + "pre_context": pre_context, + "context_line": context_line, + "post_context": post_context, + "pre_context_lineno": pre_context_lineno + 1, } tb = tb.tb_next @@ -524,52 +556,58 @@ class ExceptionReporter: def technical_404_response(request, exception): """Create a technical 404 error response. `exception` is the Http404.""" try: - error_url = exception.args[0]['path'] + error_url = exception.args[0]["path"] except (IndexError, TypeError, KeyError): error_url = request.path_info[1:] # Trim leading slash try: - tried = exception.args[0]['tried'] + tried = exception.args[0]["tried"] except (IndexError, TypeError, KeyError): resolved = True tried = request.resolver_match.tried if request.resolver_match else None else: resolved = False - if (not tried or ( # empty URLconf - request.path == '/' and - len(tried) == 1 and # default URLconf - len(tried[0]) == 1 and - getattr(tried[0][0], 'app_name', '') == getattr(tried[0][0], 'namespace', '') == 'admin' - )): + if not tried or ( # empty URLconf + request.path == "/" + and len(tried) == 1 + and len(tried[0]) == 1 # default URLconf + and getattr(tried[0][0], "app_name", "") + == getattr(tried[0][0], "namespace", "") + == "admin" + ): return default_urlconf(request) - urlconf = getattr(request, 'urlconf', settings.ROOT_URLCONF) + urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) if isinstance(urlconf, types.ModuleType): urlconf = urlconf.__name__ - with builtin_template_path('technical_404.html').open(encoding='utf-8') as fh: + with builtin_template_path("technical_404.html").open(encoding="utf-8") as fh: t = DEBUG_ENGINE.from_string(fh.read()) reporter_filter = get_default_exception_reporter_filter() - c = Context({ - 'urlconf': urlconf, - 'root_urlconf': settings.ROOT_URLCONF, - 'request_path': error_url, - 'urlpatterns': tried, - 'resolved': resolved, - 'reason': str(exception), - 'request': request, - 'settings': reporter_filter.get_safe_settings(), - 'raising_view_name': get_caller(request), - }) - return HttpResponseNotFound(t.render(c), content_type='text/html') + c = Context( + { + "urlconf": urlconf, + "root_urlconf": settings.ROOT_URLCONF, + "request_path": error_url, + "urlpatterns": tried, + "resolved": resolved, + "reason": str(exception), + "request": request, + "settings": reporter_filter.get_safe_settings(), + "raising_view_name": get_caller(request), + } + ) + return HttpResponseNotFound(t.render(c), content_type="text/html") def default_urlconf(request): """Create an empty URLconf 404 error response.""" - with builtin_template_path('default_urlconf.html').open(encoding='utf-8') as fh: + with builtin_template_path("default_urlconf.html").open(encoding="utf-8") as fh: t = DEBUG_ENGINE.from_string(fh.read()) - c = Context({ - 'version': get_docs_version(), - }) + c = Context( + { + "version": get_docs_version(), + } + ) - return HttpResponse(t.render(c), content_type='text/html') + return HttpResponse(t.render(c), content_type="text/html") diff --git a/django/views/decorators/cache.py b/django/views/decorators/cache.py index 417f3614f8..de61b6fe26 100644 --- a/django/views/decorators/cache.py +++ b/django/views/decorators/cache.py @@ -20,7 +20,9 @@ def cache_page(timeout, *, cache=None, key_prefix=None): into account on caching -- just like the middleware does. """ return decorator_from_middleware_with_args(CacheMiddleware)( - page_timeout=timeout, cache_alias=cache, key_prefix=key_prefix, + page_timeout=timeout, + cache_alias=cache, + key_prefix=key_prefix, ) @@ -29,7 +31,7 @@ def cache_control(**kwargs): @wraps(viewfunc) def _cache_controlled(request, *args, **kw): # Ensure argument looks like a request. - if not hasattr(request, 'META'): + if not hasattr(request, "META"): raise TypeError( "cache_control didn't receive an HttpRequest. If you are " "decorating a classmethod, be sure to use " @@ -38,7 +40,9 @@ def cache_control(**kwargs): response = viewfunc(request, *args, **kw) patch_cache_control(response, **kwargs) return response + return _cache_controlled + return _cache_controller @@ -46,10 +50,11 @@ def never_cache(view_func): """ Decorator that adds headers to a response so that it will never be cached. """ + @wraps(view_func) def _wrapped_view_func(request, *args, **kwargs): # Ensure argument looks like a request. - if not hasattr(request, 'META'): + if not hasattr(request, "META"): raise TypeError( "never_cache didn't receive an HttpRequest. If you are " "decorating a classmethod, be sure to use @method_decorator." @@ -57,4 +62,5 @@ def never_cache(view_func): response = view_func(request, *args, **kwargs) add_never_cache_headers(response) return response + return _wrapped_view_func diff --git a/django/views/decorators/clickjacking.py b/django/views/decorators/clickjacking.py index f8fc2d2b95..c59f4f2098 100644 --- a/django/views/decorators/clickjacking.py +++ b/django/views/decorators/clickjacking.py @@ -11,11 +11,13 @@ def xframe_options_deny(view_func): def some_view(request): ... """ + def wrapped_view(*args, **kwargs): resp = view_func(*args, **kwargs) - if resp.get('X-Frame-Options') is None: - resp['X-Frame-Options'] = 'DENY' + if resp.get("X-Frame-Options") is None: + resp["X-Frame-Options"] = "DENY" return resp + return wraps(view_func)(wrapped_view) @@ -29,11 +31,13 @@ def xframe_options_sameorigin(view_func): def some_view(request): ... """ + def wrapped_view(*args, **kwargs): resp = view_func(*args, **kwargs) - if resp.get('X-Frame-Options') is None: - resp['X-Frame-Options'] = 'SAMEORIGIN' + if resp.get("X-Frame-Options") is None: + resp["X-Frame-Options"] = "SAMEORIGIN" return resp + return wraps(view_func)(wrapped_view) @@ -46,8 +50,10 @@ def xframe_options_exempt(view_func): def some_view(request): ... """ + def wrapped_view(*args, **kwargs): resp = view_func(*args, **kwargs) resp.xframe_options_exempt = True return resp + return wraps(view_func)(wrapped_view) diff --git a/django/views/decorators/common.py b/django/views/decorators/common.py index 34b0e5a50e..8c84688122 100644 --- a/django/views/decorators/common.py +++ b/django/views/decorators/common.py @@ -10,5 +10,6 @@ def no_append_slash(view_func): # nicer if they don't have side effects, so return a new function. def wrapped_view(*args, **kwargs): return view_func(*args, **kwargs) + wrapped_view.should_append_slash = False return wraps(view_func)(wrapped_view) diff --git a/django/views/decorators/csrf.py b/django/views/decorators/csrf.py index 19d439a55a..4841089ca8 100644 --- a/django/views/decorators/csrf.py +++ b/django/views/decorators/csrf.py @@ -19,7 +19,7 @@ class _EnsureCsrfToken(CsrfViewMiddleware): requires_csrf_token = decorator_from_middleware(_EnsureCsrfToken) -requires_csrf_token.__name__ = 'requires_csrf_token' +requires_csrf_token.__name__ = "requires_csrf_token" requires_csrf_token.__doc__ = """ Use this decorator on views that need a correct csrf_token available to RequestContext, but without the CSRF protection that csrf_protect @@ -39,7 +39,7 @@ class _EnsureCsrfCookie(CsrfViewMiddleware): ensure_csrf_cookie = decorator_from_middleware(_EnsureCsrfCookie) -ensure_csrf_cookie.__name__ = 'ensure_csrf_cookie' +ensure_csrf_cookie.__name__ = "ensure_csrf_cookie" ensure_csrf_cookie.__doc__ = """ Use this decorator to ensure that a view sets a CSRF cookie, whether or not it uses the csrf_token template tag, or the CsrfViewMiddleware is used. @@ -52,5 +52,6 @@ def csrf_exempt(view_func): # if they don't have side effects, so return a new function. def wrapped_view(*args, **kwargs): return view_func(*args, **kwargs) + wrapped_view.csrf_exempt = True return wraps(view_func)(wrapped_view) diff --git a/django/views/decorators/debug.py b/django/views/decorators/debug.py index 312269baba..5cf8db891f 100644 --- a/django/views/decorators/debug.py +++ b/django/views/decorators/debug.py @@ -28,8 +28,8 @@ def sensitive_variables(*variables): """ if len(variables) == 1 and callable(variables[0]): raise TypeError( - 'sensitive_variables() must be called to use it as a decorator, ' - 'e.g., use @sensitive_variables(), not @sensitive_variables.' + "sensitive_variables() must be called to use it as a decorator, " + "e.g., use @sensitive_variables(), not @sensitive_variables." ) def decorator(func): @@ -38,9 +38,11 @@ def sensitive_variables(*variables): if variables: sensitive_variables_wrapper.sensitive_variables = variables else: - sensitive_variables_wrapper.sensitive_variables = '__ALL__' + sensitive_variables_wrapper.sensitive_variables = "__ALL__" return func(*func_args, **func_kwargs) + return sensitive_variables_wrapper + return decorator @@ -69,9 +71,9 @@ def sensitive_post_parameters(*parameters): """ if len(parameters) == 1 and callable(parameters[0]): raise TypeError( - 'sensitive_post_parameters() must be called to use it as a ' - 'decorator, e.g., use @sensitive_post_parameters(), not ' - '@sensitive_post_parameters.' + "sensitive_post_parameters() must be called to use it as a " + "decorator, e.g., use @sensitive_post_parameters(), not " + "@sensitive_post_parameters." ) def decorator(view): @@ -86,7 +88,9 @@ def sensitive_post_parameters(*parameters): if parameters: request.sensitive_post_parameters = parameters else: - request.sensitive_post_parameters = '__ALL__' + request.sensitive_post_parameters = "__ALL__" return view(request, *args, **kwargs) + return sensitive_post_parameters_wrapper + return decorator diff --git a/django/views/decorators/http.py b/django/views/decorators/http.py index 28da8dd921..6f7578ef31 100644 --- a/django/views/decorators/http.py +++ b/django/views/decorators/http.py @@ -26,19 +26,24 @@ def require_http_methods(request_method_list): Note that request methods should be in uppercase. """ + def decorator(func): @wraps(func) def inner(request, *args, **kwargs): if request.method not in request_method_list: response = HttpResponseNotAllowed(request_method_list) log_response( - 'Method Not Allowed (%s): %s', request.method, request.path, + "Method Not Allowed (%s): %s", + request.method, + request.path, response=response, request=request, ) return response return func(request, *args, **kwargs) + return inner + return decorator @@ -49,7 +54,9 @@ require_POST = require_http_methods(["POST"]) require_POST.__doc__ = "Decorator to require that a view only accepts the POST method." require_safe = require_http_methods(["GET", "HEAD"]) -require_safe.__doc__ = "Decorator to require that a view only accepts safe methods: GET and HEAD." +require_safe.__doc__ = ( + "Decorator to require that a view only accepts safe methods: GET and HEAD." +) def condition(etag_func=None, last_modified_func=None): @@ -74,6 +81,7 @@ def condition(etag_func=None, last_modified_func=None): will add the generated ETag and Last-Modified headers to the response if the headers aren't already set and if the request's method is safe. """ + def decorator(func): @wraps(func) def inner(request, *args, **kwargs): @@ -102,15 +110,16 @@ def condition(etag_func=None, last_modified_func=None): # Set relevant headers on the response if they don't already exist # and if the request method is safe. - if request.method in ('GET', 'HEAD'): - if res_last_modified and not response.has_header('Last-Modified'): - response.headers['Last-Modified'] = http_date(res_last_modified) + if request.method in ("GET", "HEAD"): + if res_last_modified and not response.has_header("Last-Modified"): + response.headers["Last-Modified"] = http_date(res_last_modified) if res_etag: - response.headers.setdefault('ETag', res_etag) + response.headers.setdefault("ETag", res_etag) return response return inner + return decorator diff --git a/django/views/decorators/vary.py b/django/views/decorators/vary.py index 68b783ed3a..6098a0fbc0 100644 --- a/django/views/decorators/vary.py +++ b/django/views/decorators/vary.py @@ -14,13 +14,16 @@ def vary_on_headers(*headers): Note that the header names are not case-sensitive. """ + def decorator(func): @wraps(func) def inner_func(*args, **kwargs): response = func(*args, **kwargs) patch_vary_headers(response, headers) return response + return inner_func + return decorator @@ -33,9 +36,11 @@ def vary_on_cookie(func): def index(request): ... """ + @wraps(func) def inner_func(*args, **kwargs): response = func(*args, **kwargs) - patch_vary_headers(response, ('Cookie',)) + patch_vary_headers(response, ("Cookie",)) return response + return inner_func diff --git a/django/views/defaults.py b/django/views/defaults.py index e3d2c3dd29..273d5fa7d5 100644 --- a/django/views/defaults.py +++ b/django/views/defaults.py @@ -1,16 +1,18 @@ from urllib.parse import quote from django.http import ( - HttpResponseBadRequest, HttpResponseForbidden, HttpResponseNotFound, + HttpResponseBadRequest, + HttpResponseForbidden, + HttpResponseNotFound, HttpResponseServerError, ) from django.template import Context, Engine, TemplateDoesNotExist, loader from django.views.decorators.csrf import requires_csrf_token -ERROR_404_TEMPLATE_NAME = '404.html' -ERROR_403_TEMPLATE_NAME = '403.html' -ERROR_400_TEMPLATE_NAME = '400.html' -ERROR_500_TEMPLATE_NAME = '500.html' +ERROR_404_TEMPLATE_NAME = "404.html" +ERROR_403_TEMPLATE_NAME = "403.html" +ERROR_400_TEMPLATE_NAME = "400.html" +ERROR_500_TEMPLATE_NAME = "500.html" ERROR_PAGE_TEMPLATE = """ <!doctype html> <html lang="en"> @@ -54,13 +56,13 @@ def page_not_found(request, exception, template_name=ERROR_404_TEMPLATE_NAME): if isinstance(message, str): exception_repr = message context = { - 'request_path': quote(request.path), - 'exception': exception_repr, + "request_path": quote(request.path), + "exception": exception_repr, } try: template = loader.get_template(template_name) body = template.render(context, request) - content_type = None # Django will use 'text/html'. + content_type = None # Django will use 'text/html'. except TemplateDoesNotExist: if template_name != ERROR_404_TEMPLATE_NAME: # Reraise if it's a missing custom template. @@ -68,13 +70,14 @@ def page_not_found(request, exception, template_name=ERROR_404_TEMPLATE_NAME): # Render template (even though there are no substitutions) to allow # inspecting the context in tests. template = Engine().from_string( - ERROR_PAGE_TEMPLATE % { - 'title': 'Not Found', - 'details': 'The requested resource was not found on this server.', + ERROR_PAGE_TEMPLATE + % { + "title": "Not Found", + "details": "The requested resource was not found on this server.", }, ) body = template.render(Context(context)) - content_type = 'text/html' + content_type = "text/html" return HttpResponseNotFound(body, content_type=content_type) @@ -93,8 +96,8 @@ def server_error(request, template_name=ERROR_500_TEMPLATE_NAME): # Reraise if it's a missing custom template. raise return HttpResponseServerError( - ERROR_PAGE_TEMPLATE % {'title': 'Server Error (500)', 'details': ''}, - content_type='text/html', + ERROR_PAGE_TEMPLATE % {"title": "Server Error (500)", "details": ""}, + content_type="text/html", ) return HttpResponseServerError(template.render()) @@ -114,8 +117,8 @@ def bad_request(request, exception, template_name=ERROR_400_TEMPLATE_NAME): # Reraise if it's a missing custom template. raise return HttpResponseBadRequest( - ERROR_PAGE_TEMPLATE % {'title': 'Bad Request (400)', 'details': ''}, - content_type='text/html', + ERROR_PAGE_TEMPLATE % {"title": "Bad Request (400)", "details": ""}, + content_type="text/html", ) # No exception content is passed to the template, to not disclose any sensitive information. return HttpResponseBadRequest(template.render()) @@ -142,9 +145,9 @@ def permission_denied(request, exception, template_name=ERROR_403_TEMPLATE_NAME) # Reraise if it's a missing custom template. raise return HttpResponseForbidden( - ERROR_PAGE_TEMPLATE % {'title': '403 Forbidden', 'details': ''}, - content_type='text/html', + ERROR_PAGE_TEMPLATE % {"title": "403 Forbidden", "details": ""}, + content_type="text/html", ) return HttpResponseForbidden( - template.render(request=request, context={'exception': str(exception)}) + template.render(request=request, context={"exception": str(exception)}) ) diff --git a/django/views/generic/__init__.py b/django/views/generic/__init__.py index bc32403fc7..8514bae515 100644 --- a/django/views/generic/__init__.py +++ b/django/views/generic/__init__.py @@ -1,22 +1,39 @@ from django.views.generic.base import RedirectView, TemplateView, View from django.views.generic.dates import ( - ArchiveIndexView, DateDetailView, DayArchiveView, MonthArchiveView, - TodayArchiveView, WeekArchiveView, YearArchiveView, + ArchiveIndexView, + DateDetailView, + DayArchiveView, + MonthArchiveView, + TodayArchiveView, + WeekArchiveView, + YearArchiveView, ) from django.views.generic.detail import DetailView -from django.views.generic.edit import ( - CreateView, DeleteView, FormView, UpdateView, -) +from django.views.generic.edit import CreateView, DeleteView, FormView, UpdateView from django.views.generic.list import ListView __all__ = [ - 'View', 'TemplateView', 'RedirectView', 'ArchiveIndexView', - 'YearArchiveView', 'MonthArchiveView', 'WeekArchiveView', 'DayArchiveView', - 'TodayArchiveView', 'DateDetailView', 'DetailView', 'FormView', - 'CreateView', 'UpdateView', 'DeleteView', 'ListView', 'GenericViewError', + "View", + "TemplateView", + "RedirectView", + "ArchiveIndexView", + "YearArchiveView", + "MonthArchiveView", + "WeekArchiveView", + "DayArchiveView", + "TodayArchiveView", + "DateDetailView", + "DetailView", + "FormView", + "CreateView", + "UpdateView", + "DeleteView", + "ListView", + "GenericViewError", ] class GenericViewError(Exception): """A problem in a generic view.""" + pass diff --git a/django/views/generic/base.py b/django/views/generic/base.py index ca14156510..d45b1762e6 100644 --- a/django/views/generic/base.py +++ b/django/views/generic/base.py @@ -2,14 +2,17 @@ import logging from django.core.exceptions import ImproperlyConfigured from django.http import ( - HttpResponse, HttpResponseGone, HttpResponseNotAllowed, - HttpResponsePermanentRedirect, HttpResponseRedirect, + HttpResponse, + HttpResponseGone, + HttpResponseNotAllowed, + HttpResponsePermanentRedirect, + HttpResponseRedirect, ) from django.template.response import TemplateResponse from django.urls import reverse from django.utils.decorators import classonlymethod -logger = logging.getLogger('django.request') +logger = logging.getLogger("django.request") class ContextMixin: @@ -17,10 +20,11 @@ class ContextMixin: A default context mixin that passes the keyword arguments received by get_context_data() as the template context. """ + extra_context = None def get_context_data(self, **kwargs): - kwargs.setdefault('view', self) + kwargs.setdefault("view", self) if self.extra_context is not None: kwargs.update(self.extra_context) return kwargs @@ -32,7 +36,16 @@ class View: dispatch-by-method and simple sanity checking. """ - http_method_names = ['get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace'] + http_method_names = [ + "get", + "post", + "put", + "patch", + "delete", + "head", + "options", + "trace", + ] def __init__(self, **kwargs): """ @@ -50,23 +63,26 @@ class View: for key in initkwargs: if key in cls.http_method_names: raise TypeError( - 'The method name %s is not accepted as a keyword argument ' - 'to %s().' % (key, cls.__name__) + "The method name %s is not accepted as a keyword argument " + "to %s()." % (key, cls.__name__) ) if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r. as_view " - "only accepts arguments that are already " - "attributes of the class." % (cls.__name__, key)) + raise TypeError( + "%s() received an invalid keyword %r. as_view " + "only accepts arguments that are already " + "attributes of the class." % (cls.__name__, key) + ) def view(request, *args, **kwargs): self = cls(**initkwargs) self.setup(request, *args, **kwargs) - if not hasattr(self, 'request'): + if not hasattr(self, "request"): raise AttributeError( "%s instance has no 'request' attribute. Did you override " "setup() and forget to call super()?" % cls.__name__ ) return self.dispatch(request, *args, **kwargs) + view.view_class = cls view.view_initkwargs = initkwargs @@ -84,7 +100,7 @@ class View: def setup(self, request, *args, **kwargs): """Initialize attributes shared by all view methods.""" - if hasattr(self, 'get') and not hasattr(self, 'head'): + if hasattr(self, "get") and not hasattr(self, "head"): self.head = self.get self.request = request self.args = args @@ -95,23 +111,27 @@ class View: # defer to the error handler. Also defer to the error handler if the # request method isn't on the approved list. if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), self.http_method_not_allowed) + handler = getattr( + self, request.method.lower(), self.http_method_not_allowed + ) else: handler = self.http_method_not_allowed return handler(request, *args, **kwargs) def http_method_not_allowed(self, request, *args, **kwargs): logger.warning( - 'Method Not Allowed (%s): %s', request.method, request.path, - extra={'status_code': 405, 'request': request} + "Method Not Allowed (%s): %s", + request.method, + request.path, + extra={"status_code": 405, "request": request}, ) return HttpResponseNotAllowed(self._allowed_methods()) def options(self, request, *args, **kwargs): """Handle responding to requests for the OPTIONS HTTP verb.""" response = HttpResponse() - response.headers['Allow'] = ', '.join(self._allowed_methods()) - response.headers['Content-Length'] = '0' + response.headers["Allow"] = ", ".join(self._allowed_methods()) + response.headers["Content-Length"] = "0" return response def _allowed_methods(self): @@ -120,6 +140,7 @@ class View: class TemplateResponseMixin: """A mixin that can be used to render a template.""" + template_name = None template_engine = None response_class = TemplateResponse @@ -132,13 +153,13 @@ class TemplateResponseMixin: Pass response_kwargs to the constructor of the response class. """ - response_kwargs.setdefault('content_type', self.content_type) + response_kwargs.setdefault("content_type", self.content_type) return self.response_class( request=self.request, template=self.get_template_names(), context=context, using=self.template_engine, - **response_kwargs + **response_kwargs, ) def get_template_names(self): @@ -149,7 +170,8 @@ class TemplateResponseMixin: if self.template_name is None: raise ImproperlyConfigured( "TemplateResponseMixin requires either a definition of " - "'template_name' or an implementation of 'get_template_names()'") + "'template_name' or an implementation of 'get_template_names()'" + ) else: return [self.template_name] @@ -158,6 +180,7 @@ class TemplateView(TemplateResponseMixin, ContextMixin, View): """ Render a template. Pass keyword arguments from the URLconf to the context. """ + def get(self, request, *args, **kwargs): context = self.get_context_data(**kwargs) return self.render_to_response(context) @@ -165,6 +188,7 @@ class TemplateView(TemplateResponseMixin, ContextMixin, View): class RedirectView(View): """Provide a redirect on any GET request.""" + permanent = False url = None pattern_name = None @@ -183,7 +207,7 @@ class RedirectView(View): else: return None - args = self.request.META.get('QUERY_STRING', '') + args = self.request.META.get("QUERY_STRING", "") if args and self.query_string: url = "%s?%s" % (url, args) return url @@ -197,8 +221,7 @@ class RedirectView(View): return HttpResponseRedirect(url) else: logger.warning( - 'Gone: %s', request.path, - extra={'status_code': 410, 'request': request} + "Gone: %s", request.path, extra={"status_code": 410, "request": request} ) return HttpResponseGone() diff --git a/django/views/generic/dates.py b/django/views/generic/dates.py index 63151dd5a5..d2b776c122 100644 --- a/django/views/generic/dates.py +++ b/django/views/generic/dates.py @@ -9,16 +9,19 @@ from django.utils.functional import cached_property from django.utils.translation import gettext as _ from django.views.generic.base import View from django.views.generic.detail import ( - BaseDetailView, SingleObjectTemplateResponseMixin, + BaseDetailView, + SingleObjectTemplateResponseMixin, ) from django.views.generic.list import ( - MultipleObjectMixin, MultipleObjectTemplateResponseMixin, + MultipleObjectMixin, + MultipleObjectTemplateResponseMixin, ) class YearMixin: """Mixin for views manipulating year-based data.""" - year_format = '%Y' + + year_format = "%Y" year = None def get_year_format(self): @@ -33,21 +36,21 @@ class YearMixin: year = self.year if year is None: try: - year = self.kwargs['year'] + year = self.kwargs["year"] except KeyError: try: - year = self.request.GET['year'] + year = self.request.GET["year"] except KeyError: raise Http404(_("No year specified")) return year def get_next_year(self, date): """Get the next valid year.""" - return _get_next_prev(self, date, is_previous=False, period='year') + return _get_next_prev(self, date, is_previous=False, period="year") def get_previous_year(self, date): """Get the previous valid year.""" - return _get_next_prev(self, date, is_previous=True, period='year') + return _get_next_prev(self, date, is_previous=True, period="year") def _get_next_year(self, date): """ @@ -67,7 +70,8 @@ class YearMixin: class MonthMixin: """Mixin for views manipulating month-based data.""" - month_format = '%b' + + month_format = "%b" month = None def get_month_format(self): @@ -82,21 +86,21 @@ class MonthMixin: month = self.month if month is None: try: - month = self.kwargs['month'] + month = self.kwargs["month"] except KeyError: try: - month = self.request.GET['month'] + month = self.request.GET["month"] except KeyError: raise Http404(_("No month specified")) return month def get_next_month(self, date): """Get the next valid month.""" - return _get_next_prev(self, date, is_previous=False, period='month') + return _get_next_prev(self, date, is_previous=False, period="month") def get_previous_month(self, date): """Get the previous valid month.""" - return _get_next_prev(self, date, is_previous=True, period='month') + return _get_next_prev(self, date, is_previous=True, period="month") def _get_next_month(self, date): """ @@ -119,7 +123,8 @@ class MonthMixin: class DayMixin: """Mixin for views manipulating day-based data.""" - day_format = '%d' + + day_format = "%d" day = None def get_day_format(self): @@ -134,21 +139,21 @@ class DayMixin: day = self.day if day is None: try: - day = self.kwargs['day'] + day = self.kwargs["day"] except KeyError: try: - day = self.request.GET['day'] + day = self.request.GET["day"] except KeyError: raise Http404(_("No day specified")) return day def get_next_day(self, date): """Get the next valid day.""" - return _get_next_prev(self, date, is_previous=False, period='day') + return _get_next_prev(self, date, is_previous=False, period="day") def get_previous_day(self, date): """Get the previous valid day.""" - return _get_next_prev(self, date, is_previous=True, period='day') + return _get_next_prev(self, date, is_previous=True, period="day") def _get_next_day(self, date): """ @@ -165,7 +170,8 @@ class DayMixin: class WeekMixin: """Mixin for views manipulating week-based data.""" - week_format = '%U' + + week_format = "%U" week = None def get_week_format(self): @@ -180,21 +186,21 @@ class WeekMixin: week = self.week if week is None: try: - week = self.kwargs['week'] + week = self.kwargs["week"] except KeyError: try: - week = self.request.GET['week'] + week = self.request.GET["week"] except KeyError: raise Http404(_("No week specified")) return week def get_next_week(self, date): """Get the next valid week.""" - return _get_next_prev(self, date, is_previous=False, period='week') + return _get_next_prev(self, date, is_previous=False, period="week") def get_previous_week(self, date): """Get the previous valid week.""" - return _get_next_prev(self, date, is_previous=True, period='week') + return _get_next_prev(self, date, is_previous=True, period="week") def _get_next_week(self, date): """ @@ -218,9 +224,9 @@ class WeekMixin: The first day according to the week format is 0 and the last day is 6. """ week_format = self.get_week_format() - if week_format in {'%W', '%V'}: # week starts on Monday + if week_format in {"%W", "%V"}: # week starts on Monday return date.weekday() - elif week_format == '%U': # week starts on Sunday + elif week_format == "%U": # week starts on Sunday return (date.weekday() + 1) % 7 else: raise ValueError("unknown week format: %s" % week_format) @@ -228,13 +234,16 @@ class WeekMixin: class DateMixin: """Mixin class for views manipulating date-based data.""" + date_field = None allow_future = False def get_date_field(self): """Get the name of the date field to be used to filter by.""" if self.date_field is None: - raise ImproperlyConfigured("%s.date_field is required." % self.__class__.__name__) + raise ImproperlyConfigured( + "%s.date_field is required." % self.__class__.__name__ + ) return self.date_field def get_allow_future(self): @@ -282,8 +291,8 @@ class DateMixin: since = self._make_date_lookup_arg(date) until = self._make_date_lookup_arg(date + datetime.timedelta(days=1)) return { - '%s__gte' % date_field: since, - '%s__lt' % date_field: until, + "%s__gte" % date_field: since, + "%s__lt" % date_field: until, } else: # Skip self._make_date_lookup_arg, it's a no-op in this branch. @@ -292,28 +301,29 @@ class DateMixin: class BaseDateListView(MultipleObjectMixin, DateMixin, View): """Abstract base class for date-based views displaying a list of objects.""" + allow_empty = False - date_list_period = 'year' + date_list_period = "year" def get(self, request, *args, **kwargs): self.date_list, self.object_list, extra_context = self.get_dated_items() context = self.get_context_data( - object_list=self.object_list, - date_list=self.date_list, - **extra_context + object_list=self.object_list, date_list=self.date_list, **extra_context ) return self.render_to_response(context) def get_dated_items(self): """Obtain the list of dates and items.""" - raise NotImplementedError('A DateView must provide an implementation of get_dated_items()') + raise NotImplementedError( + "A DateView must provide an implementation of get_dated_items()" + ) def get_ordering(self): """ Return the field or fields to use for ordering the queryset; use the date field by default. """ - return '-%s' % self.get_date_field() if self.ordering is None else self.ordering + return "-%s" % self.get_date_field() if self.ordering is None else self.ordering def get_dated_queryset(self, **lookup): """ @@ -328,16 +338,19 @@ class BaseDateListView(MultipleObjectMixin, DateMixin, View): if not allow_future: now = timezone.now() if self.uses_datetime_field else timezone_today() - qs = qs.filter(**{'%s__lte' % date_field: now}) + qs = qs.filter(**{"%s__lte" % date_field: now}) if not allow_empty: # When pagination is enabled, it's better to do a cheap query # than to load the unpaginated queryset in memory. is_empty = not qs if paginate_by is None else not qs.exists() if is_empty: - raise Http404(_("No %(verbose_name_plural)s available") % { - 'verbose_name_plural': qs.model._meta.verbose_name_plural, - }) + raise Http404( + _("No %(verbose_name_plural)s available") + % { + "verbose_name_plural": qs.model._meta.verbose_name_plural, + } + ) return qs @@ -348,7 +361,7 @@ class BaseDateListView(MultipleObjectMixin, DateMixin, View): """ return self.date_list_period - def get_date_list(self, queryset, date_type=None, ordering='ASC'): + def get_date_list(self, queryset, date_type=None, ordering="ASC"): """ Get a date list by calling `queryset.dates/datetimes()`, checking along the way for empty lists that aren't allowed. @@ -364,8 +377,9 @@ class BaseDateListView(MultipleObjectMixin, DateMixin, View): date_list = queryset.dates(date_field, date_type, ordering) if date_list is not None and not date_list and not allow_empty: raise Http404( - _("No %(verbose_name_plural)s available") % { - 'verbose_name_plural': queryset.model._meta.verbose_name_plural, + _("No %(verbose_name_plural)s available") + % { + "verbose_name_plural": queryset.model._meta.verbose_name_plural, } ) @@ -376,12 +390,13 @@ class BaseArchiveIndexView(BaseDateListView): """ Base class for archives of date-based items. Requires a response mixin. """ - context_object_name = 'latest' + + context_object_name = "latest" def get_dated_items(self): """Return (date_list, items, extra_context) for this request.""" qs = self.get_dated_queryset() - date_list = self.get_date_list(qs, ordering='DESC') + date_list = self.get_date_list(qs, ordering="DESC") if not date_list: qs = qs.none() @@ -391,12 +406,14 @@ class BaseArchiveIndexView(BaseDateListView): class ArchiveIndexView(MultipleObjectTemplateResponseMixin, BaseArchiveIndexView): """Top-level archive of date-based items.""" - template_name_suffix = '_archive' + + template_name_suffix = "_archive" class BaseYearArchiveView(YearMixin, BaseDateListView): """List of objects published in a given year.""" - date_list_period = 'month' + + date_list_period = "month" make_object_list = False def get_dated_items(self): @@ -409,8 +426,8 @@ class BaseYearArchiveView(YearMixin, BaseDateListView): since = self._make_date_lookup_arg(date) until = self._make_date_lookup_arg(self._get_next_year(date)) lookup_kwargs = { - '%s__gte' % date_field: since, - '%s__lt' % date_field: until, + "%s__gte" % date_field: since, + "%s__lt" % date_field: until, } qs = self.get_dated_queryset(**lookup_kwargs) @@ -421,11 +438,15 @@ class BaseYearArchiveView(YearMixin, BaseDateListView): # to find information about the model. qs = qs.none() - return (date_list, qs, { - 'year': date, - 'next_year': self.get_next_year(date), - 'previous_year': self.get_previous_year(date), - }) + return ( + date_list, + qs, + { + "year": date, + "next_year": self.get_next_year(date), + "previous_year": self.get_previous_year(date), + }, + ) def get_make_object_list(self): """ @@ -437,12 +458,14 @@ class BaseYearArchiveView(YearMixin, BaseDateListView): class YearArchiveView(MultipleObjectTemplateResponseMixin, BaseYearArchiveView): """List of objects published in a given year.""" - template_name_suffix = '_archive_year' + + template_name_suffix = "_archive_year" class BaseMonthArchiveView(YearMixin, MonthMixin, BaseDateListView): """List of objects published in a given month.""" - date_list_period = 'day' + + date_list_period = "day" def get_dated_items(self): """Return (date_list, items, extra_context) for this request.""" @@ -450,29 +473,35 @@ class BaseMonthArchiveView(YearMixin, MonthMixin, BaseDateListView): month = self.get_month() date_field = self.get_date_field() - date = _date_from_string(year, self.get_year_format(), - month, self.get_month_format()) + date = _date_from_string( + year, self.get_year_format(), month, self.get_month_format() + ) since = self._make_date_lookup_arg(date) until = self._make_date_lookup_arg(self._get_next_month(date)) lookup_kwargs = { - '%s__gte' % date_field: since, - '%s__lt' % date_field: until, + "%s__gte" % date_field: since, + "%s__lt" % date_field: until, } qs = self.get_dated_queryset(**lookup_kwargs) date_list = self.get_date_list(qs) - return (date_list, qs, { - 'month': date, - 'next_month': self.get_next_month(date), - 'previous_month': self.get_previous_month(date), - }) + return ( + date_list, + qs, + { + "month": date, + "next_month": self.get_next_month(date), + "previous_month": self.get_previous_month(date), + }, + ) class MonthArchiveView(MultipleObjectTemplateResponseMixin, BaseMonthArchiveView): """List of objects published in a given month.""" - template_name_suffix = '_archive_month' + + template_name_suffix = "_archive_month" class BaseWeekArchiveView(YearMixin, WeekMixin, BaseDateListView): @@ -485,55 +514,71 @@ class BaseWeekArchiveView(YearMixin, WeekMixin, BaseDateListView): date_field = self.get_date_field() week_format = self.get_week_format() - week_choices = {'%W': '1', '%U': '0', '%V': '1'} + week_choices = {"%W": "1", "%U": "0", "%V": "1"} try: week_start = week_choices[week_format] except KeyError: - raise ValueError('Unknown week format %r. Choices are: %s' % ( - week_format, - ', '.join(sorted(week_choices)), - )) - year_format = self.get_year_format() - if week_format == '%V' and year_format != '%G': raise ValueError( - "ISO week directive '%s' is incompatible with the year " - "directive '%s'. Use the ISO year '%%G' instead." % ( - week_format, year_format, + "Unknown week format %r. Choices are: %s" + % ( + week_format, + ", ".join(sorted(week_choices)), ) ) - date = _date_from_string(year, year_format, week_start, '%w', week, week_format) + year_format = self.get_year_format() + if week_format == "%V" and year_format != "%G": + raise ValueError( + "ISO week directive '%s' is incompatible with the year " + "directive '%s'. Use the ISO year '%%G' instead." + % ( + week_format, + year_format, + ) + ) + date = _date_from_string(year, year_format, week_start, "%w", week, week_format) since = self._make_date_lookup_arg(date) until = self._make_date_lookup_arg(self._get_next_week(date)) lookup_kwargs = { - '%s__gte' % date_field: since, - '%s__lt' % date_field: until, + "%s__gte" % date_field: since, + "%s__lt" % date_field: until, } qs = self.get_dated_queryset(**lookup_kwargs) - return (None, qs, { - 'week': date, - 'next_week': self.get_next_week(date), - 'previous_week': self.get_previous_week(date), - }) + return ( + None, + qs, + { + "week": date, + "next_week": self.get_next_week(date), + "previous_week": self.get_previous_week(date), + }, + ) class WeekArchiveView(MultipleObjectTemplateResponseMixin, BaseWeekArchiveView): """List of objects published in a given week.""" - template_name_suffix = '_archive_week' + + template_name_suffix = "_archive_week" class BaseDayArchiveView(YearMixin, MonthMixin, DayMixin, BaseDateListView): """List of objects published on a given day.""" + def get_dated_items(self): """Return (date_list, items, extra_context) for this request.""" year = self.get_year() month = self.get_month() day = self.get_day() - date = _date_from_string(year, self.get_year_format(), - month, self.get_month_format(), - day, self.get_day_format()) + date = _date_from_string( + year, + self.get_year_format(), + month, + self.get_month_format(), + day, + self.get_day_format(), + ) return self._get_dated_items(date) @@ -545,17 +590,22 @@ class BaseDayArchiveView(YearMixin, MonthMixin, DayMixin, BaseDateListView): lookup_kwargs = self._make_single_date_lookup(date) qs = self.get_dated_queryset(**lookup_kwargs) - return (None, qs, { - 'day': date, - 'previous_day': self.get_previous_day(date), - 'next_day': self.get_next_day(date), - 'previous_month': self.get_previous_month(date), - 'next_month': self.get_next_month(date) - }) + return ( + None, + qs, + { + "day": date, + "previous_day": self.get_previous_day(date), + "next_day": self.get_next_day(date), + "previous_month": self.get_previous_month(date), + "next_month": self.get_next_month(date), + }, + ) class DayArchiveView(MultipleObjectTemplateResponseMixin, BaseDayArchiveView): """List of objects published on a given day.""" + template_name_suffix = "_archive_day" @@ -569,6 +619,7 @@ class BaseTodayArchiveView(BaseDayArchiveView): class TodayArchiveView(MultipleObjectTemplateResponseMixin, BaseTodayArchiveView): """List of objects published today.""" + template_name_suffix = "_archive_day" @@ -577,26 +628,35 @@ class BaseDateDetailView(YearMixin, MonthMixin, DayMixin, DateMixin, BaseDetailV Detail view of a single object on a single date; this differs from the standard DetailView by accepting a year/month/day in the URL. """ + def get_object(self, queryset=None): """Get the object this request displays.""" year = self.get_year() month = self.get_month() day = self.get_day() - date = _date_from_string(year, self.get_year_format(), - month, self.get_month_format(), - day, self.get_day_format()) + date = _date_from_string( + year, + self.get_year_format(), + month, + self.get_month_format(), + day, + self.get_day_format(), + ) # Use a custom queryset if provided qs = self.get_queryset() if queryset is None else queryset if not self.get_allow_future() and date > datetime.date.today(): - raise Http404(_( - "Future %(verbose_name_plural)s not available because " - "%(class_name)s.allow_future is False." - ) % { - 'verbose_name_plural': qs.model._meta.verbose_name_plural, - 'class_name': self.__class__.__name__, - }) + raise Http404( + _( + "Future %(verbose_name_plural)s not available because " + "%(class_name)s.allow_future is False." + ) + % { + "verbose_name_plural": qs.model._meta.verbose_name_plural, + "class_name": self.__class__.__name__, + } + ) # Filter down a queryset from self.queryset using the date from the # URL. This'll get passed as the queryset to DetailView.get_object, @@ -612,10 +672,13 @@ class DateDetailView(SingleObjectTemplateResponseMixin, BaseDateDetailView): Detail view of a single object on a single date; this differs from the standard DetailView by accepting a year/month/day in the URL. """ - template_name_suffix = '_detail' + + template_name_suffix = "_detail" -def _date_from_string(year, year_format, month='', month_format='', day='', day_format='', delim='__'): +def _date_from_string( + year, year_format, month="", month_format="", day="", day_format="", delim="__" +): """ Get a datetime.date object given a format string and a year, month, and day (only year is mandatory). Raise a 404 for an invalid date. @@ -625,10 +688,13 @@ def _date_from_string(year, year_format, month='', month_format='', day='', day_ try: return datetime.datetime.strptime(datestr, format).date() except ValueError: - raise Http404(_('Invalid date string “%(datestr)s” given format “%(format)s”') % { - 'datestr': datestr, - 'format': format, - }) + raise Http404( + _("Invalid date string “%(datestr)s” given format “%(format)s”") + % { + "datestr": datestr, + "format": format, + } + ) def _get_next_prev(generic_view, date, is_previous, period): @@ -661,8 +727,8 @@ def _get_next_prev(generic_view, date, is_previous, period): allow_empty = generic_view.get_allow_empty() allow_future = generic_view.get_allow_future() - get_current = getattr(generic_view, '_get_current_%s' % period) - get_next = getattr(generic_view, '_get_next_%s' % period) + get_current = getattr(generic_view, "_get_current_%s" % period) + get_next = getattr(generic_view, "_get_next_%s" % period) # Bounds of the current interval start, end = get_current(date), get_next(date) @@ -686,10 +752,10 @@ def _get_next_prev(generic_view, date, is_previous, period): # Construct a lookup and an ordering depending on whether we're doing # a previous date or a next date lookup. if is_previous: - lookup = {'%s__lt' % date_field: generic_view._make_date_lookup_arg(start)} - ordering = '-%s' % date_field + lookup = {"%s__lt" % date_field: generic_view._make_date_lookup_arg(start)} + ordering = "-%s" % date_field else: - lookup = {'%s__gte' % date_field: generic_view._make_date_lookup_arg(end)} + lookup = {"%s__gte" % date_field: generic_view._make_date_lookup_arg(end)} ordering = date_field # Filter out objects in the future if appropriate. @@ -700,7 +766,7 @@ def _get_next_prev(generic_view, date, is_previous, period): now = timezone.now() else: now = timezone_today() - lookup['%s__lte' % date_field] = now + lookup["%s__lte" % date_field] = now qs = generic_view.get_queryset().filter(**lookup).order_by(ordering) diff --git a/django/views/generic/detail.py b/django/views/generic/detail.py index 1922114598..e4428c8036 100644 --- a/django/views/generic/detail.py +++ b/django/views/generic/detail.py @@ -9,12 +9,13 @@ class SingleObjectMixin(ContextMixin): """ Provide the ability to retrieve a single object for further manipulation. """ + model = None queryset = None - slug_field = 'slug' + slug_field = "slug" context_object_name = None - slug_url_kwarg = 'slug' - pk_url_kwarg = 'pk' + slug_url_kwarg = "slug" + pk_url_kwarg = "pk" query_pk_and_slug = False def get_object(self, queryset=None): @@ -51,8 +52,10 @@ class SingleObjectMixin(ContextMixin): # Get the single item from the filtered queryset obj = queryset.get() except queryset.model.DoesNotExist: - raise Http404(_("No %(verbose_name)s found matching the query") % - {'verbose_name': queryset.model._meta.verbose_name}) + raise Http404( + _("No %(verbose_name)s found matching the query") + % {"verbose_name": queryset.model._meta.verbose_name} + ) return obj def get_queryset(self): @@ -69,9 +72,7 @@ class SingleObjectMixin(ContextMixin): raise ImproperlyConfigured( "%(cls)s is missing a QuerySet. Define " "%(cls)s.model, %(cls)s.queryset, or override " - "%(cls)s.get_queryset()." % { - 'cls': self.__class__.__name__ - } + "%(cls)s.get_queryset()." % {"cls": self.__class__.__name__} ) return self.queryset.all() @@ -92,7 +93,7 @@ class SingleObjectMixin(ContextMixin): """Insert the single object into the context dict.""" context = {} if self.object: - context['object'] = self.object + context["object"] = self.object context_object_name = self.get_context_object_name(self.object) if context_object_name: context[context_object_name] = self.object @@ -102,6 +103,7 @@ class SingleObjectMixin(ContextMixin): class BaseDetailView(SingleObjectMixin, View): """A base view for displaying a single object.""" + def get(self, request, *args, **kwargs): self.object = self.get_object() context = self.get_context_data(object=self.object) @@ -110,7 +112,7 @@ class BaseDetailView(SingleObjectMixin, View): class SingleObjectTemplateResponseMixin(TemplateResponseMixin): template_name_field = None - template_name_suffix = '_detail' + template_name_suffix = "_detail" def get_template_names(self): """ @@ -141,17 +143,25 @@ class SingleObjectTemplateResponseMixin(TemplateResponseMixin): # only use this if the object in question is a model. if isinstance(self.object, models.Model): object_meta = self.object._meta - names.append("%s/%s%s.html" % ( - object_meta.app_label, - object_meta.model_name, - self.template_name_suffix - )) - elif getattr(self, 'model', None) is not None and issubclass(self.model, models.Model): - names.append("%s/%s%s.html" % ( - self.model._meta.app_label, - self.model._meta.model_name, - self.template_name_suffix - )) + names.append( + "%s/%s%s.html" + % ( + object_meta.app_label, + object_meta.model_name, + self.template_name_suffix, + ) + ) + elif getattr(self, "model", None) is not None and issubclass( + self.model, models.Model + ): + names.append( + "%s/%s%s.html" + % ( + self.model._meta.app_label, + self.model._meta.model_name, + self.template_name_suffix, + ) + ) # If we still haven't managed to find any template names, we should # re-raise the ImproperlyConfigured to alert the user. diff --git a/django/views/generic/edit.py b/django/views/generic/edit.py index 8cbdeb8d13..e1f37bc593 100644 --- a/django/views/generic/edit.py +++ b/django/views/generic/edit.py @@ -1,16 +1,20 @@ import warnings from django.core.exceptions import ImproperlyConfigured -from django.forms import Form, models as model_forms +from django.forms import Form +from django.forms import models as model_forms from django.http import HttpResponseRedirect from django.views.generic.base import ContextMixin, TemplateResponseMixin, View from django.views.generic.detail import ( - BaseDetailView, SingleObjectMixin, SingleObjectTemplateResponseMixin, + BaseDetailView, + SingleObjectMixin, + SingleObjectTemplateResponseMixin, ) class FormMixin(ContextMixin): """Provide a way to show and handle a form in a request.""" + initial = {} form_class = None success_url = None @@ -37,15 +41,17 @@ class FormMixin(ContextMixin): def get_form_kwargs(self): """Return the keyword arguments for instantiating the form.""" kwargs = { - 'initial': self.get_initial(), - 'prefix': self.get_prefix(), + "initial": self.get_initial(), + "prefix": self.get_prefix(), } - if self.request.method in ('POST', 'PUT'): - kwargs.update({ - 'data': self.request.POST, - 'files': self.request.FILES, - }) + if self.request.method in ("POST", "PUT"): + kwargs.update( + { + "data": self.request.POST, + "files": self.request.FILES, + } + ) return kwargs def get_success_url(self): @@ -64,13 +70,14 @@ class FormMixin(ContextMixin): def get_context_data(self, **kwargs): """Insert the form into the context dict.""" - if 'form' not in kwargs: - kwargs['form'] = self.get_form() + if "form" not in kwargs: + kwargs["form"] = self.get_form() return super().get_context_data(**kwargs) class ModelFormMixin(FormMixin, SingleObjectMixin): """Provide a way to show and handle a ModelForm in a request.""" + fields = None def get_form_class(self): @@ -85,7 +92,7 @@ class ModelFormMixin(FormMixin, SingleObjectMixin): if self.model is not None: # If a model has been explicitly provided, use it model = self.model - elif getattr(self, 'object', None) is not None: + elif getattr(self, "object", None) is not None: # If this view is operating on a single object, use # the class of that object model = self.object.__class__ @@ -105,8 +112,8 @@ class ModelFormMixin(FormMixin, SingleObjectMixin): def get_form_kwargs(self): """Return the keyword arguments for instantiating the form.""" kwargs = super().get_form_kwargs() - if hasattr(self, 'object'): - kwargs.update({'instance': self.object}) + if hasattr(self, "object"): + kwargs.update({"instance": self.object}) return kwargs def get_success_url(self): @@ -119,7 +126,8 @@ class ModelFormMixin(FormMixin, SingleObjectMixin): except AttributeError: raise ImproperlyConfigured( "No URL to redirect to. Either provide a url or define" - " a get_absolute_url method on the Model.") + " a get_absolute_url method on the Model." + ) return url def form_valid(self, form): @@ -130,6 +138,7 @@ class ModelFormMixin(FormMixin, SingleObjectMixin): class ProcessFormView(View): """Render a form on GET and processes it on POST.""" + def get(self, request, *args, **kwargs): """Handle GET requests: instantiate a blank version of the form.""" return self.render_to_response(self.get_context_data()) @@ -165,6 +174,7 @@ class BaseCreateView(ModelFormMixin, ProcessFormView): Using this base class requires subclassing to provide a response mixin. """ + def get(self, request, *args, **kwargs): self.object = None return super().get(request, *args, **kwargs) @@ -178,7 +188,8 @@ class CreateView(SingleObjectTemplateResponseMixin, BaseCreateView): """ View for creating a new object, with a response rendered by a template. """ - template_name_suffix = '_form' + + template_name_suffix = "_form" class BaseUpdateView(ModelFormMixin, ProcessFormView): @@ -187,6 +198,7 @@ class BaseUpdateView(ModelFormMixin, ProcessFormView): Using this base class requires subclassing to provide a response mixin. """ + def get(self, request, *args, **kwargs): self.object = self.get_object() return super().get(request, *args, **kwargs) @@ -198,11 +210,13 @@ class BaseUpdateView(ModelFormMixin, ProcessFormView): class UpdateView(SingleObjectTemplateResponseMixin, BaseUpdateView): """View for updating an object, with a response rendered by a template.""" - template_name_suffix = '_form' + + template_name_suffix = "_form" class DeletionMixin: """Provide the ability to delete objects.""" + success_url = None def delete(self, request, *args, **kwargs): @@ -223,8 +237,7 @@ class DeletionMixin: if self.success_url: return self.success_url.format(**self.object.__dict__) else: - raise ImproperlyConfigured( - "No URL to redirect to. Provide a success_url.") + raise ImproperlyConfigured("No URL to redirect to. Provide a success_url.") # RemovedInDjango50Warning. @@ -238,16 +251,17 @@ class BaseDeleteView(DeletionMixin, FormMixin, BaseDetailView): Using this base class requires subclassing to provide a response mixin. """ + form_class = Form def __init__(self, *args, **kwargs): # RemovedInDjango50Warning. if self.__class__.delete is not DeletionMixin.delete: warnings.warn( - f'DeleteView uses FormMixin to handle POST requests. As a ' - f'consequence, any custom deletion logic in ' - f'{self.__class__.__name__}.delete() handler should be moved ' - f'to form_valid().', + f"DeleteView uses FormMixin to handle POST requests. As a " + f"consequence, any custom deletion logic in " + f"{self.__class__.__name__}.delete() handler should be moved " + f"to form_valid().", DeleteViewCustomDeleteWarning, stacklevel=2, ) @@ -276,4 +290,5 @@ class DeleteView(SingleObjectTemplateResponseMixin, BaseDeleteView): View for deleting an object retrieved with self.get_object(), with a response rendered by a template. """ - template_name_suffix = '_confirm_delete' + + template_name_suffix = "_confirm_delete" diff --git a/django/views/generic/list.py b/django/views/generic/list.py index 65bae2b498..830a8df630 100644 --- a/django/views/generic/list.py +++ b/django/views/generic/list.py @@ -8,6 +8,7 @@ from django.views.generic.base import ContextMixin, TemplateResponseMixin, View class MultipleObjectMixin(ContextMixin): """A mixin for views manipulating multiple objects.""" + allow_empty = True queryset = None model = None @@ -15,7 +16,7 @@ class MultipleObjectMixin(ContextMixin): paginate_orphans = 0 context_object_name = None paginator_class = Paginator - page_kwarg = 'page' + page_kwarg = "page" ordering = None def get_queryset(self): @@ -35,9 +36,7 @@ class MultipleObjectMixin(ContextMixin): raise ImproperlyConfigured( "%(cls)s is missing a QuerySet. Define " "%(cls)s.model, %(cls)s.queryset, or override " - "%(cls)s.get_queryset()." % { - 'cls': self.__class__.__name__ - } + "%(cls)s.get_queryset()." % {"cls": self.__class__.__name__} ) ordering = self.get_ordering() if ordering: @@ -54,25 +53,30 @@ class MultipleObjectMixin(ContextMixin): def paginate_queryset(self, queryset, page_size): """Paginate the queryset, if needed.""" paginator = self.get_paginator( - queryset, page_size, orphans=self.get_paginate_orphans(), - allow_empty_first_page=self.get_allow_empty()) + queryset, + page_size, + orphans=self.get_paginate_orphans(), + allow_empty_first_page=self.get_allow_empty(), + ) page_kwarg = self.page_kwarg page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 try: page_number = int(page) except ValueError: - if page == 'last': + if page == "last": page_number = paginator.num_pages else: - raise Http404(_('Page is not “last”, nor can it be converted to an int.')) + raise Http404( + _("Page is not “last”, nor can it be converted to an int.") + ) try: page = paginator.page(page_number) return (paginator, page, page.object_list, page.has_other_pages()) except InvalidPage as e: - raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { - 'page_number': page_number, - 'message': str(e) - }) + raise Http404( + _("Invalid page (%(page_number)s): %(message)s") + % {"page_number": page_number, "message": str(e)} + ) def get_paginate_by(self, queryset): """ @@ -80,12 +84,17 @@ class MultipleObjectMixin(ContextMixin): """ return self.paginate_by - def get_paginator(self, queryset, per_page, orphans=0, - allow_empty_first_page=True, **kwargs): + def get_paginator( + self, queryset, per_page, orphans=0, allow_empty_first_page=True, **kwargs + ): """Return an instance of the paginator for this view.""" return self.paginator_class( - queryset, per_page, orphans=orphans, - allow_empty_first_page=allow_empty_first_page, **kwargs) + queryset, + per_page, + orphans=orphans, + allow_empty_first_page=allow_empty_first_page, + **kwargs, + ) def get_paginate_orphans(self): """ @@ -105,8 +114,8 @@ class MultipleObjectMixin(ContextMixin): """Get the name of the item to be used in the context.""" if self.context_object_name: return self.context_object_name - elif hasattr(object_list, 'model'): - return '%s_list' % object_list.model._meta.model_name + elif hasattr(object_list, "model"): + return "%s_list" % object_list.model._meta.model_name else: return None @@ -116,19 +125,21 @@ class MultipleObjectMixin(ContextMixin): page_size = self.get_paginate_by(queryset) context_object_name = self.get_context_object_name(queryset) if page_size: - paginator, page, queryset, is_paginated = self.paginate_queryset(queryset, page_size) + paginator, page, queryset, is_paginated = self.paginate_queryset( + queryset, page_size + ) context = { - 'paginator': paginator, - 'page_obj': page, - 'is_paginated': is_paginated, - 'object_list': queryset + "paginator": paginator, + "page_obj": page, + "is_paginated": is_paginated, + "object_list": queryset, } else: context = { - 'paginator': None, - 'page_obj': None, - 'is_paginated': False, - 'object_list': queryset + "paginator": None, + "page_obj": None, + "is_paginated": False, + "object_list": queryset, } if context_object_name is not None: context[context_object_name] = queryset @@ -138,6 +149,7 @@ class MultipleObjectMixin(ContextMixin): class BaseListView(MultipleObjectMixin, View): """A base view for displaying a list of objects.""" + def get(self, request, *args, **kwargs): self.object_list = self.get_queryset() allow_empty = self.get_allow_empty() @@ -146,21 +158,27 @@ class BaseListView(MultipleObjectMixin, View): # When pagination is enabled and object_list is a queryset, # it's better to do a cheap query than to load the unpaginated # queryset in memory. - if self.get_paginate_by(self.object_list) is not None and hasattr(self.object_list, 'exists'): + if self.get_paginate_by(self.object_list) is not None and hasattr( + self.object_list, "exists" + ): is_empty = not self.object_list.exists() else: is_empty = not self.object_list if is_empty: - raise Http404(_('Empty list and “%(class_name)s.allow_empty” is False.') % { - 'class_name': self.__class__.__name__, - }) + raise Http404( + _("Empty list and “%(class_name)s.allow_empty” is False.") + % { + "class_name": self.__class__.__name__, + } + ) context = self.get_context_data() return self.render_to_response(context) class MultipleObjectTemplateResponseMixin(TemplateResponseMixin): """Mixin for responding with a template and list of objects.""" - template_name_suffix = '_list' + + template_name_suffix = "_list" def get_template_names(self): """ @@ -178,14 +196,18 @@ class MultipleObjectTemplateResponseMixin(TemplateResponseMixin): # app and model name. This name gets put at the end of the template # name list so that user-supplied names override the automatically- # generated ones. - if hasattr(self.object_list, 'model'): + if hasattr(self.object_list, "model"): opts = self.object_list.model._meta - names.append("%s/%s%s.html" % (opts.app_label, opts.model_name, self.template_name_suffix)) + names.append( + "%s/%s%s.html" + % (opts.app_label, opts.model_name, self.template_name_suffix) + ) elif not names: raise ImproperlyConfigured( "%(cls)s requires either a 'template_name' attribute " - "or a get_queryset() method that returns a QuerySet." % { - 'cls': self.__class__.__name__, + "or a get_queryset() method that returns a QuerySet." + % { + "cls": self.__class__.__name__, } ) return names diff --git a/django/views/i18n.py b/django/views/i18n.py index 3a93c30d0a..b7eca510aa 100644 --- a/django/views/i18n.py +++ b/django/views/i18n.py @@ -14,7 +14,7 @@ from django.utils.translation import check_for_language, get_language from django.utils.translation.trans_real import DjangoTranslation from django.views.generic import View -LANGUAGE_QUERY_PARAMETER = 'language' +LANGUAGE_QUERY_PARAMETER = "language" def set_language(request): @@ -28,24 +28,23 @@ def set_language(request): redirect to the page in the request (the 'next' parameter) without changing any state. """ - next_url = request.POST.get('next', request.GET.get('next')) + next_url = request.POST.get("next", request.GET.get("next")) if ( - (next_url or request.accepts('text/html')) and - not url_has_allowed_host_and_scheme( - url=next_url, - allowed_hosts={request.get_host()}, - require_https=request.is_secure(), - ) + next_url or request.accepts("text/html") + ) and not url_has_allowed_host_and_scheme( + url=next_url, + allowed_hosts={request.get_host()}, + require_https=request.is_secure(), ): - next_url = request.META.get('HTTP_REFERER') + next_url = request.META.get("HTTP_REFERER") if not url_has_allowed_host_and_scheme( url=next_url, allowed_hosts={request.get_host()}, require_https=request.is_secure(), ): - next_url = '/' + next_url = "/" response = HttpResponseRedirect(next_url) if next_url else HttpResponse(status=204) - if request.method == 'POST': + if request.method == "POST": lang_code = request.POST.get(LANGUAGE_QUERY_PARAMETER) if lang_code and check_for_language(lang_code): if next_url: @@ -53,7 +52,8 @@ def set_language(request): if next_trans != next_url: response = HttpResponseRedirect(next_trans) response.set_cookie( - settings.LANGUAGE_COOKIE_NAME, lang_code, + settings.LANGUAGE_COOKIE_NAME, + lang_code, max_age=settings.LANGUAGE_COOKIE_AGE, path=settings.LANGUAGE_COOKIE_PATH, domain=settings.LANGUAGE_COOKIE_DOMAIN, @@ -67,11 +67,20 @@ def set_language(request): def get_formats(): """Return all formats strings required for i18n to work.""" FORMAT_SETTINGS = ( - 'DATE_FORMAT', 'DATETIME_FORMAT', 'TIME_FORMAT', - 'YEAR_MONTH_FORMAT', 'MONTH_DAY_FORMAT', 'SHORT_DATE_FORMAT', - 'SHORT_DATETIME_FORMAT', 'FIRST_DAY_OF_WEEK', 'DECIMAL_SEPARATOR', - 'THOUSAND_SEPARATOR', 'NUMBER_GROUPING', - 'DATE_INPUT_FORMATS', 'TIME_INPUT_FORMATS', 'DATETIME_INPUT_FORMATS' + "DATE_FORMAT", + "DATETIME_FORMAT", + "TIME_FORMAT", + "YEAR_MONTH_FORMAT", + "MONTH_DAY_FORMAT", + "SHORT_DATE_FORMAT", + "SHORT_DATETIME_FORMAT", + "FIRST_DAY_OF_WEEK", + "DECIMAL_SEPARATOR", + "THOUSAND_SEPARATOR", + "NUMBER_GROUPING", + "DATE_INPUT_FORMATS", + "TIME_INPUT_FORMATS", + "DATETIME_INPUT_FORMATS", ) return {attr: get_format(attr) for attr in FORMAT_SETTINGS} @@ -194,31 +203,37 @@ class JavaScriptCatalog(View): want to do that as JavaScript messages go to the djangojs domain. This might be needed if you deliver your JavaScript source from Django templates. """ - domain = 'djangojs' + + domain = "djangojs" packages = None def get(self, request, *args, **kwargs): locale = get_language() - domain = kwargs.get('domain', self.domain) + domain = kwargs.get("domain", self.domain) # If packages are not provided, default to all installed packages, as # DjangoTranslation without localedirs harvests them all. - packages = kwargs.get('packages', '') - packages = packages.split('+') if packages else self.packages + packages = kwargs.get("packages", "") + packages = packages.split("+") if packages else self.packages paths = self.get_paths(packages) if packages else None self.translation = DjangoTranslation(locale, domain=domain, localedirs=paths) context = self.get_context_data(**kwargs) return self.render_to_response(context) def get_paths(self, packages): - allowable_packages = {app_config.name: app_config for app_config in apps.get_app_configs()} - app_configs = [allowable_packages[p] for p in packages if p in allowable_packages] + allowable_packages = { + app_config.name: app_config for app_config in apps.get_app_configs() + } + app_configs = [ + allowable_packages[p] for p in packages if p in allowable_packages + ] if len(app_configs) < len(packages): excluded = [p for p in packages if p not in allowable_packages] raise ValueError( - 'Invalid package(s) provided to JavaScriptCatalog: %s' % ','.join(excluded) + "Invalid package(s) provided to JavaScriptCatalog: %s" + % ",".join(excluded) ) # paths of requested packages - return [os.path.join(app.path, 'locale') for app in app_configs] + return [os.path.join(app.path, "locale") for app in app_configs] @property def _num_plurals(self): @@ -226,7 +241,7 @@ class JavaScriptCatalog(View): Return the number of plurals for this catalog language, or 2 if no plural string is available. """ - match = re.search(r'nplurals=\s*(\d+)', self._plural_string or '') + match = re.search(r"nplurals=\s*(\d+)", self._plural_string or "") if match: return int(match[1]) return 2 @@ -237,10 +252,10 @@ class JavaScriptCatalog(View): Return the plural string (including nplurals) for this catalog language, or None if no plural string is available. """ - if '' in self.translation._catalog: - for line in self.translation._catalog[''].split('\n'): - if line.startswith('Plural-Forms:'): - return line.split(':', 1)[1].strip() + if "" in self.translation._catalog: + for line in self.translation._catalog[""].split("\n"): + if line.startswith("Plural-Forms:"): + return line.split(":", 1)[1].strip() return None def get_plural(self): @@ -249,7 +264,11 @@ class JavaScriptCatalog(View): # This should be a compiled function of a typical plural-form: # Plural-Forms: nplurals=3; plural=n%10==1 && n%100!=11 ? 0 : # n%10>=2 && n%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2; - plural = [el.strip() for el in plural.split(';') if el.strip().startswith('plural=')][0].split('=', 1)[1] + plural = [ + el.strip() + for el in plural.split(";") + if el.strip().startswith("plural=") + ][0].split("=", 1)[1] return plural def get_catalog(self): @@ -257,10 +276,14 @@ class JavaScriptCatalog(View): num_plurals = self._num_plurals catalog = {} trans_cat = self.translation._catalog - trans_fallback_cat = self.translation._fallback._catalog if self.translation._fallback else {} + trans_fallback_cat = ( + self.translation._fallback._catalog if self.translation._fallback else {} + ) seen_keys = set() - for key, value in itertools.chain(trans_cat.items(), trans_fallback_cat.items()): - if key == '' or key in seen_keys: + for key, value in itertools.chain( + trans_cat.items(), trans_fallback_cat.items() + ): + if key == "" or key in seen_keys: continue if isinstance(key, str): catalog[key] = value @@ -271,27 +294,33 @@ class JavaScriptCatalog(View): raise TypeError(key) seen_keys.add(key) for k, v in pdict.items(): - catalog[k] = [v.get(i, '') for i in range(num_plurals)] + catalog[k] = [v.get(i, "") for i in range(num_plurals)] return catalog def get_context_data(self, **kwargs): return { - 'catalog': self.get_catalog(), - 'formats': get_formats(), - 'plural': self.get_plural(), + "catalog": self.get_catalog(), + "formats": get_formats(), + "plural": self.get_plural(), } def render_to_response(self, context, **response_kwargs): def indent(s): - return s.replace('\n', '\n ') + return s.replace("\n", "\n ") template = Engine().from_string(js_catalog_template) - context['catalog_str'] = indent( - json.dumps(context['catalog'], sort_keys=True, indent=2) - ) if context['catalog'] else None - context['formats_str'] = indent(json.dumps(context['formats'], sort_keys=True, indent=2)) + context["catalog_str"] = ( + indent(json.dumps(context["catalog"], sort_keys=True, indent=2)) + if context["catalog"] + else None + ) + context["formats_str"] = indent( + json.dumps(context["formats"], sort_keys=True, indent=2) + ) - return HttpResponse(template.render(Context(context)), 'text/javascript; charset="utf-8"') + return HttpResponse( + template.render(Context(context)), 'text/javascript; charset="utf-8"' + ) class JSONCatalog(JavaScriptCatalog): @@ -311,5 +340,6 @@ class JSONCatalog(JavaScriptCatalog): "plural": '...' # Expression for plural forms, or null. } """ + def render_to_response(self, context, **response_kwargs): return JsonResponse(context) diff --git a/django/views/static.py b/django/views/static.py index 1d4900b1da..1c558a53ff 100644 --- a/django/views/static.py +++ b/django/views/static.py @@ -7,13 +7,12 @@ import posixpath import re from pathlib import Path -from django.http import ( - FileResponse, Http404, HttpResponse, HttpResponseNotModified, -) +from django.http import FileResponse, Http404, HttpResponse, HttpResponseNotModified from django.template import Context, Engine, TemplateDoesNotExist, loader from django.utils._os import safe_join from django.utils.http import http_date, parse_http_date -from django.utils.translation import gettext as _, gettext_lazy +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy def serve(request, path, document_root=None, show_indexes=False): @@ -32,22 +31,23 @@ def serve(request, path, document_root=None, show_indexes=False): but if you'd like to override it, you can create a template called ``static/directory_index.html``. """ - path = posixpath.normpath(path).lstrip('/') + path = posixpath.normpath(path).lstrip("/") fullpath = Path(safe_join(document_root, path)) if fullpath.is_dir(): if show_indexes: return directory_index(path, fullpath) raise Http404(_("Directory indexes are not allowed here.")) if not fullpath.exists(): - raise Http404(_('“%(path)s” does not exist') % {'path': fullpath}) + raise Http404(_("“%(path)s” does not exist") % {"path": fullpath}) # Respect the If-Modified-Since header. statobj = fullpath.stat() - if not was_modified_since(request.META.get('HTTP_IF_MODIFIED_SINCE'), - statobj.st_mtime, statobj.st_size): + if not was_modified_since( + request.META.get("HTTP_IF_MODIFIED_SINCE"), statobj.st_mtime, statobj.st_size + ): return HttpResponseNotModified() content_type, encoding = mimetypes.guess_type(str(fullpath)) - content_type = content_type or 'application/octet-stream' - response = FileResponse(fullpath.open('rb'), content_type=content_type) + content_type = content_type or "application/octet-stream" + response = FileResponse(fullpath.open("rb"), content_type=content_type) response.headers["Last-Modified"] = http_date(statobj.st_mtime) if encoding: response.headers["Content-Encoding"] = encoding @@ -82,26 +82,32 @@ template_translatable = gettext_lazy("Index of %(directory)s") def directory_index(path, fullpath): try: - t = loader.select_template([ - 'static/directory_index.html', - 'static/directory_index', - ]) + t = loader.select_template( + [ + "static/directory_index.html", + "static/directory_index", + ] + ) except TemplateDoesNotExist: - t = Engine(libraries={'i18n': 'django.templatetags.i18n'}).from_string(DEFAULT_DIRECTORY_INDEX_TEMPLATE) + t = Engine(libraries={"i18n": "django.templatetags.i18n"}).from_string( + DEFAULT_DIRECTORY_INDEX_TEMPLATE + ) c = Context() else: c = {} files = [] for f in fullpath.iterdir(): - if not f.name.startswith('.'): + if not f.name.startswith("."): url = str(f.relative_to(fullpath)) if f.is_dir(): - url += '/' + url += "/" files.append(url) - c.update({ - 'directory': path + '/', - 'file_list': files, - }) + c.update( + { + "directory": path + "/", + "file_list": files, + } + ) return HttpResponse(t.render(c)) @@ -122,8 +128,7 @@ def was_modified_since(header=None, mtime=0, size=0): try: if header is None: raise ValueError - matches = re.match(r"^([^;]+)(; length=([0-9]+))?$", header, - re.IGNORECASE) + matches = re.match(r"^([^;]+)(; length=([0-9]+))?$", header, re.IGNORECASE) header_mtime = parse_http_date(matches[1]) header_len = matches[3] if header_len and int(header_len) != size: diff --git a/docs/_ext/djangodocs.py b/docs/_ext/djangodocs.py index 2829d581cd..f3c4321499 100644 --- a/docs/_ext/djangodocs.py +++ b/docs/_ext/djangodocs.py @@ -8,7 +8,8 @@ import re from docutils import nodes from docutils.parsers.rst import Directive from docutils.statemachine import ViewList -from sphinx import addnodes, version_info as sphinx_version +from sphinx import addnodes +from sphinx import version_info as sphinx_version from sphinx.builders.html import StandaloneHTMLBuilder from sphinx.directives.code import CodeBlock from sphinx.domains.std import Cmdoption @@ -19,8 +20,7 @@ from sphinx.writers.html import HTMLTranslator logger = logging.getLogger(__name__) # RE for option descriptions without a '--' prefix -simple_option_desc_re = re.compile( - r'([-_a-zA-Z0-9]+)(\s*.*?)(?=,\s+(?:/|-|--)|$)') +simple_option_desc_re = re.compile(r"([-_a-zA-Z0-9]+)(\s*.*?)(?=,\s+(?:/|-|--)|$)") def setup(app): @@ -32,12 +32,12 @@ def setup(app): app.add_crossref_type( directivename="templatetag", rolename="ttag", - indextemplate="pair: %s; template tag" + indextemplate="pair: %s; template tag", ) app.add_crossref_type( directivename="templatefilter", rolename="tfilter", - indextemplate="pair: %s; template filter" + indextemplate="pair: %s; template filter", ) app.add_crossref_type( directivename="fieldlookup", @@ -50,13 +50,13 @@ def setup(app): indextemplate="pair: %s; django-admin command", parse_node=parse_django_admin_node, ) - app.add_directive('django-admin-option', Cmdoption) - app.add_config_value('django_next_version', '0.0', True) - app.add_directive('versionadded', VersionDirective) - app.add_directive('versionchanged', VersionDirective) + app.add_directive("django-admin-option", Cmdoption) + app.add_config_value("django_next_version", "0.0", True) + app.add_directive("versionadded", VersionDirective) + app.add_directive("versionchanged", VersionDirective) app.add_builder(DjangoStandaloneHTMLBuilder) - app.set_translator('djangohtml', DjangoHTMLTranslator) - app.set_translator('json', DjangoHTMLTranslator) + app.set_translator("djangohtml", DjangoHTMLTranslator) + app.set_translator("json", DjangoHTMLTranslator) app.add_node( ConsoleNode, html=(visit_console_html, None), @@ -65,10 +65,10 @@ def setup(app): text=(visit_console_dummy, depart_console_dummy), texinfo=(visit_console_dummy, depart_console_dummy), ) - app.add_directive('console', ConsoleDirective) - app.connect('html-page-context', html_page_context_hook) - app.add_role('default-role-error', default_role_error) - return {'parallel_read_safe': True} + app.add_directive("console", ConsoleDirective) + app.connect("html-page-context", html_page_context_hook) + app.add_role("default-role-error", default_role_error) + return {"parallel_read_safe": True} class VersionDirective(Directive): @@ -82,7 +82,9 @@ class VersionDirective(Directive): if len(self.arguments) > 1: msg = """Only one argument accepted for directive '{directive_name}::'. Comments should be provided as content, - not as an extra argument.""".format(directive_name=self.name) + not as an extra argument.""".format( + directive_name=self.name + ) raise self.error(msg) env = self.state.document.settings.env @@ -91,18 +93,18 @@ class VersionDirective(Directive): ret.append(node) if self.arguments[0] == env.config.django_next_version: - node['version'] = "Development version" + node["version"] = "Development version" else: - node['version'] = self.arguments[0] + node["version"] = self.arguments[0] - node['type'] = self.name + node["type"] = self.name if self.content: self.state.nested_parse(self.content, self.content_offset, node) try: - env.get_domain('changeset').note_changeset(node) + env.get_domain("changeset").note_changeset(node) except ExtensionError: # Sphinx < 1.8: Domain 'changeset' is not registered - env.note_versionchange(node['type'], node['version'], node, self.lineno) + env.note_versionchange(node["type"], node["version"], node, self.lineno) return ret @@ -120,23 +122,25 @@ class DjangoHTMLTranslator(HTMLTranslator): self._table_row_indices.append(0) else: self._table_row_index = 0 - self.body.append(self.starttag(node, 'table', CLASS='docutils')) + self.body.append(self.starttag(node, "table", CLASS="docutils")) def depart_table(self, node): self.compact_p = self.context.pop() if sphinx_version >= (4, 3): self._table_row_indices.pop() - self.body.append('</table>\n') + self.body.append("</table>\n") def visit_desc_parameterlist(self, node): - self.body.append('(') # by default sphinx puts <big> around the "(" + self.body.append("(") # by default sphinx puts <big> around the "(" self.first_param = 1 self.optional_param_level = 0 self.param_separator = node.child_text_separator - self.required_params_left = sum(isinstance(c, addnodes.desc_parameter) for c in node.children) + self.required_params_left = sum( + isinstance(c, addnodes.desc_parameter) for c in node.children + ) def depart_desc_parameterlist(self, node): - self.body.append(')') + self.body.append(")") # # Turn the "new in version" stuff (versionadded/versionchanged) into a @@ -148,20 +152,15 @@ class DjangoHTMLTranslator(HTMLTranslator): # that work. # version_text = { - 'versionchanged': 'Changed in Django %s', - 'versionadded': 'New in Django %s', + "versionchanged": "Changed in Django %s", + "versionadded": "New in Django %s", } def visit_versionmodified(self, node): - self.body.append( - self.starttag(node, 'div', CLASS=node['type']) - ) - version_text = self.version_text.get(node['type']) + self.body.append(self.starttag(node, "div", CLASS=node["type"])) + version_text = self.version_text.get(node["type"]) if version_text: - title = "%s%s" % ( - version_text % node['version'], - ":" if len(node) else "." - ) + title = "%s%s" % (version_text % node["version"], ":" if len(node) else ".") self.body.append('<span class="title">%s</span> ' % title) def depart_versionmodified(self, node): @@ -169,16 +168,16 @@ class DjangoHTMLTranslator(HTMLTranslator): # Give each section a unique ID -- nice for custom CSS hooks def visit_section(self, node): - old_ids = node.get('ids', []) - node['ids'] = ['s-' + i for i in old_ids] - node['ids'].extend(old_ids) + old_ids = node.get("ids", []) + node["ids"] = ["s-" + i for i in old_ids] + node["ids"].extend(old_ids) super().visit_section(node) - node['ids'] = old_ids + node["ids"] = old_ids def parse_django_admin_node(env, sig, signode): - command = sig.split(' ')[0] - env.ref_context['std:program'] = command + command = sig.split(" ")[0] + env.ref_context["std:program"] = command title = "django-admin %s" % sig signode += addnodes.desc_name(title, title) return command @@ -189,7 +188,7 @@ class DjangoStandaloneHTMLBuilder(StandaloneHTMLBuilder): Subclass to add some extra things we need. """ - name = 'djangohtml' + name = "djangohtml" def finish(self): super().finish() @@ -197,19 +196,21 @@ class DjangoStandaloneHTMLBuilder(StandaloneHTMLBuilder): xrefs = self.env.domaindata["std"]["objects"] templatebuiltins = { "ttags": [ - n for ((t, n), (k, a)) in xrefs.items() + n + for ((t, n), (k, a)) in xrefs.items() if t == "templatetag" and k == "ref/templates/builtins" ], "tfilters": [ - n for ((t, n), (k, a)) in xrefs.items() + n + for ((t, n), (k, a)) in xrefs.items() if t == "templatefilter" and k == "ref/templates/builtins" ], } outfilename = os.path.join(self.outdir, "templatebuiltins.js") - with open(outfilename, 'w') as fp: - fp.write('var django_template_builtins = ') + with open(outfilename, "w") as fp: + fp.write("var django_template_builtins = ") json.dump(templatebuiltins, fp) - fp.write(';\n') + fp.write(";\n") class ConsoleNode(nodes.literal_block): @@ -217,13 +218,14 @@ class ConsoleNode(nodes.literal_block): Custom node to override the visit/depart event handlers at registration time. Wrap a literal_block object and defer to it. """ - tagname = 'ConsoleNode' + + tagname = "ConsoleNode" def __init__(self, litblk_obj): self.wrapped = litblk_obj def __getattr__(self, attr): - if attr == 'wrapped': + if attr == "wrapped": return self.__dict__.wrapped return getattr(self.wrapped, attr) @@ -240,38 +242,43 @@ def depart_console_dummy(self, node): def visit_console_html(self, node): """Generate HTML for the console directive.""" - if self.builder.name in ('djangohtml', 'json') and node['win_console_text']: + if self.builder.name in ("djangohtml", "json") and node["win_console_text"]: # Put a mark on the document object signaling the fact the directive # has been used on it. self.document._console_directive_used_flag = True - uid = node['uid'] - self.body.append('''\ + uid = node["uid"] + self.body.append( + """\ <div class="console-block" id="console-block-%(id)s"> <input class="c-tab-unix" id="c-tab-%(id)s-unix" type="radio" name="console-%(id)s" checked> <label for="c-tab-%(id)s-unix" title="Linux/macOS">/</label> <input class="c-tab-win" id="c-tab-%(id)s-win" type="radio" name="console-%(id)s"> <label for="c-tab-%(id)s-win" title="Windows"></label> -<section class="c-content-unix" id="c-content-%(id)s-unix">\n''' % {'id': uid}) +<section class="c-content-unix" id="c-content-%(id)s-unix">\n""" + % {"id": uid} + ) try: self.visit_literal_block(node) except nodes.SkipNode: pass - self.body.append('</section>\n') + self.body.append("</section>\n") - self.body.append('<section class="c-content-win" id="c-content-%(id)s-win">\n' % {'id': uid}) - win_text = node['win_console_text'] - highlight_args = {'force': True} - linenos = node.get('linenos', False) + self.body.append( + '<section class="c-content-win" id="c-content-%(id)s-win">\n' % {"id": uid} + ) + win_text = node["win_console_text"] + highlight_args = {"force": True} + linenos = node.get("linenos", False) def warner(msg): self.builder.warn(msg, (self.builder.current_docname, node.line)) highlighted = self.highlighter.highlight_block( - win_text, 'doscon', warn=warner, linenos=linenos, **highlight_args + win_text, "doscon", warn=warner, linenos=linenos, **highlight_args ) self.body.append(highlighted) - self.body.append('</section>\n') - self.body.append('</div>\n') + self.body.append("</section>\n") + self.body.append("</div>\n") raise nodes.SkipNode else: self.visit_literal_block(node) @@ -283,54 +290,54 @@ class ConsoleDirective(CodeBlock): the second tab shows a Windows command line equivalent of the usual Unix-oriented examples. """ + required_arguments = 0 # The 'doscon' Pygments formatter needs a prompt like this. '>' alone # won't do it because then it simply paints the whole command line as a # gray comment with no highlighting at all. - WIN_PROMPT = r'...\> ' + WIN_PROMPT = r"...\> " def run(self): - def args_to_win(cmdline): changed = False out = [] for token in cmdline.split(): - if token[:2] == './': + if token[:2] == "./": token = token[2:] changed = True - elif token[:2] == '~/': - token = '%HOMEPATH%\\' + token[2:] + elif token[:2] == "~/": + token = "%HOMEPATH%\\" + token[2:] changed = True - elif token == 'make': - token = 'make.bat' + elif token == "make": + token = "make.bat" changed = True - if '://' not in token and 'git' not in cmdline: - out.append(token.replace('/', '\\')) + if "://" not in token and "git" not in cmdline: + out.append(token.replace("/", "\\")) changed = True else: out.append(token) if changed: - return ' '.join(out) + return " ".join(out) return cmdline def cmdline_to_win(line): - if line.startswith('# '): - return 'REM ' + args_to_win(line[2:]) - if line.startswith('$ # '): - return 'REM ' + args_to_win(line[4:]) - if line.startswith('$ ./manage.py'): - return 'manage.py ' + args_to_win(line[13:]) - if line.startswith('$ manage.py'): - return 'manage.py ' + args_to_win(line[11:]) - if line.startswith('$ ./runtests.py'): - return 'runtests.py ' + args_to_win(line[15:]) - if line.startswith('$ ./'): + if line.startswith("# "): + return "REM " + args_to_win(line[2:]) + if line.startswith("$ # "): + return "REM " + args_to_win(line[4:]) + if line.startswith("$ ./manage.py"): + return "manage.py " + args_to_win(line[13:]) + if line.startswith("$ manage.py"): + return "manage.py " + args_to_win(line[11:]) + if line.startswith("$ ./runtests.py"): + return "runtests.py " + args_to_win(line[15:]) + if line.startswith("$ ./"): return args_to_win(line[4:]) - if line.startswith('$ python3'): - return 'py ' + args_to_win(line[9:]) - if line.startswith('$ python'): - return 'py ' + args_to_win(line[8:]) - if line.startswith('$ '): + if line.startswith("$ python3"): + return "py " + args_to_win(line[9:]) + if line.startswith("$ python"): + return "py " + args_to_win(line[8:]) + if line.startswith("$ "): return args_to_win(line[2:]) return None @@ -349,23 +356,23 @@ class ConsoleDirective(CodeBlock): return None env = self.state.document.settings.env - self.arguments = ['console'] + self.arguments = ["console"] lit_blk_obj = super().run()[0] # Only do work when the djangohtml HTML Sphinx builder is being used, # invoke the default behavior for the rest. - if env.app.builder.name not in ('djangohtml', 'json'): + if env.app.builder.name not in ("djangohtml", "json"): return [lit_blk_obj] - lit_blk_obj['uid'] = str(env.new_serialno('console')) + lit_blk_obj["uid"] = str(env.new_serialno("console")) # Only add the tabbed UI if there is actually a Windows-specific # version of the CLI example. win_content = code_block_to_win(self.content) if win_content is None: - lit_blk_obj['win_console_text'] = None + lit_blk_obj["win_console_text"] = None else: self.content = win_content - lit_blk_obj['win_console_text'] = super().run()[0].rawsource + lit_blk_obj["win_console_text"] = super().run()[0].rawsource # Replace the literal_node object returned by Sphinx's CodeBlock with # the ConsoleNode wrapper. @@ -377,7 +384,9 @@ def html_page_context_hook(app, pagename, templatename, context, doctree): # control inclusion of console-tabs.css and activation of the JavaScript. # This way it's include only from HTML files rendered from reST files where # the ConsoleDirective is used. - context['include_console_assets'] = getattr(doctree, '_console_directive_used_flag', False) + context["include_console_assets"] = getattr( + doctree, "_console_directive_used_flag", False + ) def default_role_error( @@ -385,8 +394,7 @@ def default_role_error( ): msg = ( "Default role used (`single backticks`): %s. Did you mean to use two " - "backticks for ``code``, or miss an underscore for a `link`_ ?" - % rawtext + "backticks for ``code``, or miss an underscore for a `link`_ ?" % rawtext ) logger.warning(msg, location=(inliner.document.current_source, lineno)) return [nodes.Text(text)], [] diff --git a/docs/conf.py b/docs/conf.py index 81a8ce4a2a..41301337ee 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,13 +32,13 @@ sys.path.append(abspath(join(dirname(__file__), "_ext"))) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.6.0' +needs_sphinx = "1.6.0" # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "djangodocs", - 'sphinx.ext.extlinks', + "sphinx.ext.extlinks", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", @@ -53,35 +53,35 @@ autosectionlabel_maxdepth = 2 # Linkcheck settings. linkcheck_ignore = [ # Special-use addresses and domain names. (RFC 6761/6890) - r'^https?://(?:127\.0\.0\.1|\[::1\])(?::\d+)?/', - r'^https?://(?:[^/\.]+\.)*example\.(?:com|net|org)(?::\d+)?/', - r'^https?://(?:[^/\.]+\.)*(?:example|invalid|localhost|test)(?::\d+)?/', + r"^https?://(?:127\.0\.0\.1|\[::1\])(?::\d+)?/", + r"^https?://(?:[^/\.]+\.)*example\.(?:com|net|org)(?::\d+)?/", + r"^https?://(?:[^/\.]+\.)*(?:example|invalid|localhost|test)(?::\d+)?/", # Pages that are inaccessible because they require authentication. - r'^https://github\.com/[^/]+/[^/]+/fork', - r'^https://code\.djangoproject\.com/github/login', - r'^https://code\.djangoproject\.com/newticket', - r'^https://(?:code|www)\.djangoproject\.com/admin/', - r'^https://www\.djangoproject\.com/community/add/blogs/', - r'^https://www\.google\.com/webmasters/tools/ping', - r'^https://search\.google\.com/search-console/welcome', + r"^https://github\.com/[^/]+/[^/]+/fork", + r"^https://code\.djangoproject\.com/github/login", + r"^https://code\.djangoproject\.com/newticket", + r"^https://(?:code|www)\.djangoproject\.com/admin/", + r"^https://www\.djangoproject\.com/community/add/blogs/", + r"^https://www\.google\.com/webmasters/tools/ping", + r"^https://search\.google\.com/search-console/welcome", # Fragments used to dynamically switch content or populate fields. - r'^https://web\.libera\.chat/#', - r'^https://github\.com/[^#]+#L\d+-L\d+$', - r'^https://help\.apple\.com/itc/podcasts_connect/#/itc', + r"^https://web\.libera\.chat/#", + r"^https://github\.com/[^#]+#L\d+-L\d+$", + r"^https://help\.apple\.com/itc/podcasts_connect/#/itc", # Anchors on certain pages with missing a[name] attributes. - r'^https://tools\.ietf\.org/html/rfc1123\.html#section-', + r"^https://tools\.ietf\.org/html/rfc1123\.html#section-", ] # Spelling check needs an additional module that is not installed by default. # Add it only if spelling check is requested so docs can be generated without it. -if 'spelling' in sys.argv: +if "spelling" in sys.argv: extensions.append("sphinxcontrib.spelling") # Spelling language. -spelling_lang = 'en_US' +spelling_lang = "en_US" # Location of word list. -spelling_word_list_filename = 'spelling_wordlist' +spelling_word_list_filename = "spelling_wordlist" spelling_warning = True @@ -89,17 +89,17 @@ spelling_warning = True # templates_path = [] # The suffix of source filenames. -source_suffix = '.txt' +source_suffix = ".txt" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'contents' +master_doc = "contents" # General substitutions. -project = 'Django' -copyright = 'Django Software Foundation and contributors' +project = "Django" +copyright = "Django Software Foundation and contributors" # The version info for the project you're documenting, acts as replacement for @@ -107,31 +107,32 @@ copyright = 'Django Software Foundation and contributors' # built documents. # # The short X.Y version. -version = '4.1' +version = "4.1" # The full version, including alpha/beta/rc tags. try: from django import VERSION, get_version except ImportError: release = version else: + def django_release(): pep440ver = get_version() - if VERSION[3:5] == ('alpha', 0) and 'dev' not in pep440ver: - return pep440ver + '.dev' + if VERSION[3:5] == ("alpha", 0) and "dev" not in pep440ver: + return pep440ver + ".dev" return pep440ver release = django_release() # The "development version" of Django -django_next_version = '4.1' +django_next_version = "4.1" extlinks = { - 'bpo': ('https://bugs.python.org/issue%s', 'bpo-'), - 'commit': ('https://github.com/django/django/commit/%s', ''), - 'cve': ('https://nvd.nist.gov/vuln/detail/CVE-%s', 'CVE-'), + "bpo": ("https://bugs.python.org/issue%s", "bpo-"), + "commit": ("https://github.com/django/django/commit/%s", ""), + "cve": ("https://nvd.nist.gov/vuln/detail/CVE-%s", "CVE-"), # A file or directory. GitHub redirects from blob to tree if needed. - 'source': ('https://github.com/django/django/blob/main/%s', ''), - 'ticket': ('https://code.djangoproject.com/ticket/%s', '#'), + "source": ("https://github.com/django/django/blob/main/%s", ""), + "ticket": ("https://code.djangoproject.com/ticket/%s", "#"), } # The language for content autogenerated by Sphinx. Refer to documentation @@ -139,17 +140,17 @@ extlinks = { # language = None # Location for .po/.mo translation files used when language is set -locale_dirs = ['locale/'] +locale_dirs = ["locale/"] # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' # Else, today_fmt is used as the format for a strftime call. -today_fmt = '%B %d, %Y' +today_fmt = "%B %d, %Y" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build', '_theme', 'requirements.txt'] +exclude_patterns = ["_build", "_theme", "requirements.txt"] # The reST default role (used for this markup: `text`) to use for all documents. default_role = "default-role-error" @@ -166,21 +167,21 @@ add_module_names = False show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'trac' +pygments_style = "trac" # Links to Python's docs should reference the most recent version of the 3.x # branch, which is located at this URL. intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), - 'psycopg2': ('https://www.psycopg.org/docs/', None), + "python": ("https://docs.python.org/3/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), + "psycopg2": ("https://www.psycopg.org/docs/", None), } # Python's docs don't change every week. intersphinx_cache_limit = 90 # days # The 'versionadded' and 'versionchanged' directives are overridden. -suppress_warnings = ['app.add_directive'] +suppress_warnings = ["app.add_directive"] # -- Options for HTML output --------------------------------------------------- @@ -219,7 +220,7 @@ html_theme_path = ["_theme"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -html_last_updated_fmt = '%b %d, %Y' +html_last_updated_fmt = "%b %d, %Y" # Content template for the index page. # html_index = '' @@ -258,7 +259,7 @@ html_additional_pages = {} # html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Djangodoc' +htmlhelp_basename = "Djangodoc" modindex_common_prefix = ["django."] @@ -273,14 +274,14 @@ rst_epilog = """ # -- Options for LaTeX output -------------------------------------------------- # Use XeLaTeX for Unicode support. -latex_engine = 'xelatex' +latex_engine = "xelatex" latex_use_xindy = False # Set font for CJK and fallbacks for unicode characters. latex_elements = { - 'fontpkg': r""" + "fontpkg": r""" \setmainfont{Symbola} """, - 'preamble': r""" + "preamble": r""" \usepackage{newunicodechar} \usepackage[UTF8]{ctex} \newunicodechar{π}{\ensuremath{\pi}} @@ -295,8 +296,13 @@ latex_elements = { # (source start file, target name, title, author, document class [howto/manual]). # latex_documents = [] latex_documents = [ - ('contents', 'django.tex', 'Django Documentation', - 'Django Software Foundation', 'manual'), + ( + "contents", + "django.tex", + "Django Documentation", + "Django Software Foundation", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of @@ -324,31 +330,41 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [( - 'ref/django-admin', - 'django-admin', - 'Utility script for the Django web framework', - ['Django Software Foundation'], - 1 -)] +man_pages = [ + ( + "ref/django-admin", + "django-admin", + "Utility script for the Django web framework", + ["Django Software Foundation"], + 1, + ) +] # -- Options for Texinfo output ------------------------------------------------ # List of tuples (startdocname, targetname, title, author, dir_entry, # description, category, toctree_only) -texinfo_documents = [( - master_doc, "django", "", "", "Django", - "Documentation of the Django framework", "Web development", False -)] +texinfo_documents = [ + ( + master_doc, + "django", + "", + "", + "Django", + "Documentation of the Django framework", + "Web development", + False, + ) +] # -- Options for Epub output --------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project -epub_author = 'Django Software Foundation' -epub_publisher = 'Django Software Foundation' +epub_author = "Django Software Foundation" +epub_publisher = "Django Software Foundation" epub_copyright = copyright # The basename for the epub file. It defaults to the project name. @@ -358,7 +374,7 @@ epub_copyright = copyright # for small screen space, using the same theme for HTML and epub output is # usually not wise. This defaults to 'epub', a theme designed to save visual # space. -epub_theme = 'djangodocs-epub' +epub_theme = "djangodocs-epub" # The language of the text. It defaults to the language option # or en if the language is not set. @@ -375,7 +391,7 @@ epub_theme = 'djangodocs-epub' # epub_uid = '' # A tuple containing the cover image and cover page html template filenames. -epub_cover = ('', 'epub-cover.html') +epub_cover = ("", "epub-cover.html") # A sequence of (type, uri, title) tuples for the guide element of content.opf. # epub_guide = () diff --git a/scripts/manage_translations.py b/scripts/manage_translations.py index fd394e10ab..5b82011f20 100644 --- a/scripts/manage_translations.py +++ b/scripts/manage_translations.py @@ -26,7 +26,7 @@ import django from django.conf import settings from django.core.management import call_command -HAVE_JS = ['admin'] +HAVE_JS = ["admin"] def _get_locale_dirs(resources, include_core=True): @@ -35,33 +35,35 @@ def _get_locale_dirs(resources, include_core=True): optionally including the django core catalog. If resources list is not None, filter directories matching resources content. """ - contrib_dir = os.path.join(os.getcwd(), 'django', 'contrib') + contrib_dir = os.path.join(os.getcwd(), "django", "contrib") dirs = [] # Collect all locale directories for contrib_name in os.listdir(contrib_dir): - path = os.path.join(contrib_dir, contrib_name, 'locale') + path = os.path.join(contrib_dir, contrib_name, "locale") if os.path.isdir(path): dirs.append((contrib_name, path)) if contrib_name in HAVE_JS: dirs.append(("%s-js" % contrib_name, path)) if include_core: - dirs.insert(0, ('core', os.path.join(os.getcwd(), 'django', 'conf', 'locale'))) + dirs.insert(0, ("core", os.path.join(os.getcwd(), "django", "conf", "locale"))) # Filter by resources, if any if resources is not None: res_names = [d[0] for d in dirs] dirs = [ld for ld in dirs if ld[0] in resources] if len(resources) > len(dirs): - print("You have specified some unknown resources. " - "Available resource names are: %s" % (', '.join(res_names),)) + print( + "You have specified some unknown resources. " + "Available resource names are: %s" % (", ".join(res_names),) + ) exit(1) return dirs def _tx_resource_for_name(name): - """ Return the Transifex resource name """ - if name == 'core': + """Return the Transifex resource name""" + if name == "core": return "django.core" else: return "django.contrib-%s" % name @@ -71,10 +73,15 @@ def _check_diff(cat_name, base_path): """ Output the approximate number of changed/added strings in the en catalog. """ - po_path = '%(path)s/en/LC_MESSAGES/django%(ext)s.po' % { - 'path': base_path, 'ext': 'js' if cat_name.endswith('-js') else ''} - p = run("git diff -U0 %s | egrep '^[-+]msgid' | wc -l" % po_path, - capture_output=True, shell=True) + po_path = "%(path)s/en/LC_MESSAGES/django%(ext)s.po" % { + "path": base_path, + "ext": "js" if cat_name.endswith("-js") else "", + } + p = run( + "git diff -U0 %s | egrep '^[-+]msgid' | wc -l" % po_path, + capture_output=True, + shell=True, + ) num_changes = int(p.stdout.strip()) print("%d changed/added messages in '%s' catalog." % (num_changes, cat_name)) @@ -90,14 +97,14 @@ def update_catalogs(resources=None, languages=None): print("`update_catalogs` will always process all resources.") contrib_dirs = _get_locale_dirs(None, include_core=False) - os.chdir(os.path.join(os.getcwd(), 'django')) + os.chdir(os.path.join(os.getcwd(), "django")) print("Updating en catalogs for Django and contrib apps...") - call_command('makemessages', locale=['en']) + call_command("makemessages", locale=["en"]) print("Updating en JS catalogs for Django and contrib apps...") - call_command('makemessages', locale=['en'], domain='djangojs') + call_command("makemessages", locale=["en"], domain="djangojs") # Output changed stats - _check_diff('core', os.path.join(os.getcwd(), 'conf', 'locale')) + _check_diff("core", os.path.join(os.getcwd(), "conf", "locale")) for name, dir_ in contrib_dirs: _check_diff(name, dir_) @@ -113,26 +120,26 @@ def lang_stats(resources=None, languages=None): for name, dir_ in locale_dirs: print("\nShowing translations stats for '%s':" % name) - langs = sorted(d for d in os.listdir(dir_) if not d.startswith('_')) + langs = sorted(d for d in os.listdir(dir_) if not d.startswith("_")) for lang in langs: if languages and lang not in languages: continue # TODO: merge first with the latest en catalog - po_path = '{path}/{lang}/LC_MESSAGES/django{ext}.po'.format( - path=dir_, lang=lang, ext='js' if name.endswith('-js') else '' + po_path = "{path}/{lang}/LC_MESSAGES/django{ext}.po".format( + path=dir_, lang=lang, ext="js" if name.endswith("-js") else "" ) p = run( - ['msgfmt', '-vc', '-o', '/dev/null', po_path], + ["msgfmt", "-vc", "-o", "/dev/null", po_path], capture_output=True, - env={'LANG': 'C'}, - encoding='utf-8', + env={"LANG": "C"}, + encoding="utf-8", ) if p.returncode == 0: # msgfmt output stats on stderr - print('%s: %s' % (lang, p.stderr.strip())) + print("%s: %s" % (lang, p.stderr.strip())) else: print( - 'Errors happened when checking %s translation for %s:\n%s' + "Errors happened when checking %s translation for %s:\n%s" % (lang, name, p.stderr) ) @@ -147,23 +154,40 @@ def fetch(resources=None, languages=None): for name, dir_ in locale_dirs: # Transifex pull if languages is None: - run(['tx', 'pull', '-r', _tx_resource_for_name(name), '-a', '-f', '--minimum-perc=5']) - target_langs = sorted(d for d in os.listdir(dir_) if not d.startswith('_') and d != 'en') + run( + [ + "tx", + "pull", + "-r", + _tx_resource_for_name(name), + "-a", + "-f", + "--minimum-perc=5", + ] + ) + target_langs = sorted( + d for d in os.listdir(dir_) if not d.startswith("_") and d != "en" + ) else: for lang in languages: - run(['tx', 'pull', '-r', _tx_resource_for_name(name), '-f', '-l', lang]) + run(["tx", "pull", "-r", _tx_resource_for_name(name), "-f", "-l", lang]) target_langs = languages # msgcat to wrap lines and msgfmt for compilation of .mo file for lang in target_langs: - po_path = '%(path)s/%(lang)s/LC_MESSAGES/django%(ext)s.po' % { - 'path': dir_, 'lang': lang, 'ext': 'js' if name.endswith('-js') else ''} + po_path = "%(path)s/%(lang)s/LC_MESSAGES/django%(ext)s.po" % { + "path": dir_, + "lang": lang, + "ext": "js" if name.endswith("-js") else "", + } if not os.path.exists(po_path): - print("No %(lang)s translation for resource %(name)s" % { - 'lang': lang, 'name': name}) + print( + "No %(lang)s translation for resource %(name)s" + % {"lang": lang, "name": name} + ) continue - run(['msgcat', '--no-location', '-o', po_path, po_path]) - msgfmt = run(['msgfmt', '-c', '-o', '%s.mo' % po_path[:-3], po_path]) + run(["msgcat", "--no-location", "-o", po_path, po_path]) + msgfmt = run(["msgfmt", "-c", "-o", "%s.mo" % po_path[:-3], po_path]) if msgfmt.returncode != 0: errors.append((name, lang)) if errors: @@ -174,12 +198,22 @@ def fetch(resources=None, languages=None): if __name__ == "__main__": - RUNABLE_SCRIPTS = ('update_catalogs', 'lang_stats', 'fetch') + RUNABLE_SCRIPTS = ("update_catalogs", "lang_stats", "fetch") parser = ArgumentParser() - parser.add_argument('cmd', nargs=1, choices=RUNABLE_SCRIPTS) - parser.add_argument("-r", "--resources", action='append', help="limit operation to the specified resources") - parser.add_argument("-l", "--languages", action='append', help="limit operation to the specified languages") + parser.add_argument("cmd", nargs=1, choices=RUNABLE_SCRIPTS) + parser.add_argument( + "-r", + "--resources", + action="append", + help="limit operation to the specified resources", + ) + parser.add_argument( + "-l", + "--languages", + action="append", + help="limit operation to the specified languages", + ) options = parser.parse_args() eval(options.cmd[0])(options.resources, options.languages) diff --git a/setup.py b/setup.py index 43a5f69365..ef91130d47 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup # Allow editable install into user site directory. # See https://github.com/pypa/pip/issues/7953. -site.ENABLE_USER_SITE = '--user' in sys.argv[1:] +site.ENABLE_USER_SITE = "--user" in sys.argv[1:] # Warn if we are installing over top of an existing installation. This can # cause issues where files that were deleted from a more recent Django are @@ -32,7 +32,8 @@ setup() if overlay_warning: - sys.stderr.write(""" + sys.stderr.write( + """ ======== WARNING! @@ -49,4 +50,6 @@ should manually remove the directory and re-install Django. -""" % {"existing_path": existing_path}) +""" + % {"existing_path": existing_path} + ) diff --git a/tests/absolute_url_overrides/tests.py b/tests/absolute_url_overrides/tests.py index 6879283083..b9a7e4cd64 100644 --- a/tests/absolute_url_overrides/tests.py +++ b/tests/absolute_url_overrides/tests.py @@ -3,35 +3,39 @@ from django.test import SimpleTestCase from django.test.utils import isolate_apps -@isolate_apps('absolute_url_overrides') +@isolate_apps("absolute_url_overrides") class AbsoluteUrlOverrideTests(SimpleTestCase): - def test_get_absolute_url(self): """ get_absolute_url() functions as a normal method. """ - def get_absolute_url(o): - return '/test-a/%s/' % o.pk - TestA = self._create_model_class('TestA', get_absolute_url) - self.assertTrue(hasattr(TestA, 'get_absolute_url')) - obj = TestA(pk=1, name='Foo') - self.assertEqual('/test-a/%s/' % obj.pk, obj.get_absolute_url()) + def get_absolute_url(o): + return "/test-a/%s/" % o.pk + + TestA = self._create_model_class("TestA", get_absolute_url) + + self.assertTrue(hasattr(TestA, "get_absolute_url")) + obj = TestA(pk=1, name="Foo") + self.assertEqual("/test-a/%s/" % obj.pk, obj.get_absolute_url()) def test_override_get_absolute_url(self): """ ABSOLUTE_URL_OVERRIDES should override get_absolute_url(). """ + def get_absolute_url(o): - return '/test-b/%s/' % o.pk + return "/test-b/%s/" % o.pk + with self.settings( ABSOLUTE_URL_OVERRIDES={ - 'absolute_url_overrides.testb': lambda o: '/overridden-test-b/%s/' % o.pk, + "absolute_url_overrides.testb": lambda o: "/overridden-test-b/%s/" + % o.pk, }, ): - TestB = self._create_model_class('TestB', get_absolute_url) - obj = TestB(pk=1, name='Foo') - self.assertEqual('/overridden-test-b/%s/' % obj.pk, obj.get_absolute_url()) + TestB = self._create_model_class("TestB", get_absolute_url) + obj = TestB(pk=1, name="Foo") + self.assertEqual("/overridden-test-b/%s/" % obj.pk, obj.get_absolute_url()) def test_insert_get_absolute_url(self): """ @@ -40,19 +44,19 @@ class AbsoluteUrlOverrideTests(SimpleTestCase): """ with self.settings( ABSOLUTE_URL_OVERRIDES={ - 'absolute_url_overrides.testc': lambda o: '/test-c/%s/' % o.pk, + "absolute_url_overrides.testc": lambda o: "/test-c/%s/" % o.pk, }, ): - TestC = self._create_model_class('TestC') - obj = TestC(pk=1, name='Foo') - self.assertEqual('/test-c/%s/' % obj.pk, obj.get_absolute_url()) + TestC = self._create_model_class("TestC") + obj = TestC(pk=1, name="Foo") + self.assertEqual("/test-c/%s/" % obj.pk, obj.get_absolute_url()) def _create_model_class(self, class_name, get_absolute_url_method=None): attrs = { - 'name': models.CharField(max_length=50), - '__module__': 'absolute_url_overrides', + "name": models.CharField(max_length=50), + "__module__": "absolute_url_overrides", } if get_absolute_url_method: - attrs['get_absolute_url'] = get_absolute_url_method + attrs["get_absolute_url"] = get_absolute_url_method return type(class_name, (models.Model,), attrs) diff --git a/tests/admin_autodiscover/tests.py b/tests/admin_autodiscover/tests.py index b15060ade8..66de74f2a5 100644 --- a/tests/admin_autodiscover/tests.py +++ b/tests/admin_autodiscover/tests.py @@ -7,11 +7,12 @@ class AdminAutoDiscoverTests(SimpleTestCase): Test for bug #8245 - don't raise an AlreadyRegistered exception when using autodiscover() and an admin.py module contains an error. """ + def test_double_call_autodiscover(self): # The first time autodiscover is called, we should get our real error. - with self.assertRaisesMessage(Exception, 'Bad admin module'): + with self.assertRaisesMessage(Exception, "Bad admin module"): admin.autodiscover() # Calling autodiscover again should raise the very same error it did # the first time, not an AlreadyRegistered error. - with self.assertRaisesMessage(Exception, 'Bad admin module'): + with self.assertRaisesMessage(Exception, "Bad admin module"): admin.autodiscover() diff --git a/tests/admin_changelist/admin.py b/tests/admin_changelist/admin.py index 929539ea88..67187f5b79 100644 --- a/tests/admin_changelist/admin.py +++ b/tests/admin_changelist/admin.py @@ -12,12 +12,14 @@ site.register(User, UserAdmin) class CustomPaginator(Paginator): def __init__(self, queryset, page_size, orphans=0, allow_empty_first_page=True): - super().__init__(queryset, 5, orphans=2, allow_empty_first_page=allow_empty_first_page) + super().__init__( + queryset, 5, orphans=2, allow_empty_first_page=allow_empty_first_page + ) class EventAdmin(admin.ModelAdmin): - date_hierarchy = 'date' - list_display = ['event_date_func'] + date_hierarchy = "date" + list_display = ["event_date_func"] @admin.display def event_date_func(self, event): @@ -31,21 +33,21 @@ site.register(Event, EventAdmin) class ParentAdmin(admin.ModelAdmin): - list_filter = ['child__name'] - search_fields = ['child__name'] - list_select_related = ['child'] + list_filter = ["child__name"] + search_fields = ["child__name"] + list_select_related = ["child"] class ParentAdminTwoSearchFields(admin.ModelAdmin): - list_filter = ['child__name'] - search_fields = ['child__name', 'child__age'] - list_select_related = ['child'] + list_filter = ["child__name"] + search_fields = ["child__name", "child__age"] + list_select_related = ["child"] class ChildAdmin(admin.ModelAdmin): - list_display = ['name', 'parent'] + list_display = ["name", "parent"] list_per_page = 10 - list_filter = ['parent', 'age'] + list_filter = ["parent", "age"] def get_queryset(self, request): return super().get_queryset(request).select_related("parent") @@ -56,32 +58,32 @@ class CustomPaginationAdmin(ChildAdmin): class FilteredChildAdmin(admin.ModelAdmin): - list_display = ['name', 'parent'] + list_display = ["name", "parent"] list_per_page = 10 def get_queryset(self, request): - return super().get_queryset(request).filter(name__contains='filtered') + return super().get_queryset(request).filter(name__contains="filtered") class BandAdmin(admin.ModelAdmin): - list_filter = ['genres'] + list_filter = ["genres"] class NrOfMembersFilter(admin.SimpleListFilter): - title = 'number of members' - parameter_name = 'nr_of_members_partition' + title = "number of members" + parameter_name = "nr_of_members_partition" def lookups(self, request, model_admin): return [ - ('5', '0 - 5'), - ('more', 'more than 5'), + ("5", "0 - 5"), + ("more", "more than 5"), ] def queryset(self, request, queryset): value = self.value() - if value == '5': + if value == "5": return queryset.filter(nr_of_members__lte=5) - if value == 'more': + if value == "more": return queryset.filter(nr_of_members__gt=5) @@ -93,44 +95,44 @@ site.register(Band, BandCallableFilterAdmin) class GroupAdmin(admin.ModelAdmin): - list_filter = ['members'] + list_filter = ["members"] class ConcertAdmin(admin.ModelAdmin): - list_filter = ['group__members'] - search_fields = ['group__members__name'] + list_filter = ["group__members"] + search_fields = ["group__members__name"] class QuartetAdmin(admin.ModelAdmin): - list_filter = ['members'] + list_filter = ["members"] class ChordsBandAdmin(admin.ModelAdmin): - list_filter = ['members'] + list_filter = ["members"] class InvitationAdmin(admin.ModelAdmin): - list_display = ('band', 'player') - list_select_related = ('player',) + list_display = ("band", "player") + list_select_related = ("player",) class DynamicListDisplayChildAdmin(admin.ModelAdmin): - list_display = ('parent', 'name', 'age') + list_display = ("parent", "name", "age") def get_list_display(self, request): my_list_display = super().get_list_display(request) - if request.user.username == 'noparents': + if request.user.username == "noparents": my_list_display = list(my_list_display) - my_list_display.remove('parent') + my_list_display.remove("parent") return my_list_display class DynamicListDisplayLinksChildAdmin(admin.ModelAdmin): - list_display = ('parent', 'name', 'age') - list_display_links = ['parent', 'name'] + list_display = ("parent", "name", "age") + list_display_links = ["parent", "name"] def get_list_display_links(self, request, list_display): - return ['age'] + return ["age"] site.register(Child, DynamicListDisplayChildAdmin) @@ -138,8 +140,8 @@ site.register(Child, DynamicListDisplayChildAdmin) class NoListDisplayLinksParentAdmin(admin.ModelAdmin): list_display_links = None - list_display = ['name'] - list_editable = ['name'] + list_display = ["name"] + list_editable = ["name"] actions_on_bottom = True @@ -148,8 +150,8 @@ site.register(Parent, NoListDisplayLinksParentAdmin) class SwallowAdmin(admin.ModelAdmin): actions = None # prevent ['action_checkbox'] + list(list_display) - list_display = ('origin', 'load', 'speed', 'swallowonetoone') - list_editable = ['load', 'speed'] + list_display = ("origin", "load", "speed", "swallowonetoone") + list_editable = ["load", "speed"] list_per_page = 3 @@ -157,29 +159,29 @@ site.register(Swallow, SwallowAdmin) class DynamicListFilterChildAdmin(admin.ModelAdmin): - list_filter = ('parent', 'name', 'age') + list_filter = ("parent", "name", "age") def get_list_filter(self, request): my_list_filter = super().get_list_filter(request) - if request.user.username == 'noparents': + if request.user.username == "noparents": my_list_filter = list(my_list_filter) - my_list_filter.remove('parent') + my_list_filter.remove("parent") return my_list_filter class DynamicSearchFieldsChildAdmin(admin.ModelAdmin): - search_fields = ('name',) + search_fields = ("name",) def get_search_fields(self, request): search_fields = super().get_search_fields(request) - search_fields += ('age',) + search_fields += ("age",) return search_fields class EmptyValueChildAdmin(admin.ModelAdmin): - empty_value_display = '-empty-' - list_display = ('name', 'age_display', 'age') + empty_value_display = "-empty-" + list_display = ("name", "age_display", "age") - @admin.display(empty_value='†') + @admin.display(empty_value="†") def age_display(self, obj): return obj.age diff --git a/tests/admin_changelist/models.py b/tests/admin_changelist/models.py index 81d7fdfb3a..180c38428a 100644 --- a/tests/admin_changelist/models.py +++ b/tests/admin_changelist/models.py @@ -38,7 +38,7 @@ class Musician(models.Model): class Group(models.Model): name = models.CharField(max_length=30) - members = models.ManyToManyField(Musician, through='Membership') + members = models.ManyToManyField(Musician, through="Membership") def __str__(self): return self.name @@ -65,7 +65,7 @@ class ChordsMusician(Musician): class ChordsBand(models.Model): name = models.CharField(max_length=30) - members = models.ManyToManyField(ChordsMusician, through='Invitation') + members = models.ManyToManyField(ChordsMusician, through="Invitation") class Invitation(models.Model): @@ -81,7 +81,7 @@ class Swallow(models.Model): speed = models.FloatField() class Meta: - ordering = ('speed', 'load') + ordering = ("speed", "load") class SwallowOneToOne(models.Model): @@ -93,12 +93,13 @@ class UnorderedObject(models.Model): Model without any defined `Meta.ordering`. Refs #17198. """ + bool = models.BooleanField(default=True) class OrderedObjectManager(models.Manager): def get_queryset(self): - return super().get_queryset().order_by('number') + return super().get_queryset().order_by("number") class OrderedObject(models.Model): @@ -106,9 +107,10 @@ class OrderedObject(models.Model): Model with Manager that defines a default order. Refs #17198. """ + name = models.CharField(max_length=255) bool = models.BooleanField(default=True) - number = models.IntegerField(default=0, db_column='number_val') + number = models.IntegerField(default=0, db_column="number_val") objects = OrderedObjectManager() diff --git a/tests/admin_changelist/test_date_hierarchy.py b/tests/admin_changelist/test_date_hierarchy.py index a321650b32..a8c10f7cd8 100644 --- a/tests/admin_changelist/test_date_hierarchy.py +++ b/tests/admin_changelist/test_date_hierarchy.py @@ -5,7 +5,8 @@ from django.contrib.auth.models import User from django.test import RequestFactory, TestCase from django.utils.timezone import make_aware -from .admin import EventAdmin, site as custom_site +from .admin import EventAdmin +from .admin import site as custom_site from .models import Event @@ -14,34 +15,48 @@ class DateHierarchyTests(TestCase): @classmethod def setUpTestData(cls): - cls.superuser = User.objects.create_superuser(username='super', email='a@b.com', password='xxx') + cls.superuser = User.objects.create_superuser( + username="super", email="a@b.com", password="xxx" + ) def assertDateParams(self, query, expected_from_date, expected_to_date): - query = {'date__%s' % field: val for field, val in query.items()} - request = self.factory.get('/', query) + query = {"date__%s" % field: val for field, val in query.items()} + request = self.factory.get("/", query) request.user = self.superuser changelist = EventAdmin(Event, custom_site).get_changelist_instance(request) _, _, lookup_params, *_ = changelist.get_filters(request) - self.assertEqual(lookup_params['date__gte'], expected_from_date) - self.assertEqual(lookup_params['date__lt'], expected_to_date) + self.assertEqual(lookup_params["date__gte"], expected_from_date) + self.assertEqual(lookup_params["date__lt"], expected_to_date) def test_bounded_params(self): tests = ( - ({'year': 2017}, datetime(2017, 1, 1), datetime(2018, 1, 1)), - ({'year': 2017, 'month': 2}, datetime(2017, 2, 1), datetime(2017, 3, 1)), - ({'year': 2017, 'month': 12}, datetime(2017, 12, 1), datetime(2018, 1, 1)), - ({'year': 2017, 'month': 12, 'day': 15}, datetime(2017, 12, 15), datetime(2017, 12, 16)), - ({'year': 2017, 'month': 12, 'day': 31}, datetime(2017, 12, 31), datetime(2018, 1, 1)), - ({'year': 2017, 'month': 2, 'day': 28}, datetime(2017, 2, 28), datetime(2017, 3, 1)), + ({"year": 2017}, datetime(2017, 1, 1), datetime(2018, 1, 1)), + ({"year": 2017, "month": 2}, datetime(2017, 2, 1), datetime(2017, 3, 1)), + ({"year": 2017, "month": 12}, datetime(2017, 12, 1), datetime(2018, 1, 1)), + ( + {"year": 2017, "month": 12, "day": 15}, + datetime(2017, 12, 15), + datetime(2017, 12, 16), + ), + ( + {"year": 2017, "month": 12, "day": 31}, + datetime(2017, 12, 31), + datetime(2018, 1, 1), + ), + ( + {"year": 2017, "month": 2, "day": 28}, + datetime(2017, 2, 28), + datetime(2017, 3, 1), + ), ) for query, expected_from_date, expected_to_date in tests: with self.subTest(query=query): self.assertDateParams(query, expected_from_date, expected_to_date) def test_bounded_params_with_time_zone(self): - with self.settings(USE_TZ=True, TIME_ZONE='Asia/Jerusalem'): + with self.settings(USE_TZ=True, TIME_ZONE="Asia/Jerusalem"): self.assertDateParams( - {'year': 2017, 'month': 2, 'day': 28}, + {"year": 2017, "month": 2, "day": 28}, make_aware(datetime(2017, 2, 28)), make_aware(datetime(2017, 3, 1)), ) @@ -49,31 +64,33 @@ class DateHierarchyTests(TestCase): def test_bounded_params_with_dst_time_zone(self): tests = [ # Northern hemisphere. - ('Asia/Jerusalem', 3), - ('Asia/Jerusalem', 10), + ("Asia/Jerusalem", 3), + ("Asia/Jerusalem", 10), # Southern hemisphere. - ('Pacific/Chatham', 4), - ('Pacific/Chatham', 9), + ("Pacific/Chatham", 4), + ("Pacific/Chatham", 9), ] for time_zone, month in tests: with self.subTest(time_zone=time_zone, month=month): with self.settings(USE_TZ=True, TIME_ZONE=time_zone): self.assertDateParams( - {'year': 2019, 'month': month}, + {"year": 2019, "month": month}, make_aware(datetime(2019, month, 1)), make_aware(datetime(2019, month + 1, 1)), ) def test_invalid_params(self): tests = ( - {'year': 'x'}, - {'year': 2017, 'month': 'x'}, - {'year': 2017, 'month': 12, 'day': 'x'}, - {'year': 2017, 'month': 13}, - {'year': 2017, 'month': 12, 'day': 32}, - {'year': 2017, 'month': 0}, - {'year': 2017, 'month': 12, 'day': 0}, + {"year": "x"}, + {"year": 2017, "month": "x"}, + {"year": 2017, "month": 12, "day": "x"}, + {"year": 2017, "month": 13}, + {"year": 2017, "month": 12, "day": 32}, + {"year": 2017, "month": 0}, + {"year": 2017, "month": 12, "day": 0}, ) for invalid_query in tests: - with self.subTest(query=invalid_query), self.assertRaises(IncorrectLookupParameters): + with self.subTest(query=invalid_query), self.assertRaises( + IncorrectLookupParameters + ): self.assertDateParams(invalid_query, None, None) diff --git a/tests/admin_changelist/tests.py b/tests/admin_changelist/tests.py index 86fcab531a..7d6deced7e 100644 --- a/tests/admin_changelist/tests.py +++ b/tests/admin_changelist/tests.py @@ -6,7 +6,12 @@ from django.contrib.admin.options import IncorrectLookupParameters from django.contrib.admin.templatetags.admin_list import pagination from django.contrib.admin.tests import AdminSeleniumTestCase from django.contrib.admin.views.main import ( - ALL_VAR, IS_POPUP_VAR, ORDER_VAR, PAGE_VAR, SEARCH_VAR, TO_FIELD_VAR, + ALL_VAR, + IS_POPUP_VAR, + ORDER_VAR, + PAGE_VAR, + SEARCH_VAR, + TO_FIELD_VAR, ) from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType @@ -18,36 +23,64 @@ from django.db.models.lookups import Contains, Exact from django.template import Context, Template, TemplateSyntaxError from django.test import TestCase, override_settings from django.test.client import RequestFactory -from django.test.utils import ( - CaptureQueriesContext, isolate_apps, register_lookup, -) +from django.test.utils import CaptureQueriesContext, isolate_apps, register_lookup from django.urls import reverse from django.utils import formats from .admin import ( - BandAdmin, ChildAdmin, ChordsBandAdmin, ConcertAdmin, - CustomPaginationAdmin, CustomPaginator, DynamicListDisplayChildAdmin, - DynamicListDisplayLinksChildAdmin, DynamicListFilterChildAdmin, - DynamicSearchFieldsChildAdmin, EmptyValueChildAdmin, EventAdmin, - FilteredChildAdmin, GroupAdmin, InvitationAdmin, - NoListDisplayLinksParentAdmin, ParentAdmin, ParentAdminTwoSearchFields, - QuartetAdmin, SwallowAdmin, site as custom_site, + BandAdmin, + ChildAdmin, + ChordsBandAdmin, + ConcertAdmin, + CustomPaginationAdmin, + CustomPaginator, + DynamicListDisplayChildAdmin, + DynamicListDisplayLinksChildAdmin, + DynamicListFilterChildAdmin, + DynamicSearchFieldsChildAdmin, + EmptyValueChildAdmin, + EventAdmin, + FilteredChildAdmin, + GroupAdmin, + InvitationAdmin, + NoListDisplayLinksParentAdmin, + ParentAdmin, + ParentAdminTwoSearchFields, + QuartetAdmin, + SwallowAdmin, ) +from .admin import site as custom_site from .models import ( - Band, CharPK, Child, ChordsBand, ChordsMusician, Concert, CustomIdUser, - Event, Genre, Group, Invitation, Membership, Musician, OrderedObject, - Parent, Quartet, Swallow, SwallowOneToOne, UnorderedObject, + Band, + CharPK, + Child, + ChordsBand, + ChordsMusician, + Concert, + CustomIdUser, + Event, + Genre, + Group, + Invitation, + Membership, + Musician, + OrderedObject, + Parent, + Quartet, + Swallow, + SwallowOneToOne, + UnorderedObject, ) def build_tbody_html(pk, href, extra_fields): return ( - '<tbody><tr>' + "<tbody><tr>" '<td class="action-checkbox">' '<input type="checkbox" name="_selected_action" value="{}" ' 'class="action-select"></td>' '<th class="field-name"><a href="{}">name</a></th>' - '{}</tr></tbody>' + "{}</tr></tbody>" ).format(pk, href, extra_fields) @@ -57,10 +90,14 @@ class ChangeListTests(TestCase): @classmethod def setUpTestData(cls): - cls.superuser = User.objects.create_superuser(username='super', email='a@b.com', password='xxx') + cls.superuser = User.objects.create_superuser( + username="super", email="a@b.com", password="xxx" + ) def _create_superuser(self, username): - return User.objects.create_superuser(username=username, email='a@b.com', password='xxx') + return User.objects.create_superuser( + username=username, email="a@b.com", password="xxx" + ) def _mocked_authenticated_request(self, url, user): request = self.factory.get(url) @@ -69,36 +106,36 @@ class ChangeListTests(TestCase): def test_repr(self): m = ChildAdmin(Child, custom_site) - request = self.factory.get('/child/') + request = self.factory.get("/child/") request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertEqual(repr(cl), '<ChangeList: model=Child model_admin=ChildAdmin>') + self.assertEqual(repr(cl), "<ChangeList: model=Child model_admin=ChildAdmin>") def test_specified_ordering_by_f_expression(self): class OrderedByFBandAdmin(admin.ModelAdmin): - list_display = ['name', 'genres', 'nr_of_members'] + list_display = ["name", "genres", "nr_of_members"] ordering = ( - F('nr_of_members').desc(nulls_last=True), - Upper(F('name')).asc(), - F('genres').asc(), + F("nr_of_members").desc(nulls_last=True), + Upper(F("name")).asc(), + F("genres").asc(), ) m = OrderedByFBandAdmin(Band, custom_site) - request = self.factory.get('/band/') + request = self.factory.get("/band/") request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertEqual(cl.get_ordering_field_columns(), {3: 'desc', 2: 'asc'}) + self.assertEqual(cl.get_ordering_field_columns(), {3: "desc", 2: "asc"}) def test_specified_ordering_by_f_expression_without_asc_desc(self): class OrderedByFBandAdmin(admin.ModelAdmin): - list_display = ['name', 'genres', 'nr_of_members'] - ordering = (F('nr_of_members'), Upper('name'), F('genres')) + list_display = ["name", "genres", "nr_of_members"] + ordering = (F("nr_of_members"), Upper("name"), F("genres")) m = OrderedByFBandAdmin(Band, custom_site) - request = self.factory.get('/band/') + request = self.factory.get("/band/") request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertEqual(cl.get_ordering_field_columns(), {3: 'asc', 2: 'asc'}) + self.assertEqual(cl.get_ordering_field_columns(), {3: "asc", 2: "asc"}) def test_select_related_preserved(self): """ @@ -106,85 +143,85 @@ class ChangeListTests(TestCase): overwrite a custom select_related provided by ModelAdmin.get_queryset(). """ m = ChildAdmin(Child, custom_site) - request = self.factory.get('/child/') + request = self.factory.get("/child/") request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertEqual(cl.queryset.query.select_related, {'parent': {}}) + self.assertEqual(cl.queryset.query.select_related, {"parent": {}}) def test_select_related_preserved_when_multi_valued_in_search_fields(self): - parent = Parent.objects.create(name='Mary') - Child.objects.create(parent=parent, name='Danielle') - Child.objects.create(parent=parent, name='Daniel') + parent = Parent.objects.create(name="Mary") + Child.objects.create(parent=parent, name="Danielle") + Child.objects.create(parent=parent, name="Daniel") m = ParentAdmin(Parent, custom_site) - request = self.factory.get('/parent/', data={SEARCH_VAR: 'daniel'}) + request = self.factory.get("/parent/", data={SEARCH_VAR: "daniel"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertEqual(cl.queryset.count(), 1) # select_related is preserved. - self.assertEqual(cl.queryset.query.select_related, {'child': {}}) + self.assertEqual(cl.queryset.query.select_related, {"child": {}}) def test_select_related_as_tuple(self): ia = InvitationAdmin(Invitation, custom_site) - request = self.factory.get('/invitation/') + request = self.factory.get("/invitation/") request.user = self.superuser cl = ia.get_changelist_instance(request) - self.assertEqual(cl.queryset.query.select_related, {'player': {}}) + self.assertEqual(cl.queryset.query.select_related, {"player": {}}) def test_select_related_as_empty_tuple(self): ia = InvitationAdmin(Invitation, custom_site) ia.list_select_related = () - request = self.factory.get('/invitation/') + request = self.factory.get("/invitation/") request.user = self.superuser cl = ia.get_changelist_instance(request) self.assertIs(cl.queryset.query.select_related, False) def test_get_select_related_custom_method(self): class GetListSelectRelatedAdmin(admin.ModelAdmin): - list_display = ('band', 'player') + list_display = ("band", "player") def get_list_select_related(self, request): - return ('band', 'player') + return ("band", "player") ia = GetListSelectRelatedAdmin(Invitation, custom_site) - request = self.factory.get('/invitation/') + request = self.factory.get("/invitation/") request.user = self.superuser cl = ia.get_changelist_instance(request) - self.assertEqual(cl.queryset.query.select_related, {'player': {}, 'band': {}}) + self.assertEqual(cl.queryset.query.select_related, {"player": {}, "band": {}}) def test_many_search_terms(self): - parent = Parent.objects.create(name='Mary') - Child.objects.create(parent=parent, name='Danielle') - Child.objects.create(parent=parent, name='Daniel') + parent = Parent.objects.create(name="Mary") + Child.objects.create(parent=parent, name="Danielle") + Child.objects.create(parent=parent, name="Daniel") m = ParentAdmin(Parent, custom_site) - request = self.factory.get('/parent/', data={SEARCH_VAR: 'daniel ' * 80}) + request = self.factory.get("/parent/", data={SEARCH_VAR: "daniel " * 80}) request.user = self.superuser cl = m.get_changelist_instance(request) with CaptureQueriesContext(connection) as context: object_count = cl.queryset.count() self.assertEqual(object_count, 1) - self.assertEqual(context.captured_queries[0]['sql'].count('JOIN'), 1) + self.assertEqual(context.captured_queries[0]["sql"].count("JOIN"), 1) def test_related_field_multiple_search_terms(self): """ Searches over multi-valued relationships return rows from related models only when all searched fields match that row. """ - parent = Parent.objects.create(name='Mary') - Child.objects.create(parent=parent, name='Danielle', age=18) - Child.objects.create(parent=parent, name='Daniel', age=19) + parent = Parent.objects.create(name="Mary") + Child.objects.create(parent=parent, name="Danielle", age=18) + Child.objects.create(parent=parent, name="Daniel", age=19) m = ParentAdminTwoSearchFields(Parent, custom_site) - request = self.factory.get('/parent/', data={SEARCH_VAR: 'danielle 19'}) + request = self.factory.get("/parent/", data={SEARCH_VAR: "danielle 19"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertEqual(cl.queryset.count(), 0) - request = self.factory.get('/parent/', data={SEARCH_VAR: 'daniel 19'}) + request = self.factory.get("/parent/", data={SEARCH_VAR: "daniel 19"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertEqual(cl.queryset.count(), 1) @@ -194,78 +231,108 @@ class ChangeListTests(TestCase): Regression test for #14982: EMPTY_CHANGELIST_VALUE should be honored for relationship fields """ - new_child = Child.objects.create(name='name', parent=None) - request = self.factory.get('/child/') + new_child = Child.objects.create(name="name", parent=None) + request = self.factory.get("/child/") request.user = self.superuser m = ChildAdmin(Child, custom_site) cl = m.get_changelist_instance(request) cl.formset = None - template = Template('{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}') - context = Context({'cl': cl, 'opts': Child._meta}) + template = Template( + "{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}" + ) + context = Context({"cl": cl, "opts": Child._meta}) table_output = template.render(context) - link = reverse('admin:admin_changelist_child_change', args=(new_child.id,)) - row_html = build_tbody_html(new_child.id, link, '<td class="field-parent nowrap">-</td>') - self.assertNotEqual(table_output.find(row_html), -1, 'Failed to find expected row element: %s' % table_output) + link = reverse("admin:admin_changelist_child_change", args=(new_child.id,)) + row_html = build_tbody_html( + new_child.id, link, '<td class="field-parent nowrap">-</td>' + ) + self.assertNotEqual( + table_output.find(row_html), + -1, + "Failed to find expected row element: %s" % table_output, + ) def test_result_list_set_empty_value_display_on_admin_site(self): """ Empty value display can be set on AdminSite. """ - new_child = Child.objects.create(name='name', parent=None) - request = self.factory.get('/child/') + new_child = Child.objects.create(name="name", parent=None) + request = self.factory.get("/child/") request.user = self.superuser # Set a new empty display value on AdminSite. - admin.site.empty_value_display = '???' + admin.site.empty_value_display = "???" m = ChildAdmin(Child, admin.site) cl = m.get_changelist_instance(request) cl.formset = None - template = Template('{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}') - context = Context({'cl': cl, 'opts': Child._meta}) + template = Template( + "{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}" + ) + context = Context({"cl": cl, "opts": Child._meta}) table_output = template.render(context) - link = reverse('admin:admin_changelist_child_change', args=(new_child.id,)) - row_html = build_tbody_html(new_child.id, link, '<td class="field-parent nowrap">???</td>') - self.assertNotEqual(table_output.find(row_html), -1, 'Failed to find expected row element: %s' % table_output) + link = reverse("admin:admin_changelist_child_change", args=(new_child.id,)) + row_html = build_tbody_html( + new_child.id, link, '<td class="field-parent nowrap">???</td>' + ) + self.assertNotEqual( + table_output.find(row_html), + -1, + "Failed to find expected row element: %s" % table_output, + ) def test_result_list_set_empty_value_display_in_model_admin(self): """ Empty value display can be set in ModelAdmin or individual fields. """ - new_child = Child.objects.create(name='name', parent=None) - request = self.factory.get('/child/') + new_child = Child.objects.create(name="name", parent=None) + request = self.factory.get("/child/") request.user = self.superuser m = EmptyValueChildAdmin(Child, admin.site) cl = m.get_changelist_instance(request) cl.formset = None - template = Template('{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}') - context = Context({'cl': cl, 'opts': Child._meta}) + template = Template( + "{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}" + ) + context = Context({"cl": cl, "opts": Child._meta}) table_output = template.render(context) - link = reverse('admin:admin_changelist_child_change', args=(new_child.id,)) + link = reverse("admin:admin_changelist_child_change", args=(new_child.id,)) row_html = build_tbody_html( new_child.id, link, '<td class="field-age_display">&dagger;</td>' - '<td class="field-age">-empty-</td>' + '<td class="field-age">-empty-</td>', + ) + self.assertNotEqual( + table_output.find(row_html), + -1, + "Failed to find expected row element: %s" % table_output, ) - self.assertNotEqual(table_output.find(row_html), -1, 'Failed to find expected row element: %s' % table_output) def test_result_list_html(self): """ Inclusion tag result_list generates a table when with default ModelAdmin settings. """ - new_parent = Parent.objects.create(name='parent') - new_child = Child.objects.create(name='name', parent=new_parent) - request = self.factory.get('/child/') + new_parent = Parent.objects.create(name="parent") + new_child = Child.objects.create(name="name", parent=new_parent) + request = self.factory.get("/child/") request.user = self.superuser m = ChildAdmin(Child, custom_site) cl = m.get_changelist_instance(request) cl.formset = None - template = Template('{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}') - context = Context({'cl': cl, 'opts': Child._meta}) + template = Template( + "{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}" + ) + context = Context({"cl": cl, "opts": Child._meta}) table_output = template.render(context) - link = reverse('admin:admin_changelist_child_change', args=(new_child.id,)) - row_html = build_tbody_html(new_child.id, link, '<td class="field-parent nowrap">%s</td>' % new_parent) - self.assertNotEqual(table_output.find(row_html), -1, 'Failed to find expected row element: %s' % table_output) + link = reverse("admin:admin_changelist_child_change", args=(new_child.id,)) + row_html = build_tbody_html( + new_child.id, link, '<td class="field-parent nowrap">%s</td>' % new_parent + ) + self.assertNotEqual( + table_output.find(row_html), + -1, + "Failed to find expected row element: %s" % table_output, + ) def test_result_list_editable_html(self): """ @@ -276,29 +343,33 @@ class ChangeListTests(TestCase): when list_editable is enabled are rendered in a div outside the table. """ - new_parent = Parent.objects.create(name='parent') - new_child = Child.objects.create(name='name', parent=new_parent) - request = self.factory.get('/child/') + new_parent = Parent.objects.create(name="parent") + new_child = Child.objects.create(name="name", parent=new_parent) + request = self.factory.get("/child/") request.user = self.superuser m = ChildAdmin(Child, custom_site) # Test with list_editable fields - m.list_display = ['id', 'name', 'parent'] - m.list_display_links = ['id'] - m.list_editable = ['name'] + m.list_display = ["id", "name", "parent"] + m.list_display_links = ["id"] + m.list_editable = ["name"] cl = m.get_changelist_instance(request) FormSet = m.get_changelist_formset(request) cl.formset = FormSet(queryset=cl.result_list) - template = Template('{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}') - context = Context({'cl': cl, 'opts': Child._meta}) + template = Template( + "{% load admin_list %}{% spaceless %}{% result_list cl %}{% endspaceless %}" + ) + context = Context({"cl": cl, "opts": Child._meta}) table_output = template.render(context) # make sure that hidden fields are in the correct place hiddenfields_div = ( '<div class="hiddenfields">' '<input type="hidden" name="form-0-id" value="%d" id="id_form-0-id">' - '</div>' + "</div>" ) % new_child.id - self.assertInHTML(hiddenfields_div, table_output, msg_prefix='Failed to find hidden fields') + self.assertInHTML( + hiddenfields_div, table_output, msg_prefix="Failed to find hidden fields" + ) # make sure that list editable fields are rendered in divs correctly editable_name_field = ( @@ -315,26 +386,26 @@ class ChangeListTests(TestCase): """ Regression test for #14312: list_editable with pagination """ - new_parent = Parent.objects.create(name='parent') + new_parent = Parent.objects.create(name="parent") for i in range(1, 201): - Child.objects.create(name='name %s' % i, parent=new_parent) - request = self.factory.get('/child/', data={'p': -1}) # Anything outside range + Child.objects.create(name="name %s" % i, parent=new_parent) + request = self.factory.get("/child/", data={"p": -1}) # Anything outside range request.user = self.superuser m = ChildAdmin(Child, custom_site) # Test with list_editable fields - m.list_display = ['id', 'name', 'parent'] - m.list_display_links = ['id'] - m.list_editable = ['name'] + m.list_display = ["id", "name", "parent"] + m.list_display_links = ["id"] + m.list_editable = ["name"] with self.assertRaises(IncorrectLookupParameters): m.get_changelist_instance(request) def test_custom_paginator(self): - new_parent = Parent.objects.create(name='parent') + new_parent = Parent.objects.create(name="parent") for i in range(1, 201): - Child.objects.create(name='name %s' % i, parent=new_parent) + Child.objects.create(name="name %s" % i, parent=new_parent) - request = self.factory.get('/child/') + request = self.factory.get("/child/") request.user = self.superuser m = CustomPaginationAdmin(Child, custom_site) @@ -347,14 +418,14 @@ class ChangeListTests(TestCase): Regression test for #13902: When using a ManyToMany in list_filter, results shouldn't appear more than once. Basic ManyToMany. """ - blues = Genre.objects.create(name='Blues') - band = Band.objects.create(name='B.B. King Review', nr_of_members=11) + blues = Genre.objects.create(name="Blues") + band = Band.objects.create(name="B.B. King Review", nr_of_members=11) band.genres.add(blues) band.genres.add(blues) m = BandAdmin(Band, custom_site) - request = self.factory.get('/band/', data={'genres': blues.pk}) + request = self.factory.get("/band/", data={"genres": blues.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -372,13 +443,13 @@ class ChangeListTests(TestCase): Regression test for #13902: When using a ManyToMany in list_filter, results shouldn't appear more than once. With an intermediate model. """ - lead = Musician.objects.create(name='Vox') - band = Group.objects.create(name='The Hype') - Membership.objects.create(group=band, music=lead, role='lead voice') - Membership.objects.create(group=band, music=lead, role='bass player') + lead = Musician.objects.create(name="Vox") + band = Group.objects.create(name="The Hype") + Membership.objects.create(group=band, music=lead, role="lead voice") + Membership.objects.create(group=band, music=lead, role="bass player") m = GroupAdmin(Group, custom_site) - request = self.factory.get('/group/', data={'members': lead.pk}) + request = self.factory.get("/group/", data={"members": lead.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -396,14 +467,14 @@ class ChangeListTests(TestCase): When using a ManyToMany in list_filter at the second level behind a ForeignKey, results shouldn't appear more than once. """ - lead = Musician.objects.create(name='Vox') - band = Group.objects.create(name='The Hype') - Concert.objects.create(name='Woodstock', group=band) - Membership.objects.create(group=band, music=lead, role='lead voice') - Membership.objects.create(group=band, music=lead, role='bass player') + lead = Musician.objects.create(name="Vox") + band = Group.objects.create(name="The Hype") + Concert.objects.create(name="Woodstock", group=band) + Membership.objects.create(group=band, music=lead, role="lead voice") + Membership.objects.create(group=band, music=lead, role="bass player") m = ConcertAdmin(Concert, custom_site) - request = self.factory.get('/concert/', data={'group__members': lead.pk}) + request = self.factory.get("/concert/", data={"group__members": lead.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -422,13 +493,13 @@ class ChangeListTests(TestCase): results shouldn't appear more than once. Model managed in the admin inherits from the one that defines the relationship. """ - lead = Musician.objects.create(name='John') - four = Quartet.objects.create(name='The Beatles') - Membership.objects.create(group=four, music=lead, role='lead voice') - Membership.objects.create(group=four, music=lead, role='guitar player') + lead = Musician.objects.create(name="John") + four = Quartet.objects.create(name="The Beatles") + Membership.objects.create(group=four, music=lead, role="lead voice") + Membership.objects.create(group=four, music=lead, role="guitar player") m = QuartetAdmin(Quartet, custom_site) - request = self.factory.get('/quartet/', data={'members': lead.pk}) + request = self.factory.get("/quartet/", data={"members": lead.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -447,13 +518,13 @@ class ChangeListTests(TestCase): results shouldn't appear more than once. Target of the relationship inherits from another. """ - lead = ChordsMusician.objects.create(name='Player A') - three = ChordsBand.objects.create(name='The Chords Trio') - Invitation.objects.create(band=three, player=lead, instrument='guitar') - Invitation.objects.create(band=three, player=lead, instrument='bass') + lead = ChordsMusician.objects.create(name="Player A") + three = ChordsBand.objects.create(name="The Chords Trio") + Invitation.objects.create(band=three, player=lead, instrument="guitar") + Invitation.objects.create(band=three, player=lead, instrument="bass") m = ChordsBandAdmin(ChordsBand, custom_site) - request = self.factory.get('/chordsband/', data={'members': lead.pk}) + request = self.factory.get("/chordsband/", data={"members": lead.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -471,13 +542,13 @@ class ChangeListTests(TestCase): Regressions tests for #15819: If a field listed in list_filters is a non-unique related object, results shouldn't appear more than once. """ - parent = Parent.objects.create(name='Mary') + parent = Parent.objects.create(name="Mary") # Two children with the same name - Child.objects.create(parent=parent, name='Daniel') - Child.objects.create(parent=parent, name='Daniel') + Child.objects.create(parent=parent, name="Daniel") + Child.objects.create(parent=parent, name="Daniel") m = ParentAdmin(Parent, custom_site) - request = self.factory.get('/parent/', data={'child__name': 'Daniel'}) + request = self.factory.get("/parent/", data={"child__name": "Daniel"}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -491,12 +562,12 @@ class ChangeListTests(TestCase): def test_changelist_search_form_validation(self): m = ConcertAdmin(Concert, custom_site) tests = [ - ({SEARCH_VAR: '\x00'}, 'Null characters are not allowed.'), - ({SEARCH_VAR: 'some\x00thing'}, 'Null characters are not allowed.'), + ({SEARCH_VAR: "\x00"}, "Null characters are not allowed."), + ({SEARCH_VAR: "some\x00thing"}, "Null characters are not allowed."), ] for case, error in tests: with self.subTest(case=case): - request = self.factory.get('/concert/', case) + request = self.factory.get("/concert/", case) request.user = self.superuser request._messages = CookieStorage(request) m.get_changelist_instance(request) @@ -509,12 +580,12 @@ class ChangeListTests(TestCase): Regressions tests for #15819: If a field listed in search_fields is a non-unique related object, Exists() must be applied. """ - parent = Parent.objects.create(name='Mary') - Child.objects.create(parent=parent, name='Danielle') - Child.objects.create(parent=parent, name='Daniel') + parent = Parent.objects.create(name="Mary") + Child.objects.create(parent=parent, name="Danielle") + Child.objects.create(parent=parent, name="Daniel") m = ParentAdmin(Parent, custom_site) - request = self.factory.get('/parent/', data={SEARCH_VAR: 'daniel'}) + request = self.factory.get("/parent/", data={SEARCH_VAR: "daniel"}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -531,14 +602,14 @@ class ChangeListTests(TestCase): ForeignKey, Exists() must be applied and results shouldn't appear more than once. """ - lead = Musician.objects.create(name='Vox') - band = Group.objects.create(name='The Hype') - Concert.objects.create(name='Woodstock', group=band) - Membership.objects.create(group=band, music=lead, role='lead voice') - Membership.objects.create(group=band, music=lead, role='bass player') + lead = Musician.objects.create(name="Vox") + band = Group.objects.create(name="The Hype") + Concert.objects.create(name="Woodstock", group=band) + Membership.objects.create(group=band, music=lead, role="lead voice") + Membership.objects.create(group=band, music=lead, role="bass player") m = ConcertAdmin(Concert, custom_site) - request = self.factory.get('/concert/', data={SEARCH_VAR: 'vox'}) + request = self.factory.get("/concert/", data={SEARCH_VAR: "vox"}) request.user = self.superuser cl = m.get_changelist_instance(request) @@ -554,139 +625,143 @@ class ChangeListTests(TestCase): All rows containing each of the searched words are returned, where each word must be in one of search_fields. """ - band_duo = Group.objects.create(name='Duo') - band_hype = Group.objects.create(name='The Hype') - mary = Musician.objects.create(name='Mary Halvorson') - jonathan = Musician.objects.create(name='Jonathan Finlayson') + band_duo = Group.objects.create(name="Duo") + band_hype = Group.objects.create(name="The Hype") + mary = Musician.objects.create(name="Mary Halvorson") + jonathan = Musician.objects.create(name="Jonathan Finlayson") band_duo.members.set([mary, jonathan]) - Concert.objects.create(name='Tiny desk concert', group=band_duo) - Concert.objects.create(name='Woodstock concert', group=band_hype) + Concert.objects.create(name="Tiny desk concert", group=band_duo) + Concert.objects.create(name="Woodstock concert", group=band_hype) # FK lookup. concert_model_admin = ConcertAdmin(Concert, custom_site) - concert_model_admin.search_fields = ['group__name', 'name'] + concert_model_admin.search_fields = ["group__name", "name"] # Reverse FK lookup. group_model_admin = GroupAdmin(Group, custom_site) - group_model_admin.search_fields = ['name', 'concert__name', 'members__name'] + group_model_admin.search_fields = ["name", "concert__name", "members__name"] for search_string, result_count in ( - ('Duo Concert', 1), - ('Tiny Desk Concert', 1), - ('Concert', 2), - ('Other Concert', 0), - ('Duo Woodstock', 0), + ("Duo Concert", 1), + ("Tiny Desk Concert", 1), + ("Concert", 2), + ("Other Concert", 0), + ("Duo Woodstock", 0), ): with self.subTest(search_string=search_string): # FK lookup. - request = self.factory.get('/concert/', data={SEARCH_VAR: search_string}) + request = self.factory.get( + "/concert/", data={SEARCH_VAR: search_string} + ) request.user = self.superuser - concert_changelist = concert_model_admin.get_changelist_instance(request) + concert_changelist = concert_model_admin.get_changelist_instance( + request + ) self.assertEqual(concert_changelist.queryset.count(), result_count) # Reverse FK lookup. - request = self.factory.get('/group/', data={SEARCH_VAR: search_string}) + request = self.factory.get("/group/", data={SEARCH_VAR: search_string}) request.user = self.superuser group_changelist = group_model_admin.get_changelist_instance(request) self.assertEqual(group_changelist.queryset.count(), result_count) # Many-to-many lookup. for search_string, result_count in ( - ('Finlayson Duo Tiny', 1), - ('Finlayson', 1), - ('Finlayson Hype', 0), - ('Jonathan Finlayson Duo', 1), - ('Mary Jonathan Duo', 0), - ('Oscar Finlayson Duo', 0), + ("Finlayson Duo Tiny", 1), + ("Finlayson", 1), + ("Finlayson Hype", 0), + ("Jonathan Finlayson Duo", 1), + ("Mary Jonathan Duo", 0), + ("Oscar Finlayson Duo", 0), ): with self.subTest(search_string=search_string): - request = self.factory.get('/group/', data={SEARCH_VAR: search_string}) + request = self.factory.get("/group/", data={SEARCH_VAR: search_string}) request.user = self.superuser group_changelist = group_model_admin.get_changelist_instance(request) self.assertEqual(group_changelist.queryset.count(), result_count) def test_pk_in_search_fields(self): - band = Group.objects.create(name='The Hype') - Concert.objects.create(name='Woodstock', group=band) + band = Group.objects.create(name="The Hype") + Concert.objects.create(name="Woodstock", group=band) m = ConcertAdmin(Concert, custom_site) - m.search_fields = ['group__pk'] + m.search_fields = ["group__pk"] - request = self.factory.get('/concert/', data={SEARCH_VAR: band.pk}) + request = self.factory.get("/concert/", data={SEARCH_VAR: band.pk}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertEqual(cl.queryset.count(), 1) - request = self.factory.get('/concert/', data={SEARCH_VAR: band.pk + 5}) + request = self.factory.get("/concert/", data={SEARCH_VAR: band.pk + 5}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertEqual(cl.queryset.count(), 0) def test_builtin_lookup_in_search_fields(self): - band = Group.objects.create(name='The Hype') - concert = Concert.objects.create(name='Woodstock', group=band) + band = Group.objects.create(name="The Hype") + concert = Concert.objects.create(name="Woodstock", group=band) m = ConcertAdmin(Concert, custom_site) - m.search_fields = ['name__iexact'] + m.search_fields = ["name__iexact"] - request = self.factory.get('/', data={SEARCH_VAR: 'woodstock'}) + request = self.factory.get("/", data={SEARCH_VAR: "woodstock"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, [concert]) - request = self.factory.get('/', data={SEARCH_VAR: 'wood'}) + request = self.factory.get("/", data={SEARCH_VAR: "wood"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, []) def test_custom_lookup_in_search_fields(self): - band = Group.objects.create(name='The Hype') - concert = Concert.objects.create(name='Woodstock', group=band) + band = Group.objects.create(name="The Hype") + concert = Concert.objects.create(name="Woodstock", group=band) m = ConcertAdmin(Concert, custom_site) - m.search_fields = ['group__name__cc'] - with register_lookup(Field, Contains, lookup_name='cc'): - request = self.factory.get('/', data={SEARCH_VAR: 'Hype'}) + m.search_fields = ["group__name__cc"] + with register_lookup(Field, Contains, lookup_name="cc"): + request = self.factory.get("/", data={SEARCH_VAR: "Hype"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, [concert]) - request = self.factory.get('/', data={SEARCH_VAR: 'Woodstock'}) + request = self.factory.get("/", data={SEARCH_VAR: "Woodstock"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, []) def test_spanning_relations_with_custom_lookup_in_search_fields(self): - hype = Group.objects.create(name='The Hype') - concert = Concert.objects.create(name='Woodstock', group=hype) - vox = Musician.objects.create(name='Vox', age=20) + hype = Group.objects.create(name="The Hype") + concert = Concert.objects.create(name="Woodstock", group=hype) + vox = Musician.objects.create(name="Vox", age=20) Membership.objects.create(music=vox, group=hype) # Register a custom lookup on IntegerField to ensure that field # traversing logic in ModelAdmin.get_search_results() works. - with register_lookup(IntegerField, Exact, lookup_name='exactly'): + with register_lookup(IntegerField, Exact, lookup_name="exactly"): m = ConcertAdmin(Concert, custom_site) - m.search_fields = ['group__members__age__exactly'] + m.search_fields = ["group__members__age__exactly"] - request = self.factory.get('/', data={SEARCH_VAR: '20'}) + request = self.factory.get("/", data={SEARCH_VAR: "20"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, [concert]) - request = self.factory.get('/', data={SEARCH_VAR: '21'}) + request = self.factory.get("/", data={SEARCH_VAR: "21"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, []) def test_custom_lookup_with_pk_shortcut(self): - self.assertEqual(CharPK._meta.pk.name, 'char_pk') # Not equal to 'pk'. + self.assertEqual(CharPK._meta.pk.name, "char_pk") # Not equal to 'pk'. m = admin.ModelAdmin(CustomIdUser, custom_site) - abc = CharPK.objects.create(char_pk='abc') - abcd = CharPK.objects.create(char_pk='abcd') + abc = CharPK.objects.create(char_pk="abc") + abcd = CharPK.objects.create(char_pk="abcd") m = admin.ModelAdmin(CharPK, custom_site) - m.search_fields = ['pk__exact'] + m.search_fields = ["pk__exact"] - request = self.factory.get('/', data={SEARCH_VAR: 'abc'}) + request = self.factory.get("/", data={SEARCH_VAR: "abc"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, [abc]) - request = self.factory.get('/', data={SEARCH_VAR: 'abcd'}) + request = self.factory.get("/", data={SEARCH_VAR: "abcd"}) request.user = self.superuser cl = m.get_changelist_instance(request) self.assertCountEqual(cl.queryset, [abcd]) @@ -697,29 +772,29 @@ class ChangeListTests(TestCase): the changelist's query shouldn't have Exists(). """ m = BandAdmin(Band, custom_site) - for lookup_params in ({}, {'name': 'test'}): - request = self.factory.get('/band/', lookup_params) + for lookup_params in ({}, {"name": "test"}): + request = self.factory.get("/band/", lookup_params) request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertNotIn(' EXISTS', str(cl.queryset.query)) + self.assertNotIn(" EXISTS", str(cl.queryset.query)) # A ManyToManyField in params does have Exists() applied. - request = self.factory.get('/band/', {'genres': '0'}) + request = self.factory.get("/band/", {"genres": "0"}) request.user = self.superuser cl = m.get_changelist_instance(request) - self.assertIn(' EXISTS', str(cl.queryset.query)) + self.assertIn(" EXISTS", str(cl.queryset.query)) def test_pagination(self): """ Regression tests for #12893: Pagination in admins changelist doesn't use queryset set by modeladmin. """ - parent = Parent.objects.create(name='anything') + parent = Parent.objects.create(name="anything") for i in range(1, 31): - Child.objects.create(name='name %s' % i, parent=parent) - Child.objects.create(name='filtered %s' % i, parent=parent) + Child.objects.create(name="name %s" % i, parent=parent) + Child.objects.create(name="filtered %s" % i, parent=parent) - request = self.factory.get('/child/') + request = self.factory.get("/child/") request.user = self.superuser # Test default queryset @@ -743,7 +818,7 @@ class ChangeListTests(TestCase): """ self.client.force_login(self.superuser) event = Event.objects.create(date=datetime.date.today()) - response = self.client.get(reverse('admin:admin_changelist_event_changelist')) + response = self.client.get(reverse("admin:admin_changelist_event_changelist")) self.assertContains(response, formats.localize(event.date)) self.assertNotContains(response, str(event.date)) @@ -751,52 +826,52 @@ class ChangeListTests(TestCase): """ Regression tests for #14206: dynamic list_display support. """ - parent = Parent.objects.create(name='parent') + parent = Parent.objects.create(name="parent") for i in range(10): - Child.objects.create(name='child %s' % i, parent=parent) + Child.objects.create(name="child %s" % i, parent=parent) - user_noparents = self._create_superuser('noparents') - user_parents = self._create_superuser('parents') + user_noparents = self._create_superuser("noparents") + user_parents = self._create_superuser("parents") # Test with user 'noparents' m = custom_site._registry[Child] - request = self._mocked_authenticated_request('/child/', user_noparents) + request = self._mocked_authenticated_request("/child/", user_noparents) response = m.changelist_view(request) - self.assertNotContains(response, 'Parent object') + self.assertNotContains(response, "Parent object") list_display = m.get_list_display(request) list_display_links = m.get_list_display_links(request, list_display) - self.assertEqual(list_display, ['name', 'age']) - self.assertEqual(list_display_links, ['name']) + self.assertEqual(list_display, ["name", "age"]) + self.assertEqual(list_display_links, ["name"]) # Test with user 'parents' m = DynamicListDisplayChildAdmin(Child, custom_site) - request = self._mocked_authenticated_request('/child/', user_parents) + request = self._mocked_authenticated_request("/child/", user_parents) response = m.changelist_view(request) - self.assertContains(response, 'Parent object') + self.assertContains(response, "Parent object") custom_site.unregister(Child) list_display = m.get_list_display(request) list_display_links = m.get_list_display_links(request, list_display) - self.assertEqual(list_display, ('parent', 'name', 'age')) - self.assertEqual(list_display_links, ['parent']) + self.assertEqual(list_display, ("parent", "name", "age")) + self.assertEqual(list_display_links, ["parent"]) # Test default implementation custom_site.register(Child, ChildAdmin) m = custom_site._registry[Child] - request = self._mocked_authenticated_request('/child/', user_noparents) + request = self._mocked_authenticated_request("/child/", user_noparents) response = m.changelist_view(request) - self.assertContains(response, 'Parent object') + self.assertContains(response, "Parent object") def test_show_all(self): - parent = Parent.objects.create(name='anything') + parent = Parent.objects.create(name="anything") for i in range(1, 31): - Child.objects.create(name='name %s' % i, parent=parent) - Child.objects.create(name='filtered %s' % i, parent=parent) + Child.objects.create(name="name %s" % i, parent=parent) + Child.objects.create(name="filtered %s" % i, parent=parent) # Add "show all" parameter to request - request = self.factory.get('/child/', data={ALL_VAR: ''}) + request = self.factory.get("/child/", data={ALL_VAR: ""}) request.user = self.superuser # Test valid "show all" request (number of total objects is under max) @@ -820,52 +895,52 @@ class ChangeListTests(TestCase): """ Regression tests for #16257: dynamic list_display_links support. """ - parent = Parent.objects.create(name='parent') + parent = Parent.objects.create(name="parent") for i in range(1, 10): - Child.objects.create(id=i, name='child %s' % i, parent=parent, age=i) + Child.objects.create(id=i, name="child %s" % i, parent=parent, age=i) m = DynamicListDisplayLinksChildAdmin(Child, custom_site) - superuser = self._create_superuser('superuser') - request = self._mocked_authenticated_request('/child/', superuser) + superuser = self._create_superuser("superuser") + request = self._mocked_authenticated_request("/child/", superuser) response = m.changelist_view(request) for i in range(1, 10): - link = reverse('admin:admin_changelist_child_change', args=(i,)) + link = reverse("admin:admin_changelist_child_change", args=(i,)) self.assertContains(response, '<a href="%s">%s</a>' % (link, i)) list_display = m.get_list_display(request) list_display_links = m.get_list_display_links(request, list_display) - self.assertEqual(list_display, ('parent', 'name', 'age')) - self.assertEqual(list_display_links, ['age']) + self.assertEqual(list_display, ("parent", "name", "age")) + self.assertEqual(list_display_links, ["age"]) def test_no_list_display_links(self): """#15185 -- Allow no links from the 'change list' view grid.""" - p = Parent.objects.create(name='parent') + p = Parent.objects.create(name="parent") m = NoListDisplayLinksParentAdmin(Parent, custom_site) - superuser = self._create_superuser('superuser') - request = self._mocked_authenticated_request('/parent/', superuser) + superuser = self._create_superuser("superuser") + request = self._mocked_authenticated_request("/parent/", superuser) response = m.changelist_view(request) - link = reverse('admin:admin_changelist_parent_change', args=(p.pk,)) + link = reverse("admin:admin_changelist_parent_change", args=(p.pk,)) self.assertNotContains(response, '<a href="%s">' % link) def test_clear_all_filters_link(self): self.client.force_login(self.superuser) - url = reverse('admin:auth_user_changelist') + url = reverse("admin:auth_user_changelist") response = self.client.get(url) - self.assertNotContains(response, '✖ Clear all filters') + self.assertNotContains(response, "✖ Clear all filters") link = '<a href="%s">✖ Clear all filters</a>' for data, href in ( - ({'is_staff__exact': '0'}, '?'), + ({"is_staff__exact": "0"}, "?"), ( - {'is_staff__exact': '0', 'username__startswith': 'test'}, - '?username__startswith=test', + {"is_staff__exact": "0", "username__startswith": "test"}, + "?username__startswith=test", ), ( - {'is_staff__exact': '0', SEARCH_VAR: 'test'}, - '?%s=test' % SEARCH_VAR, + {"is_staff__exact": "0", SEARCH_VAR: "test"}, + "?%s=test" % SEARCH_VAR, ), ( - {'is_staff__exact': '0', IS_POPUP_VAR: 'id'}, - '?%s=id' % IS_POPUP_VAR, + {"is_staff__exact": "0", IS_POPUP_VAR: "id"}, + "?%s=id" % IS_POPUP_VAR, ), ): with self.subTest(data=data): @@ -874,19 +949,19 @@ class ChangeListTests(TestCase): def test_clear_all_filters_link_callable_filter(self): self.client.force_login(self.superuser) - url = reverse('admin:admin_changelist_band_changelist') + url = reverse("admin:admin_changelist_band_changelist") response = self.client.get(url) - self.assertNotContains(response, '✖ Clear all filters') + self.assertNotContains(response, "✖ Clear all filters") link = '<a href="%s">✖ Clear all filters</a>' for data, href in ( - ({'nr_of_members_partition': '5'}, '?'), + ({"nr_of_members_partition": "5"}, "?"), ( - {'nr_of_members_partition': 'more', 'name__startswith': 'test'}, - '?name__startswith=test', + {"nr_of_members_partition": "more", "name__startswith": "test"}, + "?name__startswith=test", ), ( - {'nr_of_members_partition': '5', IS_POPUP_VAR: 'id'}, - '?%s=id' % IS_POPUP_VAR, + {"nr_of_members_partition": "5", IS_POPUP_VAR: "id"}, + "?%s=id" % IS_POPUP_VAR, ), ): with self.subTest(data=data): @@ -895,28 +970,28 @@ class ChangeListTests(TestCase): def test_no_clear_all_filters_link(self): self.client.force_login(self.superuser) - url = reverse('admin:auth_user_changelist') - link = '>✖ Clear all filters</a>' + url = reverse("admin:auth_user_changelist") + link = ">✖ Clear all filters</a>" for data in ( - {SEARCH_VAR: 'test'}, - {ORDER_VAR: '-1'}, - {TO_FIELD_VAR: 'id'}, - {PAGE_VAR: '1'}, - {IS_POPUP_VAR: '1'}, - {'username__startswith': 'test'}, + {SEARCH_VAR: "test"}, + {ORDER_VAR: "-1"}, + {TO_FIELD_VAR: "id"}, + {PAGE_VAR: "1"}, + {IS_POPUP_VAR: "1"}, + {"username__startswith": "test"}, ): with self.subTest(data=data): response = self.client.get(url, data=data) self.assertNotContains(response, link) def test_tuple_list_display(self): - swallow = Swallow.objects.create(origin='Africa', load='12.34', speed='22.2') - swallow2 = Swallow.objects.create(origin='Africa', load='12.34', speed='22.2') + swallow = Swallow.objects.create(origin="Africa", load="12.34", speed="22.2") + swallow2 = Swallow.objects.create(origin="Africa", load="12.34", speed="22.2") swallow_o2o = SwallowOneToOne.objects.create(swallow=swallow2) model_admin = SwallowAdmin(Swallow, custom_site) - superuser = self._create_superuser('superuser') - request = self._mocked_authenticated_request('/swallow/', superuser) + superuser = self._create_superuser("superuser") + request = self._mocked_authenticated_request("/swallow/", superuser) response = model_admin.changelist_view(request) # just want to ensure it doesn't blow up during rendering self.assertContains(response, str(swallow.origin)) @@ -924,7 +999,9 @@ class ChangeListTests(TestCase): self.assertContains(response, str(swallow.speed)) # Reverse one-to-one relations should work. self.assertContains(response, '<td class="field-swallowonetoone">-</td>') - self.assertContains(response, '<td class="field-swallowonetoone">%s</td>' % swallow_o2o) + self.assertContains( + response, '<td class="field-swallowonetoone">%s</td>' % swallow_o2o + ) def test_multiuser_edit(self): """ @@ -945,150 +1022,152 @@ class ChangeListTests(TestCase): # Setup the test to reflect the DB state after step 2 where User2 has # edited the first swallow object's speed from '4' to '1'. - a = Swallow.objects.create(origin='Swallow A', load=4, speed=1) - b = Swallow.objects.create(origin='Swallow B', load=2, speed=2) - c = Swallow.objects.create(origin='Swallow C', load=5, speed=5) - d = Swallow.objects.create(origin='Swallow D', load=9, speed=9) + a = Swallow.objects.create(origin="Swallow A", load=4, speed=1) + b = Swallow.objects.create(origin="Swallow B", load=2, speed=2) + c = Swallow.objects.create(origin="Swallow C", load=5, speed=5) + d = Swallow.objects.create(origin="Swallow D", load=9, speed=9) - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") self.client.force_login(superuser) - changelist_url = reverse('admin:admin_changelist_swallow_changelist') + changelist_url = reverse("admin:admin_changelist_swallow_changelist") # Send the POST from User1 for step 3. It's still using the changelist # ordering from before User2's edits in step 2. data = { - 'form-TOTAL_FORMS': '3', - 'form-INITIAL_FORMS': '3', - 'form-MIN_NUM_FORMS': '0', - 'form-MAX_NUM_FORMS': '1000', - 'form-0-uuid': str(d.pk), - 'form-1-uuid': str(c.pk), - 'form-2-uuid': str(a.pk), - 'form-0-load': '9.0', - 'form-0-speed': '9.0', - 'form-1-load': '5.0', - 'form-1-speed': '5.0', - 'form-2-load': '5.0', - 'form-2-speed': '4.0', - '_save': 'Save', + "form-TOTAL_FORMS": "3", + "form-INITIAL_FORMS": "3", + "form-MIN_NUM_FORMS": "0", + "form-MAX_NUM_FORMS": "1000", + "form-0-uuid": str(d.pk), + "form-1-uuid": str(c.pk), + "form-2-uuid": str(a.pk), + "form-0-load": "9.0", + "form-0-speed": "9.0", + "form-1-load": "5.0", + "form-1-speed": "5.0", + "form-2-load": "5.0", + "form-2-speed": "4.0", + "_save": "Save", } - response = self.client.post(changelist_url, data, follow=True, extra={'o': '-2'}) + response = self.client.post( + changelist_url, data, follow=True, extra={"o": "-2"} + ) # The object User1 edited in step 3 is displayed on the changelist and # has the correct edits applied. - self.assertContains(response, '1 swallow was changed successfully.') + self.assertContains(response, "1 swallow was changed successfully.") self.assertContains(response, a.origin) a.refresh_from_db() - self.assertEqual(a.load, float(data['form-2-load'])) - self.assertEqual(a.speed, float(data['form-2-speed'])) + self.assertEqual(a.load, float(data["form-2-load"])) + self.assertEqual(a.speed, float(data["form-2-speed"])) b.refresh_from_db() self.assertEqual(b.load, 2) self.assertEqual(b.speed, 2) c.refresh_from_db() - self.assertEqual(c.load, float(data['form-1-load'])) - self.assertEqual(c.speed, float(data['form-1-speed'])) + self.assertEqual(c.load, float(data["form-1-load"])) + self.assertEqual(c.speed, float(data["form-1-speed"])) d.refresh_from_db() - self.assertEqual(d.load, float(data['form-0-load'])) - self.assertEqual(d.speed, float(data['form-0-speed'])) + self.assertEqual(d.load, float(data["form-0-load"])) + self.assertEqual(d.speed, float(data["form-0-speed"])) # No new swallows were created. self.assertEqual(len(Swallow.objects.all()), 4) def test_get_edited_object_ids(self): - a = Swallow.objects.create(origin='Swallow A', load=4, speed=1) - b = Swallow.objects.create(origin='Swallow B', load=2, speed=2) - c = Swallow.objects.create(origin='Swallow C', load=5, speed=5) - superuser = self._create_superuser('superuser') + a = Swallow.objects.create(origin="Swallow A", load=4, speed=1) + b = Swallow.objects.create(origin="Swallow B", load=2, speed=2) + c = Swallow.objects.create(origin="Swallow C", load=5, speed=5) + superuser = self._create_superuser("superuser") self.client.force_login(superuser) - changelist_url = reverse('admin:admin_changelist_swallow_changelist') + changelist_url = reverse("admin:admin_changelist_swallow_changelist") m = SwallowAdmin(Swallow, custom_site) data = { - 'form-TOTAL_FORMS': '3', - 'form-INITIAL_FORMS': '3', - 'form-MIN_NUM_FORMS': '0', - 'form-MAX_NUM_FORMS': '1000', - 'form-0-uuid': str(a.pk), - 'form-1-uuid': str(b.pk), - 'form-2-uuid': str(c.pk), - 'form-0-load': '9.0', - 'form-0-speed': '9.0', - 'form-1-load': '5.0', - 'form-1-speed': '5.0', - 'form-2-load': '5.0', - 'form-2-speed': '4.0', - '_save': 'Save', + "form-TOTAL_FORMS": "3", + "form-INITIAL_FORMS": "3", + "form-MIN_NUM_FORMS": "0", + "form-MAX_NUM_FORMS": "1000", + "form-0-uuid": str(a.pk), + "form-1-uuid": str(b.pk), + "form-2-uuid": str(c.pk), + "form-0-load": "9.0", + "form-0-speed": "9.0", + "form-1-load": "5.0", + "form-1-speed": "5.0", + "form-2-load": "5.0", + "form-2-speed": "4.0", + "_save": "Save", } request = self.factory.post(changelist_url, data=data) - pks = m._get_edited_object_pks(request, prefix='form') + pks = m._get_edited_object_pks(request, prefix="form") self.assertEqual(sorted(pks), sorted([str(a.pk), str(b.pk), str(c.pk)])) def test_get_list_editable_queryset(self): - a = Swallow.objects.create(origin='Swallow A', load=4, speed=1) - Swallow.objects.create(origin='Swallow B', load=2, speed=2) + a = Swallow.objects.create(origin="Swallow A", load=4, speed=1) + Swallow.objects.create(origin="Swallow B", load=2, speed=2) data = { - 'form-TOTAL_FORMS': '2', - 'form-INITIAL_FORMS': '2', - 'form-MIN_NUM_FORMS': '0', - 'form-MAX_NUM_FORMS': '1000', - 'form-0-uuid': str(a.pk), - 'form-0-load': '10', - '_save': 'Save', + "form-TOTAL_FORMS": "2", + "form-INITIAL_FORMS": "2", + "form-MIN_NUM_FORMS": "0", + "form-MAX_NUM_FORMS": "1000", + "form-0-uuid": str(a.pk), + "form-0-load": "10", + "_save": "Save", } - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") self.client.force_login(superuser) - changelist_url = reverse('admin:admin_changelist_swallow_changelist') + changelist_url = reverse("admin:admin_changelist_swallow_changelist") m = SwallowAdmin(Swallow, custom_site) request = self.factory.post(changelist_url, data=data) - queryset = m._get_list_editable_queryset(request, prefix='form') + queryset = m._get_list_editable_queryset(request, prefix="form") self.assertEqual(queryset.count(), 1) - data['form-0-uuid'] = 'INVALD_PRIMARY_KEY' + data["form-0-uuid"] = "INVALD_PRIMARY_KEY" # The unfiltered queryset is returned if there's invalid data. request = self.factory.post(changelist_url, data=data) - queryset = m._get_list_editable_queryset(request, prefix='form') + queryset = m._get_list_editable_queryset(request, prefix="form") self.assertEqual(queryset.count(), 2) def test_get_list_editable_queryset_with_regex_chars_in_prefix(self): - a = Swallow.objects.create(origin='Swallow A', load=4, speed=1) - Swallow.objects.create(origin='Swallow B', load=2, speed=2) + a = Swallow.objects.create(origin="Swallow A", load=4, speed=1) + Swallow.objects.create(origin="Swallow B", load=2, speed=2) data = { - 'form$-TOTAL_FORMS': '2', - 'form$-INITIAL_FORMS': '2', - 'form$-MIN_NUM_FORMS': '0', - 'form$-MAX_NUM_FORMS': '1000', - 'form$-0-uuid': str(a.pk), - 'form$-0-load': '10', - '_save': 'Save', + "form$-TOTAL_FORMS": "2", + "form$-INITIAL_FORMS": "2", + "form$-MIN_NUM_FORMS": "0", + "form$-MAX_NUM_FORMS": "1000", + "form$-0-uuid": str(a.pk), + "form$-0-load": "10", + "_save": "Save", } - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") self.client.force_login(superuser) - changelist_url = reverse('admin:admin_changelist_swallow_changelist') + changelist_url = reverse("admin:admin_changelist_swallow_changelist") m = SwallowAdmin(Swallow, custom_site) request = self.factory.post(changelist_url, data=data) - queryset = m._get_list_editable_queryset(request, prefix='form$') + queryset = m._get_list_editable_queryset(request, prefix="form$") self.assertEqual(queryset.count(), 1) def test_changelist_view_list_editable_changed_objects_uses_filter(self): """list_editable edits use a filtered queryset to limit memory usage.""" - a = Swallow.objects.create(origin='Swallow A', load=4, speed=1) - Swallow.objects.create(origin='Swallow B', load=2, speed=2) + a = Swallow.objects.create(origin="Swallow A", load=4, speed=1) + Swallow.objects.create(origin="Swallow B", load=2, speed=2) data = { - 'form-TOTAL_FORMS': '2', - 'form-INITIAL_FORMS': '2', - 'form-MIN_NUM_FORMS': '0', - 'form-MAX_NUM_FORMS': '1000', - 'form-0-uuid': str(a.pk), - 'form-0-load': '10', - '_save': 'Save', + "form-TOTAL_FORMS": "2", + "form-INITIAL_FORMS": "2", + "form-MIN_NUM_FORMS": "0", + "form-MAX_NUM_FORMS": "1000", + "form-0-uuid": str(a.pk), + "form-0-load": "10", + "_save": "Save", } - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") self.client.force_login(superuser) - changelist_url = reverse('admin:admin_changelist_swallow_changelist') + changelist_url = reverse("admin:admin_changelist_swallow_changelist") with CaptureQueriesContext(connection) as context: response = self.client.post(changelist_url, data=data) self.assertEqual(response.status_code, 200) - self.assertIn('WHERE', context.captured_queries[4]['sql']) - self.assertIn('IN', context.captured_queries[4]['sql']) + self.assertIn("WHERE", context.captured_queries[4]["sql"]) + self.assertIn("IN", context.captured_queries[4]["sql"]) # Check only the first few characters since the UUID may have dashes. - self.assertIn(str(a.pk)[:8], context.captured_queries[4]['sql']) + self.assertIn(str(a.pk)[:8], context.captured_queries[4]["sql"]) def test_deterministic_order_for_unordered_model(self): """ @@ -1096,7 +1175,7 @@ class ChangeListTests(TestCase): guarantee a deterministic order, even when the model doesn't have any default ordering defined (#17198). """ - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") for counter in range(1, 51): UnorderedObject.objects.create(id=counter, bool=True) @@ -1109,9 +1188,11 @@ class ChangeListTests(TestCase): model_admin = UnorderedObjectAdmin(UnorderedObject, custom_site) counter = 0 if ascending else 51 for page in range(1, 6): - request = self._mocked_authenticated_request('/unorderedobject/?p=%s' % page, superuser) + request = self._mocked_authenticated_request( + "/unorderedobject/?p=%s" % page, superuser + ) response = model_admin.changelist_view(request) - for result in response.context_data['cl'].result_list: + for result in response.context_data["cl"].result_list: counter += 1 if ascending else -1 self.assertEqual(result.id, counter) custom_site.unregister(UnorderedObject) @@ -1121,17 +1202,17 @@ class ChangeListTests(TestCase): # When an order field is defined but multiple records have the same # value for that field, make sure everything gets ordered by -pk as well. - UnorderedObjectAdmin.ordering = ['bool'] + UnorderedObjectAdmin.ordering = ["bool"] check_results_order() # When order fields are defined, including the pk itself, use them. - UnorderedObjectAdmin.ordering = ['bool', '-pk'] + UnorderedObjectAdmin.ordering = ["bool", "-pk"] check_results_order() - UnorderedObjectAdmin.ordering = ['bool', 'pk'] + UnorderedObjectAdmin.ordering = ["bool", "pk"] check_results_order(ascending=True) - UnorderedObjectAdmin.ordering = ['-id', 'bool'] + UnorderedObjectAdmin.ordering = ["-id", "bool"] check_results_order() - UnorderedObjectAdmin.ordering = ['id', 'bool'] + UnorderedObjectAdmin.ordering = ["id", "bool"] check_results_order(ascending=True) def test_deterministic_order_for_model_ordered_by_its_manager(self): @@ -1140,7 +1221,7 @@ class ChangeListTests(TestCase): guarantee a deterministic order, even when the model has a manager that defines a default ordering (#17198). """ - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") for counter in range(1, 51): OrderedObject.objects.create(id=counter, bool=True, number=counter) @@ -1153,9 +1234,11 @@ class ChangeListTests(TestCase): model_admin = OrderedObjectAdmin(OrderedObject, custom_site) counter = 0 if ascending else 51 for page in range(1, 6): - request = self._mocked_authenticated_request('/orderedobject/?p=%s' % page, superuser) + request = self._mocked_authenticated_request( + "/orderedobject/?p=%s" % page, superuser + ) response = model_admin.changelist_view(request) - for result in response.context_data['cl'].result_list: + for result in response.context_data["cl"].result_list: counter += 1 if ascending else -1 self.assertEqual(result.id, counter) custom_site.unregister(OrderedObject) @@ -1165,26 +1248,26 @@ class ChangeListTests(TestCase): # When an order field is defined but multiple records have the same # value for that field, make sure everything gets ordered by -pk as well. - OrderedObjectAdmin.ordering = ['bool'] + OrderedObjectAdmin.ordering = ["bool"] check_results_order() # When order fields are defined, including the pk itself, use them. - OrderedObjectAdmin.ordering = ['bool', '-pk'] + OrderedObjectAdmin.ordering = ["bool", "-pk"] check_results_order() - OrderedObjectAdmin.ordering = ['bool', 'pk'] + OrderedObjectAdmin.ordering = ["bool", "pk"] check_results_order(ascending=True) - OrderedObjectAdmin.ordering = ['-id', 'bool'] + OrderedObjectAdmin.ordering = ["-id", "bool"] check_results_order() - OrderedObjectAdmin.ordering = ['id', 'bool'] + OrderedObjectAdmin.ordering = ["id", "bool"] check_results_order(ascending=True) - @isolate_apps('admin_changelist') + @isolate_apps("admin_changelist") def test_total_ordering_optimization(self): class Related(models.Model): unique_field = models.BooleanField(unique=True) class Meta: - ordering = ('unique_field',) + ordering = ("unique_field",) class Model(models.Model): unique_field = models.BooleanField(unique=True) @@ -1198,64 +1281,69 @@ class ChangeListTests(TestCase): class Meta: unique_together = { - ('field', 'other_field'), - ('field', 'null_field'), - ('related', 'other_related_id'), + ("field", "other_field"), + ("field", "null_field"), + ("related", "other_related_id"), } class ModelAdmin(admin.ModelAdmin): def get_queryset(self, request): return Model.objects.none() - request = self._mocked_authenticated_request('/', self.superuser) - site = admin.AdminSite(name='admin') + request = self._mocked_authenticated_request("/", self.superuser) + site = admin.AdminSite(name="admin") model_admin = ModelAdmin(Model, site) change_list = model_admin.get_changelist_instance(request) tests = ( - ([], ['-pk']), + ([], ["-pk"]), # Unique non-nullable field. - (['unique_field'], ['unique_field']), - (['-unique_field'], ['-unique_field']), + (["unique_field"], ["unique_field"]), + (["-unique_field"], ["-unique_field"]), # Unique nullable field. - (['unique_nullable_field'], ['unique_nullable_field', '-pk']), + (["unique_nullable_field"], ["unique_nullable_field", "-pk"]), # Field. - (['field'], ['field', '-pk']), + (["field"], ["field", "-pk"]), # Related field introspection is not implemented. - (['related__unique_field'], ['related__unique_field', '-pk']), + (["related__unique_field"], ["related__unique_field", "-pk"]), # Related attname unique. - (['related_unique_id'], ['related_unique_id']), + (["related_unique_id"], ["related_unique_id"]), # Related ordering introspection is not implemented. - (['related_unique'], ['related_unique', '-pk']), + (["related_unique"], ["related_unique", "-pk"]), # Composite unique. - (['field', '-other_field'], ['field', '-other_field']), + (["field", "-other_field"], ["field", "-other_field"]), # Composite unique nullable. - (['-field', 'null_field'], ['-field', 'null_field', '-pk']), + (["-field", "null_field"], ["-field", "null_field", "-pk"]), # Composite unique and nullable. - (['-field', 'null_field', 'other_field'], ['-field', 'null_field', 'other_field']), + ( + ["-field", "null_field", "other_field"], + ["-field", "null_field", "other_field"], + ), # Composite unique attnames. - (['related_id', '-other_related_id'], ['related_id', '-other_related_id']), + (["related_id", "-other_related_id"], ["related_id", "-other_related_id"]), # Composite unique names. - (['related', '-other_related_id'], ['related', '-other_related_id', '-pk']), + (["related", "-other_related_id"], ["related", "-other_related_id", "-pk"]), ) # F() objects composite unique. - total_ordering = [F('field'), F('other_field').desc(nulls_last=True)] + total_ordering = [F("field"), F("other_field").desc(nulls_last=True)] # F() objects composite unique nullable. - non_total_ordering = [F('field'), F('null_field').desc(nulls_last=True)] + non_total_ordering = [F("field"), F("null_field").desc(nulls_last=True)] tests += ( (total_ordering, total_ordering), - (non_total_ordering, non_total_ordering + ['-pk']), + (non_total_ordering, non_total_ordering + ["-pk"]), ) for ordering, expected in tests: with self.subTest(ordering=ordering): - self.assertEqual(change_list._get_deterministic_ordering(ordering), expected) + self.assertEqual( + change_list._get_deterministic_ordering(ordering), expected + ) - @isolate_apps('admin_changelist') + @isolate_apps("admin_changelist") def test_total_ordering_optimization_meta_constraints(self): class Related(models.Model): unique_field = models.BooleanField(unique=True) class Meta: - ordering = ('unique_field',) + ordering = ("unique_field",) class Model(models.Model): field_1 = models.BooleanField() @@ -1274,28 +1362,28 @@ class ChangeListTests(TestCase): class Meta: constraints = [ *[ - models.UniqueConstraint(fields=fields, name=''.join(fields)) + models.UniqueConstraint(fields=fields, name="".join(fields)) for fields in ( - ['field_1'], - ['nullable_1'], - ['related_1'], - ['related_2_id'], - ['field_2', 'field_3'], - ['field_2', 'nullable_2'], - ['field_2', 'related_3'], - ['field_3', 'related_4_id'], + ["field_1"], + ["nullable_1"], + ["related_1"], + ["related_2_id"], + ["field_2", "field_3"], + ["field_2", "nullable_2"], + ["field_2", "related_3"], + ["field_3", "related_4_id"], ) ], - models.CheckConstraint(check=models.Q(id__gt=0), name='foo'), + models.CheckConstraint(check=models.Q(id__gt=0), name="foo"), models.UniqueConstraint( - fields=['field_5'], + fields=["field_5"], condition=models.Q(id__gt=10), - name='total_ordering_1', + name="total_ordering_1", ), models.UniqueConstraint( - fields=['field_6'], + fields=["field_6"], condition=models.Q(), - name='total_ordering', + name="total_ordering", ), ] @@ -1303,73 +1391,77 @@ class ChangeListTests(TestCase): def get_queryset(self, request): return Model.objects.none() - request = self._mocked_authenticated_request('/', self.superuser) - site = admin.AdminSite(name='admin') + request = self._mocked_authenticated_request("/", self.superuser) + site = admin.AdminSite(name="admin") model_admin = ModelAdmin(Model, site) change_list = model_admin.get_changelist_instance(request) tests = ( # Unique non-nullable field. - (['field_1'], ['field_1']), + (["field_1"], ["field_1"]), # Unique nullable field. - (['nullable_1'], ['nullable_1', '-pk']), + (["nullable_1"], ["nullable_1", "-pk"]), # Related attname unique. - (['related_1_id'], ['related_1_id']), - (['related_2_id'], ['related_2_id']), + (["related_1_id"], ["related_1_id"]), + (["related_2_id"], ["related_2_id"]), # Related ordering introspection is not implemented. - (['related_1'], ['related_1', '-pk']), + (["related_1"], ["related_1", "-pk"]), # Composite unique. - (['-field_2', 'field_3'], ['-field_2', 'field_3']), + (["-field_2", "field_3"], ["-field_2", "field_3"]), # Composite unique nullable. - (['field_2', '-nullable_2'], ['field_2', '-nullable_2', '-pk']), + (["field_2", "-nullable_2"], ["field_2", "-nullable_2", "-pk"]), # Composite unique and nullable. ( - ['field_2', '-nullable_2', 'field_3'], - ['field_2', '-nullable_2', 'field_3'], + ["field_2", "-nullable_2", "field_3"], + ["field_2", "-nullable_2", "field_3"], ), # Composite field and related field name. - (['field_2', '-related_3'], ['field_2', '-related_3', '-pk']), - (['field_3', 'related_4'], ['field_3', 'related_4', '-pk']), + (["field_2", "-related_3"], ["field_2", "-related_3", "-pk"]), + (["field_3", "related_4"], ["field_3", "related_4", "-pk"]), # Composite field and related field attname. - (['field_2', 'related_3_id'], ['field_2', 'related_3_id']), - (['field_3', '-related_4_id'], ['field_3', '-related_4_id']), + (["field_2", "related_3_id"], ["field_2", "related_3_id"]), + (["field_3", "-related_4_id"], ["field_3", "-related_4_id"]), # Partial unique constraint is ignored. - (['field_5'], ['field_5', '-pk']), + (["field_5"], ["field_5", "-pk"]), # Unique constraint with an empty condition. - (['field_6'], ['field_6']), + (["field_6"], ["field_6"]), ) for ordering, expected in tests: with self.subTest(ordering=ordering): - self.assertEqual(change_list._get_deterministic_ordering(ordering), expected) + self.assertEqual( + change_list._get_deterministic_ordering(ordering), expected + ) def test_dynamic_list_filter(self): """ Regression tests for ticket #17646: dynamic list_filter support. """ - parent = Parent.objects.create(name='parent') + parent = Parent.objects.create(name="parent") for i in range(10): - Child.objects.create(name='child %s' % i, parent=parent) + Child.objects.create(name="child %s" % i, parent=parent) - user_noparents = self._create_superuser('noparents') - user_parents = self._create_superuser('parents') + user_noparents = self._create_superuser("noparents") + user_parents = self._create_superuser("parents") # Test with user 'noparents' m = DynamicListFilterChildAdmin(Child, custom_site) - request = self._mocked_authenticated_request('/child/', user_noparents) + request = self._mocked_authenticated_request("/child/", user_noparents) response = m.changelist_view(request) - self.assertEqual(response.context_data['cl'].list_filter, ['name', 'age']) + self.assertEqual(response.context_data["cl"].list_filter, ["name", "age"]) # Test with user 'parents' m = DynamicListFilterChildAdmin(Child, custom_site) - request = self._mocked_authenticated_request('/child/', user_parents) + request = self._mocked_authenticated_request("/child/", user_parents) response = m.changelist_view(request) - self.assertEqual(response.context_data['cl'].list_filter, ('parent', 'name', 'age')) + self.assertEqual( + response.context_data["cl"].list_filter, ("parent", "name", "age") + ) def test_dynamic_search_fields(self): - child = self._create_superuser('child') + child = self._create_superuser("child") m = DynamicSearchFieldsChildAdmin(Child, custom_site) - request = self._mocked_authenticated_request('/child/', child) + request = self._mocked_authenticated_request("/child/", child) response = m.changelist_view(request) - self.assertEqual(response.context_data['cl'].search_fields, ('name', 'age')) + self.assertEqual(response.context_data["cl"].search_fields, ("name", "age")) def test_pagination_page_range(self): """ @@ -1378,7 +1470,7 @@ class ChangeListTests(TestCase): """ # instantiating and setting up ChangeList object m = GroupAdmin(Group, custom_site) - request = self.factory.get('/group/') + request = self.factory.get("/group/") request.user = self.superuser cl = m.get_changelist_instance(request) cl.list_per_page = 10 @@ -1401,144 +1493,148 @@ class ChangeListTests(TestCase): # assuming exactly `pages * cl.list_per_page` objects Group.objects.all().delete() for i in range(pages * cl.list_per_page): - Group.objects.create(name='test band') + Group.objects.create(name="test band") # setting page number and calculating page range cl.page_num = number cl.get_results(request) - self.assertEqual(list(pagination(cl)['page_range']), expected) + self.assertEqual(list(pagination(cl)["page_range"]), expected) def test_object_tools_displayed_no_add_permission(self): """ When ModelAdmin.has_add_permission() returns False, the object-tools block is still shown. """ - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") m = EventAdmin(Event, custom_site) - request = self._mocked_authenticated_request('/event/', superuser) + request = self._mocked_authenticated_request("/event/", superuser) self.assertFalse(m.has_add_permission(request)) response = m.changelist_view(request) self.assertIn('<ul class="object-tools">', response.rendered_content) # The "Add" button inside the object-tools shouldn't appear. - self.assertNotIn('Add ', response.rendered_content) + self.assertNotIn("Add ", response.rendered_content) def test_search_help_text(self): - superuser = self._create_superuser('superuser') + superuser = self._create_superuser("superuser") m = BandAdmin(Band, custom_site) # search_fields without search_help_text. - m.search_fields = ['name'] - request = self._mocked_authenticated_request('/band/', superuser) + m.search_fields = ["name"] + request = self._mocked_authenticated_request("/band/", superuser) response = m.changelist_view(request) - self.assertIsNone(response.context_data['cl'].search_help_text) + self.assertIsNone(response.context_data["cl"].search_help_text) self.assertNotContains(response, '<div class="help">') # search_fields with search_help_text. - m.search_help_text = 'Search help text' - request = self._mocked_authenticated_request('/band/', superuser) + m.search_help_text = "Search help text" + request = self._mocked_authenticated_request("/band/", superuser) response = m.changelist_view(request) - self.assertEqual(response.context_data['cl'].search_help_text, 'Search help text') + self.assertEqual( + response.context_data["cl"].search_help_text, "Search help text" + ) self.assertContains(response, '<div class="help">Search help text</div>') class GetAdminLogTests(TestCase): - def test_custom_user_pk_not_named_id(self): """ {% get_admin_log %} works if the user model's primary key isn't named 'id'. """ - context = Context({'user': CustomIdUser()}) - template = Template('{% load log %}{% get_admin_log 10 as admin_log for_user user %}') + context = Context({"user": CustomIdUser()}) + template = Template( + "{% load log %}{% get_admin_log 10 as admin_log for_user user %}" + ) # This template tag just logs. - self.assertEqual(template.render(context), '') + self.assertEqual(template.render(context), "") def test_no_user(self): """{% get_admin_log %} works without specifying a user.""" - user = User(username='jondoe', password='secret', email='super@example.com') + user = User(username="jondoe", password="secret", email="super@example.com") user.save() ct = ContentType.objects.get_for_model(User) LogEntry.objects.log_action(user.pk, ct.pk, user.pk, repr(user), 1) t = Template( - '{% load log %}' - '{% get_admin_log 100 as admin_log %}' - '{% for entry in admin_log %}' - '{{ entry|safe }}' - '{% endfor %}' + "{% load log %}" + "{% get_admin_log 100 as admin_log %}" + "{% for entry in admin_log %}" + "{{ entry|safe }}" + "{% endfor %}" ) - self.assertEqual(t.render(Context({})), 'Added “<User: jondoe>”.') + self.assertEqual(t.render(Context({})), "Added “<User: jondoe>”.") def test_missing_args(self): msg = "'get_admin_log' statements require two arguments" with self.assertRaisesMessage(TemplateSyntaxError, msg): - Template('{% load log %}{% get_admin_log 10 as %}') + Template("{% load log %}{% get_admin_log 10 as %}") def test_non_integer_limit(self): msg = "First argument to 'get_admin_log' must be an integer" with self.assertRaisesMessage(TemplateSyntaxError, msg): - Template('{% load log %}{% get_admin_log "10" as admin_log for_user user %}') + Template( + '{% load log %}{% get_admin_log "10" as admin_log for_user user %}' + ) def test_without_as(self): msg = "Second argument to 'get_admin_log' must be 'as'" with self.assertRaisesMessage(TemplateSyntaxError, msg): - Template('{% load log %}{% get_admin_log 10 ad admin_log for_user user %}') + Template("{% load log %}{% get_admin_log 10 ad admin_log for_user user %}") def test_without_for_user(self): msg = "Fourth argument to 'get_admin_log' must be 'for_user'" with self.assertRaisesMessage(TemplateSyntaxError, msg): - Template('{% load log %}{% get_admin_log 10 as admin_log foruser user %}') + Template("{% load log %}{% get_admin_log 10 as admin_log foruser user %}") -@override_settings(ROOT_URLCONF='admin_changelist.urls') +@override_settings(ROOT_URLCONF="admin_changelist.urls") class SeleniumTests(AdminSeleniumTestCase): - available_apps = ['admin_changelist'] + AdminSeleniumTestCase.available_apps + available_apps = ["admin_changelist"] + AdminSeleniumTestCase.available_apps def setUp(self): - User.objects.create_superuser(username='super', password='secret', email=None) + User.objects.create_superuser(username="super", password="secret", email=None) def test_add_row_selection(self): """ The status line for selected rows gets updated correctly (#22038). """ from selenium.webdriver.common.by import By - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:auth_user_changelist')) - form_id = '#changelist-form' + self.admin_login(username="super", password="secret") + self.selenium.get(self.live_server_url + reverse("admin:auth_user_changelist")) + + form_id = "#changelist-form" # Test amount of rows in the Changelist rows = self.selenium.find_elements( - By.CSS_SELECTOR, - '%s #result_list tbody tr' % form_id + By.CSS_SELECTOR, "%s #result_list tbody tr" % form_id ) self.assertEqual(len(rows), 1) row = rows[0] selection_indicator = self.selenium.find_element( - By.CSS_SELECTOR, - '%s .action-counter' % form_id + By.CSS_SELECTOR, "%s .action-counter" % form_id ) - all_selector = self.selenium.find_element(By.ID, 'action-toggle') + all_selector = self.selenium.find_element(By.ID, "action-toggle") row_selector = self.selenium.find_element( By.CSS_SELECTOR, - '%s #result_list tbody tr:first-child .action-select' % form_id + "%s #result_list tbody tr:first-child .action-select" % form_id, ) # Test current selection self.assertEqual(selection_indicator.text, "0 of 1 selected") - self.assertIs(all_selector.get_property('checked'), False) - self.assertEqual(row.get_attribute('class'), '') + self.assertIs(all_selector.get_property("checked"), False) + self.assertEqual(row.get_attribute("class"), "") # Select a row and check again row_selector.click() self.assertEqual(selection_indicator.text, "1 of 1 selected") - self.assertIs(all_selector.get_property('checked'), True) - self.assertEqual(row.get_attribute('class'), 'selected') + self.assertIs(all_selector.get_property("checked"), True) + self.assertEqual(row.get_attribute("class"), "selected") # Deselect a row and check again row_selector.click() self.assertEqual(selection_indicator.text, "0 of 1 selected") - self.assertIs(all_selector.get_property('checked'), False) - self.assertEqual(row.get_attribute('class'), '') + self.assertIs(all_selector.get_property("checked"), False) + self.assertEqual(row.get_attribute("class"), "") def test_modifier_allows_multiple_section(self): """ @@ -1549,89 +1645,105 @@ class SeleniumTests(AdminSeleniumTestCase): from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys - Parent.objects.bulk_create([Parent(name='parent%d' % i) for i in range(5)]) - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:admin_changelist_parent_changelist')) - checkboxes = self.selenium.find_elements(By.CSS_SELECTOR, 'tr input.action-select') + Parent.objects.bulk_create([Parent(name="parent%d" % i) for i in range(5)]) + self.admin_login(username="super", password="secret") + self.selenium.get( + self.live_server_url + reverse("admin:admin_changelist_parent_changelist") + ) + checkboxes = self.selenium.find_elements( + By.CSS_SELECTOR, "tr input.action-select" + ) self.assertEqual(len(checkboxes), 5) for c in checkboxes: - self.assertIs(c.get_property('checked'), False) + self.assertIs(c.get_property("checked"), False) # Check first row. Hold-shift and check next-to-last row. checkboxes[0].click() - ActionChains(self.selenium).key_down(Keys.SHIFT).click(checkboxes[-2]).key_up(Keys.SHIFT).perform() + ActionChains(self.selenium).key_down(Keys.SHIFT).click(checkboxes[-2]).key_up( + Keys.SHIFT + ).perform() for c in checkboxes[:-2]: - self.assertIs(c.get_property('checked'), True) - self.assertIs(checkboxes[-1].get_property('checked'), False) + self.assertIs(c.get_property("checked"), True) + self.assertIs(checkboxes[-1].get_property("checked"), False) def test_select_all_across_pages(self): from selenium.webdriver.common.by import By - Parent.objects.bulk_create([Parent(name='parent%d' % i) for i in range(101)]) - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:admin_changelist_parent_changelist')) - selection_indicator = self.selenium.find_element(By.CSS_SELECTOR, '.action-counter') - select_all_indicator = self.selenium.find_element(By.CSS_SELECTOR, '.actions .all') - question = self.selenium.find_element(By.CSS_SELECTOR, '.actions > .question') - clear = self.selenium.find_element(By.CSS_SELECTOR, '.actions > .clear') - select_all = self.selenium.find_element(By.ID, 'action-toggle') - select_across = self.selenium.find_elements(By.NAME, 'select_across') + Parent.objects.bulk_create([Parent(name="parent%d" % i) for i in range(101)]) + self.admin_login(username="super", password="secret") + self.selenium.get( + self.live_server_url + reverse("admin:admin_changelist_parent_changelist") + ) + + selection_indicator = self.selenium.find_element( + By.CSS_SELECTOR, ".action-counter" + ) + select_all_indicator = self.selenium.find_element( + By.CSS_SELECTOR, ".actions .all" + ) + question = self.selenium.find_element(By.CSS_SELECTOR, ".actions > .question") + clear = self.selenium.find_element(By.CSS_SELECTOR, ".actions > .clear") + select_all = self.selenium.find_element(By.ID, "action-toggle") + select_across = self.selenium.find_elements(By.NAME, "select_across") self.assertIs(question.is_displayed(), False) self.assertIs(clear.is_displayed(), False) - self.assertIs(select_all.get_property('checked'), False) + self.assertIs(select_all.get_property("checked"), False) for hidden_input in select_across: - self.assertEqual(hidden_input.get_property('value'), '0') + self.assertEqual(hidden_input.get_property("value"), "0") self.assertIs(selection_indicator.is_displayed(), True) - self.assertEqual(selection_indicator.text, '0 of 100 selected') + self.assertEqual(selection_indicator.text, "0 of 100 selected") self.assertIs(select_all_indicator.is_displayed(), False) select_all.click() self.assertIs(question.is_displayed(), True) self.assertIs(clear.is_displayed(), False) - self.assertIs(select_all.get_property('checked'), True) + self.assertIs(select_all.get_property("checked"), True) for hidden_input in select_across: - self.assertEqual(hidden_input.get_property('value'), '0') + self.assertEqual(hidden_input.get_property("value"), "0") self.assertIs(selection_indicator.is_displayed(), True) - self.assertEqual(selection_indicator.text, '100 of 100 selected') + self.assertEqual(selection_indicator.text, "100 of 100 selected") self.assertIs(select_all_indicator.is_displayed(), False) question.click() self.assertIs(question.is_displayed(), False) self.assertIs(clear.is_displayed(), True) - self.assertIs(select_all.get_property('checked'), True) + self.assertIs(select_all.get_property("checked"), True) for hidden_input in select_across: - self.assertEqual(hidden_input.get_property('value'), '1') + self.assertEqual(hidden_input.get_property("value"), "1") self.assertIs(selection_indicator.is_displayed(), False) self.assertIs(select_all_indicator.is_displayed(), True) clear.click() self.assertIs(question.is_displayed(), False) self.assertIs(clear.is_displayed(), False) - self.assertIs(select_all.get_property('checked'), False) + self.assertIs(select_all.get_property("checked"), False) for hidden_input in select_across: - self.assertEqual(hidden_input.get_property('value'), '0') + self.assertEqual(hidden_input.get_property("value"), "0") self.assertIs(selection_indicator.is_displayed(), True) - self.assertEqual(selection_indicator.text, '0 of 100 selected') + self.assertEqual(selection_indicator.text, "0 of 100 selected") self.assertIs(select_all_indicator.is_displayed(), False) def test_actions_warn_on_pending_edits(self): from selenium.webdriver.common.by import By - Parent.objects.create(name='foo') - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:admin_changelist_parent_changelist')) + Parent.objects.create(name="foo") - name_input = self.selenium.find_element(By.ID, 'id_form-0-name') + self.admin_login(username="super", password="secret") + self.selenium.get( + self.live_server_url + reverse("admin:admin_changelist_parent_changelist") + ) + + name_input = self.selenium.find_element(By.ID, "id_form-0-name") name_input.clear() - name_input.send_keys('bar') - self.selenium.find_element(By.ID, 'action-toggle').click() - self.selenium.find_element(By.NAME, 'index').click() # Go + name_input.send_keys("bar") + self.selenium.find_element(By.ID, "action-toggle").click() + self.selenium.find_element(By.NAME, "index").click() # Go alert = self.selenium.switch_to.alert try: self.assertEqual( alert.text, - 'You have unsaved changes on individual editable fields. If you ' - 'run an action, your unsaved changes will be lost.' + "You have unsaved changes on individual editable fields. If you " + "run an action, your unsaved changes will be lost.", ) finally: alert.dismiss() @@ -1640,25 +1752,27 @@ class SeleniumTests(AdminSeleniumTestCase): from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import Select - Parent.objects.create(name='parent') + Parent.objects.create(name="parent") - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:admin_changelist_parent_changelist')) + self.admin_login(username="super", password="secret") + self.selenium.get( + self.live_server_url + reverse("admin:admin_changelist_parent_changelist") + ) - name_input = self.selenium.find_element(By.ID, 'id_form-0-name') + name_input = self.selenium.find_element(By.ID, "id_form-0-name") name_input.clear() - name_input.send_keys('other name') - Select( - self.selenium.find_element(By.NAME, 'action') - ).select_by_value('delete_selected') - self.selenium.find_element(By.NAME, '_save').click() + name_input.send_keys("other name") + Select(self.selenium.find_element(By.NAME, "action")).select_by_value( + "delete_selected" + ) + self.selenium.find_element(By.NAME, "_save").click() alert = self.selenium.switch_to.alert try: self.assertEqual( alert.text, - 'You have selected an action, but you haven’t saved your ' - 'changes to individual fields yet. Please click OK to save. ' - 'You’ll need to re-run the action.', + "You have selected an action, but you haven’t saved your " + "changes to individual fields yet. Please click OK to save. " + "You’ll need to re-run the action.", ) finally: alert.dismiss() @@ -1667,22 +1781,24 @@ class SeleniumTests(AdminSeleniumTestCase): from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import Select - Parent.objects.create(name='parent') + Parent.objects.create(name="parent") - self.admin_login(username='super', password='secret') - self.selenium.get(self.live_server_url + reverse('admin:admin_changelist_parent_changelist')) + self.admin_login(username="super", password="secret") + self.selenium.get( + self.live_server_url + reverse("admin:admin_changelist_parent_changelist") + ) - Select( - self.selenium.find_element(By.NAME, 'action') - ).select_by_value('delete_selected') - self.selenium.find_element(By.NAME, '_save').click() + Select(self.selenium.find_element(By.NAME, "action")).select_by_value( + "delete_selected" + ) + self.selenium.find_element(By.NAME, "_save").click() alert = self.selenium.switch_to.alert try: self.assertEqual( alert.text, - 'You have selected an action, and you haven’t made any ' - 'changes on individual fields. You’re probably looking for ' - 'the Go button rather than the Save button.', + "You have selected an action, and you haven’t made any " + "changes on individual fields. You’re probably looking for " + "the Go button rather than the Save button.", ) finally: alert.dismiss() diff --git a/tests/admin_changelist/urls.py b/tests/admin_changelist/urls.py index be569cdca5..ad9b44c9fc 100644 --- a/tests/admin_changelist/urls.py +++ b/tests/admin_changelist/urls.py @@ -3,5 +3,5 @@ from django.urls import path from . import admin urlpatterns = [ - path('admin/', admin.site.urls), + path("admin/", admin.site.urls), ] diff --git a/tests/admin_checks/models.py b/tests/admin_checks/models.py index 3336ce878e..b9140206f1 100644 --- a/tests/admin_checks/models.py +++ b/tests/admin_checks/models.py @@ -17,7 +17,7 @@ class Song(models.Model): original_release = models.DateField(editable=False) class Meta: - ordering = ('title',) + ordering = ("title",) def __str__(self): return self.title @@ -41,7 +41,7 @@ class Book(models.Model): name = models.CharField(max_length=100) subtitle = models.CharField(max_length=100) price = models.FloatField() - authors = models.ManyToManyField(Author, through='AuthorsBooks') + authors = models.ManyToManyField(Author, through="AuthorsBooks") class AuthorsBooks(models.Model): @@ -63,4 +63,4 @@ class Influence(models.Model): content_type = models.ForeignKey(ContentType, models.CASCADE) object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') + content_object = GenericForeignKey("content_type", "object_id") diff --git a/tests/admin_checks/tests.py b/tests/admin_checks/tests.py index 67625c7c86..a2372ba0d2 100644 --- a/tests/admin_checks/tests.py +++ b/tests/admin_checks/tests.py @@ -9,9 +9,7 @@ from django.contrib.sessions.middleware import SessionMiddleware from django.core import checks from django.test import SimpleTestCase, override_settings -from .models import ( - Album, Author, Book, City, Influence, Song, State, TwoAlbumFKAndAnE, -) +from .models import Album, Author, Book, City, Influence, Song, State, TwoAlbumFKAndAnE class SongForm(forms.ModelForm): @@ -20,25 +18,29 @@ class SongForm(forms.ModelForm): class ValidFields(admin.ModelAdmin): form = SongForm - fields = ['title'] + fields = ["title"] class ValidFormFieldsets(admin.ModelAdmin): def get_form(self, request, obj=None, **kwargs): class ExtraFieldForm(SongForm): name = forms.CharField(max_length=50) + return ExtraFieldForm fieldsets = ( - (None, { - 'fields': ('name',), - }), + ( + None, + { + "fields": ("name",), + }, + ), ) class MyAdmin(admin.ModelAdmin): def check(self, **kwargs): - return ['error!'] + return ["error!"] class AuthenticationMiddlewareSubclass(AuthenticationMiddleware): @@ -58,27 +60,26 @@ class SessionMiddlewareSubclass(SessionMiddleware): @override_settings( - SILENCED_SYSTEM_CHECKS=['fields.W342'], # ForeignKey(unique=True) + SILENCED_SYSTEM_CHECKS=["fields.W342"], # ForeignKey(unique=True) INSTALLED_APPS=[ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.messages', - 'admin_checks', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.messages", + "admin_checks", ], ) class SystemChecksTestCase(SimpleTestCase): - def test_checks_are_performed(self): admin.site.register(Song, MyAdmin) try: errors = checks.run_checks() - expected = ['error!'] + expected = ["error!"] self.assertEqual(errors, expected) finally: admin.site.unregister(Song) - @override_settings(INSTALLED_APPS=['django.contrib.admin']) + @override_settings(INSTALLED_APPS=["django.contrib.admin"]) def test_apps_dependencies(self): errors = admin.checks.check_dependencies() expected = [ @@ -90,36 +91,41 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "'django.contrib.auth' must be in INSTALLED_APPS in order " "to use the admin application.", - id='admin.E405', + id="admin.E405", ), checks.Error( "'django.contrib.messages' must be in INSTALLED_APPS in order " "to use the admin application.", - id='admin.E406', + id="admin.E406", ), ] self.assertEqual(errors, expected) @override_settings(TEMPLATES=[]) def test_no_template_engines(self): - self.assertEqual(admin.checks.check_dependencies(), [ - checks.Error( - "A 'django.template.backends.django.DjangoTemplates' " - "instance must be configured in TEMPLATES in order to use " - "the admin application.", - id='admin.E403', - ) - ]) + self.assertEqual( + admin.checks.check_dependencies(), + [ + checks.Error( + "A 'django.template.backends.django.DjangoTemplates' " + "instance must be configured in TEMPLATES in order to use " + "the admin application.", + id="admin.E403", + ) + ], + ) @override_settings( - TEMPLATES=[{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [], - }, - }], + TEMPLATES=[ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [], + }, + } + ], ) def test_context_processor_dependencies(self): expected = [ @@ -127,20 +133,20 @@ class SystemChecksTestCase(SimpleTestCase): "'django.contrib.auth.context_processors.auth' must be " "enabled in DjangoTemplates (TEMPLATES) if using the default " "auth backend in order to use the admin application.", - id='admin.E402', + id="admin.E402", ), checks.Error( "'django.contrib.messages.context_processors.messages' must " "be enabled in DjangoTemplates (TEMPLATES) in order to use " "the admin application.", - id='admin.E404', + id="admin.E404", ), checks.Warning( "'django.template.context_processors.request' must be enabled " "in DjangoTemplates (TEMPLATES) in order to use the admin " "navigation sidebar.", - id='admin.W411', - ) + id="admin.W411", + ), ] self.assertEqual(admin.checks.check_dependencies(), expected) # The first error doesn't happen if @@ -150,45 +156,50 @@ class SystemChecksTestCase(SimpleTestCase): self.assertEqual(admin.checks.check_dependencies(), expected[1:]) @override_settings( - AUTHENTICATION_BACKENDS=['admin_checks.tests.ModelBackendSubclass'], - TEMPLATES=[{ - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.request', - 'django.contrib.messages.context_processors.messages', - ], - }, - }], + AUTHENTICATION_BACKENDS=["admin_checks.tests.ModelBackendSubclass"], + TEMPLATES=[ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.request", + "django.contrib.messages.context_processors.messages", + ], + }, + } + ], ) def test_context_processor_dependencies_model_backend_subclass(self): - self.assertEqual(admin.checks.check_dependencies(), [ - checks.Error( - "'django.contrib.auth.context_processors.auth' must be " - "enabled in DjangoTemplates (TEMPLATES) if using the default " - "auth backend in order to use the admin application.", - id='admin.E402', - ), - ]) + self.assertEqual( + admin.checks.check_dependencies(), + [ + checks.Error( + "'django.contrib.auth.context_processors.auth' must be " + "enabled in DjangoTemplates (TEMPLATES) if using the default " + "auth backend in order to use the admin application.", + id="admin.E402", + ), + ], + ) @override_settings( TEMPLATES=[ { - 'BACKEND': 'django.template.backends.dummy.TemplateStrings', - 'DIRS': [], - 'APP_DIRS': True, + "BACKEND": "django.template.backends.dummy.TemplateStrings", + "DIRS": [], + "APP_DIRS": True, }, { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, @@ -204,12 +215,12 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "'django.contrib.auth.middleware.AuthenticationMiddleware' " "must be in MIDDLEWARE in order to use the admin application.", - id='admin.E408', + id="admin.E408", ), checks.Error( "'django.contrib.messages.middleware.MessageMiddleware' " "must be in MIDDLEWARE in order to use the admin application.", - id='admin.E409', + id="admin.E409", ), checks.Error( "'django.contrib.sessions.middleware.SessionMiddleware' " @@ -220,25 +231,29 @@ class SystemChecksTestCase(SimpleTestCase): "before " "'django.contrib.auth.middleware.AuthenticationMiddleware'." ), - id='admin.E410', + id="admin.E410", ), ] self.assertEqual(errors, expected) - @override_settings(MIDDLEWARE=[ - 'admin_checks.tests.AuthenticationMiddlewareSubclass', - 'admin_checks.tests.MessageMiddlewareSubclass', - 'admin_checks.tests.SessionMiddlewareSubclass', - ]) + @override_settings( + MIDDLEWARE=[ + "admin_checks.tests.AuthenticationMiddlewareSubclass", + "admin_checks.tests.MessageMiddlewareSubclass", + "admin_checks.tests.SessionMiddlewareSubclass", + ] + ) def test_middleware_subclasses(self): self.assertEqual(admin.checks.check_dependencies(), []) - @override_settings(MIDDLEWARE=[ - 'django.contrib.does.not.Exist', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - ]) + @override_settings( + MIDDLEWARE=[ + "django.contrib.does.not.Exist", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + ] + ) def test_admin_check_ignores_import_error_in_middleware(self): self.assertEqual(admin.checks.check_dependencies(), []) @@ -250,7 +265,7 @@ class SystemChecksTestCase(SimpleTestCase): custom_site.register(Song, MyAdmin) try: errors = checks.run_checks() - expected = ['error!'] + expected = ["error!"] self.assertEqual(errors, expected) finally: custom_site.unregister(Song) @@ -261,7 +276,7 @@ class SystemChecksTestCase(SimpleTestCase): errors = super().check(**kwargs) author_admin = self.admin_site._registry.get(Author) if author_admin is None: - errors.append('AuthorAdmin missing!') + errors.append("AuthorAdmin missing!") return errors class MyAuthorAdmin(admin.ModelAdmin): @@ -285,35 +300,41 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'list_editable[0]' refers to 'original_release', " "which is not contained in 'list_display'.", obj=SongAdmin, - id='admin.E122', + id="admin.E122", ) ] self.assertEqual(errors, expected) def test_list_editable_not_a_list_or_tuple(self): class SongAdmin(admin.ModelAdmin): - list_editable = 'test' + list_editable = "test" - self.assertEqual(SongAdmin(Song, AdminSite()).check(), [ - checks.Error( - "The value of 'list_editable' must be a list or tuple.", - obj=SongAdmin, - id='admin.E120', - ) - ]) + self.assertEqual( + SongAdmin(Song, AdminSite()).check(), + [ + checks.Error( + "The value of 'list_editable' must be a list or tuple.", + obj=SongAdmin, + id="admin.E120", + ) + ], + ) def test_list_editable_missing_field(self): class SongAdmin(admin.ModelAdmin): - list_editable = ('test',) + list_editable = ("test",) - self.assertEqual(SongAdmin(Song, AdminSite()).check(), [ - checks.Error( - "The value of 'list_editable[0]' refers to 'test', which is " - "not a field of 'admin_checks.Song'.", - obj=SongAdmin, - id='admin.E121', - ) - ]) + self.assertEqual( + SongAdmin(Song, AdminSite()).check(), + [ + checks.Error( + "The value of 'list_editable[0]' refers to 'test', which is " + "not a field of 'admin_checks.Song'.", + obj=SongAdmin, + id="admin.E121", + ) + ], + ) def test_readonly_and_editable(self): class SongAdmin(admin.ModelAdmin): @@ -321,17 +342,21 @@ class SystemChecksTestCase(SimpleTestCase): list_display = ["pk", "original_release"] list_editable = ["original_release"] fieldsets = [ - (None, { - "fields": ["title", "original_release"], - }), + ( + None, + { + "fields": ["title", "original_release"], + }, + ), ] + errors = SongAdmin(Song, AdminSite()).check() expected = [ checks.Error( "The value of 'list_editable[0]' refers to 'original_release', " "which is not editable through the admin.", obj=SongAdmin, - id='admin.E125', + id="admin.E125", ) ] self.assertEqual(errors, expected) @@ -341,9 +366,12 @@ class SystemChecksTestCase(SimpleTestCase): list_display = ["pk", "title"] list_editable = ["title"] fieldsets = [ - (None, { - "fields": ["title", "original_release"], - }), + ( + None, + { + "fields": ["title", "original_release"], + }, + ), ] errors = SongAdmin(Song, AdminSite()).check() @@ -368,13 +396,12 @@ class SystemChecksTestCase(SimpleTestCase): """ The first fieldset's fields must be a list/tuple. """ + class NotATupleAdmin(admin.ModelAdmin): list_display = ["pk", "title"] list_editable = ["title"] fieldsets = [ - (None, { - "fields": "title" # not a tuple - }), + (None, {"fields": "title"}), # not a tuple ] errors = NotATupleAdmin(Song, AdminSite()).check() @@ -382,7 +409,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "The value of 'fieldsets[0][1]['fields']' must be a list or tuple.", obj=NotATupleAdmin, - id='admin.E008', + id="admin.E008", ) ] self.assertEqual(errors, expected) @@ -391,14 +418,11 @@ class SystemChecksTestCase(SimpleTestCase): """ The second fieldset's fields must be a list/tuple. """ + class NotATupleAdmin(admin.ModelAdmin): fieldsets = [ - (None, { - "fields": ("title",) - }), - ('foo', { - "fields": "author" # not a tuple - }), + (None, {"fields": ("title",)}), + ("foo", {"fields": "author"}), # not a tuple ] errors = NotATupleAdmin(Song, AdminSite()).check() @@ -406,7 +430,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "The value of 'fieldsets[1][1]['fields']' must be a list or tuple.", obj=NotATupleAdmin, - id='admin.E008', + id="admin.E008", ) ] self.assertEqual(errors, expected) @@ -415,29 +439,30 @@ class SystemChecksTestCase(SimpleTestCase): """ Tests for basic system checks of 'exclude' option values (#12689) """ + class ExcludedFields1(admin.ModelAdmin): - exclude = 'foo' + exclude = "foo" errors = ExcludedFields1(Book, AdminSite()).check() expected = [ checks.Error( "The value of 'exclude' must be a list or tuple.", obj=ExcludedFields1, - id='admin.E014', + id="admin.E014", ) ] self.assertEqual(errors, expected) def test_exclude_duplicate_values(self): class ExcludedFields2(admin.ModelAdmin): - exclude = ('name', 'name') + exclude = ("name", "name") errors = ExcludedFields2(Book, AdminSite()).check() expected = [ checks.Error( "The value of 'exclude' contains duplicate field(s).", obj=ExcludedFields2, - id='admin.E015', + id="admin.E015", ) ] self.assertEqual(errors, expected) @@ -445,7 +470,7 @@ class SystemChecksTestCase(SimpleTestCase): def test_exclude_in_inline(self): class ExcludedFieldsInline(admin.TabularInline): model = Song - exclude = 'foo' + exclude = "foo" class ExcludedFieldsAlbumAdmin(admin.ModelAdmin): model = Album @@ -456,7 +481,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "The value of 'exclude' must be a list or tuple.", obj=ExcludedFieldsInline, - id='admin.E014', + id="admin.E014", ) ] self.assertEqual(errors, expected) @@ -466,9 +491,10 @@ class SystemChecksTestCase(SimpleTestCase): Regression test for #9932 - exclude in InlineModelAdmin should not contain the ForeignKey field used in ModelAdmin.model """ + class SongInline(admin.StackedInline): model = Song - exclude = ['album'] + exclude = ["album"] class AlbumAdmin(admin.ModelAdmin): model = Album @@ -480,7 +506,7 @@ class SystemChecksTestCase(SimpleTestCase): "Cannot exclude the field 'album', because it is the foreign key " "to the parent model 'admin_checks.Album'.", obj=SongInline, - id='admin.E201', + id="admin.E201", ) ] self.assertEqual(errors, expected) @@ -490,6 +516,7 @@ class SystemChecksTestCase(SimpleTestCase): Regression test for #22034 - check that generic inlines don't look for normal ForeignKey relations. """ + class InfluenceInline(GenericStackedInline): model = Influence @@ -504,6 +531,7 @@ class SystemChecksTestCase(SimpleTestCase): A model without a GenericForeignKey raises problems if it's included in a GenericInlineModelAdmin definition. """ + class BookInline(GenericStackedInline): model = Book @@ -515,7 +543,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "'admin_checks.Book' has no GenericForeignKey.", obj=BookInline, - id='admin.E301', + id="admin.E301", ) ] self.assertEqual(errors, expected) @@ -525,9 +553,10 @@ class SystemChecksTestCase(SimpleTestCase): A GenericInlineModelAdmin errors if the ct_field points to a nonexistent field. """ + class InfluenceInline(GenericStackedInline): model = Influence - ct_field = 'nonexistent' + ct_field = "nonexistent" class SongAdmin(admin.ModelAdmin): inlines = [InfluenceInline] @@ -537,7 +566,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "'ct_field' references 'nonexistent', which is not a field on 'admin_checks.Influence'.", obj=InfluenceInline, - id='admin.E302', + id="admin.E302", ) ] self.assertEqual(errors, expected) @@ -547,9 +576,10 @@ class SystemChecksTestCase(SimpleTestCase): A GenericInlineModelAdmin errors if the ct_fk_field points to a nonexistent field. """ + class InfluenceInline(GenericStackedInline): model = Influence - ct_fk_field = 'nonexistent' + ct_fk_field = "nonexistent" class SongAdmin(admin.ModelAdmin): inlines = [InfluenceInline] @@ -559,7 +589,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "'ct_fk_field' references 'nonexistent', which is not a field on 'admin_checks.Influence'.", obj=InfluenceInline, - id='admin.E303', + id="admin.E303", ) ] self.assertEqual(errors, expected) @@ -569,9 +599,10 @@ class SystemChecksTestCase(SimpleTestCase): A GenericInlineModelAdmin raises problems if the ct_field points to a field that isn't part of a GenericForeignKey. """ + class InfluenceInline(GenericStackedInline): model = Influence - ct_field = 'name' + ct_field = "name" class SongAdmin(admin.ModelAdmin): inlines = [InfluenceInline] @@ -582,7 +613,7 @@ class SystemChecksTestCase(SimpleTestCase): "'admin_checks.Influence' has no GenericForeignKey using " "content type field 'name' and object ID field 'object_id'.", obj=InfluenceInline, - id='admin.E304', + id="admin.E304", ) ] self.assertEqual(errors, expected) @@ -592,9 +623,10 @@ class SystemChecksTestCase(SimpleTestCase): A GenericInlineModelAdmin raises problems if the ct_fk_field points to a field that isn't part of a GenericForeignKey. """ + class InfluenceInline(GenericStackedInline): model = Influence - ct_fk_field = 'name' + ct_fk_field = "name" class SongAdmin(admin.ModelAdmin): inlines = [InfluenceInline] @@ -605,14 +637,14 @@ class SystemChecksTestCase(SimpleTestCase): "'admin_checks.Influence' has no GenericForeignKey using " "content type field 'content_type' and object ID field 'name'.", obj=InfluenceInline, - id='admin.E304', + id="admin.E304", ) ] self.assertEqual(errors, expected) def test_app_label_in_admin_checks(self): class RawIdNonexistentAdmin(admin.ModelAdmin): - raw_id_fields = ('nonexistent',) + raw_id_fields = ("nonexistent",) errors = RawIdNonexistentAdmin(Album, AdminSite()).check() expected = [ @@ -620,7 +652,7 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'raw_id_fields[0]' refers to 'nonexistent', " "which is not a field of 'admin_checks.Album'.", obj=RawIdNonexistentAdmin, - id='admin.E002', + id="admin.E002", ) ] self.assertEqual(errors, expected) @@ -631,6 +663,7 @@ class SystemChecksTestCase(SimpleTestCase): given) make sure fk_name is honored or things blow up when there is more than one fk to the parent model. """ + class TwoAlbumFKAndAnEInline(admin.TabularInline): model = TwoAlbumFKAndAnE exclude = ("e",) @@ -656,7 +689,7 @@ class SystemChecksTestCase(SimpleTestCase): "to 'admin_checks.Album'. You must specify a 'fk_name' " "attribute.", obj=TwoAlbumFKAndAnEInline, - id='admin.E202', + id="admin.E202", ) ] self.assertEqual(errors, expected) @@ -719,9 +752,11 @@ class SystemChecksTestCase(SimpleTestCase): def __getattr__(self, item): if item == "dynamic_method": + @admin.display def method(obj): pass + return method raise AttributeError @@ -745,7 +780,7 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'readonly_fields[1]' is not a callable, an attribute " "of 'SongAdmin', or an attribute of 'admin_checks.Song'.", obj=SongAdmin, - id='admin.E035', + id="admin.E035", ) ] self.assertEqual(errors, expected) @@ -753,7 +788,7 @@ class SystemChecksTestCase(SimpleTestCase): def test_nonexistent_field_on_inline(self): class CityInline(admin.TabularInline): model = City - readonly_fields = ['i_dont_exist'] # Missing attribute + readonly_fields = ["i_dont_exist"] # Missing attribute errors = CityInline(State, AdminSite()).check() expected = [ @@ -761,22 +796,25 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'readonly_fields[0]' is not a callable, an attribute " "of 'CityInline', or an attribute of 'admin_checks.City'.", obj=CityInline, - id='admin.E035', + id="admin.E035", ) ] self.assertEqual(errors, expected) def test_readonly_fields_not_list_or_tuple(self): class SongAdmin(admin.ModelAdmin): - readonly_fields = 'test' + readonly_fields = "test" - self.assertEqual(SongAdmin(Song, AdminSite()).check(), [ - checks.Error( - "The value of 'readonly_fields' must be a list or tuple.", - obj=SongAdmin, - id='admin.E034', - ) - ]) + self.assertEqual( + SongAdmin(Song, AdminSite()).check(), + [ + checks.Error( + "The value of 'readonly_fields' must be a list or tuple.", + obj=SongAdmin, + id="admin.E034", + ) + ], + ) def test_extra(self): class SongAdmin(admin.ModelAdmin): @@ -802,8 +840,9 @@ class SystemChecksTestCase(SimpleTestCase): specifies the 'through' option is included in the 'fields' or the 'fieldsets' ModelAdmin options. """ + class BookAdmin(admin.ModelAdmin): - fields = ['authors'] + fields = ["authors"] errors = BookAdmin(Book, AdminSite()).check() expected = [ @@ -811,7 +850,7 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'fields' cannot include the ManyToManyField 'authors', " "because that field manually specifies a relationship model.", obj=BookAdmin, - id='admin.E013', + id="admin.E013", ) ] self.assertEqual(errors, expected) @@ -819,8 +858,8 @@ class SystemChecksTestCase(SimpleTestCase): def test_cannot_include_through(self): class FieldsetBookAdmin(admin.ModelAdmin): fieldsets = ( - ('Header 1', {'fields': ('name',)}), - ('Header 2', {'fields': ('authors',)}), + ("Header 1", {"fields": ("name",)}), + ("Header 2", {"fields": ("authors",)}), ) errors = FieldsetBookAdmin(Book, AdminSite()).check() @@ -829,23 +868,21 @@ class SystemChecksTestCase(SimpleTestCase): "The value of 'fieldsets[1][1][\"fields\"]' cannot include the ManyToManyField " "'authors', because that field manually specifies a relationship model.", obj=FieldsetBookAdmin, - id='admin.E013', + id="admin.E013", ) ] self.assertEqual(errors, expected) def test_nested_fields(self): class NestedFieldsAdmin(admin.ModelAdmin): - fields = ('price', ('name', 'subtitle')) + fields = ("price", ("name", "subtitle")) errors = NestedFieldsAdmin(Book, AdminSite()).check() self.assertEqual(errors, []) def test_nested_fieldsets(self): class NestedFieldsetAdmin(admin.ModelAdmin): - fieldsets = ( - ('Main', {'fields': ('price', ('name', 'subtitle'))}), - ) + fieldsets = (("Main", {"fields": ("price", ("name", "subtitle"))}),) errors = NestedFieldsetAdmin(Book, AdminSite()).check() self.assertEqual(errors, []) @@ -856,6 +893,7 @@ class SystemChecksTestCase(SimpleTestCase): is specified as a string, the admin should still be able use Model.m2m_field.through """ + class AuthorsInline(admin.TabularInline): model = Book.authors.through @@ -870,12 +908,13 @@ class SystemChecksTestCase(SimpleTestCase): Regression for ensuring ModelAdmin.fields can contain non-model fields that broke with r11737 """ + class SongForm(forms.ModelForm): extra_data = forms.CharField() class FieldsOnFormOnlyAdmin(admin.ModelAdmin): form = SongForm - fields = ['title', 'extra_data'] + fields = ["title", "extra_data"] errors = FieldsOnFormOnlyAdmin(Song, AdminSite()).check() self.assertEqual(errors, []) @@ -885,30 +924,31 @@ class SystemChecksTestCase(SimpleTestCase): Regression for ensuring ModelAdmin.field can handle first elem being a non-model field (test fix for UnboundLocalError introduced with r16225). """ + class SongForm(forms.ModelForm): extra_data = forms.CharField() class Meta: model = Song - fields = '__all__' + fields = "__all__" class FieldsOnFormOnlyAdmin(admin.ModelAdmin): form = SongForm - fields = ['extra_data', 'title'] + fields = ["extra_data", "title"] errors = FieldsOnFormOnlyAdmin(Song, AdminSite()).check() self.assertEqual(errors, []) def test_check_sublists_for_duplicates(self): class MyModelAdmin(admin.ModelAdmin): - fields = ['state', ['state']] + fields = ["state", ["state"]] errors = MyModelAdmin(Song, AdminSite()).check() expected = [ checks.Error( "The value of 'fields' contains duplicate field(s).", obj=MyModelAdmin, - id='admin.E006' + id="admin.E006", ) ] self.assertEqual(errors, expected) @@ -916,9 +956,7 @@ class SystemChecksTestCase(SimpleTestCase): def test_check_fieldset_sublists_for_duplicates(self): class MyModelAdmin(admin.ModelAdmin): fieldsets = [ - (None, { - 'fields': ['title', 'album', ('title', 'album')] - }), + (None, {"fields": ["title", "album", ("title", "album")]}), ] errors = MyModelAdmin(Song, AdminSite()).check() @@ -926,7 +964,7 @@ class SystemChecksTestCase(SimpleTestCase): checks.Error( "There are duplicate field(s) in 'fieldsets[0][1]'.", obj=MyModelAdmin, - id='admin.E012' + id="admin.E012", ) ] self.assertEqual(errors, expected) @@ -936,8 +974,9 @@ class SystemChecksTestCase(SimpleTestCase): Ensure list_filter can access reverse fields even when the app registry is not ready; refs #24146. """ + class BookAdminWithListFilter(admin.ModelAdmin): - list_filter = ['authorsbooks__featured'] + list_filter = ["authorsbooks__featured"] # Temporarily pretending apps are not ready yet. This issue can happen # if the value of 'list_filter' refers to a 'through__field'. diff --git a/tests/admin_custom_urls/models.py b/tests/admin_custom_urls/models.py index 8b91383b0f..ea3ea61f8b 100644 --- a/tests/admin_custom_urls/models.py +++ b/tests/admin_custom_urls/models.py @@ -21,7 +21,7 @@ class ActionAdmin(admin.ModelAdmin): The Action model has a CharField PK. """ - list_display = ('name', 'description') + list_display = ("name", "description") def remove_url(self, name): """ @@ -38,14 +38,15 @@ class ActionAdmin(admin.ModelAdmin): def wrap(view): def wrapper(*args, **kwargs): return self.admin_site.admin_view(view)(*args, **kwargs) + return update_wrapper(wrapper, view) info = self.model._meta.app_label, self.model._meta.model_name - view_name = '%s_%s_add' % info + view_name = "%s_%s_add" % info return [ - re_path('^!add/$', wrap(self.add_view), name=view_name), + re_path("^!add/$", wrap(self.add_view), name=view_name), ] + self.remove_url(view_name) @@ -54,14 +55,15 @@ class Person(models.Model): class PersonAdmin(admin.ModelAdmin): - def response_post_save_add(self, request, obj): return HttpResponseRedirect( - reverse('admin:admin_custom_urls_person_history', args=[obj.pk])) + reverse("admin:admin_custom_urls_person_history", args=[obj.pk]) + ) def response_post_save_change(self, request, obj): return HttpResponseRedirect( - reverse('admin:admin_custom_urls_person_delete', args=[obj.pk])) + reverse("admin:admin_custom_urls_person_delete", args=[obj.pk]) + ) class Car(models.Model): @@ -69,15 +71,17 @@ class Car(models.Model): class CarAdmin(admin.ModelAdmin): - def response_add(self, request, obj, post_url_continue=None): return super().response_add( - request, obj, - post_url_continue=reverse('admin:admin_custom_urls_car_history', args=[obj.pk]), + request, + obj, + post_url_continue=reverse( + "admin:admin_custom_urls_car_history", args=[obj.pk] + ), ) -site = admin.AdminSite(name='admin_custom_urls') +site = admin.AdminSite(name="admin_custom_urls") site.register(Action, ActionAdmin) site.register(Person, PersonAdmin) site.register(Car, CarAdmin) diff --git a/tests/admin_custom_urls/tests.py b/tests/admin_custom_urls/tests.py index ebd3291f2c..d401976ebb 100644 --- a/tests/admin_custom_urls/tests.py +++ b/tests/admin_custom_urls/tests.py @@ -8,7 +8,9 @@ from django.urls import reverse from .models import Action, Car, Person -@override_settings(ROOT_URLCONF='admin_custom_urls.urls',) +@override_settings( + ROOT_URLCONF="admin_custom_urls.urls", +) class AdminCustomUrlsTest(TestCase): """ Remember that: @@ -19,18 +21,22 @@ class AdminCustomUrlsTest(TestCase): @classmethod def setUpTestData(cls): - cls.superuser = User.objects.create_superuser(username='super', password='secret', email='super@example.com') - Action.objects.create(name='delete', description='Remove things.') - Action.objects.create(name='rename', description='Gives things other names.') - Action.objects.create(name='add', description='Add things.') - Action.objects.create(name='path/to/file/', description="An action with '/' in its name.") + cls.superuser = User.objects.create_superuser( + username="super", password="secret", email="super@example.com" + ) + Action.objects.create(name="delete", description="Remove things.") + Action.objects.create(name="rename", description="Gives things other names.") + Action.objects.create(name="add", description="Add things.") Action.objects.create( - name='path/to/html/document.html', - description='An action with a name similar to a HTML doc path.' + name="path/to/file/", description="An action with '/' in its name." ) Action.objects.create( - name='javascript:alert(\'Hello world\');">Click here</a>', - description='An action with a name suspected of being a XSS attempt' + name="path/to/html/document.html", + description="An action with a name similar to a HTML doc path.", + ) + Action.objects.create( + name="javascript:alert('Hello world');\">Click here</a>", + description="An action with a name suspected of being a XSS attempt", ) def setUp(self): @@ -40,8 +46,8 @@ class AdminCustomUrlsTest(TestCase): """ Ensure GET on the add_view works. """ - add_url = reverse('admin_custom_urls:admin_custom_urls_action_add') - self.assertTrue(add_url.endswith('/!add/')) + add_url = reverse("admin_custom_urls:admin_custom_urls_action_add") + self.assertTrue(add_url.endswith("/!add/")) response = self.client.get(add_url) self.assertIsInstance(response, TemplateResponse) self.assertEqual(response.status_code, 200) @@ -51,7 +57,10 @@ class AdminCustomUrlsTest(TestCase): Ensure GET on the add_view plus specifying a field value in the query string works. """ - response = self.client.get(reverse('admin_custom_urls:admin_custom_urls_action_add'), {'name': 'My Action'}) + response = self.client.get( + reverse("admin_custom_urls:admin_custom_urls_action_add"), + {"name": "My Action"}, + ) self.assertContains(response, 'value="My Action"') def test_basic_add_POST(self): @@ -59,28 +68,33 @@ class AdminCustomUrlsTest(TestCase): Ensure POST on add_view works. """ post_data = { - IS_POPUP_VAR: '1', - "name": 'Action added through a popup', + IS_POPUP_VAR: "1", + "name": "Action added through a popup", "description": "Description of added action", } - response = self.client.post(reverse('admin_custom_urls:admin_custom_urls_action_add'), post_data) - self.assertContains(response, 'Action added through a popup') + response = self.client.post( + reverse("admin_custom_urls:admin_custom_urls_action_add"), post_data + ) + self.assertContains(response, "Action added through a popup") def test_admin_URLs_no_clash(self): # Should get the change_view for model instance with PK 'add', not show # the add_view - url = reverse('admin_custom_urls:%s_action_change' % Action._meta.app_label, args=(quote('add'),)) + url = reverse( + "admin_custom_urls:%s_action_change" % Action._meta.app_label, + args=(quote("add"),), + ) response = self.client.get(url) - self.assertContains(response, 'Change action') + self.assertContains(response, "Change action") # Should correctly get the change_view for the model instance with the # funny-looking PK (the one with a 'path/to/html/document.html' value) url = reverse( - 'admin_custom_urls:%s_action_change' % Action._meta.app_label, - args=(quote("path/to/html/document.html"),) + "admin_custom_urls:%s_action_change" % Action._meta.app_label, + args=(quote("path/to/html/document.html"),), ) response = self.client.get(url) - self.assertContains(response, 'Change action') + self.assertContains(response, "Change action") self.assertContains(response, 'value="path/to/html/document.html"') def test_post_save_add_redirect(self): @@ -88,12 +102,16 @@ class AdminCustomUrlsTest(TestCase): ModelAdmin.response_post_save_add() controls the redirection after the 'Save' button has been pressed when adding a new object. """ - post_data = {'name': 'John Doe'} + post_data = {"name": "John Doe"} self.assertEqual(Person.objects.count(), 0) - response = self.client.post(reverse('admin_custom_urls:admin_custom_urls_person_add'), post_data) + response = self.client.post( + reverse("admin_custom_urls:admin_custom_urls_person_add"), post_data + ) persons = Person.objects.all() self.assertEqual(len(persons), 1) - redirect_url = reverse('admin_custom_urls:admin_custom_urls_person_history', args=[persons[0].pk]) + redirect_url = reverse( + "admin_custom_urls:admin_custom_urls_person_history", args=[persons[0].pk] + ) self.assertRedirects(response, redirect_url) def test_post_save_change_redirect(self): @@ -101,21 +119,35 @@ class AdminCustomUrlsTest(TestCase): ModelAdmin.response_post_save_change() controls the redirection after the 'Save' button has been pressed when editing an existing object. """ - Person.objects.create(name='John Doe') + Person.objects.create(name="John Doe") self.assertEqual(Person.objects.count(), 1) person = Person.objects.all()[0] - post_url = reverse('admin_custom_urls:admin_custom_urls_person_change', args=[person.pk]) - response = self.client.post(post_url, {'name': 'Jack Doe'}) - self.assertRedirects(response, reverse('admin_custom_urls:admin_custom_urls_person_delete', args=[person.pk])) + post_url = reverse( + "admin_custom_urls:admin_custom_urls_person_change", args=[person.pk] + ) + response = self.client.post(post_url, {"name": "Jack Doe"}) + self.assertRedirects( + response, + reverse( + "admin_custom_urls:admin_custom_urls_person_delete", args=[person.pk] + ), + ) def test_post_url_continue(self): """ The ModelAdmin.response_add()'s parameter `post_url_continue` controls the redirection after an object has been created. """ - post_data = {'name': 'SuperFast', '_continue': '1'} + post_data = {"name": "SuperFast", "_continue": "1"} self.assertEqual(Car.objects.count(), 0) - response = self.client.post(reverse('admin_custom_urls:admin_custom_urls_car_add'), post_data) + response = self.client.post( + reverse("admin_custom_urls:admin_custom_urls_car_add"), post_data + ) cars = Car.objects.all() self.assertEqual(len(cars), 1) - self.assertRedirects(response, reverse('admin_custom_urls:admin_custom_urls_car_history', args=[cars[0].pk])) + self.assertRedirects( + response, + reverse( + "admin_custom_urls:admin_custom_urls_car_history", args=[cars[0].pk] + ), + ) diff --git a/tests/admin_custom_urls/urls.py b/tests/admin_custom_urls/urls.py index ade49b3957..9dd7574792 100644 --- a/tests/admin_custom_urls/urls.py +++ b/tests/admin_custom_urls/urls.py @@ -3,5 +3,5 @@ from django.urls import path from .models import site urlpatterns = [ - path('admin/', site.urls), + path("admin/", site.urls), ] diff --git a/tests/admin_default_site/apps.py b/tests/admin_default_site/apps.py index 92743c18d4..d2dde8784f 100644 --- a/tests/admin_default_site/apps.py +++ b/tests/admin_default_site/apps.py @@ -2,5 +2,5 @@ from django.contrib.admin.apps import SimpleAdminConfig class MyCustomAdminConfig(SimpleAdminConfig): - verbose_name = 'My custom default admin site.' - default_site = 'admin_default_site.sites.CustomAdminSite' + verbose_name = "My custom default admin site." + default_site = "admin_default_site.sites.CustomAdminSite" diff --git a/tests/admin_default_site/tests.py b/tests/admin_default_site/tests.py index 9ce087549f..0566306668 100644 --- a/tests/admin_default_site/tests.py +++ b/tests/admin_default_site/tests.py @@ -5,16 +5,17 @@ from django.test import SimpleTestCase, override_settings from .sites import CustomAdminSite -@override_settings(INSTALLED_APPS=[ - 'admin_default_site.apps.MyCustomAdminConfig', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', -]) +@override_settings( + INSTALLED_APPS=[ + "admin_default_site.apps.MyCustomAdminConfig", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + ] +) class CustomAdminSiteTests(SimpleTestCase): - def setUp(self): # Reset admin.site since it may have already been instantiated by # another test app. @@ -25,12 +26,12 @@ class CustomAdminSiteTests(SimpleTestCase): admin.site = sites.site = self._old_site def test_use_custom_admin_site(self): - self.assertEqual(admin.site.__class__.__name__, 'CustomAdminSite') + self.assertEqual(admin.site.__class__.__name__, "CustomAdminSite") class DefaultAdminSiteTests(SimpleTestCase): def test_use_default_admin_site(self): - self.assertEqual(admin.site.__class__.__name__, 'AdminSite') + self.assertEqual(admin.site.__class__.__name__, "AdminSite") def test_repr(self): self.assertEqual(str(admin.site), "AdminSite(name='admin')") @@ -39,5 +40,5 @@ class DefaultAdminSiteTests(SimpleTestCase): class AdminSiteTests(SimpleTestCase): def test_repr(self): - admin_site = CustomAdminSite(name='other') + admin_site = CustomAdminSite(name="other") self.assertEqual(repr(admin_site), "CustomAdminSite(name='other')") diff --git a/tests/admin_docs/models.py b/tests/admin_docs/models.py index 06569d5c4c..7ae1ed8f95 100644 --- a/tests/admin_docs/models.py +++ b/tests/admin_docs/models.py @@ -36,10 +36,11 @@ class Person(models.Model): .. include:: admin_docs/evilfile.txt """ + first_name = models.CharField(max_length=200, help_text="The person's first name") last_name = models.CharField(max_length=200, help_text="The person's last name") company = models.ForeignKey(Company, models.CASCADE, help_text="place of work") - family = models.ForeignKey(Family, models.SET_NULL, related_name='+', null=True) + family = models.ForeignKey(Family, models.SET_NULL, related_name="+", null=True) groups = models.ManyToManyField(Group, help_text="has membership") def _get_full_name(self): @@ -55,13 +56,13 @@ class Person(models.Model): @property def a_property(self): - return 'a_property' + return "a_property" @cached_property def a_cached_property(self): - return 'a_cached_property' + return "a_cached_property" - def suffix_company_name(self, suffix='ltd'): + def suffix_company_name(self, suffix="ltd"): return self.company.name + suffix def add_image(self): diff --git a/tests/admin_docs/namespace_urls.py b/tests/admin_docs/namespace_urls.py index 719bf0ddf5..b8eb0351f2 100644 --- a/tests/admin_docs/namespace_urls.py +++ b/tests/admin_docs/namespace_urls.py @@ -3,12 +3,15 @@ from django.urls import include, path from . import views -backend_urls = ([ - path('something/', views.XViewClass.as_view(), name='something'), -], 'backend') +backend_urls = ( + [ + path("something/", views.XViewClass.as_view(), name="something"), + ], + "backend", +) urlpatterns = [ - path('admin/doc/', include('django.contrib.admindocs.urls')), - path('admin/', admin.site.urls), - path('api/backend/', include(backend_urls, namespace='backend')), + path("admin/doc/", include("django.contrib.admindocs.urls")), + path("admin/", admin.site.urls), + path("api/backend/", include(backend_urls, namespace="backend")), ] diff --git a/tests/admin_docs/test_middleware.py b/tests/admin_docs/test_middleware.py index 5d737f1bfd..3095074ee6 100644 --- a/tests/admin_docs/test_middleware.py +++ b/tests/admin_docs/test_middleware.py @@ -6,47 +6,48 @@ from .tests import AdminDocsTestCase, TestDataMixin class XViewMiddlewareTest(TestDataMixin, AdminDocsTestCase): - def test_xview_func(self): - user = User.objects.get(username='super') - response = self.client.head('/xview/func/') - self.assertNotIn('X-View', response) + user = User.objects.get(username="super") + response = self.client.head("/xview/func/") + self.assertNotIn("X-View", response) self.client.force_login(self.superuser) - response = self.client.head('/xview/func/') - self.assertIn('X-View', response) - self.assertEqual(response.headers['X-View'], 'admin_docs.views.xview') + response = self.client.head("/xview/func/") + self.assertIn("X-View", response) + self.assertEqual(response.headers["X-View"], "admin_docs.views.xview") user.is_staff = False user.save() - response = self.client.head('/xview/func/') - self.assertNotIn('X-View', response) + response = self.client.head("/xview/func/") + self.assertNotIn("X-View", response) user.is_staff = True user.is_active = False user.save() - response = self.client.head('/xview/func/') - self.assertNotIn('X-View', response) + response = self.client.head("/xview/func/") + self.assertNotIn("X-View", response) def test_xview_class(self): - user = User.objects.get(username='super') - response = self.client.head('/xview/class/') - self.assertNotIn('X-View', response) + user = User.objects.get(username="super") + response = self.client.head("/xview/class/") + self.assertNotIn("X-View", response) self.client.force_login(self.superuser) - response = self.client.head('/xview/class/') - self.assertIn('X-View', response) - self.assertEqual(response.headers['X-View'], 'admin_docs.views.XViewClass') + response = self.client.head("/xview/class/") + self.assertIn("X-View", response) + self.assertEqual(response.headers["X-View"], "admin_docs.views.XViewClass") user.is_staff = False user.save() - response = self.client.head('/xview/class/') - self.assertNotIn('X-View', response) + response = self.client.head("/xview/class/") + self.assertNotIn("X-View", response) user.is_staff = True user.is_active = False user.save() - response = self.client.head('/xview/class/') - self.assertNotIn('X-View', response) + response = self.client.head("/xview/class/") + self.assertNotIn("X-View", response) def test_callable_object_view(self): self.client.force_login(self.superuser) - response = self.client.head('/xview/callable_object/') - self.assertEqual(response.headers['X-View'], 'admin_docs.views.XViewCallableObject') + response = self.client.head("/xview/callable_object/") + self.assertEqual( + response.headers["X-View"], "admin_docs.views.XViewCallableObject" + ) @override_settings(MIDDLEWARE=[]) def test_no_auth_middleware(self): @@ -56,4 +57,4 @@ class XViewMiddlewareTest(TestDataMixin, AdminDocsTestCase): "'django.contrib.auth.middleware.AuthenticationMiddleware'." ) with self.assertRaisesMessage(ImproperlyConfigured, msg): - self.client.head('/xview/func/') + self.client.head("/xview/func/") diff --git a/tests/admin_docs/test_utils.py b/tests/admin_docs/test_utils.py index 9ffc25392c..18c6769fad 100644 --- a/tests/admin_docs/test_utils.py +++ b/tests/admin_docs/test_utils.py @@ -1,7 +1,9 @@ import unittest from django.contrib.admindocs.utils import ( - docutils_is_available, parse_docstring, parse_rst, + docutils_is_available, + parse_docstring, + parse_rst, ) from django.test.utils import captured_stderr @@ -29,41 +31,42 @@ class TestUtils(AdminDocsSimpleTestCase): some_metadata: some data """ + def setUp(self): self.docstring = self.__doc__ def test_parse_docstring(self): title, description, metadata = parse_docstring(self.docstring) docstring_title = ( - 'This __doc__ output is required for testing. I copied this example from\n' - '`admindocs` documentation. (TITLE)' + "This __doc__ output is required for testing. I copied this example from\n" + "`admindocs` documentation. (TITLE)" ) docstring_description = ( - 'Display an individual :model:`myapp.MyModel`.\n\n' - '**Context**\n\n``RequestContext``\n\n``mymodel``\n' - ' An instance of :model:`myapp.MyModel`.\n\n' - '**Template:**\n\n:template:`myapp/my_template.html` ' - '(DESCRIPTION)' + "Display an individual :model:`myapp.MyModel`.\n\n" + "**Context**\n\n``RequestContext``\n\n``mymodel``\n" + " An instance of :model:`myapp.MyModel`.\n\n" + "**Template:**\n\n:template:`myapp/my_template.html` " + "(DESCRIPTION)" ) self.assertEqual(title, docstring_title) self.assertEqual(description, docstring_description) - self.assertEqual(metadata, {'some_metadata': 'some data'}) + self.assertEqual(metadata, {"some_metadata": "some data"}) def test_title_output(self): title, description, metadata = parse_docstring(self.docstring) - title_output = parse_rst(title, 'model', 'model:admindocs') - self.assertIn('TITLE', title_output) + title_output = parse_rst(title, "model", "model:admindocs") + self.assertIn("TITLE", title_output) title_rendered = ( - '<p>This __doc__ output is required for testing. I copied this ' + "<p>This __doc__ output is required for testing. I copied this " 'example from\n<a class="reference external" ' 'href="/admindocs/models/admindocs/">admindocs</a> documentation. ' - '(TITLE)</p>\n' + "(TITLE)</p>\n" ) self.assertHTMLEqual(title_output, title_rendered) def test_description_output(self): title, description, metadata = parse_docstring(self.docstring) - description_output = parse_rst(description, 'model', 'model:admindocs') + description_output = parse_rst(description, "model", "model:admindocs") description_rendered = ( '<p>Display an individual <a class="reference external" ' 'href="/admindocs/models/myapp.mymodel/">myapp.MyModel</a>.</p>\n' @@ -71,35 +74,35 @@ class TestUtils(AdminDocsSimpleTestCase): 'RequestContext</tt></p>\n<dl class="docutils">\n<dt><tt class="' 'docutils literal">mymodel</tt></dt>\n<dd>An instance of <a class="' 'reference external" href="/admindocs/models/myapp.mymodel/">' - 'myapp.MyModel</a>.</dd>\n</dl>\n<p><strong>Template:</strong></p>' + "myapp.MyModel</a>.</dd>\n</dl>\n<p><strong>Template:</strong></p>" '\n<p><a class="reference external" href="/admindocs/templates/' 'myapp/my_template.html/">myapp/my_template.html</a> (DESCRIPTION)' - '</p>\n' + "</p>\n" ) self.assertHTMLEqual(description_output, description_rendered) def test_initial_header_level(self): - header = 'should be h3...\n\nHeader\n------\n' - output = parse_rst(header, 'header') - self.assertIn('<h3>Header</h3>', output) + header = "should be h3...\n\nHeader\n------\n" + output = parse_rst(header, "header") + self.assertIn("<h3>Header</h3>", output) def test_parse_rst(self): """ parse_rst() should use `cmsreference` as the default role. """ markup = '<p><a class="reference external" href="/admindocs/%s">title</a></p>\n' - self.assertEqual(parse_rst('`title`', 'model'), markup % 'models/title/') - self.assertEqual(parse_rst('`title`', 'view'), markup % 'views/title/') - self.assertEqual(parse_rst('`title`', 'template'), markup % 'templates/title/') - self.assertEqual(parse_rst('`title`', 'filter'), markup % 'filters/#title') - self.assertEqual(parse_rst('`title`', 'tag'), markup % 'tags/#title') + self.assertEqual(parse_rst("`title`", "model"), markup % "models/title/") + self.assertEqual(parse_rst("`title`", "view"), markup % "views/title/") + self.assertEqual(parse_rst("`title`", "template"), markup % "templates/title/") + self.assertEqual(parse_rst("`title`", "filter"), markup % "filters/#title") + self.assertEqual(parse_rst("`title`", "tag"), markup % "tags/#title") def test_parse_rst_with_docstring_no_leading_line_feed(self): - title, body, _ = parse_docstring('firstline\n\n second line') + title, body, _ = parse_docstring("firstline\n\n second line") with captured_stderr() as stderr: - self.assertEqual(parse_rst(title, ''), '<p>firstline</p>\n') - self.assertEqual(parse_rst(body, ''), '<p>second line</p>\n') - self.assertEqual(stderr.getvalue(), '') + self.assertEqual(parse_rst(title, ""), "<p>firstline</p>\n") + self.assertEqual(parse_rst(body, ""), "<p>second line</p>\n") + self.assertEqual(stderr.getvalue(), "") def test_publish_parts(self): """ @@ -108,8 +111,11 @@ class TestUtils(AdminDocsSimpleTestCase): ``cmsreference`` (#6681). """ import docutils - self.assertNotEqual(docutils.parsers.rst.roles.DEFAULT_INTERPRETED_ROLE, 'cmsreference') - source = 'reST, `interpreted text`, default role.' - markup = '<p>reST, <cite>interpreted text</cite>, default role.</p>\n' + + self.assertNotEqual( + docutils.parsers.rst.roles.DEFAULT_INTERPRETED_ROLE, "cmsreference" + ) + source = "reST, `interpreted text`, default role." + markup = "<p>reST, <cite>interpreted text</cite>, default role.</p>\n" parts = docutils.core.publish_parts(source=source, writer_name="html4css1") - self.assertEqual(parts['fragment'], markup) + self.assertEqual(parts["fragment"], markup) diff --git a/tests/admin_docs/test_views.py b/tests/admin_docs/test_views.py index b786fb1930..c11b8f71c2 100644 --- a/tests/admin_docs/test_views.py +++ b/tests/admin_docs/test_views.py @@ -19,44 +19,50 @@ from .tests import AdminDocsTestCase, TestDataMixin @unittest.skipUnless(utils.docutils_is_available, "no docutils installed.") class AdminDocViewTests(TestDataMixin, AdminDocsTestCase): - def setUp(self): self.client.force_login(self.superuser) def test_index(self): - response = self.client.get(reverse('django-admindocs-docroot')) - self.assertContains(response, '<h1>Documentation</h1>', html=True) - self.assertContains(response, '<h1 id="site-name"><a href="/admin/">Django administration</a></h1>') + response = self.client.get(reverse("django-admindocs-docroot")) + self.assertContains(response, "<h1>Documentation</h1>", html=True) + self.assertContains( + response, + '<h1 id="site-name"><a href="/admin/">Django administration</a></h1>', + ) self.client.logout() - response = self.client.get(reverse('django-admindocs-docroot'), follow=True) + response = self.client.get(reverse("django-admindocs-docroot"), follow=True) # Should display the login screen - self.assertContains(response, '<input type="hidden" name="next" value="/admindocs/">', html=True) + self.assertContains( + response, '<input type="hidden" name="next" value="/admindocs/">', html=True + ) def test_bookmarklets(self): - response = self.client.get(reverse('django-admindocs-bookmarklets')) - self.assertContains(response, '/admindocs/views/') + response = self.client.get(reverse("django-admindocs-bookmarklets")) + self.assertContains(response, "/admindocs/views/") def test_templatetag_index(self): - response = self.client.get(reverse('django-admindocs-tags')) - self.assertContains(response, '<h3 id="built_in-extends">extends</h3>', html=True) + response = self.client.get(reverse("django-admindocs-tags")) + self.assertContains( + response, '<h3 id="built_in-extends">extends</h3>', html=True + ) def test_templatefilter_index(self): - response = self.client.get(reverse('django-admindocs-filters')) + response = self.client.get(reverse("django-admindocs-filters")) self.assertContains(response, '<h3 id="built_in-first">first</h3>', html=True) def test_view_index(self): - response = self.client.get(reverse('django-admindocs-views-index')) + response = self.client.get(reverse("django-admindocs-views-index")) self.assertContains( response, '<h3><a href="/admindocs/views/django.contrib.admindocs.views.BaseAdminDocsView/">/admindocs/</a></h3>', - html=True + html=True, ) - self.assertContains(response, 'Views by namespace test') - self.assertContains(response, 'Name: <code>test:func</code>.') + self.assertContains(response, "Views by namespace test") + self.assertContains(response, "Name: <code>test:func</code>.") self.assertContains( response, '<h3><a href="/admindocs/views/admin_docs.views.XViewCallableObject/">' - '/xview/callable_object_without_xview/</a></h3>', + "/xview/callable_object_without_xview/</a></h3>", html=True, ) @@ -64,27 +70,35 @@ class AdminDocViewTests(TestDataMixin, AdminDocsTestCase): """ Views that are methods are listed correctly. """ - response = self.client.get(reverse('django-admindocs-views-index')) + response = self.client.get(reverse("django-admindocs-views-index")) self.assertContains( response, '<h3><a href="/admindocs/views/django.contrib.admin.sites.AdminSite.index/">/admin/</a></h3>', - html=True + html=True, ) def test_view_detail(self): - url = reverse('django-admindocs-views-detail', args=['django.contrib.admindocs.views.BaseAdminDocsView']) + url = reverse( + "django-admindocs-views-detail", + args=["django.contrib.admindocs.views.BaseAdminDocsView"], + ) response = self.client.get(url) # View docstring - self.assertContains(response, 'Base view for admindocs views.') + self.assertContains(response, "Base view for admindocs views.") - @override_settings(ROOT_URLCONF='admin_docs.namespace_urls') + @override_settings(ROOT_URLCONF="admin_docs.namespace_urls") def test_namespaced_view_detail(self): - url = reverse('django-admindocs-views-detail', args=['admin_docs.views.XViewClass']) + url = reverse( + "django-admindocs-views-detail", args=["admin_docs.views.XViewClass"] + ) response = self.client.get(url) - self.assertContains(response, '<h1>admin_docs.views.XViewClass</h1>') + self.assertContains(response, "<h1>admin_docs.views.XViewClass</h1>") def test_view_detail_illegal_import(self): - url = reverse('django-admindocs-views-detail', args=['urlpatterns_reverse.nonimported_module.view']) + url = reverse( + "django-admindocs-views-detail", + args=["urlpatterns_reverse.nonimported_module.view"], + ) response = self.client.get(url) self.assertEqual(response.status_code, 404) self.assertNotIn("urlpatterns_reverse.nonimported_module", sys.modules) @@ -93,41 +107,55 @@ class AdminDocViewTests(TestDataMixin, AdminDocsTestCase): """ Views that are methods can be displayed. """ - url = reverse('django-admindocs-views-detail', args=['django.contrib.admin.sites.AdminSite.index']) + url = reverse( + "django-admindocs-views-detail", + args=["django.contrib.admin.sites.AdminSite.index"], + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) def test_model_index(self): - response = self.client.get(reverse('django-admindocs-models-index')) + response = self.client.get(reverse("django-admindocs-models-index")) self.assertContains( response, '<h2 id="app-auth">Authentication and Authorization (django.contrib.auth)</h2>', - html=True + html=True, ) def test_template_detail(self): - response = self.client.get(reverse('django-admindocs-templates', args=['admin_doc/template_detail.html'])) - self.assertContains(response, '<h1>Template: <q>admin_doc/template_detail.html</q></h1>', html=True) + response = self.client.get( + reverse( + "django-admindocs-templates", args=["admin_doc/template_detail.html"] + ) + ) + self.assertContains( + response, + "<h1>Template: <q>admin_doc/template_detail.html</q></h1>", + html=True, + ) def test_missing_docutils(self): utils.docutils_is_available = False try: - response = self.client.get(reverse('django-admindocs-docroot')) + response = self.client.get(reverse("django-admindocs-docroot")) self.assertContains( response, - '<h3>The admin documentation system requires Python’s ' + "<h3>The admin documentation system requires Python’s " '<a href="https://docutils.sourceforge.io/">docutils</a> ' - 'library.</h3>' - '<p>Please ask your administrators to install ' + "library.</h3>" + "<p>Please ask your administrators to install " '<a href="https://docutils.sourceforge.io/">docutils</a>.</p>', - html=True + html=True, + ) + self.assertContains( + response, + '<h1 id="site-name"><a href="/admin/">Django administration</a></h1>', ) - self.assertContains(response, '<h1 id="site-name"><a href="/admin/">Django administration</a></h1>') finally: utils.docutils_is_available = True - @modify_settings(INSTALLED_APPS={'remove': 'django.contrib.sites'}) - @override_settings(SITE_ID=None) # will restore SITE_ID after the test + @modify_settings(INSTALLED_APPS={"remove": "django.contrib.sites"}) + @override_settings(SITE_ID=None) # will restore SITE_ID after the test def test_no_sites_framework(self): """ Without the sites framework, should not access SITE_ID or Site @@ -135,73 +163,78 @@ class AdminDocViewTests(TestDataMixin, AdminDocsTestCase): """ Site.objects.all().delete() del settings.SITE_ID - response = self.client.get(reverse('django-admindocs-views-index')) - self.assertContains(response, 'View documentation') + response = self.client.get(reverse("django-admindocs-views-index")) + self.assertContains(response, "View documentation") def test_callable_urlconf(self): """ Index view should correctly resolve view patterns when ROOT_URLCONF is not a string. """ + def urlpatterns(): return ( - path('admin/doc/', include('django.contrib.admindocs.urls')), - path('admin/', admin.site.urls), + path("admin/doc/", include("django.contrib.admindocs.urls")), + path("admin/", admin.site.urls), ) with self.settings(ROOT_URLCONF=SimpleLazyObject(urlpatterns)): - response = self.client.get(reverse('django-admindocs-views-index')) + response = self.client.get(reverse("django-admindocs-views-index")) self.assertEqual(response.status_code, 200) -@unittest.skipUnless(utils.docutils_is_available, 'no docutils installed.') +@unittest.skipUnless(utils.docutils_is_available, "no docutils installed.") class AdminDocViewDefaultEngineOnly(TestDataMixin, AdminDocsTestCase): - def setUp(self): self.client.force_login(self.superuser) def test_template_detail_path_traversal(self): - cases = ['/etc/passwd', '../passwd'] + cases = ["/etc/passwd", "../passwd"] for fpath in cases: with self.subTest(path=fpath): response = self.client.get( - reverse('django-admindocs-templates', args=[fpath]), + reverse("django-admindocs-templates", args=[fpath]), ) self.assertEqual(response.status_code, 400) -@override_settings(TEMPLATES=[{ - 'NAME': 'ONE', - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'APP_DIRS': True, -}, { - 'NAME': 'TWO', - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'APP_DIRS': True, -}]) +@override_settings( + TEMPLATES=[ + { + "NAME": "ONE", + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + }, + { + "NAME": "TWO", + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + }, + ] +) @unittest.skipUnless(utils.docutils_is_available, "no docutils installed.") class AdminDocViewWithMultipleEngines(AdminDocViewTests): - def test_templatefilter_index(self): # Overridden because non-trivial TEMPLATES settings aren't supported # but the page shouldn't crash (#24125). - response = self.client.get(reverse('django-admindocs-filters')) - self.assertContains(response, '<title>Template filters', html=True) + response = self.client.get(reverse("django-admindocs-filters")) + self.assertContains(response, "Template filters", html=True) def test_templatetag_index(self): # Overridden because non-trivial TEMPLATES settings aren't supported # but the page shouldn't crash (#24125). - response = self.client.get(reverse('django-admindocs-tags')) - self.assertContains(response, 'Template tags', html=True) + response = self.client.get(reverse("django-admindocs-tags")) + self.assertContains(response, "Template tags", html=True) @unittest.skipUnless(utils.docutils_is_available, "no docutils installed.") class TestModelDetailView(TestDataMixin, AdminDocsTestCase): - def setUp(self): self.client.force_login(self.superuser) with captured_stderr() as self.docutils_stderr: - self.response = self.client.get(reverse('django-admindocs-models-detail', args=['admin_docs', 'Person'])) + self.response = self.client.get( + reverse("django-admindocs-models-detail", args=["admin_docs", "Person"]) + ) def test_method_excludes(self): """ @@ -235,42 +268,52 @@ class TestModelDetailView(TestDataMixin, AdminDocsTestCase): """ Methods with keyword arguments should have their arguments displayed. """ - self.assertContains(self.response, 'suffix='ltd'') + self.assertContains(self.response, "suffix='ltd'") def test_methods_with_multiple_arguments_display_arguments(self): """ Methods with multiple arguments should have all their arguments displayed, but omitting 'self'. """ - self.assertContains(self.response, "baz, rox, *some_args, **some_kwargs") + self.assertContains( + self.response, "baz, rox, *some_args, **some_kwargs" + ) def test_instance_of_property_methods_are_displayed(self): """Model properties are displayed as fields.""" - self.assertContains(self.response, 'a_property') + self.assertContains(self.response, "a_property") def test_instance_of_cached_property_methods_are_displayed(self): """Model cached properties are displayed as fields.""" - self.assertContains(self.response, 'a_cached_property') + self.assertContains(self.response, "a_cached_property") def test_method_data_types(self): company = Company.objects.create(name="Django") - person = Person.objects.create(first_name="Human", last_name="User", company=company) - self.assertEqual(get_return_data_type(person.get_status_count.__name__), 'Integer') - self.assertEqual(get_return_data_type(person.get_groups_list.__name__), 'List') + person = Person.objects.create( + first_name="Human", last_name="User", company=company + ) + self.assertEqual( + get_return_data_type(person.get_status_count.__name__), "Integer" + ) + self.assertEqual(get_return_data_type(person.get_groups_list.__name__), "List") def test_descriptions_render_correctly(self): """ The ``description`` field should render correctly for each field type. """ # help text in fields - self.assertContains(self.response, "first name - The person's first name") - self.assertContains(self.response, "last name - The person's last name") + self.assertContains( + self.response, "first name - The person's first name" + ) + self.assertContains( + self.response, "last name - The person's last name" + ) # method docstrings self.assertContains(self.response, "

Get the full name of the person

") link = '%s' - markup = '

the related %s object

' + markup = "

the related %s object

" company_markup = markup % (link % ("admin_docs.company", "admin_docs.Company")) # foreign keys @@ -282,18 +325,28 @@ class TestModelDetailView(TestDataMixin, AdminDocsTestCase): # many to many fields self.assertContains( self.response, - "number of related %s objects" % (link % ("admin_docs.group", "admin_docs.Group")) + "number of related %s objects" + % (link % ("admin_docs.group", "admin_docs.Group")), ) self.assertContains( self.response, - "all related %s objects" % (link % ("admin_docs.group", "admin_docs.Group")) + "all related %s objects" + % (link % ("admin_docs.group", "admin_docs.Group")), ) # "raw" and "include" directives are disabled - self.assertContains(self.response, '

"raw" directive disabled.

',) - self.assertContains(self.response, '.. raw:: html\n :file: admin_docs/evilfile.txt') - self.assertContains(self.response, '

"include" directive disabled.

',) - self.assertContains(self.response, '.. include:: admin_docs/evilfile.txt') + self.assertContains( + self.response, + "

"raw" directive disabled.

", + ) + self.assertContains( + self.response, ".. raw:: html\n :file: admin_docs/evilfile.txt" + ) + self.assertContains( + self.response, + "

"include" directive disabled.

", + ) + self.assertContains(self.response, ".. include:: admin_docs/evilfile.txt") out = self.docutils_stderr.getvalue() self.assertIn('"raw" directive disabled', out) self.assertIn('"include" directive disabled', out) @@ -301,15 +354,17 @@ class TestModelDetailView(TestDataMixin, AdminDocsTestCase): def test_model_with_many_to_one(self): link = '%s' response = self.client.get( - reverse('django-admindocs-models-detail', args=['admin_docs', 'company']) + reverse("django-admindocs-models-detail", args=["admin_docs", "company"]) ) self.assertContains( response, - "number of related %s objects" % (link % ("admin_docs.person", "admin_docs.Person")) + "number of related %s objects" + % (link % ("admin_docs.person", "admin_docs.Person")), ) self.assertContains( response, - "all related %s objects" % (link % ("admin_docs.person", "admin_docs.Person")) + "all related %s objects" + % (link % ("admin_docs.person", "admin_docs.Person")), ) def test_model_with_no_backward_relations_render_only_relevant_fields(self): @@ -317,8 +372,10 @@ class TestModelDetailView(TestDataMixin, AdminDocsTestCase): A model with ``related_name`` of `+` shouldn't show backward relationship links. """ - response = self.client.get(reverse('django-admindocs-models-detail', args=['admin_docs', 'family'])) - fields = response.context_data.get('fields') + response = self.client.get( + reverse("django-admindocs-models-detail", args=["admin_docs", "family"]) + ) + fields = response.context_data.get("fields") self.assertEqual(len(fields), 2) def test_model_docstring_renders_correctly(self): @@ -326,31 +383,40 @@ class TestModelDetailView(TestDataMixin, AdminDocsTestCase): '

Stores information about a person, related to myapp.Company.

' ) - subheading = '

Notes

' + subheading = "

Notes

" body = '

Use save_changes() when saving this object.

' model_body = ( '
company
Field storing ' - 'myapp.Company where the person works.
' + "myapp.Company where the person works." ) - self.assertContains(self.response, 'DESCRIPTION') + self.assertContains(self.response, "DESCRIPTION") self.assertContains(self.response, summary, html=True) self.assertContains(self.response, subheading, html=True) self.assertContains(self.response, body, html=True) self.assertContains(self.response, model_body, html=True) def test_model_detail_title(self): - self.assertContains(self.response, '

admin_docs.Person

', html=True) + self.assertContains(self.response, "

admin_docs.Person

", html=True) def test_app_not_found(self): - response = self.client.get(reverse('django-admindocs-models-detail', args=['doesnotexist', 'Person'])) - self.assertEqual(response.context['exception'], "App 'doesnotexist' not found") + response = self.client.get( + reverse("django-admindocs-models-detail", args=["doesnotexist", "Person"]) + ) + self.assertEqual(response.context["exception"], "App 'doesnotexist' not found") self.assertEqual(response.status_code, 404) def test_model_not_found(self): - response = self.client.get(reverse('django-admindocs-models-detail', args=['admin_docs', 'doesnotexist'])) - self.assertEqual(response.context['exception'], "Model 'doesnotexist' not found in app 'admin_docs'") + response = self.client.get( + reverse( + "django-admindocs-models-detail", args=["admin_docs", "doesnotexist"] + ) + ) + self.assertEqual( + response.context["exception"], + "Model 'doesnotexist' not found in app 'admin_docs'", + ) self.assertEqual(response.status_code, 404) @@ -370,123 +436,124 @@ class TestFieldType(unittest.TestCase): def test_builtin_fields(self): self.assertEqual( views.get_readable_field_data_type(fields.BooleanField()), - 'Boolean (Either True or False)' + "Boolean (Either True or False)", ) def test_custom_fields(self): - self.assertEqual(views.get_readable_field_data_type(CustomField()), 'A custom field type') + self.assertEqual( + views.get_readable_field_data_type(CustomField()), "A custom field type" + ) self.assertEqual( views.get_readable_field_data_type(DescriptionLackingField()), - 'Field of type: DescriptionLackingField' + "Field of type: DescriptionLackingField", ) class AdminDocViewFunctionsTests(SimpleTestCase): - def test_simplify_regex(self): tests = ( # Named and unnamed groups. - (r'^(?P\w+)/b/(?P\w+)/$', '//b//'), - (r'^(?P\w+)/b/(?P\w+)$', '//b/'), - (r'^(?P\w+)/b/(?P\w+)', '//b/'), - (r'^(?P\w+)/b/(\w+)$', '//b/'), - (r'^(?P\w+)/b/(\w+)', '//b/'), - (r'^(?P\w+)/b/((x|y)\w+)$', '//b/'), - (r'^(?P\w+)/b/((x|y)\w+)', '//b/'), - (r'^(?P(x|y))/b/(?P\w+)$', '//b/'), - (r'^(?P(x|y))/b/(?P\w+)', '//b/'), - (r'^(?P(x|y))/b/(?P\w+)ab', '//b/ab'), - (r'^(?P(x|y)(\(|\)))/b/(?P\w+)ab', '//b/ab'), + (r"^(?P\w+)/b/(?P\w+)/$", "//b//"), + (r"^(?P\w+)/b/(?P\w+)$", "//b/"), + (r"^(?P\w+)/b/(?P\w+)", "//b/"), + (r"^(?P\w+)/b/(\w+)$", "//b/"), + (r"^(?P\w+)/b/(\w+)", "//b/"), + (r"^(?P\w+)/b/((x|y)\w+)$", "//b/"), + (r"^(?P\w+)/b/((x|y)\w+)", "//b/"), + (r"^(?P(x|y))/b/(?P\w+)$", "//b/"), + (r"^(?P(x|y))/b/(?P\w+)", "//b/"), + (r"^(?P(x|y))/b/(?P\w+)ab", "//b/ab"), + (r"^(?P(x|y)(\(|\)))/b/(?P\w+)ab", "//b/ab"), # Non-capturing groups. - (r'^a(?:\w+)b', '/ab'), - (r'^a(?:(x|y))', '/a'), - (r'^(?:\w+(?:\w+))a', '/a'), - (r'^a(?:\w+)/b(?:\w+)', '/a/b'), - (r'(?P\w+)/b/(?:\w+)c(?:\w+)', '//b/c'), - (r'(?P\w+)/b/(\w+)/(?:\w+)c(?:\w+)', '//b//c'), + (r"^a(?:\w+)b", "/ab"), + (r"^a(?:(x|y))", "/a"), + (r"^(?:\w+(?:\w+))a", "/a"), + (r"^a(?:\w+)/b(?:\w+)", "/a/b"), + (r"(?P\w+)/b/(?:\w+)c(?:\w+)", "//b/c"), + (r"(?P\w+)/b/(\w+)/(?:\w+)c(?:\w+)", "//b//c"), # Single and repeated metacharacters. - (r'^a', '/a'), - (r'^^a', '/a'), - (r'^^^a', '/a'), - (r'a$', '/a'), - (r'a$$', '/a'), - (r'a$$$', '/a'), - (r'a?', '/a'), - (r'a??', '/a'), - (r'a???', '/a'), - (r'a*', '/a'), - (r'a**', '/a'), - (r'a***', '/a'), - (r'a+', '/a'), - (r'a++', '/a'), - (r'a+++', '/a'), - (r'\Aa', '/a'), - (r'\A\Aa', '/a'), - (r'\A\A\Aa', '/a'), - (r'a\Z', '/a'), - (r'a\Z\Z', '/a'), - (r'a\Z\Z\Z', '/a'), - (r'\ba', '/a'), - (r'\b\ba', '/a'), - (r'\b\b\ba', '/a'), - (r'a\B', '/a'), - (r'a\B\B', '/a'), - (r'a\B\B\B', '/a'), + (r"^a", "/a"), + (r"^^a", "/a"), + (r"^^^a", "/a"), + (r"a$", "/a"), + (r"a$$", "/a"), + (r"a$$$", "/a"), + (r"a?", "/a"), + (r"a??", "/a"), + (r"a???", "/a"), + (r"a*", "/a"), + (r"a**", "/a"), + (r"a***", "/a"), + (r"a+", "/a"), + (r"a++", "/a"), + (r"a+++", "/a"), + (r"\Aa", "/a"), + (r"\A\Aa", "/a"), + (r"\A\A\Aa", "/a"), + (r"a\Z", "/a"), + (r"a\Z\Z", "/a"), + (r"a\Z\Z\Z", "/a"), + (r"\ba", "/a"), + (r"\b\ba", "/a"), + (r"\b\b\ba", "/a"), + (r"a\B", "/a"), + (r"a\B\B", "/a"), + (r"a\B\B\B", "/a"), # Multiple mixed metacharacters. - (r'^a/?$', '/a/'), - (r'\Aa\Z', '/a'), - (r'\ba\B', '/a'), + (r"^a/?$", "/a/"), + (r"\Aa\Z", "/a"), + (r"\ba\B", "/a"), # Escaped single metacharacters. - (r'\^a', r'/^a'), - (r'\\^a', r'/\\a'), - (r'\\\^a', r'/\\^a'), - (r'\\\\^a', r'/\\\\a'), - (r'\\\\\^a', r'/\\\\^a'), - (r'a\$', r'/a$'), - (r'a\\$', r'/a\\'), - (r'a\\\$', r'/a\\$'), - (r'a\\\\$', r'/a\\\\'), - (r'a\\\\\$', r'/a\\\\$'), - (r'a\?', r'/a?'), - (r'a\\?', r'/a\\'), - (r'a\\\?', r'/a\\?'), - (r'a\\\\?', r'/a\\\\'), - (r'a\\\\\?', r'/a\\\\?'), - (r'a\*', r'/a*'), - (r'a\\*', r'/a\\'), - (r'a\\\*', r'/a\\*'), - (r'a\\\\*', r'/a\\\\'), - (r'a\\\\\*', r'/a\\\\*'), - (r'a\+', r'/a+'), - (r'a\\+', r'/a\\'), - (r'a\\\+', r'/a\\+'), - (r'a\\\\+', r'/a\\\\'), - (r'a\\\\\+', r'/a\\\\+'), - (r'\\Aa', r'/\Aa'), - (r'\\\Aa', r'/\\a'), - (r'\\\\Aa', r'/\\\Aa'), - (r'\\\\\Aa', r'/\\\\a'), - (r'\\\\\\Aa', r'/\\\\\Aa'), - (r'a\\Z', r'/a\Z'), - (r'a\\\Z', r'/a\\'), - (r'a\\\\Z', r'/a\\\Z'), - (r'a\\\\\Z', r'/a\\\\'), - (r'a\\\\\\Z', r'/a\\\\\Z'), + (r"\^a", r"/^a"), + (r"\\^a", r"/\\a"), + (r"\\\^a", r"/\\^a"), + (r"\\\\^a", r"/\\\\a"), + (r"\\\\\^a", r"/\\\\^a"), + (r"a\$", r"/a$"), + (r"a\\$", r"/a\\"), + (r"a\\\$", r"/a\\$"), + (r"a\\\\$", r"/a\\\\"), + (r"a\\\\\$", r"/a\\\\$"), + (r"a\?", r"/a?"), + (r"a\\?", r"/a\\"), + (r"a\\\?", r"/a\\?"), + (r"a\\\\?", r"/a\\\\"), + (r"a\\\\\?", r"/a\\\\?"), + (r"a\*", r"/a*"), + (r"a\\*", r"/a\\"), + (r"a\\\*", r"/a\\*"), + (r"a\\\\*", r"/a\\\\"), + (r"a\\\\\*", r"/a\\\\*"), + (r"a\+", r"/a+"), + (r"a\\+", r"/a\\"), + (r"a\\\+", r"/a\\+"), + (r"a\\\\+", r"/a\\\\"), + (r"a\\\\\+", r"/a\\\\+"), + (r"\\Aa", r"/\Aa"), + (r"\\\Aa", r"/\\a"), + (r"\\\\Aa", r"/\\\Aa"), + (r"\\\\\Aa", r"/\\\\a"), + (r"\\\\\\Aa", r"/\\\\\Aa"), + (r"a\\Z", r"/a\Z"), + (r"a\\\Z", r"/a\\"), + (r"a\\\\Z", r"/a\\\Z"), + (r"a\\\\\Z", r"/a\\\\"), + (r"a\\\\\\Z", r"/a\\\\\Z"), # Escaped mixed metacharacters. - (r'^a\?$', r'/a?'), - (r'^a\\?$', r'/a\\'), - (r'^a\\\?$', r'/a\\?'), - (r'^a\\\\?$', r'/a\\\\'), - (r'^a\\\\\?$', r'/a\\\\?'), + (r"^a\?$", r"/a?"), + (r"^a\\?$", r"/a\\"), + (r"^a\\\?$", r"/a\\?"), + (r"^a\\\\?$", r"/a\\\\"), + (r"^a\\\\\?$", r"/a\\\\?"), # Adjacent escaped metacharacters. - (r'^a\?\$', r'/a?$'), - (r'^a\\?\\$', r'/a\\\\'), - (r'^a\\\?\\\$', r'/a\\?\\$'), - (r'^a\\\\?\\\\$', r'/a\\\\\\\\'), - (r'^a\\\\\?\\\\\$', r'/a\\\\?\\\\$'), + (r"^a\?\$", r"/a?$"), + (r"^a\\?\\$", r"/a\\\\"), + (r"^a\\\?\\\$", r"/a\\?\\$"), + (r"^a\\\\?\\\\$", r"/a\\\\\\\\"), + (r"^a\\\\\?\\\\\$", r"/a\\\\?\\\\$"), # Complex examples with metacharacters and (un)named groups. - (r'^\b(?P\w+)\B/(\w+)?', '//'), - (r'^\A(?P\w+)\Z', '/'), + (r"^\b(?P\w+)\B/(\w+)?", "//"), + (r"^\A(?P\w+)\Z", "/"), ) for pattern, output in tests: with self.subTest(pattern=pattern): diff --git a/tests/admin_docs/tests.py b/tests/admin_docs/tests.py index d53cb80c94..3ab2970b90 100644 --- a/tests/admin_docs/tests.py +++ b/tests/admin_docs/tests.py @@ -1,23 +1,22 @@ from django.contrib.auth.models import User -from django.test import ( - SimpleTestCase, TestCase, modify_settings, override_settings, -) +from django.test import SimpleTestCase, TestCase, modify_settings, override_settings class TestDataMixin: - @classmethod def setUpTestData(cls): - cls.superuser = User.objects.create_superuser(username='super', password='secret', email='super@example.com') + cls.superuser = User.objects.create_superuser( + username="super", password="secret", email="super@example.com" + ) -@override_settings(ROOT_URLCONF='admin_docs.urls') -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.admindocs'}) +@override_settings(ROOT_URLCONF="admin_docs.urls") +@modify_settings(INSTALLED_APPS={"append": "django.contrib.admindocs"}) class AdminDocsSimpleTestCase(SimpleTestCase): pass -@override_settings(ROOT_URLCONF='admin_docs.urls') -@modify_settings(INSTALLED_APPS={'append': 'django.contrib.admindocs'}) +@override_settings(ROOT_URLCONF="admin_docs.urls") +@modify_settings(INSTALLED_APPS={"append": "django.contrib.admindocs"}) class AdminDocsTestCase(TestCase): pass diff --git a/tests/admin_docs/urls.py b/tests/admin_docs/urls.py index f535afc9f2..de23d9baf5 100644 --- a/tests/admin_docs/urls.py +++ b/tests/admin_docs/urls.py @@ -3,16 +3,19 @@ from django.urls import include, path from . import views -ns_patterns = ([ - path('xview/func/', views.xview_dec(views.xview), name='func'), -], 'test') +ns_patterns = ( + [ + path("xview/func/", views.xview_dec(views.xview), name="func"), + ], + "test", +) urlpatterns = [ - path('admin/', admin.site.urls), - path('admindocs/', include('django.contrib.admindocs.urls')), - path('', include(ns_patterns, namespace='test')), - path('xview/func/', views.xview_dec(views.xview)), - path('xview/class/', views.xview_dec(views.XViewClass.as_view())), - path('xview/callable_object/', views.xview_dec(views.XViewCallableObject())), - path('xview/callable_object_without_xview/', views.XViewCallableObject()), + path("admin/", admin.site.urls), + path("admindocs/", include("django.contrib.admindocs.urls")), + path("", include(ns_patterns, namespace="test")), + path("xview/func/", views.xview_dec(views.xview)), + path("xview/class/", views.xview_dec(views.XViewClass.as_view())), + path("xview/callable_object/", views.xview_dec(views.XViewCallableObject())), + path("xview/callable_object_without_xview/", views.XViewCallableObject()), ] diff --git a/tests/admin_filters/models.py b/tests/admin_filters/models.py index 90b9cab2ac..53b471dd90 100644 --- a/tests/admin_filters/models.py +++ b/tests/admin_filters/models.py @@ -1,7 +1,5 @@ from django.contrib.auth.models import User -from django.contrib.contenttypes.fields import ( - GenericForeignKey, GenericRelation, -) +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.db import models @@ -13,30 +11,35 @@ class Book(models.Model): User, models.SET_NULL, verbose_name="Verbose Author", - related_name='books_authored', - blank=True, null=True, + related_name="books_authored", + blank=True, + null=True, ) contributors = models.ManyToManyField( User, verbose_name="Verbose Contributors", - related_name='books_contributed', + related_name="books_contributed", blank=True, ) employee = models.ForeignKey( - 'Employee', + "Employee", models.SET_NULL, - verbose_name='Employee', - blank=True, null=True, + verbose_name="Employee", + blank=True, + null=True, ) is_best_seller = models.BooleanField(default=0, null=True) date_registered = models.DateField(null=True) - availability = models.BooleanField(choices=( - (False, 'Paid'), - (True, 'Free'), - (None, 'Obscure'), - ), null=True) + availability = models.BooleanField( + choices=( + (False, "Paid"), + (True, "Free"), + (None, "Obscure"), + ), + null=True, + ) # This field name is intentionally 2 characters long (#16080). - no = models.IntegerField(verbose_name='number', blank=True, null=True) + no = models.IntegerField(verbose_name="number", blank=True, null=True) def __str__(self): return self.title @@ -64,9 +67,11 @@ class Employee(models.Model): class TaggedItem(models.Model): tag = models.SlugField() - content_type = models.ForeignKey(ContentType, models.CASCADE, related_name='tagged_items') + content_type = models.ForeignKey( + ContentType, models.CASCADE, related_name="tagged_items" + ) object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') + content_object = GenericForeignKey("content_type", "object_id") def __str__(self): return self.tag @@ -77,11 +82,13 @@ class Bookmark(models.Model): tags = GenericRelation(TaggedItem) CHOICES = [ - ('a', 'A'), - (None, 'None'), - ('', '-'), + ("a", "A"), + (None, "None"), + ("", "-"), ] - none_or_null = models.CharField(max_length=20, choices=CHOICES, blank=True, null=True) + none_or_null = models.CharField( + max_length=20, choices=CHOICES, blank=True, null=True + ) def __str__(self): return self.url diff --git a/tests/admin_filters/tests.py b/tests/admin_filters/tests.py index 4f1ff26048..14cfb757b5 100644 --- a/tests/admin_filters/tests.py +++ b/tests/admin_filters/tests.py @@ -3,8 +3,13 @@ import sys import unittest from django.contrib.admin import ( - AllValuesFieldListFilter, BooleanFieldListFilter, EmptyFieldListFilter, - FieldListFilter, ModelAdmin, RelatedOnlyFieldListFilter, SimpleListFilter, + AllValuesFieldListFilter, + BooleanFieldListFilter, + EmptyFieldListFilter, + FieldListFilter, + ModelAdmin, + RelatedOnlyFieldListFilter, + SimpleListFilter, site, ) from django.contrib.admin.options import IncorrectLookupParameters @@ -13,9 +18,7 @@ from django.contrib.auth.models import User from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase, override_settings -from .models import ( - Book, Bookmark, Department, Employee, ImprovedBook, TaggedItem, -) +from .models import Book, Bookmark, Department, Employee, ImprovedBook, TaggedItem def select_by(dictlist, key, value): @@ -23,22 +26,21 @@ def select_by(dictlist, key, value): class DecadeListFilter(SimpleListFilter): - def lookups(self, request, model_admin): return ( - ('the 80s', "the 1980's"), - ('the 90s', "the 1990's"), - ('the 00s', "the 2000's"), - ('other', "other decades"), + ("the 80s", "the 1980's"), + ("the 90s", "the 1990's"), + ("the 00s", "the 2000's"), + ("other", "other decades"), ) def queryset(self, request, queryset): decade = self.value() - if decade == 'the 80s': + if decade == "the 80s": return queryset.filter(year__gte=1980, year__lte=1989) - if decade == 'the 90s': + if decade == "the 90s": return queryset.filter(year__gte=1990, year__lte=1999) - if decade == 'the 00s': + if decade == "the 00s": return queryset.filter(year__gte=2000, year__lte=2009) @@ -47,100 +49,103 @@ class NotNinetiesListFilter(SimpleListFilter): parameter_name = "book_year" def lookups(self, request, model_admin): - return ( - ('the 90s', "the 1990's"), - ) + return (("the 90s", "the 1990's"),) def queryset(self, request, queryset): - if self.value() == 'the 90s': + if self.value() == "the 90s": return queryset.filter(year__gte=1990, year__lte=1999) else: return queryset.exclude(year__gte=1990, year__lte=1999) class DecadeListFilterWithTitleAndParameter(DecadeListFilter): - title = 'publication decade' - parameter_name = 'publication-decade' + title = "publication decade" + parameter_name = "publication-decade" class DecadeListFilterWithoutTitle(DecadeListFilter): - parameter_name = 'publication-decade' + parameter_name = "publication-decade" class DecadeListFilterWithoutParameter(DecadeListFilter): - title = 'publication decade' + title = "publication decade" class DecadeListFilterWithNoneReturningLookups(DecadeListFilterWithTitleAndParameter): - def lookups(self, request, model_admin): pass class DecadeListFilterWithFailingQueryset(DecadeListFilterWithTitleAndParameter): - def queryset(self, request, queryset): raise 1 / 0 class DecadeListFilterWithQuerysetBasedLookups(DecadeListFilterWithTitleAndParameter): - def lookups(self, request, model_admin): qs = model_admin.get_queryset(request) if qs.filter(year__gte=1980, year__lte=1989).exists(): - yield ('the 80s', "the 1980's") + yield ("the 80s", "the 1980's") if qs.filter(year__gte=1990, year__lte=1999).exists(): - yield ('the 90s', "the 1990's") + yield ("the 90s", "the 1990's") if qs.filter(year__gte=2000, year__lte=2009).exists(): - yield ('the 00s', "the 2000's") + yield ("the 00s", "the 2000's") class DecadeListFilterParameterEndsWith__In(DecadeListFilter): - title = 'publication decade' - parameter_name = 'decade__in' # Ends with '__in" + title = "publication decade" + parameter_name = "decade__in" # Ends with '__in" class DecadeListFilterParameterEndsWith__Isnull(DecadeListFilter): - title = 'publication decade' - parameter_name = 'decade__isnull' # Ends with '__isnull" + title = "publication decade" + parameter_name = "decade__isnull" # Ends with '__isnull" class DepartmentListFilterLookupWithNonStringValue(SimpleListFilter): - title = 'department' - parameter_name = 'department' + title = "department" + parameter_name = "department" def lookups(self, request, model_admin): - return sorted({ - (employee.department.id, # Intentionally not a string (Refs #19318) - employee.department.code) - for employee in model_admin.get_queryset(request).all() - }) + return sorted( + { + ( + employee.department.id, # Intentionally not a string (Refs #19318) + employee.department.code, + ) + for employee in model_admin.get_queryset(request).all() + } + ) def queryset(self, request, queryset): if self.value(): return queryset.filter(department__id=self.value()) -class DepartmentListFilterLookupWithUnderscoredParameter(DepartmentListFilterLookupWithNonStringValue): - parameter_name = 'department__whatever' +class DepartmentListFilterLookupWithUnderscoredParameter( + DepartmentListFilterLookupWithNonStringValue +): + parameter_name = "department__whatever" class DepartmentListFilterLookupWithDynamicValue(DecadeListFilterWithTitleAndParameter): - def lookups(self, request, model_admin): - if self.value() == 'the 80s': - return (('the 90s', "the 1990's"),) - elif self.value() == 'the 90s': - return (('the 80s', "the 1980's"),) + if self.value() == "the 80s": + return (("the 90s", "the 1990's"),) + elif self.value() == "the 90s": + return (("the 80s", "the 1980's"),) else: - return (('the 80s', "the 1980's"), ('the 90s', "the 1990's"),) + return ( + ("the 80s", "the 1980's"), + ("the 90s", "the 1990's"), + ) class EmployeeNameCustomDividerFilter(FieldListFilter): - list_separator = '|' + list_separator = "|" def __init__(self, field, request, params, model, model_admin, field_path): - self.lookup_kwarg = '%s__in' % field_path + self.lookup_kwarg = "%s__in" % field_path super().__init__(field, request, params, model, model_admin, field_path) def expected_parameters(self): @@ -148,44 +153,51 @@ class EmployeeNameCustomDividerFilter(FieldListFilter): class CustomUserAdmin(UserAdmin): - list_filter = ('books_authored', 'books_contributed') + list_filter = ("books_authored", "books_contributed") class BookAdmin(ModelAdmin): - list_filter = ('year', 'author', 'contributors', 'is_best_seller', 'date_registered', 'no', 'availability') - ordering = ('-id',) + list_filter = ( + "year", + "author", + "contributors", + "is_best_seller", + "date_registered", + "no", + "availability", + ) + ordering = ("-id",) class BookAdminWithTupleBooleanFilter(BookAdmin): list_filter = ( - 'year', - 'author', - 'contributors', - ('is_best_seller', BooleanFieldListFilter), - 'date_registered', - 'no', - ('availability', BooleanFieldListFilter), + "year", + "author", + "contributors", + ("is_best_seller", BooleanFieldListFilter), + "date_registered", + "no", + ("availability", BooleanFieldListFilter), ) class BookAdminWithUnderscoreLookupAndTuple(BookAdmin): list_filter = ( - 'year', - ('author__email', AllValuesFieldListFilter), - 'contributors', - 'is_best_seller', - 'date_registered', - 'no', + "year", + ("author__email", AllValuesFieldListFilter), + "contributors", + "is_best_seller", + "date_registered", + "no", ) class BookAdminWithCustomQueryset(ModelAdmin): - def __init__(self, user, *args, **kwargs): self.user = user super().__init__(*args, **kwargs) - list_filter = ('year',) + list_filter = ("year",) def get_queryset(self, request): return super().get_queryset(request).filter(author=self.user) @@ -193,17 +205,20 @@ class BookAdminWithCustomQueryset(ModelAdmin): class BookAdminRelatedOnlyFilter(ModelAdmin): list_filter = ( - 'year', 'is_best_seller', 'date_registered', 'no', - ('author', RelatedOnlyFieldListFilter), - ('contributors', RelatedOnlyFieldListFilter), - ('employee__department', RelatedOnlyFieldListFilter), + "year", + "is_best_seller", + "date_registered", + "no", + ("author", RelatedOnlyFieldListFilter), + ("contributors", RelatedOnlyFieldListFilter), + ("employee__department", RelatedOnlyFieldListFilter), ) - ordering = ('-id',) + ordering = ("-id",) class DecadeFilterBookAdmin(ModelAdmin): - list_filter = ('author', DecadeListFilterWithTitleAndParameter) - ordering = ('-id',) + list_filter = ("author", DecadeListFilterWithTitleAndParameter) + ordering = ("-id",) class NotNinetiesListFilterAdmin(ModelAdmin): @@ -239,13 +254,13 @@ class DecadeFilterBookAdminParameterEndsWith__Isnull(ModelAdmin): class EmployeeAdmin(ModelAdmin): - list_display = ['name', 'department'] - list_filter = ['department'] + list_display = ["name", "department"] + list_filter = ["department"] class EmployeeCustomDividerFilterAdmin(EmployeeAdmin): list_filter = [ - ('name', EmployeeNameCustomDividerFilter), + ("name", EmployeeNameCustomDividerFilter), ] @@ -262,21 +277,21 @@ class DepartmentFilterDynamicValueBookAdmin(EmployeeAdmin): class BookmarkAdminGenericRelation(ModelAdmin): - list_filter = ['tags__tag'] + list_filter = ["tags__tag"] class BookAdminWithEmptyFieldListFilter(ModelAdmin): list_filter = [ - ('author', EmptyFieldListFilter), - ('title', EmptyFieldListFilter), - ('improvedbook', EmptyFieldListFilter), + ("author", EmptyFieldListFilter), + ("title", EmptyFieldListFilter), + ("improvedbook", EmptyFieldListFilter), ] class DepartmentAdminWithEmptyFieldListFilter(ModelAdmin): list_filter = [ - ('description', EmptyFieldListFilter), - ('employee', EmptyFieldListFilter), + ("description", EmptyFieldListFilter), + ("employee", EmptyFieldListFilter), ] @@ -295,68 +310,83 @@ class ListFiltersTests(TestCase): cls.next_year = cls.today.replace(year=cls.today.year + 1, month=1, day=1) # Users - cls.alfred = User.objects.create_superuser('alfred', 'alfred@example.com', 'password') - cls.bob = User.objects.create_user('bob', 'bob@example.com') - cls.lisa = User.objects.create_user('lisa', 'lisa@example.com') + cls.alfred = User.objects.create_superuser( + "alfred", "alfred@example.com", "password" + ) + cls.bob = User.objects.create_user("bob", "bob@example.com") + cls.lisa = User.objects.create_user("lisa", "lisa@example.com") # Books cls.djangonaut_book = Book.objects.create( - title='Djangonaut: an art of living', year=2009, - author=cls.alfred, is_best_seller=True, date_registered=cls.today, + title="Djangonaut: an art of living", + year=2009, + author=cls.alfred, + is_best_seller=True, + date_registered=cls.today, availability=True, ) cls.bio_book = Book.objects.create( - title='Django: a biography', year=1999, author=cls.alfred, - is_best_seller=False, no=207, + title="Django: a biography", + year=1999, + author=cls.alfred, + is_best_seller=False, + no=207, availability=False, ) cls.django_book = Book.objects.create( - title='The Django Book', year=None, author=cls.bob, - is_best_seller=None, date_registered=cls.today, no=103, + title="The Django Book", + year=None, + author=cls.bob, + is_best_seller=None, + date_registered=cls.today, + no=103, availability=True, ) cls.guitar_book = Book.objects.create( - title='Guitar for dummies', year=2002, is_best_seller=True, + title="Guitar for dummies", + year=2002, + is_best_seller=True, date_registered=cls.one_week_ago, availability=None, ) cls.guitar_book.contributors.set([cls.bob, cls.lisa]) # Departments - cls.dev = Department.objects.create(code='DEV', description='Development') - cls.design = Department.objects.create(code='DSN', description='Design') + cls.dev = Department.objects.create(code="DEV", description="Development") + cls.design = Department.objects.create(code="DSN", description="Design") # Employees - cls.john = Employee.objects.create(name='John Blue', department=cls.dev) - cls.jack = Employee.objects.create(name='Jack Red', department=cls.design) + cls.john = Employee.objects.create(name="John Blue", department=cls.dev) + cls.jack = Employee.objects.create(name="Jack Red", department=cls.design) def test_choicesfieldlistfilter_has_none_choice(self): """ The last choice is for the None value. """ + class BookmarkChoicesAdmin(ModelAdmin): - list_display = ['none_or_null'] - list_filter = ['none_or_null'] + list_display = ["none_or_null"] + list_filter = ["none_or_null"] modeladmin = BookmarkChoicesAdmin(Bookmark, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[-1]['display'], 'None') - self.assertEqual(choices[-1]['query_string'], '?none_or_null__isnull=True') + self.assertEqual(choices[-1]["display"], "None") + self.assertEqual(choices[-1]["query_string"], "?none_or_null__isnull=True") def test_datefieldlistfilter(self): modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist(request) - request = self.request_factory.get('/', { - 'date_registered__gte': self.today, - 'date_registered__lt': self.tomorrow}, + request = self.request_factory.get( + "/", + {"date_registered__gte": self.today, "date_registered__lt": self.tomorrow}, ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -367,48 +397,62 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') + self.assertEqual(filterspec.title, "date registered") choice = select_by(filterspec.choices(changelist), "display", "Today") - self.assertIs(choice['selected'], True) + self.assertIs(choice["selected"], True) self.assertEqual( - choice['query_string'], - '?date_registered__gte=%s&date_registered__lt=%s' % ( + choice["query_string"], + "?date_registered__gte=%s&date_registered__lt=%s" + % ( self.today, self.tomorrow, - ) + ), ) - request = self.request_factory.get('/', { - 'date_registered__gte': self.today.replace(day=1), - 'date_registered__lt': self.next_month}, + request = self.request_factory.get( + "/", + { + "date_registered__gte": self.today.replace(day=1), + "date_registered__lt": self.next_month, + }, ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned queryset = changelist.get_queryset(request) - if (self.today.year, self.today.month) == (self.one_week_ago.year, self.one_week_ago.month): + if (self.today.year, self.today.month) == ( + self.one_week_ago.year, + self.one_week_ago.month, + ): # In case one week ago is in the same month. - self.assertEqual(list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book]) + self.assertEqual( + list(queryset), + [self.guitar_book, self.django_book, self.djangonaut_book], + ) else: self.assertEqual(list(queryset), [self.django_book, self.djangonaut_book]) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') + self.assertEqual(filterspec.title, "date registered") choice = select_by(filterspec.choices(changelist), "display", "This month") - self.assertIs(choice['selected'], True) + self.assertIs(choice["selected"], True) self.assertEqual( - choice['query_string'], - '?date_registered__gte=%s&date_registered__lt=%s' % ( + choice["query_string"], + "?date_registered__gte=%s&date_registered__lt=%s" + % ( self.today.replace(day=1), self.next_month, - ) + ), ) - request = self.request_factory.get('/', { - 'date_registered__gte': self.today.replace(month=1, day=1), - 'date_registered__lt': self.next_year}, + request = self.request_factory.get( + "/", + { + "date_registered__gte": self.today.replace(month=1, day=1), + "date_registered__lt": self.next_year, + }, ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -417,49 +461,59 @@ class ListFiltersTests(TestCase): queryset = changelist.get_queryset(request) if self.today.year == self.one_week_ago.year: # In case one week ago is in the same year. - self.assertEqual(list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book]) + self.assertEqual( + list(queryset), + [self.guitar_book, self.django_book, self.djangonaut_book], + ) else: self.assertEqual(list(queryset), [self.django_book, self.djangonaut_book]) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') + self.assertEqual(filterspec.title, "date registered") choice = select_by(filterspec.choices(changelist), "display", "This year") - self.assertIs(choice['selected'], True) + self.assertIs(choice["selected"], True) self.assertEqual( - choice['query_string'], - '?date_registered__gte=%s&date_registered__lt=%s' % ( + choice["query_string"], + "?date_registered__gte=%s&date_registered__lt=%s" + % ( self.today.replace(month=1, day=1), self.next_year, - ) + ), ) - request = self.request_factory.get('/', { - 'date_registered__gte': str(self.one_week_ago), - 'date_registered__lt': str(self.tomorrow), - }) + request = self.request_factory.get( + "/", + { + "date_registered__gte": str(self.one_week_ago), + "date_registered__lt": str(self.tomorrow), + }, + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned queryset = changelist.get_queryset(request) - self.assertEqual(list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book]) + self.assertEqual( + list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book] + ) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') + self.assertEqual(filterspec.title, "date registered") choice = select_by(filterspec.choices(changelist), "display", "Past 7 days") - self.assertIs(choice['selected'], True) + self.assertIs(choice["selected"], True) self.assertEqual( - choice['query_string'], - '?date_registered__gte=%s&date_registered__lt=%s' % ( + choice["query_string"], + "?date_registered__gte=%s&date_registered__lt=%s" + % ( str(self.one_week_ago), str(self.tomorrow), - ) + ), ) # Null/not null queries - request = self.request_factory.get('/', {'date_registered__isnull': 'True'}) + request = self.request_factory.get("/", {"date_registered__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -470,31 +524,33 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') - choice = select_by(filterspec.choices(changelist), 'display', 'No date') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?date_registered__isnull=True') + self.assertEqual(filterspec.title, "date registered") + choice = select_by(filterspec.choices(changelist), "display", "No date") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?date_registered__isnull=True") - request = self.request_factory.get('/', {'date_registered__isnull': 'False'}) + request = self.request_factory.get("/", {"date_registered__isnull": "False"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned queryset = changelist.get_queryset(request) self.assertEqual(queryset.count(), 3) - self.assertEqual(list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book]) + self.assertEqual( + list(queryset), [self.guitar_book, self.django_book, self.djangonaut_book] + ) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][4] - self.assertEqual(filterspec.title, 'date registered') - choice = select_by(filterspec.choices(changelist), 'display', 'Has date') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?date_registered__isnull=False') + self.assertEqual(filterspec.title, "date registered") + choice = select_by(filterspec.choices(changelist), "display", "Has date") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?date_registered__isnull=False") @unittest.skipIf( - sys.platform == 'win32', + sys.platform == "win32", "Windows doesn't support setting a timezone that differs from the " - "system timezone." + "system timezone.", ) @override_settings(USE_TZ=True) def test_datefieldlistfilter_with_time_zone_support(self): @@ -504,7 +560,7 @@ class ListFiltersTests(TestCase): def test_allvaluesfieldlistfilter(self): modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/', {'year__isnull': 'True'}) + request = self.request_factory.get("/", {"year__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -514,26 +570,26 @@ class ListFiltersTests(TestCase): # Make sure the last choice is None and is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'year') + self.assertEqual(filterspec.title, "year") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[-1]['selected'], True) - self.assertEqual(choices[-1]['query_string'], '?year__isnull=True') + self.assertIs(choices[-1]["selected"], True) + self.assertEqual(choices[-1]["query_string"], "?year__isnull=True") - request = self.request_factory.get('/', {'year': '2002'}) + request = self.request_factory.get("/", {"year": "2002"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'year') + self.assertEqual(filterspec.title, "year") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[2]['selected'], True) - self.assertEqual(choices[2]['query_string'], '?year=2002') + self.assertIs(choices[2]["selected"], True) + self.assertEqual(choices[2]["query_string"], "?year=2002") def test_allvaluesfieldlistfilter_custom_qs(self): # Make sure that correct filters are returned with custom querysets modeladmin = BookAdminWithCustomQueryset(self.alfred, Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -543,23 +599,27 @@ class ListFiltersTests(TestCase): # books written by alfred (which is the filtering criteria set by # BookAdminWithCustomQueryset.get_queryset()) self.assertEqual(3, len(choices)) - self.assertEqual(choices[0]['query_string'], '?') - self.assertEqual(choices[1]['query_string'], '?year=1999') - self.assertEqual(choices[2]['query_string'], '?year=2009') + self.assertEqual(choices[0]["query_string"], "?") + self.assertEqual(choices[1]["query_string"], "?year=1999") + self.assertEqual(choices[2]["query_string"], "?year=2009") def test_relatedfieldlistfilter_foreignkey(self): modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure that all users are present in the author's list filter filterspec = changelist.get_filters(request)[0][1] - expected = [(self.alfred.pk, 'alfred'), (self.bob.pk, 'bob'), (self.lisa.pk, 'lisa')] + expected = [ + (self.alfred.pk, "alfred"), + (self.bob.pk, "bob"), + (self.lisa.pk, "lisa"), + ] self.assertEqual(sorted(filterspec.lookup_choices), sorted(expected)) - request = self.request_factory.get('/', {'author__isnull': 'True'}) + request = self.request_factory.get("/", {"author__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -569,119 +629,133 @@ class ListFiltersTests(TestCase): # Make sure the last choice is None and is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'Verbose Author') + self.assertEqual(filterspec.title, "Verbose Author") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[-1]['selected'], True) - self.assertEqual(choices[-1]['query_string'], '?author__isnull=True') + self.assertIs(choices[-1]["selected"], True) + self.assertEqual(choices[-1]["query_string"], "?author__isnull=True") - request = self.request_factory.get('/', {'author__id__exact': self.alfred.pk}) + request = self.request_factory.get("/", {"author__id__exact": self.alfred.pk}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'Verbose Author') + self.assertEqual(filterspec.title, "Verbose Author") # order of choices depends on User model, which has no order choice = select_by(filterspec.choices(changelist), "display", "alfred") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?author__id__exact=%d' % self.alfred.pk) + self.assertIs(choice["selected"], True) + self.assertEqual( + choice["query_string"], "?author__id__exact=%d" % self.alfred.pk + ) def test_relatedfieldlistfilter_foreignkey_ordering(self): """RelatedFieldListFilter ordering respects ModelAdmin.ordering.""" + class EmployeeAdminWithOrdering(ModelAdmin): - ordering = ('name',) + ordering = ("name",) class BookAdmin(ModelAdmin): - list_filter = ('employee',) + list_filter = ("employee",) site.register(Employee, EmployeeAdminWithOrdering) self.addCleanup(lambda: site.unregister(Employee)) modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - expected = [(self.jack.pk, 'Jack Red'), (self.john.pk, 'John Blue')] + expected = [(self.jack.pk, "Jack Red"), (self.john.pk, "John Blue")] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedfieldlistfilter_foreignkey_ordering_reverse(self): class EmployeeAdminWithOrdering(ModelAdmin): - ordering = ('-name',) + ordering = ("-name",) class BookAdmin(ModelAdmin): - list_filter = ('employee',) + list_filter = ("employee",) site.register(Employee, EmployeeAdminWithOrdering) self.addCleanup(lambda: site.unregister(Employee)) modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - expected = [(self.john.pk, 'John Blue'), (self.jack.pk, 'Jack Red')] + expected = [(self.john.pk, "John Blue"), (self.jack.pk, "Jack Red")] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedfieldlistfilter_foreignkey_default_ordering(self): """RelatedFieldListFilter ordering respects Model.ordering.""" - class BookAdmin(ModelAdmin): - list_filter = ('employee',) - self.addCleanup(setattr, Employee._meta, 'ordering', Employee._meta.ordering) - Employee._meta.ordering = ('name',) + class BookAdmin(ModelAdmin): + list_filter = ("employee",) + + self.addCleanup(setattr, Employee._meta, "ordering", Employee._meta.ordering) + Employee._meta.ordering = ("name",) modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - expected = [(self.jack.pk, 'Jack Red'), (self.john.pk, 'John Blue')] + expected = [(self.jack.pk, "Jack Red"), (self.john.pk, "John Blue")] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedfieldlistfilter_manytomany(self): modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure that all users are present in the contrib's list filter filterspec = changelist.get_filters(request)[0][2] - expected = [(self.alfred.pk, 'alfred'), (self.bob.pk, 'bob'), (self.lisa.pk, 'lisa')] + expected = [ + (self.alfred.pk, "alfred"), + (self.bob.pk, "bob"), + (self.lisa.pk, "lisa"), + ] self.assertEqual(sorted(filterspec.lookup_choices), sorted(expected)) - request = self.request_factory.get('/', {'contributors__isnull': 'True'}) + request = self.request_factory.get("/", {"contributors__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned queryset = changelist.get_queryset(request) - self.assertEqual(list(queryset), [self.django_book, self.bio_book, self.djangonaut_book]) + self.assertEqual( + list(queryset), [self.django_book, self.bio_book, self.djangonaut_book] + ) # Make sure the last choice is None and is selected filterspec = changelist.get_filters(request)[0][2] - self.assertEqual(filterspec.title, 'Verbose Contributors') + self.assertEqual(filterspec.title, "Verbose Contributors") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[-1]['selected'], True) - self.assertEqual(choices[-1]['query_string'], '?contributors__isnull=True') + self.assertIs(choices[-1]["selected"], True) + self.assertEqual(choices[-1]["query_string"], "?contributors__isnull=True") - request = self.request_factory.get('/', {'contributors__id__exact': self.bob.pk}) + request = self.request_factory.get( + "/", {"contributors__id__exact": self.bob.pk} + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][2] - self.assertEqual(filterspec.title, 'Verbose Contributors') + self.assertEqual(filterspec.title, "Verbose Contributors") choice = select_by(filterspec.choices(changelist), "display", "bob") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?contributors__id__exact=%d' % self.bob.pk) + self.assertIs(choice["selected"], True) + self.assertEqual( + choice["query_string"], "?contributors__id__exact=%d" % self.bob.pk + ) def test_relatedfieldlistfilter_reverse_relationships(self): modeladmin = CustomUserAdmin(User, site) # FK relationship ----- - request = self.request_factory.get('/', {'books_authored__isnull': 'True'}) + request = self.request_factory.get("/", {"books_authored__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -691,24 +765,30 @@ class ListFiltersTests(TestCase): # Make sure the last choice is None and is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'book') + self.assertEqual(filterspec.title, "book") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[-1]['selected'], True) - self.assertEqual(choices[-1]['query_string'], '?books_authored__isnull=True') + self.assertIs(choices[-1]["selected"], True) + self.assertEqual(choices[-1]["query_string"], "?books_authored__isnull=True") - request = self.request_factory.get('/', {'books_authored__id__exact': self.bio_book.pk}) + request = self.request_factory.get( + "/", {"books_authored__id__exact": self.bio_book.pk} + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'book') - choice = select_by(filterspec.choices(changelist), "display", self.bio_book.title) - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?books_authored__id__exact=%d' % self.bio_book.pk) + self.assertEqual(filterspec.title, "book") + choice = select_by( + filterspec.choices(changelist), "display", self.bio_book.title + ) + self.assertIs(choice["selected"], True) + self.assertEqual( + choice["query_string"], "?books_authored__id__exact=%d" % self.bio_book.pk + ) # M2M relationship ----- - request = self.request_factory.get('/', {'books_contributed__isnull': 'True'}) + request = self.request_factory.get("/", {"books_contributed__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -718,21 +798,28 @@ class ListFiltersTests(TestCase): # Make sure the last choice is None and is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'book') + self.assertEqual(filterspec.title, "book") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[-1]['selected'], True) - self.assertEqual(choices[-1]['query_string'], '?books_contributed__isnull=True') + self.assertIs(choices[-1]["selected"], True) + self.assertEqual(choices[-1]["query_string"], "?books_contributed__isnull=True") - request = self.request_factory.get('/', {'books_contributed__id__exact': self.django_book.pk}) + request = self.request_factory.get( + "/", {"books_contributed__id__exact": self.django_book.pk} + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'book') - choice = select_by(filterspec.choices(changelist), "display", self.django_book.title) - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?books_contributed__id__exact=%d' % self.django_book.pk) + self.assertEqual(filterspec.title, "book") + choice = select_by( + filterspec.choices(changelist), "display", self.django_book.title + ) + self.assertIs(choice["selected"], True) + self.assertEqual( + choice["query_string"], + "?books_contributed__id__exact=%d" % self.django_book.pk, + ) # With one book, the list filter should appear because there is also a # (None) option. @@ -745,39 +832,37 @@ class ListFiltersTests(TestCase): self.assertEqual(len(filterspec), 0) def test_relatedfieldlistfilter_reverse_relationships_default_ordering(self): - self.addCleanup(setattr, Book._meta, 'ordering', Book._meta.ordering) - Book._meta.ordering = ('title',) + self.addCleanup(setattr, Book._meta, "ordering", Book._meta.ordering) + Book._meta.ordering = ("title",) modeladmin = CustomUserAdmin(User, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] expected = [ - (self.bio_book.pk, 'Django: a biography'), - (self.djangonaut_book.pk, 'Djangonaut: an art of living'), - (self.guitar_book.pk, 'Guitar for dummies'), - (self.django_book.pk, 'The Django Book') + (self.bio_book.pk, "Django: a biography"), + (self.djangonaut_book.pk, "Djangonaut: an art of living"), + (self.guitar_book.pk, "Guitar for dummies"), + (self.django_book.pk, "The Django Book"), ] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedonlyfieldlistfilter_foreignkey(self): modeladmin = BookAdminRelatedOnlyFilter(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure that only actual authors are present in author's list filter filterspec = changelist.get_filters(request)[0][4] - expected = [(self.alfred.pk, 'alfred'), (self.bob.pk, 'bob')] + expected = [(self.alfred.pk, "alfred"), (self.bob.pk, "bob")] self.assertEqual(sorted(filterspec.lookup_choices), sorted(expected)) def test_relatedonlyfieldlistfilter_foreignkey_reverse_relationships(self): class EmployeeAdminReverseRelationship(ModelAdmin): - list_filter = ( - ('book', RelatedOnlyFieldListFilter), - ) + list_filter = (("book", RelatedOnlyFieldListFilter),) self.djangonaut_book.employee = self.john self.djangonaut_book.save() @@ -785,42 +870,42 @@ class ListFiltersTests(TestCase): self.django_book.save() modeladmin = EmployeeAdminReverseRelationship(Employee, site) - request = self.request_factory.get('/') - request.user = self.alfred - changelist = modeladmin.get_changelist_instance(request) - filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.lookup_choices, [ - (self.djangonaut_book.pk, 'Djangonaut: an art of living'), - (self.django_book.pk, 'The Django Book'), - ]) - - def test_relatedonlyfieldlistfilter_manytomany_reverse_relationships(self): - class UserAdminReverseRelationship(ModelAdmin): - list_filter = ( - ('books_contributed', RelatedOnlyFieldListFilter), - ) - - modeladmin = UserAdminReverseRelationship(User, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] self.assertEqual( filterspec.lookup_choices, - [(self.guitar_book.pk, 'Guitar for dummies')], + [ + (self.djangonaut_book.pk, "Djangonaut: an art of living"), + (self.django_book.pk, "The Django Book"), + ], + ) + + def test_relatedonlyfieldlistfilter_manytomany_reverse_relationships(self): + class UserAdminReverseRelationship(ModelAdmin): + list_filter = (("books_contributed", RelatedOnlyFieldListFilter),) + + modeladmin = UserAdminReverseRelationship(User, site) + request = self.request_factory.get("/") + request.user = self.alfred + changelist = modeladmin.get_changelist_instance(request) + filterspec = changelist.get_filters(request)[0][0] + self.assertEqual( + filterspec.lookup_choices, + [(self.guitar_book.pk, "Guitar for dummies")], ) def test_relatedonlyfieldlistfilter_foreignkey_ordering(self): """RelatedOnlyFieldListFilter ordering respects ModelAdmin.ordering.""" + class EmployeeAdminWithOrdering(ModelAdmin): - ordering = ('name',) + ordering = ("name",) class BookAdmin(ModelAdmin): - list_filter = ( - ('employee', RelatedOnlyFieldListFilter), - ) + list_filter = (("employee", RelatedOnlyFieldListFilter),) - albert = Employee.objects.create(name='Albert Green', department=self.dev) + albert = Employee.objects.create(name="Albert Green", department=self.dev) self.djangonaut_book.employee = albert self.djangonaut_book.save() self.bio_book.employee = self.jack @@ -830,46 +915,45 @@ class ListFiltersTests(TestCase): self.addCleanup(lambda: site.unregister(Employee)) modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - expected = [(albert.pk, 'Albert Green'), (self.jack.pk, 'Jack Red')] + expected = [(albert.pk, "Albert Green"), (self.jack.pk, "Jack Red")] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedonlyfieldlistfilter_foreignkey_default_ordering(self): """RelatedOnlyFieldListFilter ordering respects Meta.ordering.""" - class BookAdmin(ModelAdmin): - list_filter = ( - ('employee', RelatedOnlyFieldListFilter), - ) - albert = Employee.objects.create(name='Albert Green', department=self.dev) + class BookAdmin(ModelAdmin): + list_filter = (("employee", RelatedOnlyFieldListFilter),) + + albert = Employee.objects.create(name="Albert Green", department=self.dev) self.djangonaut_book.employee = albert self.djangonaut_book.save() self.bio_book.employee = self.jack self.bio_book.save() - self.addCleanup(setattr, Employee._meta, 'ordering', Employee._meta.ordering) - Employee._meta.ordering = ('name',) + self.addCleanup(setattr, Employee._meta, "ordering", Employee._meta.ordering) + Employee._meta.ordering = ("name",) modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - expected = [(albert.pk, 'Albert Green'), (self.jack.pk, 'Jack Red')] + expected = [(albert.pk, "Albert Green"), (self.jack.pk, "Jack Red")] self.assertEqual(filterspec.lookup_choices, expected) def test_relatedonlyfieldlistfilter_underscorelookup_foreignkey(self): - Department.objects.create(code='TEST', description='Testing') + Department.objects.create(code="TEST", description="Testing") self.djangonaut_book.employee = self.john self.djangonaut_book.save() self.bio_book.employee = self.jack self.bio_book.save() modeladmin = BookAdminRelatedOnlyFilter(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -885,27 +969,27 @@ class ListFiltersTests(TestCase): def test_relatedonlyfieldlistfilter_manytomany(self): modeladmin = BookAdminRelatedOnlyFilter(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure that only actual contributors are present in contrib's list filter filterspec = changelist.get_filters(request)[0][5] - expected = [(self.bob.pk, 'bob'), (self.lisa.pk, 'lisa')] + expected = [(self.bob.pk, "bob"), (self.lisa.pk, "lisa")] self.assertEqual(sorted(filterspec.lookup_choices), sorted(expected)) def test_listfilter_genericrelation(self): - django_bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') - python_bookmark = Bookmark.objects.create(url='https://www.python.org/') - kernel_bookmark = Bookmark.objects.create(url='https://www.kernel.org/') + django_bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/") + python_bookmark = Bookmark.objects.create(url="https://www.python.org/") + kernel_bookmark = Bookmark.objects.create(url="https://www.kernel.org/") - TaggedItem.objects.create(content_object=django_bookmark, tag='python') - TaggedItem.objects.create(content_object=python_bookmark, tag='python') - TaggedItem.objects.create(content_object=kernel_bookmark, tag='linux') + TaggedItem.objects.create(content_object=django_bookmark, tag="python") + TaggedItem.objects.create(content_object=python_bookmark, tag="python") + TaggedItem.objects.create(content_object=kernel_bookmark, tag="linux") modeladmin = BookmarkAdminGenericRelation(Bookmark, site) - request = self.request_factory.get('/', {'tags__tag': 'python'}) + request = self.request_factory.get("/", {"tags__tag": "python"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) @@ -922,11 +1006,11 @@ class ListFiltersTests(TestCase): self.verify_booleanfieldlistfilter(modeladmin) def verify_booleanfieldlistfilter(self, modeladmin): - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) - request = self.request_factory.get('/', {'is_best_seller__exact': 0}) + request = self.request_factory.get("/", {"is_best_seller__exact": 0}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -936,12 +1020,12 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][3] - self.assertEqual(filterspec.title, 'is best seller') + self.assertEqual(filterspec.title, "is best seller") choice = select_by(filterspec.choices(changelist), "display", "No") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?is_best_seller__exact=0') + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?is_best_seller__exact=0") - request = self.request_factory.get('/', {'is_best_seller__exact': 1}) + request = self.request_factory.get("/", {"is_best_seller__exact": 1}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -951,12 +1035,12 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][3] - self.assertEqual(filterspec.title, 'is best seller') + self.assertEqual(filterspec.title, "is best seller") choice = select_by(filterspec.choices(changelist), "display", "Yes") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?is_best_seller__exact=1') + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?is_best_seller__exact=1") - request = self.request_factory.get('/', {'is_best_seller__isnull': 'True'}) + request = self.request_factory.get("/", {"is_best_seller__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -966,10 +1050,10 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][3] - self.assertEqual(filterspec.title, 'is best seller') + self.assertEqual(filterspec.title, "is best seller") choice = select_by(filterspec.choices(changelist), "display", "Unknown") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?is_best_seller__isnull=True') + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?is_best_seller__isnull=True") def test_booleanfieldlistfilter_choices(self): modeladmin = BookAdmin(Book, site) @@ -981,40 +1065,40 @@ class ListFiltersTests(TestCase): def verify_booleanfieldlistfilter_choices(self, modeladmin): # False. - request = self.request_factory.get('/', {'availability__exact': 0}) + request = self.request_factory.get("/", {"availability__exact": 0}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) self.assertEqual(list(queryset), [self.bio_book]) filterspec = changelist.get_filters(request)[0][6] - self.assertEqual(filterspec.title, 'availability') - choice = select_by(filterspec.choices(changelist), 'display', 'Paid') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?availability__exact=0') + self.assertEqual(filterspec.title, "availability") + choice = select_by(filterspec.choices(changelist), "display", "Paid") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?availability__exact=0") # True. - request = self.request_factory.get('/', {'availability__exact': 1}) + request = self.request_factory.get("/", {"availability__exact": 1}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) self.assertEqual(list(queryset), [self.django_book, self.djangonaut_book]) filterspec = changelist.get_filters(request)[0][6] - self.assertEqual(filterspec.title, 'availability') - choice = select_by(filterspec.choices(changelist), 'display', 'Free') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?availability__exact=1') + self.assertEqual(filterspec.title, "availability") + choice = select_by(filterspec.choices(changelist), "display", "Free") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?availability__exact=1") # None. - request = self.request_factory.get('/', {'availability__isnull': 'True'}) + request = self.request_factory.get("/", {"availability__isnull": "True"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) self.assertEqual(list(queryset), [self.guitar_book]) filterspec = changelist.get_filters(request)[0][6] - self.assertEqual(filterspec.title, 'availability') - choice = select_by(filterspec.choices(changelist), 'display', 'Obscure') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?availability__isnull=True') + self.assertEqual(filterspec.title, "availability") + choice = select_by(filterspec.choices(changelist), "display", "Obscure") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?availability__isnull=True") # All. - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) @@ -1023,10 +1107,10 @@ class ListFiltersTests(TestCase): [self.guitar_book, self.django_book, self.bio_book, self.djangonaut_book], ) filterspec = changelist.get_filters(request)[0][6] - self.assertEqual(filterspec.title, 'availability') - choice = select_by(filterspec.choices(changelist), 'display', 'All') - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?') + self.assertEqual(filterspec.title, "availability") + choice = select_by(filterspec.choices(changelist), "display", "All") + self.assertIs(choice["selected"], True) + self.assertEqual(choice["query_string"], "?") def test_fieldlistfilter_underscorelookup_tuple(self): """ @@ -1034,11 +1118,11 @@ class ListFiltersTests(TestCase): when fieldpath contains double underscore in value (#19182). """ modeladmin = BookAdminWithUnderscoreLookupAndTuple(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) - request = self.request_factory.get('/', {'author__email': 'alfred@example.com'}) + request = self.request_factory.get("/", {"author__email": "alfred@example.com"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1049,7 +1133,9 @@ class ListFiltersTests(TestCase): def test_fieldlistfilter_invalid_lookup_parameters(self): """Filtering by an invalid value.""" modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/', {'author__id__exact': 'StringNotInteger!'}) + request = self.request_factory.get( + "/", {"author__id__exact": "StringNotInteger!"} + ) request.user = self.alfred with self.assertRaises(IncorrectLookupParameters): modeladmin.get_changelist_instance(request) @@ -1058,24 +1144,24 @@ class ListFiltersTests(TestCase): modeladmin = DecadeFilterBookAdmin(Book, site) # Make sure that the first option is 'All' --------------------------- - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned queryset = changelist.get_queryset(request) - self.assertEqual(list(queryset), list(Book.objects.all().order_by('-id'))) + self.assertEqual(list(queryset), list(Book.objects.all().order_by("-id"))) # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[0]['display'], 'All') - self.assertIs(choices[0]['selected'], True) - self.assertEqual(choices[0]['query_string'], '?') + self.assertEqual(choices[0]["display"], "All") + self.assertIs(choices[0]["selected"], True) + self.assertEqual(choices[0]["query_string"], "?") # Look for books in the 1980s ---------------------------------------- - request = self.request_factory.get('/', {'publication-decade': 'the 80s'}) + request = self.request_factory.get("/", {"publication-decade": "the 80s"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1085,14 +1171,14 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[1]['display'], 'the 1980\'s') - self.assertIs(choices[1]['selected'], True) - self.assertEqual(choices[1]['query_string'], '?publication-decade=the+80s') + self.assertEqual(choices[1]["display"], "the 1980's") + self.assertIs(choices[1]["selected"], True) + self.assertEqual(choices[1]["query_string"], "?publication-decade=the+80s") # Look for books in the 1990s ---------------------------------------- - request = self.request_factory.get('/', {'publication-decade': 'the 90s'}) + request = self.request_factory.get("/", {"publication-decade": "the 90s"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1102,14 +1188,14 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[2]['display'], 'the 1990\'s') - self.assertIs(choices[2]['selected'], True) - self.assertEqual(choices[2]['query_string'], '?publication-decade=the+90s') + self.assertEqual(choices[2]["display"], "the 1990's") + self.assertIs(choices[2]["selected"], True) + self.assertEqual(choices[2]["query_string"], "?publication-decade=the+90s") # Look for books in the 2000s ---------------------------------------- - request = self.request_factory.get('/', {'publication-decade': 'the 00s'}) + request = self.request_factory.get("/", {"publication-decade": "the 00s"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1119,14 +1205,16 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[3]['display'], 'the 2000\'s') - self.assertIs(choices[3]['selected'], True) - self.assertEqual(choices[3]['query_string'], '?publication-decade=the+00s') + self.assertEqual(choices[3]["display"], "the 2000's") + self.assertIs(choices[3]["selected"], True) + self.assertEqual(choices[3]["query_string"], "?publication-decade=the+00s") # Combine multiple filters ------------------------------------------- - request = self.request_factory.get('/', {'publication-decade': 'the 00s', 'author__id__exact': self.alfred.pk}) + request = self.request_factory.get( + "/", {"publication-decade": "the 00s", "author__id__exact": self.alfred.pk} + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1136,29 +1224,34 @@ class ListFiltersTests(TestCase): # Make sure the correct choices are selected filterspec = changelist.get_filters(request)[0][1] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[3]['display'], 'the 2000\'s') - self.assertIs(choices[3]['selected'], True) + self.assertEqual(choices[3]["display"], "the 2000's") + self.assertIs(choices[3]["selected"], True) self.assertEqual( - choices[3]['query_string'], - '?author__id__exact=%s&publication-decade=the+00s' % self.alfred.pk + choices[3]["query_string"], + "?author__id__exact=%s&publication-decade=the+00s" % self.alfred.pk, ) filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'Verbose Author') + self.assertEqual(filterspec.title, "Verbose Author") choice = select_by(filterspec.choices(changelist), "display", "alfred") - self.assertIs(choice['selected'], True) - self.assertEqual(choice['query_string'], '?author__id__exact=%s&publication-decade=the+00s' % self.alfred.pk) + self.assertIs(choice["selected"], True) + self.assertEqual( + choice["query_string"], + "?author__id__exact=%s&publication-decade=the+00s" % self.alfred.pk, + ) def test_listfilter_without_title(self): """ Any filter must define a title. """ modeladmin = DecadeFilterBookAdminWithoutTitle(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred - msg = "The list filter 'DecadeListFilterWithoutTitle' does not specify a 'title'." + msg = ( + "The list filter 'DecadeListFilterWithoutTitle' does not specify a 'title'." + ) with self.assertRaisesMessage(ImproperlyConfigured, msg): modeladmin.get_changelist_instance(request) @@ -1167,7 +1260,7 @@ class ListFiltersTests(TestCase): Any SimpleListFilter must define a parameter_name. """ modeladmin = DecadeFilterBookAdminWithoutParameter(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred msg = "The list filter 'DecadeListFilterWithoutParameter' does not specify a 'parameter_name'." with self.assertRaisesMessage(ImproperlyConfigured, msg): @@ -1179,7 +1272,7 @@ class ListFiltersTests(TestCase): filter completely. """ modeladmin = DecadeFilterBookAdminWithNoneReturningLookups(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0] @@ -1191,40 +1284,40 @@ class ListFiltersTests(TestCase): the corresponding exception doesn't get swallowed (#17828). """ modeladmin = DecadeFilterBookAdminWithFailingQueryset(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred with self.assertRaises(ZeroDivisionError): modeladmin.get_changelist_instance(request) def test_simplelistfilter_with_queryset_based_lookups(self): modeladmin = DecadeFilterBookAdminWithQuerysetBasedLookups(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) self.assertEqual(len(choices), 3) - self.assertEqual(choices[0]['display'], 'All') - self.assertIs(choices[0]['selected'], True) - self.assertEqual(choices[0]['query_string'], '?') + self.assertEqual(choices[0]["display"], "All") + self.assertIs(choices[0]["selected"], True) + self.assertEqual(choices[0]["query_string"], "?") - self.assertEqual(choices[1]['display'], 'the 1990\'s') - self.assertIs(choices[1]['selected'], False) - self.assertEqual(choices[1]['query_string'], '?publication-decade=the+90s') + self.assertEqual(choices[1]["display"], "the 1990's") + self.assertIs(choices[1]["selected"], False) + self.assertEqual(choices[1]["query_string"], "?publication-decade=the+90s") - self.assertEqual(choices[2]['display'], 'the 2000\'s') - self.assertIs(choices[2]['selected'], False) - self.assertEqual(choices[2]['query_string'], '?publication-decade=the+00s') + self.assertEqual(choices[2]["display"], "the 2000's") + self.assertIs(choices[2]["selected"], False) + self.assertEqual(choices[2]["query_string"], "?publication-decade=the+00s") def test_two_characters_long_field(self): """ list_filter works with two-characters long field names (#16080). """ modeladmin = BookAdmin(Book, site) - request = self.request_factory.get('/', {'no': '207'}) + request = self.request_factory.get("/", {"no": "207"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1233,10 +1326,10 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), [self.bio_book]) filterspec = changelist.get_filters(request)[0][5] - self.assertEqual(filterspec.title, 'number') + self.assertEqual(filterspec.title, "number") choices = list(filterspec.choices(changelist)) - self.assertIs(choices[2]['selected'], True) - self.assertEqual(choices[2]['query_string'], '?no=207') + self.assertIs(choices[2]["selected"], True) + self.assertEqual(choices[2]["query_string"], "?no=207") def test_parameter_ends_with__in__or__isnull(self): """ @@ -1245,7 +1338,7 @@ class ListFiltersTests(TestCase): """ # When it ends with '__in' ----------------------------------------- modeladmin = DecadeFilterBookAdminParameterEndsWith__In(Book, site) - request = self.request_factory.get('/', {'decade__in': 'the 90s'}) + request = self.request_factory.get("/", {"decade__in": "the 90s"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1255,15 +1348,15 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[2]['display'], 'the 1990\'s') - self.assertIs(choices[2]['selected'], True) - self.assertEqual(choices[2]['query_string'], '?decade__in=the+90s') + self.assertEqual(choices[2]["display"], "the 1990's") + self.assertIs(choices[2]["selected"], True) + self.assertEqual(choices[2]["query_string"], "?decade__in=the+90s") # When it ends with '__isnull' --------------------------------------- modeladmin = DecadeFilterBookAdminParameterEndsWith__Isnull(Book, site) - request = self.request_factory.get('/', {'decade__isnull': 'the 90s'}) + request = self.request_factory.get("/", {"decade__isnull": "the 90s"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1273,11 +1366,11 @@ class ListFiltersTests(TestCase): # Make sure the correct choice is selected filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'publication decade') + self.assertEqual(filterspec.title, "publication decade") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[2]['display'], 'the 1990\'s') - self.assertIs(choices[2]['selected'], True) - self.assertEqual(choices[2]['query_string'], '?decade__isnull=the+90s') + self.assertEqual(choices[2]["display"], "the 1990's") + self.assertIs(choices[2]["selected"], True) + self.assertEqual(choices[2]["query_string"], "?decade__isnull=the+90s") def test_lookup_with_non_string_value(self): """ @@ -1285,7 +1378,7 @@ class ListFiltersTests(TestCase): for lookups in SimpleListFilters (#19318). """ modeladmin = DepartmentFilterEmployeeAdmin(Employee, site) - request = self.request_factory.get('/', {'department': self.john.department.pk}) + request = self.request_factory.get("/", {"department": self.john.department.pk}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1294,11 +1387,13 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), [self.john]) filterspec = changelist.get_filters(request)[0][-1] - self.assertEqual(filterspec.title, 'department') + self.assertEqual(filterspec.title, "department") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[1]['display'], 'DEV') - self.assertIs(choices[1]['selected'], True) - self.assertEqual(choices[1]['query_string'], '?department=%s' % self.john.department.pk) + self.assertEqual(choices[1]["display"], "DEV") + self.assertIs(choices[1]["selected"], True) + self.assertEqual( + choices[1]["query_string"], "?department=%s" % self.john.department.pk + ) def test_lookup_with_non_string_value_underscored(self): """ @@ -1306,7 +1401,9 @@ class ListFiltersTests(TestCase): parameter_name attribute contains double-underscore value (#19182). """ modeladmin = DepartmentFilterUnderscoredEmployeeAdmin(Employee, site) - request = self.request_factory.get('/', {'department__whatever': self.john.department.pk}) + request = self.request_factory.get( + "/", {"department__whatever": self.john.department.pk} + ) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1315,11 +1412,14 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), [self.john]) filterspec = changelist.get_filters(request)[0][-1] - self.assertEqual(filterspec.title, 'department') + self.assertEqual(filterspec.title, "department") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[1]['display'], 'DEV') - self.assertIs(choices[1]['selected'], True) - self.assertEqual(choices[1]['query_string'], '?department__whatever=%s' % self.john.department.pk) + self.assertEqual(choices[1]["display"], "DEV") + self.assertIs(choices[1]["selected"], True) + self.assertEqual( + choices[1]["query_string"], + "?department__whatever=%s" % self.john.department.pk, + ) def test_fk_with_to_field(self): """ @@ -1327,7 +1427,7 @@ class ListFiltersTests(TestCase): """ modeladmin = EmployeeAdmin(Employee, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1336,24 +1436,24 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), [self.jack, self.john]) filterspec = changelist.get_filters(request)[0][-1] - self.assertEqual(filterspec.title, 'department') + self.assertEqual(filterspec.title, "department") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[0]['display'], 'All') - self.assertIs(choices[0]['selected'], True) - self.assertEqual(choices[0]['query_string'], '?') + self.assertEqual(choices[0]["display"], "All") + self.assertIs(choices[0]["selected"], True) + self.assertEqual(choices[0]["query_string"], "?") - self.assertEqual(choices[1]['display'], 'Development') - self.assertIs(choices[1]['selected'], False) - self.assertEqual(choices[1]['query_string'], '?department__code__exact=DEV') + self.assertEqual(choices[1]["display"], "Development") + self.assertIs(choices[1]["selected"], False) + self.assertEqual(choices[1]["query_string"], "?department__code__exact=DEV") - self.assertEqual(choices[2]['display'], 'Design') - self.assertIs(choices[2]['selected'], False) - self.assertEqual(choices[2]['query_string'], '?department__code__exact=DSN') + self.assertEqual(choices[2]["display"], "Design") + self.assertIs(choices[2]["selected"], False) + self.assertEqual(choices[2]["query_string"], "?department__code__exact=DSN") # Filter by Department=='Development' -------------------------------- - request = self.request_factory.get('/', {'department__code__exact': 'DEV'}) + request = self.request_factory.get("/", {"department__code__exact": "DEV"}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) @@ -1362,20 +1462,20 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), [self.john]) filterspec = changelist.get_filters(request)[0][-1] - self.assertEqual(filterspec.title, 'department') + self.assertEqual(filterspec.title, "department") choices = list(filterspec.choices(changelist)) - self.assertEqual(choices[0]['display'], 'All') - self.assertIs(choices[0]['selected'], False) - self.assertEqual(choices[0]['query_string'], '?') + self.assertEqual(choices[0]["display"], "All") + self.assertIs(choices[0]["selected"], False) + self.assertEqual(choices[0]["query_string"], "?") - self.assertEqual(choices[1]['display'], 'Development') - self.assertIs(choices[1]['selected'], True) - self.assertEqual(choices[1]['query_string'], '?department__code__exact=DEV') + self.assertEqual(choices[1]["display"], "Development") + self.assertIs(choices[1]["selected"], True) + self.assertEqual(choices[1]["query_string"], "?department__code__exact=DEV") - self.assertEqual(choices[2]['display'], 'Design') - self.assertIs(choices[2]['selected'], False) - self.assertEqual(choices[2]['query_string'], '?department__code__exact=DSN') + self.assertEqual(choices[2]["display"], "Design") + self.assertIs(choices[2]["selected"], False) + self.assertEqual(choices[2]["query_string"], "?department__code__exact=DSN") def test_lookup_with_dynamic_value(self): """ @@ -1387,18 +1487,23 @@ class ListFiltersTests(TestCase): request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'publication decade') - choices = tuple(c['display'] for c in filterspec.choices(changelist)) + self.assertEqual(filterspec.title, "publication decade") + choices = tuple(c["display"] for c in filterspec.choices(changelist)) self.assertEqual(choices, expected_displays) - _test_choices(self.request_factory.get('/', {}), - ("All", "the 1980's", "the 1990's")) + _test_choices( + self.request_factory.get("/", {}), ("All", "the 1980's", "the 1990's") + ) - _test_choices(self.request_factory.get('/', {'publication-decade': 'the 80s'}), - ("All", "the 1990's")) + _test_choices( + self.request_factory.get("/", {"publication-decade": "the 80s"}), + ("All", "the 1990's"), + ) - _test_choices(self.request_factory.get('/', {'publication-decade': 'the 90s'}), - ("All", "the 1980's")) + _test_choices( + self.request_factory.get("/", {"publication-decade": "the 90s"}), + ("All", "the 1980's"), + ) def test_list_filter_queryset_filtered_by_default(self): """ @@ -1406,16 +1511,16 @@ class ListFiltersTests(TestCase): full_result_count. """ modeladmin = NotNinetiesListFilterAdmin(Book, site) - request = self.request_factory.get('/', {}) + request = self.request_factory.get("/", {}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) changelist.get_results(request) self.assertEqual(changelist.full_result_count, 4) def test_emptylistfieldfilter(self): - empty_description = Department.objects.create(code='EMPT', description='') - none_description = Department.objects.create(code='NONE', description=None) - empty_title = Book.objects.create(title='', author=self.alfred) + empty_description = Department.objects.create(code="EMPT", description="") + none_description = Department.objects.create(code="NONE", description=None) + empty_title = Book.objects.create(title="", author=self.alfred) department_admin = DepartmentAdminWithEmptyFieldListFilter(Department, site) book_admin = BookAdminWithEmptyFieldListFilter(Book, site) @@ -1424,27 +1529,32 @@ class ListFiltersTests(TestCase): # Allows nulls and empty strings. ( department_admin, - {'description__isempty': '1'}, + {"description__isempty": "1"}, [empty_description, none_description], ), ( department_admin, - {'description__isempty': '0'}, + {"description__isempty": "0"}, [self.dev, self.design], ), # Allows nulls. - (book_admin, {'author__isempty': '1'}, [self.guitar_book]), + (book_admin, {"author__isempty": "1"}, [self.guitar_book]), ( book_admin, - {'author__isempty': '0'}, + {"author__isempty": "0"}, [self.django_book, self.bio_book, self.djangonaut_book, empty_title], ), # Allows empty strings. - (book_admin, {'title__isempty': '1'}, [empty_title]), + (book_admin, {"title__isempty": "1"}, [empty_title]), ( book_admin, - {'title__isempty': '0'}, - [self.django_book, self.bio_book, self.djangonaut_book, self.guitar_book], + {"title__isempty": "0"}, + [ + self.django_book, + self.bio_book, + self.djangonaut_book, + self.guitar_book, + ], ), ] for modeladmin, query_string, expected_result in tests: @@ -1452,7 +1562,7 @@ class ListFiltersTests(TestCase): modeladmin=modeladmin.__class__.__name__, query_string=query_string, ): - request = self.request_factory.get('/', query_string) + request = self.request_factory.get("/", query_string) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) @@ -1460,12 +1570,10 @@ class ListFiltersTests(TestCase): def test_emptylistfieldfilter_reverse_relationships(self): class UserAdminReverseRelationship(UserAdmin): - list_filter = ( - ('books_contributed', EmptyFieldListFilter), - ) + list_filter = (("books_contributed", EmptyFieldListFilter),) ImprovedBook.objects.create(book=self.guitar_book) - no_employees = Department.objects.create(code='NONE', description=None) + no_employees = Department.objects.create(code="NONE", description=None) book_admin = BookAdminWithEmptyFieldListFilter(Book, site) department_admin = DepartmentAdminWithEmptyFieldListFilter(Department, site) @@ -1475,23 +1583,23 @@ class ListFiltersTests(TestCase): # Reverse one-to-one relationship. ( book_admin, - {'improvedbook__isempty': '1'}, + {"improvedbook__isempty": "1"}, [self.django_book, self.bio_book, self.djangonaut_book], ), - (book_admin, {'improvedbook__isempty': '0'}, [self.guitar_book]), + (book_admin, {"improvedbook__isempty": "0"}, [self.guitar_book]), # Reverse foreign key relationship. - (department_admin, {'employee__isempty': '1'}, [no_employees]), - (department_admin, {'employee__isempty': '0'}, [self.dev, self.design]), + (department_admin, {"employee__isempty": "1"}, [no_employees]), + (department_admin, {"employee__isempty": "0"}, [self.dev, self.design]), # Reverse many-to-many relationship. - (user_admin, {'books_contributed__isempty': '1'}, [self.alfred]), - (user_admin, {'books_contributed__isempty': '0'}, [self.bob, self.lisa]), + (user_admin, {"books_contributed__isempty": "1"}, [self.alfred]), + (user_admin, {"books_contributed__isempty": "0"}, [self.bob, self.lisa]), ] for modeladmin, query_string, expected_result in tests: with self.subTest( modeladmin=modeladmin.__class__.__name__, query_string=query_string, ): - request = self.request_factory.get('/', query_string) + request = self.request_factory.get("/", query_string) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) @@ -1499,25 +1607,23 @@ class ListFiltersTests(TestCase): def test_emptylistfieldfilter_genericrelation(self): class BookmarkGenericRelation(ModelAdmin): - list_filter = ( - ('tags', EmptyFieldListFilter), - ) + list_filter = (("tags", EmptyFieldListFilter),) modeladmin = BookmarkGenericRelation(Bookmark, site) - django_bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') - python_bookmark = Bookmark.objects.create(url='https://www.python.org/') - none_tags = Bookmark.objects.create(url='https://www.kernel.org/') - TaggedItem.objects.create(content_object=django_bookmark, tag='python') - TaggedItem.objects.create(content_object=python_bookmark, tag='python') + django_bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/") + python_bookmark = Bookmark.objects.create(url="https://www.python.org/") + none_tags = Bookmark.objects.create(url="https://www.kernel.org/") + TaggedItem.objects.create(content_object=django_bookmark, tag="python") + TaggedItem.objects.create(content_object=python_bookmark, tag="python") tests = [ - ({'tags__isempty': '1'}, [none_tags]), - ({'tags__isempty': '0'}, [django_bookmark, python_bookmark]), + ({"tags__isempty": "1"}, [none_tags]), + ({"tags__isempty": "0"}, [django_bookmark, python_bookmark]), ] for query_string, expected_result in tests: with self.subTest(query_string=query_string): - request = self.request_factory.get('/', query_string) + request = self.request_factory.get("/", query_string) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) queryset = changelist.get_queryset(request) @@ -1525,32 +1631,32 @@ class ListFiltersTests(TestCase): def test_emptylistfieldfilter_choices(self): modeladmin = BookAdminWithEmptyFieldListFilter(Book, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) filterspec = changelist.get_filters(request)[0][0] - self.assertEqual(filterspec.title, 'Verbose Author') + self.assertEqual(filterspec.title, "Verbose Author") choices = list(filterspec.choices(changelist)) self.assertEqual(len(choices), 3) - self.assertEqual(choices[0]['display'], 'All') - self.assertIs(choices[0]['selected'], True) - self.assertEqual(choices[0]['query_string'], '?') + self.assertEqual(choices[0]["display"], "All") + self.assertIs(choices[0]["selected"], True) + self.assertEqual(choices[0]["query_string"], "?") - self.assertEqual(choices[1]['display'], 'Empty') - self.assertIs(choices[1]['selected'], False) - self.assertEqual(choices[1]['query_string'], '?author__isempty=1') + self.assertEqual(choices[1]["display"], "Empty") + self.assertIs(choices[1]["selected"], False) + self.assertEqual(choices[1]["query_string"], "?author__isempty=1") - self.assertEqual(choices[2]['display'], 'Not empty') - self.assertIs(choices[2]['selected'], False) - self.assertEqual(choices[2]['query_string'], '?author__isempty=0') + self.assertEqual(choices[2]["display"], "Not empty") + self.assertIs(choices[2]["selected"], False) + self.assertEqual(choices[2]["query_string"], "?author__isempty=0") def test_emptylistfieldfilter_non_empty_field(self): class EmployeeAdminWithEmptyFieldListFilter(ModelAdmin): - list_filter = [('department', EmptyFieldListFilter)] + list_filter = [("department", EmptyFieldListFilter)] modeladmin = EmployeeAdminWithEmptyFieldListFilter(Employee, site) - request = self.request_factory.get('/') + request = self.request_factory.get("/") request.user = self.alfred msg = ( "The list filter 'EmptyFieldListFilter' cannot be used with field " @@ -1561,7 +1667,7 @@ class ListFiltersTests(TestCase): def test_emptylistfieldfilter_invalid_lookup_parameters(self): modeladmin = BookAdminWithEmptyFieldListFilter(Book, site) - request = self.request_factory.get('/', {'author__isempty': 42}) + request = self.request_factory.get("/", {"author__isempty": 42}) request.user = self.alfred with self.assertRaises(IncorrectLookupParameters): modeladmin.get_changelist_instance(request) @@ -1570,12 +1676,12 @@ class ListFiltersTests(TestCase): """ Filter __in lookups with a custom divider. """ - jane = Employee.objects.create(name='Jane,Green', department=self.design) + jane = Employee.objects.create(name="Jane,Green", department=self.design) modeladmin = EmployeeCustomDividerFilterAdmin(Employee, site) employees = [jane, self.jack] request = self.request_factory.get( - '/', {'name__in': "|".join(e.name for e in employees)} + "/", {"name__in": "|".join(e.name for e in employees)} ) # test for lookup with custom divider request.user = self.alfred @@ -1585,7 +1691,7 @@ class ListFiltersTests(TestCase): self.assertEqual(list(queryset), employees) # test for lookup with comma in the lookup string - request = self.request_factory.get('/', {'name': jane.name}) + request = self.request_factory.get("/", {"name": jane.name}) request.user = self.alfred changelist = modeladmin.get_changelist_instance(request) # Make sure the correct queryset is returned diff --git a/tests/admin_inlines/admin.py b/tests/admin_inlines/admin.py index 116556db7e..0ec56d71b3 100644 --- a/tests/admin_inlines/admin.py +++ b/tests/admin_inlines/admin.py @@ -4,16 +4,57 @@ from django.core.exceptions import ValidationError from django.db import models from .models import ( - Author, BinaryTree, CapoFamiglia, Chapter, Child, ChildModel1, ChildModel2, - Class, Consigliere, Course, CourseProxy, CourseProxy1, CourseProxy2, - EditablePKBook, ExtraTerrestrial, Fashionista, FootNote, Holder, Holder2, - Holder3, Holder4, Holder5, Inner, Inner2, Inner3, Inner4Stacked, - Inner4Tabular, Inner5Stacked, Inner5Tabular, NonAutoPKBook, - NonAutoPKBookChild, Novel, NovelReadonlyChapter, OutfitItem, - ParentModelWithCustomPk, Person, Poll, Profile, ProfileCollection, - Question, ReadOnlyInline, ShoppingWeakness, ShowInlineChild, - ShowInlineParent, Sighting, SomeChildModel, SomeParentModel, SottoCapo, - Teacher, Title, TitleCollection, + Author, + BinaryTree, + CapoFamiglia, + Chapter, + Child, + ChildModel1, + ChildModel2, + Class, + Consigliere, + Course, + CourseProxy, + CourseProxy1, + CourseProxy2, + EditablePKBook, + ExtraTerrestrial, + Fashionista, + FootNote, + Holder, + Holder2, + Holder3, + Holder4, + Holder5, + Inner, + Inner2, + Inner3, + Inner4Stacked, + Inner4Tabular, + Inner5Stacked, + Inner5Tabular, + NonAutoPKBook, + NonAutoPKBookChild, + Novel, + NovelReadonlyChapter, + OutfitItem, + ParentModelWithCustomPk, + Person, + Poll, + Profile, + ProfileCollection, + Question, + ReadOnlyInline, + ShoppingWeakness, + ShowInlineChild, + ShowInlineParent, + Sighting, + SomeChildModel, + SomeParentModel, + SottoCapo, + Teacher, + Title, + TitleCollection, ) site = admin.AdminSite(name="admin") @@ -25,17 +66,17 @@ class BookInline(admin.TabularInline): class NonAutoPKBookTabularInline(admin.TabularInline): model = NonAutoPKBook - classes = ('collapse',) + classes = ("collapse",) class NonAutoPKBookChildTabularInline(admin.TabularInline): model = NonAutoPKBookChild - classes = ('collapse',) + classes = ("collapse",) class NonAutoPKBookStackedInline(admin.StackedInline): model = NonAutoPKBook - classes = ('collapse',) + classes = ("collapse",) class EditablePKBookTabularInline(admin.TabularInline): @@ -48,8 +89,11 @@ class EditablePKBookStackedInline(admin.StackedInline): class AuthorAdmin(admin.ModelAdmin): inlines = [ - BookInline, NonAutoPKBookTabularInline, NonAutoPKBookStackedInline, - EditablePKBookTabularInline, EditablePKBookStackedInline, + BookInline, + NonAutoPKBookTabularInline, + NonAutoPKBookStackedInline, + EditablePKBookTabularInline, + EditablePKBookStackedInline, NonAutoPKBookChildTabularInline, ] @@ -57,25 +101,24 @@ class AuthorAdmin(admin.ModelAdmin): class InnerInline(admin.StackedInline): model = Inner can_delete = False - readonly_fields = ('readonly',) # For bug #13174 tests. + readonly_fields = ("readonly",) # For bug #13174 tests. class HolderAdmin(admin.ModelAdmin): - class Media: - js = ('my_awesome_admin_scripts.js',) + js = ("my_awesome_admin_scripts.js",) class ReadOnlyInlineInline(admin.TabularInline): model = ReadOnlyInline - readonly_fields = ['name'] + readonly_fields = ["name"] class InnerInline2(admin.StackedInline): model = Inner2 class Media: - js = ('my_awesome_inline_scripts.js',) + js = ("my_awesome_inline_scripts.js",) class InnerInline2Tabular(admin.TabularInline): @@ -84,17 +127,17 @@ class InnerInline2Tabular(admin.TabularInline): class CustomNumberWidget(forms.NumberInput): class Media: - js = ('custom_number.js',) + js = ("custom_number.js",) class InnerInline3(admin.StackedInline): model = Inner3 formfield_overrides = { - models.IntegerField: {'widget': CustomNumberWidget}, + models.IntegerField: {"widget": CustomNumberWidget}, } class Media: - js = ('my_awesome_inline_scripts.js',) + js = ("my_awesome_inline_scripts.js",) class TitleForm(forms.ModelForm): @@ -131,12 +174,12 @@ class Holder4Admin(admin.ModelAdmin): class Inner5StackedInline(admin.StackedInline): model = Inner5Stacked - classes = ('collapse',) + classes = ("collapse",) class Inner5TabularInline(admin.TabularInline): model = Inner5Tabular - classes = ('collapse',) + classes = ("collapse",) class Holder5Admin(admin.ModelAdmin): @@ -153,7 +196,7 @@ class WeaknessForm(forms.ModelForm): class Meta: model = ShoppingWeakness - fields = '__all__' + fields = "__all__" class WeaknessInlineCustomForm(admin.TabularInline): @@ -166,7 +209,7 @@ class FootNoteForm(forms.ModelForm): class Meta: model = FootNote - fields = '__all__' + fields = "__all__" class FootNoteNonEditableInlineCustomForm(admin.TabularInline): @@ -179,25 +222,25 @@ class FootNoteNonEditableInlineCustomForm(admin.TabularInline): class QuestionInline(admin.TabularInline): model = Question - readonly_fields = ['call_me'] + readonly_fields = ["call_me"] def call_me(self, obj): - return 'Callable in QuestionInline' + return "Callable in QuestionInline" class PollAdmin(admin.ModelAdmin): inlines = [QuestionInline] def call_me(self, obj): - return 'Callable in PollAdmin' + return "Callable in PollAdmin" class ChapterInline(admin.TabularInline): model = Chapter - readonly_fields = ['call_me'] + readonly_fields = ["call_me"] def call_me(self, obj): - return 'Callable in ChapterInline' + return "Callable in ChapterInline" class NovelAdmin(admin.ModelAdmin): @@ -261,32 +304,31 @@ class SightingInline(admin.TabularInline): # admin and form for #18263 class SomeChildModelForm(forms.ModelForm): - class Meta: - fields = '__all__' + fields = "__all__" model = SomeChildModel widgets = { - 'position': forms.HiddenInput, + "position": forms.HiddenInput, } - labels = {'readonly_field': 'Label from ModelForm.Meta'} - help_texts = {'readonly_field': 'Help text from ModelForm.Meta'} + labels = {"readonly_field": "Label from ModelForm.Meta"} + help_texts = {"readonly_field": "Help text from ModelForm.Meta"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields['name'].label = 'new label' + self.fields["name"].label = "new label" class SomeChildModelInline(admin.TabularInline): model = SomeChildModel form = SomeChildModelForm - readonly_fields = ('readonly_field',) + readonly_fields = ("readonly_field",) class StudentInline(admin.StackedInline): model = Child extra = 1 fieldsets = [ - ('Name', {'fields': ('name',), 'classes': ('collapse',)}), + ("Name", {"fields": ("name",), "classes": ("collapse",)}), ] @@ -306,7 +348,7 @@ class FashonistaStackedInline(admin.StackedInline): class ClassStackedHorizontal(admin.StackedInline): model = Class extra = 1 - filter_horizontal = ['person'] + filter_horizontal = ["person"] class ClassAdminStackedHorizontal(admin.ModelAdmin): @@ -316,7 +358,7 @@ class ClassAdminStackedHorizontal(admin.ModelAdmin): class ClassTabularHorizontal(admin.TabularInline): model = Class extra = 1 - filter_horizontal = ['person'] + filter_horizontal = ["person"] class ClassAdminTabularHorizontal(admin.ModelAdmin): @@ -326,7 +368,7 @@ class ClassAdminTabularHorizontal(admin.ModelAdmin): class ClassTabularVertical(admin.TabularInline): model = Class extra = 1 - filter_vertical = ['person'] + filter_vertical = ["person"] class ClassAdminTabularVertical(admin.ModelAdmin): @@ -336,7 +378,7 @@ class ClassAdminTabularVertical(admin.ModelAdmin): class ClassStackedVertical(admin.StackedInline): model = Class extra = 1 - filter_vertical = ['person'] + filter_vertical = ["person"] class ClassAdminStackedVertical(admin.ModelAdmin): @@ -346,13 +388,13 @@ class ClassAdminStackedVertical(admin.ModelAdmin): class ChildHiddenFieldForm(forms.ModelForm): class Meta: model = SomeChildModel - fields = ['name', 'position', 'parent'] - widgets = {'position': forms.HiddenInput} + fields = ["name", "position", "parent"] + widgets = {"position": forms.HiddenInput} def _post_clean(self): super()._post_clean() if self.instance is not None and self.instance.position == 1: - self.add_error(None, ValidationError('A non-field error')) + self.add_error(None, ValidationError("A non-field error")) class ChildHiddenFieldTabularInline(admin.TabularInline): @@ -363,13 +405,13 @@ class ChildHiddenFieldTabularInline(admin.TabularInline): class ChildHiddenFieldInFieldsGroupStackedInline(admin.StackedInline): model = SomeChildModel form = ChildHiddenFieldForm - fields = [('name', 'position')] + fields = [("name", "position")] class ChildHiddenFieldOnSingleLineStackedInline(admin.StackedInline): model = SomeChildModel form = ChildHiddenFieldForm - fields = ('name', 'position') + fields = ("name", "position") class ShowInlineChildInline(admin.StackedInline): @@ -399,7 +441,9 @@ site.register(Fashionista, inlines=[InlineWeakness]) site.register(Holder4, Holder4Admin) site.register(Holder5, Holder5Admin) site.register(Author, AuthorAdmin) -site.register(CapoFamiglia, inlines=[ConsigliereInline, SottoCapoInline, ReadOnlyInlineInline]) +site.register( + CapoFamiglia, inlines=[ConsigliereInline, SottoCapoInline, ReadOnlyInlineInline] +) site.register(ProfileCollection, inlines=[ProfileInline]) site.register(ParentModelWithCustomPk, inlines=[ChildModel1Inline, ChildModel2Inline]) site.register(BinaryTree, inlines=[BinaryTreeAdmin]) @@ -416,9 +460,9 @@ site.register(CourseProxy1, ClassAdminTabularVertical) site.register(CourseProxy2, ClassAdminTabularHorizontal) site.register(ShowInlineParent, ShowInlineParentAdmin) # Used to test hidden fields in tabular and stacked inlines. -site2 = admin.AdminSite(name='tabular_inline_hidden_field_admin') +site2 = admin.AdminSite(name="tabular_inline_hidden_field_admin") site2.register(SomeParentModel, inlines=[ChildHiddenFieldTabularInline]) -site3 = admin.AdminSite(name='stacked_inline_hidden_field_in_group_admin') +site3 = admin.AdminSite(name="stacked_inline_hidden_field_in_group_admin") site3.register(SomeParentModel, inlines=[ChildHiddenFieldInFieldsGroupStackedInline]) -site4 = admin.AdminSite(name='stacked_inline_hidden_field_on_single_line_admin') +site4 = admin.AdminSite(name="stacked_inline_hidden_field_on_single_line_admin") site4.register(SomeParentModel, inlines=[ChildHiddenFieldOnSingleLineStackedInline]) diff --git a/tests/admin_inlines/models.py b/tests/admin_inlines/models.py index 6125e06c28..47c5b91828 100644 --- a/tests/admin_inlines/models.py +++ b/tests/admin_inlines/models.py @@ -31,7 +31,7 @@ class Child(models.Model): parent = GenericForeignKey() def __str__(self): - return 'I am %s, a child of %s' % (self.name, self.parent) + return "I am %s, a child of %s" % (self.name, self.parent) class Book(models.Model): @@ -44,7 +44,7 @@ class Book(models.Model): class Author(models.Model): name = models.CharField(max_length=50) books = models.ManyToManyField(Book) - person = models.OneToOneField('Person', models.CASCADE, null=True) + person = models.OneToOneField("Person", models.CASCADE, null=True) class NonAutoPKBook(models.Model): @@ -80,7 +80,7 @@ class Inner(models.Model): readonly = models.CharField("Inner readonly label", max_length=1) def get_absolute_url(self): - return '/inner/' + return "/inner/" class Holder2(models.Model): @@ -100,6 +100,7 @@ class Inner3(models.Model): dummy = models.IntegerField() holder = models.ForeignKey(Holder3, models.CASCADE) + # Models for ticket #8190 @@ -113,7 +114,9 @@ class Inner4Stacked(models.Model): class Meta: constraints = [ - models.UniqueConstraint(fields=['dummy', 'holder'], name='unique_stacked_dummy_per_holder') + models.UniqueConstraint( + fields=["dummy", "holder"], name="unique_stacked_dummy_per_holder" + ) ] @@ -123,9 +126,12 @@ class Inner4Tabular(models.Model): class Meta: constraints = [ - models.UniqueConstraint(fields=['dummy', 'holder'], name='unique_tabular_dummy_per_holder') + models.UniqueConstraint( + fields=["dummy", "holder"], name="unique_tabular_dummy_per_holder" + ) ] + # Models for ticket #31441 @@ -135,7 +141,7 @@ class Holder5(models.Model): class Inner5Stacked(models.Model): name = models.CharField(max_length=10) - select = models.CharField(choices=(('1', 'One'), ('2', 'Two')), max_length=10) + select = models.CharField(choices=(("1", "One"), ("2", "Two")), max_length=10) text = models.TextField() dummy = models.IntegerField() holder = models.ForeignKey(Holder5, models.CASCADE) @@ -143,7 +149,7 @@ class Inner5Stacked(models.Model): class Inner5Tabular(models.Model): name = models.CharField(max_length=10) - select = models.CharField(choices=(('1', 'One'), ('2', 'Two')), max_length=10) + select = models.CharField(choices=(("1", "One"), ("2", "Two")), max_length=10) text = models.TextField() dummy = models.IntegerField() holder = models.ForeignKey(Holder5, models.CASCADE) @@ -162,13 +168,16 @@ class OutfitItem(models.Model): class Fashionista(models.Model): person = models.OneToOneField(Person, models.CASCADE, primary_key=True) - weaknesses = models.ManyToManyField(OutfitItem, through='ShoppingWeakness', blank=True) + weaknesses = models.ManyToManyField( + OutfitItem, through="ShoppingWeakness", blank=True + ) class ShoppingWeakness(models.Model): fashionista = models.ForeignKey(Fashionista, models.CASCADE) item = models.ForeignKey(OutfitItem, models.CASCADE) + # Models for #13510 @@ -177,10 +186,13 @@ class TitleCollection(models.Model): class Title(models.Model): - collection = models.ForeignKey(TitleCollection, models.SET_NULL, blank=True, null=True) + collection = models.ForeignKey( + TitleCollection, models.SET_NULL, blank=True, null=True + ) title1 = models.CharField(max_length=100) title2 = models.CharField(max_length=100) + # Models for #15424 @@ -198,7 +210,6 @@ class Novel(models.Model): class NovelReadonlyChapter(Novel): - class Meta: proxy = True @@ -212,9 +223,11 @@ class FootNote(models.Model): """ Model added for ticket 19838 """ + chapter = models.ForeignKey(Chapter, models.PROTECT) note = models.CharField(max_length=40) + # Models for #16838 @@ -223,22 +236,23 @@ class CapoFamiglia(models.Model): class Consigliere(models.Model): - name = models.CharField(max_length=100, help_text='Help text for Consigliere') - capo_famiglia = models.ForeignKey(CapoFamiglia, models.CASCADE, related_name='+') + name = models.CharField(max_length=100, help_text="Help text for Consigliere") + capo_famiglia = models.ForeignKey(CapoFamiglia, models.CASCADE, related_name="+") class SottoCapo(models.Model): name = models.CharField(max_length=100) - capo_famiglia = models.ForeignKey(CapoFamiglia, models.CASCADE, related_name='+') + capo_famiglia = models.ForeignKey(CapoFamiglia, models.CASCADE, related_name="+") class ReadOnlyInline(models.Model): - name = models.CharField(max_length=100, help_text='Help text for ReadOnlyInline') + name = models.CharField(max_length=100, help_text="Help text for ReadOnlyInline") capo_famiglia = models.ForeignKey(CapoFamiglia, models.CASCADE) # Models for #18433 + class ParentModelWithCustomPk(models.Model): my_own_pk = models.CharField(max_length=100, primary_key=True) name = models.CharField(max_length=100) @@ -250,7 +264,7 @@ class ChildModel1(models.Model): parent = models.ForeignKey(ParentModelWithCustomPk, models.CASCADE) def get_absolute_url(self): - return '/child_model1/' + return "/child_model1/" class ChildModel2(models.Model): @@ -259,13 +273,14 @@ class ChildModel2(models.Model): parent = models.ForeignKey(ParentModelWithCustomPk, models.CASCADE) def get_absolute_url(self): - return '/child_model2/' + return "/child_model2/" # Models for #19425 class BinaryTree(models.Model): name = models.CharField(max_length=100) - parent = models.ForeignKey('self', models.SET_NULL, null=True, blank=True) + parent = models.ForeignKey("self", models.SET_NULL, null=True, blank=True) + # Models for #19524 @@ -304,24 +319,21 @@ class Course(models.Model): class Class(models.Model): - person = models.ManyToManyField(Person, verbose_name='attendant') + person = models.ManyToManyField(Person, verbose_name="attendant") course = models.ForeignKey(Course, on_delete=models.CASCADE) class CourseProxy(Course): - class Meta: proxy = True class CourseProxy1(Course): - class Meta: proxy = True class CourseProxy2(Course): - class Meta: proxy = True @@ -340,22 +352,24 @@ class ProfileCollection(models.Model): class Profile(models.Model): - collection = models.ForeignKey(ProfileCollection, models.SET_NULL, blank=True, null=True) + collection = models.ForeignKey( + ProfileCollection, models.SET_NULL, blank=True, null=True + ) first_name = models.CharField(max_length=100) last_name = models.CharField(max_length=100) class VerboseNameProfile(Profile): class Meta: - verbose_name = 'Model with verbose name only' + verbose_name = "Model with verbose name only" class VerboseNamePluralProfile(Profile): class Meta: - verbose_name_plural = 'Model with verbose name plural only' + verbose_name_plural = "Model with verbose name plural only" class BothVerboseNameProfile(Profile): class Meta: - verbose_name = 'Model with both - name' - verbose_name_plural = 'Model with both - plural name' + verbose_name = "Model with both - name" + verbose_name_plural = "Model with both - plural name" diff --git a/tests/admin_inlines/test_templates.py b/tests/admin_inlines/test_templates.py index 5c8d7ce0c1..9ed4d65944 100644 --- a/tests/admin_inlines/test_templates.py +++ b/tests/admin_inlines/test_templates.py @@ -7,17 +7,19 @@ from django.test import SimpleTestCase class TestTemplates(SimpleTestCase): def test_javascript_escaping(self): context = { - 'inline_admin_formset': { - 'inline_formset_data': json.dumps({ - 'formset': {'prefix': 'my-prefix'}, - 'opts': {'verbose_name': 'verbose name\\'}, - }), + "inline_admin_formset": { + "inline_formset_data": json.dumps( + { + "formset": {"prefix": "my-prefix"}, + "opts": {"verbose_name": "verbose name\\"}, + } + ), }, } - output = render_to_string('admin/edit_inline/stacked.html', context) - self.assertIn('"prefix": "my-prefix"', output) - self.assertIn('"verbose_name": "verbose name\\\\"', output) + output = render_to_string("admin/edit_inline/stacked.html", context) + self.assertIn(""prefix": "my-prefix"", output) + self.assertIn(""verbose_name": "verbose name\\\\"", output) - output = render_to_string('admin/edit_inline/tabular.html', context) - self.assertIn('"prefix": "my-prefix"', output) - self.assertIn('"verbose_name": "verbose name\\\\"', output) + output = render_to_string("admin/edit_inline/tabular.html", context) + self.assertIn(""prefix": "my-prefix"", output) + self.assertIn(""verbose_name": "verbose name\\\\"", output) diff --git a/tests/admin_inlines/tests.py b/tests/admin_inlines/tests.py index 54bbec28aa..7db86f5663 100644 --- a/tests/admin_inlines/tests.py +++ b/tests/admin_inlines/tests.py @@ -6,27 +6,58 @@ from django.contrib.contenttypes.models import ContentType from django.test import RequestFactory, TestCase, override_settings from django.urls import reverse -from .admin import InnerInline, site as admin_site +from .admin import InnerInline +from .admin import site as admin_site from .models import ( - Author, BinaryTree, Book, BothVerboseNameProfile, Chapter, Child, - ChildModel1, ChildModel2, Fashionista, FootNote, Holder, Holder2, Holder3, - Holder4, Inner, Inner2, Inner3, Inner4Stacked, Inner4Tabular, Novel, - OutfitItem, Parent, ParentModelWithCustomPk, Person, Poll, Profile, - ProfileCollection, Question, ShowInlineParent, Sighting, SomeChildModel, - SomeParentModel, Teacher, VerboseNamePluralProfile, VerboseNameProfile, + Author, + BinaryTree, + Book, + BothVerboseNameProfile, + Chapter, + Child, + ChildModel1, + ChildModel2, + Fashionista, + FootNote, + Holder, + Holder2, + Holder3, + Holder4, + Inner, + Inner2, + Inner3, + Inner4Stacked, + Inner4Tabular, + Novel, + OutfitItem, + Parent, + ParentModelWithCustomPk, + Person, + Poll, + Profile, + ProfileCollection, + Question, + ShowInlineParent, + Sighting, + SomeChildModel, + SomeParentModel, + Teacher, + VerboseNamePluralProfile, + VerboseNameProfile, ) INLINE_CHANGELINK_HTML = 'class="inlinechangelink">Change' class TestDataMixin: - @classmethod def setUpTestData(cls): - cls.superuser = User.objects.create_superuser(username='super', email='super@example.com', password='secret') + cls.superuser = User.objects.create_superuser( + username="super", email="super@example.com", password="secret" + ) -@override_settings(ROOT_URLCONF='admin_inlines.urls') +@override_settings(ROOT_URLCONF="admin_inlines.urls") class TestInline(TestDataMixin, TestCase): factory = RequestFactory() @@ -36,22 +67,24 @@ class TestInline(TestDataMixin, TestCase): cls.holder = Holder.objects.create(dummy=13) Inner.objects.create(dummy=42, holder=cls.holder) - cls.parent = SomeParentModel.objects.create(name='a') - SomeChildModel.objects.create(name='b', position='0', parent=cls.parent) - SomeChildModel.objects.create(name='c', position='1', parent=cls.parent) + cls.parent = SomeParentModel.objects.create(name="a") + SomeChildModel.objects.create(name="b", position="0", parent=cls.parent) + SomeChildModel.objects.create(name="c", position="1", parent=cls.parent) cls.view_only_user = User.objects.create_user( - username='user', password='pwd', is_staff=True, + username="user", + password="pwd", + is_staff=True, ) parent_ct = ContentType.objects.get_for_model(SomeParentModel) child_ct = ContentType.objects.get_for_model(SomeChildModel) permission = Permission.objects.get( - codename='view_someparentmodel', + codename="view_someparentmodel", content_type=parent_ct, ) cls.view_only_user.user_permissions.add(permission) permission = Permission.objects.get( - codename='view_somechildmodel', + codename="view_somechildmodel", content_type=child_ct, ) cls.view_only_user.user_permissions.add(permission) @@ -64,61 +97,65 @@ class TestInline(TestDataMixin, TestCase): can_delete should be passed to inlineformset factory. """ response = self.client.get( - reverse('admin:admin_inlines_holder_change', args=(self.holder.id,)) + reverse("admin:admin_inlines_holder_change", args=(self.holder.id,)) ) - inner_formset = response.context['inline_admin_formsets'][0].formset + inner_formset = response.context["inline_admin_formsets"][0].formset expected = InnerInline.can_delete actual = inner_formset.can_delete - self.assertEqual(expected, actual, 'can_delete must be equal') + self.assertEqual(expected, actual, "can_delete must be equal") def test_readonly_stacked_inline_label(self): """Bug #13174.""" holder = Holder.objects.create(dummy=42) - Inner.objects.create(holder=holder, dummy=42, readonly='') + Inner.objects.create(holder=holder, dummy=42, readonly="") response = self.client.get( - reverse('admin:admin_inlines_holder_change', args=(holder.id,)) + reverse("admin:admin_inlines_holder_change", args=(holder.id,)) ) - self.assertContains(response, '') + self.assertContains(response, "") def test_many_to_many_inlines(self): "Autogenerated many-to-many inlines are displayed correctly (#13407)" - response = self.client.get(reverse('admin:admin_inlines_author_add')) + response = self.client.get(reverse("admin:admin_inlines_author_add")) # The heading for the m2m inline block uses the right text - self.assertContains(response, '

Author-book relationships

') + self.assertContains(response, "

Author-book relationships

") # The "add another" label is correct - self.assertContains(response, 'Add another Author-book relationship') + self.assertContains(response, "Add another Author-book relationship") # The '+' is dropped from the autogenerated form prefix (Author_books+) self.assertContains(response, 'id="id_Author_books-TOTAL_FORMS"') def test_inline_primary(self): - person = Person.objects.create(firstname='Imelda') - item = OutfitItem.objects.create(name='Shoes') + person = Person.objects.create(firstname="Imelda") + item = OutfitItem.objects.create(name="Shoes") # Imelda likes shoes, but can't carry her own bags. data = { - 'shoppingweakness_set-TOTAL_FORMS': 1, - 'shoppingweakness_set-INITIAL_FORMS': 0, - 'shoppingweakness_set-MAX_NUM_FORMS': 0, - '_save': 'Save', - 'person': person.id, - 'max_weight': 0, - 'shoppingweakness_set-0-item': item.id, + "shoppingweakness_set-TOTAL_FORMS": 1, + "shoppingweakness_set-INITIAL_FORMS": 0, + "shoppingweakness_set-MAX_NUM_FORMS": 0, + "_save": "Save", + "person": person.id, + "max_weight": 0, + "shoppingweakness_set-0-item": item.id, } - response = self.client.post(reverse('admin:admin_inlines_fashionista_add'), data) + response = self.client.post( + reverse("admin:admin_inlines_fashionista_add"), data + ) self.assertEqual(response.status_code, 302) - self.assertEqual(len(Fashionista.objects.filter(person__firstname='Imelda')), 1) + self.assertEqual(len(Fashionista.objects.filter(person__firstname="Imelda")), 1) def test_tabular_inline_column_css_class(self): """ Field names are included in the context to output a field-specific CSS class name in the column headers. """ - response = self.client.get(reverse('admin:admin_inlines_poll_add')) - text_field, call_me_field = list(response.context['inline_admin_formset'].fields()) + response = self.client.get(reverse("admin:admin_inlines_poll_add")) + text_field, call_me_field = list( + response.context["inline_admin_formset"].fields() + ) # Editable field. - self.assertEqual(text_field['name'], 'text') + self.assertEqual(text_field["name"], "text") self.assertContains(response, '') # Read-only field. - self.assertEqual(call_me_field['name'], 'call_me') + self.assertEqual(call_me_field["name"], "call_me") self.assertContains(response, '') def test_custom_form_tabular_inline_label(self): @@ -126,28 +163,32 @@ class TestInline(TestDataMixin, TestCase): A model form with a form field specified (TitleForm.title1) should have its label rendered in the tabular inline. """ - response = self.client.get(reverse('admin:admin_inlines_titlecollection_add')) - self.assertContains(response, 'Title1', html=True) + response = self.client.get(reverse("admin:admin_inlines_titlecollection_add")) + self.assertContains( + response, 'Title1', html=True + ) def test_custom_form_tabular_inline_extra_field_label(self): - response = self.client.get(reverse('admin:admin_inlines_outfititem_add')) - _, extra_field = list(response.context['inline_admin_formset'].fields()) - self.assertEqual(extra_field['label'], 'Extra field') + response = self.client.get(reverse("admin:admin_inlines_outfititem_add")) + _, extra_field = list(response.context["inline_admin_formset"].fields()) + self.assertEqual(extra_field["label"], "Extra field") def test_non_editable_custom_form_tabular_inline_extra_field_label(self): - response = self.client.get(reverse('admin:admin_inlines_chapter_add')) - _, extra_field = list(response.context['inline_admin_formset'].fields()) - self.assertEqual(extra_field['label'], 'Extra field') + response = self.client.get(reverse("admin:admin_inlines_chapter_add")) + _, extra_field = list(response.context["inline_admin_formset"].fields()) + self.assertEqual(extra_field["label"], "Extra field") def test_custom_form_tabular_inline_overridden_label(self): """ SomeChildModelForm.__init__() overrides the label of a form field. That label is displayed in the TabularInline. """ - response = self.client.get(reverse('admin:admin_inlines_someparentmodel_add')) - field = list(response.context['inline_admin_formset'].fields())[0] - self.assertEqual(field['label'], 'new label') - self.assertContains(response, 'New label', html=True) + response = self.client.get(reverse("admin:admin_inlines_someparentmodel_add")) + field = list(response.context["inline_admin_formset"].fields())[0] + self.assertEqual(field["label"], "new label") + self.assertContains( + response, 'New label', html=True + ) def test_tabular_non_field_errors(self): """ @@ -155,68 +196,72 @@ class TestInline(TestDataMixin, TestCase): for colspan. """ data = { - 'title_set-TOTAL_FORMS': 1, - 'title_set-INITIAL_FORMS': 0, - 'title_set-MAX_NUM_FORMS': 0, - '_save': 'Save', - 'title_set-0-title1': 'a title', - 'title_set-0-title2': 'a different title', + "title_set-TOTAL_FORMS": 1, + "title_set-INITIAL_FORMS": 0, + "title_set-MAX_NUM_FORMS": 0, + "_save": "Save", + "title_set-0-title1": "a title", + "title_set-0-title2": "a different title", } - response = self.client.post(reverse('admin:admin_inlines_titlecollection_add'), data) + response = self.client.post( + reverse("admin:admin_inlines_titlecollection_add"), data + ) # Here colspan is "4": two fields (title1 and title2), one hidden field and the delete checkbox. self.assertContains( response, '
    ' - '
  • The two titles must be the same
' + "
  • The two titles must be the same
  • ", ) def test_no_parent_callable_lookup(self): """Admin inline `readonly_field` shouldn't invoke parent ModelAdmin callable""" # Identically named callable isn't present in the parent ModelAdmin, # rendering of the add view shouldn't explode - response = self.client.get(reverse('admin:admin_inlines_novel_add')) + response = self.client.get(reverse("admin:admin_inlines_novel_add")) # View should have the child inlines section self.assertContains( response, - '
    Callable in QuestionInline

    ') + self.assertContains(response, "

    Callable in QuestionInline

    ") def test_help_text(self): """ The inlines' model field help texts are displayed when using both the stacked and tabular layouts. """ - response = self.client.get(reverse('admin:admin_inlines_holder4_add')) - self.assertContains(response, '
    Awesome stacked help text is awesome.
    ', 4) + response = self.client.get(reverse("admin:admin_inlines_holder4_add")) + self.assertContains( + response, '
    Awesome stacked help text is awesome.
    ', 4 + ) self.assertContains( response, '', - 1 + 1, ) # ReadOnly fields - response = self.client.get(reverse('admin:admin_inlines_capofamiglia_add')) + response = self.client.get(reverse("admin:admin_inlines_capofamiglia_add")) self.assertContains( response, '', - 1 + 1, ) def test_tabular_model_form_meta_readonly_field(self): @@ -224,22 +269,24 @@ class TestInline(TestDataMixin, TestCase): Tabular inlines use ModelForm.Meta.help_texts and labels for read-only fields. """ - response = self.client.get(reverse('admin:admin_inlines_someparentmodel_add')) + response = self.client.get(reverse("admin:admin_inlines_someparentmodel_add")) self.assertContains( response, '' + 'title="Help text from ModelForm.Meta">', ) - self.assertContains(response, 'Label from ModelForm.Meta') + self.assertContains(response, "Label from ModelForm.Meta") def test_inline_hidden_field_no_column(self): """#18263 -- Make sure hidden fields don't get a column in tabular inlines""" - parent = SomeParentModel.objects.create(name='a') - SomeChildModel.objects.create(name='b', position='0', parent=parent) - SomeChildModel.objects.create(name='c', position='1', parent=parent) - response = self.client.get(reverse('admin:admin_inlines_someparentmodel_change', args=(parent.pk,))) + parent = SomeParentModel.objects.create(name="a") + SomeChildModel.objects.create(name="b", position="0", parent=parent) + SomeChildModel.objects.create(name="c", position="1", parent=parent) + response = self.client.get( + reverse("admin:admin_inlines_someparentmodel_change", args=(parent.pk,)) + ) self.assertNotContains(response, '') self.assertInHTML( 'Position', response.rendered_content) - self.assertInHTML('

    0

    ', response.rendered_content) - self.assertInHTML('

    1

    ', response.rendered_content) + self.assertInHTML( + 'Position', + response.rendered_content, + ) + self.assertInHTML( + '

    0

    ', response.rendered_content + ) + self.assertInHTML( + '

    1

    ', response.rendered_content + ) def test_stacked_inline_hidden_field_with_view_only_permissions(self): """ @@ -269,12 +323,14 @@ class TestInline(TestDataMixin, TestCase): """ self.client.force_login(self.view_only_user) url = reverse( - 'stacked_inline_hidden_field_in_group_admin:admin_inlines_someparentmodel_change', + "stacked_inline_hidden_field_in_group_admin:admin_inlines_someparentmodel_change", args=(self.parent.pk,), ) response = self.client.get(url) # The whole line containing name + position fields is not hidden. - self.assertContains(response, '
    ') + self.assertContains( + response, '
    ' + ) # The div containing the position field is hidden. self.assertInHTML( '