fix(mcp): per-process PID isolation prevents cross-session crash on restart

- _stdio_pids: set → Dict[int,str] tracks pid→server_name
- SIGTERM-first with 2s grace before SIGKILL escalation
- hasattr guard for SIGKILL on platforms without it
- Updated tests for dict-based tracking and 3-phase kill sequence
This commit is contained in:
Jefferson 2026-04-23 20:01:17 +02:00 committed by Teknium
parent c7d023937c
commit 67c8f837fc
2 changed files with 51 additions and 23 deletions

View file

@ -77,7 +77,7 @@ class TestStdioPidTracking:
from tools.mcp_tool import _stdio_pids, _lock from tools.mcp_tool import _stdio_pids, _lock
with _lock: with _lock:
# Might have residual state from other tests, just check type # Might have residual state from other tests, just check type
assert isinstance(_stdio_pids, set) assert isinstance(_stdio_pids, dict)
def test_kill_orphaned_noop_when_empty(self): def test_kill_orphaned_noop_when_empty(self):
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked.""" """_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
@ -96,7 +96,7 @@ class TestStdioPidTracking:
# Use a PID that definitely doesn't exist # Use a PID that definitely doesn't exist
fake_pid = 999999999 fake_pid = 999999999
with _lock: with _lock:
_stdio_pids.add(fake_pid) _stdio_pids[fake_pid] = "test"
# Should not raise (ProcessLookupError is caught) # Should not raise (ProcessLookupError is caught)
_kill_orphaned_mcp_children() _kill_orphaned_mcp_children()
@ -105,40 +105,49 @@ class TestStdioPidTracking:
assert fake_pid not in _stdio_pids assert fake_pid not in _stdio_pids
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch): def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
"""Unix-like platforms should keep using SIGKILL for orphan cleanup.""" """SIGTERM-first then SIGKILL after 2s for orphan cleanup."""
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
fake_pid = 424242 fake_pid = 424242
with _lock: with _lock:
_stdio_pids.clear() _stdio_pids.clear()
_stdio_pids.add(fake_pid) _stdio_pids[fake_pid] = "test"
fake_sigkill = 9 fake_sigkill = 9
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False) monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
with patch("tools.mcp_tool.os.kill") as mock_kill: with patch("tools.mcp_tool.os.kill") as mock_kill, \
patch("time.sleep") as mock_sleep:
_kill_orphaned_mcp_children() _kill_orphaned_mcp_children()
mock_kill.assert_called_once_with(fake_pid, fake_sigkill) # SIGTERM, then alive-check (signal 0), then SIGKILL
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
mock_kill.assert_any_call(fake_pid, 0) # alive check
mock_kill.assert_any_call(fake_pid, fake_sigkill)
assert mock_kill.call_count == 3
mock_sleep.assert_called_once_with(2)
with _lock: with _lock:
assert fake_pid not in _stdio_pids assert fake_pid not in _stdio_pids
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch): def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
"""Windows-like signal modules without SIGKILL should fall back to SIGTERM.""" """Without SIGKILL, SIGTERM is used for both phases."""
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
fake_pid = 434343 fake_pid = 434343
with _lock: with _lock:
_stdio_pids.clear() _stdio_pids.clear()
_stdio_pids.add(fake_pid) _stdio_pids[fake_pid] = "test"
monkeypatch.delattr(signal, "SIGKILL", raising=False) monkeypatch.delattr(signal, "SIGKILL", raising=False)
with patch("tools.mcp_tool.os.kill") as mock_kill: with patch("tools.mcp_tool.os.kill") as mock_kill, \
patch("time.sleep") as mock_sleep:
_kill_orphaned_mcp_children() _kill_orphaned_mcp_children()
mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM) # SIGTERM phase, alive check raises (process gone), no escalation
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
assert mock_sleep.called
with _lock: with _lock:
assert fake_pid not in _stdio_pids assert fake_pid not in _stdio_pids

View file

@ -967,7 +967,8 @@ class MCPServerTask:
new_pids = _snapshot_child_pids() - pids_before new_pids = _snapshot_child_pids() - pids_before
if new_pids: if new_pids:
with _lock: with _lock:
_stdio_pids.update(new_pids) for _pid in new_pids:
_stdio_pids[_pid] = self.name
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize() await session.initialize()
self.session = session self.session = session
@ -980,7 +981,8 @@ class MCPServerTask:
# Context exited cleanly — subprocess was terminated by the SDK. # Context exited cleanly — subprocess was terminated by the SDK.
if new_pids: if new_pids:
with _lock: with _lock:
_stdio_pids.difference_update(new_pids) for _pid in new_pids:
_stdio_pids.pop(_pid, None)
async def _run_http(self, config: dict): async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport.""" """Run the server using HTTP/StreamableHTTP transport."""
@ -1484,7 +1486,7 @@ _lock = threading.Lock()
# them on shutdown if the graceful cleanup (SDK context-manager teardown) # them on shutdown if the graceful cleanup (SDK context-manager teardown)
# fails or times out. PIDs are added after connection and removed on # fails or times out. PIDs are added after connection and removed on
# normal server shutdown. # normal server shutdown.
_stdio_pids: set = set() _stdio_pids: Dict[int, str] = {} # pid -> server_name
def _snapshot_child_pids() -> set: def _snapshot_child_pids() -> set:
@ -2618,27 +2620,44 @@ def shutdown_mcp_servers():
def _kill_orphaned_mcp_children() -> None: def _kill_orphaned_mcp_children() -> None:
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown. """Graceful shutdown of MCP stdio subprocesses that survived loop cleanup.
After the MCP event loop is stopped, stdio server subprocesses *should* Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL.
have been terminated by the SDK's context-manager cleanup. If the loop This prevents shared-resource collisions when multiple hermes processes
was stuck or the shutdown timed out, orphaned children may remain. run on the same host (each has its own _stdio_pids dict).
Only kills PIDs tracked in ``_stdio_pids`` never arbitrary children. Only kills PIDs tracked in ``_stdio_pids`` never arbitrary children.
""" """
import signal as _signal import signal as _signal
kill_signal = getattr(_signal, "SIGKILL", _signal.SIGTERM) import time as _time
with _lock: with _lock:
pids = list(_stdio_pids) pids = dict(_stdio_pids)
_stdio_pids.clear() _stdio_pids.clear()
for pid in pids: # Phase 1: SIGTERM (graceful)
for pid, server_name in pids.items():
try: try:
os.kill(pid, kill_signal) os.kill(pid, _signal.SIGTERM)
logger.debug("Force-killed orphaned MCP stdio process %d", pid) logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name)
except (ProcessLookupError, PermissionError, OSError): except (ProcessLookupError, PermissionError, OSError):
pass # Already exited or inaccessible pass
# Phase 2: Wait for graceful exit
_time.sleep(2)
# Phase 3: SIGKILL any survivors
_sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM)
for pid, server_name in pids.items():
try:
os.kill(pid, 0) # Check if still alive
os.kill(pid, _sigkill)
logger.warning(
"Force-killed MCP process %d (%s) after SIGTERM timeout",
pid, server_name,
)
except (ProcessLookupError, PermissionError, OSError):
pass # Good — exited after SIGTERM
def _stop_mcp_loop(): def _stop_mcp_loop():