mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
moved in main atropos agent files to Hermes-Agent, updated paths, gated on optional package install
This commit is contained in:
parent
db348dc467
commit
8dccd6569e
42 changed files with 10978 additions and 4 deletions
501
atropos/tools/tool_executor.py
Normal file
501
atropos/tools/tool_executor.py
Normal file
|
|
@ -0,0 +1,501 @@
|
|||
"""
|
||||
ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories.
|
||||
|
||||
This component is responsible for:
|
||||
- Maintaining trajectory -> Slot affinity (workspace continuity)
|
||||
- Batching sandbox tool calls across trajectories to maximize container utilization
|
||||
- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5)
|
||||
|
||||
For now, only sandbox tools are executed:
|
||||
- bash
|
||||
- read_file
|
||||
- write_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolCall,
|
||||
ToolCallPayload,
|
||||
ToolRegistry,
|
||||
ToolResult,
|
||||
ToolResultPayload,
|
||||
ToolServerExecuteRequest,
|
||||
)
|
||||
from ..slots import Slot, SlotPool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutorConfig:
|
||||
batch_window_ms: int = 20
|
||||
max_batch_size: int = 200
|
||||
allow_network: bool = True
|
||||
require_sandbox: bool = False
|
||||
require_stateful_sandbox: bool = False
|
||||
tool_server_url: Optional[str] = None
|
||||
tool_server_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueuedToolRequest:
|
||||
trajectory_id: str
|
||||
call: ToolCall
|
||||
timeout_s: Optional[float]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
pool: SlotPool,
|
||||
tools: ToolRegistry,
|
||||
config: Optional[ToolExecutorConfig] = None,
|
||||
) -> None:
|
||||
self.pool = pool
|
||||
self.tools = tools
|
||||
self.config = config or ToolExecutorConfig()
|
||||
|
||||
self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._stopping = asyncio.Event()
|
||||
|
||||
self._slots_lock = asyncio.Lock()
|
||||
self._slot_by_trajectory: Dict[str, Slot] = {}
|
||||
|
||||
self._tool_server_client: Optional[httpx.AsyncClient] = None
|
||||
self._tool_server_lock = asyncio.Lock()
|
||||
|
||||
# lightweight stats for status endpoints
|
||||
self.total_requests: int = 0
|
||||
self.total_errors: int = 0
|
||||
self.latencies_s: List[float] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._stopping.set()
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
client = self._tool_server_client
|
||||
self._tool_server_client = None
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
# Best-effort release any remaining slots.
|
||||
async with self._slots_lock:
|
||||
slots = list(self._slot_by_trajectory.items())
|
||||
self._slot_by_trajectory.clear()
|
||||
|
||||
for _, slot in slots:
|
||||
try:
|
||||
await self.pool.release(slot, reset_workspace=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
trajectory_id: str,
|
||||
call: ToolCall,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> ToolResult:
|
||||
if self._task is None:
|
||||
raise RuntimeError("ToolExecutor not started (call start() first)")
|
||||
|
||||
# Allow tool args to suggest a timeout (Hermes-compatible terminal tool),
|
||||
# but never let the model choose "infinite" timeouts.
|
||||
if timeout_s is None:
|
||||
raw_timeout = call.arguments.get("timeout")
|
||||
if isinstance(raw_timeout, (int, float)):
|
||||
timeout_s = float(raw_timeout)
|
||||
if timeout_s is not None:
|
||||
timeout_s = max(1.0, min(float(timeout_s), 600.0))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
started = time.perf_counter()
|
||||
await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut))
|
||||
try:
|
||||
result: ToolResult = await fut
|
||||
return result
|
||||
finally:
|
||||
self.latencies_s.append(time.perf_counter() - started)
|
||||
|
||||
async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None:
|
||||
async with self._slots_lock:
|
||||
slot = self._slot_by_trajectory.pop(trajectory_id, None)
|
||||
|
||||
if slot is not None:
|
||||
await self.pool.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
|
||||
async with self._slots_lock:
|
||||
return self._slot_by_trajectory.get(trajectory_id)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.pool.executor.read_artifact(
|
||||
slot,
|
||||
req.path,
|
||||
encoding=req.encoding,
|
||||
max_bytes=req.max_bytes,
|
||||
include_sha256=req.include_sha256,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactReadResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}")
|
||||
|
||||
async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.pool.executor.list_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
recursive=req.recursive,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactListResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}")
|
||||
|
||||
async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.pool.executor.archive_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
archive_format=req.format,
|
||||
max_bytes=req.max_bytes,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactArchiveResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}")
|
||||
|
||||
async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot:
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
slot = await self.pool.acquire(trajectory_id)
|
||||
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
# Another coroutine won the race; return its slot.
|
||||
await self.pool.release(slot, reset_workspace=False)
|
||||
return existing
|
||||
self._slot_by_trajectory[trajectory_id] = slot
|
||||
return slot
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
pending: List[_QueuedToolRequest] = []
|
||||
deadline: Optional[float] = None
|
||||
|
||||
batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0)
|
||||
max_batch = max(1, self.config.max_batch_size)
|
||||
|
||||
while True:
|
||||
if self._stopping.is_set() and self._queue.empty() and not pending:
|
||||
break
|
||||
|
||||
timeout = None
|
||||
if pending and deadline is not None:
|
||||
timeout = max(0.0, deadline - time.perf_counter())
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
if item is None:
|
||||
continue
|
||||
pending.append(item)
|
||||
if len(pending) == 1:
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
if len(pending) < max_batch:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
# batch window elapsed
|
||||
pass
|
||||
|
||||
if not pending:
|
||||
deadline = None
|
||||
continue
|
||||
|
||||
batch = pending
|
||||
pending = []
|
||||
deadline = None
|
||||
|
||||
await self._execute_batch(batch)
|
||||
|
||||
async def _get_tool_server_client(self) -> httpx.AsyncClient:
|
||||
url = self.config.tool_server_url
|
||||
if not url:
|
||||
raise RuntimeError("ToolServer not configured")
|
||||
|
||||
if self._tool_server_client is not None:
|
||||
return self._tool_server_client
|
||||
|
||||
async with self._tool_server_lock:
|
||||
if self._tool_server_client is None:
|
||||
self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/"))
|
||||
return self._tool_server_client
|
||||
|
||||
def _tool_server_headers(self) -> Dict[str, str]:
|
||||
token = self.config.tool_server_token
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult:
|
||||
client = await self._get_tool_server_client()
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is not None:
|
||||
slot_id = slot.slot_id
|
||||
container_addr = slot.container_addr
|
||||
|
||||
payload = ToolServerExecuteRequest(
|
||||
trajectory_id=req.trajectory_id,
|
||||
tool=ToolCallPayload.from_tool_call(req.call),
|
||||
timeout_s=req.timeout_s,
|
||||
slot_id=slot_id,
|
||||
container_addr=container_addr,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/execute",
|
||||
json=payload.model_dump(),
|
||||
headers=self._tool_server_headers(),
|
||||
timeout=req.timeout_s,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
parsed = ToolResultPayload(**data)
|
||||
result = parsed.to_tool_result()
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"External tool failed: {e}",
|
||||
uniq_id=req.call.uniq_id,
|
||||
)
|
||||
|
||||
async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None:
|
||||
# Resolve tool schemas once per request and separate sandbox/external/unknown.
|
||||
sandbox_items: List[_QueuedToolRequest] = []
|
||||
external_items: List[_QueuedToolRequest] = []
|
||||
unknown_items: List[_QueuedToolRequest] = []
|
||||
|
||||
for it in batch:
|
||||
tool = self.tools.get(it.call.name)
|
||||
if tool is None:
|
||||
unknown_items.append(it)
|
||||
continue
|
||||
|
||||
schema = tool.schema
|
||||
if not schema.external:
|
||||
sandbox_items.append(it)
|
||||
else:
|
||||
external_items.append(it)
|
||||
|
||||
for it in unknown_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_items:
|
||||
if not self.config.tool_server_url:
|
||||
for it in external_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"External tool not available (ToolServer not configured): {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results = await asyncio.gather(*[self._execute_external(it) for it in external_items])
|
||||
for it, res in zip(external_items, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(res)
|
||||
|
||||
if not sandbox_items:
|
||||
return
|
||||
|
||||
# Acquire slots for the distinct trajectories in this batch.
|
||||
try:
|
||||
traj_ids = list({it.trajectory_id for it in sandbox_items})
|
||||
slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids])
|
||||
slot_by_traj = dict(zip(traj_ids, slots))
|
||||
except Exception as e:
|
||||
for it in sandbox_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Failed to acquire slot: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
|
||||
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
|
||||
default_timeout = None
|
||||
if self.pool.executor.timeout.total is not None:
|
||||
default_timeout = float(self.pool.executor.timeout.total)
|
||||
|
||||
for it in sandbox_items:
|
||||
t = it.timeout_s
|
||||
if t is None:
|
||||
t = default_timeout
|
||||
if t is None:
|
||||
t = 30.0
|
||||
by_timeout.setdefault(float(t), []).append(it)
|
||||
|
||||
for timeout_s, items in by_timeout.items():
|
||||
requests = []
|
||||
dispatched: List[_QueuedToolRequest] = []
|
||||
for it in items:
|
||||
slot = slot_by_traj[it.trajectory_id]
|
||||
tool_name = it.call.name
|
||||
args = dict(it.call.arguments)
|
||||
|
||||
# Hermes compatibility: treat `terminal` as an alias of sandbox `bash`.
|
||||
if tool_name == "terminal":
|
||||
if args.get("background"):
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error="terminal background execution is not supported in sandbox",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
tool_name = "bash"
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "terminal_stateful":
|
||||
tool_name = "bash_stateful"
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# `tmux` is a sandbox tool backed by the stateful session manager.
|
||||
# Network policy is env-controlled.
|
||||
args.pop("allow_network", None)
|
||||
|
||||
if tool_name == "bash":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_sandbox"] = bool(self.config.require_sandbox)
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "bash_stateful":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# Network policy applies to the underlying stateful session.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
|
||||
requests.append((slot, tool_name, args))
|
||||
dispatched.append(it)
|
||||
|
||||
results = None
|
||||
try:
|
||||
if not dispatched:
|
||||
continue
|
||||
results = await self.pool.execute_batch(requests, timeout=timeout_s)
|
||||
except Exception as e:
|
||||
for it in items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Batch execution failed: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
for it, res in zip(dispatched, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
tool_result = res.to_tool_result()
|
||||
tool_result.uniq_id = it.call.uniq_id
|
||||
if not it.future.done():
|
||||
it.future.set_result(tool_result)
|
||||
Loading…
Add table
Add a link
Reference in a new issue