diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py index fef5532e58..d08fb77a47 100644 --- a/django/core/servers/basehttp.py +++ b/django/core/servers/basehttp.py @@ -11,6 +11,7 @@ import logging import socket import socketserver import sys +from collections import deque from wsgiref import simple_server from django.core.exceptions import ImproperlyConfigured @@ -130,10 +131,18 @@ class ServerHandler(simple_server.ServerHandler): def cleanup_headers(self): super().cleanup_headers() + if ( + self.environ["REQUEST_METHOD"] == "HEAD" + and "Content-Length" in self.headers + ): + del self.headers["Content-Length"] # HTTP/1.1 requires support for persistent connections. Send 'close' if # the content length is unknown to prevent clients from reusing the # connection. - if "Content-Length" not in self.headers: + if ( + self.environ["REQUEST_METHOD"] != "HEAD" + and "Content-Length" not in self.headers + ): self.headers["Connection"] = "close" # Persistent connections require threading server. elif not isinstance(self.request_handler.server, socketserver.ThreadingMixIn): @@ -147,6 +156,22 @@ class ServerHandler(simple_server.ServerHandler): self.get_stdin().read() super().close() + def finish_response(self): + if self.environ["REQUEST_METHOD"] == "HEAD": + try: + deque(self.result, maxlen=0) # Consume iterator. + # Don't call self.finish_content() as, if the headers have not + # been sent and Content-Length isn't set, it'll default to "0" + # which will prevent omission of the Content-Length header with + # HEAD requests as permitted by RFC 9110 Section 9.3.2. + # Instead, send the headers, if not sent yet. + if not self.headers_sent: + self.send_headers() + finally: + self.close() + else: + super().finish_response() + class WSGIRequestHandler(simple_server.WSGIRequestHandler): protocol_version = "HTTP/1.1" diff --git a/tests/servers/test_basehttp.py b/tests/servers/test_basehttp.py index a837505feb..1e535e933e 100644 --- a/tests/servers/test_basehttp.py +++ b/tests/servers/test_basehttp.py @@ -1,4 +1,5 @@ from io import BytesIO +from socketserver import ThreadingMixIn from django.core.handlers.wsgi import WSGIRequest from django.core.servers.basehttp import WSGIRequestHandler, WSGIServer @@ -7,7 +8,7 @@ from django.test.client import RequestFactory from django.test.utils import captured_stderr -class Stub: +class Stub(ThreadingMixIn): def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -15,6 +16,13 @@ class Stub: self.makefile("wb").write(data) +class UnclosableBytesIO(BytesIO): + def close(self): + # WSGIRequestHandler closes the output file; we need to make this a + # no-op so we can still read its contents. + pass + + class WSGIRequestHandlerTestCase(SimpleTestCase): request_factory = RequestFactory() @@ -79,12 +87,6 @@ class WSGIRequestHandlerTestCase(SimpleTestCase): rfile.write(b"Other_Header: bad\r\n") rfile.seek(0) - # WSGIRequestHandler closes the output file; we need to make this a - # no-op so we can still read its contents. - class UnclosableBytesIO(BytesIO): - def close(self): - pass - wfile = UnclosableBytesIO() def makefile(mode, *a, **kw): @@ -106,6 +108,59 @@ class WSGIRequestHandlerTestCase(SimpleTestCase): self.assertEqual(body, b"HTTP_SOME_HEADER:good") + def test_no_body_returned_for_head_requests(self): + hello_world_body = b"Hello World" + content_length = len(hello_world_body) + + def test_app(environ, start_response): + """A WSGI app that returns a hello world.""" + start_response("200 OK", []) + return [hello_world_body] + + rfile = BytesIO(b"GET / HTTP/1.0\r\n") + rfile.seek(0) + + wfile = UnclosableBytesIO() + + def makefile(mode, *a, **kw): + if mode == "rb": + return rfile + elif mode == "wb": + return wfile + + request = Stub(makefile=makefile) + server = Stub(base_environ={}, get_app=lambda: test_app) + + # Prevent logging from appearing in test output. + with self.assertLogs("django.server", "INFO"): + # Instantiating a handler runs the request as side effect. + WSGIRequestHandler(request, "192.168.0.2", server) + + wfile.seek(0) + lines = list(wfile.readlines()) + body = lines[-1] + # The body is returned in a GET response. + self.assertEqual(body, hello_world_body) + self.assertIn(f"Content-Length: {content_length}\r\n".encode(), lines) + self.assertNotIn(b"Connection: close\r\n", lines) + + rfile = BytesIO(b"HEAD / HTTP/1.0\r\n") + rfile.seek(0) + wfile = UnclosableBytesIO() + + with self.assertLogs("django.server", "INFO"): + WSGIRequestHandler(request, "192.168.0.2", server) + + wfile.seek(0) + lines = list(wfile.readlines()) + body = lines[-1] + # The body is not returned in a HEAD response. + self.assertEqual(body, b"\r\n") + self.assertIs( + any([line.startswith(b"Content-Length:") for line in lines]), False + ) + self.assertNotIn(b"Connection: close\r\n", lines) + class WSGIServerTestCase(SimpleTestCase): request_factory = RequestFactory()