From 52b054824e899db40ba48f908a9a00dadc56cb89 Mon Sep 17 00:00:00 2001
From: Alexandre Spaeth <Alexerson@users.noreply.github.com>
Date: Wed, 15 Feb 2023 15:16:51 -0800
Subject: [PATCH] Fixed #34342, Refs #33735 -- Fixed test client handling of
 async streaming responses.

Bug in 0bd2c0c9015b53c41394a1c0989afbfd94dc2830.

Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
---
 django/test/client.py   | 35 ++++++++++++++++++++++++++---------
 tests/handlers/tests.py | 17 +++++++++++++++++
 tests/handlers/urls.py  |  1 +
 tests/handlers/views.py |  9 +++++++++
 4 files changed, 53 insertions(+), 9 deletions(-)

diff --git a/django/test/client.py b/django/test/client.py
index c699eb9264..cf63265faa 100644
--- a/django/test/client.py
+++ b/django/test/client.py
@@ -116,6 +116,16 @@ def closing_iterator_wrapper(iterable, close):
         request_finished.connect(close_old_connections)
 
 
+async def aclosing_iterator_wrapper(iterable, close):
+    try:
+        async for chunk in iterable:
+            yield chunk
+    finally:
+        request_finished.disconnect(close_old_connections)
+        close()  # will fire request_finished
+        request_finished.connect(close_old_connections)
+
+
 def conditional_content_removal(request, response):
     """
     Simulate the behavior of most web servers by removing the content of
@@ -174,9 +184,14 @@ class ClientHandler(BaseHandler):
 
         # Emulate a WSGI server by calling the close method on completion.
         if response.streaming:
-            response.streaming_content = closing_iterator_wrapper(
-                response.streaming_content, response.close
-            )
+            if response.is_async:
+                response.streaming_content = aclosing_iterator_wrapper(
+                    response.streaming_content, response.close
+                )
+            else:
+                response.streaming_content = closing_iterator_wrapper(
+                    response.streaming_content, response.close
+                )
         else:
             request_finished.disconnect(close_old_connections)
             response.close()  # will fire request_finished
@@ -223,12 +238,14 @@ class AsyncClientHandler(BaseHandler):
         response.asgi_request = request
         # Emulate a server by calling the close method on completion.
         if response.streaming:
-            response.streaming_content = await sync_to_async(
-                closing_iterator_wrapper, thread_sensitive=False
-            )(
-                response.streaming_content,
-                response.close,
-            )
+            if response.is_async:
+                response.streaming_content = aclosing_iterator_wrapper(
+                    response.streaming_content, response.close
+                )
+            else:
+                response.streaming_content = closing_iterator_wrapper(
+                    response.streaming_content, response.close
+                )
         else:
             request_finished.disconnect(close_old_connections)
             # Will fire request_finished.
diff --git a/tests/handlers/tests.py b/tests/handlers/tests.py
index 0df481c2fc..0348b8e5d6 100644
--- a/tests/handlers/tests.py
+++ b/tests/handlers/tests.py
@@ -253,6 +253,16 @@ class HandlerRequestTests(SimpleTestCase):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(b"".join(list(response)), b"streaming content")
 
+    def test_async_streaming(self):
+        response = self.client.get("/async_streaming/")
+        self.assertEqual(response.status_code, 200)
+        msg = (
+            "StreamingHttpResponse must consume asynchronous iterators in order to "
+            "serve them synchronously. Use a synchronous iterator instead."
+        )
+        with self.assertWarnsMessage(Warning, msg):
+            self.assertEqual(b"".join(list(response)), b"streaming content")
+
 
 class ScriptNameTests(SimpleTestCase):
     def test_get_script_name(self):
@@ -329,3 +339,10 @@ class AsyncHandlerRequestTests(SimpleTestCase):
             self.assertEqual(
                 b"".join([chunk async for chunk in response]), b"streaming content"
             )
+
+    async def test_async_streaming(self):
+        response = await self.async_client.get("/async_streaming/")
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(
+            b"".join([chunk async for chunk in response]), b"streaming content"
+        )
diff --git a/tests/handlers/urls.py b/tests/handlers/urls.py
index 73d99c7edf..a0efece602 100644
--- a/tests/handlers/urls.py
+++ b/tests/handlers/urls.py
@@ -8,6 +8,7 @@ urlpatterns = [
     path("no_response_fbv/", views.no_response),
     path("no_response_cbv/", views.NoResponse()),
     path("streaming/", views.streaming),
+    path("async_streaming/", views.async_streaming),
     path("in_transaction/", views.in_transaction),
     path("not_in_transaction/", views.not_in_transaction),
     path("not_in_transaction_using_none/", views.not_in_transaction_using_none),
diff --git a/tests/handlers/views.py b/tests/handlers/views.py
index 351eb65264..95d663323d 100644
--- a/tests/handlers/views.py
+++ b/tests/handlers/views.py
@@ -65,6 +65,15 @@ async def async_regular(request):
     return HttpResponse(b"regular content")
 
 
+async def async_streaming(request):
+    async def async_streaming_generator():
+        yield b"streaming"
+        yield b" "
+        yield b"content"
+
+    return StreamingHttpResponse(async_streaming_generator())
+
+
 class CoroutineClearingView:
     def __call__(self, request):
         """Return an unawaited coroutine (common error for async views)."""