fix: make tool registry reads thread-safe

This commit is contained in:
Greer Guthrie 2026-04-14 00:46:41 -05:00 committed by Teknium
parent 6dc8f8e9c0
commit c7e2fe655a
5 changed files with 341 additions and 62 deletions

View file

@ -1,6 +1,7 @@
"""Tests for the central tool registry."""
import json
import threading
from tools.registry import ToolRegistry
@ -167,6 +168,32 @@ class TestToolsetAvailability:
)
assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
def test_get_registered_toolset_names(self):
reg = ToolRegistry()
reg.register(
name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
)
assert reg.get_registered_toolset_names() == ["alpha", "zeta"]
def test_get_tool_names_for_toolset(self):
reg = ToolRegistry()
reg.register(
name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
)
reg.register(
name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler
)
assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"]
def test_handler_exception_returns_error(self):
reg = ToolRegistry()
@ -301,6 +328,22 @@ class TestEmojiMetadata:
assert reg.get_emoji("t") == ""
class TestEntryLookup:
def test_get_entry_returns_registered_entry(self):
reg = ToolRegistry()
reg.register(
name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler
)
entry = reg.get_entry("alpha")
assert entry is not None
assert entry.name == "alpha"
assert entry.toolset == "core"
def test_get_entry_returns_none_for_unknown_tool(self):
reg = ToolRegistry()
assert reg.get_entry("missing") is None
class TestSecretCaptureResultContract:
def test_secret_request_result_does_not_include_secret_value(self):
result = {
@ -309,3 +352,141 @@ class TestSecretCaptureResultContract:
"validated": False,
}
assert "secret" not in json.dumps(result).lower()
class TestThreadSafety:
def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch):
reg = ToolRegistry()
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=lambda: False,
)
entries, toolset_checks = reg._snapshot_state()
def snapshot_then_mutate():
reg.deregister("alpha")
return entries, toolset_checks
monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate)
toolsets = reg.get_available_toolsets()
assert toolsets["gated"]["available"] is False
assert toolsets["gated"]["tools"] == ["alpha"]
def test_check_tool_availability_tolerates_concurrent_register(self):
reg = ToolRegistry()
check_started = threading.Event()
writer_done = threading.Event()
errors = []
result_holder = {}
writer_completed_during_check = {}
def blocking_check():
check_started.set()
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
return True
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=blocking_check,
)
reg.register(
name="beta",
toolset="plain",
schema=_make_schema("beta"),
handler=_dummy_handler,
)
def reader():
try:
result_holder["value"] = reg.check_tool_availability()
except Exception as exc: # pragma: no cover - exercised on failure only
errors.append(exc)
def writer():
assert check_started.wait(timeout=1)
reg.register(
name="gamma",
toolset="new",
schema=_make_schema("gamma"),
handler=_dummy_handler,
)
writer_done.set()
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join(timeout=2)
writer_thread.join(timeout=2)
assert not reader_thread.is_alive()
assert not writer_thread.is_alive()
assert writer_completed_during_check["value"] is True
assert errors == []
available, unavailable = result_holder["value"]
assert "gated" in available
assert "plain" in available
assert unavailable == []
def test_get_available_toolsets_tolerates_concurrent_deregister(self):
reg = ToolRegistry()
check_started = threading.Event()
writer_done = threading.Event()
errors = []
result_holder = {}
writer_completed_during_check = {}
def blocking_check():
check_started.set()
writer_completed_during_check["value"] = writer_done.wait(timeout=1)
return True
reg.register(
name="alpha",
toolset="gated",
schema=_make_schema("alpha"),
handler=_dummy_handler,
check_fn=blocking_check,
)
reg.register(
name="beta",
toolset="plain",
schema=_make_schema("beta"),
handler=_dummy_handler,
)
def reader():
try:
result_holder["value"] = reg.get_available_toolsets()
except Exception as exc: # pragma: no cover - exercised on failure only
errors.append(exc)
def writer():
assert check_started.wait(timeout=1)
reg.deregister("beta")
writer_done.set()
reader_thread = threading.Thread(target=reader)
writer_thread = threading.Thread(target=writer)
reader_thread.start()
writer_thread.start()
reader_thread.join(timeout=2)
writer_thread.join(timeout=2)
assert not reader_thread.is_alive()
assert not writer_thread.is_alive()
assert writer_completed_during_check["value"] is True
assert errors == []
toolsets = result_holder["value"]
assert "gated" in toolsets
assert toolsets["gated"]["available"] is True