From 9ab85e10f4e5650f36fcfa9378f9a9580ab08c54 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Tue, 19 May 2026 02:36:19 +0800 Subject: [PATCH] fix: reject initialize protocol version conflicts --- src/mcp/server/streamable_http.py | 60 ++++++++++++++++++++-------- tests/shared/test_streamable_http.py | 39 ++++++++++++++++++ 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 2cb4c0748..612c144e7 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -476,23 +476,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - # Check if this is an initialization request - is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" - - if is_initialization_request: - # Check if the server already has an established session - if self.mcp_session_id: - # Check if request has a session ID - request_session_id = self._get_session_id(request) - - # If request has a session ID but doesn't match, return 404 - if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover - response = self._create_error_response( - "Not Found: Invalid or expired session ID", - HTTPStatus.NOT_FOUND, - ) - await response(scope, receive, send) - return + is_initialization_request = False + if isinstance(message, JSONRPCRequest) and message.method == "initialize": + is_initialization_request = True + if not await self._validate_initialization_request(message, request, send): + return elif not await self._validate_request_headers(request, send): return @@ -865,6 +853,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True + async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool: + if not await self._validate_initialization_protocol_version(message, request, send): + return False + + if not self.mcp_session_id: + return True + + request_session_id = self._get_session_id(request) + if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _validate_initialization_protocol_version( + self, message: JSONRPCRequest, request: Request, send: Send + ) -> bool: + header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) + body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None + if ( + header_protocol_version is not None + and body_protocol_version is not None + and header_protocol_version != body_protocol_version + ): + response = self._create_error_response( + f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion", + HTTPStatus.BAD_REQUEST, + INVALID_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + return True + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b43a3361c..3bb9a0099 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1489,6 +1489,45 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) -> assert response.status_code == 200 +@pytest.mark.anyio +@pytest.mark.parametrize( + ("header_version", "body_version"), + [ + ("2025-03-26", "2025-06-18"), + ("2025-06-18", "2025-03-26"), + ], +) +async def test_server_rejects_initialize_protocol_version_mismatch( + basic_app: Starlette, header_version: str, body_version: str +) -> None: + """Initialize rejects conflicting protocol versions in header and body.""" + init_request: dict[str, Any] = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": body_version, + "capabilities": {}, + }, + "id": "init-1", + } + + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_PROTOCOL_VERSION_HEADER: header_version, + }, + json=init_request, + ) + + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text + assert "protocolVersion" in response.text + + @pytest.mark.anyio async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: """A request without a protocol version header is accepted for backwards compatibility."""