fix: resolve thread safety issues and shutdown deadlock in MCP client

- Add threading.Lock protecting all shared state (_servers, _mcp_loop, _mcp_thread)
- Fix deadlock in shutdown_mcp_servers: _stop_mcp_loop was called inside
  a _lock block but also acquires _lock (non-reentrant)
- Fix race condition in _ensure_mcp_loop with concurrent callers
- Change idempotency to per-server (retry failed servers, skip connected)
- Dynamic toolset injection via startswith("hermes-") instead of hardcoded list
- Parallel shutdown via asyncio.gather instead of sequential loop
- Add tests for partial failure retry, parallel shutdown, dynamic injection
This commit is contained in:
0xbyt4 2026-03-02 22:08:32 +03:00
parent 151e8d896c
commit 11a2ecb936
2 changed files with 184 additions and 58 deletions

View file

@ -459,14 +459,13 @@ class TestMCPServerTask:
# ---------------------------------------------------------------------------
class TestToolsetInjection:
def test_mcp_tools_added_to_platform_toolsets(self):
"""Discovered MCP tools are injected into hermes-cli and platform toolsets."""
def test_mcp_tools_added_to_all_hermes_toolsets(self):
"""Discovered MCP tools are dynamically injected into all hermes-* toolsets."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("list_files", "List files")]
mock_session = MagicMock()
# Fresh _servers dict to bypass idempotency guard
fresh_servers = {}
async def fake_connect(name, config):
@ -476,12 +475,12 @@ class TestToolsetInjection:
return server
fake_toolsets = {
"hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []},
}
fake_config = {
"fs": {"command": "npx", "args": []},
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
"hermes-gateway": {"tools": [], "description": "GW", "includes": []},
"non-hermes": {"tools": [], "description": "other", "includes": []},
}
fake_config = {"fs": {"command": "npx", "args": []}}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
@ -492,8 +491,12 @@ class TestToolsetInjection:
result = discover_mcp_tools()
assert "mcp_fs_list_files" in result
# All hermes-* toolsets get injection
assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"]
# Non-hermes toolset should NOT get injection
assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"]
# Original tools preserved
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
@ -504,7 +507,6 @@ class TestToolsetInjection:
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
# Fresh _servers dict to bypass idempotency guard
fresh_servers = {}
call_count = 0
@ -534,10 +536,62 @@ class TestToolsetInjection:
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
# Only good server's tool registered
assert "mcp_good_ping" in result
assert "mcp_broken_ping" not in result
assert call_count == 2 # Both were attempted
assert call_count == 2
def test_partial_failure_retry_on_second_call(self):
"""Failed servers are retried on subsequent discover_mcp_tools() calls."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
# Use a real dict so idempotency logic works correctly
fresh_servers = {}
call_count = 0
broken_fixed = False
async def flaky_connect(name, config):
nonlocal call_count
call_count += 1
if name == "broken" and not broken_fixed:
raise ConnectionError("cannot reach server")
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
fake_config = {
"broken": {"command": "bad"},
"good": {"command": "npx", "args": []},
}
fake_toolsets = {
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
patch("toolsets.TOOLSETS", fake_toolsets):
from tools.mcp_tool import discover_mcp_tools
# First call: good connects, broken fails
result1 = discover_mcp_tools()
assert "mcp_good_ping" in result1
assert "mcp_broken_ping" not in result1
first_attempts = call_count
# "Fix" the broken server
broken_fixed = True
call_count = 0
# Second call: should retry broken, skip good
result2 = discover_mcp_tools()
assert "mcp_good_ping" in result2
assert "mcp_broken_ping" in result2
assert call_count == 1 # Only broken retried
# ---------------------------------------------------------------------------
@ -581,6 +635,7 @@ class TestShutdown:
_servers.clear()
mock_server = MagicMock()
mock_server.name = "test"
mock_server.shutdown = AsyncMock()
_servers["test"] = mock_server
@ -601,6 +656,7 @@ class TestShutdown:
_servers.clear()
mock_server = MagicMock()
mock_server.name = "broken"
mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
_servers["broken"] = mock_server
@ -612,3 +668,33 @@ class TestShutdown:
mcp_mod._mcp_thread = None
assert len(_servers) == 0
def test_shutdown_is_parallel(self):
"""Multiple servers are shut down in parallel via asyncio.gather."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import shutdown_mcp_servers, _servers
import time
_servers.clear()
# 3 servers each taking 1s to shut down
for i in range(3):
mock_server = MagicMock()
mock_server.name = f"srv_{i}"
async def slow_shutdown():
await asyncio.sleep(1)
mock_server.shutdown = slow_shutdown
_servers[f"srv_{i}"] = mock_server
mcp_mod._ensure_mcp_loop()
try:
start = time.monotonic()
shutdown_mcp_servers()
elapsed = time.monotonic() - start
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert len(_servers) == 0
# Parallel: ~1s, not ~3s. Allow some margin.
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"

View file

@ -32,6 +32,12 @@ Architecture:
On shutdown, each server Task is signalled to exit its ``async with``
block, ensuring the anyio cancel-scope cleanup happens in the *same*
Task that opened the connection (required by anyio).
Thread safety:
_servers and _mcp_loop/_mcp_thread are accessed from both the MCP
background thread and caller threads. All mutations are protected by
_lock so the code is safe regardless of GIL presence (e.g. Python 3.13+
free-threading).
"""
import asyncio
@ -161,26 +167,32 @@ _servers: Dict[str, MCPServerTask] = {}
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
_lock = threading.Lock()
def _ensure_mcp_loop():
"""Start the background event loop thread if not already running."""
global _mcp_loop, _mcp_thread
if _mcp_loop is not None and _mcp_loop.is_running():
return
_mcp_loop = asyncio.new_event_loop()
_mcp_thread = threading.Thread(
target=_mcp_loop.run_forever,
name="mcp-event-loop",
daemon=True,
)
_mcp_thread.start()
with _lock:
if _mcp_loop is not None and _mcp_loop.is_running():
return
_mcp_loop = asyncio.new_event_loop()
_mcp_thread = threading.Thread(
target=_mcp_loop.run_forever,
name="mcp-event-loop",
daemon=True,
)
_mcp_thread.start()
def _run_on_mcp_loop(coro, timeout: float = 30):
"""Schedule a coroutine on the MCP event loop and block until done."""
if _mcp_loop is None or not _mcp_loop.is_running():
with _lock:
loop = _mcp_loop
if loop is None or not loop.is_running():
raise RuntimeError("MCP event loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop)
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result(timeout=timeout)
@ -236,7 +248,8 @@ def _make_tool_handler(server_name: str, tool_name: str):
"""
def _handler(args: dict, **kwargs) -> str:
server = _servers.get(server_name)
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
@ -272,7 +285,8 @@ def _make_check_fn(server_name: str):
"""Return a check function that verifies the MCP connection is alive."""
def _check() -> bool:
server = _servers.get(server_name)
with _lock:
server = _servers.get(server_name)
return server is not None and server.session is not None
return _check
@ -307,6 +321,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
}
def _existing_tool_names() -> List[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
for sname, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
names.append(schema["name"])
return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
"""Connect to a single MCP server, discover tools, and register them.
@ -316,7 +340,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
from toolsets import create_custom_toolset
server = await _connect_server(name, config)
_servers[name] = server
with _lock:
_servers[name] = server
registered_names: List[str] = []
toolset_name = f"mcp-{name}"
@ -361,8 +386,8 @@ def discover_mcp_tools() -> List[str]:
Called from ``model_tools._discover_tools()``. Safe to call even when
the ``mcp`` package is not installed (returns empty list).
Idempotent: if servers are already connected, returns the existing
tool names without creating duplicate connections.
Idempotent for already-connected servers. If some servers failed on a
previous call, only the missing ones are retried.
Returns:
List of all registered MCP tool names.
@ -371,27 +396,25 @@ def discover_mcp_tools() -> List[str]:
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
return []
# Already connected -- return existing tool names (idempotent)
if _servers:
existing: List[str] = []
for name, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(name, mcp_tool)
existing.append(schema["name"])
return existing
servers = _load_mcp_config()
if not servers:
logger.debug("No MCP servers configured")
return []
# Only attempt servers that aren't already connected
with _lock:
new_servers = {k: v for k, v in servers.items() if k not in _servers}
if not new_servers:
return _existing_tool_names()
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
async def _discover_all():
for name, cfg in servers.items():
for name, cfg in new_servers.items():
try:
registered = await _discover_and_register_server(name, cfg)
all_tools.extend(registered)
@ -401,17 +424,16 @@ def discover_mcp_tools() -> List[str]:
_run_on_mcp_loop(_discover_all(), timeout=60)
if all_tools:
# Add MCP tools to hermes-cli and other platform toolsets
# Dynamically inject into all hermes-* platform toolsets
from toolsets import TOOLSETS
for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord",
"hermes-whatsapp", "hermes-slack"):
ts = TOOLSETS.get(ts_name)
if ts:
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in all_tools:
if tool_name not in ts["tools"]:
ts["tools"].append(tool_name)
return all_tools
# Return ALL registered tools (existing + newly discovered)
return _existing_tool_names()
def shutdown_mcp_servers():
@ -419,24 +441,39 @@ def shutdown_mcp_servers():
Each server Task is signalled to exit its ``async with`` block so that
the anyio cancel-scope cleanup happens in the same Task that opened it.
All servers are shut down in parallel via ``asyncio.gather``.
"""
global _mcp_loop, _mcp_thread
with _lock:
if not _servers:
# No servers -- just stop the loop. _stop_mcp_loop() also
# acquires _lock, so we must release it first.
pass
else:
servers_snapshot = list(_servers.values())
# Fast path: nothing to shut down.
if not _servers:
_stop_mcp_loop()
return
async def _shutdown():
for name, server in list(_servers.items()):
try:
await server.shutdown()
except Exception as exc:
logger.debug("Error closing MCP server '%s': %s", name, exc)
_servers.clear()
results = await asyncio.gather(
*(server.shutdown() for server in servers_snapshot),
return_exceptions=True,
)
for server, result in zip(servers_snapshot, results):
if isinstance(result, Exception):
logger.debug(
"Error closing MCP server '%s': %s", server.name, result,
)
with _lock:
_servers.clear()
if _mcp_loop is not None and _mcp_loop.is_running():
with _lock:
loop = _mcp_loop
if loop is not None and loop.is_running():
try:
future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop)
future = asyncio.run_coroutine_threadsafe(_shutdown(), loop)
future.result(timeout=15)
except Exception as exc:
logger.debug("Error during MCP shutdown: %s", exc)
@ -447,10 +484,13 @@ def shutdown_mcp_servers():
def _stop_mcp_loop():
"""Stop the background event loop and join its thread."""
global _mcp_loop, _mcp_thread
if _mcp_loop is not None:
_mcp_loop.call_soon_threadsafe(_mcp_loop.stop)
if _mcp_thread is not None:
_mcp_thread.join(timeout=5)
_mcp_thread = None
_mcp_loop.close()
with _lock:
loop = _mcp_loop
thread = _mcp_thread
_mcp_loop = None
_mcp_thread = None
if loop is not None:
loop.call_soon_threadsafe(loop.stop)
if thread is not None:
thread.join(timeout=5)
loop.close()