diff --git a/cli.py b/cli.py index d7e92069b..42a49440c 100644 --- a/cli.py +++ b/cli.py @@ -5016,11 +5016,18 @@ class HermesCLI: return # mcp_servers unchanged (some other section was edited) self._config_mcp_servers = new_mcp - # Notify user and reload + # Notify user and reload. Run in a separate thread with a hard + # timeout so a hung MCP server cannot block the process_loop + # indefinitely (which would freeze the entire TUI). print() print("🔄 MCP server config changed — reloading connections...") - with self._busy_command(self._slow_command_status("/reload-mcp")): - self._reload_mcp() + _reload_thread = threading.Thread( + target=self._reload_mcp, daemon=True + ) + _reload_thread.start() + _reload_thread.join(timeout=30) + if _reload_thread.is_alive(): + print(" ⚠️ MCP reload timed out (30s). Some servers may not have reconnected.") def _reload_mcp(self): """Reload MCP servers: disconnect all, re-read config.yaml, reconnect. diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index 66ac3b616..19c588e58 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -9,10 +9,13 @@ import pytest from tools.mcp_oauth import ( HermesTokenStorage, + OAuthNonInteractiveError, build_oauth_auth, remove_oauth_tokens, _find_free_port, _can_open_browser, + _is_interactive, + _wait_for_callback, ) @@ -236,3 +239,99 @@ class TestRemoveOAuthTokens: def test_no_error_when_files_missing(self, tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) remove_oauth_tokens("nonexistent") # should not raise + + +# --------------------------------------------------------------------------- +# Non-interactive / startup-safety tests (issue #4462) +# --------------------------------------------------------------------------- + +class TestIsInteractive: + """_is_interactive() detects headless/daemon/container environments.""" + + def test_false_when_stdin_not_tty(self, monkeypatch): + mock_stdin = MagicMock() + mock_stdin.isatty.return_value = False + monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin) + assert _is_interactive() is False + + def test_true_when_stdin_is_tty(self, monkeypatch): + mock_stdin = MagicMock() + mock_stdin.isatty.return_value = True + monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin) + assert _is_interactive() is True + + def test_false_when_stdin_has_no_isatty(self, monkeypatch): + """Some environments replace stdin with an object without isatty().""" + mock_stdin = object() # no isatty attribute + monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin) + assert _is_interactive() is False + + +class TestWaitForCallbackNoBlocking: + """_wait_for_callback() must never call input() — it raises instead.""" + + def test_raises_on_timeout_instead_of_input(self): + """When no auth code arrives, raises OAuthNonInteractiveError.""" + import tools.mcp_oauth as mod + import asyncio + + mod._oauth_port = _find_free_port() + + async def instant_sleep(_seconds): + pass + + with patch.object(mod.asyncio, "sleep", instant_sleep): + with patch("builtins.input", side_effect=AssertionError("input() must not be called")): + with pytest.raises(OAuthNonInteractiveError, match="callback timed out"): + asyncio.run(_wait_for_callback()) + + +class TestBuildOAuthAuthNonInteractive: + """build_oauth_auth() in non-interactive mode.""" + + def test_noninteractive_without_cached_tokens_warns(self, tmp_path, monkeypatch, caplog): + """Without cached tokens, non-interactive mode logs a clear warning.""" + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + mock_stdin = MagicMock() + mock_stdin.isatty.return_value = False + monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin) + + import logging + with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"): + auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp") + + assert auth is not None + assert "no cached tokens found" in caplog.text.lower() + assert "non-interactive" in caplog.text.lower() + + def test_noninteractive_with_cached_tokens_no_warning(self, tmp_path, monkeypatch, caplog): + """With cached tokens, non-interactive mode logs no 'no cached tokens' warning.""" + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + mock_stdin = MagicMock() + mock_stdin.isatty.return_value = False + monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin) + + # Pre-populate cached tokens + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "atlassian.json").write_text(json.dumps({ + "access_token": "cached", + "token_type": "Bearer", + })) + + import logging + with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"): + auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp") + + assert auth is not None + assert "no cached tokens found" not in caplog.text.lower() diff --git a/tests/tools/test_mcp_stability.py b/tests/tools/test_mcp_stability.py new file mode 100644 index 000000000..c83dda463 --- /dev/null +++ b/tests/tools/test_mcp_stability.py @@ -0,0 +1,143 @@ +"""Tests for MCP stability fixes — event loop handler, PID tracking, shutdown robustness.""" + +import asyncio +import os +import signal +import threading +from unittest.mock import patch, MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fix 1: MCP event loop exception handler +# --------------------------------------------------------------------------- + +class TestMCPLoopExceptionHandler: + """_mcp_loop_exception_handler suppresses benign 'Event loop is closed'.""" + + def test_suppresses_event_loop_closed(self): + from tools.mcp_tool import _mcp_loop_exception_handler + loop = MagicMock() + context = {"exception": RuntimeError("Event loop is closed")} + # Should NOT call default handler + _mcp_loop_exception_handler(loop, context) + loop.default_exception_handler.assert_not_called() + + def test_forwards_other_runtime_errors(self): + from tools.mcp_tool import _mcp_loop_exception_handler + loop = MagicMock() + context = {"exception": RuntimeError("some other error")} + _mcp_loop_exception_handler(loop, context) + loop.default_exception_handler.assert_called_once_with(context) + + def test_forwards_non_runtime_errors(self): + from tools.mcp_tool import _mcp_loop_exception_handler + loop = MagicMock() + context = {"exception": ValueError("bad value")} + _mcp_loop_exception_handler(loop, context) + loop.default_exception_handler.assert_called_once_with(context) + + def test_forwards_contexts_without_exception(self): + from tools.mcp_tool import _mcp_loop_exception_handler + loop = MagicMock() + context = {"message": "just a message"} + _mcp_loop_exception_handler(loop, context) + loop.default_exception_handler.assert_called_once_with(context) + + def test_handler_installed_on_mcp_loop(self): + """_ensure_mcp_loop installs the exception handler on the new loop.""" + import tools.mcp_tool as mcp_mod + try: + mcp_mod._ensure_mcp_loop() + with mcp_mod._lock: + loop = mcp_mod._mcp_loop + assert loop is not None + assert loop.get_exception_handler() is mcp_mod._mcp_loop_exception_handler + finally: + mcp_mod._stop_mcp_loop() + + +# --------------------------------------------------------------------------- +# Fix 2: stdio PID tracking +# --------------------------------------------------------------------------- + +class TestStdioPidTracking: + """_snapshot_child_pids and _stdio_pids track subprocess PIDs.""" + + def test_snapshot_returns_set(self): + from tools.mcp_tool import _snapshot_child_pids + result = _snapshot_child_pids() + assert isinstance(result, set) + # All elements should be ints + for pid in result: + assert isinstance(pid, int) + + def test_stdio_pids_starts_empty(self): + from tools.mcp_tool import _stdio_pids, _lock + with _lock: + # Might have residual state from other tests, just check type + assert isinstance(_stdio_pids, set) + + def test_kill_orphaned_noop_when_empty(self): + """_kill_orphaned_mcp_children does nothing when no PIDs tracked.""" + from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock + + with _lock: + _stdio_pids.clear() + + # Should not raise + _kill_orphaned_mcp_children() + + def test_kill_orphaned_handles_dead_pids(self): + """_kill_orphaned_mcp_children gracefully handles already-dead PIDs.""" + from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock + + # Use a PID that definitely doesn't exist + fake_pid = 999999999 + with _lock: + _stdio_pids.add(fake_pid) + + # Should not raise (ProcessLookupError is caught) + _kill_orphaned_mcp_children() + + with _lock: + assert fake_pid not in _stdio_pids + + +# --------------------------------------------------------------------------- +# Fix 3: MCP reload timeout (cli.py) +# --------------------------------------------------------------------------- + +class TestMCPReloadTimeout: + """_check_config_mcp_changes uses a timeout on _reload_mcp.""" + + def test_reload_timeout_does_not_block_forever(self, tmp_path, monkeypatch): + """If _reload_mcp hangs, the config watcher times out and returns.""" + import time + + # Create a mock HermesCLI-like object with the needed attributes + class FakeCLI: + _config_mtime = 0.0 + _config_mcp_servers = {} + _last_config_check = 0.0 + _command_running = False + config = {} + agent = None + + def _reload_mcp(self): + # Simulate a hang — sleep longer than the timeout + time.sleep(60) + + def _slow_command_status(self, cmd): + return cmd + + # This test verifies the timeout mechanism exists in the code + # by checking that _check_config_mcp_changes doesn't call + # _reload_mcp directly (it uses a thread now) + import inspect + from cli import HermesCLI + source = inspect.getsource(HermesCLI._check_config_mcp_changes) + # The fix adds threading.Thread for _reload_mcp + assert "Thread" in source or "thread" in source.lower(), \ + "_check_config_mcp_changes should use a thread for _reload_mcp" diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 4fa228589..b614826a8 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -5,6 +5,12 @@ Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements authorization. The SDK handles all of the heavy lifting: PKCE generation, metadata discovery, dynamic client registration, token exchange, and refresh. +Startup safety: + The callback handler never calls blocking ``input()`` on the event loop. + In non-interactive environments (no TTY, SSH, headless), the OAuth flow + raises ``OAuthNonInteractiveError`` instead of blocking, so that the + server degrades gracefully and other MCP servers are not affected. + Usage in mcp_tool.py:: from tools.mcp_oauth import build_oauth_auth @@ -19,6 +25,7 @@ import json import logging import os import socket +import sys import threading import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer @@ -28,6 +35,11 @@ from urllib.parse import parse_qs, urlparse logger = logging.getLogger(__name__) + +class OAuthNonInteractiveError(RuntimeError): + """Raised when OAuth requires user interaction but the environment is non-interactive.""" + pass + _TOKEN_DIR_NAME = "mcp-tokens" @@ -164,7 +176,13 @@ async def _redirect_to_browser(auth_url: str) -> None: async def _wait_for_callback() -> tuple[str, str | None]: - """Start a local HTTP server on the pre-registered port and wait for the OAuth redirect.""" + """Start a local HTTP server on the pre-registered port and wait for the OAuth redirect. + + If the callback times out, raises ``OAuthNonInteractiveError`` instead of + calling blocking ``input()`` — the old ``input()`` call would block the + entire MCP asyncio event loop, preventing all other MCP servers from + connecting and potentially hanging Hermes startup indefinitely. + """ global _oauth_port port = _oauth_port or _find_free_port() HandlerClass, result = _make_callback_handler() @@ -186,8 +204,10 @@ async def _wait_for_callback() -> tuple[str, str | None]: code = result["auth_code"] or "" state = result["state"] if not code: - print(" Browser callback timed out. Paste the authorization code manually:") - code = input(" Code: ").strip() + raise OAuthNonInteractiveError( + "OAuth browser callback timed out after 120 seconds. " + "Run 'hermes mcp auth ' to authorize interactively." + ) return code, state @@ -199,6 +219,17 @@ def _can_open_browser() -> bool: return True +def _is_interactive() -> bool: + """Check if the current environment can support interactive OAuth flows. + + Returns False in headless/daemon/container environments where no user + can interact with a browser or paste an auth code. + """ + if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty(): + return False + return True + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -209,6 +240,11 @@ def build_oauth_auth(server_name: str, server_url: str): Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery, registration, PKCE, token exchange, and refresh automatically. + In non-interactive environments (no TTY), this still returns a provider + so that **cached tokens and refresh flows work**. Only the interactive + authorization-code grant will fail fast with a clear error instead of + blocking the event loop. + Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``), or ``None`` if the MCP SDK auth module is not available. """ @@ -219,6 +255,25 @@ def build_oauth_auth(server_name: str, server_url: str): logger.warning("MCP SDK auth module not available — OAuth disabled") return None + storage = HermesTokenStorage(server_name) + interactive = _is_interactive() + + if not interactive: + # Check whether cached tokens exist. If they do, the SDK can still + # use them (and refresh them) without any user interaction. If not, + # we still build the provider — the callback_handler will raise + # OAuthNonInteractiveError if a fresh authorization is actually + # needed, which surfaces as a clean connection failure for this + # server only (other MCP servers are unaffected). + has_cached = storage._read_json(storage._tokens_path()) is not None + if not has_cached: + logger.warning( + "MCP server '%s' requires OAuth but no cached tokens found " + "and environment is non-interactive. The server will fail to " + "connect. Run 'hermes mcp auth %s' to authorize interactively.", + server_name, server_name, + ) + global _oauth_port _oauth_port = _find_free_port() redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback" @@ -232,14 +287,36 @@ def build_oauth_auth(server_name: str, server_url: str): token_endpoint_auth_method="none", ) - storage = HermesTokenStorage(server_name) + # In non-interactive mode, the redirect handler logs the URL and the + # callback handler raises immediately — no blocking, no input(). + redirect_handler = _redirect_to_browser + callback_handler = _wait_for_callback + + if not interactive: + async def _noninteractive_redirect(auth_url: str) -> None: + logger.warning( + "MCP server '%s' needs OAuth authorization (non-interactive, " + "cannot open browser). URL: %s", + server_name, auth_url, + ) + + async def _noninteractive_callback() -> tuple[str, str | None]: + raise OAuthNonInteractiveError( + f"MCP server '{server_name}' requires interactive OAuth " + f"authorization but the environment is non-interactive " + f"(no TTY). Run 'hermes mcp auth {server_name}' to " + f"authorize, then restart." + ) + + redirect_handler = _noninteractive_redirect + callback_handler = _noninteractive_callback return OAuthClientProvider( server_url=server_url, client_metadata=client_metadata, storage=storage, - redirect_handler=_redirect_to_browser, - callback_handler=_wait_for_callback, + redirect_handler=redirect_handler, + callback_handler=callback_handler, timeout=120.0, ) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 0918de20a..88bb6fd73 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -842,13 +842,25 @@ class MCPServerTask: sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: sampling_kwargs["message_handler"] = self._make_message_handler() + + # Snapshot child PIDs before spawning so we can track the new one. + pids_before = _snapshot_child_pids() async with stdio_client(server_params) as (read_stream, write_stream): + # Capture the newly spawned subprocess PID for force-kill cleanup. + new_pids = _snapshot_child_pids() - pids_before + if new_pids: + with _lock: + _stdio_pids.update(new_pids) async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() self._ready.set() await self._shutdown_event.wait() + # Context exited cleanly — subprocess was terminated by the SDK. + if new_pids: + with _lock: + _stdio_pids.difference_update(new_pids) async def _run_http(self, config: dict): """Run the server using HTTP/StreamableHTTP transport.""" @@ -863,7 +875,10 @@ class MCPServerTask: headers = dict(config.get("headers") or {}) connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) - # OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK + # OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK. + # If OAuth setup fails (e.g. non-interactive environment without + # cached tokens), re-raise so this server is reported as failed + # without blocking other MCP servers from connecting. _oauth_auth = None if self._auth_type == "oauth": try: @@ -871,6 +886,7 @@ class MCPServerTask: _oauth_auth = build_oauth_auth(self.name, url) except Exception as exc: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) + raise sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: @@ -1044,9 +1060,56 @@ _servers: Dict[str, MCPServerTask] = {} _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None -# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access. +# Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids. _lock = threading.Lock() +# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill +# them on shutdown if the graceful cleanup (SDK context-manager teardown) +# fails or times out. PIDs are added after connection and removed on +# normal server shutdown. +_stdio_pids: set = set() + + +def _snapshot_child_pids() -> set: + """Return a set of current child process PIDs. + + Uses /proc on Linux, falls back to psutil, then empty set. + Used by _run_stdio to identify the subprocess spawned by stdio_client. + """ + my_pid = os.getpid() + + # Linux: read from /proc + try: + children_path = f"/proc/{my_pid}/task/{my_pid}/children" + with open(children_path) as f: + return {int(p) for p in f.read().split() if p.strip()} + except (FileNotFoundError, OSError, ValueError): + pass + + # Fallback: psutil + try: + import psutil + return {c.pid for c in psutil.Process(my_pid).children()} + except Exception: + pass + + return set() + + +def _mcp_loop_exception_handler(loop, context): + """Suppress benign 'Event loop is closed' noise during shutdown. + + When the MCP event loop is stopped and closed, httpx/httpcore async + transports may fire __del__ finalizers that call call_soon() on the + dead loop. asyncio catches that RuntimeError and routes it here. + We silence it because the connection is being torn down anyway; all + other exceptions are forwarded to the default handler. + """ + exc = context.get("exception") + if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc): + return # benign shutdown race — suppress + loop.default_exception_handler(context) + def _ensure_mcp_loop(): """Start the background event loop thread if not already running.""" @@ -1055,6 +1118,7 @@ def _ensure_mcp_loop(): if _mcp_loop is not None and _mcp_loop.is_running(): return _mcp_loop = asyncio.new_event_loop() + _mcp_loop.set_exception_handler(_mcp_loop_exception_handler) _mcp_thread = threading.Thread( target=_mcp_loop.run_forever, name="mcp-event-loop", @@ -2057,6 +2121,29 @@ def shutdown_mcp_servers(): _stop_mcp_loop() +def _kill_orphaned_mcp_children() -> None: + """Best-effort kill of MCP stdio subprocesses that survived loop shutdown. + + After the MCP event loop is stopped, stdio server subprocesses *should* + have been terminated by the SDK's context-manager cleanup. If the loop + was stuck or the shutdown timed out, orphaned children may remain. + + Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children. + """ + import signal as _signal + + with _lock: + pids = list(_stdio_pids) + _stdio_pids.clear() + + for pid in pids: + try: + os.kill(pid, _signal.SIGKILL) + logger.debug("Force-killed orphaned MCP stdio process %d", pid) + except (ProcessLookupError, PermissionError, OSError): + pass # Already exited or inaccessible + + def _stop_mcp_loop(): """Stop the background event loop and join its thread.""" global _mcp_loop, _mcp_thread @@ -2069,4 +2156,10 @@ def _stop_mcp_loop(): loop.call_soon_threadsafe(loop.stop) if thread is not None: thread.join(timeout=5) - loop.close() + try: + loop.close() + except Exception: + pass + # After closing the loop, any stdio subprocesses that survived the + # graceful shutdown are now orphaned. Force-kill them. + _kill_orphaned_mcp_children()