fix(guardrails): preserve display _detect_tool_failure semantics

The initial guardrail PR consolidated failure classification by pointing
display._detect_tool_failure at the new classify_tool_failure helper,
which was strictly broader: it flagged any JSON result with
"success": false / "failed": true / non-empty "error", plus plain-text
"traceback" and "error:" prefixes. That would uptick the user-visible
[error] tag on tools that return {"success": false} as a benign signal
(memory fullness, todo state, etc.) and feed the failure-streak counter
at the same time.

Restore display._detect_tool_failure to its pre-PR semantics verbatim.
Tighten classify_tool_failure (the guardrail's internal safety-fallback
used only when callers don't pass failed=) to match _detect_tool_failure
exactly, so the two never disagree. Production callers in run_agent.py
already pass an explicit failed= derived from _detect_tool_failure, so
the guardrail counter is driven by the same signal the CLI shows.
This commit is contained in:
Teknium 2026-04-30 20:42:44 -07:00
parent 0704589ceb
commit 8fa44b1724
2 changed files with 37 additions and 21 deletions

View file

@ -14,7 +14,6 @@ from difflib import unified_diff
from pathlib import Path from pathlib import Path
from utils import safe_json_loads from utils import safe_json_loads
from agent.tool_guardrails import classify_tool_failure
# ANSI escape codes for coloring tool failure indicators # ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m" _RED = "\033[31m"
@ -809,7 +808,30 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
failures. On success, returns ``(False, "")``. failures. On success, returns ``(False, "")``.
""" """
return classify_tool_failure(tool_name, result) if result is None:
return False, ""
if tool_name == "terminal":
data = safe_json_loads(result)
if isinstance(data, dict):
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
return False, ""
# Memory-specific: distinguish "full" from real errors
if tool_name == "memory":
data = safe_json_loads(result)
if isinstance(data, dict):
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]"
# Generic heuristic for non-terminal tools
lower = result[:500].lower()
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
return True, " [error]"
return False, ""
def get_cute_tool_message( def get_cute_tool_message(

View file

@ -186,7 +186,14 @@ def canonical_tool_args(args: Mapping[str, Any]) -> str:
def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]: def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
"""Classify a tool result using shared display/runtime semantics.""" """Safety-fallback classifier used only when callers don't pass ``failed``.
Mirrors ``agent.display._detect_tool_failure`` exactly so the guardrail
never disagrees with the CLI's user-visible ``[error]`` tag. Production
callers in ``run_agent.py`` always pass an explicit ``failed=`` derived
from ``_detect_tool_failure``; this function exists so standalone callers
(tests, tooling) still get consistent behavior.
"""
if result is None: if result is None:
return False, "" return False, ""
@ -196,31 +203,18 @@ def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str
exit_code = data.get("exit_code") exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0: if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]" return True, f" [exit {exit_code}]"
if data.get("success") is False or data.get("failed") is True:
return True, " [error]"
error = data.get("error")
if error is not None and error != "":
return True, " [error]"
return False, "" return False, ""
data = safe_json_loads(result) if tool_name == "memory":
if isinstance(data, dict): data = safe_json_loads(result)
if tool_name == "memory": if isinstance(data, dict):
error = data.get("error", "") if data.get("success") is False and "exceed the limit" in data.get("error", ""):
if data.get("success") is False and isinstance(error, str) and "exceed the limit" in error:
return True, " [full]" return True, " [full]"
if data.get("success") is False or data.get("failed") is True:
return True, " [error]"
error = data.get("error")
if error is not None and error != "":
return True, " [error]"
return False, ""
lower = result[:500].lower() lower = result[:500].lower()
if "traceback" in lower or lower.startswith("error:"):
return True, " [error]"
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"): if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
return True, " [error]" return True, " [error]"
return False, "" return False, ""