Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,23 +476,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
await response(scope, receive, send)
return

# Check if this is an initialization request
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)

# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
is_initialization_request = False
if isinstance(message, JSONRPCRequest) and message.method == "initialize":
is_initialization_request = True
if not await self._validate_initialization_request(message, request, send):
return
elif not await self._validate_request_headers(request, send):
return

Expand Down Expand Up @@ -865,6 +853,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool

return True

async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool:
if not await self._validate_initialization_protocol_version(message, request, send):
return False

if not self.mcp_session_id:
return True

request_session_id = self._get_session_id(request)
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(request.scope, request.receive, send)
return False

return True

async def _validate_initialization_protocol_version(
self, message: JSONRPCRequest, request: Request, send: Send
) -> bool:
header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None
if (
header_protocol_version is not None
and body_protocol_version is not None
and header_protocol_version != body_protocol_version
):
response = self._create_error_response(
f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion",
HTTPStatus.BAD_REQUEST,
INVALID_REQUEST,
)
await response(request.scope, request.receive, send)
return False

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""Replays events that would have been sent after the specified event ID.

Expand Down
39 changes: 39 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,45 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) ->
assert response.status_code == 200


@pytest.mark.anyio
@pytest.mark.parametrize(
("header_version", "body_version"),
[
("2025-03-26", "2025-06-18"),
("2025-06-18", "2025-03-26"),
],
)
async def test_server_rejects_initialize_protocol_version_mismatch(
basic_app: Starlette, header_version: str, body_version: str
) -> None:
"""Initialize rejects conflicting protocol versions in header and body."""
init_request: dict[str, Any] = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": body_version,
"capabilities": {},
},
"id": "init-1",
}

async with make_client(basic_app) as client:
response = await client.post(
"/mcp",
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
MCP_PROTOCOL_VERSION_HEADER: header_version,
},
json=init_request,
)

assert response.status_code == 400
assert MCP_PROTOCOL_VERSION_HEADER in response.text
assert "protocolVersion" in response.text


@pytest.mark.anyio
async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None:
"""A request without a protocol version header is accepted for backwards compatibility."""
Expand Down
Loading