diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py index 89306683b8..8a36b67eef 100644 --- a/django/core/servers/basehttp.py +++ b/django/core/servers/basehttp.py @@ -74,12 +74,24 @@ class WSGIServer(simple_server.WSGIServer): class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer): """A threaded version of the WSGIServer""" - pass + daemon_threads = True class ServerHandler(simple_server.ServerHandler): http_version = '1.1' + def cleanup_headers(self): + super().cleanup_headers() + # HTTP/1.1 requires us to support persistent connections, so + # explicitly send close if we do not know the content length to + # prevent clients from reusing the connection. + if 'Content-Length' not in self.headers: + self.headers['Connection'] = 'close' + # Mark the connection for closing if we set it as such above or + # if the application sent the header. + if self.headers.get('Connection') == 'close': + self.request_handler.close_connection = True + def handle_error(self): # Ignore broken pipe errors, otherwise pass on if not is_broken_pipe_error(): @@ -135,6 +147,16 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): return super().get_environ() def handle(self): + self.close_connection = True + self.handle_one_request() + while not self.close_connection: + self.handle_one_request() + try: + self.connection.shutdown(socket.SHUT_WR) + except (socket.error, AttributeError): + pass + + def handle_one_request(self): """Copy of WSGIRequestHandler.handle() but with different ServerHandler""" self.raw_requestline = self.rfile.readline(65537) if len(self.raw_requestline) > 65536: @@ -150,7 +172,7 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler): handler = ServerHandler( self.rfile, self.wfile, self.get_stderr(), self.get_environ() ) - handler.request_handler = self # backpointer for logging + handler.request_handler = self # backpointer for logging & connection closing handler.run(self.server.get_app()) diff --git a/tests/servers/tests.py b/tests/servers/tests.py index ce08eb4a3f..e38cb5eb07 100644 --- a/tests/servers/tests.py +++ b/tests/servers/tests.py @@ -4,8 +4,7 @@ Tests for django.core.servers. import errno import os import socket -import sys -from http.client import HTTPConnection, RemoteDisconnected +from http.client import HTTPConnection from urllib.error import HTTPError from urllib.parse import urlencode from urllib.request import urlopen @@ -57,29 +56,60 @@ class LiveServerViews(LiveServerBase): with self.urlopen('/example_view/') as f: self.assertEqual(f.version, 11) - @override_settings(MIDDLEWARE=[]) def test_closes_connection_without_content_length(self): """ - The server doesn't support keep-alive because Python's http.server - module that it uses hangs if a Content-Length header isn't set (for - example, if CommonMiddleware isn't enabled or if the response is a - StreamingHttpResponse) (#28440 / https://bugs.python.org/issue31076). + A HTTP 1.1 server is supposed to support keep-alive. Since our + development server is rather simple we support it only in cases where + we can detect a content length from the response. This should be doable + for all simple views and streaming responses where an iterable with + length of one is passed. The latter follows as result of `set_content_length` + from https://github.com/python/cpython/blob/master/Lib/wsgiref/handlers.py. + + If we cannot detect a content length we explicitly set the `Connection` + header to `close` to notify the client that we do not actually support + it. """ conn = HTTPConnection(LiveServerViews.server_thread.host, LiveServerViews.server_thread.port, timeout=1) try: - conn.request('GET', '/example_view/', headers={'Connection': 'keep-alive'}) - response = conn.getresponse().read() - conn.request('GET', '/example_view/', headers={'Connection': 'close'}) - # macOS may give ConnectionResetError. - with self.assertRaises((RemoteDisconnected, ConnectionResetError)): - try: - conn.getresponse() - except ConnectionAbortedError: - if sys.platform == 'win32': - self.skipTest('Ignore nondeterministic failure on Windows.') + conn.request('GET', '/streaming_example_view/', headers={'Connection': 'keep-alive'}) + response = conn.getresponse() + self.assertTrue(response.will_close) + self.assertEqual(response.read(), b'Iamastream') + self.assertEqual(response.status, 200) + self.assertEqual(response.getheader('Connection'), 'close') + + conn.request('GET', '/streaming_example_view/', headers={'Connection': 'close'}) + response = conn.getresponse() + self.assertTrue(response.will_close) + self.assertEqual(response.read(), b'Iamastream') + self.assertEqual(response.status, 200) + self.assertEqual(response.getheader('Connection'), 'close') + finally: + conn.close() + + def test_keep_alive_on_connection_with_content_length(self): + """ + See `test_closes_connection_without_content_length` for details. This + is a follow up test, which ensure that we do not close the connection + if not needed, hence allowing us to take advantage of keep-alive. + """ + conn = HTTPConnection(LiveServerViews.server_thread.host, LiveServerViews.server_thread.port) + try: + conn.request('GET', '/example_view/', headers={"Connection": "keep-alive"}) + response = conn.getresponse() + self.assertFalse(response.will_close) + self.assertEqual(response.read(), b'example view') + self.assertEqual(response.status, 200) + self.assertIsNone(response.getheader('Connection')) + + conn.request('GET', '/example_view/', headers={"Connection": "close"}) + response = conn.getresponse() + self.assertFalse(response.will_close) + self.assertEqual(response.read(), b'example view') + self.assertEqual(response.status, 200) + self.assertIsNone(response.getheader('Connection')) finally: conn.close() - self.assertEqual(response, b'example view') def test_404(self): with self.assertRaises(HTTPError) as err: diff --git a/tests/servers/urls.py b/tests/servers/urls.py index 4963bde357..9017161808 100644 --- a/tests/servers/urls.py +++ b/tests/servers/urls.py @@ -4,6 +4,7 @@ from . import views urlpatterns = [ url(r'^example_view/$', views.example_view), + url(r'^streaming_example_view/$', views.streaming_example_view), url(r'^model_view/$', views.model_view), url(r'^create_model_instance/$', views.create_model_instance), url(r'^environ_view/$', views.environ_view), diff --git a/tests/servers/views.py b/tests/servers/views.py index 3bae0834ab..078be67f46 100644 --- a/tests/servers/views.py +++ b/tests/servers/views.py @@ -1,6 +1,6 @@ from urllib.request import urlopen -from django.http import HttpResponse +from django.http import HttpResponse, StreamingHttpResponse from .models import Person @@ -9,6 +9,10 @@ def example_view(request): return HttpResponse('example view') +def streaming_example_view(request): + return StreamingHttpResponse((b'I', b'am', b'a', b'stream')) + + def model_view(request): people = Person.objects.all() return HttpResponse('\n'.join(person.name for person in people))