fix: resolve thread safety issues and shutdown deadlock in MCP client

- Add threading.Lock protecting all shared state (_servers, _mcp_loop, _mcp_thread)
- Fix deadlock in shutdown_mcp_servers: _stop_mcp_loop was called inside
  a _lock block but also acquires _lock (non-reentrant)
- Fix race condition in _ensure_mcp_loop with concurrent callers
- Change idempotency to per-server (retry failed servers, skip connected)
- Dynamic toolset injection via startswith("hermes-") instead of hardcoded list
- Parallel shutdown via asyncio.gather instead of sequential loop
- Add tests for partial failure retry, parallel shutdown, dynamic injection
This commit is contained in:
0xbyt4 2026-03-02 22:08:32 +03:00
parent 151e8d896c
commit 11a2ecb936
2 changed files with 184 additions and 58 deletions

View file

@ -459,14 +459,13 @@ class TestMCPServerTask:
# ---------------------------------------------------------------------------
class TestToolsetInjection:
def test_mcp_tools_added_to_platform_toolsets(self):
"""Discovered MCP tools are injected into hermes-cli and platform toolsets."""
def test_mcp_tools_added_to_all_hermes_toolsets(self):
"""Discovered MCP tools are dynamically injected into all hermes-* toolsets."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("list_files", "List files")]
mock_session = MagicMock()
# Fresh _servers dict to bypass idempotency guard
fresh_servers = {}
async def fake_connect(name, config):
@ -476,12 +475,12 @@ class TestToolsetInjection:
return server
fake_toolsets = {
"hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []},
}
fake_config = {
"fs": {"command": "npx", "args": []},
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
"hermes-gateway": {"tools": [], "description": "GW", "includes": []},
"non-hermes": {"tools": [], "description": "other", "includes": []},
}
fake_config = {"fs": {"command": "npx", "args": []}}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
@ -492,8 +491,12 @@ class TestToolsetInjection:
result = discover_mcp_tools()
assert "mcp_fs_list_files" in result
# All hermes-* toolsets get injection
assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"]
# Non-hermes toolset should NOT get injection
assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"]
# Original tools preserved
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
@ -504,7 +507,6 @@ class TestToolsetInjection:
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
# Fresh _servers dict to bypass idempotency guard
fresh_servers = {}
call_count = 0
@ -534,10 +536,62 @@ class TestToolsetInjection:
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
# Only good server's tool registered
assert "mcp_good_ping" in result
assert "mcp_broken_ping" not in result
assert call_count == 2 # Both were attempted
assert call_count == 2
def test_partial_failure_retry_on_second_call(self):
"""Failed servers are retried on subsequent discover_mcp_tools() calls."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
# Use a real dict so idempotency logic works correctly
fresh_servers = {}
call_count = 0
broken_fixed = False
async def flaky_connect(name, config):
nonlocal call_count
call_count += 1
if name == "broken" and not broken_fixed:
raise ConnectionError("cannot reach server")
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
fake_config = {
"broken": {"command": "bad"},
"good": {"command": "npx", "args": []},
}
fake_toolsets = {
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
patch("toolsets.TOOLSETS", fake_toolsets):
from tools.mcp_tool import discover_mcp_tools
# First call: good connects, broken fails
result1 = discover_mcp_tools()
assert "mcp_good_ping" in result1
assert "mcp_broken_ping" not in result1
first_attempts = call_count
# "Fix" the broken server
broken_fixed = True
call_count = 0
# Second call: should retry broken, skip good
result2 = discover_mcp_tools()
assert "mcp_good_ping" in result2
assert "mcp_broken_ping" in result2
assert call_count == 1 # Only broken retried
# ---------------------------------------------------------------------------
@ -581,6 +635,7 @@ class TestShutdown:
_servers.clear()
mock_server = MagicMock()
mock_server.name = "test"
mock_server.shutdown = AsyncMock()
_servers["test"] = mock_server
@ -601,6 +656,7 @@ class TestShutdown:
_servers.clear()
mock_server = MagicMock()
mock_server.name = "broken"
mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
_servers["broken"] = mock_server
@ -612,3 +668,33 @@ class TestShutdown:
mcp_mod._mcp_thread = None
assert len(_servers) == 0
def test_shutdown_is_parallel(self):
"""Multiple servers are shut down in parallel via asyncio.gather."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import shutdown_mcp_servers, _servers
import time
_servers.clear()
# 3 servers each taking 1s to shut down
for i in range(3):
mock_server = MagicMock()
mock_server.name = f"srv_{i}"
async def slow_shutdown():
await asyncio.sleep(1)
mock_server.shutdown = slow_shutdown
_servers[f"srv_{i}"] = mock_server
mcp_mod._ensure_mcp_loop()
try:
start = time.monotonic()
shutdown_mcp_servers()
elapsed = time.monotonic() - start
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert len(_servers) == 0
# Parallel: ~1s, not ~3s. Allow some margin.
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"