1
0
mirror of https://github.com/django/django.git synced 2025-10-24 06:06:09 +00:00

Fixed #4476 -- Added a `follow` option to the test client request methods. This implements browser-like behavior for the test client, following redirect chains when a 30X response is received. Thanks to Marc Fargas and Keith Bussell for their work on this.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9911 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee
2009-02-27 13:14:59 +00:00
parent e20f09c2d0
commit e735fe7160
7 changed files with 298 additions and 76 deletions

View File

@@ -1,5 +1,5 @@
import urllib
from urlparse import urlparse, urlunparse
from urlparse import urlparse, urlunparse, urlsplit
import sys
import os
try:
@@ -12,7 +12,7 @@ from django.contrib.auth import authenticate, login
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.signals import got_request_exception
from django.http import SimpleCookie, HttpRequest
from django.http import SimpleCookie, HttpRequest, QueryDict
from django.template import TemplateDoesNotExist
from django.test import signals
from django.utils.functional import curry
@@ -261,7 +261,7 @@ class Client(object):
return response
def get(self, path, data={}, **extra):
def get(self, path, data={}, follow=False, **extra):
"""
Requests a response from the server using GET.
"""
@@ -275,9 +275,13 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
def post(self, path, data={}, content_type=MULTIPART_CONTENT,
follow=False, **extra):
"""
Requests a response from the server using POST.
"""
@@ -297,9 +301,12 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def head(self, path, data={}, **extra):
def head(self, path, data={}, follow=False, **extra):
"""
Request a response from the server using HEAD.
"""
@@ -313,9 +320,12 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def options(self, path, data={}, **extra):
def options(self, path, data={}, follow=False, **extra):
"""
Request a response from the server using OPTIONS.
"""
@@ -328,9 +338,13 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
def put(self, path, data={}, content_type=MULTIPART_CONTENT,
follow=False, **extra):
"""
Send a resource to the server using PUT.
"""
@@ -350,9 +364,12 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def delete(self, path, data={}, **extra):
def delete(self, path, data={}, follow=False, **extra):
"""
Send a DELETE request to the server.
"""
@@ -365,7 +382,10 @@ class Client(object):
}
r.update(extra)
return self.request(**r)
response = self.request(**r)
if follow:
response = self._handle_redirects(response)
return response
def login(self, **credentials):
"""
@@ -416,3 +436,27 @@ class Client(object):
session = __import__(settings.SESSION_ENGINE, {}, {}, ['']).SessionStore()
session.delete(session_key=self.cookies[settings.SESSION_COOKIE_NAME].value)
self.cookies = SimpleCookie()
def _handle_redirects(self, response):
"Follows any redirects by requesting responses from the server using GET."
response.redirect_chain = []
while response.status_code in (301, 302, 303, 307):
url = response['Location']
scheme, netloc, path, query, fragment = urlsplit(url)
redirect_chain = response.redirect_chain
redirect_chain.append((url, response.status_code))
# The test client doesn't handle external links,
# but since the situation is simulated in test_client,
# we fake things here by ignoring the netloc portion of the
# redirected URL.
response = self.get(path, QueryDict(query), follow=False)
response.redirect_chain = redirect_chain
# Prevent loops
if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
break
return response

View File

@@ -43,7 +43,7 @@ def disable_transaction_methods():
transaction.savepoint_commit = nop
transaction.savepoint_rollback = nop
transaction.enter_transaction_management = nop
transaction.leave_transaction_management = nop
transaction.leave_transaction_management = nop
def restore_transaction_methods():
transaction.commit = real_commit
@@ -198,7 +198,7 @@ class DocTestRunner(doctest.DocTestRunner):
# Rollback, in case of database errors. Otherwise they'd have
# side effects on other tests.
transaction.rollback_unless_managed()
class TransactionTestCase(unittest.TestCase):
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
@@ -242,7 +242,7 @@ class TransactionTestCase(unittest.TestCase):
import sys
result.addError(self, sys.exc_info())
return
super(TransactionTestCase, self).__call__(result)
super(TransactionTestCase, self).__call__(result)
try:
self._post_teardown()
except (KeyboardInterrupt, SystemExit):
@@ -263,7 +263,7 @@ class TransactionTestCase(unittest.TestCase):
def _fixture_teardown(self):
pass
def _urlconf_teardown(self):
def _urlconf_teardown(self):
if hasattr(self, '_old_root_urlconf'):
settings.ROOT_URLCONF = self._old_root_urlconf
clear_url_caches()
@@ -276,25 +276,48 @@ class TransactionTestCase(unittest.TestCase):
Note that assertRedirects won't work for external links since it uses
TestClient to do a request.
"""
self.assertEqual(response.status_code, status_code,
("Response didn't redirect as expected: Response code was %d"
" (expected %d)" % (response.status_code, status_code)))
url = response['Location']
scheme, netloc, path, query, fragment = urlsplit(url)
if hasattr(response, 'redirect_chain'):
# The request was a followed redirect
self.assertTrue(len(response.redirect_chain) > 0,
("Response didn't redirect as expected: Response code was %d"
" (expected %d)" % (response.status_code, status_code)))
self.assertEqual(response.redirect_chain[0][1], status_code,
("Initial response didn't redirect as expected: Response code was %d"
" (expected %d)" % (response.redirect_chain[0][1], status_code)))
url, status_code = response.redirect_chain[-1]
self.assertEqual(response.status_code, target_status_code,
("Response didn't redirect as expected: Final Response code was %d"
" (expected %d)" % (response.status_code, target_status_code)))
else:
# Not a followed redirect
self.assertEqual(response.status_code, status_code,
("Response didn't redirect as expected: Response code was %d"
" (expected %d)" % (response.status_code, status_code)))
url = response['Location']
scheme, netloc, path, query, fragment = urlsplit(url)
redirect_response = response.client.get(path, QueryDict(query))
# Get the redirection page, using the same client that was used
# to obtain the original response.
self.assertEqual(redirect_response.status_code, target_status_code,
("Couldn't retrieve redirection page '%s': response code was %d"
" (expected %d)") %
(path, redirect_response.status_code, target_status_code))
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
if not (e_scheme or e_netloc):
expected_url = urlunsplit(('http', host or 'testserver', e_path,
e_query, e_fragment))
e_query, e_fragment))
self.assertEqual(url, expected_url,
"Response redirected to '%s', expected '%s'" % (url, expected_url))
# Get the redirection page, using the same client that was used
# to obtain the original response.
redirect_response = response.client.get(path, QueryDict(query))
self.assertEqual(redirect_response.status_code, target_status_code,
("Couldn't retrieve redirection page '%s': response code was %d"
" (expected %d)") %
(path, redirect_response.status_code, target_status_code))
def assertContains(self, response, text, count=None, status_code=200):
"""
@@ -401,15 +424,15 @@ class TransactionTestCase(unittest.TestCase):
class TestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase, but surrounds every test
with a transaction, monkey-patches the real transaction management routines to
do nothing, and rollsback the test transaction at the end of the test. You have
with a transaction, monkey-patches the real transaction management routines to
do nothing, and rollsback the test transaction at the end of the test. You have
to use TransactionTestCase, if you need transaction management inside a test.
"""
def _fixture_setup(self):
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
return super(TestCase, self)._fixture_setup()
transaction.enter_transaction_management()
transaction.managed(True)
disable_transaction_methods()
@@ -426,7 +449,7 @@ class TestCase(TransactionTestCase):
def _fixture_teardown(self):
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
return super(TestCase, self)._fixture_teardown()
restore_transaction_methods()
transaction.rollback()
transaction.leave_transaction_management()