mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
updated hermes_base_env, moved in sandbox logic from old agent, added patch so sglang on runpod works with /generate format (will remove). worked, model didnt produce tool calls but full logprobs worked
This commit is contained in:
parent
4619d1c8ef
commit
a69924631c
7 changed files with 766 additions and 123 deletions
|
|
@ -64,7 +64,7 @@ from environments.agent_loop import AgentResult, HermesAgentLoop
|
|||
from environments.tool_context import ToolContext
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from model_tools import get_tool_definitions, handle_function_call
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -538,6 +538,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def _use_sandbox_backend(self) -> bool:
|
||||
"""Check if we should route tool execution through a sandbox backend."""
|
||||
return (
|
||||
self.config.tool_pool_mode != "default"
|
||||
and self._sandbox_backend is not None
|
||||
)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
|
|
@ -546,12 +553,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
|
||||
When tool_pool_mode != "default", routes tool execution through the
|
||||
sandbox backend (Modal, Nomad) with slot-based multiplexing:
|
||||
1. Acquire a slot from the sandbox pool
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with terminal calls routed through sandbox
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
|
|
@ -562,11 +576,183 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
# Dispatch to the appropriate path
|
||||
if self._use_sandbox_backend():
|
||||
return await self._collect_trajectory_sandbox(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
else:
|
||||
return await self._collect_trajectory_local(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
|
||||
async def _collect_trajectory_local(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Default (local) trajectory collection path.
|
||||
|
||||
Uses hermes-agent's handle_function_call() for tool execution.
|
||||
Reward computed via compute_reward() with ToolContext.
|
||||
"""
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages, tool_handler=None
|
||||
)
|
||||
|
||||
# Skip reward if the agent loop produced no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
async def _collect_trajectory_sandbox(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Sandbox trajectory collection path (Modal, Nomad).
|
||||
|
||||
Acquires a slot, sets up the workspace, runs the agent loop with
|
||||
terminal calls routed through the sandbox backend, then verifies
|
||||
and scores in-sandbox before releasing the slot.
|
||||
"""
|
||||
from atropos.slots.executor import ExecutionResult
|
||||
|
||||
backend = self._sandbox_backend
|
||||
slot = None
|
||||
try:
|
||||
# 1. Acquire a slot from the sandbox pool
|
||||
logger.info("Acquiring sandbox slot for task %s", task_id)
|
||||
slot = await backend.acquire(task_id)
|
||||
logger.info(
|
||||
"Acquired slot %s (sandbox %s) for task %s",
|
||||
slot.slot_id, slot.alloc_id, task_id,
|
||||
)
|
||||
|
||||
# 2. Create exec_tool callable for setup/verify hooks
|
||||
# Returns ExecutionResult (structured) for env hook consumption
|
||||
async def exec_tool(tool_name: str, args: Dict[str, Any], timeout: float = 300) -> ExecutionResult:
|
||||
results = await backend.execute_batch(
|
||||
[(slot, tool_name, args)], timeout_s=timeout
|
||||
)
|
||||
return results[0]
|
||||
|
||||
# 3. Setup workspace (subclass hook: git clone, worktree, etc.)
|
||||
workspace_meta = await self.setup_trajectory_workspace(
|
||||
item, trajectory_id=task_id, exec_tool=exec_tool
|
||||
)
|
||||
|
||||
# 4. Create tool_handler for agent loop
|
||||
# Returns JSON string (same format as hermes terminal tool)
|
||||
async def sandbox_tool_handler(
|
||||
tool_name: str, args: Dict[str, Any], handler_task_id: str
|
||||
) -> str:
|
||||
if tool_name == "terminal":
|
||||
timeout_val = min(float(args.get("timeout", 300)), 600)
|
||||
result = await backend.execute_batch(
|
||||
[(slot, "bash", {"command": args.get("command", "")})],
|
||||
timeout_s=timeout_val,
|
||||
)
|
||||
r = result[0]
|
||||
output = r.output if r.success else f"{r.output}\n{r.error}" if r.output else r.error
|
||||
return json.dumps({
|
||||
"output": output,
|
||||
"exit_code": r.metadata.get("returncode", 0 if r.success else 1),
|
||||
})
|
||||
else:
|
||||
# Non-terminal tools run locally via hermes-agent
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=handler_task_id
|
||||
),
|
||||
)
|
||||
|
||||
# 5. Run agent loop with sandbox routing
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages,
|
||||
tool_handler=sandbox_tool_handler,
|
||||
)
|
||||
|
||||
# 6. Skip verification if no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# 7. Verify and score in-sandbox (subclass hook: pytest, etc.)
|
||||
reward, score_meta = await self.verify_and_score_trajectory(
|
||||
item, result,
|
||||
trajectory_id=task_id,
|
||||
exec_tool=exec_tool,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
logger.info("Sandbox reward for task %s: %.2f", task_id, reward)
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Sandbox trajectory failed for task %s: %s", task_id, e, exc_info=True)
|
||||
# Return a zero-reward placeholder so the pipeline doesn't break
|
||||
dummy_result = AgentResult(
|
||||
messages=messages, turns_used=0, finished_naturally=False
|
||||
)
|
||||
return self._build_scored_item(item, dummy_result, 0.0)
|
||||
|
||||
finally:
|
||||
if slot is not None:
|
||||
try:
|
||||
await backend.release(slot, reset_workspace=True)
|
||||
logger.info("Released slot %s for task %s", slot.slot_id, task_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to release slot %s: %s", slot.slot_id, e)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_handler=None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent loop in either Phase 1 or Phase 2 mode.
|
||||
|
||||
Shared between local and sandbox paths -- the only difference is
|
||||
the tool_handler parameter (None for local, sandbox callable for sandbox).
|
||||
"""
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
|
|
@ -590,10 +776,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
|
|
@ -606,10 +792,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
|
|
@ -618,32 +804,21 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
tool_handler=tool_handler,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
def _build_scored_item(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
reward: float,
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Build a ScoredDataItem from an AgentResult and reward.
|
||||
|
||||
Shared between local and sandbox paths.
|
||||
"""
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
|
|
@ -656,28 +831,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
node = nodes[-1]
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["advantages"] = None
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
|
|
@ -688,13 +854,11 @@ class HermesAgentBaseEnv(BaseEnv):
|
|||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"masks": [-100] + tokens[1:],
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue