diff --git a/tests/tools/test_mcp_sse_transport.py b/tests/tools/test_mcp_sse_transport.py new file mode 100644 index 0000000000..d5f15260ac --- /dev/null +++ b/tests/tools/test_mcp_sse_transport.py @@ -0,0 +1,209 @@ +"""Regression tests for SSE transport in ``MCPServerTask._run_http``. + +Covers fixes distilled from @amiller's PR #5981 that couldn't be cherry-picked +due to stale-branch divergence: + +1. ``sse_read_timeout`` is set to 300s (not the tool timeout). SSE servers + commonly hold the stream idle for minutes between events; a 60s read + timeout drops the connection after the first slow stretch. Original + observation: Router Teamwork / Supermemory on Cloudflare Workers dropping + at ~60s idle. + +2. OAuth auth is forwarded to ``sse_client`` when configured. Previously the + code built ``_oauth_auth`` but never passed it to the SSE path, so SSE MCP + servers behind OAuth 2.1 PKCE would silently fail with 401s. +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +async def _noop_initialize(): + return None + + +def _build_server_with_sse(oauth: bool = False): + """Stand up an MCPServerTask configured for SSE transport, with mocks + threaded through so ``_run_http`` can enter the SSE branch without a + real network call.""" + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("sse-test") + server._auth_type = "oauth" if oauth else "" + server._sampling = None + return server + + +@pytest.fixture +def patch_sse_client(): + """Replace ``sse_client`` with a MagicMock that records its kwargs. + + Returns the mock so tests can assert how ``_run_http`` called it. + """ + captured_kwargs: dict = {} + + class _FakeStream: + def __init__(self): + self._read = AsyncMock() + self._write = AsyncMock() + + async def __aenter__(self): + return (self._read, self._write) + + async def __aexit__(self, *a): + return False + + def fake_sse_client(**kwargs): + captured_kwargs.clear() + captured_kwargs.update(kwargs) + return _FakeStream() + + class _FakeSession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + return mock_session + + async def __aexit__(self, *a): + return False + + with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \ + patch("tools.mcp_tool.ClientSession", new=_FakeSession): + yield captured_kwargs + + +class TestSSEReadTimeout: + def test_sse_read_timeout_is_300s_not_tool_timeout(self, patch_sse_client): + """``sse_read_timeout`` must be 300s regardless of the configured + ``timeout``. Using the tool timeout (60s default) causes Cloudflare- + Workers-style SSE MCP servers to drop the connection at ~60s idle.""" + from tools.mcp_tool import MCPServerTask + + server = _build_server_with_sse() + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "timeout": 60, + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + assert patch_sse_client.get("sse_read_timeout") == 300.0, ( + f"sse_read_timeout = {patch_sse_client.get('sse_read_timeout')} " + f"(expected 300.0) — SSE idle disconnect regression" + ) + + def test_sse_read_timeout_still_300s_when_tool_timeout_is_large(self, patch_sse_client): + """Even if user sets a large ``timeout``, ``sse_read_timeout`` stays + decoupled — it's a transport-level budget for inter-event silence, + not a per-call budget.""" + from tools.mcp_tool import MCPServerTask + + server = _build_server_with_sse() + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "timeout": 600, + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + assert patch_sse_client.get("sse_read_timeout") == 300.0 + + +class TestSSEOAuthForwarding: + def test_sse_client_receives_oauth_auth_when_configured(self, patch_sse_client): + """If ``_auth_type == 'oauth'``, ``sse_client`` must receive the + constructed OAuth provider via ``auth=``. Previously the provider + was built but never forwarded to the SSE path.""" + from tools.mcp_tool import MCPServerTask + + server = _build_server_with_sse(oauth=True) + fake_oauth_provider = MagicMock(name="fake_oauth_provider") + fake_manager = MagicMock() + fake_manager.get_or_build_provider.return_value = fake_oauth_provider + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()), \ + patch("tools.mcp_oauth_manager.get_manager", return_value=fake_manager): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "auth": "oauth", + "timeout": 60, + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + assert "auth" in patch_sse_client, ( + "sse_client was NOT called with auth= — SSE OAuth forwarding regressed" + ) + assert patch_sse_client["auth"] is fake_oauth_provider + + def test_sse_client_omits_auth_when_no_oauth_configured(self, patch_sse_client): + """Without OAuth, ``sse_client`` should not receive an ``auth=`` kwarg. + Passing ``None`` would be equally fine but the current code path only + sets it when configured — lock that in.""" + from tools.mcp_tool import MCPServerTask + + server = _build_server_with_sse(oauth=False) + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "timeout": 60, + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + assert "auth" not in patch_sse_client, ( + f"sse_client was called with auth= when no OAuth was configured: " + f"{patch_sse_client!r}" + ) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index d2eb5bd344..fcfc5dbadc 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1243,12 +1243,26 @@ class MCPServerTask: "mcp.client.sse.sse_client is not available. " "Upgrade the mcp package to get SSE support." ) - async with sse_client( - url=url, - headers=headers or None, - timeout=float(connect_timeout), - sse_read_timeout=float(config.get("timeout", _DEFAULT_TOOL_TIMEOUT)), - ) as (read_stream, write_stream): + # sse_read_timeout governs how long sse_client will wait between + # events on the SSE stream. Using the tool_timeout (default 60s) + # here is wrong: SSE servers commonly hold the stream idle for + # minutes between events, so a 60s read timeout drops the + # connection after the first slow stretch. 300s matches the + # Streamable HTTP code path's httpx read timeout below. Original + # observation from @amiller in PR #5981 (Router Teamwork, + # Supermemory on Cloudflare Workers idle-disconnect at ~60s). + _sse_kwargs: dict = { + "url": url, + "headers": headers or None, + "timeout": float(connect_timeout), + "sse_read_timeout": 300.0, + } + if _oauth_auth is not None: + # Pass OAuth auth through to sse_client so SSE MCP servers + # behind OAuth 2.1 PKCE work. Previously built but never + # forwarded — SSE OAuth would silently fail with 401s. + _sse_kwargs["auth"] = _oauth_auth + async with sse_client(**_sse_kwargs) as (read_stream, write_stream): async with ClientSession( read_stream, write_stream, **sampling_kwargs ) as session: