From 24d26b7daa46f92d304f47ee93f9b82342f0a4db Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Mon, 14 Nov 2022 16:54:47 +0400 Subject: [PATCH 1/4] Standardize HTTPResponse.read(X) behavior regardless of compression Co-authored-by: Franek Magiera (cherry picked from commit c35033f6cc54106ca66ef2d48a9e3564d4fb0e07) --- changelog/2128.removal.rst | 7 + src/urllib3/response.py | 169 +++++++++++++++++----- test/test_response.py | 147 ++++++++++++++++--- test/with_dummyserver/test_socketlevel.py | 9 +- 4 files changed, 271 insertions(+), 61 deletions(-) create mode 100644 changelog/2128.removal.rst diff --git a/changelog/2128.removal.rst b/changelog/2128.removal.rst new file mode 100644 index 00000000..abf412f9 --- /dev/null +++ b/changelog/2128.removal.rst @@ -0,0 +1,7 @@ +Standardized :meth:`~urllib3.response.HTTPResponse.read` to respect the semantics of BufferedIOBase regardless of compression. Specifically, this method: + +* only returns an empty bytes object to indicate EOF (that is, the response has been fully consumed), +* never returns more bytes than requested, +* can issue any number of system calls: zero, one or multiple. + +If you want each :meth:`~urllib3.response.HTTPResponse.read` call to issue a single system call, you need to disable decompression by setting ``decode_content=False``. diff --git a/src/urllib3/response.py b/src/urllib3/response.py index c112690b..92d23a46 100644 --- a/src/urllib3/response.py +++ b/src/urllib3/response.py @@ -1,8 +1,10 @@ from __future__ import absolute_import from contextlib import contextmanager +import collections import zlib import io import logging +import sys from socket import timeout as SocketTimeout from socket import error as SocketError @@ -121,6 +123,63 @@ def _get_decoder(mode): return DeflateDecoder() +class BytesQueueBuffer: + """Memory-efficient bytes buffer + + To return decoded data in read() and still follow the BufferedIOBase API, we need a + buffer to always return the correct amount of bytes. + + This buffer should be filled using calls to put() + + Our maximum memory usage is determined by the sum of the size of: + + * self.buffer, which contains the full data + * the largest chunk that we will copy in get() + + The worst case scenario is a single chunk, in which case we'll make a full copy of + the data inside get(). + """ + + def __init__(self): + self.buffer = collections.deque() + self._size = 0 + + def __len__(self): + return self._size + + def put(self, data): + self.buffer.append(data) + self._size += len(data) + + def get(self, n): + if not self.buffer: + raise RuntimeError("buffer is empty") + elif n < 0: + raise ValueError("n should be > 0") + + fetched = 0 + ret = io.BytesIO() + while fetched < n: + remaining = n - fetched + chunk = self.buffer.popleft() + chunk_length = len(chunk) + if remaining < chunk_length: + left_chunk, right_chunk = chunk[:remaining], chunk[remaining:] + ret.write(left_chunk) + self.buffer.appendleft(right_chunk) + self._size -= remaining + break + else: + ret.write(chunk) + self._size -= chunk_length + fetched += chunk_length + + if not self.buffer: + break + + return ret.getvalue() + + class HTTPResponse(io.IOBase): """ HTTP Response container. @@ -204,6 +263,9 @@ class HTTPResponse(io.IOBase): # Determine length of response self.length_remaining = self._init_length(request_method) + # Used to return the correct amount of bytes for partial read()s + self._decoded_buffer = BytesQueueBuffer() + # If requested, preload the body. if preload_content and not self._body: self._body = self.read(decode_content=decode_content) @@ -401,6 +463,49 @@ class HTTPResponse(io.IOBase): if self._original_response and self._original_response.isclosed(): self.release_conn() + def _raw_read(self, amt=None): + """ + Reads `amt` of bytes from the socket. + """ + if self._fp is None: + return None # type: ignore[return-value] + + fp_closed = getattr(self._fp, "closed", False) + + with self._error_catcher(): + if amt is None: + # cStringIO doesn't like amt=None + data = self._fp.read() if not fp_closed else b"" + else: + data = self._fp.read(amt) if not fp_closed else b"" + if amt is not None and amt != 0 and not data: + # Platform-specific: Buggy versions of Python. + # Close the connection when no data is returned + # + # This is redundant to what httplib/http.client _should_ + # already do. However, versions of python released before + # December 15, 2012 (http://bugs.python.org/issue16298) do + # not properly close the connection in all cases. There is + # no harm in redundantly calling close. + self._fp.close() + if ( + self.enforce_content_length + and self.length_remaining is not None + and self.length_remaining != 0 + ): + # This is an edge case that httplib failed to cover due + # to concerns of backward compatibility. We're + # addressing it here to make sure IncompleteRead is + # raised during streaming, so all calls with incorrect + # Content-Length are caught. + raise IncompleteRead(self._fp_bytes_read, self.length_remaining) + + if data: + self._fp_bytes_read += len(data) + if self.length_remaining is not None: + self.length_remaining -= len(data) + return data + def read(self, amt=None, decode_content=None, cache_content=False): """ Similar to :meth:`httplib.HTTPResponse.read`, but with two additional @@ -426,47 +531,43 @@ class HTTPResponse(io.IOBase): if decode_content is None: decode_content = self.decode_content - if self._fp is None: - return + if amt is not None: + cache_content = False + + if len(self._decoded_buffer) >= amt: + return self._decoded_buffer.get(amt) + + data = self._raw_read(amt) flush_decoder = False - data = None + if amt is None: + flush_decoder = True + elif amt != 0 and not data: + flush_decoder = True - with self._error_catcher(): - if amt is None: - # cStringIO doesn't like amt=None - data = self._fp.read() - flush_decoder = True - else: - cache_content = False - data = self._fp.read(amt) - if amt != 0 and not data: # Platform-specific: Buggy versions of Python. - # Close the connection when no data is returned - # - # This is redundant to what httplib/http.client _should_ - # already do. However, versions of python released before - # December 15, 2012 (http://bugs.python.org/issue16298) do - # not properly close the connection in all cases. There is - # no harm in redundantly calling close. - self._fp.close() - flush_decoder = True - if self.enforce_content_length and self.length_remaining not in (0, None): - # This is an edge case that httplib failed to cover due - # to concerns of backward compatibility. We're - # addressing it here to make sure IncompleteRead is - # raised during streaming, so all calls with incorrect - # Content-Length are caught. - raise IncompleteRead(self._fp_bytes_read, self.length_remaining) - - if data: - self._fp_bytes_read += len(data) - if self.length_remaining is not None: - self.length_remaining -= len(data) + if not data and len(self._decoded_buffer) == 0: + return data + if amt is None: data = self._decode(data, decode_content, flush_decoder) - if cache_content: self._body = data + else: + # do not waste memory on buffer when not decoding + if not decode_content: + return data + + decoded_data = self._decode(data, decode_content, flush_decoder) + self._decoded_buffer.put(decoded_data) + + while len(self._decoded_buffer) < amt and data: + # TODO make sure to initially read enough data to get past the headers + # For example, the GZ file header takes 10 bytes, we don't want to read + # it one byte at a time + data = self._raw_read(amt) + decoded_data = self._decode(data, decode_content, flush_decoder) + self._decoded_buffer.put(decoded_data) + data = self._decoded_buffer.get(amt) return data diff --git a/test/test_response.py b/test/test_response.py index ae70dff9..a84a4885 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,4 +1,6 @@ import socket +import ssl +import sys import zlib from io import BytesIO, BufferedReader @@ -6,7 +8,7 @@ from io import BytesIO, BufferedReader import pytest import mock -from urllib3.response import HTTPResponse +from urllib3.response import HTTPResponse, BytesQueueBuffer from urllib3.exceptions import ( DecodeError, ResponseNotChunked, ProtocolError, InvalidHeader ) @@ -16,6 +18,56 @@ from urllib3.util.response import is_fp_closed from base64 import b64decode + +class TestBytesQueueBuffer: + def test_single_chunk(self): + buffer = BytesQueueBuffer() + assert len(buffer) == 0 + with pytest.raises(RuntimeError, match="buffer is empty"): + assert buffer.get(10) + + buffer.put(b"foo") + with pytest.raises(ValueError, match="n should be > 0"): + buffer.get(-1) + + assert buffer.get(1) == b"f" + assert buffer.get(2) == b"oo" + with pytest.raises(RuntimeError, match="buffer is empty"): + assert buffer.get(10) + + def test_read_too_much(self): + buffer = BytesQueueBuffer() + buffer.put(b"foo") + assert buffer.get(100) == b"foo" + + def test_multiple_chunks(self): + buffer = BytesQueueBuffer() + buffer.put(b"foo") + buffer.put(b"bar") + buffer.put(b"baz") + assert len(buffer) == 9 + + assert buffer.get(1) == b"f" + assert len(buffer) == 8 + assert buffer.get(4) == b"ooba" + assert len(buffer) == 4 + assert buffer.get(4) == b"rbaz" + assert len(buffer) == 0 + + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="pytest-memray requires Python 3.8+" + ) + @pytest.mark.limit_memory("12.5 MB") # assert that we're not doubling memory usage + def test_memory_usage(self): + # Allocate 10 1MiB chunks + buffer = BytesQueueBuffer() + for i in range(10): + # This allocates 2MiB, putting the max at around 12MiB. Not sure why. + buffer.put(bytes(2**20)) + + assert len(buffer.get(10 * 2**20)) == 10 * 2**20 + + # A known random (i.e, not-too-compressible) payload generated with: # "".join(random.choice(string.printable) for i in xrange(512)) # .encode("zlib").encode("base64") @@ -119,12 +171,7 @@ class TestResponse(object): r = HTTPResponse(fp, headers={'content-encoding': 'deflate'}, preload_content=False) - assert r.read(3) == b'' - # Buffer in case we need to switch to the raw stream - assert r._decoder._data is not None assert r.read(1) == b'f' - # Now that we've decoded data, we just stream through the decoder - assert r._decoder._data is None assert r.read(2) == b'oo' assert r.read() == b'' assert r.read() == b'' @@ -138,10 +185,7 @@ class TestResponse(object): r = HTTPResponse(fp, headers={'content-encoding': 'deflate'}, preload_content=False) - assert r.read(1) == b'' assert r.read(1) == b'f' - # Once we've decoded data, we just stream to the decoder; no buffering - assert r._decoder._data is None assert r.read(2) == b'oo' assert r.read() == b'' assert r.read() == b'' @@ -155,7 +199,6 @@ class TestResponse(object): r = HTTPResponse(fp, headers={'content-encoding': 'gzip'}, preload_content=False) - assert r.read(11) == b'' assert r.read(1) == b'f' assert r.read(2) == b'oo' assert r.read() == b'' @@ -240,6 +283,23 @@ class TestResponse(object): assert r.data == b'foo' + def test_read_multi_decoding_deflate_deflate(self): + msg = b"foobarbaz" * 42 + data = zlib.compress(zlib.compress(msg)) + + fp = BytesIO(data) + r = HTTPResponse( + fp, headers={"content-encoding": "deflate, deflate"}, preload_content=False + ) + + assert r.read(3) == b"foo" + assert r.read(3) == b"bar" + assert r.read(3) == b"baz" + assert r.read(9) == b"foobarbaz" + assert r.read(9 * 3) == b"foobarbaz" * 3 + assert r.read(9 * 37) == b"foobarbaz" * 37 + assert r.read() == b"" + def test_body_blob(self): resp = HTTPResponse(b'foo') assert resp.data == b'foo' @@ -361,8 +421,8 @@ class TestResponse(object): preload_content=False) stream = resp.stream(2) - assert next(stream) == b'f' - assert next(stream) == b'oo' + assert next(stream) == b'fo' + assert next(stream) == b'o' with pytest.raises(StopIteration): next(stream) @@ -390,6 +450,7 @@ class TestResponse(object): # Ensure that ``tell()`` returns the correct number of bytes when # part-way through streaming compressed content. NUMBER_OF_READS = 10 + PART_SIZE = 64 class MockCompressedDataReading(BytesIO): """ @@ -416,7 +477,7 @@ class TestResponse(object): fp = MockCompressedDataReading(ZLIB_PAYLOAD, payload_part_size) resp = HTTPResponse(fp, headers={'content-encoding': 'deflate'}, preload_content=False) - stream = resp.stream() + stream = resp.stream(PART_SIZE) parts_positions = [(part, resp.tell()) for part in stream] end_of_stream = resp.tell() @@ -431,12 +492,28 @@ class TestResponse(object): assert uncompressed_data == payload # Check that the positions in the stream are correct - expected = [(i+1)*payload_part_size for i in range(NUMBER_OF_READS)] - assert expected == list(positions) + # It is difficult to determine programatically what the positions + # returned by `tell` will be because the `HTTPResponse.read` method may + # call socket `read` a couple of times if it doesn't have enough data + # in the buffer or not call socket `read` at all if it has enough. All + # this depends on the message, how it was compressed, what is + # `PART_SIZE` and `payload_part_size`. + # So for simplicity the expected values are hardcoded. + expected = (92, 184, 230, 276, 322, 368, 414, 460) + assert expected == positions # Check that the end of the stream is in the correct place assert len(ZLIB_PAYLOAD) == end_of_stream + # Check that all parts have expected length + expected_last_part_size = len(uncompressed_data) % PART_SIZE + whole_parts = len(uncompressed_data) // PART_SIZE + if expected_last_part_size == 0: + expected_lengths = [PART_SIZE] * whole_parts + else: + expected_lengths = [PART_SIZE] * whole_parts + [expected_last_part_size] + assert expected_lengths == [len(part) for part in parts] + def test_deflate_streaming(self): data = zlib.compress(b'foo') @@ -445,8 +522,8 @@ class TestResponse(object): preload_content=False) stream = resp.stream(2) - assert next(stream) == b'f' - assert next(stream) == b'oo' + assert next(stream) == b'fo' + assert next(stream) == b'o' with pytest.raises(StopIteration): next(stream) @@ -460,8 +537,8 @@ class TestResponse(object): preload_content=False) stream = resp.stream(2) - assert next(stream) == b'f' - assert next(stream) == b'oo' + assert next(stream) == b'fo' + assert next(stream) == b'o' with pytest.raises(StopIteration): next(stream) @@ -473,6 +550,38 @@ class TestResponse(object): with pytest.raises(StopIteration): next(stream) + @pytest.mark.parametrize( + "preload_content, amt", + [(True, None), (False, None), (False, 10 * 2**20)], + ) + @pytest.mark.limit_memory("25 MB") + def test_buffer_memory_usage_decode_one_chunk( + self, preload_content, amt + ): + content_length = 10 * 2**20 # 10 MiB + fp = BytesIO(zlib.compress(bytes(content_length))) + resp = HTTPResponse( + fp, + preload_content=preload_content, + headers={"content-encoding": "deflate"}, + ) + data = resp.data if preload_content else resp.read(amt) + assert len(data) == content_length + + @pytest.mark.parametrize( + "preload_content, amt", + [(True, None), (False, None), (False, 10 * 2**20)], + ) + @pytest.mark.limit_memory("10.5 MB") + def test_buffer_memory_usage_no_decoding( + self, preload_content, amt + ): + content_length = 10 * 2**20 # 10 MiB + fp = BytesIO(bytes(content_length)) + resp = HTTPResponse(fp, preload_content=preload_content, decode_content=False) + data = resp.data if preload_content else resp.read(amt) + assert len(data) == content_length + def test_length_no_header(self): fp = BytesIO(b'12345') resp = HTTPResponse(fp, preload_content=False) diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index 8f2cf9e9..24990eac 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -1445,15 +1445,8 @@ class TestBadContentLength(SocketDummyServerTestCase): get_response = conn.request('GET', url='/', preload_content=False, enforce_content_length=True) data = get_response.stream(100) - # Read "good" data before we try to read again. - # This won't trigger till generator is exhausted. - next(data) - try: + with pytest.raises(ProtocolError, match="12 bytes read, 10 more expected"): next(data) - self.assertFail() - except ProtocolError as e: - self.assertIn('12 bytes read, 10 more expected', str(e)) - done_event.set() def test_enforce_content_length_no_body(self): -- 2.52.0 From adae14464c96a72570f7db22ea6053c5dd6c97e9 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Wed, 3 May 2023 15:46:21 -0500 Subject: [PATCH 2/4] Continue reading the response stream if there is buffered decompressed data (cherry picked from commit 4714836a667eb4837d005eb89d34fae60b9dc6cc) --- changelog/3009.bugfix | 3 ++ src/urllib3/response.py | 2 +- test/with_dummyserver/test_socketlevel.py | 49 +++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 changelog/3009.bugfix diff --git a/changelog/3009.bugfix b/changelog/3009.bugfix new file mode 100644 index 00000000..61f54a49 --- /dev/null +++ b/changelog/3009.bugfix @@ -0,0 +1,3 @@ +Fixed ``HTTPResponse.stream()`` to continue yielding bytes if buffered decompressed data +was still available to be read even if the underlying socket is closed. This prevents +a compressed response from being truncated. diff --git a/src/urllib3/response.py b/src/urllib3/response.py index 92d23a46..97278564 100644 --- a/src/urllib3/response.py +++ b/src/urllib3/response.py @@ -591,7 +591,7 @@ class HTTPResponse(io.IOBase): for line in self.read_chunked(amt, decode_content=decode_content): yield line else: - while not is_fp_closed(self._fp): + while not is_fp_closed(self._fp) or len(self._decoded_buffer) > 0: data = self.read(amt=amt, decode_content=decode_content) if data: diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index 24990eac..e2c40893 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -32,6 +32,7 @@ from threading import Event import select import socket import ssl +import zlib import pytest @@ -1415,6 +1416,54 @@ class TestStream(SocketDummyServerTestCase): done_event.set() + def test_large_compressed_stream(self): + done_event = Event() + expected_total_length = 296085 + + def socket_handler(listener: socket.socket): + compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) + data = compress.compress(b"x" * expected_total_length) + data += compress.flush() + + sock = listener.accept()[0] + + buf = b"" + while not buf.endswith(b"\r\n\r\n"): + buf += sock.recv(65536) + + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: %d\r\n" + b"Content-Encoding: gzip\r\n" + b"\r\n" % (len(data),) + data + ) + + done_event.wait(5) + sock.close() + + self._start_server(socket_handler) + + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("GET", "/", timeout=0.01, preload_content=False) + + # Chunks must all be equal or less than 10240 + # and only the last chunk is allowed to be smaller + # than 10240. + total_length = 0 + chunks_smaller_than_10240 = 0 + for chunk in r.stream(10240, decode_content=True): + assert 0 < len(chunk) <= 10240 + if len(chunk) < 10240: + chunks_smaller_than_10240 += 1 + else: + assert chunks_smaller_than_10240 == 0 + total_length += len(chunk) + + assert chunks_smaller_than_10240 == 1 + assert expected_total_length == total_length + + done_event.set() + class TestBadContentLength(SocketDummyServerTestCase): def test_enforce_content_length_get(self): -- 2.52.0 From 51e483688d776140cf268ced21b677db4f4366a4 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Sun, 30 Apr 2023 00:29:17 +0400 Subject: [PATCH 3/4] Fix HTTPResponse.read(0) when underlying buffer is empty (#2998) (cherry picked from commit 02ae65a45654bb3ced12b9ad22278c11e214aaf8) --- changelog/2998.bugfix.rst | 1 + src/urllib3/response.py | 2 ++ test/test_response.py | 3 +++ 3 files changed, 6 insertions(+) create mode 100644 changelog/2998.bugfix.rst diff --git a/changelog/2998.bugfix.rst b/changelog/2998.bugfix.rst new file mode 100644 index 00000000..584f309a --- /dev/null +++ b/changelog/2998.bugfix.rst @@ -0,0 +1 @@ +Fixed ``HTTPResponse.read(0)`` call when underlying buffer is empty. diff --git a/src/urllib3/response.py b/src/urllib3/response.py index 97278564..aaba4e50 100644 --- a/src/urllib3/response.py +++ b/src/urllib3/response.py @@ -152,6 +152,8 @@ class BytesQueueBuffer: self._size += len(data) def get(self, n): + if n == 0: + return b"" if not self.buffer: raise RuntimeError("buffer is empty") elif n < 0: diff --git a/test/test_response.py b/test/test_response.py index a84a4885..3f3427a8 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -26,6 +26,8 @@ class TestBytesQueueBuffer: with pytest.raises(RuntimeError, match="buffer is empty"): assert buffer.get(10) + assert buffer.get(0) == b"" + buffer.put(b"foo") with pytest.raises(ValueError, match="n should be > 0"): buffer.get(-1) @@ -143,6 +145,7 @@ class TestResponse(object): fp = BytesIO(b'foo') r = HTTPResponse(fp, preload_content=False) + assert r.read(0) == b'' assert r.read(1) == b'f' assert r.read(2) == b'oo' assert r.read() == b'' -- 2.52.0 From 88c9cfbd87a4b0e6db3b7d2a7fa3cfd20d4ca34d Mon Sep 17 00:00:00 2001 From: Illia Volochii Date: Fri, 5 Dec 2025 16:40:41 +0200 Subject: [PATCH 4/4] Security fix for CVE-2025-66471 (cherry picked from commit c19571de34c47de3a766541b041637ba5f716ed7) --- CHANGES.rst | 9 ++ docs/advanced-usage.rst | 4 +- docs/user-guide.rst | 4 +- src/urllib3/response.py | 177 +++++++++++++++++++++++++++++++++------- test/test_response.py | 164 +++++++++++++++++++++++++++++++++++++ 5 files changed, 324 insertions(+), 34 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 1e264e28..cf315aaf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,15 @@ Changes ======= +Backports +--------- + +- Fixed a security issue where streaming API could improperly handle highly + compressed HTTP content ("decompression bombs") leading to excessive resource + consumption even when a small amount of data was requested. Reading small + chunks of compressed data is safer and much more efficient now. + + 1.24.2 (2019-04-17) ------------------- diff --git a/docs/advanced-usage.rst b/docs/advanced-usage.rst index 283e9a45..d13b569a 100644 --- a/docs/advanced-usage.rst +++ b/docs/advanced-usage.rst @@ -52,8 +52,8 @@ multi-threaded applications. Streaming and IO ---------------- -When dealing with large responses it's often better to stream the response -content:: +When dealing with responses of large or unknown length, +it's often better to stream the response content:: >>> import urllib3 >>> http = urllib3.PoolManager() diff --git a/docs/user-guide.rst b/docs/user-guide.rst index 11c94f3e..9f22caf7 100644 --- a/docs/user-guide.rst +++ b/docs/user-guide.rst @@ -74,8 +74,8 @@ to a byte string representing the response content:: >>> r.data b'\xaa\xa5H?\x95\xe9\x9b\x11' -.. note:: For larger responses, it's sometimes better to :ref:`stream ` - the response. +.. note:: For responses of large or unknown length, it's sometimes better to + :ref:`stream ` the response. .. _request_data: diff --git a/src/urllib3/response.py b/src/urllib3/response.py index aaba4e50..60bd98c9 100644 --- a/src/urllib3/response.py +++ b/src/urllib3/response.py @@ -5,13 +5,14 @@ import zlib import io import logging import sys +import warnings from socket import timeout as SocketTimeout from socket import error as SocketError from ._collections import HTTPHeaderDict from .exceptions import ( BodyNotHttplibCompatible, ProtocolError, DecodeError, ReadTimeoutError, - ResponseNotChunked, IncompleteRead, InvalidHeader + ResponseNotChunked, IncompleteRead, InvalidHeader, DependencyWarning ) from .packages.six import string_types as basestring, PY3 from .packages.six.moves import http_client as httplib @@ -25,33 +26,60 @@ class DeflateDecoder(object): def __init__(self): self._first_try = True - self._data = b'' + self._first_try_data = b"" + self._unfed_data = b"" self._obj = zlib.decompressobj() def __getattr__(self, name): return getattr(self._obj, name) - def decompress(self, data): - if not data: + def decompress(self, data, max_length=-1): + data = self._unfed_data + data + self._unfed_data = b"" + if not data and not self._obj.unconsumed_tail: return data + original_max_length = max_length + if original_max_length < 0: + max_length = 0 + elif original_max_length == 0: + # We should not pass 0 to the zlib decompressor because 0 is + # the default value that will make zlib decompress without a + # length limit. + # Data should be stored for subsequent calls. + self._unfed_data = data + return b"" + # Subsequent calls always reuse `self._obj`. zlib requires + # passing the unconsumed tail if decompression is to continue. if not self._first_try: - return self._obj.decompress(data) + return self._obj.decompress( + self._obj.unconsumed_tail + data, max_length=max_length + ) - self._data += data + # First call tries with RFC 1950 ZLIB format. + self._first_try_data += data try: - decompressed = self._obj.decompress(data) + decompressed = self._obj.decompress(data, max_length=max_length) if decompressed: self._first_try = False - self._data = None + self._first_try_data = b"" return decompressed + # On failure, it falls back to RFC 1951 DEFLATE format. except zlib.error: self._first_try = False self._obj = zlib.decompressobj(-zlib.MAX_WBITS) try: - return self.decompress(self._data) + return self.decompress( + self._first_try_data, max_length=original_max_length + ) finally: - self._data = None + self._first_try_data = b"" + + @property + def has_unconsumed_tail(self): + return bool(self._unfed_data) or ( + bool(self._obj.unconsumed_tail) and not self._first_try + ) class GzipDecoderState(object): @@ -66,30 +94,64 @@ class GzipDecoder(object): def __init__(self): self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) self._state = GzipDecoderState.FIRST_MEMBER + self._unconsumed_tail = b"" def __getattr__(self, name): return getattr(self._obj, name) - def decompress(self, data): + def decompress(self, data, max_length=-1): ret = bytearray() - if self._state == GzipDecoderState.SWALLOW_DATA or not data: + if self._state == GzipDecoderState.SWALLOW_DATA: return bytes(ret) + + if max_length == 0: + # We should not pass 0 to the zlib decompressor because 0 is + # the default value that will make zlib decompress without a + # length limit. + # Data should be stored for subsequent calls. + self._unconsumed_tail += data + return b"" + + # zlib requires passing the unconsumed tail to the subsequent + # call if decompression is to continue. + data = self._unconsumed_tail + data + if not data and self._obj.eof: + return bytes(ret) + while True: try: - ret += self._obj.decompress(data) + ret += self._obj.decompress( + data, max_length=max(max_length - len(ret), 0) + ) except zlib.error: previous_state = self._state # Ignore data after the first error self._state = GzipDecoderState.SWALLOW_DATA + self._unconsumed_tail = b"" if previous_state == GzipDecoderState.OTHER_MEMBERS: # Allow trailing garbage acceptable in other gzip clients return bytes(ret) raise - data = self._obj.unused_data + + self._unconsumed_tail = data = ( + self._obj.unconsumed_tail or self._obj.unused_data + ) + if max_length > 0 and len(ret) >= max_length: + break + if not data: return bytes(ret) - self._state = GzipDecoderState.OTHER_MEMBERS - self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) + # When the end of a gzip member is reached, a new decompressor + # must be created for unused (possibly future) data. + if self._obj.eof: + self._state = GzipDecoderState.OTHER_MEMBERS + self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) + + return bytes(ret) + + @property + def has_unconsumed_tail(self): + return bool(self._unconsumed_tail) class MultiDecoder(object): @@ -107,10 +169,35 @@ class MultiDecoder(object): def flush(self): return self._decoders[0].flush() - def decompress(self, data): - for d in reversed(self._decoders): - data = d.decompress(data) - return data + def decompress(self, data, max_length=-1): + if max_length <= 0: + for d in reversed(self._decoders): + data = d.decompress(data) + return data + + ret = bytearray() + # Every while loop iteration goes through all decoders once. + # It exits when enough data is read or no more data can be read. + # It is possible that the while loop iteration does not produce + # any data because we retrieve up to `max_length` from every + # decoder, and the amount of bytes may be insufficient for the + # next decoder to produce enough/any output. + while True: + any_data = False + for d in reversed(self._decoders): + data = d.decompress(data, max_length=max_length - len(ret)) + if data: + any_data = True + # We should not break when no data is returned because + # next decoders may produce data even with empty input. + ret += data + if not any_data or len(ret) >= max_length: + return bytes(ret) + data = b"" + + @property + def has_unconsumed_tail(self): + return any(d.has_unconsumed_tail for d in self._decoders) def _get_decoder(mode): @@ -135,9 +222,6 @@ class BytesQueueBuffer: * self.buffer, which contains the full data * the largest chunk that we will copy in get() - - The worst case scenario is a single chunk, in which case we'll make a full copy of - the data inside get(). """ def __init__(self): @@ -159,6 +243,10 @@ class BytesQueueBuffer: elif n < 0: raise ValueError("n should be > 0") + if len(self.buffer[0]) == n and isinstance(self.buffer[0], bytes): + self._size -= n + return self.buffer.popleft() + fetched = 0 ret = io.BytesIO() while fetched < n: @@ -379,13 +467,16 @@ class HTTPResponse(io.IOBase): if len(encodings): self._decoder = _get_decoder(content_encoding) - def _decode(self, data, decode_content, flush_decoder): + def _decode(self, data, decode_content, flush_decoder, max_length=None): """ Decode the data passed in and potentially flush the decoder. """ + if max_length is None or flush_decoder: + max_length = -1 + try: if decode_content and self._decoder: - data = self._decoder.decompress(data) + data = self._decoder.decompress(data, max_length=max_length) except (IOError, zlib.error) as e: content_encoding = self.headers.get('content-encoding', '').lower() raise DecodeError( @@ -536,6 +627,14 @@ class HTTPResponse(io.IOBase): if amt is not None: cache_content = False + if self._decoder and self._decoder.has_unconsumed_tail: + decoded_data = self._decode( + b"", + decode_content, + flush_decoder=False, + max_length=amt - len(self._decoded_buffer), + ) + self._decoded_buffer.put(decoded_data) if len(self._decoded_buffer) >= amt: return self._decoded_buffer.get(amt) @@ -547,7 +646,11 @@ class HTTPResponse(io.IOBase): elif amt != 0 and not data: flush_decoder = True - if not data and len(self._decoded_buffer) == 0: + if ( + not data + and len(self._decoded_buffer) == 0 + and not (self._decoder and self._decoder.has_unconsumed_tail) + ): return data if amt is None: @@ -559,7 +662,12 @@ class HTTPResponse(io.IOBase): if not decode_content: return data - decoded_data = self._decode(data, decode_content, flush_decoder) + decoded_data = self._decode( + data, + decode_content, + flush_decoder, + max_length=amt - len(self._decoded_buffer), + ) self._decoded_buffer.put(decoded_data) while len(self._decoded_buffer) < amt and data: @@ -567,7 +675,12 @@ class HTTPResponse(io.IOBase): # For example, the GZ file header takes 10 bytes, we don't want to read # it one byte at a time data = self._raw_read(amt) - decoded_data = self._decode(data, decode_content, flush_decoder) + decoded_data = self._decode( + data, + decode_content, + flush_decoder, + max_length=amt - len(self._decoded_buffer), + ) self._decoded_buffer.put(decoded_data) data = self._decoded_buffer.get(amt) @@ -593,7 +706,11 @@ class HTTPResponse(io.IOBase): for line in self.read_chunked(amt, decode_content=decode_content): yield line else: - while not is_fp_closed(self._fp) or len(self._decoded_buffer) > 0: + while ( + not is_fp_closed(self._fp) + or len(self._decoded_buffer) > 0 + or (self._decoder and self._decoder.has_unconsumed_tail) + ): data = self.read(amt=amt, decode_content=decode_content) if data: @@ -771,7 +888,7 @@ class HTTPResponse(io.IOBase): break chunk = self._handle_chunk(amt) decoded = self._decode(chunk, decode_content=decode_content, - flush_decoder=False) + flush_decoder=False, max_length=amt) if decoded: yield decoded diff --git a/test/test_response.py b/test/test_response.py index 3f3427a8..347276e8 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,3 +1,4 @@ +import gzip import socket import ssl import sys @@ -19,6 +20,14 @@ from urllib3.util.response import is_fp_closed from base64 import b64decode +def deflate2_compress(data): + compressor = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS) + return compressor.compress(data) + compressor.flush() + + +_zstd_available = False + + class TestBytesQueueBuffer: def test_single_chunk(self): buffer = BytesQueueBuffer() @@ -254,6 +263,161 @@ class TestResponse(object): assert r.data == b'foofoofoo' + _test_compressor_params = [ + ("deflate1", ("deflate", zlib.compress)), + ("deflate2", ("deflate", deflate2_compress)), + ("gzip", ("gzip", gzip.compress)), + ] + if _zstd_available: + _test_compressor_params.append(("zstd", ("zstd", zstd_compress))) + else: + _test_compressor_params.append(("zstd", None)) + + @pytest.mark.parametrize("read_method", ("read",)) + @pytest.mark.parametrize( + "data", + [d[1] for d in _test_compressor_params], + ids=[d[0] for d in _test_compressor_params], + ) + def test_read_with_all_data_already_in_decompressor( + self, + request, + read_method, + data, + ): + if data is None: + pytest.skip(f"Proper {request.node.callspec.id} decoder is not available") + original_data = b"bar" * 1000 + name, compress_func = data + compressed_data = compress_func(original_data) + fp = mock.Mock(read=mock.Mock(return_value=b"")) + r = HTTPResponse(fp, headers={"content-encoding": name}, preload_content=False) + # Put all data in the decompressor's buffer. + r._init_decoder() + assert r._decoder is not None # for mypy + decoded = r._decoder.decompress(compressed_data, max_length=0) + if name == "br": + # It's known that some Brotli libraries do not respect + # `max_length`. + r._decoded_buffer.put(decoded) + else: + assert decoded == b"" + # Read the data via `HTTPResponse`. + read = getattr(r, read_method) + assert read(0) == b"" + assert read(2500) == original_data[:2500] + assert read(500) == original_data[2500:] + assert read(0) == b"" + assert read() == b"" + + @pytest.mark.parametrize( + "delta", + ( + 0, # First read from socket returns all compressed data. + -1, # First read from socket returns all but one byte of compressed data. + ), + ) + @pytest.mark.parametrize("read_method", ("read",)) + @pytest.mark.parametrize( + "data", + [d[1] for d in _test_compressor_params], + ids=[d[0] for d in _test_compressor_params], + ) + def test_decode_with_max_length_close_to_compressed_data_size( + self, + request, + delta, + read_method, + data, + ): + """ + Test decoding when the first read from the socket returns all or + almost all the compressed data, but then it has to be + decompressed in a couple of read calls. + """ + if data is None: + pytest.skip(f"Proper {request.node.callspec.id} decoder is not available") + + original_data = b"foo" * 1000 + name, compress_func = data + compressed_data = compress_func(original_data) + fp = BytesIO(compressed_data) + r = HTTPResponse(fp, headers={"content-encoding": name}, preload_content=False) + initial_limit = len(compressed_data) + delta + read = getattr(r, read_method) + initial_chunk = read(amt=initial_limit, decode_content=True) + assert len(initial_chunk) == initial_limit + assert ( + len(read(amt=len(original_data), decode_content=True)) + == len(original_data) - initial_limit + ) + + # Prepare 50 MB of compressed data outside of the test measuring + # memory usage. + _test_memory_usage_decode_with_max_length_params = [ + ( + params[0], + (params[1][0], params[1][1](b"A" * (50 * 2**20))) if params[1] else None, + ) + for params in _test_compressor_params + ] + + @pytest.mark.parametrize( + "data", + [d[1] for d in _test_memory_usage_decode_with_max_length_params], + ids=[d[0] for d in _test_memory_usage_decode_with_max_length_params], + ) + @pytest.mark.parametrize("read_method", ("read", "read_chunked", "stream")) + # Decoders consume different amounts of memory during decompression. + # We set the 10 MB limit to ensure that the whole decompressed data + # is not stored unnecessarily. + # + # FYI, the following consumption was observed for the test with + # `read` on CPython 3.14.0: + # - deflate: 2.3 MiB + # - deflate2: 2.1 MiB + # - gzip: 2.1 MiB + # - brotli: + # - brotli v1.2.0: 9 MiB + # - brotlicffi v1.2.0.0: 6 MiB + # - brotlipy v0.7.0: 105.8 MiB + # - zstd: 4.5 MiB + @pytest.mark.limit_memory("10 MB", current_thread_only=True) + def test_memory_usage_decode_with_max_length( + self, + request, + read_method, + data, + ): + if data is None: + pytest.skip(f"Proper {request.node.callspec.id} decoder is not available") + + name, compressed_data = data + limit = 1024 * 1024 # 1 MiB + if read_method in ("read_chunked", "stream"): + httplib_r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + httplib_r.fp = MockChunkedEncodingResponse([compressed_data]) # type: ignore[assignment] + r = HTTPResponse( + httplib_r, + preload_content=False, + headers={"transfer-encoding": "chunked", "content-encoding": name}, + ) + next(getattr(r, read_method)(amt=limit, decode_content=True)) + else: + fp = BytesIO(compressed_data) + r = HTTPResponse( + fp, headers={"content-encoding": name}, preload_content=False + ) + getattr(r, read_method)(amt=limit, decode_content=True) + + # Check that the internal decoded buffer is empty unless brotli + # is used. + # Google's brotli library does not fully respect the output + # buffer limit: https://github.com/google/brotli/issues/1396 + # And unmaintained brotlipy cannot limit the output buffer size. + if name != "br" or brotli.__name__ == "brotlicffi": + assert len(r._decoded_buffer) == 0 + def test_multi_decoding_deflate_deflate(self): data = zlib.compress(zlib.compress(b'foo')) -- 2.52.0