diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index fbe6422d50..a1f8db31ff 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -647,7 +647,7 @@ def get_plugin_toolsets() -> List[tuple]: toolset_tools: Dict[str, List[str]] = {} toolset_plugin: Dict[str, LoadedPlugin] = {} for tool_name in manager._plugin_tool_names: - entry = registry._tools.get(tool_name) + entry = registry.get_entry(tool_name) if not entry: continue ts = entry.toolset @@ -656,7 +656,7 @@ def get_plugin_toolsets() -> List[tuple]: # Map toolsets back to the plugin that registered them for _name, loaded in manager._plugins.items(): for tool_name in loaded.tools_registered: - entry = registry._tools.get(tool_name) + entry = registry.get_entry(tool_name) if entry and entry.toolset in toolset_tools: toolset_plugin.setdefault(entry.toolset, loaded) diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 13c3450702..774bf98938 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -1,7 +1,6 @@ """Tests for toolsets.py — toolset resolution, validation, and composition.""" -import pytest - +from tools.registry import ToolRegistry from toolsets import ( TOOLSETS, get_toolset, @@ -15,6 +14,18 @@ from toolsets import ( ) +def _dummy_handler(args, **kwargs): + return "{}" + + +def _make_schema(name: str, description: str = "test tool"): + return { + "name": name, + "description": description, + "parameters": {"type": "object", "properties": {}}, + } + + class TestGetToolset: def test_known_toolset(self): ts = get_toolset("web") @@ -52,6 +63,25 @@ class TestResolveToolset: def test_unknown_toolset_returns_empty(self): assert resolve_toolset("nonexistent") == [] + def test_plugin_toolset_uses_registry_snapshot(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="plugin_b", + toolset="plugin_example", + schema=_make_schema("plugin_b", "B"), + handler=_dummy_handler, + ) + reg.register( + name="plugin_a", + toolset="plugin_example", + schema=_make_schema("plugin_a", "A"), + handler=_dummy_handler, + ) + + monkeypatch.setattr("tools.registry.registry", reg) + + assert resolve_toolset("plugin_example") == ["plugin_a", "plugin_b"] + def test_all_alias(self): tools = resolve_toolset("all") assert len(tools) > 10 # Should resolve all tools from all toolsets @@ -141,3 +171,20 @@ class TestToolsetConsistency: # All platform toolsets should be identical for ts in tool_sets[1:]: assert ts == tool_sets[0] + + +class TestPluginToolsets: + def test_get_all_toolsets_includes_plugin_toolset(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="plugin_tool", + toolset="plugin_bundle", + schema=_make_schema("plugin_tool", "Plugin tool"), + handler=_dummy_handler, + ) + + monkeypatch.setattr("tools.registry.registry", reg) + + all_toolsets = get_all_toolsets() + assert "plugin_bundle" in all_toolsets + assert all_toolsets["plugin_bundle"]["tools"] == ["plugin_tool"] diff --git a/tests/tools/test_registry.py b/tests/tools/test_registry.py index 455e9f48a8..6b2756886c 100644 --- a/tests/tools/test_registry.py +++ b/tests/tools/test_registry.py @@ -1,6 +1,7 @@ """Tests for the central tool registry.""" import json +import threading from tools.registry import ToolRegistry @@ -167,6 +168,32 @@ class TestToolsetAvailability: ) assert reg.get_all_tool_names() == ["a_tool", "z_tool"] + def test_get_registered_toolset_names(self): + reg = ToolRegistry() + reg.register( + name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler + ) + assert reg.get_registered_toolset_names() == ["alpha", "zeta"] + + def test_get_tool_names_for_toolset(self): + reg = ToolRegistry() + reg.register( + name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler + ) + assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"] + def test_handler_exception_returns_error(self): reg = ToolRegistry() @@ -301,6 +328,22 @@ class TestEmojiMetadata: assert reg.get_emoji("t") == "⚡" +class TestEntryLookup: + def test_get_entry_returns_registered_entry(self): + reg = ToolRegistry() + reg.register( + name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler + ) + entry = reg.get_entry("alpha") + assert entry is not None + assert entry.name == "alpha" + assert entry.toolset == "core" + + def test_get_entry_returns_none_for_unknown_tool(self): + reg = ToolRegistry() + assert reg.get_entry("missing") is None + + class TestSecretCaptureResultContract: def test_secret_request_result_does_not_include_secret_value(self): result = { @@ -309,3 +352,141 @@ class TestSecretCaptureResultContract: "validated": False, } assert "secret" not in json.dumps(result).lower() + + +class TestThreadSafety: + def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=lambda: False, + ) + + entries, toolset_checks = reg._snapshot_state() + + def snapshot_then_mutate(): + reg.deregister("alpha") + return entries, toolset_checks + + monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate) + + toolsets = reg.get_available_toolsets() + assert toolsets["gated"]["available"] is False + assert toolsets["gated"]["tools"] == ["alpha"] + + def test_check_tool_availability_tolerates_concurrent_register(self): + reg = ToolRegistry() + check_started = threading.Event() + writer_done = threading.Event() + errors = [] + result_holder = {} + writer_completed_during_check = {} + + def blocking_check(): + check_started.set() + writer_completed_during_check["value"] = writer_done.wait(timeout=1) + return True + + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=blocking_check, + ) + reg.register( + name="beta", + toolset="plain", + schema=_make_schema("beta"), + handler=_dummy_handler, + ) + + def reader(): + try: + result_holder["value"] = reg.check_tool_availability() + except Exception as exc: # pragma: no cover - exercised on failure only + errors.append(exc) + + def writer(): + assert check_started.wait(timeout=1) + reg.register( + name="gamma", + toolset="new", + schema=_make_schema("gamma"), + handler=_dummy_handler, + ) + writer_done.set() + + reader_thread = threading.Thread(target=reader) + writer_thread = threading.Thread(target=writer) + reader_thread.start() + writer_thread.start() + reader_thread.join(timeout=2) + writer_thread.join(timeout=2) + + assert not reader_thread.is_alive() + assert not writer_thread.is_alive() + assert writer_completed_during_check["value"] is True + assert errors == [] + + available, unavailable = result_holder["value"] + assert "gated" in available + assert "plain" in available + assert unavailable == [] + + def test_get_available_toolsets_tolerates_concurrent_deregister(self): + reg = ToolRegistry() + check_started = threading.Event() + writer_done = threading.Event() + errors = [] + result_holder = {} + writer_completed_during_check = {} + + def blocking_check(): + check_started.set() + writer_completed_during_check["value"] = writer_done.wait(timeout=1) + return True + + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=blocking_check, + ) + reg.register( + name="beta", + toolset="plain", + schema=_make_schema("beta"), + handler=_dummy_handler, + ) + + def reader(): + try: + result_holder["value"] = reg.get_available_toolsets() + except Exception as exc: # pragma: no cover - exercised on failure only + errors.append(exc) + + def writer(): + assert check_started.wait(timeout=1) + reg.deregister("beta") + writer_done.set() + + reader_thread = threading.Thread(target=reader) + writer_thread = threading.Thread(target=writer) + reader_thread.start() + writer_thread.start() + reader_thread.join(timeout=2) + writer_thread.join(timeout=2) + + assert not reader_thread.is_alive() + assert not writer_thread.is_alive() + assert writer_completed_during_check["value"] is True + assert errors == [] + + toolsets = result_holder["value"] + assert "gated" in toolsets + assert toolsets["gated"]["available"] is True diff --git a/tools/registry.py b/tools/registry.py index d3590a42c0..d6aff83486 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -16,6 +16,7 @@ Import chain (circular-import safe): import json import logging +import threading from typing import Callable, Dict, List, Optional, Set logger = logging.getLogger(__name__) @@ -51,6 +52,49 @@ class ToolRegistry: def __init__(self): self._tools: Dict[str, ToolEntry] = {} self._toolset_checks: Dict[str, Callable] = {} + # MCP dynamic refresh can mutate the registry while other threads are + # reading tool metadata, so keep mutations serialized and readers on + # stable snapshots. + self._lock = threading.RLock() + + def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]: + """Return a coherent snapshot of registry entries and toolset checks.""" + with self._lock: + return list(self._tools.values()), dict(self._toolset_checks) + + def _snapshot_entries(self) -> List[ToolEntry]: + """Return a stable snapshot of registered tool entries.""" + return self._snapshot_state()[0] + + def _snapshot_toolset_checks(self) -> Dict[str, Callable]: + """Return a stable snapshot of toolset availability checks.""" + return self._snapshot_state()[1] + + def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool: + """Run a toolset check, treating missing or failing checks as unavailable/available.""" + if not check: + return True + try: + return bool(check()) + except Exception: + logger.debug("Toolset %s check raised; marking unavailable", toolset) + return False + + def get_entry(self, name: str) -> Optional[ToolEntry]: + """Return a registered tool entry by name, or None.""" + with self._lock: + return self._tools.get(name) + + def get_registered_toolset_names(self) -> List[str]: + """Return sorted unique toolset names present in the registry.""" + return sorted({entry.toolset for entry in self._snapshot_entries()}) + + def get_tool_names_for_toolset(self, toolset: str) -> List[str]: + """Return sorted tool names registered under a given toolset.""" + return sorted( + entry.name for entry in self._snapshot_entries() + if entry.toolset == toolset + ) # ------------------------------------------------------------------ # Registration @@ -70,27 +114,28 @@ class ToolRegistry: max_result_size_chars: int | float | None = None, ): """Register a tool. Called at module-import time by each tool file.""" - existing = self._tools.get(name) - if existing and existing.toolset != toolset: - logger.warning( - "Tool name collision: '%s' (toolset '%s') is being " - "overwritten by toolset '%s'", - name, existing.toolset, toolset, + with self._lock: + existing = self._tools.get(name) + if existing and existing.toolset != toolset: + logger.warning( + "Tool name collision: '%s' (toolset '%s') is being " + "overwritten by toolset '%s'", + name, existing.toolset, toolset, + ) + self._tools[name] = ToolEntry( + name=name, + toolset=toolset, + schema=schema, + handler=handler, + check_fn=check_fn, + requires_env=requires_env or [], + is_async=is_async, + description=description or schema.get("description", ""), + emoji=emoji, + max_result_size_chars=max_result_size_chars, ) - self._tools[name] = ToolEntry( - name=name, - toolset=toolset, - schema=schema, - handler=handler, - check_fn=check_fn, - requires_env=requires_env or [], - is_async=is_async, - description=description or schema.get("description", ""), - emoji=emoji, - max_result_size_chars=max_result_size_chars, - ) - if check_fn and toolset not in self._toolset_checks: - self._toolset_checks[toolset] = check_fn + if check_fn and toolset not in self._toolset_checks: + self._toolset_checks[toolset] = check_fn def deregister(self, name: str) -> None: """Remove a tool from the registry. @@ -99,14 +144,15 @@ class ToolRegistry: same toolset. Used by MCP dynamic tool discovery to nuke-and-repave when a server sends ``notifications/tools/list_changed``. """ - entry = self._tools.pop(name, None) - if entry is None: - return - # Drop the toolset check if this was the last tool in that toolset - if entry.toolset in self._toolset_checks and not any( - e.toolset == entry.toolset for e in self._tools.values() - ): - self._toolset_checks.pop(entry.toolset, None) + with self._lock: + entry = self._tools.pop(name, None) + if entry is None: + return + # Drop the toolset check if this was the last tool in that toolset + if entry.toolset in self._toolset_checks and not any( + e.toolset == entry.toolset for e in self._tools.values() + ): + self._toolset_checks.pop(entry.toolset, None) logger.debug("Deregistered tool: %s", name) # ------------------------------------------------------------------ @@ -121,8 +167,9 @@ class ToolRegistry: """ result = [] check_results: Dict[Callable, bool] = {} + entries_by_name = {entry.name: entry for entry in self._snapshot_entries()} for name in sorted(tool_names): - entry = self._tools.get(name) + entry = entries_by_name.get(name) if not entry: continue if entry.check_fn: @@ -153,7 +200,7 @@ class ToolRegistry: * All exceptions are caught and returned as ``{"error": "..."}`` for consistent error format. """ - entry = self._tools.get(name) + entry = self.get_entry(name) if not entry: return json.dumps({"error": f"Unknown tool: {name}"}) try: @@ -171,7 +218,7 @@ class ToolRegistry: def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float: """Return per-tool max result size, or *default* (or global default).""" - entry = self._tools.get(name) + entry = self.get_entry(name) if entry and entry.max_result_size_chars is not None: return entry.max_result_size_chars if default is not None: @@ -181,7 +228,7 @@ class ToolRegistry: def get_all_tool_names(self) -> List[str]: """Return sorted list of all registered tool names.""" - return sorted(self._tools.keys()) + return sorted(entry.name for entry in self._snapshot_entries()) def get_schema(self, name: str) -> Optional[dict]: """Return a tool's raw schema dict, bypassing check_fn filtering. @@ -189,22 +236,22 @@ class ToolRegistry: Useful for token estimation and introspection where availability doesn't matter — only the schema content does. """ - entry = self._tools.get(name) + entry = self.get_entry(name) return entry.schema if entry else None def get_toolset_for_tool(self, name: str) -> Optional[str]: """Return the toolset a tool belongs to, or None.""" - entry = self._tools.get(name) + entry = self.get_entry(name) return entry.toolset if entry else None def get_emoji(self, name: str, default: str = "⚡") -> str: """Return the emoji for a tool, or *default* if unset.""" - entry = self._tools.get(name) + entry = self.get_entry(name) return (entry.emoji if entry and entry.emoji else default) def get_tool_to_toolset_map(self) -> Dict[str, str]: """Return ``{tool_name: toolset_name}`` for every registered tool.""" - return {name: e.toolset for name, e in self._tools.items()} + return {entry.name: entry.toolset for entry in self._snapshot_entries()} def is_toolset_available(self, toolset: str) -> bool: """Check if a toolset's requirements are met. @@ -212,28 +259,30 @@ class ToolRegistry: Returns False (rather than crashing) when the check function raises an unexpected exception (e.g. network error, missing import, bad config). """ - check = self._toolset_checks.get(toolset) - if not check: - return True - try: - return bool(check()) - except Exception: - logger.debug("Toolset %s check raised; marking unavailable", toolset) - return False + with self._lock: + check = self._toolset_checks.get(toolset) + return self._evaluate_toolset_check(toolset, check) def check_toolset_requirements(self) -> Dict[str, bool]: """Return ``{toolset: available_bool}`` for every toolset.""" - toolsets = set(e.toolset for e in self._tools.values()) - return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)} + entries, toolset_checks = self._snapshot_state() + toolsets = sorted({entry.toolset for entry in entries}) + return { + toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset)) + for toolset in toolsets + } def get_available_toolsets(self) -> Dict[str, dict]: """Return toolset metadata for UI display.""" toolsets: Dict[str, dict] = {} - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts not in toolsets: toolsets[ts] = { - "available": self.is_toolset_available(ts), + "available": self._evaluate_toolset_check( + ts, toolset_checks.get(ts) + ), "tools": [], "description": "", "requirements": [], @@ -248,13 +297,14 @@ class ToolRegistry: def get_toolset_requirements(self) -> Dict[str, dict]: """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat.""" result: Dict[str, dict] = {} - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts not in result: result[ts] = { "name": ts, "env_vars": [], - "check_fn": self._toolset_checks.get(ts), + "check_fn": toolset_checks.get(ts), "setup_url": None, "tools": [], } @@ -270,18 +320,19 @@ class ToolRegistry: available = [] unavailable = [] seen = set() - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts in seen: continue seen.add(ts) - if self.is_toolset_available(ts): + if self._evaluate_toolset_check(ts, toolset_checks.get(ts)): available.append(ts) else: unavailable.append({ "name": ts, "env_vars": entry.requires_env, - "tools": [e.name for e in self._tools.values() if e.toolset == ts], + "tools": [e.name for e in entries if e.toolset == ts], }) return available, unavailable diff --git a/toolsets.py b/toolsets.py index 57e03d2500..da7a2d2b2c 100644 --- a/toolsets.py +++ b/toolsets.py @@ -449,7 +449,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: if name in _get_plugin_toolset_names(): try: from tools.registry import registry - return [e.name for e in registry._tools.values() if e.toolset == name] + return registry.get_tool_names_for_toolset(name) except Exception: pass return [] @@ -495,9 +495,9 @@ def _get_plugin_toolset_names() -> Set[str]: try: from tools.registry import registry return { - entry.toolset - for entry in registry._tools.values() - if entry.toolset not in TOOLSETS + toolset_name + for toolset_name in registry.get_registered_toolset_names() + if toolset_name not in TOOLSETS } except Exception: return set() @@ -518,7 +518,7 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]: if ts_name not in result: try: from tools.registry import registry - tools = [e.name for e in registry._tools.values() if e.toolset == ts_name] + tools = registry.get_tool_names_for_toolset(ts_name) result[ts_name] = { "description": f"Plugin toolset: {ts_name}", "tools": tools,