fix(mcp): seed protocol header before HTTP initialize

This commit is contained in:
Matt Maximo 2026-04-15 13:34:39 -04:00 committed by Teknium
parent 983bbe2d40
commit 3ccda2aa05
2 changed files with 100 additions and 0 deletions

View file

@ -1200,6 +1200,92 @@ class TestHTTPConfig:
asyncio.run(_test())
def test_http_seeds_initial_protocol_header(self):
from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask
server = MCPServerTask("remote")
captured = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, exc_type, exc, tb):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def initialize(self):
return None
class DummyLegacyTransportCtx:
def __init__(self, **kwargs):
captured["legacy_headers"] = kwargs.get("headers")
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, exc_type, exc, tb):
return False
async def _discover_tools(self):
self._shutdown_event.set()
async def _run(config, *, new_http):
captured.clear()
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http(config)
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True))
assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
asyncio.run(_run({
"url": "https://example.com/mcp",
"headers": {"mcp-protocol-version": "custom-version"},
}, new_http=True))
assert captured["headers"]["mcp-protocol-version"] == "custom-version"
asyncio.run(_run({
"url": "https://example.com/mcp",
"headers": {"MCP-Protocol-Version": "custom-version"},
}, new_http=True))
assert captured["headers"]["MCP-Protocol-Version"] == "custom-version"
assert "mcp-protocol-version" not in captured["headers"]
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False))
assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
asyncio.run(_run({
"url": "https://example.com/mcp",
"headers": {"MCP-Protocol-Version": "custom-version"},
}, new_http=False))
assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version"
assert "mcp-protocol-version" not in captured["legacy_headers"]
# ---------------------------------------------------------------------------
# Reconnection logic