fix(mcp): stability fix pack — reload timeout, shutdown cleanup, event loop handler, OAuth non-blocking (#4757)

Four fixes for MCP server stability issues reported by community member
(terminal lockup, zombie processes, escape sequence pollution, startup hang):

1. MCP reload timeout guard (cli.py): _check_config_mcp_changes now runs
   _reload_mcp in a separate daemon thread with a 30s hard timeout. Previously,
   a hung MCP server could block the process_loop thread indefinitely, freezing
   the entire TUI (user can type but nothing happens, only Ctrl+D/Ctrl+\ work).

2. MCP stdio subprocess PID tracking (mcp_tool.py): Tracks child PIDs spawned
   by stdio_client via before/after snapshots of /proc children. On shutdown,
   _stop_mcp_loop force-kills any tracked PIDs that survived the SDK's graceful
   SIGTERM→SIGKILL cleanup. Prevents zombie MCP server processes from
   accumulating across sessions.

3. MCP event loop exception handler (mcp_tool.py): Installs
   _mcp_loop_exception_handler on the MCP background event loop — same pattern
   as the existing _suppress_closed_loop_errors on prompt_toolkit's loop.
   Suppresses benign 'Event loop is closed' RuntimeError from httpx transport
   __del__ during MCP shutdown. Salvaged from PR #2538 (acsezen).

4. MCP OAuth non-blocking (mcp_oauth.py): Replaces blocking input() call in
   _wait_for_callback with OAuthNonInteractiveError raise. Adds _is_interactive()
   TTY detection. In non-interactive environments, build_oauth_auth() still
   returns a provider (cached tokens + refresh work), but the callback handler
   raises immediately instead of blocking the MCP event loop for 120s. Re-raises
   OAuth setup failures in _run_http so failed servers are reported cleanly
   without blocking others. Salvaged from PRs #4521 (voidborne-d) and #4465
   (heathley).

Closes #2537, closes #4462
Related: #4128, #3436
This commit is contained in:
Teknium 2026-04-03 02:29:20 -07:00 committed by GitHub
parent f374ae4c61
commit cc54818d26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 431 additions and 12 deletions

View file

@ -842,13 +842,25 @@ class MCPServerTask:
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
sampling_kwargs["message_handler"] = self._make_message_handler()
# Snapshot child PIDs before spawning so we can track the new one.
pids_before = _snapshot_child_pids()
async with stdio_client(server_params) as (read_stream, write_stream):
# Capture the newly spawned subprocess PID for force-kill cleanup.
new_pids = _snapshot_child_pids() - pids_before
if new_pids:
with _lock:
_stdio_pids.update(new_pids)
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
# Context exited cleanly — subprocess was terminated by the SDK.
if new_pids:
with _lock:
_stdio_pids.difference_update(new_pids)
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
@ -863,7 +875,10 @@ class MCPServerTask:
headers = dict(config.get("headers") or {})
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK.
# If OAuth setup fails (e.g. non-interactive environment without
# cached tokens), re-raise so this server is reported as failed
# without blocking other MCP servers from connecting.
_oauth_auth = None
if self._auth_type == "oauth":
try:
@ -871,6 +886,7 @@ class MCPServerTask:
_oauth_auth = build_oauth_auth(self.name, url)
except Exception as exc:
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
raise
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
@ -1044,9 +1060,56 @@ _servers: Dict[str, MCPServerTask] = {}
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
# Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids.
_lock = threading.Lock()
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
# fails or times out. PIDs are added after connection and removed on
# normal server shutdown.
_stdio_pids: set = set()
def _snapshot_child_pids() -> set:
"""Return a set of current child process PIDs.
Uses /proc on Linux, falls back to psutil, then empty set.
Used by _run_stdio to identify the subprocess spawned by stdio_client.
"""
my_pid = os.getpid()
# Linux: read from /proc
try:
children_path = f"/proc/{my_pid}/task/{my_pid}/children"
with open(children_path) as f:
return {int(p) for p in f.read().split() if p.strip()}
except (FileNotFoundError, OSError, ValueError):
pass
# Fallback: psutil
try:
import psutil
return {c.pid for c in psutil.Process(my_pid).children()}
except Exception:
pass
return set()
def _mcp_loop_exception_handler(loop, context):
"""Suppress benign 'Event loop is closed' noise during shutdown.
When the MCP event loop is stopped and closed, httpx/httpcore async
transports may fire __del__ finalizers that call call_soon() on the
dead loop. asyncio catches that RuntimeError and routes it here.
We silence it because the connection is being torn down anyway; all
other exceptions are forwarded to the default handler.
"""
exc = context.get("exception")
if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc):
return # benign shutdown race — suppress
loop.default_exception_handler(context)
def _ensure_mcp_loop():
"""Start the background event loop thread if not already running."""
@ -1055,6 +1118,7 @@ def _ensure_mcp_loop():
if _mcp_loop is not None and _mcp_loop.is_running():
return
_mcp_loop = asyncio.new_event_loop()
_mcp_loop.set_exception_handler(_mcp_loop_exception_handler)
_mcp_thread = threading.Thread(
target=_mcp_loop.run_forever,
name="mcp-event-loop",
@ -2057,6 +2121,29 @@ def shutdown_mcp_servers():
_stop_mcp_loop()
def _kill_orphaned_mcp_children() -> None:
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown.
After the MCP event loop is stopped, stdio server subprocesses *should*
have been terminated by the SDK's context-manager cleanup. If the loop
was stuck or the shutdown timed out, orphaned children may remain.
Only kills PIDs tracked in ``_stdio_pids`` never arbitrary children.
"""
import signal as _signal
with _lock:
pids = list(_stdio_pids)
_stdio_pids.clear()
for pid in pids:
try:
os.kill(pid, _signal.SIGKILL)
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
except (ProcessLookupError, PermissionError, OSError):
pass # Already exited or inaccessible
def _stop_mcp_loop():
"""Stop the background event loop and join its thread."""
global _mcp_loop, _mcp_thread
@ -2069,4 +2156,10 @@ def _stop_mcp_loop():
loop.call_soon_threadsafe(loop.stop)
if thread is not None:
thread.join(timeout=5)
loop.close()
try:
loop.close()
except Exception:
pass
# After closing the loop, any stdio subprocesses that survived the
# graceful shutdown are now orphaned. Force-kill them.
_kill_orphaned_mcp_children()