fix(mcp): forward OAuth auth and bump sse_read_timeout on SSE transport (#21323)

* fix(mcp): re-raise CancelledError explicitly in MCPServerTask.run

On Python 3.11+, `asyncio.CancelledError` inherits from `BaseException`
(not `Exception`), so the broad `except Exception as exc:` in
`MCPServerTask.run`'s transport loop did NOT catch it. Task cancellation
from gateway restart / explicit `task.cancel()` silently escaped past
the reconnect logic — the MCP server task died without going through
the shutdown/reconnect code paths that check `_shutdown_event`.

Add an explicit `except asyncio.CancelledError: raise` before the broad
catch so cancellation propagation is self-documenting rather than an
accident of exception hierarchy, and future sibling-site work (e.g.
distinguishing shutdown-cancel from transport-cancel) has an obvious
hook. Behavior on pre-3.8 Pythons where CancelledError WAS an Exception
subclass is also corrected: the old path would have caught it and
treated it as a connection failure worth retrying.

Closes #9930.

* fix(mcp): forward OAuth auth and bump sse_read_timeout on SSE transport

Two surgical correctness bugs in the SSE branch of MCPServerTask._run_http,
distilled from @amiller's PR #5981 that couldn't be cherry-picked wholesale
(branch too stale).

1. sse_read_timeout was set to the tool timeout (default 60s). That's the
   wrong dimension — it governs how long sse_client will wait between
   events on the SSE stream, not per-call latency. SSE servers routinely
   hold the stream idle for minutes between events; a 60s read timeout
   drops the connection after the first slow stretch (Router Teamwork,
   Supermemory on Cloudflare Workers idle-disconnect at ~60s). Bump to
   300s to match the Streamable HTTP path's httpx read timeout.

2. OAuth auth was built via get_manager().get_or_build_provider() but
   never forwarded to sse_client. SSE MCP servers behind OAuth 2.1 PKCE
   would silently fail with 401s on every request.

Keepalive (the other half of #5981) intentionally left for a follow-up —
it's a real improvement but a bigger change, and these two are obvious
corrections to ship now. Credits to @amiller.

Co-authored-by: Andrew Miller <socrates1024@gmail.com>

---------

Co-authored-by: Andrew Miller <socrates1024@gmail.com>
This commit is contained in:
Teknium 2026-05-07 07:08:04 -07:00 committed by GitHub
parent 4ee6c3349a
commit dd2dc2bddf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 229 additions and 6 deletions

View file

@ -0,0 +1,209 @@
"""Regression tests for SSE transport in ``MCPServerTask._run_http``.
Covers fixes distilled from @amiller's PR #5981 that couldn't be cherry-picked
due to stale-branch divergence:
1. ``sse_read_timeout`` is set to 300s (not the tool timeout). SSE servers
commonly hold the stream idle for minutes between events; a 60s read
timeout drops the connection after the first slow stretch. Original
observation: Router Teamwork / Supermemory on Cloudflare Workers dropping
at ~60s idle.
2. OAuth auth is forwarded to ``sse_client`` when configured. Previously the
code built ``_oauth_auth`` but never passed it to the SSE path, so SSE MCP
servers behind OAuth 2.1 PKCE would silently fail with 401s.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
async def _noop_initialize():
return None
def _build_server_with_sse(oauth: bool = False):
"""Stand up an MCPServerTask configured for SSE transport, with mocks
threaded through so ``_run_http`` can enter the SSE branch without a
real network call."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sse-test")
server._auth_type = "oauth" if oauth else ""
server._sampling = None
return server
@pytest.fixture
def patch_sse_client():
"""Replace ``sse_client`` with a MagicMock that records its kwargs.
Returns the mock so tests can assert how ``_run_http`` called it.
"""
captured_kwargs: dict = {}
class _FakeStream:
def __init__(self):
self._read = AsyncMock()
self._write = AsyncMock()
async def __aenter__(self):
return (self._read, self._write)
async def __aexit__(self, *a):
return False
def fake_sse_client(**kwargs):
captured_kwargs.clear()
captured_kwargs.update(kwargs)
return _FakeStream()
class _FakeSession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
return mock_session
async def __aexit__(self, *a):
return False
with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \
patch("tools.mcp_tool.ClientSession", new=_FakeSession):
yield captured_kwargs
class TestSSEReadTimeout:
def test_sse_read_timeout_is_300s_not_tool_timeout(self, patch_sse_client):
"""``sse_read_timeout`` must be 300s regardless of the configured
``timeout``. Using the tool timeout (60s default) causes Cloudflare-
Workers-style SSE MCP servers to drop the connection at ~60s idle."""
from tools.mcp_tool import MCPServerTask
server = _build_server_with_sse()
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"timeout": 60,
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert patch_sse_client.get("sse_read_timeout") == 300.0, (
f"sse_read_timeout = {patch_sse_client.get('sse_read_timeout')} "
f"(expected 300.0) — SSE idle disconnect regression"
)
def test_sse_read_timeout_still_300s_when_tool_timeout_is_large(self, patch_sse_client):
"""Even if user sets a large ``timeout``, ``sse_read_timeout`` stays
decoupled it's a transport-level budget for inter-event silence,
not a per-call budget."""
from tools.mcp_tool import MCPServerTask
server = _build_server_with_sse()
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"timeout": 600,
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert patch_sse_client.get("sse_read_timeout") == 300.0
class TestSSEOAuthForwarding:
def test_sse_client_receives_oauth_auth_when_configured(self, patch_sse_client):
"""If ``_auth_type == 'oauth'``, ``sse_client`` must receive the
constructed OAuth provider via ``auth=``. Previously the provider
was built but never forwarded to the SSE path."""
from tools.mcp_tool import MCPServerTask
server = _build_server_with_sse(oauth=True)
fake_oauth_provider = MagicMock(name="fake_oauth_provider")
fake_manager = MagicMock()
fake_manager.get_or_build_provider.return_value = fake_oauth_provider
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()), \
patch("tools.mcp_oauth_manager.get_manager", return_value=fake_manager):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"auth": "oauth",
"timeout": 60,
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert "auth" in patch_sse_client, (
"sse_client was NOT called with auth= — SSE OAuth forwarding regressed"
)
assert patch_sse_client["auth"] is fake_oauth_provider
def test_sse_client_omits_auth_when_no_oauth_configured(self, patch_sse_client):
"""Without OAuth, ``sse_client`` should not receive an ``auth=`` kwarg.
Passing ``None`` would be equally fine but the current code path only
sets it when configured lock that in."""
from tools.mcp_tool import MCPServerTask
server = _build_server_with_sse(oauth=False)
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"timeout": 60,
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert "auth" not in patch_sse_client, (
f"sse_client was called with auth= when no OAuth was configured: "
f"{patch_sse_client!r}"
)