diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 2e52272b32..065baf4a10 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -459,14 +459,13 @@ class TestMCPServerTask: # --------------------------------------------------------------------------- class TestToolsetInjection: - def test_mcp_tools_added_to_platform_toolsets(self): - """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" + def test_mcp_tools_added_to_all_hermes_toolsets(self): + """Discovered MCP tools are dynamically injected into all hermes-* toolsets.""" from tools.mcp_tool import MCPServerTask mock_tools = [_make_mcp_tool("list_files", "List files")] mock_session = MagicMock() - # Fresh _servers dict to bypass idempotency guard fresh_servers = {} async def fake_connect(name, config): @@ -476,12 +475,12 @@ class TestToolsetInjection: return server fake_toolsets = { - "hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []}, - "hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []}, - } - fake_config = { - "fs": {"command": "npx", "args": []}, + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + "hermes-gateway": {"tools": [], "description": "GW", "includes": []}, + "non-hermes": {"tools": [], "description": "other", "includes": []}, } + fake_config = {"fs": {"command": "npx", "args": []}} with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ patch("tools.mcp_tool._servers", fresh_servers), \ @@ -492,8 +491,12 @@ class TestToolsetInjection: result = discover_mcp_tools() assert "mcp_fs_list_files" in result + # All hermes-* toolsets get injection assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"] assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"] + assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"] + # Non-hermes toolset should NOT get injection + assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"] # Original tools preserved assert "terminal" in fake_toolsets["hermes-cli"]["tools"] @@ -504,7 +507,6 @@ class TestToolsetInjection: mock_tools = [_make_mcp_tool("ping", "Ping")] mock_session = MagicMock() - # Fresh _servers dict to bypass idempotency guard fresh_servers = {} call_count = 0 @@ -534,10 +536,62 @@ class TestToolsetInjection: from tools.mcp_tool import discover_mcp_tools result = discover_mcp_tools() - # Only good server's tool registered assert "mcp_good_ping" in result assert "mcp_broken_ping" not in result - assert call_count == 2 # Both were attempted + assert call_count == 2 + + def test_partial_failure_retry_on_second_call(self): + """Failed servers are retried on subsequent discover_mcp_tools() calls.""" + from tools.mcp_tool import MCPServerTask + + mock_tools = [_make_mcp_tool("ping", "Ping")] + mock_session = MagicMock() + + # Use a real dict so idempotency logic works correctly + fresh_servers = {} + call_count = 0 + broken_fixed = False + + async def flaky_connect(name, config): + nonlocal call_count + call_count += 1 + if name == "broken" and not broken_fixed: + raise ConnectionError("cannot reach server") + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server + + fake_config = { + "broken": {"command": "bad"}, + "good": {"command": "npx", "args": []}, + } + fake_toolsets = { + "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, + } + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", fresh_servers), \ + patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ + patch("toolsets.TOOLSETS", fake_toolsets): + from tools.mcp_tool import discover_mcp_tools + + # First call: good connects, broken fails + result1 = discover_mcp_tools() + assert "mcp_good_ping" in result1 + assert "mcp_broken_ping" not in result1 + first_attempts = call_count + + # "Fix" the broken server + broken_fixed = True + call_count = 0 + + # Second call: should retry broken, skip good + result2 = discover_mcp_tools() + assert "mcp_good_ping" in result2 + assert "mcp_broken_ping" in result2 + assert call_count == 1 # Only broken retried # --------------------------------------------------------------------------- @@ -581,6 +635,7 @@ class TestShutdown: _servers.clear() mock_server = MagicMock() + mock_server.name = "test" mock_server.shutdown = AsyncMock() _servers["test"] = mock_server @@ -601,6 +656,7 @@ class TestShutdown: _servers.clear() mock_server = MagicMock() + mock_server.name = "broken" mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed")) _servers["broken"] = mock_server @@ -612,3 +668,33 @@ class TestShutdown: mcp_mod._mcp_thread = None assert len(_servers) == 0 + + def test_shutdown_is_parallel(self): + """Multiple servers are shut down in parallel via asyncio.gather.""" + import tools.mcp_tool as mcp_mod + from tools.mcp_tool import shutdown_mcp_servers, _servers + import time + + _servers.clear() + + # 3 servers each taking 1s to shut down + for i in range(3): + mock_server = MagicMock() + mock_server.name = f"srv_{i}" + async def slow_shutdown(): + await asyncio.sleep(1) + mock_server.shutdown = slow_shutdown + _servers[f"srv_{i}"] = mock_server + + mcp_mod._ensure_mcp_loop() + try: + start = time.monotonic() + shutdown_mcp_servers() + elapsed = time.monotonic() - start + finally: + mcp_mod._mcp_loop = None + mcp_mod._mcp_thread = None + + assert len(_servers) == 0 + # Parallel: ~1s, not ~3s. Allow some margin. + assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)" diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 5cdce4a398..4ab55215b8 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -32,6 +32,12 @@ Architecture: On shutdown, each server Task is signalled to exit its ``async with`` block, ensuring the anyio cancel-scope cleanup happens in the *same* Task that opened the connection (required by anyio). + +Thread safety: + _servers and _mcp_loop/_mcp_thread are accessed from both the MCP + background thread and caller threads. All mutations are protected by + _lock so the code is safe regardless of GIL presence (e.g. Python 3.13+ + free-threading). """ import asyncio @@ -161,26 +167,32 @@ _servers: Dict[str, MCPServerTask] = {} _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None +# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access. +_lock = threading.Lock() + def _ensure_mcp_loop(): """Start the background event loop thread if not already running.""" global _mcp_loop, _mcp_thread - if _mcp_loop is not None and _mcp_loop.is_running(): - return - _mcp_loop = asyncio.new_event_loop() - _mcp_thread = threading.Thread( - target=_mcp_loop.run_forever, - name="mcp-event-loop", - daemon=True, - ) - _mcp_thread.start() + with _lock: + if _mcp_loop is not None and _mcp_loop.is_running(): + return + _mcp_loop = asyncio.new_event_loop() + _mcp_thread = threading.Thread( + target=_mcp_loop.run_forever, + name="mcp-event-loop", + daemon=True, + ) + _mcp_thread.start() def _run_on_mcp_loop(coro, timeout: float = 30): """Schedule a coroutine on the MCP event loop and block until done.""" - if _mcp_loop is None or not _mcp_loop.is_running(): + with _lock: + loop = _mcp_loop + if loop is None or not loop.is_running(): raise RuntimeError("MCP event loop is not running") - future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop) + future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result(timeout=timeout) @@ -236,7 +248,8 @@ def _make_tool_handler(server_name: str, tool_name: str): """ def _handler(args: dict, **kwargs) -> str: - server = _servers.get(server_name) + with _lock: + server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" @@ -272,7 +285,8 @@ def _make_check_fn(server_name: str): """Return a check function that verifies the MCP connection is alive.""" def _check() -> bool: - server = _servers.get(server_name) + with _lock: + server = _servers.get(server_name) return server is not None and server.session is not None return _check @@ -307,6 +321,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: } +def _existing_tool_names() -> List[str]: + """Return tool names for all currently connected servers.""" + names: List[str] = [] + for sname, server in _servers.items(): + for mcp_tool in server._tools: + schema = _convert_mcp_schema(sname, mcp_tool) + names.append(schema["name"]) + return names + + async def _discover_and_register_server(name: str, config: dict) -> List[str]: """Connect to a single MCP server, discover tools, and register them. @@ -316,7 +340,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: from toolsets import create_custom_toolset server = await _connect_server(name, config) - _servers[name] = server + with _lock: + _servers[name] = server registered_names: List[str] = [] toolset_name = f"mcp-{name}" @@ -361,8 +386,8 @@ def discover_mcp_tools() -> List[str]: Called from ``model_tools._discover_tools()``. Safe to call even when the ``mcp`` package is not installed (returns empty list). - Idempotent: if servers are already connected, returns the existing - tool names without creating duplicate connections. + Idempotent for already-connected servers. If some servers failed on a + previous call, only the missing ones are retried. Returns: List of all registered MCP tool names. @@ -371,27 +396,25 @@ def discover_mcp_tools() -> List[str]: logger.debug("MCP SDK not available -- skipping MCP tool discovery") return [] - # Already connected -- return existing tool names (idempotent) - if _servers: - existing: List[str] = [] - for name, server in _servers.items(): - for mcp_tool in server._tools: - schema = _convert_mcp_schema(name, mcp_tool) - existing.append(schema["name"]) - return existing - servers = _load_mcp_config() if not servers: logger.debug("No MCP servers configured") return [] + # Only attempt servers that aren't already connected + with _lock: + new_servers = {k: v for k, v in servers.items() if k not in _servers} + + if not new_servers: + return _existing_tool_names() + # Start the background event loop for MCP connections _ensure_mcp_loop() all_tools: List[str] = [] async def _discover_all(): - for name, cfg in servers.items(): + for name, cfg in new_servers.items(): try: registered = await _discover_and_register_server(name, cfg) all_tools.extend(registered) @@ -401,17 +424,16 @@ def discover_mcp_tools() -> List[str]: _run_on_mcp_loop(_discover_all(), timeout=60) if all_tools: - # Add MCP tools to hermes-cli and other platform toolsets + # Dynamically inject into all hermes-* platform toolsets from toolsets import TOOLSETS - for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord", - "hermes-whatsapp", "hermes-slack"): - ts = TOOLSETS.get(ts_name) - if ts: + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): for tool_name in all_tools: if tool_name not in ts["tools"]: ts["tools"].append(tool_name) - return all_tools + # Return ALL registered tools (existing + newly discovered) + return _existing_tool_names() def shutdown_mcp_servers(): @@ -419,24 +441,39 @@ def shutdown_mcp_servers(): Each server Task is signalled to exit its ``async with`` block so that the anyio cancel-scope cleanup happens in the same Task that opened it. + All servers are shut down in parallel via ``asyncio.gather``. """ - global _mcp_loop, _mcp_thread + with _lock: + if not _servers: + # No servers -- just stop the loop. _stop_mcp_loop() also + # acquires _lock, so we must release it first. + pass + else: + servers_snapshot = list(_servers.values()) + # Fast path: nothing to shut down. if not _servers: _stop_mcp_loop() return async def _shutdown(): - for name, server in list(_servers.items()): - try: - await server.shutdown() - except Exception as exc: - logger.debug("Error closing MCP server '%s': %s", name, exc) - _servers.clear() + results = await asyncio.gather( + *(server.shutdown() for server in servers_snapshot), + return_exceptions=True, + ) + for server, result in zip(servers_snapshot, results): + if isinstance(result, Exception): + logger.debug( + "Error closing MCP server '%s': %s", server.name, result, + ) + with _lock: + _servers.clear() - if _mcp_loop is not None and _mcp_loop.is_running(): + with _lock: + loop = _mcp_loop + if loop is not None and loop.is_running(): try: - future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop) + future = asyncio.run_coroutine_threadsafe(_shutdown(), loop) future.result(timeout=15) except Exception as exc: logger.debug("Error during MCP shutdown: %s", exc) @@ -447,10 +484,13 @@ def shutdown_mcp_servers(): def _stop_mcp_loop(): """Stop the background event loop and join its thread.""" global _mcp_loop, _mcp_thread - if _mcp_loop is not None: - _mcp_loop.call_soon_threadsafe(_mcp_loop.stop) - if _mcp_thread is not None: - _mcp_thread.join(timeout=5) - _mcp_thread = None - _mcp_loop.close() + with _lock: + loop = _mcp_loop + thread = _mcp_thread _mcp_loop = None + _mcp_thread = None + if loop is not None: + loop.call_soon_threadsafe(loop.stop) + if thread is not None: + thread.join(timeout=5) + loop.close()