Enhance async tool execution and error handling in Hermes agent for Atropos integration

- Updated `.gitignore` to exclude `testlogs` directory.
- Refactored `handle_web_function_call` in `model_tools.py` to support running async functions in existing event loops, improving compatibility with Atropos.
- Introduced a thread pool executor in `agent_loop.py` for running synchronous tool calls that internally use `asyncio.run()`, preventing deadlocks.
- Added `ToolError` class to track tool execution errors, enhancing error reporting during agent loops.
- Updated `wandb_log` method in `hermes_base_env.py` to log tool error statistics for better monitoring.
- Implemented patches in `patches.py` to ensure async-safe operation of tools within Atropos's event loop.
- Enhanced `ToolContext` and `terminal_tool.py` to utilize the new async handling, improving overall tool execution reliability.
This commit is contained in:
teknium 2026-02-08 05:00:47 +00:00
parent a8809bbd3e
commit d999d9876d
9 changed files with 540 additions and 64 deletions

View file

@ -11,6 +11,8 @@ identical to hermes-agent's run_agent.py. Tool execution is dispatched via
handle_function_call() from model_tools.py.
"""
import asyncio
import concurrent.futures
import json
import logging
import uuid
@ -19,9 +21,25 @@ from typing import Any, Dict, List, Optional, Set
from model_tools import handle_function_call
# Thread pool for running sync tool calls that internally use asyncio.run()
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
logger = logging.getLogger(__name__)
@dataclass
class ToolError:
"""Record of a tool execution error during the agent loop."""
turn: int # Which turn the error occurred on
tool_name: str # Which tool was called
arguments: str # The arguments passed (truncated)
error: str # The error message
tool_result: str # The raw result returned to the model
@dataclass
class AgentResult:
"""Result of running the agent loop."""
@ -36,6 +54,8 @@ class AgentResult:
finished_naturally: bool = False
# Extracted reasoning content per turn (from PR #297 helpers)
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
# Tool errors encountered during the loop
tool_errors: List[ToolError] = field(default_factory=list)
def _extract_reasoning_from_message(message) -> Optional[str]:
@ -133,6 +153,7 @@ class HermesAgentLoop:
AgentResult with full conversation history, managed state, and metadata
"""
reasoning_per_turn = []
tool_errors: List[ToolError] = []
for turn in range(self.max_turns):
# Build the chat_completion kwargs
@ -161,6 +182,7 @@ class HermesAgentLoop:
turns_used=turn + 1,
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
)
if not response or not response.choices:
@ -171,6 +193,7 @@ class HermesAgentLoop:
turns_used=turn + 1,
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
)
assistant_msg = response.choices[0].message
@ -209,6 +232,7 @@ class HermesAgentLoop:
# Execute each tool call via hermes-agent's dispatch
for tc in assistant_msg.tool_calls:
tool_name = tc.function.name
tool_args_raw = tc.function.arguments
# Validate tool name
if tool_name not in self.valid_tool_names:
@ -218,35 +242,75 @@ class HermesAgentLoop:
f"Available tools: {sorted(self.valid_tool_names)}"
}
)
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"Unknown tool '{tool_name}'",
tool_result=tool_result,
))
logger.warning(
"Model called unknown tool '%s' on turn %d",
tool_name,
turn + 1,
tool_name, turn + 1,
)
else:
# Parse arguments and dispatch
try:
args = json.loads(tc.function.arguments)
args = json.loads(tool_args_raw)
except json.JSONDecodeError:
args = {}
logger.warning(
"Invalid JSON in tool call arguments for '%s': %s",
tool_name,
tc.function.arguments[:200],
tool_name, tool_args_raw[:200],
)
try:
tool_result = handle_function_call(
tool_name, args, task_id=self.task_id
if tool_name == "terminal":
import os
backend = os.getenv("TERMINAL_ENV", "local")
cmd_preview = args.get("command", "")[:80]
print(f" 🖥️ [{backend}] $ {cmd_preview}")
# Run tool calls in a thread pool so backends that use
# asyncio.run() internally (modal, docker) get a clean
# event loop instead of deadlocking inside Atropos's loop.
loop = asyncio.get_event_loop()
tool_result = await loop.run_in_executor(
_tool_executor,
lambda: handle_function_call(
tool_name, args, task_id=self.task_id
),
)
except Exception as e:
tool_result = json.dumps(
{"error": f"Tool execution failed: {str(e)}"}
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
)
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=f"{type(e).__name__}: {str(e)}",
tool_result=tool_result,
))
logger.error(
"Tool '%s' execution failed: %s", tool_name, e
"Tool '%s' execution failed on turn %d: %s",
tool_name, turn + 1, e,
)
# Also check if the tool returned an error in its JSON result
try:
result_data = json.loads(tool_result)
if isinstance(result_data, dict):
err = result_data.get("error")
exit_code = result_data.get("exit_code")
if err and exit_code and exit_code < 0:
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err),
tool_result=tool_result[:500],
))
except (json.JSONDecodeError, TypeError):
pass
# Add tool response to conversation
messages.append(
{
@ -282,6 +346,7 @@ class HermesAgentLoop:
turns_used=turn + 1,
finished_naturally=True,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
)
# Hit max turns without the model stopping
@ -292,6 +357,7 @@ class HermesAgentLoop:
turns_used=self.max_turns,
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
)
def _get_managed_state(self) -> Optional[Dict[str, Any]]:

View file

@ -41,6 +41,12 @@ _env_path = _repo_root / ".env"
if _env_path.exists():
load_dotenv(dotenv_path=_env_path)
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
# This patches SwerexModalEnvironment to use a background thread instead of
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
from environments.patches import apply_patches
apply_patches()
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
@ -172,10 +178,14 @@ class HermesAgentBaseEnv(BaseEnv):
# Set terminal backend environment variable so hermes tools pick it up
if config.terminal_backend:
os.environ["TERMINAL_ENV"] = config.terminal_backend
print(f"🖥️ Terminal backend: {config.terminal_backend}")
# Current group's resolved tools (set in collect_trajectories)
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
# Tool error tracking for wandb logging
self._tool_error_buffer: List[Dict[str, Any]] = []
# =========================================================================
# Toolset resolution (per-group)
# =========================================================================
@ -348,6 +358,33 @@ class HermesAgentBaseEnv(BaseEnv):
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log base metrics including tool errors to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Log tool error stats
if self._tool_error_buffer:
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
error_summaries = []
for err in self._tool_error_buffer:
error_summaries.append(
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
)
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
# Also print to stdout for immediate visibility
for summary in error_summaries:
print(f" Tool Error: {summary}")
self._tool_error_buffer = []
else:
wandb_metrics["train/tool_errors_count"] = 0
await super().wandb_log(wandb_metrics)
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
@ -376,8 +413,22 @@ class HermesAgentBaseEnv(BaseEnv):
result: AgentResult
if self._use_managed_server():
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
# Load the tool call parser from registry based on config
from environments.tool_call_parsers import get_parser
try:
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
tc_parser = get_parser(self.config.tool_call_parser)
except KeyError:
logger.warning(
"Tool call parser '%s' not found, falling back to 'hermes'",
self.config.tool_call_parser,
)
tc_parser = get_parser("hermes")
try:
async with self.server.managed_server(
tokenizer=self.tokenizer,
tool_call_parser=tc_parser,
) as managed:
agent = HermesAgentLoop(
server=managed,
tool_schemas=tools,
@ -417,15 +468,39 @@ class HermesAgentBaseEnv(BaseEnv):
)
result = await agent.run(messages)
# Compute reward using ToolContext (gives verifier full tool access)
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.error("compute_reward failed: %s", e)
# Skip reward computation if the agent loop produced no meaningful work
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
# just to verify files that were never created.
only_system_and_user = all(
msg.get("role") in ("system", "user") for msg in result.messages
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
result.turns_used, len(result.messages),
)
reward = 0.0
finally:
ctx.cleanup()
else:
# Compute reward using ToolContext (gives verifier full tool access)
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.error("compute_reward failed: %s", e)
reward = 0.0
finally:
ctx.cleanup()
# Track tool errors for wandb logging
if result.tool_errors:
for err in result.tool_errors:
self._tool_error_buffer.append({
"turn": err.turn,
"tool": err.tool_name,
"args": err.arguments[:150],
"error": err.error[:300],
"result": err.tool_result[:300],
})
# Build ScoredDataItem from ManagedServer state
# Phase 2: real tokens/masks/logprobs from SequenceNodes

188
environments/patches.py Normal file
View file

@ -0,0 +1,188 @@
"""
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
Problem:
Some tools use asyncio.run() internally (e.g., mini-swe-agent's Modal backend,
web_extract). This crashes when called from inside Atropos's event loop because
asyncio.run() can't be nested.
Solution:
Replace the problematic methods with versions that use a dedicated background
thread with its own event loop. The calling code sees the same sync interface --
call a function, get a result -- but internally the async work happens on a
separate thread that doesn't conflict with Atropos's loop.
These patches are safe for normal CLI use too: when there's no running event
loop, the behavior is identical (the background thread approach works regardless).
What gets patched:
- SwerexModalEnvironment.__init__ -- creates Modal deployment on a background thread
- SwerexModalEnvironment.execute -- runs commands on the same background thread
- SwerexModalEnvironment.stop -- stops deployment on the background thread
Usage:
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
This is idempotent -- calling it multiple times is safe.
"""
import asyncio
import logging
import threading
from typing import Any
logger = logging.getLogger(__name__)
_patches_applied = False
class _AsyncWorker:
"""
A dedicated background thread with its own event loop.
Allows sync code to submit async coroutines and block for results,
even when called from inside another running event loop. Used to
bridge sync tool interfaces with async backends (Modal, SWE-ReX).
"""
def __init__(self):
self._loop: asyncio.AbstractEventLoop = None
self._thread: threading.Thread = None
self._started = threading.Event()
def start(self):
"""Start the background event loop thread."""
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._started.wait(timeout=30)
def _run_loop(self):
"""Background thread entry point -- runs the event loop forever."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._started.set()
self._loop.run_forever()
def run_coroutine(self, coro, timeout=600):
"""
Submit a coroutine to the background loop and block until it completes.
Safe to call from any thread, including threads that already have
a running event loop.
"""
if self._loop is None or self._loop.is_closed():
raise RuntimeError("AsyncWorker loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result(timeout=timeout)
def stop(self):
"""Stop the background event loop and join the thread."""
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=10)
def _patch_swerex_modal():
"""
Monkey patch SwerexModalEnvironment to use a background thread event loop
instead of asyncio.run(). This makes it safe to call from inside Atropos's
async event loop.
The patched methods have the exact same interface and behavior -- the only
difference is HOW the async work is executed internally.
"""
try:
from minisweagent.environments.extra.swerex_modal import (
SwerexModalEnvironment,
SwerexModalEnvironmentConfig,
)
from swerex.deployment.modal import ModalDeployment
from swerex.runtime.abstract import Command as RexCommand
except ImportError:
# mini-swe-agent or swe-rex not installed -- nothing to patch
logger.debug("mini-swe-agent Modal backend not available, skipping patch")
return
# Save original methods so we can refer to config handling
_original_init = SwerexModalEnvironment.__init__
def _patched_init(self, **kwargs):
"""Patched __init__: creates Modal deployment on a background thread."""
self.config = SwerexModalEnvironmentConfig(**kwargs)
# Start a dedicated event loop thread for all Modal async operations
self._worker = _AsyncWorker()
self._worker.start()
# Create AND start the deployment entirely on the worker's loop/thread
# so all gRPC channels and async state are bound to that loop
async def _create_and_start():
deployment = ModalDeployment(
image=self.config.image,
startup_timeout=self.config.startup_timeout,
runtime_timeout=self.config.runtime_timeout,
deployment_timeout=self.config.deployment_timeout,
install_pipx=self.config.install_pipx,
modal_sandbox_kwargs=self.config.modal_sandbox_kwargs,
)
await deployment.start()
return deployment
self.deployment = self._worker.run_coroutine(_create_and_start())
def _patched_execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
"""Patched execute: runs commands on the background thread's loop."""
async def _do_execute():
return await self.deployment.runtime.execute(
RexCommand(
command=command,
shell=True,
check=False,
cwd=cwd or self.config.cwd,
timeout=timeout or self.config.timeout,
merge_output_streams=True,
env=self.config.env if self.config.env else None,
)
)
output = self._worker.run_coroutine(_do_execute())
return {
"output": output.stdout,
"returncode": output.exit_code,
}
def _patched_stop(self):
"""Patched stop: stops deployment on the background thread, then stops the thread."""
try:
self._worker.run_coroutine(
asyncio.wait_for(self.deployment.stop(), timeout=10),
timeout=15,
)
except Exception:
pass
finally:
self._worker.stop()
# Apply the patches
SwerexModalEnvironment.__init__ = _patched_init
SwerexModalEnvironment.execute = _patched_execute
SwerexModalEnvironment.stop = _patched_stop
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
def apply_patches():
"""
Apply all monkey patches needed for Atropos compatibility.
Safe to call multiple times -- patches are only applied once.
Safe for normal CLI use -- patched code works identically when
there is no running event loop.
"""
global _patches_applied
if _patches_applied:
return
_patch_swerex_modal()
_patches_applied = True

View file

@ -132,7 +132,7 @@ class TerminalTestEnv(HermesAgentBaseEnv):
terminal_backend="modal",
# Atropos settings
group_size=3, # 3 rollouts per group
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
tokenizer_name="NousResearch/q-30b-t-h45-e1",
tool_call_parser="hermes",
steps_per_eval=3, # Eval after all 3 steps
total_steps=3, # 3 groups total (1 group per step)

View file

@ -25,14 +25,43 @@ Example usage in a compute_reward():
import json
import logging
import os
from typing import Any, Dict, List, Optional
import asyncio
import concurrent.futures
from model_tools import handle_function_call
from tools.terminal_tool import cleanup_vm
from tools.browser_tool import cleanup_browser
logger = logging.getLogger(__name__)
# Thread pool for running sync tool calls that internally use asyncio.run()
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
"""
Run a tool call in a thread pool executor so backends that use asyncio.run()
internally (modal, docker) get a clean event loop.
If we're already in an async context, uses run_in_executor.
If not (e.g., called from sync code), runs directly.
"""
try:
loop = asyncio.get_running_loop()
# We're in an async context -- need to run in thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(
handle_function_call, tool_name, arguments, task_id
)
return future.result(timeout=300)
except RuntimeError:
# No running event loop -- safe to call directly
return handle_function_call(tool_name, arguments, task_id)
class ToolContext:
"""
@ -61,10 +90,15 @@ class ToolContext:
Returns:
Dict with 'exit_code' (int) and 'output' (str)
"""
result = handle_function_call(
import os
backend = os.getenv("TERMINAL_ENV", "local")
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
# Run in thread pool so modal/docker backends' asyncio.run() doesn't deadlock
result = _run_tool_in_thread(
"terminal",
{"command": command, "timeout": timeout},
task_id=self.task_id,
self.task_id,
)
try:
return json.loads(result)
@ -222,7 +256,7 @@ class ToolContext:
Returns:
Raw JSON string result from the tool
"""
return handle_function_call(tool_name, arguments, task_id=self.task_id)
return _run_tool_in_thread(tool_name, arguments, self.task_id)
# -------------------------------------------------------------------------
# Cleanup
@ -240,7 +274,16 @@ class ToolContext:
except Exception as e:
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
# Suppress browser_tool's noisy debug prints during cleanup.
# The cleanup still runs (safe), it just doesn't spam the console.
_prev_quiet = os.environ.get("HERMES_QUIET")
os.environ["HERMES_QUIET"] = "1"
try:
cleanup_browser(self.task_id)
except Exception as e:
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
finally:
if _prev_quiet is None:
os.environ.pop("HERMES_QUIET", None)
else:
os.environ["HERMES_QUIET"] = _prev_quiet