fix: classify landed file mutations with diagnostics

This commit is contained in:
GodsBoy 2026-05-13 11:36:07 +02:00 committed by kshitij
parent 71c6dd0dcf
commit da0ddbf88a
8 changed files with 153 additions and 1 deletions

View file

@ -14,6 +14,7 @@ from difflib import unified_diff
from pathlib import Path from pathlib import Path
from utils import safe_json_loads from utils import safe_json_loads
from agent.tool_result_classification import file_mutation_result_landed
# ANSI escape codes for coloring tool failure indicators # ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m" _RED = "\033[31m"
@ -810,6 +811,8 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
""" """
if result is None: if result is None:
return False, "" return False, ""
if file_mutation_result_landed(tool_name, result):
return False, ""
if tool_name == "terminal": if tool_name == "terminal":
data = safe_json_loads(result) data = safe_json_loads(result)

View file

@ -14,6 +14,7 @@ from dataclasses import dataclass, field
from typing import Any, Mapping from typing import Any, Mapping
from utils import safe_json_loads from utils import safe_json_loads
from agent.tool_result_classification import file_mutation_result_landed
IDEMPOTENT_TOOL_NAMES = frozenset( IDEMPOTENT_TOOL_NAMES = frozenset(
@ -196,6 +197,8 @@ def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str
""" """
if result is None: if result is None:
return False, "" return False, ""
if file_mutation_result_landed(tool_name, result):
return False, ""
if tool_name == "terminal": if tool_name == "terminal":
data = safe_json_loads(result) data = safe_json_loads(result)

View file

@ -0,0 +1,26 @@
"""Shared helpers for classifying tool result payloads."""
from __future__ import annotations
import json
from typing import Any
FILE_MUTATING_TOOL_NAMES = frozenset({"write_file", "patch"})
def file_mutation_result_landed(tool_name: str, result: Any) -> bool:
"""Return True when a file mutation result proves the write landed."""
if tool_name not in FILE_MUTATING_TOOL_NAMES or not isinstance(result, str):
return False
try:
data = json.loads(result.strip())
except Exception:
return False
if not isinstance(data, dict) or data.get("error"):
return False
if tool_name == "write_file":
return "bytes_written" in data
if tool_name == "patch":
return data.get("success") is True
return False

View file

@ -181,6 +181,7 @@ from agent.tool_guardrails import (
append_toolguard_guidance, append_toolguard_guidance,
toolguard_synthetic_result, toolguard_synthetic_result,
) )
from agent.tool_result_classification import file_mutation_result_landed
from agent.trajectory import ( from agent.trajectory import (
convert_scratchpad_to_think, has_incomplete_scratchpad, convert_scratchpad_to_think, has_incomplete_scratchpad,
save_trajectory as _save_trajectory_to_file, save_trajectory as _save_trajectory_to_file,
@ -5347,7 +5348,8 @@ class AIAgent:
targets = _extract_file_mutation_targets(tool_name, args) targets = _extract_file_mutation_targets(tool_name, args)
if not targets: if not targets:
return return
if is_error: landed = file_mutation_result_landed(tool_name, result)
if is_error and not landed:
preview = _extract_error_preview(result) preview = _extract_error_preview(result)
for path in targets: for path in targets:
# Keep the FIRST error we saw for a given path unless we # Keep the FIRST error we saw for a given path unless we

View file

@ -1,6 +1,7 @@
"""Tests for agent/display.py — build_tool_preview() and inline diff previews.""" """Tests for agent/display.py — build_tool_preview() and inline diff previews."""
import os import os
import json
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -149,6 +150,27 @@ class TestCuteToolMessagePreviewLength:
assert path in line assert path in line
assert "..." not in line assert "..." not in line
def test_write_file_lint_error_result_is_not_marked_failed(self):
result = json.dumps({
"bytes_written": 12,
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
})
line = get_cute_tool_message("write_file", {"path": "/tmp/a.py"}, 0.1, result=result)
assert "[error]" not in line
def test_patch_lsp_diagnostics_result_is_not_marked_failed(self):
result = json.dumps({
"success": True,
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
})
line = get_cute_tool_message("patch", {"path": "/tmp/a.py"}, 0.1, result=result)
assert "[error]" not in line
class TestEditDiffPreview: class TestEditDiffPreview:
def test_extract_edit_diff_for_patch(self): def test_extract_edit_diff_for_patch(self):

View file

@ -7,6 +7,7 @@ from agent.tool_guardrails import (
ToolCallGuardrailController, ToolCallGuardrailController,
ToolCallSignature, ToolCallSignature,
canonical_tool_args, canonical_tool_args,
classify_tool_failure,
) )
@ -131,6 +132,21 @@ def test_success_resets_exact_signature_failure_streak():
assert controller.before_call("web_search", args).action == "allow" assert controller.before_call("web_search", args).action == "allow"
def test_file_mutation_lint_error_result_is_not_a_tool_failure():
write_result = json.dumps({
"bytes_written": 12,
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
})
patch_result = json.dumps({
"success": True,
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
})
assert classify_tool_failure("write_file", write_result) == (False, "")
assert classify_tool_failure("patch", patch_result) == (False, "")
def test_same_tool_varying_args_warns_by_default_without_halting(): def test_same_tool_varying_args_warns_by_default_without_halting():
controller = ToolCallGuardrailController( controller = ToolCallGuardrailController(
ToolCallGuardrailConfig(same_tool_failure_warn_after=2, same_tool_failure_halt_after=3) ToolCallGuardrailConfig(same_tool_failure_warn_after=2, same_tool_failure_halt_after=3)

View file

@ -0,0 +1,30 @@
"""Tests for shared tool result classification helpers."""
import json
from agent.tool_result_classification import file_mutation_result_landed
def test_write_file_with_nested_lint_error_counts_as_landed():
result = json.dumps({
"bytes_written": 12,
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
})
assert file_mutation_result_landed("write_file", result) is True
def test_patch_with_nested_lsp_diagnostics_counts_as_landed():
result = json.dumps({
"success": True,
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
})
assert file_mutation_result_landed("patch", result) is True
def test_top_level_file_mutation_error_does_not_count_as_landed():
result = json.dumps({"success": True, "error": "post-write verification failed"})
assert file_mutation_result_landed("patch", result) is False

View file

@ -166,6 +166,56 @@ class TestRecordFileMutationResult:
) )
assert agent._turn_failed_file_mutations == {} assert agent._turn_failed_file_mutations == {}
def test_write_file_with_lint_error_counts_as_landed(self):
agent = _bare_agent()
agent._record_file_mutation_result(
"write_file",
{"path": "/tmp/a.py", "content": "bad"},
json.dumps({"error": "write failed"}),
is_error=True,
)
assert "/tmp/a.py" in agent._turn_failed_file_mutations
result = json.dumps({
"bytes_written": 24,
"lint": {"status": "error", "output": "SyntaxError: invalid syntax"},
})
agent._record_file_mutation_result(
"write_file",
{"path": "/tmp/a.py", "content": "def nope(:\n"},
result,
is_error=True,
)
assert agent._turn_failed_file_mutations == {}
def test_patch_with_lsp_diagnostics_counts_as_landed(self):
agent = _bare_agent()
agent._record_file_mutation_result(
"patch",
{"mode": "replace", "path": "/tmp/a.py", "old_string": "x", "new_string": "y"},
json.dumps({"error": "Could not find old_string"}),
is_error=True,
)
assert "/tmp/a.py" in agent._turn_failed_file_mutations
result = json.dumps({
"success": True,
"diff": "--- a/tmp.py\n+++ b/tmp.py\n",
"files_modified": ["/tmp/a.py"],
"lsp_diagnostics": "<diagnostics>ERROR [1:1] type mismatch</diagnostics>",
})
agent._record_file_mutation_result(
"patch",
{"mode": "replace", "path": "/tmp/a.py", "old_string": "x", "new_string": "y"},
result,
is_error=True,
)
assert agent._turn_failed_file_mutations == {}
def test_repeated_failure_keeps_first_error(self): def test_repeated_failure_keeps_first_error(self):
agent = _bare_agent() agent = _bare_agent()
agent._record_file_mutation_result( agent._record_file_mutation_result(