Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ async def run_server():
```
"""

import json
import re
import sys
from contextlib import asynccontextmanager
from io import TextIOWrapper
from typing import Any, cast

import anyio
import anyio.lowlevel
Expand All @@ -28,6 +31,50 @@ async def run_server():
from mcp.shared._context_streams import create_context_streams
from mcp.shared.message import SessionMessage

_JSONRPC_ID_PATTERN = re.compile(r'"id"\s*:\s*(-?\d+|"[^"\\]*")')


def _request_id_from_raw_message(line: str) -> types.RequestId | None:
try:
raw_message: Any = json.loads(line)
except Exception:
raw_message = None

if not isinstance(raw_message, dict):
match = _JSONRPC_ID_PATTERN.search(line)
if not match:
return None

raw_request_id = match.group(1)
if raw_request_id.startswith('"'):
return json.loads(raw_request_id)
return int(raw_request_id)

raw_message_dict = cast(dict[str, Any], raw_message)
request_id = raw_message_dict.get("id")
if isinstance(request_id, str) or type(request_id) is int:
return request_id
return None


def _error_response_from_parse_failure(line: str, exc: Exception) -> SessionMessage:
request_id = _request_id_from_raw_message(line)
message = str(exc)
if "Invalid JSON" in message:
code = types.PARSE_ERROR
prefix = "Parse error"
else:
code = types.INVALID_REQUEST
prefix = "Invalid request"

return SessionMessage(
types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=types.ErrorData(code=code, message=f"{prefix}: {message}"),
)
)


@asynccontextmanager
async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None):
Expand All @@ -53,7 +100,8 @@ async def stdin_reader():
try:
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
except Exception as exc:
await read_stream_writer.send(exc)
error_response = _error_response_from_parse_failure(line, exc)
await write_stream.send(error_response)
continue

session_message = SessionMessage(message)
Expand Down
12 changes: 8 additions & 4 deletions tests/interaction/transports/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import tempfile
from pathlib import Path
from typing import TextIO, cast

import anyio
import pytest
Expand Down Expand Up @@ -67,7 +68,8 @@ async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess(
async def collect(params: LoggingMessageNotificationParams) -> None:
received.append(params)

with tempfile.TemporaryFile(mode="w+") as errlog:
with tempfile.TemporaryFile(mode="w+") as errlog_file:
errlog = cast(TextIO, errlog_file)
transport = stdio_client(
StdioServerParameters(
command=sys.executable,
Expand Down Expand Up @@ -98,9 +100,11 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
assert received == snapshot(
[LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")]
)
# The server writes this line only after its run loop returns on stdin close: seeing it proves
# a self-exit, not the terminate escalation. The capture itself proves stderr passthrough.
assert captured_stderr == snapshot("stdio-echo: clean exit\n")
# The server writes this line only after its run loop returns, which happens when stdin closes:
# seeing it proves the process exited on its own rather than via the transport's terminate
# escalation, without a timing-based assertion. The suffix check keeps the test stable if the
# child interpreter emits dependency warnings before the server's own stderr line.
assert captured_stderr.endswith("stdio-echo: clean exit\n")


@requirement("transport:stdio:stream-purity")
Expand Down
93 changes: 79 additions & 14 deletions tests/server/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import sys
import threading
from collections.abc import AsyncIterator
Expand All @@ -9,9 +10,17 @@
import pytest

from mcp.server.mcpserver import MCPServer
from mcp.server.stdio import stdio_server
from mcp.server.stdio import _error_response_from_parse_failure, _request_id_from_raw_message, stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import (
INVALID_REQUEST,
PARSE_ERROR,
JSONRPCError,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCResponse,
jsonrpc_message_adapter,
)


@pytest.mark.anyio
Expand Down Expand Up @@ -68,10 +77,10 @@ async def test_stdio_server_round_trips_messages_over_injected_streams() -> None

@pytest.mark.anyio
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None:
"""Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream.
"""Non-UTF-8 stdin bytes produce an error response without killing the stream.

Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream
exception; subsequent valid messages are still processed.
Invalid bytes are replaced with U+FFFD, then fail JSON parsing and are returned
as a JSON-RPC parse error. Subsequent valid messages are still processed.
"""
# \xff\xfe are invalid UTF-8 start bytes.
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
Expand All @@ -80,20 +89,76 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> Non
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
# stdio_server()'s default path wraps it with errors='replace'.
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
stdout = io.StringIO()

with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
async with stdio_server(stdout=anyio.AsyncFile(stdout)) as (read_stream, write_stream):
async with read_stream: # pragma: no branch
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> error response on stdout
first = await read_stream.receive()
assert isinstance(first, Exception)
assert isinstance(first, SessionMessage)
assert first.message == valid

await write_stream.aclose()

stdout.seek(0)
output = stdout.read()
error = jsonrpc_message_adapter.validate_json(output.strip())
assert isinstance(error, JSONRPCError)
assert error.id is None
assert error.error.code == PARSE_ERROR


@pytest.mark.anyio
async def test_stdio_server_parse_error_completes_id_bearing_request() -> None:
params: object = {"leaf": True}
for index in reversed(range(256)):
params = {f"p{index}": params}
line = json.dumps({"jsonrpc": "2.0", "id": 900256, "method": "ping", "params": params}) + "\n"

stdin = io.StringIO(line)
stdout = io.StringIO()

with anyio.fail_after(5):
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
async with read_stream:
with pytest.raises(anyio.EndOfStream):
await read_stream.receive()
await write_stream.aclose()

stdout.seek(0)
output_lines = stdout.readlines()
assert len(output_lines) == 1

response = jsonrpc_message_adapter.validate_json(output_lines[0].strip())
assert isinstance(response, JSONRPCError)
assert response.id == 900256
assert response.error.code == PARSE_ERROR
assert "Parse error" in response.error.message


def test_stdio_request_id_recovery_edges() -> None:
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":"abc","method":"ping","params":[') == "abc"
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":42,"method":"ping","params":[') == 42
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":-7,"method":1}') == -7
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":null,"method":1}') is None
assert _request_id_from_raw_message("[]") is None


def test_stdio_invalid_request_response_preserves_string_id() -> None:
line = '{"jsonrpc":"2.0","id":"bad-method","method":1}'
with pytest.raises(Exception) as exc_info:
jsonrpc_message_adapter.validate_json(line)

response = _error_response_from_parse_failure(line, exc_info.value)

# Second line: valid message still comes through
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid
assert isinstance(response.message, JSONRPCError)
assert response.message.id == "bad-method"
assert response.message.error.code == INVALID_REQUEST
assert "Invalid request" in response.message.error.message


class _KeepOpenBytesIO(io.BytesIO):
Expand Down
Loading