From af08c43f3e82a313122d9fdb71c521cdfcd75a72 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 7 Jun 2026 09:51:11 -0700 Subject: [PATCH] fix: skip MCP preflight content-type probe on reconnect when already ready (#40604) Closes #40366. Salvaged from #40548; re-verified on main, tightened, tested. Co-authored-by: mohamedorigami-jpg --- tests/tools/test_mcp_tool.py | 74 ++++++++++++++++++++++++++++++++++++ tools/mcp_tool.py | 7 +++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index e2575664748..f7a19f4d921 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -1719,6 +1719,80 @@ class TestReconnection: asyncio.run(_test()) + def test_preflight_probe_runs_on_initial_http_connect(self): + """The content-type preflight probe fires on the first HTTP connect.""" + from tools.mcp_tool import MCPServerTask + + target_server = None + probe = AsyncMock() + + original_run_http = MCPServerTask._run_http + + async def patched_run_http(self_srv, config): + if target_server is not self_srv: + return await original_run_http(self_srv, config) + # First connect succeeds; signal shutdown so run() exits cleanly. + self_srv.session = MagicMock() + self_srv._tools = [] + self_srv._ready.set() + self_srv._shutdown_event.set() + await self_srv._shutdown_event.wait() + + async def _test(): + nonlocal target_server + server = MCPServerTask("http_srv") + target_server = server + + with patch.object(MCPServerTask, "_run_http", patched_run_http), \ + patch.object(MCPServerTask, "_preflight_content_type", probe), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await server.run({"url": "https://example.com/mcp"}) + + # Probe ran exactly once on the initial (pre-_ready) connect. + assert probe.await_count == 1 + + asyncio.run(_test()) + + def test_preflight_probe_skipped_when_already_ready(self): + """The probe must NOT re-run on reconnect (_ready already set). + + On reconnect (OAuth recovery / manual refresh) run() is re-entered + with _ready still set from the prior successful connect. Re-probing + the already-validated endpoint burns a redundant network round-trip, + so the guard must skip it. Regression test for #40548. + """ + from tools.mcp_tool import MCPServerTask + + target_server = None + probe = AsyncMock() + + original_run_http = MCPServerTask._run_http + + async def patched_run_http(self_srv, config): + if target_server is not self_srv: + return await original_run_http(self_srv, config) + self_srv.session = MagicMock() + self_srv._tools = [] + self_srv._shutdown_event.set() + await self_srv._shutdown_event.wait() + + async def _test(): + nonlocal target_server + server = MCPServerTask("http_srv") + target_server = server + # Simulate a reconnect: _ready was set by the prior connect. + server._ready.set() + + with patch.object(MCPServerTask, "_run_http", patched_run_http), \ + patch.object(MCPServerTask, "_preflight_content_type", probe), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await server.run({"url": "https://example.com/mcp"}) + + # Probe skipped because _ready was already set. + assert probe.await_count == 0 + + asyncio.run(_test()) + # --------------------------------------------------------------------------- # Configurable timeouts diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 3a817d57995..5c3c46c4db4 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1800,7 +1800,12 @@ class MCPServerTask: # before surfacing an opaque CancelledError. Probing here — once, # outside the SDK task group — fails fast and non-retryably with # an actionable message, mirroring the URL-validation path above. - if config.get("transport") != "sse": + # Skip the probe when _ready is already set: that only happens + # after a prior successful connect, so this run() invocation is a + # reconnect (OAuth recovery / manual refresh). The endpoint was + # already validated once; re-probing burns a redundant network + # round-trip against a known-good server on every reconnect. + if config.get("transport") != "sse" and not self._ready.is_set(): try: _probe_headers = dict(config.get("headers") or {}) await self._preflight_content_type(