Modal backend stubs

This commit is contained in:
Shannon Sands 2026-02-04 15:20:37 +10:00
parent 7130fa50cb
commit ea7aa0b0d4
7 changed files with 383 additions and 56 deletions

View file

@ -1,7 +1,7 @@
"""
Tool Executor API (Phase 4)
This service provides a queued, batched execution layer on top of SlotPool.
This service provides a queued, batched execution layer on top of a ToolBackend.
It mirrors the stateful FastAPI + app.state pattern used in:
atropos/atroposlib/api/server.py
@ -18,7 +18,7 @@ from pathlib import Path
from fastapi import FastAPI, Header, HTTPException, status
from pydantic import BaseModel, Field
from ..slots import SlotPool, SlotPoolConfig
from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend
from ..tools import ToolRegistry, build_tool_registry
from ..tools.base import (
ArtifactArchiveRequestPayload,
@ -123,22 +123,23 @@ async def _startup() -> None:
tool_server_url=cfg.tool_server_url,
)
pool = SlotPool(
SlotPoolConfig(
backend = NomadToolBackend(
NomadBackendConfig(
nomad_address=cfg.nomad_address,
job_id=cfg.job_id,
image=cfg.image,
sandbox_job_id=cfg.job_id,
sandbox_image=cfg.image,
slots_per_container=cfg.slots_per_container,
min_containers=cfg.min_containers,
max_containers=cfg.max_containers,
privileged=cfg.privileged,
acquire_timeout=cfg.acquire_timeout_s,
acquire_timeout_s=cfg.acquire_timeout_s,
purge_job_on_start=False,
)
)
await pool.start()
await backend.start()
executor = ToolExecutor(
pool=pool,
backend=backend,
tools=tools,
config=ToolExecutorConfig(
batch_window_ms=cfg.batch_window_ms,
@ -151,21 +152,21 @@ async def _startup() -> None:
await executor.start()
app.state.cfg = cfg
app.state.pool = pool
app.state.backend = backend
app.state.executor = executor
@app.on_event("shutdown")
async def _shutdown() -> None:
executor: Optional[ToolExecutor] = getattr(app.state, "executor", None)
pool: Optional[SlotPool] = getattr(app.state, "pool", None)
backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None)
cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None)
if executor is not None:
await executor.close()
if pool is not None:
await pool.stop(purge_job=bool(cfg.purge_job_on_shutdown) if cfg else False)
if backend is not None:
await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False)
@app.get("/health")
@ -176,13 +177,13 @@ async def health() -> Dict[str, Any]:
@app.get("/status")
async def status_endpoint() -> Dict[str, Any]:
executor: ToolExecutor = app.state.executor
pool: SlotPool = app.state.pool
backend: NomadToolBackend = app.state.backend
return {
"queue_size": executor.queue_size(),
"total_requests": executor.total_requests,
"total_errors": executor.total_errors,
"pool": pool.get_stats(),
"pool": backend.get_stats(),
}

View file

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Any
from .base import ToolBackend
from .modal_backend import ModalBackendConfig, ModalToolBackend
from .nomad_backend import NomadBackendConfig, NomadToolBackend
def create_tool_backend(cfg: Any) -> ToolBackend:
mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower()
if mode == "nomad":
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
if mode == "modal":
return ModalToolBackend(ModalBackendConfig.from_agent_env_config(cfg))
raise ValueError(f"Unknown tool_pool_mode: {mode}")
__all__ = [
"ToolBackend",
"create_tool_backend",
"NomadBackendConfig",
"NomadToolBackend",
"ModalBackendConfig",
"ModalToolBackend",
]

89
atropos/backends/base.py Normal file
View file

@ -0,0 +1,89 @@
"""
Backend interfaces for AgentEnv tool execution.
The goal of this module is to decouple ToolExecutor / AgentEnv from any single
execution backend (Nomad/Docker today; Modal later).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Protocol, Tuple
from ..slots.executor import ExecutionResult
from ..slots.slot import Slot
class ToolBackend(Protocol):
"""
Minimal interface required by ToolExecutor.
Backends provide:
- lifecycle (start/stop)
- slot acquisition/release (workspace affinity)
- batched tool execution across slots
- optional artifact helpers (for env verification / demos)
"""
@property
def default_timeout_s(self) -> Optional[float]:
"""Default sandbox execution timeout in seconds (if any)."""
async def start(self) -> None:
"""Start the backend (provision workers/containers, health checks, etc)."""
async def stop(self, *, purge: bool = False) -> None:
"""Stop the backend and optionally purge remote resources."""
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
"""Acquire a slot for a trajectory (workspace affinity)."""
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
"""Release a slot back to the pool."""
async def execute_batch(
self,
requests: List[Tuple[Slot, str, Dict[str, Any]]],
*,
timeout_s: Optional[float] = None,
) -> List[ExecutionResult]:
"""Execute a batch of sandbox tool calls and return results in order."""
# ---------------------------------------------------------------------
# Optional artifact helpers (supported by the Nomad sandbox-server today)
# ---------------------------------------------------------------------
async def read_artifact(
self,
slot: Slot,
path: str,
*,
encoding: str = "text",
max_bytes: Optional[int] = None,
include_sha256: bool = False,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
raise NotImplementedError
async def list_artifacts(
self,
slot: Slot,
path: str = ".",
*,
recursive: bool = False,
max_entries: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
raise NotImplementedError
async def archive_artifacts(
self,
slot: Slot,
path: str = ".",
*,
archive_format: str = "tar.gz",
max_bytes: Optional[int] = None,
max_entries: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
raise NotImplementedError

View file

@ -0,0 +1,73 @@
"""
Modal tool backend (stub).
We intentionally ship a placeholder implementation so AgentEnv can expose a
backend switch without forcing Modal as a hard dependency for Hermes-Agent.
When org access is available, this backend will be implemented by running a
long-lived Modal worker (or pool) that owns N slots and exposes `execute_batch`.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from ..slots.executor import ExecutionResult
from ..slots.slot import Slot
from .base import ToolBackend
@dataclass(frozen=True)
class ModalBackendConfig:
# Placeholders for future implementation.
app_name: str = "atropos-sandbox"
function_name: str = "sandbox_server"
volume_name: Optional[str] = None
volume_mount_path: str = "/data"
@classmethod
def from_agent_env_config(cls, cfg: Any) -> "ModalBackendConfig":
return cls(
app_name=str(getattr(cfg, "modal_app_name", cls.app_name)),
function_name=str(getattr(cfg, "modal_function_name", cls.function_name)),
volume_name=(getattr(cfg, "modal_volume_name", None) or None),
volume_mount_path=str(getattr(cfg, "modal_volume_mount_path", cls.volume_mount_path)),
)
class ModalToolBackend(ToolBackend):
def __init__(self, config: ModalBackendConfig):
self.config = config
@property
def default_timeout_s(self) -> Optional[float]:
return None
def _unavailable(self) -> RuntimeError:
return RuntimeError(
"Modal tool backend is not implemented yet. "
"Keep `--env.tool_pool_mode nomad` for now."
)
async def start(self) -> None:
raise self._unavailable()
async def stop(self, *, purge: bool = False) -> None: # noqa: ARG002
# If start() isn't implemented, stop() is also unavailable.
raise self._unavailable()
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot: # noqa: ARG002
raise self._unavailable()
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None: # noqa: ARG002
raise self._unavailable()
async def execute_batch(
self,
requests: List[Tuple[Slot, str, Dict[str, Any]]],
*,
timeout_s: Optional[float] = None, # noqa: ARG002
) -> List[ExecutionResult]:
raise self._unavailable()

View file

@ -0,0 +1,148 @@
"""
Nomad/Docker tool backend.
This backend is the current default for AgentEnv: it provisions a Nomad job
running `sandbox_server.py` and multiplexes stateless slots inside each container.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from ..slots import Slot, SlotPool, SlotPoolConfig
from ..slots.executor import ExecutionResult
from .base import ToolBackend
@dataclass(frozen=True)
class NomadBackendConfig:
nomad_address: str
sandbox_job_id: str
sandbox_image: str
slots_per_container: int
min_containers: int
max_containers: int
privileged: bool
acquire_timeout_s: float
purge_job_on_start: bool
@classmethod
def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig":
return cls(
nomad_address=str(getattr(cfg, "nomad_address")),
sandbox_job_id=str(getattr(cfg, "sandbox_job_id")),
sandbox_image=str(getattr(cfg, "sandbox_image")),
slots_per_container=int(getattr(cfg, "slots_per_container")),
min_containers=int(getattr(cfg, "min_containers")),
max_containers=int(getattr(cfg, "max_containers")),
privileged=bool(getattr(cfg, "privileged")),
acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")),
purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)),
)
class NomadToolBackend(ToolBackend):
def __init__(self, config: NomadBackendConfig):
self.config = config
self.pool = SlotPool(
SlotPoolConfig(
nomad_address=config.nomad_address,
job_id=config.sandbox_job_id,
image=config.sandbox_image,
slots_per_container=config.slots_per_container,
min_containers=config.min_containers,
max_containers=config.max_containers,
privileged=config.privileged,
acquire_timeout=config.acquire_timeout_s,
purge_job_on_start=bool(config.purge_job_on_start),
)
)
@property
def default_timeout_s(self) -> Optional[float]:
t = getattr(self.pool.executor, "timeout", None)
total = getattr(t, "total", None)
try:
return float(total) if total is not None else None
except Exception:
return None
async def start(self) -> None:
await self.pool.start()
async def stop(self, *, purge: bool = False) -> None:
await self.pool.stop(purge_job=purge)
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
return await self.pool.acquire(trajectory_id)
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
await self.pool.release(slot, reset_workspace=reset_workspace)
async def execute_batch(
self,
requests: List[Tuple[Slot, str, Dict[str, Any]]],
*,
timeout_s: Optional[float] = None,
) -> List[ExecutionResult]:
return await self.pool.execute_batch(requests, timeout=timeout_s)
async def read_artifact(
self,
slot: Slot,
path: str,
*,
encoding: str = "text",
max_bytes: Optional[int] = None,
include_sha256: bool = False,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
return await self.pool.executor.read_artifact(
slot,
path,
encoding=encoding,
max_bytes=max_bytes,
include_sha256=include_sha256,
timeout=timeout_s,
)
async def list_artifacts(
self,
slot: Slot,
path: str = ".",
*,
recursive: bool = False,
max_entries: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
return await self.pool.executor.list_artifacts(
slot,
path,
recursive=recursive,
max_entries=max_entries,
timeout=timeout_s,
)
async def archive_artifacts(
self,
slot: Slot,
path: str = ".",
*,
archive_format: str = "tar.gz",
max_bytes: Optional[int] = None,
max_entries: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> Dict[str, Any]:
return await self.pool.executor.archive_artifacts(
slot,
path,
archive_format=archive_format,
max_bytes=max_bytes,
max_entries=max_entries,
timeout=timeout_s,
)
def get_stats(self) -> Dict[str, Any]:
return self.pool.get_stats()

View file

@ -19,14 +19,14 @@ from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item,
from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight
from ..agent import AgentConfig, AgentResult, AtroposAgent
from ..slots import SlotPool, SlotPoolConfig
from ..backends import ToolBackend, create_tool_backend
from ..tools import ToolRegistry, build_tool_registry
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
# Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily.
class AgentEnvConfig(BaseEnvConfig):
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend (only 'nomad' is supported)")
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')")
allow_network: bool = Field(
default=True,
@ -61,6 +61,12 @@ class AgentEnvConfig(BaseEnvConfig):
)
purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown")
# modal mode settings (stub; implementation pending)
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name (stub)")
modal_function_name: str = Field(default="sandbox_server", description="Modal function/actor name (stub)")
modal_volume_name: Optional[str] = Field(default=None, description="Modal Volume name for persistent storage (stub)")
modal_volume_mount_path: str = Field(default="/data", description="Modal Volume mount path (stub)")
# basic agent defaults
agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory")
agent_temperature: float = Field(default=0.7, description="Sampling temperature")
@ -108,7 +114,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
self.tools: ToolRegistry = self.build_tools()
self._pool: Optional[Any] = None
self._backend: Optional[ToolBackend] = None
self._tool_executor: Optional[ToolExecutor] = None
self._tool_server_inprocess: bool = False
self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {}
@ -263,27 +269,11 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
tool_server_url = "http://toolserver"
self._tool_server_inprocess = True
if self.config.tool_pool_mode != "nomad":
# TODO Add Modal here, maybe in-process, but not safe to have that tbh
raise RuntimeError("tool_pool_mode must be 'nomad' (local/in-process pools are not supported)")
pool = SlotPool(
SlotPoolConfig(
nomad_address=self.config.nomad_address,
job_id=self.config.sandbox_job_id,
image=self.config.sandbox_image,
slots_per_container=self.config.slots_per_container,
min_containers=self.config.min_containers,
max_containers=self.config.max_containers,
privileged=self.config.privileged,
acquire_timeout=self.config.acquire_timeout_s,
purge_job_on_start=bool(self.config.purge_job_on_start),
)
)
await pool.start()
backend = create_tool_backend(self.config)
await backend.start()
executor = ToolExecutor(
pool=pool,
backend=backend,
tools=self.tools,
config=ToolExecutorConfig(
batch_window_ms=self.config.tool_batch_window_ms,
@ -299,21 +289,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
if tool_server_client is not None:
executor._tool_server_client = tool_server_client # type: ignore[attr-defined]
self._pool = pool
self._backend = backend
self._tool_executor = executor
async def shutdown_tool_backend(self) -> None:
executor = self._tool_executor
pool = self._pool
backend = self._backend
inprocess_tool_server = self._tool_server_inprocess
self._tool_executor = None
self._pool = None
self._backend = None
self._tool_server_inprocess = False
if executor is not None:
await executor.close()
if pool is not None:
await pool.stop(purge_job=bool(self.config.purge_job_on_shutdown))
if backend is not None:
await backend.stop(purge=bool(self.config.purge_job_on_shutdown))
if inprocess_tool_server:
from ..api.tool_server import app as tool_server_app

View file

@ -35,7 +35,8 @@ from .base import (
ToolResultPayload,
ToolServerExecuteRequest,
)
from ..slots import Slot, SlotPool
from ..backends.base import ToolBackend
from ..slots import Slot
@dataclass
@ -60,11 +61,11 @@ class _QueuedToolRequest:
class ToolExecutor:
def __init__(
self,
pool: SlotPool,
backend: ToolBackend,
tools: ToolRegistry,
config: Optional[ToolExecutorConfig] = None,
) -> None:
self.pool = pool
self.backend = backend
self.tools = tools
self.config = config or ToolExecutorConfig()
@ -109,7 +110,7 @@ class ToolExecutor:
for _, slot in slots:
try:
await self.pool.release(slot, reset_workspace=False)
await self.backend.release(slot, reset_workspace=False)
except Exception:
pass
@ -146,7 +147,7 @@ class ToolExecutor:
slot = self._slot_by_trajectory.pop(trajectory_id, None)
if slot is not None:
await self.pool.release(slot, reset_workspace=reset_workspace)
await self.backend.release(slot, reset_workspace=reset_workspace)
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
async with self._slots_lock:
@ -160,7 +161,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.read_artifact(
data = await self.backend.read_artifact(
slot,
req.path,
encoding=req.encoding,
@ -179,7 +180,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.list_artifacts(
data = await self.backend.list_artifacts(
slot,
req.path,
recursive=req.recursive,
@ -197,7 +198,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.archive_artifacts(
data = await self.backend.archive_artifacts(
slot,
req.path,
archive_format=req.format,
@ -218,13 +219,13 @@ class ToolExecutor:
if existing is not None:
return existing
slot = await self.pool.acquire(trajectory_id)
slot = await self.backend.acquire(trajectory_id)
async with self._slots_lock:
existing = self._slot_by_trajectory.get(trajectory_id)
if existing is not None:
# Another coroutine won the race; return its slot.
await self.pool.release(slot, reset_workspace=False)
await self.backend.release(slot, reset_workspace=False)
return existing
self._slot_by_trajectory[trajectory_id] = slot
return slot
@ -400,9 +401,7 @@ class ToolExecutor:
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
default_timeout = None
if self.pool.executor.timeout.total is not None:
default_timeout = float(self.pool.executor.timeout.total)
default_timeout = self.backend.default_timeout_s
for it in sandbox_items:
t = it.timeout_s
@ -476,7 +475,7 @@ class ToolExecutor:
try:
if not dispatched:
continue
results = await self.pool.execute_batch(requests, timeout=timeout_s)
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
except Exception as e:
for it in items:
self.total_requests += 1