fix(agent): add tool-call loop guardrails

This commit is contained in:
Mind-Dragon 2026-04-27 16:29:19 +02:00 committed by Teknium
parent 8d7500d80d
commit 58b89965c8
5 changed files with 944 additions and 108 deletions

View file

@ -14,6 +14,7 @@ from difflib import unified_diff
from pathlib import Path
from utils import safe_json_loads
from agent.tool_guardrails import classify_tool_failure
# ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m"
@ -808,30 +809,7 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
failures. On success, returns ``(False, "")``.
"""
if result is None:
return False, ""
if tool_name == "terminal":
data = safe_json_loads(result)
if isinstance(data, dict):
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
return False, ""
# Memory-specific: distinguish "full" from real errors
if tool_name == "memory":
data = safe_json_loads(result)
if isinstance(data, dict):
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]"
# Generic heuristic for non-terminal tools
lower = result[:500].lower()
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
return True, " [error]"
return False, ""
return classify_tool_failure(tool_name, result)
def get_cute_tool_message(

381
agent/tool_guardrails.py Normal file
View file

@ -0,0 +1,381 @@
"""Pure tool-call loop guardrail primitives.
The controller in this module is intentionally side-effect free: it tracks
per-turn tool-call observations and returns decisions. Runtime code owns whether
those decisions become synthetic tool results or controlled turn halts.
"""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, field
from typing import Any, Mapping
from utils import safe_json_loads
IDEMPOTENT_TOOL_NAMES = frozenset(
{
"read_file",
"search_files",
"web_search",
"web_extract",
"session_search",
"browser_snapshot",
"browser_console",
"browser_get_images",
"mcp_filesystem_read_file",
"mcp_filesystem_read_text_file",
"mcp_filesystem_read_multiple_files",
"mcp_filesystem_list_directory",
"mcp_filesystem_list_directory_with_sizes",
"mcp_filesystem_directory_tree",
"mcp_filesystem_get_file_info",
"mcp_filesystem_search_files",
}
)
MUTATING_TOOL_NAMES = frozenset(
{
"terminal",
"execute_code",
"write_file",
"patch",
"todo",
"memory",
"skill_manage",
"browser_click",
"browser_type",
"browser_press",
"browser_scroll",
"browser_navigate",
"send_message",
"cronjob",
"delegate_task",
"process",
}
)
@dataclass(frozen=True)
class ToolCallGuardrailConfig:
"""Thresholds for per-turn tool-call loop detection."""
exact_failure_warn_after: int = 2
exact_failure_block_after: int = 2
same_tool_failure_warn_after: int = 3
same_tool_failure_halt_after: int = 5
no_progress_warn_after: int = 2
no_progress_block_after: int = 2
idempotent_tools: frozenset[str] = field(default_factory=lambda: IDEMPOTENT_TOOL_NAMES)
mutating_tools: frozenset[str] = field(default_factory=lambda: MUTATING_TOOL_NAMES)
@dataclass(frozen=True)
class ToolCallSignature:
"""Stable, non-reversible identity for a tool name plus canonical args."""
tool_name: str
args_hash: str
@classmethod
def from_call(cls, tool_name: str, args: Mapping[str, Any] | None) -> "ToolCallSignature":
canonical = canonical_tool_args(args or {})
return cls(tool_name=tool_name, args_hash=_sha256(canonical))
def to_metadata(self) -> dict[str, str]:
"""Return public metadata without raw argument values."""
return {"tool_name": self.tool_name, "args_hash": self.args_hash}
@dataclass(frozen=True)
class ToolGuardrailDecision:
"""Decision returned by the tool-call guardrail controller."""
action: str = "allow" # allow | warn | block | halt
code: str = "allow"
message: str = ""
tool_name: str = ""
count: int = 0
signature: ToolCallSignature | None = None
@property
def allows_execution(self) -> bool:
return self.action in {"allow", "warn"}
@property
def should_halt(self) -> bool:
return self.action in {"block", "halt"}
def to_metadata(self) -> dict[str, Any]:
data: dict[str, Any] = {
"action": self.action,
"code": self.code,
"message": self.message,
"tool_name": self.tool_name,
"count": self.count,
}
if self.signature is not None:
data["signature"] = self.signature.to_metadata()
return data
def canonical_tool_args(args: Mapping[str, Any]) -> str:
"""Return sorted compact JSON for parsed tool arguments."""
if not isinstance(args, Mapping):
raise TypeError(f"tool args must be a mapping, got {type(args).__name__}")
return json.dumps(
args,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=str,
)
def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
"""Classify a tool result using shared display/runtime semantics."""
if result is None:
return False, ""
if tool_name == "terminal":
data = safe_json_loads(result)
if isinstance(data, dict):
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
if data.get("success") is False or data.get("failed") is True:
return True, " [error]"
error = data.get("error")
if error is not None and error != "":
return True, " [error]"
return False, ""
data = safe_json_loads(result)
if isinstance(data, dict):
if tool_name == "memory":
error = data.get("error", "")
if data.get("success") is False and isinstance(error, str) and "exceed the limit" in error:
return True, " [full]"
if data.get("success") is False or data.get("failed") is True:
return True, " [error]"
error = data.get("error")
if error is not None and error != "":
return True, " [error]"
return False, ""
lower = result[:500].lower()
if "traceback" in lower or lower.startswith("error:"):
return True, " [error]"
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
return True, " [error]"
return False, ""
class ToolCallGuardrailController:
"""Per-turn controller for repeated failed/non-progressing tool calls."""
def __init__(self, config: ToolCallGuardrailConfig | None = None):
self.config = config or ToolCallGuardrailConfig()
self.reset_for_turn()
def reset_for_turn(self) -> None:
self._exact_failure_counts: dict[ToolCallSignature, int] = {}
self._same_tool_failure_counts: dict[str, int] = {}
self._no_progress: dict[ToolCallSignature, tuple[str, int]] = {}
self._halt_decision: ToolGuardrailDecision | None = None
@property
def halt_decision(self) -> ToolGuardrailDecision | None:
return self._halt_decision
def before_call(self, tool_name: str, args: Mapping[str, Any] | None) -> ToolGuardrailDecision:
signature = ToolCallSignature.from_call(tool_name, _coerce_args(args))
exact_count = self._exact_failure_counts.get(signature, 0)
if exact_count >= self.config.exact_failure_block_after:
decision = ToolGuardrailDecision(
action="block",
code="repeated_exact_failure_block",
message=(
f"Blocked {tool_name}: the same tool call failed {exact_count} "
"times with identical arguments. Stop retrying it unchanged; "
"change strategy or explain the blocker."
),
tool_name=tool_name,
count=exact_count,
signature=signature,
)
self._halt_decision = decision
return decision
if self._is_idempotent(tool_name):
record = self._no_progress.get(signature)
if record is not None:
_result_hash, repeat_count = record
if repeat_count >= self.config.no_progress_block_after:
decision = ToolGuardrailDecision(
action="block",
code="idempotent_no_progress_block",
message=(
f"Blocked {tool_name}: this read-only call returned the same "
f"result {repeat_count} times. Stop repeating it unchanged; "
"use the result already provided or try a different query."
),
tool_name=tool_name,
count=repeat_count,
signature=signature,
)
self._halt_decision = decision
return decision
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
def after_call(
self,
tool_name: str,
args: Mapping[str, Any] | None,
result: str | None,
*,
failed: bool | None = None,
) -> ToolGuardrailDecision:
args = _coerce_args(args)
signature = ToolCallSignature.from_call(tool_name, args)
if failed is None:
failed, _ = classify_tool_failure(tool_name, result)
if failed:
exact_count = self._exact_failure_counts.get(signature, 0) + 1
self._exact_failure_counts[signature] = exact_count
self._no_progress.pop(signature, None)
same_count = self._same_tool_failure_counts.get(tool_name, 0) + 1
self._same_tool_failure_counts[tool_name] = same_count
if same_count >= self.config.same_tool_failure_halt_after:
decision = ToolGuardrailDecision(
action="halt",
code="same_tool_failure_halt",
message=(
f"Stopped {tool_name}: it failed {same_count} times this turn. "
"Stop retrying the same failing tool path and choose a different approach."
),
tool_name=tool_name,
count=same_count,
signature=signature,
)
self._halt_decision = decision
return decision
if exact_count >= self.config.exact_failure_warn_after:
return ToolGuardrailDecision(
action="warn",
code="repeated_exact_failure_warning",
message=(
f"Tool guardrail: {tool_name} has failed {exact_count} times "
"with identical arguments. Do not retry it unchanged; inspect the "
"error and change strategy."
),
tool_name=tool_name,
count=exact_count,
signature=signature,
)
if same_count >= self.config.same_tool_failure_warn_after:
return ToolGuardrailDecision(
action="warn",
code="same_tool_failure_warning",
message=(
f"Tool guardrail: {tool_name} has failed {same_count} times "
"this turn. Change approach before retrying."
),
tool_name=tool_name,
count=same_count,
signature=signature,
)
return ToolGuardrailDecision(tool_name=tool_name, count=exact_count, signature=signature)
self._exact_failure_counts.pop(signature, None)
self._same_tool_failure_counts.pop(tool_name, None)
if not self._is_idempotent(tool_name):
self._no_progress.pop(signature, None)
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
result_hash = _result_hash(result)
previous = self._no_progress.get(signature)
repeat_count = 1
if previous is not None and previous[0] == result_hash:
repeat_count = previous[1] + 1
self._no_progress[signature] = (result_hash, repeat_count)
if repeat_count >= self.config.no_progress_warn_after:
return ToolGuardrailDecision(
action="warn",
code="idempotent_no_progress_warning",
message=(
f"Tool guardrail: {tool_name} returned the same result "
f"{repeat_count} times. Use the result or change the query instead "
"of repeating it unchanged."
),
tool_name=tool_name,
count=repeat_count,
signature=signature,
)
return ToolGuardrailDecision(tool_name=tool_name, count=repeat_count, signature=signature)
def _is_idempotent(self, tool_name: str) -> bool:
if tool_name in self.config.mutating_tools:
return False
return tool_name in self.config.idempotent_tools
def toolguard_synthetic_result(decision: ToolGuardrailDecision) -> str:
"""Build a synthetic role=tool content string for a blocked tool call."""
return json.dumps(
{
"error": decision.message,
"guardrail": decision.to_metadata(),
},
ensure_ascii=False,
)
def append_toolguard_guidance(result: str, decision: ToolGuardrailDecision) -> str:
"""Append runtime guidance to the current tool result content."""
if decision.action not in {"warn", "halt"} or not decision.message:
return result
suffix = (
"\n\n[Tool guardrail: "
f"{decision.code}; count={decision.count}; {decision.message}]"
)
return (result or "") + suffix
def _coerce_args(args: Mapping[str, Any] | None) -> Mapping[str, Any]:
return args if isinstance(args, Mapping) else {}
def _result_hash(result: str | None) -> str:
parsed = safe_json_loads(result or "")
if parsed is not None:
try:
canonical = json.dumps(
parsed,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=str,
)
except TypeError:
canonical = str(parsed)
else:
canonical = result or ""
return _sha256(canonical)
def _sha256(value: str) -> str:
return hashlib.sha256(value.encode("utf-8")).hexdigest()

View file

@ -162,6 +162,12 @@ from agent.display import (
_detect_tool_failure,
get_tool_emoji as _get_tool_emoji,
)
from agent.tool_guardrails import (
ToolCallGuardrailController,
ToolGuardrailDecision,
append_toolguard_guidance,
toolguard_synthetic_result,
)
from agent.trajectory import (
convert_scratchpad_to_think, has_incomplete_scratchpad,
save_trajectory as _save_trajectory_to_file,
@ -1150,6 +1156,8 @@ class AIAgent:
# Tool execution state — allows _vprint during tool execution
# even when stream consumers are registered (no tokens streaming then)
self._executing_tools = False
self._tool_guardrails = ToolCallGuardrailController()
self._tool_guardrail_halt_decision: ToolGuardrailDecision | None = None
# Interrupt mechanism for breaking out of tool loops
self._interrupt_requested = False
@ -9107,6 +9115,44 @@ class AIAgent:
)
return compressed, new_system_prompt
def _set_tool_guardrail_halt(self, decision: ToolGuardrailDecision) -> None:
"""Record the first guardrail decision that should stop this turn."""
if decision.should_halt and self._tool_guardrail_halt_decision is None:
self._tool_guardrail_halt_decision = decision
def _toolguard_controlled_halt_response(self, decision: ToolGuardrailDecision) -> str:
tool = decision.tool_name or "a tool"
return (
f"I stopped retrying {tool} because it hit the tool-call guardrail "
f"({decision.code}) after {decision.count} repeated non-progressing "
"attempts. The last tool result explains the blocker; the next step is "
"to change strategy instead of repeating the same call."
)
def _append_guardrail_observation(
self,
tool_name: str,
function_args: dict,
function_result: str,
*,
failed: bool,
) -> str:
decision = self._tool_guardrails.after_call(
tool_name,
function_args,
function_result,
failed=failed,
)
if decision.action in {"warn", "halt"}:
function_result = append_toolguard_guidance(function_result, decision)
if decision.should_halt:
self._set_tool_guardrail_halt(decision)
return function_result
def _guardrail_block_result(self, decision: ToolGuardrailDecision) -> str:
self._set_tool_guardrail_halt(decision)
return toolguard_synthetic_result(decision)
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
"""Execute tool calls from the assistant message and append results to messages.
@ -9150,7 +9196,8 @@ class AIAgent:
)
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
tool_call_id: Optional[str] = None, messages: list = None) -> str:
tool_call_id: Optional[str] = None, messages: list = None,
pre_tool_block_checked: bool = False) -> str:
"""Invoke a single tool and return the result string. No display logic.
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
@ -9159,13 +9206,14 @@ class AIAgent:
"""
# Check plugin hooks for a block directive before executing anything.
block_message: Optional[str] = None
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
block_message = get_pre_tool_call_block_message(
function_name, function_args, task_id=effective_task_id or "",
)
except Exception:
pass
if not pre_tool_block_checked:
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
block_message = get_pre_tool_call_block_message(
function_name, function_args, task_id=effective_task_id or "",
)
except Exception:
pass
if block_message is not None:
return json.dumps({"error": block_message}, ensure_ascii=False)
@ -9317,13 +9365,31 @@ class AIAgent:
except Exception:
pass
parsed_calls.append((tool_call, function_name, function_args))
block_result = None
blocked_by_guardrail = False
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
block_message = get_pre_tool_call_block_message(
function_name, function_args, task_id=effective_task_id or "",
)
except Exception:
block_message = None
if block_message is not None:
block_result = json.dumps({"error": block_message}, ensure_ascii=False)
else:
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
if not guardrail_decision.allows_execution:
block_result = self._guardrail_block_result(guardrail_decision)
blocked_by_guardrail = True
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
# ── Logging / callbacks ──────────────────────────────────────────
tool_names_str = ", ".join(name for _, name, _ in parsed_calls)
tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls)
if not self.quiet_mode:
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
for i, (tc, name, args) in enumerate(parsed_calls, 1):
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
args_str = json.dumps(args, ensure_ascii=False)
if self.verbose_logging:
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
@ -9332,7 +9398,9 @@ class AIAgent:
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
for tc, name, args in parsed_calls:
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
if block_result is not None:
continue
if self.tool_progress_callback:
try:
preview = _build_tool_preview(name, args)
@ -9340,7 +9408,9 @@ class AIAgent:
except Exception as cb_err:
logging.debug(f"Tool progress callback error: {cb_err}")
for tc, name, args in parsed_calls:
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
if block_result is not None:
continue
if self.tool_start_callback:
try:
self.tool_start_callback(tc.id, name, args)
@ -9348,8 +9418,11 @@ class AIAgent:
logging.debug(f"Tool start callback error: {cb_err}")
# ── Concurrent execution ─────────────────────────────────────────
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag)
results = [None] * num_tools
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
if block_result is not None:
results[i] = (name, args, block_result, 0.0, True, True)
# Touch activity before launching workers so the gateway knows
# we're executing tools (not stuck).
@ -9404,7 +9477,14 @@ class AIAgent:
pass
start = time.time()
try:
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id, messages=messages)
result = self._invoke_tool(
function_name,
function_args,
effective_task_id,
tool_call.id,
messages=messages,
pre_tool_block_checked=True,
)
except Exception as tool_error:
result = f"Error executing tool '{function_name}': {tool_error}"
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
@ -9414,7 +9494,7 @@ class AIAgent:
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
else:
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
results[index] = (function_name, function_args, result, duration, is_error)
results[index] = (function_name, function_args, result, duration, is_error, False)
# Tear down worker-tid tracking. Clear any interrupt bit we may
# have set so the next task scheduled onto this recycled tid
# starts with a clean slate.
@ -9440,61 +9520,67 @@ class AIAgent:
spinner.start()
try:
max_workers = min(num_tools, _MAX_TOOL_WORKERS)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, (tc, name, args) in enumerate(parsed_calls):
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
ctx = contextvars.copy_context()
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
futures.append(f)
runnable_calls = [
(i, tc, name, args)
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
if block_result is None
]
futures = []
if runnable_calls:
max_workers = min(len(runnable_calls), _MAX_TOOL_WORKERS)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for i, tc, name, args in runnable_calls:
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
ctx = contextvars.copy_context()
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
futures.append(f)
# Wait for all to complete with periodic heartbeats so the
# gateway's inactivity monitor doesn't kill us during long
# concurrent tool batches. Also check for user interrupts
# so we don't block indefinitely when the user sends /stop
# or a new message during concurrent tool execution.
_conc_start = time.time()
_interrupt_logged = False
while True:
done, not_done = concurrent.futures.wait(
futures, timeout=5.0,
)
if not not_done:
break
# Check for interrupt — the per-thread interrupt signal
# already causes individual tools (terminal, execute_code)
# to abort, but tools without interrupt checks (web_search,
# read_file) will run to completion. Cancel any futures
# that haven't started yet so we don't block on them.
if self._interrupt_requested:
if not _interrupt_logged:
_interrupt_logged = True
self._vprint(
f"{self.log_prefix}⚡ Interrupt: cancelling "
f"{len(not_done)} pending concurrent tool(s)",
force=True,
)
for f in not_done:
f.cancel()
# Give already-running tools a moment to notice the
# per-thread interrupt signal and exit gracefully.
concurrent.futures.wait(not_done, timeout=3.0)
break
_conc_elapsed = int(time.time() - _conc_start)
# Heartbeat every ~30s (6 × 5s poll intervals)
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
_still_running = [
parsed_calls[futures.index(f)][1]
for f in not_done
if f in futures
]
self._touch_activity(
f"concurrent tools running ({_conc_elapsed}s, "
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
# Wait for all to complete with periodic heartbeats so the
# gateway's inactivity monitor doesn't kill us during long
# concurrent tool batches. Also check for user interrupts
# so we don't block indefinitely when the user sends /stop
# or a new message during concurrent tool execution.
_conc_start = time.time()
_interrupt_logged = False
while True:
done, not_done = concurrent.futures.wait(
futures, timeout=5.0,
)
if not not_done:
break
# Check for interrupt — the per-thread interrupt signal
# already causes individual tools (terminal, execute_code)
# to abort, but tools without interrupt checks (web_search,
# read_file) will run to completion. Cancel any futures
# that haven't started yet so we don't block on them.
if self._interrupt_requested:
if not _interrupt_logged:
_interrupt_logged = True
self._vprint(
f"{self.log_prefix}⚡ Interrupt: cancelling "
f"{len(not_done)} pending concurrent tool(s)",
force=True,
)
for f in not_done:
f.cancel()
# Give already-running tools a moment to notice the
# per-thread interrupt signal and exit gracefully.
concurrent.futures.wait(not_done, timeout=3.0)
break
_conc_elapsed = int(time.time() - _conc_start)
# Heartbeat every ~30s (6 × 5s poll intervals)
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
_still_running = [
parsed_calls[futures.index(f)][1]
for f in not_done
if f in futures
]
self._touch_activity(
f"concurrent tools running ({_conc_elapsed}s, "
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
)
finally:
if spinner:
# Build a summary message for the spinner stop
@ -9503,8 +9589,9 @@ class AIAgent:
spinner.stop(f"{completed}/{num_tools} tools completed in {total_dur:.1f}s total")
# ── Post-execution: display per-tool results ─────────────────────
for i, (tc, name, args) in enumerate(parsed_calls):
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
r = results[i]
blocked = False
if r is None:
# Tool was cancelled (interrupt) or thread didn't return
if self._interrupt_requested:
@ -9513,13 +9600,21 @@ class AIAgent:
function_result = f"Error executing tool '{name}': thread did not return a result"
tool_duration = 0.0
else:
function_name, function_args, function_result, tool_duration, is_error = r
function_name, function_args, function_result, tool_duration, is_error, blocked = r
if not blocked:
function_result = self._append_guardrail_observation(
function_name,
function_args,
function_result,
failed=is_error,
)
if is_error:
result_preview = function_result[:200] if len(function_result) > 200 else function_result
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
if self.tool_progress_callback:
if not blocked and self.tool_progress_callback:
try:
self.tool_progress_callback(
"tool.completed", function_name, None, None,
@ -9547,7 +9642,7 @@ class AIAgent:
self._current_tool = None
self._touch_activity(f"tool completed: {name} ({tool_duration:.1f}s)")
if self.tool_complete_callback:
if not blocked and self.tool_complete_callback:
try:
self.tool_complete_callback(tc.id, name, args, function_result)
except Exception as cb_err:
@ -9629,9 +9724,17 @@ class AIAgent:
except Exception:
pass
if _block_msg is not None:
# Tool blocked by plugin policy — skip counter resets.
# Execution is handled below in the tool dispatch chain.
_guardrail_block_decision: ToolGuardrailDecision | None = None
if _block_msg is None:
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
if not guardrail_decision.allows_execution:
_guardrail_block_decision = guardrail_decision
_execution_blocked = _block_msg is not None or _guardrail_block_decision is not None
if _execution_blocked:
# Tool blocked by plugin or guardrail policy — skip counters,
# callbacks, checkpointing, activity mutation, and real execution.
pass
else:
# Reset nudge counters when the relevant tool is actually used
@ -9649,35 +9752,35 @@ class AIAgent:
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
if _block_msg is None:
if not _execution_blocked:
self._current_tool = function_name
self._touch_activity(f"executing tool: {function_name}")
# Set activity callback for long-running tool execution (terminal
# commands, etc.) so the gateway's inactivity monitor doesn't kill
# the agent while a command is running.
if _block_msg is None:
if not _execution_blocked:
try:
from tools.environments.base import set_activity_callback
set_activity_callback(self._touch_activity)
except Exception:
pass
if _block_msg is None and self.tool_progress_callback:
if not _execution_blocked and self.tool_progress_callback:
try:
preview = _build_tool_preview(function_name, function_args)
self.tool_progress_callback("tool.started", function_name, preview, function_args)
except Exception as cb_err:
logging.debug(f"Tool progress callback error: {cb_err}")
if _block_msg is None and self.tool_start_callback:
if not _execution_blocked and self.tool_start_callback:
try:
self.tool_start_callback(tool_call.id, function_name, function_args)
except Exception as cb_err:
logging.debug(f"Tool start callback error: {cb_err}")
# Checkpoint: snapshot working dir before file-mutating tools
if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
if not _execution_blocked and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
try:
file_path = function_args.get("path", "")
if file_path:
@ -9689,7 +9792,7 @@ class AIAgent:
pass # never block tool execution
# Checkpoint before destructive terminal commands
if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled:
if not _execution_blocked and function_name == "terminal" and self._checkpoint_mgr.enabled:
try:
cmd = function_args.get("command", "")
if _is_destructive_command(cmd):
@ -9706,6 +9809,11 @@ class AIAgent:
# Tool blocked by plugin policy — return error without executing.
function_result = json.dumps({"error": _block_msg}, ensure_ascii=False)
tool_duration = 0.0
elif _guardrail_block_decision is not None:
# Tool blocked by tool-loop guardrail — synthesize exactly one
# tool result for the original tool_call_id without executing.
function_result = self._guardrail_block_result(_guardrail_block_decision)
tool_duration = 0.0
elif function_name == "todo":
from tools.todo_tool import todo_tool as _todo_tool
function_result = _todo_tool(
@ -9889,12 +9997,22 @@ class AIAgent:
# Log tool errors to the persistent error log so [error] tags
# in the UI always have a corresponding detailed entry on disk.
_is_error_result, _ = _detect_tool_failure(function_name, function_result)
if not _execution_blocked:
function_result = self._append_guardrail_observation(
function_name,
function_args,
function_result,
failed=_is_error_result,
)
result_preview = function_result if self.verbose_logging else (
function_result[:200] if len(function_result) > 200 else function_result
)
if _is_error_result:
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
else:
logger.info("tool %s completed (%.2fs, %d chars)", function_name, tool_duration, len(function_result))
if self.tool_progress_callback:
if not _execution_blocked and self.tool_progress_callback:
try:
self.tool_progress_callback(
"tool.completed", function_name, None, None,
@ -9910,7 +10028,7 @@ class AIAgent:
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
if self.tool_complete_callback:
if not _execution_blocked and self.tool_complete_callback:
try:
self.tool_complete_callback(tool_call.id, function_name, function_args, function_result)
except Exception as cb_err:
@ -10244,6 +10362,8 @@ class AIAgent:
self._last_content_tools_all_housekeeping = False
self._mute_post_response = False
self._unicode_sanitization_passes = 0
self._tool_guardrails.reset_for_turn()
self._tool_guardrail_halt_decision = None
# Pre-turn connection health check: detect and clean up dead TCP
# connections left over from provider outages or dropped streams.
@ -13041,6 +13161,16 @@ class AIAgent:
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
if self._tool_guardrail_halt_decision is not None:
decision = self._tool_guardrail_halt_decision
_turn_exit_reason = "guardrail_halt"
final_response = self._toolguard_controlled_halt_response(decision)
self._emit_status(
f"⚠️ Tool guardrail halted {decision.tool_name}: {decision.code}"
)
messages.append({"role": "assistant", "content": final_response})
break
# Reset per-turn retry counters after successful tool
# execution so a single truncation doesn't poison the
# entire conversation.
@ -13567,6 +13697,7 @@ class AIAgent:
"messages": messages,
"api_calls": api_call_count,
"completed": completed,
"turn_exit_reason": _turn_exit_reason,
"partial": False, # True only when stopped due to invalid tool calls
"interrupted": interrupted,
"response_previewed": getattr(self, "_response_was_previewed", False),
@ -13586,6 +13717,8 @@ class AIAgent:
"cost_status": self.session_cost_status,
"cost_source": self.session_cost_source,
}
if self._tool_guardrail_halt_decision is not None:
result["guardrail"] = self._tool_guardrail_halt_decision.to_metadata()
# If a /steer landed after the final assistant turn (no more tool
# batches to drain into), hand it back to the caller so it can be
# delivered as the next user turn instead of being silently lost.

View file

@ -0,0 +1,142 @@
"""Pure tool-call guardrail primitive tests."""
import json
from agent.tool_guardrails import (
ToolCallGuardrailConfig,
ToolCallGuardrailController,
ToolCallSignature,
canonical_tool_args,
)
def test_tool_call_signature_hashes_canonical_nested_unicode_args_without_exposing_raw_args():
args_a = {
"z": [{"β": "", "a": 1}],
"a": {"y": 2, "x": "secret-token-value"},
}
args_b = {
"a": {"x": "secret-token-value", "y": 2},
"z": [{"a": 1, "β": ""}],
}
assert canonical_tool_args(args_a) == canonical_tool_args(args_b)
sig_a = ToolCallSignature.from_call("web_search", args_a)
sig_b = ToolCallSignature.from_call("web_search", args_b)
assert sig_a == sig_b
assert len(sig_a.args_hash) == 64
metadata = sig_a.to_metadata()
assert metadata == {"tool_name": "web_search", "args_hash": sig_a.args_hash}
assert "secret-token-value" not in json.dumps(metadata)
assert "" not in json.dumps(metadata)
def test_repeated_identical_failed_call_warns_then_blocks_before_third_execution():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(
exact_failure_warn_after=2,
exact_failure_block_after=2,
same_tool_failure_halt_after=99,
)
)
args = {"query": "same"}
assert controller.before_call("web_search", args).action == "allow"
first = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
assert first.action == "allow"
assert controller.before_call("web_search", args).action == "allow"
second = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
assert second.action == "warn"
assert second.code == "repeated_exact_failure_warning"
assert second.count == 2
blocked = controller.before_call("web_search", args)
assert blocked.action == "block"
assert blocked.code == "repeated_exact_failure_block"
assert blocked.tool_name == "web_search"
assert blocked.count == 2
def test_success_resets_exact_signature_failure_streak():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(exact_failure_block_after=2, same_tool_failure_halt_after=99)
)
args = {"query": "same"}
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
controller.after_call("web_search", args, '{"ok":true}', failed=False)
assert controller.before_call("web_search", args).action == "allow"
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
assert controller.before_call("web_search", args).action == "allow"
def test_same_tool_varying_args_failure_streak_warns_then_halts_independent_of_exact_streak():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(
exact_failure_block_after=99,
same_tool_failure_warn_after=2,
same_tool_failure_halt_after=3,
)
)
first = controller.after_call("terminal", {"command": "cmd-1"}, '{"exit_code":1}', failed=True)
assert first.action == "allow"
second = controller.after_call("terminal", {"command": "cmd-2"}, '{"exit_code":1}', failed=True)
assert second.action == "warn"
assert second.code == "same_tool_failure_warning"
third = controller.after_call("terminal", {"command": "cmd-3"}, '{"exit_code":1}', failed=True)
assert third.action == "halt"
assert third.code == "same_tool_failure_halt"
assert third.count == 3
def test_idempotent_no_progress_repeated_result_warns_then_blocks_future_repeat():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
)
args = {"path": "/tmp/same.txt"}
result = "same file contents"
assert controller.before_call("read_file", args).action == "allow"
assert controller.after_call("read_file", args, result, failed=False).action == "allow"
assert controller.before_call("read_file", args).action == "allow"
warn = controller.after_call("read_file", args, result, failed=False)
assert warn.action == "warn"
assert warn.code == "idempotent_no_progress_warning"
blocked = controller.before_call("read_file", args)
assert blocked.action == "block"
assert blocked.code == "idempotent_no_progress_block"
def test_mutating_or_unknown_tools_are_not_blocked_for_repeated_identical_success_output_by_default():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
)
for _ in range(3):
assert controller.before_call("write_file", {"path": "/tmp/x", "content": "x"}).action == "allow"
assert controller.after_call("write_file", {"path": "/tmp/x", "content": "x"}, "ok", failed=False).action == "allow"
assert controller.before_call("custom_tool", {"x": 1}).action == "allow"
assert controller.after_call("custom_tool", {"x": 1}, "ok", failed=False).action == "allow"
def test_reset_for_turn_clears_bounded_guardrail_state():
controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(exact_failure_block_after=2, no_progress_block_after=2)
)
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
assert controller.before_call("web_search", {"query": "same"}).action == "block"
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "block"
controller.reset_for_turn()
assert controller.before_call("web_search", {"query": "same"}).action == "allow"
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "allow"

View file

@ -0,0 +1,202 @@
"""Runtime tests for tool-call loop guardrails."""
import json
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from run_agent import AIAgent
def _make_tool_defs(*names: str) -> list[dict]:
return [
{
"type": "function",
"function": {
"name": name,
"description": f"{name} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for name in names
]
def _mock_tool_call(name="web_search", arguments="{}", call_id=None):
return SimpleNamespace(
id=call_id or f"call_{uuid.uuid4().hex[:8]}",
type="function",
function=SimpleNamespace(name=name, arguments=arguments),
)
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
def _make_agent(*tool_names: str, max_iterations: int = 10) -> AIAgent:
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
agent = AIAgent(
api_key="test-key-1234567890",
base_url="https://openrouter.ai/api/v1",
max_iterations=max_iterations,
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.client = MagicMock()
agent._cached_system_prompt = "You are helpful."
agent._use_prompt_caching = False
agent.tool_delay = 0
agent.compression_enabled = False
agent.save_trajectories = False
return agent
def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None:
for _ in range(count):
agent._tool_guardrails.after_call(
tool_name,
args,
json.dumps({"error": "boom"}),
failed=True,
)
def test_sequential_path_blocks_repeated_exact_failure_before_execution():
agent = _make_agent("web_search")
args = {"query": "same"}
_seed_exact_failures(agent, "web_search", args)
starts = []
progress = []
agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
tc = _mock_tool_call("web_search", json.dumps(args), "c-block")
msg = SimpleNamespace(content="", tool_calls=[tc])
messages = []
with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc:
agent._execute_tool_calls_sequential(msg, messages, "task-1")
mock_hfc.assert_not_called()
assert starts == []
assert progress == []
assert len(messages) == 1
assert messages[0]["role"] == "tool"
assert messages[0]["tool_call_id"] == "c-block"
assert "repeated_exact_failure_block" in messages[0]["content"]
def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages():
agent = _make_agent("web_search")
args = {"query": "same"}
_seed_exact_failures(agent, "web_search", args, count=1)
tc = _mock_tool_call("web_search", json.dumps(args), "c-warn")
msg = SimpleNamespace(content="", tool_calls=[tc])
messages = []
with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})):
agent._execute_tool_calls_sequential(msg, messages, "task-1")
assert [m["role"] for m in messages] == ["tool"]
assert messages[0]["tool_call_id"] == "c-warn"
assert "Tool guardrail" in messages[0]["content"]
assert "repeated_exact_failure_warning" in messages[0]["content"]
def test_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order():
agent = _make_agent("web_search")
blocked_args = {"query": "blocked"}
allowed_args = {"query": "allowed"}
_seed_exact_failures(agent, "web_search", blocked_args)
starts = []
progress_events = []
agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args))
agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw))
calls = [
_mock_tool_call("web_search", json.dumps(blocked_args), "c-block"),
_mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"),
]
msg = SimpleNamespace(content="", tool_calls=calls)
messages = []
executed = []
def fake_handle(name, args, task_id, **kwargs):
executed.append((name, args, kwargs["tool_call_id"]))
return json.dumps({"ok": args["query"]})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(msg, messages, "task-1")
assert executed == [("web_search", allowed_args, "c-allow")]
assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"]
assert "repeated_exact_failure_block" in messages[0]["content"]
assert json.loads(messages[1]["content"]) == {"ok": "allowed"}
assert starts == [("c-allow", "web_search", allowed_args)]
started_events = [event for event in progress_events if event[0] == "tool.started"]
completed_events = [event for event in progress_events if event[0] == "tool.completed"]
assert started_events == [("tool.started", "web_search", allowed_args, {})]
assert len(completed_events) == 1
assert completed_events[0][1] == "web_search"
def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block():
agent = _make_agent("web_search")
args = {"query": "same"}
tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin")
msg = SimpleNamespace(content="", tool_calls=[tc])
messages = []
with (
patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"),
patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc,
):
agent._execute_tool_calls_sequential(msg, messages, "task-1")
mock_hfc.assert_not_called()
assert "plugin policy" in messages[0]["content"]
assert agent._tool_guardrails.before_call("web_search", args).action == "allow"
def test_run_conversation_returns_controlled_guardrail_halt_without_top_level_error():
agent = _make_agent("web_search", max_iterations=10)
same_args = {"query": "same"}
responses = [
_mock_response(
content="",
finish_reason="tool_calls",
tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
)
for i in range(1, 10)
]
agent.client.chat.completions.create.side_effect = responses
with (
patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("search repeatedly")
assert mock_hfc.call_count == 2
assert result["api_calls"] == 3
assert result["api_calls"] < agent.max_iterations
assert result["turn_exit_reason"] == "guardrail_halt"
assert "error" not in result
assert result["completed"] is True
assert "stopped retrying" in result["final_response"]
assert result["guardrail"]["code"] == "repeated_exact_failure_block"
assert result["guardrail"]["tool_name"] == "web_search"
assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")]
for assistant_msg in assistant_tool_calls:
call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]]
following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids]
assert len(following_results) == len(call_ids)