mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): make server aliases explicit
This commit is contained in:
parent
cda64a5961
commit
c10fea8d26
6 changed files with 133 additions and 36 deletions
|
|
@ -124,6 +124,7 @@ class TestValidateToolset:
|
||||||
schema=_make_schema("mcp_dynserver_ping", "Ping"),
|
schema=_make_schema("mcp_dynserver_ping", "Ping"),
|
||||||
handler=_dummy_handler,
|
handler=_dummy_handler,
|
||||||
)
|
)
|
||||||
|
reg.register_toolset_alias("dynserver", "mcp-dynserver")
|
||||||
|
|
||||||
monkeypatch.setattr("tools.registry.registry", reg)
|
monkeypatch.setattr("tools.registry.registry", reg)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,25 @@ class TestDeregister:
|
||||||
# bar still in ts1, so check should remain
|
# bar still in ts1, so check should remain
|
||||||
assert "ts1" in reg._toolset_checks
|
assert "ts1" in reg._toolset_checks
|
||||||
|
|
||||||
|
def test_removes_toolset_alias_when_last_tool_is_removed(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||||
|
reg.register_toolset_alias("srv", "mcp-srv")
|
||||||
|
|
||||||
|
reg.deregister("foo")
|
||||||
|
|
||||||
|
assert reg.get_toolset_alias_target("srv") is None
|
||||||
|
|
||||||
|
def test_preserves_toolset_alias_while_toolset_still_exists(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||||
|
reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||||
|
reg.register_toolset_alias("srv", "mcp-srv")
|
||||||
|
|
||||||
|
reg.deregister("foo")
|
||||||
|
|
||||||
|
assert reg.get_toolset_alias_target("srv") == "mcp-srv"
|
||||||
|
|
||||||
def test_noop_for_unknown_tool(self):
|
def test_noop_for_unknown_tool(self):
|
||||||
reg = ToolRegistry()
|
reg = ToolRegistry()
|
||||||
reg.deregister("nonexistent") # Should not raise
|
reg.deregister("nonexistent") # Should not raise
|
||||||
|
|
|
||||||
|
|
@ -184,11 +184,7 @@ class TestToolHandler:
|
||||||
def _patch_mcp_loop(self, coro_side_effect=None):
|
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||||
def fake_run(coro, timeout=30):
|
def fake_run(coro, timeout=30):
|
||||||
loop = asyncio.new_event_loop()
|
return asyncio.run(coro)
|
||||||
try:
|
|
||||||
return loop.run_until_complete(coro)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
if coro_side_effect:
|
if coro_side_effect:
|
||||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||||
|
|
@ -774,6 +770,42 @@ class TestShutdown:
|
||||||
assert len(_servers) == 0
|
assert len(_servers) == 0
|
||||||
mock_server.shutdown.assert_called_once()
|
mock_server.shutdown.assert_called_once()
|
||||||
|
|
||||||
|
def test_shutdown_deregisters_registered_tools(self):
|
||||||
|
"""shutdown_mcp_servers removes MCP tools and their raw alias."""
|
||||||
|
import tools.mcp_tool as mcp_mod
|
||||||
|
from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers
|
||||||
|
from tools.registry import registry
|
||||||
|
from toolsets import resolve_toolset, validate_toolset
|
||||||
|
|
||||||
|
_servers.clear()
|
||||||
|
registry.register(
|
||||||
|
name="mcp_test_ping",
|
||||||
|
toolset="mcp-test",
|
||||||
|
schema={
|
||||||
|
"name": "mcp_test_ping",
|
||||||
|
"description": "Ping",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
handler=lambda *_args, **_kwargs: "{}",
|
||||||
|
)
|
||||||
|
registry.register_toolset_alias("test", "mcp-test")
|
||||||
|
|
||||||
|
server = MCPServerTask("test")
|
||||||
|
server._registered_tool_names = ["mcp_test_ping"]
|
||||||
|
_servers["test"] = server
|
||||||
|
|
||||||
|
mcp_mod._ensure_mcp_loop()
|
||||||
|
try:
|
||||||
|
assert validate_toolset("test") is True
|
||||||
|
assert "mcp_test_ping" in resolve_toolset("test")
|
||||||
|
shutdown_mcp_servers()
|
||||||
|
finally:
|
||||||
|
mcp_mod._mcp_loop = None
|
||||||
|
mcp_mod._mcp_thread = None
|
||||||
|
|
||||||
|
assert "mcp_test_ping" not in registry.get_all_tool_names()
|
||||||
|
assert validate_toolset("test") is False
|
||||||
|
|
||||||
def test_shutdown_handles_errors(self):
|
def test_shutdown_handles_errors(self):
|
||||||
"""shutdown_mcp_servers handles errors during close gracefully."""
|
"""shutdown_mcp_servers handles errors during close gracefully."""
|
||||||
import tools.mcp_tool as mcp_mod
|
import tools.mcp_tool as mcp_mod
|
||||||
|
|
@ -1177,7 +1209,11 @@ class TestConfigurableTimeouts:
|
||||||
try:
|
try:
|
||||||
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
||||||
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
||||||
mock_run.return_value = json.dumps({"result": "ok"})
|
def fake_run(coro, timeout=30):
|
||||||
|
coro.close()
|
||||||
|
return json.dumps({"result": "ok"})
|
||||||
|
|
||||||
|
mock_run.side_effect = fake_run
|
||||||
handler({})
|
handler({})
|
||||||
# Verify timeout=180 was passed
|
# Verify timeout=180 was passed
|
||||||
call_kwargs = mock_run.call_args
|
call_kwargs = mock_run.call_args
|
||||||
|
|
@ -1277,11 +1313,7 @@ class TestUtilityHandlers:
|
||||||
def _patch_mcp_loop(self):
|
def _patch_mcp_loop(self):
|
||||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||||
def fake_run(coro, timeout=30):
|
def fake_run(coro, timeout=30):
|
||||||
loop = asyncio.new_event_loop()
|
return asyncio.run(coro)
|
||||||
try:
|
|
||||||
return loop.run_until_complete(coro)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||||
|
|
||||||
# -- list_resources --
|
# -- list_resources --
|
||||||
|
|
@ -3048,6 +3080,7 @@ class TestSanitizeMcpNameComponent:
|
||||||
schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
|
schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
|
||||||
handler=lambda *_args, **_kwargs: "{}",
|
handler=lambda *_args, **_kwargs: "{}",
|
||||||
)
|
)
|
||||||
|
reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa")
|
||||||
|
|
||||||
with patch("tools.registry.registry", reg):
|
with patch("tools.registry.registry", reg):
|
||||||
assert validate_toolset("ai.exa/exa") is True
|
assert validate_toolset("ai.exa/exa") is True
|
||||||
|
|
|
||||||
|
|
@ -1138,6 +1138,8 @@ class MCPServerTask:
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Signal the Task to exit and wait for clean resource teardown."""
|
"""Signal the Task to exit and wait for clean resource teardown."""
|
||||||
|
from tools.registry import registry
|
||||||
|
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
if self._task and not self._task.done():
|
if self._task and not self._task.done():
|
||||||
try:
|
try:
|
||||||
|
|
@ -1152,6 +1154,9 @@ class MCPServerTask:
|
||||||
await self._task
|
await self._task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||||
|
registry.deregister(tool_name)
|
||||||
|
self._registered_tool_names = []
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1916,6 +1921,9 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
||||||
)
|
)
|
||||||
registered_names.append(util_name)
|
registered_names.append(util_name)
|
||||||
|
|
||||||
|
if registered_names:
|
||||||
|
registry.register_toolset_alias(name, toolset_name)
|
||||||
|
|
||||||
return registered_names
|
return registered_names
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ class ToolRegistry:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tools: Dict[str, ToolEntry] = {}
|
self._tools: Dict[str, ToolEntry] = {}
|
||||||
self._toolset_checks: Dict[str, Callable] = {}
|
self._toolset_checks: Dict[str, Callable] = {}
|
||||||
|
self._toolset_aliases: Dict[str, str] = {}
|
||||||
# MCP dynamic refresh can mutate the registry while other threads are
|
# MCP dynamic refresh can mutate the registry while other threads are
|
||||||
# reading tool metadata, so keep mutations serialized and readers on
|
# reading tool metadata, so keep mutations serialized and readers on
|
||||||
# stable snapshots.
|
# stable snapshots.
|
||||||
|
|
@ -96,6 +97,27 @@ class ToolRegistry:
|
||||||
if entry.toolset == toolset
|
if entry.toolset == toolset
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def register_toolset_alias(self, alias: str, toolset: str) -> None:
|
||||||
|
"""Register an explicit alias for a canonical toolset name."""
|
||||||
|
with self._lock:
|
||||||
|
existing = self._toolset_aliases.get(alias)
|
||||||
|
if existing and existing != toolset:
|
||||||
|
logger.warning(
|
||||||
|
"Toolset alias collision: '%s' (%s) overwritten by %s",
|
||||||
|
alias, existing, toolset,
|
||||||
|
)
|
||||||
|
self._toolset_aliases[alias] = toolset
|
||||||
|
|
||||||
|
def get_registered_toolset_aliases(self) -> Dict[str, str]:
|
||||||
|
"""Return a snapshot of ``{alias: canonical_toolset}`` mappings."""
|
||||||
|
with self._lock:
|
||||||
|
return dict(self._toolset_aliases)
|
||||||
|
|
||||||
|
def get_toolset_alias_target(self, alias: str) -> Optional[str]:
|
||||||
|
"""Return the canonical toolset name for an alias, or None."""
|
||||||
|
with self._lock:
|
||||||
|
return self._toolset_aliases.get(alias)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Registration
|
# Registration
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -164,11 +186,18 @@ class ToolRegistry:
|
||||||
entry = self._tools.pop(name, None)
|
entry = self._tools.pop(name, None)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
return
|
return
|
||||||
# Drop the toolset check if this was the last tool in that toolset
|
# Drop the toolset check and aliases if this was the last tool in
|
||||||
if entry.toolset in self._toolset_checks and not any(
|
# that toolset.
|
||||||
|
toolset_still_exists = any(
|
||||||
e.toolset == entry.toolset for e in self._tools.values()
|
e.toolset == entry.toolset for e in self._tools.values()
|
||||||
):
|
)
|
||||||
|
if not toolset_still_exists:
|
||||||
self._toolset_checks.pop(entry.toolset, None)
|
self._toolset_checks.pop(entry.toolset, None)
|
||||||
|
self._toolset_aliases = {
|
||||||
|
alias: target
|
||||||
|
for alias, target in self._toolset_aliases.items()
|
||||||
|
if target != entry.toolset
|
||||||
|
}
|
||||||
logger.debug("Deregistered tool: %s", name)
|
logger.debug("Deregistered tool: %s", name)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
|
||||||
51
toolsets.py
51
toolsets.py
|
|
@ -420,14 +420,22 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
|
||||||
registry_toolset = name
|
registry_toolset = name
|
||||||
description = f"Plugin toolset: {name}"
|
description = f"Plugin toolset: {name}"
|
||||||
|
alias_target = registry.get_toolset_alias_target(name)
|
||||||
|
|
||||||
if name not in _get_plugin_toolset_names():
|
if name not in _get_plugin_toolset_names():
|
||||||
registry_toolset = _get_mcp_toolset_aliases().get(name)
|
registry_toolset = alias_target
|
||||||
if not registry_toolset:
|
if not registry_toolset:
|
||||||
return None
|
return None
|
||||||
description = f"MCP server '{name}' tools"
|
description = f"MCP server '{name}' tools"
|
||||||
elif name.startswith("mcp-"):
|
else:
|
||||||
description = f"MCP server '{name[4:]}' tools"
|
reverse_aliases = {
|
||||||
|
canonical: alias
|
||||||
|
for alias, canonical in _get_registry_toolset_aliases().items()
|
||||||
|
if alias not in TOOLSETS
|
||||||
|
}
|
||||||
|
alias = reverse_aliases.get(name)
|
||||||
|
if alias:
|
||||||
|
description = f"MCP server '{alias}' tools"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"description": description,
|
"description": description,
|
||||||
|
|
@ -525,16 +533,13 @@ def _get_plugin_toolset_names() -> Set[str]:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
|
||||||
def _get_mcp_toolset_aliases() -> Dict[str, str]:
|
def _get_registry_toolset_aliases() -> Dict[str, str]:
|
||||||
"""Map raw MCP server names to their live registry toolset names."""
|
"""Return explicit toolset aliases registered in the live registry."""
|
||||||
aliases = {}
|
try:
|
||||||
for toolset_name in _get_plugin_toolset_names():
|
from tools.registry import registry
|
||||||
if not toolset_name.startswith("mcp-"):
|
return registry.get_registered_toolset_aliases()
|
||||||
continue
|
except Exception:
|
||||||
alias = toolset_name[4:]
|
return {}
|
||||||
if alias and alias not in TOOLSETS:
|
|
||||||
aliases[alias] = toolset_name
|
|
||||||
return aliases
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||||
|
|
@ -547,12 +552,13 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||||
Dict: All toolset definitions
|
Dict: All toolset definitions
|
||||||
"""
|
"""
|
||||||
result = dict(TOOLSETS)
|
result = dict(TOOLSETS)
|
||||||
|
aliases = _get_registry_toolset_aliases()
|
||||||
for ts_name in _get_plugin_toolset_names():
|
for ts_name in _get_plugin_toolset_names():
|
||||||
display_name = ts_name
|
display_name = ts_name
|
||||||
if ts_name.startswith("mcp-"):
|
for alias, canonical in aliases.items():
|
||||||
alias = ts_name[4:]
|
if canonical == ts_name and alias not in TOOLSETS:
|
||||||
if alias and alias not in TOOLSETS:
|
|
||||||
display_name = alias
|
display_name = alias
|
||||||
|
break
|
||||||
if display_name in result:
|
if display_name in result:
|
||||||
continue
|
continue
|
||||||
toolset = get_toolset(display_name)
|
toolset = get_toolset(display_name)
|
||||||
|
|
@ -571,13 +577,14 @@ def get_toolset_names() -> List[str]:
|
||||||
List[str]: List of toolset names
|
List[str]: List of toolset names
|
||||||
"""
|
"""
|
||||||
names = set(TOOLSETS.keys())
|
names = set(TOOLSETS.keys())
|
||||||
|
aliases = _get_registry_toolset_aliases()
|
||||||
for ts_name in _get_plugin_toolset_names():
|
for ts_name in _get_plugin_toolset_names():
|
||||||
if ts_name.startswith("mcp-"):
|
for alias, canonical in aliases.items():
|
||||||
alias = ts_name[4:]
|
if canonical == ts_name and alias not in TOOLSETS:
|
||||||
if alias and alias not in TOOLSETS:
|
|
||||||
names.add(alias)
|
names.add(alias)
|
||||||
continue
|
break
|
||||||
names.add(ts_name)
|
else:
|
||||||
|
names.add(ts_name)
|
||||||
return sorted(names)
|
return sorted(names)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -600,7 +607,7 @@ def validate_toolset(name: str) -> bool:
|
||||||
return True
|
return True
|
||||||
if name in _get_plugin_toolset_names():
|
if name in _get_plugin_toolset_names():
|
||||||
return True
|
return True
|
||||||
return name in _get_mcp_toolset_aliases()
|
return name in _get_registry_toolset_aliases()
|
||||||
|
|
||||||
|
|
||||||
def create_custom_toolset(
|
def create_custom_toolset(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue