diff --git a/agent/display.py b/agent/display.py index 474595d76c..43b35ed301 100644 --- a/agent/display.py +++ b/agent/display.py @@ -14,6 +14,7 @@ from difflib import unified_diff from pathlib import Path from utils import safe_json_loads +from agent.tool_guardrails import classify_tool_failure # ANSI escape codes for coloring tool failure indicators _RED = "\033[31m" @@ -808,30 +809,7 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str] like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic failures. On success, returns ``(False, "")``. """ - if result is None: - return False, "" - - if tool_name == "terminal": - data = safe_json_loads(result) - if isinstance(data, dict): - exit_code = data.get("exit_code") - if exit_code is not None and exit_code != 0: - return True, f" [exit {exit_code}]" - return False, "" - - # Memory-specific: distinguish "full" from real errors - if tool_name == "memory": - data = safe_json_loads(result) - if isinstance(data, dict): - if data.get("success") is False and "exceed the limit" in data.get("error", ""): - return True, " [full]" - - # Generic heuristic for non-terminal tools - lower = result[:500].lower() - if '"error"' in lower or '"failed"' in lower or result.startswith("Error"): - return True, " [error]" - - return False, "" + return classify_tool_failure(tool_name, result) def get_cute_tool_message( diff --git a/agent/tool_guardrails.py b/agent/tool_guardrails.py new file mode 100644 index 0000000000..c8a7aa009a --- /dev/null +++ b/agent/tool_guardrails.py @@ -0,0 +1,381 @@ +"""Pure tool-call loop guardrail primitives. + +The controller in this module is intentionally side-effect free: it tracks +per-turn tool-call observations and returns decisions. Runtime code owns whether +those decisions become synthetic tool results or controlled turn halts. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Mapping + +from utils import safe_json_loads + + +IDEMPOTENT_TOOL_NAMES = frozenset( + { + "read_file", + "search_files", + "web_search", + "web_extract", + "session_search", + "browser_snapshot", + "browser_console", + "browser_get_images", + "mcp_filesystem_read_file", + "mcp_filesystem_read_text_file", + "mcp_filesystem_read_multiple_files", + "mcp_filesystem_list_directory", + "mcp_filesystem_list_directory_with_sizes", + "mcp_filesystem_directory_tree", + "mcp_filesystem_get_file_info", + "mcp_filesystem_search_files", + } +) + +MUTATING_TOOL_NAMES = frozenset( + { + "terminal", + "execute_code", + "write_file", + "patch", + "todo", + "memory", + "skill_manage", + "browser_click", + "browser_type", + "browser_press", + "browser_scroll", + "browser_navigate", + "send_message", + "cronjob", + "delegate_task", + "process", + } +) + + +@dataclass(frozen=True) +class ToolCallGuardrailConfig: + """Thresholds for per-turn tool-call loop detection.""" + + exact_failure_warn_after: int = 2 + exact_failure_block_after: int = 2 + same_tool_failure_warn_after: int = 3 + same_tool_failure_halt_after: int = 5 + no_progress_warn_after: int = 2 + no_progress_block_after: int = 2 + idempotent_tools: frozenset[str] = field(default_factory=lambda: IDEMPOTENT_TOOL_NAMES) + mutating_tools: frozenset[str] = field(default_factory=lambda: MUTATING_TOOL_NAMES) + + +@dataclass(frozen=True) +class ToolCallSignature: + """Stable, non-reversible identity for a tool name plus canonical args.""" + + tool_name: str + args_hash: str + + @classmethod + def from_call(cls, tool_name: str, args: Mapping[str, Any] | None) -> "ToolCallSignature": + canonical = canonical_tool_args(args or {}) + return cls(tool_name=tool_name, args_hash=_sha256(canonical)) + + def to_metadata(self) -> dict[str, str]: + """Return public metadata without raw argument values.""" + return {"tool_name": self.tool_name, "args_hash": self.args_hash} + + +@dataclass(frozen=True) +class ToolGuardrailDecision: + """Decision returned by the tool-call guardrail controller.""" + + action: str = "allow" # allow | warn | block | halt + code: str = "allow" + message: str = "" + tool_name: str = "" + count: int = 0 + signature: ToolCallSignature | None = None + + @property + def allows_execution(self) -> bool: + return self.action in {"allow", "warn"} + + @property + def should_halt(self) -> bool: + return self.action in {"block", "halt"} + + def to_metadata(self) -> dict[str, Any]: + data: dict[str, Any] = { + "action": self.action, + "code": self.code, + "message": self.message, + "tool_name": self.tool_name, + "count": self.count, + } + if self.signature is not None: + data["signature"] = self.signature.to_metadata() + return data + + +def canonical_tool_args(args: Mapping[str, Any]) -> str: + """Return sorted compact JSON for parsed tool arguments.""" + if not isinstance(args, Mapping): + raise TypeError(f"tool args must be a mapping, got {type(args).__name__}") + return json.dumps( + args, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + default=str, + ) + + +def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]: + """Classify a tool result using shared display/runtime semantics.""" + if result is None: + return False, "" + + if tool_name == "terminal": + data = safe_json_loads(result) + if isinstance(data, dict): + exit_code = data.get("exit_code") + if exit_code is not None and exit_code != 0: + return True, f" [exit {exit_code}]" + if data.get("success") is False or data.get("failed") is True: + return True, " [error]" + error = data.get("error") + if error is not None and error != "": + return True, " [error]" + return False, "" + + data = safe_json_loads(result) + if isinstance(data, dict): + if tool_name == "memory": + error = data.get("error", "") + if data.get("success") is False and isinstance(error, str) and "exceed the limit" in error: + return True, " [full]" + if data.get("success") is False or data.get("failed") is True: + return True, " [error]" + error = data.get("error") + if error is not None and error != "": + return True, " [error]" + return False, "" + + lower = result[:500].lower() + if "traceback" in lower or lower.startswith("error:"): + return True, " [error]" + if '"error"' in lower or '"failed"' in lower or result.startswith("Error"): + return True, " [error]" + return False, "" + + +class ToolCallGuardrailController: + """Per-turn controller for repeated failed/non-progressing tool calls.""" + + def __init__(self, config: ToolCallGuardrailConfig | None = None): + self.config = config or ToolCallGuardrailConfig() + self.reset_for_turn() + + def reset_for_turn(self) -> None: + self._exact_failure_counts: dict[ToolCallSignature, int] = {} + self._same_tool_failure_counts: dict[str, int] = {} + self._no_progress: dict[ToolCallSignature, tuple[str, int]] = {} + self._halt_decision: ToolGuardrailDecision | None = None + + @property + def halt_decision(self) -> ToolGuardrailDecision | None: + return self._halt_decision + + def before_call(self, tool_name: str, args: Mapping[str, Any] | None) -> ToolGuardrailDecision: + signature = ToolCallSignature.from_call(tool_name, _coerce_args(args)) + + exact_count = self._exact_failure_counts.get(signature, 0) + if exact_count >= self.config.exact_failure_block_after: + decision = ToolGuardrailDecision( + action="block", + code="repeated_exact_failure_block", + message=( + f"Blocked {tool_name}: the same tool call failed {exact_count} " + "times with identical arguments. Stop retrying it unchanged; " + "change strategy or explain the blocker." + ), + tool_name=tool_name, + count=exact_count, + signature=signature, + ) + self._halt_decision = decision + return decision + + if self._is_idempotent(tool_name): + record = self._no_progress.get(signature) + if record is not None: + _result_hash, repeat_count = record + if repeat_count >= self.config.no_progress_block_after: + decision = ToolGuardrailDecision( + action="block", + code="idempotent_no_progress_block", + message=( + f"Blocked {tool_name}: this read-only call returned the same " + f"result {repeat_count} times. Stop repeating it unchanged; " + "use the result already provided or try a different query." + ), + tool_name=tool_name, + count=repeat_count, + signature=signature, + ) + self._halt_decision = decision + return decision + + return ToolGuardrailDecision(tool_name=tool_name, signature=signature) + + def after_call( + self, + tool_name: str, + args: Mapping[str, Any] | None, + result: str | None, + *, + failed: bool | None = None, + ) -> ToolGuardrailDecision: + args = _coerce_args(args) + signature = ToolCallSignature.from_call(tool_name, args) + if failed is None: + failed, _ = classify_tool_failure(tool_name, result) + + if failed: + exact_count = self._exact_failure_counts.get(signature, 0) + 1 + self._exact_failure_counts[signature] = exact_count + self._no_progress.pop(signature, None) + + same_count = self._same_tool_failure_counts.get(tool_name, 0) + 1 + self._same_tool_failure_counts[tool_name] = same_count + + if same_count >= self.config.same_tool_failure_halt_after: + decision = ToolGuardrailDecision( + action="halt", + code="same_tool_failure_halt", + message=( + f"Stopped {tool_name}: it failed {same_count} times this turn. " + "Stop retrying the same failing tool path and choose a different approach." + ), + tool_name=tool_name, + count=same_count, + signature=signature, + ) + self._halt_decision = decision + return decision + + if exact_count >= self.config.exact_failure_warn_after: + return ToolGuardrailDecision( + action="warn", + code="repeated_exact_failure_warning", + message=( + f"Tool guardrail: {tool_name} has failed {exact_count} times " + "with identical arguments. Do not retry it unchanged; inspect the " + "error and change strategy." + ), + tool_name=tool_name, + count=exact_count, + signature=signature, + ) + + if same_count >= self.config.same_tool_failure_warn_after: + return ToolGuardrailDecision( + action="warn", + code="same_tool_failure_warning", + message=( + f"Tool guardrail: {tool_name} has failed {same_count} times " + "this turn. Change approach before retrying." + ), + tool_name=tool_name, + count=same_count, + signature=signature, + ) + + return ToolGuardrailDecision(tool_name=tool_name, count=exact_count, signature=signature) + + self._exact_failure_counts.pop(signature, None) + self._same_tool_failure_counts.pop(tool_name, None) + + if not self._is_idempotent(tool_name): + self._no_progress.pop(signature, None) + return ToolGuardrailDecision(tool_name=tool_name, signature=signature) + + result_hash = _result_hash(result) + previous = self._no_progress.get(signature) + repeat_count = 1 + if previous is not None and previous[0] == result_hash: + repeat_count = previous[1] + 1 + self._no_progress[signature] = (result_hash, repeat_count) + + if repeat_count >= self.config.no_progress_warn_after: + return ToolGuardrailDecision( + action="warn", + code="idempotent_no_progress_warning", + message=( + f"Tool guardrail: {tool_name} returned the same result " + f"{repeat_count} times. Use the result or change the query instead " + "of repeating it unchanged." + ), + tool_name=tool_name, + count=repeat_count, + signature=signature, + ) + + return ToolGuardrailDecision(tool_name=tool_name, count=repeat_count, signature=signature) + + def _is_idempotent(self, tool_name: str) -> bool: + if tool_name in self.config.mutating_tools: + return False + return tool_name in self.config.idempotent_tools + + +def toolguard_synthetic_result(decision: ToolGuardrailDecision) -> str: + """Build a synthetic role=tool content string for a blocked tool call.""" + return json.dumps( + { + "error": decision.message, + "guardrail": decision.to_metadata(), + }, + ensure_ascii=False, + ) + + +def append_toolguard_guidance(result: str, decision: ToolGuardrailDecision) -> str: + """Append runtime guidance to the current tool result content.""" + if decision.action not in {"warn", "halt"} or not decision.message: + return result + suffix = ( + "\n\n[Tool guardrail: " + f"{decision.code}; count={decision.count}; {decision.message}]" + ) + return (result or "") + suffix + + +def _coerce_args(args: Mapping[str, Any] | None) -> Mapping[str, Any]: + return args if isinstance(args, Mapping) else {} + + +def _result_hash(result: str | None) -> str: + parsed = safe_json_loads(result or "") + if parsed is not None: + try: + canonical = json.dumps( + parsed, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + default=str, + ) + except TypeError: + canonical = str(parsed) + else: + canonical = result or "" + return _sha256(canonical) + + +def _sha256(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest() diff --git a/run_agent.py b/run_agent.py index 2645a14a60..20b396f01e 100644 --- a/run_agent.py +++ b/run_agent.py @@ -162,6 +162,12 @@ from agent.display import ( _detect_tool_failure, get_tool_emoji as _get_tool_emoji, ) +from agent.tool_guardrails import ( + ToolCallGuardrailController, + ToolGuardrailDecision, + append_toolguard_guidance, + toolguard_synthetic_result, +) from agent.trajectory import ( convert_scratchpad_to_think, has_incomplete_scratchpad, save_trajectory as _save_trajectory_to_file, @@ -1150,6 +1156,8 @@ class AIAgent: # Tool execution state — allows _vprint during tool execution # even when stream consumers are registered (no tokens streaming then) self._executing_tools = False + self._tool_guardrails = ToolCallGuardrailController() + self._tool_guardrail_halt_decision: ToolGuardrailDecision | None = None # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False @@ -9107,6 +9115,44 @@ class AIAgent: ) return compressed, new_system_prompt + def _set_tool_guardrail_halt(self, decision: ToolGuardrailDecision) -> None: + """Record the first guardrail decision that should stop this turn.""" + if decision.should_halt and self._tool_guardrail_halt_decision is None: + self._tool_guardrail_halt_decision = decision + + def _toolguard_controlled_halt_response(self, decision: ToolGuardrailDecision) -> str: + tool = decision.tool_name or "a tool" + return ( + f"I stopped retrying {tool} because it hit the tool-call guardrail " + f"({decision.code}) after {decision.count} repeated non-progressing " + "attempts. The last tool result explains the blocker; the next step is " + "to change strategy instead of repeating the same call." + ) + + def _append_guardrail_observation( + self, + tool_name: str, + function_args: dict, + function_result: str, + *, + failed: bool, + ) -> str: + decision = self._tool_guardrails.after_call( + tool_name, + function_args, + function_result, + failed=failed, + ) + if decision.action in {"warn", "halt"}: + function_result = append_toolguard_guidance(function_result, decision) + if decision.should_halt: + self._set_tool_guardrail_halt(decision) + return function_result + + def _guardrail_block_result(self, decision: ToolGuardrailDecision) -> str: + self._set_tool_guardrail_halt(decision) + return toolguard_synthetic_result(decision) + def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: """Execute tool calls from the assistant message and append results to messages. @@ -9150,7 +9196,8 @@ class AIAgent: ) def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str, - tool_call_id: Optional[str] = None, messages: list = None) -> str: + tool_call_id: Optional[str] = None, messages: list = None, + pre_tool_block_checked: bool = False) -> str: """Invoke a single tool and return the result string. No display logic. Handles both agent-level tools (todo, memory, etc.) and registry-dispatched @@ -9159,13 +9206,14 @@ class AIAgent: """ # Check plugin hooks for a block directive before executing anything. block_message: Optional[str] = None - try: - from hermes_cli.plugins import get_pre_tool_call_block_message - block_message = get_pre_tool_call_block_message( - function_name, function_args, task_id=effective_task_id or "", - ) - except Exception: - pass + if not pre_tool_block_checked: + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + pass if block_message is not None: return json.dumps({"error": block_message}, ensure_ascii=False) @@ -9317,13 +9365,31 @@ class AIAgent: except Exception: pass - parsed_calls.append((tool_call, function_name, function_args)) + block_result = None + blocked_by_guardrail = False + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + block_message = None + + if block_message is not None: + block_result = json.dumps({"error": block_message}, ensure_ascii=False) + else: + guardrail_decision = self._tool_guardrails.before_call(function_name, function_args) + if not guardrail_decision.allows_execution: + block_result = self._guardrail_block_result(guardrail_decision) + blocked_by_guardrail = True + + parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail)) # ── Logging / callbacks ────────────────────────────────────────── - tool_names_str = ", ".join(name for _, name, _ in parsed_calls) + tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls) if not self.quiet_mode: print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}") - for i, (tc, name, args) in enumerate(parsed_calls, 1): + for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1): args_str = json.dumps(args, ensure_ascii=False) if self.verbose_logging: print(f" 📞 Tool {i}: {name}({list(args.keys())})") @@ -9332,7 +9398,9 @@ class AIAgent: args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}") - for tc, name, args in parsed_calls: + for tc, name, args, block_result, blocked_by_guardrail in parsed_calls: + if block_result is not None: + continue if self.tool_progress_callback: try: preview = _build_tool_preview(name, args) @@ -9340,7 +9408,9 @@ class AIAgent: except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") - for tc, name, args in parsed_calls: + for tc, name, args, block_result, blocked_by_guardrail in parsed_calls: + if block_result is not None: + continue if self.tool_start_callback: try: self.tool_start_callback(tc.id, name, args) @@ -9348,8 +9418,11 @@ class AIAgent: logging.debug(f"Tool start callback error: {cb_err}") # ── Concurrent execution ───────────────────────────────────────── - # Each slot holds (function_name, function_args, function_result, duration, error_flag) + # Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag) results = [None] * num_tools + for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls): + if block_result is not None: + results[i] = (name, args, block_result, 0.0, True, True) # Touch activity before launching workers so the gateway knows # we're executing tools (not stuck). @@ -9404,7 +9477,14 @@ class AIAgent: pass start = time.time() try: - result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id, messages=messages) + result = self._invoke_tool( + function_name, + function_args, + effective_task_id, + tool_call.id, + messages=messages, + pre_tool_block_checked=True, + ) except Exception as tool_error: result = f"Error executing tool '{function_name}': {tool_error}" logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True) @@ -9414,7 +9494,7 @@ class AIAgent: logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200]) else: logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result)) - results[index] = (function_name, function_args, result, duration, is_error) + results[index] = (function_name, function_args, result, duration, is_error, False) # Tear down worker-tid tracking. Clear any interrupt bit we may # have set so the next task scheduled onto this recycled tid # starts with a clean slate. @@ -9440,61 +9520,67 @@ class AIAgent: spinner.start() try: - max_workers = min(num_tools, _MAX_TOOL_WORKERS) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for i, (tc, name, args) in enumerate(parsed_calls): - # Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread. - ctx = contextvars.copy_context() - f = executor.submit(ctx.run, _run_tool, i, tc, name, args) - futures.append(f) + runnable_calls = [ + (i, tc, name, args) + for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls) + if block_result is None + ] + futures = [] + if runnable_calls: + max_workers = min(len(runnable_calls), _MAX_TOOL_WORKERS) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + for i, tc, name, args in runnable_calls: + # Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread. + ctx = contextvars.copy_context() + f = executor.submit(ctx.run, _run_tool, i, tc, name, args) + futures.append(f) - # Wait for all to complete with periodic heartbeats so the - # gateway's inactivity monitor doesn't kill us during long - # concurrent tool batches. Also check for user interrupts - # so we don't block indefinitely when the user sends /stop - # or a new message during concurrent tool execution. - _conc_start = time.time() - _interrupt_logged = False - while True: - done, not_done = concurrent.futures.wait( - futures, timeout=5.0, - ) - if not not_done: - break - - # Check for interrupt — the per-thread interrupt signal - # already causes individual tools (terminal, execute_code) - # to abort, but tools without interrupt checks (web_search, - # read_file) will run to completion. Cancel any futures - # that haven't started yet so we don't block on them. - if self._interrupt_requested: - if not _interrupt_logged: - _interrupt_logged = True - self._vprint( - f"{self.log_prefix}⚡ Interrupt: cancelling " - f"{len(not_done)} pending concurrent tool(s)", - force=True, - ) - for f in not_done: - f.cancel() - # Give already-running tools a moment to notice the - # per-thread interrupt signal and exit gracefully. - concurrent.futures.wait(not_done, timeout=3.0) - break - - _conc_elapsed = int(time.time() - _conc_start) - # Heartbeat every ~30s (6 × 5s poll intervals) - if _conc_elapsed > 0 and _conc_elapsed % 30 < 6: - _still_running = [ - parsed_calls[futures.index(f)][1] - for f in not_done - if f in futures - ] - self._touch_activity( - f"concurrent tools running ({_conc_elapsed}s, " - f"{len(not_done)} remaining: {', '.join(_still_running[:3])})" + # Wait for all to complete with periodic heartbeats so the + # gateway's inactivity monitor doesn't kill us during long + # concurrent tool batches. Also check for user interrupts + # so we don't block indefinitely when the user sends /stop + # or a new message during concurrent tool execution. + _conc_start = time.time() + _interrupt_logged = False + while True: + done, not_done = concurrent.futures.wait( + futures, timeout=5.0, ) + if not not_done: + break + + # Check for interrupt — the per-thread interrupt signal + # already causes individual tools (terminal, execute_code) + # to abort, but tools without interrupt checks (web_search, + # read_file) will run to completion. Cancel any futures + # that haven't started yet so we don't block on them. + if self._interrupt_requested: + if not _interrupt_logged: + _interrupt_logged = True + self._vprint( + f"{self.log_prefix}⚡ Interrupt: cancelling " + f"{len(not_done)} pending concurrent tool(s)", + force=True, + ) + for f in not_done: + f.cancel() + # Give already-running tools a moment to notice the + # per-thread interrupt signal and exit gracefully. + concurrent.futures.wait(not_done, timeout=3.0) + break + + _conc_elapsed = int(time.time() - _conc_start) + # Heartbeat every ~30s (6 × 5s poll intervals) + if _conc_elapsed > 0 and _conc_elapsed % 30 < 6: + _still_running = [ + parsed_calls[futures.index(f)][1] + for f in not_done + if f in futures + ] + self._touch_activity( + f"concurrent tools running ({_conc_elapsed}s, " + f"{len(not_done)} remaining: {', '.join(_still_running[:3])})" + ) finally: if spinner: # Build a summary message for the spinner stop @@ -9503,8 +9589,9 @@ class AIAgent: spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total") # ── Post-execution: display per-tool results ───────────────────── - for i, (tc, name, args) in enumerate(parsed_calls): + for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls): r = results[i] + blocked = False if r is None: # Tool was cancelled (interrupt) or thread didn't return if self._interrupt_requested: @@ -9513,13 +9600,21 @@ class AIAgent: function_result = f"Error executing tool '{name}': thread did not return a result" tool_duration = 0.0 else: - function_name, function_args, function_result, tool_duration, is_error = r + function_name, function_args, function_result, tool_duration, is_error, blocked = r + + if not blocked: + function_result = self._append_guardrail_observation( + function_name, + function_args, + function_result, + failed=is_error, + ) if is_error: result_preview = function_result[:200] if len(function_result) > 200 else function_result logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview) - if self.tool_progress_callback: + if not blocked and self.tool_progress_callback: try: self.tool_progress_callback( "tool.completed", function_name, None, None, @@ -9547,7 +9642,7 @@ class AIAgent: self._current_tool = None self._touch_activity(f"tool completed: {name} ({tool_duration:.1f}s)") - if self.tool_complete_callback: + if not blocked and self.tool_complete_callback: try: self.tool_complete_callback(tc.id, name, args, function_result) except Exception as cb_err: @@ -9629,9 +9724,17 @@ class AIAgent: except Exception: pass - if _block_msg is not None: - # Tool blocked by plugin policy — skip counter resets. - # Execution is handled below in the tool dispatch chain. + _guardrail_block_decision: ToolGuardrailDecision | None = None + if _block_msg is None: + guardrail_decision = self._tool_guardrails.before_call(function_name, function_args) + if not guardrail_decision.allows_execution: + _guardrail_block_decision = guardrail_decision + + _execution_blocked = _block_msg is not None or _guardrail_block_decision is not None + + if _execution_blocked: + # Tool blocked by plugin or guardrail policy — skip counters, + # callbacks, checkpointing, activity mutation, and real execution. pass else: # Reset nudge counters when the relevant tool is actually used @@ -9649,35 +9752,35 @@ class AIAgent: args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") - if _block_msg is None: + if not _execution_blocked: self._current_tool = function_name self._touch_activity(f"executing tool: {function_name}") # Set activity callback for long-running tool execution (terminal # commands, etc.) so the gateway's inactivity monitor doesn't kill # the agent while a command is running. - if _block_msg is None: + if not _execution_blocked: try: from tools.environments.base import set_activity_callback set_activity_callback(self._touch_activity) except Exception: pass - if _block_msg is None and self.tool_progress_callback: + if not _execution_blocked and self.tool_progress_callback: try: preview = _build_tool_preview(function_name, function_args) self.tool_progress_callback("tool.started", function_name, preview, function_args) except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") - if _block_msg is None and self.tool_start_callback: + if not _execution_blocked and self.tool_start_callback: try: self.tool_start_callback(tool_call.id, function_name, function_args) except Exception as cb_err: logging.debug(f"Tool start callback error: {cb_err}") # Checkpoint: snapshot working dir before file-mutating tools - if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: + if not _execution_blocked and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: try: file_path = function_args.get("path", "") if file_path: @@ -9689,7 +9792,7 @@ class AIAgent: pass # never block tool execution # Checkpoint before destructive terminal commands - if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled: + if not _execution_blocked and function_name == "terminal" and self._checkpoint_mgr.enabled: try: cmd = function_args.get("command", "") if _is_destructive_command(cmd): @@ -9706,6 +9809,11 @@ class AIAgent: # Tool blocked by plugin policy — return error without executing. function_result = json.dumps({"error": _block_msg}, ensure_ascii=False) tool_duration = 0.0 + elif _guardrail_block_decision is not None: + # Tool blocked by tool-loop guardrail — synthesize exactly one + # tool result for the original tool_call_id without executing. + function_result = self._guardrail_block_result(_guardrail_block_decision) + tool_duration = 0.0 elif function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool function_result = _todo_tool( @@ -9889,12 +9997,22 @@ class AIAgent: # Log tool errors to the persistent error log so [error] tags # in the UI always have a corresponding detailed entry on disk. _is_error_result, _ = _detect_tool_failure(function_name, function_result) + if not _execution_blocked: + function_result = self._append_guardrail_observation( + function_name, + function_args, + function_result, + failed=_is_error_result, + ) + result_preview = function_result if self.verbose_logging else ( + function_result[:200] if len(function_result) > 200 else function_result + ) if _is_error_result: logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview) else: logger.info("tool %s completed (%.2fs, %d chars)", function_name, tool_duration, len(function_result)) - if self.tool_progress_callback: + if not _execution_blocked and self.tool_progress_callback: try: self.tool_progress_callback( "tool.completed", function_name, None, None, @@ -9910,7 +10028,7 @@ class AIAgent: logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s") logging.debug(f"Tool result ({len(function_result)} chars): {function_result}") - if self.tool_complete_callback: + if not _execution_blocked and self.tool_complete_callback: try: self.tool_complete_callback(tool_call.id, function_name, function_args, function_result) except Exception as cb_err: @@ -10244,6 +10362,8 @@ class AIAgent: self._last_content_tools_all_housekeeping = False self._mute_post_response = False self._unicode_sanitization_passes = 0 + self._tool_guardrails.reset_for_turn() + self._tool_guardrail_halt_decision = None # Pre-turn connection health check: detect and clean up dead TCP # connections left over from provider outages or dropped streams. @@ -13041,6 +13161,16 @@ class AIAgent: self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count) + if self._tool_guardrail_halt_decision is not None: + decision = self._tool_guardrail_halt_decision + _turn_exit_reason = "guardrail_halt" + final_response = self._toolguard_controlled_halt_response(decision) + self._emit_status( + f"⚠️ Tool guardrail halted {decision.tool_name}: {decision.code}" + ) + messages.append({"role": "assistant", "content": final_response}) + break + # Reset per-turn retry counters after successful tool # execution so a single truncation doesn't poison the # entire conversation. @@ -13567,6 +13697,7 @@ class AIAgent: "messages": messages, "api_calls": api_call_count, "completed": completed, + "turn_exit_reason": _turn_exit_reason, "partial": False, # True only when stopped due to invalid tool calls "interrupted": interrupted, "response_previewed": getattr(self, "_response_was_previewed", False), @@ -13586,6 +13717,8 @@ class AIAgent: "cost_status": self.session_cost_status, "cost_source": self.session_cost_source, } + if self._tool_guardrail_halt_decision is not None: + result["guardrail"] = self._tool_guardrail_halt_decision.to_metadata() # If a /steer landed after the final assistant turn (no more tool # batches to drain into), hand it back to the caller so it can be # delivered as the next user turn instead of being silently lost. diff --git a/tests/agent/test_tool_guardrails.py b/tests/agent/test_tool_guardrails.py new file mode 100644 index 0000000000..18999b2f39 --- /dev/null +++ b/tests/agent/test_tool_guardrails.py @@ -0,0 +1,142 @@ +"""Pure tool-call guardrail primitive tests.""" + +import json + +from agent.tool_guardrails import ( + ToolCallGuardrailConfig, + ToolCallGuardrailController, + ToolCallSignature, + canonical_tool_args, +) + + +def test_tool_call_signature_hashes_canonical_nested_unicode_args_without_exposing_raw_args(): + args_a = { + "z": [{"β": "☤", "a": 1}], + "a": {"y": 2, "x": "secret-token-value"}, + } + args_b = { + "a": {"x": "secret-token-value", "y": 2}, + "z": [{"a": 1, "β": "☤"}], + } + + assert canonical_tool_args(args_a) == canonical_tool_args(args_b) + sig_a = ToolCallSignature.from_call("web_search", args_a) + sig_b = ToolCallSignature.from_call("web_search", args_b) + + assert sig_a == sig_b + assert len(sig_a.args_hash) == 64 + metadata = sig_a.to_metadata() + assert metadata == {"tool_name": "web_search", "args_hash": sig_a.args_hash} + assert "secret-token-value" not in json.dumps(metadata) + assert "☤" not in json.dumps(metadata) + + +def test_repeated_identical_failed_call_warns_then_blocks_before_third_execution(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig( + exact_failure_warn_after=2, + exact_failure_block_after=2, + same_tool_failure_halt_after=99, + ) + ) + args = {"query": "same"} + + assert controller.before_call("web_search", args).action == "allow" + first = controller.after_call("web_search", args, '{"error":"boom"}', failed=True) + assert first.action == "allow" + + assert controller.before_call("web_search", args).action == "allow" + second = controller.after_call("web_search", args, '{"error":"boom"}', failed=True) + assert second.action == "warn" + assert second.code == "repeated_exact_failure_warning" + assert second.count == 2 + + blocked = controller.before_call("web_search", args) + assert blocked.action == "block" + assert blocked.code == "repeated_exact_failure_block" + assert blocked.tool_name == "web_search" + assert blocked.count == 2 + + +def test_success_resets_exact_signature_failure_streak(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig(exact_failure_block_after=2, same_tool_failure_halt_after=99) + ) + args = {"query": "same"} + + controller.after_call("web_search", args, '{"error":"boom"}', failed=True) + controller.after_call("web_search", args, '{"ok":true}', failed=False) + + assert controller.before_call("web_search", args).action == "allow" + controller.after_call("web_search", args, '{"error":"boom"}', failed=True) + assert controller.before_call("web_search", args).action == "allow" + + +def test_same_tool_varying_args_failure_streak_warns_then_halts_independent_of_exact_streak(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig( + exact_failure_block_after=99, + same_tool_failure_warn_after=2, + same_tool_failure_halt_after=3, + ) + ) + + first = controller.after_call("terminal", {"command": "cmd-1"}, '{"exit_code":1}', failed=True) + assert first.action == "allow" + second = controller.after_call("terminal", {"command": "cmd-2"}, '{"exit_code":1}', failed=True) + assert second.action == "warn" + assert second.code == "same_tool_failure_warning" + third = controller.after_call("terminal", {"command": "cmd-3"}, '{"exit_code":1}', failed=True) + assert third.action == "halt" + assert third.code == "same_tool_failure_halt" + assert third.count == 3 + + +def test_idempotent_no_progress_repeated_result_warns_then_blocks_future_repeat(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2) + ) + args = {"path": "/tmp/same.txt"} + result = "same file contents" + + assert controller.before_call("read_file", args).action == "allow" + assert controller.after_call("read_file", args, result, failed=False).action == "allow" + assert controller.before_call("read_file", args).action == "allow" + warn = controller.after_call("read_file", args, result, failed=False) + assert warn.action == "warn" + assert warn.code == "idempotent_no_progress_warning" + + blocked = controller.before_call("read_file", args) + assert blocked.action == "block" + assert blocked.code == "idempotent_no_progress_block" + + +def test_mutating_or_unknown_tools_are_not_blocked_for_repeated_identical_success_output_by_default(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2) + ) + + for _ in range(3): + assert controller.before_call("write_file", {"path": "/tmp/x", "content": "x"}).action == "allow" + assert controller.after_call("write_file", {"path": "/tmp/x", "content": "x"}, "ok", failed=False).action == "allow" + assert controller.before_call("custom_tool", {"x": 1}).action == "allow" + assert controller.after_call("custom_tool", {"x": 1}, "ok", failed=False).action == "allow" + + +def test_reset_for_turn_clears_bounded_guardrail_state(): + controller = ToolCallGuardrailController( + ToolCallGuardrailConfig(exact_failure_block_after=2, no_progress_block_after=2) + ) + controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True) + controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True) + controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False) + controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False) + + assert controller.before_call("web_search", {"query": "same"}).action == "block" + assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "block" + + controller.reset_for_turn() + + assert controller.before_call("web_search", {"query": "same"}).action == "allow" + assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "allow" diff --git a/tests/run_agent/test_tool_call_guardrail_runtime.py b/tests/run_agent/test_tool_call_guardrail_runtime.py new file mode 100644 index 0000000000..1b138b02e1 --- /dev/null +++ b/tests/run_agent/test_tool_call_guardrail_runtime.py @@ -0,0 +1,202 @@ +"""Runtime tests for tool-call loop guardrails.""" + +import json +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent + + +def _make_tool_defs(*names: str) -> list[dict]: + return [ + { + "type": "function", + "function": { + "name": name, + "description": f"{name} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for name in names + ] + + +def _mock_tool_call(name="web_search", arguments="{}", call_id=None): + return SimpleNamespace( + id=call_id or f"call_{uuid.uuid4().hex[:8]}", + type="function", + function=SimpleNamespace(name=name, arguments=arguments), + ) + + +def _mock_response(content="Hello", finish_reason="stop", tool_calls=None): + msg = SimpleNamespace(content=content, tool_calls=tool_calls) + choice = SimpleNamespace(message=msg, finish_reason=finish_reason) + return SimpleNamespace(choices=[choice], model="test/model", usage=None) + + +def _make_agent(*tool_names: str, max_iterations: int = 10) -> AIAgent: + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-key-1234567890", + base_url="https://openrouter.ai/api/v1", + max_iterations=max_iterations, + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.client = MagicMock() + agent._cached_system_prompt = "You are helpful." + agent._use_prompt_caching = False + agent.tool_delay = 0 + agent.compression_enabled = False + agent.save_trajectories = False + return agent + + +def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None: + for _ in range(count): + agent._tool_guardrails.after_call( + tool_name, + args, + json.dumps({"error": "boom"}), + failed=True, + ) + + +def test_sequential_path_blocks_repeated_exact_failure_before_execution(): + agent = _make_agent("web_search") + args = {"query": "same"} + _seed_exact_failures(agent, "web_search", args) + starts = [] + progress = [] + agent.tool_start_callback = lambda *a, **k: starts.append((a, k)) + agent.tool_progress_callback = lambda *a, **k: progress.append((a, k)) + tc = _mock_tool_call("web_search", json.dumps(args), "c-block") + msg = SimpleNamespace(content="", tool_calls=[tc]) + messages = [] + + with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc: + agent._execute_tool_calls_sequential(msg, messages, "task-1") + + mock_hfc.assert_not_called() + assert starts == [] + assert progress == [] + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert messages[0]["tool_call_id"] == "c-block" + assert "repeated_exact_failure_block" in messages[0]["content"] + + +def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages(): + agent = _make_agent("web_search") + args = {"query": "same"} + _seed_exact_failures(agent, "web_search", args, count=1) + tc = _mock_tool_call("web_search", json.dumps(args), "c-warn") + msg = SimpleNamespace(content="", tool_calls=[tc]) + messages = [] + + with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})): + agent._execute_tool_calls_sequential(msg, messages, "task-1") + + assert [m["role"] for m in messages] == ["tool"] + assert messages[0]["tool_call_id"] == "c-warn" + assert "Tool guardrail" in messages[0]["content"] + assert "repeated_exact_failure_warning" in messages[0]["content"] + + +def test_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order(): + agent = _make_agent("web_search") + blocked_args = {"query": "blocked"} + allowed_args = {"query": "allowed"} + _seed_exact_failures(agent, "web_search", blocked_args) + starts = [] + progress_events = [] + agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args)) + agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw)) + calls = [ + _mock_tool_call("web_search", json.dumps(blocked_args), "c-block"), + _mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"), + ] + msg = SimpleNamespace(content="", tool_calls=calls) + messages = [] + executed = [] + + def fake_handle(name, args, task_id, **kwargs): + executed.append((name, args, kwargs["tool_call_id"])) + return json.dumps({"ok": args["query"]}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(msg, messages, "task-1") + + assert executed == [("web_search", allowed_args, "c-allow")] + assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"] + assert "repeated_exact_failure_block" in messages[0]["content"] + assert json.loads(messages[1]["content"]) == {"ok": "allowed"} + assert starts == [("c-allow", "web_search", allowed_args)] + started_events = [event for event in progress_events if event[0] == "tool.started"] + completed_events = [event for event in progress_events if event[0] == "tool.completed"] + assert started_events == [("tool.started", "web_search", allowed_args, {})] + assert len(completed_events) == 1 + assert completed_events[0][1] == "web_search" + + +def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block(): + agent = _make_agent("web_search") + args = {"query": "same"} + tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin") + msg = SimpleNamespace(content="", tool_calls=[tc]) + messages = [] + + with ( + patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"), + patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc, + ): + agent._execute_tool_calls_sequential(msg, messages, "task-1") + + mock_hfc.assert_not_called() + assert "plugin policy" in messages[0]["content"] + assert agent._tool_guardrails.before_call("web_search", args).action == "allow" + + +def test_run_conversation_returns_controlled_guardrail_halt_without_top_level_error(): + agent = _make_agent("web_search", max_iterations=10) + same_args = {"query": "same"} + responses = [ + _mock_response( + content="", + finish_reason="tool_calls", + tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")], + ) + for i in range(1, 10) + ] + agent.client.chat.completions.create.side_effect = responses + + with ( + patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("search repeatedly") + + assert mock_hfc.call_count == 2 + assert result["api_calls"] == 3 + assert result["api_calls"] < agent.max_iterations + assert result["turn_exit_reason"] == "guardrail_halt" + assert "error" not in result + assert result["completed"] is True + assert "stopped retrying" in result["final_response"] + assert result["guardrail"]["code"] == "repeated_exact_failure_block" + assert result["guardrail"]["tool_name"] == "web_search" + + assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")] + for assistant_msg in assistant_tool_calls: + call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]] + following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids] + assert len(following_results) == len(call_ids)