mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix: make tool registry reads thread-safe
This commit is contained in:
parent
6dc8f8e9c0
commit
c7e2fe655a
5 changed files with 341 additions and 62 deletions
|
|
@ -16,6 +16,7 @@ Import chain (circular-import safe):
|
|||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Callable, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -51,6 +52,49 @@ class ToolRegistry:
|
|||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
# MCP dynamic refresh can mutate the registry while other threads are
|
||||
# reading tool metadata, so keep mutations serialized and readers on
|
||||
# stable snapshots.
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]:
|
||||
"""Return a coherent snapshot of registry entries and toolset checks."""
|
||||
with self._lock:
|
||||
return list(self._tools.values()), dict(self._toolset_checks)
|
||||
|
||||
def _snapshot_entries(self) -> List[ToolEntry]:
|
||||
"""Return a stable snapshot of registered tool entries."""
|
||||
return self._snapshot_state()[0]
|
||||
|
||||
def _snapshot_toolset_checks(self) -> Dict[str, Callable]:
|
||||
"""Return a stable snapshot of toolset availability checks."""
|
||||
return self._snapshot_state()[1]
|
||||
|
||||
def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool:
|
||||
"""Run a toolset check, treating missing or failing checks as unavailable/available."""
|
||||
if not check:
|
||||
return True
|
||||
try:
|
||||
return bool(check())
|
||||
except Exception:
|
||||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||
return False
|
||||
|
||||
def get_entry(self, name: str) -> Optional[ToolEntry]:
|
||||
"""Return a registered tool entry by name, or None."""
|
||||
with self._lock:
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_registered_toolset_names(self) -> List[str]:
|
||||
"""Return sorted unique toolset names present in the registry."""
|
||||
return sorted({entry.toolset for entry in self._snapshot_entries()})
|
||||
|
||||
def get_tool_names_for_toolset(self, toolset: str) -> List[str]:
|
||||
"""Return sorted tool names registered under a given toolset."""
|
||||
return sorted(
|
||||
entry.name for entry in self._snapshot_entries()
|
||||
if entry.toolset == toolset
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
|
|
@ -70,27 +114,28 @@ class ToolRegistry:
|
|||
max_result_size_chars: int | float | None = None,
|
||||
):
|
||||
"""Register a tool. Called at module-import time by each tool file."""
|
||||
existing = self._tools.get(name)
|
||||
if existing and existing.toolset != toolset:
|
||||
logger.warning(
|
||||
"Tool name collision: '%s' (toolset '%s') is being "
|
||||
"overwritten by toolset '%s'",
|
||||
name, existing.toolset, toolset,
|
||||
with self._lock:
|
||||
existing = self._tools.get(name)
|
||||
if existing and existing.toolset != toolset:
|
||||
logger.warning(
|
||||
"Tool name collision: '%s' (toolset '%s') is being "
|
||||
"overwritten by toolset '%s'",
|
||||
name, existing.toolset, toolset,
|
||||
)
|
||||
self._tools[name] = ToolEntry(
|
||||
name=name,
|
||||
toolset=toolset,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
requires_env=requires_env or [],
|
||||
is_async=is_async,
|
||||
description=description or schema.get("description", ""),
|
||||
emoji=emoji,
|
||||
max_result_size_chars=max_result_size_chars,
|
||||
)
|
||||
self._tools[name] = ToolEntry(
|
||||
name=name,
|
||||
toolset=toolset,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
requires_env=requires_env or [],
|
||||
is_async=is_async,
|
||||
description=description or schema.get("description", ""),
|
||||
emoji=emoji,
|
||||
max_result_size_chars=max_result_size_chars,
|
||||
)
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
|
||||
def deregister(self, name: str) -> None:
|
||||
"""Remove a tool from the registry.
|
||||
|
|
@ -99,14 +144,15 @@ class ToolRegistry:
|
|||
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
|
||||
when a server sends ``notifications/tools/list_changed``.
|
||||
"""
|
||||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
with self._lock:
|
||||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
logger.debug("Deregistered tool: %s", name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -121,8 +167,9 @@ class ToolRegistry:
|
|||
"""
|
||||
result = []
|
||||
check_results: Dict[Callable, bool] = {}
|
||||
entries_by_name = {entry.name: entry for entry in self._snapshot_entries()}
|
||||
for name in sorted(tool_names):
|
||||
entry = self._tools.get(name)
|
||||
entry = entries_by_name.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
if entry.check_fn:
|
||||
|
|
@ -153,7 +200,7 @@ class ToolRegistry:
|
|||
* All exceptions are caught and returned as ``{"error": "..."}``
|
||||
for consistent error format.
|
||||
"""
|
||||
entry = self._tools.get(name)
|
||||
entry = self.get_entry(name)
|
||||
if not entry:
|
||||
return json.dumps({"error": f"Unknown tool: {name}"})
|
||||
try:
|
||||
|
|
@ -171,7 +218,7 @@ class ToolRegistry:
|
|||
|
||||
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
|
||||
"""Return per-tool max result size, or *default* (or global default)."""
|
||||
entry = self._tools.get(name)
|
||||
entry = self.get_entry(name)
|
||||
if entry and entry.max_result_size_chars is not None:
|
||||
return entry.max_result_size_chars
|
||||
if default is not None:
|
||||
|
|
@ -181,7 +228,7 @@ class ToolRegistry:
|
|||
|
||||
def get_all_tool_names(self) -> List[str]:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
return sorted(entry.name for entry in self._snapshot_entries())
|
||||
|
||||
def get_schema(self, name: str) -> Optional[dict]:
|
||||
"""Return a tool's raw schema dict, bypassing check_fn filtering.
|
||||
|
|
@ -189,22 +236,22 @@ class ToolRegistry:
|
|||
Useful for token estimation and introspection where availability
|
||||
doesn't matter — only the schema content does.
|
||||
"""
|
||||
entry = self._tools.get(name)
|
||||
entry = self.get_entry(name)
|
||||
return entry.schema if entry else None
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
entry = self.get_entry(name)
|
||||
return entry.toolset if entry else None
|
||||
|
||||
def get_emoji(self, name: str, default: str = "⚡") -> str:
|
||||
"""Return the emoji for a tool, or *default* if unset."""
|
||||
entry = self._tools.get(name)
|
||||
entry = self.get_entry(name)
|
||||
return (entry.emoji if entry and entry.emoji else default)
|
||||
|
||||
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
||||
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
||||
return {name: e.toolset for name, e in self._tools.items()}
|
||||
return {entry.name: entry.toolset for entry in self._snapshot_entries()}
|
||||
|
||||
def is_toolset_available(self, toolset: str) -> bool:
|
||||
"""Check if a toolset's requirements are met.
|
||||
|
|
@ -212,28 +259,30 @@ class ToolRegistry:
|
|||
Returns False (rather than crashing) when the check function raises
|
||||
an unexpected exception (e.g. network error, missing import, bad config).
|
||||
"""
|
||||
check = self._toolset_checks.get(toolset)
|
||||
if not check:
|
||||
return True
|
||||
try:
|
||||
return bool(check())
|
||||
except Exception:
|
||||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||
return False
|
||||
with self._lock:
|
||||
check = self._toolset_checks.get(toolset)
|
||||
return self._evaluate_toolset_check(toolset, check)
|
||||
|
||||
def check_toolset_requirements(self) -> Dict[str, bool]:
|
||||
"""Return ``{toolset: available_bool}`` for every toolset."""
|
||||
toolsets = set(e.toolset for e in self._tools.values())
|
||||
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
|
||||
entries, toolset_checks = self._snapshot_state()
|
||||
toolsets = sorted({entry.toolset for entry in entries})
|
||||
return {
|
||||
toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset))
|
||||
for toolset in toolsets
|
||||
}
|
||||
|
||||
def get_available_toolsets(self) -> Dict[str, dict]:
|
||||
"""Return toolset metadata for UI display."""
|
||||
toolsets: Dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
entries, toolset_checks = self._snapshot_state()
|
||||
for entry in entries:
|
||||
ts = entry.toolset
|
||||
if ts not in toolsets:
|
||||
toolsets[ts] = {
|
||||
"available": self.is_toolset_available(ts),
|
||||
"available": self._evaluate_toolset_check(
|
||||
ts, toolset_checks.get(ts)
|
||||
),
|
||||
"tools": [],
|
||||
"description": "",
|
||||
"requirements": [],
|
||||
|
|
@ -248,13 +297,14 @@ class ToolRegistry:
|
|||
def get_toolset_requirements(self) -> Dict[str, dict]:
|
||||
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
||||
result: Dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
entries, toolset_checks = self._snapshot_state()
|
||||
for entry in entries:
|
||||
ts = entry.toolset
|
||||
if ts not in result:
|
||||
result[ts] = {
|
||||
"name": ts,
|
||||
"env_vars": [],
|
||||
"check_fn": self._toolset_checks.get(ts),
|
||||
"check_fn": toolset_checks.get(ts),
|
||||
"setup_url": None,
|
||||
"tools": [],
|
||||
}
|
||||
|
|
@ -270,18 +320,19 @@ class ToolRegistry:
|
|||
available = []
|
||||
unavailable = []
|
||||
seen = set()
|
||||
for entry in self._tools.values():
|
||||
entries, toolset_checks = self._snapshot_state()
|
||||
for entry in entries:
|
||||
ts = entry.toolset
|
||||
if ts in seen:
|
||||
continue
|
||||
seen.add(ts)
|
||||
if self.is_toolset_available(ts):
|
||||
if self._evaluate_toolset_check(ts, toolset_checks.get(ts)):
|
||||
available.append(ts)
|
||||
else:
|
||||
unavailable.append({
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
"tools": [e.name for e in entries if e.toolset == ts],
|
||||
})
|
||||
return available, unavailable
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue