1
0
mirror of https://github.com/django/django.git synced 2025-10-25 14:46:09 +00:00

Fixed #20004 -- Moved non DB-related assertions to SimpleTestCase.

Thanks zalew for the suggestion and work on a patch.

Also updated, tweaked and fixed testing documentation.
This commit is contained in:
Ramiro Morales
2013-05-18 19:04:34 -03:00
parent 69523c1ba3
commit 0a50311063
9 changed files with 425 additions and 370 deletions

View File

@@ -231,6 +231,10 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
class SimpleTestCase(ut2.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
client_class = Client
_warn_txt = ("save_warnings_state/restore_warnings_state "
"django.test.*TestCase methods are deprecated. Use Python's "
"warnings.catch_warnings context manager instead.")
@@ -264,10 +268,31 @@ class SimpleTestCase(ut2.TestCase):
return
def _pre_setup(self):
pass
"""Performs any pre-test setup. This includes:
* If the Test Case class has a 'urls' member, replace the
ROOT_URLCONF with it.
* Clearing the mail test outbox.
"""
self.client = self.client_class()
self._urlconf_setup()
mail.outbox = []
def _urlconf_setup(self):
set_urlconf(None)
if hasattr(self, 'urls'):
self._old_root_urlconf = settings.ROOT_URLCONF
settings.ROOT_URLCONF = self.urls
clear_url_caches()
def _post_teardown(self):
pass
self._urlconf_teardown()
def _urlconf_teardown(self):
set_urlconf(None)
if hasattr(self, '_old_root_urlconf'):
settings.ROOT_URLCONF = self._old_root_urlconf
clear_url_caches()
def save_warnings_state(self):
"""
@@ -291,258 +316,6 @@ class SimpleTestCase(ut2.TestCase):
"""
return override_settings(**kwargs)
def assertRaisesMessage(self, expected_exception, expected_message,
callable_obj=None, *args, **kwargs):
"""
Asserts that the message in a raised exception matches the passed
value.
Args:
expected_exception: Exception class expected to be raised.
expected_message: expected error message string value.
callable_obj: Function to be called.
args: Extra args.
kwargs: Extra kwargs.
"""
return six.assertRaisesRegex(self, expected_exception,
re.escape(expected_message), callable_obj, *args, **kwargs)
def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
field_kwargs=None, empty_value=''):
"""
Asserts that a form field behaves correctly with various inputs.
Args:
fieldclass: the class of the field to be tested.
valid: a dictionary mapping valid inputs to their expected
cleaned values.
invalid: a dictionary mapping invalid inputs to one or more
raised error messages.
field_args: the args passed to instantiate the field
field_kwargs: the kwargs passed to instantiate the field
empty_value: the expected clean output for inputs in empty_values
"""
if field_args is None:
field_args = []
if field_kwargs is None:
field_kwargs = {}
required = fieldclass(*field_args, **field_kwargs)
optional = fieldclass(*field_args,
**dict(field_kwargs, required=False))
# test valid inputs
for input, output in valid.items():
self.assertEqual(required.clean(input), output)
self.assertEqual(optional.clean(input), output)
# test invalid inputs
for input, errors in invalid.items():
with self.assertRaises(ValidationError) as context_manager:
required.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
with self.assertRaises(ValidationError) as context_manager:
optional.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
# test required inputs
error_required = [force_text(required.error_messages['required'])]
for e in required.empty_values:
with self.assertRaises(ValidationError) as context_manager:
required.clean(e)
self.assertEqual(context_manager.exception.messages,
error_required)
self.assertEqual(optional.clean(e), empty_value)
# test that max_length and min_length are always accepted
if issubclass(fieldclass, CharField):
field_kwargs.update({'min_length':2, 'max_length':20})
self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
fieldclass))
def assertHTMLEqual(self, html1, html2, msg=None):
"""
Asserts that two HTML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid HTML.
"""
dom1 = assert_and_parse_html(self, html1, msg,
'First argument is not valid HTML:')
dom2 = assert_and_parse_html(self, html2, msg,
'Second argument is not valid HTML:')
if dom1 != dom2:
standardMsg = '%s != %s' % (
safe_repr(dom1, True), safe_repr(dom2, True))
diff = ('\n' + '\n'.join(difflib.ndiff(
six.text_type(dom1).splitlines(),
six.text_type(dom2).splitlines())))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
def assertHTMLNotEqual(self, html1, html2, msg=None):
"""Asserts that two HTML snippets are not semantically equivalent."""
dom1 = assert_and_parse_html(self, html1, msg,
'First argument is not valid HTML:')
dom2 = assert_and_parse_html(self, html2, msg,
'Second argument is not valid HTML:')
if dom1 == dom2:
standardMsg = '%s == %s' % (
safe_repr(dom1, True), safe_repr(dom2, True))
self.fail(self._formatMessage(msg, standardMsg))
def assertInHTML(self, needle, haystack, count = None, msg_prefix=''):
needle = assert_and_parse_html(self, needle, None,
'First argument is not valid HTML:')
haystack = assert_and_parse_html(self, haystack, None,
'Second argument is not valid HTML:')
real_count = haystack.count(needle)
if count is not None:
self.assertEqual(real_count, count,
msg_prefix + "Found %d instances of '%s' in response"
" (expected %d)" % (real_count, needle, count))
else:
self.assertTrue(real_count != 0,
msg_prefix + "Couldn't find '%s' in response" % needle)
def assertJSONEqual(self, raw, expected_data, msg=None):
try:
data = json.loads(raw)
except ValueError:
self.fail("First argument is not valid JSON: %r" % raw)
if isinstance(expected_data, six.string_types):
try:
expected_data = json.loads(expected_data)
except ValueError:
self.fail("Second argument is not valid JSON: %r" % expected_data)
self.assertEqual(data, expected_data, msg=msg)
def assertXMLEqual(self, xml1, xml2, msg=None):
"""
Asserts that two XML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = 'First or second argument is not valid XML\n%s' % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if not result:
standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
self.fail(self._formatMessage(msg, standardMsg))
def assertXMLNotEqual(self, xml1, xml2, msg=None):
"""
Asserts that two XML snippets are not semantically equivalent.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = 'First or second argument is not valid XML\n%s' % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if result:
standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
self.fail(self._formatMessage(msg, standardMsg))
class TransactionTestCase(SimpleTestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
client_class = Client
# Subclasses can ask for resetting of auto increment sequence before each
# test case
reset_sequences = False
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
* Flushing the database.
* If the Test Case class has a 'fixtures' member, installing the
named fixtures.
* If the Test Case class has a 'urls' member, replace the
ROOT_URLCONF with it.
* Clearing the mail test outbox.
"""
self.client = self.client_class()
self._fixture_setup()
self._urlconf_setup()
mail.outbox = []
def _databases_names(self, include_mirrors=True):
# If the test case has a multi_db=True flag, act on all databases,
# including mirrors or not. Otherwise, just on the default DB.
if getattr(self, 'multi_db', False):
return [alias for alias in connections
if include_mirrors or not connections[alias].settings_dict['TEST_MIRROR']]
else:
return [DEFAULT_DB_ALIAS]
def _reset_sequences(self, db_name):
conn = connections[db_name]
if conn.features.supports_sequence_reset:
sql_list = \
conn.ops.sequence_reset_by_name_sql(no_style(),
conn.introspection.sequence_list())
if sql_list:
with transaction.commit_on_success_unless_managed(using=db_name):
cursor = conn.cursor()
for sql in sql_list:
cursor.execute(sql)
def _fixture_setup(self):
for db_name in self._databases_names(include_mirrors=False):
# Reset sequences
if self.reset_sequences:
self._reset_sequences(db_name)
if hasattr(self, 'fixtures'):
# We have to use this slightly awkward syntax due to the fact
# that we're using *args and **kwargs together.
call_command('loaddata', *self.fixtures,
**{'verbosity': 0, 'database': db_name, 'skip_validation': True})
def _urlconf_setup(self):
set_urlconf(None)
if hasattr(self, 'urls'):
self._old_root_urlconf = settings.ROOT_URLCONF
settings.ROOT_URLCONF = self.urls
clear_url_caches()
def _post_teardown(self):
""" Performs any post-test things. This includes:
* Putting back the original ROOT_URLCONF if it was changed.
* Force closing the connection, so that the next test gets
a clean cursor.
"""
self._fixture_teardown()
self._urlconf_teardown()
# Some DB cursors include SQL statements as part of cursor
# creation. If you have a test that does rollback, the effect
# of these statements is lost, which can effect the operation
# of tests (e.g., losing a timezone setting causing objects to
# be created with the wrong time).
# To make sure this doesn't happen, get a clean connection at the
# start of every test.
for conn in connections.all():
conn.close()
def _fixture_teardown(self):
for db in self._databases_names(include_mirrors=False):
call_command('flush', verbosity=0, interactive=False, database=db,
skip_validation=True, reset_sequences=False)
def _urlconf_teardown(self):
set_urlconf(None)
if hasattr(self, '_old_root_urlconf'):
settings.ROOT_URLCONF = self._old_root_urlconf
clear_url_caches()
def assertRedirects(self, response, expected_url, status_code=302,
target_status_code=200, host=None, msg_prefix=''):
"""Asserts that a response redirected to a specific URL, and that the
@@ -787,6 +560,236 @@ class TransactionTestCase(SimpleTestCase):
msg_prefix + "Template '%s' was used unexpectedly in rendering"
" the response" % template_name)
def assertRaisesMessage(self, expected_exception, expected_message,
callable_obj=None, *args, **kwargs):
"""
Asserts that the message in a raised exception matches the passed
value.
Args:
expected_exception: Exception class expected to be raised.
expected_message: expected error message string value.
callable_obj: Function to be called.
args: Extra args.
kwargs: Extra kwargs.
"""
return six.assertRaisesRegex(self, expected_exception,
re.escape(expected_message), callable_obj, *args, **kwargs)
def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
field_kwargs=None, empty_value=''):
"""
Asserts that a form field behaves correctly with various inputs.
Args:
fieldclass: the class of the field to be tested.
valid: a dictionary mapping valid inputs to their expected
cleaned values.
invalid: a dictionary mapping invalid inputs to one or more
raised error messages.
field_args: the args passed to instantiate the field
field_kwargs: the kwargs passed to instantiate the field
empty_value: the expected clean output for inputs in empty_values
"""
if field_args is None:
field_args = []
if field_kwargs is None:
field_kwargs = {}
required = fieldclass(*field_args, **field_kwargs)
optional = fieldclass(*field_args,
**dict(field_kwargs, required=False))
# test valid inputs
for input, output in valid.items():
self.assertEqual(required.clean(input), output)
self.assertEqual(optional.clean(input), output)
# test invalid inputs
for input, errors in invalid.items():
with self.assertRaises(ValidationError) as context_manager:
required.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
with self.assertRaises(ValidationError) as context_manager:
optional.clean(input)
self.assertEqual(context_manager.exception.messages, errors)
# test required inputs
error_required = [force_text(required.error_messages['required'])]
for e in required.empty_values:
with self.assertRaises(ValidationError) as context_manager:
required.clean(e)
self.assertEqual(context_manager.exception.messages,
error_required)
self.assertEqual(optional.clean(e), empty_value)
# test that max_length and min_length are always accepted
if issubclass(fieldclass, CharField):
field_kwargs.update({'min_length':2, 'max_length':20})
self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
fieldclass))
def assertHTMLEqual(self, html1, html2, msg=None):
"""
Asserts that two HTML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid HTML.
"""
dom1 = assert_and_parse_html(self, html1, msg,
'First argument is not valid HTML:')
dom2 = assert_and_parse_html(self, html2, msg,
'Second argument is not valid HTML:')
if dom1 != dom2:
standardMsg = '%s != %s' % (
safe_repr(dom1, True), safe_repr(dom2, True))
diff = ('\n' + '\n'.join(difflib.ndiff(
six.text_type(dom1).splitlines(),
six.text_type(dom2).splitlines())))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))
def assertHTMLNotEqual(self, html1, html2, msg=None):
"""Asserts that two HTML snippets are not semantically equivalent."""
dom1 = assert_and_parse_html(self, html1, msg,
'First argument is not valid HTML:')
dom2 = assert_and_parse_html(self, html2, msg,
'Second argument is not valid HTML:')
if dom1 == dom2:
standardMsg = '%s == %s' % (
safe_repr(dom1, True), safe_repr(dom2, True))
self.fail(self._formatMessage(msg, standardMsg))
def assertInHTML(self, needle, haystack, count=None, msg_prefix=''):
needle = assert_and_parse_html(self, needle, None,
'First argument is not valid HTML:')
haystack = assert_and_parse_html(self, haystack, None,
'Second argument is not valid HTML:')
real_count = haystack.count(needle)
if count is not None:
self.assertEqual(real_count, count,
msg_prefix + "Found %d instances of '%s' in response"
" (expected %d)" % (real_count, needle, count))
else:
self.assertTrue(real_count != 0,
msg_prefix + "Couldn't find '%s' in response" % needle)
def assertJSONEqual(self, raw, expected_data, msg=None):
try:
data = json.loads(raw)
except ValueError:
self.fail("First argument is not valid JSON: %r" % raw)
if isinstance(expected_data, six.string_types):
try:
expected_data = json.loads(expected_data)
except ValueError:
self.fail("Second argument is not valid JSON: %r" % expected_data)
self.assertEqual(data, expected_data, msg=msg)
def assertXMLEqual(self, xml1, xml2, msg=None):
"""
Asserts that two XML snippets are semantically the same.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = 'First or second argument is not valid XML\n%s' % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if not result:
standardMsg = '%s != %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
self.fail(self._formatMessage(msg, standardMsg))
def assertXMLNotEqual(self, xml1, xml2, msg=None):
"""
Asserts that two XML snippets are not semantically equivalent.
Whitespace in most cases is ignored, and attribute ordering is not
significant. The passed-in arguments must be valid XML.
"""
try:
result = compare_xml(xml1, xml2)
except Exception as e:
standardMsg = 'First or second argument is not valid XML\n%s' % e
self.fail(self._formatMessage(msg, standardMsg))
else:
if result:
standardMsg = '%s == %s' % (safe_repr(xml1, True), safe_repr(xml2, True))
self.fail(self._formatMessage(msg, standardMsg))
class TransactionTestCase(SimpleTestCase):
# Subclasses can ask for resetting of auto increment sequence before each
# test case
reset_sequences = False
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
* Flushing the database.
* If the Test Case class has a 'fixtures' member, installing the
named fixtures.
"""
super(TransactionTestCase, self)._pre_setup()
self._fixture_setup()
def _databases_names(self, include_mirrors=True):
# If the test case has a multi_db=True flag, act on all databases,
# including mirrors or not. Otherwise, just on the default DB.
if getattr(self, 'multi_db', False):
return [alias for alias in connections
if include_mirrors or not connections[alias].settings_dict['TEST_MIRROR']]
else:
return [DEFAULT_DB_ALIAS]
def _reset_sequences(self, db_name):
conn = connections[db_name]
if conn.features.supports_sequence_reset:
sql_list = \
conn.ops.sequence_reset_by_name_sql(no_style(),
conn.introspection.sequence_list())
if sql_list:
with transaction.commit_on_success_unless_managed(using=db_name):
cursor = conn.cursor()
for sql in sql_list:
cursor.execute(sql)
def _fixture_setup(self):
for db_name in self._databases_names(include_mirrors=False):
# Reset sequences
if self.reset_sequences:
self._reset_sequences(db_name)
if hasattr(self, 'fixtures'):
# We have to use this slightly awkward syntax due to the fact
# that we're using *args and **kwargs together.
call_command('loaddata', *self.fixtures,
**{'verbosity': 0, 'database': db_name, 'skip_validation': True})
def _post_teardown(self):
"""Performs any post-test things. This includes:
* Putting back the original ROOT_URLCONF if it was changed.
* Force closing the connection, so that the next test gets
a clean cursor.
"""
self._fixture_teardown()
super(TransactionTestCase, self)._post_teardown()
# Some DB cursors include SQL statements as part of cursor
# creation. If you have a test that does rollback, the effect
# of these statements is lost, which can effect the operation
# of tests (e.g., losing a timezone setting causing objects to
# be created with the wrong time).
# To make sure this doesn't happen, get a clean connection at the
# start of every test.
for conn in connections.all():
conn.close()
def _fixture_teardown(self):
for db_name in self._databases_names(include_mirrors=False):
call_command('flush', verbosity=0, interactive=False, database=db_name,
skip_validation=True, reset_sequences=False)
def assertQuerysetEqual(self, qs, values, transform=repr, ordered=True):
items = six.moves.map(transform, qs)
if not ordered:
@@ -841,14 +844,14 @@ class TestCase(TransactionTestCase):
# Remove this when the legacy transaction management goes away.
disable_transaction_methods()
for db in self._databases_names(include_mirrors=False):
for db_name in self._databases_names(include_mirrors=False):
if hasattr(self, 'fixtures'):
try:
call_command('loaddata', *self.fixtures,
**{
'verbosity': 0,
'commit': False,
'database': db,
'database': db_name,
'skip_validation': True,
})
except Exception: