mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(gateway): auto-reload MCP connections on config change
Add a background _mcp_config_watcher() task to the gateway that polls config.yaml every 30 seconds and auto-reloads MCP server connections when the mcp_servers section changes. This solves the problem where OAuth token refresh cron jobs update Bearer tokens in config.yaml, but the running gateway keeps using stale cached credentials until manually restarted. The CLI already had this via _check_config_mcp_changes() — this ports the same concept to the async gateway event loop. Changes: - GatewayRunner.__init__: initialize _mcp_config_mtime and _mcp_config_servers state from current config - GatewayRunner.start(): launch _mcp_config_watcher as background task - GatewayRunner._mcp_config_watcher(): async background task that: - Uses mtime fast-path to avoid unnecessary YAML reads - Deep-compares mcp_servers dict to detect header changes - Runs shutdown/discover in executor to avoid blocking event loop - Sleeps in 1s increments for responsive shutdown - 30s initial delay to let startup finish Tests: 7 new tests covering no-change skip, header change detection, non-MCP change skip, server add/remove, shutdown behavior, and full integration test.
This commit is contained in:
parent
4cadfef8e3
commit
8cf7e80fc5
2 changed files with 330 additions and 1 deletions
114
gateway/run.py
114
gateway/run.py
|
|
@ -631,7 +631,18 @@ class GatewayRunner:
|
||||||
# Track background tasks to prevent garbage collection mid-execution
|
# Track background tasks to prevent garbage collection mid-execution
|
||||||
self._background_tasks: set = set()
|
self._background_tasks: set = set()
|
||||||
|
|
||||||
|
# MCP config watcher state — detect header changes (e.g. OAuth token refresh)
|
||||||
|
self._mcp_config_mtime: float = 0.0
|
||||||
|
self._mcp_config_servers: dict = {}
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import get_config_path, load_config
|
||||||
|
cfg_path = get_config_path()
|
||||||
|
if cfg_path.exists():
|
||||||
|
self._mcp_config_mtime = cfg_path.stat().st_mtime
|
||||||
|
cfg = load_config()
|
||||||
|
self._mcp_config_servers = cfg.get("mcp_servers") or {}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# -- Setup skill availability ----------------------------------------
|
# -- Setup skill availability ----------------------------------------
|
||||||
|
|
@ -1690,10 +1701,111 @@ class GatewayRunner:
|
||||||
)
|
)
|
||||||
asyncio.create_task(self._platform_reconnect_watcher())
|
asyncio.create_task(self._platform_reconnect_watcher())
|
||||||
|
|
||||||
|
# Start background MCP config watcher for auto-reloading on token refresh
|
||||||
|
asyncio.create_task(self._mcp_config_watcher())
|
||||||
|
|
||||||
logger.info("Press Ctrl+C to stop")
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _mcp_config_watcher(self, interval: int = 30, _initial_delay: int = 30) -> None:
|
||||||
|
"""Background task that detects MCP config changes and auto-reloads connections.
|
||||||
|
|
||||||
|
Polls config.yaml every ``interval`` seconds. When the ``mcp_servers``
|
||||||
|
section changes (e.g. OAuth token refresh updates the Authorization
|
||||||
|
header), triggers a full MCP shutdown + reconnect so the running
|
||||||
|
gateway picks up new credentials without a restart.
|
||||||
|
|
||||||
|
Mirrors the CLI's ``_check_config_mcp_changes`` but adapted for the
|
||||||
|
async gateway event loop.
|
||||||
|
"""
|
||||||
|
# Initial delay — let startup finish. Sleep in 1s increments for quick shutdown.
|
||||||
|
for _ in range(_initial_delay):
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
logger.info("MCP config watcher started (checking every %ds)", interval)
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import get_config_path
|
||||||
|
import yaml as _yaml
|
||||||
|
|
||||||
|
cfg_path = get_config_path()
|
||||||
|
if not cfg_path.exists():
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
mtime = cfg_path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtime == self._mcp_config_mtime:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# File changed — read and compare mcp_servers section
|
||||||
|
self._mcp_config_mtime = mtime
|
||||||
|
try:
|
||||||
|
with open(cfg_path, encoding="utf-8") as f:
|
||||||
|
new_cfg = _yaml.safe_load(f) or {}
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp == self._mcp_config_servers:
|
||||||
|
# Some other config section changed, not MCP
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._mcp_config_servers = new_mcp
|
||||||
|
logger.info("MCP config change detected — reloading connections...")
|
||||||
|
|
||||||
|
# Perform the reload in a thread to avoid blocking the event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _servers, _lock
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
old_servers = set(_servers.keys())
|
||||||
|
|
||||||
|
await loop.run_in_executor(None, shutdown_mcp_servers)
|
||||||
|
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
connected_servers = set(_servers.keys())
|
||||||
|
|
||||||
|
added = connected_servers - old_servers
|
||||||
|
removed = old_servers - connected_servers
|
||||||
|
reconnected = connected_servers & old_servers
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if reconnected:
|
||||||
|
parts.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||||
|
if added:
|
||||||
|
parts.append(f"➕ Added: {', '.join(sorted(added))}")
|
||||||
|
if removed:
|
||||||
|
parts.append(f"➖ Removed: {', '.join(sorted(removed))}")
|
||||||
|
parts.append(
|
||||||
|
f"🔧 {len(new_tools)} tool(s) from {len(connected_servers)} server(s)"
|
||||||
|
)
|
||||||
|
logger.info("MCP auto-reload complete: %s", "; ".join(parts))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("MCP auto-reload failed: %s", e)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("MCP config watcher error: %s", e)
|
||||||
|
|
||||||
|
# Sleep in 1-second increments so we respond quickly to shutdown
|
||||||
|
for _ in range(interval):
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def _session_expiry_watcher(self, interval: int = 300):
|
async def _session_expiry_watcher(self, interval: int = 300):
|
||||||
"""Background task that proactively flushes memories for expired sessions.
|
"""Background task that proactively flushes memories for expired sessions.
|
||||||
|
|
||||||
|
|
|
||||||
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
217
tests/gateway/test_gateway_mcp_config_watcher.py
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
"""Tests for gateway MCP config watcher — auto-reload on mcp_servers changes."""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(tmp_path, mcp_servers=None):
|
||||||
|
"""Create a minimal GatewayRunner with mocked MCP config watcher state."""
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner._running = True
|
||||||
|
runner._mcp_config_servers = mcp_servers or {}
|
||||||
|
|
||||||
|
cfg_file = tmp_path / "config.yaml"
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": mcp_servers or {}}))
|
||||||
|
runner._mcp_config_mtime = cfg_file.stat().st_mtime
|
||||||
|
|
||||||
|
return runner, cfg_file
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPConfigWatcher:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_change_does_not_reload(self, tmp_path):
|
||||||
|
"""If config file hasn't changed, no MCP reload should happen."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||||
|
})
|
||||||
|
|
||||||
|
reload_called = False
|
||||||
|
|
||||||
|
async def fake_watcher_iteration():
|
||||||
|
nonlocal reload_called
|
||||||
|
from hermes_cli.config import get_config_path
|
||||||
|
import yaml as _yaml
|
||||||
|
|
||||||
|
cfg_path = cfg_file
|
||||||
|
mtime = cfg_path.stat().st_mtime
|
||||||
|
|
||||||
|
if mtime == runner._mcp_config_mtime:
|
||||||
|
return # No change — fast path
|
||||||
|
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_path, encoding="utf-8") as f:
|
||||||
|
new_cfg = _yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp == runner._mcp_config_servers:
|
||||||
|
return
|
||||||
|
|
||||||
|
reload_called = True
|
||||||
|
|
||||||
|
await fake_watcher_iteration()
|
||||||
|
assert not reload_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_header_change_triggers_reload(self, tmp_path):
|
||||||
|
"""When Authorization header changes, reload should be triggered."""
|
||||||
|
old_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old_token"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||||
|
|
||||||
|
# Simulate token refresh updating the config
|
||||||
|
new_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new_token"}}
|
||||||
|
}
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
|
||||||
|
# Force mtime to look different
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
reload_triggered = False
|
||||||
|
|
||||||
|
# Simulate one iteration of the watcher's core logic
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
assert mtime != runner._mcp_config_mtime
|
||||||
|
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp != runner._mcp_config_servers:
|
||||||
|
reload_triggered = True
|
||||||
|
runner._mcp_config_servers = new_mcp
|
||||||
|
|
||||||
|
assert reload_triggered
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_mcp_change_does_not_reload(self, tmp_path):
|
||||||
|
"""If a non-MCP section changes but mcp_servers stays the same, no reload."""
|
||||||
|
servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer tok"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=servers)
|
||||||
|
|
||||||
|
# Write same mcp_servers but change something else
|
||||||
|
cfg_file.write_text(yaml.dump({
|
||||||
|
"mcp_servers": servers,
|
||||||
|
"some_other_setting": "changed"
|
||||||
|
}))
|
||||||
|
runner._mcp_config_mtime = 0.0 # force stale mtime
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp == runner._mcp_config_servers # Should be unchanged
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_added_triggers_reload(self, tmp_path):
|
||||||
|
"""Adding a new MCP server to config triggers reload."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={})
|
||||||
|
|
||||||
|
new_servers = {"github": {"url": "https://api.github.com/mcp"}}
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp != runner._mcp_config_servers
|
||||||
|
runner._mcp_config_servers = new_mcp
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_removed_triggers_reload(self, tmp_path):
|
||||||
|
"""Removing an MCP server from config triggers reload."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers={
|
||||||
|
"github": {"url": "https://api.github.com/mcp"}
|
||||||
|
})
|
||||||
|
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
|
||||||
|
runner._mcp_config_mtime = 0.0
|
||||||
|
|
||||||
|
mtime = cfg_file.stat().st_mtime
|
||||||
|
runner._mcp_config_mtime = mtime
|
||||||
|
with open(cfg_file, encoding="utf-8") as f:
|
||||||
|
new_cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
assert new_mcp != runner._mcp_config_servers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_watcher_stops_on_shutdown(self, tmp_path):
|
||||||
|
"""Watcher loop exits when _running is set to False."""
|
||||||
|
runner, cfg_file = _make_runner(tmp_path)
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
# The watcher should return almost immediately
|
||||||
|
# We test it doesn't hang by using a timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("_mcp_config_watcher did not exit after _running=False")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_watcher_detects_change_and_reloads(self, tmp_path):
|
||||||
|
"""Integration test: watcher detects a header change and calls MCP reload."""
|
||||||
|
old_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer old"}}
|
||||||
|
}
|
||||||
|
runner, cfg_file = _make_runner(tmp_path, mcp_servers=old_servers)
|
||||||
|
|
||||||
|
# Prepare the config change that will happen during the watcher run
|
||||||
|
new_servers = {
|
||||||
|
"betterstack": {"url": "https://mcp.betterstack.com", "headers": {"Authorization": "Bearer new"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdown_mock = MagicMock()
|
||||||
|
discover_mock = MagicMock(return_value=[{"function": {"name": "test_tool"}}])
|
||||||
|
servers_dict = {"betterstack": MagicMock()}
|
||||||
|
lock_mock = MagicMock()
|
||||||
|
|
||||||
|
async def stop_after_reload():
|
||||||
|
"""Write the config change, wait for the watcher to pick it up, then stop."""
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": new_servers}))
|
||||||
|
# Wait enough time for the watcher to detect + reload
|
||||||
|
await asyncio.sleep(4)
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
|
||||||
|
patch("tools.mcp_tool.shutdown_mcp_servers", shutdown_mock), \
|
||||||
|
patch("tools.mcp_tool.discover_mcp_tools", discover_mock), \
|
||||||
|
patch("tools.mcp_tool._servers", servers_dict), \
|
||||||
|
patch("tools.mcp_tool._lock", lock_mock):
|
||||||
|
|
||||||
|
stop_task = asyncio.create_task(stop_after_reload())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
runner._mcp_config_watcher(interval=1, _initial_delay=0),
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
runner._running = False
|
||||||
|
|
||||||
|
await stop_task
|
||||||
|
|
||||||
|
shutdown_mock.assert_called_once()
|
||||||
|
discover_mock.assert_called_once()
|
||||||
|
assert runner._mcp_config_servers == new_servers
|
||||||
Loading…
Add table
Add a link
Reference in a new issue