1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #31224 -- Added support for asynchronous views and middleware.

This implements support for asynchronous views, asynchronous tests,
asynchronous middleware, and an asynchronous test client.
This commit is contained in:
Andrew Godwin
2020-02-12 15:15:00 -07:00
committed by Mariusz Felisiak
parent 3f7e4b16bf
commit fc0fa72ff4
30 changed files with 1344 additions and 214 deletions

View File

@@ -1,6 +1,8 @@
"""Django Unit Test framework."""
from django.test.client import Client, RequestFactory
from django.test.client import (
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
)
from django.test.testcases import (
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
@@ -11,8 +13,9 @@ from django.test.utils import (
)
__all__ = [
'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase',
'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature',
'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings',
'modify_settings', 'override_settings', 'override_system_checks', 'tag',
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
'ignore_warnings', 'modify_settings', 'override_settings',
'override_system_checks', 'tag',
]

View File

@@ -9,7 +9,10 @@ from importlib import import_module
from io import BytesIO
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
from asgiref.sync import sync_to_async
from django.conf import settings
from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
@@ -157,6 +160,52 @@ class ClientHandler(BaseHandler):
return response
class AsyncClientHandler(BaseHandler):
"""An async version of ClientHandler."""
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
self.enforce_csrf_checks = enforce_csrf_checks
super().__init__(*args, **kwargs)
async def __call__(self, scope):
# Set up middleware if needed. We couldn't do this earlier, because
# settings weren't available.
if self._middleware_chain is None:
self.load_middleware(is_async=True)
# Extract body file from the scope, if provided.
if '_body_file' in scope:
body_file = scope.pop('_body_file')
else:
body_file = FakePayload('')
request_started.disconnect(close_old_connections)
await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
request_started.connect(close_old_connections)
request = ASGIRequest(scope, body_file)
# Sneaky little hack so that we can easily get round
# CsrfViewMiddleware. This makes life easier, and is probably required
# for backwards compatibility with external tests against admin views.
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
# Request goes through middleware.
response = await self.get_response_async(request)
# Simulate behaviors of most Web servers.
conditional_content_removal(request, response)
# Attach the originating ASGI request to the response so that it could
# be later retrieved.
response.asgi_request = request
# Emulate a server by calling the close method on completion.
if response.streaming:
response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
response.streaming_content,
response.close,
)
else:
request_finished.disconnect(close_old_connections)
# Will fire request_finished.
await sync_to_async(response.close)()
request_finished.connect(close_old_connections)
return response
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
"""
Store templates and contexts that are rendered.
@@ -421,7 +470,194 @@ class RequestFactory:
return self.request(**r)
class Client(RequestFactory):
class AsyncRequestFactory(RequestFactory):
"""
Class that lets you create mock ASGI-like Request objects for use in
testing. Usage:
rf = AsyncRequestFactory()
get_request = await rf.get('/hello/')
post_request = await rf.post('/submit/', {'foo': 'bar'})
Once you have a request object you can pass it to any view function,
including synchronous ones. The reason we have a separate class here is:
a) this makes ASGIRequest subclasses, and
b) AsyncTestClient can subclass it.
"""
def _base_scope(self, **request):
"""The base scope for a request."""
# This is a minimal valid ASGI scope, plus:
# - headers['cookie'] for cookie support,
# - 'client' often useful, see #8551.
scope = {
'asgi': {'version': '3.0'},
'type': 'http',
'http_version': '1.1',
'client': ['127.0.0.1', 0],
'server': ('testserver', '80'),
'scheme': 'http',
'method': 'GET',
'headers': [],
**self.defaults,
**request,
}
scope['headers'].append((
b'cookie',
b'; '.join(sorted(
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
for morsel in self.cookies.values()
)),
))
return scope
def request(self, **request):
"""Construct a generic request object."""
# This is synchronous, which means all methods on this class are.
# AsyncClient, however, has an async request function, which makes all
# its methods async.
if '_body_file' in request:
body_file = request.pop('_body_file')
else:
body_file = FakePayload('')
return ASGIRequest(self._base_scope(**request), body_file)
def generic(
self, method, path, data='', content_type='application/octet-stream',
secure=False, **extra,
):
"""Construct an arbitrary HTTP request."""
parsed = urlparse(str(path)) # path can be lazy.
data = force_bytes(data, settings.DEFAULT_CHARSET)
s = {
'method': method,
'path': self._get_path(parsed),
'server': ('127.0.0.1', '443' if secure else '80'),
'scheme': 'https' if secure else 'http',
'headers': [(b'host', b'testserver')],
}
if data:
s['headers'].extend([
(b'content-length', bytes(len(data))),
(b'content-type', content_type.encode('ascii')),
])
s['_body_file'] = FakePayload(data)
s.update(extra)
# If QUERY_STRING is absent or empty, we want to extract it from the
# URL.
if not s.get('query_string'):
s['query_string'] = parsed[4]
return self.request(**s)
class ClientMixin:
"""
Mixin with common methods between Client and AsyncClient.
"""
def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view."""
self.exc_info = sys.exc_info()
def check_exception(self, response):
"""
Look for a signaled exception, clear the current context exception
data, re-raise the signaled exception, and clear the signaled exception
from the local cache.
"""
response.exc_info = self.exc_info
if self.exc_info:
_, exc_value, _ = self.exc_info
self.exc_info = None
if self.raise_request_exception:
raise exc_value
@property
def session(self):
"""Return the current session variables."""
engine = import_module(settings.SESSION_ENGINE)
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
if cookie:
return engine.SessionStore(cookie.value)
session = engine.SessionStore()
session.save()
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
return session
def login(self, **credentials):
"""
Set the Factory to appear as if it has successfully logged into a site.
Return True if login is possible or False if the provided credentials
are incorrect.
"""
from django.contrib.auth import authenticate
user = authenticate(**credentials)
if user:
self._login(user)
return True
return False
def force_login(self, user, backend=None):
def get_backend():
from django.contrib.auth import load_backend
for backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
if hasattr(backend, 'get_user'):
return backend_path
if backend is None:
backend = get_backend()
user.backend = backend
self._login(user, backend)
def _login(self, user, backend=None):
from django.contrib.auth import login
# Create a fake request to store login details.
request = HttpRequest()
if self.session:
request.session = self.session
else:
engine = import_module(settings.SESSION_ENGINE)
request.session = engine.SessionStore()
login(request, user, backend)
# Save the session values.
request.session.save()
# Set the cookie to represent the session.
session_cookie = settings.SESSION_COOKIE_NAME
self.cookies[session_cookie] = request.session.session_key
cookie_data = {
'max-age': None,
'path': '/',
'domain': settings.SESSION_COOKIE_DOMAIN,
'secure': settings.SESSION_COOKIE_SECURE or None,
'expires': None,
}
self.cookies[session_cookie].update(cookie_data)
def logout(self):
"""Log out the user by removing the cookies and session object."""
from django.contrib.auth import get_user, logout
request = HttpRequest()
if self.session:
request.session = self.session
request.user = get_user(request)
else:
engine = import_module(settings.SESSION_ENGINE)
request.session = engine.SessionStore()
logout(request)
self.cookies = SimpleCookie()
def _parse_json(self, response, **extra):
if not hasattr(response, '_json'):
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
raise ValueError(
'Content-Type header is "%s", not "application/json"'
% response.get('Content-Type')
)
response._json = json.loads(response.content.decode(response.charset), **extra)
return response._json
class Client(ClientMixin, RequestFactory):
"""
A class that can act as a client for testing purposes.
@@ -446,23 +682,6 @@ class Client(RequestFactory):
self.exc_info = None
self.extra = None
def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view."""
self.exc_info = sys.exc_info()
@property
def session(self):
"""Return the current session variables."""
engine = import_module(settings.SESSION_ENGINE)
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
if cookie:
return engine.SessionStore(cookie.value)
session = engine.SessionStore()
session.save()
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
return session
def request(self, **request):
"""
The master request method. Compose the environment dictionary and pass
@@ -486,15 +705,8 @@ class Client(RequestFactory):
finally:
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
got_request_exception.disconnect(dispatch_uid=exception_uid)
# Look for a signaled exception, clear the current context exception
# data, then re-raise the signaled exception. Also clear the signaled
# exception from the local cache.
response.exc_info = self.exc_info
if self.exc_info:
_, exc_value, _ = self.exc_info
self.exc_info = None
if self.raise_request_exception:
raise exc_value
# Check for signaled exceptions.
self.check_exception(response)
# Save the client and request that stimulated the response.
response.client = self
response.request = request
@@ -583,85 +795,6 @@ class Client(RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
def login(self, **credentials):
"""
Set the Factory to appear as if it has successfully logged into a site.
Return True if login is possible; False if the provided credentials
are incorrect.
"""
from django.contrib.auth import authenticate
user = authenticate(**credentials)
if user:
self._login(user)
return True
else:
return False
def force_login(self, user, backend=None):
def get_backend():
from django.contrib.auth import load_backend
for backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
if hasattr(backend, 'get_user'):
return backend_path
if backend is None:
backend = get_backend()
user.backend = backend
self._login(user, backend)
def _login(self, user, backend=None):
from django.contrib.auth import login
engine = import_module(settings.SESSION_ENGINE)
# Create a fake request to store login details.
request = HttpRequest()
if self.session:
request.session = self.session
else:
request.session = engine.SessionStore()
login(request, user, backend)
# Save the session values.
request.session.save()
# Set the cookie to represent the session.
session_cookie = settings.SESSION_COOKIE_NAME
self.cookies[session_cookie] = request.session.session_key
cookie_data = {
'max-age': None,
'path': '/',
'domain': settings.SESSION_COOKIE_DOMAIN,
'secure': settings.SESSION_COOKIE_SECURE or None,
'expires': None,
}
self.cookies[session_cookie].update(cookie_data)
def logout(self):
"""Log out the user by removing the cookies and session object."""
from django.contrib.auth import get_user, logout
request = HttpRequest()
engine = import_module(settings.SESSION_ENGINE)
if self.session:
request.session = self.session
request.user = get_user(request)
else:
request.session = engine.SessionStore()
logout(request)
self.cookies = SimpleCookie()
def _parse_json(self, response, **extra):
if not hasattr(response, '_json'):
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
raise ValueError(
'Content-Type header is "{}", not "application/json"'
.format(response.get('Content-Type'))
)
response._json = json.loads(response.content.decode(response.charset), **extra)
return response._json
def _handle_redirects(self, response, data='', content_type='', **extra):
"""
Follow any redirects by requesting responses from the server using GET.
@@ -714,3 +847,66 @@ class Client(RequestFactory):
raise RedirectCycleError("Too many redirects.", last_response=response)
return response
class AsyncClient(ClientMixin, AsyncRequestFactory):
"""
An async version of Client that creates ASGIRequests and calls through an
async request path.
Does not currently support "follow" on its methods.
"""
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
super().__init__(**defaults)
self.handler = AsyncClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
self.exc_info = None
self.extra = None
async def request(self, **request):
"""
The master request method. Compose the scope dictionary and pass to the
handler, return the result of the handler. Assume defaults for the
query environment, which can be overridden using the arguments to the
request.
"""
if 'follow' in request:
raise NotImplementedError(
'AsyncClient request methods do not accept the follow '
'parameter.'
)
scope = self._base_scope(**request)
# Curry a data dictionary into an instance of the template renderer
# callback function.
data = {}
on_template_render = partial(store_rendered_templates, data)
signal_uid = 'template-render-%s' % id(request)
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
# Capture exceptions created by the handler.
exception_uid = 'request-exception-%s' % id(request)
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
try:
response = await self.handler(scope)
finally:
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
got_request_exception.disconnect(dispatch_uid=exception_uid)
# Check for signaled exceptions.
self.check_exception(response)
# Save the client and request that stimulated the response.
response.client = self
response.request = request
# Add any rendered template detail to the response.
response.templates = data.get('templates', [])
response.context = data.get('context')
response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response.
response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
# Flatten a single context. Not really necessary anymore thanks to the
# __getattr__ flattening in ContextList, but has some edge case
# backwards compatibility implications.
if response.context and len(response.context) == 1:
response.context = response.context[0]
# Update persistent cookie data.
if response.cookies:
self.cookies.update(response.cookies)
return response

View File

@@ -33,7 +33,7 @@ from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.forms.fields import CharField
from django.http import QueryDict
from django.http.request import split_domain_port, validate_host
from django.test.client import Client
from django.test.client import AsyncClient, Client
from django.test.html import HTMLParseError, parse_html
from django.test.signals import setting_changed, template_rendered
from django.test.utils import (
@@ -151,6 +151,7 @@ class SimpleTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
client_class = Client
async_client_class = AsyncClient
_overridden_settings = None
_modified_settings = None
@@ -292,6 +293,7 @@ class SimpleTestCase(unittest.TestCase):
* Clear the mail test outbox.
"""
self.client = self.client_class()
self.async_client = self.async_client_class()
mail.outbox = []
def _post_teardown(self):

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import re
import sys
@@ -362,12 +363,22 @@ class TestContextDecorator:
raise TypeError('Can only decorate subclasses of unittest.TestCase')
def decorate_callable(self, func):
@wraps(func)
def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
# If the inner function is an async function, we must execute async
# as well so that the `with` statement executes at the right time.
@wraps(func)
async def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return await func(*args, **kwargs)
else:
@wraps(func)
def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return func(*args, **kwargs)
return inner
def __call__(self, decorated):