mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
491 lines
19 KiB
Python
491 lines
19 KiB
Python
"""
|
|
ReACT-style agent implementation for atropos-agent.
|
|
|
|
This module provides the core AtroposAgent class that implements a basic
|
|
Reason-Act-Observe loop with tool calling capabilities.
|
|
|
|
Uses ManagedServer from atroposlib for automatic token/logprob tracking,
|
|
making trajectories ready for RL training.
|
|
|
|
The agent uses Hermes-style XML tags for tool calls:
|
|
- <think>...</think> for reasoning
|
|
- <tool_call>{"name": "...", "arguments": {...}}</tool_call> for actions
|
|
- <tool_response>...</tool_response> for observations
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import json
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from ..tools import ToolCall, ToolRegistry, ToolResult
|
|
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
|
|
|
load_dotenv()
|
|
|
|
|
|
# Default system prompt with tool calling instructions.
|
|
#
|
|
# IMPORTANT: In training-mode environments we want "raw text in -> raw text out" and we
|
|
# parse tool calls from completion text. Do not rely on server-specific `tool_calls` fields.
|
|
AGENT_SYSTEM_PROMPT = """You are a function-calling AI model.
|
|
|
|
You are provided with function signatures within <tools></tools> XML tags.
|
|
You may call one or more functions to assist with the user query. If available tools are not relevant,
|
|
respond in natural language.
|
|
|
|
After calling & executing a function, you will be provided with function results within
|
|
<tool_response></tool_response> XML tags.
|
|
|
|
Here are the available tools:
|
|
<tools>
|
|
{tools_json}
|
|
</tools>
|
|
|
|
## REQUIRED TOOL FORMAT
|
|
|
|
When you decide to call a tool, your assistant message MUST be:
|
|
1) exactly one <think>...</think> block, followed by
|
|
2) one or more <tool_call>...</tool_call> blocks,
|
|
and NOTHING else in that message.
|
|
|
|
For each tool call, output a JSON object with this schema:
|
|
{"name": "function_name", "arguments": { ... }}
|
|
|
|
Each tool call MUST be enclosed within <tool_call></tool_call> XML tags.
|
|
The JSON inside <tool_call> MUST be valid JSON with double quotes.
|
|
|
|
Do NOT output <tool_response> in an assistant message.
|
|
|
|
After you receive tool results, you may either call more tools (same required format) or provide the final answer.
|
|
When providing the final answer, do NOT include any <tool_call> blocks.
|
|
|
|
## ICL (examples)
|
|
|
|
User: Show the current directory.
|
|
Assistant:
|
|
<think>I should use the terminal tool to print the current directory.</think>
|
|
<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>
|
|
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
|
Assistant: /tmp
|
|
|
|
User: List files, then count them.
|
|
Assistant:
|
|
<think>I should list files and count lines.</think>
|
|
<tool_call>{"name": "terminal", "arguments": {"command": "ls -1 | wc -l"}}</tool_call>
|
|
User: <tool_response>{"success": true, "output": "3\\n"}</tool_response>
|
|
Assistant: 3
|
|
|
|
User: Run pwd, then print ok.
|
|
Assistant:
|
|
<think>I should run pwd, then run a command that prints ok.</think>
|
|
<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>
|
|
<tool_call>{"name": "terminal", "arguments": {"command": "echo ok"}}</tool_call>
|
|
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
|
User: <tool_response>{"success": true, "output": "ok\\n"}</tool_response>
|
|
Assistant: ok
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
"""Configuration for the AtroposAgent."""
|
|
|
|
# Generation parameters
|
|
temperature: Optional[float] = 0.7
|
|
# Default to "let the backend decide" (important for tool-tag completions that may be longer).
|
|
max_tokens: Optional[int] = None
|
|
|
|
# Agent behavior
|
|
max_steps: int = 50
|
|
system_prompt: Optional[str] = None
|
|
tool_delay_s: float = 0.0
|
|
|
|
# Working directory for tools
|
|
working_dir: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class SequenceData:
|
|
"""Token/logprob data from a single completion."""
|
|
|
|
full_text: str
|
|
tokens: List[int]
|
|
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
|
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
|
|
|
@classmethod
|
|
def from_sequence_node(cls, node) -> "SequenceData":
|
|
"""Create from a ManagedServer SequenceNode."""
|
|
return cls(
|
|
full_text=node.full_text,
|
|
tokens=node.tokens,
|
|
masked_tokens=node.masked_tokens,
|
|
logprobs=node.logprobs,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class AgentStep:
|
|
"""A single step in the agent's trajectory."""
|
|
|
|
step_number: int
|
|
assistant_message: str
|
|
tool_calls: List[ToolCall] = field(default_factory=list)
|
|
tool_results: List[ToolResult] = field(default_factory=list)
|
|
sequence_data: Optional[SequenceData] = None # Token data from this step
|
|
|
|
@property
|
|
def has_tool_calls(self) -> bool:
|
|
return len(self.tool_calls) > 0
|
|
|
|
|
|
@dataclass
|
|
class AgentResult:
|
|
"""Result of running an agent trajectory."""
|
|
|
|
success: bool
|
|
final_response: str
|
|
steps: List[AgentStep] = field(default_factory=list)
|
|
total_tokens: int = 0
|
|
error: Optional[str] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Full trajectory token data for RL training
|
|
trajectory_data: Optional[SequenceData] = None
|
|
|
|
@property
|
|
def num_steps(self) -> int:
|
|
return len(self.steps)
|
|
|
|
@property
|
|
def total_tool_calls(self) -> int:
|
|
return sum(len(step.tool_calls) for step in self.steps)
|
|
|
|
def to_messages(self) -> List[Dict[str, str]]:
|
|
"""Convert trajectory to messages format for logging."""
|
|
messages = []
|
|
for step in self.steps:
|
|
messages.append({"role": "assistant", "content": step.assistant_message})
|
|
if step.tool_results:
|
|
# Combine all tool responses
|
|
responses = "\n".join(r.to_xml() for r in step.tool_results)
|
|
messages.append({"role": "user", "content": responses})
|
|
return messages
|
|
|
|
def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Convert to format suitable for ScoredDataGroup.
|
|
|
|
Args:
|
|
score: The score for this trajectory
|
|
|
|
Returns:
|
|
Dict with tokens, masks, scores suitable for training, or None if no data
|
|
"""
|
|
if self.trajectory_data is None:
|
|
return None
|
|
|
|
return {
|
|
"tokens": self.trajectory_data.tokens,
|
|
"masks": self.trajectory_data.masked_tokens,
|
|
"scores": score,
|
|
"logprobs": self.trajectory_data.logprobs,
|
|
}
|
|
|
|
|
|
class AtroposAgent:
|
|
"""
|
|
A ReACT-style agent that uses LLMs with tool calling.
|
|
|
|
This implementation wraps ManagedServer for automatic token/logprob tracking,
|
|
making trajectories ready for RL training.
|
|
|
|
Example:
|
|
# `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`.
|
|
# In practice, environments usually construct this via `BaseEnv`.
|
|
server = ...
|
|
tools = ToolRegistry()
|
|
tools.register(BashTool())
|
|
|
|
agent = AtroposAgent(server=server, tools=tools)
|
|
result = await agent.run("List the files in the current directory")
|
|
|
|
# Access token data for training
|
|
if result.trajectory_data:
|
|
print(f"Tokens: {result.trajectory_data.tokens}")
|
|
print(f"Masked: {result.trajectory_data.masked_tokens}")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server, # ServerManager or APIServer
|
|
tools: Optional[ToolRegistry] = None,
|
|
config: Optional[AgentConfig] = None,
|
|
tokenizer: Optional[Any] = None,
|
|
execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None,
|
|
):
|
|
self.server = server
|
|
self.tools = tools or ToolRegistry()
|
|
self.config = config or AgentConfig()
|
|
self.tokenizer = tokenizer or getattr(server, "tokenizer", None)
|
|
self.execute_tool = execute_tool or self.tools.execute
|
|
|
|
@asynccontextmanager
|
|
async def _managed(self) -> AsyncGenerator[Any, None]:
|
|
"""
|
|
Yield a ManagedServer-like object.
|
|
|
|
- If `self.server` is a ServerManager, use its `managed_server()` context manager.
|
|
- If `self.server` is a single APIServer, wrap it in `ManagedServer` directly.
|
|
"""
|
|
if hasattr(self.server, "managed_server"):
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
yield managed
|
|
else:
|
|
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
|
try:
|
|
yield managed
|
|
finally:
|
|
managed.reset()
|
|
|
|
def _build_system_prompt(self) -> str:
|
|
"""Build the system prompt with tool descriptions."""
|
|
if self.config.system_prompt:
|
|
return self.config.system_prompt
|
|
|
|
tools_json = self.tools.get_prompt_tool_definitions_json()
|
|
# Avoid `str.format()` here because the prompt contains many literal `{}` braces
|
|
# in JSON examples; we only want to substitute the single `{tools_json}` token.
|
|
return AGENT_SYSTEM_PROMPT.replace("{tools_json}", tools_json)
|
|
|
|
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
|
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
|
return
|
|
try:
|
|
# Avoid dumping megabytes by default; messages can be huge.
|
|
meta = {
|
|
"step": step_num,
|
|
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
|
"n": chat_kwargs.get("n"),
|
|
"max_tokens": chat_kwargs.get("max_tokens"),
|
|
"temperature": chat_kwargs.get("temperature"),
|
|
"num_messages": len(chat_kwargs.get("messages") or []),
|
|
}
|
|
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST ===", flush=True)
|
|
print(meta, flush=True)
|
|
|
|
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_FULL") == "1":
|
|
payload = dict(chat_kwargs)
|
|
# Make the payload more legible and less huge.
|
|
try:
|
|
dumped = json.dumps(payload, ensure_ascii=False, indent=2)
|
|
except Exception:
|
|
dumped = repr(payload)
|
|
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST_FULL ===", flush=True)
|
|
print(dumped[:200_000], flush=True)
|
|
except Exception:
|
|
return
|
|
|
|
def _debug_dump_response(self, *, step_num: int, response: Any) -> None:
|
|
if os.getenv("ATROPOS_DEBUG_AGENT_RESPONSE") != "1":
|
|
return
|
|
print("\n=== ATROPOS_DEBUG_AGENT_RESPONSE ===", flush=True)
|
|
print({"step": step_num, "type": type(response).__name__}, flush=True)
|
|
try:
|
|
dumped = response.model_dump() # openai pydantic model
|
|
except Exception:
|
|
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
|
# Keep the dump bounded; we only need enough to see the assistant message content.
|
|
text = str(dumped)
|
|
print(text[:200_000], flush=True)
|
|
|
|
async def run(
|
|
self,
|
|
task: str,
|
|
initial_messages: Optional[List[Dict[str, str]]] = None,
|
|
) -> AgentResult:
|
|
"""
|
|
Run the agent on a task using ManagedServer for token tracking.
|
|
|
|
Args:
|
|
task: The task/prompt for the agent
|
|
initial_messages: Optional additional context messages
|
|
|
|
Returns:
|
|
AgentResult with the trajectory, final response, and token data
|
|
"""
|
|
messages = [
|
|
{"role": "system", "content": self._build_system_prompt()},
|
|
]
|
|
|
|
if initial_messages:
|
|
messages.extend(initial_messages)
|
|
|
|
messages.append({"role": "user", "content": task})
|
|
|
|
steps = []
|
|
final_response = ""
|
|
final_node = None
|
|
final_prompt_messages: Optional[List[Dict[str, str]]] = None
|
|
|
|
# Use ManagedServer for automatic token tracking
|
|
async with self._managed() as managed:
|
|
for step_num in range(self.config.max_steps):
|
|
try:
|
|
# Keep a copy of the prompt messages used for this completion.
|
|
# Useful for reconstructing tokens/masks when state tracking is unavailable.
|
|
prompt_messages = list(messages)
|
|
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
|
if self.config.max_tokens is not None:
|
|
chat_kwargs["max_tokens"] = self.config.max_tokens
|
|
if self.config.temperature is not None:
|
|
chat_kwargs["temperature"] = self.config.temperature
|
|
|
|
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
|
|
response = await managed.chat_completion(**chat_kwargs)
|
|
self._debug_dump_response(step_num=step_num + 1, response=response)
|
|
|
|
current_node = None
|
|
if hasattr(managed, "get_state"):
|
|
state = managed.get_state()
|
|
nodes = state.get("nodes", [])
|
|
current_node = nodes[-1] if nodes else None
|
|
|
|
except Exception as e:
|
|
return AgentResult(
|
|
success=False,
|
|
final_response="",
|
|
steps=steps,
|
|
error=f"Generation error: {str(e)}",
|
|
)
|
|
|
|
msg = response.choices[0].message
|
|
# Some OpenAI-compatible servers populate `message.reasoning` and leave `content=""`.
|
|
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
|
tool_calls = ToolCall.parse_from_text(response_text)
|
|
|
|
step = AgentStep(
|
|
step_number=step_num + 1,
|
|
assistant_message=response_text,
|
|
tool_calls=tool_calls,
|
|
sequence_data=SequenceData.from_sequence_node(current_node) if current_node else None,
|
|
)
|
|
|
|
if not tool_calls:
|
|
steps.append(step)
|
|
final_response = response_text
|
|
final_node = current_node
|
|
final_prompt_messages = prompt_messages
|
|
break
|
|
|
|
messages.append({"role": "assistant", "content": response_text})
|
|
|
|
tool_responses = []
|
|
for call in tool_calls:
|
|
result = await self.execute_tool(call)
|
|
step.tool_results.append(result)
|
|
tool_responses.append(result.to_xml())
|
|
if self.config.tool_delay_s > 0:
|
|
await asyncio.sleep(self.config.tool_delay_s)
|
|
|
|
steps.append(step)
|
|
|
|
responses_text = "\n".join(tool_responses)
|
|
# Tool observations are represented as user content with Hermes-style tags.
|
|
# This is compatible with most OpenAI-compatible chat APIs and ensures
|
|
# tokenizers/chat templates include tool outputs during training.
|
|
messages.append({"role": "user", "content": responses_text})
|
|
|
|
else:
|
|
# Reached max steps without completing
|
|
return AgentResult(
|
|
success=False,
|
|
final_response=final_response,
|
|
steps=steps,
|
|
error=f"Reached maximum steps ({self.config.max_steps})",
|
|
)
|
|
|
|
# Build result with trajectory data
|
|
trajectory_data = None
|
|
if final_node:
|
|
trajectory_data = SequenceData.from_sequence_node(final_node)
|
|
elif final_prompt_messages is not None and self.tokenizer is not None:
|
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
prompt_text = self.tokenizer.apply_chat_template(
|
|
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
|
else:
|
|
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
|
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
|
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
|
tokens = prompt_tokens + output_tokens
|
|
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
|
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
|
trajectory_data = SequenceData(
|
|
full_text=f"{prompt_text}{final_response}",
|
|
tokens=tokens,
|
|
masked_tokens=masked_tokens,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
return AgentResult(
|
|
success=True,
|
|
final_response=final_response,
|
|
steps=steps,
|
|
trajectory_data=trajectory_data,
|
|
)
|
|
|
|
async def run_single_turn(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
execute_tools: bool = True,
|
|
) -> tuple[str, List[ToolResult], Optional[SequenceData]]:
|
|
"""
|
|
Run a single turn of the agent (one LLM call + tool execution).
|
|
|
|
This is useful for integration with BaseEnv where you want more
|
|
control over the loop.
|
|
|
|
Args:
|
|
messages: The conversation history
|
|
execute_tools: Whether to execute parsed tool calls
|
|
|
|
Returns:
|
|
Tuple of (response_text, tool_results, sequence_data)
|
|
"""
|
|
async with self._managed() as managed:
|
|
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
|
if self.config.max_tokens is not None:
|
|
chat_kwargs["max_tokens"] = self.config.max_tokens
|
|
if self.config.temperature is not None:
|
|
chat_kwargs["temperature"] = self.config.temperature
|
|
|
|
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
|
|
response = await managed.chat_completion(**chat_kwargs)
|
|
self._debug_dump_response(step_num=1, response=response)
|
|
|
|
current_node = None
|
|
if hasattr(managed, "get_state"):
|
|
state = managed.get_state()
|
|
nodes = state.get("nodes", [])
|
|
current_node = nodes[-1] if nodes else None
|
|
|
|
msg = response.choices[0].message
|
|
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
|
tool_results = []
|
|
|
|
if execute_tools:
|
|
tool_calls = ToolCall.parse_from_text(response_text)
|
|
for call in tool_calls:
|
|
result = await self.execute_tool(call)
|
|
tool_results.append(result)
|
|
|
|
sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
|
|
|
return response_text, tool_results, sequence_data
|