Fixes and refactors enabled by recent updates to main.

This commit is contained in:
Robin Fernandes 2026-03-31 09:29:59 +09:00
parent 1126284c97
commit 1b7473e702
5 changed files with 406 additions and 150 deletions

View file

@ -9,13 +9,16 @@ import json
import logging
import shlex
import threading
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
from hermes_constants import get_hermes_home
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
from tools.environments.modal_common import (
BaseModalExecutionEnvironment,
ModalExecStart,
PreparedModalExec,
)
logger = logging.getLogger(__name__)
@ -135,9 +138,20 @@ class _AsyncWorker:
self._thread.join(timeout=10)
class ModalEnvironment(BaseEnvironment):
@dataclass
class _DirectModalExecHandle:
thread: threading.Thread
result_holder: Dict[str, Any]
class ModalEnvironment(BaseModalExecutionEnvironment):
"""Modal cloud execution via native Modal sandboxes."""
_stdin_mode = "heredoc"
_poll_interval_seconds = 0.2
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
_unexpected_error_prefix = "Modal execution error"
def __init__(
self,
image: str,
@ -312,36 +326,11 @@ class ModalEnvironment(BaseEnvironment):
except Exception as e:
logger.debug("Modal: file sync failed: %s", e)
def execute(
self,
command: str,
cwd: str = "",
*,
timeout: int | None = None,
stdin_data: str | None = None,
) -> dict:
def _before_execute(self) -> None:
self._sync_files()
if stdin_data is not None:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
while marker in stdin_data:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
exec_command, sudo_stdin = self._prepare_command(command)
# Modal sandboxes execute commands via exec() and cannot pipe
# subprocess stdin directly. When a sudo password is present,
# use a shell-level pipe from printf.
if sudo_stdin is not None:
exec_command = (
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
)
effective_cwd = cwd or self.cwd
effective_timeout = timeout or self.timeout
full_command = f"cd {shlex.quote(effective_cwd)} && {exec_command}"
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
result_holder = {"value": None, "error": None}
def _run():
@ -351,7 +340,7 @@ class ModalEnvironment(BaseEnvironment):
"bash",
"-c",
full_command,
timeout=effective_timeout,
timeout=prepared.timeout,
)
stdout = await process.stdout.read.aio()
stderr = await process.stderr.read.aio()
@ -363,42 +352,31 @@ class ModalEnvironment(BaseEnvironment):
output = stdout
if stderr:
output = f"{stdout}\n{stderr}" if stdout else stderr
return output, exit_code
return self._result(output, exit_code)
output, exit_code = self._worker.run_coroutine(
result_holder["value"] = self._worker.run_coroutine(
_do_execute(),
timeout=effective_timeout + 30,
timeout=prepared.timeout + 30,
)
result_holder["value"] = {
"output": output,
"returncode": exit_code,
}
except Exception as e:
result_holder["error"] = e
t = threading.Thread(target=_run, daemon=True)
t.start()
while t.is_alive():
t.join(timeout=0.2)
if is_interrupted():
try:
self._worker.run_coroutine(
self._sandbox.terminate.aio(),
timeout=15,
)
except Exception:
pass
return {
"output": "[Command interrupted - Modal sandbox terminated]",
"returncode": 130,
}
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
if result_holder["error"]:
return {
"output": f"Modal execution error: {result_holder['error']}",
"returncode": 1,
}
return result_holder["value"]
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
if handle.thread.is_alive():
return None
if handle.result_holder["error"]:
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
return handle.result_holder["value"]
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
self._worker.run_coroutine(
self._sandbox.terminate.aio(),
timeout=15,
)
def cleanup(self):
"""Snapshot the filesystem (if persistent) then stop the sandbox."""