diff --git a/tests/tools/test_mcp_stability.py b/tests/tools/test_mcp_stability.py index e3827f0a5..7a500dad5 100644 --- a/tests/tools/test_mcp_stability.py +++ b/tests/tools/test_mcp_stability.py @@ -77,7 +77,7 @@ class TestStdioPidTracking: from tools.mcp_tool import _stdio_pids, _lock with _lock: # Might have residual state from other tests, just check type - assert isinstance(_stdio_pids, set) + assert isinstance(_stdio_pids, dict) def test_kill_orphaned_noop_when_empty(self): """_kill_orphaned_mcp_children does nothing when no PIDs tracked.""" @@ -96,7 +96,7 @@ class TestStdioPidTracking: # Use a PID that definitely doesn't exist fake_pid = 999999999 with _lock: - _stdio_pids.add(fake_pid) + _stdio_pids[fake_pid] = "test" # Should not raise (ProcessLookupError is caught) _kill_orphaned_mcp_children() @@ -105,40 +105,49 @@ class TestStdioPidTracking: assert fake_pid not in _stdio_pids def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch): - """Unix-like platforms should keep using SIGKILL for orphan cleanup.""" + """SIGTERM-first then SIGKILL after 2s for orphan cleanup.""" from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock fake_pid = 424242 with _lock: _stdio_pids.clear() - _stdio_pids.add(fake_pid) + _stdio_pids[fake_pid] = "test" fake_sigkill = 9 monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False) - with patch("tools.mcp_tool.os.kill") as mock_kill: + with patch("tools.mcp_tool.os.kill") as mock_kill, \ + patch("time.sleep") as mock_sleep: _kill_orphaned_mcp_children() - mock_kill.assert_called_once_with(fake_pid, fake_sigkill) + # SIGTERM, then alive-check (signal 0), then SIGKILL + mock_kill.assert_any_call(fake_pid, signal.SIGTERM) + mock_kill.assert_any_call(fake_pid, 0) # alive check + mock_kill.assert_any_call(fake_pid, fake_sigkill) + assert mock_kill.call_count == 3 + mock_sleep.assert_called_once_with(2) with _lock: assert fake_pid not in _stdio_pids def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch): - """Windows-like signal modules without SIGKILL should fall back to SIGTERM.""" + """Without SIGKILL, SIGTERM is used for both phases.""" from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock fake_pid = 434343 with _lock: _stdio_pids.clear() - _stdio_pids.add(fake_pid) + _stdio_pids[fake_pid] = "test" monkeypatch.delattr(signal, "SIGKILL", raising=False) - with patch("tools.mcp_tool.os.kill") as mock_kill: + with patch("tools.mcp_tool.os.kill") as mock_kill, \ + patch("time.sleep") as mock_sleep: _kill_orphaned_mcp_children() - mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM) + # SIGTERM phase, alive check raises (process gone), no escalation + mock_kill.assert_any_call(fake_pid, signal.SIGTERM) + assert mock_sleep.called with _lock: assert fake_pid not in _stdio_pids diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2de479338..efef5ea91 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -967,7 +967,8 @@ class MCPServerTask: new_pids = _snapshot_child_pids() - pids_before if new_pids: with _lock: - _stdio_pids.update(new_pids) + for _pid in new_pids: + _stdio_pids[_pid] = self.name async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session @@ -980,7 +981,8 @@ class MCPServerTask: # Context exited cleanly — subprocess was terminated by the SDK. if new_pids: with _lock: - _stdio_pids.difference_update(new_pids) + for _pid in new_pids: + _stdio_pids.pop(_pid, None) async def _run_http(self, config: dict): """Run the server using HTTP/StreamableHTTP transport.""" @@ -1484,7 +1486,7 @@ _lock = threading.Lock() # them on shutdown if the graceful cleanup (SDK context-manager teardown) # fails or times out. PIDs are added after connection and removed on # normal server shutdown. -_stdio_pids: set = set() +_stdio_pids: Dict[int, str] = {} # pid -> server_name def _snapshot_child_pids() -> set: @@ -2618,27 +2620,44 @@ def shutdown_mcp_servers(): def _kill_orphaned_mcp_children() -> None: - """Best-effort kill of MCP stdio subprocesses that survived loop shutdown. + """Graceful shutdown of MCP stdio subprocesses that survived loop cleanup. - After the MCP event loop is stopped, stdio server subprocesses *should* - have been terminated by the SDK's context-manager cleanup. If the loop - was stuck or the shutdown timed out, orphaned children may remain. + Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL. + This prevents shared-resource collisions when multiple hermes processes + run on the same host (each has its own _stdio_pids dict). Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children. """ import signal as _signal - kill_signal = getattr(_signal, "SIGKILL", _signal.SIGTERM) + import time as _time with _lock: - pids = list(_stdio_pids) + pids = dict(_stdio_pids) _stdio_pids.clear() - for pid in pids: + # Phase 1: SIGTERM (graceful) + for pid, server_name in pids.items(): try: - os.kill(pid, kill_signal) - logger.debug("Force-killed orphaned MCP stdio process %d", pid) + os.kill(pid, _signal.SIGTERM) + logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name) except (ProcessLookupError, PermissionError, OSError): - pass # Already exited or inaccessible + pass + + # Phase 2: Wait for graceful exit + _time.sleep(2) + + # Phase 3: SIGKILL any survivors + _sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM) + for pid, server_name in pids.items(): + try: + os.kill(pid, 0) # Check if still alive + os.kill(pid, _sigkill) + logger.warning( + "Force-killed MCP process %d (%s) after SIGTERM timeout", + pid, server_name, + ) + except (ProcessLookupError, PermissionError, OSError): + pass # Good — exited after SIGTERM def _stop_mcp_loop():