fix(mcp): per-process PID isolation prevents cross-session crash on restart

- _stdio_pids: set → Dict[int,str] tracks pid→server_name
- SIGTERM-first with 2s grace before SIGKILL escalation
- hasattr guard for SIGKILL on platforms without it
- Updated tests for dict-based tracking and 3-phase kill sequence
This commit is contained in:
Jefferson 2026-04-23 20:01:17 +02:00 committed by Teknium
parent c7d023937c
commit 67c8f837fc
2 changed files with 51 additions and 23 deletions

View file

@ -967,7 +967,8 @@ class MCPServerTask:
new_pids = _snapshot_child_pids() - pids_before
if new_pids:
with _lock:
_stdio_pids.update(new_pids)
for _pid in new_pids:
_stdio_pids[_pid] = self.name
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
@ -980,7 +981,8 @@ class MCPServerTask:
# Context exited cleanly — subprocess was terminated by the SDK.
if new_pids:
with _lock:
_stdio_pids.difference_update(new_pids)
for _pid in new_pids:
_stdio_pids.pop(_pid, None)
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
@ -1484,7 +1486,7 @@ _lock = threading.Lock()
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
# fails or times out. PIDs are added after connection and removed on
# normal server shutdown.
_stdio_pids: set = set()
_stdio_pids: Dict[int, str] = {} # pid -> server_name
def _snapshot_child_pids() -> set:
@ -2618,27 +2620,44 @@ def shutdown_mcp_servers():
def _kill_orphaned_mcp_children() -> None:
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown.
"""Graceful shutdown of MCP stdio subprocesses that survived loop cleanup.
After the MCP event loop is stopped, stdio server subprocesses *should*
have been terminated by the SDK's context-manager cleanup. If the loop
was stuck or the shutdown timed out, orphaned children may remain.
Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL.
This prevents shared-resource collisions when multiple hermes processes
run on the same host (each has its own _stdio_pids dict).
Only kills PIDs tracked in ``_stdio_pids`` never arbitrary children.
"""
import signal as _signal
kill_signal = getattr(_signal, "SIGKILL", _signal.SIGTERM)
import time as _time
with _lock:
pids = list(_stdio_pids)
pids = dict(_stdio_pids)
_stdio_pids.clear()
for pid in pids:
# Phase 1: SIGTERM (graceful)
for pid, server_name in pids.items():
try:
os.kill(pid, kill_signal)
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
os.kill(pid, _signal.SIGTERM)
logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name)
except (ProcessLookupError, PermissionError, OSError):
pass # Already exited or inaccessible
pass
# Phase 2: Wait for graceful exit
_time.sleep(2)
# Phase 3: SIGKILL any survivors
_sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM)
for pid, server_name in pids.items():
try:
os.kill(pid, 0) # Check if still alive
os.kill(pid, _sigkill)
logger.warning(
"Force-killed MCP process %d (%s) after SIGTERM timeout",
pid, server_name,
)
except (ProcessLookupError, PermissionError, OSError):
pass # Good — exited after SIGTERM
def _stop_mcp_loop():