From f20ccb650319a67ae3a0945eae36ab46cf57f752 Mon Sep 17 00:00:00 2001 From: Zhang-986 <3225483474@qq.com> Date: Mon, 8 Jun 2026 16:22:04 +0800 Subject: [PATCH] fix: preserve auth endpoint query params --- src/mcp/client/auth/oauth2.py | 11 +++++++-- tests/client/test_auth.py | 44 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 3c546fda2b..a3fef788ab 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -12,7 +12,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit import anyio import httpx @@ -353,7 +353,14 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if "offline_access" in self.context.client_metadata.scope.split(): auth_params["prompt"] = "consent" - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + auth_endpoint_parts = urlsplit(auth_endpoint) + authorization_query = urlencode( + [ + *parse_qsl(auth_endpoint_parts.query, keep_blank_values=True), + *auth_params.items(), + ] + ) + authorization_url = urlunsplit(auth_endpoint_parts._replace(query=authorization_query)) await self.context.redirect_handler(authorization_url) # Wait for callback diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c92..c75ed7b375 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -263,6 +263,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O class TestOAuthFlow: """Test OAuth flow methods.""" + @pytest.mark.anyio + async def test_authorization_endpoint_query_params_are_preserved( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """OAuth authorization endpoints may already carry provider-specific query params.""" + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_state + parsed = urlparse(url) + params = parse_qs(parsed.query) + + assert params["prompt"] == ["select_account"] + assert params["response_type"] == ["code"] + assert params["client_id"] == ["test_client"] + + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + auth_code, code_verifier = await provider._perform_authorization_code_grant() + + assert auth_code == "test_auth_code" + assert code_verifier + @pytest.mark.anyio async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage