1
0
mirror of https://github.com/django/django.git synced 2025-04-25 09:44:36 +00:00

Fixed #20038 -- Better error message for host validation.

This commit is contained in:
Baptiste Mispelon 2013-04-03 14:27:20 -06:00 committed by Carl Meyer
parent c8deaa9e7b
commit c250f9c99b
2 changed files with 87 additions and 25 deletions

View File

@ -4,7 +4,6 @@ import copy
import os import os
import re import re
import sys import sys
import warnings
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
try: try:
@ -66,11 +65,14 @@ class HttpRequest(object):
host = '%s:%s' % (host, server_port) host = '%s:%s' % (host, server_port)
allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS
if validate_host(host, allowed_hosts): domain, port = split_domain_port(host)
if domain and validate_host(domain, allowed_hosts):
return host return host
else: else:
raise SuspiciousOperation( msg = "Invalid HTTP_HOST header: %r." % host
"Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host) if domain:
msg += "You may need to add %r to ALLOWED_HOSTS." % domain
raise SuspiciousOperation(msg)
def get_full_path(self): def get_full_path(self):
# RFC 3986 requires query string arguments to be in the ASCII range. # RFC 3986 requires query string arguments to be in the ASCII range.
@ -454,9 +456,30 @@ def bytes_to_text(s, encoding):
return s return s
def split_domain_port(host):
"""
Return a (domain, port) tuple from a given host.
Returned domain is lower-cased. If the host is invalid, the domain will be
empty.
"""
host = host.lower()
if not host_validation_re.match(host):
return '', ''
if host[-1] == ']':
# It's an IPv6 address without a port.
return host, ''
bits = host.rsplit(':', 1)
if len(bits) == 2:
return tuple(bits)
return bits[0], ''
def validate_host(host, allowed_hosts): def validate_host(host, allowed_hosts):
""" """
Validate the given host header value for this site. Validate the given host for this site.
Check that the host looks valid and matches a host or host pattern in the Check that the host looks valid and matches a host or host pattern in the
given list of ``allowed_hosts``. Any pattern beginning with a period given list of ``allowed_hosts``. Any pattern beginning with a period
@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts):
``example.com`` and any subdomain), ``*`` matches anything, and anything ``example.com`` and any subdomain), ``*`` matches anything, and anything
else must match exactly. else must match exactly.
Note: This function assumes that the given host is lower-cased and has
already had the port, if any, stripped off.
Return ``True`` for a valid host, ``False`` otherwise. Return ``True`` for a valid host, ``False`` otherwise.
""" """
# All validation is case-insensitive
host = host.lower()
# Basic sanity check
if not host_validation_re.match(host):
return False
# Validate only the domain part.
if host[-1] == ']':
# It's an IPv6 address without a port.
domain = host
else:
domain = host.rsplit(':', 1)[0]
for pattern in allowed_hosts: for pattern in allowed_hosts:
pattern = pattern.lower() pattern = pattern.lower()
match = ( match = (
pattern == '*' or pattern == '*' or
pattern.startswith('.') and ( pattern.startswith('.') and (
domain.endswith(pattern) or domain == pattern[1:] host.endswith(pattern) or host == pattern[1:]
) or ) or
pattern == domain pattern == host
) )
if match: if match:
return True return True

View File

@ -11,16 +11,16 @@ from django.core import signals
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.core.handlers.wsgi import WSGIRequest, LimitedStream from django.core.handlers.wsgi import WSGIRequest, LimitedStream
from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError
from django.test import TransactionTestCase from django.test import SimpleTestCase, TransactionTestCase
from django.test.client import FakePayload from django.test.client import FakePayload
from django.test.utils import override_settings, str_prefix from django.test.utils import override_settings, str_prefix
from django.utils import six from django.utils import six
from django.utils import unittest from django.utils.unittest import skipIf
from django.utils.http import cookie_date, urlencode from django.utils.http import cookie_date, urlencode
from django.utils.timezone import utc from django.utils.timezone import utc
class RequestsTests(unittest.TestCase): class RequestsTests(SimpleTestCase):
def test_httprequest(self): def test_httprequest(self):
request = HttpRequest() request = HttpRequest()
self.assertEqual(list(request.GET.keys()), []) self.assertEqual(list(request.GET.keys()), [])
@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase):
self.assertEqual(request.get_host(), 'example.com') self.assertEqual(request.get_host(), 'example.com')
@override_settings(ALLOWED_HOSTS=[])
def test_get_host_suggestion_of_allowed_host(self):
"""get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS."""
msg_invalid_host = "Invalid HTTP_HOST header: %r."
msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS."
for host in [ # Valid-looking hosts
'example.com',
'12.34.56.78',
'[2001:19f0:feee::dead:beef:cafe]',
'xn--4ca9at.com', # Punnycode for öäü.com
]:
request = HttpRequest()
request.META = {'HTTP_HOST': host}
self.assertRaisesMessage(
SuspiciousOperation,
msg_suggestion % (host, host),
request.get_host
)
for domain, port in [ # Valid-looking hosts with a port number
('example.com', 80),
('12.34.56.78', 443),
('[2001:19f0:feee::dead:beef:cafe]', 8080),
]:
host = '%s:%s' % (domain, port)
request = HttpRequest()
request.META = {'HTTP_HOST': host}
self.assertRaisesMessage(
SuspiciousOperation,
msg_suggestion % (host, domain),
request.get_host
)
for host in [ # Invalid hosts
'example.com@evil.tld',
'example.com:dr.frankenstein@evil.tld',
'example.com:dr.frankenstein@evil.tld:80',
'example.com:80/badpath',
'example.com: recovermypassword.com',
]:
request = HttpRequest()
request.META = {'HTTP_HOST': host}
self.assertRaisesMessage(
SuspiciousOperation,
msg_invalid_host % host,
request.get_host
)
def test_near_expiration(self): def test_near_expiration(self):
"Cookie will expire when an near expiration time is provided" "Cookie will expire when an near expiration time is provided"
response = HttpResponse() response = HttpResponse()
@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase):
request.body request.body
@unittest.skipIf(connection.vendor == 'sqlite' @skipIf(connection.vendor == 'sqlite'
and connection.settings_dict['NAME'] in ('', ':memory:'), and connection.settings_dict['NAME'] in ('', ':memory:'),
"Cannot establish two connections to an in-memory SQLite database.") "Cannot establish two connections to an in-memory SQLite database.")
class DatabaseConnectionHandlingTests(TransactionTestCase): class DatabaseConnectionHandlingTests(TransactionTestCase):