mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Fixed #4476 -- Added a `follow` option to the test client request methods. This implements browser-like behavior for the test client, following redirect chains when a 30X response is received. Thanks to Marc Fargas and Keith Bussell for their work on this.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@9911 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import urllib
|
||||
from urlparse import urlparse, urlunparse
|
||||
from urlparse import urlparse, urlunparse, urlsplit
|
||||
import sys
|
||||
import os
|
||||
try:
|
||||
@@ -12,7 +12,7 @@ from django.contrib.auth import authenticate, login
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.signals import got_request_exception
|
||||
from django.http import SimpleCookie, HttpRequest
|
||||
from django.http import SimpleCookie, HttpRequest, QueryDict
|
||||
from django.template import TemplateDoesNotExist
|
||||
from django.test import signals
|
||||
from django.utils.functional import curry
|
||||
@@ -261,7 +261,7 @@ class Client(object):
|
||||
|
||||
return response
|
||||
|
||||
def get(self, path, data={}, **extra):
|
||||
def get(self, path, data={}, follow=False, **extra):
|
||||
"""
|
||||
Requests a response from the server using GET.
|
||||
"""
|
||||
@@ -275,9 +275,13 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def post(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
|
||||
def post(self, path, data={}, content_type=MULTIPART_CONTENT,
|
||||
follow=False, **extra):
|
||||
"""
|
||||
Requests a response from the server using POST.
|
||||
"""
|
||||
@@ -297,9 +301,12 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def head(self, path, data={}, **extra):
|
||||
def head(self, path, data={}, follow=False, **extra):
|
||||
"""
|
||||
Request a response from the server using HEAD.
|
||||
"""
|
||||
@@ -313,9 +320,12 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def options(self, path, data={}, **extra):
|
||||
def options(self, path, data={}, follow=False, **extra):
|
||||
"""
|
||||
Request a response from the server using OPTIONS.
|
||||
"""
|
||||
@@ -328,9 +338,13 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def put(self, path, data={}, content_type=MULTIPART_CONTENT, **extra):
|
||||
def put(self, path, data={}, content_type=MULTIPART_CONTENT,
|
||||
follow=False, **extra):
|
||||
"""
|
||||
Send a resource to the server using PUT.
|
||||
"""
|
||||
@@ -350,9 +364,12 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def delete(self, path, data={}, **extra):
|
||||
def delete(self, path, data={}, follow=False, **extra):
|
||||
"""
|
||||
Send a DELETE request to the server.
|
||||
"""
|
||||
@@ -365,7 +382,10 @@ class Client(object):
|
||||
}
|
||||
r.update(extra)
|
||||
|
||||
return self.request(**r)
|
||||
response = self.request(**r)
|
||||
if follow:
|
||||
response = self._handle_redirects(response)
|
||||
return response
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
@@ -416,3 +436,27 @@ class Client(object):
|
||||
session = __import__(settings.SESSION_ENGINE, {}, {}, ['']).SessionStore()
|
||||
session.delete(session_key=self.cookies[settings.SESSION_COOKIE_NAME].value)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _handle_redirects(self, response):
|
||||
"Follows any redirects by requesting responses from the server using GET."
|
||||
|
||||
response.redirect_chain = []
|
||||
while response.status_code in (301, 302, 303, 307):
|
||||
url = response['Location']
|
||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||
|
||||
redirect_chain = response.redirect_chain
|
||||
redirect_chain.append((url, response.status_code))
|
||||
|
||||
# The test client doesn't handle external links,
|
||||
# but since the situation is simulated in test_client,
|
||||
# we fake things here by ignoring the netloc portion of the
|
||||
# redirected URL.
|
||||
response = self.get(path, QueryDict(query), follow=False)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
# Prevent loops
|
||||
if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
|
||||
break
|
||||
return response
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def disable_transaction_methods():
|
||||
transaction.savepoint_commit = nop
|
||||
transaction.savepoint_rollback = nop
|
||||
transaction.enter_transaction_management = nop
|
||||
transaction.leave_transaction_management = nop
|
||||
transaction.leave_transaction_management = nop
|
||||
|
||||
def restore_transaction_methods():
|
||||
transaction.commit = real_commit
|
||||
@@ -198,7 +198,7 @@ class DocTestRunner(doctest.DocTestRunner):
|
||||
# Rollback, in case of database errors. Otherwise they'd have
|
||||
# side effects on other tests.
|
||||
transaction.rollback_unless_managed()
|
||||
|
||||
|
||||
class TransactionTestCase(unittest.TestCase):
|
||||
def _pre_setup(self):
|
||||
"""Performs any pre-test setup. This includes:
|
||||
@@ -242,7 +242,7 @@ class TransactionTestCase(unittest.TestCase):
|
||||
import sys
|
||||
result.addError(self, sys.exc_info())
|
||||
return
|
||||
super(TransactionTestCase, self).__call__(result)
|
||||
super(TransactionTestCase, self).__call__(result)
|
||||
try:
|
||||
self._post_teardown()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
@@ -263,7 +263,7 @@ class TransactionTestCase(unittest.TestCase):
|
||||
def _fixture_teardown(self):
|
||||
pass
|
||||
|
||||
def _urlconf_teardown(self):
|
||||
def _urlconf_teardown(self):
|
||||
if hasattr(self, '_old_root_urlconf'):
|
||||
settings.ROOT_URLCONF = self._old_root_urlconf
|
||||
clear_url_caches()
|
||||
@@ -276,25 +276,48 @@ class TransactionTestCase(unittest.TestCase):
|
||||
Note that assertRedirects won't work for external links since it uses
|
||||
TestClient to do a request.
|
||||
"""
|
||||
self.assertEqual(response.status_code, status_code,
|
||||
("Response didn't redirect as expected: Response code was %d"
|
||||
" (expected %d)" % (response.status_code, status_code)))
|
||||
url = response['Location']
|
||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||
if hasattr(response, 'redirect_chain'):
|
||||
# The request was a followed redirect
|
||||
self.assertTrue(len(response.redirect_chain) > 0,
|
||||
("Response didn't redirect as expected: Response code was %d"
|
||||
" (expected %d)" % (response.status_code, status_code)))
|
||||
|
||||
self.assertEqual(response.redirect_chain[0][1], status_code,
|
||||
("Initial response didn't redirect as expected: Response code was %d"
|
||||
" (expected %d)" % (response.redirect_chain[0][1], status_code)))
|
||||
|
||||
url, status_code = response.redirect_chain[-1]
|
||||
|
||||
self.assertEqual(response.status_code, target_status_code,
|
||||
("Response didn't redirect as expected: Final Response code was %d"
|
||||
" (expected %d)" % (response.status_code, target_status_code)))
|
||||
|
||||
else:
|
||||
# Not a followed redirect
|
||||
self.assertEqual(response.status_code, status_code,
|
||||
("Response didn't redirect as expected: Response code was %d"
|
||||
" (expected %d)" % (response.status_code, status_code)))
|
||||
|
||||
url = response['Location']
|
||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||
|
||||
redirect_response = response.client.get(path, QueryDict(query))
|
||||
|
||||
# Get the redirection page, using the same client that was used
|
||||
# to obtain the original response.
|
||||
self.assertEqual(redirect_response.status_code, target_status_code,
|
||||
("Couldn't retrieve redirection page '%s': response code was %d"
|
||||
" (expected %d)") %
|
||||
(path, redirect_response.status_code, target_status_code))
|
||||
|
||||
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
|
||||
if not (e_scheme or e_netloc):
|
||||
expected_url = urlunsplit(('http', host or 'testserver', e_path,
|
||||
e_query, e_fragment))
|
||||
e_query, e_fragment))
|
||||
|
||||
self.assertEqual(url, expected_url,
|
||||
"Response redirected to '%s', expected '%s'" % (url, expected_url))
|
||||
|
||||
# Get the redirection page, using the same client that was used
|
||||
# to obtain the original response.
|
||||
redirect_response = response.client.get(path, QueryDict(query))
|
||||
self.assertEqual(redirect_response.status_code, target_status_code,
|
||||
("Couldn't retrieve redirection page '%s': response code was %d"
|
||||
" (expected %d)") %
|
||||
(path, redirect_response.status_code, target_status_code))
|
||||
|
||||
def assertContains(self, response, text, count=None, status_code=200):
|
||||
"""
|
||||
@@ -401,15 +424,15 @@ class TransactionTestCase(unittest.TestCase):
|
||||
class TestCase(TransactionTestCase):
|
||||
"""
|
||||
Does basically the same as TransactionTestCase, but surrounds every test
|
||||
with a transaction, monkey-patches the real transaction management routines to
|
||||
do nothing, and rollsback the test transaction at the end of the test. You have
|
||||
with a transaction, monkey-patches the real transaction management routines to
|
||||
do nothing, and rollsback the test transaction at the end of the test. You have
|
||||
to use TransactionTestCase, if you need transaction management inside a test.
|
||||
"""
|
||||
|
||||
def _fixture_setup(self):
|
||||
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
|
||||
return super(TestCase, self)._fixture_setup()
|
||||
|
||||
|
||||
transaction.enter_transaction_management()
|
||||
transaction.managed(True)
|
||||
disable_transaction_methods()
|
||||
@@ -426,7 +449,7 @@ class TestCase(TransactionTestCase):
|
||||
def _fixture_teardown(self):
|
||||
if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
|
||||
return super(TestCase, self)._fixture_teardown()
|
||||
|
||||
|
||||
restore_transaction_methods()
|
||||
transaction.rollback()
|
||||
transaction.leave_transaction_management()
|
||||
|
||||
Reference in New Issue
Block a user