mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): seed protocol header before HTTP initialize
This commit is contained in:
parent
983bbe2d40
commit
3ccda2aa05
2 changed files with 100 additions and 0 deletions
|
|
@ -1200,6 +1200,92 @@ class TestHTTPConfig:
|
||||||
|
|
||||||
asyncio.run(_test())
|
asyncio.run(_test())
|
||||||
|
|
||||||
|
def test_http_seeds_initial_protocol_header(self):
|
||||||
|
from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask
|
||||||
|
|
||||||
|
server = MCPServerTask("remote")
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class DummyAsyncClient:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class DummyTransportCtx:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return MagicMock(), MagicMock(), (lambda: None)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class DummySession:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyLegacyTransportCtx:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured["legacy_headers"] = kwargs.get("headers")
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return MagicMock(), MagicMock(), (lambda: None)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _discover_tools(self):
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
async def _run(config, *, new_http):
|
||||||
|
captured.clear()
|
||||||
|
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
||||||
|
patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \
|
||||||
|
patch("httpx.AsyncClient", DummyAsyncClient), \
|
||||||
|
patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \
|
||||||
|
patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \
|
||||||
|
patch("tools.mcp_tool.ClientSession", DummySession), \
|
||||||
|
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
||||||
|
await server._run_http(config)
|
||||||
|
|
||||||
|
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True))
|
||||||
|
assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||||
|
|
||||||
|
asyncio.run(_run({
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": {"mcp-protocol-version": "custom-version"},
|
||||||
|
}, new_http=True))
|
||||||
|
assert captured["headers"]["mcp-protocol-version"] == "custom-version"
|
||||||
|
|
||||||
|
asyncio.run(_run({
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||||
|
}, new_http=True))
|
||||||
|
assert captured["headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||||
|
assert "mcp-protocol-version" not in captured["headers"]
|
||||||
|
|
||||||
|
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False))
|
||||||
|
assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||||
|
|
||||||
|
asyncio.run(_run({
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||||
|
}, new_http=False))
|
||||||
|
assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||||
|
assert "mcp-protocol-version" not in captured["legacy_headers"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Reconnection logic
|
# Reconnection logic
|
||||||
|
|
|
||||||
|
|
@ -167,6 +167,10 @@ _MCP_HTTP_AVAILABLE = False
|
||||||
_MCP_SAMPLING_TYPES = False
|
_MCP_SAMPLING_TYPES = False
|
||||||
_MCP_NOTIFICATION_TYPES = False
|
_MCP_NOTIFICATION_TYPES = False
|
||||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||||
|
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
|
||||||
|
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
|
||||||
|
# HTTP transport path even on older-but-supported SDK versions.
|
||||||
|
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||||
try:
|
try:
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
@ -183,6 +187,10 @@ try:
|
||||||
_MCP_NEW_HTTP = True
|
_MCP_NEW_HTTP = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_MCP_NEW_HTTP = False
|
_MCP_NEW_HTTP = False
|
||||||
|
try:
|
||||||
|
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
|
||||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||||
try:
|
try:
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
|
@ -1075,6 +1083,12 @@ class MCPServerTask:
|
||||||
|
|
||||||
url = config["url"]
|
url = config["url"]
|
||||||
headers = dict(config.get("headers") or {})
|
headers = dict(config.get("headers") or {})
|
||||||
|
# Some MCP servers require MCP-Protocol-Version on the initial
|
||||||
|
# initialize request and reject session-less POSTs otherwise.
|
||||||
|
# Seed it as a client-level default, but treat user overrides as
|
||||||
|
# case-insensitive so conventional casing is preserved.
|
||||||
|
if not any(key.lower() == "mcp-protocol-version" for key in headers):
|
||||||
|
headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
|
||||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||||
ssl_verify = config.get("ssl_verify", True)
|
ssl_verify = config.get("ssl_verify", True)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue