mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
fix(agent): add tool-call loop guardrails
This commit is contained in:
parent
8d7500d80d
commit
58b89965c8
5 changed files with 944 additions and 108 deletions
142
tests/agent/test_tool_guardrails.py
Normal file
142
tests/agent/test_tool_guardrails.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""Pure tool-call guardrail primitive tests."""
|
||||
|
||||
import json
|
||||
|
||||
from agent.tool_guardrails import (
|
||||
ToolCallGuardrailConfig,
|
||||
ToolCallGuardrailController,
|
||||
ToolCallSignature,
|
||||
canonical_tool_args,
|
||||
)
|
||||
|
||||
|
||||
def test_tool_call_signature_hashes_canonical_nested_unicode_args_without_exposing_raw_args():
|
||||
args_a = {
|
||||
"z": [{"β": "☤", "a": 1}],
|
||||
"a": {"y": 2, "x": "secret-token-value"},
|
||||
}
|
||||
args_b = {
|
||||
"a": {"x": "secret-token-value", "y": 2},
|
||||
"z": [{"a": 1, "β": "☤"}],
|
||||
}
|
||||
|
||||
assert canonical_tool_args(args_a) == canonical_tool_args(args_b)
|
||||
sig_a = ToolCallSignature.from_call("web_search", args_a)
|
||||
sig_b = ToolCallSignature.from_call("web_search", args_b)
|
||||
|
||||
assert sig_a == sig_b
|
||||
assert len(sig_a.args_hash) == 64
|
||||
metadata = sig_a.to_metadata()
|
||||
assert metadata == {"tool_name": "web_search", "args_hash": sig_a.args_hash}
|
||||
assert "secret-token-value" not in json.dumps(metadata)
|
||||
assert "☤" not in json.dumps(metadata)
|
||||
|
||||
|
||||
def test_repeated_identical_failed_call_warns_then_blocks_before_third_execution():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(
|
||||
exact_failure_warn_after=2,
|
||||
exact_failure_block_after=2,
|
||||
same_tool_failure_halt_after=99,
|
||||
)
|
||||
)
|
||||
args = {"query": "same"}
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
first = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert first.action == "allow"
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
second = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert second.action == "warn"
|
||||
assert second.code == "repeated_exact_failure_warning"
|
||||
assert second.count == 2
|
||||
|
||||
blocked = controller.before_call("web_search", args)
|
||||
assert blocked.action == "block"
|
||||
assert blocked.code == "repeated_exact_failure_block"
|
||||
assert blocked.tool_name == "web_search"
|
||||
assert blocked.count == 2
|
||||
|
||||
|
||||
def test_success_resets_exact_signature_failure_streak():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(exact_failure_block_after=2, same_tool_failure_halt_after=99)
|
||||
)
|
||||
args = {"query": "same"}
|
||||
|
||||
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("web_search", args, '{"ok":true}', failed=False)
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
|
||||
|
||||
def test_same_tool_varying_args_failure_streak_warns_then_halts_independent_of_exact_streak():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(
|
||||
exact_failure_block_after=99,
|
||||
same_tool_failure_warn_after=2,
|
||||
same_tool_failure_halt_after=3,
|
||||
)
|
||||
)
|
||||
|
||||
first = controller.after_call("terminal", {"command": "cmd-1"}, '{"exit_code":1}', failed=True)
|
||||
assert first.action == "allow"
|
||||
second = controller.after_call("terminal", {"command": "cmd-2"}, '{"exit_code":1}', failed=True)
|
||||
assert second.action == "warn"
|
||||
assert second.code == "same_tool_failure_warning"
|
||||
third = controller.after_call("terminal", {"command": "cmd-3"}, '{"exit_code":1}', failed=True)
|
||||
assert third.action == "halt"
|
||||
assert third.code == "same_tool_failure_halt"
|
||||
assert third.count == 3
|
||||
|
||||
|
||||
def test_idempotent_no_progress_repeated_result_warns_then_blocks_future_repeat():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||
)
|
||||
args = {"path": "/tmp/same.txt"}
|
||||
result = "same file contents"
|
||||
|
||||
assert controller.before_call("read_file", args).action == "allow"
|
||||
assert controller.after_call("read_file", args, result, failed=False).action == "allow"
|
||||
assert controller.before_call("read_file", args).action == "allow"
|
||||
warn = controller.after_call("read_file", args, result, failed=False)
|
||||
assert warn.action == "warn"
|
||||
assert warn.code == "idempotent_no_progress_warning"
|
||||
|
||||
blocked = controller.before_call("read_file", args)
|
||||
assert blocked.action == "block"
|
||||
assert blocked.code == "idempotent_no_progress_block"
|
||||
|
||||
|
||||
def test_mutating_or_unknown_tools_are_not_blocked_for_repeated_identical_success_output_by_default():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
assert controller.before_call("write_file", {"path": "/tmp/x", "content": "x"}).action == "allow"
|
||||
assert controller.after_call("write_file", {"path": "/tmp/x", "content": "x"}, "ok", failed=False).action == "allow"
|
||||
assert controller.before_call("custom_tool", {"x": 1}).action == "allow"
|
||||
assert controller.after_call("custom_tool", {"x": 1}, "ok", failed=False).action == "allow"
|
||||
|
||||
|
||||
def test_reset_for_turn_clears_bounded_guardrail_state():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(exact_failure_block_after=2, no_progress_block_after=2)
|
||||
)
|
||||
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||
|
||||
assert controller.before_call("web_search", {"query": "same"}).action == "block"
|
||||
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "block"
|
||||
|
||||
controller.reset_for_turn()
|
||||
|
||||
assert controller.before_call("web_search", {"query": "same"}).action == "allow"
|
||||
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "allow"
|
||||
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""Runtime tests for tool-call loop guardrails."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": f"{name} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
|
||||
def _mock_tool_call(name="web_search", arguments="{}", call_id=None):
|
||||
return SimpleNamespace(
|
||||
id=call_id or f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=SimpleNamespace(name=name, arguments=arguments),
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
|
||||
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
|
||||
|
||||
|
||||
def _make_agent(*tool_names: str, max_iterations: int = 10) -> AIAgent:
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
agent._cached_system_prompt = "You are helpful."
|
||||
agent._use_prompt_caching = False
|
||||
agent.tool_delay = 0
|
||||
agent.compression_enabled = False
|
||||
agent.save_trajectories = False
|
||||
return agent
|
||||
|
||||
|
||||
def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None:
|
||||
for _ in range(count):
|
||||
agent._tool_guardrails.after_call(
|
||||
tool_name,
|
||||
args,
|
||||
json.dumps({"error": "boom"}),
|
||||
failed=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_path_blocks_repeated_exact_failure_before_execution():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
_seed_exact_failures(agent, "web_search", args)
|
||||
starts = []
|
||||
progress = []
|
||||
agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
|
||||
agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-block")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc:
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
mock_hfc.assert_not_called()
|
||||
assert starts == []
|
||||
assert progress == []
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "tool"
|
||||
assert messages[0]["tool_call_id"] == "c-block"
|
||||
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||
|
||||
|
||||
def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
_seed_exact_failures(agent, "web_search", args, count=1)
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-warn")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})):
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
assert [m["role"] for m in messages] == ["tool"]
|
||||
assert messages[0]["tool_call_id"] == "c-warn"
|
||||
assert "Tool guardrail" in messages[0]["content"]
|
||||
assert "repeated_exact_failure_warning" in messages[0]["content"]
|
||||
|
||||
|
||||
def test_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order():
|
||||
agent = _make_agent("web_search")
|
||||
blocked_args = {"query": "blocked"}
|
||||
allowed_args = {"query": "allowed"}
|
||||
_seed_exact_failures(agent, "web_search", blocked_args)
|
||||
starts = []
|
||||
progress_events = []
|
||||
agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args))
|
||||
agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw))
|
||||
calls = [
|
||||
_mock_tool_call("web_search", json.dumps(blocked_args), "c-block"),
|
||||
_mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"),
|
||||
]
|
||||
msg = SimpleNamespace(content="", tool_calls=calls)
|
||||
messages = []
|
||||
executed = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
executed.append((name, args, kwargs["tool_call_id"]))
|
||||
return json.dumps({"ok": args["query"]})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(msg, messages, "task-1")
|
||||
|
||||
assert executed == [("web_search", allowed_args, "c-allow")]
|
||||
assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"]
|
||||
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||
assert json.loads(messages[1]["content"]) == {"ok": "allowed"}
|
||||
assert starts == [("c-allow", "web_search", allowed_args)]
|
||||
started_events = [event for event in progress_events if event[0] == "tool.started"]
|
||||
completed_events = [event for event in progress_events if event[0] == "tool.completed"]
|
||||
assert started_events == [("tool.started", "web_search", allowed_args, {})]
|
||||
assert len(completed_events) == 1
|
||||
assert completed_events[0][1] == "web_search"
|
||||
|
||||
|
||||
def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with (
|
||||
patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"),
|
||||
patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc,
|
||||
):
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
mock_hfc.assert_not_called()
|
||||
assert "plugin policy" in messages[0]["content"]
|
||||
assert agent._tool_guardrails.before_call("web_search", args).action == "allow"
|
||||
|
||||
|
||||
def test_run_conversation_returns_controlled_guardrail_halt_without_top_level_error():
|
||||
agent = _make_agent("web_search", max_iterations=10)
|
||||
same_args = {"query": "same"}
|
||||
responses = [
|
||||
_mock_response(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
|
||||
)
|
||||
for i in range(1, 10)
|
||||
]
|
||||
agent.client.chat.completions.create.side_effect = responses
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("search repeatedly")
|
||||
|
||||
assert mock_hfc.call_count == 2
|
||||
assert result["api_calls"] == 3
|
||||
assert result["api_calls"] < agent.max_iterations
|
||||
assert result["turn_exit_reason"] == "guardrail_halt"
|
||||
assert "error" not in result
|
||||
assert result["completed"] is True
|
||||
assert "stopped retrying" in result["final_response"]
|
||||
assert result["guardrail"]["code"] == "repeated_exact_failure_block"
|
||||
assert result["guardrail"]["tool_name"] == "web_search"
|
||||
|
||||
assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")]
|
||||
for assistant_msg in assistant_tool_calls:
|
||||
call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]]
|
||||
following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids]
|
||||
assert len(following_results) == len(call_ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue