hermes-agent/atropos_compatible_agent.py
2026-02-02 14:06:07 +10:00

415 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Atropos-compatible Hermes agent runner.
This is a minimal subclass of Hermes-Agent's `AIAgent` that swaps the OpenAI
function-calling backend for Atroposlib's `ManagedServer`/`ServerManager` backend
and uses Hermes-style XML tool tags:
- <tool_call>{"name": "...", "arguments": {...}}</tool_call>
- <tool_response>{...}</tool_response>
Tool observations are appended as `role="user"` messages containing one or more
`<tool_response>` blocks so they survive common chat templates during tokenization.
"""
from __future__ import annotations
import asyncio
import json
import re
import time
import warnings
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from model_tools import cleanup_vm, handle_function_call
from run_agent import AIAgent
_TOOL_CALL_RE = re.compile(r"<tool_call>\\s*(.*?)\\s*</tool_call>", re.DOTALL)
ATROPOS_TOOL_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools.
## Available Tools
<tools>
{tool_descriptions}
</tools>
## How to Use Tools
To call a tool, output:
<tool_call>{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}</tool_call>
You may include optional reasoning in <think>...</think> before tool calls.
After each tool call, you will receive tool results as:
<tool_response>{{...}}</tool_response>
Continue until finished, then provide a final response with no <tool_call> blocks.
"""
class AtroposAIAgent(AIAgent):
"""
Hermes `AIAgent` variant that uses Atroposlib ServerManager/ManagedServer.
Notes:
- The default Hermes `AIAgent` remains unchanged; this class is opt-in.
- The underlying server must expose `managed_server(tokenizer=...)` OR be a single
APIServer-compatible object usable by Atroposlib's `ManagedServer`.
"""
def __init__(
self,
*,
server: Any,
tokenizer: Any = None,
model: str = "local",
max_iterations: int = 10,
tool_delay: float = 0.0,
enabled_toolsets: Optional[List[str]] = None,
disabled_toolsets: Optional[List[str]] = None,
save_trajectories: bool = False,
verbose_logging: bool = False,
quiet_mode: bool = False,
ephemeral_system_prompt: Optional[str] = None,
log_prefix_chars: int = 100,
log_prefix: str = "",
session_id: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
):
# Call parent init mainly to reuse tool selection + trajectory saving utilities.
super().__init__(
base_url="http://unused",
api_key="dummy-key",
model=model,
max_iterations=max_iterations,
tool_delay=tool_delay,
enabled_toolsets=enabled_toolsets,
disabled_toolsets=disabled_toolsets,
save_trajectories=save_trajectories,
verbose_logging=verbose_logging,
quiet_mode=quiet_mode,
ephemeral_system_prompt=ephemeral_system_prompt,
log_prefix_chars=log_prefix_chars,
log_prefix=log_prefix,
session_id=session_id,
)
self.server = server
self.tokenizer = tokenizer
self.temperature = temperature
self.max_tokens = max_tokens
@asynccontextmanager
async def _managed(self) -> AsyncGenerator[Any, None]:
if hasattr(self.server, "managed_server"):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r"Using OpenAIServer with managed_server does not allow for state tracking",
category=UserWarning,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
yield managed
return
# Fall back to directly wrapping a single server object.
from atroposlib.envs.server_handling.managed_server import ManagedServer
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
try:
yield managed
finally:
managed.reset()
def _tool_descriptions_text(self) -> str:
if not self.tools:
return "(no tools available)"
parts: List[str] = []
for tool in self.tools:
fn = (tool or {}).get("function", {})
name = fn.get("name", "")
desc = (fn.get("description") or "").strip()
if not name:
continue
if desc:
parts.append(f"- {name}: {desc}")
else:
parts.append(f"- {name}")
return "\n".join(parts) if parts else "(no tools available)"
def _build_system_prompt(self, system_message: Optional[str]) -> Optional[str]:
tool_prompt = ATROPOS_TOOL_SYSTEM_PROMPT.format(
tool_descriptions=self._tool_descriptions_text()
)
parts: List[str] = []
if system_message:
parts.append(system_message)
if self.ephemeral_system_prompt:
parts.append(self.ephemeral_system_prompt)
parts.append(tool_prompt)
return "\n\n".join(parts)
def _parse_tool_calls(self, content: str) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
"""
Returns:
(calls, errors)
"""
calls: List[Tuple[str, Dict[str, Any]]] = []
errors: List[str] = []
for raw in _TOOL_CALL_RE.findall(content or ""):
try:
payload = json.loads(raw)
except json.JSONDecodeError as exc:
errors.append(f"Invalid JSON inside <tool_call>: {exc}")
continue
name = payload.get("name")
args = payload.get("arguments", {})
if not isinstance(name, str) or not name:
errors.append("Tool call missing 'name' string")
continue
if not isinstance(args, dict):
errors.append("Tool call 'arguments' must be an object")
continue
calls.append((name, args))
return calls, errors
async def run_conversation_async(
self,
user_message: str,
system_message: Optional[str] = None,
conversation_history: Optional[List[Dict[str, Any]]] = None,
task_id: Optional[str] = None,
) -> Dict[str, Any]:
import uuid
effective_task_id = task_id or str(uuid.uuid4())
messages: List[Dict[str, Any]] = conversation_history.copy() if conversation_history else []
messages.append({"role": "user", "content": user_message})
active_system_prompt = self._build_system_prompt(system_message)
api_call_count = 0
final_response: Optional[str] = None
managed_state: Optional[Dict[str, Any]] = None
completed = False
try:
async with self._managed() as managed:
while api_call_count < self.max_iterations:
api_call_count += 1
api_messages = messages.copy()
if active_system_prompt:
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
chat_kwargs: Dict[str, Any] = {"messages": api_messages, "n": 1}
if self.max_tokens is not None:
chat_kwargs["max_tokens"] = self.max_tokens
if self.temperature is not None:
chat_kwargs["temperature"] = self.temperature
# Prefer OpenAI tool calling when supported by the backend:
# - Many providers normalize Hermes-style <tool_call> tags into tool_calls when `tools` is provided.
# - ManagedServer (atroposlib) does prompt->completion conversion and does not support `tools`.
# Only pass `tools` when we're calling an OpenAI-compatible chat endpoint directly.
tool_schemas = self.tools if self.tools else None
managed_cls = type(managed).__name__
if tool_schemas and managed_cls != "ManagedServer":
chat_kwargs["tools"] = tool_schemas
if os.getenv("HERMES_DEBUG_ATROPOS_REQUEST") == "1":
meta = {
"managed_type": managed_cls,
"model": getattr(getattr(managed, "config", None), "model_name", self.model),
"base_url": getattr(getattr(managed, "config", None), "base_url", None),
"kwargs": chat_kwargs,
}
# Avoid dumping megabytes of data accidentally.
# (Messages can be large; this is still "full" but bounded.)
print("\n=== HERMES_DEBUG_ATROPOS_REQUEST ===", flush=True)
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
response = await managed.chat_completion(**chat_kwargs)
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
try:
dumped = response.model_dump() # openai pydantic model
except Exception:
dumped = getattr(response, "__dict__", {"repr": repr(response)})
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: ChatCompletion (raw) ===", flush=True)
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
if hasattr(managed, "get_state"):
managed_state = managed.get_state()
msg = response.choices[0].message
assistant_content = (msg.content or "")
msg_reasoning = getattr(msg, "reasoning", None)
# Use tool_calls if the backend provides them (preferred).
structured_tool_calls = getattr(msg, "tool_calls", None)
# If the backend emits content="" but includes useful text in reasoning,
# use it for parsing *only if needed* (e.g. tool tags).
if assistant_content == "" and isinstance(msg_reasoning, str) and msg_reasoning:
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: message.reasoning present (content empty) ===", flush=True)
print(msg_reasoning, flush=True)
assistant_msg: Dict[str, Any] = {"role": "assistant", "content": assistant_content}
if structured_tool_calls:
# Preserve tool_calls so the next request is consistent with OpenAI protocol.
try:
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
}
for tc in structured_tool_calls
]
except Exception:
# Best-effort; keep conversation moving.
pass
messages.append(assistant_msg)
# Mode A: OpenAI tool calling (preferred when supported)
if structured_tool_calls:
for tc in structured_tool_calls:
tool_start = time.time()
try:
tool_args = json.loads(tc.function.arguments or "{}")
except Exception:
tool_args = {}
tool_result = handle_function_call(tc.function.name, tool_args, effective_task_id)
tool_duration = time.time() - tool_start
# Keep the raw tool result as tool content (OpenAI protocol expects role=tool).
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": tool_result,
}
)
if self.tool_delay and self.tool_delay > 0:
await asyncio.sleep(self.tool_delay)
# Continue loop after tool execution.
continue
# Mode B: Hermes XML tool tags in assistant text (fallback).
parse_source = assistant_content or (msg_reasoning or "")
tool_calls, parse_errors = self._parse_tool_calls(parse_source)
if parse_errors and not tool_calls:
# Ask the model to retry with valid tool JSON.
err_text = "; ".join(parse_errors[:3])
messages.append(
{
"role": "user",
"content": (
f"<tool_response>{json.dumps({'error': err_text}, ensure_ascii=False)}</tool_response>\n"
"The previous <tool_call> blocks were invalid. Please output valid JSON inside <tool_call>."
),
}
)
continue
if not tool_calls:
# No tool calls: treat as final answer.
final_response = (assistant_content or "").strip()
completed = True
break
tool_responses: List[str] = []
for tool_name, tool_args in tool_calls:
tool_start = time.time()
tool_result = handle_function_call(tool_name, tool_args, effective_task_id)
tool_duration = time.time() - tool_start
try:
parsed = json.loads(tool_result)
payload: Any = parsed
except Exception:
payload = tool_result
tool_payload = {
"name": tool_name,
"duration_s": round(tool_duration, 3),
"result": payload,
}
tool_responses.append(
f"<tool_response>{json.dumps(tool_payload, ensure_ascii=False)}</tool_response>"
)
if self.tool_delay and self.tool_delay > 0:
await asyncio.sleep(self.tool_delay)
messages.append({"role": "user", "content": "\n".join(tool_responses)})
if final_response is None:
final_response = "I've reached the maximum number of iterations."
finally:
try:
cleanup_vm(effective_task_id)
except Exception:
pass
# Save trajectory using Hermes formatting (optional).
self._save_trajectory(messages, user_message, completed=completed)
return {
"final_response": final_response,
"messages": messages,
"api_calls": api_call_count,
"completed": completed,
"managed_state": managed_state,
"system_prompt": active_system_prompt,
"task_id": effective_task_id,
}
def run_conversation(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""
Sync wrapper for convenience.
If called from within a running event loop (e.g. prompt_toolkit), this
runs the async conversation in a dedicated thread to avoid nested loops.
"""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(self.run_conversation_async(*args, **kwargs))
import queue
import threading
out: "queue.Queue[object]" = queue.Queue(maxsize=1)
def runner() -> None:
try:
out.put(asyncio.run(self.run_conversation_async(*args, **kwargs)))
except BaseException as exc: # noqa: BLE001
out.put(exc)
thread = threading.Thread(target=runner, daemon=True)
thread.start()
result = out.get()
if isinstance(result, BaseException):
raise result
return result # type: ignore[return-value]