updated hermes_base_env, moved in sandbox logic from old agent, added patch so sglang on runpod works with /generate format (will remove). worked, model didnt produce tool calls but full logprobs worked

This commit is contained in:
Shannon Sands 2026-02-10 06:06:21 +00:00
parent 4619d1c8ef
commit a69924631c
7 changed files with 766 additions and 123 deletions

View file

@ -64,7 +64,7 @@ from environments.agent_loop import AgentResult, HermesAgentLoop
from environments.tool_context import ToolContext
# Import hermes-agent toolset infrastructure
from model_tools import get_tool_definitions
from model_tools import get_tool_definitions, handle_function_call
from toolset_distributions import sample_toolsets_from_distribution
logger = logging.getLogger(__name__)
@ -538,6 +538,13 @@ class HermesAgentBaseEnv(BaseEnv):
await super().wandb_log(wandb_metrics)
def _use_sandbox_backend(self) -> bool:
"""Check if we should route tool execution through a sandbox backend."""
return (
self.config.tool_pool_mode != "default"
and self._sandbox_backend is not None
)
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
@ -546,12 +553,19 @@ class HermesAgentBaseEnv(BaseEnv):
This is called group_size times in parallel by collect_trajectories().
Each call gets its own task_id for terminal/browser session isolation.
When tool_pool_mode != "default", routes tool execution through the
sandbox backend (Modal, Nomad) with slot-based multiplexing:
1. Acquire a slot from the sandbox pool
2. Setup workspace via subclass hook (e.g., git clone + worktree)
3. Run agent loop with terminal calls routed through sandbox
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
5. Release the slot
"""
task_id = str(uuid.uuid4())
# Get group-level tools (resolved once in collect_trajectories)
if self._current_group_tools is None:
# Fallback: resolve per-trajectory if called outside collect_trajectories
tools, valid_names = self._resolve_tools_for_group()
else:
tools, valid_names = self._current_group_tools
@ -562,11 +576,183 @@ class HermesAgentBaseEnv(BaseEnv):
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": self.format_prompt(item)})
# Run the agent loop
result: AgentResult
# Dispatch to the appropriate path
if self._use_sandbox_backend():
return await self._collect_trajectory_sandbox(
item, task_id, tools, valid_names, messages
)
else:
return await self._collect_trajectory_local(
item, task_id, tools, valid_names, messages
)
async def _collect_trajectory_local(
self,
item: Item,
task_id: str,
tools: List[Dict[str, Any]],
valid_names: Set[str],
messages: List[Dict[str, Any]],
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
"""
Default (local) trajectory collection path.
Uses hermes-agent's handle_function_call() for tool execution.
Reward computed via compute_reward() with ToolContext.
"""
result = await self._run_agent_loop(
task_id, tools, valid_names, messages, tool_handler=None
)
# Skip reward if the agent loop produced no meaningful work
only_system_and_user = all(
msg.get("role") in ("system", "user") for msg in result.messages
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
result.turns_used, len(result.messages),
)
reward = 0.0
else:
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.error("compute_reward failed: %s", e)
reward = 0.0
finally:
ctx.cleanup()
return self._build_scored_item(item, result, reward)
async def _collect_trajectory_sandbox(
self,
item: Item,
task_id: str,
tools: List[Dict[str, Any]],
valid_names: Set[str],
messages: List[Dict[str, Any]],
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
"""
Sandbox trajectory collection path (Modal, Nomad).
Acquires a slot, sets up the workspace, runs the agent loop with
terminal calls routed through the sandbox backend, then verifies
and scores in-sandbox before releasing the slot.
"""
from atropos.slots.executor import ExecutionResult
backend = self._sandbox_backend
slot = None
try:
# 1. Acquire a slot from the sandbox pool
logger.info("Acquiring sandbox slot for task %s", task_id)
slot = await backend.acquire(task_id)
logger.info(
"Acquired slot %s (sandbox %s) for task %s",
slot.slot_id, slot.alloc_id, task_id,
)
# 2. Create exec_tool callable for setup/verify hooks
# Returns ExecutionResult (structured) for env hook consumption
async def exec_tool(tool_name: str, args: Dict[str, Any], timeout: float = 300) -> ExecutionResult:
results = await backend.execute_batch(
[(slot, tool_name, args)], timeout_s=timeout
)
return results[0]
# 3. Setup workspace (subclass hook: git clone, worktree, etc.)
workspace_meta = await self.setup_trajectory_workspace(
item, trajectory_id=task_id, exec_tool=exec_tool
)
# 4. Create tool_handler for agent loop
# Returns JSON string (same format as hermes terminal tool)
async def sandbox_tool_handler(
tool_name: str, args: Dict[str, Any], handler_task_id: str
) -> str:
if tool_name == "terminal":
timeout_val = min(float(args.get("timeout", 300)), 600)
result = await backend.execute_batch(
[(slot, "bash", {"command": args.get("command", "")})],
timeout_s=timeout_val,
)
r = result[0]
output = r.output if r.success else f"{r.output}\n{r.error}" if r.output else r.error
return json.dumps({
"output": output,
"exit_code": r.metadata.get("returncode", 0 if r.success else 1),
})
else:
# Non-terminal tools run locally via hermes-agent
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: handle_function_call(
tool_name, args, task_id=handler_task_id
),
)
# 5. Run agent loop with sandbox routing
result = await self._run_agent_loop(
task_id, tools, valid_names, messages,
tool_handler=sandbox_tool_handler,
)
# 6. Skip verification if no meaningful work
only_system_and_user = all(
msg.get("role") in ("system", "user") for msg in result.messages
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
result.turns_used, len(result.messages),
)
reward = 0.0
else:
# 7. Verify and score in-sandbox (subclass hook: pytest, etc.)
reward, score_meta = await self.verify_and_score_trajectory(
item, result,
trajectory_id=task_id,
exec_tool=exec_tool,
workspace_meta=workspace_meta,
)
logger.info("Sandbox reward for task %s: %.2f", task_id, reward)
return self._build_scored_item(item, result, reward)
except Exception as e:
logger.error("Sandbox trajectory failed for task %s: %s", task_id, e, exc_info=True)
# Return a zero-reward placeholder so the pipeline doesn't break
dummy_result = AgentResult(
messages=messages, turns_used=0, finished_naturally=False
)
return self._build_scored_item(item, dummy_result, 0.0)
finally:
if slot is not None:
try:
await backend.release(slot, reset_workspace=True)
logger.info("Released slot %s for task %s", slot.slot_id, task_id)
except Exception as e:
logger.error("Failed to release slot %s: %s", slot.slot_id, e)
async def _run_agent_loop(
self,
task_id: str,
tools: List[Dict[str, Any]],
valid_names: Set[str],
messages: List[Dict[str, Any]],
tool_handler=None,
) -> AgentResult:
"""
Run the agent loop in either Phase 1 or Phase 2 mode.
Shared between local and sandbox paths -- the only difference is
the tool_handler parameter (None for local, sandbox callable for sandbox).
"""
if self._use_managed_server():
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
# Load the tool call parser from registry based on config
from environments.tool_call_parsers import get_parser
try:
tc_parser = get_parser(self.config.tool_call_parser)
@ -590,10 +776,10 @@ class HermesAgentBaseEnv(BaseEnv):
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
tool_handler=tool_handler,
)
result = await agent.run(messages)
return await agent.run(messages)
except NotImplementedError:
# DummyManagedServer not allowed -- fall back to Phase 1
logger.warning(
"ManagedServer not available (OpenAI server?). "
"Falling back to direct server mode."
@ -606,10 +792,10 @@ class HermesAgentBaseEnv(BaseEnv):
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
tool_handler=tool_handler,
)
result = await agent.run(messages)
return await agent.run(messages)
else:
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
@ -618,32 +804,21 @@ class HermesAgentBaseEnv(BaseEnv):
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
tool_handler=tool_handler,
)
result = await agent.run(messages)
return await agent.run(messages)
# Skip reward computation if the agent loop produced no meaningful work
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
# just to verify files that were never created.
only_system_and_user = all(
msg.get("role") in ("system", "user") for msg in result.messages
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
result.turns_used, len(result.messages),
)
reward = 0.0
else:
# Compute reward using ToolContext (gives verifier full tool access)
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.error("compute_reward failed: %s", e)
reward = 0.0
finally:
ctx.cleanup()
def _build_scored_item(
self,
item: Item,
result: AgentResult,
reward: float,
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
"""
Build a ScoredDataItem from an AgentResult and reward.
Shared between local and sandbox paths.
"""
# Track tool errors for wandb logging
if result.tool_errors:
for err in result.tool_errors:
@ -656,28 +831,19 @@ class HermesAgentBaseEnv(BaseEnv):
})
# Build ScoredDataItem from ManagedServer state
# Phase 2: real tokens/masks/logprobs from SequenceNodes
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
nodes = (result.managed_state or {}).get("nodes", [])
if nodes:
# Phase 2 (or DummyManagedServer): use actual node data
node = nodes[-1] # Final sequence node = full trajectory
node = nodes[-1]
scored_item: Dict[str, Any] = {
"tokens": node.tokens,
"masks": node.masked_tokens,
"scores": reward,
}
# Include logprobs if available (Phase 2)
if hasattr(node, "logprobs") and node.logprobs:
scored_item["advantages"] = None # Computed by trainer
scored_item["advantages"] = None
scored_item["ref_logprobs"] = None
else:
# Phase 1 with no managed state: create placeholder tokens
# so the data pipeline doesn't break. These are NOT suitable
# for training but allow process mode (SFT data gen) to work.
# Tokenize the full conversation to get approximate tokens.
full_text = "\n".join(
msg.get("content", "") for msg in result.messages if msg.get("content")
)
@ -688,13 +854,11 @@ class HermesAgentBaseEnv(BaseEnv):
scored_item = {
"tokens": tokens,
"masks": [-100] + tokens[1:], # Mask first token as prompt
"masks": [-100] + tokens[1:],
"scores": reward,
}
# Always include messages for wandb rollout display and data logging
scored_item["messages"] = result.messages
return scored_item, []
# =========================================================================