From c4eaa67e2b880db778c9fe6d9854fbdfcc16ecd2 Mon Sep 17 00:00:00 2001 From: Scott Halgrim Date: Tue, 8 Nov 2022 12:19:59 +0100 Subject: [PATCH] Fixed #34063 -- Fixed reading request body with async request factory and client. Co-authored-by: Kevan Swanberg Co-authored-by: Carlton Gibson --- django/test/client.py | 10 +++++++--- tests/test_client/tests.py | 18 ++++++++++++++++++ tests/test_client/views.py | 2 ++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/django/test/client.py b/django/test/client.py index 99e831aebd..8b926fc38d 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -14,7 +14,7 @@ from asgiref.sync import sync_to_async from django.conf import settings from django.core.handlers.asgi import ASGIRequest from django.core.handlers.base import BaseHandler -from django.core.handlers.wsgi import WSGIRequest +from django.core.handlers.wsgi import LimitedStream, WSGIRequest from django.core.serializers.json import DjangoJSONEncoder from django.core.signals import got_request_exception, request_finished, request_started from django.db import close_old_connections @@ -198,7 +198,8 @@ class AsyncClientHandler(BaseHandler): sender=self.__class__, scope=scope ) request_started.connect(close_old_connections) - request = ASGIRequest(scope, body_file) + # Wrap FakePayload body_file to allow large read() in test environment. + request = ASGIRequest(scope, LimitedStream(body_file, len(body_file))) # Sneaky little hack so that we can easily get round # CsrfViewMiddleware. This makes life easier, and is probably required # for backwards compatibility with external tests against admin views. @@ -598,7 +599,10 @@ class AsyncRequestFactory(RequestFactory): body_file = request.pop("_body_file") else: body_file = FakePayload("") - return ASGIRequest(self._base_scope(**request), body_file) + # Wrap FakePayload body_file to allow large read() in test environment. + return ASGIRequest( + self._base_scope(**request), LimitedStream(body_file, len(body_file)) + ) def generic( self, diff --git a/tests/test_client/tests.py b/tests/test_client/tests.py index 57dc22ea0c..5612ae4462 100644 --- a/tests/test_client/tests.py +++ b/tests/test_client/tests.py @@ -1103,6 +1103,14 @@ class AsyncClientTest(TestCase): response = await self.async_client.get("/get_view/", {"var": "val"}) self.assertContains(response, "This is a test. val is the value.") + async def test_post_data(self): + response = await self.async_client.post("/post_view/", {"value": 37}) + self.assertContains(response, "Data received: 37 is the value.") + + async def test_body_read_on_get_data(self): + response = await self.async_client.get("/post_view/") + self.assertContains(response, "Viewing GET page.") + @override_settings(ROOT_URLCONF="test_client.urls") class AsyncRequestFactoryTest(SimpleTestCase): @@ -1147,6 +1155,16 @@ class AsyncRequestFactoryTest(SimpleTestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.content, b'{"example": "data"}') + async def test_request_limited_read(self): + tests = ["GET", "POST"] + for method in tests: + with self.subTest(method=method): + request = self.request_factory.generic( + method, + "/somewhere", + ) + self.assertEqual(request.read(200), b"") + def test_request_factory_sets_headers(self): request = self.request_factory.get( "/somewhere/", diff --git a/tests/test_client/views.py b/tests/test_client/views.py index 773e9e4e98..01850257b5 100644 --- a/tests/test_client/views.py +++ b/tests/test_client/views.py @@ -90,6 +90,8 @@ def post_view(request): c = Context() else: t = Template("Viewing GET page.", name="Empty GET Template") + # Used by test_body_read_on_get_data. + request.read(200) c = Context() return HttpResponse(t.render(c))