Modal backend stubs

This commit is contained in:
Shannon Sands 2026-02-04 15:20:37 +10:00
parent 7130fa50cb
commit ea7aa0b0d4
7 changed files with 383 additions and 56 deletions

View file

@ -35,7 +35,8 @@ from .base import (
ToolResultPayload,
ToolServerExecuteRequest,
)
from ..slots import Slot, SlotPool
from ..backends.base import ToolBackend
from ..slots import Slot
@dataclass
@ -60,11 +61,11 @@ class _QueuedToolRequest:
class ToolExecutor:
def __init__(
self,
pool: SlotPool,
backend: ToolBackend,
tools: ToolRegistry,
config: Optional[ToolExecutorConfig] = None,
) -> None:
self.pool = pool
self.backend = backend
self.tools = tools
self.config = config or ToolExecutorConfig()
@ -109,7 +110,7 @@ class ToolExecutor:
for _, slot in slots:
try:
await self.pool.release(slot, reset_workspace=False)
await self.backend.release(slot, reset_workspace=False)
except Exception:
pass
@ -146,7 +147,7 @@ class ToolExecutor:
slot = self._slot_by_trajectory.pop(trajectory_id, None)
if slot is not None:
await self.pool.release(slot, reset_workspace=reset_workspace)
await self.backend.release(slot, reset_workspace=reset_workspace)
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
async with self._slots_lock:
@ -160,7 +161,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.read_artifact(
data = await self.backend.read_artifact(
slot,
req.path,
encoding=req.encoding,
@ -179,7 +180,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.list_artifacts(
data = await self.backend.list_artifacts(
slot,
req.path,
recursive=req.recursive,
@ -197,7 +198,7 @@ class ToolExecutor:
slot = await self._get_slot_if_present(req.trajectory_id)
if slot is None:
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
data = await self.pool.executor.archive_artifacts(
data = await self.backend.archive_artifacts(
slot,
req.path,
archive_format=req.format,
@ -218,13 +219,13 @@ class ToolExecutor:
if existing is not None:
return existing
slot = await self.pool.acquire(trajectory_id)
slot = await self.backend.acquire(trajectory_id)
async with self._slots_lock:
existing = self._slot_by_trajectory.get(trajectory_id)
if existing is not None:
# Another coroutine won the race; return its slot.
await self.pool.release(slot, reset_workspace=False)
await self.backend.release(slot, reset_workspace=False)
return existing
self._slot_by_trajectory[trajectory_id] = slot
return slot
@ -400,9 +401,7 @@ class ToolExecutor:
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
default_timeout = None
if self.pool.executor.timeout.total is not None:
default_timeout = float(self.pool.executor.timeout.total)
default_timeout = self.backend.default_timeout_s
for it in sandbox_items:
t = it.timeout_s
@ -476,7 +475,7 @@ class ToolExecutor:
try:
if not dispatched:
continue
results = await self.pool.execute_batch(requests, timeout=timeout_s)
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
except Exception as e:
for it in items:
self.total_requests += 1