diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 3762eb6169..1604d4adb5 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -1200,6 +1200,92 @@ class TestHTTPConfig: asyncio.run(_test()) + def test_http_seeds_initial_protocol_header(self): + from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask + + server = MCPServerTask("remote") + captured = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class DummyTransportCtx: + async def __aenter__(self): + return MagicMock(), MagicMock(), (lambda: None) + + async def __aexit__(self, exc_type, exc, tb): + return False + + class DummySession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def initialize(self): + return None + + class DummyLegacyTransportCtx: + def __init__(self, **kwargs): + captured["legacy_headers"] = kwargs.get("headers") + + async def __aenter__(self): + return MagicMock(), MagicMock(), (lambda: None) + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def _discover_tools(self): + self._shutdown_event.set() + + async def _run(config, *, new_http): + captured.clear() + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ + patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \ + patch("httpx.AsyncClient", DummyAsyncClient), \ + patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \ + patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \ + patch("tools.mcp_tool.ClientSession", DummySession), \ + patch.object(MCPServerTask, "_discover_tools", _discover_tools): + await server._run_http(config) + + asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True)) + assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION + + asyncio.run(_run({ + "url": "https://example.com/mcp", + "headers": {"mcp-protocol-version": "custom-version"}, + }, new_http=True)) + assert captured["headers"]["mcp-protocol-version"] == "custom-version" + + asyncio.run(_run({ + "url": "https://example.com/mcp", + "headers": {"MCP-Protocol-Version": "custom-version"}, + }, new_http=True)) + assert captured["headers"]["MCP-Protocol-Version"] == "custom-version" + assert "mcp-protocol-version" not in captured["headers"] + + asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False)) + assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION + + asyncio.run(_run({ + "url": "https://example.com/mcp", + "headers": {"MCP-Protocol-Version": "custom-version"}, + }, new_http=False)) + assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version" + assert "mcp-protocol-version" not in captured["legacy_headers"] + # --------------------------------------------------------------------------- # Reconnection logic diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2b7f681ed6..6827a21dfe 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -167,6 +167,10 @@ _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False _MCP_NOTIFICATION_TYPES = False _MCP_MESSAGE_HANDLER_SUPPORTED = False +# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION. +# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the +# HTTP transport path even on older-but-supported SDK versions. +LATEST_PROTOCOL_VERSION = "2025-03-26" try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -183,6 +187,10 @@ try: _MCP_NEW_HTTP = True except ImportError: _MCP_NEW_HTTP = False + try: + from mcp.types import LATEST_PROTOCOL_VERSION + except ImportError: + logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version") # Sampling types -- separated so older SDK versions don't break MCP support try: from mcp.types import ( @@ -1075,6 +1083,12 @@ class MCPServerTask: url = config["url"] headers = dict(config.get("headers") or {}) + # Some MCP servers require MCP-Protocol-Version on the initial + # initialize request and reject session-less POSTs otherwise. + # Seed it as a client-level default, but treat user overrides as + # case-insensitive so conventional casing is preserved. + if not any(key.lower() == "mcp-protocol-version" for key in headers): + headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) ssl_verify = config.get("ssl_verify", True)