From 92b2dec918d7b73456899ecb5726cc03b86bd068 Mon Sep 17 00:00:00 2001
From: Claude Paroz <claude@2xlibre.net>
Date: Thu, 9 Aug 2012 20:08:47 +0200
Subject: [PATCH] [py3] Made signing infrastructure pass tests with Python 3

---
 django/core/signing.py                 | 19 +++++++++----------
 tests/regressiontests/signing/tests.py | 10 +++++++---
 2 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/django/core/signing.py b/django/core/signing.py
index 9ab8c5b8b0..6fc76bc201 100644
--- a/django/core/signing.py
+++ b/django/core/signing.py
@@ -32,6 +32,8 @@ start of the base64 JSON.
 There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
 These functions make use of all of them.
 """
+from __future__ import unicode_literals
+
 import base64
 import json
 import time
@@ -41,7 +43,7 @@ from django.conf import settings
 from django.core.exceptions import ImproperlyConfigured
 from django.utils import baseconv
 from django.utils.crypto import constant_time_compare, salted_hmac
-from django.utils.encoding import force_text, smart_bytes
+from django.utils.encoding import smart_bytes
 from django.utils.importlib import import_module
 
 
@@ -60,12 +62,12 @@ class SignatureExpired(BadSignature):
 
 
 def b64_encode(s):
-    return base64.urlsafe_b64encode(s).strip('=')
+    return base64.urlsafe_b64encode(smart_bytes(s)).decode('ascii').strip('=')
 
 
 def b64_decode(s):
     pad = '=' * (-len(s) % 4)
-    return base64.urlsafe_b64decode(s + pad)
+    return base64.urlsafe_b64decode(smart_bytes(s + pad)).decode('ascii')
 
 
 def base64_hmac(salt, value, key):
@@ -121,7 +123,7 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer,
 
     if compress:
         # Avoid zlib dependency unless compress is being used
-        compressed = zlib.compress(data)
+        compressed = zlib.compress(smart_bytes(data))
         if len(compressed) < (len(data) - 1):
             data = compressed
             is_compressed = True
@@ -135,8 +137,7 @@ def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, ma
     """
     Reverse of dumps(), raises BadSignature if signature fails
     """
-    base64d = smart_bytes(
-        TimestampSigner(key, salt=salt).unsign(s, max_age=max_age))
+    base64d = TimestampSigner(key, salt=salt).unsign(s, max_age=max_age)
     decompress = False
     if base64d[0] == '.':
         # It's compressed; uncompress it first
@@ -159,16 +160,14 @@ class Signer(object):
         return base64_hmac(self.salt + 'signer', value, self.key)
 
     def sign(self, value):
-        value = smart_bytes(value)
         return '%s%s%s' % (value, self.sep, self.signature(value))
 
     def unsign(self, signed_value):
-        signed_value = smart_bytes(signed_value)
         if not self.sep in signed_value:
             raise BadSignature('No "%s" found in value' % self.sep)
         value, sig = signed_value.rsplit(self.sep, 1)
         if constant_time_compare(sig, self.signature(value)):
-            return force_text(value)
+            return value
         raise BadSignature('Signature "%s" does not match' % sig)
 
 
@@ -178,7 +177,7 @@ class TimestampSigner(Signer):
         return baseconv.base62.encode(int(time.time()))
 
     def sign(self, value):
-        value = smart_bytes('%s%s%s' % (value, self.sep, self.timestamp()))
+        value = '%s%s%s' % (value, self.sep, self.timestamp())
         return '%s%s%s' % (value, self.sep, self.signature(value))
 
     def unsign(self, value, max_age=None):
diff --git a/tests/regressiontests/signing/tests.py b/tests/regressiontests/signing/tests.py
index 2368405060..7145ec8b18 100644
--- a/tests/regressiontests/signing/tests.py
+++ b/tests/regressiontests/signing/tests.py
@@ -4,6 +4,7 @@ import time
 
 from django.core import signing
 from django.test import TestCase
+from django.utils import six
 from django.utils.encoding import force_text
 
 
@@ -69,15 +70,18 @@ class TestSigner(TestCase):
 
     def test_dumps_loads(self):
         "dumps and loads be reversible for any JSON serializable object"
-        objects = (
+        objects = [
             ['a', 'list'],
-            b'a string',
             'a unicode string \u2019',
             {'a': 'dictionary'},
-        )
+        ]
+        if not six.PY3:
+            objects.append(b'a byte string')
         for o in objects:
             self.assertNotEqual(o, signing.dumps(o))
             self.assertEqual(o, signing.loads(signing.dumps(o)))
+            self.assertNotEqual(o, signing.dumps(o, compress=True))
+            self.assertEqual(o, signing.loads(signing.dumps(o, compress=True)))
 
     def test_decode_detects_tampering(self):
         "loads should raise exception for tampered objects"