mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #31224 -- Added support for asynchronous views and middleware.
This implements support for asynchronous views, asynchronous tests, asynchronous middleware, and an asynchronous test client.
This commit is contained in:
committed by
Mariusz Felisiak
parent
3f7e4b16bf
commit
fc0fa72ff4
@@ -1,6 +1,8 @@
|
||||
"""Django Unit Test framework."""
|
||||
|
||||
from django.test.client import Client, RequestFactory
|
||||
from django.test.client import (
|
||||
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
|
||||
)
|
||||
from django.test.testcases import (
|
||||
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
|
||||
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||
@@ -11,8 +13,9 @@ from django.test.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase',
|
||||
'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature',
|
||||
'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings',
|
||||
'modify_settings', 'override_settings', 'override_system_checks', 'tag',
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
|
||||
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
|
||||
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
|
||||
'ignore_warnings', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
]
|
||||
|
||||
@@ -9,7 +9,10 @@ from importlib import import_module
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
@@ -157,6 +160,52 @@ class ClientHandler(BaseHandler):
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClientHandler(BaseHandler):
|
||||
"""An async version of ClientHandler."""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def __call__(self, scope):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware(is_async=True)
|
||||
# Extract body file from the scope, if provided.
|
||||
if '_body_file' in scope:
|
||||
body_file = scope.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
|
||||
request_started.connect(close_old_connections)
|
||||
request = ASGIRequest(scope, body_file)
|
||||
# Sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably required
|
||||
# for backwards compatibility with external tests against admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
# Request goes through middleware.
|
||||
response = await self.get_response_async(request)
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
# Attach the originating ASGI request to the response so that it could
|
||||
# be later retrieved.
|
||||
response.asgi_request = request
|
||||
# Emulate a server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
|
||||
response.streaming_content,
|
||||
response.close,
|
||||
)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
# Will fire request_finished.
|
||||
await sync_to_async(response.close)()
|
||||
request_finished.connect(close_old_connections)
|
||||
return response
|
||||
|
||||
|
||||
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
|
||||
"""
|
||||
Store templates and contexts that are rendered.
|
||||
@@ -421,7 +470,194 @@ class RequestFactory:
|
||||
return self.request(**r)
|
||||
|
||||
|
||||
class Client(RequestFactory):
|
||||
class AsyncRequestFactory(RequestFactory):
|
||||
"""
|
||||
Class that lets you create mock ASGI-like Request objects for use in
|
||||
testing. Usage:
|
||||
|
||||
rf = AsyncRequestFactory()
|
||||
get_request = await rf.get('/hello/')
|
||||
post_request = await rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
including synchronous ones. The reason we have a separate class here is:
|
||||
a) this makes ASGIRequest subclasses, and
|
||||
b) AsyncTestClient can subclass it.
|
||||
"""
|
||||
def _base_scope(self, **request):
|
||||
"""The base scope for a request."""
|
||||
# This is a minimal valid ASGI scope, plus:
|
||||
# - headers['cookie'] for cookie support,
|
||||
# - 'client' often useful, see #8551.
|
||||
scope = {
|
||||
'asgi': {'version': '3.0'},
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'client': ['127.0.0.1', 0],
|
||||
'server': ('testserver', '80'),
|
||||
'scheme': 'http',
|
||||
'method': 'GET',
|
||||
'headers': [],
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
scope['headers'].append((
|
||||
b'cookie',
|
||||
b'; '.join(sorted(
|
||||
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
))
|
||||
return scope
|
||||
|
||||
def request(self, **request):
|
||||
"""Construct a generic request object."""
|
||||
# This is synchronous, which means all methods on this class are.
|
||||
# AsyncClient, however, has an async request function, which makes all
|
||||
# its methods async.
|
||||
if '_body_file' in request:
|
||||
body_file = request.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
return ASGIRequest(self._base_scope(**request), body_file)
|
||||
|
||||
def generic(
|
||||
self, method, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy.
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
s = {
|
||||
'method': method,
|
||||
'path': self._get_path(parsed),
|
||||
'server': ('127.0.0.1', '443' if secure else '80'),
|
||||
'scheme': 'https' if secure else 'http',
|
||||
'headers': [(b'host', b'testserver')],
|
||||
}
|
||||
if data:
|
||||
s['headers'].extend([
|
||||
(b'content-length', bytes(len(data))),
|
||||
(b'content-type', content_type.encode('ascii')),
|
||||
])
|
||||
s['_body_file'] = FakePayload(data)
|
||||
s.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the
|
||||
# URL.
|
||||
if not s.get('query_string'):
|
||||
s['query_string'] = parsed[4]
|
||||
return self.request(**s)
|
||||
|
||||
|
||||
class ClientMixin:
|
||||
"""
|
||||
Mixin with common methods between Client and AsyncClient.
|
||||
"""
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
def check_exception(self, response):
|
||||
"""
|
||||
Look for a signaled exception, clear the current context exception
|
||||
data, re-raise the signaled exception, and clear the signaled exception
|
||||
from the local cache.
|
||||
"""
|
||||
response.exc_info = self.exc_info
|
||||
if self.exc_info:
|
||||
_, exc_value, _ = self.exc_info
|
||||
self.exc_info = None
|
||||
if self.raise_request_exception:
|
||||
raise exc_value
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Return the current session variables."""
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if cookie:
|
||||
return engine.SessionStore(cookie.value)
|
||||
session = engine.SessionStore()
|
||||
session.save()
|
||||
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
return session
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
Set the Factory to appear as if it has successfully logged into a site.
|
||||
|
||||
Return True if login is possible or False if the provided credentials
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
|
||||
if backend is None:
|
||||
backend = get_backend()
|
||||
user.backend = backend
|
||||
self._login(user, backend)
|
||||
|
||||
def _login(self, user, backend=None):
|
||||
from django.contrib.auth import login
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
login(request, user, backend)
|
||||
# Save the session values.
|
||||
request.session.save()
|
||||
# Set the cookie to represent the session.
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
request.user = get_user(request)
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
logout(request)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "%s", not "application/json"'
|
||||
% response.get('Content-Type')
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
|
||||
class Client(ClientMixin, RequestFactory):
|
||||
"""
|
||||
A class that can act as a client for testing purposes.
|
||||
|
||||
@@ -446,23 +682,6 @@ class Client(RequestFactory):
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Return the current session variables."""
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if cookie:
|
||||
return engine.SessionStore(cookie.value)
|
||||
|
||||
session = engine.SessionStore()
|
||||
session.save()
|
||||
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
return session
|
||||
|
||||
def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the environment dictionary and pass
|
||||
@@ -486,15 +705,8 @@ class Client(RequestFactory):
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Look for a signaled exception, clear the current context exception
|
||||
# data, then re-raise the signaled exception. Also clear the signaled
|
||||
# exception from the local cache.
|
||||
response.exc_info = self.exc_info
|
||||
if self.exc_info:
|
||||
_, exc_value, _ = self.exc_info
|
||||
self.exc_info = None
|
||||
if self.raise_request_exception:
|
||||
raise exc_value
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
@@ -583,85 +795,6 @@ class Client(RequestFactory):
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
Set the Factory to appear as if it has successfully logged into a site.
|
||||
|
||||
Return True if login is possible; False if the provided credentials
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
if backend is None:
|
||||
backend = get_backend()
|
||||
user.backend = backend
|
||||
self._login(user, backend)
|
||||
|
||||
def _login(self, user, backend=None):
|
||||
from django.contrib.auth import login
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
else:
|
||||
request.session = engine.SessionStore()
|
||||
login(request, user, backend)
|
||||
|
||||
# Save the session values.
|
||||
request.session.save()
|
||||
|
||||
# Set the cookie to represent the session.
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
|
||||
request = HttpRequest()
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
request.user = get_user(request)
|
||||
else:
|
||||
request.session = engine.SessionStore()
|
||||
logout(request)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "{}", not "application/json"'
|
||||
.format(response.get('Content-Type'))
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
def _handle_redirects(self, response, data='', content_type='', **extra):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
@@ -714,3 +847,66 @@ class Client(RequestFactory):
|
||||
raise RedirectCycleError("Too many redirects.", last_response=response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
"""
|
||||
An async version of Client that creates ASGIRequests and calls through an
|
||||
async request path.
|
||||
|
||||
Does not currently support "follow" on its methods.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = AsyncClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
async def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the scope dictionary and pass to the
|
||||
handler, return the result of the handler. Assume defaults for the
|
||||
query environment, which can be overridden using the arguments to the
|
||||
request.
|
||||
"""
|
||||
if 'follow' in request:
|
||||
raise NotImplementedError(
|
||||
'AsyncClient request methods do not accept the follow '
|
||||
'parameter.'
|
||||
)
|
||||
scope = self._base_scope(**request)
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = 'template-render-%s' % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = 'request-exception-%s' % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = await self.handler(scope)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
|
||||
@@ -33,7 +33,7 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
||||
from django.forms.fields import CharField
|
||||
from django.http import QueryDict
|
||||
from django.http.request import split_domain_port, validate_host
|
||||
from django.test.client import Client
|
||||
from django.test.client import AsyncClient, Client
|
||||
from django.test.html import HTMLParseError, parse_html
|
||||
from django.test.signals import setting_changed, template_rendered
|
||||
from django.test.utils import (
|
||||
@@ -151,6 +151,7 @@ class SimpleTestCase(unittest.TestCase):
|
||||
# The class we'll use for the test client self.client.
|
||||
# Can be overridden in derived classes.
|
||||
client_class = Client
|
||||
async_client_class = AsyncClient
|
||||
_overridden_settings = None
|
||||
_modified_settings = None
|
||||
|
||||
@@ -292,6 +293,7 @@ class SimpleTestCase(unittest.TestCase):
|
||||
* Clear the mail test outbox.
|
||||
"""
|
||||
self.client = self.client_class()
|
||||
self.async_client = self.async_client_class()
|
||||
mail.outbox = []
|
||||
|
||||
def _post_teardown(self):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
@@ -362,12 +363,22 @@ class TestContextDecorator:
|
||||
raise TypeError('Can only decorate subclasses of unittest.TestCase')
|
||||
|
||||
def decorate_callable(self, func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# If the inner function is an async function, we must execute async
|
||||
# as well so that the `with` statement executes at the right time.
|
||||
@wraps(func)
|
||||
async def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
def __call__(self, decorated):
|
||||
|
||||
Reference in New Issue
Block a user