fix(mcp): make server aliases explicit

This commit is contained in:
Greer Guthrie 2026-04-14 15:12:45 -05:00 committed by Teknium
parent cda64a5961
commit c10fea8d26
6 changed files with 133 additions and 36 deletions

View file

@ -124,6 +124,7 @@ class TestValidateToolset:
schema=_make_schema("mcp_dynserver_ping", "Ping"), schema=_make_schema("mcp_dynserver_ping", "Ping"),
handler=_dummy_handler, handler=_dummy_handler,
) )
reg.register_toolset_alias("dynserver", "mcp-dynserver")
monkeypatch.setattr("tools.registry.registry", reg) monkeypatch.setattr("tools.registry.registry", reg)

View file

@ -136,6 +136,25 @@ class TestDeregister:
# bar still in ts1, so check should remain # bar still in ts1, so check should remain
assert "ts1" in reg._toolset_checks assert "ts1" in reg._toolset_checks
def test_removes_toolset_alias_when_last_tool_is_removed(self):
reg = ToolRegistry()
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
reg.register_toolset_alias("srv", "mcp-srv")
reg.deregister("foo")
assert reg.get_toolset_alias_target("srv") is None
def test_preserves_toolset_alias_while_toolset_still_exists(self):
reg = ToolRegistry()
reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x)
reg.register_toolset_alias("srv", "mcp-srv")
reg.deregister("foo")
assert reg.get_toolset_alias_target("srv") == "mcp-srv"
def test_noop_for_unknown_tool(self): def test_noop_for_unknown_tool(self):
reg = ToolRegistry() reg = ToolRegistry()
reg.deregister("nonexistent") # Should not raise reg.deregister("nonexistent") # Should not raise

View file

@ -184,11 +184,7 @@ class TestToolHandler:
def _patch_mcp_loop(self, coro_side_effect=None): def _patch_mcp_loop(self, coro_side_effect=None):
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" """Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
def fake_run(coro, timeout=30): def fake_run(coro, timeout=30):
loop = asyncio.new_event_loop() return asyncio.run(coro)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
if coro_side_effect: if coro_side_effect:
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
@ -774,6 +770,42 @@ class TestShutdown:
assert len(_servers) == 0 assert len(_servers) == 0
mock_server.shutdown.assert_called_once() mock_server.shutdown.assert_called_once()
def test_shutdown_deregisters_registered_tools(self):
"""shutdown_mcp_servers removes MCP tools and their raw alias."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers
from tools.registry import registry
from toolsets import resolve_toolset, validate_toolset
_servers.clear()
registry.register(
name="mcp_test_ping",
toolset="mcp-test",
schema={
"name": "mcp_test_ping",
"description": "Ping",
"parameters": {"type": "object", "properties": {}},
},
handler=lambda *_args, **_kwargs: "{}",
)
registry.register_toolset_alias("test", "mcp-test")
server = MCPServerTask("test")
server._registered_tool_names = ["mcp_test_ping"]
_servers["test"] = server
mcp_mod._ensure_mcp_loop()
try:
assert validate_toolset("test") is True
assert "mcp_test_ping" in resolve_toolset("test")
shutdown_mcp_servers()
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert "mcp_test_ping" not in registry.get_all_tool_names()
assert validate_toolset("test") is False
def test_shutdown_handles_errors(self): def test_shutdown_handles_errors(self):
"""shutdown_mcp_servers handles errors during close gracefully.""" """shutdown_mcp_servers handles errors during close gracefully."""
import tools.mcp_tool as mcp_mod import tools.mcp_tool as mcp_mod
@ -1177,7 +1209,11 @@ class TestConfigurableTimeouts:
try: try:
handler = _make_tool_handler("test_srv", "my_tool", 180) handler = _make_tool_handler("test_srv", "my_tool", 180)
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run: with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
mock_run.return_value = json.dumps({"result": "ok"}) def fake_run(coro, timeout=30):
coro.close()
return json.dumps({"result": "ok"})
mock_run.side_effect = fake_run
handler({}) handler({})
# Verify timeout=180 was passed # Verify timeout=180 was passed
call_kwargs = mock_run.call_args call_kwargs = mock_run.call_args
@ -1277,11 +1313,7 @@ class TestUtilityHandlers:
def _patch_mcp_loop(self): def _patch_mcp_loop(self):
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" """Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
def fake_run(coro, timeout=30): def fake_run(coro, timeout=30):
loop = asyncio.new_event_loop() return asyncio.run(coro)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
# -- list_resources -- # -- list_resources --
@ -3048,6 +3080,7 @@ class TestSanitizeMcpNameComponent:
schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}}, schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
handler=lambda *_args, **_kwargs: "{}", handler=lambda *_args, **_kwargs: "{}",
) )
reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa")
with patch("tools.registry.registry", reg): with patch("tools.registry.registry", reg):
assert validate_toolset("ai.exa/exa") is True assert validate_toolset("ai.exa/exa") is True

View file

@ -1138,6 +1138,8 @@ class MCPServerTask:
async def shutdown(self): async def shutdown(self):
"""Signal the Task to exit and wait for clean resource teardown.""" """Signal the Task to exit and wait for clean resource teardown."""
from tools.registry import registry
self._shutdown_event.set() self._shutdown_event.set()
if self._task and not self._task.done(): if self._task and not self._task.done():
try: try:
@ -1152,6 +1154,9 @@ class MCPServerTask:
await self._task await self._task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
for tool_name in list(getattr(self, "_registered_tool_names", [])):
registry.deregister(tool_name)
self._registered_tool_names = []
self.session = None self.session = None
@ -1916,6 +1921,9 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
) )
registered_names.append(util_name) registered_names.append(util_name)
if registered_names:
registry.register_toolset_alias(name, toolset_name)
return registered_names return registered_names

View file

@ -52,6 +52,7 @@ class ToolRegistry:
def __init__(self): def __init__(self):
self._tools: Dict[str, ToolEntry] = {} self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {} self._toolset_checks: Dict[str, Callable] = {}
self._toolset_aliases: Dict[str, str] = {}
# MCP dynamic refresh can mutate the registry while other threads are # MCP dynamic refresh can mutate the registry while other threads are
# reading tool metadata, so keep mutations serialized and readers on # reading tool metadata, so keep mutations serialized and readers on
# stable snapshots. # stable snapshots.
@ -96,6 +97,27 @@ class ToolRegistry:
if entry.toolset == toolset if entry.toolset == toolset
) )
def register_toolset_alias(self, alias: str, toolset: str) -> None:
"""Register an explicit alias for a canonical toolset name."""
with self._lock:
existing = self._toolset_aliases.get(alias)
if existing and existing != toolset:
logger.warning(
"Toolset alias collision: '%s' (%s) overwritten by %s",
alias, existing, toolset,
)
self._toolset_aliases[alias] = toolset
def get_registered_toolset_aliases(self) -> Dict[str, str]:
"""Return a snapshot of ``{alias: canonical_toolset}`` mappings."""
with self._lock:
return dict(self._toolset_aliases)
def get_toolset_alias_target(self, alias: str) -> Optional[str]:
"""Return the canonical toolset name for an alias, or None."""
with self._lock:
return self._toolset_aliases.get(alias)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Registration # Registration
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -164,11 +186,18 @@ class ToolRegistry:
entry = self._tools.pop(name, None) entry = self._tools.pop(name, None)
if entry is None: if entry is None:
return return
# Drop the toolset check if this was the last tool in that toolset # Drop the toolset check and aliases if this was the last tool in
if entry.toolset in self._toolset_checks and not any( # that toolset.
toolset_still_exists = any(
e.toolset == entry.toolset for e in self._tools.values() e.toolset == entry.toolset for e in self._tools.values()
): )
if not toolset_still_exists:
self._toolset_checks.pop(entry.toolset, None) self._toolset_checks.pop(entry.toolset, None)
self._toolset_aliases = {
alias: target
for alias, target in self._toolset_aliases.items()
if target != entry.toolset
}
logger.debug("Deregistered tool: %s", name) logger.debug("Deregistered tool: %s", name)
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View file

@ -420,14 +420,22 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]:
registry_toolset = name registry_toolset = name
description = f"Plugin toolset: {name}" description = f"Plugin toolset: {name}"
alias_target = registry.get_toolset_alias_target(name)
if name not in _get_plugin_toolset_names(): if name not in _get_plugin_toolset_names():
registry_toolset = _get_mcp_toolset_aliases().get(name) registry_toolset = alias_target
if not registry_toolset: if not registry_toolset:
return None return None
description = f"MCP server '{name}' tools" description = f"MCP server '{name}' tools"
elif name.startswith("mcp-"): else:
description = f"MCP server '{name[4:]}' tools" reverse_aliases = {
canonical: alias
for alias, canonical in _get_registry_toolset_aliases().items()
if alias not in TOOLSETS
}
alias = reverse_aliases.get(name)
if alias:
description = f"MCP server '{alias}' tools"
return { return {
"description": description, "description": description,
@ -525,16 +533,13 @@ def _get_plugin_toolset_names() -> Set[str]:
return set() return set()
def _get_mcp_toolset_aliases() -> Dict[str, str]: def _get_registry_toolset_aliases() -> Dict[str, str]:
"""Map raw MCP server names to their live registry toolset names.""" """Return explicit toolset aliases registered in the live registry."""
aliases = {} try:
for toolset_name in _get_plugin_toolset_names(): from tools.registry import registry
if not toolset_name.startswith("mcp-"): return registry.get_registered_toolset_aliases()
continue except Exception:
alias = toolset_name[4:] return {}
if alias and alias not in TOOLSETS:
aliases[alias] = toolset_name
return aliases
def get_all_toolsets() -> Dict[str, Dict[str, Any]]: def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
@ -547,12 +552,13 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
Dict: All toolset definitions Dict: All toolset definitions
""" """
result = dict(TOOLSETS) result = dict(TOOLSETS)
aliases = _get_registry_toolset_aliases()
for ts_name in _get_plugin_toolset_names(): for ts_name in _get_plugin_toolset_names():
display_name = ts_name display_name = ts_name
if ts_name.startswith("mcp-"): for alias, canonical in aliases.items():
alias = ts_name[4:] if canonical == ts_name and alias not in TOOLSETS:
if alias and alias not in TOOLSETS:
display_name = alias display_name = alias
break
if display_name in result: if display_name in result:
continue continue
toolset = get_toolset(display_name) toolset = get_toolset(display_name)
@ -571,13 +577,14 @@ def get_toolset_names() -> List[str]:
List[str]: List of toolset names List[str]: List of toolset names
""" """
names = set(TOOLSETS.keys()) names = set(TOOLSETS.keys())
aliases = _get_registry_toolset_aliases()
for ts_name in _get_plugin_toolset_names(): for ts_name in _get_plugin_toolset_names():
if ts_name.startswith("mcp-"): for alias, canonical in aliases.items():
alias = ts_name[4:] if canonical == ts_name and alias not in TOOLSETS:
if alias and alias not in TOOLSETS:
names.add(alias) names.add(alias)
continue break
names.add(ts_name) else:
names.add(ts_name)
return sorted(names) return sorted(names)
@ -600,7 +607,7 @@ def validate_toolset(name: str) -> bool:
return True return True
if name in _get_plugin_toolset_names(): if name in _get_plugin_toolset_names():
return True return True
return name in _get_mcp_toolset_aliases() return name in _get_registry_toolset_aliases()
def create_custom_toolset( def create_custom_toolset(