diff --git a/environments/agent_loop.py b/environments/agent_loop.py index fffde112f70..e61f4f54853 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -73,6 +73,12 @@ class AgentResult: # Tool errors encountered during the loop tool_errors: List[ToolError] = field(default_factory=list) + # Tool-call metrics (debugging / optional reward shaping) + tool_calls_attempted: int = 0 + tool_calls_schema_valid: int = 0 + tool_calls_executed_ok: int = 0 + tool_calls_exec_error: int = 0 + def _extract_reasoning_from_message(message) -> Optional[str]: """ @@ -136,6 +142,8 @@ class HermesAgentLoop: temperature: float = 1.0, max_tokens: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, + tool_handler=None, + max_context_tokens: Optional[int] = None, ): """ Initialize the agent loop. @@ -152,6 +160,13 @@ class HermesAgentLoop: extra_body: Extra parameters passed to the OpenAI client's create() call. Used for OpenRouter provider preferences, transforms, etc. e.g. {"provider": {"ignore": ["DeepInfra"]}} + tool_handler: Optional async callable(tool_name, args, task_id) -> str. + When provided, used INSTEAD of handle_function_call() for + tool dispatch. This allows sandbox backends (Modal, Nomad) + to route tool calls through their slot-based execution. + max_context_tokens: Maximum prompt tokens before truncation. + If None, no truncation is applied. + Recommended: set to max_model_len - max_tokens - 512 (safety margin). """ self.server = server self.tool_schemas = tool_schemas @@ -161,6 +176,123 @@ class HermesAgentLoop: self.temperature = temperature self.max_tokens = max_tokens self.extra_body = extra_body + self.tool_handler = tool_handler + self.max_context_tokens = max_context_tokens + + def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Truncate conversation history to fit within max_context_tokens. + + Strategy: + - Keep system message (index 0) and initial user message (index 1) always + - Keep last 6 messages (recent context) always + - For everything in between, progressively truncate tool result content + - If still too long, drop oldest middle messages entirely + + Uses rough char/4 token estimate (fast, no tokenizer needed). + + NOTE: This function mutates the provided list (it may pop/replace entries). + Call it on a copy when you want to preserve the full trajectory. + """ + if self.max_context_tokens is None: + return messages + + def estimate_tokens(msgs): + total = 0 + for m in msgs: + content = m.get("content", "") or "" + total += len(content) // 4 + 10 # ~4 chars per token + overhead + if "tool_calls" in m: + total += 50 * len(m["tool_calls"]) # tool call overhead + return total + + if estimate_tokens(messages) <= self.max_context_tokens: + return messages + + protect_head = 2 + protect_tail = max(0, min(6, len(messages) - protect_head)) + middle_start = protect_head + middle_end = len(messages) - protect_tail + + # Phase 1: truncate tool outputs in the middle + if middle_start < middle_end: + for i in range(middle_start, middle_end): + if messages[i].get("role") == "tool": + content = messages[i].get("content", "") or "" + if len(content) > 200: + messages[i] = dict(messages[i]) + messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:] + + if estimate_tokens(messages) <= self.max_context_tokens: + return messages + + # Phase 2: drop oldest middle messages (try to keep assistant+tool pairs) + while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens: + msg = messages[middle_start] + messages.pop(middle_start) + middle_end -= 1 + + if msg.get("role") == "assistant" and msg.get("tool_calls"): + tool_ids = { + tc.get("id") or tc.get("tool_call_id", "") + for tc in msg.get("tool_calls", []) + if isinstance(tc, dict) + } + i = middle_start + while i < middle_end: + if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids: + messages.pop(i) + middle_end -= 1 + else: + i += 1 + + return messages + + def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool): + """Normalize tool arguments into a dict. + + Returns: (args_dict, schema_valid) + + schema_valid is True only when arguments decode directly into a dict + (no double-decoding and no coercion/wrapping required). + + Goal: keep environments robust (never crash on args format drift) while + still allowing reward functions to penalize malformed formats if desired. + """ + try: + decoded = json.loads(tool_args_raw) + except json.JSONDecodeError: + # Not JSON at all — treat as a plain string + if tool_name == "terminal": + return {"command": tool_args_raw}, False + return {"input": tool_args_raw}, False + + if isinstance(decoded, dict): + if tool_name == "terminal": + cmd = decoded.get("command") + if isinstance(cmd, str) and cmd.strip(): + return decoded, True + if isinstance(decoded.get("input"), str): + return {"command": decoded.get("input")}, False + return decoded, False + return decoded, True + + if isinstance(decoded, str): + s = decoded.strip() + if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): + try: + decoded2 = json.loads(s) + except json.JSONDecodeError: + decoded2 = None + if isinstance(decoded2, dict): + return decoded2, False + + if tool_name == "terminal": + return {"command": decoded}, False + return {"input": decoded}, False + + if tool_name == "terminal": + return {"command": str(decoded)}, False + return {"input": decoded}, False async def run(self, messages: List[Dict[str, Any]]) -> AgentResult: """ @@ -176,14 +308,22 @@ class HermesAgentLoop: reasoning_per_turn = [] tool_errors: List[ToolError] = [] + tool_calls_attempted = 0 + tool_calls_schema_valid = 0 + tool_calls_executed_ok = 0 + tool_calls_exec_error = 0 + import time as _time for turn in range(self.max_turns): turn_start = _time.monotonic() + # Truncate prompt view on a copy (preserve full trajectory in `messages`) + prompt_messages = self._truncate_context(list(messages)) + # Build the chat_completion kwargs chat_kwargs = { - "messages": messages, + "messages": prompt_messages, "n": 1, "temperature": self.temperature, } @@ -215,6 +355,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) api_elapsed = _time.monotonic() - api_start @@ -228,6 +372,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) assistant_msg = response.choices[0].message @@ -270,6 +418,7 @@ class HermesAgentLoop: # Validate tool name if tool_name not in self.valid_tool_names: + tool_calls_exec_error += 1 tool_result = json.dumps( { "error": f"Unknown tool '{tool_name}'. " @@ -287,35 +436,35 @@ class HermesAgentLoop: tool_name, turn + 1, ) else: - # Parse arguments and dispatch - try: - args = json.loads(tool_args_raw) - except json.JSONDecodeError: - args = {} - logger.warning( - "Invalid JSON in tool call arguments for '%s': %s", - tool_name, tool_args_raw[:200], - ) + tool_calls_attempted += 1 + args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw) + if schema_valid: + tool_calls_schema_valid += 1 try: if tool_name == "terminal": backend = os.getenv("TERMINAL_ENV", "local") - cmd_preview = args.get("command", "")[:80] + cmd_preview = str(args.get("command", ""))[:80] logger.info( "[%s] $ %s", self.task_id[:8], cmd_preview, ) - # Run tool calls in a thread pool so backends that use - # asyncio.run() internally (modal, docker) get a clean - # event loop instead of deadlocking inside Atropos's loop. tool_submit_time = _time.monotonic() - loop = asyncio.get_event_loop() - tool_result = await loop.run_in_executor( - _tool_executor, - lambda: handle_function_call( - tool_name, args, task_id=self.task_id - ), - ) + + if self.tool_handler: + tool_result = await self.tool_handler(tool_name, args, self.task_id) + else: + # Run tool calls in a thread pool so backends that use + # asyncio.run() internally (modal, docker) get a clean + # event loop instead of deadlocking inside Atropos's loop. + loop = asyncio.get_event_loop() + tool_result = await loop.run_in_executor( + _tool_executor, + lambda: handle_function_call( + tool_name, args, task_id=self.task_id + ), + ) + tool_elapsed = _time.monotonic() - tool_submit_time # Log slow tools and thread pool stats for debugging @@ -327,6 +476,7 @@ class HermesAgentLoop: tool_elapsed, pool_active, ) except Exception as e: + tool_calls_exec_error += 1 tool_result = json.dumps( {"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"} ) @@ -340,22 +490,31 @@ class HermesAgentLoop: "Tool '%s' execution failed on turn %d: %s", tool_name, turn + 1, e, ) + else: + tool_err = False + try: + result_data = json.loads(tool_result) + if isinstance(result_data, dict): + err = result_data.get("error") + if err: + tool_err = True - # Also check if the tool returned an error in its JSON result - try: - result_data = json.loads(tool_result) - if isinstance(result_data, dict): - err = result_data.get("error") - exit_code = result_data.get("exit_code") - if err and exit_code and exit_code < 0: - tool_errors.append(ToolError( - turn=turn + 1, tool_name=tool_name, - arguments=tool_args_raw[:200], - error=str(err), - tool_result=tool_result[:500], - )) - except (json.JSONDecodeError, TypeError): - pass + exit_code = result_data.get("exit_code") + if exit_code is not None and isinstance(exit_code, int) and exit_code < 0: + tool_err = True + tool_errors.append(ToolError( + turn=turn + 1, tool_name=tool_name, + arguments=tool_args_raw[:200], + error=str(err) if err else "nonzero exit_code", + tool_result=tool_result[:500], + )) + except (json.JSONDecodeError, TypeError): + pass + + if tool_err: + tool_calls_exec_error += 1 + else: + tool_calls_executed_ok += 1 # Add tool response to conversation messages.append( @@ -396,6 +555,10 @@ class HermesAgentLoop: finished_naturally=True, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) # Hit max turns without the model stopping @@ -407,6 +570,10 @@ class HermesAgentLoop: finished_naturally=False, reasoning_per_turn=reasoning_per_turn, tool_errors=tool_errors, + tool_calls_attempted=tool_calls_attempted, + tool_calls_schema_valid=tool_calls_schema_valid, + tool_calls_executed_ok=tool_calls_executed_ok, + tool_calls_exec_error=tool_calls_exec_error, ) def _get_managed_state(self) -> Optional[Dict[str, Any]]: diff --git a/environments/hermes_base_env.py b/environments/hermes_base_env.py index 98a40dd5d19..1a4152a02d4 100644 --- a/environments/hermes_base_env.py +++ b/environments/hermes_base_env.py @@ -478,6 +478,7 @@ class HermesAgentBaseEnv(BaseEnv): tokenizer=self.tokenizer, tool_call_parser=tc_parser, ) as managed: + _max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None agent = HermesAgentLoop( server=managed, tool_schemas=tools, @@ -487,6 +488,7 @@ class HermesAgentBaseEnv(BaseEnv): temperature=self.config.agent_temperature, max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, + max_context_tokens=_max_ctx, ) result = await agent.run(messages) except NotImplementedError: @@ -495,6 +497,7 @@ class HermesAgentBaseEnv(BaseEnv): "ManagedServer not available (OpenAI server?). " "Falling back to direct server mode." ) + _max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None agent = HermesAgentLoop( server=self.server, tool_schemas=tools, @@ -504,10 +507,12 @@ class HermesAgentBaseEnv(BaseEnv): temperature=self.config.agent_temperature, max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, + max_context_tokens=_max_ctx, ) result = await agent.run(messages) else: # Phase 1: OpenAI server -- native tool_calls, placeholder tokens + _max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None agent = HermesAgentLoop( server=self.server, tool_schemas=tools, @@ -517,6 +522,7 @@ class HermesAgentBaseEnv(BaseEnv): temperature=self.config.agent_temperature, max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, + max_context_tokens=_max_ctx, ) result = await agent.run(messages) diff --git a/environments/tool_call_parsers/hermes_parser.py b/environments/tool_call_parsers/hermes_parser.py index c1902fd623c..2f3e164270b 100644 --- a/environments/tool_call_parsers/hermes_parser.py +++ b/environments/tool_call_parsers/hermes_parser.py @@ -49,15 +49,22 @@ class HermesToolCallParser(ToolCallParser): continue tc_data = json.loads(raw_json) + # Handle arguments: could be dict or already a JSON string + raw_args = tc_data.get("arguments", {}) + if isinstance(raw_args, str): + # Already a string — pass through as-is. + # It may be a JSON string ("{...}") or a plain string ("ls"). + args_str = raw_args + else: + # Dict — serialize to JSON + args_str = json.dumps(raw_args, ensure_ascii=False) tool_calls.append( ChatCompletionMessageToolCall( id=f"call_{uuid.uuid4().hex[:8]}", type="function", function=Function( name=tc_data["name"], - arguments=json.dumps( - tc_data.get("arguments", {}), ensure_ascii=False - ), + arguments=args_str, ), ) )