mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): stability fix pack — reload timeout, shutdown cleanup, event loop handler, OAuth non-blocking (#4757)
Four fixes for MCP server stability issues reported by community member (terminal lockup, zombie processes, escape sequence pollution, startup hang): 1. MCP reload timeout guard (cli.py): _check_config_mcp_changes now runs _reload_mcp in a separate daemon thread with a 30s hard timeout. Previously, a hung MCP server could block the process_loop thread indefinitely, freezing the entire TUI (user can type but nothing happens, only Ctrl+D/Ctrl+\ work). 2. MCP stdio subprocess PID tracking (mcp_tool.py): Tracks child PIDs spawned by stdio_client via before/after snapshots of /proc children. On shutdown, _stop_mcp_loop force-kills any tracked PIDs that survived the SDK's graceful SIGTERM→SIGKILL cleanup. Prevents zombie MCP server processes from accumulating across sessions. 3. MCP event loop exception handler (mcp_tool.py): Installs _mcp_loop_exception_handler on the MCP background event loop — same pattern as the existing _suppress_closed_loop_errors on prompt_toolkit's loop. Suppresses benign 'Event loop is closed' RuntimeError from httpx transport __del__ during MCP shutdown. Salvaged from PR #2538 (acsezen). 4. MCP OAuth non-blocking (mcp_oauth.py): Replaces blocking input() call in _wait_for_callback with OAuthNonInteractiveError raise. Adds _is_interactive() TTY detection. In non-interactive environments, build_oauth_auth() still returns a provider (cached tokens + refresh work), but the callback handler raises immediately instead of blocking the MCP event loop for 120s. Re-raises OAuth setup failures in _run_http so failed servers are reported cleanly without blocking others. Salvaged from PRs #4521 (voidborne-d) and #4465 (heathley). Closes #2537, closes #4462 Related: #4128, #3436
This commit is contained in:
parent
f374ae4c61
commit
cc54818d26
5 changed files with 431 additions and 12 deletions
13
cli.py
13
cli.py
|
|
@ -5016,11 +5016,18 @@ class HermesCLI:
|
|||
return # mcp_servers unchanged (some other section was edited)
|
||||
|
||||
self._config_mcp_servers = new_mcp
|
||||
# Notify user and reload
|
||||
# Notify user and reload. Run in a separate thread with a hard
|
||||
# timeout so a hung MCP server cannot block the process_loop
|
||||
# indefinitely (which would freeze the entire TUI).
|
||||
print()
|
||||
print("🔄 MCP server config changed — reloading connections...")
|
||||
with self._busy_command(self._slow_command_status("/reload-mcp")):
|
||||
self._reload_mcp()
|
||||
_reload_thread = threading.Thread(
|
||||
target=self._reload_mcp, daemon=True
|
||||
)
|
||||
_reload_thread.start()
|
||||
_reload_thread.join(timeout=30)
|
||||
if _reload_thread.is_alive():
|
||||
print(" ⚠️ MCP reload timed out (30s). Some servers may not have reconnected.")
|
||||
|
||||
def _reload_mcp(self):
|
||||
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
|
||||
|
|
|
|||
|
|
@ -9,10 +9,13 @@ import pytest
|
|||
|
||||
from tools.mcp_oauth import (
|
||||
HermesTokenStorage,
|
||||
OAuthNonInteractiveError,
|
||||
build_oauth_auth,
|
||||
remove_oauth_tokens,
|
||||
_find_free_port,
|
||||
_can_open_browser,
|
||||
_is_interactive,
|
||||
_wait_for_callback,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -236,3 +239,99 @@ class TestRemoveOAuthTokens:
|
|||
def test_no_error_when_files_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
remove_oauth_tokens("nonexistent") # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-interactive / startup-safety tests (issue #4462)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsInteractive:
|
||||
"""_is_interactive() detects headless/daemon/container environments."""
|
||||
|
||||
def test_false_when_stdin_not_tty(self, monkeypatch):
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is False
|
||||
|
||||
def test_true_when_stdin_is_tty(self, monkeypatch):
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is True
|
||||
|
||||
def test_false_when_stdin_has_no_isatty(self, monkeypatch):
|
||||
"""Some environments replace stdin with an object without isatty()."""
|
||||
mock_stdin = object() # no isatty attribute
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is False
|
||||
|
||||
|
||||
class TestWaitForCallbackNoBlocking:
|
||||
"""_wait_for_callback() must never call input() — it raises instead."""
|
||||
|
||||
def test_raises_on_timeout_instead_of_input(self):
|
||||
"""When no auth code arrives, raises OAuthNonInteractiveError."""
|
||||
import tools.mcp_oauth as mod
|
||||
import asyncio
|
||||
|
||||
mod._oauth_port = _find_free_port()
|
||||
|
||||
async def instant_sleep(_seconds):
|
||||
pass
|
||||
|
||||
with patch.object(mod.asyncio, "sleep", instant_sleep):
|
||||
with patch("builtins.input", side_effect=AssertionError("input() must not be called")):
|
||||
with pytest.raises(OAuthNonInteractiveError, match="callback timed out"):
|
||||
asyncio.run(_wait_for_callback())
|
||||
|
||||
|
||||
class TestBuildOAuthAuthNonInteractive:
|
||||
"""build_oauth_auth() in non-interactive mode."""
|
||||
|
||||
def test_noninteractive_without_cached_tokens_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Without cached tokens, non-interactive mode logs a clear warning."""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
|
||||
import logging
|
||||
with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"):
|
||||
auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp")
|
||||
|
||||
assert auth is not None
|
||||
assert "no cached tokens found" in caplog.text.lower()
|
||||
assert "non-interactive" in caplog.text.lower()
|
||||
|
||||
def test_noninteractive_with_cached_tokens_no_warning(self, tmp_path, monkeypatch, caplog):
|
||||
"""With cached tokens, non-interactive mode logs no 'no cached tokens' warning."""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
|
||||
# Pre-populate cached tokens
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "atlassian.json").write_text(json.dumps({
|
||||
"access_token": "cached",
|
||||
"token_type": "Bearer",
|
||||
}))
|
||||
|
||||
import logging
|
||||
with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"):
|
||||
auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp")
|
||||
|
||||
assert auth is not None
|
||||
assert "no cached tokens found" not in caplog.text.lower()
|
||||
|
|
|
|||
143
tests/tools/test_mcp_stability.py
Normal file
143
tests/tools/test_mcp_stability.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Tests for MCP stability fixes — event loop handler, PID tracking, shutdown robustness."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 1: MCP event loop exception handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPLoopExceptionHandler:
|
||||
"""_mcp_loop_exception_handler suppresses benign 'Event loop is closed'."""
|
||||
|
||||
def test_suppresses_event_loop_closed(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": RuntimeError("Event loop is closed")}
|
||||
# Should NOT call default handler
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_not_called()
|
||||
|
||||
def test_forwards_other_runtime_errors(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": RuntimeError("some other error")}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_forwards_non_runtime_errors(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": ValueError("bad value")}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_forwards_contexts_without_exception(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"message": "just a message"}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_handler_installed_on_mcp_loop(self):
|
||||
"""_ensure_mcp_loop installs the exception handler on the new loop."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
try:
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
with mcp_mod._lock:
|
||||
loop = mcp_mod._mcp_loop
|
||||
assert loop is not None
|
||||
assert loop.get_exception_handler() is mcp_mod._mcp_loop_exception_handler
|
||||
finally:
|
||||
mcp_mod._stop_mcp_loop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 2: stdio PID tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStdioPidTracking:
|
||||
"""_snapshot_child_pids and _stdio_pids track subprocess PIDs."""
|
||||
|
||||
def test_snapshot_returns_set(self):
|
||||
from tools.mcp_tool import _snapshot_child_pids
|
||||
result = _snapshot_child_pids()
|
||||
assert isinstance(result, set)
|
||||
# All elements should be ints
|
||||
for pid in result:
|
||||
assert isinstance(pid, int)
|
||||
|
||||
def test_stdio_pids_starts_empty(self):
|
||||
from tools.mcp_tool import _stdio_pids, _lock
|
||||
with _lock:
|
||||
# Might have residual state from other tests, just check type
|
||||
assert isinstance(_stdio_pids, set)
|
||||
|
||||
def test_kill_orphaned_noop_when_empty(self):
|
||||
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
|
||||
# Should not raise
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
def test_kill_orphaned_handles_dead_pids(self):
|
||||
"""_kill_orphaned_mcp_children gracefully handles already-dead PIDs."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
# Use a PID that definitely doesn't exist
|
||||
fake_pid = 999999999
|
||||
with _lock:
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
# Should not raise (ProcessLookupError is caught)
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 3: MCP reload timeout (cli.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPReloadTimeout:
|
||||
"""_check_config_mcp_changes uses a timeout on _reload_mcp."""
|
||||
|
||||
def test_reload_timeout_does_not_block_forever(self, tmp_path, monkeypatch):
|
||||
"""If _reload_mcp hangs, the config watcher times out and returns."""
|
||||
import time
|
||||
|
||||
# Create a mock HermesCLI-like object with the needed attributes
|
||||
class FakeCLI:
|
||||
_config_mtime = 0.0
|
||||
_config_mcp_servers = {}
|
||||
_last_config_check = 0.0
|
||||
_command_running = False
|
||||
config = {}
|
||||
agent = None
|
||||
|
||||
def _reload_mcp(self):
|
||||
# Simulate a hang — sleep longer than the timeout
|
||||
time.sleep(60)
|
||||
|
||||
def _slow_command_status(self, cmd):
|
||||
return cmd
|
||||
|
||||
# This test verifies the timeout mechanism exists in the code
|
||||
# by checking that _check_config_mcp_changes doesn't call
|
||||
# _reload_mcp directly (it uses a thread now)
|
||||
import inspect
|
||||
from cli import HermesCLI
|
||||
source = inspect.getsource(HermesCLI._check_config_mcp_changes)
|
||||
# The fix adds threading.Thread for _reload_mcp
|
||||
assert "Thread" in source or "thread" in source.lower(), \
|
||||
"_check_config_mcp_changes should use a thread for _reload_mcp"
|
||||
|
|
@ -5,6 +5,12 @@ Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
|||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Startup safety:
|
||||
The callback handler never calls blocking ``input()`` on the event loop.
|
||||
In non-interactive environments (no TTY, SSH, headless), the OAuth flow
|
||||
raises ``OAuthNonInteractiveError`` instead of blocking, so that the
|
||||
server degrades gracefully and other MCP servers are not affected.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
|
|
@ -19,6 +25,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
|
@ -28,6 +35,11 @@ from urllib.parse import parse_qs, urlparse
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthNonInteractiveError(RuntimeError):
|
||||
"""Raised when OAuth requires user interaction but the environment is non-interactive."""
|
||||
pass
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
|
||||
|
||||
|
|
@ -164,7 +176,13 @@ async def _redirect_to_browser(auth_url: str) -> None:
|
|||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect.
|
||||
|
||||
If the callback times out, raises ``OAuthNonInteractiveError`` instead of
|
||||
calling blocking ``input()`` — the old ``input()`` call would block the
|
||||
entire MCP asyncio event loop, preventing all other MCP servers from
|
||||
connecting and potentially hanging Hermes startup indefinitely.
|
||||
"""
|
||||
global _oauth_port
|
||||
port = _oauth_port or _find_free_port()
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
|
|
@ -186,8 +204,10 @@ async def _wait_for_callback() -> tuple[str, str | None]:
|
|||
code = result["auth_code"] or ""
|
||||
state = result["state"]
|
||||
if not code:
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
raise OAuthNonInteractiveError(
|
||||
"OAuth browser callback timed out after 120 seconds. "
|
||||
"Run 'hermes mcp auth <server-name>' to authorize interactively."
|
||||
)
|
||||
return code, state
|
||||
|
||||
|
||||
|
|
@ -199,6 +219,17 @@ def _can_open_browser() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _is_interactive() -> bool:
|
||||
"""Check if the current environment can support interactive OAuth flows.
|
||||
|
||||
Returns False in headless/daemon/container environments where no user
|
||||
can interact with a browser or paste an auth code.
|
||||
"""
|
||||
if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -209,6 +240,11 @@ def build_oauth_auth(server_name: str, server_url: str):
|
|||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
In non-interactive environments (no TTY), this still returns a provider
|
||||
so that **cached tokens and refresh flows work**. Only the interactive
|
||||
authorization-code grant will fail fast with a clear error instead of
|
||||
blocking the event loop.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
|
|
@ -219,6 +255,25 @@ def build_oauth_auth(server_name: str, server_url: str):
|
|||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
interactive = _is_interactive()
|
||||
|
||||
if not interactive:
|
||||
# Check whether cached tokens exist. If they do, the SDK can still
|
||||
# use them (and refresh them) without any user interaction. If not,
|
||||
# we still build the provider — the callback_handler will raise
|
||||
# OAuthNonInteractiveError if a fresh authorization is actually
|
||||
# needed, which surfaces as a clean connection failure for this
|
||||
# server only (other MCP servers are unaffected).
|
||||
has_cached = storage._read_json(storage._tokens_path()) is not None
|
||||
if not has_cached:
|
||||
logger.warning(
|
||||
"MCP server '%s' requires OAuth but no cached tokens found "
|
||||
"and environment is non-interactive. The server will fail to "
|
||||
"connect. Run 'hermes mcp auth %s' to authorize interactively.",
|
||||
server_name, server_name,
|
||||
)
|
||||
|
||||
global _oauth_port
|
||||
_oauth_port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||
|
|
@ -232,14 +287,36 @@ def build_oauth_auth(server_name: str, server_url: str):
|
|||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
# In non-interactive mode, the redirect handler logs the URL and the
|
||||
# callback handler raises immediately — no blocking, no input().
|
||||
redirect_handler = _redirect_to_browser
|
||||
callback_handler = _wait_for_callback
|
||||
|
||||
if not interactive:
|
||||
async def _noninteractive_redirect(auth_url: str) -> None:
|
||||
logger.warning(
|
||||
"MCP server '%s' needs OAuth authorization (non-interactive, "
|
||||
"cannot open browser). URL: %s",
|
||||
server_name, auth_url,
|
||||
)
|
||||
|
||||
async def _noninteractive_callback() -> tuple[str, str | None]:
|
||||
raise OAuthNonInteractiveError(
|
||||
f"MCP server '{server_name}' requires interactive OAuth "
|
||||
f"authorization but the environment is non-interactive "
|
||||
f"(no TTY). Run 'hermes mcp auth {server_name}' to "
|
||||
f"authorize, then restart."
|
||||
)
|
||||
|
||||
redirect_handler = _noninteractive_redirect
|
||||
callback_handler = _noninteractive_callback
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_to_browser,
|
||||
callback_handler=_wait_for_callback,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -842,13 +842,25 @@ class MCPServerTask:
|
|||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
|
||||
# Snapshot child PIDs before spawning so we can track the new one.
|
||||
pids_before = _snapshot_child_pids()
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
# Capture the newly spawned subprocess PID for force-kill cleanup.
|
||||
new_pids = _snapshot_child_pids() - pids_before
|
||||
if new_pids:
|
||||
with _lock:
|
||||
_stdio_pids.update(new_pids)
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||
if new_pids:
|
||||
with _lock:
|
||||
_stdio_pids.difference_update(new_pids)
|
||||
|
||||
async def _run_http(self, config: dict):
|
||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||
|
|
@ -863,7 +875,10 @@ class MCPServerTask:
|
|||
headers = dict(config.get("headers") or {})
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK.
|
||||
# If OAuth setup fails (e.g. non-interactive environment without
|
||||
# cached tokens), re-raise so this server is reported as failed
|
||||
# without blocking other MCP servers from connecting.
|
||||
_oauth_auth = None
|
||||
if self._auth_type == "oauth":
|
||||
try:
|
||||
|
|
@ -871,6 +886,7 @@ class MCPServerTask:
|
|||
_oauth_auth = build_oauth_auth(self.name, url)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
raise
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
|
|
@ -1044,9 +1060,56 @@ _servers: Dict[str, MCPServerTask] = {}
|
|||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids.
|
||||
_lock = threading.Lock()
|
||||
|
||||
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
|
||||
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
|
||||
# fails or times out. PIDs are added after connection and removed on
|
||||
# normal server shutdown.
|
||||
_stdio_pids: set = set()
|
||||
|
||||
|
||||
def _snapshot_child_pids() -> set:
|
||||
"""Return a set of current child process PIDs.
|
||||
|
||||
Uses /proc on Linux, falls back to psutil, then empty set.
|
||||
Used by _run_stdio to identify the subprocess spawned by stdio_client.
|
||||
"""
|
||||
my_pid = os.getpid()
|
||||
|
||||
# Linux: read from /proc
|
||||
try:
|
||||
children_path = f"/proc/{my_pid}/task/{my_pid}/children"
|
||||
with open(children_path) as f:
|
||||
return {int(p) for p in f.read().split() if p.strip()}
|
||||
except (FileNotFoundError, OSError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: psutil
|
||||
try:
|
||||
import psutil
|
||||
return {c.pid for c in psutil.Process(my_pid).children()}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _mcp_loop_exception_handler(loop, context):
|
||||
"""Suppress benign 'Event loop is closed' noise during shutdown.
|
||||
|
||||
When the MCP event loop is stopped and closed, httpx/httpcore async
|
||||
transports may fire __del__ finalizers that call call_soon() on the
|
||||
dead loop. asyncio catches that RuntimeError and routes it here.
|
||||
We silence it because the connection is being torn down anyway; all
|
||||
other exceptions are forwarded to the default handler.
|
||||
"""
|
||||
exc = context.get("exception")
|
||||
if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc):
|
||||
return # benign shutdown race — suppress
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
|
||||
def _ensure_mcp_loop():
|
||||
"""Start the background event loop thread if not already running."""
|
||||
|
|
@ -1055,6 +1118,7 @@ def _ensure_mcp_loop():
|
|||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
return
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_loop.set_exception_handler(_mcp_loop_exception_handler)
|
||||
_mcp_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever,
|
||||
name="mcp-event-loop",
|
||||
|
|
@ -2057,6 +2121,29 @@ def shutdown_mcp_servers():
|
|||
_stop_mcp_loop()
|
||||
|
||||
|
||||
def _kill_orphaned_mcp_children() -> None:
|
||||
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown.
|
||||
|
||||
After the MCP event loop is stopped, stdio server subprocesses *should*
|
||||
have been terminated by the SDK's context-manager cleanup. If the loop
|
||||
was stuck or the shutdown timed out, orphaned children may remain.
|
||||
|
||||
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
||||
"""
|
||||
import signal as _signal
|
||||
|
||||
with _lock:
|
||||
pids = list(_stdio_pids)
|
||||
_stdio_pids.clear()
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGKILL)
|
||||
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass # Already exited or inaccessible
|
||||
|
||||
|
||||
def _stop_mcp_loop():
|
||||
"""Stop the background event loop and join its thread."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
|
|
@ -2069,4 +2156,10 @@ def _stop_mcp_loop():
|
|||
loop.call_soon_threadsafe(loop.stop)
|
||||
if thread is not None:
|
||||
thread.join(timeout=5)
|
||||
loop.close()
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
# After closing the loop, any stdio subprocesses that survived the
|
||||
# graceful shutdown are now orphaned. Force-kill them.
|
||||
_kill_orphaned_mcp_children()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue