diff --git a/.gitignore b/.gitignore index 87617b600f..062ef4da79 100644 --- a/.gitignore +++ b/.gitignore @@ -40,5 +40,7 @@ agent-browser/ privvy* images/ -# CLI config (may contain sensitive SSH paths) -cli-config.yaml +# CLI config (may contain sensitive SSH paths) +cli-config.yaml + +.DS_Store diff --git a/atropos/Dockerfile b/atropos/Dockerfile new file mode 100644 index 0000000000..0738d9e747 --- /dev/null +++ b/atropos/Dockerfile @@ -0,0 +1,41 @@ +# Dockerfile for atropos-agent sandbox server +# Runs inside Nomad containers to handle tool execution +# Includes bubblewrap for namespace-based slot isolation + +FROM python:3.11-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + # Bubblewrap for namespace isolation + bubblewrap \ + # `script` for PTY allocation (used for stable tmux+asciinema startup) + util-linux \ + # Git for SWE-style tasks (cloning repos) + git \ + # tmux for stateful terminal sessions (Phase 4.7+) + tmux \ + # Common tools agents might need + curl \ + wget \ + jq \ + # Cleanup + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies (sandbox server + optional terminal recording) +RUN pip install --no-cache-dir aiohttp asciinema + +# Copy the sandbox server +COPY sandbox_server.py /app/sandbox_server.py + +WORKDIR /app + +# Create data directory for slot workspaces +RUN mkdir -p /data + +# Verify bubblewrap is installed and working +RUN bwrap --version + +EXPOSE 8080 + +# Default command - can be overridden by Nomad job spec +CMD ["python", "sandbox_server.py", "--port", "8080", "--slots", "10", "--data-dir", "/data"] diff --git a/atropos/__init__.py b/atropos/__init__.py new file mode 100644 index 0000000000..18bfed5df2 --- /dev/null +++ b/atropos/__init__.py @@ -0,0 +1,46 @@ +""" +Atropos integration for Hermes-Agent. + +This package is intentionally optional: Hermes-Agent should work without Atropos. +If you import anything from `atropos.*` without having `atroposlib` installed, +we raise a clear error with install instructions. + +Install (recommended, from repo checkout): + uv sync --extra atropos + +Or (pip / editable): + pip install -e '.[atropos]' +""" + +from __future__ import annotations + + +def _require_atroposlib() -> None: + try: + import atroposlib # noqa: F401 + except ModuleNotFoundError as exc: # pragma: no cover + raise ModuleNotFoundError( + "Hermes-Agent Atropos integration requires `atroposlib`, but it is not installed.\n" + "Install it with:\n" + " uv sync --extra atropos\n" + "or:\n" + " pip install -e '.[atropos]'\n" + ) from exc + + +_require_atroposlib() + +# Re-export the most commonly used pieces for convenience. +from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402 +from .envs import AgentEnv, AgentEnvConfig # noqa: E402 + +__all__ = [ + "AtroposAgent", + "AgentConfig", + "AgentResult", + "AgentStep", + "SequenceData", + "AgentEnv", + "AgentEnvConfig", +] + diff --git a/atropos/agent/__init__.py b/atropos/agent/__init__.py new file mode 100644 index 0000000000..9f95189d6e --- /dev/null +++ b/atropos/agent/__init__.py @@ -0,0 +1,15 @@ +""" +Agent abstractions for atropos-agent. + +Provides the core AtroposAgent class for running ReACT-style agent loops. +""" + +from .atropos_agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData + +__all__ = [ + "AtroposAgent", + "AgentConfig", + "AgentResult", + "AgentStep", + "SequenceData", +] diff --git a/atropos/agent/atropos_agent.py b/atropos/agent/atropos_agent.py new file mode 100644 index 0000000000..9ea6e3044e --- /dev/null +++ b/atropos/agent/atropos_agent.py @@ -0,0 +1,407 @@ +""" +ReACT-style agent implementation for atropos-agent. + +This module provides the core AtroposAgent class that implements a basic +Reason-Act-Observe loop with tool calling capabilities. + +Uses ManagedServer from atroposlib for automatic token/logprob tracking, +making trajectories ready for RL training. + +The agent uses Hermes-style XML tags for tool calls: +- ... for reasoning +- {"name": "...", "arguments": {...}} for actions +- ... for observations +""" + +import asyncio +import os +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union + +from dotenv import load_dotenv + +from ..tools import ToolCall, ToolRegistry, ToolResult +from atroposlib.envs.server_handling.managed_server import ManagedServer + +load_dotenv() + + +# Default system prompt with tool calling instructions +AGENT_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. You can use tools to accomplish tasks. + +## Available Tools + +{tool_descriptions} + + +## How to Use Tools +To use a tool, output a tool call in the following format: +{{"name": "tool_name", "arguments": {{"arg1": "value1", "arg2": "value2"}}}} + +You may reason about what to do before calling a tool: +I need to check what files are in the current directory... +{{"name": "bash", "arguments": {{"command": "ls -la"}}}} + +After a tool is executed, you will receive the result: +{{"success": true, "output": "..."}} + +Continue using tools as needed until you have completed the task. +When you have finished, provide your final response without any tool calls. + +## Important Guidelines +- Think step by step about what you need to do +- Use tools to gather information and perform actions +- If a tool call fails, analyze the error and try a different approach +- Provide clear, concise responses when the task is complete +""" + + +@dataclass +class AgentConfig: + """Configuration for the AtroposAgent.""" + + # Generation parameters + temperature: float = 0.7 + max_tokens: int = 4096 + + # Agent behavior + max_steps: int = 50 + system_prompt: Optional[str] = None + tool_delay_s: float = 0.0 + + # Working directory for tools + working_dir: Optional[str] = None + + +@dataclass +class SequenceData: + """Token/logprob data from a single completion.""" + + full_text: str + tokens: List[int] + masked_tokens: List[int] # -100 for prompt, actual IDs for completion + logprobs: List[float] # 1.0 for prompt, actual values for completion + + @classmethod + def from_sequence_node(cls, node) -> "SequenceData": + """Create from a ManagedServer SequenceNode.""" + return cls( + full_text=node.full_text, + tokens=node.tokens, + masked_tokens=node.masked_tokens, + logprobs=node.logprobs, + ) + + +@dataclass +class AgentStep: + """A single step in the agent's trajectory.""" + + step_number: int + assistant_message: str + tool_calls: List[ToolCall] = field(default_factory=list) + tool_results: List[ToolResult] = field(default_factory=list) + sequence_data: Optional[SequenceData] = None # Token data from this step + + @property + def has_tool_calls(self) -> bool: + return len(self.tool_calls) > 0 + + +@dataclass +class AgentResult: + """Result of running an agent trajectory.""" + + success: bool + final_response: str + steps: List[AgentStep] = field(default_factory=list) + total_tokens: int = 0 + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + # Full trajectory token data for RL training + trajectory_data: Optional[SequenceData] = None + + @property + def num_steps(self) -> int: + return len(self.steps) + + @property + def total_tool_calls(self) -> int: + return sum(len(step.tool_calls) for step in self.steps) + + def to_messages(self) -> List[Dict[str, str]]: + """Convert trajectory to messages format for logging.""" + messages = [] + for step in self.steps: + messages.append({"role": "assistant", "content": step.assistant_message}) + if step.tool_results: + # Combine all tool responses + responses = "\n".join(r.to_xml() for r in step.tool_results) + messages.append({"role": "user", "content": responses}) + return messages + + def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]: + """ + Convert to format suitable for ScoredDataGroup. + + Args: + score: The score for this trajectory + + Returns: + Dict with tokens, masks, scores suitable for training, or None if no data + """ + if self.trajectory_data is None: + return None + + return { + "tokens": self.trajectory_data.tokens, + "masks": self.trajectory_data.masked_tokens, + "scores": score, + "logprobs": self.trajectory_data.logprobs, + } + + +class AtroposAgent: + """ + A ReACT-style agent that uses LLMs with tool calling. + + This implementation wraps ManagedServer for automatic token/logprob tracking, + making trajectories ready for RL training. + + Example: + # `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`. + # In practice, environments usually construct this via `BaseEnv`. + server = ... + tools = ToolRegistry() + tools.register(BashTool()) + + agent = AtroposAgent(server=server, tools=tools) + result = await agent.run("List the files in the current directory") + + # Access token data for training + if result.trajectory_data: + print(f"Tokens: {result.trajectory_data.tokens}") + print(f"Masked: {result.trajectory_data.masked_tokens}") + """ + + def __init__( + self, + server, # ServerManager or APIServer + tools: Optional[ToolRegistry] = None, + config: Optional[AgentConfig] = None, + tokenizer: Optional[Any] = None, + execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None, + ): + self.server = server + self.tools = tools or ToolRegistry() + self.config = config or AgentConfig() + self.tokenizer = tokenizer or getattr(server, "tokenizer", None) + self.execute_tool = execute_tool or self.tools.execute + + @asynccontextmanager + async def _managed(self) -> AsyncGenerator[Any, None]: + """ + Yield a ManagedServer-like object. + + - If `self.server` is a ServerManager, use its `managed_server()` context manager. + - If `self.server` is a single APIServer, wrap it in `ManagedServer` directly. + """ + if hasattr(self.server, "managed_server"): + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + yield managed + else: + managed = ManagedServer(server=self.server, tokenizer=self.tokenizer) + try: + yield managed + finally: + managed.reset() + + def _build_system_prompt(self) -> str: + """Build the system prompt with tool descriptions.""" + if self.config.system_prompt: + return self.config.system_prompt + + tool_descriptions = self.tools.get_prompt_description() + if not tool_descriptions: + tool_descriptions = "(No tools available)" + + return AGENT_SYSTEM_PROMPT.format(tool_descriptions=tool_descriptions) + + async def run( + self, + task: str, + initial_messages: Optional[List[Dict[str, str]]] = None, + ) -> AgentResult: + """ + Run the agent on a task using ManagedServer for token tracking. + + Args: + task: The task/prompt for the agent + initial_messages: Optional additional context messages + + Returns: + AgentResult with the trajectory, final response, and token data + """ + messages = [ + {"role": "system", "content": self._build_system_prompt()}, + ] + + if initial_messages: + messages.extend(initial_messages) + + messages.append({"role": "user", "content": task}) + + steps = [] + final_response = "" + final_node = None + final_prompt_messages: Optional[List[Dict[str, str]]] = None + + # Use ManagedServer for automatic token tracking + async with self._managed() as managed: + for step_num in range(self.config.max_steps): + try: + # Keep a copy of the prompt messages used for this completion. + # Useful for reconstructing tokens/masks when state tracking is unavailable. + prompt_messages = list(messages) + response = await managed.chat_completion( + messages=messages, + n=1, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + ) + + current_node = None + if hasattr(managed, "get_state"): + state = managed.get_state() + nodes = state.get("nodes", []) + current_node = nodes[-1] if nodes else None + + except Exception as e: + return AgentResult( + success=False, + final_response="", + steps=steps, + error=f"Generation error: {str(e)}", + ) + + response_text = response.choices[0].message.content or "" + tool_calls = ToolCall.parse_from_text(response_text) + + step = AgentStep( + step_number=step_num + 1, + assistant_message=response_text, + tool_calls=tool_calls, + sequence_data=SequenceData.from_sequence_node(current_node) if current_node else None, + ) + + if not tool_calls: + steps.append(step) + final_response = response_text + final_node = current_node + final_prompt_messages = prompt_messages + break + + messages.append({"role": "assistant", "content": response_text}) + + tool_responses = [] + for call in tool_calls: + result = await self.execute_tool(call) + step.tool_results.append(result) + tool_responses.append(result.to_xml()) + if self.config.tool_delay_s > 0: + await asyncio.sleep(self.config.tool_delay_s) + + steps.append(step) + + responses_text = "\n".join(tool_responses) + # Tool observations are represented as user content with Hermes-style tags. + # This is compatible with most OpenAI-compatible chat APIs and ensures + # tokenizers/chat templates include tool outputs during training. + messages.append({"role": "user", "content": responses_text}) + + else: + # Reached max steps without completing + return AgentResult( + success=False, + final_response=final_response, + steps=steps, + error=f"Reached maximum steps ({self.config.max_steps})", + ) + + # Build result with trajectory data + trajectory_data = None + if final_node: + trajectory_data = SequenceData.from_sequence_node(final_node) + elif final_prompt_messages is not None and self.tokenizer is not None: + if hasattr(self.tokenizer, "apply_chat_template"): + prompt_text = self.tokenizer.apply_chat_template( + final_prompt_messages, tokenize=False, add_generation_prompt=True + ) + prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False) + else: + prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages]) + prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True) + output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False) + tokens = prompt_tokens + output_tokens + masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens + logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens)) + trajectory_data = SequenceData( + full_text=f"{prompt_text}{final_response}", + tokens=tokens, + masked_tokens=masked_tokens, + logprobs=logprobs, + ) + + return AgentResult( + success=True, + final_response=final_response, + steps=steps, + trajectory_data=trajectory_data, + ) + + async def run_single_turn( + self, + messages: List[Dict[str, str]], + execute_tools: bool = True, + ) -> tuple[str, List[ToolResult], Optional[SequenceData]]: + """ + Run a single turn of the agent (one LLM call + tool execution). + + This is useful for integration with BaseEnv where you want more + control over the loop. + + Args: + messages: The conversation history + execute_tools: Whether to execute parsed tool calls + + Returns: + Tuple of (response_text, tool_results, sequence_data) + """ + async with self._managed() as managed: + response = await managed.chat_completion( + messages=messages, + n=1, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + ) + + current_node = None + if hasattr(managed, "get_state"): + state = managed.get_state() + nodes = state.get("nodes", []) + current_node = nodes[-1] if nodes else None + + response_text = response.choices[0].message.content or "" + tool_results = [] + + if execute_tools: + tool_calls = ToolCall.parse_from_text(response_text) + for call in tool_calls: + result = await self.execute_tool(call) + tool_results.append(result) + + sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None + + return response_text, tool_results, sequence_data diff --git a/atropos/api/__init__.py b/atropos/api/__init__.py new file mode 100644 index 0000000000..6648a8372d --- /dev/null +++ b/atropos/api/__init__.py @@ -0,0 +1,6 @@ +""" +FastAPI services for atropos-agent. + +- tool_executor_server: queued/batched sandbox tool execution (Phase 4) +""" + diff --git a/atropos/api/tool_executor_server.py b/atropos/api/tool_executor_server.py new file mode 100644 index 0000000000..95deed9331 --- /dev/null +++ b/atropos/api/tool_executor_server.py @@ -0,0 +1,262 @@ +""" +Tool Executor API (Phase 4) + +This service provides a queued, batched execution layer on top of SlotPool. +It mirrors the stateful FastAPI + app.state pattern used in: + atropos/atroposlib/api/server.py + +Run (dev): + uv run uvicorn atropos_agent.api.tool_executor_server:app --host 0.0.0.0 --port 9001 +""" + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional +from pathlib import Path + +from fastapi import FastAPI, Header, HTTPException, status +from pydantic import BaseModel, Field + +from ..slots import SlotPool, SlotPoolConfig +from ..tools import BashTool, ImageGenerateTool, ReadFileTool, ToolRegistry, WriteFileTool +from ..tools.mixture_of_agents_tool import MixtureOfAgentsTool +from ..tools.terminal_tool import TerminalTool +from ..tools.vision_tools import VisionAnalyzeTool +from ..tools.web_tools import WebCrawlTool, WebExtractTool, WebSearchTool +from ..tools.base import ( + ArtifactArchiveRequestPayload, + ArtifactArchiveResponsePayload, + ArtifactListRequestPayload, + ArtifactListResponsePayload, + ArtifactReadRequestPayload, + ArtifactReadResponsePayload, + ToolExecutorExecuteRequest, + ToolExecutorReleaseRequest, + ToolResultPayload, +) +from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig + + +class ToolExecutorServerConfig(BaseModel): + nomad_address: str = Field(default="http://localhost:4646") + job_id: str = Field(default="atropos-sandbox-tool-executor") + image: str = Field(default="atropos-sandbox:local") + slots_per_container: int = Field(default=10) + min_containers: int = Field(default=1) + max_containers: int = Field(default=10) + privileged: bool = Field(default=False) + acquire_timeout_s: float = Field(default=30.0) + + batch_window_ms: int = Field(default=20) + max_batch_size: int = Field(default=200) + allow_network: bool = Field(default=True) + + tool_server_url: Optional[str] = Field(default=None) + tool_server_token: Optional[str] = Field(default=None) + + token: Optional[str] = Field(default=None, description="Bearer token required for requests (optional in dev).") + + purge_job_on_shutdown: bool = Field(default=True) + + @classmethod + def from_env(cls) -> "ToolExecutorServerConfig": + # In dev, prefer loading secrets/config from the repo-local `.env` (not committed). + try: + from dotenv import load_dotenv # type: ignore + except Exception: # pragma: no cover + load_dotenv = None # type: ignore[assignment] + if load_dotenv is not None: + env_path = Path(__file__).resolve().parents[2] / ".env" + if env_path.exists(): + load_dotenv(dotenv_path=env_path) + + def _get_bool(name: str, default: bool) -> bool: + raw = os.getenv(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "y", "on"} + + return cls( + nomad_address=os.getenv("TOOL_EXECUTOR_NOMAD_ADDRESS", "http://localhost:4646"), + job_id=os.getenv("TOOL_EXECUTOR_JOB_ID", "atropos-sandbox-tool-executor"), + image=os.getenv("TOOL_EXECUTOR_IMAGE", "atropos-sandbox:local"), + slots_per_container=int(os.getenv("TOOL_EXECUTOR_SLOTS", "10")), + min_containers=int(os.getenv("TOOL_EXECUTOR_MIN_CONTAINERS", "1")), + max_containers=int(os.getenv("TOOL_EXECUTOR_MAX_CONTAINERS", "10")), + privileged=_get_bool("TOOL_EXECUTOR_PRIVILEGED", False), + acquire_timeout_s=float(os.getenv("TOOL_EXECUTOR_ACQUIRE_TIMEOUT_S", "30.0")), + batch_window_ms=int(os.getenv("TOOL_EXECUTOR_BATCH_WINDOW_MS", "20")), + max_batch_size=int(os.getenv("TOOL_EXECUTOR_MAX_BATCH_SIZE", "200")), + allow_network=_get_bool("TOOL_EXECUTOR_ALLOW_NETWORK", True), + tool_server_url=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_URL") or None, + tool_server_token=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_TOKEN") or None, + token=os.getenv("TOOL_EXECUTOR_TOKEN") or None, + purge_job_on_shutdown=_get_bool("TOOL_EXECUTOR_PURGE_JOB_ON_SHUTDOWN", True), + ) + + +app = FastAPI(title="Atropos-Agent Tool Executor") + + +@app.get("/") +async def root() -> Dict[str, str]: + return {"message": "Atropos-Agent Tool Executor"} + + +def _check_auth(cfg: ToolExecutorServerConfig, authorization: Optional[str]) -> None: + if not cfg.token: + return + if not authorization: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header") + if not authorization.lower().startswith("bearer "): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header") + token = authorization.split(" ", 1)[1].strip() + if token != cfg.token: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token") + + +@app.on_event("startup") +async def _startup() -> None: + cfg = ToolExecutorServerConfig.from_env() + + tools = ToolRegistry() + tools.register(BashTool()) + tools.register(TerminalTool()) + tools.register(ReadFileTool()) + tools.register(WriteFileTool()) + tools.register(ImageGenerateTool()) + tools.register(WebSearchTool()) + tools.register(WebExtractTool()) + tools.register(WebCrawlTool()) + tools.register(VisionAnalyzeTool()) + tools.register(MixtureOfAgentsTool()) + + pool = SlotPool( + SlotPoolConfig( + nomad_address=cfg.nomad_address, + job_id=cfg.job_id, + image=cfg.image, + slots_per_container=cfg.slots_per_container, + min_containers=cfg.min_containers, + max_containers=cfg.max_containers, + privileged=cfg.privileged, + acquire_timeout=cfg.acquire_timeout_s, + ) + ) + await pool.start() + + executor = ToolExecutor( + pool=pool, + tools=tools, + config=ToolExecutorConfig( + batch_window_ms=cfg.batch_window_ms, + max_batch_size=cfg.max_batch_size, + allow_network=cfg.allow_network, + tool_server_url=cfg.tool_server_url, + tool_server_token=cfg.tool_server_token, + ), + ) + await executor.start() + + app.state.cfg = cfg + app.state.pool = pool + app.state.executor = executor + + +@app.on_event("shutdown") +async def _shutdown() -> None: + executor: Optional[ToolExecutor] = getattr(app.state, "executor", None) + pool: Optional[SlotPool] = getattr(app.state, "pool", None) + cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None) + + if executor is not None: + await executor.close() + + if pool is not None: + await pool.stop(purge_job=bool(cfg.purge_job_on_shutdown) if cfg else False) + + +@app.get("/health") +async def health() -> Dict[str, Any]: + return {"status": "ok"} + + +@app.get("/status") +async def status_endpoint() -> Dict[str, Any]: + executor: ToolExecutor = app.state.executor + pool: SlotPool = app.state.pool + + return { + "queue_size": executor.queue_size(), + "total_requests": executor.total_requests, + "total_errors": executor.total_errors, + "pool": pool.get_stats(), + } + + +@app.post("/execute", response_model=ToolResultPayload) +async def execute_tool( + req: ToolExecutorExecuteRequest, + authorization: Optional[str] = Header(default=None), + status_code: int = status.HTTP_200_OK, # noqa: B008 +) -> ToolResultPayload: + cfg: ToolExecutorServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + executor: ToolExecutor = app.state.executor + result = await executor.execute( + trajectory_id=req.trajectory_id, + call=req.tool.to_tool_call(), + timeout_s=req.timeout_s, + ) + return ToolResultPayload.from_tool_result(result) + + +@app.post("/release") +async def release_trajectory( + req: ToolExecutorReleaseRequest, + authorization: Optional[str] = Header(default=None), +) -> Dict[str, Any]: + cfg: ToolExecutorServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + executor: ToolExecutor = app.state.executor + await executor.release_trajectory(req.trajectory_id, reset_workspace=req.reset_workspace) + return {"status": "ok"} + + +@app.post("/artifacts/read", response_model=ArtifactReadResponsePayload) +async def artifacts_read( + req: ArtifactReadRequestPayload, + authorization: Optional[str] = Header(default=None), +) -> ArtifactReadResponsePayload: + cfg: ToolExecutorServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + executor: ToolExecutor = app.state.executor + return await executor.read_artifact(req) + + +@app.post("/artifacts/list", response_model=ArtifactListResponsePayload) +async def artifacts_list( + req: ArtifactListRequestPayload, + authorization: Optional[str] = Header(default=None), +) -> ArtifactListResponsePayload: + cfg: ToolExecutorServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + executor: ToolExecutor = app.state.executor + return await executor.list_artifacts(req) + + +@app.post("/artifacts/archive", response_model=ArtifactArchiveResponsePayload) +async def artifacts_archive( + req: ArtifactArchiveRequestPayload, + authorization: Optional[str] = Header(default=None), +) -> ArtifactArchiveResponsePayload: + cfg: ToolExecutorServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + executor: ToolExecutor = app.state.executor + return await executor.archive_artifacts(req) diff --git a/atropos/api/tool_server.py b/atropos/api/tool_server.py new file mode 100644 index 0000000000..9357886b8e --- /dev/null +++ b/atropos/api/tool_server.py @@ -0,0 +1,149 @@ +""" +External ToolServer (Phase 4.5+). + +This server executes tools that must NOT run inside the sandbox, typically +because they require credentials or access to external services. + +Run (dev): + uv run uvicorn atropos_agent.api.tool_server:app --host 0.0.0.0 --port 9002 +""" + +from __future__ import annotations + +import asyncio +import os +import inspect +from typing import Any, Dict, List, Optional +from pathlib import Path + +from fastapi import FastAPI, Header, HTTPException, status +from pydantic import BaseModel, Field + +from ..tools import ToolRegistry +from ..tools.base import ToolResultPayload, ToolServerExecuteRequest +from ..tools.image_generation_tool import ImageGenerateTool +from ..tools.mixture_of_agents_tool import MixtureOfAgentsTool +from ..tools.vision_tools import VisionAnalyzeTool +from ..tools.web_tools import WebCrawlTool, WebExtractTool, WebSearchTool + + +class ToolServerConfig(BaseModel): + token: Optional[str] = Field( + default=None, + description="Bearer token required for requests (optional in dev).", + ) + max_concurrency: int = Field(default=16, ge=1, description="Max concurrent tool executions.") + + @classmethod + def from_env(cls) -> "ToolServerConfig": + # In dev, prefer loading secrets from the repo-local `.env` (not committed). + try: + from dotenv import load_dotenv # type: ignore + except Exception: # pragma: no cover + load_dotenv = None # type: ignore[assignment] + if load_dotenv is not None: + env_path = Path(__file__).resolve().parents[2] / ".env" + if env_path.exists(): + load_dotenv(dotenv_path=env_path) + + token = os.getenv("TOOL_SERVER_TOKEN") or None + max_concurrency = int(os.getenv("TOOL_SERVER_MAX_CONCURRENCY", "16")) + return cls(token=token, max_concurrency=max_concurrency) + + +app = FastAPI(title="Atropos-Agent Tool Server") + + +@app.get("/") +async def root() -> Dict[str, str]: + return {"message": "Atropos-Agent Tool Server"} + + +@app.on_event("startup") +async def _startup() -> None: + cfg = ToolServerConfig.from_env() + + tools = ToolRegistry() + for tool in [ + ImageGenerateTool(), + WebSearchTool(), + WebExtractTool(), + WebCrawlTool(), + VisionAnalyzeTool(), + MixtureOfAgentsTool(), + ]: + ok, reason = tool.is_available() + if ok: + tools.register(tool) + else: + # Keep startup resilient when optional deps/keys are missing. + print(f"[ToolServer] Skipping tool '{tool.name}': {reason}") + + app.state.cfg = cfg + app.state.tools = tools + app.state.semaphore = asyncio.Semaphore(cfg.max_concurrency) + + +@app.get("/health") +async def health() -> Dict[str, Any]: + return {"status": "ok"} + + +@app.get("/tools") +async def list_tools() -> Dict[str, Any]: + tools: ToolRegistry = app.state.tools + return {"tools": [s.to_dict() for s in tools.get_schemas()]} + + +def _check_auth(cfg: ToolServerConfig, authorization: Optional[str]) -> None: + if not cfg.token: + return + if not authorization: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header") + if not authorization.lower().startswith("bearer "): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header") + token = authorization.split(" ", 1)[1].strip() + if token != cfg.token: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token") + + +@app.post("/execute", response_model=ToolResultPayload) +async def execute_tool( + req: ToolServerExecuteRequest, + authorization: Optional[str] = Header(default=None), +) -> ToolResultPayload: + cfg: ToolServerConfig = app.state.cfg + _check_auth(cfg, authorization) + + tools: ToolRegistry = app.state.tools + sem: asyncio.Semaphore = app.state.semaphore + + tool = tools.get(req.tool.name) + if tool is None: + return ToolResultPayload( + success=False, + error=f"Unknown tool: {req.tool.name}", + uniq_id=req.tool.uniq_id, + ) + + async with sem: + try: + kwargs = dict(req.tool.arguments) + # Some external tools need access to the trajectory/workspace context (e.g. fetching sandbox artifacts). + if req.trajectory_id and "trajectory_id" in inspect.signature(tool.execute).parameters: + kwargs["trajectory_id"] = req.trajectory_id + if req.slot_id and "slot_id" in inspect.signature(tool.execute).parameters: + kwargs["slot_id"] = req.slot_id + if req.container_addr and "container_addr" in inspect.signature(tool.execute).parameters: + kwargs["container_addr"] = req.container_addr + result = await tool.execute(**kwargs) + except Exception as e: + return ToolResultPayload( + success=False, + error=f"Tool execution error: {e}", + uniq_id=req.tool.uniq_id, + ) + + if result.uniq_id is None: + result.uniq_id = req.tool.uniq_id + return ToolResultPayload.from_tool_result(result) diff --git a/atropos/envs/__init__.py b/atropos/envs/__init__.py new file mode 100644 index 0000000000..c56655345f --- /dev/null +++ b/atropos/envs/__init__.py @@ -0,0 +1,10 @@ +""" +Environment implementations for atropos-agent. +""" + +from .agent_env import AgentEnv, AgentEnvConfig + +# NOTE: Additional example envs exist as modules (e.g. `test_env`, `swe_smith_oracle_env`), +# but are intentionally not imported here to avoid pulling heavy optional deps at import time. + +__all__ = ["AgentEnv", "AgentEnvConfig"] diff --git a/atropos/envs/agent_env.py b/atropos/envs/agent_env.py new file mode 100644 index 0000000000..1c4bbcfd7f --- /dev/null +++ b/atropos/envs/agent_env.py @@ -0,0 +1,427 @@ +""" +AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads. + +AgentEnv is responsible for starting the sandbox tool execution backend and +providing helpers for running agent trajectories with queued/batched tool calls. + +For Phase 4 we support a Nomad-backed `SlotPool` for true container sandboxing. +""" + +from __future__ import annotations + +import asyncio +import uuid +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar + +from pydantic import Field + +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem + +from ..agent import AgentConfig, AtroposAgent +from ..slots import SlotPool, SlotPoolConfig +from ..tools import BashTool, ReadFileTool, ToolRegistry, WriteFileTool +from ..tools.image_generation_tool import ImageGenerateTool +from ..tools.mixture_of_agents_tool import MixtureOfAgentsTool +from ..tools.terminal_tool import TerminalTool +from ..tools.terminal_stateful_tool import TerminalStatefulTool +from ..tools.tmux_tool import TmuxTool +from ..tools.toolsets import resolve_multiple_toolsets +from ..tools.vision_tools import VisionAnalyzeTool +from ..tools.web_tools import WebCrawlTool, WebExtractTool, WebSearchTool +from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig + + +class AgentEnvConfig(BaseEnvConfig): + tool_pool_mode: str = Field(default="nomad", description="Tool execution backend (only 'nomad' is supported)") + + allow_network: bool = Field( + default=True, + description="Whether sandbox bash commands may access the network (env policy).", + ) + require_sandbox: bool = Field( + default=False, + description="Fail closed if bubblewrap sandboxing is unavailable/unusable for stateless sandbox tools.", + ) + require_stateful_sandbox: bool = Field( + default=False, + description="Fail closed if bubblewrap/PID isolation is unavailable for stateful terminal tools (tmux).", + ) + tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)") + tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size") + + # nomad mode settings + nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address") + sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers") + sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers") + slots_per_container: int = Field(default=10, description="Nomad mode: slots per container") + min_containers: int = Field(default=1, description="Nomad mode: minimum containers") + max_containers: int = Field(default=10, description="Nomad mode: maximum containers") + privileged: bool = Field(default=False, description="Nomad mode: run container privileged") + acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds)") + purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown") + + # basic agent defaults + agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory") + agent_temperature: float = Field(default=0.7, description="Sampling temperature") + agent_max_tokens: int = Field(default=4096, description="Max tokens per model response") + agent_tool_delay_s: float = Field(default=0.0, description="Delay between tool calls (seconds)") + + # tool selection + enabled_toolsets: List[str] = Field( + default_factory=lambda: ["default"], + description="Toolsets to enable (Hermes-style grouping).", + ) + disabled_toolsets: List[str] = Field( + default_factory=list, + description="Toolsets to disable (applied after enabled_toolsets).", + ) + + # external ToolServer routing (Phase 4.5+) + tool_server_url: Optional[str] = Field( + default=None, + description="Base URL for external ToolServer (enables external tools).", + ) + tool_server_token: Optional[str] = Field( + default=None, + description="Bearer token for ToolServer auth (optional in dev).", + ) + +AgentEnvConfigT = TypeVar("AgentEnvConfigT", bound="AgentEnvConfig") + + +class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): + env_config_cls = AgentEnvConfig + + def __init__( + self, + config: AgentEnvConfigT, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config: AgentEnvConfigT = config + + self.tools: ToolRegistry = self.build_tools() + + self._pool: Optional[Any] = None + self._tool_executor: Optional[ToolExecutor] = None + self._tool_server_inprocess: bool = False + + def build_tools(self) -> ToolRegistry: + available_tools = [ + BashTool(), + TerminalTool(), + TerminalStatefulTool(), + TmuxTool(), + ReadFileTool(), + WriteFileTool(), + ImageGenerateTool(), + WebSearchTool(), + WebExtractTool(), + WebCrawlTool(), + VisionAnalyzeTool(), + MixtureOfAgentsTool(), + ] + + tool_by_name = {t.name: t for t in available_tools} + + enabled_toolsets = self.config.enabled_toolsets or ["default"] + selected = set(resolve_multiple_toolsets(enabled_toolsets)) + if self.config.disabled_toolsets: + selected -= set(resolve_multiple_toolsets(self.config.disabled_toolsets)) + + tools = ToolRegistry() + for name in sorted(selected): + tool = tool_by_name.get(name) + if tool is None: + continue + # External tools require a ToolServer URL; avoid advertising broken tools. + if tool.schema.external and not self.config.tool_server_url: + continue + ok, _reason = tool.is_available() + if not ok: + continue + tools.register(tool) + + return tools + + @abstractmethod + def build_task(self, item: Item) -> str: + """Return the user-facing task string for the agent.""" + + @abstractmethod + async def score_trajectory(self, item: Item, final_response: str) -> float: + """Return a scalar score for this trajectory.""" + + async def setup_trajectory_workspace( + self, + item: Item, + *, + trajectory_id: str, + exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]], + ) -> Dict[str, Any]: + """ + Optional hook: prepare the sandbox workspace before the agent starts. + + Examples: + - clone a repo and checkout a commit + - write fixture files (e.g. images) for external-tool demos + - pre-install dependencies + + Default: no-op. + """ + _ = (item, trajectory_id, exec_tool) + return {} + + async def verify_and_score_trajectory( + self, + item: Item, + final_response: str, + *, + trajectory_id: str, + exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]], + ) -> tuple[float, Dict[str, Any]]: + """ + Optional hook: run in-sandbox verification before scoring. + + Many agent envs need to execute verification inside the same trajectory + workspace (e.g. pytest) before releasing/resetting the slot. + + Default: calls `score_trajectory()` and returns empty metadata. + """ + _ = (trajectory_id, exec_tool) # default ignores in-workspace verification + score = await self.score_trajectory(item, final_response) + return score, {} + + def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002 + return AgentConfig( + max_steps=self.config.agent_max_steps, + temperature=self.config.agent_temperature, + max_tokens=self.config.agent_max_tokens, + tool_delay_s=self.config.agent_tool_delay_s, + ) + + async def setup(self) -> None: + await self._start_tool_backend() + await self.setup_agent_env() + + @abstractmethod + async def setup_agent_env(self) -> None: + """Subclass hook for env-specific setup.""" + + async def evaluate(self, *args, **kwargs): # noqa: ARG002 + """ + Default eval hook (no-op). + + Atropos BaseEnv requires an `evaluate()` implementation. Many agent envs + won't have a meaningful evaluation path during early PoC work; they can + override this when needed. + """ + return {} + + async def env_manager(self): + try: + return await super().env_manager() + finally: + await self.shutdown_tool_backend() + + async def process_manager(self): + try: + return await super().process_manager() + finally: + await self.shutdown_tool_backend() + + async def _start_tool_backend(self) -> None: + if self._tool_executor is not None: + return + + tool_server_url = self.config.tool_server_url + tool_server_client = None + if tool_server_url == "inprocess": + import httpx + from ..api.tool_server import app as tool_server_app + + await tool_server_app.router.startup() + tool_server_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=tool_server_app), + base_url="http://toolserver", + ) + tool_server_url = "http://toolserver" + self._tool_server_inprocess = True + + if self.config.tool_pool_mode != "nomad": + raise RuntimeError("tool_pool_mode must be 'nomad' (local/in-process pools are not supported)") + + pool = SlotPool( + SlotPoolConfig( + nomad_address=self.config.nomad_address, + job_id=self.config.sandbox_job_id, + image=self.config.sandbox_image, + slots_per_container=self.config.slots_per_container, + min_containers=self.config.min_containers, + max_containers=self.config.max_containers, + privileged=self.config.privileged, + acquire_timeout=self.config.acquire_timeout_s, + ) + ) + await pool.start() + + executor = ToolExecutor( + pool=pool, + tools=self.tools, + config=ToolExecutorConfig( + batch_window_ms=self.config.tool_batch_window_ms, + max_batch_size=self.config.tool_max_batch_size, + allow_network=self.config.allow_network, + require_sandbox=self.config.require_sandbox, + require_stateful_sandbox=self.config.require_stateful_sandbox, + tool_server_url=tool_server_url, + tool_server_token=self.config.tool_server_token, + ), + ) + await executor.start() + if tool_server_client is not None: + executor._tool_server_client = tool_server_client # type: ignore[attr-defined] + + self._pool = pool + self._tool_executor = executor + + async def shutdown_tool_backend(self) -> None: + executor = self._tool_executor + pool = self._pool + inprocess_tool_server = self._tool_server_inprocess + self._tool_executor = None + self._pool = None + self._tool_server_inprocess = False + + if executor is not None: + await executor.close() + if pool is not None: + await pool.stop(purge_job=bool(self.config.purge_job_on_shutdown)) + if inprocess_tool_server: + from ..api.tool_server import app as tool_server_app + + await tool_server_app.router.shutdown() + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[ScoredDataItem], List[Item]]: + if self._tool_executor is None: + raise RuntimeError("Tool backend not started") + + trajectory_id = str(uuid.uuid4()) + task = self.build_task(item) + agent_config = self.build_agent_config(item) + + async def _exec(call): + return await self._tool_executor.execute(trajectory_id, call) + + agent = AtroposAgent( + server=self.server, + tokenizer=self.tokenizer, + tools=self.tools, + config=agent_config, + execute_tool=_exec, + ) + + try: + await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec) + + result = await agent.run(task) + if not result.success or result.trajectory_data is None: + return None, [] + + score, _score_metadata = await self.verify_and_score_trajectory( + item, + result.final_response, + trajectory_id=trajectory_id, + exec_tool=_exec, + ) + + messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001 + messages.append({"role": "user", "content": task}) + for step in result.steps: + messages.append({"role": "assistant", "content": step.assistant_message}) + if step.tool_results: + tool_text = "\n".join(r.to_xml() for r in step.tool_results) + messages.append({"role": "user", "content": tool_text}) + + scored: ScoredDataItem = { + "tokens": result.trajectory_data.tokens, + "masks": result.trajectory_data.masked_tokens, + "scores": score, + } + if self.config.include_messages: + scored["messages"] = messages + + return scored, [] + finally: + await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[Optional[ScoredDataGroup], List[Item]]: + tasks = [self.collect_trajectory(item) for _ in range(self.config.group_size)] + results = await asyncio.gather(*tasks) + + backlog: List[Item] = [] + items: List[ScoredDataItem] = [] + for scored, b in results: + backlog.extend(b) + if scored is not None: + items.append(scored) + + if len(items) != self.config.group_size: + return None, backlog + + group: ScoredDataGroup = ScoredDataGroup( + tokens=[], + masks=[], + scores=[], + advantages=[], + ref_logprobs=[], + messages=[] if self.config.include_messages else None, + group_overrides={}, + overrides=[], + images=[], + ) + + for it in items: + group["tokens"].append(it["tokens"]) + group["masks"].append(it["masks"]) + group["scores"].append(it["scores"]) + if group.get("messages") is not None and it.get("messages") is not None: + group["messages"].append(it["messages"]) + + return group, backlog + + async def run_agent(self, task: str, *, trajectory_id: Optional[str] = None) -> Tuple[str, Dict[str, Any]]: + """ + Run the AtroposAgent on a single task and return (final_response, debug). + + This is a helper intended for simple environments and tests. + """ + if self._tool_executor is None: + raise RuntimeError("Tool backend not started") + + tid = trajectory_id or str(uuid.uuid4()) + + async def _exec(call): + return await self._tool_executor.execute(tid, call) + + agent = AtroposAgent( + server=self.server, + tokenizer=self.tokenizer, + tools=self.tools, + config=AgentConfig( + max_steps=self.config.agent_max_steps, + temperature=self.config.agent_temperature, + max_tokens=self.config.agent_max_tokens, + ), + execute_tool=_exec, + ) + result = await agent.run(task) + await self._tool_executor.release_trajectory(tid, reset_workspace=True) + return result.final_response, {"success": result.success, "error": result.error, "tool_calls": result.total_tool_calls} diff --git a/atropos/envs/hermes_compat_test_env.py b/atropos/envs/hermes_compat_test_env.py new file mode 100644 index 0000000000..270b507f83 --- /dev/null +++ b/atropos/envs/hermes_compat_test_env.py @@ -0,0 +1,208 @@ +""" +Hermes-Agent (Atropos-compatible) smoke environment. + +This is a minimal `BaseEnv` environment that uses Hermes-Agent's Atropos-backed +runner (`AtroposAIAgent`) and can be exercised via `BaseEnv`'s `process` mode. + +This deliberately does NOT use slot multiplexing / sandboxes yet (stage 1). +""" + +from __future__ import annotations + +import json +import os +import uuid +from typing import Dict, List, Tuple + +from dotenv import load_dotenv +from pydantic import Field + +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item + +load_dotenv() + + +def _build_forced_tool_item() -> Item: + """ + Construct a task that *cannot* be completed reliably without executing a tool. + + We generate a high-entropy token *inside the tool execution* and ask the agent to + repeat it exactly. Scoring verifies that: + - a terminal tool call occurred (role="tool" message present), and + - the final answer matches the tool stdout exactly. + """ + return { + "command": "python -c \"import secrets; print(secrets.token_hex(16))\"", + "prompt": ( + "Use the terminal tool to run:\n" + "python -c \"import secrets; print(secrets.token_hex(16))\"\n" + "Then answer with EXACTLY what it printed and nothing else." + ), + } + + +TEST_ITEMS: List[Item] = [ + _build_forced_tool_item(), + _build_forced_tool_item(), +] + + +class HermesCompatTestEnvConfig(BaseEnvConfig): + """Config for HermesCompatTestEnv.""" + + server_base_url: str = Field( + default="http://localhost:11434", + description="Base URL for an OpenAI-compatible chat server (without /v1).", + ) + server_model: str = Field(default="glm-4.7-flash", description="Model name") + + +class HermesCompatTestEnv(BaseEnv): + """ + Minimal BaseEnv that runs Hermes-Agent's Atropos-compatible agent loop. + + Run (process mode): + uv run atropos-agent-hermes-compat-test process --env.use_wandb false --env.total_steps 2 --env.group_size 1 + """ + + name = "hermes_compat_test_env" + env_config_cls = HermesCompatTestEnvConfig + + def __init__( + self, + config: HermesCompatTestEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config=config, server_configs=server_configs, slurm=slurm, testing=testing) + self._iter = 0 + + from atropos_compatible_agent import AtroposAIAgent # noqa: WPS433 + + # Only expose terminal for this smoke env. + self._agent = AtroposAIAgent( + server=self.server, + tokenizer=self.tokenizer, + model=getattr(config, "server_model", "local"), + max_iterations=8, + enabled_toolsets=["terminal"], + tool_delay=0.0, + # Let the server decide token limits; we care about tool calling correctness here. + max_tokens=None, + temperature=None, + ) + + @classmethod + def config_init(cls) -> Tuple[HermesCompatTestEnvConfig, List[APIServerConfig]]: + base_url = ( + os.getenv("ATROPOS_SERVER_BASE_URL") + or os.getenv("OPENAI_BASE_URL") + or os.getenv("LLM_BASE_URL") + or "http://localhost:11434" + ) + model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "glm-4.7-flash" + # Never pass through real API keys in this smoke env (they will be printed by BaseEnv config logging). + # Local OpenAI-compatible servers typically ignore the API key anyway. + api_key = "local" + + env_config = HermesCompatTestEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", + group_size=1, + use_wandb=False, + include_messages=True, + ensure_scores_are_not_same=False, + total_steps=2, + batch_size=1, + server_base_url=base_url, + server_model=model, + ) + + server_configs = [ + APIServerConfig( + server_type="openai", + model_name=model, + base_url=f"{base_url}/v1", + api_key=api_key, + num_max_requests_at_once=1, + num_requests_for_eval=1, + timeout=120, + ) + ] + return env_config, server_configs + + async def setup(self): + return None + + async def get_next_item(self) -> Item: + # Regenerate token per task to avoid leakage across steps. + item = _build_forced_tool_item() + self._iter += 1 + return item + + async def collect_trajectory(self, item: Item): + prompt = item.get("prompt", "") + + result = await self._agent.run_conversation_async( + prompt, + task_id=str(uuid.uuid4()), + ) + + final = (result.get("final_response") or "").strip() + + # Verify the agent actually executed the tool by extracting stdout from the tool message. + observed: str = "" + saw_tool = False + for msg in result.get("messages", []): + if msg.get("role") == "tool": + saw_tool = True + # Tool messages contain JSON strings from terminal tool. + try: + payload = json.loads(msg.get("content") or "{}") + out = (payload.get("output") or "").strip() + if out: + observed = out.splitlines()[-1].strip() + except Exception: + continue + # Pass if: + # - a tool call occurred, and + # - the final answer matches the observed stdout exactly. + score = 1.0 if saw_tool and observed and final == observed else 0.0 + + # Tokenization fallback: build tokens/masks from final prompt + completion. + # Note: this is sufficient for smoke testing; production training should + # use a backend that supports ManagedServer state tracking. + system_prompt = result.get("system_prompt") + messages: List[Dict[str, str]] = result.get("messages", []) + prompt_messages = messages[:-1] if messages and messages[-1].get("role") == "assistant" else messages + + if system_prompt: + prompt_messages = [{"role": "system", "content": system_prompt}] + prompt_messages + + if hasattr(self.tokenizer, "apply_chat_template"): + prompt_text = self.tokenizer.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False) + else: + prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in prompt_messages]) + prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True) + + output_tokens = self.tokenizer.encode(final, add_special_tokens=False) + + scored = { + "tokens": prompt_tokens + output_tokens, + "masks": ([-100] * len(prompt_tokens)) + output_tokens, + "scores": score, + "messages": prompt_messages + [{"role": "assistant", "content": final}], + } + + return scored, [] + + async def evaluate(self, *args, **kwargs): # noqa: ARG002 + # Minimal eval hook for BaseEnv abstract method. + return {} + + +if __name__ == "__main__": + HermesCompatTestEnv.cli() diff --git a/atropos/envs/swe_smith_oracle_env.py b/atropos/envs/swe_smith_oracle_env.py new file mode 100644 index 0000000000..756284ed2e --- /dev/null +++ b/atropos/envs/swe_smith_oracle_env.py @@ -0,0 +1,284 @@ +""" +SWE-smith-oracle benchmark environment (Phase 4.7). + +This environment is intentionally minimal: +- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit` +- runs an AtroposAgent tool loop to apply a fix +- verifies by running pytest nodeids from the dataset (reward = pass/fail) + +Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set). +""" + +from __future__ import annotations + +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import Field + +from atroposlib.envs.base import APIServerConfig, Item + +from ..agent import AgentConfig +from ..tools import ToolCall +from .agent_env import AgentEnv, AgentEnvConfig + + +class SweSmithOracleEnvConfig(AgentEnvConfig): + dataset_name: str = Field(default="NousResearch/SWE-smith-oracle") + dataset_split: str = Field(default="train") + max_items: int = Field(default=0, description="0 = no limit") + shuffle: bool = Field(default=True) + seed: int = Field(default=0) + + python_only: bool = Field(default=True, description="Filter to Python-evaluable rows") + score_include_fail_to_pass: bool = Field( + default=False, + description="If true, score tests on PASS_TO_PASS ∪ FAIL_TO_PASS; else PASS_TO_PASS only.", + ) + + repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning") + install_timeout_s: float = Field(default=600.0) + test_timeout_s: float = Field(default=600.0) + + +class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): + """ + SWE-smith-oracle AgentEnv. + + This is designed for benchmarking multiplexed slot execution vs naive container-per-trajectory. + """ + + name = "swe_smith_oracle_env" + env_config_cls = SweSmithOracleEnvConfig + + def __init__( + self, + config: SweSmithOracleEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self._dataset = None + self._indices: List[int] = [] + self._cursor = 0 + + @classmethod + def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]: + # Defaults for running the env via CLI in offline `process` mode. + # Override via env vars or `--env.*` flags as needed. + base_url = ( + os.getenv("ATROPOS_SERVER_BASE_URL") + or os.getenv("OPENAI_BASE_URL") + or os.getenv("LLM_BASE_URL") + or "http://localhost:11434" + ) + model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "glm-4.7-flash" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + + env_config = SweSmithOracleEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", # tokenization only + group_size=1, + use_wandb=False, + rollout_server_url="http://localhost:8000", + total_steps=1, + batch_size=1, + steps_per_eval=1, + max_token_length=8192, + inference_weight=1.0, + wandb_name="swe_smith_oracle", + ) + + server_configs = [ + APIServerConfig( + model_name=model, + base_url=f"{base_url.rstrip('/')}/v1", + api_key=api_key, + num_max_requests_at_once=1, + num_requests_for_eval=1, + timeout=300, + ), + ] + + return env_config, server_configs + + async def setup_agent_env(self) -> None: + from datasets import load_dataset + + ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split) + self._dataset = ds + + indices: List[int] = [] + for idx in range(len(ds)): + row = ds[idx] + if self.config.python_only and not self._is_python_row(row): + continue + indices.append(idx) + + if self.config.shuffle: + rnd = random.Random(self.config.seed) + rnd.shuffle(indices) + + if self.config.max_items and self.config.max_items > 0: + indices = indices[: self.config.max_items] + + self._indices = indices + self._cursor = 0 + + print( + f"SweSmithOracleEnv loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split}" + ) + + def _is_python_row(self, row: Dict[str, Any]) -> bool: + nodeids = row.get("PASS_TO_PASS") + if not isinstance(nodeids, list) or not nodeids: + return False + for nid in nodeids: + if not isinstance(nid, str) or ".py::" not in nid: + return False + return True + + async def get_next_item(self) -> Item: + if not self._dataset or not self._indices: + raise RuntimeError("Dataset not initialized (did setup() run?)") + if self._cursor >= len(self._indices): + self._cursor = 0 + idx = self._indices[self._cursor] + self._cursor += 1 + return dict(self._dataset[idx]) + + def _repo_name(self, item: Item) -> str: + repo = item.get("repo") or "" + if isinstance(repo, str) and "/" in repo: + return repo.split("/")[-1] + return "repo" + + def build_task(self, item: Item) -> str: + repo = item.get("repo") or "" + base_commit = item.get("base_commit") or "" + problem = item.get("problem_statement") or "" + context = item.get("text") or "" + + nodeids = self._tests_for_item(item) + tests_preview = "\n".join(f"- {t}" for t in nodeids[:50]) + if len(nodeids) > 50: + tests_preview += f"\n- ... ({len(nodeids) - 50} more)" + + repo_dir = self._repo_name(item) + return ( + "You are a senior software engineer. Fix the repository so the specified tests pass.\n\n" + f"Repository: {repo} (checked out at base_commit={base_commit})\n" + f"Workspace path: ./{repo_dir}\n\n" + "Constraints:\n" + "- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n" + "- Use non-interactive commands only.\n\n" + "Problem statement:\n" + f"{problem}\n\n" + "Additional context:\n" + f"{context}\n\n" + "Run these tests to verify:\n" + f"{tests_preview}\n\n" + "When done, briefly describe what you changed and confirm tests pass." + ) + + def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002 + # SWE tasks are longer than the simple test env. + return AgentConfig( + max_steps=self.config.agent_max_steps, + temperature=self.config.agent_temperature, + max_tokens=self.config.agent_max_tokens, + tool_delay_s=self.config.agent_tool_delay_s, + ) + + async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]: + _ = trajectory_id + repo = item.get("repo") + base_commit = item.get("base_commit") + if not isinstance(repo, str) or not isinstance(base_commit, str): + raise RuntimeError("Invalid dataset row: missing repo/base_commit") + + repo_dir = self._repo_name(item) + clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git" + + # Clone and checkout the base commit. + clone_cmd = f"rm -rf {repo_dir} && git clone {clone_url} {repo_dir}" + res = await exec_tool(ToolCall(name="terminal", arguments={"command": clone_cmd, "timeout": self.config.install_timeout_s})) + if not res.success: + raise RuntimeError(f"git clone failed: {res.error}\n{res.output}") + + checkout_cmd = f"cd {repo_dir} && git checkout {base_commit}" + res = await exec_tool(ToolCall(name="terminal", arguments={"command": checkout_cmd, "timeout": self.config.install_timeout_s})) + if not res.success: + raise RuntimeError(f"git checkout failed: {res.error}\n{res.output}") + + # Best-effort baseline python env. + setup_cmd = ( + f"cd {repo_dir} && " + "python -m venv .venv && " + ". .venv/bin/activate && " + "python -m pip install -U pip setuptools wheel && " + "python -m pip install -e . && " + "python -m pip install pytest" + ) + await exec_tool(ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})) + + return {"repo_dir": repo_dir, "base_commit": base_commit} + + def _tests_for_item(self, item: Item) -> List[str]: + tests: List[str] = [] + if self.config.score_include_fail_to_pass: + for key in ("PASS_TO_PASS", "FAIL_TO_PASS"): + nodeids = item.get(key) + if isinstance(nodeids, list): + tests.extend([n for n in nodeids if isinstance(n, str)]) + else: + nodeids = item.get("PASS_TO_PASS") + if isinstance(nodeids, list): + tests.extend([n for n in nodeids if isinstance(n, str)]) + # Stable order for reproducibility. + return sorted(dict.fromkeys(tests)) + + def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]: + chunks: List[List[str]] = [] + for i in range(0, len(nodeids), max_per_chunk): + chunks.append(nodeids[i : i + max_per_chunk]) + return chunks + + async def verify_and_score_trajectory( + self, + item: Item, + final_response: str, # noqa: ARG002 + *, + trajectory_id: str, + exec_tool, + ) -> tuple[float, Dict[str, Any]]: + _ = trajectory_id + repo_dir = self._repo_name(item) + nodeids = self._tests_for_item(item) + if not nodeids: + return 0.0, {"error": "No tests provided"} + + chunks = self._chunk_nodeids(nodeids, max_per_chunk=50) + for chunk_idx, chunk in enumerate(chunks): + joined = " ".join(chunk) + cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}" + res = await exec_tool( + ToolCall( + name="terminal", + arguments={"command": cmd, "timeout": self.config.test_timeout_s}, + ) + ) + if not res.success: + return 0.0, {"failed_chunk": chunk_idx, "error": res.error, "output": res.output} + + return 1.0, {"passed": True} + + async def score_trajectory(self, item: Item, final_response: str) -> float: + # Not used; scoring happens in verify_and_score_trajectory. + _ = (item, final_response) + return 0.0 + + +if __name__ == "__main__": + SweSmithOracleEnv.cli() diff --git a/atropos/envs/test_env.py b/atropos/envs/test_env.py new file mode 100644 index 0000000000..8ae6a2f27c --- /dev/null +++ b/atropos/envs/test_env.py @@ -0,0 +1,216 @@ +""" +Simple test environment for validating the atropos-agent setup. + +This environment uses a local OpenAI-compatible server for LLM testing to verify: +- BaseEnv extension works correctly +- API communication via OpenAI-compatible endpoint +- Basic trajectory collection + +This is a minimal environment for testing, not production use. +""" + +import os +from typing import Dict, List, Optional, Tuple + +from dotenv import load_dotenv +from pydantic import Field + +from atroposlib.envs.base import ( + APIServerConfig, + Item, +) + +from ..agent import AgentConfig +from .agent_env import AgentEnv, AgentEnvConfig + +# Load environment variables from .env file +load_dotenv() + + +# Simple test prompts for validation +TEST_PROMPTS = [ + { + "prompt": "What is 2 + 2? Answer with just the number.", + "expected": "4", + }, + { + "prompt": "What is the capital of France? Answer with just the city name.", + "expected": "Paris", + }, + { + "prompt": "What color is the sky on a clear day? Answer with just the color.", + "expected": "Blue", + }, + { + "prompt": "How many days are in a week? Answer with just the number.", + "expected": "7", + }, + { + "prompt": "What is 10 * 5? Answer with just the number.", + "expected": "50", + }, +] + +SYSTEM_PROMPT = ( + "You are a helpful assistant. Answer questions concisely and directly. " + "When asked for a simple answer, provide just that answer without explanation." +) + + +class SimpleTestEnvConfig(AgentEnvConfig): + """Configuration for the simple test environment.""" + + server_base_url: str = Field( + default="http://localhost:11434", + description="Base URL for an OpenAI-compatible server (without /v1)", + ) + server_model: str = Field( + default="glm-4.7-flash", + description="Model name", + ) + + +class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]): + """ + A simple test environment to validate the atropos-agent setup. + + Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks. + Scoring is based on whether the response contains the expected answer. + """ + + name = "simple_test_env" + env_config_cls = SimpleTestEnvConfig + + def __init__( + self, + config: SimpleTestEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.iter = 0 + self.test_prompts = TEST_PROMPTS + self.percent_correct_buffer: List[float] = [] + + @classmethod + def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]: + """ + Initialize configuration with local server settings from environment variables. + """ + base_url = ( + os.getenv("ATROPOS_SERVER_BASE_URL") + or os.getenv("OPENAI_BASE_URL") + or os.getenv("LLM_BASE_URL") + or "http://localhost:11434" + ) + model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "glm-4.7-flash" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + + env_config = SimpleTestEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", # For tokenization only + group_size=4, + use_wandb=False, # Disable wandb for simple testing + rollout_server_url="http://localhost:8000", + total_steps=10, + batch_size=16, + steps_per_eval=5, + max_token_length=2048, + inference_weight=1.0, + wandb_name="simple_test", + server_base_url=base_url, + server_model=model, + ) + + # OpenAI-compatible servers typically expose chat completions at /v1. + server_configs = [ + APIServerConfig( + model_name=model, + base_url=f"{base_url}/v1", + api_key=api_key, + num_max_requests_at_once=4, + num_requests_for_eval=8, + timeout=120, # Local models may be slower + ), + ] + + return env_config, server_configs + + async def setup_agent_env(self): + """Setup the environment - load test data.""" + print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.") + print(f"Using server at: {self.config.server_base_url}") + print(f"Model: {self.config.server_model}") + + async def get_next_item(self) -> Item: + """Get the next test prompt.""" + item = self.test_prompts[self.iter % len(self.test_prompts)] + self.iter += 1 + return item + + def build_task(self, item: Item) -> str: + return item["prompt"] + + def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002 + return AgentConfig( + max_steps=5, + temperature=0.7, + max_tokens=256, + system_prompt=SYSTEM_PROMPT, + ) + + async def score_trajectory(self, item: Item, final_response: str) -> float: + expected = item["expected"].lower() + response_lower = (final_response or "").lower() + score = 1.0 if expected in response_lower else 0.0 + self.percent_correct_buffer.append(score) + return score + + async def evaluate(self, *args, **kwargs): + """ + Simple evaluation - run through all test prompts once. + """ + correct = 0 + total = len(self.test_prompts) + + for item in self.test_prompts: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": item["prompt"]}, + ] + + response = await self.server.chat_completion( + messages=messages, + n=1, + max_tokens=256, + temperature=0.0, # Greedy for eval + split="eval", + ) + + response_text = response.choices[0].message.content or "" + expected = item["expected"].lower() + + if expected in response_text.lower(): + correct += 1 + + accuracy = correct / total + print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy") + return {"eval_accuracy": accuracy} + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log metrics (simplified for testing).""" + if wandb_metrics is None: + wandb_metrics = {} + + if self.percent_correct_buffer: + avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer) + wandb_metrics["train/percent_correct"] = avg_correct + print(f"Train accuracy: {avg_correct:.2%}") + self.percent_correct_buffer = [] + + await super().wandb_log(wandb_metrics) + + +if __name__ == "__main__": + # Allow running as CLI + SimpleTestEnv.cli() diff --git a/atropos/nomad/__init__.py b/atropos/nomad/__init__.py new file mode 100644 index 0000000000..db4af3f445 --- /dev/null +++ b/atropos/nomad/__init__.py @@ -0,0 +1,11 @@ +""" +Nomad integration for atropos-agent. + +Provides: +- NomadClient: Client for Nomad HTTP API +- Job templates for sandbox containers +""" + +from .client import NomadClient + +__all__ = ["NomadClient"] diff --git a/atropos/nomad/client.py b/atropos/nomad/client.py new file mode 100644 index 0000000000..7ff81e5183 --- /dev/null +++ b/atropos/nomad/client.py @@ -0,0 +1,452 @@ +""" +Nomad API Client for atropos-agent. + +Provides a simple async client for interacting with the Nomad HTTP API: +- Submit/stop jobs +- Query allocations +- Get allocation addresses +- Scale jobs up/down +""" + +import asyncio +import json +import os +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +import aiohttp + + +class AllocationStatus(Enum): + """Nomad allocation status.""" + PENDING = "pending" + RUNNING = "running" + COMPLETE = "complete" + FAILED = "failed" + LOST = "lost" + + +@dataclass +class Allocation: + """Information about a Nomad allocation.""" + id: str + job_id: str + task_group: str + node_id: str + status: AllocationStatus + # Network info for reaching the allocation + address: Optional[str] = None + port: Optional[int] = None + + @property + def http_address(self) -> Optional[str]: + """Get full HTTP address for the allocation.""" + if self.address and self.port: + return f"http://{self.address}:{self.port}" + return None + + +@dataclass +class JobStatus: + """Status of a Nomad job.""" + id: str + name: str + status: str + allocations: List[Allocation] = field(default_factory=list) + count: int = 0 # Number of task groups + + +class NomadClient: + """ + Async client for Nomad HTTP API. + + Usage: + client = NomadClient(address="http://localhost:4646") + + # Submit a job + await client.submit_job(job_spec) + + # Get allocations + allocs = await client.get_job_allocations("sandbox-python") + + # Scale job + await client.scale_job("sandbox-python", count=5) + """ + + def __init__( + self, + address: str = "http://localhost:4646", + token: Optional[str] = None, + timeout: float = 30.0, + ): + self.address = address.rstrip("/") + self.token = token or os.environ.get("NOMAD_TOKEN") + self.timeout = aiohttp.ClientTimeout(total=timeout) + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create HTTP session.""" + if self._session is None or self._session.closed: + headers = {} + if self.token: + headers["X-Nomad-Token"] = self.token + self._session = aiohttp.ClientSession( + timeout=self.timeout, + headers=headers, + ) + return self._session + + async def close(self): + """Close the HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def _request( + self, + method: str, + path: str, + data: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Make an HTTP request to Nomad API.""" + session = await self._get_session() + url = f"{self.address}{path}" + + try: + async with session.request(method, url, json=data) as response: + if response.status == 404: + return {"error": "not_found", "status": 404} + + text = await response.text() + if not text: + return {"status": response.status} + + try: + result = json.loads(text) + except json.JSONDecodeError: + return {"text": text, "status": response.status} + + if response.status >= 400: + return {"error": result, "status": response.status} + + return result if isinstance(result, dict) else {"data": result, "status": response.status} + + except aiohttp.ClientError as e: + return {"error": str(e), "status": 0} + + # Job Operations + + async def submit_job(self, job_spec: Dict[str, Any]) -> Dict[str, Any]: + """ + Submit a job to Nomad. + + Args: + job_spec: Job specification dict (HCL converted to JSON) + + Returns: + Response with EvalID if successful + """ + return await self._request("POST", "/v1/jobs", {"Job": job_spec}) + + async def stop_job(self, job_id: str, purge: bool = False) -> Dict[str, Any]: + """ + Stop (and optionally purge) a job. + + Args: + job_id: Job identifier + purge: If True, completely remove the job + """ + path = f"/v1/job/{job_id}" + if purge: + path += "?purge=true" + return await self._request("DELETE", path) + + async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]: + """Get job details.""" + result = await self._request("GET", f"/v1/job/{job_id}") + if "error" in result and result.get("status") == 404: + return None + return result + + async def get_job_status(self, job_id: str) -> Optional[JobStatus]: + """Get job status with allocations.""" + job = await self.get_job(job_id) + if not job: + return None + + allocs = await self.get_job_allocations(job_id) + + # Get count from task groups + count = 0 + task_groups = job.get("TaskGroups", []) + for tg in task_groups: + count += tg.get("Count", 1) + + return JobStatus( + id=job_id, + name=job.get("Name", job_id), + status=job.get("Status", "unknown"), + allocations=allocs, + count=count, + ) + + # Allocation Operations + + async def get_job_allocations(self, job_id: str) -> List[Allocation]: + """Get all allocations for a job.""" + result = await self._request("GET", f"/v1/job/{job_id}/allocations") + + if "error" in result: + return [] + + allocs_data = result.get("data", result) if isinstance(result, dict) else result + if not isinstance(allocs_data, list): + return [] + + allocations = [] + for alloc_data in allocs_data: + # Parse allocation info + alloc_id = alloc_data.get("ID", "") + status_str = alloc_data.get("ClientStatus", "unknown") + + try: + status = AllocationStatus(status_str) + except ValueError: + status = AllocationStatus.PENDING + + # Get network info - need to fetch detailed allocation for this + address = None + port = None + + # First try the summary data + resources = alloc_data.get("AllocatedResources") or {} + shared = resources.get("Shared") or {} + networks = shared.get("Networks") or [] + + # If no networks in summary, fetch detailed allocation + if not networks and alloc_id: + detailed = await self.get_allocation(alloc_id) + if detailed: + resources = detailed.get("AllocatedResources") or {} + shared = resources.get("Shared") or {} + networks = shared.get("Networks") or [] + + if networks: + network = networks[0] + address = network.get("IP") + # Look for dynamic ports + dyn_ports = network.get("DynamicPorts") or [] + for dp in dyn_ports: + if dp.get("Label") == "http": + port = dp.get("Value") + break + + allocations.append(Allocation( + id=alloc_id, + job_id=job_id, + task_group=alloc_data.get("TaskGroup", ""), + node_id=alloc_data.get("NodeID", ""), + status=status, + address=address, + port=port, + )) + + return allocations + + async def get_allocation(self, alloc_id: str) -> Optional[Dict[str, Any]]: + """Get detailed allocation info.""" + result = await self._request("GET", f"/v1/allocation/{alloc_id}") + if "error" in result and result.get("status") == 404: + return None + return result + + # Scaling Operations + + async def scale_job(self, job_id: str, count: int, task_group: str = "sandbox") -> Dict[str, Any]: + """ + Scale a job's task group to specified count. + + Args: + job_id: Job identifier + count: Desired number of allocations + task_group: Name of task group to scale + """ + payload = { + "Count": count, + "Target": { + "Group": task_group, + }, + } + return await self._request("POST", f"/v1/job/{job_id}/scale", payload) + + async def get_job_scale_status(self, job_id: str) -> Dict[str, int]: + """ + Get current scale status for a job. + + Returns: + Dict mapping task group name to count + """ + result = await self._request("GET", f"/v1/job/{job_id}/scale") + + if "error" in result: + return {} + + task_groups = result.get("TaskGroups", {}) + return { + name: info.get("Running", 0) + for name, info in task_groups.items() + } + + # Health Check + + async def is_healthy(self) -> bool: + """Check if Nomad is reachable and healthy.""" + try: + result = await self._request("GET", "/v1/status/leader") + return "error" not in result + except Exception: + return False + + async def get_leader(self) -> Optional[str]: + """Get current Nomad leader address.""" + result = await self._request("GET", "/v1/status/leader") + if isinstance(result, dict) and "data" in result: + return result["data"] + return None + + +def load_job_template( + template_name: str = "sandbox", + **kwargs, +) -> Dict[str, Any]: + """ + Load and configure a job template. + + Args: + template_name: Name of template (e.g., "sandbox") + **kwargs: Template variables to substitute + + Returns: + Job specification dict ready for Nomad API + """ + # Default job template for sandbox container + if template_name == "sandbox": + return create_sandbox_job(**kwargs) + else: + raise ValueError(f"Unknown template: {template_name}") + + +def create_sandbox_job( + job_id: str = "atropos-sandbox", + image: str = "atropos-sandbox:local", # Use :local tag to avoid registry pull + count: int = 1, + slots_per_container: int = 10, + privileged: bool = False, + cpu: int = 500, + memory: int = 512, + port: int = 8080, + datacenter: str = "dc1", +) -> Dict[str, Any]: + """ + Create a sandbox job specification. + + This job runs the sandbox_server.py inside a Python container, + with the specified number of slots for agent workspaces. + + Args: + job_id: Unique job identifier + image: Docker image to use + count: Number of container instances + slots_per_container: Number of slots per container + privileged: Run container in privileged mode (recommended for bubblewrap) + cpu: CPU allocation in MHz + memory: Memory allocation in MB + port: HTTP port for sandbox server + datacenter: Nomad datacenter + + Returns: + Job specification dict + """ + return { + "ID": job_id, + "Name": job_id, + "Type": "service", + "Datacenters": [datacenter], + "TaskGroups": [ + { + "Name": "sandbox", + "Count": count, + # Speed up deployments and avoid Consul checks. Without this, Nomad may + # keep an "active deployment" around for the default MinHealthyTime, + # which blocks immediate scaling under load. + "Update": { + "HealthCheck": "task_states", + "MinHealthyTime": 0, + }, + "Networks": [ + { + "Mode": "host", + "DynamicPorts": [ + { + "Label": "http", + "To": port, + } + ], + } + ], + "Tasks": [ + { + "Name": "sandbox-server", + "Driver": "docker", + "Config": { + "image": image, + "force_pull": False, # Use local image, don't try to pull + "ports": ["http"], + "privileged": privileged, + "command": "python", + "args": [ + "sandbox_server.py", + "--port", str(port), + "--slots", str(slots_per_container), + "--data-dir", "/data", + ], + # Note: On Linux, you can mount persistent storage: + # "volumes": ["${NOMAD_ALLOC_DIR}/data:/data"], + # On macOS/Docker Desktop, skip volumes for PoC + # (container /data is ephemeral but works for testing) + }, + "Env": { + "PYTHONUNBUFFERED": "1", + "NOMAD_ALLOC_DIR": "${NOMAD_ALLOC_DIR}", + }, + "Resources": { + "CPU": cpu, + "MemoryMB": memory, + }, + # Note: Services with Checks require Consul, which we skip for the PoC + } + ], + "RestartPolicy": { + "Attempts": 3, + "Interval": 300_000_000_000, # 5 minutes + "Delay": 10_000_000_000, # 10 seconds + "Mode": "delay", + }, + "ReschedulePolicy": { + "Attempts": 5, + "Interval": 3600_000_000_000, # 1 hour + "Delay": 30_000_000_000, # 30 seconds + "DelayFunction": "exponential", + "MaxDelay": 300_000_000_000, # 5 minutes + "Unlimited": False, + }, + } + ], + } diff --git a/atropos/sandbox_server.py b/atropos/sandbox_server.py new file mode 100644 index 0000000000..e9635b9622 --- /dev/null +++ b/atropos/sandbox_server.py @@ -0,0 +1,1912 @@ +#!/usr/bin/env python3 +""" +Sandbox Server - HTTP server that runs inside Nomad containers. + +This server handles tool execution requests from the SlotPool/SandboxExecutor. +Each slot has an isolated workspace directory where tools execute. + +Usage (inside container): + python -m atropos_agent.sandbox_server --port 8080 --slots 10 + +API: + POST /execute - Execute a single tool in a slot's workspace + POST /batch - Execute multiple tools in parallel + GET /health - Health check and status + POST /reset - Reset a slot's workspace (clear files) +""" + +import argparse +import asyncio +import base64 +import hashlib +import json +import os +import socket +import shutil +import signal +import subprocess +import tempfile +import tarfile +import uuid +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +from aiohttp import web + + +# Check if bubblewrap is available +def _check_bwrap_available() -> bool: + """Check if bubblewrap (bwrap) is installed and usable for sandboxing.""" + try: + version = subprocess.run( + ["bwrap", "--version"], + capture_output=True, + timeout=5, + ) + if version.returncode != 0: + return False + + # Some environments (e.g. Docker without extra capabilities) have bwrap installed + # but can't create the required namespaces/mounts. Probe a minimal sandbox so we + # can fall back to unsandboxed execution instead of failing all bash tools. + probe = subprocess.run( + [ + "bwrap", + "--unshare-user", + "--unshare-pid", + "--unshare-uts", + "--unshare-ipc", + "--die-with-parent", + "--ro-bind", + "/", + "/", + "--proc", + "/proc", + "--dev", + "/dev", + "--tmpfs", + "/tmp", + "/bin/sh", + "-c", + "true", + ], + capture_output=True, + timeout=5, + ) + return probe.returncode == 0 + except (FileNotFoundError, subprocess.TimeoutExpired): + return False + + +BWRAP_AVAILABLE = _check_bwrap_available() + + +class SlotState(Enum): + """State of a sandbox slot.""" + AVAILABLE = "available" + EXECUTING = "executing" + + +@dataclass +class SlotInfo: + """Information about a slot in this container.""" + slot_id: str + workspace_dir: Path + state: SlotState = SlotState.AVAILABLE + current_execution_id: Optional[str] = None + + +@dataclass +class ExecuteRequest: + """Request to execute a tool in a slot.""" + slot_id: str + tool: str # Tool name: "bash", "bash_stateful", "read_file", "write_file", "tmux" + args: Dict[str, Any] + execution_id: Optional[str] = None # For tracking + timeout: float = 30.0 + + +@dataclass +class ExecuteResponse: + """Response from tool execution.""" + success: bool + output: str = "" + error: str = "" + execution_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "success": self.success, + "output": self.output, + "error": self.error, + "execution_id": self.execution_id, + "metadata": self.metadata, + } + + +@dataclass +class _StatefulTmuxSession: + """ + Per-slot stateful tmux session hosted inside a long-lived bubblewrap sandbox. + + This provides: + - PID namespace isolation (so state doesn't leak across slots) + - Persistent process state across tool calls (tmux session stays alive) + - Kill/cleanup by terminating the bwrap process + """ + + bwrap_proc: Optional[asyncio.subprocess.Process] = None + sock_relpath: str = "tmux.sock" + session_name: str = "s" + pane_width: int = 120 + pane_height: int = 40 + allow_network: bool = True + prev_capture: str = "" + record_relpath: str = ".asciinema.cast" + record_offset: int = 0 + + +class SandboxServer: + """ + HTTP server for tool execution inside a Nomad container. + + Manages multiple slots, each with an isolated workspace directory. + Tools execute within the slot's workspace (via chdir). + """ + + def __init__( + self, + data_dir: str = "/data", + num_slots: int = 10, + max_output_size: int = 50000, + max_file_size: int = 100000, + max_artifact_size: int = 5_000_000, + max_artifact_entries: int = 5_000, + stateful_dir: Optional[str] = None, + ): + self.data_dir = Path(data_dir) + self.num_slots = num_slots + self.max_output_size = max_output_size + self.max_file_size = max_file_size + self.max_artifact_size = max_artifact_size + self.max_artifact_entries = max_artifact_entries + + # Directory for stateful per-slot runtime artifacts (tmux socket, temp files). + # IMPORTANT: this MUST be on a filesystem that supports unix domain sockets. + # + # On macOS + Docker Desktop + Nomad, `${NOMAD_ALLOC_DIR}/data` is often backed by a + # host-shared filesystem that rejects unix sockets (e.g. Errno 95). + # + # TODO(prod): confirm best default on Linux clusters (likely `${NOMAD_ALLOC_DIR}/local/atropos_stateful`) + # and ensure it is backed by node-local disk (not NFS/CSI that may not support unix sockets). + self._stateful_dir = self._choose_stateful_dir(stateful_dir) + + # Initialize slots + self.slots: Dict[str, SlotInfo] = {} + self._bwrap_available = BWRAP_AVAILABLE + self._init_slots() + + # Lock per slot to prevent concurrent execution in same slot + self.slot_locks: Dict[str, asyncio.Lock] = { + slot_id: asyncio.Lock() for slot_id in self.slots + } + + # Per-slot stateful session state (created lazily). + self._stateful_tmux: Dict[str, _StatefulTmuxSession] = {} + + def _choose_stateful_dir(self, stateful_dir: Optional[str]) -> Path: + configured = stateful_dir or os.environ.get("ATROPOS_STATEFUL_DIR") + if configured: + path = Path(configured) + if not self._dir_supports_unix_sockets(path): + raise RuntimeError( + f"Configured stateful_dir={path} does not support unix domain sockets; " + "set ATROPOS_STATEFUL_DIR to a different path." + ) + return path + + candidates: List[Path] = [] + nomad_alloc_dir = os.environ.get("NOMAD_ALLOC_DIR") + if nomad_alloc_dir: + candidates.append(Path(nomad_alloc_dir) / "local" / "atropos_stateful") + candidates.append(Path("/tmp/atropos_stateful")) + + for path in candidates: + if self._dir_supports_unix_sockets(path): + return path + + # Absolute last resort: try /tmp even if the probe failed (should be rare). + fallback = Path("/tmp/atropos_stateful") + fallback.mkdir(parents=True, exist_ok=True) + return fallback + + def _dir_supports_unix_sockets(self, path: Path) -> bool: + try: + path.mkdir(parents=True, exist_ok=True) + except Exception: + return False + + probe_path = path / f".atropos_socket_probe_{uuid.uuid4().hex}.sock" + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + s.bind(str(probe_path)) + except OSError: + return False + finally: + try: + s.close() + except Exception: + pass + try: + probe_path.unlink(missing_ok=True) + except Exception: + pass + + return True + + def _stateful_runtime_dir(self, slot_id: str) -> Path: + # slot_id is always server-generated (slot_0, slot_1, ...) and not user-controlled. + return self._stateful_dir / slot_id + + def _stateful_tmux_sock_path(self, slot_id: str, sess: _StatefulTmuxSession) -> Path: + return self._stateful_runtime_dir(slot_id) / sess.sock_relpath + + def _stateful_tmux_target(self, sess: _StatefulTmuxSession) -> str: + # Be explicit about pane targeting. Some tmux operations behave differently + # when only the session name is provided and there is no attached client. + return f"{sess.session_name}:0.0" + + def _init_slots(self) -> None: + """Initialize slot workspaces.""" + self.data_dir.mkdir(parents=True, exist_ok=True) + + for i in range(self.num_slots): + slot_id = f"slot_{i}" + workspace_dir = self.data_dir / slot_id + workspace_dir.mkdir(parents=True, exist_ok=True) + + self.slots[slot_id] = SlotInfo( + slot_id=slot_id, + workspace_dir=workspace_dir, + state=SlotState.AVAILABLE, + ) + + def _validate_path(self, workspace_dir: Path, path: str) -> Optional[Path]: + """ + Validate and resolve a path within the workspace. + Returns None if path escapes workspace. + """ + try: + workspace_root = workspace_dir.resolve() + full_path = (workspace_root / path).resolve() + if not full_path.is_relative_to(workspace_root): + return None + return full_path + except Exception: + return None + + def _mime_type_from_suffix(self, path: Path) -> str: + suffix = path.suffix.lower() + mapping = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + ".gif": "image/gif", + ".txt": "text/plain", + ".log": "text/plain", + ".json": "application/json", + ".xml": "application/xml", + ".html": "text/html", + ".md": "text/markdown", + ".tar.gz": "application/gzip", + ".tgz": "application/gzip", + } + return mapping.get(suffix, "application/octet-stream") + + def _clamp_int(self, value: Any, *, default: int, minimum: int, maximum: int) -> int: + try: + n = int(value) + except Exception: + return default + return max(minimum, min(maximum, n)) + + def _safe_relpath(self, workspace_dir: Path, path: Path) -> str: + try: + return str(path.resolve().relative_to(workspace_dir.resolve())) + except Exception: + return str(path.name) + + async def artifacts_read( + self, + workspace_dir: Path, + path: str, + *, + encoding: str = "text", + max_bytes: Optional[int] = None, + include_sha256: bool = False, + ) -> web.Response: + resolved = self._validate_path(workspace_dir, path) + if resolved is None: + return web.json_response({"success": False, "error": "Access denied: path outside workspace"}, status=403) + if not resolved.exists(): + return web.json_response({"success": False, "error": f"Not found: {path}"}, status=404) + if not resolved.is_file(): + return web.json_response({"success": False, "error": f"Not a file: {path}"}, status=400) + + enc = (encoding or "text").strip().lower() + if enc not in {"text", "base64"}: + return web.json_response({"success": False, "error": f"Unsupported encoding: {encoding}"}, status=400) + + limit = self.max_artifact_size + if max_bytes is not None: + limit = self._clamp_int(max_bytes, default=limit, minimum=1, maximum=self.max_artifact_size) + + file_size = resolved.stat().st_size + truncated = file_size > limit + + with resolved.open("rb") as f: + data = f.read(limit) + + sha256_hex: Optional[str] = None + if include_sha256: + sha256_hex = hashlib.sha256(data).hexdigest() + + if enc == "text": + content = data.decode("utf-8", errors="replace") + else: + content = base64.b64encode(data).decode("ascii") + + return web.json_response( + { + "success": True, + "content": content, + "encoding": enc, + "truncated": truncated, + "bytes": len(data), + "file_size": file_size, + "path": self._safe_relpath(workspace_dir, resolved), + "mime": self._mime_type_from_suffix(resolved), + "sha256": sha256_hex, + } + ) + + async def artifacts_list( + self, + workspace_dir: Path, + path: str, + *, + recursive: bool = False, + max_entries: Optional[int] = None, + ) -> web.Response: + resolved = self._validate_path(workspace_dir, path) + if resolved is None: + return web.json_response({"success": False, "error": "Access denied: path outside workspace"}, status=403) + if not resolved.exists(): + return web.json_response({"success": False, "error": f"Not found: {path}"}, status=404) + + limit = self.max_artifact_entries + if max_entries is not None: + limit = self._clamp_int(max_entries, default=limit, minimum=1, maximum=self.max_artifact_entries) + + workspace_root = workspace_dir.resolve() + + entries: List[Dict[str, Any]] = [] + + def _maybe_add(p: Path) -> None: + if len(entries) >= limit: + return + try: + rp = p.resolve() + if not rp.is_relative_to(workspace_root): + return + st = rp.stat() + entries.append( + { + "path": str(rp.relative_to(workspace_root)), + "is_dir": rp.is_dir(), + "size": st.st_size, + "mtime": st.st_mtime, + } + ) + except Exception: + return + + if resolved.is_file(): + _maybe_add(resolved) + return web.json_response({"success": True, "entries": entries, "truncated": False}) + + if not resolved.is_dir(): + return web.json_response({"success": False, "error": f"Unsupported path type: {path}"}, status=400) + + if recursive: + for root, dirnames, filenames in os.walk(resolved, followlinks=False): + # Ensure we don't traverse out via tricky paths. + root_path = Path(root) + for name in sorted(dirnames): + _maybe_add(root_path / name) + for name in sorted(filenames): + _maybe_add(root_path / name) + if len(entries) >= limit: + break + else: + for child in sorted(resolved.iterdir()): + _maybe_add(child) + if len(entries) >= limit: + break + + truncated = len(entries) >= limit + return web.json_response({"success": True, "entries": entries, "truncated": truncated}) + + async def artifacts_archive( + self, + workspace_dir: Path, + path: str, + *, + archive_format: str = "tar.gz", + max_bytes: Optional[int] = None, + max_entries: Optional[int] = None, + ) -> web.Response: + fmt = (archive_format or "tar.gz").strip().lower() + if fmt not in {"tar.gz", "tgz"}: + return web.json_response({"success": False, "error": f"Unsupported archive format: {archive_format}"}, status=400) + + resolved = self._validate_path(workspace_dir, path) + if resolved is None: + return web.json_response({"success": False, "error": "Access denied: path outside workspace"}, status=403) + if not resolved.exists(): + return web.json_response({"success": False, "error": f"Not found: {path}"}, status=404) + + byte_limit = self.max_artifact_size + if max_bytes is not None: + byte_limit = self._clamp_int(max_bytes, default=byte_limit, minimum=1, maximum=self.max_artifact_size) + + entry_limit = self.max_artifact_entries + if max_entries is not None: + entry_limit = self._clamp_int(max_entries, default=entry_limit, minimum=1, maximum=self.max_artifact_entries) + + workspace_root = workspace_dir.resolve() + + # Create a temp archive, then enforce size limits. + tmp_path = None + try: + with tempfile.NamedTemporaryFile(prefix="atropos-artifacts-", suffix=".tar.gz", delete=False) as tmp: + tmp_path = Path(tmp.name) + + count = 0 + with tarfile.open(tmp_path, mode="w:gz") as tf: + if resolved.is_file(): + arcname = str(resolved.resolve().relative_to(workspace_root)) + tf.add(resolved, arcname=arcname, recursive=False) + count = 1 + else: + for root, _dirnames, filenames in os.walk(resolved, followlinks=False): + root_path = Path(root) + for name in sorted(filenames): + if count >= entry_limit: + break + fp = root_path / name + try: + rp = fp.resolve() + if not rp.is_relative_to(workspace_root): + continue + if not rp.is_file(): + continue + arcname = str(rp.relative_to(workspace_root)) + tf.add(rp, arcname=arcname, recursive=False) + count += 1 + except Exception: + continue + if count >= entry_limit: + break + + archive_size = tmp_path.stat().st_size if tmp_path.exists() else 0 + if archive_size > byte_limit: + return web.json_response( + { + "success": False, + "error": f"Archive too large: {archive_size} bytes (max {byte_limit})", + "bytes": archive_size, + }, + status=413, + ) + + data = tmp_path.read_bytes() if tmp_path.exists() else b"" + content = base64.b64encode(data).decode("ascii") + return web.json_response( + { + "success": True, + "content": content, + "encoding": "base64", + "format": "tar.gz", + "bytes": len(data), + "entry_count": count, + } + ) + finally: + if tmp_path is not None: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + + def _build_bwrap_command( + self, + workspace_dir: Path, + command: str, + allow_network: bool = True, + isolate_pid: bool = True, + tmp_bind_dir: Optional[Path] = None, + ) -> List[str]: + """ + Build bubblewrap command for isolated execution. + + Creates a sandboxed environment where: + - Only the slot's workspace is visible and writable + - System binaries are read-only + - PID namespace is isolated (can't see other processes) + - User namespace isolates privileges + """ + bwrap_args = [ + "bwrap", + # Namespace isolation + "--unshare-user", # User namespace (root in sandbox) + "--unshare-uts", # UTS namespace (hostname) + "--unshare-ipc", # IPC namespace + "--die-with-parent", # Kill sandbox when parent dies + + # Filesystem: read-only system, writable workspace only + "--ro-bind", "/usr", "/usr", + "--ro-bind", "/bin", "/bin", + ] + + if isolate_pid: + # PID namespace isolation prevents persistent background processes. + # For "stateful" terminal sessions (tmux), we intentionally disable this. + bwrap_args.insert(3, "--unshare-pid") + + # Handle /lib and /lib64 (may be symlinks or directories) + if Path("/lib").exists(): + if Path("/lib").is_symlink(): + bwrap_args.extend(["--symlink", os.readlink("/lib"), "/lib"]) + else: + bwrap_args.extend(["--ro-bind", "/lib", "/lib"]) + + if Path("/lib64").exists(): + if Path("/lib64").is_symlink(): + bwrap_args.extend(["--symlink", os.readlink("/lib64"), "/lib64"]) + else: + bwrap_args.extend(["--ro-bind", "/lib64", "/lib64"]) + + # /etc is needed for things like /etc/passwd, timezone, etc. + bwrap_args.extend(["--ro-bind", "/etc", "/etc"]) + + # Required virtual filesystems + bwrap_args.extend(["--proc", "/proc", "--dev", "/dev"]) + + if tmp_bind_dir is None: + bwrap_args.extend(["--tmpfs", "/tmp"]) + else: + bwrap_args.extend(["--bind", str(tmp_bind_dir.resolve()), "/tmp"]) + + # THE KEY: Only this slot's workspace is visible and writable! + # Map to /workspace inside the sandbox + bwrap_args.extend([ + "--bind", str(workspace_dir.resolve()), "/workspace", + "--chdir", "/workspace", + ]) + + # Network isolation (optional) + if not allow_network: + bwrap_args.append("--unshare-net") + + # Execute the command via sh + bwrap_args.extend(["/bin/sh", "-c", command]) + + return bwrap_args + + # --------------------------------------------------------------------- + # Stateful tmux session helpers + # --------------------------------------------------------------------- + + def _get_stateful_tmux(self, slot_id: str) -> _StatefulTmuxSession: + existing = self._stateful_tmux.get(slot_id) + if existing is not None: + return existing + sess = _StatefulTmuxSession() + self._stateful_tmux[slot_id] = sess + return sess + + async def _run_host_cmd( + self, + argv: List[str], + *, + cwd: Path, + timeout_s: float, + ) -> tuple[int, str, str]: + proc = await asyncio.create_subprocess_exec( + *argv, + cwd=str(cwd), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout_s) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return -1, "", "timeout" + + stdout = (stdout_b or b"").decode("utf-8", errors="replace") + stderr = (stderr_b or b"").decode("utf-8", errors="replace") + return int(proc.returncode or 0), stdout, stderr + + def _stateful_init_script_path(self, workspace_dir: Path) -> Path: + return workspace_dir / ".atropos_stateful_init.sh" + + async def _ensure_stateful_tmux_session( + self, + *, + slot_id: str, + workspace_dir: Path, + allow_network: bool, + require_stateful_sandbox: bool = False, + pane_width: Optional[int] = None, + pane_height: Optional[int] = None, + ) -> _StatefulTmuxSession: + """ + Ensure a per-slot tmux server is running inside a long-lived bwrap sandbox. + + If bubblewrap is unavailable, may fall back to starting tmux directly in the + container's host namespace (less isolated; intended for local/dev only). + """ + if shutil.which("tmux") is None: + raise RuntimeError("tmux is not installed") + if shutil.which("asciinema") is None: + raise RuntimeError("asciinema is not installed") + sess = self._get_stateful_tmux(slot_id) + + if pane_width is not None: + sess.pane_width = self._clamp_int(pane_width, default=sess.pane_width, minimum=20, maximum=500) + if pane_height is not None: + sess.pane_height = self._clamp_int(pane_height, default=sess.pane_height, minimum=10, maximum=200) + + # Restart if network policy changed. + if sess.bwrap_proc is not None and sess.allow_network != bool(allow_network): + await self._stop_stateful_tmux(slot_id) + + sess.allow_network = bool(allow_network) + + runtime_dir = self._stateful_runtime_dir(slot_id) + runtime_dir.mkdir(parents=True, exist_ok=True) + sock_path = self._stateful_tmux_sock_path(slot_id, sess) + record_path = workspace_dir / sess.record_relpath + + # In bwrap mode, the presence of a live bwrap process is the source of truth. + if self._bwrap_available: + if sess.bwrap_proc is not None and sess.bwrap_proc.returncode is None: + return sess + else: + # In host fallback mode, check whether the tmux session already exists. + if require_stateful_sandbox: + raise RuntimeError("Stateful sandboxing required but bubblewrap is unavailable in this environment") + if not allow_network: + raise RuntimeError("Network isolation requested but bubblewrap is unavailable in this environment") + if sock_path.exists(): + rc, _out, _err = await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "has-session", "-t", sess.session_name], + cwd=workspace_dir, + timeout_s=1.0, + ) + if rc == 0: + return sess + + # Cleanup any stale socket/recording file. + try: + sock_path.unlink(missing_ok=True) + except Exception: + pass + try: + record_path.unlink(missing_ok=True) + except Exception: + pass + + # Best-effort: kill any tmux server that might still be attached to this socket. + try: + await self._run_host_cmd(["tmux", "-S", str(sock_path), "kill-server"], cwd=workspace_dir, timeout_s=2.0) + except Exception: + pass + + # Start sandboxed tmux server. + if self._bwrap_available: + init_script = self._stateful_init_script_path(workspace_dir) + init_script.write_text( + "\n".join( + [ + "#!/bin/sh", + "set -eu", + "export TERM=xterm-256color", + "export SHELL=/bin/bash", + f"sock=/tmp/{sess.sock_relpath}", + f"record=/workspace/{sess.record_relpath}", + f"session={sess.session_name!s}", + "rm -f \"$sock\"", + # Start a shell in a new tmux session. + # NOTE: Use `script` to allocate a PTY so tmux behaves consistently + # even when started from a non-interactive background process. + f"script -qc \"tmux -S \\\"$sock\\\" new-session -d -s \\\"$session\\\" -x {sess.pane_width} -y {sess.pane_height} asciinema rec --overwrite -c sh \\\"$record\\\"\" /dev/null", + # Large but bounded history for capture-pane -S - + "tmux -S \"$sock\" set-option -g history-limit 20000 >/dev/null 2>&1 || true", + # Keep PID namespace alive so background processes persist. + "exec tail -f /dev/null", + "", + ] + ), + encoding="utf-8", + ) + init_script.chmod(0o700) + + bwrap_cmd = self._build_bwrap_command( + workspace_dir, + "/workspace/.atropos_stateful_init.sh", + allow_network=bool(allow_network), + isolate_pid=True, + tmp_bind_dir=runtime_dir, + ) + + proc = await asyncio.create_subprocess_exec( + *bwrap_cmd, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + sess.bwrap_proc = proc + else: + # Dev fallback: start tmux in host namespace. + rc, _out, _err = await self._run_host_cmd( + [ + "tmux", + "-S", + str(sock_path), + "new-session", + "-d", + "-s", + sess.session_name, + "-x", + str(sess.pane_width), + "-y", + str(sess.pane_height), + "asciinema", + "rec", + "--overwrite", + "-c", + "sh", + str(record_path), + ], + cwd=workspace_dir, + timeout_s=5.0, + ) + if rc != 0: + raise RuntimeError("Failed to start tmux (host fallback)") + sess.bwrap_proc = None + + # Wait for the socket to appear. + for _ in range(200): # ~10s + if sock_path.exists(): + break + if sess.bwrap_proc is not None and sess.bwrap_proc.returncode is not None: + raise RuntimeError("Stateful bwrap process exited during tmux startup") + await asyncio.sleep(0.05) + else: + await self._stop_stateful_tmux(slot_id) + raise RuntimeError("Timed out waiting for tmux socket to appear") + + # Wait for the recorded shell to be ready. We use the asciinema .cast file + # as a simple readiness signal because `tmux new-session -d` can return before + # the pane's command is actually accepting input, causing early `send-keys` + # (especially blocking ones) to hang. + for _ in range(200): # ~10s + try: + if record_path.exists() and record_path.stat().st_size > 0: + # Header is a single JSON object line; once we see any event line + # (starts with '[') we consider the shell "ready enough". + head = record_path.read_text(encoding="utf-8", errors="ignore")[:4096] + if "\n[" in head: + break + except Exception: + pass + await asyncio.sleep(0.05) + + sess.prev_capture = "" + return sess + + async def _stop_stateful_tmux(self, slot_id: str) -> None: + sess = self._stateful_tmux.get(slot_id) + if sess is None: + return + + # Best-effort: ask tmux to stop (works for both bwrap-backed and host fallback sessions). + try: + slot = self.slots.get(slot_id) + if slot is not None: + sock_path = self._stateful_tmux_sock_path(slot_id, sess) + await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "kill-server"], + cwd=slot.workspace_dir, + timeout_s=2.0, + ) + except Exception: + pass + + proc = sess.bwrap_proc + sess.bwrap_proc = None + sess.prev_capture = "" + + if proc is not None and proc.returncode is None: + try: + proc.terminate() + await asyncio.wait_for(proc.wait(), timeout=2.0) + except Exception: + try: + proc.kill() + except Exception: + pass + try: + await asyncio.wait_for(proc.wait(), timeout=2.0) + except Exception: + pass + + # Best-effort cleanup of the per-slot runtime dir (socket/temp files). + try: + runtime_dir = self._stateful_runtime_dir(slot_id) + # Safety: never delete outside our configured base dir. + if runtime_dir.resolve().is_relative_to(self._stateful_dir.resolve()): + shutil.rmtree(runtime_dir, ignore_errors=True) + except Exception: + pass + + async def execute_bash_sandboxed( + self, + workspace_dir: Path, + command: str, + timeout: float = 30.0, + allow_network: bool = True, + isolate_pid: bool = True, + ) -> ExecuteResponse: + """ + Execute a bash command in an isolated namespace via bubblewrap. + + Provides: + - Filesystem isolation: Only /workspace (slot's dir) is writable + - PID isolation: Can't see other processes + - User isolation: Runs as root in sandbox but unprivileged outside + """ + try: + bwrap_cmd = self._build_bwrap_command( + workspace_dir, + command, + allow_network, + isolate_pid=isolate_pid, + ) + + process = await asyncio.create_subprocess_exec( + *bwrap_cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=timeout, + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + return ExecuteResponse( + success=False, + error=f"Command timed out after {timeout}s", + metadata={"exit_code": -1, "timeout": True, "sandboxed": True}, + ) + + stdout_str = stdout.decode("utf-8", errors="replace") + stderr_str = stderr.decode("utf-8", errors="replace") + + # Truncate if too long + if len(stdout_str) > self.max_output_size: + stdout_str = stdout_str[:self.max_output_size] + "\n... (output truncated)" + if len(stderr_str) > self.max_output_size: + stderr_str = stderr_str[:self.max_output_size] + "\n... (output truncated)" + + exit_code = process.returncode + success = exit_code == 0 + + output = stdout_str + if stderr_str: + output = f"{stdout_str}\n[stderr]\n{stderr_str}" if stdout_str else stderr_str + + return ExecuteResponse( + success=success, + output=output.strip(), + error="" if success else f"Exit code: {exit_code}", + metadata={ + "exit_code": exit_code, + "sandboxed": True, + "network_isolated": not allow_network, + }, + ) + + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to execute sandboxed command: {str(e)}", + metadata={"sandboxed": True}, + ) + + async def execute_bash_unsandboxed( + self, + workspace_dir: Path, + command: str, + timeout: float = 30.0 + ) -> ExecuteResponse: + """Execute a bash command without sandboxing (development/fallback mode).""" + try: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(workspace_dir), + start_new_session=True, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=timeout, + ) + except asyncio.TimeoutError: + # Kill the whole process group so shell children don't keep pipes open. + if process.pid is not None: + try: + os.killpg(process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + except Exception: + process.kill() + else: + process.kill() + await process.wait() + return ExecuteResponse( + success=False, + error=f"Command timed out after {timeout}s", + metadata={ + "exit_code": -1, + "timeout": True, + "sandboxed": False, + "network_isolated": False, + }, + ) + + stdout_str = stdout.decode("utf-8", errors="replace") + stderr_str = stderr.decode("utf-8", errors="replace") + + # Truncate if too long + if len(stdout_str) > self.max_output_size: + stdout_str = stdout_str[:self.max_output_size] + "\n... (output truncated)" + if len(stderr_str) > self.max_output_size: + stderr_str = stderr_str[:self.max_output_size] + "\n... (output truncated)" + + exit_code = process.returncode + success = exit_code == 0 + + output = stdout_str + if stderr_str: + output = f"{stdout_str}\n[stderr]\n{stderr_str}" if stdout_str else stderr_str + + return ExecuteResponse( + success=success, + output=output.strip(), + error="" if success else f"Exit code: {exit_code}", + metadata={ + "exit_code": exit_code, + "sandboxed": False, + "network_isolated": False, + }, + ) + + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to execute command: {str(e)}", + metadata={"sandboxed": False}, + ) + + async def execute_bash( + self, + workspace_dir: Path, + command: str, + timeout: float = 30.0, + allow_network: bool = True, + require_sandbox: bool = False, + ) -> ExecuteResponse: + """ + Execute a bash command in the workspace. + + Uses bubblewrap for sandboxed execution if available, + otherwise falls back to unsandboxed execution. + """ + if self._bwrap_available: + result = await self.execute_bash_sandboxed( + workspace_dir, + command, + timeout, + allow_network=allow_network, + ) + if result.success: + return result + + # Bubblewrap can be installed but unusable in some container runtimes. + # If we detect a namespace failure, disable bwrap for the remainder of + # this server lifetime and fall back to unsandboxed execution. + output = f"{result.output}\n{result.error}".lower() + if "no permissions to create new namespace" in output: + self._bwrap_available = False + if require_sandbox: + return ExecuteResponse( + success=False, + error="Sandboxing required but bubblewrap is unavailable in this environment", + metadata={"sandboxed": False, "network_isolated": False, "requires_bwrap": True}, + ) + if not allow_network: + return ExecuteResponse( + success=False, + error="Network isolation requested but bubblewrap is unavailable in this environment", + metadata={"sandboxed": False, "network_isolated": False, "requires_bwrap": True}, + ) + + fallback = await self.execute_bash_unsandboxed(workspace_dir, command, timeout) + fallback.metadata = dict(fallback.metadata) + fallback.metadata["sandbox_fallback"] = True + return fallback + + return result + + if require_sandbox: + return ExecuteResponse( + success=False, + error="Sandboxing required but bubblewrap is unavailable in this environment", + metadata={"sandboxed": False, "network_isolated": False, "requires_bwrap": True}, + ) + + if not allow_network: + return ExecuteResponse( + success=False, + error="Network isolation requested but bubblewrap is unavailable in this environment", + metadata={"sandboxed": False, "network_isolated": False, "requires_bwrap": True}, + ) + + return await self.execute_bash_unsandboxed(workspace_dir, command, timeout) + + async def execute_bash_stateful( + self, + slot_id: str, + workspace_dir: Path, + command: str, + timeout: float = 30.0, + allow_network: bool = True, + require_stateful_sandbox: bool = False, + ) -> ExecuteResponse: + """ + Execute a command in a per-slot stateful terminal session. + + Implementation: + - Start (lazily) a long-lived bwrap sandbox per slot with a tmux server inside. + - Send the command into a tmux-backed shell and wait for completion. + - Capture the pane output and return a diff vs the previous capture. + + This provides both: + - persistent process state across tool calls (good for TUIs / background tasks) + - PID namespace isolation (no global process visibility across slots) + """ + command = (command or "").strip() + if not command: + return ExecuteResponse(success=False, error="Missing command") + + # Ensure the per-slot stateful sandbox exists. + try: + sess = await self._ensure_stateful_tmux_session( + slot_id=slot_id, + workspace_dir=workspace_dir, + allow_network=allow_network, + require_stateful_sandbox=require_stateful_sandbox, + ) + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to start stateful session: {e}", + metadata={"stateful": True, "requires_bwrap": bool(require_stateful_sandbox)}, + ) + + sock_path = self._stateful_tmux_sock_path(slot_id, sess) + + marker = uuid.uuid4().hex + wait_name = f"done_{marker}" + exit_marker = f"__ATROPOS_EXIT_{marker}__" + + # Append a unique completion signal and exit code marker. + # NOTE: This assumes we're at a shell prompt (not a full-screen TUI). + cmd_to_send = f"{command}; echo {exit_marker}:$?; tmux wait-for -S {wait_name}" + + # Send command + rc, _out, err = await self._run_host_cmd( + [ + "tmux", + "-S", + str(sock_path), + "send-keys", + "-t", + self._stateful_tmux_target(sess), + cmd_to_send, + "Enter", + ], + cwd=workspace_dir, + timeout_s=5.0, + ) + if rc != 0: + # Try restarting the session once (e.g. user killed tmux server). + try: + await self._stop_stateful_tmux(slot_id) + sess = await self._ensure_stateful_tmux_session( + slot_id=slot_id, + workspace_dir=workspace_dir, + allow_network=allow_network, + require_stateful_sandbox=require_stateful_sandbox, + ) + sock_path = self._stateful_tmux_sock_path(slot_id, sess) + rc, _out, err = await self._run_host_cmd( + [ + "tmux", + "-S", + str(sock_path), + "send-keys", + "-t", + self._stateful_tmux_target(sess), + cmd_to_send, + "Enter", + ], + cwd=workspace_dir, + timeout_s=5.0, + ) + except Exception: + pass + if rc != 0: + return ExecuteResponse(success=False, error=f"tmux send-keys failed: {err.strip()}") + + # Wait for completion. + rc, _out, err = await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "wait-for", wait_name], + cwd=workspace_dir, + timeout_s=float(timeout), + ) + if rc != 0: + # Timeout or failure: tear down the stateful sandbox to avoid leaking runaway processes. + await self._stop_stateful_tmux(slot_id) + return ExecuteResponse( + success=False, + error=f"Command timed out after {timeout}s", + metadata={"exit_code": -1, "timeout": True, "sandboxed": True, "stateful": True}, + ) + + # Capture output. + rc, capture, err = await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "capture-pane", "-p", "-S", "-", "-t", self._stateful_tmux_target(sess)], + cwd=workspace_dir, + timeout_s=10.0, + ) + if rc != 0: + return ExecuteResponse(success=False, error=f"tmux capture-pane failed: {err.strip()}") + + # Diff against previous capture for this slot. + prev = sess.prev_capture or "" + if prev and capture.startswith(prev): + delta = capture[len(prev) :] + else: + delta = capture + sess.prev_capture = capture + + # Extract exit code from the capture. + exit_code = 0 + marker_idx = capture.rfind(exit_marker) + if marker_idx != -1: + line = capture[marker_idx:].splitlines()[0] + try: + exit_code = int(line.split(":", 1)[1].strip()) + except Exception: + exit_code = 0 + + # Remove the marker line from delta. + cleaned_lines = [] + for line in delta.splitlines(): + if exit_marker in line: + continue + cleaned_lines.append(line) + output = "\n".join(cleaned_lines).strip() + + if len(output) > self.max_output_size: + output = output[: self.max_output_size] + "\n... (output truncated)" + + success = exit_code == 0 + return ExecuteResponse( + success=success, + output=output, + error="" if success else f"Exit code: {exit_code}", + metadata={ + "exit_code": exit_code, + "sandboxed": bool(self._bwrap_available), + "network_isolated": not allow_network, + "stateful": True, + }, + ) + + async def execute_tmux( + self, + slot_id: str, + workspace_dir: Path, + args: Dict[str, Any], + timeout: float, + allow_network: bool, + require_stateful_sandbox: bool = False, + ) -> ExecuteResponse: + """ + Direct tmux session control for TUI-style terminal interactions. + + Supported actions: + - start: ensure session exists (optionally sets pane size) + - send_keys: send keys into the session (optionally block using tmux wait-for) + - capture: capture the current pane contents (or entire history) + - stop: stop the session and tear down the stateful sandbox + """ + action = str(args.get("action") or "capture").strip().lower() + + # Always resolve session state (creates default entry). + sess = self._get_stateful_tmux(slot_id) + sock_path = self._stateful_tmux_sock_path(slot_id, sess) + + if action == "stop": + try: + await self._run_host_cmd(["tmux", "-S", str(sock_path), "kill-server"], cwd=workspace_dir, timeout_s=2.0) + except Exception: + pass + await self._stop_stateful_tmux(slot_id) + try: + sock_path.unlink(missing_ok=True) + except Exception: + pass + self._stateful_tmux.pop(slot_id, None) + return ExecuteResponse(success=True, output="stopped", metadata={"stateful": True}) + + # start/send/capture: ensure session exists + try: + await self._ensure_stateful_tmux_session( + slot_id=slot_id, + workspace_dir=workspace_dir, + allow_network=allow_network, + require_stateful_sandbox=require_stateful_sandbox, + pane_width=args.get("pane_width"), + pane_height=args.get("pane_height"), + ) + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to ensure tmux session: {e}", + metadata={"stateful": True, "requires_bwrap": bool(require_stateful_sandbox)}, + ) + + if action == "start": + sess.record_offset = 0 + return ExecuteResponse( + success=True, + output="started", + metadata={ + "stateful": True, + "pane_width": sess.pane_width, + "pane_height": sess.pane_height, + "socket": sess.sock_relpath, + "session": sess.session_name, + "recording": sess.record_relpath, + }, + ) + + if action == "send_keys": + keys = args.get("keys") + if isinstance(keys, str): + key_list = [keys] + elif isinstance(keys, list) and all(isinstance(k, str) for k in keys): + key_list = list(keys) + else: + return ExecuteResponse(success=False, error="tmux send_keys requires 'keys' as string or list[string]") + + block = bool(args.get("block", False)) + min_wait_s = float(args.get("min_wait_s", 0.0) or 0.0) + max_wait_s = float(args.get("max_wait_s", timeout) or timeout) + + wait_name = None + if block: + # Blocking mode assumes we're at a shell prompt (not a full-screen TUI). + marker = uuid.uuid4().hex + wait_name = f"keys_{marker}" + if key_list: + key_list[-1] = key_list[-1] + f"; tmux wait-for -S {wait_name}" + else: + key_list = [f"tmux wait-for -S {wait_name}"] + if not key_list or key_list[-1] not in {"Enter", "C-m", "KPEnter"}: + key_list.append("Enter") + + rc, _out, err = await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "send-keys", "-t", self._stateful_tmux_target(sess), *key_list], + cwd=workspace_dir, + timeout_s=5.0, + ) + if rc != 0: + return ExecuteResponse(success=False, error=f"tmux send-keys failed: {err.strip()}") + + if block and wait_name: + rc, _out, err = await self._run_host_cmd( + ["tmux", "-S", str(sock_path), "wait-for", wait_name], + cwd=workspace_dir, + timeout_s=max_wait_s, + ) + if rc != 0: + await self._stop_stateful_tmux(slot_id) + return ExecuteResponse(success=False, error=f"tmux blocked keys timed out after {max_wait_s}s") + elif min_wait_s > 0: + await asyncio.sleep(min_wait_s) + + return ExecuteResponse(success=True, output="ok", metadata={"stateful": True, "blocked": block}) + + if action == "capture": + capture_entire = bool(args.get("capture_entire", False)) + tmux_args: List[str] = ["tmux", "-S", str(sock_path), "capture-pane", "-p"] + if capture_entire: + tmux_args.extend(["-S", "-"]) + tmux_args.extend(["-t", self._stateful_tmux_target(sess)]) + + rc, out, err = await self._run_host_cmd( + tmux_args, + cwd=workspace_dir, + timeout_s=min(10.0, max(1.0, float(timeout or 10.0))), + ) + if rc != 0: + return ExecuteResponse(success=False, error=f"tmux capture-pane failed: {err.strip()}") + + if len(out) > self.max_output_size: + out = out[: self.max_output_size] + "\n... (output truncated)" + return ExecuteResponse( + success=True, + output=out, + metadata={"stateful": True, "capture_entire": capture_entire}, + ) + + if action == "stream": + # Streaming: return asciinema .cast lines since last offset. + record_path = workspace_dir / sess.record_relpath + if not record_path.exists(): + return ExecuteResponse( + success=True, + output="", + metadata={"stateful": True, "stream": True, "offset": sess.record_offset, "recording": sess.record_relpath}, + ) + + if bool(args.get("reset", False)): + sess.record_offset = 0 + + max_bytes = self._clamp_int( + args.get("max_bytes", self.max_output_size), + default=self.max_output_size, + minimum=1, + maximum=self.max_output_size * 4, + ) + + try: + with record_path.open("rb") as f: + f.seek(sess.record_offset) + data = f.read(max_bytes) + except Exception as e: + return ExecuteResponse(success=False, error=f"Failed to read asciinema recording: {e}") + + if not data: + return ExecuteResponse( + success=True, + output="", + metadata={"stateful": True, "stream": True, "offset": sess.record_offset, "recording": sess.record_relpath}, + ) + + # Avoid returning partial lines. + last_nl = data.rfind(b"\n") + if last_nl == -1: + return ExecuteResponse( + success=True, + output="", + metadata={"stateful": True, "stream": True, "offset": sess.record_offset, "recording": sess.record_relpath}, + ) + chunk = data[: last_nl + 1] + sess.record_offset += len(chunk) + + out = chunk.decode("utf-8", errors="replace") + return ExecuteResponse( + success=True, + output=out, + metadata={"stateful": True, "stream": True, "offset": sess.record_offset, "recording": sess.record_relpath}, + ) + + return ExecuteResponse(success=False, error=f"Unknown tmux action: {action}") + + async def execute_read_file( + self, + workspace_dir: Path, + path: str + ) -> ExecuteResponse: + """Read a file from the workspace.""" + try: + file_path = self._validate_path(workspace_dir, path) + if file_path is None: + return ExecuteResponse( + success=False, + error="Access denied: path outside workspace", + ) + + if not file_path.exists(): + return ExecuteResponse( + success=False, + error=f"File not found: {path}", + ) + + if not file_path.is_file(): + return ExecuteResponse( + success=False, + error=f"Not a file: {path}", + ) + + size = file_path.stat().st_size + if size > self.max_file_size: + return ExecuteResponse( + success=False, + error=f"File too large: {size} bytes (max {self.max_file_size})", + ) + + content = file_path.read_text(encoding="utf-8", errors="replace") + + return ExecuteResponse( + success=True, + output=content, + metadata={"path": str(file_path), "size": size}, + ) + + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to read file: {str(e)}", + ) + + async def execute_write_file( + self, + workspace_dir: Path, + path: str, + content: str + ) -> ExecuteResponse: + """Write content to a file in the workspace.""" + try: + if len(content) > self.max_file_size: + return ExecuteResponse( + success=False, + error=f"Content too large: {len(content)} bytes (max {self.max_file_size})", + ) + + file_path = self._validate_path(workspace_dir, path) + if file_path is None: + return ExecuteResponse( + success=False, + error="Access denied: path outside workspace", + ) + + # Create parent directories + file_path.parent.mkdir(parents=True, exist_ok=True) + + file_path.write_text(content, encoding="utf-8") + + return ExecuteResponse( + success=True, + output=f"Successfully wrote {len(content)} bytes to {path}", + metadata={"path": str(file_path), "size": len(content)}, + ) + + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to write file: {str(e)}", + ) + + async def execute_tool(self, request: ExecuteRequest) -> ExecuteResponse: + """Execute a tool in a slot's workspace.""" + # Validate slot + if request.slot_id not in self.slots: + return ExecuteResponse( + success=False, + error=f"Unknown slot: {request.slot_id}", + execution_id=request.execution_id, + ) + + slot = self.slots[request.slot_id] + + # Acquire slot lock + async with self.slot_locks[request.slot_id]: + slot.state = SlotState.EXECUTING + slot.current_execution_id = request.execution_id + + try: + # Route to appropriate tool + if request.tool == "bash": + command = request.args.get("command", "") + allow_network = bool(request.args.get("allow_network", True)) + require_sandbox = bool(request.args.get("require_sandbox", False)) + result = await self.execute_bash( + slot.workspace_dir, + command, + request.timeout, + allow_network=allow_network, + require_sandbox=require_sandbox, + ) + elif request.tool == "bash_stateful": + command = request.args.get("command", "") + allow_network = bool(request.args.get("allow_network", True)) + require_stateful_sandbox = bool( + request.args.get("require_stateful_sandbox", request.args.get("require_sandbox", False)) + ) + result = await self.execute_bash_stateful( + request.slot_id, + slot.workspace_dir, + command, + request.timeout, + allow_network=allow_network, + require_stateful_sandbox=require_stateful_sandbox, + ) + elif request.tool == "read_file": + path = request.args.get("path", "") + result = await self.execute_read_file(slot.workspace_dir, path) + elif request.tool == "write_file": + path = request.args.get("path", "") + content = request.args.get("content", "") + result = await self.execute_write_file(slot.workspace_dir, path, content) + elif request.tool == "tmux": + allow_network = bool(request.args.get("allow_network", True)) + require_stateful_sandbox = bool( + request.args.get("require_stateful_sandbox", request.args.get("require_sandbox", False)) + ) + result = await self.execute_tmux( + request.slot_id, + slot.workspace_dir, + request.args, + request.timeout, + allow_network, + require_stateful_sandbox=require_stateful_sandbox, + ) + else: + result = ExecuteResponse( + success=False, + error=f"Unknown tool: {request.tool}", + ) + + result.execution_id = request.execution_id + return result + + finally: + slot.state = SlotState.AVAILABLE + slot.current_execution_id = None + + async def reset_slot(self, slot_id: str) -> ExecuteResponse: + """Reset a slot's workspace (delete all files).""" + if slot_id not in self.slots: + return ExecuteResponse( + success=False, + error=f"Unknown slot: {slot_id}", + ) + + slot = self.slots[slot_id] + + async with self.slot_locks[slot_id]: + try: + # Stop any long-lived per-slot stateful sandbox process. + try: + await self._stop_stateful_tmux(slot_id) + except Exception: + pass + self._stateful_tmux.pop(slot_id, None) + + # Best-effort cleanup of stateful runtime dir (socket/temp files), + # even if the session was never registered (e.g. crash mid-start). + try: + runtime_dir = self._stateful_runtime_dir(slot_id) + if runtime_dir.resolve().is_relative_to(self._stateful_dir.resolve()): + shutil.rmtree(runtime_dir, ignore_errors=True) + except Exception: + pass + + # Remove all contents but keep the directory + for item in slot.workspace_dir.iterdir(): + if item.is_dir(): + shutil.rmtree(item) + else: + item.unlink() + + return ExecuteResponse( + success=True, + output=f"Reset workspace for {slot_id}", + ) + except Exception as e: + return ExecuteResponse( + success=False, + error=f"Failed to reset workspace: {str(e)}", + ) + + # HTTP Handlers + + async def handle_execute(self, request: web.Request) -> web.Response: + """Handle POST /execute.""" + try: + data = await request.json() + + exec_request = ExecuteRequest( + slot_id=data.get("slot_id", ""), + tool=data.get("tool", ""), + args=data.get("args", {}), + execution_id=data.get("execution_id"), + timeout=data.get("timeout", 30.0), + ) + + result = await self.execute_tool(exec_request) + return web.json_response(result.to_dict()) + + except json.JSONDecodeError: + return web.json_response( + {"success": False, "error": "Invalid JSON"}, + status=400, + ) + except Exception as e: + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def handle_batch(self, request: web.Request) -> web.Response: + """Handle POST /batch - execute multiple tools in parallel.""" + try: + data = await request.json() + + if not isinstance(data, list): + return web.json_response( + {"success": False, "error": "Expected array of requests"}, + status=400, + ) + + # Create execution requests + exec_requests = [ + ExecuteRequest( + slot_id=item.get("slot_id", ""), + tool=item.get("tool", ""), + args=item.get("args", {}), + execution_id=item.get("execution_id"), + timeout=item.get("timeout", 30.0), + ) + for item in data + ] + + # Execute in parallel + results = await asyncio.gather( + *[self.execute_tool(req) for req in exec_requests], + return_exceptions=True, + ) + + # Convert results + response_data = [] + for result in results: + if isinstance(result, BaseException): + response_data.append({ + "success": False, + "error": str(result), + }) + elif isinstance(result, ExecuteResponse): + response_data.append(result.to_dict()) + else: + response_data.append({ + "success": False, + "error": f"Unexpected result type: {type(result)}", + }) + + return web.json_response(response_data) + + except json.JSONDecodeError: + return web.json_response( + {"success": False, "error": "Invalid JSON"}, + status=400, + ) + except Exception as e: + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def handle_health(self, request: web.Request) -> web.Response: + """Handle GET /health.""" + available = sum( + 1 for slot in self.slots.values() + if slot.state == SlotState.AVAILABLE + ) + executing = sum( + 1 for slot in self.slots.values() + if slot.state == SlotState.EXECUTING + ) + + return web.json_response({ + "status": "ok", + "slots": self.num_slots, + "available": available, + "executing": executing, + "data_dir": str(self.data_dir), + "bwrap_available": bool(self._bwrap_available), + "stateful_dir": str(self._stateful_dir), + }) + + async def handle_reset(self, request: web.Request) -> web.Response: + """Handle POST /reset - reset a slot's workspace.""" + try: + data = await request.json() + slot_id = data.get("slot_id", "") + + result = await self.reset_slot(slot_id) + return web.json_response(result.to_dict()) + + except json.JSONDecodeError: + return web.json_response( + {"success": False, "error": "Invalid JSON"}, + status=400, + ) + except Exception as e: + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def handle_list_slots(self, request: web.Request) -> web.Response: + """Handle GET /slots - list all slots and their status.""" + slots_info = [] + for slot_id, slot in self.slots.items(): + slots_info.append({ + "slot_id": slot.slot_id, + "state": slot.state.value, + "workspace_dir": str(slot.workspace_dir), + "current_execution_id": slot.current_execution_id, + }) + + return web.json_response({"slots": slots_info}) + + async def handle_artifacts_read(self, request: web.Request) -> web.Response: + try: + data = await request.json() + except json.JSONDecodeError: + return web.json_response({"success": False, "error": "Invalid JSON"}, status=400) + + slot_id = data.get("slot_id", "") + if slot_id not in self.slots: + return web.json_response({"success": False, "error": f"Unknown slot: {slot_id}"}, status=404) + + path = data.get("path", "") + encoding = data.get("encoding", "text") + max_bytes = data.get("max_bytes") + include_sha256 = bool(data.get("include_sha256", False)) + + slot = self.slots[slot_id] + async with self.slot_locks[slot_id]: + return await self.artifacts_read( + slot.workspace_dir, + path, + encoding=encoding, + max_bytes=max_bytes, + include_sha256=include_sha256, + ) + + async def handle_artifacts_list(self, request: web.Request) -> web.Response: + try: + data = await request.json() + except json.JSONDecodeError: + return web.json_response({"success": False, "error": "Invalid JSON"}, status=400) + + slot_id = data.get("slot_id", "") + if slot_id not in self.slots: + return web.json_response({"success": False, "error": f"Unknown slot: {slot_id}"}, status=404) + + path = data.get("path", ".") + recursive = bool(data.get("recursive", False)) + max_entries = data.get("max_entries") + + slot = self.slots[slot_id] + async with self.slot_locks[slot_id]: + return await self.artifacts_list(slot.workspace_dir, path, recursive=recursive, max_entries=max_entries) + + async def handle_artifacts_archive(self, request: web.Request) -> web.Response: + try: + data = await request.json() + except json.JSONDecodeError: + return web.json_response({"success": False, "error": "Invalid JSON"}, status=400) + + slot_id = data.get("slot_id", "") + if slot_id not in self.slots: + return web.json_response({"success": False, "error": f"Unknown slot: {slot_id}"}, status=404) + + path = data.get("path", ".") + archive_format = data.get("format", "tar.gz") + max_bytes = data.get("max_bytes") + max_entries = data.get("max_entries") + + slot = self.slots[slot_id] + async with self.slot_locks[slot_id]: + return await self.artifacts_archive( + slot.workspace_dir, + path, + archive_format=archive_format, + max_bytes=max_bytes, + max_entries=max_entries, + ) + + def create_app(self) -> web.Application: + """Create the aiohttp application.""" + app = web.Application() + + app.router.add_post("/execute", self.handle_execute) + app.router.add_post("/batch", self.handle_batch) + app.router.add_get("/health", self.handle_health) + app.router.add_post("/reset", self.handle_reset) + app.router.add_get("/slots", self.handle_list_slots) + app.router.add_post("/artifacts/read", self.handle_artifacts_read) + app.router.add_post("/artifacts/list", self.handle_artifacts_list) + app.router.add_post("/artifacts/archive", self.handle_artifacts_archive) + + return app + + +def main(): + """Run the sandbox server.""" + parser = argparse.ArgumentParser(description="Sandbox Server for Nomad containers") + parser.add_argument("--port", type=int, default=8080, help="Port to listen on") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--slots", type=int, default=10, help="Number of slots") + parser.add_argument( + "--data-dir", + type=str, + default=os.environ.get("NOMAD_ALLOC_DIR", "/data"), + help="Base directory for slot workspaces" + ) + + args = parser.parse_args() + + # If in Nomad, use alloc dir + data_dir = args.data_dir + if os.environ.get("NOMAD_ALLOC_DIR"): + data_dir = os.path.join(os.environ["NOMAD_ALLOC_DIR"], "data") + + print(f"Starting Sandbox Server on {args.host}:{args.port}") + print(f"Data directory: {data_dir}") + print(f"Number of slots: {args.slots}") + + server = SandboxServer( + data_dir=data_dir, + num_slots=args.slots, + ) + + app = server.create_app() + web.run_app(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/atropos/slots/__init__.py b/atropos/slots/__init__.py new file mode 100644 index 0000000000..9692822740 --- /dev/null +++ b/atropos/slots/__init__.py @@ -0,0 +1,20 @@ +""" +Slot-based multiplexing for atropos-agent. + +Provides: +- Slot: Isolated workspace for a single trajectory +- SlotPool: Manages slots across Nomad allocations +- SandboxExecutor: Executes tools in sandbox containers +""" + +from .executor import SandboxExecutor +from .pool import SlotPool, SlotPoolConfig +from .slot import Slot, SlotState + +__all__ = [ + "Slot", + "SlotState", + "SlotPool", + "SlotPoolConfig", + "SandboxExecutor", +] diff --git a/atropos/slots/executor.py b/atropos/slots/executor.py new file mode 100644 index 0000000000..56e2fd3934 --- /dev/null +++ b/atropos/slots/executor.py @@ -0,0 +1,457 @@ +""" +SandboxExecutor - HTTP client for sandbox container communication. + +Sends tool execution requests to sandbox_server.py running inside Nomad containers. +Supports single and batch execution for efficiency. +""" + +import asyncio +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import aiohttp + +from .slot import Slot, SlotState +from ..tools.base import ToolCall, ToolResult + + +@dataclass +class ExecutionRequest: + """Request to execute a tool in a slot.""" + slot: Slot + tool_name: str + args: Dict[str, Any] + execution_id: str = field(default_factory=lambda: str(uuid.uuid4())) + timeout: float = 30.0 + + +@dataclass +class ExecutionResult: + """Result from sandbox execution.""" + success: bool + output: str = "" + error: str = "" + execution_id: str = "" + slot_id: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_tool_result(self) -> ToolResult: + """Convert to ToolResult for agent consumption.""" + return ToolResult( + success=self.success, + output=self.output, + error=self.error, + metadata=self.metadata, + uniq_id=self.execution_id, + ) + + +class SandboxExecutor: + """ + HTTP client for executing tools in sandbox containers. + + Communicates with sandbox_server.py running inside Nomad allocations. + Supports both single execution and batched parallel execution. + + Usage: + executor = SandboxExecutor() + + # Single execution + result = await executor.execute(slot, "bash", {"command": "ls"}) + + # Batch execution + results = await executor.execute_batch([ + (slot1, "bash", {"command": "ls"}), + (slot2, "write_file", {"path": "test.txt", "content": "hello"}), + ]) + """ + + def __init__( + self, + timeout: float = 30.0, + max_retries: int = 3, + retry_delay: float = 1.0, + ): + self.timeout = aiohttp.ClientTimeout(total=timeout) + self.max_retries = max_retries + self.retry_delay = retry_delay + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create HTTP session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(timeout=self.timeout) + return self._session + + async def close(self): + """Close HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def execute( + self, + slot: Slot, + tool_name: str, + args: Dict[str, Any], + timeout: Optional[float] = None, + ) -> ExecutionResult: + """ + Execute a tool in a slot's workspace. + + Args: + slot: Slot to execute in + tool_name: Name of tool (bash, read_file, write_file) + args: Tool arguments + timeout: Optional timeout override + + Returns: + ExecutionResult with output or error + """ + execution_id = str(uuid.uuid4()) + exec_timeout = timeout or self.timeout.total or 30.0 + + # Mark slot as executing + original_state = slot.state + try: + if slot.state == SlotState.ACQUIRED: + slot.start_execution(execution_id) + + result = await self._send_execute_request( + container_addr=slot.container_addr, + slot_id=slot.slot_id, + tool_name=tool_name, + args=args, + execution_id=execution_id, + timeout=exec_timeout, + ) + result.slot_id = slot.slot_id + return result + + finally: + # Restore slot state + if slot.state == SlotState.EXECUTING: + slot.end_execution() + + async def _send_execute_request( + self, + container_addr: str, + slot_id: str, + tool_name: str, + args: Dict[str, Any], + execution_id: str, + timeout: float, + ) -> ExecutionResult: + """Send execution request to sandbox server with retry logic.""" + session = await self._get_session() + url = f"{container_addr}/execute" + + payload = { + "slot_id": slot_id, + "tool": tool_name, + "args": args, + "execution_id": execution_id, + "timeout": timeout, + } + + last_error = None + for attempt in range(self.max_retries): + try: + async with session.post(url, json=payload) as response: + data = await response.json() + + return ExecutionResult( + success=data.get("success", False), + output=data.get("output", ""), + error=data.get("error", ""), + execution_id=data.get("execution_id", execution_id), + metadata=data.get("metadata", {}), + ) + + except aiohttp.ClientError as e: + last_error = str(e) + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) + continue + except asyncio.TimeoutError: + last_error = f"Request timed out after {timeout}s" + break + except Exception as e: + last_error = str(e) + break + + return ExecutionResult( + success=False, + error=f"Failed after {self.max_retries} attempts: {last_error}", + execution_id=execution_id, + ) + + async def execute_batch( + self, + requests: List[Tuple[Slot, str, Dict[str, Any]]], + timeout: Optional[float] = None, + ) -> List[ExecutionResult]: + """ + Execute multiple tools in parallel across slots. + + This is the key optimization - we batch tool calls to maximize + container utilization while agents are waiting for LLM responses. + + Args: + requests: List of (slot, tool_name, args) tuples + timeout: Optional timeout override + + Returns: + List of ExecutionResults in same order as requests + """ + if not requests: + return [] + + # Group requests by container address for batch API + by_container: Dict[str, List[Tuple[int, Slot, str, Dict[str, Any], str]]] = {} + + for idx, (slot, tool_name, args) in enumerate(requests): + execution_id = str(uuid.uuid4()) + container = slot.container_addr + + if container not in by_container: + by_container[container] = [] + by_container[container].append((idx, slot, tool_name, args, execution_id)) + + # Mark slots as executing + if slot.state == SlotState.ACQUIRED: + slot.start_execution(execution_id) + + # Execute batches in parallel + exec_timeout = timeout or self.timeout.total or 30.0 + batch_tasks = [] + + for container_addr, batch_requests in by_container.items(): + task = self._send_batch_request( + container_addr=container_addr, + batch_requests=batch_requests, + timeout=exec_timeout, + ) + batch_tasks.append(task) + + # Gather all batch results + batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) + + # Collect results in original order + results: List[Optional[ExecutionResult]] = [None] * len(requests) + + for batch_result in batch_results: + if isinstance(batch_result, Exception): + # Mark all in this batch as failed + continue + + for idx, result in batch_result: + results[idx] = result + + # Fill in any missing results + for idx, result in enumerate(results): + if result is None: + slot, tool_name, args = requests[idx] + results[idx] = ExecutionResult( + success=False, + error="Batch execution failed", + slot_id=slot.slot_id, + ) + + # End execution on all slots + for slot, _, _ in requests: + if slot.state == SlotState.EXECUTING: + slot.end_execution() + + return results # type: ignore + + async def _send_batch_request( + self, + container_addr: str, + batch_requests: List[Tuple[int, Slot, str, Dict[str, Any], str]], + timeout: float, + ) -> List[Tuple[int, ExecutionResult]]: + """Send batch execution request to a single container.""" + session = await self._get_session() + url = f"{container_addr}/batch" + + # Build batch payload + payload = [ + { + "slot_id": slot.slot_id, + "tool": tool_name, + "args": args, + "execution_id": execution_id, + "timeout": timeout, + } + for _, slot, tool_name, args, execution_id in batch_requests + ] + + try: + async with session.post(url, json=payload) as response: + data = await response.json() + + if not isinstance(data, list): + raise ValueError(f"Expected list response, got {type(data)}") + + results = [] + for i, (idx, slot, _, _, execution_id) in enumerate(batch_requests): + if i < len(data): + item = data[i] + result = ExecutionResult( + success=item.get("success", False), + output=item.get("output", ""), + error=item.get("error", ""), + execution_id=item.get("execution_id", execution_id), + slot_id=slot.slot_id, + metadata=item.get("metadata", {}), + ) + else: + result = ExecutionResult( + success=False, + error="Missing result in batch response", + execution_id=execution_id, + slot_id=slot.slot_id, + ) + results.append((idx, result)) + + return results + + except Exception as e: + # Return error for all requests in batch + return [ + (idx, ExecutionResult( + success=False, + error=str(e), + execution_id=execution_id, + slot_id=slot.slot_id, + )) + for idx, slot, _, _, execution_id in batch_requests + ] + + async def reset_slot(self, slot: Slot) -> ExecutionResult: + """ + Reset a slot's workspace (delete all files). + + Useful when reusing a slot for a new trajectory. + """ + session = await self._get_session() + url = f"{slot.container_addr}/reset" + + try: + async with session.post(url, json={"slot_id": slot.slot_id}) as response: + data = await response.json() + return ExecutionResult( + success=data.get("success", False), + output=data.get("output", ""), + error=data.get("error", ""), + slot_id=slot.slot_id, + ) + except Exception as e: + return ExecutionResult( + success=False, + error=str(e), + slot_id=slot.slot_id, + ) + + async def health_check(self, container_addr: str) -> bool: + """Check if a sandbox container is healthy.""" + session = await self._get_session() + url = f"{container_addr}/health" + + try: + async with session.get(url) as response: + data = await response.json() + return data.get("status") == "ok" + except Exception: + return False + + async def get_container_status( + self, + container_addr: str + ) -> Optional[Dict[str, Any]]: + """Get status info from a sandbox container.""" + session = await self._get_session() + url = f"{container_addr}/health" + + try: + async with session.get(url) as response: + return await response.json() + except Exception: + return None + + # ------------------------------------------------------------------------- + # Artifact helpers (optional) + # ------------------------------------------------------------------------- + + async def _post_json( + self, + url: str, + payload: Dict[str, Any], + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + session = await self._get_session() + try: + async with session.post(url, json=payload, timeout=timeout) as response: + data = await response.json() + if isinstance(data, dict): + data.setdefault("http_status", response.status) + return data + return {"success": False, "error": f"Unexpected response type: {type(data)}", "http_status": response.status} + except Exception as e: + return {"success": False, "error": str(e)} + + async def read_artifact( + self, + slot: Slot, + path: str, + *, + encoding: str = "text", + max_bytes: Optional[int] = None, + include_sha256: bool = False, + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + url = f"{slot.container_addr}/artifacts/read" + payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "encoding": encoding, "include_sha256": include_sha256} + if max_bytes is not None: + payload["max_bytes"] = max_bytes + return await self._post_json(url, payload, timeout=timeout) + + async def list_artifacts( + self, + slot: Slot, + path: str = ".", + *, + recursive: bool = False, + max_entries: Optional[int] = None, + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + url = f"{slot.container_addr}/artifacts/list" + payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "recursive": recursive} + if max_entries is not None: + payload["max_entries"] = max_entries + return await self._post_json(url, payload, timeout=timeout) + + async def archive_artifacts( + self, + slot: Slot, + path: str = ".", + *, + archive_format: str = "tar.gz", + max_bytes: Optional[int] = None, + max_entries: Optional[int] = None, + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + url = f"{slot.container_addr}/artifacts/archive" + payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "format": archive_format} + if max_bytes is not None: + payload["max_bytes"] = max_bytes + if max_entries is not None: + payload["max_entries"] = max_entries + return await self._post_json(url, payload, timeout=timeout) diff --git a/atropos/slots/pool.py b/atropos/slots/pool.py new file mode 100644 index 0000000000..ba7eb683b3 --- /dev/null +++ b/atropos/slots/pool.py @@ -0,0 +1,484 @@ +""" +SlotPool - Manages slots across Nomad allocations. + +The SlotPool is the core abstraction for slot-based multiplexing: +- Tracks available/acquired slots across containers +- Handles slot acquisition and release +- Auto-scales Nomad job count based on demand +- Provides batched tool execution +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from ..nomad.client import ( + Allocation, + AllocationStatus, + NomadClient, + create_sandbox_job, +) +from .executor import ExecutionResult, SandboxExecutor +from .slot import Slot, SlotState, create_slots_for_allocation + +logger = logging.getLogger(__name__) + + +@dataclass +class SlotPoolConfig: + """Configuration for SlotPool.""" + + # Nomad settings + nomad_address: str = "http://localhost:4646" + job_id: str = "atropos-sandbox" + datacenter: str = "dc1" + + # Container settings + image: str = "atropos-sandbox:local" # Use :local tag to avoid registry pull + slots_per_container: int = 10 + privileged: bool = False + cpu: int = 500 # MHz + memory: int = 512 # MB + + # Scaling settings + min_containers: int = 1 + max_containers: int = 10 + + # Timeouts + acquire_timeout: float = 30.0 # Seconds between acquire polls (also triggers scale-up attempts) + health_check_interval: float = 30.0 # Seconds between health checks + scale_cooldown: float = 60.0 # Seconds between scale operations + + +class SlotPool: + """ + Manages a pool of slots across Nomad allocations. + + The SlotPool: + - Deploys sandbox containers to Nomad + - Tracks slots across all running containers + - Handles slot acquisition/release + - Auto-scales based on demand + - Provides batched execution via SandboxExecutor + + Usage: + config = SlotPoolConfig( + nomad_address="http://localhost:4646", + job_id="my-sandbox", + slots_per_container=10, + ) + + pool = SlotPool(config) + await pool.start() + + # Acquire a slot + slot = await pool.acquire() + + # Execute tool + result = await pool.execute(slot, "bash", {"command": "ls"}) + + # Release slot + await pool.release(slot) + + # Shutdown + await pool.stop() + """ + + def __init__(self, config: Optional[SlotPoolConfig] = None): + self.config = config or SlotPoolConfig() + + # Nomad client + self.nomad = NomadClient(address=self.config.nomad_address) + + # Sandbox executor for tool execution + self.executor = SandboxExecutor() + + # Slot tracking + self._slots: Dict[str, Slot] = {} # slot_key -> Slot + self._available_queue: asyncio.Queue[str] = asyncio.Queue() + self._lock = asyncio.Lock() + self._scale_lock = asyncio.Lock() + + # State + self._started = False + self._health_task: Optional[asyncio.Task] = None + self._scale_task: Optional[asyncio.Task] = None + self._last_scale_time = 0.0 + + def _slot_key(self, alloc_id: str, slot_id: str) -> str: + """Generate unique key for a slot.""" + return f"{alloc_id}:{slot_id}" + + @property + def total_slots(self) -> int: + """Total number of slots in pool.""" + return len(self._slots) + + @property + def available_slots(self) -> int: + """Number of available slots.""" + return sum(1 for s in self._slots.values() if s.is_available) + + @property + def acquired_slots(self) -> int: + """Number of acquired slots.""" + return sum(1 for s in self._slots.values() if s.is_acquired) + + async def start(self) -> None: + """ + Start the slot pool. + + - Checks if Nomad is healthy + - Deploys sandbox job if not running + - Discovers existing allocations + - Starts health check background task + """ + if self._started: + return + + logger.info(f"Starting SlotPool (job_id={self.config.job_id})") + + # Check Nomad health + if not await self.nomad.is_healthy(): + raise RuntimeError(f"Nomad is not reachable at {self.config.nomad_address}") + + # Check if job exists + job = await self.nomad.get_job(self.config.job_id) + + if job is None: + # Deploy new job + logger.info(f"Deploying sandbox job: {self.config.job_id}") + job_spec = create_sandbox_job( + job_id=self.config.job_id, + image=self.config.image, + count=self.config.min_containers, + slots_per_container=self.config.slots_per_container, + privileged=self.config.privileged, + cpu=self.config.cpu, + memory=self.config.memory, + datacenter=self.config.datacenter, + ) + result = await self.nomad.submit_job(job_spec) + if "error" in result: + raise RuntimeError(f"Failed to submit job: {result}") + + # Wait for allocations to be running + await self._wait_for_healthy_allocations(self.config.min_containers) + + # Discover existing allocations and slots + await self._refresh_slots() + + # Start health check task + self._health_task = asyncio.create_task(self._health_check_loop()) + + self._started = True + logger.info(f"SlotPool started: {self.total_slots} slots available") + + async def stop(self, purge_job: bool = False) -> None: + """ + Stop the slot pool. + + Args: + purge_job: If True, also stop the Nomad job + """ + logger.info("Stopping SlotPool") + + # Cancel health check task + if self._health_task: + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + pass + finally: + self._health_task = None + + if self._scale_task: + self._scale_task.cancel() + try: + await self._scale_task + except asyncio.CancelledError: + pass + finally: + self._scale_task = None + + # Optionally stop the job (do this even if start() never completed). + if purge_job: + logger.info(f"Stopping Nomad job: {self.config.job_id}") + await self.nomad.stop_job(self.config.job_id, purge=True) + + # Close connections + await self.executor.close() + await self.nomad.close() + + self._started = False + self._slots.clear() + + # Clear the queue + while not self._available_queue.empty(): + try: + self._available_queue.get_nowait() + except asyncio.QueueEmpty: + break + + async def acquire(self, trajectory_id: Optional[str] = None) -> Slot: + """ + Acquire an available slot. + + If no slots are available, waits up to acquire_timeout seconds. + If still no slots, attempts to scale up. + + Args: + trajectory_id: Optional ID of trajectory acquiring the slot + + Returns: + Acquired Slot + + Raises: + asyncio.TimeoutError: If no slot becomes available + """ + if not self._started: + raise RuntimeError("SlotPool not started") + + while True: + try: + # Try to get an available slot + slot_key = await asyncio.wait_for( + self._available_queue.get(), + timeout=self.config.acquire_timeout, + ) + except asyncio.TimeoutError: + # Try to scale up, but keep waiting even if scaling isn't possible. + # In practice, slots may become available shortly (e.g. contention), + # and scaling may be temporarily blocked by Nomad deployments. + await self._try_scale_up() + continue + + slot = self._slots.get(slot_key) + if slot is None: + # Slot was removed; discard stale queue entry and retry. + continue + + try: + slot.acquire(trajectory_id) + except RuntimeError: + # Slot isn't actually available (e.g. duplicate queue entry); retry. + continue + + logger.debug(f"Acquired slot {slot.slot_id} (alloc={slot.alloc_id[:8]})") + return slot + + async def release(self, slot: Slot, reset_workspace: bool = False) -> None: + """ + Release a slot back to the pool. + + Args: + slot: Slot to release + reset_workspace: If True, clear the workspace files + """ + slot_key = self._slot_key(slot.alloc_id, slot.slot_id) + + if slot_key not in self._slots: + logger.warning(f"Releasing unknown slot: {slot_key}") + return + + # Optionally reset workspace + if reset_workspace: + await self.executor.reset_slot(slot) + + slot.release() + await self._available_queue.put(slot_key) + + logger.debug(f"Released slot {slot.slot_id}") + + async def execute( + self, + slot: Slot, + tool_name: str, + args: Dict[str, Any], + timeout: Optional[float] = None, + ) -> ExecutionResult: + """ + Execute a tool in a slot's workspace. + + Args: + slot: Slot to execute in + tool_name: Name of tool (bash, read_file, write_file) + args: Tool arguments + timeout: Optional timeout override + + Returns: + ExecutionResult + """ + return await self.executor.execute(slot, tool_name, args, timeout) + + async def execute_batch( + self, + requests: List[Tuple[Slot, str, Dict[str, Any]]], + timeout: Optional[float] = None, + ) -> List[ExecutionResult]: + """ + Execute multiple tools in parallel. + + This is the key optimization - batch execution across multiple slots + maximizes container utilization. + + Args: + requests: List of (slot, tool_name, args) tuples + timeout: Optional timeout override + + Returns: + List of ExecutionResults in same order + """ + return await self.executor.execute_batch(requests, timeout) + + async def _refresh_slots(self) -> None: + """Refresh slot inventory from Nomad allocations.""" + async with self._lock: + allocs = await self.nomad.get_job_allocations(self.config.job_id) + + # Track which slots we've seen + seen_keys = set() + + for alloc in allocs: + if alloc.status != AllocationStatus.RUNNING: + continue + + if not alloc.http_address: + continue + + # Check container health + healthy = await self.executor.health_check(alloc.http_address) + if not healthy: + continue + + # Create slots for this allocation + for i in range(self.config.slots_per_container): + slot_id = f"slot_{i}" + slot_key = self._slot_key(alloc.id, slot_id) + seen_keys.add(slot_key) + + if slot_key not in self._slots: + # New slot + slot = Slot( + slot_id=slot_id, + alloc_id=alloc.id, + container_addr=alloc.http_address, + ) + self._slots[slot_key] = slot + await self._available_queue.put(slot_key) + logger.debug(f"Added slot: {slot_key}") + + # Remove slots from dead allocations + for slot_key in list(self._slots.keys()): + if slot_key not in seen_keys: + slot = self._slots.pop(slot_key) + logger.debug(f"Removed slot: {slot_key}") + + async def _wait_for_healthy_allocations( + self, + min_count: int, + timeout: float = 120.0 + ) -> None: + """Wait for allocations to become healthy.""" + import time + start = time.time() + + while time.time() - start < timeout: + allocs = await self.nomad.get_job_allocations(self.config.job_id) + + healthy_count = 0 + for alloc in allocs: + if alloc.status == AllocationStatus.RUNNING and alloc.http_address: + if await self.executor.health_check(alloc.http_address): + healthy_count += 1 + + if healthy_count >= min_count: + return + + await asyncio.sleep(2.0) + + raise RuntimeError(f"Timed out waiting for {min_count} healthy allocations") + + async def _try_scale_up(self) -> bool: + """Attempt to scale up the job.""" + import time + + async with self._scale_lock: + # Check cooldown + if time.time() - self._last_scale_time < self.config.scale_cooldown: + return False + + # Check max containers + status = await self.nomad.get_job_status(self.config.job_id) + if status is None: + return False + + current_count = status.count + if current_count >= self.config.max_containers: + logger.warning(f"Cannot scale up: already at max ({self.config.max_containers})") + return False + + # Scale up + new_count = min(current_count + 1, self.config.max_containers) + logger.info(f"Scaling up from {current_count} to {new_count} containers") + + scale_resp = await self.nomad.scale_job( + self.config.job_id, + count=new_count, + task_group="sandbox", + ) + + # Nomad may return non-JSON errors (e.g. plain text) with a status field. + if isinstance(scale_resp, dict) and scale_resp.get("status", 200) >= 400: + logger.warning(f"Scale request rejected: {scale_resp}") + self._last_scale_time = time.time() + return False + + self._last_scale_time = time.time() + + # Wait for new allocation in the background so contended acquires can still + # make progress (e.g. by grabbing slots released by other trajectories). + if self._scale_task is None or self._scale_task.done(): + self._scale_task = asyncio.create_task(self._wait_for_scale(new_count)) + + return True + + async def _wait_for_scale(self, desired_count: int) -> None: + try: + await self._wait_for_healthy_allocations(desired_count, timeout=60.0) + await self._refresh_slots() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Failed to scale up: {e}") + + async def _health_check_loop(self) -> None: + """Background task to monitor container health.""" + while True: + try: + await asyncio.sleep(self.config.health_check_interval) + await self._refresh_slots() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Health check error: {e}") + + def get_stats(self) -> Dict[str, Any]: + """Get pool statistics.""" + slots_by_state = {} + for slot in self._slots.values(): + state = slot.state.value + slots_by_state[state] = slots_by_state.get(state, 0) + 1 + + container_count = len({s.alloc_id for s in self._slots.values()}) if self._slots else 0 + + return { + "total_slots": self.total_slots, + "available_slots": self.available_slots, + "acquired_slots": self.acquired_slots, + "containers": container_count, + "slots_by_state": slots_by_state, + "started": self._started, + } diff --git a/atropos/slots/slot.py b/atropos/slots/slot.py new file mode 100644 index 0000000000..fb7c6747c7 --- /dev/null +++ b/atropos/slots/slot.py @@ -0,0 +1,159 @@ +""" +Slot abstraction for atropos-agent. + +A Slot represents an isolated workspace for a single agent trajectory. +Slots are hosted on Nomad allocations and provide workspace isolation +via filesystem directories. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Optional +import uuid + + +class SlotState(Enum): + """State of a slot in the pool.""" + AVAILABLE = "available" # Ready to be acquired + ACQUIRED = "acquired" # Assigned to a trajectory + EXECUTING = "executing" # Currently executing a tool + RELEASING = "releasing" # Being released back to pool + ERROR = "error" # In error state + + +@dataclass +class Slot: + """ + An isolated workspace for a single agent trajectory. + + Slots are the unit of scheduling - each trajectory runs in its own slot, + with an isolated workspace directory. Multiple slots share a container. + + Attributes: + slot_id: Unique identifier for this slot (e.g., "slot_0") + alloc_id: Nomad allocation ID hosting this slot + container_addr: HTTP address of the sandbox server (e.g., "http://10.0.0.1:8080") + workspace_dir: Path to workspace in container (e.g., "/data/slot_0") + state: Current state of the slot + trajectory_id: ID of trajectory currently using this slot (if acquired) + metadata: Additional metadata + """ + slot_id: str + alloc_id: str + container_addr: str + workspace_dir: str = "" + state: SlotState = SlotState.AVAILABLE + trajectory_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Set default workspace_dir if not provided.""" + if not self.workspace_dir: + self.workspace_dir = f"/data/{self.slot_id}" + + @property + def is_available(self) -> bool: + """Check if slot is available for acquisition.""" + return self.state == SlotState.AVAILABLE + + @property + def is_acquired(self) -> bool: + """Check if slot is currently acquired.""" + return self.state in (SlotState.ACQUIRED, SlotState.EXECUTING) + + def acquire(self, trajectory_id: Optional[str] = None) -> None: + """ + Mark slot as acquired by a trajectory. + + Args: + trajectory_id: Optional ID of acquiring trajectory + """ + if not self.is_available: + raise RuntimeError(f"Cannot acquire slot {self.slot_id}: state is {self.state}") + + self.state = SlotState.ACQUIRED + self.trajectory_id = trajectory_id or str(uuid.uuid4()) + + def start_execution(self, execution_id: Optional[str] = None) -> None: + """Mark slot as executing.""" + if self.state != SlotState.ACQUIRED: + raise RuntimeError(f"Cannot start execution on slot {self.slot_id}: state is {self.state}") + + self.state = SlotState.EXECUTING + if execution_id: + self.metadata["current_execution_id"] = execution_id + + def end_execution(self) -> None: + """Mark execution as complete, return to acquired state.""" + if self.state != SlotState.EXECUTING: + raise RuntimeError(f"Cannot end execution on slot {self.slot_id}: state is {self.state}") + + self.state = SlotState.ACQUIRED + self.metadata.pop("current_execution_id", None) + + def release(self) -> None: + """Release slot back to available state.""" + self.state = SlotState.AVAILABLE + self.trajectory_id = None + self.metadata.pop("current_execution_id", None) + + def mark_error(self, error: str) -> None: + """Mark slot as in error state.""" + self.state = SlotState.ERROR + self.metadata["error"] = error + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "slot_id": self.slot_id, + "alloc_id": self.alloc_id, + "container_addr": self.container_addr, + "workspace_dir": self.workspace_dir, + "state": self.state.value, + "trajectory_id": self.trajectory_id, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Slot": + """Create from dictionary.""" + return cls( + slot_id=data["slot_id"], + alloc_id=data["alloc_id"], + container_addr=data["container_addr"], + workspace_dir=data.get("workspace_dir", ""), + state=SlotState(data.get("state", "available")), + trajectory_id=data.get("trajectory_id"), + metadata=data.get("metadata", {}), + ) + + def __repr__(self) -> str: + return f"Slot({self.slot_id}, state={self.state.value}, alloc={self.alloc_id[:8]}...)" + + +def create_slots_for_allocation( + alloc_id: str, + container_addr: str, + num_slots: int = 10, +) -> list["Slot"]: + """ + Create slots for a Nomad allocation. + + Args: + alloc_id: Nomad allocation ID + container_addr: HTTP address of sandbox server + num_slots: Number of slots to create + + Returns: + List of Slot objects + """ + slots = [] + for i in range(num_slots): + slot_id = f"slot_{i}" + slots.append(Slot( + slot_id=slot_id, + alloc_id=alloc_id, + container_addr=container_addr, + workspace_dir=f"/data/{slot_id}", + )) + return slots diff --git a/atropos/terminal/__init__.py b/atropos/terminal/__init__.py new file mode 100644 index 0000000000..94a321ddc9 --- /dev/null +++ b/atropos/terminal/__init__.py @@ -0,0 +1,2 @@ +"""Terminal helpers for stateful sandbox interactions.""" + diff --git a/atropos/terminal/asciinema_stream.py b/atropos/terminal/asciinema_stream.py new file mode 100644 index 0000000000..9b32b4d147 --- /dev/null +++ b/atropos/terminal/asciinema_stream.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +from typing import Any + +import pyte + + +class AsciinemaStreamDecoder: + def __init__(self, *, default_width: int = 80, default_height: int = 24) -> None: + self._default_width = max(1, int(default_width)) + self._default_height = max(1, int(default_height)) + self._buffer = "" + self._has_header = False + self.width = self._default_width + self.height = self._default_height + self._screen = pyte.Screen(self.width, self.height) + self._stream = pyte.Stream(self._screen) + + def reset(self) -> None: + self._buffer = "" + self._has_header = False + self.width = self._default_width + self.height = self._default_height + self._screen = pyte.Screen(self.width, self.height) + self._stream = pyte.Stream(self._screen) + + def feed(self, chunk: str | bytes) -> None: + if not chunk: + return + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="replace") + self._buffer += chunk + while True: + line, sep, rest = self._buffer.partition("\n") + if not sep: + break + self._buffer = rest + line = line.strip() + if not line: + continue + parsed = self._parse_json_line(line) + if parsed is None: + continue + if not self._has_header: + if isinstance(parsed, dict): + self._init_from_header(parsed) + continue + if isinstance(parsed, list): + self._has_header = True + self._apply_event(parsed) + continue + continue + if isinstance(parsed, list): + self._apply_event(parsed) + + def render(self) -> str: + return "\n".join(self._screen.display) + + def _parse_json_line(self, line: str) -> Any | None: + try: + return json.loads(line) + except json.JSONDecodeError: + return None + + def _init_from_header(self, header: dict[str, Any]) -> None: + width = _coerce_int( + header.get("width") or header.get("columns") or header.get("cols"), + self._default_width, + ) + height = _coerce_int( + header.get("height") or header.get("rows") or header.get("lines"), + self._default_height, + ) + self.width = max(1, width) + self.height = max(1, height) + self._screen = pyte.Screen(self.width, self.height) + self._stream = pyte.Stream(self._screen) + self._has_header = True + + def _apply_event(self, event: list[Any]) -> None: + if len(event) < 2: + return + event_type = event[1] + payload = event[2] if len(event) > 2 else "" + if event_type == "o": + if isinstance(payload, str): + self._stream.feed(payload) + elif event_type == "r": + width, height = _parse_resize(payload) + if width and height: + self.width = width + self.height = height + self._screen.resize(width, height) + + +def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return int(default) + + +def _parse_resize(payload: Any) -> tuple[int, int]: + if isinstance(payload, str) and "x" in payload: + left, right = payload.lower().split("x", 1) + return _coerce_int(left, 0), _coerce_int(right, 0) + if isinstance(payload, dict): + width = _coerce_int(payload.get("width") or payload.get("columns") or payload.get("cols"), 0) + height = _coerce_int(payload.get("height") or payload.get("rows") or payload.get("lines"), 0) + return width, height + if isinstance(payload, list) and len(payload) >= 2: + return _coerce_int(payload[0], 0), _coerce_int(payload[1], 0) + return 0, 0 + diff --git a/atropos/tools/__init__.py b/atropos/tools/__init__.py new file mode 100644 index 0000000000..1ba666159c --- /dev/null +++ b/atropos/tools/__init__.py @@ -0,0 +1,35 @@ +""" +Tool abstractions for atropos-agent. + +Provides base Tool class and common tool implementations. +""" + +from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema +from .basic_tools import BashTool, ReadFileTool, WriteFileTool +from .image_generation_tool import ImageGenerateTool +from .mixture_of_agents_tool import MixtureOfAgentsTool +from .terminal_tool import TerminalTool +from .terminal_stateful_tool import TerminalStatefulTool +from .tmux_tool import TmuxTool +from .vision_tools import VisionAnalyzeTool +from .web_tools import WebCrawlTool, WebExtractTool, WebSearchTool + +__all__ = [ + "Tool", + "ToolCall", + "ToolRegistry", + "ToolResult", + "ToolSchema", + "BashTool", + "ReadFileTool", + "WriteFileTool", + "ImageGenerateTool", + "TerminalTool", + "TerminalStatefulTool", + "TmuxTool", + "WebSearchTool", + "WebExtractTool", + "WebCrawlTool", + "VisionAnalyzeTool", + "MixtureOfAgentsTool", +] diff --git a/atropos/tools/base.py b/atropos/tools/base.py new file mode 100644 index 0000000000..33c9d1017a --- /dev/null +++ b/atropos/tools/base.py @@ -0,0 +1,363 @@ +""" +Base Tool abstraction for atropos-agent. + +Tools follow a simple pattern: +1. Define schema (name, description, parameters) +2. Implement execute() method +3. Return ToolResult with output/error + +Tool calls use Hermes-style XML tags: +{"name": "bash", "arguments": {"command": "ls"}} +""" + +import json +import re +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + + +@dataclass +class ToolSchema: + """JSON Schema for a tool's parameters.""" + + name: str + description: str + parameters: Dict[str, Any] = field(default_factory=dict) + required: List[str] = field(default_factory=list) + external: bool = False # Whether the tool must be executed via an external ToolServer (secret proxy) and not inside the sandbox. + + def to_dict(self) -> Dict[str, Any]: + """Convert to OpenAI-compatible function schema.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": self.parameters, + "required": self.required, + }, + }, + } + + def to_prompt_description(self) -> str: + """Convert to human-readable description for system prompt.""" + params_desc = [] + for name, spec in self.parameters.items(): + req = "(required)" if name in self.required else "(optional)" + desc = spec.get("description", "") + param_type = spec.get("type", "string") + params_desc.append(f" - {name} ({param_type}) {req}: {desc}") + + params_str = "\n".join(params_desc) if params_desc else " (no parameters)" + return f"**{self.name}**: {self.description}\nParameters:\n{params_str}" + + +@dataclass +class ToolCall: + """A parsed tool call from model output.""" + + name: str + arguments: Dict[str, Any] + raw_text: str = "" # Original XML/JSON text + uniq_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Unique tool-call id for traceability/reconstruction. + + @classmethod + def parse_from_text(cls, text: str) -> List["ToolCall"]: + """ + Extract tool calls from text using Hermes-style XML tags. + + Format: {"name": "...", "arguments": {...}} + """ + calls = [] + pattern = r"\s*(.*?)\s*" + matches = re.findall(pattern, text, re.DOTALL) + + for match in matches: + try: + data = json.loads(match) + uniq_id = data.get("uniq_id") or data.get("id") or str(uuid.uuid4()) + calls.append(cls( + name=data.get("name", ""), + arguments=data.get("arguments", {}), + raw_text=match, + uniq_id=uniq_id, + )) + except json.JSONDecodeError: + # Skip malformed tool calls + continue + + return calls + + @classmethod + def has_tool_call(cls, text: str) -> bool: + """Check if text contains any tool calls.""" + return bool(re.search(r"", text)) + + +@dataclass +class ToolResult: + """Result from executing a tool.""" + + success: bool + output: str = "" + error: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + uniq_id: Optional[str] = None # Should match ToolCall.uniq_id for async execution tracking. + + def to_xml(self) -> str: + """Format as XML for including in conversation.""" + data = { + "success": self.success, + "output": self.output, + } + if self.uniq_id: + data["uniq_id"] = self.uniq_id + if self.error: + data["error"] = self.error + if self.metadata: + data["metadata"] = self.metadata + return f"{json.dumps(data)}" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "success": self.success, + "output": self.output, + "error": self.error, + "metadata": self.metadata, + "uniq_id": self.uniq_id, + } + + +class Tool(ABC): + """ + Abstract base class for tools. + + Subclasses must implement: + - schema: ToolSchema describing the tool + - execute(): async method that performs the tool action + """ + + @property + @abstractmethod + def schema(self) -> ToolSchema: + """Return the tool's schema.""" + pass + + @property + def name(self) -> str: + """Tool name (from schema).""" + return self.schema.name + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """ + Execute the tool with given arguments. + + Args: + **kwargs: Tool-specific arguments + + Returns: + ToolResult with success/failure and output + """ + pass + + def is_available(self) -> tuple[bool, str | None]: + """ + Return whether this tool should be exposed/executable in the current process. + + Tools that depend on optional binaries/services/env vars can override this + to avoid advertising a tool that will fail at runtime. + """ + return True, None + + async def __call__(self, **kwargs) -> ToolResult: + """Allow calling tool instance directly.""" + return await self.execute(**kwargs) + +# Note: This is only wrapping declarations for the external ToolServer (for execution on external process tools), and tools preinstalled in envs +class ToolRegistry: + """Registry of available tools.""" + + def __init__(self): + self._tools: Dict[str, Tool] = {} + + def register(self, tool: Tool) -> None: + """Register a tool.""" + self._tools[tool.name] = tool + + def get(self, name: str) -> Optional[Tool]: + """Get a tool by name.""" + return self._tools.get(name) + + def list_tools(self) -> List[Tool]: + """List all registered tools.""" + return list(self._tools.values()) + + def get_schemas(self) -> List[ToolSchema]: + """Get schemas for all registered tools.""" + return [tool.schema for tool in self._tools.values()] + + def get_prompt_description(self) -> str: + """Generate tool descriptions for system prompt.""" + descriptions = [tool.schema.to_prompt_description() for tool in self._tools.values()] + return "\n\n".join(descriptions) + + async def execute(self, call: ToolCall) -> ToolResult: + """Execute a tool call.""" + tool = self.get(call.name) + if tool is None: + return ToolResult( + success=False, + error=f"Unknown tool: {call.name}", + uniq_id=call.uniq_id, + ) + + try: + result = await tool.execute(**call.arguments) + if result.uniq_id is None: + result.uniq_id = call.uniq_id + return result + except Exception as e: + return ToolResult( + success=False, + error=f"Tool execution error: {str(e)}", + uniq_id=call.uniq_id, + ) + + +# ============================================================================= +# FastAPI / transport models +# ============================================================================= + + +class ToolCallPayload(BaseModel): + name: str + arguments: Dict[str, Any] = Field(default_factory=dict) + uniq_id: str + + @classmethod + def from_tool_call(cls, call: ToolCall) -> "ToolCallPayload": + return cls(name=call.name, arguments=call.arguments, uniq_id=call.uniq_id) + + def to_tool_call(self) -> ToolCall: + return ToolCall(name=self.name, arguments=self.arguments, uniq_id=self.uniq_id) + + +class ToolResultPayload(BaseModel): + success: bool + output: str = "" + error: str = "" + metadata: Dict[str, Any] = Field(default_factory=dict) + uniq_id: Optional[str] = None + + @classmethod + def from_tool_result(cls, result: ToolResult) -> "ToolResultPayload": + return cls( + success=result.success, + output=result.output, + error=result.error, + metadata=result.metadata, + uniq_id=result.uniq_id, + ) + + def to_tool_result(self) -> ToolResult: + return ToolResult( + success=self.success, + output=self.output, + error=self.error, + metadata=self.metadata, + uniq_id=self.uniq_id, + ) + + +class ToolExecutorExecuteRequest(BaseModel): + trajectory_id: str + tool: ToolCallPayload + timeout_s: Optional[float] = None + + +class ToolExecutorReleaseRequest(BaseModel): + trajectory_id: str + reset_workspace: bool = False + + +class ToolServerExecuteRequest(BaseModel): + trajectory_id: Optional[str] = None + tool: ToolCallPayload + timeout_s: Optional[float] = None + # Optional sandbox context for tools that need workspace artifacts. + # This is set by ToolExecutor and is NOT model-controlled. + slot_id: Optional[str] = None + container_addr: Optional[str] = None + + +# ============================================================================= +# Artifact transport models +# ============================================================================= + + +class ArtifactReadRequestPayload(BaseModel): + trajectory_id: str + path: str + encoding: Literal["text", "base64"] = "text" + max_bytes: Optional[int] = None + include_sha256: bool = False + + +class ArtifactReadResponsePayload(BaseModel): + success: bool + content: str = "" + error: str = "" + encoding: str = "text" + truncated: bool = False + bytes: int = 0 + file_size: Optional[int] = None + path: str = "" + mime: Optional[str] = None + sha256: Optional[str] = None + + +class ArtifactListRequestPayload(BaseModel): + trajectory_id: str + path: str = "." + recursive: bool = False + max_entries: Optional[int] = None + + +class ArtifactListEntryPayload(BaseModel): + path: str + is_dir: bool + size: int + mtime: float + + +class ArtifactListResponsePayload(BaseModel): + success: bool + entries: List[ArtifactListEntryPayload] = Field(default_factory=list) + truncated: bool = False + error: str = "" + + +class ArtifactArchiveRequestPayload(BaseModel): + trajectory_id: str + path: str = "." + format: Literal["tar.gz", "tgz"] = "tar.gz" + max_bytes: Optional[int] = None + max_entries: Optional[int] = None + + +class ArtifactArchiveResponsePayload(BaseModel): + success: bool + content: str = "" + error: str = "" + encoding: str = "base64" + format: str = "tar.gz" + bytes: int = 0 + entry_count: int = 0 diff --git a/atropos/tools/basic_tools.py b/atropos/tools/basic_tools.py new file mode 100644 index 0000000000..0efd625efa --- /dev/null +++ b/atropos/tools/basic_tools.py @@ -0,0 +1,283 @@ +""" +Basic tool implementations for atropos-agent. + +These tools provide simple sandbox operations: +- BashTool: Execute shell commands +- ReadFileTool: Read file contents +- WriteFileTool: Write content to files + +For PoC, these run via subprocess in the local environment. +Production usage should use proper sandbox isolation. +""" + +import asyncio +import os +from pathlib import Path +from typing import Optional + +from .base import Tool, ToolResult, ToolSchema + + +class BashTool(Tool): + """ + Execute bash commands in a sandboxed environment. + + TODO: Nomad slot execution + """ + + def __init__( + self, + working_dir: Optional[str] = None, + timeout: float = 30.0, + max_output_size: int = 10000, + ): + self.working_dir = working_dir or os.getcwd() + self.timeout = timeout + self.max_output_size = max_output_size + + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="bash", + description="Execute a bash command and return stdout/stderr. Use for running shell commands, scripts, or system operations.", + parameters={ + "command": { + "type": "string", + "description": "The bash command to execute", + }, + }, + required=["command"], + ) + + async def execute(self, command: str) -> ToolResult: + """Execute a bash command.""" + try: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.working_dir, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=self.timeout, + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + return ToolResult( + success=False, + error=f"Command timed out after {self.timeout}s", + metadata={"exit_code": -1, "timeout": True}, + ) + + stdout_str = stdout.decode("utf-8", errors="replace") + stderr_str = stderr.decode("utf-8", errors="replace") + + # Truncate if too long + if len(stdout_str) > self.max_output_size: + stdout_str = stdout_str[:self.max_output_size] + "\n... (output truncated)" + if len(stderr_str) > self.max_output_size: + stderr_str = stderr_str[:self.max_output_size] + "\n... (output truncated)" + + exit_code = process.returncode + success = exit_code == 0 + + output = stdout_str + if stderr_str: + output = f"{stdout_str}\n[stderr]\n{stderr_str}" if stdout_str else stderr_str + + return ToolResult( + success=success, + output=output.strip(), + error="" if success else f"Exit code: {exit_code}", + metadata={"exit_code": exit_code}, + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"Failed to execute command: {str(e)}", + ) + + +class ReadFileTool(Tool): + """Read the contents of a file.""" + + def __init__( + self, + working_dir: Optional[str] = None, + max_file_size: int = 100000, + ): + self.working_dir = Path(working_dir) if working_dir else Path.cwd() + self.max_file_size = max_file_size + + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="read_file", + description="Read the contents of a file at the given path.", + parameters={ + "path": { + "type": "string", + "description": "Path to the file (relative to working directory)", + }, + }, + required=["path"], + ) + + async def execute(self, path: str) -> ToolResult: + """Read a file's contents.""" + try: + file_path = self.working_dir / path + + # Security: prevent path traversal outside working dir + resolved = file_path.resolve() + if not str(resolved).startswith(str(self.working_dir.resolve())): + return ToolResult( + success=False, + error="Access denied: path outside working directory", + ) + + if not resolved.exists(): + return ToolResult( + success=False, + error=f"File not found: {path}", + ) + + if not resolved.is_file(): + return ToolResult( + success=False, + error=f"Not a file: {path}", + ) + + size = resolved.stat().st_size + if size > self.max_file_size: + return ToolResult( + success=False, + error=f"File too large: {size} bytes (max {self.max_file_size})", + ) + + content = resolved.read_text(encoding="utf-8", errors="replace") + + return ToolResult( + success=True, + output=content, + metadata={"path": str(resolved), "size": size}, + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"Failed to read file: {str(e)}", + ) + + +class WriteFileTool(Tool): + """Write content to a file.""" + + def __init__( + self, + working_dir: Optional[str] = None, + max_file_size: int = 100000, + ): + self.working_dir = Path(working_dir) if working_dir else Path.cwd() + self.max_file_size = max_file_size + + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="write_file", + description="Write content to a file at the given path. Creates parent directories if needed.", + parameters={ + "path": { + "type": "string", + "description": "Path to the file (relative to working directory)", + }, + "content": { + "type": "string", + "description": "Content to write to the file", + }, + }, + required=["path", "content"], + ) + + async def execute(self, path: str, content: str) -> ToolResult: + """Write content to a file.""" + try: + if len(content) > self.max_file_size: + return ToolResult( + success=False, + error=f"Content too large: {len(content)} bytes (max {self.max_file_size})", + ) + + file_path = self.working_dir / path + + # Security: prevent path traversal outside working dir + resolved = file_path.resolve() + if not str(resolved).startswith(str(self.working_dir.resolve())): + return ToolResult( + success=False, + error="Access denied: path outside working directory", + ) + + # Create parent directories + resolved.parent.mkdir(parents=True, exist_ok=True) + + resolved.write_text(content, encoding="utf-8") + + return ToolResult( + success=True, + output=f"Successfully wrote {len(content)} bytes to {path}", + metadata={"path": str(resolved), "size": len(content)}, + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"Failed to write file: {str(e)}", + ) + +class FireCrawl(Tool): + """Perform a web crawl using FireCrawl tool.""" + + def __init__( + self, + working_dir: Optional[str] = None, + timeout: float = 60.0, + ): + self.working_dir = working_dir or os.getcwd() + self.timeout = timeout + + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="firecrawl", + description="Perform a web crawl starting from a given URL using FireCrawl.", + parameters={ + "start_url": { + "type": "string", + "description": "The starting URL for the web crawl", + }, + "max_pages": { + "type": "integer", + "description": "Maximum number of pages to crawl", + }, + }, + required=["start_url"], + ) + + async def execute(self, start_url: str, max_pages: int = 100) -> ToolResult: + """Execute a web crawl using FireCrawl.""" + try: + command = f"firecrawl --start-url {start_url} --max-pages {max_pages}" + bash_tool = BashTool(working_dir=self.working_dir, timeout=self.timeout) + result = await bash_tool.execute(command) + return result + except Exception as e: + return ToolResult( + success=False, + error=f"Failed to execute FireCrawl: {str(e)}", + ) \ No newline at end of file diff --git a/atropos/tools/image_generation_tool.py b/atropos/tools/image_generation_tool.py new file mode 100644 index 0000000000..d24a264dc2 --- /dev/null +++ b/atropos/tools/image_generation_tool.py @@ -0,0 +1,129 @@ +""" +Image generation tool (external). + +This is intentionally minimal for the Phase 4.6 "external tool demo": +- executed via ToolServer (no secrets in sandboxes) +- by default returns a tiny inline PNG data URL (no network required) +- can optionally proxy to an OpenAI-compatible images endpoint (e.g. a local service) +""" + +from __future__ import annotations + +import base64 +import json +import os +from dataclasses import dataclass +from typing import Literal, Optional + +import httpx + +from .base import Tool, ToolResult, ToolSchema + + +def _tiny_png_data_url() -> str: + # 1x1 transparent PNG + png_bytes = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMB/6X0kQAAAABJRU5ErkJggg==" + ) + return f"data:image/png;base64,{base64.b64encode(png_bytes).decode('ascii')}" + + +@dataclass +class ImageGenerateConfig: + backend: Literal["inline", "openai"] = "inline" + base_url: Optional[str] = None + model: Optional[str] = None + timeout_s: float = 120.0 + + @classmethod + def from_env(cls) -> "ImageGenerateConfig": + backend = (os.getenv("IMAGE_GENERATE_BACKEND") or "inline").strip().lower() + if backend not in {"inline", "openai"}: + backend = "inline" + + base_url = os.getenv("IMAGE_GENERATE_BASE_URL") or os.getenv("OLLAMA_BASE_URL") + model = os.getenv("IMAGE_GENERATE_MODEL") or os.getenv("OLLAMA_IMAGE_MODEL") + + timeout_s = float(os.getenv("IMAGE_GENERATE_TIMEOUT_S", "120.0")) + return cls( + backend=backend, # type: ignore[arg-type] + base_url=base_url, + model=model, + timeout_s=timeout_s, + ) + + +class ImageGenerateTool(Tool): + def __init__(self, config: Optional[ImageGenerateConfig] = None) -> None: + self._config = config or ImageGenerateConfig.from_env() + + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="image_generate", + description=( + "Generate an image from a text prompt. Returns a JSON object containing an image URL " + "(often a data URL) plus metadata." + ), + parameters={ + "prompt": {"type": "string", "description": "The image prompt. Be detailed."}, + "aspect_ratio": { + "type": "string", + "enum": ["landscape", "square", "portrait"], + "description": "Desired aspect ratio.", + "default": "landscape", + }, + }, + required=["prompt"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + # Availability of this external tool is primarily governed by whether a ToolServer + # is configured/enabled. The tool itself will return a clear error if an upstream + # image endpoint is required but not configured. + return True, None + + async def execute(self, prompt: str, aspect_ratio: str = "landscape") -> ToolResult: # noqa: ARG002 + cfg = self._config + + if cfg.backend == "inline": + payload = { + "url": _tiny_png_data_url(), + "backend": "inline", + "width": 1, + "height": 1, + } + return ToolResult(success=True, output=json.dumps(payload)) + + # OpenAI-compatible images generation endpoint. + # Expected: POST {base_url}/v1/images/generations with OpenAI-style body. + base_url = cfg.base_url or "" + base_url = base_url.rstrip("/") + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + url = f"{base_url}/images/generations" + + body = { + "prompt": prompt, + "n": 1, + "response_format": "url", + } + if cfg.model: + body["model"] = cfg.model + + try: + async with httpx.AsyncClient(timeout=cfg.timeout_s) as client: + resp = await client.post(url, json=body) + resp.raise_for_status() + data = resp.json() + except Exception as e: + return ToolResult(success=False, error=f"image_generate upstream call failed: {e}") + + try: + image_url = data["data"][0]["url"] + except Exception: + return ToolResult(success=False, error=f"Unexpected image response format: {data!r}") + + payload = {"url": image_url, "backend": "openai"} + return ToolResult(success=True, output=json.dumps(payload)) diff --git a/atropos/tools/mixture_of_agents_tool.py b/atropos/tools/mixture_of_agents_tool.py new file mode 100644 index 0000000000..6ff25326eb --- /dev/null +++ b/atropos/tools/mixture_of_agents_tool.py @@ -0,0 +1,649 @@ +#!/usr/bin/env python3 +""" +Mixture-of-Agents Tool Module + +This module implements the Mixture-of-Agents (MoA) methodology that leverages +the collective strengths of multiple LLMs through a layered architecture to +achieve state-of-the-art performance on complex reasoning tasks. + +Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities" +by Junlin Wang et al. (arXiv:2406.04692v1) + +Key Features: +- Multi-layer LLM collaboration for enhanced reasoning +- Parallel processing of reference models for efficiency +- Intelligent aggregation and synthesis of diverse responses +- Specialized for extremely difficult problems requiring intense reasoning +- Optimized for coding, mathematics, and complex analytical tasks + +Available Tool: +- mixture_of_agents_tool: Process complex queries using multiple frontier models + +Architecture: +1. Reference models generate diverse initial responses in parallel +2. Aggregator model synthesizes responses into a high-quality output +3. Multiple layers can be used for iterative refinement (future enhancement) + +Models Used (via OpenRouter): +- Reference Models: claude-opus-4, gemini-2.5-pro, gpt-4.1, deepseek-r1 +- Aggregator Model: claude-opus-4 (highest capability for synthesis) + +Configuration: + To customize the MoA setup, modify the configuration constants at the top of this file: + - REFERENCE_MODELS: List of models for generating diverse initial responses + - AGGREGATOR_MODEL: Model used to synthesize the final response + - REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures + - MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed + +Usage: + from mixture_of_agents_tool import mixture_of_agents_tool + import asyncio + + # Process a complex query + result = await mixture_of_agents_tool( + user_prompt="Solve this complex mathematical proof..." + ) +""" + +import json +import os +import asyncio +import uuid +import datetime +from pathlib import Path +from typing import Dict, Any, List, Optional +from openai import AsyncOpenAI + +# Initialize OpenRouter API client for MoA processing +openrouter_client = AsyncOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1" +) + +# Configuration for MoA processing +# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs) +REFERENCE_MODELS = [ + "anthropic/claude-opus-4.5", + "google/gemini-3-pro-preview", + "openai/gpt-5.2-pro", + "deepseek/deepseek-v3.2" +] + +# Aggregator model - synthesizes reference responses into final output +AGGREGATOR_MODEL = "anthropic/claude-opus-4.5" # Use highest capability model for aggregation + +# Temperature settings optimized for MoA performance +REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives +AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency + +# Failure handling configuration +MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed + +# System prompt for the aggregator model (from the research paper) +AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models:""" + +# Debug mode configuration +DEBUG_MODE = os.getenv("MOA_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› MoA debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"moa_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› MoA debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving MoA debug log: {str(e)}") + + +def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str: + """ + Construct the final system prompt for the aggregator including all model responses. + + Args: + system_prompt (str): Base system prompt for aggregation + responses (List[str]): List of responses from reference models + + Returns: + str: Complete system prompt with enumerated responses + """ + response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)]) + return f"{system_prompt}\n\n{response_text}" + + +async def _run_reference_model_safe( + model: str, + user_prompt: str, + temperature: float = REFERENCE_TEMPERATURE, + max_tokens: int = 32000, + max_retries: int = 6 +) -> tuple[str, str, bool]: + """ + Run a single reference model with retry logic and graceful failure handling. + + Args: + model (str): Model identifier to use + user_prompt (str): The user's query + temperature (float): Sampling temperature for response generation + max_tokens (int): Maximum tokens in response + max_retries (int): Maximum number of retry attempts + + Returns: + tuple[str, str, bool]: (model_name, response_content_or_error, success_flag) + """ + for attempt in range(max_retries): + try: + print(f"šŸ¤– Querying {model} (attempt {attempt + 1}/{max_retries})") + + # Build parameters for the API call + api_params = { + "model": model, + "messages": [{"role": "user", "content": user_prompt}], + "extra_body": { + "reasoning": { + "enabled": True, + "effort": "xhigh" + } + } + } + + # GPT models (especially gpt-4o-mini) don't support custom temperature values + # Only include temperature for non-GPT models + if not model.lower().startswith('gpt-'): + api_params["temperature"] = temperature + + response = await openrouter_client.chat.completions.create(**api_params) + + content = response.choices[0].message.content.strip() + print(f"āœ… {model} responded ({len(content)} characters)") + return model, content, True + + except Exception as e: + error_str = str(e) + # Log more detailed error information for debugging + if "invalid" in error_str.lower(): + print(f"āš ļø {model} invalid request error (attempt {attempt + 1}): {error_str}") + elif "rate" in error_str.lower() or "limit" in error_str.lower(): + print(f"āš ļø {model} rate limit error (attempt {attempt + 1}): {error_str}") + else: + print(f"āš ļø {model} unknown error (attempt {attempt + 1}): {error_str}") + + if attempt < max_retries - 1: + # Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s + sleep_time = min(2 ** (attempt + 1), 60) + print(f" Retrying in {sleep_time}s...") + await asyncio.sleep(sleep_time) + else: + error_msg = f"{model} failed after {max_retries} attempts: {error_str}" + print(f"āŒ {error_msg}") + return model, error_msg, False + + +async def _run_aggregator_model( + system_prompt: str, + user_prompt: str, + temperature: float = AGGREGATOR_TEMPERATURE, + max_tokens: int = None +) -> str: + """ + Run the aggregator model to synthesize the final response. + + Args: + system_prompt (str): System prompt with all reference responses + user_prompt (str): Original user query + temperature (float): Focused temperature for consistent aggregation + max_tokens (int): Maximum tokens in final response + + Returns: + str: Synthesized final response + """ + print(f"🧠 Running aggregator model: {AGGREGATOR_MODEL}") + + # Build parameters for the API call + api_params = { + "model": AGGREGATOR_MODEL, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + "extra_body": { + "reasoning": { + "enabled": True, + "effort": "xhigh" + } + } + } + + # GPT models (especially gpt-4o-mini) don't support custom temperature values + # Only include temperature for non-GPT models + if not AGGREGATOR_MODEL.lower().startswith('gpt-'): + api_params["temperature"] = temperature + + response = await openrouter_client.chat.completions.create(**api_params) + + content = response.choices[0].message.content.strip() + print(f"āœ… Aggregation complete ({len(content)} characters)") + return content + + +async def mixture_of_agents_tool( + user_prompt: str, + reference_models: Optional[List[str]] = None, + aggregator_model: Optional[str] = None +) -> str: + """ + Process a complex query using the Mixture-of-Agents methodology. + + This tool leverages multiple frontier language models to collaboratively solve + extremely difficult problems requiring intense reasoning. It's particularly + effective for: + - Complex mathematical proofs and calculations + - Advanced coding problems and algorithm design + - Multi-step analytical reasoning tasks + - Problems requiring diverse domain expertise + - Tasks where single models show limitations + + The MoA approach uses a fixed 2-layer architecture: + 1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6) + 2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4) + + Args: + user_prompt (str): The complex query or problem to solve + reference_models (Optional[List[str]]): Custom reference models to use + aggregator_model (Optional[str]): Custom aggregator model to use + + Returns: + str: JSON string containing the MoA results with the following structure: + { + "success": bool, + "response": str, + "models_used": { + "reference_models": List[str], + "aggregator_model": str + }, + "processing_time": float + } + + Raises: + Exception: If MoA processing fails or API key is not set + """ + start_time = datetime.datetime.now() + + debug_call_data = { + "parameters": { + "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, + "reference_models": reference_models or REFERENCE_MODELS, + "aggregator_model": aggregator_model or AGGREGATOR_MODEL, + "reference_temperature": REFERENCE_TEMPERATURE, + "aggregator_temperature": AGGREGATOR_TEMPERATURE, + "min_successful_references": MIN_SUCCESSFUL_REFERENCES + }, + "error": None, + "success": False, + "reference_responses_count": 0, + "failed_models_count": 0, + "failed_models": [], + "final_response_length": 0, + "processing_time_seconds": 0, + "models_used": {} + } + + try: + print(f"šŸš€ Starting Mixture-of-Agents processing...") + print(f"šŸ“ Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}") + + # Validate API key availability + if not os.getenv("OPENROUTER_API_KEY"): + raise ValueError("OPENROUTER_API_KEY environment variable not set") + + # Use provided models or defaults + ref_models = reference_models or REFERENCE_MODELS + agg_model = aggregator_model or AGGREGATOR_MODEL + + print(f"šŸ”„ Using {len(ref_models)} reference models in 2-layer MoA architecture") + + # Layer 1: Generate diverse responses from reference models (with failure handling) + print("šŸ“” Layer 1: Generating reference responses...") + model_results = await asyncio.gather(*[ + _run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) + for model in ref_models + ]) + + # Separate successful and failed responses + successful_responses = [] + failed_models = [] + + for model_name, content, success in model_results: + if success: + successful_responses.append(content) + else: + failed_models.append(model_name) + + successful_count = len(successful_responses) + failed_count = len(failed_models) + + print(f"šŸ“Š Reference model results: {successful_count} successful, {failed_count} failed") + + if failed_models: + print(f"āš ļø Failed models: {', '.join(failed_models)}") + + # Check if we have enough successful responses to proceed + if successful_count < MIN_SUCCESSFUL_REFERENCES: + raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.") + + debug_call_data["reference_responses_count"] = successful_count + debug_call_data["failed_models_count"] = failed_count + debug_call_data["failed_models"] = failed_models + + # Layer 2: Aggregate responses using the aggregator model + print("🧠 Layer 2: Synthesizing final response...") + aggregator_system_prompt = _construct_aggregator_prompt( + AGGREGATOR_SYSTEM_PROMPT, + successful_responses + ) + + final_response = await _run_aggregator_model( + aggregator_system_prompt, + user_prompt, + AGGREGATOR_TEMPERATURE + ) + + # Calculate processing time + end_time = datetime.datetime.now() + processing_time = (end_time - start_time).total_seconds() + + print(f"āœ… MoA processing completed in {processing_time:.2f} seconds") + + # Prepare successful response (only final aggregated result, minimal fields) + result = { + "success": True, + "response": final_response, + "models_used": { + "reference_models": ref_models, + "aggregator_model": agg_model + } + } + + debug_call_data["success"] = True + debug_call_data["final_response_length"] = len(final_response) + debug_call_data["processing_time_seconds"] = processing_time + debug_call_data["models_used"] = result["models_used"] + + # Log debug information + _log_debug_call("mixture_of_agents_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + except Exception as e: + error_msg = f"Error in MoA processing: {str(e)}" + print(f"āŒ {error_msg}") + + # Calculate processing time even for errors + end_time = datetime.datetime.now() + processing_time = (end_time - start_time).total_seconds() + + # Prepare error response (minimal fields) + result = { + "success": False, + "response": "MoA processing failed. Please try again or use a single model for this query.", + "models_used": { + "reference_models": reference_models or REFERENCE_MODELS, + "aggregator_model": aggregator_model or AGGREGATOR_MODEL + }, + "error": error_msg + } + + debug_call_data["error"] = error_msg + debug_call_data["processing_time_seconds"] = processing_time + _log_debug_call("mixture_of_agents_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + +def check_openrouter_api_key() -> bool: + """ + Check if the OpenRouter API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("OPENROUTER_API_KEY")) + + +def check_moa_requirements() -> bool: + """ + Check if all requirements for MoA tools are met. + + Returns: + bool: True if requirements are met, False otherwise + """ + return check_openrouter_api_key() + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"moa_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +def get_available_models() -> Dict[str, List[str]]: + """ + Get information about available models for MoA processing. + + Returns: + Dict[str, List[str]]: Dictionary with reference and aggregator models + """ + return { + "reference_models": REFERENCE_MODELS, + "aggregator_models": [AGGREGATOR_MODEL], + "supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL] + } + + +def get_moa_configuration() -> Dict[str, Any]: + """ + Get the current MoA configuration settings. + + Returns: + Dict[str, Any]: Dictionary containing all configuration parameters + """ + return { + "reference_models": REFERENCE_MODELS, + "aggregator_model": AGGREGATOR_MODEL, + "reference_temperature": REFERENCE_TEMPERATURE, + "aggregator_temperature": AGGREGATOR_TEMPERATURE, + "min_successful_references": MIN_SUCCESSFUL_REFERENCES, + "total_reference_models": len(REFERENCE_MODELS), + "failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail" + } + + +# ============================================================================= +# Atropos-Agent Tool wrapper (Hermes compatibility) +# ============================================================================= + +from .base import Tool, ToolResult, ToolSchema # noqa: E402 + + +def _tool_result_from_json(output: str) -> ToolResult: + try: + data = json.loads(output) + except Exception: + return ToolResult(success=True, output=output) + + if isinstance(data, dict) and "success" in data: + ok = bool(data.get("success")) + if not ok: + err = data.get("error") or "Mixture-of-Agents failed" + return ToolResult(success=False, output=output, error=str(err)) + return ToolResult(success=True, output=output) + + return ToolResult(success=True, output=output) + + +class MixtureOfAgentsTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="mixture_of_agents", + description=( + "Process extremely difficult problems requiring intense reasoning using a Mixture-of-Agents." + ), + parameters={ + "user_prompt": { + "type": "string", + "description": "The complex query or problem to solve using multiple models.", + } + }, + required=["user_prompt"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + if not os.getenv("OPENROUTER_API_KEY"): + return False, "OPENROUTER_API_KEY not set" + return True, None + + async def execute(self, user_prompt: str) -> ToolResult: + output = await mixture_of_agents_tool(user_prompt=user_prompt) + return _tool_result_from_json(output) + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("šŸ¤– Mixture-of-Agents Tool Module") + print("=" * 50) + + # Check if API key is available + api_available = check_openrouter_api_key() + + if not api_available: + print("āŒ OPENROUTER_API_KEY environment variable not set") + print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'") + print("Get API key at: https://openrouter.ai/") + exit(1) + else: + print("āœ… OpenRouter API key found") + + print("šŸ› ļø MoA tools ready for use!") + + # Show current configuration + config = get_moa_configuration() + print(f"\nāš™ļø Current Configuration:") + print(f" šŸ¤– Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}") + print(f" 🧠 Aggregator model: {config['aggregator_model']}") + print(f" šŸŒ”ļø Reference temperature: {config['reference_temperature']}") + print(f" šŸŒ”ļø Aggregator temperature: {config['aggregator_temperature']}") + print(f" šŸ›”ļø Failure tolerance: {config['failure_tolerance']}") + print(f" šŸ“Š Minimum successful models: {config['min_successful_references']}") + + # Show debug mode status + if DEBUG_MODE: + print(f"\nšŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("\nšŸ› Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from mixture_of_agents_tool import mixture_of_agents_tool") + print(" import asyncio") + print("") + print(" async def main():") + print(" result = await mixture_of_agents_tool(") + print(" user_prompt='Solve this complex mathematical proof...'") + print(" )") + print(" print(result)") + print(" asyncio.run(main())") + + print("\nBest use cases:") + print(" - Complex mathematical proofs and calculations") + print(" - Advanced coding problems and algorithm design") + print(" - Multi-step analytical reasoning tasks") + print(" - Problems requiring diverse domain expertise") + print(" - Tasks where single models show limitations") + + print("\nPerformance characteristics:") + print(" - Higher latency due to multiple model calls") + print(" - Significantly improved quality for complex tasks") + print(" - Parallel processing for efficiency") + print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation") + print(" - Token-efficient: only returns final aggregated response") + print(" - Resilient: continues with partial model failures") + print(f" - Configurable: easy to modify models and settings at top of file") + print(" - State-of-the-art results on challenging benchmarks") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export MOA_TOOLS_DEBUG=true") + print(" # Debug logs capture all MoA processing steps and metrics") + print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json") diff --git a/atropos/tools/terminal_hecate.py b/atropos/tools/terminal_hecate.py new file mode 100644 index 0000000000..fc25eec6d1 --- /dev/null +++ b/atropos/tools/terminal_hecate.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +""" +Terminal Hecate Tool Module + +A terminal tool that executes commands on MorphCloud/Hecate VMs. +Uses E2B-style cloud VMs for execution with automatic lifecycle management. + +Features: +- Direct SSH command execution on cloud VMs +- Background task support +- VM lifecycle management with TTL +- Automatic cleanup after inactivity + +Usage: + from terminal_hecate import terminal_hecate_tool + + # Execute a simple command + result = terminal_hecate_tool("ls -la") + + # Execute in background + result = terminal_hecate_tool("python server.py", background=True) +""" + +import json +import os +import time +import threading +import atexit +from typing import Optional, Dict, Any + +# Tool description for LLM +TERMINAL_HECATE_DESCRIPTION = """Execute commands on a secure cloud Linux VM environment (Hecate/MorphCloud). + +**Environment:** +- Minimal Debian-based OS with internet access +- Automatic VM lifecycle management (creates on-demand, reuses, cleans up) +- Filesystem is persisted between tool calls but environment variables, venvs, etc are reset. + +**Command Execution:** +- Simple commands: Just provide the 'command' parameter +- Background processes: Set 'background': True for servers/long-running tasks +- Command timeout: Optional 'timeout' parameter in seconds + +**Examples:** +- Run command: `{"command": "ls -la"}` +- Background task: `{"command": "source path/to/my/venv/bin/activate && python server.py", "background": True}` +- With timeout: `{"command": "long_task.sh", "timeout": 300}` + +**Best Practices:** +- Run servers/long processes in background +- Monitor disk usage for large tasks +- Install whatever tools you need with sudo apt-get +- Do not be afraid to run pip with --break-system-packages + +**Things to avoid** +- Do NOT use interactive tools such as tmux, vim, nano, python repl - you will get stuck. Even git sometimes becomes interactive if the output is large. If you're not sure pipe to cat. +""" + +# Global state for VM lifecycle management +_active_instances: Dict[str, Any] = {} +_last_activity: Dict[str, float] = {} +_instance_lock = threading.Lock() +_cleanup_thread = None +_cleanup_running = False + + +def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300): + """Clean up VMs that have been inactive for longer than vm_lifetime_seconds.""" + global _active_instances, _last_activity + + current_time = time.time() + tasks_to_cleanup = [] + + with _instance_lock: + for task_id, last_time in list(_last_activity.items()): + if current_time - last_time > vm_lifetime_seconds: + tasks_to_cleanup.append(task_id) + + for task_id in tasks_to_cleanup: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + del _active_instances[task_id] + print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + # 404 errors are benign - VM already cleaned up by TTL + error_str = str(e) + if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower(): + print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)") + else: + print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}") + + # Always remove from tracking dicts to prevent infinite retry loops + if task_id in _active_instances: + del _active_instances[task_id] + if task_id in _last_activity: + del _last_activity[task_id] + + +def _cleanup_thread_worker(): + """Background thread worker that periodically cleans up inactive VMs.""" + global _cleanup_running + + while _cleanup_running: + try: + vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300")) + _cleanup_inactive_vms(vm_lifetime) + except Exception as e: + print(f"[VM Cleanup] Error in cleanup thread: {e}") + + for _ in range(60): + if not _cleanup_running: + break + time.sleep(1) + + +def _start_cleanup_thread(): + """Start the background cleanup thread if not already running.""" + global _cleanup_thread, _cleanup_running + + with _instance_lock: + if _cleanup_thread is None or not _cleanup_thread.is_alive(): + _cleanup_running = True + _cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True) + _cleanup_thread.start() + + +def _stop_cleanup_thread(): + """Stop the background cleanup thread.""" + global _cleanup_running + _cleanup_running = False + if _cleanup_thread is not None: + _cleanup_thread.join(timeout=5) + + +def cleanup_vm(task_id: str): + """Manually clean up a specific VM by task_id.""" + global _active_instances, _last_activity + + with _instance_lock: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + del _active_instances[task_id] + print(f"[VM Cleanup] Manually terminated VM for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + # 404 errors are benign - VM already cleaned up by TTL + error_str = str(e) + if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower(): + print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)") + else: + print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}") + + +atexit.register(_stop_cleanup_thread) + + +def _execute_command(instance, command: str, timeout: Optional[int] = None) -> Dict[str, Any]: + """ + Execute a command on the VM instance using instance.exec() for proper stderr capture. + + Args: + instance: MorphVM instance + command: Command to execute + timeout: Optional timeout in seconds (Note: exec() may not support timeout directly) + + Returns: + dict with stdout, stderr, returncode + """ + try: + # Use instance.exec() which properly captures both stdout and stderr + # (unlike ssh.run() which doesn't capture stderr correctly) + result = instance.exec(command) + + # Debug logging only for verbose mode or unusual cases + # Note: Non-zero exit codes are normal (model's command failed) - not a tool error + if result.exit_code != 0 and not result.stdout and not result.stderr: + # Only log if we got absolutely no output - might indicate an issue + print(f"āš ļø Command returned exit={result.exit_code} with no output") + + return { + "stdout": result.stdout or "", + "stderr": result.stderr or "", + "returncode": result.exit_code + } + + except Exception as e: + # Check if it's a timeout + error_str = str(e).lower() + if "timeout" in error_str: + return { + "stdout": "", + "stderr": f"Command timed out after {timeout or 120} seconds", + "returncode": 124 + } + + return { + "stdout": "", + "stderr": f"Command execution failed: {str(e)}", + "returncode": -1 + } + + +def terminal_hecate_tool( + command: str, + background: bool = False, + timeout: Optional[int] = None, + task_id: Optional[str] = None +) -> str: + """ + Execute a command on a MorphCloud/Hecate VM without session persistence. + + Args: + command: The command to execute + background: Whether to run in background (default: False) + timeout: Command timeout in seconds (default: 120) + task_id: Unique identifier for VM isolation (optional) + + Returns: + str: JSON string with output, exit_code, and error fields + + Examples: + # Execute a simple command + >>> result = terminal_hecate_tool(command="ls -la /tmp") + + # Run a background task + >>> result = terminal_hecate_tool(command="python server.py", background=True) + + # With custom timeout + >>> result = terminal_hecate_tool(command="long_task.sh", timeout=300) + """ + global _active_instances, _last_activity + + try: + # Import required modules + try: + from morphcloud.api import MorphCloudClient + except ImportError as import_error: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Terminal tool disabled: {import_error}", + "status": "disabled" + }, ensure_ascii=False) + + # Get configuration + vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) + snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg") + + # Check API key + morph_api_key = os.getenv("MORPH_API_KEY") + if not morph_api_key: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": "MORPH_API_KEY environment variable not set", + "status": "disabled" + }, ensure_ascii=False) + + # Use task_id for VM isolation + effective_task_id = task_id or "default" + + # Start cleanup thread + _start_cleanup_thread() + + # Get or create VM instance + with _instance_lock: + if effective_task_id not in _active_instances: + morph_client = MorphCloudClient(api_key=morph_api_key) + _active_instances[effective_task_id] = morph_client.instances.start( + snapshot_id=snapshot_id, + ttl_seconds=vm_ttl_seconds, + ttl_action="stop" + ) + + # Update last activity time + _last_activity[effective_task_id] = time.time() + instance = _active_instances[effective_task_id] + + # Wait for instance to be ready + instance.wait_until_ready() + + # Prepare command for execution + if background: + # Run in background with nohup and redirect output + exec_command = f"nohup {command} > /tmp/bg_output.log 2>&1 &" + result = _execute_command(instance, exec_command, timeout=10) + + # For background tasks, return immediately with info + if result["returncode"] == 0: + return json.dumps({ + "output": "Background task started successfully", + "exit_code": 0, + "error": None + }, ensure_ascii=False) + else: + # Include stderr in output but don't set error (command failure, not tool failure) + bg_output = result["stdout"] + if result["stderr"]: + bg_output = f"{bg_output}\n{result['stderr']}" if bg_output else result["stderr"] + return json.dumps({ + "output": bg_output, + "exit_code": result["returncode"], + "error": None # Only set for actual tool failures + }, ensure_ascii=False) + else: + # Run foreground command with retry logic for transient failures + max_retries = 3 + retry_count = 0 + result = None + + while retry_count <= max_retries: + result = _execute_command(instance, command, timeout=timeout) + + # Check if we should retry (only for transient errors, not normal results) + stdout = result.get("stdout", "") + stderr = result.get("stderr", "") + returncode = result.get("returncode", 0) + + should_retry = False + retry_reason = "" + + # NOTE: Empty output with exit_code=0 is NORMAL for many commands: + # - File writes: cat > file, echo > file + # - Directory ops: mkdir, cd + # - Silent installs: pip install --quiet + # So we do NOT retry on exit_code=0, even with empty output. + + # Only retry on special error codes that suggest transient/infra issues + if not stdout and not stderr and returncode in [-1, 124]: + should_retry = True + retry_reason = f"transient error (code {returncode})" + + if should_retry and retry_count < max_retries: + retry_count += 1 + wait_time = 2 ** retry_count # Exponential backoff: 2s, 4s, 8s + print(f"āš ļø Terminal: {retry_reason}, retrying in {wait_time}s (attempt {retry_count}/{max_retries})") + time.sleep(wait_time) + continue + + # Got a result (success or normal command failure) - exit retry loop + break + + # Combine stdout and stderr for output + output = result["stdout"] + if result["stderr"] and result["returncode"] != 0: + output = f"{output}\n{result['stderr']}" if output else result["stderr"] + + # Truncate output if too long (max 50,000 chars to avoid context explosion) + MAX_OUTPUT_CHARS = 50000 + if len(output) > MAX_OUTPUT_CHARS: + truncated_notice = f"\n\n... [OUTPUT TRUNCATED - showing last {MAX_OUTPUT_CHARS} chars of {len(output)} total] ..." + output = truncated_notice + output[-MAX_OUTPUT_CHARS:] + + # NOTE: error is only set for FUNCTIONAL tool failures (VM issues, timeouts, etc.) + # Non-zero exit codes from the model's commands are NOT tool failures - + # the model can self-correct. The exit_code field tells the model if the command succeeded. + # Retries that eventually succeed also don't count as failures. + return json.dumps({ + "output": output.strip(), + "exit_code": result["returncode"], + "error": None # Only set for actual tool failures, not command failures + }, ensure_ascii=False) + + except Exception as e: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Failed to execute command: {str(e)}", + "status": "error" + }, ensure_ascii=False) + + +def check_hecate_requirements() -> bool: + """Check if all requirements for the Hecate terminal tool are met.""" + required_vars = ["MORPH_API_KEY"] + missing_required = [var for var in required_vars if not os.getenv(var)] + + if missing_required: + print(f"Missing required environment variables: {', '.join(missing_required)}") + return False + + try: + from morphcloud.api import MorphCloudClient + return True + except Exception as e: + print(f"MorphCloud not available: {e}") + return False + + +if __name__ == "__main__": + """Simple test when run directly.""" + print("Terminal Hecate Tool Module (MorphCloud/E2B)") + print("=" * 40) + + if not check_hecate_requirements(): + print("Requirements not met. Please check the messages above.") + exit(1) + + print("All requirements met!") + print("\nAvailable Tool:") + print(" - terminal_hecate_tool: Execute commands on cloud VMs") + + print("\nUsage Examples:") + print(" # Execute a command") + print(" result = terminal_hecate_tool(command='ls -la')") + print(" ") + print(" # Run a background task") + print(" result = terminal_hecate_tool(command='python server.py', background=True)") + + print("\nEnvironment Variables:") + print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}") + print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)") + print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)") + print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')}") diff --git a/atropos/tools/terminal_stateful_tool.py b/atropos/tools/terminal_stateful_tool.py new file mode 100644 index 0000000000..0e5786ca82 --- /dev/null +++ b/atropos/tools/terminal_stateful_tool.py @@ -0,0 +1,45 @@ +""" +Stateful terminal tool schema. + +This is a sandbox tool that routes to the sandbox server as `bash_stateful` +via ToolExecutor mapping. It exists to expose an explicit, opt-in terminal +primitive suitable for stateful workflows (e.g. tmux sessions / TUIs). +""" + +from __future__ import annotations + +from typing import Optional + +from .base import Tool, ToolResult, ToolSchema +from .basic_tools import BashTool + + +class TerminalStatefulTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="terminal_stateful", + description=( + "Execute a command in the sandbox, allowing stateful/background processes to persist " + "across tool calls within the same trajectory slot (e.g. tmux sessions). " + "Use sparingly; output is still non-interactive." + ), + parameters={ + "command": {"type": "string", "description": "The command to execute"}, + "timeout": { + "type": "integer", + "description": "Command timeout in seconds (optional).", + "minimum": 1, + }, + }, + required=["command"], + ) + + def is_available(self) -> tuple[bool, str | None]: + return True, None + + async def execute(self, command: str, timeout: Optional[int] = None) -> ToolResult: + # Fallback direct execution (not stateful) when used outside ToolExecutor. + bash = BashTool(timeout=float(timeout) if timeout else 30.0) + return await bash.execute(command=command) + diff --git a/atropos/tools/terminal_tool.py b/atropos/tools/terminal_tool.py new file mode 100644 index 0000000000..a865a274be --- /dev/null +++ b/atropos/tools/terminal_tool.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python3 +""" +Terminal Tool Module (mini-swe-agent backend) + +A terminal tool that executes commands using mini-swe-agent's execution environments. +Supports local execution, Docker containers, and Modal cloud sandboxes. + +Environment Selection (via TERMINAL_ENV environment variable): +- "local": Execute directly on the host machine (default, fastest) +- "docker": Execute in Docker containers (isolated, requires Docker) +- "modal": Execute in Modal cloud sandboxes (scalable, requires Modal account) + +Features: +- Multiple execution backends (local, docker, modal) +- Background task support +- VM/container lifecycle management +- Automatic cleanup after inactivity + +Usage: + from terminal_tool import terminal_tool + + # Execute a simple command + result = terminal_tool("ls -la") + + # Execute in background + result = terminal_tool("python server.py", background=True) +""" + +import json +import os +import sys +import time +import threading +import atexit +from pathlib import Path +from typing import Optional, Dict, Any + +# Add mini-swe-agent to path if not installed +mini_swe_path = Path(__file__).parent.parent / "mini-swe-agent" / "src" +if mini_swe_path.exists(): + sys.path.insert(0, str(mini_swe_path)) + +# Tool description for LLM +TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure Linux environment. + +**Environment:** +- Isolated execution environment (local, Docker, or Modal cloud based on configuration) +- Filesystem persists between tool calls within the same task +- Internet access available + +**Command Execution:** +- Simple commands: Just provide the 'command' parameter +- Background processes: Set 'background': True for servers/long-running tasks +- Command timeout: Optional 'timeout' parameter in seconds + +**Examples:** +- Run command: `{"command": "ls -la"}` +- Background task: `{"command": "source venv/bin/activate && python server.py", "background": True}` +- With timeout: `{"command": "long_task.sh", "timeout": 300}` + +**Best Practices:** +- Run servers/long processes in background +- Monitor disk usage for large tasks +- Install whatever tools you need with apt-get or pip +- Do not be afraid to run pip with --break-system-packages + +**Things to avoid:** +- Do NOT use interactive tools such as tmux, vim, nano, python repl - you will get stuck. +- Even git sometimes becomes interactive if the output is large. If you're not sure, pipe to cat. +""" + +# Global state for environment lifecycle management +_active_environments: Dict[str, Any] = {} +_last_activity: Dict[str, float] = {} +_env_lock = threading.Lock() +_cleanup_thread = None +_cleanup_running = False + +# Configuration from environment variables +def _get_env_config() -> Dict[str, Any]: + """Get terminal environment configuration from environment variables.""" + return { + "env_type": os.getenv("TERMINAL_ENV", "local"), # local, docker, or modal + "docker_image": os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11-slim"), + "modal_image": os.getenv("TERMINAL_MODAL_IMAGE", "python:3.11-slim"), + "cwd": os.getenv("TERMINAL_CWD", "/tmp"), + "timeout": int(os.getenv("TERMINAL_TIMEOUT", "60")), + "lifetime_seconds": int(os.getenv("TERMINAL_LIFETIME_SECONDS", "300")), + } + + +def _create_environment(env_type: str, image: str, cwd: str, timeout: int): + """ + Create an execution environment from mini-swe-agent. + + Args: + env_type: One of "local", "docker", "modal" + image: Docker/Modal image name (ignored for local) + cwd: Working directory + timeout: Default command timeout + + Returns: + Environment instance with execute() method + """ + if env_type == "local": + from minisweagent.environments.local import LocalEnvironment + return LocalEnvironment(cwd=cwd, timeout=timeout) + + elif env_type == "docker": + from minisweagent.environments.docker import DockerEnvironment + return DockerEnvironment(image=image, cwd=cwd, timeout=timeout) + + elif env_type == "modal": + from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment + return SwerexModalEnvironment(image=image, cwd=cwd, timeout=timeout) + + else: + raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', or 'modal'") + + +def _cleanup_inactive_envs(lifetime_seconds: int = 300): + """Clean up environments that have been inactive for longer than lifetime_seconds.""" + global _active_environments, _last_activity + + current_time = time.time() + tasks_to_cleanup = [] + + with _env_lock: + for task_id, last_time in list(_last_activity.items()): + if current_time - last_time > lifetime_seconds: + tasks_to_cleanup.append(task_id) + + for task_id in tasks_to_cleanup: + try: + if task_id in _active_environments: + env = _active_environments[task_id] + # Try various cleanup methods + if hasattr(env, 'cleanup'): + env.cleanup() + elif hasattr(env, 'stop'): + env.stop() + elif hasattr(env, 'terminate'): + env.terminate() + + del _active_environments[task_id] + print(f"[Terminal Cleanup] Cleaned up inactive environment for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + error_str = str(e) + if "404" in error_str or "not found" in error_str.lower(): + print(f"[Terminal Cleanup] Environment for task {task_id} already cleaned up") + else: + print(f"[Terminal Cleanup] Error cleaning up environment for task {task_id}: {e}") + + # Always remove from tracking dicts + if task_id in _active_environments: + del _active_environments[task_id] + if task_id in _last_activity: + del _last_activity[task_id] + + +def _cleanup_thread_worker(): + """Background thread worker that periodically cleans up inactive environments.""" + global _cleanup_running + + while _cleanup_running: + try: + config = _get_env_config() + _cleanup_inactive_envs(config["lifetime_seconds"]) + except Exception as e: + print(f"[Terminal Cleanup] Error in cleanup thread: {e}") + + for _ in range(60): + if not _cleanup_running: + break + time.sleep(1) + + +def _start_cleanup_thread(): + """Start the background cleanup thread if not already running.""" + global _cleanup_thread, _cleanup_running + + with _env_lock: + if _cleanup_thread is None or not _cleanup_thread.is_alive(): + _cleanup_running = True + _cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True) + _cleanup_thread.start() + + +def _stop_cleanup_thread(): + """Stop the background cleanup thread.""" + global _cleanup_running + _cleanup_running = False + if _cleanup_thread is not None: + _cleanup_thread.join(timeout=5) + + +def cleanup_vm(task_id: str): + """Manually clean up a specific environment by task_id.""" + global _active_environments, _last_activity + + with _env_lock: + try: + if task_id in _active_environments: + env = _active_environments[task_id] + if hasattr(env, 'cleanup'): + env.cleanup() + elif hasattr(env, 'stop'): + env.stop() + elif hasattr(env, 'terminate'): + env.terminate() + + del _active_environments[task_id] + print(f"[Terminal Cleanup] Manually cleaned up environment for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + error_str = str(e) + if "404" in error_str or "not found" in error_str.lower(): + print(f"[Terminal Cleanup] Environment for task {task_id} already cleaned up") + else: + print(f"[Terminal Cleanup] Error cleaning up environment for task {task_id}: {e}") + + +atexit.register(_stop_cleanup_thread) + + +def terminal_tool( + command: str, + background: bool = False, + timeout: Optional[int] = None, + task_id: Optional[str] = None +) -> str: + """ + Execute a command using mini-swe-agent's execution environments. + + Args: + command: The command to execute + background: Whether to run in background (default: False) + timeout: Command timeout in seconds (default: from config) + task_id: Unique identifier for environment isolation (optional) + + Returns: + str: JSON string with output, exit_code, and error fields + + Examples: + # Execute a simple command + >>> result = terminal_tool(command="ls -la /tmp") + + # Run a background task + >>> result = terminal_tool(command="python server.py", background=True) + + # With custom timeout + >>> result = terminal_tool(command="long_task.sh", timeout=300) + """ + global _active_environments, _last_activity + + try: + # Get configuration + config = _get_env_config() + env_type = config["env_type"] + + # Select image based on env type + if env_type == "docker": + image = config["docker_image"] + elif env_type == "modal": + image = config["modal_image"] + else: + image = "" + + cwd = config["cwd"] + default_timeout = config["timeout"] + effective_timeout = timeout or default_timeout + + # Use task_id for environment isolation + effective_task_id = task_id or "default" + + # Start cleanup thread + _start_cleanup_thread() + + # Get or create environment + with _env_lock: + if effective_task_id not in _active_environments: + try: + _active_environments[effective_task_id] = _create_environment( + env_type=env_type, + image=image, + cwd=cwd, + timeout=effective_timeout + ) + except ImportError as e: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Terminal tool disabled: mini-swe-agent not available ({e})", + "status": "disabled" + }, ensure_ascii=False) + + # Update last activity time + _last_activity[effective_task_id] = time.time() + env = _active_environments[effective_task_id] + + # Prepare command for execution + if background: + # Run in background with nohup and redirect output + exec_command = f"nohup {command} > /tmp/bg_output.log 2>&1 &" + try: + result = env.execute(exec_command, timeout=10) + return json.dumps({ + "output": "Background task started successfully", + "exit_code": 0, + "error": None + }, ensure_ascii=False) + except Exception as e: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Failed to start background task: {str(e)}" + }, ensure_ascii=False) + else: + # Run foreground command with retry logic + max_retries = 3 + retry_count = 0 + result = None + + while retry_count <= max_retries: + try: + result = env.execute(command, timeout=effective_timeout) + except Exception as e: + error_str = str(e).lower() + if "timeout" in error_str: + return json.dumps({ + "output": "", + "exit_code": 124, + "error": f"Command timed out after {effective_timeout} seconds" + }, ensure_ascii=False) + + # Retry on transient errors + if retry_count < max_retries: + retry_count += 1 + wait_time = 2 ** retry_count + print(f"āš ļø Terminal: execution error, retrying in {wait_time}s (attempt {retry_count}/{max_retries})") + time.sleep(wait_time) + continue + + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Command execution failed: {str(e)}" + }, ensure_ascii=False) + + # Got a result + break + + # Extract output + output = result.get("output", "") + returncode = result.get("returncode", 0) + + # Truncate output if too long + MAX_OUTPUT_CHARS = 50000 + if len(output) > MAX_OUTPUT_CHARS: + truncated_notice = f"\n\n... [OUTPUT TRUNCATED - showing last {MAX_OUTPUT_CHARS} chars of {len(output)} total] ..." + output = truncated_notice + output[-MAX_OUTPUT_CHARS:] + + return json.dumps({ + "output": output.strip() if output else "", + "exit_code": returncode, + "error": None + }, ensure_ascii=False) + + except Exception as e: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Failed to execute command: {str(e)}", + "status": "error" + }, ensure_ascii=False) + + +def check_terminal_requirements() -> bool: + """Check if all requirements for the terminal tool are met.""" + config = _get_env_config() + env_type = config["env_type"] + + try: + if env_type == "local": + from minisweagent.environments.local import LocalEnvironment + return True + elif env_type == "docker": + from minisweagent.environments.docker import DockerEnvironment + # Check if docker is available + import subprocess + result = subprocess.run(["docker", "version"], capture_output=True, timeout=5) + return result.returncode == 0 + elif env_type == "modal": + from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment + # Check for modal token + return os.getenv("MODAL_TOKEN_ID") is not None or Path.home().joinpath(".modal.toml").exists() + else: + return False + except Exception as e: + print(f"Terminal requirements check failed: {e}") + return False + + +# ============================================================================= +# Atropos-Agent Tool wrapper (Hermes compatibility) +# ============================================================================= + +from .base import Tool, ToolResult, ToolSchema # noqa: E402 +from .basic_tools import BashTool # noqa: E402 + + +class TerminalTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="terminal", + description=TERMINAL_TOOL_DESCRIPTION, + parameters={ + "command": {"type": "string", "description": "The command to execute"}, + "background": { + "type": "boolean", + "description": "Run the command in the background (not supported in sandbox).", + "default": False, + }, + "timeout": { + "type": "integer", + "description": "Command timeout in seconds (optional).", + "minimum": 1, + }, + }, + required=["command"], + ) + + def is_available(self) -> tuple[bool, str | None]: + # The canonical terminal tool in atropos-agent is sandboxed; availability is governed + # by whether the environment includes this tool and its own network policy. + return True, None + + async def execute( + self, + command: str, + background: bool = False, + timeout: Optional[int] = None, + ) -> ToolResult: + if background: + return ToolResult(success=False, error="background execution is not supported in sandboxed terminal") + bash = BashTool(timeout=float(timeout) if timeout else 30.0) + return await bash.execute(command=command) + + +if __name__ == "__main__": + """Simple test when run directly.""" + print("Terminal Tool Module (mini-swe-agent backend)") + print("=" * 50) + + config = _get_env_config() + print(f"\nCurrent Configuration:") + print(f" Environment type: {config['env_type']}") + print(f" Docker image: {config['docker_image']}") + print(f" Modal image: {config['modal_image']}") + print(f" Working directory: {config['cwd']}") + print(f" Default timeout: {config['timeout']}s") + print(f" Lifetime: {config['lifetime_seconds']}s") + + if not check_terminal_requirements(): + print("\nāŒ Requirements not met. Please check the messages above.") + exit(1) + + print("\nāœ… All requirements met!") + print("\nAvailable Tool:") + print(" - terminal_tool: Execute commands using mini-swe-agent environments") + + print("\nUsage Examples:") + print(" # Execute a command") + print(" result = terminal_tool(command='ls -la')") + print(" ") + print(" # Run a background task") + print(" result = terminal_tool(command='python server.py', background=True)") + + print("\nEnvironment Variables:") + print(f" TERMINAL_ENV: {os.getenv('TERMINAL_ENV', 'local')} (local/docker/modal)") + print(f" TERMINAL_DOCKER_IMAGE: {os.getenv('TERMINAL_DOCKER_IMAGE', 'python:3.11-slim')}") + print(f" TERMINAL_MODAL_IMAGE: {os.getenv('TERMINAL_MODAL_IMAGE', 'python:3.11-slim')}") + print(f" TERMINAL_CWD: {os.getenv('TERMINAL_CWD', '/tmp')}") + print(f" TERMINAL_TIMEOUT: {os.getenv('TERMINAL_TIMEOUT', '60')}") + print(f" TERMINAL_LIFETIME_SECONDS: {os.getenv('TERMINAL_LIFETIME_SECONDS', '300')}") diff --git a/atropos/tools/tmux_tool.py b/atropos/tools/tmux_tool.py new file mode 100644 index 0000000000..e7c4a91a92 --- /dev/null +++ b/atropos/tools/tmux_tool.py @@ -0,0 +1,89 @@ +""" +tmux tool schema (sandbox). + +This is a sandbox tool that provides basic tmux session control suitable for +TUI-style terminal interactions: +- send keys (arrow keys, enter, etc.) +- capture the current screen buffer + +Execution is routed by ToolExecutor to the sandbox server's `tmux` backend. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from .base import Tool, ToolResult, ToolSchema + + +class TmuxTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="tmux", + description=( + "Control a per-trajectory tmux session inside the sandbox (stateful terminal). " + "Use this for TUI-style interactions: send keys and capture the current screen." + ), + parameters={ + "action": { + "type": "string", + "description": "Action to perform: start | send_keys | stream | stop.", + "enum": ["start", "send_keys", "stream", "stop", "capture"], + }, + "keys": { + "description": "Keys to send (string or list of strings) when action=send_keys.", + }, + "block": { + "type": "boolean", + "description": "If true, wait for shell command completion (only valid at a shell prompt).", + "default": False, + }, + "min_wait_s": { + "type": "number", + "description": "For non-blocking send_keys, sleep this long after sending keys (seconds).", + "default": 0.0, + }, + "max_wait_s": { + "type": "number", + "description": "For blocking send_keys, max time to wait for completion (seconds).", + }, + "capture_entire": { + "type": "boolean", + "description": "Deprecated. Streaming is preferred.", + "default": False, + }, + "max_bytes": { + "type": "integer", + "description": "Max bytes to return per stream call.", + }, + "reset": { + "type": "boolean", + "description": "If true, reset stream offset to the beginning of the asciinema recording.", + "default": False, + }, + "pane_width": { + "type": "integer", + "description": "Pane width for action=start (columns).", + "minimum": 20, + }, + "pane_height": { + "type": "integer", + "description": "Pane height for action=start (rows).", + "minimum": 10, + }, + }, + required=["action"], + ) + + def is_available(self) -> tuple[bool, str | None]: + return True, None + + async def execute(self, **kwargs: Dict[str, Any]) -> ToolResult: + # This tool is intended to be executed via ToolExecutor -> sandbox server. + # We keep a safe fallback for non-sandbox contexts. + action = str(kwargs.get("action") or "").strip() + return ToolResult( + success=False, + error=f"tmux tool must be executed in the sandbox (got action={action!r})", + ) diff --git a/atropos/tools/tool_executor.py b/atropos/tools/tool_executor.py new file mode 100644 index 0000000000..148354db3b --- /dev/null +++ b/atropos/tools/tool_executor.py @@ -0,0 +1,501 @@ +""" +ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories. + +This component is responsible for: +- Maintaining trajectory -> Slot affinity (workspace continuity) +- Batching sandbox tool calls across trajectories to maximize container utilization +- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5) + +For now, only sandbox tools are executed: +- bash +- read_file +- write_file +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import httpx + +from .base import ( + ArtifactArchiveRequestPayload, + ArtifactArchiveResponsePayload, + ArtifactListRequestPayload, + ArtifactListResponsePayload, + ArtifactReadRequestPayload, + ArtifactReadResponsePayload, + ToolCall, + ToolCallPayload, + ToolRegistry, + ToolResult, + ToolResultPayload, + ToolServerExecuteRequest, +) +from ..slots import Slot, SlotPool + + +@dataclass +class ToolExecutorConfig: + batch_window_ms: int = 20 + max_batch_size: int = 200 + allow_network: bool = True + require_sandbox: bool = False + require_stateful_sandbox: bool = False + tool_server_url: Optional[str] = None + tool_server_token: Optional[str] = None + + +@dataclass +class _QueuedToolRequest: + trajectory_id: str + call: ToolCall + timeout_s: Optional[float] + future: asyncio.Future + + +class ToolExecutor: + def __init__( + self, + pool: SlotPool, + tools: ToolRegistry, + config: Optional[ToolExecutorConfig] = None, + ) -> None: + self.pool = pool + self.tools = tools + self.config = config or ToolExecutorConfig() + + self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue() + self._task: Optional[asyncio.Task] = None + self._stopping = asyncio.Event() + + self._slots_lock = asyncio.Lock() + self._slot_by_trajectory: Dict[str, Slot] = {} + + self._tool_server_client: Optional[httpx.AsyncClient] = None + self._tool_server_lock = asyncio.Lock() + + # lightweight stats for status endpoints + self.total_requests: int = 0 + self.total_errors: int = 0 + self.latencies_s: List[float] = [] + + async def start(self) -> None: + if self._task is None: + self._task = asyncio.create_task(self._run_loop()) + + def queue_size(self) -> int: + return self._queue.qsize() + + async def close(self) -> None: + self._stopping.set() + await self._queue.put(None) + if self._task: + await self._task + self._task = None + + client = self._tool_server_client + self._tool_server_client = None + if client is not None: + await client.aclose() + + # Best-effort release any remaining slots. + async with self._slots_lock: + slots = list(self._slot_by_trajectory.items()) + self._slot_by_trajectory.clear() + + for _, slot in slots: + try: + await self.pool.release(slot, reset_workspace=False) + except Exception: + pass + + async def execute( + self, + trajectory_id: str, + call: ToolCall, + timeout_s: Optional[float] = None, + ) -> ToolResult: + if self._task is None: + raise RuntimeError("ToolExecutor not started (call start() first)") + + # Allow tool args to suggest a timeout (Hermes-compatible terminal tool), + # but never let the model choose "infinite" timeouts. + if timeout_s is None: + raw_timeout = call.arguments.get("timeout") + if isinstance(raw_timeout, (int, float)): + timeout_s = float(raw_timeout) + if timeout_s is not None: + timeout_s = max(1.0, min(float(timeout_s), 600.0)) + + loop = asyncio.get_running_loop() + fut: asyncio.Future = loop.create_future() + started = time.perf_counter() + await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut)) + try: + result: ToolResult = await fut + return result + finally: + self.latencies_s.append(time.perf_counter() - started) + + async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None: + async with self._slots_lock: + slot = self._slot_by_trajectory.pop(trajectory_id, None) + + if slot is not None: + await self.pool.release(slot, reset_workspace=reset_workspace) + + async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]: + async with self._slots_lock: + return self._slot_by_trajectory.get(trajectory_id) + + # --------------------------------------------------------------------- + # Artifact helpers (optional) + # --------------------------------------------------------------------- + + async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload: + slot = await self._get_slot_if_present(req.trajectory_id) + if slot is None: + return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") + data = await self.pool.executor.read_artifact( + slot, + req.path, + encoding=req.encoding, + max_bytes=req.max_bytes, + include_sha256=req.include_sha256, + ) + if isinstance(data, dict): + data = dict(data) + data.pop("http_status", None) + try: + return ArtifactReadResponsePayload(**(data or {})) + except Exception as e: + return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}") + + async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload: + slot = await self._get_slot_if_present(req.trajectory_id) + if slot is None: + return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") + data = await self.pool.executor.list_artifacts( + slot, + req.path, + recursive=req.recursive, + max_entries=req.max_entries, + ) + if isinstance(data, dict): + data = dict(data) + data.pop("http_status", None) + try: + return ArtifactListResponsePayload(**(data or {})) + except Exception as e: + return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}") + + async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload: + slot = await self._get_slot_if_present(req.trajectory_id) + if slot is None: + return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)") + data = await self.pool.executor.archive_artifacts( + slot, + req.path, + archive_format=req.format, + max_bytes=req.max_bytes, + max_entries=req.max_entries, + ) + if isinstance(data, dict): + data = dict(data) + data.pop("http_status", None) + try: + return ArtifactArchiveResponsePayload(**(data or {})) + except Exception as e: + return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}") + + async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot: + async with self._slots_lock: + existing = self._slot_by_trajectory.get(trajectory_id) + if existing is not None: + return existing + + slot = await self.pool.acquire(trajectory_id) + + async with self._slots_lock: + existing = self._slot_by_trajectory.get(trajectory_id) + if existing is not None: + # Another coroutine won the race; return its slot. + await self.pool.release(slot, reset_workspace=False) + return existing + self._slot_by_trajectory[trajectory_id] = slot + return slot + + async def _run_loop(self) -> None: + pending: List[_QueuedToolRequest] = [] + deadline: Optional[float] = None + + batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0) + max_batch = max(1, self.config.max_batch_size) + + while True: + if self._stopping.is_set() and self._queue.empty() and not pending: + break + + timeout = None + if pending and deadline is not None: + timeout = max(0.0, deadline - time.perf_counter()) + + try: + item = await asyncio.wait_for(self._queue.get(), timeout=timeout) + if item is None: + continue + pending.append(item) + if len(pending) == 1: + deadline = time.perf_counter() + batch_window_s + if len(pending) < max_batch: + continue + except asyncio.TimeoutError: + # batch window elapsed + pass + + if not pending: + deadline = None + continue + + batch = pending + pending = [] + deadline = None + + await self._execute_batch(batch) + + async def _get_tool_server_client(self) -> httpx.AsyncClient: + url = self.config.tool_server_url + if not url: + raise RuntimeError("ToolServer not configured") + + if self._tool_server_client is not None: + return self._tool_server_client + + async with self._tool_server_lock: + if self._tool_server_client is None: + self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/")) + return self._tool_server_client + + def _tool_server_headers(self) -> Dict[str, str]: + token = self.config.tool_server_token + if not token: + return {} + return {"Authorization": f"Bearer {token}"} + + async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult: + client = await self._get_tool_server_client() + slot_id: Optional[str] = None + container_addr: Optional[str] = None + slot = await self._get_slot_if_present(req.trajectory_id) + if slot is not None: + slot_id = slot.slot_id + container_addr = slot.container_addr + + payload = ToolServerExecuteRequest( + trajectory_id=req.trajectory_id, + tool=ToolCallPayload.from_tool_call(req.call), + timeout_s=req.timeout_s, + slot_id=slot_id, + container_addr=container_addr, + ) + + try: + resp = await client.post( + "/execute", + json=payload.model_dump(), + headers=self._tool_server_headers(), + timeout=req.timeout_s, + ) + resp.raise_for_status() + data = resp.json() + parsed = ToolResultPayload(**data) + result = parsed.to_tool_result() + if result.uniq_id is None: + result.uniq_id = req.call.uniq_id + return result + except Exception as e: + return ToolResult( + success=False, + error=f"External tool failed: {e}", + uniq_id=req.call.uniq_id, + ) + + async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None: + # Resolve tool schemas once per request and separate sandbox/external/unknown. + sandbox_items: List[_QueuedToolRequest] = [] + external_items: List[_QueuedToolRequest] = [] + unknown_items: List[_QueuedToolRequest] = [] + + for it in batch: + tool = self.tools.get(it.call.name) + if tool is None: + unknown_items.append(it) + continue + + schema = tool.schema + if not schema.external: + sandbox_items.append(it) + else: + external_items.append(it) + + for it in unknown_items: + self.total_requests += 1 + self.total_errors += 1 + if not it.future.done(): + it.future.set_result( + ToolResult( + success=False, + error=f"Unknown tool: {it.call.name}", + uniq_id=it.call.uniq_id, + ) + ) + + if external_items: + if not self.config.tool_server_url: + for it in external_items: + self.total_requests += 1 + self.total_errors += 1 + if not it.future.done(): + it.future.set_result( + ToolResult( + success=False, + error=f"External tool not available (ToolServer not configured): {it.call.name}", + uniq_id=it.call.uniq_id, + ) + ) + else: + results = await asyncio.gather(*[self._execute_external(it) for it in external_items]) + for it, res in zip(external_items, results): + self.total_requests += 1 + if not getattr(res, "success", False): + self.total_errors += 1 + if not it.future.done(): + it.future.set_result(res) + + if not sandbox_items: + return + + # Acquire slots for the distinct trajectories in this batch. + try: + traj_ids = list({it.trajectory_id for it in sandbox_items}) + slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids]) + slot_by_traj = dict(zip(traj_ids, slots)) + except Exception as e: + for it in sandbox_items: + self.total_requests += 1 + self.total_errors += 1 + if not it.future.done(): + it.future.set_result( + ToolResult( + success=False, + error=f"Failed to acquire slot: {e}", + uniq_id=it.call.uniq_id, + ) + ) + return + + # Group by timeout so we don't accidentally make short timeouts wait on long ones. + by_timeout: Dict[float, List[_QueuedToolRequest]] = {} + default_timeout = None + if self.pool.executor.timeout.total is not None: + default_timeout = float(self.pool.executor.timeout.total) + + for it in sandbox_items: + t = it.timeout_s + if t is None: + t = default_timeout + if t is None: + t = 30.0 + by_timeout.setdefault(float(t), []).append(it) + + for timeout_s, items in by_timeout.items(): + requests = [] + dispatched: List[_QueuedToolRequest] = [] + for it in items: + slot = slot_by_traj[it.trajectory_id] + tool_name = it.call.name + args = dict(it.call.arguments) + + # Hermes compatibility: treat `terminal` as an alias of sandbox `bash`. + if tool_name == "terminal": + if args.get("background"): + self.total_requests += 1 + self.total_errors += 1 + if not it.future.done(): + it.future.set_result( + ToolResult( + success=False, + error="terminal background execution is not supported in sandbox", + uniq_id=it.call.uniq_id, + ) + ) + continue + tool_name = "bash" + # `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args. + args.pop("timeout", None) + elif tool_name == "terminal_stateful": + tool_name = "bash_stateful" + args.pop("timeout", None) + elif tool_name == "tmux": + # `tmux` is a sandbox tool backed by the stateful session manager. + # Network policy is env-controlled. + args.pop("allow_network", None) + + if tool_name == "bash": + # Network policy is set by the environment/executor, not by the model. + args.pop("allow_network", None) + args.pop("require_sandbox", None) + args["allow_network"] = bool(self.config.allow_network) + args["require_sandbox"] = bool(self.config.require_sandbox) + # `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args. + args.pop("timeout", None) + elif tool_name == "bash_stateful": + # Network policy is set by the environment/executor, not by the model. + args.pop("allow_network", None) + args.pop("require_sandbox", None) + args.pop("require_stateful_sandbox", None) + args["allow_network"] = bool(self.config.allow_network) + args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox) + args.pop("timeout", None) + elif tool_name == "tmux": + # Network policy applies to the underlying stateful session. + args.pop("allow_network", None) + args.pop("require_sandbox", None) + args.pop("require_stateful_sandbox", None) + args["allow_network"] = bool(self.config.allow_network) + args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox) + + requests.append((slot, tool_name, args)) + dispatched.append(it) + + results = None + try: + if not dispatched: + continue + results = await self.pool.execute_batch(requests, timeout=timeout_s) + except Exception as e: + for it in items: + self.total_requests += 1 + self.total_errors += 1 + if not it.future.done(): + it.future.set_result( + ToolResult( + success=False, + error=f"Batch execution failed: {e}", + uniq_id=it.call.uniq_id, + ) + ) + continue + + for it, res in zip(dispatched, results): + self.total_requests += 1 + if not getattr(res, "success", False): + self.total_errors += 1 + tool_result = res.to_tool_result() + tool_result.uniq_id = it.call.uniq_id + if not it.future.done(): + it.future.set_result(tool_result) diff --git a/atropos/tools/toolset_distributions.py b/atropos/tools/toolset_distributions.py new file mode 100644 index 0000000000..885649e1ab --- /dev/null +++ b/atropos/tools/toolset_distributions.py @@ -0,0 +1,68 @@ +""" +Toolset distributions (Hermes-Agent inspired). + +Distributions are optional helpers for data generation runs where you want +probabilistic inclusion of toolsets per trajectory/item. +""" + +from __future__ import annotations + +import random +from typing import Dict, List, Optional, TypedDict + +from .toolsets import validate_toolset + + +class DistributionDef(TypedDict): + description: str + toolsets: Dict[str, int] + + +DISTRIBUTIONS: Dict[str, DistributionDef] = { + "default": { + "description": "All common sandbox tools.", + "toolsets": {"sandbox": 100}, + }, + "code_agent_plus_image": { + "description": "Sandbox tools with optional image generation.", + "toolsets": {"sandbox": 100, "image_gen": 30}, + }, + "sandbox_only": { + "description": "Only sandbox tools (terminal + filesystem).", + "toolsets": {"sandbox": 100}, + }, +} + + +def get_distribution(name: str) -> Optional[DistributionDef]: + return DISTRIBUTIONS.get(name) + + +def list_distributions() -> Dict[str, DistributionDef]: + return DISTRIBUTIONS.copy() + + +def validate_distribution(name: str) -> bool: + return name in DISTRIBUTIONS + + +def sample_toolsets_from_distribution(distribution_name: str) -> List[str]: + dist = get_distribution(distribution_name) + if not dist: + raise ValueError(f"Unknown distribution: {distribution_name}") + + selected: List[str] = [] + for toolset_name, probability in dist["toolsets"].items(): + if not validate_toolset(toolset_name): + continue + if random.random() * 100 < probability: + selected.append(toolset_name) + + # Ensure at least one toolset if the distribution isn't empty. + if not selected and dist["toolsets"]: + highest = max(dist["toolsets"].items(), key=lambda x: x[1])[0] + if validate_toolset(highest): + selected.append(highest) + + return selected + diff --git a/atropos/tools/toolsets.py b/atropos/tools/toolsets.py new file mode 100644 index 0000000000..ddd4df0586 --- /dev/null +++ b/atropos/tools/toolsets.py @@ -0,0 +1,157 @@ +""" +Toolsets (Hermes-Agent inspired). + +Toolsets are named groups of tools with optional composition (includes). +They are used to decide which tools are advertised to the model and/or enabled +for a particular environment run. + +This module is intentionally lightweight and dependency-free. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Set, TypedDict + + +class ToolsetDef(TypedDict): + description: str + tools: List[str] + includes: List[str] + + +TOOLSETS: Dict[str, ToolsetDef] = { + # Primitive building blocks + "filesystem": { + "description": "Read/write files in the current workspace.", + "tools": ["read_file", "write_file"], + "includes": [], + }, + "terminal": { + "description": "Terminal/command execution tools.", + # Prefer `terminal` for Hermes compatibility; keep `bash` as a legacy alias. + "tools": ["terminal", "bash"], + "includes": [], + }, + "terminal_stateful": { + "description": "Stateful terminal execution (enables persistent background processes like tmux).", + "tools": ["terminal_stateful", "tmux"], + "includes": [], + }, + "sandbox": { + "description": "Standard sandbox tools (terminal + filesystem).", + "tools": [], + "includes": ["terminal", "filesystem"], + }, + # External tools (executed via ToolServer) + "web": { + "description": "Web research and content extraction tools (external).", + "tools": ["web_search", "web_extract", "web_crawl"], + "includes": [], + }, + "vision": { + "description": "Vision/image analysis tools (external).", + "tools": ["vision_analyze"], + "includes": [], + }, + "image_gen": { + "description": "Image generation tools (external).", + "tools": ["image_generate"], + "includes": [], + }, + "moa": { + "description": "Advanced reasoning tools (Mixture-of-Agents, external).", + "tools": ["mixture_of_agents"], + "includes": [], + }, + # Convenience presets + "default": { + "description": "Default toolset for code-agent tasks.", + "tools": [], + "includes": ["sandbox"], + }, + "debugging": { + "description": "Debugging toolkit (terminal + web).", + "tools": [], + "includes": ["sandbox", "web"], + }, + "research": { + "description": "Research toolkit (web + vision + reasoning).", + "tools": [], + "includes": ["web", "vision", "moa"], + }, + "safe": { + "description": "Safe toolkit without terminal access.", + "tools": [], + "includes": ["web", "vision", "image_gen", "moa"], + }, + "full": { + "description": "All common tools (sandbox + external).", + "tools": [], + "includes": ["sandbox", "web", "vision", "image_gen", "moa"], + }, +} + + +def get_toolset(name: str) -> Optional[ToolsetDef]: + return TOOLSETS.get(name) + + +def get_toolset_names() -> List[str]: + return list(TOOLSETS.keys()) + + +def validate_toolset(name: str) -> bool: + return name in {"all", "*"} or name in TOOLSETS + + +def resolve_toolset(name: str, visited: Optional[Set[str]] = None) -> List[str]: + """ + Recursively resolve a toolset to a list of tool names. + + Includes are expanded depth-first with cycle protection. + """ + if visited is None: + visited = set() + + if name in {"all", "*"}: + all_tools: Set[str] = set() + for toolset_name in get_toolset_names(): + all_tools.update(resolve_toolset(toolset_name, visited=set())) + return sorted(all_tools) + + if name in visited: + # Cycle: return empty to avoid infinite recursion. + return [] + + visited.add(name) + toolset = TOOLSETS.get(name) + if toolset is None: + return [] + + tools: Set[str] = set(toolset.get("tools", [])) + for included in toolset.get("includes", []): + tools.update(resolve_toolset(included, visited=set(visited))) + return sorted(tools) + + +def resolve_multiple_toolsets(toolset_names: List[str]) -> List[str]: + tools: Set[str] = set() + for name in toolset_names: + tools.update(resolve_toolset(name)) + return sorted(tools) + + +def get_toolset_info(name: str) -> Optional[Dict[str, Any]]: + toolset = get_toolset(name) + if toolset is None: + return None + resolved = resolve_toolset(name) + return { + "name": name, + "description": toolset["description"], + "direct_tools": toolset["tools"], + "includes": toolset["includes"], + "resolved_tools": resolved, + "tool_count": len(resolved), + "is_composite": bool(toolset["includes"]), + } diff --git a/atropos/tools/vision_tools.py b/atropos/tools/vision_tools.py new file mode 100644 index 0000000000..1adc0667a8 --- /dev/null +++ b/atropos/tools/vision_tools.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +""" +Vision Tools Module + +This module provides vision analysis tools that work with image URLs. +Uses Gemini 3 Flash Preview via OpenRouter API for intelligent image understanding. + +Available tools: +- vision_analyze_tool: Analyze images from URLs with custom prompts + +Features: +- Downloads images from URLs and converts to base64 for API compatibility +- Comprehensive image description +- Context-aware analysis based on user queries +- Automatic temporary file cleanup +- Proper error handling and validation +- Debug logging support + +Usage: + from vision_tools import vision_analyze_tool + import asyncio + + # Analyze an image + result = await vision_analyze_tool( + image_url="https://example.com/image.jpg", + user_prompt="What architectural style is this building?" + ) +""" + +import json +import os +import asyncio +import uuid +import datetime +import base64 +from pathlib import Path +from typing import Dict, Any, Optional +from openai import AsyncOpenAI +import httpx # Use httpx for async HTTP requests + +# Initialize OpenRouter API client for vision processing +openrouter_client = AsyncOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1" +) + +# Configuration for vision processing +DEFAULT_VISION_MODEL = "google/gemini-3-flash-preview" + +# Debug mode configuration +DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› Vision debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving vision debug log: {str(e)}") + + +def _validate_image_url(url: str) -> bool: + """ + Basic validation of image URL format. + + Args: + url (str): The URL to validate + + Returns: + bool: True if URL appears to be valid, False otherwise + """ + if not url or not isinstance(url, str): + return False + + # Check if it's a valid URL format + if not (url.startswith('http://') or url.startswith('https://')): + return False + + # Check for common image extensions (optional, as URLs may not have extensions) + image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg'] + + return True # Allow all HTTP/HTTPS URLs for flexibility + + +async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path: + """ + Download an image from a URL to a local destination (async) with retry logic. + + Args: + image_url (str): The URL of the image to download + destination (Path): The path where the image should be saved + max_retries (int): Maximum number of retry attempts (default: 3) + + Returns: + Path: The path to the downloaded image + + Raises: + Exception: If download fails after all retries + """ + import asyncio + + # Create parent directories if they don't exist + destination.parent.mkdir(parents=True, exist_ok=True) + + last_error = None + for attempt in range(max_retries): + try: + # Download the image with appropriate headers using async httpx + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get( + image_url, + headers={"User-Agent": "hermes-agent-vision/1.0"}, + ) + response.raise_for_status() + + # Save the image content + destination.write_bytes(response.content) + + return destination + except Exception as e: + last_error = e + if attempt < max_retries - 1: + wait_time = 2 ** (attempt + 1) # 2s, 4s, 8s + print(f"āš ļø Image download failed (attempt {attempt + 1}/{max_retries}): {str(e)[:50]}") + print(f" Retrying in {wait_time}s...") + await asyncio.sleep(wait_time) + else: + print(f"āŒ Image download failed after {max_retries} attempts: {str(e)[:100]}") + + raise last_error + + +def _determine_mime_type(image_path: Path) -> str: + """ + Determine the MIME type of an image based on its file extension. + + Args: + image_path (Path): Path to the image file + + Returns: + str: The MIME type (defaults to image/jpeg if unknown) + """ + extension = image_path.suffix.lower() + mime_types = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + '.svg': 'image/svg+xml' + } + return mime_types.get(extension, 'image/jpeg') + + +def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str: + """ + Convert an image file to a base64-encoded data URL. + + Args: + image_path (Path): Path to the image file + mime_type (Optional[str]): MIME type of the image (auto-detected if None) + + Returns: + str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...") + """ + # Read the image as bytes + data = image_path.read_bytes() + + # Encode to base64 + encoded = base64.b64encode(data).decode("ascii") + + # Determine MIME type + mime = mime_type or _determine_mime_type(image_path) + + # Create data URL + data_url = f"data:{mime};base64,{encoded}" + + return data_url + + +async def vision_analyze_tool( + image_url: str, + user_prompt: str, + model: str = DEFAULT_VISION_MODEL +) -> str: + """ + Analyze an image from a URL using vision AI. + + This tool downloads images from URLs, converts them to base64, and processes + them using Gemini 3 Flash Preview via OpenRouter API. The image is downloaded to a + temporary location and automatically cleaned up after processing. + + The user_prompt parameter is expected to be pre-formatted by the calling + function (typically model_tools.py) to include both full description + requests and specific questions. + + Args: + image_url (str): The URL of the image to analyze (must be http:// or https://) + user_prompt (str): The pre-formatted prompt for the vision model + model (str): The vision model to use (default: google/gemini-3-flash-preview) + + Returns: + str: JSON string containing the analysis results with the following structure: + { + "success": bool, + "analysis": str (defaults to error message if None) + } + + Raises: + Exception: If download fails, analysis fails, or API key is not set + + Note: + - Temporary images are stored in ./temp_vision_images/ + - Images are automatically deleted after processing + - Supports common image formats (JPEG, PNG, GIF, WebP, etc.) + """ + debug_call_data = { + "parameters": { + "image_url": image_url, + "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, + "model": model + }, + "error": None, + "success": False, + "analysis_length": 0, + "model_used": model, + "image_size_bytes": 0 + } + + temp_image_path = None + + try: + print(f"šŸ” Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}", flush=True) + print(f"šŸ“ User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}", flush=True) + + # Validate image URL + if not _validate_image_url(image_url): + raise ValueError("Invalid image URL format. Must start with http:// or https://") + + # Check API key availability + if not os.getenv("OPENROUTER_API_KEY"): + raise ValueError("OPENROUTER_API_KEY environment variable not set") + + # Download the image to a temporary location + print(f"ā¬‡ļø Downloading image from URL...", flush=True) + temp_dir = Path("./temp_vision_images") + temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg" + + await _download_image(image_url, temp_image_path) + + # Get image file size for logging + image_size_bytes = temp_image_path.stat().st_size + image_size_kb = image_size_bytes / 1024 + print(f"āœ… Image downloaded successfully ({image_size_kb:.1f} KB)", flush=True) + + # Convert image to base64 data URL + print(f"šŸ”„ Converting image to base64...", flush=True) + image_data_url = _image_to_base64_data_url(temp_image_path) + # Calculate size in KB for better readability + data_size_kb = len(image_data_url) / 1024 + print(f"āœ… Image converted to base64 ({data_size_kb:.1f} KB)", flush=True) + + debug_call_data["image_size_bytes"] = image_size_bytes + + # Use the prompt as provided (model_tools.py now handles full description formatting) + comprehensive_prompt = user_prompt + + # Prepare the message with base64-encoded image + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": comprehensive_prompt + }, + { + "type": "image_url", + "image_url": { + "url": image_data_url + } + } + ] + } + ] + + print(f"🧠 Processing image with {model}...", flush=True) + + # Call the vision API with reasoning enabled + response = await openrouter_client.chat.completions.create( + model=model, + messages=messages, + temperature=0.1, # Low temperature for consistent analysis + max_tokens=2000, # Generous limit for detailed analysis + extra_body={ + "reasoning": { + "enabled": True, + "effort": "xhigh" + } + } + ) + + # Extract the analysis + analysis = response.choices[0].message.content.strip() + analysis_length = len(analysis) + + print(f"āœ… Image analysis completed ({analysis_length} characters)", flush=True) + + # Prepare successful response + result = { + "success": True, + "analysis": analysis or "There was a problem with the request and the image could not be analyzed." + } + + debug_call_data["success"] = True + debug_call_data["analysis_length"] = analysis_length + + # Log debug information + _log_debug_call("vision_analyze_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + except Exception as e: + error_msg = f"Error analyzing image: {str(e)}" + print(f"āŒ {error_msg}", flush=True) + + # Prepare error response + result = { + "success": False, + "analysis": "There was a problem with the request and the image could not be analyzed." + } + + debug_call_data["error"] = error_msg + _log_debug_call("vision_analyze_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + finally: + # Clean up temporary image file + if temp_image_path and temp_image_path.exists(): + try: + temp_image_path.unlink() + print(f"🧹 Cleaned up temporary image file", flush=True) + except Exception as cleanup_error: + print(f"āš ļø Warning: Could not delete temporary file: {cleanup_error}", flush=True) + + +def check_openrouter_api_key() -> bool: + """ + Check if the OpenRouter API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("OPENROUTER_API_KEY")) + + +def check_vision_requirements() -> bool: + """ + Check if all requirements for vision tools are met. + + Returns: + bool: True if requirements are met, False otherwise + """ + return check_openrouter_api_key() + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +# ============================================================================= +# Atropos-Agent Tool wrapper (Hermes compatibility) +# ============================================================================= + +from .base import Tool, ToolResult, ToolSchema # noqa: E402 + + +def _tool_result_from_json(output: str) -> ToolResult: + try: + data = json.loads(output) + except Exception: + return ToolResult(success=True, output=output) + + if isinstance(data, dict) and "success" in data: + ok = bool(data.get("success")) + if not ok: + err = data.get("error") or "Vision analysis failed" + return ToolResult(success=False, output=output, error=str(err)) + return ToolResult(success=True, output=output) + + return ToolResult(success=True, output=output) + + +class VisionAnalyzeTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="vision_analyze", + description=( + "Analyze images from URLs using AI vision. Provides a comprehensive description and answers a question." + ), + parameters={ + "image_url": { + "type": "string", + "description": "The URL of the image to analyze (must be publicly accessible HTTP/HTTPS URL)", + }, + "question": { + "type": "string", + "description": "A specific question about the image content to answer.", + }, + }, + required=["image_url", "question"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + if not os.getenv("OPENROUTER_API_KEY"): + return False, "OPENROUTER_API_KEY not set" + return True, None + + async def execute(self, image_url: str, question: str) -> ToolResult: + full_prompt = ( + "Fully describe and explain everything about this image, then answer the following question:\n\n" + f"{question}" + ) + output = await vision_analyze_tool(image_url=image_url, user_prompt=full_prompt, model=DEFAULT_VISION_MODEL) + return _tool_result_from_json(output) + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("šŸ‘ļø Vision Tools Module") + print("=" * 40) + + # Check if API key is available + api_available = check_openrouter_api_key() + + if not api_available: + print("āŒ OPENROUTER_API_KEY environment variable not set") + print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'") + print("Get API key at: https://openrouter.ai/") + exit(1) + else: + print("āœ… OpenRouter API key found") + + print("šŸ› ļø Vision tools ready for use!") + print(f"🧠 Using model: {DEFAULT_VISION_MODEL}") + + # Show debug mode status + if DEBUG_MODE: + print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("šŸ› Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from vision_tools import vision_analyze_tool") + print(" import asyncio") + print("") + print(" async def main():") + print(" result = await vision_analyze_tool(") + print(" image_url='https://example.com/image.jpg',") + print(" user_prompt='What do you see in this image?'") + print(" )") + print(" print(result)") + print(" asyncio.run(main())") + + print("\nExample prompts:") + print(" - 'What architectural style is this building?'") + print(" - 'Describe the emotions and mood in this image'") + print(" - 'What text can you read in this image?'") + print(" - 'Identify any safety hazards visible'") + print(" - 'What products or brands are shown?'") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export VISION_TOOLS_DEBUG=true") + print(" # Debug logs capture all vision analysis calls and results") + print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json") diff --git a/atropos/tools/web_tools.py b/atropos/tools/web_tools.py new file mode 100644 index 0000000000..f0e253afce --- /dev/null +++ b/atropos/tools/web_tools.py @@ -0,0 +1,1417 @@ +#!/usr/bin/env python3 +""" +Standalone Web Tools Module + +This module provides generic web tools that work with multiple backend providers. +Currently uses Firecrawl as the backend, and the interface makes it easy to swap +providers without changing the function signatures. + +Available tools: +- web_search_tool: Search the web for information +- web_extract_tool: Extract content from specific web pages +- web_crawl_tool: Crawl websites with specific instructions + +Backend compatibility: +- Firecrawl: https://docs.firecrawl.dev/introduction + +LLM Processing: +- Uses OpenRouter API with Gemini 3 Flash Preview for intelligent content extraction +- Extracts key excerpts and creates markdown summaries to reduce token usage + +Debug Mode: +- Set WEB_TOOLS_DEBUG=true to enable detailed logging +- Creates web_tools_debug_UUID.json in ./logs directory +- Captures all tool calls, results, and compression metrics + +Usage: + from web_tools import web_search_tool, web_extract_tool, web_crawl_tool + + # Search the web + results = web_search_tool("Python machine learning libraries", limit=3) + + # Extract content from URLs + content = web_extract_tool(["https://example.com"], format="markdown") + + # Crawl a website + crawl_data = web_crawl_tool("example.com", "Find contact information") +""" + +#TODO: Search Capabilities over the scraped pages +#TODO: Store the pages in something +#TODO: Tool to see what pages are available/saved to search over + +import json +import os +import re +import asyncio +import uuid +import datetime +from pathlib import Path +from typing import List, Dict, Any, Optional +from openai import AsyncOpenAI + +try: + from firecrawl import Firecrawl # type: ignore +except Exception: # pragma: no cover - optional dependency + Firecrawl = None # type: ignore[assignment] + +_firecrawl_client = None +_summarizer_client: Optional[AsyncOpenAI] = None + + +def _get_firecrawl_client(): + global _firecrawl_client + + if _firecrawl_client is not None: + return _firecrawl_client + + if Firecrawl is None: + raise RuntimeError("firecrawl package not installed") + + api_key = os.getenv("FIRECRAWL_API_KEY") + if not api_key: + raise RuntimeError("FIRECRAWL_API_KEY environment variable not set") + + _firecrawl_client = Firecrawl(api_key=api_key) + return _firecrawl_client + + +def _get_summarizer_client() -> AsyncOpenAI: + global _summarizer_client + + if _summarizer_client is not None: + return _summarizer_client + + _summarizer_client = AsyncOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1", + ) + return _summarizer_client + +# Configuration for LLM processing +DEFAULT_SUMMARIZER_MODEL = "google/gemini-3-flash-preview" +DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 + +# Debug mode configuration +DEBUG_MODE = os.getenv("WEB_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› Debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"web_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› Debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving debug log: {str(e)}") + + +async def process_content_with_llm( + content: str, + url: str = "", + title: str = "", + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> Optional[str]: + """ + Process web content using LLM to create intelligent summaries with key excerpts. + + This function uses Gemini 3 Flash Preview (or specified model) via OpenRouter API + to intelligently extract key information and create markdown summaries, + significantly reducing token usage while preserving all important information. + + For very large content (>500k chars), uses chunked processing with synthesis. + For extremely large content (>2M chars), refuses to process entirely. + + Args: + content (str): The raw content to process + url (str): The source URL (for context, optional) + title (str): The page title (for context, optional) + model (str): The model to use for processing (default: google/gemini-3-flash-preview) + min_length (int): Minimum content length to trigger processing (default: 5000) + + Returns: + Optional[str]: Processed markdown content, or None if content too short or processing fails + """ + # Size thresholds + MAX_CONTENT_SIZE = 2_000_000 # 2M chars - refuse entirely above this + CHUNK_THRESHOLD = 500_000 # 500k chars - use chunked processing above this + CHUNK_SIZE = 100_000 # 100k chars per chunk + MAX_OUTPUT_SIZE = 5000 # Hard cap on final output size + + try: + content_len = len(content) + + # Refuse if content is absurdly large + if content_len > MAX_CONTENT_SIZE: + size_mb = content_len / 1_000_000 + print(f"🚫 Content too large ({size_mb:.1f}MB > 2MB limit). Refusing to process.") + return f"[Content too large to process: {size_mb:.1f}MB. Try using web_crawl with specific extraction instructions, or search for a more focused source.]" + + # Skip processing if content is too short + if content_len < min_length: + print(f"šŸ“ Content too short ({content_len} < {min_length} chars), skipping LLM processing") + return None + + # If we don't have an OpenRouter key, skip processing rather than failing. + if not os.getenv("OPENROUTER_API_KEY"): + print("šŸ”‘ OPENROUTER_API_KEY not set; skipping LLM processing") + return None + + # Create context information + context_info = [] + if title: + context_info.append(f"Title: {title}") + if url: + context_info.append(f"Source: {url}") + context_str = "\n".join(context_info) + "\n\n" if context_info else "" + + # Check if we need chunked processing + if content_len > CHUNK_THRESHOLD: + print(f"šŸ“¦ Content large ({content_len:,} chars). Using chunked processing...") + return await _process_large_content_chunked( + content, context_str, model, CHUNK_SIZE, MAX_OUTPUT_SIZE + ) + + # Standard single-pass processing for normal content + print(f"🧠 Processing content with LLM ({content_len} characters)") + + processed_content = await _call_summarizer_llm(content, context_str, model) + + if processed_content: + # Enforce output cap + if len(processed_content) > MAX_OUTPUT_SIZE: + processed_content = processed_content[:MAX_OUTPUT_SIZE] + "\n\n[... summary truncated for context management ...]" + + # Log compression metrics + processed_length = len(processed_content) + compression_ratio = processed_length / content_len if content_len > 0 else 1.0 + print(f"āœ… Content processed: {content_len} → {processed_length} chars ({compression_ratio:.1%})") + + return processed_content + + except Exception as e: + print(f"āŒ Error processing content with LLM: {str(e)}") + return f"[Failed to process content: {str(e)[:100]}. Content size: {len(content):,} chars]" + + +async def _call_summarizer_llm( + content: str, + context_str: str, + model: str, + max_tokens: int = 4000, + is_chunk: bool = False, + chunk_info: str = "" +) -> Optional[str]: + """ + Make a single LLM call to summarize content. + + Args: + content: The content to summarize + context_str: Context information (title, URL) + model: Model to use + max_tokens: Maximum output tokens + is_chunk: Whether this is a chunk of a larger document + chunk_info: Information about chunk position (e.g., "Chunk 2/5") + + Returns: + Summarized content or None on failure + """ + if is_chunk: + # Chunk-specific prompt - aware that this is partial content + system_prompt = """You are an expert content analyst processing a SECTION of a larger document. Your job is to extract and summarize the key information from THIS SECTION ONLY. + +Important guidelines for chunk processing: +1. Do NOT write introductions or conclusions - this is a partial document +2. Focus on extracting ALL key facts, figures, data points, and insights from this section +3. Preserve important quotes, code snippets, and specific details verbatim +4. Use bullet points and structured formatting for easy synthesis later +5. Note any references to other sections (e.g., "as mentioned earlier", "see below") without trying to resolve them + +Your output will be combined with summaries of other sections, so focus on thorough extraction rather than narrative flow.""" + + user_prompt = f"""Extract key information from this SECTION of a larger document: + +{context_str}{chunk_info} + +SECTION CONTENT: +{content} + +Extract all important information from this section in a structured format. Focus on facts, data, insights, and key details. Do not add introductions or conclusions.""" + + else: + # Standard full-document prompt + system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. + +Create a well-structured markdown summary that includes: +1. Key excerpts (quotes, code snippets, important facts) in their original format +2. Comprehensive summary of all other important information +3. Proper markdown formatting with headers, bullets, and emphasis + +Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized.""" + + user_prompt = f"""Please process this web content and create a comprehensive markdown summary: + +{context_str}CONTENT TO PROCESS: +{content} + +Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights.""" + + # Call the LLM with retry logic + max_retries = 6 + retry_delay = 2 + last_error = None + + for attempt in range(max_retries): + try: + response = await _get_summarizer_client().chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.1, + max_tokens=max_tokens, + extra_body={ + "reasoning": { + "enabled": True, + "effort": "xhigh" + } + } + ) + return response.choices[0].message.content.strip() + except Exception as api_error: + last_error = api_error + if attempt < max_retries - 1: + print(f"āš ļø LLM API call failed (attempt {attempt + 1}/{max_retries}): {str(api_error)[:100]}") + print(f" Retrying in {retry_delay}s...") + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, 60) + else: + raise last_error + + return None + + +async def _process_large_content_chunked( + content: str, + context_str: str, + model: str, + chunk_size: int, + max_output_size: int +) -> Optional[str]: + """ + Process large content by chunking, summarizing each chunk in parallel, + then synthesizing the summaries. + + Args: + content: The large content to process + context_str: Context information + model: Model to use + chunk_size: Size of each chunk in characters + max_output_size: Maximum final output size + + Returns: + Synthesized summary or None on failure + """ + # Split content into chunks + chunks = [] + for i in range(0, len(content), chunk_size): + chunk = content[i:i + chunk_size] + chunks.append(chunk) + + print(f" šŸ“¦ Split into {len(chunks)} chunks of ~{chunk_size:,} chars each") + + # Summarize each chunk in parallel + async def summarize_chunk(chunk_idx: int, chunk_content: str) -> tuple[int, Optional[str]]: + """Summarize a single chunk.""" + try: + chunk_info = f"[Processing chunk {chunk_idx + 1} of {len(chunks)}]" + summary = await _call_summarizer_llm( + chunk_content, + context_str, + model, + max_tokens=2000, + is_chunk=True, + chunk_info=chunk_info + ) + if summary: + print(f" āœ… Chunk {chunk_idx + 1}/{len(chunks)} summarized: {len(chunk_content):,} → {len(summary):,} chars") + return chunk_idx, summary + except Exception as e: + print(f" āš ļø Chunk {chunk_idx + 1}/{len(chunks)} failed: {str(e)[:50]}") + return chunk_idx, None + + # Run all chunk summarizations in parallel + tasks = [summarize_chunk(i, chunk) for i, chunk in enumerate(chunks)] + results = await asyncio.gather(*tasks) + + # Collect successful summaries in order + summaries = [] + for chunk_idx, summary in sorted(results, key=lambda x: x[0]): + if summary: + summaries.append(f"## Section {chunk_idx + 1}\n{summary}") + + if not summaries: + print(f" āŒ All chunk summarizations failed") + return "[Failed to process large content: all chunk summarizations failed]" + + print(f" šŸ“Š Got {len(summaries)}/{len(chunks)} chunk summaries") + + # If only one chunk succeeded, just return it (with cap) + if len(summaries) == 1: + result = summaries[0] + if len(result) > max_output_size: + result = result[:max_output_size] + "\n\n[... truncated ...]" + return result + + # Synthesize the summaries into a final summary + print(f" šŸ”— Synthesizing {len(summaries)} summaries...") + + combined_summaries = "\n\n---\n\n".join(summaries) + + synthesis_prompt = f"""You have been given summaries of different sections of a large document. +Synthesize these into ONE cohesive, comprehensive summary that: +1. Removes redundancy between sections +2. Preserves all key facts, figures, and actionable information +3. Is well-organized with clear structure +4. Is under {max_output_size} characters + +{context_str}SECTION SUMMARIES: +{combined_summaries} + +Create a single, unified markdown summary.""" + + try: + response = await _get_summarizer_client().chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You synthesize multiple summaries into one cohesive, comprehensive summary. Be thorough but concise."}, + {"role": "user", "content": synthesis_prompt} + ], + temperature=0.1, + max_tokens=4000, + extra_body={ + "reasoning": { + "enabled": True, + "effort": "xhigh" + } + } + ) + final_summary = response.choices[0].message.content.strip() + + # Enforce hard cap + if len(final_summary) > max_output_size: + final_summary = final_summary[:max_output_size] + "\n\n[... summary truncated for context management ...]" + + original_len = len(content) + final_len = len(final_summary) + compression = final_len / original_len if original_len > 0 else 1.0 + + print(f" āœ… Synthesis complete: {original_len:,} → {final_len:,} chars ({compression:.2%})") + return final_summary + + except Exception as e: + print(f" āš ļø Synthesis failed: {str(e)[:100]}") + # Fall back to concatenated summaries with truncation + fallback = "\n\n".join(summaries) + if len(fallback) > max_output_size: + fallback = fallback[:max_output_size] + "\n\n[... truncated due to synthesis failure ...]" + return fallback + + +def clean_base64_images(text: str) -> str: + """ + Remove base64 encoded images from text to reduce token count and clutter. + + This function finds and removes base64 encoded images in various formats: + - (data:image/png;base64,...) + - (data:image/jpeg;base64,...) + - (data:image/svg+xml;base64,...) + - data:image/[type];base64,... (without parentheses) + + Args: + text: The text content to clean + + Returns: + Cleaned text with base64 images replaced with placeholders + """ + # Pattern to match base64 encoded images wrapped in parentheses + # Matches: (data:image/[type];base64,[base64-string]) + base64_with_parens_pattern = r'\(data:image/[^;]+;base64,[A-Za-z0-9+/=]+\)' + + # Pattern to match base64 encoded images without parentheses + # Matches: data:image/[type];base64,[base64-string] + base64_pattern = r'data:image/[^;]+;base64,[A-Za-z0-9+/=]+' + + # Replace parentheses-wrapped images first + cleaned_text = re.sub(base64_with_parens_pattern, '[BASE64_IMAGE_REMOVED]', text) + + # Then replace any remaining non-parentheses images + cleaned_text = re.sub(base64_pattern, '[BASE64_IMAGE_REMOVED]', cleaned_text) + + return cleaned_text + + +def web_search_tool(query: str, limit: int = 5) -> str: + """ + Search the web for information using available search API backend. + + This function provides a generic interface for web search that can work + with multiple backends. Currently uses Firecrawl. + + Note: This function returns search result metadata only (URLs, titles, descriptions). + Use web_extract_tool to get full content from specific URLs. + + Args: + query (str): The search query to look up + limit (int): Maximum number of results to return (default: 5) + + Returns: + str: JSON string containing search results with the following structure: + { + "success": bool, + "data": { + "web": [ + { + "title": str, + "url": str, + "description": str, + "position": int + }, + ... + ] + } + } + + Raises: + Exception: If search fails or API key is not set + """ + debug_call_data = { + "parameters": { + "query": query, + "limit": limit + }, + "error": None, + "results_count": 0, + "original_response_size": 0, + "final_response_size": 0 + } + + try: + print(f"šŸ” Searching the web for: '{query}' (limit: {limit})") + + # Use Firecrawl's v2 search functionality WITHOUT scraping + # We only want search result metadata, not scraped content + # Docs: https://docs.firecrawl.dev/features/search + response = _get_firecrawl_client().search( + query=query, + limit=limit + ) + + # The response is a SearchData object with web, news, and images attributes + # When not scraping, the results are directly in these attributes + web_results = [] + + # Check if response has web attribute (SearchData object) + if hasattr(response, 'web'): + # Response is a SearchData object with web attribute + if response.web: + # Convert each SearchResultWeb object to dict + for result in response.web: + if hasattr(result, 'model_dump'): + # Pydantic model - use model_dump + web_results.append(result.model_dump()) + elif hasattr(result, '__dict__'): + # Regular object - use __dict__ + web_results.append(result.__dict__) + elif isinstance(result, dict): + # Already a dict + web_results.append(result) + elif hasattr(response, 'model_dump'): + # Response has model_dump method - use it to get dict + response_dict = response.model_dump() + if 'web' in response_dict and response_dict['web']: + web_results = response_dict['web'] + elif isinstance(response, dict): + # Response is already a dictionary + if 'web' in response and response['web']: + web_results = response['web'] + + results_count = len(web_results) + print(f"āœ… Found {results_count} search results") + + # Build response with just search metadata (URLs, titles, descriptions) + response_data = { + "success": True, + "data": { + "web": web_results + } + } + + # Capture debug information + debug_call_data["results_count"] = results_count + + # Convert to JSON + result_json = json.dumps(response_data, indent=2, ensure_ascii=False) + + debug_call_data["final_response_size"] = len(result_json) + + # Log debug information + _log_debug_call("web_search_tool", debug_call_data) + _save_debug_log() + + return result_json + + except Exception as e: + error_msg = f"Error searching web: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_search_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +async def web_extract_tool( + urls: List[str], + format: str = None, + use_llm_processing: bool = True, + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> str: + """ + Extract content from specific web pages using available extraction API backend. + + This function provides a generic interface for web content extraction that + can work with multiple backends. Currently uses Firecrawl. + + Args: + urls (List[str]): List of URLs to extract content from + format (str): Desired output format ("markdown" or "html", optional) + use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) + model (str): The model to use for LLM processing (default: google/gemini-3-flash-preview) + min_length (int): Minimum content length to trigger LLM processing (default: 5000) + + Returns: + str: JSON string containing extracted content. If LLM processing is enabled and successful, + the 'content' field will contain the processed markdown summary instead of raw content. + + Raises: + Exception: If extraction fails or API key is not set + """ + debug_call_data = { + "parameters": { + "urls": urls, + "format": format, + "use_llm_processing": use_llm_processing, + "model": model, + "min_length": min_length + }, + "error": None, + "pages_extracted": 0, + "pages_processed_with_llm": 0, + "original_response_size": 0, + "final_response_size": 0, + "compression_metrics": [], + "processing_applied": [] + } + + try: + print(f"šŸ“„ Extracting content from {len(urls)} URL(s)") + + # Determine requested formats for Firecrawl v2 + formats: List[str] = [] + if format == "markdown": + formats = ["markdown"] + elif format == "html": + formats = ["html"] + else: + # Default: request markdown for LLM-readiness and include html as backup + formats = ["markdown", "html"] + + # Always use individual scraping for simplicity and reliability + # Batch scraping adds complexity without much benefit for small numbers of URLs + results: List[Dict[str, Any]] = [] + + for url in urls: + try: + print(f" šŸ“„ Scraping: {url}") + scrape_result = _get_firecrawl_client().scrape( + url=url, + formats=formats + ) + + # Process the result - properly handle object serialization + metadata = {} + title = "" + content_markdown = None + content_html = None + + # Extract data from the scrape result + if hasattr(scrape_result, 'model_dump'): + # Pydantic model - use model_dump to get dict + result_dict = scrape_result.model_dump() + content_markdown = result_dict.get('markdown') + content_html = result_dict.get('html') + metadata = result_dict.get('metadata', {}) + elif hasattr(scrape_result, '__dict__'): + # Regular object with attributes + content_markdown = getattr(scrape_result, 'markdown', None) + content_html = getattr(scrape_result, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(scrape_result, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(scrape_result, dict): + # Already a dictionary + content_markdown = scrape_result.get('markdown') + content_html = scrape_result.get('html') + metadata = scrape_result.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Get title from metadata + title = metadata.get("title", "") + + # Choose content based on requested format + chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" + + results.append({ + "url": metadata.get("sourceURL", url), + "title": title, + "content": chosen_content, + "raw_content": chosen_content, + "metadata": metadata # Now guaranteed to be a dict + }) + + except Exception as scrape_err: + print(f" āŒ Error scraping {url}: {str(scrape_err)}") + results.append({ + "url": url, + "title": "", + "content": "", + "raw_content": "", + "error": str(scrape_err) + }) + + response = {"results": results} + + pages_extracted = len(response.get('results', [])) + print(f"āœ… Extracted content from {pages_extracted} pages") + + debug_call_data["pages_extracted"] = pages_extracted + debug_call_data["original_response_size"] = len(json.dumps(response)) + + # Process each result with LLM if enabled + if use_llm_processing and os.getenv("OPENROUTER_API_KEY"): + print("🧠 Processing extracted content with LLM (parallel)...") + debug_call_data["processing_applied"].append("llm_processing") + + # Prepare tasks for parallel processing + async def process_single_result(result): + """Process a single result with LLM and return updated result with metrics.""" + url = result.get('url', 'Unknown URL') + title = result.get('title', '') + raw_content = result.get('raw_content', '') or result.get('content', '') + + if not raw_content: + return result, None, "no_content" + + original_size = len(raw_content) + + # Process content with LLM + processed = await process_content_with_llm( + raw_content, url, title, model, min_length + ) + + if processed: + processed_size = len(processed) + compression_ratio = processed_size / original_size if original_size > 0 else 1.0 + + # Update result with processed content + result['content'] = processed + result['raw_content'] = raw_content + + metrics = { + "url": url, + "original_size": original_size, + "processed_size": processed_size, + "compression_ratio": compression_ratio, + "model_used": model + } + return result, metrics, "processed" + else: + metrics = { + "url": url, + "original_size": original_size, + "processed_size": original_size, + "compression_ratio": 1.0, + "model_used": None, + "reason": "content_too_short" + } + return result, metrics, "too_short" + + # Run all LLM processing in parallel + results_list = response.get('results', []) + tasks = [process_single_result(result) for result in results_list] + processed_results = await asyncio.gather(*tasks) + + # Collect metrics and print results + for result, metrics, status in processed_results: + url = result.get('url', 'Unknown URL') + if status == "processed": + debug_call_data["compression_metrics"].append(metrics) + debug_call_data["pages_processed_with_llm"] += 1 + print(f" šŸ“ {url} (processed)") + elif status == "too_short": + debug_call_data["compression_metrics"].append(metrics) + print(f" šŸ“ {url} (no processing - content too short)") + else: + print(f" āš ļø {url} (no content to process)") + else: + if use_llm_processing and not os.getenv("OPENROUTER_API_KEY"): + print("āš ļø LLM processing requested but OPENROUTER_API_KEY not set, returning raw content") + debug_call_data["processing_applied"].append("llm_processing_unavailable") + + # Print summary of extracted pages for debugging (original behavior) + for result in response.get('results', []): + url = result.get('url', 'Unknown URL') + content_length = len(result.get('raw_content', '')) + print(f" šŸ“ {url} ({content_length} characters)") + + # Trim output to minimal fields per entry: title, content, error + trimmed_results = [ + { + "title": r.get("title", ""), + "content": r.get("content", ""), + "error": r.get("error"), + } + for r in response.get("results", []) + ] + trimmed_response = {"results": trimmed_results} + + if trimmed_response.get("results") == []: + result_json = json.dumps({"error": "Content was inaccessible or not found"}, ensure_ascii=False) + + cleaned_result = clean_base64_images(result_json) + + else: + result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) + + cleaned_result = clean_base64_images(result_json) + + debug_call_data["final_response_size"] = len(cleaned_result) + debug_call_data["processing_applied"].append("base64_image_removal") + + # Log debug information + _log_debug_call("web_extract_tool", debug_call_data) + _save_debug_log() + + return cleaned_result + + except Exception as e: + error_msg = f"Error extracting content: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_extract_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +async def web_crawl_tool( + url: str, + instructions: str = None, + depth: str = "basic", + use_llm_processing: bool = True, + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> str: + """ + Crawl a website with specific instructions using available crawling API backend. + + This function provides a generic interface for web crawling that can work + with multiple backends. Currently uses Firecrawl. + + Args: + url (str): The base URL to crawl (can include or exclude https://) + instructions (str): Instructions for what to crawl/extract using LLM intelligence (optional) + depth (str): Depth of extraction ("basic" or "advanced", default: "basic") + use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) + model (str): The model to use for LLM processing (default: google/gemini-3-flash-preview) + min_length (int): Minimum content length to trigger LLM processing (default: 5000) + + Returns: + str: JSON string containing crawled content. If LLM processing is enabled and successful, + the 'content' field will contain the processed markdown summary instead of raw content. + Each page is processed individually. + + Raises: + Exception: If crawling fails or API key is not set + """ + debug_call_data = { + "parameters": { + "url": url, + "instructions": instructions, + "depth": depth, + "use_llm_processing": use_llm_processing, + "model": model, + "min_length": min_length + }, + "error": None, + "pages_crawled": 0, + "pages_processed_with_llm": 0, + "original_response_size": 0, + "final_response_size": 0, + "compression_metrics": [], + "processing_applied": [] + } + + try: + # Ensure URL has protocol + if not url.startswith(('http://', 'https://')): + url = f'https://{url}' + print(f" šŸ“ Added https:// prefix to URL: {url}") + + instructions_text = f" with instructions: '{instructions}'" if instructions else "" + print(f"šŸ•·ļø Crawling {url}{instructions_text}") + + # Use Firecrawl's v2 crawl functionality + # Docs: https://docs.firecrawl.dev/features/crawl + # The crawl() method automatically waits for completion and returns all data + + # Build crawl parameters - keep it simple + crawl_params = { + "limit": 20, # Limit number of pages to crawl + "scrape_options": { + "formats": ["markdown"] # Just markdown for simplicity + } + } + + # Note: The 'prompt' parameter is not documented for crawl + # Instructions are typically used with the Extract endpoint, not Crawl + if instructions: + print(f" ā„¹ļø Note: Instructions parameter ignored (not supported in crawl API)") + + # Use the crawl method which waits for completion automatically + try: + crawl_result = _get_firecrawl_client().crawl( + url=url, + **crawl_params + ) + except Exception as e: + print(f" āŒ Crawl API call failed: {e}") + raise + + pages: List[Dict[str, Any]] = [] + + # Process crawl results - the crawl method returns a CrawlJob object with data attribute + data_list = [] + + # The crawl_result is a CrawlJob object with a 'data' attribute containing list of Document objects + if hasattr(crawl_result, 'data'): + data_list = crawl_result.data if crawl_result.data else [] + print(f" šŸ“Š Status: {getattr(crawl_result, 'status', 'unknown')}") + print(f" šŸ“„ Retrieved {len(data_list)} pages") + + # Debug: Check other attributes if no data + if not data_list: + print(f" šŸ” Debug - CrawlJob attributes: {[attr for attr in dir(crawl_result) if not attr.startswith('_')]}") + print(f" šŸ” Debug - Status: {getattr(crawl_result, 'status', 'N/A')}") + print(f" šŸ” Debug - Total: {getattr(crawl_result, 'total', 'N/A')}") + print(f" šŸ” Debug - Completed: {getattr(crawl_result, 'completed', 'N/A')}") + + elif isinstance(crawl_result, dict) and 'data' in crawl_result: + data_list = crawl_result.get("data", []) + else: + print(" āš ļø Unexpected crawl result type") + print(f" šŸ” Debug - Result type: {type(crawl_result)}") + if hasattr(crawl_result, '__dict__'): + print(f" šŸ” Debug - Result attributes: {list(crawl_result.__dict__.keys())}") + + for item in data_list: + # Process each crawled page - properly handle object serialization + page_url = "Unknown URL" + title = "" + content_markdown = None + content_html = None + metadata = {} + + # Extract data from the item + if hasattr(item, 'model_dump'): + # Pydantic model - use model_dump to get dict + item_dict = item.model_dump() + content_markdown = item_dict.get('markdown') + content_html = item_dict.get('html') + metadata = item_dict.get('metadata', {}) + elif hasattr(item, '__dict__'): + # Regular object with attributes + content_markdown = getattr(item, 'markdown', None) + content_html = getattr(item, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(item, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(item, dict): + # Already a dictionary + content_markdown = item.get('markdown') + content_html = item.get('html') + metadata = item.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Extract URL and title from metadata + page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL")) + title = metadata.get("title", "") + + # Choose content (prefer markdown) + content = content_markdown or content_html or "" + + pages.append({ + "url": page_url, + "title": title, + "content": content, + "raw_content": content, + "metadata": metadata # Now guaranteed to be a dict + }) + + response = {"results": pages} + + pages_crawled = len(response.get('results', [])) + print(f"āœ… Crawled {pages_crawled} pages") + + debug_call_data["pages_crawled"] = pages_crawled + debug_call_data["original_response_size"] = len(json.dumps(response)) + + # Process each result with LLM if enabled + if use_llm_processing and os.getenv("OPENROUTER_API_KEY"): + print("🧠 Processing crawled content with LLM (parallel)...") + debug_call_data["processing_applied"].append("llm_processing") + + # Prepare tasks for parallel processing + async def process_single_crawl_result(result): + """Process a single crawl result with LLM and return updated result with metrics.""" + page_url = result.get('url', 'Unknown URL') + title = result.get('title', '') + content = result.get('content', '') + + if not content: + return result, None, "no_content" + + original_size = len(content) + + # Process content with LLM + processed = await process_content_with_llm( + content, page_url, title, model, min_length + ) + + if processed: + processed_size = len(processed) + compression_ratio = processed_size / original_size if original_size > 0 else 1.0 + + # Update result with processed content + result['raw_content'] = content + result['content'] = processed + + metrics = { + "url": page_url, + "original_size": original_size, + "processed_size": processed_size, + "compression_ratio": compression_ratio, + "model_used": model + } + return result, metrics, "processed" + else: + metrics = { + "url": page_url, + "original_size": original_size, + "processed_size": original_size, + "compression_ratio": 1.0, + "model_used": None, + "reason": "content_too_short" + } + return result, metrics, "too_short" + + # Run all LLM processing in parallel + results_list = response.get('results', []) + tasks = [process_single_crawl_result(result) for result in results_list] + processed_results = await asyncio.gather(*tasks) + + # Collect metrics and print results + for result, metrics, status in processed_results: + page_url = result.get('url', 'Unknown URL') + if status == "processed": + debug_call_data["compression_metrics"].append(metrics) + debug_call_data["pages_processed_with_llm"] += 1 + print(f" 🌐 {page_url} (processed)") + elif status == "too_short": + debug_call_data["compression_metrics"].append(metrics) + print(f" 🌐 {page_url} (no processing - content too short)") + else: + print(f" āš ļø {page_url} (no content to process)") + else: + if use_llm_processing and not os.getenv("OPENROUTER_API_KEY"): + print("āš ļø LLM processing requested but OPENROUTER_API_KEY not set, returning raw content") + debug_call_data["processing_applied"].append("llm_processing_unavailable") + + # Print summary of crawled pages for debugging (original behavior) + for result in response.get('results', []): + page_url = result.get('url', 'Unknown URL') + content_length = len(result.get('content', '')) + print(f" 🌐 {page_url} ({content_length} characters)") + + # Trim output to minimal fields per entry: title, content, error + trimmed_results = [ + { + "title": r.get("title", ""), + "content": r.get("content", ""), + "error": r.get("error") + } + for r in response.get("results", []) + ] + trimmed_response = {"results": trimmed_results} + + result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) + # Clean base64 images from crawled content + cleaned_result = clean_base64_images(result_json) + + debug_call_data["final_response_size"] = len(cleaned_result) + debug_call_data["processing_applied"].append("base64_image_removal") + + # Log debug information + _log_debug_call("web_crawl_tool", debug_call_data) + _save_debug_log() + + return cleaned_result + + except Exception as e: + error_msg = f"Error crawling website: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_crawl_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +# Convenience function to check if API key is available +def check_firecrawl_api_key() -> bool: + """ + Check if the Firecrawl API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("FIRECRAWL_API_KEY")) and Firecrawl is not None + + +def check_nous_api_key() -> bool: + """ + Check if the Nous Research API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("OPENROUTER_API_KEY")) + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information: + - enabled: Whether debug mode is enabled + - session_id: Current session UUID (if enabled) + - log_path: Path where debug logs are saved (if enabled) + - total_calls: Number of tool calls logged so far (if enabled) + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"web_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +# ============================================================================= +# Atropos-Agent Tool wrappers (Hermes compatibility) +# ============================================================================= + +from .base import Tool, ToolResult, ToolSchema # noqa: E402 + + +def _tool_result_from_json(output: str) -> ToolResult: + try: + data = json.loads(output) + except Exception: + return ToolResult(success=True, output=output) + + if isinstance(data, dict): + if data.get("success") is False: + err = data.get("error") or data.get("message") or "Tool failed" + return ToolResult(success=False, output=output, error=str(err)) + if "error" in data and data.get("error"): + return ToolResult(success=False, output=output, error=str(data["error"])) + + return ToolResult(success=True, output=output) + + +class WebSearchTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="web_search", + description=( + "Search the web for information on any topic. Returns up to 5 relevant results with titles and URLs." + ), + parameters={ + "query": {"type": "string", "description": "The search query to look up on the web"}, + }, + required=["query"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + if Firecrawl is None: + return False, "firecrawl package not installed" + if not os.getenv("FIRECRAWL_API_KEY"): + return False, "FIRECRAWL_API_KEY not set" + return True, None + + async def execute(self, query: str) -> ToolResult: + output = web_search_tool(query, limit=5) + return _tool_result_from_json(output) + + +class WebExtractTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="web_extract", + description=( + "Extract and read the full content from specific web page URLs. Returns excerpts and key points." + ), + parameters={ + "urls": { + "type": "array", + "items": {"type": "string"}, + "description": "List of URLs to extract content from (max 5 URLs per call)", + "maxItems": 5, + }, + }, + required=["urls"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + if Firecrawl is None: + return False, "firecrawl package not installed" + if not os.getenv("FIRECRAWL_API_KEY"): + return False, "FIRECRAWL_API_KEY not set" + return True, None + + async def execute(self, urls: List[str]) -> ToolResult: + safe_urls = urls[:5] if isinstance(urls, list) else [] + output = await web_extract_tool(safe_urls, format="markdown", use_llm_processing=True) + return _tool_result_from_json(output) + + +class WebCrawlTool(Tool): + @property + def schema(self) -> ToolSchema: + return ToolSchema( + name="web_crawl", + description="Crawl a website and extract relevant content across pages.", + parameters={ + "url": {"type": "string", "description": "The base URL to crawl (can include or exclude https://)"}, + "instructions": { + "type": "string", + "description": "Specific instructions for what to crawl/extract (optional).", + }, + }, + required=["url"], + external=True, + ) + + def is_available(self) -> tuple[bool, str | None]: + if Firecrawl is None: + return False, "firecrawl package not installed" + if not os.getenv("FIRECRAWL_API_KEY"): + return False, "FIRECRAWL_API_KEY not set" + return True, None + + async def execute(self, url: str, instructions: Optional[str] = None) -> ToolResult: + output = await web_crawl_tool(url, instructions=instructions, depth="basic", use_llm_processing=True) + return _tool_result_from_json(output) + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("🌐 Standalone Web Tools Module") + print("=" * 40) + + # Check if API keys are available + firecrawl_available = check_firecrawl_api_key() + nous_available = check_nous_api_key() + + if not firecrawl_available: + print("āŒ FIRECRAWL_API_KEY environment variable not set") + print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'") + print("Get API key at: https://firecrawl.dev/") + else: + print("āœ… Firecrawl API key found") + + if not nous_available: + print("āŒ OPENROUTER_API_KEY environment variable not set") + print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'") + print("Get API key at: https://inference-api.nousresearch.com/") + print("āš ļø Without Nous API key, LLM content processing will be disabled") + else: + print("āœ… Nous Research API key found") + + if not firecrawl_available: + exit(1) + + print("šŸ› ļø Web tools ready for use!") + + if nous_available: + print("🧠 LLM content processing available with Gemini 3 Flash Preview via OpenRouter") + print(f" Default min length for processing: {DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION} chars") + + # Show debug mode status + if DEBUG_MODE: + print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/web_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("šŸ› Debug mode disabled (set WEB_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from web_tools import web_search_tool, web_extract_tool, web_crawl_tool") + print(" import asyncio") + print("") + print(" # Search (synchronous)") + print(" results = web_search_tool('Python tutorials')") + print("") + print(" # Extract and crawl (asynchronous)") + print(" async def main():") + print(" content = await web_extract_tool(['https://example.com'])") + print(" crawl_data = await web_crawl_tool('example.com', 'Find docs')") + print(" asyncio.run(main())") + + if nous_available: + print("\nLLM-enhanced usage:") + print(" # Content automatically processed for pages >5000 chars (default)") + print(" content = await web_extract_tool(['https://python.org/about/'])") + print("") + print(" # Customize processing parameters") + print(" crawl_data = await web_crawl_tool(") + print(" 'docs.python.org',") + print(" 'Find key concepts',") + print(" model='google/gemini-3-flash-preview',") + print(" min_length=3000") + print(" )") + print("") + print(" # Disable LLM processing") + print(" raw_content = await web_extract_tool(['https://example.com'], use_llm_processing=False)") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export WEB_TOOLS_DEBUG=true") + print(" # Debug logs capture:") + print(" # - All tool calls with parameters") + print(" # - Original API responses") + print(" # - LLM compression metrics") + print(" # - Final processed results") + print(" # Logs saved to: ./logs/web_tools_debug_UUID.json") + + print(f"\nšŸ“ Run 'python test_web_tools_llm.py' to test LLM processing capabilities") diff --git a/hermes_agent.egg-info/PKG-INFO b/hermes_agent.egg-info/PKG-INFO index edf2642ae9..98fb7f6d68 100644 --- a/hermes_agent.egg-info/PKG-INFO +++ b/hermes_agent.egg-info/PKG-INFO @@ -30,6 +30,10 @@ Requires-Dist: pytest; extra == "dev" Requires-Dist: pytest-asyncio; extra == "dev" Provides-Extra: atropos Requires-Dist: atroposlib @ git+ssh://git@github.com/NousResearch/atropos.git ; extra == "atropos" +Requires-Dist: aiohttp; extra == "atropos" +Requires-Dist: fastapi; extra == "atropos" +Requires-Dist: uvicorn; extra == "atropos" +Requires-Dist: pyte; extra == "atropos" # Hermes Agent diff --git a/hermes_agent.egg-info/SOURCES.txt b/hermes_agent.egg-info/SOURCES.txt index 57a84acdf8..3464034f41 100644 --- a/hermes_agent.egg-info/SOURCES.txt +++ b/hermes_agent.egg-info/SOURCES.txt @@ -8,6 +8,40 @@ run_agent.py toolset_distributions.py toolsets.py trajectory_compressor.py +atropos/__init__.py +atropos/sandbox_server.py +atropos/agent/__init__.py +atropos/agent/atropos_agent.py +atropos/api/__init__.py +atropos/api/tool_executor_server.py +atropos/api/tool_server.py +atropos/envs/__init__.py +atropos/envs/agent_env.py +atropos/envs/hermes_compat_test_env.py +atropos/envs/swe_smith_oracle_env.py +atropos/envs/test_env.py +atropos/nomad/__init__.py +atropos/nomad/client.py +atropos/slots/__init__.py +atropos/slots/executor.py +atropos/slots/pool.py +atropos/slots/slot.py +atropos/terminal/__init__.py +atropos/terminal/asciinema_stream.py +atropos/tools/__init__.py +atropos/tools/base.py +atropos/tools/basic_tools.py +atropos/tools/image_generation_tool.py +atropos/tools/mixture_of_agents_tool.py +atropos/tools/terminal_hecate.py +atropos/tools/terminal_stateful_tool.py +atropos/tools/terminal_tool.py +atropos/tools/tmux_tool.py +atropos/tools/tool_executor.py +atropos/tools/toolset_distributions.py +atropos/tools/toolsets.py +atropos/tools/vision_tools.py +atropos/tools/web_tools.py hermes_agent.egg-info/PKG-INFO hermes_agent.egg-info/SOURCES.txt hermes_agent.egg-info/dependency_links.txt diff --git a/hermes_agent.egg-info/requires.txt b/hermes_agent.egg-info/requires.txt index 1aeae131ef..0ef7437426 100644 --- a/hermes_agent.egg-info/requires.txt +++ b/hermes_agent.egg-info/requires.txt @@ -17,6 +17,10 @@ platformdirs [atropos] atroposlib @ git+ssh://git@github.com/NousResearch/atropos.git +aiohttp +fastapi +uvicorn +pyte [dev] pytest diff --git a/hermes_agent.egg-info/top_level.txt b/hermes_agent.egg-info/top_level.txt index fc74ec88bd..bc15cf1e3b 100644 --- a/hermes_agent.egg-info/top_level.txt +++ b/hermes_agent.egg-info/top_level.txt @@ -1,3 +1,4 @@ +atropos atropos_compatible_agent batch_runner local_server diff --git a/pyproject.toml b/pyproject.toml index e93b0ba936..aa78046db3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,14 @@ dependencies = [ modal = ["modal", "boto3"] dev = ["pytest", "pytest-asyncio"] # Install Atropos from source (PyPI is often stale for this internal dependency). -atropos = ["atroposlib @ git+ssh://git@github.com/NousResearch/atropos.git"] +atropos = [ + "atroposlib @ git+ssh://git@github.com/NousResearch/atropos.git", + # Atropos integration runtime deps (kept optional for Hermes-only users) + "aiohttp", + "fastapi", + "uvicorn", + "pyte", +] [project.scripts] hermes-agent = "run_agent:main" @@ -54,4 +61,4 @@ py-modules = [ ] [tool.setuptools.packages.find] -include = ["tools"] +include = ["tools", "atropos", "atropos.*"] diff --git a/uv.lock b/uv.lock index 20e3909a2d..52eaaccbe0 100644 --- a/uv.lock +++ b/uv.lock @@ -895,7 +895,11 @@ dependencies = [ [package.optional-dependencies] atropos = [ + { name = "aiohttp" }, { name = "atroposlib" }, + { name = "fastapi" }, + { name = "pyte" }, + { name = "uvicorn" }, ] dev = [ { name = "pytest" }, @@ -908,9 +912,11 @@ modal = [ [package.metadata] requires-dist = [ + { name = "aiohttp", marker = "extra == 'atropos'" }, { name = "atroposlib", marker = "extra == 'atropos'", git = "ssh://git@github.com/NousResearch/atropos.git" }, { name = "boto3", marker = "extra == 'modal'" }, { name = "fal-client" }, + { name = "fastapi", marker = "extra == 'atropos'" }, { name = "fire" }, { name = "firecrawl-py" }, { name = "httpx" }, @@ -921,6 +927,7 @@ requires-dist = [ { name = "platformdirs" }, { name = "prompt-toolkit" }, { name = "pydantic", specifier = ">=2.0" }, + { name = "pyte", marker = "extra == 'atropos'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-asyncio", marker = "extra == 'dev'" }, { name = "python-dotenv" }, @@ -929,6 +936,7 @@ requires-dist = [ { name = "rich" }, { name = "tenacity" }, { name = "typer" }, + { name = "uvicorn", marker = "extra == 'atropos'" }, ] provides-extras = ["modal", "dev", "atropos"] @@ -2467,6 +2475,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyte" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/ab/b599762933eba04de7dc5b31ae083112a6c9a9db15b01d3109ad797559d9/pyte-0.8.2.tar.gz", hash = "sha256:5af970e843fa96a97149d64e170c984721f20e52227a2f57f0a54207f08f083f", size = 92301, upload-time = "2023-11-12T09:33:43.217Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/d0/bb522283b90853afbf506cd5b71c650cf708829914efd0003d615cf426cd/pyte-0.8.2-py3-none-any.whl", hash = "sha256:85db42a35798a5aafa96ac4d8da78b090b2c933248819157fc0e6f78876a0135", size = 31627, upload-time = "2023-11-12T09:33:41.096Z" }, +] + [[package]] name = "pytest" version = "9.0.2"