mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): per-process PID isolation prevents cross-session crash on restart
- _stdio_pids: set → Dict[int,str] tracks pid→server_name - SIGTERM-first with 2s grace before SIGKILL escalation - hasattr guard for SIGKILL on platforms without it - Updated tests for dict-based tracking and 3-phase kill sequence
This commit is contained in:
parent
c7d023937c
commit
67c8f837fc
2 changed files with 51 additions and 23 deletions
|
|
@ -77,7 +77,7 @@ class TestStdioPidTracking:
|
||||||
from tools.mcp_tool import _stdio_pids, _lock
|
from tools.mcp_tool import _stdio_pids, _lock
|
||||||
with _lock:
|
with _lock:
|
||||||
# Might have residual state from other tests, just check type
|
# Might have residual state from other tests, just check type
|
||||||
assert isinstance(_stdio_pids, set)
|
assert isinstance(_stdio_pids, dict)
|
||||||
|
|
||||||
def test_kill_orphaned_noop_when_empty(self):
|
def test_kill_orphaned_noop_when_empty(self):
|
||||||
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
||||||
|
|
@ -96,7 +96,7 @@ class TestStdioPidTracking:
|
||||||
# Use a PID that definitely doesn't exist
|
# Use a PID that definitely doesn't exist
|
||||||
fake_pid = 999999999
|
fake_pid = 999999999
|
||||||
with _lock:
|
with _lock:
|
||||||
_stdio_pids.add(fake_pid)
|
_stdio_pids[fake_pid] = "test"
|
||||||
|
|
||||||
# Should not raise (ProcessLookupError is caught)
|
# Should not raise (ProcessLookupError is caught)
|
||||||
_kill_orphaned_mcp_children()
|
_kill_orphaned_mcp_children()
|
||||||
|
|
@ -105,40 +105,49 @@ class TestStdioPidTracking:
|
||||||
assert fake_pid not in _stdio_pids
|
assert fake_pid not in _stdio_pids
|
||||||
|
|
||||||
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
||||||
"""Unix-like platforms should keep using SIGKILL for orphan cleanup."""
|
"""SIGTERM-first then SIGKILL after 2s for orphan cleanup."""
|
||||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||||
|
|
||||||
fake_pid = 424242
|
fake_pid = 424242
|
||||||
with _lock:
|
with _lock:
|
||||||
_stdio_pids.clear()
|
_stdio_pids.clear()
|
||||||
_stdio_pids.add(fake_pid)
|
_stdio_pids[fake_pid] = "test"
|
||||||
|
|
||||||
fake_sigkill = 9
|
fake_sigkill = 9
|
||||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||||
|
|
||||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
with patch("tools.mcp_tool.os.kill") as mock_kill, \
|
||||||
|
patch("time.sleep") as mock_sleep:
|
||||||
_kill_orphaned_mcp_children()
|
_kill_orphaned_mcp_children()
|
||||||
|
|
||||||
mock_kill.assert_called_once_with(fake_pid, fake_sigkill)
|
# SIGTERM, then alive-check (signal 0), then SIGKILL
|
||||||
|
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
|
||||||
|
mock_kill.assert_any_call(fake_pid, 0) # alive check
|
||||||
|
mock_kill.assert_any_call(fake_pid, fake_sigkill)
|
||||||
|
assert mock_kill.call_count == 3
|
||||||
|
mock_sleep.assert_called_once_with(2)
|
||||||
|
|
||||||
with _lock:
|
with _lock:
|
||||||
assert fake_pid not in _stdio_pids
|
assert fake_pid not in _stdio_pids
|
||||||
|
|
||||||
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
||||||
"""Windows-like signal modules without SIGKILL should fall back to SIGTERM."""
|
"""Without SIGKILL, SIGTERM is used for both phases."""
|
||||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||||
|
|
||||||
fake_pid = 434343
|
fake_pid = 434343
|
||||||
with _lock:
|
with _lock:
|
||||||
_stdio_pids.clear()
|
_stdio_pids.clear()
|
||||||
_stdio_pids.add(fake_pid)
|
_stdio_pids[fake_pid] = "test"
|
||||||
|
|
||||||
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
||||||
|
|
||||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
with patch("tools.mcp_tool.os.kill") as mock_kill, \
|
||||||
|
patch("time.sleep") as mock_sleep:
|
||||||
_kill_orphaned_mcp_children()
|
_kill_orphaned_mcp_children()
|
||||||
|
|
||||||
mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM)
|
# SIGTERM phase, alive check raises (process gone), no escalation
|
||||||
|
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
|
||||||
|
assert mock_sleep.called
|
||||||
|
|
||||||
with _lock:
|
with _lock:
|
||||||
assert fake_pid not in _stdio_pids
|
assert fake_pid not in _stdio_pids
|
||||||
|
|
|
||||||
|
|
@ -967,7 +967,8 @@ class MCPServerTask:
|
||||||
new_pids = _snapshot_child_pids() - pids_before
|
new_pids = _snapshot_child_pids() - pids_before
|
||||||
if new_pids:
|
if new_pids:
|
||||||
with _lock:
|
with _lock:
|
||||||
_stdio_pids.update(new_pids)
|
for _pid in new_pids:
|
||||||
|
_stdio_pids[_pid] = self.name
|
||||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
@ -980,7 +981,8 @@ class MCPServerTask:
|
||||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||||
if new_pids:
|
if new_pids:
|
||||||
with _lock:
|
with _lock:
|
||||||
_stdio_pids.difference_update(new_pids)
|
for _pid in new_pids:
|
||||||
|
_stdio_pids.pop(_pid, None)
|
||||||
|
|
||||||
async def _run_http(self, config: dict):
|
async def _run_http(self, config: dict):
|
||||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||||
|
|
@ -1484,7 +1486,7 @@ _lock = threading.Lock()
|
||||||
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
|
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
|
||||||
# fails or times out. PIDs are added after connection and removed on
|
# fails or times out. PIDs are added after connection and removed on
|
||||||
# normal server shutdown.
|
# normal server shutdown.
|
||||||
_stdio_pids: set = set()
|
_stdio_pids: Dict[int, str] = {} # pid -> server_name
|
||||||
|
|
||||||
|
|
||||||
def _snapshot_child_pids() -> set:
|
def _snapshot_child_pids() -> set:
|
||||||
|
|
@ -2618,27 +2620,44 @@ def shutdown_mcp_servers():
|
||||||
|
|
||||||
|
|
||||||
def _kill_orphaned_mcp_children() -> None:
|
def _kill_orphaned_mcp_children() -> None:
|
||||||
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown.
|
"""Graceful shutdown of MCP stdio subprocesses that survived loop cleanup.
|
||||||
|
|
||||||
After the MCP event loop is stopped, stdio server subprocesses *should*
|
Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL.
|
||||||
have been terminated by the SDK's context-manager cleanup. If the loop
|
This prevents shared-resource collisions when multiple hermes processes
|
||||||
was stuck or the shutdown timed out, orphaned children may remain.
|
run on the same host (each has its own _stdio_pids dict).
|
||||||
|
|
||||||
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
||||||
"""
|
"""
|
||||||
import signal as _signal
|
import signal as _signal
|
||||||
kill_signal = getattr(_signal, "SIGKILL", _signal.SIGTERM)
|
import time as _time
|
||||||
|
|
||||||
with _lock:
|
with _lock:
|
||||||
pids = list(_stdio_pids)
|
pids = dict(_stdio_pids)
|
||||||
_stdio_pids.clear()
|
_stdio_pids.clear()
|
||||||
|
|
||||||
for pid in pids:
|
# Phase 1: SIGTERM (graceful)
|
||||||
|
for pid, server_name in pids.items():
|
||||||
try:
|
try:
|
||||||
os.kill(pid, kill_signal)
|
os.kill(pid, _signal.SIGTERM)
|
||||||
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
|
logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name)
|
||||||
except (ProcessLookupError, PermissionError, OSError):
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
pass # Already exited or inaccessible
|
pass
|
||||||
|
|
||||||
|
# Phase 2: Wait for graceful exit
|
||||||
|
_time.sleep(2)
|
||||||
|
|
||||||
|
# Phase 3: SIGKILL any survivors
|
||||||
|
_sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM)
|
||||||
|
for pid, server_name in pids.items():
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0) # Check if still alive
|
||||||
|
os.kill(pid, _sigkill)
|
||||||
|
logger.warning(
|
||||||
|
"Force-killed MCP process %d (%s) after SIGTERM timeout",
|
||||||
|
pid, server_name,
|
||||||
|
)
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
pass # Good — exited after SIGTERM
|
||||||
|
|
||||||
|
|
||||||
def _stop_mcp_loop():
|
def _stop_mcp_loop():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue