diff --git a/tests/tools/test_modal_bulk_upload.py b/tests/tools/test_modal_bulk_upload.py new file mode 100644 index 0000000000..ffde9c3df0 --- /dev/null +++ b/tests/tools/test_modal_bulk_upload.py @@ -0,0 +1,224 @@ +"""Tests for Modal bulk upload via tar/base64 archive.""" + +import asyncio +import base64 +import io +import tarfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.environments import modal as modal_env + + +def _make_mock_modal_env(monkeypatch, tmp_path): + """Create a minimal mock ModalEnvironment for testing upload methods. + + Returns a ModalEnvironment-like object with _sandbox and _worker mocked. + We don't call __init__ because it requires the Modal SDK. + """ + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + env._sync_manager = None + return env + + +def _wire_async_exec(env, exec_calls=None): + """Wire mock sandbox.exec.aio and a real run_coroutine on the env. + + Optionally captures exec call args into *exec_calls* list. + Returns a dict that will contain ``run_coroutine`` kwargs after + the next call (useful for timeout assertions). + """ + if exec_calls is None: + exec_calls = [] + run_kwargs: dict = {} + + async def mock_exec(*args, **kwargs): + exec_calls.append(args) + proc = MagicMock() + proc.wait = MagicMock() + proc.wait.aio = AsyncMock(return_value=0) + return proc + + env._sandbox.exec = MagicMock() + env._sandbox.exec.aio = mock_exec + + def real_run_coroutine(coro, **kwargs): + run_kwargs.update(kwargs) + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + env._worker.run_coroutine = real_run_coroutine + return exec_calls, run_kwargs + + +class TestModalBulkUpload: + """Test _modal_bulk_upload method.""" + + def test_empty_files_is_noop(self, monkeypatch, tmp_path): + """Empty file list should not call worker.run_coroutine.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + env._modal_bulk_upload([]) + env._worker.run_coroutine.assert_not_called() + + def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path): + """The tar archive sent to the sandbox should contain all files.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src_a = tmp_path / "a.json" + src_b = tmp_path / "b.py" + src_a.write_text("cred_content") + src_b.write_text("skill_content") + + files = [ + (str(src_a), "/root/.hermes/credentials/a.json"), + (str(src_b), "/root/.hermes/skills/b.py"), + ] + + exec_calls, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # Verify exec was called with bash -c and a tar command + assert len(exec_calls) == 1 + args = exec_calls[0] + assert args[0] == "bash" + assert args[1] == "-c" + cmd = args[2] + assert "mkdir -p" in cmd + assert "base64 -d" in cmd + assert "tar xzf" in cmd + assert "-C /" in cmd + + # Extract the base64 payload and verify tar contents + import re + match = re.search(r"echo '?([A-Za-z0-9+/=]+)'?", cmd) + assert match, f"Could not find base64 payload in command: {cmd}" + payload = match.group(1) + + tar_data = base64.b64decode(payload) + buf = io.BytesIO(tar_data) + with tarfile.open(fileobj=buf, mode="r:gz") as tar: + names = sorted(tar.getnames()) + assert "root/.hermes/credentials/a.json" in names + assert "root/.hermes/skills/b.py" in names + + # Verify content + a_content = tar.extractfile("root/.hermes/credentials/a.json").read() + assert a_content == b"cred_content" + b_content = tar.extractfile("root/.hermes/skills/b.py").read() + assert b_content == b"skill_content" + + def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path): + """Remote parent directories should be pre-created in the command.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + + files = [ + (str(src), "/root/.hermes/credentials/f.txt"), + (str(src), "/root/.hermes/skills/deep/nested/f.txt"), + ] + + exec_calls, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + cmd = exec_calls[0][2] + assert "/root/.hermes/credentials" in cmd + assert "/root/.hermes/skills/deep/nested" in cmd + + def test_single_exec_call(self, monkeypatch, tmp_path): + """Bulk upload should use exactly one exec call regardless of file count.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + files = [] + for i in range(20): + src = tmp_path / f"file_{i}.txt" + src.write_text(f"content_{i}") + files.append((str(src), f"/root/.hermes/cache/file_{i}.txt")) + + exec_calls, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # Should be exactly 1 exec call, not 20 + assert len(exec_calls) == 1 + + def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch): + """Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager.""" + captured_kwargs = {} + + def capture_fsm(**kwargs): + captured_kwargs.update(kwargs) + return type("M", (), {"sync": lambda self, **k: None})() + + monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm) + + # Create a minimal env without full __init__ + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + + # Manually call the part of __init__ that wires FileSyncManager + from tools.environments.file_sync import iter_sync_files + env._sync_manager = modal_env.FileSyncManager( + get_files_fn=lambda: iter_sync_files("/root/.hermes"), + upload_fn=env._modal_upload, + delete_fn=env._modal_delete, + bulk_upload_fn=env._modal_bulk_upload, + ) + + assert "bulk_upload_fn" in captured_kwargs + assert captured_kwargs["bulk_upload_fn"] is not None + assert callable(captured_kwargs["bulk_upload_fn"]) + + def test_timeout_set_to_120(self, monkeypatch, tmp_path): + """Bulk upload uses a 120s timeout (not the per-file 15s).""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + files = [(str(src), "/root/.hermes/f.txt")] + + _, run_kwargs = _wire_async_exec(env) + env._modal_bulk_upload(files) + + assert run_kwargs.get("timeout") == 120 + + def test_nonzero_exit_raises(self, monkeypatch, tmp_path): + """Non-zero exit code from remote exec should raise RuntimeError.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + files = [(str(src), "/root/.hermes/f.txt")] + + async def mock_exec(*args, **kwargs): + proc = MagicMock() + proc.wait = MagicMock() + proc.wait.aio = AsyncMock(return_value=1) # non-zero exit + return proc + + env._sandbox.exec = MagicMock() + env._sandbox.exec.aio = mock_exec + + def real_run_coroutine(coro, **kwargs): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + env._worker.run_coroutine = real_run_coroutine + + with pytest.raises(RuntimeError, match="Modal bulk upload failed"): + env._modal_bulk_upload(files) diff --git a/tests/tools/test_ssh_bulk_upload.py b/tests/tools/test_ssh_bulk_upload.py new file mode 100644 index 0000000000..97cb39f53c --- /dev/null +++ b/tests/tools/test_ssh_bulk_upload.py @@ -0,0 +1,517 @@ +"""Tests for SSH bulk upload via tar pipe.""" + +import os +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from tools.environments import ssh as ssh_env +from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs +from tools.environments.ssh import SSHEnvironment + + +def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""), + stderr_read=b""): + """Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes.""" + m = MagicMock() + m.stdout = MagicMock() + m.returncode = returncode + m.poll.return_value = poll_return + m.communicate.return_value = communicate_return + m.stderr = MagicMock() + m.stderr.read.return_value = stderr_read + return m + + +@pytest.fixture +def mock_env(monkeypatch): + """Create an SSHEnvironment with mocked connection/sync.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + monkeypatch.setattr( + ssh_env, "FileSyncManager", + lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), + ) + return SSHEnvironment(host="example.com", user="testuser") + + +class TestSSHBulkUpload: + """Unit tests for _ssh_bulk_upload — tar pipe mechanics.""" + + def test_empty_files_is_noop(self, mock_env): + """Empty file list should not spawn any subprocesses.""" + with patch.object(subprocess, "run") as mock_run, \ + patch.object(subprocess, "Popen") as mock_popen: + mock_env._ssh_bulk_upload([]) + mock_run.assert_not_called() + mock_popen.assert_not_called() + + def test_mkdir_batched_into_single_call(self, mock_env, tmp_path): + """All parent directories should be created in one SSH call.""" + # Create test files + f1 = tmp_path / "a.txt" + f1.write_text("aaa") + f2 = tmp_path / "b.txt" + f2.write_text("bbb") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/a.txt"), + (str(f2), "/home/testuser/.hermes/credentials/b.txt"), + ] + + # Mock subprocess.run for mkdir and Popen for tar pipe + mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0)) + + def make_proc(cmd, **kwargs): + m = MagicMock() + m.stdout = MagicMock() + m.returncode = 0 + m.poll.return_value = 0 + m.communicate.return_value = (b"", b"") + m.stderr = MagicMock() + m.stderr.read.return_value = b"" + return m + + with patch.object(subprocess, "run", mock_run), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + mock_env._ssh_bulk_upload(files) + + # Exactly one subprocess.run call for mkdir + assert mock_run.call_count == 1 + mkdir_cmd = mock_run.call_args[0][0] + # Should contain mkdir -p with both parent dirs + mkdir_str = " ".join(mkdir_cmd) + assert "mkdir -p" in mkdir_str + assert "/home/testuser/.hermes/skills" in mkdir_str + assert "/home/testuser/.hermes/credentials" in mkdir_str + + def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path): + """Symlinks in staging dir should mirror the remote path structure.""" + f1 = tmp_path / "local_a.txt" + f1.write_text("content a") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/my_skill.md"), + ] + + staging_paths = [] + + def capture_tar_cmd(cmd, **kwargs): + if cmd[0] == "tar": + # Capture the staging dir from -C argument + c_idx = cmd.index("-C") + staging_dir = cmd[c_idx + 1] + # Check the symlink exists + expected = os.path.join( + staging_dir, "home/testuser/.hermes/skills/my_skill.md" + ) + staging_paths.append(expected) + assert os.path.islink(expected), f"Expected symlink at {expected}" + assert os.readlink(expected) == os.path.abspath(str(f1)) + + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_tar_cmd): + mock_env._ssh_bulk_upload(files) + + assert len(staging_paths) == 1, "tar command should have been called" + + def test_tar_pipe_commands(self, mock_env, tmp_path): + """Verify tar and SSH commands are wired correctly.""" + f1 = tmp_path / "x.txt" + f1.write_text("x") + + files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")] + + popen_cmds = [] + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + mock_env._ssh_bulk_upload(files) + + assert len(popen_cmds) == 2, "Should spawn tar + ssh processes" + + tar_cmd = popen_cmds[0] + ssh_cmd = popen_cmds[1] + + # tar: create, dereference symlinks, to stdout + assert tar_cmd[0] == "tar" + assert "-chf" in tar_cmd + assert "-" in tar_cmd # stdout + assert "-C" in tar_cmd + + # ssh: extract from stdin at / + ssh_str = " ".join(ssh_cmd) + assert "ssh" in ssh_str + assert "tar xf - -C /" in ssh_str + assert "testuser@example.com" in ssh_str + + def test_mkdir_failure_raises(self, mock_env, tmp_path): + """mkdir failure should raise RuntimeError before tar pipe.""" + f1 = tmp_path / "y.txt" + f1.write_text("y") + files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")] + + failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied") + with patch.object(subprocess, "run", return_value=failed_run): + with pytest.raises(RuntimeError, match="remote mkdir failed"): + mock_env._ssh_bulk_upload(files) + + def test_tar_create_failure_raises(self, mock_env, tmp_path): + """tar create failure should raise RuntimeError.""" + f1 = tmp_path / "z.txt" + f1.write_text("z") + files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = 1 + mock_tar.poll.return_value = 1 + mock_tar.communicate.return_value = (b"tar: error", b"") + mock_tar.stderr = MagicMock() + mock_tar.stderr.read.return_value = b"tar: error" + + mock_ssh = MagicMock() + mock_ssh.communicate.return_value = (b"", b"") + mock_ssh.returncode = 0 + + def popen_side_effect(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=popen_side_effect): + with pytest.raises(RuntimeError, match="tar create failed"): + mock_env._ssh_bulk_upload(files) + + def test_ssh_extract_failure_raises(self, mock_env, tmp_path): + """SSH tar extract failure should raise RuntimeError.""" + f1 = tmp_path / "w.txt" + f1.write_text("w") + files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = 0 + mock_tar.poll.return_value = 0 + mock_tar.communicate.return_value = (b"", b"") + mock_tar.stderr = MagicMock() + mock_tar.stderr.read.return_value = b"" + + mock_ssh = MagicMock() + mock_ssh.communicate.return_value = (b"", b"Permission denied") + mock_ssh.returncode = 1 + + def popen_side_effect(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=popen_side_effect): + with pytest.raises(RuntimeError, match="tar extract over SSH failed"): + mock_env._ssh_bulk_upload(files) + + def test_ssh_command_uses_control_socket(self, mock_env, tmp_path): + """SSH command for tar extract should reuse ControlMaster socket.""" + f1 = tmp_path / "c.txt" + f1.write_text("c") + files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")] + + popen_cmds = [] + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + mock_env._ssh_bulk_upload(files) + + # The SSH command (second Popen call) should include ControlPath + ssh_cmd = popen_cmds[1] + assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd) + + def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path): + """Bulk upload SSH command should include custom port and key.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + monkeypatch.setattr( + ssh_env, "FileSyncManager", + lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), + ) + env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key") + + f1 = tmp_path / "d.txt" + f1.write_text("d") + files = [(str(f1), "/home/u/.hermes/skills/d.txt")] + + run_cmds = [] + popen_cmds = [] + + def capture_run(cmd, **kwargs): + run_cmds.append(cmd) + return subprocess.CompletedProcess([], 0) + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", side_effect=capture_run), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + env._ssh_bulk_upload(files) + + # Check mkdir SSH call includes port and key + assert len(run_cmds) == 1 + mkdir_cmd = run_cmds[0] + assert "-p" in mkdir_cmd and "2222" in mkdir_cmd + assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd + + # Check tar extract SSH call includes port and key + ssh_cmd = popen_cmds[1] + assert "-p" in ssh_cmd and "2222" in ssh_cmd + assert "-i" in ssh_cmd and "/my/key" in ssh_cmd + + def test_parent_dirs_deduplicated(self, mock_env, tmp_path): + """Multiple files in the same dir should produce one mkdir entry.""" + f1 = tmp_path / "a.txt" + f1.write_text("a") + f2 = tmp_path / "b.txt" + f2.write_text("b") + f3 = tmp_path / "c.txt" + f3.write_text("c") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/a.txt"), + (str(f2), "/home/testuser/.hermes/skills/b.txt"), + (str(f3), "/home/testuser/.hermes/credentials/c.txt"), + ] + + run_cmds = [] + + def capture_run(cmd, **kwargs): + run_cmds.append(cmd) + return subprocess.CompletedProcess([], 0) + + def make_mock_proc(cmd, **kwargs): + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", side_effect=capture_run), \ + patch.object(subprocess, "Popen", side_effect=make_mock_proc): + mock_env._ssh_bulk_upload(files) + + # Only one mkdir call + assert len(run_cmds) == 1 + mkdir_str = " ".join(run_cmds[0]) + # skills dir should appear exactly once despite two files + assert mkdir_str.count("/home/testuser/.hermes/skills") == 1 + assert "/home/testuser/.hermes/credentials" in mkdir_str + + def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path): + """tar_proc.stdout must be closed so SIGPIPE propagates correctly.""" + f1 = tmp_path / "s.txt" + f1.write_text("s") + files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")] + + mock_tar_stdout = MagicMock() + + def make_proc(cmd, **kwargs): + mock = MagicMock() + if cmd[0] == "tar": + mock.stdout = mock_tar_stdout + else: + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + mock_env._ssh_bulk_upload(files) + + mock_tar_stdout.close.assert_called_once() + + def test_timeout_kills_both_processes(self, mock_env, tmp_path): + """TimeoutExpired during communicate should kill both processes.""" + f1 = tmp_path / "t.txt" + f1.write_text("t") + files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = None + mock_tar.poll.return_value = None + + mock_ssh = MagicMock() + mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120) + mock_ssh.returncode = None + + def make_proc(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + with pytest.raises(RuntimeError, match="SSH bulk upload timed out"): + mock_env._ssh_bulk_upload(files) + + mock_tar.kill.assert_called_once() + mock_ssh.kill.assert_called_once() + + +class TestSSHBulkUploadWiring: + """Verify bulk_upload_fn is wired into FileSyncManager.""" + + def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch): + """SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + + captured_kwargs = {} + + class FakeSyncManager: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + def sync(self, **kw): + pass + + monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager) + + env = SSHEnvironment(host="h", user="u") + + assert "bulk_upload_fn" in captured_kwargs + assert captured_kwargs["bulk_upload_fn"] is not None + # Should be the bound method + assert callable(captured_kwargs["bulk_upload_fn"]) + + +class TestSharedHelpers: + """Direct unit tests for file_sync.py helpers.""" + + def test_quoted_mkdir_command_basic(self): + result = quoted_mkdir_command(["/a", "/b/c"]) + assert result == "mkdir -p /a /b/c" + + def test_quoted_mkdir_command_quotes_special_chars(self): + result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"]) + assert "mkdir -p" in result + # shlex.quote wraps in single quotes + assert "'/path/with spaces'" in result + + def test_quoted_mkdir_command_empty(self): + result = quoted_mkdir_command([]) + assert result == "mkdir -p " + + def test_unique_parent_dirs_deduplicates(self): + files = [ + ("/local/a.txt", "/remote/dir/a.txt"), + ("/local/b.txt", "/remote/dir/b.txt"), + ("/local/c.txt", "/remote/other/c.txt"), + ] + result = unique_parent_dirs(files) + assert result == ["/remote/dir", "/remote/other"] + + def test_unique_parent_dirs_sorted(self): + files = [ + ("/local/z.txt", "/z/file.txt"), + ("/local/a.txt", "/a/file.txt"), + ] + result = unique_parent_dirs(files) + assert result == ["/a", "/z"] + + def test_unique_parent_dirs_empty(self): + assert unique_parent_dirs([]) == [] + + +class TestSSHBulkUploadEdgeCases: + """Edge cases for _ssh_bulk_upload.""" + + def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path): + """If SSH Popen raises, tar process must be killed and cleaned up.""" + f1 = tmp_path / "e.txt" + f1.write_text("e") + files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")] + + mock_tar = _mock_proc() + + call_count = 0 + + def failing_ssh_popen(cmd, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_tar # tar Popen succeeds + raise OSError("SSH binary not found") + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=failing_ssh_popen): + with pytest.raises(OSError, match="SSH binary not found"): + mock_env._ssh_bulk_upload(files) + + mock_tar.kill.assert_called_once() + mock_tar.wait.assert_called_once() diff --git a/tools/environments/file_sync.py b/tools/environments/file_sync.py index 29b45f858f..64a5b56dc4 100644 --- a/tools/environments/file_sync.py +++ b/tools/environments/file_sync.py @@ -10,6 +10,7 @@ import logging import os import shlex import time +from pathlib import Path from typing import Callable from tools.environments.base import _file_mtime_key @@ -60,6 +61,16 @@ def quoted_rm_command(remote_paths: list[str]) -> str: return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths) +def quoted_mkdir_command(dirs: list[str]) -> str: + """Build a shell ``mkdir -p`` command for a batch of directories.""" + return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs) + + +def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]: + """Extract sorted unique parent directories from (host, remote) pairs.""" + return sorted({str(Path(remote).parent) for _, remote in files}) + + class FileSyncManager: """Tracks local file changes and syncs to a remote environment. diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 365eca9fb1..a122eb0ee8 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -5,8 +5,11 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions. """ import asyncio +import base64 +import io import logging import shlex +import tarfile import threading from pathlib import Path from typing import Any, Optional @@ -18,7 +21,13 @@ from tools.environments.base import ( _load_json_store, _save_json_store, ) -from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_mkdir_command, + quoted_rm_command, + unique_parent_dirs, +) logger = logging.getLogger(__name__) @@ -259,13 +268,13 @@ class ModalEnvironment(BaseEnvironment): get_files_fn=lambda: iter_sync_files("/root/.hermes"), upload_fn=self._modal_upload, delete_fn=self._modal_delete, + bulk_upload_fn=self._modal_bulk_upload, ) self._sync_manager.sync(force=True) self.init_session() def _modal_upload(self, host_path: str, remote_path: str) -> None: """Upload a single file via base64-over-exec.""" - import base64 content = Path(host_path).read_bytes() b64 = base64.b64encode(content).decode("ascii") container_dir = str(Path(remote_path).parent) @@ -280,6 +289,44 @@ class ModalEnvironment(BaseEnvironment): self._worker.run_coroutine(_write(), timeout=15) + def _modal_bulk_upload(self, files: list[tuple[str, str]]) -> None: + """Upload many files in a single exec call via tar archive. + + Builds a gzipped tar archive in memory, base64-encodes it, and + decodes+extracts in one ``exec`` call. Avoids per-file + exec+encoding overhead (~580 files goes from minutes to seconds). + """ + if not files: + return + + # Build a tar archive in memory with files at their remote paths + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for host_path, remote_path in files: + # Store with leading '/' stripped so extracting at '/' + # recreates the full absolute path + tar.add(host_path, arcname=remote_path.lstrip("/")) + payload = base64.b64encode(buf.getvalue()).decode("ascii") + + # Pre-create parent dirs + decode + extract in one exec call + parents = unique_parent_dirs(files) + mkdir_part = quoted_mkdir_command(parents) + cmd = ( + f"{mkdir_part} && " + f"echo {shlex.quote(payload)} | base64 -d | tar xzf - -C /" + ) + sandbox = self._sandbox + + async def _bulk(): + proc = await sandbox.exec.aio("bash", "-c", cmd) + exit_code = await proc.wait.aio() + if exit_code != 0: + raise RuntimeError( + f"Modal bulk upload failed (exit {exit_code})" + ) + + self._worker.run_coroutine(_bulk(), timeout=120) + def _modal_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files via exec.""" rm_cmd = quoted_rm_command(remote_paths) diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 8cb1b0c570..48d72554d7 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -1,6 +1,7 @@ """SSH remote execution environment with ControlMaster connection persistence.""" import logging +import os import shlex import shutil import subprocess @@ -8,7 +9,13 @@ import tempfile from pathlib import Path from tools.environments.base import BaseEnvironment, _popen_bash -from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_mkdir_command, + quoted_rm_command, + unique_parent_dirs, +) logger = logging.getLogger(__name__) @@ -50,6 +57,7 @@ class SSHEnvironment(BaseEnvironment): get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"), upload_fn=self._scp_upload, delete_fn=self._ssh_delete, + bulk_upload_fn=self._ssh_bulk_upload, ) self._sync_manager.sync(force=True) @@ -107,9 +115,8 @@ class SSHEnvironment(BaseEnvironment): """Create base ~/.hermes directory tree on remote in one SSH call.""" base = f"{self._remote_home}/.hermes" dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"] - mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(d) for d in dirs) cmd = self._build_ssh_command() - cmd.append(mkdir_cmd) + cmd.append(quoted_mkdir_command(dirs)) subprocess.run(cmd, capture_output=True, text=True, timeout=10) # _get_sync_files provided via iter_sync_files in FileSyncManager init @@ -131,6 +138,92 @@ class SSHEnvironment(BaseEnvironment): if result.returncode != 0: raise RuntimeError(f"scp failed: {result.stderr.strip()}") + def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None: + """Upload many files in a single tar-over-SSH stream. + + Pipes ``tar c`` on the local side through an SSH connection to + ``tar x`` on the remote, transferring all files in one TCP stream + instead of spawning a subprocess per file. Directory creation is + batched into a single ``mkdir -p`` call beforehand. + + Typical improvement: ~580 files goes from O(N) scp round-trips + to a single streaming transfer. + """ + if not files: + return + + # Pre-create all unique parent directories in one SSH call + parents = unique_parent_dirs(files) + if parents: + cmd = self._build_ssh_command() + cmd.append(quoted_mkdir_command(parents)) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}") + + # Symlink staging avoids fragile GNU tar --transform rules. + with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging: + for host_path, remote_path in files: + # remote_path is absolute (e.g. /home/user/.hermes/skills/foo.md) + # Create the directory structure under staging + staged = os.path.join(staging, remote_path.lstrip("/")) + os.makedirs(os.path.dirname(staged), exist_ok=True) + # Symlink to the actual file (avoid copying) + os.symlink(os.path.abspath(host_path), staged) + + # tar: dereference symlinks (-h), create archive from staging root + # The archive paths are relative to staging, which mirrors / on remote + tar_cmd = ["tar", "-chf", "-", "-C", staging, "."] + + # ssh: extract on remote at / + ssh_cmd = self._build_ssh_command() + ssh_cmd.append("tar xf - -C /") + + tar_proc = subprocess.Popen( + tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + try: + ssh_proc = subprocess.Popen( + ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except Exception: + tar_proc.kill() + tar_proc.wait() + raise + + # Allow tar_proc to receive SIGPIPE if ssh_proc exits early + tar_proc.stdout.close() + + try: + _, ssh_stderr = ssh_proc.communicate(timeout=120) + # Use communicate() instead of wait() to drain stderr and + # avoid deadlock if tar produces more than PIPE_BUF of errors. + tar_stderr_raw = b"" + if tar_proc.poll() is None: + _, tar_stderr_raw = tar_proc.communicate(timeout=10) + else: + tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b"" + except subprocess.TimeoutExpired: + tar_proc.kill() + ssh_proc.kill() + tar_proc.wait() + ssh_proc.wait() + raise RuntimeError("SSH bulk upload timed out") + + if tar_proc.returncode != 0: + raise RuntimeError( + f"tar create failed (rc={tar_proc.returncode}): " + f"{tar_stderr_raw.decode(errors='replace').strip()}" + ) + if ssh_proc.returncode != 0: + raise RuntimeError( + f"tar extract over SSH failed (rc={ssh_proc.returncode}): " + f"{ssh_stderr.decode(errors='replace').strip()}" + ) + + logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files)) + def _ssh_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files in one SSH call.""" cmd = self._build_ssh_command()