diff --git a/CHANGES/12830.bugfix.rst b/CHANGES/12830.bugfix.rst new file mode 100644 index 00000000000..d44d76da404 --- /dev/null +++ b/CHANGES/12830.bugfix.rst @@ -0,0 +1 @@ +Bounded the number of parsed-but-unhandled pipelined HTTP/1 requests buffered per connection on the server; once the queue reaches an internal limit the parser stops emitting and the transport is paused, resuming as the request handler drains the queue, so a client keeping one handler busy can no longer accumulate an unbounded backlog of pipelined requests -- by :user:`bdraco`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 700e9db7f2e..825e5238b0b 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -325,6 +325,8 @@ cdef class HttpParser: list _messages bint _more_data_available bint _paused + Py_ssize_t _msg_in_flight + Py_ssize_t _max_msg_queue_size bint _eof_pending object _payload unsigned long long _content_length_expected @@ -361,6 +363,7 @@ cdef class HttpParser: size_t max_field_size=8190, payload_exception=None, bint response_with_body=True, bint read_until_eof=False, bint auto_decompress=True, + Py_ssize_t max_msg_queue_size=0, ): cparser.llhttp_settings_init(self._csettings) cparser.llhttp_init(self._cparser, mode, self._csettings) @@ -375,6 +378,8 @@ cdef class HttpParser: self._buf = bytearray() self._more_data_available = False self._paused = False + self._msg_in_flight = 0 + self._max_msg_queue_size = max_msg_queue_size self._eof_pending = False self._payload = None self._payload_error = 0 @@ -558,6 +563,11 @@ cdef class HttpParser: assert self._payload is not None self._paused = True + def message_consumed(self): + # Protocol drained a queued message; free a slot for parsing. + if self._msg_in_flight > 0: + self._msg_in_flight -= 1 + def feed_eof(self): cdef bytes desc @@ -680,12 +690,12 @@ cdef class HttpRequestParser(HttpParser): size_t max_line_size=8190, size_t max_headers=128, size_t max_field_size=8190, payload_exception=None, bint response_with_body=True, bint read_until_eof=False, - bint auto_decompress=True, + bint auto_decompress=True, Py_ssize_t max_msg_queue_size=0, ): self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer, max_line_size, max_headers, max_field_size, payload_exception, response_with_body, read_until_eof, - auto_decompress) + auto_decompress, max_msg_queue_size) cdef object _on_status_complete(self): cdef int idx1, idx2 @@ -894,6 +904,12 @@ cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: pyparser._last_error = exc return -1 else: + if pyparser._max_msg_queue_size: + pyparser._msg_in_flight += 1 + if pyparser._msg_in_flight >= pyparser._max_msg_queue_size: + # Queue full: pause llhttp between messages. feed_data() buffers + # the remainder as tail; resumes once the queue drains. + return cparser.HPE_PAUSED return 0 diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index f1f6edc3836..df3f8c089ac 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -8,6 +8,13 @@ if TYPE_CHECKING: from .http_parser import HttpParser +# Raised by transport.pause_reading()/resume_reading() when the transport +# does not support flow control; safe to ignore. +# NOTE: Catch these with a plain try/except/pass, never contextlib.suppress(): +# pause/resume run on the hot read path and suppress() is ~6x slower than +# try/except here (it builds a context manager and unpacks this tuple per call). +PAUSE_RESUME_READING_ERRORS = (AttributeError, NotImplementedError, RuntimeError) + class BaseProtocol(asyncio.Protocol): __slots__ = ( @@ -65,9 +72,15 @@ def pause_reading(self) -> None: if self.transport is not None: try: self.transport.pause_reading() - except (AttributeError, NotImplementedError, RuntimeError): + except PAUSE_RESUME_READING_ERRORS: + # Transport lacks flow control; nothing to pause. Intentionally + # ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress). pass + def _reading_paused_for_msg_queue(self) -> bool: + """Keep the transport paused for protocol-specific reasons (overridden).""" + return False + def resume_reading(self, resume_parser: bool = True) -> None: self._reading_paused = False @@ -77,10 +90,16 @@ def resume_reading(self, resume_parser: bool = True) -> None: # Reading may have been paused again in the above call if there was a lot of # compressed data still pending. - if not self._reading_paused and self.transport is not None: + if ( + not self._reading_paused + and not self._reading_paused_for_msg_queue() + and self.transport is not None + ): try: self.transport.resume_reading() - except (AttributeError, NotImplementedError, RuntimeError): + except PAUSE_RESUME_READING_ERRORS: + # Transport lacks flow control; nothing to resume. Intentionally + # ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress). pass self._reading_paused = False diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index cd4342677aa..8d48b877c0b 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -266,6 +266,7 @@ def __init__( response_with_body: bool = True, read_until_eof: bool = False, auto_decompress: bool = True, + max_msg_queue_size: int = 0, ) -> None: self.protocol = protocol self.loop = loop @@ -288,6 +289,9 @@ def __init__( self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_field_size, self.lax) + # Stop emitting messages once this many are queued unconsumed (0 = off). + self._max_msg_queue_size = max_msg_queue_size + self._msg_in_flight = 0 @abc.abstractmethod def parse_message(self, lines: list[bytes]) -> _MsgT: ... @@ -299,6 +303,11 @@ def pause_reading(self) -> None: assert self._payload_parser is not None self._payload_parser.pause_reading() + def message_consumed(self) -> None: + """Protocol drained a queued message; free a slot for parsing.""" + if self._msg_in_flight > 0: + self._msg_in_flight -= 1 + def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() @@ -340,6 +349,15 @@ def feed_data( # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: + if ( + self._max_msg_queue_size + and self._msg_in_flight >= self._max_msg_queue_size + ): + # Queue full: buffer the rest and stop. Safe pause point; + # any preceding body is consumed before the next request + # line. Resumes via feed_data(b"") when the queue drains. + self._tail = data[start_pos:] + break pos = data.find(SEP, start_pos) # consume \r\n if pos == start_pos and not self._lines: @@ -484,6 +502,8 @@ def get_content_length() -> int | None: payload = EMPTY_PAYLOAD messages.append((msg, payload)) + if self._max_msg_queue_size: + self._msg_in_flight += 1 should_close = msg.should_close else: self._tail = data[start_pos:] diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 1e297111bef..96cd9401b5d 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -14,7 +14,7 @@ from propcache import under_cached_property from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter -from .base_protocol import BaseProtocol +from .base_protocol import PAUSE_RESUME_READING_ERRORS, BaseProtocol from .helpers import DEFAULT_CHUNK_SIZE, ceil_timeout, frozen_dataclass_decorator from .http import ( HttpProcessingError, @@ -35,6 +35,11 @@ __all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") +# Max parsed-but-unhandled pipelined requests buffered per connection before +# reading is paused. Bounds memory a client can pin by keeping one handler busy +# and pipelining behind it; reading resumes as the queue drains. +MAX_MSG_QUEUE_SIZE = 32 + if TYPE_CHECKING: import ssl @@ -168,6 +173,9 @@ class RequestHandler(BaseProtocol, Generic[_Request]): "_keepalive_timeout", "_lingering_time", "_messages", + "_max_msg_queue_size", + "_msg_queue_resume_size", + "_msg_queue_paused", "_message_tail", "_handler_waiter", "_waiter", @@ -206,6 +214,13 @@ def __init__( auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): + self._max_msg_queue_size = MAX_MSG_QUEUE_SIZE + # Low-water mark: resume reading once the queue drains to half the limit + # so we refill in batches instead of churning pause/resume per request. + self._msg_queue_resume_size = MAX_MSG_QUEUE_SIZE // 2 + # Set before super().__init__ so _reading_paused_for_msg_queue() is safe + # if BaseProtocol ever triggers a resume during init. + self._msg_queue_paused = False parser = HttpRequestParser( self, loop, @@ -215,6 +230,7 @@ def __init__( max_headers=max_headers, payload_exception=RequestPayloadError, auto_decompress=auto_decompress, + max_msg_queue_size=MAX_MSG_QUEUE_SIZE, ) super().__init__(loop, parser) @@ -461,6 +477,14 @@ def data_received(self, data: bytes) -> None: # don't set result twice waiter.set_result(None) + # Queue full: pause the transport (the parser already stopped + # emitting). start() resumes as it drains the queue. + if ( + not self._msg_queue_paused + and len(self._messages) >= self._max_msg_queue_size + ): + self._pause_msg_queue_reading() + self._upgraded = upgraded if upgraded and tail: self._message_tail = tail @@ -477,6 +501,36 @@ def data_received(self, data: bytes) -> None: if eof: self.close() + def _reading_paused_for_msg_queue(self) -> bool: + return self._msg_queue_paused + + def _pause_msg_queue_reading(self) -> None: + self._msg_queue_paused = True + if self.transport is not None: + try: + self.transport.pause_reading() + except PAUSE_RESUME_READING_ERRORS: + # Transport lacks flow control; nothing to pause. Intentionally + # ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress). + pass + + def _resume_msg_queue_reading(self) -> None: + if not self._upgraded: + # Reparse buffered pipelined requests while still marked paused so + # a refill past the limit does not re-pause an already-paused + # transport; only resume below once it stayed under the limit. + self.data_received(b"") + if len(self._messages) >= self._max_msg_queue_size: + return + self._msg_queue_paused = False + if not self._reading_paused and self.transport is not None: + try: + self.transport.resume_reading() + except PAUSE_RESUME_READING_ERRORS: + # Transport lacks flow control; nothing to resume. Intentionally + # ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress). + pass + def keep_alive(self, val: bool) -> None: """Set keep-alive connection mode. @@ -606,6 +660,18 @@ async def start(self) -> None: message, payload = self._messages.popleft() + # Free a parser slot; resume reading once drained to low water so + # pipelining keeps flowing while this request is handled. + # no branch: _parser is only None after connection_lost, whose path + # exits this loop, so the None case is not reachably exercisable. + if self._parser is not None: # pragma: no branch + self._parser.message_consumed() + if ( + self._msg_queue_paused + and len(self._messages) <= self._msg_queue_resume_size + ): + self._resume_msg_queue_reading() + # time is only fetched if logging is enabled as otherwise # its thrown away and never used. start = loop.time() if self._logging_enabled else None diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 75d3d0c8323..1a30d25a3a2 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -251,6 +251,7 @@ peername performant pickleable ping +pipelined pipelining pluggable plugin diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index f95e4cfacf1..8b6d5f52094 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -153,6 +153,78 @@ def test_c_parser_loaded() -> None: assert "RawResponseMessageC" in dir(aiohttp.http_parser) +_PIPELINED_GET = b"GET / HTTP/1.1\r\nHost: a\r\n\r\n" + + +def _build_request_parser( + request_cls: type[HttpRequestParser], + protocol: BaseProtocol, + loop: asyncio.AbstractEventLoop, + max_msg_queue_size: int, +) -> HttpRequestParser: + return request_cls( + protocol, + loop, + DEFAULT_CHUNK_SIZE, + max_line_size=8190, + max_headers=128, + max_field_size=8190, + max_msg_queue_size=max_msg_queue_size, + ) + + +def test_max_msg_queue_size_caps_emitted_messages( + request_cls: type[HttpRequestParser], + protocol: BaseProtocol, + event_loop: asyncio.AbstractEventLoop, +) -> None: + parser = _build_request_parser(request_cls, protocol, event_loop, 4) + messages, upgraded, _tail = parser.feed_data(_PIPELINED_GET * 10) + assert len(messages) == 4 + assert not upgraded + + +def test_max_msg_queue_size_resumes_after_consume( + request_cls: type[HttpRequestParser], + protocol: BaseProtocol, + event_loop: asyncio.AbstractEventLoop, +) -> None: + limit = 4 + total = 10 + parser = _build_request_parser(request_cls, protocol, event_loop, limit) + messages, _upgraded, _tail = parser.feed_data(_PIPELINED_GET * total) + seen = 0 + while messages: + assert len(messages) <= limit + seen += len(messages) + for _msg, _payload in messages: + parser.message_consumed() + messages, _upgraded, _tail = parser.feed_data(b"") + assert seen == total + + +def test_max_msg_queue_size_zero_is_unbounded( + request_cls: type[HttpRequestParser], + protocol: BaseProtocol, + event_loop: asyncio.AbstractEventLoop, +) -> None: + parser = _build_request_parser(request_cls, protocol, event_loop, 0) + messages, _upgraded, _tail = parser.feed_data(_PIPELINED_GET * 50) + assert len(messages) == 50 + + +def test_message_consumed_underflow_is_ignored( + request_cls: type[HttpRequestParser], + protocol: BaseProtocol, + event_loop: asyncio.AbstractEventLoop, +) -> None: + parser = _build_request_parser(request_cls, protocol, event_loop, 4) + # No message is in flight; consuming must not underflow the counter. + parser.message_consumed() + messages, _upgraded, _tail = parser.feed_data(_PIPELINED_GET * 4) + assert len(messages) == 4 + + def test_parse_headers(parser: HttpRequestParser) -> None: text = b"""GET /test HTTP/1.1\r Host: a\r diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index c760bff2f19..b8488588906 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -32,7 +32,7 @@ from aiohttp.helpers import DEFAULT_CHUNK_SIZE, HeadersDictProxy from aiohttp.streams import StreamReader from aiohttp.typedefs import Handler, Middleware -from aiohttp.web_protocol import RequestHandler +from aiohttp.web_protocol import MAX_MSG_QUEUE_SIZE, RequestHandler try: import brotlicffi as brotli @@ -1717,6 +1717,136 @@ async def handler(request: web.Request) -> web.StreamResponse: resp.release() +async def test_http1_pipelined_requests_are_count_limited( + aiohttp_server: AiohttpServer, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Requests pipelined behind a busy handler must not grow unbounded. + + A client can keep one handler active and pipeline many complete requests + behind it; the per-connection queue stays bounded by MAX_MSG_QUEUE_SIZE. + """ + pipelined_requests = 500 + slow_handler_started = asyncio.Event() + queue_observed = asyncio.Event() + max_queued = 0 + data_received = RequestHandler.data_received + + def observe_data_received(self: RequestHandler[web.Request], data: bytes) -> None: + nonlocal max_queued + data_received(self, data) + if self._request_in_progress and self._messages: + max_queued = max(max_queued, len(self._messages)) + queue_observed.set() + + monkeypatch.setattr(RequestHandler, "data_received", observe_data_received) + + async def slow_handler(request: web.Request) -> web.Response: + slow_handler_started.set() + await asyncio.sleep(0.5) + return web.Response(text="slow") + + async def fast_handler(request: web.Request) -> NoReturn: + # The pipelined requests are only counted, never handled: the test + # closes the connection while the slow handler still holds the loop. + assert False + + app = web.Application() + app.router.add_get("/slow", slow_handler) + app.router.add_get("/x", fast_handler) + server = await aiohttp_server(app) + + def raw_get(path: str) -> bytes: + return ( + f"GET {path} HTTP/1.1\r\nHost: localhost\r\n" + "Connection: keep-alive\r\n\r\n" + ).encode("ascii") + + reader, writer = await asyncio.open_connection(server.host, server.port) + try: + writer.write(raw_get("/slow")) + await writer.drain() + await asyncio.wait_for(slow_handler_started.wait(), 1) + + writer.write(raw_get("/x") * pipelined_requests) + await writer.drain() + await asyncio.wait_for(queue_observed.wait(), 1) + finally: + writer.close() + with suppress(ConnectionResetError, BrokenPipeError): + await writer.wait_closed() + + # Tight lower bound also catches over-aggressive pausing (e.g. clamping to 1). + assert MAX_MSG_QUEUE_SIZE // 2 < max_queued <= MAX_MSG_QUEUE_SIZE + + +async def test_http1_pipelined_queue_resumes_after_drain( + aiohttp_server: AiohttpServer, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A paused pipeline queue resumes reading once handlers drain it. + + Once enough requests are pipelined behind a busy handler to fill the queue, + reading is paused; as the handlers drain the queue past the low-water mark + reading must resume so the remaining buffered requests are still served. + """ + # Several times the limit so the queue refills and re-pauses while draining. + pipelined_requests = MAX_MSG_QUEUE_SIZE * 3 + first_started = asyncio.Event() + release_first = asyncio.Event() + resumed = asyncio.Event() + handled: list[str] = [] + all_handled = asyncio.Event() + + resume = RequestHandler._resume_msg_queue_reading + + def observe_resume(self: RequestHandler[web.Request]) -> None: + resume(self) + resumed.set() + + monkeypatch.setattr(RequestHandler, "_resume_msg_queue_reading", observe_resume) + + async def handler(request: web.Request) -> web.Response: + if request.path == "/first": + first_started.set() + await release_first.wait() + handled.append(request.path) + if len(handled) == pipelined_requests + 1: + all_handled.set() + return web.Response() + + app = web.Application() + app.router.add_get("/{tail:.*}", handler) + server = await aiohttp_server(app) + + def raw_get(path: str) -> bytes: + return ( + f"GET {path} HTTP/1.1\r\nHost: localhost\r\n" + "Connection: keep-alive\r\n\r\n" + ).encode("ascii") + + reader, writer = await asyncio.open_connection(server.host, server.port) + try: + writer.write(raw_get("/first")) + await writer.drain() + await asyncio.wait_for(first_started.wait(), 1) + + writer.write(b"".join(raw_get(f"/r{i}") for i in range(pipelined_requests))) + await writer.drain() + + # Let the busy handler finish so the queue drains and reading resumes. + release_first.set() + await asyncio.wait_for(resumed.wait(), 5) + # Every pipelined request is still served only if reading resumed. + await asyncio.wait_for(all_handled.wait(), 5) + finally: + writer.close() + with suppress(ConnectionResetError, BrokenPipeError): + await writer.wait_closed() + + assert len(handled) == pipelined_requests + 1 + + @pytest.mark.parametrize("decompressed_size", [4 * 1024 * 1024, 32 * 1024 * 1024]) async def test_unread_compressed_body_drain_is_bounded( aiohttp_server: AiohttpServer, diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 8936c1dfeeb..8968dea78b5 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -49,3 +49,96 @@ def test_data_received_calls_data_received_cb( cb.assert_called_once() dummy_reader[1].feed_data.assert_called_once_with(b"x") + + +def test_pause_msg_queue_reading_without_transport( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """Pausing with no transport still records the paused state.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + handler.transport = None + + handler._pause_msg_queue_reading() + + assert handler._msg_queue_paused is True + + +def test_resume_msg_queue_reading_after_upgrade_skips_reparse( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """Resume after an upgrade clears the pause and resumes without reparsing.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + transport = mock.Mock() + handler.transport = transport + handler._upgraded = True + handler._msg_queue_paused = True + handler._reading_paused = False + + with mock.patch.object(RequestHandler, "data_received") as data_received: + handler._resume_msg_queue_reading() + + data_received.assert_not_called() + assert handler._msg_queue_paused is False + transport.resume_reading.assert_called_once_with() + + +def test_resume_msg_queue_reading_without_transport( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """Resume clears the pause but does not touch a missing transport.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + handler.transport = None + handler._upgraded = True # skip the reparse branch + handler._msg_queue_paused = True + + handler._resume_msg_queue_reading() + + assert handler._msg_queue_paused is False + + +def test_resume_reading_stays_paused_for_msg_queue( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """Base resume_reading must not un-pause the transport while queue-paused.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + transport = mock.Mock() + handler.transport = transport + handler._msg_queue_paused = True + + handler.resume_reading() + + transport.resume_reading.assert_not_called() + + +def test_pause_msg_queue_reading_ignores_unsupported_transport( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """A transport without flow control raising on pause is ignored.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + # Bare asyncio.Transport.pause_reading() raises NotImplementedError. + handler.transport = asyncio.Transport() + + handler._pause_msg_queue_reading() + + assert handler._msg_queue_paused is True + + +def test_resume_msg_queue_reading_ignores_unsupported_transport( + event_loop: asyncio.AbstractEventLoop, + dummy_manager: Server[BaseRequest], +) -> None: + """A transport without flow control raising on resume is ignored.""" + handler = RequestHandler(dummy_manager, loop=event_loop) + # Bare asyncio.Transport.resume_reading() raises NotImplementedError. + handler.transport = asyncio.Transport() + handler._upgraded = True # skip the reparse branch + handler._msg_queue_paused = True + + handler._resume_msg_queue_reading() + + assert handler._msg_queue_paused is False