From 90477a3b6e73da69740e00b8161f53fea19b831f Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Tue, 5 Mar 2024 16:57:17 -0500 Subject: [PATCH] Address potential DoS with high compression ratio Fixes CVE-2024-28102 Signed-off-by: Simo Sorce --- jwcrypto/jwe.py | 7 +++++++ jwcrypto/tests.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 9412881..5df500b 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -9,5 +10,8 @@ from jwcrypto.jwa import JWA +# Limit the amount of data we are willing to decompress by default. +default_max_compressed_size = 256 * 1024 + # RFC 7516 - 4.1 # name: (description, supported?) @@ -374,6 +374,10 @@ def _decrypt(self, key, ppe): compress = jh.get('zip', None) if compress == 'DEF': + if len(data) > default_max_compressed_size: + raise InvalidJWEData( + 'Compressed data exceeds maximum allowed' + 'size' + f' ({default_max_compressed_size})') self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS) elif compress is None: self.plaintext = data diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index bb2ff10..59049f8 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -1196,6 +1196,32 @@ def test_pbes2_hs256_aeskw_custom_params(self): check.deserialize(enc, key) self.assertEqual(b'plain', check.payload) + def test_jwe_decompression_max(self): + key = jwk.JWK(kty='oct', k=base64url_encode(b'A' * (128 // 8))) + payload = '{"u": "' + "u" * 400000000 + '", "uu":"' \ + + "u" * 400000000 + '"}' + protected_header = { + "alg": "A128KW", + "enc": "A128GCM", + "typ": "JWE", + "zip": "DEF", + } + enc = jwe.JWE(payload.encode('utf-8'), + recipient=key, + protected=protected_header).serialize(compact=True) + with self.assertRaises(jwe.InvalidJWEData): + check = jwe.JWE() + check.deserialize(enc) + check.decrypt(key) + + defmax = jwe.default_max_compressed_size + jwe.default_max_compressed_size = 1000000000 + # ensure we can eraise the limit and decrypt + check = jwe.JWE() + check.deserialize(enc) + check.decrypt(key) + jwe.default_max_compressed_size = defmax + class JWATests(unittest.TestCase): def test_jwa_create(self):