279 lines
8.9 KiB
Python
279 lines
8.9 KiB
Python
import time
|
|
import pickle
|
|
import hashlib
|
|
import unittest
|
|
from datetime import datetime
|
|
|
|
import itsdangerous as idmod
|
|
from itsdangerous import want_bytes, text_type, PY2
|
|
|
|
|
|
# Helper function for some unsafe string manipulation on encoded
|
|
# data. This is required for Python 3 but would break on Python 2
|
|
if PY2:
|
|
def _coerce_string(reference_string, value):
|
|
return value
|
|
else:
|
|
def _coerce_string(reference_string, value):
|
|
assert isinstance(value, text_type), 'rhs needs to be a string'
|
|
if type(reference_string) != type(value):
|
|
value = value.encode('utf-8')
|
|
return value
|
|
|
|
|
|
class UtilityTestCase(unittest.TestCase):
|
|
def test_want_bytes(self):
|
|
self.assertEqual(want_bytes(b"foobar"), b"foobar")
|
|
self.assertEqual(want_bytes(u"foobar"), b"foobar")
|
|
|
|
|
|
class SerializerTestCase(unittest.TestCase):
|
|
serializer_class = idmod.Serializer
|
|
|
|
def make_serializer(self, *args, **kwargs):
|
|
return self.serializer_class(*args, **kwargs)
|
|
|
|
def test_dumps_loads(self):
|
|
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
|
|
{'a': 'dictionary'}, 42, 42.5)
|
|
s = self.make_serializer('Test')
|
|
for o in objects:
|
|
value = s.dumps(o)
|
|
self.assertNotEqual(o, value)
|
|
self.assertEqual(o, s.loads(value))
|
|
|
|
def test_decode_detects_tampering(self):
|
|
s = self.make_serializer('Test')
|
|
|
|
transforms = (
|
|
lambda s: s.upper(),
|
|
lambda s: s + _coerce_string(s, 'a'),
|
|
lambda s: _coerce_string(s, 'a') + s[1:],
|
|
lambda s: s.replace(_coerce_string(s, '.'), _coerce_string(s, '')),
|
|
)
|
|
value = {
|
|
'foo': 'bar',
|
|
'baz': 1,
|
|
}
|
|
encoded = s.dumps(value)
|
|
self.assertEqual(value, s.loads(encoded))
|
|
for transform in transforms:
|
|
self.assertRaises(
|
|
idmod.BadSignature, s.loads, transform(encoded))
|
|
|
|
def test_accepts_unicode(self):
|
|
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
|
|
{'a': 'dictionary'}, 42, 42.5)
|
|
s = self.make_serializer('Test')
|
|
for o in objects:
|
|
value = s.dumps(o)
|
|
self.assertNotEqual(o, value)
|
|
self.assertEqual(o, s.loads(value))
|
|
|
|
def test_exception_attributes(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
|
|
try:
|
|
s.loads(ts + _coerce_string(ts, 'x'))
|
|
except idmod.BadSignature as e:
|
|
self.assertEqual(want_bytes(e.payload),
|
|
want_bytes(ts).rsplit(b'.', 1)[0])
|
|
self.assertEqual(s.load_payload(e.payload), value)
|
|
else:
|
|
self.fail('Did not get bad signature')
|
|
|
|
def test_unsafe_load(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
self.assertEqual(s.loads_unsafe(ts), (True, u'hello'))
|
|
self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello'))
|
|
|
|
def test_load_unsafe_with_unicode_strings(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
self.assertEqual(s.loads_unsafe(ts), (True, u'hello'))
|
|
self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello'))
|
|
|
|
try:
|
|
s.loads(ts, salt='modified')
|
|
except idmod.BadSignature as e:
|
|
self.assertEqual(s.load_payload(e.payload), u'hello')
|
|
|
|
def test_signer_kwargs(self):
|
|
secret_key = 'predictable-key'
|
|
value = 'hello'
|
|
s = self.make_serializer(secret_key, signer_kwargs=dict(
|
|
digest_method=hashlib.md5,
|
|
key_derivation='hmac'
|
|
))
|
|
ts = s.dumps(value)
|
|
self.assertEqual(s.loads(ts), u'hello')
|
|
|
|
|
|
class TimedSerializerTestCase(SerializerTestCase):
|
|
serializer_class = idmod.TimedSerializer
|
|
|
|
def setUp(self):
|
|
self._time = time.time
|
|
time.time = lambda: idmod.EPOCH
|
|
|
|
def tearDown(self):
|
|
time.time = self._time
|
|
|
|
def test_decode_with_timeout(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
self.assertNotEqual(ts, idmod.Serializer(secret_key).dumps(value))
|
|
|
|
self.assertEqual(s.loads(ts), value)
|
|
time.time = lambda: idmod.EPOCH + 10
|
|
self.assertEqual(s.loads(ts, max_age=11), value)
|
|
self.assertEqual(s.loads(ts, max_age=10), value)
|
|
self.assertRaises(
|
|
idmod.SignatureExpired, s.loads, ts, max_age=9)
|
|
|
|
def test_decode_return_timestamp(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
loaded, timestamp = s.loads(ts, return_timestamp=True)
|
|
self.assertEqual(loaded, value)
|
|
self.assertEqual(timestamp, datetime.utcfromtimestamp(time.time()))
|
|
|
|
def test_exception_attributes(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
try:
|
|
s.loads(ts, max_age=-1)
|
|
except idmod.SignatureExpired as e:
|
|
self.assertEqual(e.date_signed,
|
|
datetime.utcfromtimestamp(time.time()))
|
|
self.assertEqual(want_bytes(e.payload),
|
|
want_bytes(ts).rsplit(b'.', 2)[0])
|
|
self.assertEqual(s.load_payload(e.payload), value)
|
|
else:
|
|
self.fail('Did not get expiration')
|
|
|
|
|
|
class JSONWebSignatureSerializerTestCase(SerializerTestCase):
|
|
serializer_class = idmod.JSONWebSignatureSerializer
|
|
|
|
def test_decode_return_header(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
header = {"typ": "dummy"}
|
|
|
|
s = self.make_serializer(secret_key)
|
|
full_header = header.copy()
|
|
full_header['alg'] = s.algorithm_name
|
|
|
|
ts = s.dumps(value, header_fields=header)
|
|
loaded, loaded_header = s.loads(ts, return_header=True)
|
|
self.assertEqual(loaded, value)
|
|
self.assertEqual(loaded_header, full_header)
|
|
|
|
def test_hmac_algorithms(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
algorithms = ('HS256', 'HS384', 'HS512')
|
|
for algorithm in algorithms:
|
|
s = self.make_serializer(secret_key, algorithm_name=algorithm)
|
|
ts = s.dumps(value)
|
|
self.assertEqual(s.loads(ts), value)
|
|
|
|
def test_none_algorithm(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key)
|
|
ts = s.dumps(value)
|
|
self.assertEqual(s.loads(ts), value)
|
|
|
|
def test_algorithm_mismatch(self):
|
|
secret_key = 'predictable-key'
|
|
value = u'hello'
|
|
|
|
s = self.make_serializer(secret_key, algorithm_name='HS256')
|
|
ts = s.dumps(value)
|
|
|
|
s = self.make_serializer(secret_key, algorithm_name='HS384')
|
|
try:
|
|
s.loads(ts)
|
|
except idmod.BadSignature as e:
|
|
self.assertEqual(s.load_payload(e.payload), value)
|
|
else:
|
|
self.fail('Did not get algorithm mismatch')
|
|
|
|
|
|
class URLSafeSerializerMixin(object):
|
|
|
|
def test_is_base62(self):
|
|
allowed = frozenset(b'0123456789abcdefghijklmnopqrstuvwxyz' +
|
|
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.')
|
|
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
|
|
{'a': 'dictionary'}, 42, 42.5)
|
|
s = self.make_serializer('Test')
|
|
for o in objects:
|
|
value = want_bytes(s.dumps(o))
|
|
self.assertTrue(set(value).issubset(set(allowed)))
|
|
self.assertNotEqual(o, value)
|
|
self.assertEqual(o, s.loads(value))
|
|
|
|
def test_invalid_base64_does_not_fail_load_payload(self):
|
|
s = idmod.URLSafeSerializer('aha!')
|
|
self.assertRaises(idmod.BadPayload, s.load_payload, b'kZ4m3du844lIN')
|
|
|
|
|
|
class PickleSerializerMixin(object):
|
|
|
|
def make_serializer(self, *args, **kwargs):
|
|
kwargs.setdefault('serializer', pickle)
|
|
return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs)
|
|
|
|
|
|
class URLSafeSerializerTestCase(URLSafeSerializerMixin, SerializerTestCase):
|
|
serializer_class = idmod.URLSafeSerializer
|
|
|
|
|
|
class URLSafeTimedSerializerTestCase(URLSafeSerializerMixin, TimedSerializerTestCase):
|
|
serializer_class = idmod.URLSafeTimedSerializer
|
|
|
|
|
|
class PickleSerializerTestCase(PickleSerializerMixin, SerializerTestCase):
|
|
pass
|
|
|
|
|
|
class PickleTimedSerializerTestCase(PickleSerializerMixin, TimedSerializerTestCase):
|
|
pass
|
|
|
|
|
|
class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTestCase):
|
|
pass
|
|
|
|
|
|
class PickleURLSafeTimedSerializerTestCase(PickleSerializerMixin, URLSafeTimedSerializerTestCase):
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|