From 365437e4aaf3756d1d921dad6012eb93d4303465 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 6 Jun 2026 18:35:12 -0700 Subject: [PATCH] fix(cua-driver): reconnect MCP stdio session once on ClosedResourceError after daemon restart (#40570) Salvaged from #40282; cleaned up, re-verified against main, tests added. Co-authored-by: jeeves-assistant --- tests/tools/test_computer_use.py | 72 +++++++++++++++++++++++++++++++ tools/computer_use/cua_backend.py | 35 ++++++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_computer_use.py b/tests/tools/test_computer_use.py index c60a5426f9c..7e7420db596 100644 --- a/tests/tools/test_computer_use.py +++ b/tests/tools/test_computer_use.py @@ -1204,6 +1204,78 @@ def _make_cua_backend_with_windows(windows: List[Dict[str, Any]]): return backend +class TestCuaDriverSessionReconnect: + def test_call_tool_reconnects_once_after_closed_resource(self): + """A daemon restart closes the cached MCP stdio channel; recover once.""" + import threading + from typing import Any, cast + from anyio import ClosedResourceError + from tools.computer_use.cua_backend import _CuaDriverSession + + class FakeBridge: + def __init__(self): + self.calls = [] + # 1st call_tool -> closed; aexit ok; aenter ok; retried call_tool ok. + self.effects = [ClosedResourceError(), None, None, {"ok": True}] + + def run(self, value, timeout=None): + self.calls.append((value, timeout)) + effect = self.effects.pop(0) + if isinstance(effect, Exception): + raise effect + return effect + + bridge = FakeBridge() + session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession)) + session._bridge = bridge + session._session = object() + session._exit_stack = None + session._lock = threading.Lock() + session._started = True + session._call_tool_async = lambda name, args: ("call", name, args) + session._aexit = lambda: ("aexit",) + session._aenter = lambda: ("aenter",) + + assert session.call_tool("list_apps", {}) == {"ok": True} + # Reconnect-once sequence: failed call -> aexit -> aenter -> retried call. + assert bridge.calls[0][0] == ("call", "list_apps", {}) + assert bridge.calls[1][0] == ("aexit",) + assert bridge.calls[2][0] == ("aenter",) + assert bridge.calls[3][0] == ("call", "list_apps", {}) + assert len(bridge.calls) == 4 + + def test_call_tool_does_not_retry_on_unrelated_error(self): + """Non-transport errors must propagate without a reconnect attempt.""" + import threading + from typing import Any, cast + from tools.computer_use.cua_backend import _CuaDriverSession + + class FakeBridge: + def __init__(self): + self.calls = [] + + def run(self, value, timeout=None): + self.calls.append((value, timeout)) + raise ValueError("boom") + + bridge = FakeBridge() + session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession)) + session._bridge = bridge + session._session = object() + session._exit_stack = None + session._lock = threading.Lock() + session._started = True + session._call_tool_async = lambda name, args: ("call", name, args) + session._aexit = lambda: ("aexit",) + session._aenter = lambda: ("aenter",) + + import pytest + with pytest.raises(ValueError): + session.call_tool("list_apps", {}) + # Exactly one attempt, no reconnect. + assert len(bridge.calls) == 1 + + class TestCaptureAppFilterNoMatch: """capture(app=X) must not silently fall back to the frontmost window when X matches nothing — on a non-English macOS, list_windows returns diff --git a/tools/computer_use/cua_backend.py b/tools/computer_use/cua_backend.py index 714ae6d3260..5ade2fdf85e 100644 --- a/tools/computer_use/cua_backend.py +++ b/tools/computer_use/cua_backend.py @@ -277,9 +277,42 @@ class _CuaDriverSession: result = await self._session.call_tool(name, args) return _extract_tool_result(result) + @staticmethod + def _is_closed_session_error(exc: Exception) -> bool: + """Return True for MCP/stdio failures that are recoverable by reconnecting.""" + name = exc.__class__.__name__ + module = getattr(exc.__class__, "__module__", "") + return ( + name in {"ClosedResourceError", "BrokenResourceError", "EndOfStream"} + or (module.startswith("anyio") and "Resource" in name) + or isinstance(exc, (BrokenPipeError, EOFError)) + ) + + def _restart_session_locked(self) -> None: + """Recreate the MCP session after the daemon/stdin transport was closed.""" + try: + if self._started: + self._bridge.run(self._aexit(), timeout=5.0) + except Exception as e: + logger.debug("cua-driver session cleanup before reconnect failed: %s", e) + self._started = False + self._bridge.run(self._aenter(), timeout=15.0) + self._started = True + def call_tool(self, name: str, args: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]: self._require_started() - return self._bridge.run(self._call_tool_async(name, args), timeout=timeout) + try: + return self._bridge.run(self._call_tool_async(name, args), timeout=timeout) + except Exception as e: + if not self._is_closed_session_error(e): + raise + # Daemon restart closes the cached stdio channel. Reconnect once and + # retry exactly one more time — never loop, to avoid hammering a + # genuinely dead daemon. + logger.warning("cua-driver MCP session closed during %s; reconnecting once", name) + with self._lock: + self._restart_session_locked() + return self._bridge.run(self._call_tool_async(name, args), timeout=timeout) def _extract_tool_result(mcp_result: Any) -> Dict[str, Any]: