mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix: classify landed file mutations with diagnostics
This commit is contained in:
parent
71c6dd0dcf
commit
da0ddbf88a
8 changed files with 153 additions and 1 deletions
|
|
@ -14,6 +14,7 @@ from difflib import unified_diff
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from utils import safe_json_loads
|
from utils import safe_json_loads
|
||||||
|
from agent.tool_result_classification import file_mutation_result_landed
|
||||||
|
|
||||||
# ANSI escape codes for coloring tool failure indicators
|
# ANSI escape codes for coloring tool failure indicators
|
||||||
_RED = "\033[31m"
|
_RED = "\033[31m"
|
||||||
|
|
@ -810,6 +811,8 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
||||||
"""
|
"""
|
||||||
if result is None:
|
if result is None:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
if file_mutation_result_landed(tool_name, result):
|
||||||
|
return False, ""
|
||||||
|
|
||||||
if tool_name == "terminal":
|
if tool_name == "terminal":
|
||||||
data = safe_json_loads(result)
|
data = safe_json_loads(result)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
from utils import safe_json_loads
|
from utils import safe_json_loads
|
||||||
|
from agent.tool_result_classification import file_mutation_result_landed
|
||||||
|
|
||||||
|
|
||||||
IDEMPOTENT_TOOL_NAMES = frozenset(
|
IDEMPOTENT_TOOL_NAMES = frozenset(
|
||||||
|
|
@ -196,6 +197,8 @@ def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str
|
||||||
"""
|
"""
|
||||||
if result is None:
|
if result is None:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
if file_mutation_result_landed(tool_name, result):
|
||||||
|
return False, ""
|
||||||
|
|
||||||
if tool_name == "terminal":
|
if tool_name == "terminal":
|
||||||
data = safe_json_loads(result)
|
data = safe_json_loads(result)
|
||||||
|
|
|
||||||
26
agent/tool_result_classification.py
Normal file
26
agent/tool_result_classification.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Shared helpers for classifying tool result payloads."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
FILE_MUTATING_TOOL_NAMES = frozenset({"write_file", "patch"})
|
||||||
|
|
||||||
|
|
||||||
|
def file_mutation_result_landed(tool_name: str, result: Any) -> bool:
|
||||||
|
"""Return True when a file mutation result proves the write landed."""
|
||||||
|
if tool_name not in FILE_MUTATING_TOOL_NAMES or not isinstance(result, str):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
data = json.loads(result.strip())
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
if not isinstance(data, dict) or data.get("error"):
|
||||||
|
return False
|
||||||
|
if tool_name == "write_file":
|
||||||
|
return "bytes_written" in data
|
||||||
|
if tool_name == "patch":
|
||||||
|
return data.get("success") is True
|
||||||
|
return False
|
||||||
|
|
@ -181,6 +181,7 @@ from agent.tool_guardrails import (
|
||||||
append_toolguard_guidance,
|
append_toolguard_guidance,
|
||||||
toolguard_synthetic_result,
|
toolguard_synthetic_result,
|
||||||
)
|
)
|
||||||
|
from agent.tool_result_classification import file_mutation_result_landed
|
||||||
from agent.trajectory import (
|
from agent.trajectory import (
|
||||||
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
||||||
save_trajectory as _save_trajectory_to_file,
|
save_trajectory as _save_trajectory_to_file,
|
||||||
|
|
@ -5347,7 +5348,8 @@ class AIAgent:
|
||||||
targets = _extract_file_mutation_targets(tool_name, args)
|
targets = _extract_file_mutation_targets(tool_name, args)
|
||||||
if not targets:
|
if not targets:
|
||||||
return
|
return
|
||||||
if is_error:
|
landed = file_mutation_result_landed(tool_name, result)
|
||||||
|
if is_error and not landed:
|
||||||
preview = _extract_error_preview(result)
|
preview = _extract_error_preview(result)
|
||||||
for path in targets:
|
for path in targets:
|
||||||
# Keep the FIRST error we saw for a given path unless we
|
# Keep the FIRST error we saw for a given path unless we
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Tests for agent/display.py — build_tool_preview() and inline diff previews."""
|
"""Tests for agent/display.py — build_tool_preview() and inline diff previews."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
@ -149,6 +150,27 @@ class TestCuteToolMessagePreviewLength:
|
||||||
assert path in line
|
assert path in line
|
||||||
assert "..." not in line
|
assert "..." not in line
|
||||||
|
|
||||||
|
def test_write_file_lint_error_result_is_not_marked_failed(self):
|
||||||
|
result = json.dumps({
|
||||||
|
"bytes_written": 12,
|
||||||
|
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
|
||||||
|
})
|
||||||
|
|
||||||
|
line = get_cute_tool_message("write_file", {"path": "/tmp/a.py"}, 0.1, result=result)
|
||||||
|
|
||||||
|
assert "[error]" not in line
|
||||||
|
|
||||||
|
def test_patch_lsp_diagnostics_result_is_not_marked_failed(self):
|
||||||
|
result = json.dumps({
|
||||||
|
"success": True,
|
||||||
|
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
|
||||||
|
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
|
||||||
|
})
|
||||||
|
|
||||||
|
line = get_cute_tool_message("patch", {"path": "/tmp/a.py"}, 0.1, result=result)
|
||||||
|
|
||||||
|
assert "[error]" not in line
|
||||||
|
|
||||||
|
|
||||||
class TestEditDiffPreview:
|
class TestEditDiffPreview:
|
||||||
def test_extract_edit_diff_for_patch(self):
|
def test_extract_edit_diff_for_patch(self):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from agent.tool_guardrails import (
|
||||||
ToolCallGuardrailController,
|
ToolCallGuardrailController,
|
||||||
ToolCallSignature,
|
ToolCallSignature,
|
||||||
canonical_tool_args,
|
canonical_tool_args,
|
||||||
|
classify_tool_failure,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -131,6 +132,21 @@ def test_success_resets_exact_signature_failure_streak():
|
||||||
assert controller.before_call("web_search", args).action == "allow"
|
assert controller.before_call("web_search", args).action == "allow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_mutation_lint_error_result_is_not_a_tool_failure():
|
||||||
|
write_result = json.dumps({
|
||||||
|
"bytes_written": 12,
|
||||||
|
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
|
||||||
|
})
|
||||||
|
patch_result = json.dumps({
|
||||||
|
"success": True,
|
||||||
|
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
|
||||||
|
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert classify_tool_failure("write_file", write_result) == (False, "")
|
||||||
|
assert classify_tool_failure("patch", patch_result) == (False, "")
|
||||||
|
|
||||||
|
|
||||||
def test_same_tool_varying_args_warns_by_default_without_halting():
|
def test_same_tool_varying_args_warns_by_default_without_halting():
|
||||||
controller = ToolCallGuardrailController(
|
controller = ToolCallGuardrailController(
|
||||||
ToolCallGuardrailConfig(same_tool_failure_warn_after=2, same_tool_failure_halt_after=3)
|
ToolCallGuardrailConfig(same_tool_failure_warn_after=2, same_tool_failure_halt_after=3)
|
||||||
|
|
|
||||||
30
tests/agent/test_tool_result_classification.py
Normal file
30
tests/agent/test_tool_result_classification.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""Tests for shared tool result classification helpers."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from agent.tool_result_classification import file_mutation_result_landed
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_file_with_nested_lint_error_counts_as_landed():
|
||||||
|
result = json.dumps({
|
||||||
|
"bytes_written": 12,
|
||||||
|
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert file_mutation_result_landed("write_file", result) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_with_nested_lsp_diagnostics_counts_as_landed():
|
||||||
|
result = json.dumps({
|
||||||
|
"success": True,
|
||||||
|
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
|
||||||
|
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert file_mutation_result_landed("patch", result) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_level_file_mutation_error_does_not_count_as_landed():
|
||||||
|
result = json.dumps({"success": True, "error": "post-write verification failed"})
|
||||||
|
|
||||||
|
assert file_mutation_result_landed("patch", result) is False
|
||||||
|
|
@ -166,6 +166,56 @@ class TestRecordFileMutationResult:
|
||||||
)
|
)
|
||||||
assert agent._turn_failed_file_mutations == {}
|
assert agent._turn_failed_file_mutations == {}
|
||||||
|
|
||||||
|
def test_write_file_with_lint_error_counts_as_landed(self):
|
||||||
|
agent = _bare_agent()
|
||||||
|
agent._record_file_mutation_result(
|
||||||
|
"write_file",
|
||||||
|
{"path": "/tmp/a.py", "content": "bad"},
|
||||||
|
json.dumps({"error": "write failed"}),
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
assert "/tmp/a.py" in agent._turn_failed_file_mutations
|
||||||
|
|
||||||
|
result = json.dumps({
|
||||||
|
"bytes_written": 24,
|
||||||
|
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
|
||||||
|
})
|
||||||
|
|
||||||
|
agent._record_file_mutation_result(
|
||||||
|
"write_file",
|
||||||
|
{"path": "/tmp/a.py", "content": "def nope(:\n"},
|
||||||
|
result,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent._turn_failed_file_mutations == {}
|
||||||
|
|
||||||
|
def test_patch_with_lsp_diagnostics_counts_as_landed(self):
|
||||||
|
agent = _bare_agent()
|
||||||
|
agent._record_file_mutation_result(
|
||||||
|
"patch",
|
||||||
|
{"mode": "replace", "path": "/tmp/a.py", "old_string": "x", "new_string": "y"},
|
||||||
|
json.dumps({"error": "Could not find old_string"}),
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
assert "/tmp/a.py" in agent._turn_failed_file_mutations
|
||||||
|
|
||||||
|
result = json.dumps({
|
||||||
|
"success": True,
|
||||||
|
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
|
||||||
|
"files_modified": ["/tmp/a.py"],
|
||||||
|
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
|
||||||
|
})
|
||||||
|
|
||||||
|
agent._record_file_mutation_result(
|
||||||
|
"patch",
|
||||||
|
{"mode": "replace", "path": "/tmp/a.py", "old_string": "x", "new_string": "y"},
|
||||||
|
result,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent._turn_failed_file_mutations == {}
|
||||||
|
|
||||||
def test_repeated_failure_keeps_first_error(self):
|
def test_repeated_failure_keeps_first_error(self):
|
||||||
agent = _bare_agent()
|
agent = _bare_agent()
|
||||||
agent._record_file_mutation_result(
|
agent._record_file_mutation_result(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue