mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-16 09:31:37 +00:00
fix(cua-driver): reconnect MCP stdio session once on ClosedResourceError after daemon restart (#40570)
Salvaged from #40282; cleaned up, re-verified against main, tests added. Co-authored-by: jeeves-assistant <jeeves-assistant@users.noreply.github.com>
This commit is contained in:
parent
97524344ad
commit
365437e4aa
2 changed files with 106 additions and 1 deletions
|
|
@ -1204,6 +1204,78 @@ def _make_cua_backend_with_windows(windows: List[Dict[str, Any]]):
|
|||
return backend
|
||||
|
||||
|
||||
class TestCuaDriverSessionReconnect:
|
||||
def test_call_tool_reconnects_once_after_closed_resource(self):
|
||||
"""A daemon restart closes the cached MCP stdio channel; recover once."""
|
||||
import threading
|
||||
from typing import Any, cast
|
||||
from anyio import ClosedResourceError
|
||||
from tools.computer_use.cua_backend import _CuaDriverSession
|
||||
|
||||
class FakeBridge:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
# 1st call_tool -> closed; aexit ok; aenter ok; retried call_tool ok.
|
||||
self.effects = [ClosedResourceError(), None, None, {"ok": True}]
|
||||
|
||||
def run(self, value, timeout=None):
|
||||
self.calls.append((value, timeout))
|
||||
effect = self.effects.pop(0)
|
||||
if isinstance(effect, Exception):
|
||||
raise effect
|
||||
return effect
|
||||
|
||||
bridge = FakeBridge()
|
||||
session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession))
|
||||
session._bridge = bridge
|
||||
session._session = object()
|
||||
session._exit_stack = None
|
||||
session._lock = threading.Lock()
|
||||
session._started = True
|
||||
session._call_tool_async = lambda name, args: ("call", name, args)
|
||||
session._aexit = lambda: ("aexit",)
|
||||
session._aenter = lambda: ("aenter",)
|
||||
|
||||
assert session.call_tool("list_apps", {}) == {"ok": True}
|
||||
# Reconnect-once sequence: failed call -> aexit -> aenter -> retried call.
|
||||
assert bridge.calls[0][0] == ("call", "list_apps", {})
|
||||
assert bridge.calls[1][0] == ("aexit",)
|
||||
assert bridge.calls[2][0] == ("aenter",)
|
||||
assert bridge.calls[3][0] == ("call", "list_apps", {})
|
||||
assert len(bridge.calls) == 4
|
||||
|
||||
def test_call_tool_does_not_retry_on_unrelated_error(self):
|
||||
"""Non-transport errors must propagate without a reconnect attempt."""
|
||||
import threading
|
||||
from typing import Any, cast
|
||||
from tools.computer_use.cua_backend import _CuaDriverSession
|
||||
|
||||
class FakeBridge:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def run(self, value, timeout=None):
|
||||
self.calls.append((value, timeout))
|
||||
raise ValueError("boom")
|
||||
|
||||
bridge = FakeBridge()
|
||||
session = cast(Any, _CuaDriverSession.__new__(_CuaDriverSession))
|
||||
session._bridge = bridge
|
||||
session._session = object()
|
||||
session._exit_stack = None
|
||||
session._lock = threading.Lock()
|
||||
session._started = True
|
||||
session._call_tool_async = lambda name, args: ("call", name, args)
|
||||
session._aexit = lambda: ("aexit",)
|
||||
session._aenter = lambda: ("aenter",)
|
||||
|
||||
import pytest
|
||||
with pytest.raises(ValueError):
|
||||
session.call_tool("list_apps", {})
|
||||
# Exactly one attempt, no reconnect.
|
||||
assert len(bridge.calls) == 1
|
||||
|
||||
|
||||
class TestCaptureAppFilterNoMatch:
|
||||
"""capture(app=X) must not silently fall back to the frontmost window
|
||||
when X matches nothing — on a non-English macOS, list_windows returns
|
||||
|
|
|
|||
|
|
@ -277,9 +277,42 @@ class _CuaDriverSession:
|
|||
result = await self._session.call_tool(name, args)
|
||||
return _extract_tool_result(result)
|
||||
|
||||
@staticmethod
|
||||
def _is_closed_session_error(exc: Exception) -> bool:
|
||||
"""Return True for MCP/stdio failures that are recoverable by reconnecting."""
|
||||
name = exc.__class__.__name__
|
||||
module = getattr(exc.__class__, "__module__", "")
|
||||
return (
|
||||
name in {"ClosedResourceError", "BrokenResourceError", "EndOfStream"}
|
||||
or (module.startswith("anyio") and "Resource" in name)
|
||||
or isinstance(exc, (BrokenPipeError, EOFError))
|
||||
)
|
||||
|
||||
def _restart_session_locked(self) -> None:
|
||||
"""Recreate the MCP session after the daemon/stdin transport was closed."""
|
||||
try:
|
||||
if self._started:
|
||||
self._bridge.run(self._aexit(), timeout=5.0)
|
||||
except Exception as e:
|
||||
logger.debug("cua-driver session cleanup before reconnect failed: %s", e)
|
||||
self._started = False
|
||||
self._bridge.run(self._aenter(), timeout=15.0)
|
||||
self._started = True
|
||||
|
||||
def call_tool(self, name: str, args: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
|
||||
self._require_started()
|
||||
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
|
||||
try:
|
||||
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
|
||||
except Exception as e:
|
||||
if not self._is_closed_session_error(e):
|
||||
raise
|
||||
# Daemon restart closes the cached stdio channel. Reconnect once and
|
||||
# retry exactly one more time — never loop, to avoid hammering a
|
||||
# genuinely dead daemon.
|
||||
logger.warning("cua-driver MCP session closed during %s; reconnecting once", name)
|
||||
with self._lock:
|
||||
self._restart_session_locked()
|
||||
return self._bridge.run(self._call_tool_async(name, args), timeout=timeout)
|
||||
|
||||
|
||||
def _extract_tool_result(mcp_result: Any) -> Dict[str, Any]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue