hermes-agent/tools/environments/ssh.py
Siddharth Balyan 27eeea0555
perf(ssh,modal): bulk file sync via tar pipe and tar/base64 archive (#8014)
* perf(ssh,modal): bulk file sync via tar pipe and tar/base64 archive

SSH: symlink-staging + tar -ch piped over SSH in a single TCP stream.
Eliminates per-file scp round-trips. Handles timeout (kills both
processes), SSH Popen failure (kills tar), and tar create failure.

Modal: in-memory gzipped tar archive, base64-encoded, decoded+extracted
in one exec call. Checks exit code and raises on failure.

Both backends use shared helpers extracted into file_sync.py:
- quoted_mkdir_command() — mirrors existing quoted_rm_command()
- unique_parent_dirs() — deduplicates parent dirs from file pairs

Migrates _ensure_remote_dirs to use the new helpers.

28 new tests (21 SSH + 7 Modal), all passing.

Closes #7465
Closes #7467

* fix(modal): pipe stdin to avoid ARG_MAX, clean up review findings

- Modal bulk upload: stream base64 payload through proc.stdin in 1MB
  chunks instead of embedding in command string (Modal SDK enforces
  64KB ARG_MAX_BYTES — typical payloads are ~4.3MB)
- Modal single-file upload: same stdin fix, add exit code checking
- Remove what-narrating comments in ssh.py and modal.py (keep WHY
  comments: symlink staging rationale, SIGPIPE, deadlock avoidance)
- Remove unnecessary `sandbox = self._sandbox` alias in modal bulk
- Daytona: use shared helpers (unique_parent_dirs, quoted_mkdir_command)
  instead of inlined duplicates

---------

Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com>
2026-04-12 06:18:05 +05:30

258 lines
10 KiB
Python

"""SSH remote execution environment with ControlMaster connection persistence."""
import logging
import os
import shlex
import shutil
import subprocess
import tempfile
from pathlib import Path
from tools.environments.base import BaseEnvironment, _popen_bash
from tools.environments.file_sync import (
FileSyncManager,
iter_sync_files,
quoted_mkdir_command,
quoted_rm_command,
unique_parent_dirs,
)
logger = logging.getLogger(__name__)
def _ensure_ssh_available() -> None:
"""Fail fast with a clear error when the SSH client is unavailable."""
if not shutil.which("ssh"):
raise RuntimeError(
"SSH is not installed or not in PATH. Install OpenSSH client: apt install openssh-client"
)
class SSHEnvironment(BaseEnvironment):
"""Run commands on a remote machine over SSH.
Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process.
Session snapshot preserves env vars across calls.
CWD persists via in-band stdout markers.
Uses SSH ControlMaster for connection reuse.
"""
def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = ""):
super().__init__(cwd=cwd, timeout=timeout)
self.host = host
self.user = user
self.port = port
self.key_path = key_path
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
self.control_dir.mkdir(parents=True, exist_ok=True)
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
_ensure_ssh_available()
self._establish_connection()
self._remote_home = self._detect_remote_home()
self._ensure_remote_dirs()
self._sync_manager = FileSyncManager(
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
upload_fn=self._scp_upload,
delete_fn=self._ssh_delete,
bulk_upload_fn=self._ssh_bulk_upload,
)
self._sync_manager.sync(force=True)
self.init_session()
def _build_ssh_command(self, extra_args: list | None = None) -> list:
cmd = ["ssh"]
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
cmd.extend(["-o", "ControlMaster=auto"])
cmd.extend(["-o", "ControlPersist=300"])
cmd.extend(["-o", "BatchMode=yes"])
cmd.extend(["-o", "StrictHostKeyChecking=accept-new"])
cmd.extend(["-o", "ConnectTimeout=10"])
if self.port != 22:
cmd.extend(["-p", str(self.port)])
if self.key_path:
cmd.extend(["-i", self.key_path])
if extra_args:
cmd.extend(extra_args)
cmd.append(f"{self.user}@{self.host}")
return cmd
def _establish_connection(self):
cmd = self._build_ssh_command()
cmd.append("echo 'SSH connection established'")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
if result.returncode != 0:
error_msg = result.stderr.strip() or result.stdout.strip()
raise RuntimeError(f"SSH connection failed: {error_msg}")
except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
def _detect_remote_home(self) -> str:
"""Detect the remote user's home directory."""
try:
cmd = self._build_ssh_command()
cmd.append("echo $HOME")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
home = result.stdout.strip()
if home and result.returncode == 0:
logger.debug("SSH: remote home = %s", home)
return home
except Exception:
pass
if self.user == "root":
return "/root"
return f"/home/{self.user}"
# ------------------------------------------------------------------
# File sync (via FileSyncManager)
# ------------------------------------------------------------------
def _ensure_remote_dirs(self) -> None:
"""Create base ~/.hermes directory tree on remote in one SSH call."""
base = f"{self._remote_home}/.hermes"
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
cmd = self._build_ssh_command()
cmd.append(quoted_mkdir_command(dirs))
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
# _get_sync_files provided via iter_sync_files in FileSyncManager init
def _scp_upload(self, host_path: str, remote_path: str) -> None:
"""Upload a single file via scp over ControlMaster."""
parent = str(Path(remote_path).parent)
mkdir_cmd = self._build_ssh_command()
mkdir_cmd.append(f"mkdir -p {shlex.quote(parent)}")
subprocess.run(mkdir_cmd, capture_output=True, text=True, timeout=10)
scp_cmd = ["scp", "-o", f"ControlPath={self.control_socket}"]
if self.port != 22:
scp_cmd.extend(["-P", str(self.port)])
if self.key_path:
scp_cmd.extend(["-i", self.key_path])
scp_cmd.extend([host_path, f"{self.user}@{self.host}:{remote_path}"])
result = subprocess.run(scp_cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None:
"""Upload many files in a single tar-over-SSH stream.
Pipes ``tar c`` on the local side through an SSH connection to
``tar x`` on the remote, transferring all files in one TCP stream
instead of spawning a subprocess per file. Directory creation is
batched into a single ``mkdir -p`` call beforehand.
Typical improvement: ~580 files goes from O(N) scp round-trips
to a single streaming transfer.
"""
if not files:
return
parents = unique_parent_dirs(files)
if parents:
cmd = self._build_ssh_command()
cmd.append(quoted_mkdir_command(parents))
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}")
# Symlink staging avoids fragile GNU tar --transform rules.
with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging:
for host_path, remote_path in files:
staged = os.path.join(staging, remote_path.lstrip("/"))
os.makedirs(os.path.dirname(staged), exist_ok=True)
os.symlink(os.path.abspath(host_path), staged)
tar_cmd = ["tar", "-chf", "-", "-C", staging, "."]
ssh_cmd = self._build_ssh_command()
ssh_cmd.append("tar xf - -C /")
tar_proc = subprocess.Popen(
tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
try:
ssh_proc = subprocess.Popen(
ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except Exception:
tar_proc.kill()
tar_proc.wait()
raise
# Allow tar_proc to receive SIGPIPE if ssh_proc exits early
tar_proc.stdout.close()
try:
_, ssh_stderr = ssh_proc.communicate(timeout=120)
# Use communicate() instead of wait() to drain stderr and
# avoid deadlock if tar produces more than PIPE_BUF of errors.
tar_stderr_raw = b""
if tar_proc.poll() is None:
_, tar_stderr_raw = tar_proc.communicate(timeout=10)
else:
tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b""
except subprocess.TimeoutExpired:
tar_proc.kill()
ssh_proc.kill()
tar_proc.wait()
ssh_proc.wait()
raise RuntimeError("SSH bulk upload timed out")
if tar_proc.returncode != 0:
raise RuntimeError(
f"tar create failed (rc={tar_proc.returncode}): "
f"{tar_stderr_raw.decode(errors='replace').strip()}"
)
if ssh_proc.returncode != 0:
raise RuntimeError(
f"tar extract over SSH failed (rc={ssh_proc.returncode}): "
f"{ssh_stderr.decode(errors='replace').strip()}"
)
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
def _ssh_delete(self, remote_paths: list[str]) -> None:
"""Batch-delete remote files in one SSH call."""
cmd = self._build_ssh_command()
cmd.append(quoted_rm_command(remote_paths))
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
raise RuntimeError(f"remote rm failed: {result.stderr.strip()}")
def _before_execute(self) -> None:
"""Sync files to remote via FileSyncManager (rate-limited internally)."""
self._sync_manager.sync()
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def _run_bash(self, cmd_string: str, *, login: bool = False,
timeout: int = 120,
stdin_data: str | None = None) -> subprocess.Popen:
"""Spawn an SSH process that runs bash on the remote host."""
cmd = self._build_ssh_command()
if login:
cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)])
else:
cmd.extend(["bash", "-c", shlex.quote(cmd_string)])
return _popen_bash(cmd, stdin_data)
def cleanup(self):
if self.control_socket.exists():
try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
"-O", "exit", f"{self.user}@{self.host}"]
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass
try:
self.control_socket.unlink()
except OSError:
pass