mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): seed protocol header before HTTP initialize
This commit is contained in:
parent
983bbe2d40
commit
3ccda2aa05
2 changed files with 100 additions and 0 deletions
|
|
@ -1200,6 +1200,92 @@ class TestHTTPConfig:
|
|||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_http_seeds_initial_protocol_header(self):
|
||||
from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask
|
||||
|
||||
server = MCPServerTask("remote")
|
||||
captured = {}
|
||||
|
||||
class DummyAsyncClient:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyTransportCtx:
|
||||
async def __aenter__(self):
|
||||
return MagicMock(), MagicMock(), (lambda: None)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def initialize(self):
|
||||
return None
|
||||
|
||||
class DummyLegacyTransportCtx:
|
||||
def __init__(self, **kwargs):
|
||||
captured["legacy_headers"] = kwargs.get("headers")
|
||||
|
||||
async def __aenter__(self):
|
||||
return MagicMock(), MagicMock(), (lambda: None)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def _discover_tools(self):
|
||||
self._shutdown_event.set()
|
||||
|
||||
async def _run(config, *, new_http):
|
||||
captured.clear()
|
||||
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \
|
||||
patch("httpx.AsyncClient", DummyAsyncClient), \
|
||||
patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \
|
||||
patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \
|
||||
patch("tools.mcp_tool.ClientSession", DummySession), \
|
||||
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
||||
await server._run_http(config)
|
||||
|
||||
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True))
|
||||
assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"mcp-protocol-version": "custom-version"},
|
||||
}, new_http=True))
|
||||
assert captured["headers"]["mcp-protocol-version"] == "custom-version"
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||
}, new_http=True))
|
||||
assert captured["headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||
assert "mcp-protocol-version" not in captured["headers"]
|
||||
|
||||
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False))
|
||||
assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||
}, new_http=False))
|
||||
assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||
assert "mcp-protocol-version" not in captured["legacy_headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnection logic
|
||||
|
|
|
|||
|
|
@ -167,6 +167,10 @@ _MCP_HTTP_AVAILABLE = False
|
|||
_MCP_SAMPLING_TYPES = False
|
||||
_MCP_NOTIFICATION_TYPES = False
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
|
||||
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
|
||||
# HTTP transport path even on older-but-supported SDK versions.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
|
@ -183,6 +187,10 @@ try:
|
|||
_MCP_NEW_HTTP = True
|
||||
except ImportError:
|
||||
_MCP_NEW_HTTP = False
|
||||
try:
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
except ImportError:
|
||||
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
|
||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||
try:
|
||||
from mcp.types import (
|
||||
|
|
@ -1075,6 +1083,12 @@ class MCPServerTask:
|
|||
|
||||
url = config["url"]
|
||||
headers = dict(config.get("headers") or {})
|
||||
# Some MCP servers require MCP-Protocol-Version on the initial
|
||||
# initialize request and reject session-less POSTs otherwise.
|
||||
# Seed it as a client-level default, but treat user overrides as
|
||||
# case-insensitive so conventional casing is preserved.
|
||||
if not any(key.lower() == "mcp-protocol-version" for key in headers):
|
||||
headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
ssl_verify = config.get("ssl_verify", True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue