mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): make server aliases explicit
This commit is contained in:
parent
cda64a5961
commit
c10fea8d26
6 changed files with 133 additions and 36 deletions
|
|
@ -124,6 +124,7 @@ class TestValidateToolset:
|
|||
schema=_make_schema("mcp_dynserver_ping", "Ping"),
|
||||
handler=_dummy_handler,
|
||||
)
|
||||
reg.register_toolset_alias("dynserver", "mcp-dynserver")
|
||||
|
||||
monkeypatch.setattr("tools.registry.registry", reg)
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,25 @@ class TestDeregister:
|
|||
# bar still in ts1, so check should remain
|
||||
assert "ts1" in reg._toolset_checks
|
||||
|
||||
def test_removes_toolset_alias_when_last_tool_is_removed(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register_toolset_alias("srv", "mcp-srv")
|
||||
|
||||
reg.deregister("foo")
|
||||
|
||||
assert reg.get_toolset_alias_target("srv") is None
|
||||
|
||||
def test_preserves_toolset_alias_while_toolset_still_exists(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x)
|
||||
reg.register_toolset_alias("srv", "mcp-srv")
|
||||
|
||||
reg.deregister("foo")
|
||||
|
||||
assert reg.get_toolset_alias_target("srv") == "mcp-srv"
|
||||
|
||||
def test_noop_for_unknown_tool(self):
|
||||
reg = ToolRegistry()
|
||||
reg.deregister("nonexistent") # Should not raise
|
||||
|
|
|
|||
|
|
@ -184,11 +184,7 @@ class TestToolHandler:
|
|||
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
return asyncio.run(coro)
|
||||
if coro_side_effect:
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
|
@ -774,6 +770,42 @@ class TestShutdown:
|
|||
assert len(_servers) == 0
|
||||
mock_server.shutdown.assert_called_once()
|
||||
|
||||
def test_shutdown_deregisters_registered_tools(self):
|
||||
"""shutdown_mcp_servers removes MCP tools and their raw alias."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers
|
||||
from tools.registry import registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
|
||||
_servers.clear()
|
||||
registry.register(
|
||||
name="mcp_test_ping",
|
||||
toolset="mcp-test",
|
||||
schema={
|
||||
"name": "mcp_test_ping",
|
||||
"description": "Ping",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
registry.register_toolset_alias("test", "mcp-test")
|
||||
|
||||
server = MCPServerTask("test")
|
||||
server._registered_tool_names = ["mcp_test_ping"]
|
||||
_servers["test"] = server
|
||||
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
try:
|
||||
assert validate_toolset("test") is True
|
||||
assert "mcp_test_ping" in resolve_toolset("test")
|
||||
shutdown_mcp_servers()
|
||||
finally:
|
||||
mcp_mod._mcp_loop = None
|
||||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert "mcp_test_ping" not in registry.get_all_tool_names()
|
||||
assert validate_toolset("test") is False
|
||||
|
||||
def test_shutdown_handles_errors(self):
|
||||
"""shutdown_mcp_servers handles errors during close gracefully."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
|
|
@ -1177,7 +1209,11 @@ class TestConfigurableTimeouts:
|
|||
try:
|
||||
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
||||
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
||||
mock_run.return_value = json.dumps({"result": "ok"})
|
||||
def fake_run(coro, timeout=30):
|
||||
coro.close()
|
||||
return json.dumps({"result": "ok"})
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
handler({})
|
||||
# Verify timeout=180 was passed
|
||||
call_kwargs = mock_run.call_args
|
||||
|
|
@ -1277,11 +1313,7 @@ class TestUtilityHandlers:
|
|||
def _patch_mcp_loop(self):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
return asyncio.run(coro)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
# -- list_resources --
|
||||
|
|
@ -3048,6 +3080,7 @@ class TestSanitizeMcpNameComponent:
|
|||
schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa")
|
||||
|
||||
with patch("tools.registry.registry", reg):
|
||||
assert validate_toolset("ai.exa/exa") is True
|
||||
|
|
|
|||
|
|
@ -1138,6 +1138,8 @@ class MCPServerTask:
|
|||
|
||||
async def shutdown(self):
|
||||
"""Signal the Task to exit and wait for clean resource teardown."""
|
||||
from tools.registry import registry
|
||||
|
||||
self._shutdown_event.set()
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
|
|
@ -1152,6 +1154,9 @@ class MCPServerTask:
|
|||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||
registry.deregister(tool_name)
|
||||
self._registered_tool_names = []
|
||||
self.session = None
|
||||
|
||||
|
||||
|
|
@ -1916,6 +1921,9 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
|||
)
|
||||
registered_names.append(util_name)
|
||||
|
||||
if registered_names:
|
||||
registry.register_toolset_alias(name, toolset_name)
|
||||
|
||||
return registered_names
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class ToolRegistry:
|
|||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
self._toolset_aliases: Dict[str, str] = {}
|
||||
# MCP dynamic refresh can mutate the registry while other threads are
|
||||
# reading tool metadata, so keep mutations serialized and readers on
|
||||
# stable snapshots.
|
||||
|
|
@ -96,6 +97,27 @@ class ToolRegistry:
|
|||
if entry.toolset == toolset
|
||||
)
|
||||
|
||||
def register_toolset_alias(self, alias: str, toolset: str) -> None:
|
||||
"""Register an explicit alias for a canonical toolset name."""
|
||||
with self._lock:
|
||||
existing = self._toolset_aliases.get(alias)
|
||||
if existing and existing != toolset:
|
||||
logger.warning(
|
||||
"Toolset alias collision: '%s' (%s) overwritten by %s",
|
||||
alias, existing, toolset,
|
||||
)
|
||||
self._toolset_aliases[alias] = toolset
|
||||
|
||||
def get_registered_toolset_aliases(self) -> Dict[str, str]:
|
||||
"""Return a snapshot of ``{alias: canonical_toolset}`` mappings."""
|
||||
with self._lock:
|
||||
return dict(self._toolset_aliases)
|
||||
|
||||
def get_toolset_alias_target(self, alias: str) -> Optional[str]:
|
||||
"""Return the canonical toolset name for an alias, or None."""
|
||||
with self._lock:
|
||||
return self._toolset_aliases.get(alias)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -164,11 +186,18 @@ class ToolRegistry:
|
|||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
# Drop the toolset check and aliases if this was the last tool in
|
||||
# that toolset.
|
||||
toolset_still_exists = any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
)
|
||||
if not toolset_still_exists:
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
self._toolset_aliases = {
|
||||
alias: target
|
||||
for alias, target in self._toolset_aliases.items()
|
||||
if target != entry.toolset
|
||||
}
|
||||
logger.debug("Deregistered tool: %s", name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
51
toolsets.py
51
toolsets.py
|
|
@ -420,14 +420,22 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]:
|
|||
|
||||
registry_toolset = name
|
||||
description = f"Plugin toolset: {name}"
|
||||
alias_target = registry.get_toolset_alias_target(name)
|
||||
|
||||
if name not in _get_plugin_toolset_names():
|
||||
registry_toolset = _get_mcp_toolset_aliases().get(name)
|
||||
registry_toolset = alias_target
|
||||
if not registry_toolset:
|
||||
return None
|
||||
description = f"MCP server '{name}' tools"
|
||||
elif name.startswith("mcp-"):
|
||||
description = f"MCP server '{name[4:]}' tools"
|
||||
else:
|
||||
reverse_aliases = {
|
||||
canonical: alias
|
||||
for alias, canonical in _get_registry_toolset_aliases().items()
|
||||
if alias not in TOOLSETS
|
||||
}
|
||||
alias = reverse_aliases.get(name)
|
||||
if alias:
|
||||
description = f"MCP server '{alias}' tools"
|
||||
|
||||
return {
|
||||
"description": description,
|
||||
|
|
@ -525,16 +533,13 @@ def _get_plugin_toolset_names() -> Set[str]:
|
|||
return set()
|
||||
|
||||
|
||||
def _get_mcp_toolset_aliases() -> Dict[str, str]:
|
||||
"""Map raw MCP server names to their live registry toolset names."""
|
||||
aliases = {}
|
||||
for toolset_name in _get_plugin_toolset_names():
|
||||
if not toolset_name.startswith("mcp-"):
|
||||
continue
|
||||
alias = toolset_name[4:]
|
||||
if alias and alias not in TOOLSETS:
|
||||
aliases[alias] = toolset_name
|
||||
return aliases
|
||||
def _get_registry_toolset_aliases() -> Dict[str, str]:
|
||||
"""Return explicit toolset aliases registered in the live registry."""
|
||||
try:
|
||||
from tools.registry import registry
|
||||
return registry.get_registered_toolset_aliases()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
|
|
@ -547,12 +552,13 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
|||
Dict: All toolset definitions
|
||||
"""
|
||||
result = dict(TOOLSETS)
|
||||
aliases = _get_registry_toolset_aliases()
|
||||
for ts_name in _get_plugin_toolset_names():
|
||||
display_name = ts_name
|
||||
if ts_name.startswith("mcp-"):
|
||||
alias = ts_name[4:]
|
||||
if alias and alias not in TOOLSETS:
|
||||
for alias, canonical in aliases.items():
|
||||
if canonical == ts_name and alias not in TOOLSETS:
|
||||
display_name = alias
|
||||
break
|
||||
if display_name in result:
|
||||
continue
|
||||
toolset = get_toolset(display_name)
|
||||
|
|
@ -571,13 +577,14 @@ def get_toolset_names() -> List[str]:
|
|||
List[str]: List of toolset names
|
||||
"""
|
||||
names = set(TOOLSETS.keys())
|
||||
aliases = _get_registry_toolset_aliases()
|
||||
for ts_name in _get_plugin_toolset_names():
|
||||
if ts_name.startswith("mcp-"):
|
||||
alias = ts_name[4:]
|
||||
if alias and alias not in TOOLSETS:
|
||||
for alias, canonical in aliases.items():
|
||||
if canonical == ts_name and alias not in TOOLSETS:
|
||||
names.add(alias)
|
||||
continue
|
||||
names.add(ts_name)
|
||||
break
|
||||
else:
|
||||
names.add(ts_name)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
|
|
@ -600,7 +607,7 @@ def validate_toolset(name: str) -> bool:
|
|||
return True
|
||||
if name in _get_plugin_toolset_names():
|
||||
return True
|
||||
return name in _get_mcp_toolset_aliases()
|
||||
return name in _get_registry_toolset_aliases()
|
||||
|
||||
|
||||
def create_custom_toolset(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue