mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: make tool registry reads thread-safe
This commit is contained in:
parent
6dc8f8e9c0
commit
c7e2fe655a
5 changed files with 341 additions and 62 deletions
|
|
@ -647,7 +647,7 @@ def get_plugin_toolsets() -> List[tuple]:
|
||||||
toolset_tools: Dict[str, List[str]] = {}
|
toolset_tools: Dict[str, List[str]] = {}
|
||||||
toolset_plugin: Dict[str, LoadedPlugin] = {}
|
toolset_plugin: Dict[str, LoadedPlugin] = {}
|
||||||
for tool_name in manager._plugin_tool_names:
|
for tool_name in manager._plugin_tool_names:
|
||||||
entry = registry._tools.get(tool_name)
|
entry = registry.get_entry(tool_name)
|
||||||
if not entry:
|
if not entry:
|
||||||
continue
|
continue
|
||||||
ts = entry.toolset
|
ts = entry.toolset
|
||||||
|
|
@ -656,7 +656,7 @@ def get_plugin_toolsets() -> List[tuple]:
|
||||||
# Map toolsets back to the plugin that registered them
|
# Map toolsets back to the plugin that registered them
|
||||||
for _name, loaded in manager._plugins.items():
|
for _name, loaded in manager._plugins.items():
|
||||||
for tool_name in loaded.tools_registered:
|
for tool_name in loaded.tools_registered:
|
||||||
entry = registry._tools.get(tool_name)
|
entry = registry.get_entry(tool_name)
|
||||||
if entry and entry.toolset in toolset_tools:
|
if entry and entry.toolset in toolset_tools:
|
||||||
toolset_plugin.setdefault(entry.toolset, loaded)
|
toolset_plugin.setdefault(entry.toolset, loaded)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""Tests for toolsets.py — toolset resolution, validation, and composition."""
|
"""Tests for toolsets.py — toolset resolution, validation, and composition."""
|
||||||
|
|
||||||
import pytest
|
from tools.registry import ToolRegistry
|
||||||
|
|
||||||
from toolsets import (
|
from toolsets import (
|
||||||
TOOLSETS,
|
TOOLSETS,
|
||||||
get_toolset,
|
get_toolset,
|
||||||
|
|
@ -15,6 +14,18 @@ from toolsets import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_handler(args, **kwargs):
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_schema(name: str, description: str = "test tool"):
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestGetToolset:
|
class TestGetToolset:
|
||||||
def test_known_toolset(self):
|
def test_known_toolset(self):
|
||||||
ts = get_toolset("web")
|
ts = get_toolset("web")
|
||||||
|
|
@ -52,6 +63,25 @@ class TestResolveToolset:
|
||||||
def test_unknown_toolset_returns_empty(self):
|
def test_unknown_toolset_returns_empty(self):
|
||||||
assert resolve_toolset("nonexistent") == []
|
assert resolve_toolset("nonexistent") == []
|
||||||
|
|
||||||
|
def test_plugin_toolset_uses_registry_snapshot(self, monkeypatch):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="plugin_b",
|
||||||
|
toolset="plugin_example",
|
||||||
|
schema=_make_schema("plugin_b", "B"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="plugin_a",
|
||||||
|
toolset="plugin_example",
|
||||||
|
schema=_make_schema("plugin_a", "A"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("tools.registry.registry", reg)
|
||||||
|
|
||||||
|
assert resolve_toolset("plugin_example") == ["plugin_a", "plugin_b"]
|
||||||
|
|
||||||
def test_all_alias(self):
|
def test_all_alias(self):
|
||||||
tools = resolve_toolset("all")
|
tools = resolve_toolset("all")
|
||||||
assert len(tools) > 10 # Should resolve all tools from all toolsets
|
assert len(tools) > 10 # Should resolve all tools from all toolsets
|
||||||
|
|
@ -141,3 +171,20 @@ class TestToolsetConsistency:
|
||||||
# All platform toolsets should be identical
|
# All platform toolsets should be identical
|
||||||
for ts in tool_sets[1:]:
|
for ts in tool_sets[1:]:
|
||||||
assert ts == tool_sets[0]
|
assert ts == tool_sets[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginToolsets:
|
||||||
|
def test_get_all_toolsets_includes_plugin_toolset(self, monkeypatch):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="plugin_tool",
|
||||||
|
toolset="plugin_bundle",
|
||||||
|
schema=_make_schema("plugin_tool", "Plugin tool"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("tools.registry.registry", reg)
|
||||||
|
|
||||||
|
all_toolsets = get_all_toolsets()
|
||||||
|
assert "plugin_bundle" in all_toolsets
|
||||||
|
assert all_toolsets["plugin_bundle"]["tools"] == ["plugin_tool"]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Tests for the central tool registry."""
|
"""Tests for the central tool registry."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
from tools.registry import ToolRegistry
|
from tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
@ -167,6 +168,32 @@ class TestToolsetAvailability:
|
||||||
)
|
)
|
||||||
assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
|
assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
|
||||||
|
|
||||||
|
def test_get_registered_toolset_names(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
assert reg.get_registered_toolset_names() == ["alpha", "zeta"]
|
||||||
|
|
||||||
|
def test_get_tool_names_for_toolset(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"]
|
||||||
|
|
||||||
def test_handler_exception_returns_error(self):
|
def test_handler_exception_returns_error(self):
|
||||||
reg = ToolRegistry()
|
reg = ToolRegistry()
|
||||||
|
|
||||||
|
|
@ -301,6 +328,22 @@ class TestEmojiMetadata:
|
||||||
assert reg.get_emoji("t") == "⚡"
|
assert reg.get_emoji("t") == "⚡"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntryLookup:
|
||||||
|
def test_get_entry_returns_registered_entry(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler
|
||||||
|
)
|
||||||
|
entry = reg.get_entry("alpha")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry.name == "alpha"
|
||||||
|
assert entry.toolset == "core"
|
||||||
|
|
||||||
|
def test_get_entry_returns_none_for_unknown_tool(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
assert reg.get_entry("missing") is None
|
||||||
|
|
||||||
|
|
||||||
class TestSecretCaptureResultContract:
|
class TestSecretCaptureResultContract:
|
||||||
def test_secret_request_result_does_not_include_secret_value(self):
|
def test_secret_request_result_does_not_include_secret_value(self):
|
||||||
result = {
|
result = {
|
||||||
|
|
@ -309,3 +352,141 @@ class TestSecretCaptureResultContract:
|
||||||
"validated": False,
|
"validated": False,
|
||||||
}
|
}
|
||||||
assert "secret" not in json.dumps(result).lower()
|
assert "secret" not in json.dumps(result).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadSafety:
|
||||||
|
def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(
|
||||||
|
name="alpha",
|
||||||
|
toolset="gated",
|
||||||
|
schema=_make_schema("alpha"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
check_fn=lambda: False,
|
||||||
|
)
|
||||||
|
|
||||||
|
entries, toolset_checks = reg._snapshot_state()
|
||||||
|
|
||||||
|
def snapshot_then_mutate():
|
||||||
|
reg.deregister("alpha")
|
||||||
|
return entries, toolset_checks
|
||||||
|
|
||||||
|
monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate)
|
||||||
|
|
||||||
|
toolsets = reg.get_available_toolsets()
|
||||||
|
assert toolsets["gated"]["available"] is False
|
||||||
|
assert toolsets["gated"]["tools"] == ["alpha"]
|
||||||
|
|
||||||
|
def test_check_tool_availability_tolerates_concurrent_register(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
check_started = threading.Event()
|
||||||
|
writer_done = threading.Event()
|
||||||
|
errors = []
|
||||||
|
result_holder = {}
|
||||||
|
writer_completed_during_check = {}
|
||||||
|
|
||||||
|
def blocking_check():
|
||||||
|
check_started.set()
|
||||||
|
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
reg.register(
|
||||||
|
name="alpha",
|
||||||
|
toolset="gated",
|
||||||
|
schema=_make_schema("alpha"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
check_fn=blocking_check,
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="beta",
|
||||||
|
toolset="plain",
|
||||||
|
schema=_make_schema("beta"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
try:
|
||||||
|
result_holder["value"] = reg.check_tool_availability()
|
||||||
|
except Exception as exc: # pragma: no cover - exercised on failure only
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
def writer():
|
||||||
|
assert check_started.wait(timeout=1)
|
||||||
|
reg.register(
|
||||||
|
name="gamma",
|
||||||
|
toolset="new",
|
||||||
|
schema=_make_schema("gamma"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
writer_done.set()
|
||||||
|
|
||||||
|
reader_thread = threading.Thread(target=reader)
|
||||||
|
writer_thread = threading.Thread(target=writer)
|
||||||
|
reader_thread.start()
|
||||||
|
writer_thread.start()
|
||||||
|
reader_thread.join(timeout=2)
|
||||||
|
writer_thread.join(timeout=2)
|
||||||
|
|
||||||
|
assert not reader_thread.is_alive()
|
||||||
|
assert not writer_thread.is_alive()
|
||||||
|
assert writer_completed_during_check["value"] is True
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
available, unavailable = result_holder["value"]
|
||||||
|
assert "gated" in available
|
||||||
|
assert "plain" in available
|
||||||
|
assert unavailable == []
|
||||||
|
|
||||||
|
def test_get_available_toolsets_tolerates_concurrent_deregister(self):
|
||||||
|
reg = ToolRegistry()
|
||||||
|
check_started = threading.Event()
|
||||||
|
writer_done = threading.Event()
|
||||||
|
errors = []
|
||||||
|
result_holder = {}
|
||||||
|
writer_completed_during_check = {}
|
||||||
|
|
||||||
|
def blocking_check():
|
||||||
|
check_started.set()
|
||||||
|
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
reg.register(
|
||||||
|
name="alpha",
|
||||||
|
toolset="gated",
|
||||||
|
schema=_make_schema("alpha"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
check_fn=blocking_check,
|
||||||
|
)
|
||||||
|
reg.register(
|
||||||
|
name="beta",
|
||||||
|
toolset="plain",
|
||||||
|
schema=_make_schema("beta"),
|
||||||
|
handler=_dummy_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
try:
|
||||||
|
result_holder["value"] = reg.get_available_toolsets()
|
||||||
|
except Exception as exc: # pragma: no cover - exercised on failure only
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
def writer():
|
||||||
|
assert check_started.wait(timeout=1)
|
||||||
|
reg.deregister("beta")
|
||||||
|
writer_done.set()
|
||||||
|
|
||||||
|
reader_thread = threading.Thread(target=reader)
|
||||||
|
writer_thread = threading.Thread(target=writer)
|
||||||
|
reader_thread.start()
|
||||||
|
writer_thread.start()
|
||||||
|
reader_thread.join(timeout=2)
|
||||||
|
writer_thread.join(timeout=2)
|
||||||
|
|
||||||
|
assert not reader_thread.is_alive()
|
||||||
|
assert not writer_thread.is_alive()
|
||||||
|
assert writer_completed_during_check["value"] is True
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
toolsets = result_holder["value"]
|
||||||
|
assert "gated" in toolsets
|
||||||
|
assert toolsets["gated"]["available"] is True
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ Import chain (circular-import safe):
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import Callable, Dict, List, Optional, Set
|
from typing import Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -51,6 +52,49 @@ class ToolRegistry:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tools: Dict[str, ToolEntry] = {}
|
self._tools: Dict[str, ToolEntry] = {}
|
||||||
self._toolset_checks: Dict[str, Callable] = {}
|
self._toolset_checks: Dict[str, Callable] = {}
|
||||||
|
# MCP dynamic refresh can mutate the registry while other threads are
|
||||||
|
# reading tool metadata, so keep mutations serialized and readers on
|
||||||
|
# stable snapshots.
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]:
|
||||||
|
"""Return a coherent snapshot of registry entries and toolset checks."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._tools.values()), dict(self._toolset_checks)
|
||||||
|
|
||||||
|
def _snapshot_entries(self) -> List[ToolEntry]:
|
||||||
|
"""Return a stable snapshot of registered tool entries."""
|
||||||
|
return self._snapshot_state()[0]
|
||||||
|
|
||||||
|
def _snapshot_toolset_checks(self) -> Dict[str, Callable]:
|
||||||
|
"""Return a stable snapshot of toolset availability checks."""
|
||||||
|
return self._snapshot_state()[1]
|
||||||
|
|
||||||
|
def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool:
|
||||||
|
"""Run a toolset check, treating missing or failing checks as unavailable/available."""
|
||||||
|
if not check:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return bool(check())
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_entry(self, name: str) -> Optional[ToolEntry]:
|
||||||
|
"""Return a registered tool entry by name, or None."""
|
||||||
|
with self._lock:
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def get_registered_toolset_names(self) -> List[str]:
|
||||||
|
"""Return sorted unique toolset names present in the registry."""
|
||||||
|
return sorted({entry.toolset for entry in self._snapshot_entries()})
|
||||||
|
|
||||||
|
def get_tool_names_for_toolset(self, toolset: str) -> List[str]:
|
||||||
|
"""Return sorted tool names registered under a given toolset."""
|
||||||
|
return sorted(
|
||||||
|
entry.name for entry in self._snapshot_entries()
|
||||||
|
if entry.toolset == toolset
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Registration
|
# Registration
|
||||||
|
|
@ -70,27 +114,28 @@ class ToolRegistry:
|
||||||
max_result_size_chars: int | float | None = None,
|
max_result_size_chars: int | float | None = None,
|
||||||
):
|
):
|
||||||
"""Register a tool. Called at module-import time by each tool file."""
|
"""Register a tool. Called at module-import time by each tool file."""
|
||||||
existing = self._tools.get(name)
|
with self._lock:
|
||||||
if existing and existing.toolset != toolset:
|
existing = self._tools.get(name)
|
||||||
logger.warning(
|
if existing and existing.toolset != toolset:
|
||||||
"Tool name collision: '%s' (toolset '%s') is being "
|
logger.warning(
|
||||||
"overwritten by toolset '%s'",
|
"Tool name collision: '%s' (toolset '%s') is being "
|
||||||
name, existing.toolset, toolset,
|
"overwritten by toolset '%s'",
|
||||||
|
name, existing.toolset, toolset,
|
||||||
|
)
|
||||||
|
self._tools[name] = ToolEntry(
|
||||||
|
name=name,
|
||||||
|
toolset=toolset,
|
||||||
|
schema=schema,
|
||||||
|
handler=handler,
|
||||||
|
check_fn=check_fn,
|
||||||
|
requires_env=requires_env or [],
|
||||||
|
is_async=is_async,
|
||||||
|
description=description or schema.get("description", ""),
|
||||||
|
emoji=emoji,
|
||||||
|
max_result_size_chars=max_result_size_chars,
|
||||||
)
|
)
|
||||||
self._tools[name] = ToolEntry(
|
if check_fn and toolset not in self._toolset_checks:
|
||||||
name=name,
|
self._toolset_checks[toolset] = check_fn
|
||||||
toolset=toolset,
|
|
||||||
schema=schema,
|
|
||||||
handler=handler,
|
|
||||||
check_fn=check_fn,
|
|
||||||
requires_env=requires_env or [],
|
|
||||||
is_async=is_async,
|
|
||||||
description=description or schema.get("description", ""),
|
|
||||||
emoji=emoji,
|
|
||||||
max_result_size_chars=max_result_size_chars,
|
|
||||||
)
|
|
||||||
if check_fn and toolset not in self._toolset_checks:
|
|
||||||
self._toolset_checks[toolset] = check_fn
|
|
||||||
|
|
||||||
def deregister(self, name: str) -> None:
|
def deregister(self, name: str) -> None:
|
||||||
"""Remove a tool from the registry.
|
"""Remove a tool from the registry.
|
||||||
|
|
@ -99,14 +144,15 @@ class ToolRegistry:
|
||||||
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
|
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
|
||||||
when a server sends ``notifications/tools/list_changed``.
|
when a server sends ``notifications/tools/list_changed``.
|
||||||
"""
|
"""
|
||||||
entry = self._tools.pop(name, None)
|
with self._lock:
|
||||||
if entry is None:
|
entry = self._tools.pop(name, None)
|
||||||
return
|
if entry is None:
|
||||||
# Drop the toolset check if this was the last tool in that toolset
|
return
|
||||||
if entry.toolset in self._toolset_checks and not any(
|
# Drop the toolset check if this was the last tool in that toolset
|
||||||
e.toolset == entry.toolset for e in self._tools.values()
|
if entry.toolset in self._toolset_checks and not any(
|
||||||
):
|
e.toolset == entry.toolset for e in self._tools.values()
|
||||||
self._toolset_checks.pop(entry.toolset, None)
|
):
|
||||||
|
self._toolset_checks.pop(entry.toolset, None)
|
||||||
logger.debug("Deregistered tool: %s", name)
|
logger.debug("Deregistered tool: %s", name)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -121,8 +167,9 @@ class ToolRegistry:
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
check_results: Dict[Callable, bool] = {}
|
check_results: Dict[Callable, bool] = {}
|
||||||
|
entries_by_name = {entry.name: entry for entry in self._snapshot_entries()}
|
||||||
for name in sorted(tool_names):
|
for name in sorted(tool_names):
|
||||||
entry = self._tools.get(name)
|
entry = entries_by_name.get(name)
|
||||||
if not entry:
|
if not entry:
|
||||||
continue
|
continue
|
||||||
if entry.check_fn:
|
if entry.check_fn:
|
||||||
|
|
@ -153,7 +200,7 @@ class ToolRegistry:
|
||||||
* All exceptions are caught and returned as ``{"error": "..."}``
|
* All exceptions are caught and returned as ``{"error": "..."}``
|
||||||
for consistent error format.
|
for consistent error format.
|
||||||
"""
|
"""
|
||||||
entry = self._tools.get(name)
|
entry = self.get_entry(name)
|
||||||
if not entry:
|
if not entry:
|
||||||
return json.dumps({"error": f"Unknown tool: {name}"})
|
return json.dumps({"error": f"Unknown tool: {name}"})
|
||||||
try:
|
try:
|
||||||
|
|
@ -171,7 +218,7 @@ class ToolRegistry:
|
||||||
|
|
||||||
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
|
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
|
||||||
"""Return per-tool max result size, or *default* (or global default)."""
|
"""Return per-tool max result size, or *default* (or global default)."""
|
||||||
entry = self._tools.get(name)
|
entry = self.get_entry(name)
|
||||||
if entry and entry.max_result_size_chars is not None:
|
if entry and entry.max_result_size_chars is not None:
|
||||||
return entry.max_result_size_chars
|
return entry.max_result_size_chars
|
||||||
if default is not None:
|
if default is not None:
|
||||||
|
|
@ -181,7 +228,7 @@ class ToolRegistry:
|
||||||
|
|
||||||
def get_all_tool_names(self) -> List[str]:
|
def get_all_tool_names(self) -> List[str]:
|
||||||
"""Return sorted list of all registered tool names."""
|
"""Return sorted list of all registered tool names."""
|
||||||
return sorted(self._tools.keys())
|
return sorted(entry.name for entry in self._snapshot_entries())
|
||||||
|
|
||||||
def get_schema(self, name: str) -> Optional[dict]:
|
def get_schema(self, name: str) -> Optional[dict]:
|
||||||
"""Return a tool's raw schema dict, bypassing check_fn filtering.
|
"""Return a tool's raw schema dict, bypassing check_fn filtering.
|
||||||
|
|
@ -189,22 +236,22 @@ class ToolRegistry:
|
||||||
Useful for token estimation and introspection where availability
|
Useful for token estimation and introspection where availability
|
||||||
doesn't matter — only the schema content does.
|
doesn't matter — only the schema content does.
|
||||||
"""
|
"""
|
||||||
entry = self._tools.get(name)
|
entry = self.get_entry(name)
|
||||||
return entry.schema if entry else None
|
return entry.schema if entry else None
|
||||||
|
|
||||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||||
"""Return the toolset a tool belongs to, or None."""
|
"""Return the toolset a tool belongs to, or None."""
|
||||||
entry = self._tools.get(name)
|
entry = self.get_entry(name)
|
||||||
return entry.toolset if entry else None
|
return entry.toolset if entry else None
|
||||||
|
|
||||||
def get_emoji(self, name: str, default: str = "⚡") -> str:
|
def get_emoji(self, name: str, default: str = "⚡") -> str:
|
||||||
"""Return the emoji for a tool, or *default* if unset."""
|
"""Return the emoji for a tool, or *default* if unset."""
|
||||||
entry = self._tools.get(name)
|
entry = self.get_entry(name)
|
||||||
return (entry.emoji if entry and entry.emoji else default)
|
return (entry.emoji if entry and entry.emoji else default)
|
||||||
|
|
||||||
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
||||||
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
||||||
return {name: e.toolset for name, e in self._tools.items()}
|
return {entry.name: entry.toolset for entry in self._snapshot_entries()}
|
||||||
|
|
||||||
def is_toolset_available(self, toolset: str) -> bool:
|
def is_toolset_available(self, toolset: str) -> bool:
|
||||||
"""Check if a toolset's requirements are met.
|
"""Check if a toolset's requirements are met.
|
||||||
|
|
@ -212,28 +259,30 @@ class ToolRegistry:
|
||||||
Returns False (rather than crashing) when the check function raises
|
Returns False (rather than crashing) when the check function raises
|
||||||
an unexpected exception (e.g. network error, missing import, bad config).
|
an unexpected exception (e.g. network error, missing import, bad config).
|
||||||
"""
|
"""
|
||||||
check = self._toolset_checks.get(toolset)
|
with self._lock:
|
||||||
if not check:
|
check = self._toolset_checks.get(toolset)
|
||||||
return True
|
return self._evaluate_toolset_check(toolset, check)
|
||||||
try:
|
|
||||||
return bool(check())
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_toolset_requirements(self) -> Dict[str, bool]:
|
def check_toolset_requirements(self) -> Dict[str, bool]:
|
||||||
"""Return ``{toolset: available_bool}`` for every toolset."""
|
"""Return ``{toolset: available_bool}`` for every toolset."""
|
||||||
toolsets = set(e.toolset for e in self._tools.values())
|
entries, toolset_checks = self._snapshot_state()
|
||||||
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
|
toolsets = sorted({entry.toolset for entry in entries})
|
||||||
|
return {
|
||||||
|
toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset))
|
||||||
|
for toolset in toolsets
|
||||||
|
}
|
||||||
|
|
||||||
def get_available_toolsets(self) -> Dict[str, dict]:
|
def get_available_toolsets(self) -> Dict[str, dict]:
|
||||||
"""Return toolset metadata for UI display."""
|
"""Return toolset metadata for UI display."""
|
||||||
toolsets: Dict[str, dict] = {}
|
toolsets: Dict[str, dict] = {}
|
||||||
for entry in self._tools.values():
|
entries, toolset_checks = self._snapshot_state()
|
||||||
|
for entry in entries:
|
||||||
ts = entry.toolset
|
ts = entry.toolset
|
||||||
if ts not in toolsets:
|
if ts not in toolsets:
|
||||||
toolsets[ts] = {
|
toolsets[ts] = {
|
||||||
"available": self.is_toolset_available(ts),
|
"available": self._evaluate_toolset_check(
|
||||||
|
ts, toolset_checks.get(ts)
|
||||||
|
),
|
||||||
"tools": [],
|
"tools": [],
|
||||||
"description": "",
|
"description": "",
|
||||||
"requirements": [],
|
"requirements": [],
|
||||||
|
|
@ -248,13 +297,14 @@ class ToolRegistry:
|
||||||
def get_toolset_requirements(self) -> Dict[str, dict]:
|
def get_toolset_requirements(self) -> Dict[str, dict]:
|
||||||
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
||||||
result: Dict[str, dict] = {}
|
result: Dict[str, dict] = {}
|
||||||
for entry in self._tools.values():
|
entries, toolset_checks = self._snapshot_state()
|
||||||
|
for entry in entries:
|
||||||
ts = entry.toolset
|
ts = entry.toolset
|
||||||
if ts not in result:
|
if ts not in result:
|
||||||
result[ts] = {
|
result[ts] = {
|
||||||
"name": ts,
|
"name": ts,
|
||||||
"env_vars": [],
|
"env_vars": [],
|
||||||
"check_fn": self._toolset_checks.get(ts),
|
"check_fn": toolset_checks.get(ts),
|
||||||
"setup_url": None,
|
"setup_url": None,
|
||||||
"tools": [],
|
"tools": [],
|
||||||
}
|
}
|
||||||
|
|
@ -270,18 +320,19 @@ class ToolRegistry:
|
||||||
available = []
|
available = []
|
||||||
unavailable = []
|
unavailable = []
|
||||||
seen = set()
|
seen = set()
|
||||||
for entry in self._tools.values():
|
entries, toolset_checks = self._snapshot_state()
|
||||||
|
for entry in entries:
|
||||||
ts = entry.toolset
|
ts = entry.toolset
|
||||||
if ts in seen:
|
if ts in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(ts)
|
seen.add(ts)
|
||||||
if self.is_toolset_available(ts):
|
if self._evaluate_toolset_check(ts, toolset_checks.get(ts)):
|
||||||
available.append(ts)
|
available.append(ts)
|
||||||
else:
|
else:
|
||||||
unavailable.append({
|
unavailable.append({
|
||||||
"name": ts,
|
"name": ts,
|
||||||
"env_vars": entry.requires_env,
|
"env_vars": entry.requires_env,
|
||||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
"tools": [e.name for e in entries if e.toolset == ts],
|
||||||
})
|
})
|
||||||
return available, unavailable
|
return available, unavailable
|
||||||
|
|
||||||
|
|
|
||||||
10
toolsets.py
10
toolsets.py
|
|
@ -449,7 +449,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
|
||||||
if name in _get_plugin_toolset_names():
|
if name in _get_plugin_toolset_names():
|
||||||
try:
|
try:
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
return [e.name for e in registry._tools.values() if e.toolset == name]
|
return registry.get_tool_names_for_toolset(name)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
|
|
@ -495,9 +495,9 @@ def _get_plugin_toolset_names() -> Set[str]:
|
||||||
try:
|
try:
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
return {
|
return {
|
||||||
entry.toolset
|
toolset_name
|
||||||
for entry in registry._tools.values()
|
for toolset_name in registry.get_registered_toolset_names()
|
||||||
if entry.toolset not in TOOLSETS
|
if toolset_name not in TOOLSETS
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
return set()
|
return set()
|
||||||
|
|
@ -518,7 +518,7 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||||
if ts_name not in result:
|
if ts_name not in result:
|
||||||
try:
|
try:
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
tools = [e.name for e in registry._tools.values() if e.toolset == ts_name]
|
tools = registry.get_tool_names_for_toolset(ts_name)
|
||||||
result[ts_name] = {
|
result[ts_name] = {
|
||||||
"description": f"Plugin toolset: {ts_name}",
|
"description": f"Plugin toolset: {ts_name}",
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue