feat(gateway): auto-reload MCP connections on config change

Add a background _mcp_config_watcher() task to the gateway that polls
config.yaml every 30 seconds and auto-reloads MCP server connections
when the mcp_servers section changes.

This solves the problem where OAuth token refresh cron jobs update
Bearer tokens in config.yaml, but the running gateway keeps using
stale cached credentials until manually restarted.

The CLI already had this via _check_config_mcp_changes() — this ports
the same concept to the async gateway event loop.

Changes:
- GatewayRunner.__init__: initialize _mcp_config_mtime and
  _mcp_config_servers state from current config
- GatewayRunner.start(): launch _mcp_config_watcher as background task
- GatewayRunner._mcp_config_watcher(): async background task that:
  - Uses mtime fast-path to avoid unnecessary YAML reads
  - Deep-compares mcp_servers dict to detect header changes
  - Runs shutdown/discover in executor to avoid blocking event loop
  - Sleeps in 1s increments for responsive shutdown
  - 30s initial delay to let startup finish

Tests: 7 new tests covering no-change skip, header change detection,
non-MCP change skip, server add/remove, shutdown behavior, and
full integration test.
This commit is contained in:
Hermes Agent 2026-04-12 07:03:50 +00:00
parent 4cadfef8e3
commit 8cf7e80fc5
2 changed files with 330 additions and 1 deletions

View file

@ -631,7 +631,18 @@ class GatewayRunner:
# Track background tasks to prevent garbage collection mid-execution # Track background tasks to prevent garbage collection mid-execution
self._background_tasks: set = set() self._background_tasks: set = set()
# MCP config watcher state — detect header changes (e.g. OAuth token refresh)
self._mcp_config_mtime: float = 0.0
self._mcp_config_servers: dict = {}
try:
from hermes_cli.config import get_config_path, load_config
cfg_path = get_config_path()
if cfg_path.exists():
self._mcp_config_mtime = cfg_path.stat().st_mtime
cfg = load_config()
self._mcp_config_servers = cfg.get("mcp_servers") or {}
except Exception:
pass
# -- Setup skill availability ---------------------------------------- # -- Setup skill availability ----------------------------------------
@ -1690,10 +1701,111 @@ class GatewayRunner:
) )
asyncio.create_task(self._platform_reconnect_watcher()) asyncio.create_task(self._platform_reconnect_watcher())
# Start background MCP config watcher for auto-reloading on token refresh
asyncio.create_task(self._mcp_config_watcher())
logger.info("Press Ctrl+C to stop") logger.info("Press Ctrl+C to stop")
return True return True
async def _mcp_config_watcher(self, interval: int = 30, _initial_delay: int = 30) -> None:
"""Background task that detects MCP config changes and auto-reloads connections.
Polls config.yaml every ``interval`` seconds. When the ``mcp_servers``
section changes (e.g. OAuth token refresh updates the Authorization
header), triggers a full MCP shutdown + reconnect so the running
gateway picks up new credentials without a restart.
Mirrors the CLI's ``_check_config_mcp_changes`` but adapted for the
async gateway event loop.
"""
# Initial delay — let startup finish. Sleep in 1s increments for quick shutdown.
for _ in range(_initial_delay):
if not self._running:
return
await asyncio.sleep(1)
logger.info("MCP config watcher started (checking every %ds)", interval)
while self._running:
try:
from hermes_cli.config import get_config_path
import yaml as _yaml
cfg_path = get_config_path()
if not cfg_path.exists():
await asyncio.sleep(interval)
continue
try:
mtime = cfg_path.stat().st_mtime
except OSError:
await asyncio.sleep(interval)
continue
if mtime == self._mcp_config_mtime:
await asyncio.sleep(interval)
continue
# File changed — read and compare mcp_servers section
self._mcp_config_mtime = mtime
try:
with open(cfg_path, encoding="utf-8") as f:
new_cfg = _yaml.safe_load(f) or {}
except Exception:
await asyncio.sleep(interval)
continue
new_mcp = new_cfg.get("mcp_servers") or {}
if new_mcp == self._mcp_config_servers:
# Some other config section changed, not MCP
await asyncio.sleep(interval)
continue
self._mcp_config_servers = new_mcp
logger.info("MCP config change detected — reloading connections...")
# Perform the reload in a thread to avoid blocking the event loop
loop = asyncio.get_event_loop()
try:
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _servers, _lock
with _lock:
old_servers = set(_servers.keys())
await loop.run_in_executor(None, shutdown_mcp_servers)
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
with _lock:
connected_servers = set(_servers.keys())
added = connected_servers - old_servers
removed = old_servers - connected_servers
reconnected = connected_servers & old_servers
parts = []
if reconnected:
parts.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
if added:
parts.append(f" Added: {', '.join(sorted(added))}")
if removed:
parts.append(f" Removed: {', '.join(sorted(removed))}")
parts.append(
f"🔧 {len(new_tools)} tool(s) from {len(connected_servers)} server(s)"
)
logger.info("MCP auto-reload complete: %s", "; ".join(parts))
except Exception as e:
logger.warning("MCP auto-reload failed: %s", e)
except Exception as e:
logger.debug("MCP config watcher error: %s", e)
# Sleep in 1-second increments so we respond quickly to shutdown
for _ in range(interval):
if not self._running:
return
await asyncio.sleep(1)
async def _session_expiry_watcher(self, interval: int = 300): async def _session_expiry_watcher(self, interval: int = 300):
"""Background task that proactively flushes memories for expired sessions. """Background task that proactively flushes memories for expired sessions.

View file

@ -0,0 +1,217 @@
"""Tests for gateway MCP config watcher — auto-reload on mcp_servers changes."""
import asyncio
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from gateway.run import GatewayRunner
def _make_runner(tmp_path, mcp_servers=None):
"""Create a minimal GatewayRunner with mocked MCP config watcher state."""
runner = object.__new__(GatewayRunner)
runner._running = True
runner._mcp_config_servers = mcp_servers or {}
cfg_file = tmp_path / "config.yaml"
cfg_file.write_text(yaml.dump({"mcp_servers": mcp_servers or {}}))
runner._mcp_config_mtime = cfg_file.stat().st_mtime
return runner, cfg_file
class TestMCPConfigWatcher:
@pytest.mark.asyncio
async def test_no_change_does_not_reload(self, tmp_path):
"""If config file hasn't changed, no MCP reload should happen."""
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
})
reload_called = False
async def fake_watcher_iteration():
nonlocal reload_called
from hermes_cli.config import get_config_path
import yaml as _yaml
cfg_path = cfg_file
mtime = cfg_path.stat().st_mtime
if mtime == runner._mcp_config_mtime:
return # No change — fast path
runner._mcp_config_mtime = mtime
with open(cfg_path, encoding="utf-8") as f:
new_cfg = _yaml.safe_load(f) or {}
new_mcp = new_cfg.get("mcp_servers") or {}
if new_mcp == runner._mcp_config_servers:
return
reload_called = True
await fake_watcher_iteration()
assert not reload_called
@pytest.mark.asyncio
async def test_header_change_triggers_reload(self, tmp_path):
"""When Authorization header changes, reload should be triggered."""
old_servers = {
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old_token"}}
}
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
# Simulate token refresh updating the config
new_servers = {
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new_token"}}
}
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
# Force mtime to look different
runner._mcp_config_mtime = 0.0
reload_triggered = False
# Simulate one iteration of the watcher's core logic
mtime = cfg_file.stat().st_mtime
assert mtime != runner._mcp_config_mtime
runner._mcp_config_mtime = mtime
with open(cfg_file, encoding="utf-8") as f:
new_cfg = yaml.safe_load(f) or {}
new_mcp = new_cfg.get("mcp_servers") or {}
if new_mcp != runner._mcp_config_servers:
reload_triggered = True
runner._mcp_config_servers = new_mcp
assert reload_triggered
assert runner._mcp_config_servers == new_servers
@pytest.mark.asyncio
async def test_non_mcp_change_does_not_reload(self, tmp_path):
"""If a non-MCP section changes but mcp_servers stays the same, no reload."""
servers = {
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer tok"}}
}
runner, cfg_file = _make_runner(tmp_path, mcp_servers=servers)
# Write same mcp_servers but change something else
cfg_file.write_text(yaml.dump({
"mcp_servers": servers,
"some_other_setting": "changed"
}))
runner._mcp_config_mtime = 0.0 # force stale mtime
mtime = cfg_file.stat().st_mtime
runner._mcp_config_mtime = mtime
with open(cfg_file, encoding="utf-8") as f:
new_cfg = yaml.safe_load(f) or {}
new_mcp = new_cfg.get("mcp_servers") or {}
assert new_mcp == runner._mcp_config_servers # Should be unchanged
@pytest.mark.asyncio
async def test_server_added_triggers_reload(self, tmp_path):
"""Adding a new MCP server to config triggers reload."""
runner, cfg_file = _make_runner(tmp_path, mcp_servers={})
new_servers = {"github": {"url": "https://api.github.com/mcp"}}
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
runner._mcp_config_mtime = 0.0
mtime = cfg_file.stat().st_mtime
runner._mcp_config_mtime = mtime
with open(cfg_file, encoding="utf-8") as f:
new_cfg = yaml.safe_load(f) or {}
new_mcp = new_cfg.get("mcp_servers") or {}
assert new_mcp != runner._mcp_config_servers
runner._mcp_config_servers = new_mcp
assert runner._mcp_config_servers == new_servers
@pytest.mark.asyncio
async def test_server_removed_triggers_reload(self, tmp_path):
"""Removing an MCP server from config triggers reload."""
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
"github": {"url": "https://api.github.com/mcp"}
})
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
runner._mcp_config_mtime = 0.0
mtime = cfg_file.stat().st_mtime
runner._mcp_config_mtime = mtime
with open(cfg_file, encoding="utf-8") as f:
new_cfg = yaml.safe_load(f) or {}
new_mcp = new_cfg.get("mcp_servers") or {}
assert new_mcp != runner._mcp_config_servers
@pytest.mark.asyncio
async def test_watcher_stops_on_shutdown(self, tmp_path):
"""Watcher loop exits when _running is set to False."""
runner, cfg_file = _make_runner(tmp_path)
runner._running = False
# The watcher should return almost immediately
# We test it doesn't hang by using a timeout
try:
await asyncio.wait_for(
runner._mcp_config_watcher(interval=1, _initial_delay=0),
timeout=5.0,
)
except asyncio.TimeoutError:
pytest.fail("_mcp_config_watcher did not exit after _running=False")
@pytest.mark.asyncio
async def test_full_watcher_detects_change_and_reloads(self, tmp_path):
"""Integration test: watcher detects a header change and calls MCP reload."""
old_servers = {
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
}
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
# Prepare the config change that will happen during the watcher run
new_servers = {
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new"}}
}
shutdown_mock = MagicMock()
discover_mock = MagicMock(return_value=[{"function": {"name": "test_tool"}}])
servers_dict = {"betterstack": MagicMock()}
lock_mock = MagicMock()
async def stop_after_reload():
"""Write the config change, wait for the watcher to pick it up, then stop."""
await asyncio.sleep(0.5)
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
# Wait enough time for the watcher to detect + reload
await asyncio.sleep(4)
runner._running = False
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
patch("tools.mcp_tool.shutdown_mcp_servers", shutdown_mock), \
patch("tools.mcp_tool.discover_mcp_tools", discover_mock), \
patch("tools.mcp_tool._servers", servers_dict), \
patch("tools.mcp_tool._lock", lock_mock):
stop_task = asyncio.create_task(stop_after_reload())
try:
await asyncio.wait_for(
runner._mcp_config_watcher(interval=1, _initial_delay=0),
timeout=10.0,
)
except asyncio.TimeoutError:
runner._running = False
await stop_task
shutdown_mock.assert_called_once()
discover_mock.assert_called_once()
assert runner._mcp_config_servers == new_servers