fix: make tool registry reads thread-safe

This commit is contained in:
Greer Guthrie 2026-04-14 00:46:41 -05:00 committed by Teknium
parent 6dc8f8e9c0
commit c7e2fe655a
5 changed files with 341 additions and 62 deletions

View file

@ -647,7 +647,7 @@ def get_plugin_toolsets() -> List[tuple]:
toolset_tools: Dict[str, List[str]] = {} toolset_tools: Dict[str, List[str]] = {}
toolset_plugin: Dict[str, LoadedPlugin] = {} toolset_plugin: Dict[str, LoadedPlugin] = {}
for tool_name in manager._plugin_tool_names: for tool_name in manager._plugin_tool_names:
entry = registry._tools.get(tool_name) entry = registry.get_entry(tool_name)
if not entry: if not entry:
continue continue
ts = entry.toolset ts = entry.toolset
@ -656,7 +656,7 @@ def get_plugin_toolsets() -> List[tuple]:
# Map toolsets back to the plugin that registered them # Map toolsets back to the plugin that registered them
for _name, loaded in manager._plugins.items(): for _name, loaded in manager._plugins.items():
for tool_name in loaded.tools_registered: for tool_name in loaded.tools_registered:
entry = registry._tools.get(tool_name) entry = registry.get_entry(tool_name)
if entry and entry.toolset in toolset_tools: if entry and entry.toolset in toolset_tools:
toolset_plugin.setdefault(entry.toolset, loaded) toolset_plugin.setdefault(entry.toolset, loaded)

View file

@ -1,7 +1,6 @@
"""Tests for toolsets.py — toolset resolution, validation, and composition.""" """Tests for toolsets.py — toolset resolution, validation, and composition."""
import pytest from tools.registry import ToolRegistry
from toolsets import ( from toolsets import (
TOOLSETS, TOOLSETS,
get_toolset, get_toolset,
@ -15,6 +14,18 @@ from toolsets import (
) )
def _dummy_handler(args, **kwargs):
return "{}"
def _make_schema(name: str, description: str = "test tool"):
return {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": {}},
}
class TestGetToolset: class TestGetToolset:
def test_known_toolset(self): def test_known_toolset(self):
ts = get_toolset("web") ts = get_toolset("web")
@ -52,6 +63,25 @@ class TestResolveToolset:
def test_unknown_toolset_returns_empty(self): def test_unknown_toolset_returns_empty(self):
assert resolve_toolset("nonexistent") == [] assert resolve_toolset("nonexistent") == []
def test_plugin_toolset_uses_registry_snapshot(self, monkeypatch):
reg = ToolRegistry()
reg.register(
name="plugin_b",
toolset="plugin_example",
schema=_make_schema("plugin_b", "B"),
handler=_dummy_handler,
)
reg.register(
name="plugin_a",
toolset="plugin_example",
schema=_make_schema("plugin_a", "A"),
handler=_dummy_handler,
)
monkeypatch.setattr("tools.registry.registry", reg)
assert resolve_toolset("plugin_example") == ["plugin_a", "plugin_b"]
def test_all_alias(self): def test_all_alias(self):
tools = resolve_toolset("all") tools = resolve_toolset("all")
assert len(tools) > 10 # Should resolve all tools from all toolsets assert len(tools) > 10 # Should resolve all tools from all toolsets
@ -141,3 +171,20 @@ class TestToolsetConsistency:
# All platform toolsets should be identical # All platform toolsets should be identical
for ts in tool_sets[1:]: for ts in tool_sets[1:]:
assert ts == tool_sets[0] assert ts == tool_sets[0]
class TestPluginToolsets:
def test_get_all_toolsets_includes_plugin_toolset(self, monkeypatch):
reg = ToolRegistry()
reg.register(
name="plugin_tool",
toolset="plugin_bundle",
schema=_make_schema("plugin_tool", "Plugin tool"),
handler=_dummy_handler,
)
monkeypatch.setattr("tools.registry.registry", reg)
all_toolsets = get_all_toolsets()
assert "plugin_bundle" in all_toolsets
assert all_toolsets["plugin_bundle"]["tools"] == ["plugin_tool"]

View file

@ -1,6 +1,7 @@
"""Tests for the central tool registry.""" """Tests for the central tool registry."""
import json import json
import threading
from tools.registry import ToolRegistry from tools.registry import ToolRegistry
@ -167,6 +168,32 @@ class TestToolsetAvailability:
) )
assert reg.get_all_tool_names() == ["a_tool", "z_tool"] assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
def test_get_registered_toolset_names(self):
reg = ToolRegistry()
reg.register(
name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
)
assert reg.get_registered_toolset_names() == ["alpha", "zeta"]
def test_get_tool_names_for_toolset(self):
reg = ToolRegistry()
reg.register(
name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler
)
assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"]
def test_handler_exception_returns_error(self): def test_handler_exception_returns_error(self):
reg = ToolRegistry() reg = ToolRegistry()
@ -301,6 +328,22 @@ class TestEmojiMetadata:
assert reg.get_emoji("t") == "" assert reg.get_emoji("t") == ""
class TestEntryLookup:
def test_get_entry_returns_registered_entry(self):
reg = ToolRegistry()
reg.register(
name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler
)
entry = reg.get_entry("alpha")
assert entry is not None
assert entry.name == "alpha"
assert entry.toolset == "core"
def test_get_entry_returns_none_for_unknown_tool(self):
reg = ToolRegistry()
assert reg.get_entry("missing") is None
class TestSecretCaptureResultContract: class TestSecretCaptureResultContract:
def test_secret_request_result_does_not_include_secret_value(self): def test_secret_request_result_does_not_include_secret_value(self):
result = { result = {
@ -309,3 +352,141 @@ class TestSecretCaptureResultContract:
"validated": False, "validated": False,
} }
assert "secret" not in json.dumps(result).lower() assert "secret" not in json.dumps(result).lower()
class TestThreadSafety:
def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch):
reg = ToolRegistry()
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=lambda: False,
)
entries, toolset_checks = reg._snapshot_state()
def snapshot_then_mutate():
reg.deregister("alpha")
return entries, toolset_checks
monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate)
toolsets = reg.get_available_toolsets()
assert toolsets["gated"]["available"] is False
assert toolsets["gated"]["tools"] == ["alpha"]
def test_check_tool_availability_tolerates_concurrent_register(self):
reg = ToolRegistry()
check_started = threading.Event()
writer_done = threading.Event()
errors = []
result_holder = {}
writer_completed_during_check = {}
def blocking_check():
check_started.set()
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
return True
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=blocking_check,
)
reg.register(
name="beta",
toolset="plain",
schema=_make_schema("beta"),
handler=_dummy_handler,
)
def reader():
try:
result_holder["value"] = reg.check_tool_availability()
except Exception as exc: # pragma: no cover - exercised on failure only
errors.append(exc)
def writer():
assert check_started.wait(timeout=1)
reg.register(
name="gamma",
toolset="new",
schema=_make_schema("gamma"),
handler=_dummy_handler,
)
writer_done.set()
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join(timeout=2)
writer_thread.join(timeout=2)
assert not reader_thread.is_alive()
assert not writer_thread.is_alive()
assert writer_completed_during_check["value"] is True
assert errors == []
available, unavailable = result_holder["value"]
assert "gated" in available
assert "plain" in available
assert unavailable == []
def test_get_available_toolsets_tolerates_concurrent_deregister(self):
reg = ToolRegistry()
check_started = threading.Event()
writer_done = threading.Event()
errors = []
result_holder = {}
writer_completed_during_check = {}
def blocking_check():
check_started.set()
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
return True
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=blocking_check,
)
reg.register(
name="beta",
toolset="plain",
schema=_make_schema("beta"),
handler=_dummy_handler,
)
def reader():
try:
result_holder["value"] = reg.get_available_toolsets()
except Exception as exc: # pragma: no cover - exercised on failure only
errors.append(exc)
def writer():
assert check_started.wait(timeout=1)
reg.deregister("beta")
writer_done.set()
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join(timeout=2)
writer_thread.join(timeout=2)
assert not reader_thread.is_alive()
assert not writer_thread.is_alive()
assert writer_completed_during_check["value"] is True
assert errors == []
toolsets = result_holder["value"]
assert "gated" in toolsets
assert toolsets["gated"]["available"] is True

View file

@ -16,6 +16,7 @@ Import chain (circular-import safe):
import json import json
import logging import logging
import threading
from typing import Callable, Dict, List, Optional, Set from typing import Callable, Dict, List, Optional, Set
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,6 +52,49 @@ class ToolRegistry:
def __init__(self): def __init__(self):
self._tools: Dict[str, ToolEntry] = {} self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {} self._toolset_checks: Dict[str, Callable] = {}
# MCP dynamic refresh can mutate the registry while other threads are
# reading tool metadata, so keep mutations serialized and readers on
# stable snapshots.
self._lock = threading.RLock()
def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]:
"""Return a coherent snapshot of registry entries and toolset checks."""
with self._lock:
return list(self._tools.values()), dict(self._toolset_checks)
def _snapshot_entries(self) -> List[ToolEntry]:
"""Return a stable snapshot of registered tool entries."""
return self._snapshot_state()[0]
def _snapshot_toolset_checks(self) -> Dict[str, Callable]:
"""Return a stable snapshot of toolset availability checks."""
return self._snapshot_state()[1]
def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool:
"""Run a toolset check, treating missing or failing checks as unavailable/available."""
if not check:
return True
try:
return bool(check())
except Exception:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def get_entry(self, name: str) -> Optional[ToolEntry]:
"""Return a registered tool entry by name, or None."""
with self._lock:
return self._tools.get(name)
def get_registered_toolset_names(self) -> List[str]:
"""Return sorted unique toolset names present in the registry."""
return sorted({entry.toolset for entry in self._snapshot_entries()})
def get_tool_names_for_toolset(self, toolset: str) -> List[str]:
"""Return sorted tool names registered under a given toolset."""
return sorted(
entry.name for entry in self._snapshot_entries()
if entry.toolset == toolset
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Registration # Registration
@ -70,27 +114,28 @@ class ToolRegistry:
max_result_size_chars: int | float | None = None, max_result_size_chars: int | float | None = None,
): ):
"""Register a tool. Called at module-import time by each tool file.""" """Register a tool. Called at module-import time by each tool file."""
existing = self._tools.get(name) with self._lock:
if existing and existing.toolset != toolset: existing = self._tools.get(name)
logger.warning( if existing and existing.toolset != toolset:
"Tool name collision: '%s' (toolset '%s') is being " logger.warning(
"overwritten by toolset '%s'", "Tool name collision: '%s' (toolset '%s') is being "
name, existing.toolset, toolset, "overwritten by toolset '%s'",
name, existing.toolset, toolset,
)
self._tools[name] = ToolEntry(
name=name,
toolset=toolset,
schema=schema,
handler=handler,
check_fn=check_fn,
requires_env=requires_env or [],
is_async=is_async,
description=description or schema.get("description", ""),
emoji=emoji,
max_result_size_chars=max_result_size_chars,
) )
self._tools[name] = ToolEntry( if check_fn and toolset not in self._toolset_checks:
name=name, self._toolset_checks[toolset] = check_fn
toolset=toolset,
schema=schema,
handler=handler,
check_fn=check_fn,
requires_env=requires_env or [],
is_async=is_async,
description=description or schema.get("description", ""),
emoji=emoji,
max_result_size_chars=max_result_size_chars,
)
if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn
def deregister(self, name: str) -> None: def deregister(self, name: str) -> None:
"""Remove a tool from the registry. """Remove a tool from the registry.
@ -99,14 +144,15 @@ class ToolRegistry:
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
when a server sends ``notifications/tools/list_changed``. when a server sends ``notifications/tools/list_changed``.
""" """
entry = self._tools.pop(name, None) with self._lock:
if entry is None: entry = self._tools.pop(name, None)
return if entry is None:
# Drop the toolset check if this was the last tool in that toolset return
if entry.toolset in self._toolset_checks and not any( # Drop the toolset check if this was the last tool in that toolset
e.toolset == entry.toolset for e in self._tools.values() if entry.toolset in self._toolset_checks and not any(
): e.toolset == entry.toolset for e in self._tools.values()
self._toolset_checks.pop(entry.toolset, None) ):
self._toolset_checks.pop(entry.toolset, None)
logger.debug("Deregistered tool: %s", name) logger.debug("Deregistered tool: %s", name)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -121,8 +167,9 @@ class ToolRegistry:
""" """
result = [] result = []
check_results: Dict[Callable, bool] = {} check_results: Dict[Callable, bool] = {}
entries_by_name = {entry.name: entry for entry in self._snapshot_entries()}
for name in sorted(tool_names): for name in sorted(tool_names):
entry = self._tools.get(name) entry = entries_by_name.get(name)
if not entry: if not entry:
continue continue
if entry.check_fn: if entry.check_fn:
@ -153,7 +200,7 @@ class ToolRegistry:
* All exceptions are caught and returned as ``{"error": "..."}`` * All exceptions are caught and returned as ``{"error": "..."}``
for consistent error format. for consistent error format.
""" """
entry = self._tools.get(name) entry = self.get_entry(name)
if not entry: if not entry:
return json.dumps({"error": f"Unknown tool: {name}"}) return json.dumps({"error": f"Unknown tool: {name}"})
try: try:
@ -171,7 +218,7 @@ class ToolRegistry:
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float: def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
"""Return per-tool max result size, or *default* (or global default).""" """Return per-tool max result size, or *default* (or global default)."""
entry = self._tools.get(name) entry = self.get_entry(name)
if entry and entry.max_result_size_chars is not None: if entry and entry.max_result_size_chars is not None:
return entry.max_result_size_chars return entry.max_result_size_chars
if default is not None: if default is not None:
@ -181,7 +228,7 @@ class ToolRegistry:
def get_all_tool_names(self) -> List[str]: def get_all_tool_names(self) -> List[str]:
"""Return sorted list of all registered tool names.""" """Return sorted list of all registered tool names."""
return sorted(self._tools.keys()) return sorted(entry.name for entry in self._snapshot_entries())
def get_schema(self, name: str) -> Optional[dict]: def get_schema(self, name: str) -> Optional[dict]:
"""Return a tool's raw schema dict, bypassing check_fn filtering. """Return a tool's raw schema dict, bypassing check_fn filtering.
@ -189,22 +236,22 @@ class ToolRegistry:
Useful for token estimation and introspection where availability Useful for token estimation and introspection where availability
doesn't matter — only the schema content does. doesn't matter — only the schema content does.
""" """
entry = self._tools.get(name) entry = self.get_entry(name)
return entry.schema if entry else None return entry.schema if entry else None
def get_toolset_for_tool(self, name: str) -> Optional[str]: def get_toolset_for_tool(self, name: str) -> Optional[str]:
"""Return the toolset a tool belongs to, or None.""" """Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name) entry = self.get_entry(name)
return entry.toolset if entry else None return entry.toolset if entry else None
def get_emoji(self, name: str, default: str = "") -> str: def get_emoji(self, name: str, default: str = "") -> str:
"""Return the emoji for a tool, or *default* if unset.""" """Return the emoji for a tool, or *default* if unset."""
entry = self._tools.get(name) entry = self.get_entry(name)
return (entry.emoji if entry and entry.emoji else default) return (entry.emoji if entry and entry.emoji else default)
def get_tool_to_toolset_map(self) -> Dict[str, str]: def get_tool_to_toolset_map(self) -> Dict[str, str]:
"""Return ``{tool_name: toolset_name}`` for every registered tool.""" """Return ``{tool_name: toolset_name}`` for every registered tool."""
return {name: e.toolset for name, e in self._tools.items()} return {entry.name: entry.toolset for entry in self._snapshot_entries()}
def is_toolset_available(self, toolset: str) -> bool: def is_toolset_available(self, toolset: str) -> bool:
"""Check if a toolset's requirements are met. """Check if a toolset's requirements are met.
@ -212,28 +259,30 @@ class ToolRegistry:
Returns False (rather than crashing) when the check function raises Returns False (rather than crashing) when the check function raises
an unexpected exception (e.g. network error, missing import, bad config). an unexpected exception (e.g. network error, missing import, bad config).
""" """
check = self._toolset_checks.get(toolset) with self._lock:
if not check: check = self._toolset_checks.get(toolset)
return True return self._evaluate_toolset_check(toolset, check)
try:
return bool(check())
except Exception:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def check_toolset_requirements(self) -> Dict[str, bool]: def check_toolset_requirements(self) -> Dict[str, bool]:
"""Return ``{toolset: available_bool}`` for every toolset.""" """Return ``{toolset: available_bool}`` for every toolset."""
toolsets = set(e.toolset for e in self._tools.values()) entries, toolset_checks = self._snapshot_state()
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)} toolsets = sorted({entry.toolset for entry in entries})
return {
toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset))
for toolset in toolsets
}
def get_available_toolsets(self) -> Dict[str, dict]: def get_available_toolsets(self) -> Dict[str, dict]:
"""Return toolset metadata for UI display.""" """Return toolset metadata for UI display."""
toolsets: Dict[str, dict] = {} toolsets: Dict[str, dict] = {}
for entry in self._tools.values(): entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset ts = entry.toolset
if ts not in toolsets: if ts not in toolsets:
toolsets[ts] = { toolsets[ts] = {
"available": self.is_toolset_available(ts), "available": self._evaluate_toolset_check(
ts, toolset_checks.get(ts)
),
"tools": [], "tools": [],
"description": "", "description": "",
"requirements": [], "requirements": [],
@ -248,13 +297,14 @@ class ToolRegistry:
def get_toolset_requirements(self) -> Dict[str, dict]: def get_toolset_requirements(self) -> Dict[str, dict]:
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat.""" """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
result: Dict[str, dict] = {} result: Dict[str, dict] = {}
for entry in self._tools.values(): entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset ts = entry.toolset
if ts not in result: if ts not in result:
result[ts] = { result[ts] = {
"name": ts, "name": ts,
"env_vars": [], "env_vars": [],
"check_fn": self._toolset_checks.get(ts), "check_fn": toolset_checks.get(ts),
"setup_url": None, "setup_url": None,
"tools": [], "tools": [],
} }
@ -270,18 +320,19 @@ class ToolRegistry:
available = [] available = []
unavailable = [] unavailable = []
seen = set() seen = set()
for entry in self._tools.values(): entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset ts = entry.toolset
if ts in seen: if ts in seen:
continue continue
seen.add(ts) seen.add(ts)
if self.is_toolset_available(ts): if self._evaluate_toolset_check(ts, toolset_checks.get(ts)):
available.append(ts) available.append(ts)
else: else:
unavailable.append({ unavailable.append({
"name": ts, "name": ts,
"env_vars": entry.requires_env, "env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts], "tools": [e.name for e in entries if e.toolset == ts],
}) })
return available, unavailable return available, unavailable

View file

@ -449,7 +449,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]:
if name in _get_plugin_toolset_names(): if name in _get_plugin_toolset_names():
try: try:
from tools.registry import registry from tools.registry import registry
return [e.name for e in registry._tools.values() if e.toolset == name] return registry.get_tool_names_for_toolset(name)
except Exception: except Exception:
pass pass
return [] return []
@ -495,9 +495,9 @@ def _get_plugin_toolset_names() -> Set[str]:
try: try:
from tools.registry import registry from tools.registry import registry
return { return {
entry.toolset toolset_name
for entry in registry._tools.values() for toolset_name in registry.get_registered_toolset_names()
if entry.toolset not in TOOLSETS if toolset_name not in TOOLSETS
} }
except Exception: except Exception:
return set() return set()
@ -518,7 +518,7 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]:
if ts_name not in result: if ts_name not in result:
try: try:
from tools.registry import registry from tools.registry import registry
tools = [e.name for e in registry._tools.values() if e.toolset == ts_name] tools = registry.get_tool_names_for_toolset(ts_name)
result[ts_name] = { result[ts_name] = {
"description": f"Plugin toolset: {ts_name}", "description": f"Plugin toolset: {ts_name}",
"tools": tools, "tools": tools,