Track tool-call validity vs attempts; shape reward accordingly

- AgentResult now includes tool-call metrics: attempted, schema_valid,
  executed_ok, exec_error
- HermesAgentLoop normalizes args robustly without crashing, but
  distinguishes schema-valid args (dict) from coerced formats
  (stringified JSON, plain strings)
- SweSmithOracleEnv reward shaping now prefers schema-valid tool calls
  while still giving small credit for attempted tool use
This commit is contained in:
Shannon Sands 2026-02-14 09:17:05 +10:00
parent 35b2250b36
commit 499490d06a
2 changed files with 133 additions and 56 deletions

View file

@ -57,6 +57,12 @@ class AgentResult:
# Tool errors encountered during the loop
tool_errors: List[ToolError] = field(default_factory=list)
# Tool-call metrics (for reward shaping + debugging)
tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch
tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion)
tool_calls_executed_ok: int = 0 # Tool ran and returned no error
tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error
def _extract_reasoning_from_message(message) -> Optional[str]:
"""
@ -225,6 +231,60 @@ class HermesAgentLoop:
logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages))
return messages
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
"""Normalize tool arguments into a dict.
Returns:
(args_dict, schema_valid)
schema_valid is True only when the arguments decode directly into a dict
(i.e. no double-decoding and no coercion/wrapping was needed).
This lets us keep the environment robust (never crash due to args format)
while still scoring down malformed tool-call argument formats.
"""
try:
decoded = json.loads(tool_args_raw)
except json.JSONDecodeError:
return {}, False
# Canonical case: decoded is already a dict
if isinstance(decoded, dict):
# For terminal tool, require a command key
if tool_name == "terminal":
cmd = decoded.get("command")
if isinstance(cmd, str) and cmd.strip():
return decoded, True
# Common alternate key
if isinstance(decoded.get("input"), str):
return {"command": decoded.get("input")}, False
return decoded, False
return decoded, True
# Common drift case: decoded is a JSON string of an object
if isinstance(decoded, str):
s = decoded.strip()
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
decoded2 = json.loads(s)
except json.JSONDecodeError:
decoded2 = None
if isinstance(decoded2, dict):
# Terminal tool: ensure command
if tool_name == "terminal" and isinstance(decoded2.get("command"), str):
return decoded2, False
return decoded2, False
# Plain string (not JSON) — coerce to expected shape
if tool_name == "terminal":
return {"command": decoded}, False
return {"input": decoded}, False
# Other JSON types (list/number/etc.) — wrap
if tool_name == "terminal":
return {"command": str(decoded)}, False
return {"input": decoded}, False
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
"""
Execute the full agent loop using standard OpenAI tool calling.
@ -239,6 +299,12 @@ class HermesAgentLoop:
reasoning_per_turn = []
tool_errors: List[ToolError] = []
# Metrics to separate "attempted tool use" from "schema-valid tool use"
tool_calls_attempted = 0
tool_calls_schema_valid = 0
tool_calls_executed_ok = 0
tool_calls_exec_error = 0
for turn in range(self.max_turns):
# Truncate context if approaching limit
messages = self._truncate_context(messages)
@ -270,6 +336,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
if not response or not response.choices:
@ -281,6 +351,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
assistant_msg = response.choices[0].message
@ -339,42 +413,16 @@ class HermesAgentLoop:
"Model called unknown tool '%s' on turn %d",
tool_name, turn + 1,
)
tool_calls_exec_error += 1
else:
# Parse arguments and dispatch
try:
args = json.loads(tool_args_raw)
# Guard against double-encoded JSON strings
# Model sometimes outputs '{"command": "ls"}' as a JSON string
# so json.loads produces the string '{"command": "ls"}' not a dict
if isinstance(args, str):
try:
args2 = json.loads(args)
if isinstance(args2, dict):
args = args2
elif isinstance(args2, str):
# Triple-encoded... just wrap it
if tool_name == "terminal":
args = {"command": args2}
else:
args = {"input": args2}
else:
args = {"input": args2}
except (json.JSONDecodeError, TypeError):
# Plain string, not JSON - wrap it
if tool_name == "terminal":
args = {"command": args}
else:
args = {"input": args}
logger.debug(
"Tool args for '%s' decoded from string: %s",
tool_name, tool_args_raw[:200],
)
except json.JSONDecodeError:
args = {}
logger.warning(
"Invalid JSON in tool call arguments for '%s': %s",
tool_name, tool_args_raw[:200],
)
tool_calls_attempted += 1
# Normalize args into a dict so we never crash due to formatting.
# Track schema_valid separately so reward shaping can penalize
# non-canonical formats (e.g. stringified JSON).
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
if schema_valid:
tool_calls_schema_valid += 1
try:
if tool_name == "terminal":
@ -382,7 +430,7 @@ class HermesAgentLoop:
backend = os.getenv("TERMINAL_ENV", "local")
if self.tool_handler:
backend = "sandbox"
cmd_preview = args.get("command", "")[:80]
cmd_preview = str(args.get("command", ""))[:80]
print(f" 🖥️ [{backend}] $ {cmd_preview}")
if self.tool_handler:
@ -403,6 +451,7 @@ class HermesAgentLoop:
),
)
except Exception as e:
tool_calls_exec_error += 1
tool_result = json.dumps(
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
)
@ -416,22 +465,34 @@ class HermesAgentLoop:
"Tool '%s' execution failed on turn %d: %s",
tool_name, turn + 1, e,
)
else:
# Count tool result errors (if tool returns structured JSON error)
tool_err = False
try:
result_data = json.loads(tool_result)
if isinstance(result_data, dict):
err = result_data.get("error")
if err:
tool_err = True
# Also check if the tool returned an error in its JSON result
try:
result_data = json.loads(tool_result)
if isinstance(result_data, dict):
err = result_data.get("error")
exit_code = result_data.get("exit_code")
if err and exit_code and exit_code < 0:
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err),
tool_result=tool_result[:500],
))
except (json.JSONDecodeError, TypeError):
pass
# Keep existing behavior: treat negative exit_code as tool error
exit_code = result_data.get("exit_code")
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
tool_err = True
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err) if err else "nonzero exit_code",
tool_result=tool_result[:500],
))
except (json.JSONDecodeError, TypeError):
# Non-JSON tool output — assume ok
pass
if tool_err:
tool_calls_exec_error += 1
else:
tool_calls_executed_ok += 1
# Add tool response to conversation
messages.append(
@ -469,6 +530,10 @@ class HermesAgentLoop:
finished_naturally=True,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
# Hit max turns without the model stopping
@ -480,6 +545,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
def _get_managed_state(self) -> Optional[Dict[str, Any]]: