mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: resolve thread safety issues and shutdown deadlock in MCP client
- Add threading.Lock protecting all shared state (_servers, _mcp_loop, _mcp_thread)
- Fix deadlock in shutdown_mcp_servers: _stop_mcp_loop was called inside
a _lock block but also acquires _lock (non-reentrant)
- Fix race condition in _ensure_mcp_loop with concurrent callers
- Change idempotency to per-server (retry failed servers, skip connected)
- Dynamic toolset injection via startswith("hermes-") instead of hardcoded list
- Parallel shutdown via asyncio.gather instead of sequential loop
- Add tests for partial failure retry, parallel shutdown, dynamic injection
This commit is contained in:
parent
151e8d896c
commit
11a2ecb936
2 changed files with 184 additions and 58 deletions
|
|
@ -459,14 +459,13 @@ class TestMCPServerTask:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolsetInjection:
|
||||
def test_mcp_tools_added_to_platform_toolsets(self):
|
||||
"""Discovered MCP tools are injected into hermes-cli and platform toolsets."""
|
||||
def test_mcp_tools_added_to_all_hermes_toolsets(self):
|
||||
"""Discovered MCP tools are dynamically injected into all hermes-* toolsets."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("list_files", "List files")]
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Fresh _servers dict to bypass idempotency guard
|
||||
fresh_servers = {}
|
||||
|
||||
async def fake_connect(name, config):
|
||||
|
|
@ -476,12 +475,12 @@ class TestToolsetInjection:
|
|||
return server
|
||||
|
||||
fake_toolsets = {
|
||||
"hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []},
|
||||
}
|
||||
fake_config = {
|
||||
"fs": {"command": "npx", "args": []},
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
"hermes-gateway": {"tools": [], "description": "GW", "includes": []},
|
||||
"non-hermes": {"tools": [], "description": "other", "includes": []},
|
||||
}
|
||||
fake_config = {"fs": {"command": "npx", "args": []}}
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._servers", fresh_servers), \
|
||||
|
|
@ -492,8 +491,12 @@ class TestToolsetInjection:
|
|||
result = discover_mcp_tools()
|
||||
|
||||
assert "mcp_fs_list_files" in result
|
||||
# All hermes-* toolsets get injection
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"]
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"]
|
||||
assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"]
|
||||
# Non-hermes toolset should NOT get injection
|
||||
assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"]
|
||||
# Original tools preserved
|
||||
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
|
||||
|
||||
|
|
@ -504,7 +507,6 @@ class TestToolsetInjection:
|
|||
mock_tools = [_make_mcp_tool("ping", "Ping")]
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Fresh _servers dict to bypass idempotency guard
|
||||
fresh_servers = {}
|
||||
call_count = 0
|
||||
|
||||
|
|
@ -534,10 +536,62 @@ class TestToolsetInjection:
|
|||
from tools.mcp_tool import discover_mcp_tools
|
||||
result = discover_mcp_tools()
|
||||
|
||||
# Only good server's tool registered
|
||||
assert "mcp_good_ping" in result
|
||||
assert "mcp_broken_ping" not in result
|
||||
assert call_count == 2 # Both were attempted
|
||||
assert call_count == 2
|
||||
|
||||
def test_partial_failure_retry_on_second_call(self):
|
||||
"""Failed servers are retried on subsequent discover_mcp_tools() calls."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("ping", "Ping")]
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Use a real dict so idempotency logic works correctly
|
||||
fresh_servers = {}
|
||||
call_count = 0
|
||||
broken_fixed = False
|
||||
|
||||
async def flaky_connect(name, config):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if name == "broken" and not broken_fixed:
|
||||
raise ConnectionError("cannot reach server")
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
fake_config = {
|
||||
"broken": {"command": "bad"},
|
||||
"good": {"command": "npx", "args": []},
|
||||
}
|
||||
fake_toolsets = {
|
||||
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
|
||||
}
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._servers", fresh_servers), \
|
||||
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
|
||||
patch("toolsets.TOOLSETS", fake_toolsets):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
|
||||
# First call: good connects, broken fails
|
||||
result1 = discover_mcp_tools()
|
||||
assert "mcp_good_ping" in result1
|
||||
assert "mcp_broken_ping" not in result1
|
||||
first_attempts = call_count
|
||||
|
||||
# "Fix" the broken server
|
||||
broken_fixed = True
|
||||
call_count = 0
|
||||
|
||||
# Second call: should retry broken, skip good
|
||||
result2 = discover_mcp_tools()
|
||||
assert "mcp_good_ping" in result2
|
||||
assert "mcp_broken_ping" in result2
|
||||
assert call_count == 1 # Only broken retried
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -581,6 +635,7 @@ class TestShutdown:
|
|||
|
||||
_servers.clear()
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "test"
|
||||
mock_server.shutdown = AsyncMock()
|
||||
_servers["test"] = mock_server
|
||||
|
||||
|
|
@ -601,6 +656,7 @@ class TestShutdown:
|
|||
|
||||
_servers.clear()
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = "broken"
|
||||
mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
|
||||
_servers["broken"] = mock_server
|
||||
|
||||
|
|
@ -612,3 +668,33 @@ class TestShutdown:
|
|||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert len(_servers) == 0
|
||||
|
||||
def test_shutdown_is_parallel(self):
|
||||
"""Multiple servers are shut down in parallel via asyncio.gather."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _servers
|
||||
import time
|
||||
|
||||
_servers.clear()
|
||||
|
||||
# 3 servers each taking 1s to shut down
|
||||
for i in range(3):
|
||||
mock_server = MagicMock()
|
||||
mock_server.name = f"srv_{i}"
|
||||
async def slow_shutdown():
|
||||
await asyncio.sleep(1)
|
||||
mock_server.shutdown = slow_shutdown
|
||||
_servers[f"srv_{i}"] = mock_server
|
||||
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
try:
|
||||
start = time.monotonic()
|
||||
shutdown_mcp_servers()
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
mcp_mod._mcp_loop = None
|
||||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert len(_servers) == 0
|
||||
# Parallel: ~1s, not ~3s. Allow some margin.
|
||||
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
|
||||
|
|
|
|||
|
|
@ -32,6 +32,12 @@ Architecture:
|
|||
On shutdown, each server Task is signalled to exit its ``async with``
|
||||
block, ensuring the anyio cancel-scope cleanup happens in the *same*
|
||||
Task that opened the connection (required by anyio).
|
||||
|
||||
Thread safety:
|
||||
_servers and _mcp_loop/_mcp_thread are accessed from both the MCP
|
||||
background thread and caller threads. All mutations are protected by
|
||||
_lock so the code is safe regardless of GIL presence (e.g. Python 3.13+
|
||||
free-threading).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -161,26 +167,32 @@ _servers: Dict[str, MCPServerTask] = {}
|
|||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def _ensure_mcp_loop():
|
||||
"""Start the background event loop thread if not already running."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
return
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever,
|
||||
name="mcp-event-loop",
|
||||
daemon=True,
|
||||
)
|
||||
_mcp_thread.start()
|
||||
with _lock:
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
return
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever,
|
||||
name="mcp-event-loop",
|
||||
daemon=True,
|
||||
)
|
||||
_mcp_thread.start()
|
||||
|
||||
|
||||
def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
"""Schedule a coroutine on the MCP event loop and block until done."""
|
||||
if _mcp_loop is None or not _mcp_loop.is_running():
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
if loop is None or not loop.is_running():
|
||||
raise RuntimeError("MCP event loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
|
||||
|
|
@ -236,7 +248,8 @@ def _make_tool_handler(server_name: str, tool_name: str):
|
|||
"""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
server = _servers.get(server_name)
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
|
|
@ -272,7 +285,8 @@ def _make_check_fn(server_name: str):
|
|||
"""Return a check function that verifies the MCP connection is alive."""
|
||||
|
||||
def _check() -> bool:
|
||||
server = _servers.get(server_name)
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
return server is not None and server.session is not None
|
||||
|
||||
return _check
|
||||
|
|
@ -307,6 +321,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
names.append(schema["name"])
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
|
|
@ -316,7 +340,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
|||
from toolsets import create_custom_toolset
|
||||
|
||||
server = await _connect_server(name, config)
|
||||
_servers[name] = server
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
|
@ -361,8 +386,8 @@ def discover_mcp_tools() -> List[str]:
|
|||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
the ``mcp`` package is not installed (returns empty list).
|
||||
|
||||
Idempotent: if servers are already connected, returns the existing
|
||||
tool names without creating duplicate connections.
|
||||
Idempotent for already-connected servers. If some servers failed on a
|
||||
previous call, only the missing ones are retried.
|
||||
|
||||
Returns:
|
||||
List of all registered MCP tool names.
|
||||
|
|
@ -371,27 +396,25 @@ def discover_mcp_tools() -> List[str]:
|
|||
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
|
||||
return []
|
||||
|
||||
# Already connected -- return existing tool names (idempotent)
|
||||
if _servers:
|
||||
existing: List[str] = []
|
||||
for name, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(name, mcp_tool)
|
||||
existing.append(schema["name"])
|
||||
return existing
|
||||
|
||||
servers = _load_mcp_config()
|
||||
if not servers:
|
||||
logger.debug("No MCP servers configured")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected
|
||||
with _lock:
|
||||
new_servers = {k: v for k, v in servers.items() if k not in _servers}
|
||||
|
||||
if not new_servers:
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
|
||||
async def _discover_all():
|
||||
for name, cfg in servers.items():
|
||||
for name, cfg in new_servers.items():
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
all_tools.extend(registered)
|
||||
|
|
@ -401,17 +424,16 @@ def discover_mcp_tools() -> List[str]:
|
|||
_run_on_mcp_loop(_discover_all(), timeout=60)
|
||||
|
||||
if all_tools:
|
||||
# Add MCP tools to hermes-cli and other platform toolsets
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
from toolsets import TOOLSETS
|
||||
for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord",
|
||||
"hermes-whatsapp", "hermes-slack"):
|
||||
ts = TOOLSETS.get(ts_name)
|
||||
if ts:
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in all_tools:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
|
||||
return all_tools
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
|
|
@ -419,24 +441,39 @@ def shutdown_mcp_servers():
|
|||
|
||||
Each server Task is signalled to exit its ``async with`` block so that
|
||||
the anyio cancel-scope cleanup happens in the same Task that opened it.
|
||||
All servers are shut down in parallel via ``asyncio.gather``.
|
||||
"""
|
||||
global _mcp_loop, _mcp_thread
|
||||
with _lock:
|
||||
if not _servers:
|
||||
# No servers -- just stop the loop. _stop_mcp_loop() also
|
||||
# acquires _lock, so we must release it first.
|
||||
pass
|
||||
else:
|
||||
servers_snapshot = list(_servers.values())
|
||||
|
||||
# Fast path: nothing to shut down.
|
||||
if not _servers:
|
||||
_stop_mcp_loop()
|
||||
return
|
||||
|
||||
async def _shutdown():
|
||||
for name, server in list(_servers.items()):
|
||||
try:
|
||||
await server.shutdown()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing MCP server '%s': %s", name, exc)
|
||||
_servers.clear()
|
||||
results = await asyncio.gather(
|
||||
*(server.shutdown() for server in servers_snapshot),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for server, result in zip(servers_snapshot, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(
|
||||
"Error closing MCP server '%s': %s", server.name, result,
|
||||
)
|
||||
with _lock:
|
||||
_servers.clear()
|
||||
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop)
|
||||
future = asyncio.run_coroutine_threadsafe(_shutdown(), loop)
|
||||
future.result(timeout=15)
|
||||
except Exception as exc:
|
||||
logger.debug("Error during MCP shutdown: %s", exc)
|
||||
|
|
@ -447,10 +484,13 @@ def shutdown_mcp_servers():
|
|||
def _stop_mcp_loop():
|
||||
"""Stop the background event loop and join its thread."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
if _mcp_loop is not None:
|
||||
_mcp_loop.call_soon_threadsafe(_mcp_loop.stop)
|
||||
if _mcp_thread is not None:
|
||||
_mcp_thread.join(timeout=5)
|
||||
_mcp_thread = None
|
||||
_mcp_loop.close()
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
thread = _mcp_thread
|
||||
_mcp_loop = None
|
||||
_mcp_thread = None
|
||||
if loop is not None:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
if thread is not None:
|
||||
thread.join(timeout=5)
|
||||
loop.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue