mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 02:01:47 +00:00
fix(agent): add tool-call loop guardrails
This commit is contained in:
parent
8d7500d80d
commit
58b89965c8
5 changed files with 944 additions and 108 deletions
|
|
@ -14,6 +14,7 @@ from difflib import unified_diff
|
|||
from pathlib import Path
|
||||
|
||||
from utils import safe_json_loads
|
||||
from agent.tool_guardrails import classify_tool_failure
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
|
|
@ -808,30 +809,7 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
|||
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
|
||||
failures. On success, returns ``(False, "")``.
|
||||
"""
|
||||
if result is None:
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
return False, ""
|
||||
|
||||
# Memory-specific: distinguish "full" from real errors
|
||||
if tool_name == "memory":
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
||||
return True, " [full]"
|
||||
|
||||
# Generic heuristic for non-terminal tools
|
||||
lower = result[:500].lower()
|
||||
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
|
||||
return True, " [error]"
|
||||
|
||||
return False, ""
|
||||
return classify_tool_failure(tool_name, result)
|
||||
|
||||
|
||||
def get_cute_tool_message(
|
||||
|
|
|
|||
381
agent/tool_guardrails.py
Normal file
381
agent/tool_guardrails.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
"""Pure tool-call loop guardrail primitives.
|
||||
|
||||
The controller in this module is intentionally side-effect free: it tracks
|
||||
per-turn tool-call observations and returns decisions. Runtime code owns whether
|
||||
those decisions become synthetic tool results or controlled turn halts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping
|
||||
|
||||
from utils import safe_json_loads
|
||||
|
||||
|
||||
IDEMPOTENT_TOOL_NAMES = frozenset(
|
||||
{
|
||||
"read_file",
|
||||
"search_files",
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"session_search",
|
||||
"browser_snapshot",
|
||||
"browser_console",
|
||||
"browser_get_images",
|
||||
"mcp_filesystem_read_file",
|
||||
"mcp_filesystem_read_text_file",
|
||||
"mcp_filesystem_read_multiple_files",
|
||||
"mcp_filesystem_list_directory",
|
||||
"mcp_filesystem_list_directory_with_sizes",
|
||||
"mcp_filesystem_directory_tree",
|
||||
"mcp_filesystem_get_file_info",
|
||||
"mcp_filesystem_search_files",
|
||||
}
|
||||
)
|
||||
|
||||
MUTATING_TOOL_NAMES = frozenset(
|
||||
{
|
||||
"terminal",
|
||||
"execute_code",
|
||||
"write_file",
|
||||
"patch",
|
||||
"todo",
|
||||
"memory",
|
||||
"skill_manage",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_press",
|
||||
"browser_scroll",
|
||||
"browser_navigate",
|
||||
"send_message",
|
||||
"cronjob",
|
||||
"delegate_task",
|
||||
"process",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallGuardrailConfig:
|
||||
"""Thresholds for per-turn tool-call loop detection."""
|
||||
|
||||
exact_failure_warn_after: int = 2
|
||||
exact_failure_block_after: int = 2
|
||||
same_tool_failure_warn_after: int = 3
|
||||
same_tool_failure_halt_after: int = 5
|
||||
no_progress_warn_after: int = 2
|
||||
no_progress_block_after: int = 2
|
||||
idempotent_tools: frozenset[str] = field(default_factory=lambda: IDEMPOTENT_TOOL_NAMES)
|
||||
mutating_tools: frozenset[str] = field(default_factory=lambda: MUTATING_TOOL_NAMES)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallSignature:
|
||||
"""Stable, non-reversible identity for a tool name plus canonical args."""
|
||||
|
||||
tool_name: str
|
||||
args_hash: str
|
||||
|
||||
@classmethod
|
||||
def from_call(cls, tool_name: str, args: Mapping[str, Any] | None) -> "ToolCallSignature":
|
||||
canonical = canonical_tool_args(args or {})
|
||||
return cls(tool_name=tool_name, args_hash=_sha256(canonical))
|
||||
|
||||
def to_metadata(self) -> dict[str, str]:
|
||||
"""Return public metadata without raw argument values."""
|
||||
return {"tool_name": self.tool_name, "args_hash": self.args_hash}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolGuardrailDecision:
|
||||
"""Decision returned by the tool-call guardrail controller."""
|
||||
|
||||
action: str = "allow" # allow | warn | block | halt
|
||||
code: str = "allow"
|
||||
message: str = ""
|
||||
tool_name: str = ""
|
||||
count: int = 0
|
||||
signature: ToolCallSignature | None = None
|
||||
|
||||
@property
|
||||
def allows_execution(self) -> bool:
|
||||
return self.action in {"allow", "warn"}
|
||||
|
||||
@property
|
||||
def should_halt(self) -> bool:
|
||||
return self.action in {"block", "halt"}
|
||||
|
||||
def to_metadata(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {
|
||||
"action": self.action,
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"tool_name": self.tool_name,
|
||||
"count": self.count,
|
||||
}
|
||||
if self.signature is not None:
|
||||
data["signature"] = self.signature.to_metadata()
|
||||
return data
|
||||
|
||||
|
||||
def canonical_tool_args(args: Mapping[str, Any]) -> str:
|
||||
"""Return sorted compact JSON for parsed tool arguments."""
|
||||
if not isinstance(args, Mapping):
|
||||
raise TypeError(f"tool args must be a mapping, got {type(args).__name__}")
|
||||
return json.dumps(
|
||||
args,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=str,
|
||||
)
|
||||
|
||||
|
||||
def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
|
||||
"""Classify a tool result using shared display/runtime semantics."""
|
||||
if result is None:
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
if data.get("success") is False or data.get("failed") is True:
|
||||
return True, " [error]"
|
||||
error = data.get("error")
|
||||
if error is not None and error != "":
|
||||
return True, " [error]"
|
||||
return False, ""
|
||||
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
if tool_name == "memory":
|
||||
error = data.get("error", "")
|
||||
if data.get("success") is False and isinstance(error, str) and "exceed the limit" in error:
|
||||
return True, " [full]"
|
||||
if data.get("success") is False or data.get("failed") is True:
|
||||
return True, " [error]"
|
||||
error = data.get("error")
|
||||
if error is not None and error != "":
|
||||
return True, " [error]"
|
||||
return False, ""
|
||||
|
||||
lower = result[:500].lower()
|
||||
if "traceback" in lower or lower.startswith("error:"):
|
||||
return True, " [error]"
|
||||
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
|
||||
return True, " [error]"
|
||||
return False, ""
|
||||
|
||||
|
||||
class ToolCallGuardrailController:
|
||||
"""Per-turn controller for repeated failed/non-progressing tool calls."""
|
||||
|
||||
def __init__(self, config: ToolCallGuardrailConfig | None = None):
|
||||
self.config = config or ToolCallGuardrailConfig()
|
||||
self.reset_for_turn()
|
||||
|
||||
def reset_for_turn(self) -> None:
|
||||
self._exact_failure_counts: dict[ToolCallSignature, int] = {}
|
||||
self._same_tool_failure_counts: dict[str, int] = {}
|
||||
self._no_progress: dict[ToolCallSignature, tuple[str, int]] = {}
|
||||
self._halt_decision: ToolGuardrailDecision | None = None
|
||||
|
||||
@property
|
||||
def halt_decision(self) -> ToolGuardrailDecision | None:
|
||||
return self._halt_decision
|
||||
|
||||
def before_call(self, tool_name: str, args: Mapping[str, Any] | None) -> ToolGuardrailDecision:
|
||||
signature = ToolCallSignature.from_call(tool_name, _coerce_args(args))
|
||||
|
||||
exact_count = self._exact_failure_counts.get(signature, 0)
|
||||
if exact_count >= self.config.exact_failure_block_after:
|
||||
decision = ToolGuardrailDecision(
|
||||
action="block",
|
||||
code="repeated_exact_failure_block",
|
||||
message=(
|
||||
f"Blocked {tool_name}: the same tool call failed {exact_count} "
|
||||
"times with identical arguments. Stop retrying it unchanged; "
|
||||
"change strategy or explain the blocker."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=exact_count,
|
||||
signature=signature,
|
||||
)
|
||||
self._halt_decision = decision
|
||||
return decision
|
||||
|
||||
if self._is_idempotent(tool_name):
|
||||
record = self._no_progress.get(signature)
|
||||
if record is not None:
|
||||
_result_hash, repeat_count = record
|
||||
if repeat_count >= self.config.no_progress_block_after:
|
||||
decision = ToolGuardrailDecision(
|
||||
action="block",
|
||||
code="idempotent_no_progress_block",
|
||||
message=(
|
||||
f"Blocked {tool_name}: this read-only call returned the same "
|
||||
f"result {repeat_count} times. Stop repeating it unchanged; "
|
||||
"use the result already provided or try a different query."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=repeat_count,
|
||||
signature=signature,
|
||||
)
|
||||
self._halt_decision = decision
|
||||
return decision
|
||||
|
||||
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
|
||||
|
||||
def after_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: Mapping[str, Any] | None,
|
||||
result: str | None,
|
||||
*,
|
||||
failed: bool | None = None,
|
||||
) -> ToolGuardrailDecision:
|
||||
args = _coerce_args(args)
|
||||
signature = ToolCallSignature.from_call(tool_name, args)
|
||||
if failed is None:
|
||||
failed, _ = classify_tool_failure(tool_name, result)
|
||||
|
||||
if failed:
|
||||
exact_count = self._exact_failure_counts.get(signature, 0) + 1
|
||||
self._exact_failure_counts[signature] = exact_count
|
||||
self._no_progress.pop(signature, None)
|
||||
|
||||
same_count = self._same_tool_failure_counts.get(tool_name, 0) + 1
|
||||
self._same_tool_failure_counts[tool_name] = same_count
|
||||
|
||||
if same_count >= self.config.same_tool_failure_halt_after:
|
||||
decision = ToolGuardrailDecision(
|
||||
action="halt",
|
||||
code="same_tool_failure_halt",
|
||||
message=(
|
||||
f"Stopped {tool_name}: it failed {same_count} times this turn. "
|
||||
"Stop retrying the same failing tool path and choose a different approach."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=same_count,
|
||||
signature=signature,
|
||||
)
|
||||
self._halt_decision = decision
|
||||
return decision
|
||||
|
||||
if exact_count >= self.config.exact_failure_warn_after:
|
||||
return ToolGuardrailDecision(
|
||||
action="warn",
|
||||
code="repeated_exact_failure_warning",
|
||||
message=(
|
||||
f"Tool guardrail: {tool_name} has failed {exact_count} times "
|
||||
"with identical arguments. Do not retry it unchanged; inspect the "
|
||||
"error and change strategy."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=exact_count,
|
||||
signature=signature,
|
||||
)
|
||||
|
||||
if same_count >= self.config.same_tool_failure_warn_after:
|
||||
return ToolGuardrailDecision(
|
||||
action="warn",
|
||||
code="same_tool_failure_warning",
|
||||
message=(
|
||||
f"Tool guardrail: {tool_name} has failed {same_count} times "
|
||||
"this turn. Change approach before retrying."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=same_count,
|
||||
signature=signature,
|
||||
)
|
||||
|
||||
return ToolGuardrailDecision(tool_name=tool_name, count=exact_count, signature=signature)
|
||||
|
||||
self._exact_failure_counts.pop(signature, None)
|
||||
self._same_tool_failure_counts.pop(tool_name, None)
|
||||
|
||||
if not self._is_idempotent(tool_name):
|
||||
self._no_progress.pop(signature, None)
|
||||
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
|
||||
|
||||
result_hash = _result_hash(result)
|
||||
previous = self._no_progress.get(signature)
|
||||
repeat_count = 1
|
||||
if previous is not None and previous[0] == result_hash:
|
||||
repeat_count = previous[1] + 1
|
||||
self._no_progress[signature] = (result_hash, repeat_count)
|
||||
|
||||
if repeat_count >= self.config.no_progress_warn_after:
|
||||
return ToolGuardrailDecision(
|
||||
action="warn",
|
||||
code="idempotent_no_progress_warning",
|
||||
message=(
|
||||
f"Tool guardrail: {tool_name} returned the same result "
|
||||
f"{repeat_count} times. Use the result or change the query instead "
|
||||
"of repeating it unchanged."
|
||||
),
|
||||
tool_name=tool_name,
|
||||
count=repeat_count,
|
||||
signature=signature,
|
||||
)
|
||||
|
||||
return ToolGuardrailDecision(tool_name=tool_name, count=repeat_count, signature=signature)
|
||||
|
||||
def _is_idempotent(self, tool_name: str) -> bool:
|
||||
if tool_name in self.config.mutating_tools:
|
||||
return False
|
||||
return tool_name in self.config.idempotent_tools
|
||||
|
||||
|
||||
def toolguard_synthetic_result(decision: ToolGuardrailDecision) -> str:
|
||||
"""Build a synthetic role=tool content string for a blocked tool call."""
|
||||
return json.dumps(
|
||||
{
|
||||
"error": decision.message,
|
||||
"guardrail": decision.to_metadata(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def append_toolguard_guidance(result: str, decision: ToolGuardrailDecision) -> str:
|
||||
"""Append runtime guidance to the current tool result content."""
|
||||
if decision.action not in {"warn", "halt"} or not decision.message:
|
||||
return result
|
||||
suffix = (
|
||||
"\n\n[Tool guardrail: "
|
||||
f"{decision.code}; count={decision.count}; {decision.message}]"
|
||||
)
|
||||
return (result or "") + suffix
|
||||
|
||||
|
||||
def _coerce_args(args: Mapping[str, Any] | None) -> Mapping[str, Any]:
|
||||
return args if isinstance(args, Mapping) else {}
|
||||
|
||||
|
||||
def _result_hash(result: str | None) -> str:
|
||||
parsed = safe_json_loads(result or "")
|
||||
if parsed is not None:
|
||||
try:
|
||||
canonical = json.dumps(
|
||||
parsed,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=str,
|
||||
)
|
||||
except TypeError:
|
||||
canonical = str(parsed)
|
||||
else:
|
||||
canonical = result or ""
|
||||
return _sha256(canonical)
|
||||
|
||||
|
||||
def _sha256(value: str) -> str:
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
||||
301
run_agent.py
301
run_agent.py
|
|
@ -162,6 +162,12 @@ from agent.display import (
|
|||
_detect_tool_failure,
|
||||
get_tool_emoji as _get_tool_emoji,
|
||||
)
|
||||
from agent.tool_guardrails import (
|
||||
ToolCallGuardrailController,
|
||||
ToolGuardrailDecision,
|
||||
append_toolguard_guidance,
|
||||
toolguard_synthetic_result,
|
||||
)
|
||||
from agent.trajectory import (
|
||||
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
||||
save_trajectory as _save_trajectory_to_file,
|
||||
|
|
@ -1150,6 +1156,8 @@ class AIAgent:
|
|||
# Tool execution state — allows _vprint during tool execution
|
||||
# even when stream consumers are registered (no tokens streaming then)
|
||||
self._executing_tools = False
|
||||
self._tool_guardrails = ToolCallGuardrailController()
|
||||
self._tool_guardrail_halt_decision: ToolGuardrailDecision | None = None
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
|
|
@ -9107,6 +9115,44 @@ class AIAgent:
|
|||
)
|
||||
return compressed, new_system_prompt
|
||||
|
||||
def _set_tool_guardrail_halt(self, decision: ToolGuardrailDecision) -> None:
|
||||
"""Record the first guardrail decision that should stop this turn."""
|
||||
if decision.should_halt and self._tool_guardrail_halt_decision is None:
|
||||
self._tool_guardrail_halt_decision = decision
|
||||
|
||||
def _toolguard_controlled_halt_response(self, decision: ToolGuardrailDecision) -> str:
|
||||
tool = decision.tool_name or "a tool"
|
||||
return (
|
||||
f"I stopped retrying {tool} because it hit the tool-call guardrail "
|
||||
f"({decision.code}) after {decision.count} repeated non-progressing "
|
||||
"attempts. The last tool result explains the blocker; the next step is "
|
||||
"to change strategy instead of repeating the same call."
|
||||
)
|
||||
|
||||
def _append_guardrail_observation(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: dict,
|
||||
function_result: str,
|
||||
*,
|
||||
failed: bool,
|
||||
) -> str:
|
||||
decision = self._tool_guardrails.after_call(
|
||||
tool_name,
|
||||
function_args,
|
||||
function_result,
|
||||
failed=failed,
|
||||
)
|
||||
if decision.action in {"warn", "halt"}:
|
||||
function_result = append_toolguard_guidance(function_result, decision)
|
||||
if decision.should_halt:
|
||||
self._set_tool_guardrail_halt(decision)
|
||||
return function_result
|
||||
|
||||
def _guardrail_block_result(self, decision: ToolGuardrailDecision) -> str:
|
||||
self._set_tool_guardrail_halt(decision)
|
||||
return toolguard_synthetic_result(decision)
|
||||
|
||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls from the assistant message and append results to messages.
|
||||
|
||||
|
|
@ -9150,7 +9196,8 @@ class AIAgent:
|
|||
)
|
||||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
||||
tool_call_id: Optional[str] = None, messages: list = None) -> str:
|
||||
tool_call_id: Optional[str] = None, messages: list = None,
|
||||
pre_tool_block_checked: bool = False) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
|
|
@ -9159,13 +9206,14 @@ class AIAgent:
|
|||
"""
|
||||
# Check plugin hooks for a block directive before executing anything.
|
||||
block_message: Optional[str] = None
|
||||
try:
|
||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||
block_message = get_pre_tool_call_block_message(
|
||||
function_name, function_args, task_id=effective_task_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if not pre_tool_block_checked:
|
||||
try:
|
||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||
block_message = get_pre_tool_call_block_message(
|
||||
function_name, function_args, task_id=effective_task_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if block_message is not None:
|
||||
return json.dumps({"error": block_message}, ensure_ascii=False)
|
||||
|
||||
|
|
@ -9317,13 +9365,31 @@ class AIAgent:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args))
|
||||
block_result = None
|
||||
blocked_by_guardrail = False
|
||||
try:
|
||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||
block_message = get_pre_tool_call_block_message(
|
||||
function_name, function_args, task_id=effective_task_id or "",
|
||||
)
|
||||
except Exception:
|
||||
block_message = None
|
||||
|
||||
if block_message is not None:
|
||||
block_result = json.dumps({"error": block_message}, ensure_ascii=False)
|
||||
else:
|
||||
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
|
||||
if not guardrail_decision.allows_execution:
|
||||
block_result = self._guardrail_block_result(guardrail_decision)
|
||||
blocked_by_guardrail = True
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
tool_names_str = ", ".join(name for _, name, _ in parsed_calls)
|
||||
tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls)
|
||||
if not self.quiet_mode:
|
||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||
for i, (tc, name, args) in enumerate(parsed_calls, 1):
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
if self.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
||||
|
|
@ -9332,7 +9398,9 @@ class AIAgent:
|
|||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for tc, name, args in parsed_calls:
|
||||
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||
if block_result is not None:
|
||||
continue
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(name, args)
|
||||
|
|
@ -9340,7 +9408,9 @@ class AIAgent:
|
|||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
for tc, name, args in parsed_calls:
|
||||
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||
if block_result is not None:
|
||||
continue
|
||||
if self.tool_start_callback:
|
||||
try:
|
||||
self.tool_start_callback(tc.id, name, args)
|
||||
|
|
@ -9348,8 +9418,11 @@ class AIAgent:
|
|||
logging.debug(f"Tool start callback error: {cb_err}")
|
||||
|
||||
# ── Concurrent execution ─────────────────────────────────────────
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag)
|
||||
results = [None] * num_tools
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
if block_result is not None:
|
||||
results[i] = (name, args, block_result, 0.0, True, True)
|
||||
|
||||
# Touch activity before launching workers so the gateway knows
|
||||
# we're executing tools (not stuck).
|
||||
|
|
@ -9404,7 +9477,14 @@ class AIAgent:
|
|||
pass
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id, messages=messages)
|
||||
result = self._invoke_tool(
|
||||
function_name,
|
||||
function_args,
|
||||
effective_task_id,
|
||||
tool_call.id,
|
||||
messages=messages,
|
||||
pre_tool_block_checked=True,
|
||||
)
|
||||
except Exception as tool_error:
|
||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||
|
|
@ -9414,7 +9494,7 @@ class AIAgent:
|
|||
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
|
||||
else:
|
||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
||||
results[index] = (function_name, function_args, result, duration, is_error)
|
||||
results[index] = (function_name, function_args, result, duration, is_error, False)
|
||||
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
||||
# have set so the next task scheduled onto this recycled tid
|
||||
# starts with a clean slate.
|
||||
|
|
@ -9440,61 +9520,67 @@ class AIAgent:
|
|||
spinner.start()
|
||||
|
||||
try:
|
||||
max_workers = min(num_tools, _MAX_TOOL_WORKERS)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
|
||||
ctx = contextvars.copy_context()
|
||||
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
runnable_calls = [
|
||||
(i, tc, name, args)
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
|
||||
if block_result is None
|
||||
]
|
||||
futures = []
|
||||
if runnable_calls:
|
||||
max_workers = min(len(runnable_calls), _MAX_TOOL_WORKERS)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for i, tc, name, args in runnable_calls:
|
||||
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
|
||||
ctx = contextvars.copy_context()
|
||||
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
|
||||
# Wait for all to complete with periodic heartbeats so the
|
||||
# gateway's inactivity monitor doesn't kill us during long
|
||||
# concurrent tool batches. Also check for user interrupts
|
||||
# so we don't block indefinitely when the user sends /stop
|
||||
# or a new message during concurrent tool execution.
|
||||
_conc_start = time.time()
|
||||
_interrupt_logged = False
|
||||
while True:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
futures, timeout=5.0,
|
||||
)
|
||||
if not not_done:
|
||||
break
|
||||
|
||||
# Check for interrupt — the per-thread interrupt signal
|
||||
# already causes individual tools (terminal, execute_code)
|
||||
# to abort, but tools without interrupt checks (web_search,
|
||||
# read_file) will run to completion. Cancel any futures
|
||||
# that haven't started yet so we don't block on them.
|
||||
if self._interrupt_requested:
|
||||
if not _interrupt_logged:
|
||||
_interrupt_logged = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚡ Interrupt: cancelling "
|
||||
f"{len(not_done)} pending concurrent tool(s)",
|
||||
force=True,
|
||||
)
|
||||
for f in not_done:
|
||||
f.cancel()
|
||||
# Give already-running tools a moment to notice the
|
||||
# per-thread interrupt signal and exit gracefully.
|
||||
concurrent.futures.wait(not_done, timeout=3.0)
|
||||
break
|
||||
|
||||
_conc_elapsed = int(time.time() - _conc_start)
|
||||
# Heartbeat every ~30s (6 × 5s poll intervals)
|
||||
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
|
||||
_still_running = [
|
||||
parsed_calls[futures.index(f)][1]
|
||||
for f in not_done
|
||||
if f in futures
|
||||
]
|
||||
self._touch_activity(
|
||||
f"concurrent tools running ({_conc_elapsed}s, "
|
||||
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
||||
# Wait for all to complete with periodic heartbeats so the
|
||||
# gateway's inactivity monitor doesn't kill us during long
|
||||
# concurrent tool batches. Also check for user interrupts
|
||||
# so we don't block indefinitely when the user sends /stop
|
||||
# or a new message during concurrent tool execution.
|
||||
_conc_start = time.time()
|
||||
_interrupt_logged = False
|
||||
while True:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
futures, timeout=5.0,
|
||||
)
|
||||
if not not_done:
|
||||
break
|
||||
|
||||
# Check for interrupt — the per-thread interrupt signal
|
||||
# already causes individual tools (terminal, execute_code)
|
||||
# to abort, but tools without interrupt checks (web_search,
|
||||
# read_file) will run to completion. Cancel any futures
|
||||
# that haven't started yet so we don't block on them.
|
||||
if self._interrupt_requested:
|
||||
if not _interrupt_logged:
|
||||
_interrupt_logged = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚡ Interrupt: cancelling "
|
||||
f"{len(not_done)} pending concurrent tool(s)",
|
||||
force=True,
|
||||
)
|
||||
for f in not_done:
|
||||
f.cancel()
|
||||
# Give already-running tools a moment to notice the
|
||||
# per-thread interrupt signal and exit gracefully.
|
||||
concurrent.futures.wait(not_done, timeout=3.0)
|
||||
break
|
||||
|
||||
_conc_elapsed = int(time.time() - _conc_start)
|
||||
# Heartbeat every ~30s (6 × 5s poll intervals)
|
||||
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
|
||||
_still_running = [
|
||||
parsed_calls[futures.index(f)][1]
|
||||
for f in not_done
|
||||
if f in futures
|
||||
]
|
||||
self._touch_activity(
|
||||
f"concurrent tools running ({_conc_elapsed}s, "
|
||||
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
||||
)
|
||||
finally:
|
||||
if spinner:
|
||||
# Build a summary message for the spinner stop
|
||||
|
|
@ -9503,8 +9589,9 @@ class AIAgent:
|
|||
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
||||
|
||||
# ── Post-execution: display per-tool results ─────────────────────
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
r = results[i]
|
||||
blocked = False
|
||||
if r is None:
|
||||
# Tool was cancelled (interrupt) or thread didn't return
|
||||
if self._interrupt_requested:
|
||||
|
|
@ -9513,13 +9600,21 @@ class AIAgent:
|
|||
function_result = f"Error executing tool '{name}': thread did not return a result"
|
||||
tool_duration = 0.0
|
||||
else:
|
||||
function_name, function_args, function_result, tool_duration, is_error = r
|
||||
function_name, function_args, function_result, tool_duration, is_error, blocked = r
|
||||
|
||||
if not blocked:
|
||||
function_result = self._append_guardrail_observation(
|
||||
function_name,
|
||||
function_args,
|
||||
function_result,
|
||||
failed=is_error,
|
||||
)
|
||||
|
||||
if is_error:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.tool_progress_callback:
|
||||
if not blocked and self.tool_progress_callback:
|
||||
try:
|
||||
self.tool_progress_callback(
|
||||
"tool.completed", function_name, None, None,
|
||||
|
|
@ -9547,7 +9642,7 @@ class AIAgent:
|
|||
self._current_tool = None
|
||||
self._touch_activity(f"tool completed: {name} ({tool_duration:.1f}s)")
|
||||
|
||||
if self.tool_complete_callback:
|
||||
if not blocked and self.tool_complete_callback:
|
||||
try:
|
||||
self.tool_complete_callback(tc.id, name, args, function_result)
|
||||
except Exception as cb_err:
|
||||
|
|
@ -9629,9 +9724,17 @@ class AIAgent:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
if _block_msg is not None:
|
||||
# Tool blocked by plugin policy — skip counter resets.
|
||||
# Execution is handled below in the tool dispatch chain.
|
||||
_guardrail_block_decision: ToolGuardrailDecision | None = None
|
||||
if _block_msg is None:
|
||||
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
|
||||
if not guardrail_decision.allows_execution:
|
||||
_guardrail_block_decision = guardrail_decision
|
||||
|
||||
_execution_blocked = _block_msg is not None or _guardrail_block_decision is not None
|
||||
|
||||
if _execution_blocked:
|
||||
# Tool blocked by plugin or guardrail policy — skip counters,
|
||||
# callbacks, checkpointing, activity mutation, and real execution.
|
||||
pass
|
||||
else:
|
||||
# Reset nudge counters when the relevant tool is actually used
|
||||
|
|
@ -9649,35 +9752,35 @@ class AIAgent:
|
|||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
|
||||
if _block_msg is None:
|
||||
if not _execution_blocked:
|
||||
self._current_tool = function_name
|
||||
self._touch_activity(f"executing tool: {function_name}")
|
||||
|
||||
# Set activity callback for long-running tool execution (terminal
|
||||
# commands, etc.) so the gateway's inactivity monitor doesn't kill
|
||||
# the agent while a command is running.
|
||||
if _block_msg is None:
|
||||
if not _execution_blocked:
|
||||
try:
|
||||
from tools.environments.base import set_activity_callback
|
||||
set_activity_callback(self._touch_activity)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _block_msg is None and self.tool_progress_callback:
|
||||
if not _execution_blocked and self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(function_name, function_args)
|
||||
self.tool_progress_callback("tool.started", function_name, preview, function_args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
if _block_msg is None and self.tool_start_callback:
|
||||
if not _execution_blocked and self.tool_start_callback:
|
||||
try:
|
||||
self.tool_start_callback(tool_call.id, function_name, function_args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool start callback error: {cb_err}")
|
||||
|
||||
# Checkpoint: snapshot working dir before file-mutating tools
|
||||
if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||
if not _execution_blocked and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
file_path = function_args.get("path", "")
|
||||
if file_path:
|
||||
|
|
@ -9689,7 +9792,7 @@ class AIAgent:
|
|||
pass # never block tool execution
|
||||
|
||||
# Checkpoint before destructive terminal commands
|
||||
if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled:
|
||||
if not _execution_blocked and function_name == "terminal" and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
cmd = function_args.get("command", "")
|
||||
if _is_destructive_command(cmd):
|
||||
|
|
@ -9706,6 +9809,11 @@ class AIAgent:
|
|||
# Tool blocked by plugin policy — return error without executing.
|
||||
function_result = json.dumps({"error": _block_msg}, ensure_ascii=False)
|
||||
tool_duration = 0.0
|
||||
elif _guardrail_block_decision is not None:
|
||||
# Tool blocked by tool-loop guardrail — synthesize exactly one
|
||||
# tool result for the original tool_call_id without executing.
|
||||
function_result = self._guardrail_block_result(_guardrail_block_decision)
|
||||
tool_duration = 0.0
|
||||
elif function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
function_result = _todo_tool(
|
||||
|
|
@ -9889,12 +9997,22 @@ class AIAgent:
|
|||
# Log tool errors to the persistent error log so [error] tags
|
||||
# in the UI always have a corresponding detailed entry on disk.
|
||||
_is_error_result, _ = _detect_tool_failure(function_name, function_result)
|
||||
if not _execution_blocked:
|
||||
function_result = self._append_guardrail_observation(
|
||||
function_name,
|
||||
function_args,
|
||||
function_result,
|
||||
failed=_is_error_result,
|
||||
)
|
||||
result_preview = function_result if self.verbose_logging else (
|
||||
function_result[:200] if len(function_result) > 200 else function_result
|
||||
)
|
||||
if _is_error_result:
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
else:
|
||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, tool_duration, len(function_result))
|
||||
|
||||
if self.tool_progress_callback:
|
||||
if not _execution_blocked and self.tool_progress_callback:
|
||||
try:
|
||||
self.tool_progress_callback(
|
||||
"tool.completed", function_name, None, None,
|
||||
|
|
@ -9910,7 +10028,7 @@ class AIAgent:
|
|||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
|
||||
if self.tool_complete_callback:
|
||||
if not _execution_blocked and self.tool_complete_callback:
|
||||
try:
|
||||
self.tool_complete_callback(tool_call.id, function_name, function_args, function_result)
|
||||
except Exception as cb_err:
|
||||
|
|
@ -10244,6 +10362,8 @@ class AIAgent:
|
|||
self._last_content_tools_all_housekeeping = False
|
||||
self._mute_post_response = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
self._tool_guardrails.reset_for_turn()
|
||||
self._tool_guardrail_halt_decision = None
|
||||
|
||||
# Pre-turn connection health check: detect and clean up dead TCP
|
||||
# connections left over from provider outages or dropped streams.
|
||||
|
|
@ -13041,6 +13161,16 @@ class AIAgent:
|
|||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||
|
||||
if self._tool_guardrail_halt_decision is not None:
|
||||
decision = self._tool_guardrail_halt_decision
|
||||
_turn_exit_reason = "guardrail_halt"
|
||||
final_response = self._toolguard_controlled_halt_response(decision)
|
||||
self._emit_status(
|
||||
f"⚠️ Tool guardrail halted {decision.tool_name}: {decision.code}"
|
||||
)
|
||||
messages.append({"role": "assistant", "content": final_response})
|
||||
break
|
||||
|
||||
# Reset per-turn retry counters after successful tool
|
||||
# execution so a single truncation doesn't poison the
|
||||
# entire conversation.
|
||||
|
|
@ -13567,6 +13697,7 @@ class AIAgent:
|
|||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"turn_exit_reason": _turn_exit_reason,
|
||||
"partial": False, # True only when stopped due to invalid tool calls
|
||||
"interrupted": interrupted,
|
||||
"response_previewed": getattr(self, "_response_was_previewed", False),
|
||||
|
|
@ -13586,6 +13717,8 @@ class AIAgent:
|
|||
"cost_status": self.session_cost_status,
|
||||
"cost_source": self.session_cost_source,
|
||||
}
|
||||
if self._tool_guardrail_halt_decision is not None:
|
||||
result["guardrail"] = self._tool_guardrail_halt_decision.to_metadata()
|
||||
# If a /steer landed after the final assistant turn (no more tool
|
||||
# batches to drain into), hand it back to the caller so it can be
|
||||
# delivered as the next user turn instead of being silently lost.
|
||||
|
|
|
|||
142
tests/agent/test_tool_guardrails.py
Normal file
142
tests/agent/test_tool_guardrails.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""Pure tool-call guardrail primitive tests."""
|
||||
|
||||
import json
|
||||
|
||||
from agent.tool_guardrails import (
|
||||
ToolCallGuardrailConfig,
|
||||
ToolCallGuardrailController,
|
||||
ToolCallSignature,
|
||||
canonical_tool_args,
|
||||
)
|
||||
|
||||
|
||||
def test_tool_call_signature_hashes_canonical_nested_unicode_args_without_exposing_raw_args():
|
||||
args_a = {
|
||||
"z": [{"β": "☤", "a": 1}],
|
||||
"a": {"y": 2, "x": "secret-token-value"},
|
||||
}
|
||||
args_b = {
|
||||
"a": {"x": "secret-token-value", "y": 2},
|
||||
"z": [{"a": 1, "β": "☤"}],
|
||||
}
|
||||
|
||||
assert canonical_tool_args(args_a) == canonical_tool_args(args_b)
|
||||
sig_a = ToolCallSignature.from_call("web_search", args_a)
|
||||
sig_b = ToolCallSignature.from_call("web_search", args_b)
|
||||
|
||||
assert sig_a == sig_b
|
||||
assert len(sig_a.args_hash) == 64
|
||||
metadata = sig_a.to_metadata()
|
||||
assert metadata == {"tool_name": "web_search", "args_hash": sig_a.args_hash}
|
||||
assert "secret-token-value" not in json.dumps(metadata)
|
||||
assert "☤" not in json.dumps(metadata)
|
||||
|
||||
|
||||
def test_repeated_identical_failed_call_warns_then_blocks_before_third_execution():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(
|
||||
exact_failure_warn_after=2,
|
||||
exact_failure_block_after=2,
|
||||
same_tool_failure_halt_after=99,
|
||||
)
|
||||
)
|
||||
args = {"query": "same"}
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
first = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert first.action == "allow"
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
second = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert second.action == "warn"
|
||||
assert second.code == "repeated_exact_failure_warning"
|
||||
assert second.count == 2
|
||||
|
||||
blocked = controller.before_call("web_search", args)
|
||||
assert blocked.action == "block"
|
||||
assert blocked.code == "repeated_exact_failure_block"
|
||||
assert blocked.tool_name == "web_search"
|
||||
assert blocked.count == 2
|
||||
|
||||
|
||||
def test_success_resets_exact_signature_failure_streak():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(exact_failure_block_after=2, same_tool_failure_halt_after=99)
|
||||
)
|
||||
args = {"query": "same"}
|
||||
|
||||
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("web_search", args, '{"ok":true}', failed=False)
|
||||
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||
assert controller.before_call("web_search", args).action == "allow"
|
||||
|
||||
|
||||
def test_same_tool_varying_args_failure_streak_warns_then_halts_independent_of_exact_streak():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(
|
||||
exact_failure_block_after=99,
|
||||
same_tool_failure_warn_after=2,
|
||||
same_tool_failure_halt_after=3,
|
||||
)
|
||||
)
|
||||
|
||||
first = controller.after_call("terminal", {"command": "cmd-1"}, '{"exit_code":1}', failed=True)
|
||||
assert first.action == "allow"
|
||||
second = controller.after_call("terminal", {"command": "cmd-2"}, '{"exit_code":1}', failed=True)
|
||||
assert second.action == "warn"
|
||||
assert second.code == "same_tool_failure_warning"
|
||||
third = controller.after_call("terminal", {"command": "cmd-3"}, '{"exit_code":1}', failed=True)
|
||||
assert third.action == "halt"
|
||||
assert third.code == "same_tool_failure_halt"
|
||||
assert third.count == 3
|
||||
|
||||
|
||||
def test_idempotent_no_progress_repeated_result_warns_then_blocks_future_repeat():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||
)
|
||||
args = {"path": "/tmp/same.txt"}
|
||||
result = "same file contents"
|
||||
|
||||
assert controller.before_call("read_file", args).action == "allow"
|
||||
assert controller.after_call("read_file", args, result, failed=False).action == "allow"
|
||||
assert controller.before_call("read_file", args).action == "allow"
|
||||
warn = controller.after_call("read_file", args, result, failed=False)
|
||||
assert warn.action == "warn"
|
||||
assert warn.code == "idempotent_no_progress_warning"
|
||||
|
||||
blocked = controller.before_call("read_file", args)
|
||||
assert blocked.action == "block"
|
||||
assert blocked.code == "idempotent_no_progress_block"
|
||||
|
||||
|
||||
def test_mutating_or_unknown_tools_are_not_blocked_for_repeated_identical_success_output_by_default():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
assert controller.before_call("write_file", {"path": "/tmp/x", "content": "x"}).action == "allow"
|
||||
assert controller.after_call("write_file", {"path": "/tmp/x", "content": "x"}, "ok", failed=False).action == "allow"
|
||||
assert controller.before_call("custom_tool", {"x": 1}).action == "allow"
|
||||
assert controller.after_call("custom_tool", {"x": 1}, "ok", failed=False).action == "allow"
|
||||
|
||||
|
||||
def test_reset_for_turn_clears_bounded_guardrail_state():
|
||||
controller = ToolCallGuardrailController(
|
||||
ToolCallGuardrailConfig(exact_failure_block_after=2, no_progress_block_after=2)
|
||||
)
|
||||
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||
|
||||
assert controller.before_call("web_search", {"query": "same"}).action == "block"
|
||||
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "block"
|
||||
|
||||
controller.reset_for_turn()
|
||||
|
||||
assert controller.before_call("web_search", {"query": "same"}).action == "allow"
|
||||
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "allow"
|
||||
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""Runtime tests for tool-call loop guardrails."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": f"{name} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
|
||||
def _mock_tool_call(name="web_search", arguments="{}", call_id=None):
|
||||
return SimpleNamespace(
|
||||
id=call_id or f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=SimpleNamespace(name=name, arguments=arguments),
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
|
||||
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
|
||||
|
||||
|
||||
def _make_agent(*tool_names: str, max_iterations: int = 10) -> AIAgent:
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
agent._cached_system_prompt = "You are helpful."
|
||||
agent._use_prompt_caching = False
|
||||
agent.tool_delay = 0
|
||||
agent.compression_enabled = False
|
||||
agent.save_trajectories = False
|
||||
return agent
|
||||
|
||||
|
||||
def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None:
|
||||
for _ in range(count):
|
||||
agent._tool_guardrails.after_call(
|
||||
tool_name,
|
||||
args,
|
||||
json.dumps({"error": "boom"}),
|
||||
failed=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_path_blocks_repeated_exact_failure_before_execution():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
_seed_exact_failures(agent, "web_search", args)
|
||||
starts = []
|
||||
progress = []
|
||||
agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
|
||||
agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-block")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc:
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
mock_hfc.assert_not_called()
|
||||
assert starts == []
|
||||
assert progress == []
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "tool"
|
||||
assert messages[0]["tool_call_id"] == "c-block"
|
||||
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||
|
||||
|
||||
def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
_seed_exact_failures(agent, "web_search", args, count=1)
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-warn")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})):
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
assert [m["role"] for m in messages] == ["tool"]
|
||||
assert messages[0]["tool_call_id"] == "c-warn"
|
||||
assert "Tool guardrail" in messages[0]["content"]
|
||||
assert "repeated_exact_failure_warning" in messages[0]["content"]
|
||||
|
||||
|
||||
def test_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order():
|
||||
agent = _make_agent("web_search")
|
||||
blocked_args = {"query": "blocked"}
|
||||
allowed_args = {"query": "allowed"}
|
||||
_seed_exact_failures(agent, "web_search", blocked_args)
|
||||
starts = []
|
||||
progress_events = []
|
||||
agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args))
|
||||
agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw))
|
||||
calls = [
|
||||
_mock_tool_call("web_search", json.dumps(blocked_args), "c-block"),
|
||||
_mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"),
|
||||
]
|
||||
msg = SimpleNamespace(content="", tool_calls=calls)
|
||||
messages = []
|
||||
executed = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
executed.append((name, args, kwargs["tool_call_id"]))
|
||||
return json.dumps({"ok": args["query"]})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(msg, messages, "task-1")
|
||||
|
||||
assert executed == [("web_search", allowed_args, "c-allow")]
|
||||
assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"]
|
||||
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||
assert json.loads(messages[1]["content"]) == {"ok": "allowed"}
|
||||
assert starts == [("c-allow", "web_search", allowed_args)]
|
||||
started_events = [event for event in progress_events if event[0] == "tool.started"]
|
||||
completed_events = [event for event in progress_events if event[0] == "tool.completed"]
|
||||
assert started_events == [("tool.started", "web_search", allowed_args, {})]
|
||||
assert len(completed_events) == 1
|
||||
assert completed_events[0][1] == "web_search"
|
||||
|
||||
|
||||
def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block():
|
||||
agent = _make_agent("web_search")
|
||||
args = {"query": "same"}
|
||||
tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin")
|
||||
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
|
||||
with (
|
||||
patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"),
|
||||
patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc,
|
||||
):
|
||||
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||
|
||||
mock_hfc.assert_not_called()
|
||||
assert "plugin policy" in messages[0]["content"]
|
||||
assert agent._tool_guardrails.before_call("web_search", args).action == "allow"
|
||||
|
||||
|
||||
def test_run_conversation_returns_controlled_guardrail_halt_without_top_level_error():
|
||||
agent = _make_agent("web_search", max_iterations=10)
|
||||
same_args = {"query": "same"}
|
||||
responses = [
|
||||
_mock_response(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
|
||||
)
|
||||
for i in range(1, 10)
|
||||
]
|
||||
agent.client.chat.completions.create.side_effect = responses
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("search repeatedly")
|
||||
|
||||
assert mock_hfc.call_count == 2
|
||||
assert result["api_calls"] == 3
|
||||
assert result["api_calls"] < agent.max_iterations
|
||||
assert result["turn_exit_reason"] == "guardrail_halt"
|
||||
assert "error" not in result
|
||||
assert result["completed"] is True
|
||||
assert "stopped retrying" in result["final_response"]
|
||||
assert result["guardrail"]["code"] == "repeated_exact_failure_block"
|
||||
assert result["guardrail"]["tool_name"] == "web_search"
|
||||
|
||||
assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")]
|
||||
for assistant_msg in assistant_tool_calls:
|
||||
call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]]
|
||||
following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids]
|
||||
assert len(following_results) == len(call_ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue