mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Fixed #20038 -- Better error message for host validation.
This commit is contained in:
		
				
					committed by
					
						 Carl Meyer
						Carl Meyer
					
				
			
			
				
	
			
			
			
						parent
						
							c8deaa9e7b
						
					
				
				
					commit
					c250f9c99b
				
			| @@ -4,7 +4,6 @@ import copy | |||||||
| import os | import os | ||||||
| import re | import re | ||||||
| import sys | import sys | ||||||
| import warnings |  | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from pprint import pformat | from pprint import pformat | ||||||
| try: | try: | ||||||
| @@ -66,11 +65,14 @@ class HttpRequest(object): | |||||||
|                 host = '%s:%s' % (host, server_port) |                 host = '%s:%s' % (host, server_port) | ||||||
|  |  | ||||||
|         allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS |         allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS | ||||||
|         if validate_host(host, allowed_hosts): |         domain, port = split_domain_port(host) | ||||||
|  |         if domain and validate_host(domain, allowed_hosts): | ||||||
|             return host |             return host | ||||||
|         else: |         else: | ||||||
|             raise SuspiciousOperation( |             msg = "Invalid HTTP_HOST header: %r." % host | ||||||
|                 "Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host) |             if domain: | ||||||
|  |                 msg += "You may need to add %r to ALLOWED_HOSTS." % domain | ||||||
|  |             raise SuspiciousOperation(msg) | ||||||
|  |  | ||||||
|     def get_full_path(self): |     def get_full_path(self): | ||||||
|         # RFC 3986 requires query string arguments to be in the ASCII range. |         # RFC 3986 requires query string arguments to be in the ASCII range. | ||||||
| @@ -454,9 +456,30 @@ def bytes_to_text(s, encoding): | |||||||
|         return s |         return s | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_domain_port(host): | ||||||
|  |     """ | ||||||
|  |     Return a (domain, port) tuple from a given host. | ||||||
|  |  | ||||||
|  |     Returned domain is lower-cased. If the host is invalid, the domain will be | ||||||
|  |     empty. | ||||||
|  |     """ | ||||||
|  |     host = host.lower() | ||||||
|  |  | ||||||
|  |     if not host_validation_re.match(host): | ||||||
|  |         return '', '' | ||||||
|  |  | ||||||
|  |     if host[-1] == ']': | ||||||
|  |         # It's an IPv6 address without a port. | ||||||
|  |         return host, '' | ||||||
|  |     bits = host.rsplit(':', 1) | ||||||
|  |     if len(bits) == 2: | ||||||
|  |         return tuple(bits) | ||||||
|  |     return bits[0], '' | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_host(host, allowed_hosts): | def validate_host(host, allowed_hosts): | ||||||
|     """ |     """ | ||||||
|     Validate the given host header value for this site. |     Validate the given host for this site. | ||||||
|  |  | ||||||
|     Check that the host looks valid and matches a host or host pattern in the |     Check that the host looks valid and matches a host or host pattern in the | ||||||
|     given list of ``allowed_hosts``. Any pattern beginning with a period |     given list of ``allowed_hosts``. Any pattern beginning with a period | ||||||
| @@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts): | |||||||
|     ``example.com`` and any subdomain), ``*`` matches anything, and anything |     ``example.com`` and any subdomain), ``*`` matches anything, and anything | ||||||
|     else must match exactly. |     else must match exactly. | ||||||
|  |  | ||||||
|  |     Note: This function assumes that the given host is lower-cased and has | ||||||
|  |     already had the port, if any, stripped off. | ||||||
|  |  | ||||||
|     Return ``True`` for a valid host, ``False`` otherwise. |     Return ``True`` for a valid host, ``False`` otherwise. | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|     # All validation is case-insensitive |  | ||||||
|     host = host.lower() |  | ||||||
|  |  | ||||||
|     # Basic sanity check |  | ||||||
|     if not host_validation_re.match(host): |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     # Validate only the domain part. |  | ||||||
|     if host[-1] == ']': |  | ||||||
|         # It's an IPv6 address without a port. |  | ||||||
|         domain = host |  | ||||||
|     else: |  | ||||||
|         domain = host.rsplit(':', 1)[0] |  | ||||||
|  |  | ||||||
|     for pattern in allowed_hosts: |     for pattern in allowed_hosts: | ||||||
|         pattern = pattern.lower() |         pattern = pattern.lower() | ||||||
|         match = ( |         match = ( | ||||||
|             pattern == '*' or |             pattern == '*' or | ||||||
|             pattern.startswith('.') and ( |             pattern.startswith('.') and ( | ||||||
|                 domain.endswith(pattern) or domain == pattern[1:] |                 host.endswith(pattern) or host == pattern[1:] | ||||||
|                 ) or |                 ) or | ||||||
|             pattern == domain |             pattern == host | ||||||
|             ) |             ) | ||||||
|         if match: |         if match: | ||||||
|             return True |             return True | ||||||
|   | |||||||
| @@ -11,16 +11,16 @@ from django.core import signals | |||||||
| from django.core.exceptions import SuspiciousOperation | from django.core.exceptions import SuspiciousOperation | ||||||
| from django.core.handlers.wsgi import WSGIRequest, LimitedStream | from django.core.handlers.wsgi import WSGIRequest, LimitedStream | ||||||
| from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError | from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError | ||||||
| from django.test import TransactionTestCase | from django.test import SimpleTestCase, TransactionTestCase | ||||||
| from django.test.client import FakePayload | from django.test.client import FakePayload | ||||||
| from django.test.utils import override_settings, str_prefix | from django.test.utils import override_settings, str_prefix | ||||||
| from django.utils import six | from django.utils import six | ||||||
| from django.utils import unittest | from django.utils.unittest import skipIf | ||||||
| from django.utils.http import cookie_date, urlencode | from django.utils.http import cookie_date, urlencode | ||||||
| from django.utils.timezone import utc | from django.utils.timezone import utc | ||||||
|  |  | ||||||
|  |  | ||||||
| class RequestsTests(unittest.TestCase): | class RequestsTests(SimpleTestCase): | ||||||
|     def test_httprequest(self): |     def test_httprequest(self): | ||||||
|         request = HttpRequest() |         request = HttpRequest() | ||||||
|         self.assertEqual(list(request.GET.keys()), []) |         self.assertEqual(list(request.GET.keys()), []) | ||||||
| @@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase): | |||||||
|         self.assertEqual(request.get_host(), 'example.com') |         self.assertEqual(request.get_host(), 'example.com') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     @override_settings(ALLOWED_HOSTS=[]) | ||||||
|  |     def test_get_host_suggestion_of_allowed_host(self): | ||||||
|  |         """get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS.""" | ||||||
|  |         msg_invalid_host = "Invalid HTTP_HOST header: %r." | ||||||
|  |         msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS." | ||||||
|  |  | ||||||
|  |         for host in [ # Valid-looking hosts | ||||||
|  |             'example.com', | ||||||
|  |             '12.34.56.78', | ||||||
|  |             '[2001:19f0:feee::dead:beef:cafe]', | ||||||
|  |             'xn--4ca9at.com', # Punnycode for öäü.com | ||||||
|  |         ]: | ||||||
|  |             request = HttpRequest() | ||||||
|  |             request.META = {'HTTP_HOST': host} | ||||||
|  |             self.assertRaisesMessage( | ||||||
|  |                 SuspiciousOperation, | ||||||
|  |                 msg_suggestion % (host, host), | ||||||
|  |                 request.get_host | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         for domain, port in [ # Valid-looking hosts with a port number | ||||||
|  |             ('example.com', 80), | ||||||
|  |             ('12.34.56.78', 443), | ||||||
|  |             ('[2001:19f0:feee::dead:beef:cafe]', 8080), | ||||||
|  |         ]: | ||||||
|  |             host = '%s:%s' % (domain, port) | ||||||
|  |             request = HttpRequest() | ||||||
|  |             request.META = {'HTTP_HOST': host} | ||||||
|  |             self.assertRaisesMessage( | ||||||
|  |                 SuspiciousOperation, | ||||||
|  |                 msg_suggestion % (host, domain), | ||||||
|  |                 request.get_host | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         for host in [ # Invalid hosts | ||||||
|  |             'example.com@evil.tld', | ||||||
|  |             'example.com:dr.frankenstein@evil.tld', | ||||||
|  |             'example.com:dr.frankenstein@evil.tld:80', | ||||||
|  |             'example.com:80/badpath', | ||||||
|  |             'example.com: recovermypassword.com', | ||||||
|  |         ]: | ||||||
|  |             request = HttpRequest() | ||||||
|  |             request.META = {'HTTP_HOST': host} | ||||||
|  |             self.assertRaisesMessage( | ||||||
|  |                 SuspiciousOperation, | ||||||
|  |                 msg_invalid_host % host, | ||||||
|  |                 request.get_host | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_near_expiration(self): |     def test_near_expiration(self): | ||||||
|         "Cookie will expire when an near expiration time is provided" |         "Cookie will expire when an near expiration time is provided" | ||||||
|         response = HttpResponse() |         response = HttpResponse() | ||||||
| @@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase): | |||||||
|             request.body |             request.body | ||||||
|  |  | ||||||
|  |  | ||||||
| @unittest.skipIf(connection.vendor == 'sqlite' | @skipIf(connection.vendor == 'sqlite' | ||||||
|         and connection.settings_dict['NAME'] in ('', ':memory:'), |         and connection.settings_dict['NAME'] in ('', ':memory:'), | ||||||
|         "Cannot establish two connections to an in-memory SQLite database.") |         "Cannot establish two connections to an in-memory SQLite database.") | ||||||
| class DatabaseConnectionHandlingTests(TransactionTestCase): | class DatabaseConnectionHandlingTests(TransactionTestCase): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user