mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
perf(ssh,modal): bulk file sync via tar pipe and tar/base64 archive (#8014)
* perf(ssh,modal): bulk file sync via tar pipe and tar/base64 archive SSH: symlink-staging + tar -ch piped over SSH in a single TCP stream. Eliminates per-file scp round-trips. Handles timeout (kills both processes), SSH Popen failure (kills tar), and tar create failure. Modal: in-memory gzipped tar archive, base64-encoded, decoded+extracted in one exec call. Checks exit code and raises on failure. Both backends use shared helpers extracted into file_sync.py: - quoted_mkdir_command() — mirrors existing quoted_rm_command() - unique_parent_dirs() — deduplicates parent dirs from file pairs Migrates _ensure_remote_dirs to use the new helpers. 28 new tests (21 SSH + 7 Modal), all passing. Closes #7465 Closes #7467 * fix(modal): pipe stdin to avoid ARG_MAX, clean up review findings - Modal bulk upload: stream base64 payload through proc.stdin in 1MB chunks instead of embedding in command string (Modal SDK enforces 64KB ARG_MAX_BYTES — typical payloads are ~4.3MB) - Modal single-file upload: same stdin fix, add exit code checking - Remove what-narrating comments in ssh.py and modal.py (keep WHY comments: symlink staging rationale, SIGPIPE, deadlock avoidance) - Remove unnecessary `sandbox = self._sandbox` alias in modal bulk - Daytona: use shared helpers (unique_parent_dirs, quoted_mkdir_command) instead of inlined duplicates --------- Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
parent
fd73937ec8
commit
27eeea0555
6 changed files with 992 additions and 13 deletions
295
tests/tools/test_modal_bulk_upload.py
Normal file
295
tests/tools/test_modal_bulk_upload.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""Tests for Modal bulk upload via tar/base64 archive."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import modal as modal_env
|
||||
|
||||
|
||||
def _make_mock_modal_env(monkeypatch, tmp_path):
|
||||
"""Create a minimal mock ModalEnvironment for testing upload methods.
|
||||
|
||||
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
|
||||
We don't call __init__ because it requires the Modal SDK.
|
||||
"""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _make_mock_stdin():
|
||||
"""Create a mock stdin that captures written data."""
|
||||
stdin = MagicMock()
|
||||
written_chunks = []
|
||||
|
||||
def mock_write(data):
|
||||
written_chunks.append(data)
|
||||
|
||||
stdin.write = mock_write
|
||||
stdin.write_eof = MagicMock()
|
||||
stdin.drain = MagicMock()
|
||||
stdin.drain.aio = AsyncMock()
|
||||
stdin._written_chunks = written_chunks
|
||||
return stdin
|
||||
|
||||
|
||||
def _wire_async_exec(env, exec_calls=None):
|
||||
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
|
||||
|
||||
Optionally captures exec call args into *exec_calls* list.
|
||||
Returns (exec_calls, run_kwargs, stdin_mock).
|
||||
"""
|
||||
if exec_calls is None:
|
||||
exec_calls = []
|
||||
run_kwargs: dict = {}
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=0)
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls, run_kwargs, stdin_mock
|
||||
|
||||
|
||||
class TestModalBulkUpload:
|
||||
"""Test _modal_bulk_upload method."""
|
||||
|
||||
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
|
||||
"""Empty file list should not call worker.run_coroutine."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
env._modal_bulk_upload([])
|
||||
env._worker.run_coroutine.assert_not_called()
|
||||
|
||||
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
|
||||
"""The tar archive sent via stdin should contain all files."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src_a = tmp_path / "a.json"
|
||||
src_b = tmp_path / "b.py"
|
||||
src_a.write_text("cred_content")
|
||||
src_b.write_text("skill_content")
|
||||
|
||||
files = [
|
||||
(str(src_a), "/root/.hermes/credentials/a.json"),
|
||||
(str(src_b), "/root/.hermes/skills/b.py"),
|
||||
]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Verify the command reads from stdin (no echo with embedded payload)
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
cmd = args[2]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "base64 -d" in cmd
|
||||
assert "tar xzf" in cmd
|
||||
assert "-C /" in cmd
|
||||
|
||||
# Reassemble the base64 payload from stdin chunks and verify tar contents
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = sorted(tar.getnames())
|
||||
assert "root/.hermes/credentials/a.json" in names
|
||||
assert "root/.hermes/skills/b.py" in names
|
||||
|
||||
# Verify content
|
||||
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
|
||||
assert a_content == b"cred_content"
|
||||
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
|
||||
assert b_content == b"skill_content"
|
||||
|
||||
# Verify stdin was closed
|
||||
stdin_mock.write_eof.assert_called_once()
|
||||
|
||||
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
|
||||
"""Remote parent directories should be pre-created in the command."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
|
||||
files = [
|
||||
(str(src), "/root/.hermes/credentials/f.txt"),
|
||||
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
|
||||
]
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
cmd = exec_calls[0][2]
|
||||
assert "/root/.hermes/credentials" in cmd
|
||||
assert "/root/.hermes/skills/deep/nested" in cmd
|
||||
|
||||
def test_single_exec_call(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload should use exactly one exec call regardless of file count."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
files = []
|
||||
for i in range(20):
|
||||
src = tmp_path / f"file_{i}.txt"
|
||||
src.write_text(f"content_{i}")
|
||||
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should be exactly 1 exec call, not 20
|
||||
assert len(exec_calls) == 1
|
||||
|
||||
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
|
||||
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
# Create a minimal env without full __init__
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Manually call the part of __init__ that wires FileSyncManager
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
)
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
_, run_kwargs, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
|
||||
"""Non-zero exit code from remote exec should raise RuntimeError."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
|
||||
"""The base64 payload must NOT appear in the bash -c argument.
|
||||
|
||||
This is the core ARG_MAX fix: the payload goes through stdin,
|
||||
not embedded in the command string.
|
||||
"""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("some data to upload")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# The command should NOT contain an echo with the payload
|
||||
cmd = exec_calls[0][2]
|
||||
assert "echo" not in cmd
|
||||
# The payload should go through stdin
|
||||
assert len(stdin_mock._written_chunks) > 0
|
||||
|
||||
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
|
||||
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
# Use random bytes so gzip cannot compress them -- ensures the
|
||||
# base64 payload exceeds one 1 MB chunk.
|
||||
import os as _os
|
||||
src = tmp_path / "large.bin"
|
||||
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
|
||||
files = [(str(src), "/root/.hermes/large.bin")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should have multiple stdin write chunks
|
||||
assert len(stdin_mock._written_chunks) >= 2
|
||||
|
||||
# Reassembled payload should still decode to valid tar
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = tar.getnames()
|
||||
assert "root/.hermes/large.bin" in names
|
||||
517
tests/tools/test_ssh_bulk_upload.py
Normal file
517
tests/tools/test_ssh_bulk_upload.py
Normal file
|
|
@ -0,0 +1,517 @@
|
|||
"""Tests for SSH bulk upload via tar pipe."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
|
||||
stderr_read=b""):
|
||||
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = returncode
|
||||
m.poll.return_value = poll_return
|
||||
m.communicate.return_value = communicate_return
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = stderr_read
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
class TestSSHBulkUpload:
|
||||
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
|
||||
|
||||
def test_empty_files_is_noop(self, mock_env):
|
||||
"""Empty file list should not spawn any subprocesses."""
|
||||
with patch.object(subprocess, "run") as mock_run, \
|
||||
patch.object(subprocess, "Popen") as mock_popen:
|
||||
mock_env._ssh_bulk_upload([])
|
||||
mock_run.assert_not_called()
|
||||
mock_popen.assert_not_called()
|
||||
|
||||
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
|
||||
"""All parent directories should be created in one SSH call."""
|
||||
# Create test files
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("aaa")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("bbb")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
|
||||
]
|
||||
|
||||
# Mock subprocess.run for mkdir and Popen for tar pipe
|
||||
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = 0
|
||||
m.poll.return_value = 0
|
||||
m.communicate.return_value = (b"", b"")
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = b""
|
||||
return m
|
||||
|
||||
with patch.object(subprocess, "run", mock_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Exactly one subprocess.run call for mkdir
|
||||
assert mock_run.call_count == 1
|
||||
mkdir_cmd = mock_run.call_args[0][0]
|
||||
# Should contain mkdir -p with both parent dirs
|
||||
mkdir_str = " ".join(mkdir_cmd)
|
||||
assert "mkdir -p" in mkdir_str
|
||||
assert "/home/testuser/.hermes/skills" in mkdir_str
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
|
||||
"""Symlinks in staging dir should mirror the remote path structure."""
|
||||
f1 = tmp_path / "local_a.txt"
|
||||
f1.write_text("content a")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
|
||||
]
|
||||
|
||||
staging_paths = []
|
||||
|
||||
def capture_tar_cmd(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
# Capture the staging dir from -C argument
|
||||
c_idx = cmd.index("-C")
|
||||
staging_dir = cmd[c_idx + 1]
|
||||
# Check the symlink exists
|
||||
expected = os.path.join(
|
||||
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
|
||||
)
|
||||
staging_paths.append(expected)
|
||||
assert os.path.islink(expected), f"Expected symlink at {expected}"
|
||||
assert os.readlink(expected) == os.path.abspath(str(f1))
|
||||
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(staging_paths) == 1, "tar command should have been called"
|
||||
|
||||
def test_tar_pipe_commands(self, mock_env, tmp_path):
|
||||
"""Verify tar and SSH commands are wired correctly."""
|
||||
f1 = tmp_path / "x.txt"
|
||||
f1.write_text("x")
|
||||
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
|
||||
|
||||
tar_cmd = popen_cmds[0]
|
||||
ssh_cmd = popen_cmds[1]
|
||||
|
||||
# tar: create, dereference symlinks, to stdout
|
||||
assert tar_cmd[0] == "tar"
|
||||
assert "-chf" in tar_cmd
|
||||
assert "-" in tar_cmd # stdout
|
||||
assert "-C" in tar_cmd
|
||||
|
||||
# ssh: extract from stdin at /
|
||||
ssh_str = " ".join(ssh_cmd)
|
||||
assert "ssh" in ssh_str
|
||||
assert "tar xf - -C /" in ssh_str
|
||||
assert "testuser@example.com" in ssh_str
|
||||
|
||||
def test_mkdir_failure_raises(self, mock_env, tmp_path):
|
||||
"""mkdir failure should raise RuntimeError before tar pipe."""
|
||||
f1 = tmp_path / "y.txt"
|
||||
f1.write_text("y")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
|
||||
|
||||
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed_run):
|
||||
with pytest.raises(RuntimeError, match="remote mkdir failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_tar_create_failure_raises(self, mock_env, tmp_path):
|
||||
"""tar create failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "z.txt"
|
||||
f1.write_text("z")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 1
|
||||
mock_tar.poll.return_value = 1
|
||||
mock_tar.communicate.return_value = (b"tar: error", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b"tar: error"
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"")
|
||||
mock_ssh.returncode = 0
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar create failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
|
||||
"""SSH tar extract failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "w.txt"
|
||||
f1.write_text("w")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 0
|
||||
mock_tar.poll.return_value = 0
|
||||
mock_tar.communicate.return_value = (b"", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b""
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"Permission denied")
|
||||
mock_ssh.returncode = 1
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
|
||||
"""SSH command for tar extract should reuse ControlMaster socket."""
|
||||
f1 = tmp_path / "c.txt"
|
||||
f1.write_text("c")
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# The SSH command (second Popen call) should include ControlPath
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
|
||||
|
||||
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload SSH command should include custom port and key."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
|
||||
|
||||
f1 = tmp_path / "d.txt"
|
||||
f1.write_text("d")
|
||||
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
|
||||
|
||||
run_cmds = []
|
||||
popen_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
env._ssh_bulk_upload(files)
|
||||
|
||||
# Check mkdir SSH call includes port and key
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_cmd = run_cmds[0]
|
||||
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
|
||||
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
|
||||
|
||||
# Check tar extract SSH call includes port and key
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert "-p" in ssh_cmd and "2222" in ssh_cmd
|
||||
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
|
||||
|
||||
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
|
||||
"""Multiple files in the same dir should produce one mkdir entry."""
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("a")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("b")
|
||||
f3 = tmp_path / "c.txt"
|
||||
f3.write_text("c")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
|
||||
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
|
||||
]
|
||||
|
||||
run_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def make_mock_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Only one mkdir call
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_str = " ".join(run_cmds[0])
|
||||
# skills dir should appear exactly once despite two files
|
||||
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
|
||||
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
|
||||
f1 = tmp_path / "s.txt"
|
||||
f1.write_text("s")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
|
||||
|
||||
mock_tar_stdout = MagicMock()
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
if cmd[0] == "tar":
|
||||
mock.stdout = mock_tar_stdout
|
||||
else:
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar_stdout.close.assert_called_once()
|
||||
|
||||
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
|
||||
"""TimeoutExpired during communicate should kill both processes."""
|
||||
f1 = tmp_path / "t.txt"
|
||||
f1.write_text("t")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = None
|
||||
mock_tar.poll.return_value = None
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
|
||||
mock_ssh.returncode = None
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_ssh.kill.assert_called_once()
|
||||
|
||||
|
||||
class TestSSHBulkUploadWiring:
|
||||
"""Verify bulk_upload_fn is wired into FileSyncManager."""
|
||||
|
||||
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class FakeSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
# Should be the bound method
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
|
||||
class TestSharedHelpers:
|
||||
"""Direct unit tests for file_sync.py helpers."""
|
||||
|
||||
def test_quoted_mkdir_command_basic(self):
|
||||
result = quoted_mkdir_command(["/a", "/b/c"])
|
||||
assert result == "mkdir -p /a /b/c"
|
||||
|
||||
def test_quoted_mkdir_command_quotes_special_chars(self):
|
||||
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
|
||||
assert "mkdir -p" in result
|
||||
# shlex.quote wraps in single quotes
|
||||
assert "'/path/with spaces'" in result
|
||||
|
||||
def test_quoted_mkdir_command_empty(self):
|
||||
result = quoted_mkdir_command([])
|
||||
assert result == "mkdir -p "
|
||||
|
||||
def test_unique_parent_dirs_deduplicates(self):
|
||||
files = [
|
||||
("/local/a.txt", "/remote/dir/a.txt"),
|
||||
("/local/b.txt", "/remote/dir/b.txt"),
|
||||
("/local/c.txt", "/remote/other/c.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/remote/dir", "/remote/other"]
|
||||
|
||||
def test_unique_parent_dirs_sorted(self):
|
||||
files = [
|
||||
("/local/z.txt", "/z/file.txt"),
|
||||
("/local/a.txt", "/a/file.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/a", "/z"]
|
||||
|
||||
def test_unique_parent_dirs_empty(self):
|
||||
assert unique_parent_dirs([]) == []
|
||||
|
||||
|
||||
class TestSSHBulkUploadEdgeCases:
|
||||
"""Edge cases for _ssh_bulk_upload."""
|
||||
|
||||
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
|
||||
"""If SSH Popen raises, tar process must be killed and cleaned up."""
|
||||
f1 = tmp_path / "e.txt"
|
||||
f1.write_text("e")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
|
||||
|
||||
mock_tar = _mock_proc()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_ssh_popen(cmd, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_tar # tar Popen succeeds
|
||||
raise OSError("SSH binary not found")
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
|
||||
with pytest.raises(OSError, match="SSH binary not found"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_tar.wait.assert_called_once()
|
||||
|
|
@ -15,7 +15,13 @@ from tools.environments.base import (
|
|||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
)
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -150,11 +156,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
if not files:
|
||||
return
|
||||
|
||||
# Pre-create all unique parent directories in one shell call
|
||||
parents = sorted({str(Path(remote).parent) for _, remote in files})
|
||||
parents = unique_parent_dirs(files)
|
||||
if parents:
|
||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(p) for p in parents)
|
||||
self._sandbox.process.exec(mkdir_cmd)
|
||||
self._sandbox.process.exec(quoted_mkdir_command(parents))
|
||||
|
||||
uploads = [
|
||||
FileUpload(source=host_path, destination=remote_path)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import logging
|
|||
import os
|
||||
import shlex
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from tools.environments.base import _file_mtime_key
|
||||
|
|
@ -60,6 +61,16 @@ def quoted_rm_command(remote_paths: list[str]) -> str:
|
|||
return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths)
|
||||
|
||||
|
||||
def quoted_mkdir_command(dirs: list[str]) -> str:
|
||||
"""Build a shell ``mkdir -p`` command for a batch of directories."""
|
||||
return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
||||
|
||||
|
||||
def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
||||
"""Extract sorted unique parent directories from (host, remote) pairs."""
|
||||
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||
|
||||
|
||||
class FileSyncManager:
|
||||
"""Tracks local file changes and syncs to a remote environment.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import shlex
|
||||
import tarfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
|
@ -18,7 +21,13 @@ from tools.environments.base import (
|
|||
_load_json_store,
|
||||
_save_json_store,
|
||||
)
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -259,26 +268,84 @@ class ModalEnvironment(BaseEnvironment):
|
|||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=self._modal_upload,
|
||||
delete_fn=self._modal_delete,
|
||||
bulk_upload_fn=self._modal_bulk_upload,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
||||
def _modal_upload(self, host_path: str, remote_path: str) -> None:
|
||||
"""Upload a single file via base64-over-exec."""
|
||||
import base64
|
||||
"""Upload a single file via base64 piped through stdin."""
|
||||
content = Path(host_path).read_bytes()
|
||||
b64 = base64.b64encode(content).decode("ascii")
|
||||
container_dir = str(Path(remote_path).parent)
|
||||
cmd = (
|
||||
f"mkdir -p {shlex.quote(container_dir)} && "
|
||||
f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(remote_path)}"
|
||||
f"base64 -d > {shlex.quote(remote_path)}"
|
||||
)
|
||||
|
||||
async def _write():
|
||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||
offset = 0
|
||||
chunk_size = self._STDIN_CHUNK_SIZE
|
||||
while offset < len(b64):
|
||||
proc.stdin.write(b64[offset:offset + chunk_size])
|
||||
await proc.stdin.drain.aio()
|
||||
offset += chunk_size
|
||||
proc.stdin.write_eof()
|
||||
await proc.stdin.drain.aio()
|
||||
await proc.wait.aio()
|
||||
|
||||
self._worker.run_coroutine(_write(), timeout=15)
|
||||
self._worker.run_coroutine(_write(), timeout=30)
|
||||
|
||||
# Modal SDK stdin buffer limit (legacy server path). The command-router
|
||||
# path allows 16 MB, but we must stay under the smaller 2 MB cap for
|
||||
# compatibility. Chunks are written below this threshold and flushed
|
||||
# individually via drain().
|
||||
_STDIN_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB — safe for both transport paths
|
||||
|
||||
def _modal_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||
"""Upload many files via tar archive piped through stdin.
|
||||
|
||||
Builds a gzipped tar archive in memory and streams it into a
|
||||
``base64 -d | tar xzf -`` pipeline via the process's stdin,
|
||||
avoiding the Modal SDK's 64 KB ``ARG_MAX_BYTES`` exec-arg limit.
|
||||
"""
|
||||
if not files:
|
||||
return
|
||||
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for host_path, remote_path in files:
|
||||
tar.add(host_path, arcname=remote_path.lstrip("/"))
|
||||
payload = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
parents = unique_parent_dirs(files)
|
||||
mkdir_part = quoted_mkdir_command(parents)
|
||||
cmd = f"{mkdir_part} && base64 -d | tar xzf - -C /"
|
||||
|
||||
async def _bulk():
|
||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||
|
||||
# Stream payload through stdin in chunks to stay under the
|
||||
# SDK's per-write buffer limit (2 MB legacy / 16 MB router).
|
||||
offset = 0
|
||||
chunk_size = self._STDIN_CHUNK_SIZE
|
||||
while offset < len(payload):
|
||||
proc.stdin.write(payload[offset:offset + chunk_size])
|
||||
await proc.stdin.drain.aio()
|
||||
offset += chunk_size
|
||||
|
||||
proc.stdin.write_eof()
|
||||
await proc.stdin.drain.aio()
|
||||
|
||||
exit_code = await proc.wait.aio()
|
||||
if exit_code != 0:
|
||||
stderr_text = await proc.stderr.read.aio()
|
||||
raise RuntimeError(
|
||||
f"Modal bulk upload failed (exit {exit_code}): {stderr_text}"
|
||||
)
|
||||
|
||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||
|
||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via exec."""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""SSH remote execution environment with ControlMaster connection persistence."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
|
|
@ -8,7 +9,13 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -50,6 +57,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
|
||||
upload_fn=self._scp_upload,
|
||||
delete_fn=self._ssh_delete,
|
||||
bulk_upload_fn=self._ssh_bulk_upload,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
|
||||
|
|
@ -107,9 +115,8 @@ class SSHEnvironment(BaseEnvironment):
|
|||
"""Create base ~/.hermes directory tree on remote in one SSH call."""
|
||||
base = f"{self._remote_home}/.hermes"
|
||||
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
|
||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(mkdir_cmd)
|
||||
cmd.append(quoted_mkdir_command(dirs))
|
||||
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
# _get_sync_files provided via iter_sync_files in FileSyncManager init
|
||||
|
|
@ -131,6 +138,84 @@ class SSHEnvironment(BaseEnvironment):
|
|||
if result.returncode != 0:
|
||||
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
|
||||
|
||||
def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||
"""Upload many files in a single tar-over-SSH stream.
|
||||
|
||||
Pipes ``tar c`` on the local side through an SSH connection to
|
||||
``tar x`` on the remote, transferring all files in one TCP stream
|
||||
instead of spawning a subprocess per file. Directory creation is
|
||||
batched into a single ``mkdir -p`` call beforehand.
|
||||
|
||||
Typical improvement: ~580 files goes from O(N) scp round-trips
|
||||
to a single streaming transfer.
|
||||
"""
|
||||
if not files:
|
||||
return
|
||||
|
||||
parents = unique_parent_dirs(files)
|
||||
if parents:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(quoted_mkdir_command(parents))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}")
|
||||
|
||||
# Symlink staging avoids fragile GNU tar --transform rules.
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging:
|
||||
for host_path, remote_path in files:
|
||||
staged = os.path.join(staging, remote_path.lstrip("/"))
|
||||
os.makedirs(os.path.dirname(staged), exist_ok=True)
|
||||
os.symlink(os.path.abspath(host_path), staged)
|
||||
|
||||
tar_cmd = ["tar", "-chf", "-", "-C", staging, "."]
|
||||
ssh_cmd = self._build_ssh_command()
|
||||
ssh_cmd.append("tar xf - -C /")
|
||||
|
||||
tar_proc = subprocess.Popen(
|
||||
tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
try:
|
||||
ssh_proc = subprocess.Popen(
|
||||
ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except Exception:
|
||||
tar_proc.kill()
|
||||
tar_proc.wait()
|
||||
raise
|
||||
|
||||
# Allow tar_proc to receive SIGPIPE if ssh_proc exits early
|
||||
tar_proc.stdout.close()
|
||||
|
||||
try:
|
||||
_, ssh_stderr = ssh_proc.communicate(timeout=120)
|
||||
# Use communicate() instead of wait() to drain stderr and
|
||||
# avoid deadlock if tar produces more than PIPE_BUF of errors.
|
||||
tar_stderr_raw = b""
|
||||
if tar_proc.poll() is None:
|
||||
_, tar_stderr_raw = tar_proc.communicate(timeout=10)
|
||||
else:
|
||||
tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b""
|
||||
except subprocess.TimeoutExpired:
|
||||
tar_proc.kill()
|
||||
ssh_proc.kill()
|
||||
tar_proc.wait()
|
||||
ssh_proc.wait()
|
||||
raise RuntimeError("SSH bulk upload timed out")
|
||||
|
||||
if tar_proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"tar create failed (rc={tar_proc.returncode}): "
|
||||
f"{tar_stderr_raw.decode(errors='replace').strip()}"
|
||||
)
|
||||
if ssh_proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"tar extract over SSH failed (rc={ssh_proc.returncode}): "
|
||||
f"{ssh_stderr.decode(errors='replace').strip()}"
|
||||
)
|
||||
|
||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||
|
||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files in one SSH call."""
|
||||
cmd = self._build_ssh_command()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue