From c10fea8d264e3289c4c8f5c0b35468aca849b123 Mon Sep 17 00:00:00 2001 From: Greer Guthrie Date: Tue, 14 Apr 2026 15:12:45 -0500 Subject: [PATCH] fix(mcp): make server aliases explicit --- tests/test_toolsets.py | 1 + tests/tools/test_mcp_dynamic_discovery.py | 19 ++++++++ tests/tools/test_mcp_tool.py | 55 ++++++++++++++++++----- tools/mcp_tool.py | 8 ++++ tools/registry.py | 35 +++++++++++++-- toolsets.py | 51 ++++++++++++--------- 6 files changed, 133 insertions(+), 36 deletions(-) diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index a5e2c75bb..9a982bb5b 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -124,6 +124,7 @@ class TestValidateToolset: schema=_make_schema("mcp_dynserver_ping", "Ping"), handler=_dummy_handler, ) + reg.register_toolset_alias("dynserver", "mcp-dynserver") monkeypatch.setattr("tools.registry.registry", reg) diff --git a/tests/tools/test_mcp_dynamic_discovery.py b/tests/tools/test_mcp_dynamic_discovery.py index 991342bd0..891770319 100644 --- a/tests/tools/test_mcp_dynamic_discovery.py +++ b/tests/tools/test_mcp_dynamic_discovery.py @@ -136,6 +136,25 @@ class TestDeregister: # bar still in ts1, so check should remain assert "ts1" in reg._toolset_checks + def test_removes_toolset_alias_when_last_tool_is_removed(self): + reg = ToolRegistry() + reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x) + reg.register_toolset_alias("srv", "mcp-srv") + + reg.deregister("foo") + + assert reg.get_toolset_alias_target("srv") is None + + def test_preserves_toolset_alias_while_toolset_still_exists(self): + reg = ToolRegistry() + reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x) + reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x) + reg.register_toolset_alias("srv", "mcp-srv") + + reg.deregister("foo") + + assert reg.get_toolset_alias_target("srv") == "mcp-srv" + def test_noop_for_unknown_tool(self): reg = ToolRegistry() reg.deregister("nonexistent") # Should not raise diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index f5f15ea41..da46348ea 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -184,11 +184,7 @@ class TestToolHandler: def _patch_mcp_loop(self, coro_side_effect=None): """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" def fake_run(coro, timeout=30): - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() + return asyncio.run(coro) if coro_side_effect: return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) @@ -774,6 +770,42 @@ class TestShutdown: assert len(_servers) == 0 mock_server.shutdown.assert_called_once() + def test_shutdown_deregisters_registered_tools(self): + """shutdown_mcp_servers removes MCP tools and their raw alias.""" + import tools.mcp_tool as mcp_mod + from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers + from tools.registry import registry + from toolsets import resolve_toolset, validate_toolset + + _servers.clear() + registry.register( + name="mcp_test_ping", + toolset="mcp-test", + schema={ + "name": "mcp_test_ping", + "description": "Ping", + "parameters": {"type": "object", "properties": {}}, + }, + handler=lambda *_args, **_kwargs: "{}", + ) + registry.register_toolset_alias("test", "mcp-test") + + server = MCPServerTask("test") + server._registered_tool_names = ["mcp_test_ping"] + _servers["test"] = server + + mcp_mod._ensure_mcp_loop() + try: + assert validate_toolset("test") is True + assert "mcp_test_ping" in resolve_toolset("test") + shutdown_mcp_servers() + finally: + mcp_mod._mcp_loop = None + mcp_mod._mcp_thread = None + + assert "mcp_test_ping" not in registry.get_all_tool_names() + assert validate_toolset("test") is False + def test_shutdown_handles_errors(self): """shutdown_mcp_servers handles errors during close gracefully.""" import tools.mcp_tool as mcp_mod @@ -1177,7 +1209,11 @@ class TestConfigurableTimeouts: try: handler = _make_tool_handler("test_srv", "my_tool", 180) with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run: - mock_run.return_value = json.dumps({"result": "ok"}) + def fake_run(coro, timeout=30): + coro.close() + return json.dumps({"result": "ok"}) + + mock_run.side_effect = fake_run handler({}) # Verify timeout=180 was passed call_kwargs = mock_run.call_args @@ -1277,11 +1313,7 @@ class TestUtilityHandlers: def _patch_mcp_loop(self): """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" def fake_run(coro, timeout=30): - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() + return asyncio.run(coro) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) # -- list_resources -- @@ -3048,6 +3080,7 @@ class TestSanitizeMcpNameComponent: schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}}, handler=lambda *_args, **_kwargs: "{}", ) + reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa") with patch("tools.registry.registry", reg): assert validate_toolset("ai.exa/exa") is True diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 263e4408f..fa8b945ca 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1138,6 +1138,8 @@ class MCPServerTask: async def shutdown(self): """Signal the Task to exit and wait for clean resource teardown.""" + from tools.registry import registry + self._shutdown_event.set() if self._task and not self._task.done(): try: @@ -1152,6 +1154,9 @@ class MCPServerTask: await self._task except asyncio.CancelledError: pass + for tool_name in list(getattr(self, "_registered_tool_names", [])): + registry.deregister(tool_name) + self._registered_tool_names = [] self.session = None @@ -1916,6 +1921,9 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li ) registered_names.append(util_name) + if registered_names: + registry.register_toolset_alias(name, toolset_name) + return registered_names diff --git a/tools/registry.py b/tools/registry.py index b7351cb16..ebda77807 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -52,6 +52,7 @@ class ToolRegistry: def __init__(self): self._tools: Dict[str, ToolEntry] = {} self._toolset_checks: Dict[str, Callable] = {} + self._toolset_aliases: Dict[str, str] = {} # MCP dynamic refresh can mutate the registry while other threads are # reading tool metadata, so keep mutations serialized and readers on # stable snapshots. @@ -96,6 +97,27 @@ class ToolRegistry: if entry.toolset == toolset ) + def register_toolset_alias(self, alias: str, toolset: str) -> None: + """Register an explicit alias for a canonical toolset name.""" + with self._lock: + existing = self._toolset_aliases.get(alias) + if existing and existing != toolset: + logger.warning( + "Toolset alias collision: '%s' (%s) overwritten by %s", + alias, existing, toolset, + ) + self._toolset_aliases[alias] = toolset + + def get_registered_toolset_aliases(self) -> Dict[str, str]: + """Return a snapshot of ``{alias: canonical_toolset}`` mappings.""" + with self._lock: + return dict(self._toolset_aliases) + + def get_toolset_alias_target(self, alias: str) -> Optional[str]: + """Return the canonical toolset name for an alias, or None.""" + with self._lock: + return self._toolset_aliases.get(alias) + # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ @@ -164,11 +186,18 @@ class ToolRegistry: entry = self._tools.pop(name, None) if entry is None: return - # Drop the toolset check if this was the last tool in that toolset - if entry.toolset in self._toolset_checks and not any( + # Drop the toolset check and aliases if this was the last tool in + # that toolset. + toolset_still_exists = any( e.toolset == entry.toolset for e in self._tools.values() - ): + ) + if not toolset_still_exists: self._toolset_checks.pop(entry.toolset, None) + self._toolset_aliases = { + alias: target + for alias, target in self._toolset_aliases.items() + if target != entry.toolset + } logger.debug("Deregistered tool: %s", name) # ------------------------------------------------------------------ diff --git a/toolsets.py b/toolsets.py index 7c843fbfb..09ee8de09 100644 --- a/toolsets.py +++ b/toolsets.py @@ -420,14 +420,22 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]: registry_toolset = name description = f"Plugin toolset: {name}" + alias_target = registry.get_toolset_alias_target(name) if name not in _get_plugin_toolset_names(): - registry_toolset = _get_mcp_toolset_aliases().get(name) + registry_toolset = alias_target if not registry_toolset: return None description = f"MCP server '{name}' tools" - elif name.startswith("mcp-"): - description = f"MCP server '{name[4:]}' tools" + else: + reverse_aliases = { + canonical: alias + for alias, canonical in _get_registry_toolset_aliases().items() + if alias not in TOOLSETS + } + alias = reverse_aliases.get(name) + if alias: + description = f"MCP server '{alias}' tools" return { "description": description, @@ -525,16 +533,13 @@ def _get_plugin_toolset_names() -> Set[str]: return set() -def _get_mcp_toolset_aliases() -> Dict[str, str]: - """Map raw MCP server names to their live registry toolset names.""" - aliases = {} - for toolset_name in _get_plugin_toolset_names(): - if not toolset_name.startswith("mcp-"): - continue - alias = toolset_name[4:] - if alias and alias not in TOOLSETS: - aliases[alias] = toolset_name - return aliases +def _get_registry_toolset_aliases() -> Dict[str, str]: + """Return explicit toolset aliases registered in the live registry.""" + try: + from tools.registry import registry + return registry.get_registered_toolset_aliases() + except Exception: + return {} def get_all_toolsets() -> Dict[str, Dict[str, Any]]: @@ -547,12 +552,13 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]: Dict: All toolset definitions """ result = dict(TOOLSETS) + aliases = _get_registry_toolset_aliases() for ts_name in _get_plugin_toolset_names(): display_name = ts_name - if ts_name.startswith("mcp-"): - alias = ts_name[4:] - if alias and alias not in TOOLSETS: + for alias, canonical in aliases.items(): + if canonical == ts_name and alias not in TOOLSETS: display_name = alias + break if display_name in result: continue toolset = get_toolset(display_name) @@ -571,13 +577,14 @@ def get_toolset_names() -> List[str]: List[str]: List of toolset names """ names = set(TOOLSETS.keys()) + aliases = _get_registry_toolset_aliases() for ts_name in _get_plugin_toolset_names(): - if ts_name.startswith("mcp-"): - alias = ts_name[4:] - if alias and alias not in TOOLSETS: + for alias, canonical in aliases.items(): + if canonical == ts_name and alias not in TOOLSETS: names.add(alias) - continue - names.add(ts_name) + break + else: + names.add(ts_name) return sorted(names) @@ -600,7 +607,7 @@ def validate_toolset(name: str) -> bool: return True if name in _get_plugin_toolset_names(): return True - return name in _get_mcp_toolset_aliases() + return name in _get_registry_toolset_aliases() def create_custom_toolset(