import time import pickle import hashlib import unittest from datetime import datetime import itsdangerous as idmod from itsdangerous import want_bytes, text_type, PY2 # Helper function for some unsafe string manipulation on encoded # data. This is required for Python 3 but would break on Python 2 if PY2: def _coerce_string(reference_string, value): return value else: def _coerce_string(reference_string, value): assert isinstance(value, text_type), 'rhs needs to be a string' if type(reference_string) != type(value): value = value.encode('utf-8') return value class UtilityTestCase(unittest.TestCase): def test_want_bytes(self): self.assertEqual(want_bytes(b"foobar"), b"foobar") self.assertEqual(want_bytes(u"foobar"), b"foobar") class SerializerTestCase(unittest.TestCase): serializer_class = idmod.Serializer def make_serializer(self, *args, **kwargs): return self.serializer_class(*args, **kwargs) def test_dumps_loads(self): objects = (['a', 'list'], 'a string', u'a unicode string \u2019', {'a': 'dictionary'}, 42, 42.5) s = self.make_serializer('Test') for o in objects: value = s.dumps(o) self.assertNotEqual(o, value) self.assertEqual(o, s.loads(value)) def test_decode_detects_tampering(self): s = self.make_serializer('Test') transforms = ( lambda s: s.upper(), lambda s: s + _coerce_string(s, 'a'), lambda s: _coerce_string(s, 'a') + s[1:], lambda s: s.replace(_coerce_string(s, '.'), _coerce_string(s, '')), ) value = { 'foo': 'bar', 'baz': 1, } encoded = s.dumps(value) self.assertEqual(value, s.loads(encoded)) for transform in transforms: self.assertRaises( idmod.BadSignature, s.loads, transform(encoded)) def test_accepts_unicode(self): objects = (['a', 'list'], 'a string', u'a unicode string \u2019', {'a': 'dictionary'}, 42, 42.5) s = self.make_serializer('Test') for o in objects: value = s.dumps(o) self.assertNotEqual(o, value) self.assertEqual(o, s.loads(value)) def test_exception_attributes(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) try: s.loads(ts + _coerce_string(ts, 'x')) except idmod.BadSignature as e: self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b'.', 1)[0]) self.assertEqual(s.load_payload(e.payload), value) else: self.fail('Did not get bad signature') def test_unsafe_load(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) self.assertEqual(s.loads_unsafe(ts), (True, u'hello')) self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello')) def test_load_unsafe_with_unicode_strings(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) self.assertEqual(s.loads_unsafe(ts), (True, u'hello')) self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello')) try: s.loads(ts, salt='modified') except idmod.BadSignature as e: self.assertEqual(s.load_payload(e.payload), u'hello') def test_signer_kwargs(self): secret_key = 'predictable-key' value = 'hello' s = self.make_serializer(secret_key, signer_kwargs=dict( digest_method=hashlib.md5, key_derivation='hmac' )) ts = s.dumps(value) self.assertEqual(s.loads(ts), u'hello') class TimedSerializerTestCase(SerializerTestCase): serializer_class = idmod.TimedSerializer def setUp(self): self._time = time.time time.time = lambda: idmod.EPOCH def tearDown(self): time.time = self._time def test_decode_with_timeout(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) self.assertNotEqual(ts, idmod.Serializer(secret_key).dumps(value)) self.assertEqual(s.loads(ts), value) time.time = lambda: idmod.EPOCH + 10 self.assertEqual(s.loads(ts, max_age=11), value) self.assertEqual(s.loads(ts, max_age=10), value) self.assertRaises( idmod.SignatureExpired, s.loads, ts, max_age=9) def test_decode_return_timestamp(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) loaded, timestamp = s.loads(ts, return_timestamp=True) self.assertEqual(loaded, value) self.assertEqual(timestamp, datetime.utcfromtimestamp(time.time())) def test_exception_attributes(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) try: s.loads(ts, max_age=-1) except idmod.SignatureExpired as e: self.assertEqual(e.date_signed, datetime.utcfromtimestamp(time.time())) self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b'.', 2)[0]) self.assertEqual(s.load_payload(e.payload), value) else: self.fail('Did not get expiration') class JSONWebSignatureSerializerTestCase(SerializerTestCase): serializer_class = idmod.JSONWebSignatureSerializer def test_decode_return_header(self): secret_key = 'predictable-key' value = u'hello' header = {"typ": "dummy"} s = self.make_serializer(secret_key) full_header = header.copy() full_header['alg'] = s.algorithm_name ts = s.dumps(value, header_fields=header) loaded, loaded_header = s.loads(ts, return_header=True) self.assertEqual(loaded, value) self.assertEqual(loaded_header, full_header) def test_hmac_algorithms(self): secret_key = 'predictable-key' value = u'hello' algorithms = ('HS256', 'HS384', 'HS512') for algorithm in algorithms: s = self.make_serializer(secret_key, algorithm_name=algorithm) ts = s.dumps(value) self.assertEqual(s.loads(ts), value) def test_none_algorithm(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key) ts = s.dumps(value) self.assertEqual(s.loads(ts), value) def test_algorithm_mismatch(self): secret_key = 'predictable-key' value = u'hello' s = self.make_serializer(secret_key, algorithm_name='HS256') ts = s.dumps(value) s = self.make_serializer(secret_key, algorithm_name='HS384') try: s.loads(ts) except idmod.BadSignature as e: self.assertEqual(s.load_payload(e.payload), value) else: self.fail('Did not get algorithm mismatch') class TimedJSONWebSignatureSerializerTest(unittest.TestCase): serializer_class = idmod.TimedJSONWebSignatureSerializer def test_token_contains_issue_date_and_expiry_time(self): s = self.serializer_class('secret') result = s.dumps({'es': 'geht'}) self.assertTrue('exp' in s.loads(result, return_header=True)[1]) self.assertTrue('iat' in s.loads(result, return_header=True)[1]) def test_token_expires_at_given_expiry_time(self): s = self.serializer_class('secret') an_hour_ago = int(time.time()) - 3601 s.now = lambda: an_hour_ago result = s.dumps({'foo': 'bar'}) s = self.serializer_class('secret') self.assertRaises(idmod.SignatureExpired, s.loads, result) def test_token_is_invalid_if_expiry_time_is_missing(self): bad_s = idmod.JSONWebSignatureSerializer('secret') invalid_token_empty = bad_s.dumps({}) s = self.serializer_class('secret') self.assertRaises(idmod.BadSignature, s.loads, invalid_token_empty) def test_token_is_invalid_if_expiry_time_is_negative(self): s = self.serializer_class('secret', expires_in=-123) result = s.dumps({'foo': 'bar'}) self.assertRaises(idmod.BadSignature, s.loads, result) def test_creating_a_token_adds_the_expiry_date(self): expires_in_two_hours = 7200 s = self.serializer_class('secret', expires_in=expires_in_two_hours) result, header = s.loads(s.dumps({'foo': 'bar'}), return_header=True) self.assertEqual(header['exp'] - header['iat'], expires_in_two_hours) class URLSafeSerializerMixin(object): def test_is_base62(self): allowed = frozenset(b'0123456789abcdefghijklmnopqrstuvwxyz' + b'ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.') objects = (['a', 'list'], 'a string', u'a unicode string \u2019', {'a': 'dictionary'}, 42, 42.5) s = self.make_serializer('Test') for o in objects: value = want_bytes(s.dumps(o)) self.assertTrue(set(value).issubset(set(allowed))) self.assertNotEqual(o, value) self.assertEqual(o, s.loads(value)) def test_invalid_base64_does_not_fail_load_payload(self): s = idmod.URLSafeSerializer('aha!') self.assertRaises(idmod.BadPayload, s.load_payload, b'kZ4m3du844lIN') class PickleSerializerMixin(object): def make_serializer(self, *args, **kwargs): kwargs.setdefault('serializer', pickle) return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs) class URLSafeSerializerTestCase(URLSafeSerializerMixin, SerializerTestCase): serializer_class = idmod.URLSafeSerializer class URLSafeTimedSerializerTestCase(URLSafeSerializerMixin, TimedSerializerTestCase): serializer_class = idmod.URLSafeTimedSerializer class PickleSerializerTestCase(PickleSerializerMixin, SerializerTestCase): pass class PickleTimedSerializerTestCase(PickleSerializerMixin, TimedSerializerTestCase): pass class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTestCase): pass class PickleURLSafeTimedSerializerTestCase(PickleSerializerMixin, URLSafeTimedSerializerTestCase): pass if __name__ == '__main__': unittest.main()