diff --git a/environments/agent_loop.py b/environments/agent_loop.py index f9f17739e4..ffb9e61881 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -57,6 +57,12 @@ class AgentResult: # Tool errors encountered during the loop tool_errors: List[ToolError] = field(default_factory=list) + # Tool-call metrics (for reward shaping + debugging) + tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch + tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion) + tool_calls_executed_ok: int = 0 # Tool ran and returned no error + tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error + def _extract_reasoning_from_message(message) -> Optional[str]: """ @@ -225,6 +231,60 @@ class HermesAgentLoop: logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages)) return messages + def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool): + """Normalize tool arguments into a dict. + + Returns: + (args_dict, schema_valid) + + schema_valid is True only when the arguments decode directly into a dict + (i.e. no double-decoding and no coercion/wrapping was needed). + + This lets us keep the environment robust (never crash due to args format) + while still scoring down malformed tool-call argument formats. + """ + try: + decoded = json.loads(tool_args_raw) + except json.JSONDecodeError: + return {}, False + + # Canonical case: decoded is already a dict + if isinstance(decoded, dict): + # For terminal tool, require a command key + if tool_name == "terminal": + cmd = decoded.get("command") + if isinstance(cmd, str) and cmd.strip(): + return decoded, True + # Common alternate key + if isinstance(decoded.get("input"), str): + return {"command": decoded.get("input")}, False + return decoded, False + return decoded, True + + # Common drift case: decoded is a JSON string of an object + if isinstance(decoded, str): + s = decoded.strip() + if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): + try: + decoded2 = json.loads(s) + except json.JSONDecodeError: + decoded2 = None + if isinstance(decoded2, dict): + # Terminal tool: ensure command + if tool_name == "terminal" and isinstance(decoded2.get("command"), str): + return decoded2, False + return decoded2, False + + # Plain string (not JSON) — coerce to expected shape + if tool_name == "terminal": + return {"command": decoded}, False + return {"input": decoded}, False + + # Other JSON types (list/number/etc.) — wrap + if tool_name == "terminal": + return {"command": str(decoded)}, False + return {"input": decoded}, False + async def run(self, messages: List[Dict[str, Any]]) -> AgentResult: """ Execute the full agent loop using standard OpenAI tool calling. @@ -239,6 +299,12 @@ class HermesAgentLoop: reasoning_per_turn = [] tool_errors: List[ToolError] = [] + # Metrics to separate "attempted tool use" from "schema-valid tool use" + tool_calls_attempted = 0 + tool_calls_schema_valid = 0 + tool_calls_executed_ok = 0 + tool_calls_exec_error = 0 + for turn in range(self.max_turns): # Truncate context if approaching limit messages = self._truncate_context(messages) @@ -270,6 +336,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) if not response or not response.choices: @@ -281,6 +351,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) assistant_msg = response.choices[0].message @@ -339,42 +413,16 @@ class HermesAgentLoop: "Model called unknown tool '%s' on turn %d", tool_name, turn + 1, ) + tool_calls_exec_error += 1 else: - # Parse arguments and dispatch - try: - args = json.loads(tool_args_raw) - # Guard against double-encoded JSON strings - # Model sometimes outputs '{"command": "ls"}' as a JSON string - # so json.loads produces the string '{"command": "ls"}' not a dict - if isinstance(args, str): - try: - args2 = json.loads(args) - if isinstance(args2, dict): - args = args2 - elif isinstance(args2, str): - # Triple-encoded... just wrap it - if tool_name == "terminal": - args = {"command": args2} - else: - args = {"input": args2} - else: - args = {"input": args2} - except (json.JSONDecodeError, TypeError): - # Plain string, not JSON - wrap it - if tool_name == "terminal": - args = {"command": args} - else: - args = {"input": args} - logger.debug( - "Tool args for '%s' decoded from string: %s", - tool_name, tool_args_raw[:200], - ) - except json.JSONDecodeError: - args = {} - logger.warning( - "Invalid JSON in tool call arguments for '%s': %s", - tool_name, tool_args_raw[:200], - ) + tool_calls_attempted += 1 + + # Normalize args into a dict so we never crash due to formatting. + # Track schema_valid separately so reward shaping can penalize + # non-canonical formats (e.g. stringified JSON). + args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw) + if schema_valid: + tool_calls_schema_valid += 1 try: if tool_name == "terminal": @@ -382,7 +430,7 @@ class HermesAgentLoop: backend = os.getenv("TERMINAL_ENV", "local") if self.tool_handler: backend = "sandbox" - cmd_preview = args.get("command", "")[:80] + cmd_preview = str(args.get("command", ""))[:80] print(f" 🖥️ [{backend}] $ {cmd_preview}") if self.tool_handler: @@ -403,6 +451,7 @@ class HermesAgentLoop: ), ) except Exception as e: + tool_calls_exec_error += 1 tool_result = json.dumps( {"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"} ) @@ -416,22 +465,34 @@ class HermesAgentLoop: "Tool '%s' execution failed on turn %d: %s", tool_name, turn + 1, e, ) + else: + # Count tool result errors (if tool returns structured JSON error) + tool_err = False + try: + result_data = json.loads(tool_result) + if isinstance(result_data, dict): + err = result_data.get("error") + if err: + tool_err = True - # Also check if the tool returned an error in its JSON result - try: - result_data = json.loads(tool_result) - if isinstance(result_data, dict): - err = result_data.get("error") - exit_code = result_data.get("exit_code") - if err and exit_code and exit_code < 0: - tool_errors.append(ToolError( - turn=turn + 1, tool_name=tool_name, - arguments=tool_args_raw[:200], - error=str(err), - tool_result=tool_result[:500], - )) - except (json.JSONDecodeError, TypeError): - pass + # Keep existing behavior: treat negative exit_code as tool error + exit_code = result_data.get("exit_code") + if exit_code is not None and isinstance(exit_code, int) and exit_code < 0: + tool_err = True + tool_errors.append(ToolError( + turn=turn + 1, tool_name=tool_name, + arguments=tool_args_raw[:200], + error=str(err) if err else "nonzero exit_code", + tool_result=tool_result[:500], + )) + except (json.JSONDecodeError, TypeError): + # Non-JSON tool output — assume ok + pass + + if tool_err: + tool_calls_exec_error += 1 + else: + tool_calls_executed_ok += 1 # Add tool response to conversation messages.append( @@ -469,6 +530,10 @@ class HermesAgentLoop: finished_naturally=True, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) # Hit max turns without the model stopping @@ -480,6 +545,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) def _get_managed_state(self) -> Optional[Dict[str, Any]]: diff --git a/environments/swe_smith_oracle_env.py b/environments/swe_smith_oracle_env.py index ea34c2b502..81134ec8f1 100644 --- a/environments/swe_smith_oracle_env.py +++ b/environments/swe_smith_oracle_env.py @@ -488,19 +488,27 @@ class SweSmithOracleEnv(HermesAgentBaseEnv): """ repo_dir = self._repo_name(item) - # Count valid tool calls (assistant messages that have tool_calls) - tool_call_count = sum( + # Count tool calls. Prefer the agent-loop metrics if present: + # - attempted: model called a known tool name + # - schema_valid: args were a dict (no coercion/double-decoding) + fallback_count = sum( len(msg.get("tool_calls", [])) for msg in result.messages if msg.get("role") == "assistant" ) - if tool_call_count == 0: + attempted = getattr(result, "tool_calls_attempted", fallback_count) + schema_valid = getattr(result, "tool_calls_schema_valid", fallback_count) + + if attempted == 0: print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True) return 0.0 - # Partial reward: 0.05 per tool call, capped at 0.3 - tool_call_reward = min(tool_call_count * 0.05, 0.3) + # Shaping: reward attempting tool use a little, but reward schema-valid calls more. + # Full credit per call is still 0.05 when schema_valid. + attempt_reward = min(attempted * 0.02, 0.10) + schema_reward = min(schema_valid * 0.03, 0.20) + tool_call_reward = min(attempt_reward + schema_reward, 0.30) nodeids = self._tests_for_item(item) if not nodeids: