mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix: keep mcp dynamic refresh tasks tracked
This commit is contained in:
parent
02ae152222
commit
1350d12b0b
2 changed files with 177 additions and 20 deletions
|
|
@ -706,6 +706,106 @@ class TestMCPServerTask:
|
||||||
|
|
||||||
asyncio.run(_test())
|
asyncio.run(_test())
|
||||||
|
|
||||||
|
def test_refresh_tools_deregisters_removed_tools(self):
|
||||||
|
"""Dynamic refresh removes stale registry entries for deleted tools."""
|
||||||
|
from tools.registry import ToolRegistry
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
|
||||||
|
mock_registry = ToolRegistry()
|
||||||
|
server = MCPServerTask("srv")
|
||||||
|
server._config = {"command": "test"}
|
||||||
|
server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")]
|
||||||
|
server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"]
|
||||||
|
server.session = MagicMock()
|
||||||
|
server.session.list_tools = AsyncMock(
|
||||||
|
return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")])
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("tools.registry.registry", mock_registry):
|
||||||
|
mock_registry.register(
|
||||||
|
name="mcp_srv_old",
|
||||||
|
toolset="mcp-srv",
|
||||||
|
schema={"name": "mcp_srv_old", "description": "Old"},
|
||||||
|
handler=lambda *_args, **_kwargs: "{}",
|
||||||
|
)
|
||||||
|
mock_registry.register(
|
||||||
|
name="mcp_srv_keep",
|
||||||
|
toolset="mcp-srv",
|
||||||
|
schema={"name": "mcp_srv_keep", "description": "Keep"},
|
||||||
|
handler=lambda *_args, **_kwargs: "{}",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(server._refresh_tools())
|
||||||
|
|
||||||
|
names = mock_registry.get_all_tool_names()
|
||||||
|
assert "mcp_srv_old" not in names
|
||||||
|
assert "mcp_srv_keep" in names
|
||||||
|
assert "mcp_srv_new" in names
|
||||||
|
assert set(server._registered_tool_names) == {
|
||||||
|
"mcp_srv_keep",
|
||||||
|
"mcp_srv_new",
|
||||||
|
"mcp_srv_list_resources",
|
||||||
|
"mcp_srv_read_resource",
|
||||||
|
"mcp_srv_list_prompts",
|
||||||
|
"mcp_srv_get_prompt",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_schedule_tools_refresh_keeps_task_until_done(self):
|
||||||
|
"""Background refresh tasks are strongly referenced and then discarded."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
|
||||||
|
async def _test():
|
||||||
|
started = asyncio.Event()
|
||||||
|
finish = asyncio.Event()
|
||||||
|
server = MCPServerTask("srv")
|
||||||
|
|
||||||
|
async def fake_refresh(_server):
|
||||||
|
started.set()
|
||||||
|
await finish.wait()
|
||||||
|
|
||||||
|
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
|
||||||
|
server._schedule_tools_refresh()
|
||||||
|
|
||||||
|
await started.wait()
|
||||||
|
assert len(server._pending_refresh_tasks) == 1
|
||||||
|
task = next(iter(server._pending_refresh_tasks))
|
||||||
|
assert not task.done()
|
||||||
|
|
||||||
|
finish.set()
|
||||||
|
await task
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert server._pending_refresh_tasks == set()
|
||||||
|
|
||||||
|
asyncio.run(_test())
|
||||||
|
|
||||||
|
def test_shutdown_cancels_pending_refresh_tasks(self):
|
||||||
|
"""shutdown() cancels in-flight background refresh tasks."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
|
||||||
|
async def _test():
|
||||||
|
started = asyncio.Event()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
server = MCPServerTask("srv")
|
||||||
|
|
||||||
|
async def fake_refresh(_server):
|
||||||
|
started.set()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
|
||||||
|
server._schedule_tools_refresh()
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
await server.shutdown()
|
||||||
|
|
||||||
|
assert cancelled.is_set()
|
||||||
|
assert server._pending_refresh_tasks == set()
|
||||||
|
|
||||||
|
asyncio.run(_test())
|
||||||
|
|
||||||
def test_empty_env_gets_safe_defaults(self):
|
def test_empty_env_gets_safe_defaults(self):
|
||||||
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
|
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
|
||||||
from tools.mcp_tool import MCPServerTask
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
|
@ -1993,7 +2093,13 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ToolUseContent = _CompatType
|
ToolUseContent = _CompatType
|
||||||
|
|
||||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
from tools.mcp_tool import (
|
||||||
|
CreateMessageResultWithTools,
|
||||||
|
SamplingHandler,
|
||||||
|
SamplingToolsCapability,
|
||||||
|
ToolUseContent,
|
||||||
|
_safe_numeric,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -167,10 +167,22 @@ _MCP_HTTP_AVAILABLE = False
|
||||||
_MCP_SAMPLING_TYPES = False
|
_MCP_SAMPLING_TYPES = False
|
||||||
_MCP_NOTIFICATION_TYPES = False
|
_MCP_NOTIFICATION_TYPES = False
|
||||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||||
|
_MCP_NEW_HTTP = False
|
||||||
|
streamablehttp_client = None
|
||||||
|
streamable_http_client = None
|
||||||
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
|
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
|
||||||
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
|
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
|
||||||
# HTTP transport path even on older-but-supported SDK versions.
|
# HTTP transport path even on older-but-supported SDK versions.
|
||||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||||
|
|
||||||
|
|
||||||
|
class _CompatType:
|
||||||
|
"""Minimal attribute bag for MCP SDK types missing in older/newer builds."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
@ -191,20 +203,28 @@ try:
|
||||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
|
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
|
||||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
# Sampling types -- import individually because SDK names changed across releases.
|
||||||
try:
|
try:
|
||||||
from mcp.types import (
|
from mcp.types import CreateMessageResult, ErrorData, SamplingCapability, TextContent
|
||||||
CreateMessageResult,
|
|
||||||
CreateMessageResultWithTools,
|
try:
|
||||||
ErrorData,
|
from mcp.types import CreateMessageResultWithTools
|
||||||
SamplingCapability,
|
except ImportError:
|
||||||
SamplingToolsCapability,
|
CreateMessageResultWithTools = _CompatType
|
||||||
TextContent,
|
|
||||||
ToolUseContent,
|
try:
|
||||||
)
|
from mcp.types import SamplingToolsCapability
|
||||||
|
except ImportError:
|
||||||
|
SamplingToolsCapability = _CompatType
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp.types import ToolUseContent
|
||||||
|
except ImportError:
|
||||||
|
ToolUseContent = _CompatType
|
||||||
|
|
||||||
_MCP_SAMPLING_TYPES = True
|
_MCP_SAMPLING_TYPES = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
logger.debug("MCP sampling base types not available -- sampling disabled")
|
||||||
# Notification types for dynamic tool discovery (tools/list_changed)
|
# Notification types for dynamic tool discovery (tools/list_changed)
|
||||||
try:
|
try:
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
|
@ -868,7 +888,7 @@ class MCPServerTask:
|
||||||
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
||||||
"_tools", "_error", "_config",
|
"_tools", "_error", "_config",
|
||||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||||
"_rpc_lock",
|
"_rpc_lock", "_pending_refresh_tasks",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
|
|
@ -895,8 +915,10 @@ class MCPServerTask:
|
||||||
# list_changed notifications during startup; if the notification
|
# list_changed notifications during startup; if the notification
|
||||||
# handler calls list_tools while a normal tool call is in flight, the
|
# handler calls list_tools while a normal tool call is in flight, the
|
||||||
# stream can wedge and the user-visible tool call times out. Serialize
|
# stream can wedge and the user-visible tool call times out. Serialize
|
||||||
# client-initiated RPCs per server.
|
# client-initiated RPCs per server. The lock is also applied to HTTP
|
||||||
|
# transports for conservative per-server ordering.
|
||||||
self._rpc_lock = asyncio.Lock()
|
self._rpc_lock = asyncio.Lock()
|
||||||
|
self._pending_refresh_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
def _is_http(self) -> bool:
|
def _is_http(self) -> bool:
|
||||||
"""Check if this server uses HTTP transport."""
|
"""Check if this server uses HTTP transport."""
|
||||||
|
|
@ -904,6 +926,21 @@ class MCPServerTask:
|
||||||
|
|
||||||
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
||||||
|
|
||||||
|
async def _refresh_tools_task(self):
|
||||||
|
"""Run a dynamic tool refresh and log failures from background tasks."""
|
||||||
|
try:
|
||||||
|
await self._refresh_tools()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("MCP server '%s': dynamic tool refresh failed", self.name)
|
||||||
|
|
||||||
|
def _schedule_tools_refresh(self) -> None:
|
||||||
|
"""Schedule a background tool refresh and keep it strongly referenced."""
|
||||||
|
task = asyncio.create_task(self._refresh_tools_task())
|
||||||
|
self._pending_refresh_tasks.add(task)
|
||||||
|
task.add_done_callback(self._pending_refresh_tasks.discard)
|
||||||
|
|
||||||
def _make_message_handler(self):
|
def _make_message_handler(self):
|
||||||
"""Build a ``message_handler`` callback for ``ClientSession``.
|
"""Build a ``message_handler`` callback for ``ClientSession``.
|
||||||
|
|
||||||
|
|
@ -932,7 +969,7 @@ class MCPServerTask:
|
||||||
# subsequent tool calls time out. Do the refresh in
|
# subsequent tool calls time out. Do the refresh in
|
||||||
# a separate task and let the handler return
|
# a separate task and let the handler return
|
||||||
# promptly.
|
# promptly.
|
||||||
asyncio.create_task(self._refresh_tools())
|
self._schedule_tools_refresh()
|
||||||
case PromptListChangedNotification():
|
case PromptListChangedNotification():
|
||||||
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
||||||
case ResourceListChangedNotification():
|
case ResourceListChangedNotification():
|
||||||
|
|
@ -962,11 +999,20 @@ class MCPServerTask:
|
||||||
tools_result = await self.session.list_tools()
|
tools_result = await self.session.list_tools()
|
||||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||||
|
|
||||||
# 2. Re-register with fresh tool list. Avoid deregistering first:
|
# 2. Re-register with fresh tool list. Avoid nuke-and-repave for
|
||||||
# live agent turns already have tool-call IDs pointing at the
|
# all names: live agent turns may already have tool-call IDs
|
||||||
# existing handler functions. Replacing entries in-place is enough
|
# pointing at existing handler functions. Replacing entries
|
||||||
# for unchanged names and avoids transient "tool not connected" /
|
# in-place is enough for unchanged names and avoids transient
|
||||||
# stale-handler races during startup notifications.
|
# "tool not connected" / stale-handler races during startup
|
||||||
|
# notifications. Tools absent from the fresh list are no longer
|
||||||
|
# callable, so remove only those stale registry entries first.
|
||||||
|
stale_tool_names = old_tool_names - {
|
||||||
|
f"mcp_{sanitize_mcp_name_component(self.name)}_"
|
||||||
|
f"{sanitize_mcp_name_component(tool.name)}"
|
||||||
|
for tool in new_mcp_tools
|
||||||
|
}
|
||||||
|
for tool_name in stale_tool_names:
|
||||||
|
registry.deregister(tool_name)
|
||||||
|
|
||||||
# 3. Re-register with fresh tool list
|
# 3. Re-register with fresh tool list
|
||||||
self._tools = new_mcp_tools
|
self._tools = new_mcp_tools
|
||||||
|
|
@ -1383,6 +1429,11 @@ class MCPServerTask:
|
||||||
await self._task
|
await self._task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
if self._pending_refresh_tasks:
|
||||||
|
for task in list(self._pending_refresh_tasks):
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*self._pending_refresh_tasks, return_exceptions=True)
|
||||||
|
self._pending_refresh_tasks.clear()
|
||||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||||
registry.deregister(tool_name)
|
registry.deregister(tool_name)
|
||||||
self._registered_tool_names = []
|
self._registered_tool_names = []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue