fix(cua-driver): reconnect MCP stdio session once on ClosedResourceError after daemon restart (#40570)

Salvaged from #40282; cleaned up, re-verified against main, tests added.

Co-authored-by: jeeves-assistant <jeeves-assistant@users.noreply.github.com>
This commit is contained in:
Teknium 2026-06-06 18:35:12 -07:00 committed by GitHub
parent 97524344ad
commit 365437e4aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 106 additions and 1 deletions

View file

@ -1204,6 +1204,78 @@ def _make_cua_backend_with_windows(windows: List[Dict[str, Any]]):
return backend
class TestCuaDriverSessionReconnect:
def test_call_tool_reconnects_once_after_closed_resource(self):
"""A daemon restart closes the cached MCP stdio channel; recover once."""
import threading
from typing import Any, cast
from anyio import ClosedResourceError
from tools.computer_use.cua_backend import _CuaDriverSession
class FakeBridge:
def __init__(self):
self.calls = []
# 1st call_tool -> closed; aexit ok; aenter ok; retried call_tool ok.
self.effects = [ClosedResourceError(), None, None, {"ok": True}]
def run(self, value, timeout=None):
self.calls.append((value, timeout))
effect = self.effects.pop(0)
if isinstance(effect, Exception):
raise effect
return effect
bridge = FakeBridge()
session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession))
session._bridge = bridge
session._session = object()
session._exit_stack = None
session._lock = threading.Lock()
session._started = True
session._call_tool_async = lambda name, args: ("call", name, args)
session._aexit = lambda: ("aexit",)
session._aenter = lambda: ("aenter",)
assert session.call_tool("list_apps", {}) == {"ok": True}
# Reconnect-once sequence: failed call -> aexit -> aenter -> retried call.
assert bridge.calls[0][0] == ("call", "list_apps", {})
assert bridge.calls[1][0] == ("aexit",)
assert bridge.calls[2][0] == ("aenter",)
assert bridge.calls[3][0] == ("call", "list_apps", {})
assert len(bridge.calls) == 4
def test_call_tool_does_not_retry_on_unrelated_error(self):
"""Non-transport errors must propagate without a reconnect attempt."""
import threading
from typing import Any, cast
from tools.computer_use.cua_backend import _CuaDriverSession
class FakeBridge:
def __init__(self):
self.calls = []
def run(self, value, timeout=None):
self.calls.append((value, timeout))
raise ValueError("boom")
bridge = FakeBridge()
session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession))
session._bridge = bridge
session._session = object()
session._exit_stack = None
session._lock = threading.Lock()
session._started = True
session._call_tool_async = lambda name, args: ("call", name, args)
session._aexit = lambda: ("aexit",)
session._aenter = lambda: ("aenter",)
import pytest
with pytest.raises(ValueError):
session.call_tool("list_apps", {})
# Exactly one attempt, no reconnect.
assert len(bridge.calls) == 1
class TestCaptureAppFilterNoMatch:
"""capture(app=X) must not silently fall back to the frontmost window
when X matches nothing on a non-English macOS, list_windows returns

View file

@ -277,9 +277,42 @@ class _CuaDriverSession:
result = await self._session.call_tool(name, args)
return _extract_tool_result(result)
@staticmethod
def _is_closed_session_error(exc: Exception) -> bool:
"""Return True for MCP/stdio failures that are recoverable by reconnecting."""
name = exc.__class__.__name__
module = getattr(exc.__class__, "__module__", "")
return (
name in {"ClosedResourceError", "BrokenResourceError", "EndOfStream"}
or (module.startswith("anyio") and "Resource" in name)
or isinstance(exc, (BrokenPipeError, EOFError))
)
def _restart_session_locked(self) -> None:
"""Recreate the MCP session after the daemon/stdin transport was closed."""
try:
if self._started:
self._bridge.run(self._aexit(), timeout=5.0)
except Exception as e:
logger.debug("cua-driver session cleanup before reconnect failed: %s", e)
self._started = False
self._bridge.run(self._aenter(), timeout=15.0)
self._started = True
def call_tool(self, name: str, args: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
self._require_started()
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
try:
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
except Exception as e:
if not self._is_closed_session_error(e):
raise
# Daemon restart closes the cached stdio channel. Reconnect once and
# retry exactly one more time — never loop, to avoid hammering a
# genuinely dead daemon.
logger.warning("cua-driver MCP session closed during %s; reconnecting once", name)
with self._lock:
self._restart_session_locked()
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
def _extract_tool_result(mcp_result: Any) -> Dict[str, Any]: