mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
fix(mcp): forward OAuth auth and bump sse_read_timeout on SSE transport (#21323)
* fix(mcp): re-raise CancelledError explicitly in MCPServerTask.run On Python 3.11+, `asyncio.CancelledError` inherits from `BaseException` (not `Exception`), so the broad `except Exception as exc:` in `MCPServerTask.run`'s transport loop did NOT catch it. Task cancellation from gateway restart / explicit `task.cancel()` silently escaped past the reconnect logic — the MCP server task died without going through the shutdown/reconnect code paths that check `_shutdown_event`. Add an explicit `except asyncio.CancelledError: raise` before the broad catch so cancellation propagation is self-documenting rather than an accident of exception hierarchy, and future sibling-site work (e.g. distinguishing shutdown-cancel from transport-cancel) has an obvious hook. Behavior on pre-3.8 Pythons where CancelledError WAS an Exception subclass is also corrected: the old path would have caught it and treated it as a connection failure worth retrying. Closes #9930. * fix(mcp): forward OAuth auth and bump sse_read_timeout on SSE transport Two surgical correctness bugs in the SSE branch of MCPServerTask._run_http, distilled from @amiller's PR #5981 that couldn't be cherry-picked wholesale (branch too stale). 1. sse_read_timeout was set to the tool timeout (default 60s). That's the wrong dimension — it governs how long sse_client will wait between events on the SSE stream, not per-call latency. SSE servers routinely hold the stream idle for minutes between events; a 60s read timeout drops the connection after the first slow stretch (Router Teamwork, Supermemory on Cloudflare Workers idle-disconnect at ~60s). Bump to 300s to match the Streamable HTTP path's httpx read timeout. 2. OAuth auth was built via get_manager().get_or_build_provider() but never forwarded to sse_client. SSE MCP servers behind OAuth 2.1 PKCE would silently fail with 401s on every request. Keepalive (the other half of #5981) intentionally left for a follow-up — it's a real improvement but a bigger change, and these two are obvious corrections to ship now. Credits to @amiller. Co-authored-by: Andrew Miller <socrates1024@gmail.com> --------- Co-authored-by: Andrew Miller <socrates1024@gmail.com>
This commit is contained in:
parent
4ee6c3349a
commit
dd2dc2bddf
2 changed files with 229 additions and 6 deletions
209
tests/tools/test_mcp_sse_transport.py
Normal file
209
tests/tools/test_mcp_sse_transport.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""Regression tests for SSE transport in ``MCPServerTask._run_http``.
|
||||
|
||||
Covers fixes distilled from @amiller's PR #5981 that couldn't be cherry-picked
|
||||
due to stale-branch divergence:
|
||||
|
||||
1. ``sse_read_timeout`` is set to 300s (not the tool timeout). SSE servers
|
||||
commonly hold the stream idle for minutes between events; a 60s read
|
||||
timeout drops the connection after the first slow stretch. Original
|
||||
observation: Router Teamwork / Supermemory on Cloudflare Workers dropping
|
||||
at ~60s idle.
|
||||
|
||||
2. OAuth auth is forwarded to ``sse_client`` when configured. Previously the
|
||||
code built ``_oauth_auth`` but never passed it to the SSE path, so SSE MCP
|
||||
servers behind OAuth 2.1 PKCE would silently fail with 401s.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
async def _noop_initialize():
|
||||
return None
|
||||
|
||||
|
||||
def _build_server_with_sse(oauth: bool = False):
|
||||
"""Stand up an MCPServerTask configured for SSE transport, with mocks
|
||||
threaded through so ``_run_http`` can enter the SSE branch without a
|
||||
real network call."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("sse-test")
|
||||
server._auth_type = "oauth" if oauth else ""
|
||||
server._sampling = None
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_sse_client():
|
||||
"""Replace ``sse_client`` with a MagicMock that records its kwargs.
|
||||
|
||||
Returns the mock so tests can assert how ``_run_http`` called it.
|
||||
"""
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
class _FakeStream:
|
||||
def __init__(self):
|
||||
self._read = AsyncMock()
|
||||
self._write = AsyncMock()
|
||||
|
||||
async def __aenter__(self):
|
||||
return (self._read, self._write)
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
def fake_sse_client(**kwargs):
|
||||
captured_kwargs.clear()
|
||||
captured_kwargs.update(kwargs)
|
||||
return _FakeStream()
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
return mock_session
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \
|
||||
patch("tools.mcp_tool.ClientSession", new=_FakeSession):
|
||||
yield captured_kwargs
|
||||
|
||||
|
||||
class TestSSEReadTimeout:
|
||||
def test_sse_read_timeout_is_300s_not_tool_timeout(self, patch_sse_client):
|
||||
"""``sse_read_timeout`` must be 300s regardless of the configured
|
||||
``timeout``. Using the tool timeout (60s default) causes Cloudflare-
|
||||
Workers-style SSE MCP servers to drop the connection at ~60s idle."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = _build_server_with_sse()
|
||||
|
||||
async def drive():
|
||||
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
||||
new=AsyncMock(return_value="shutdown")), \
|
||||
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
server._run_http({
|
||||
"url": "https://example.com/mcp/sse",
|
||||
"transport": "sse",
|
||||
"timeout": 60,
|
||||
}),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
assert patch_sse_client.get("sse_read_timeout") == 300.0, (
|
||||
f"sse_read_timeout = {patch_sse_client.get('sse_read_timeout')} "
|
||||
f"(expected 300.0) — SSE idle disconnect regression"
|
||||
)
|
||||
|
||||
def test_sse_read_timeout_still_300s_when_tool_timeout_is_large(self, patch_sse_client):
|
||||
"""Even if user sets a large ``timeout``, ``sse_read_timeout`` stays
|
||||
decoupled — it's a transport-level budget for inter-event silence,
|
||||
not a per-call budget."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = _build_server_with_sse()
|
||||
|
||||
async def drive():
|
||||
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
||||
new=AsyncMock(return_value="shutdown")), \
|
||||
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
server._run_http({
|
||||
"url": "https://example.com/mcp/sse",
|
||||
"transport": "sse",
|
||||
"timeout": 600,
|
||||
}),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
assert patch_sse_client.get("sse_read_timeout") == 300.0
|
||||
|
||||
|
||||
class TestSSEOAuthForwarding:
|
||||
def test_sse_client_receives_oauth_auth_when_configured(self, patch_sse_client):
|
||||
"""If ``_auth_type == 'oauth'``, ``sse_client`` must receive the
|
||||
constructed OAuth provider via ``auth=``. Previously the provider
|
||||
was built but never forwarded to the SSE path."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = _build_server_with_sse(oauth=True)
|
||||
fake_oauth_provider = MagicMock(name="fake_oauth_provider")
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.get_or_build_provider.return_value = fake_oauth_provider
|
||||
|
||||
async def drive():
|
||||
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
||||
new=AsyncMock(return_value="shutdown")), \
|
||||
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()), \
|
||||
patch("tools.mcp_oauth_manager.get_manager", return_value=fake_manager):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
server._run_http({
|
||||
"url": "https://example.com/mcp/sse",
|
||||
"transport": "sse",
|
||||
"auth": "oauth",
|
||||
"timeout": 60,
|
||||
}),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
assert "auth" in patch_sse_client, (
|
||||
"sse_client was NOT called with auth= — SSE OAuth forwarding regressed"
|
||||
)
|
||||
assert patch_sse_client["auth"] is fake_oauth_provider
|
||||
|
||||
def test_sse_client_omits_auth_when_no_oauth_configured(self, patch_sse_client):
|
||||
"""Without OAuth, ``sse_client`` should not receive an ``auth=`` kwarg.
|
||||
Passing ``None`` would be equally fine but the current code path only
|
||||
sets it when configured — lock that in."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = _build_server_with_sse(oauth=False)
|
||||
|
||||
async def drive():
|
||||
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
||||
new=AsyncMock(return_value="shutdown")), \
|
||||
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
server._run_http({
|
||||
"url": "https://example.com/mcp/sse",
|
||||
"transport": "sse",
|
||||
"timeout": 60,
|
||||
}),
|
||||
timeout=2.0,
|
||||
)
|
||||
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
assert "auth" not in patch_sse_client, (
|
||||
f"sse_client was called with auth= when no OAuth was configured: "
|
||||
f"{patch_sse_client!r}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue