diff --git a/tests/tools/test_file_sync_back.py b/tests/tools/test_file_sync_back.py new file mode 100644 index 0000000000..29a8d71236 --- /dev/null +++ b/tests/tools/test_file_sync_back.py @@ -0,0 +1,412 @@ +"""Tests for FileSyncManager.sync_back() — pull remote changes to host.""" + +import fcntl +import io +import logging +import os +import signal +import tarfile +import time +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest + +from tools.environments.file_sync import ( + FileSyncManager, + _sha256_file, + _SYNC_BACK_BACKOFF, + _SYNC_BACK_MAX_RETRIES, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tar(files: dict[str, bytes], dest: Path): + """Write a tar archive containing the given arcname->content pairs.""" + with tarfile.open(dest, "w") as tar: + for arcname, content in files.items(): + info = tarfile.TarInfo(name=arcname) + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + + +def _make_download_fn(files: dict[str, bytes]): + """Return a bulk_download_fn that writes a tar of the given files.""" + def download(dest: Path): + _make_tar(files, dest) + return download + + +def _sha256_bytes(data: bytes) -> str: + """Compute SHA-256 hex digest of raw bytes (for test convenience).""" + import hashlib + return hashlib.sha256(data).hexdigest() + + +def _write_file(path: Path, content: bytes) -> str: + """Write bytes to *path*, creating parents, and return the string path.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(content) + return str(path) + + +def _make_manager( + tmp_path: Path, + file_mapping: list[tuple[str, str]] | None = None, + bulk_download_fn=None, +) -> FileSyncManager: + """Create a FileSyncManager wired for testing. + + *file_mapping* is a list of (host_path, remote_path) tuples that + ``get_files_fn`` returns. If *None* an empty list is used. + """ + mapping = file_mapping or [] + return FileSyncManager( + get_files_fn=lambda: mapping, + upload_fn=MagicMock(), + delete_fn=MagicMock(), + bulk_download_fn=bulk_download_fn, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSyncBackNoop: + """sync_back() is a no-op when there is no download function.""" + + def test_sync_back_noop_without_download_fn(self, tmp_path): + mgr = _make_manager(tmp_path, bulk_download_fn=None) + # Should return immediately without error + mgr.sync_back(hermes_home=tmp_path / ".hermes") + # Nothing to assert beyond "no exception raised" + + +class TestSyncBackNoChanges: + """When all remote files match pushed hashes, nothing is applied.""" + + def test_sync_back_no_changes(self, tmp_path): + host_file = tmp_path / "host" / "cred.json" + host_content = b'{"key": "val"}' + _write_file(host_file, host_content) + + remote_path = "/root/.hermes/cred.json" + mapping = [(str(host_file), remote_path)] + + # Remote tar contains the same content as was pushed + download_fn = _make_download_fn({ + "root/.hermes/cred.json": host_content, + }) + + mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn) + # Simulate that we already pushed this file with this hash + mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content) + + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # Host file should be unchanged (same content, same bytes) + assert host_file.read_bytes() == host_content + + +class TestSyncBackAppliesChanged: + """Remote file differs from pushed version -- gets copied to host.""" + + def test_sync_back_applies_changed_file(self, tmp_path): + host_file = tmp_path / "host" / "skill.py" + original_content = b"print('v1')" + _write_file(host_file, original_content) + + remote_path = "/root/.hermes/skill.py" + mapping = [(str(host_file), remote_path)] + + remote_content = b"print('v2 - edited on remote')" + download_fn = _make_download_fn({ + "root/.hermes/skill.py": remote_content, + }) + + mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn) + mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content) + + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + assert host_file.read_bytes() == remote_content + + +class TestSyncBackNewRemoteFile: + """File created on remote (not in _pushed_hashes) is applied via _infer_host_path.""" + + def test_sync_back_detects_new_remote_file(self, tmp_path): + # Existing mapping gives _infer_host_path a prefix to work with + existing_host = tmp_path / "host" / "skills" / "existing.py" + _write_file(existing_host, b"existing") + mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")] + + # Remote has a NEW file in the same directory that was never pushed + new_remote_content = b"# brand new skill created on remote" + download_fn = _make_download_fn({ + "root/.hermes/skills/new_skill.py": new_remote_content, + }) + + mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn) + # No entry in _pushed_hashes for the new file + + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # The new file should have been inferred and written to the host + expected_host_path = tmp_path / "host" / "skills" / "new_skill.py" + assert expected_host_path.exists() + assert expected_host_path.read_bytes() == new_remote_content + + +class TestSyncBackConflict: + """Host AND remote both changed since push -- warning logged, remote wins.""" + + def test_sync_back_conflict_warns(self, tmp_path, caplog): + host_file = tmp_path / "host" / "config.json" + original_content = b'{"v": 1}' + _write_file(host_file, original_content) + + remote_path = "/root/.hermes/config.json" + mapping = [(str(host_file), remote_path)] + + # Host was modified after push + host_file.write_bytes(b'{"v": 2, "host-edit": true}') + + # Remote was also modified + remote_content = b'{"v": 3, "remote-edit": true}' + download_fn = _make_download_fn({ + "root/.hermes/config.json": remote_content, + }) + + mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn) + mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content) + + with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"): + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # Conflict warning was logged + assert any("conflict" in r.message.lower() for r in caplog.records) + + # Remote version wins (last-write-wins) + assert host_file.read_bytes() == remote_content + + +class TestSyncBackRetries: + """Retry behaviour with exponential backoff.""" + + @patch("tools.environments.file_sync.time.sleep") + def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path): + call_count = 0 + + def flaky_download(dest: Path): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError(f"network error #{call_count}") + # Third attempt succeeds -- write a valid (empty) tar + _make_tar({}, dest) + + mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download) + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + assert call_count == 3 + # Sleep called twice (between attempt 1->2 and 2->3) + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0]) + mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1]) + + @patch("tools.environments.file_sync.time.sleep") + def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog): + def always_fail(dest: Path): + raise RuntimeError("persistent failure") + + mgr = _make_manager(tmp_path, bulk_download_fn=always_fail) + + with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"): + # Should NOT raise -- failures are logged, not propagated + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # All retries were attempted + assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1 + + # Final "all attempts failed" warning was logged + assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records) + + +class TestPushedHashesPopulated: + """_pushed_hashes is populated during sync() and cleared on delete.""" + + def test_pushed_hashes_populated_on_sync(self, tmp_path): + host_file = tmp_path / "data.txt" + host_file.write_bytes(b"hello world") + + remote_path = "/root/.hermes/data.txt" + mapping = [(str(host_file), remote_path)] + + mgr = FileSyncManager( + get_files_fn=lambda: mapping, + upload_fn=MagicMock(), + delete_fn=MagicMock(), + ) + + mgr.sync(force=True) + + assert remote_path in mgr._pushed_hashes + assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file)) + + def test_pushed_hashes_cleared_on_delete(self, tmp_path): + host_file = tmp_path / "deleteme.txt" + host_file.write_bytes(b"to be deleted") + + remote_path = "/root/.hermes/deleteme.txt" + mapping = [(str(host_file), remote_path)] + current_mapping = list(mapping) + + mgr = FileSyncManager( + get_files_fn=lambda: current_mapping, + upload_fn=MagicMock(), + delete_fn=MagicMock(), + ) + + # Sync to populate hashes + mgr.sync(force=True) + assert remote_path in mgr._pushed_hashes + + # Remove the file from the mapping (simulates local deletion) + os.unlink(str(host_file)) + current_mapping.clear() + + mgr.sync(force=True) + + # Hash should be cleaned up + assert remote_path not in mgr._pushed_hashes + + +class TestSyncBackFileLock: + """Verify that fcntl.flock is used during sync-back.""" + + @patch("tools.environments.file_sync.fcntl.flock") + def test_sync_back_file_lock(self, mock_flock, tmp_path): + download_fn = _make_download_fn({}) + mgr = _make_manager(tmp_path, bulk_download_fn=download_fn) + + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release + assert mock_flock.call_count >= 2 + + lock_calls = mock_flock.call_args_list + lock_ops = [c[0][1] for c in lock_calls] + assert fcntl.LOCK_EX in lock_ops + assert fcntl.LOCK_UN in lock_ops + + def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path): + """On Windows (fcntl=None), sync_back should skip file locking.""" + download_fn = _make_download_fn({}) + mgr = _make_manager(tmp_path, bulk_download_fn=download_fn) + + with patch("tools.environments.file_sync.fcntl", None): + # Should not raise — locking is skipped + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + +class TestInferHostPath: + """Edge cases for _infer_host_path prefix matching.""" + + def test_infer_no_matching_prefix(self, tmp_path): + """Remote path in unmapped directory should return None.""" + host_file = tmp_path / "host" / "skills" / "a.py" + _write_file(host_file, b"content") + mapping = [(str(host_file), "/root/.hermes/skills/a.py")] + + mgr = _make_manager(tmp_path, file_mapping=mapping) + result = mgr._infer_host_path( + "/root/.hermes/cache/new.json", + file_mapping=mapping, + ) + assert result is None + + def test_infer_partial_prefix_no_false_match(self, tmp_path): + """A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/.""" + host_file = tmp_path / "host" / "skills" / "a.py" + _write_file(host_file, b"content") + mapping = [(str(host_file), "/root/.hermes/skills/a.py")] + + mgr = _make_manager(tmp_path, file_mapping=mapping) + # /root/.hermes/skillsXtra/b.py shares prefix "skills" but the + # directory is different — should not match /root/.hermes/skills/ + result = mgr._infer_host_path( + "/root/.hermes/skillsXtra/b.py", + file_mapping=mapping, + ) + assert result is None + + def test_infer_matching_prefix(self, tmp_path): + """A file in a mapped directory should be correctly inferred.""" + host_file = tmp_path / "host" / "skills" / "a.py" + _write_file(host_file, b"content") + mapping = [(str(host_file), "/root/.hermes/skills/a.py")] + + mgr = _make_manager(tmp_path, file_mapping=mapping) + result = mgr._infer_host_path( + "/root/.hermes/skills/b.py", + file_mapping=mapping, + ) + expected = str(tmp_path / "host" / "skills" / "b.py") + assert result == expected + + +class TestSyncBackSIGINT: + """SIGINT deferral during sync-back.""" + + def test_sync_back_defers_sigint_on_main_thread(self, tmp_path): + """On the main thread, SIGINT handler should be swapped during sync.""" + download_fn = _make_download_fn({}) + mgr = _make_manager(tmp_path, bulk_download_fn=download_fn) + + handlers_seen = [] + original_getsignal = signal.getsignal + + with patch("tools.environments.file_sync.signal.getsignal", + side_effect=original_getsignal) as mock_get, \ + patch("tools.environments.file_sync.signal.signal") as mock_set: + mgr.sync_back(hermes_home=tmp_path / ".hermes") + + # signal.getsignal was called to save the original handler + assert mock_get.called + # signal.signal was called at least twice: install defer, restore original + assert mock_set.call_count >= 2 + + def test_sync_back_skips_signal_on_worker_thread(self, tmp_path): + """From a non-main thread, signal.signal should NOT be called.""" + import threading + + download_fn = _make_download_fn({}) + mgr = _make_manager(tmp_path, bulk_download_fn=download_fn) + + signal_called = [] + + def tracking_signal(*args): + signal_called.append(args) + + with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal): + # Run from a worker thread + exc = [] + def run(): + try: + mgr.sync_back(hermes_home=tmp_path / ".hermes") + except Exception as e: + exc.append(e) + + t = threading.Thread(target=run) + t.start() + t.join(timeout=10) + + assert not exc, f"sync_back raised: {exc}" + # signal.signal should NOT have been called from the worker thread + assert len(signal_called) == 0 diff --git a/tests/tools/test_sync_back_backends.py b/tests/tools/test_sync_back_backends.py new file mode 100644 index 0000000000..fb48796590 --- /dev/null +++ b/tests/tools/test_sync_back_backends.py @@ -0,0 +1,489 @@ +"""Tests for backend-specific bulk download implementations and cleanup() wiring.""" + +import asyncio +import subprocess +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +from tools.environments import ssh as ssh_env +from tools.environments import modal as modal_env +from tools.environments import daytona as daytona_env +from tools.environments.ssh import SSHEnvironment + + +# ── SSH helpers ────────────────────────────────────────────────────── + + +@pytest.fixture +def ssh_mock_env(monkeypatch): + """Create an SSHEnvironment with mocked connection/sync.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + monkeypatch.setattr( + ssh_env, "FileSyncManager", + lambda **kw: type("M", (), { + "sync": lambda self, **k: None, + "sync_back": lambda self: None, + })(), + ) + return SSHEnvironment(host="example.com", user="testuser") + + +# ── Modal helpers ──────────────────────────────────────────────────── + + +def _make_mock_modal_env(): + """Create a minimal ModalEnvironment without calling __init__.""" + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + env._sync_manager = None + return env + + +def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0): + """Wire sandbox.exec.aio to return mock tar output for download tests. + + Returns the exec_calls list for assertion. + """ + exec_calls = [] + + async def mock_exec_fn(*args, **kwargs): + exec_calls.append(args) + proc = MagicMock() + proc.stdout = MagicMock() + proc.stdout.read = MagicMock() + proc.stdout.read.aio = AsyncMock(return_value=tar_bytes) + proc.wait = MagicMock() + proc.wait.aio = AsyncMock(return_value=exit_code) + return proc + + env._sandbox.exec = MagicMock() + env._sandbox.exec.aio = mock_exec_fn + + def real_run_coroutine(coro, **kwargs): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + env._worker.run_coroutine = real_run_coroutine + return exec_calls + + +# ── Daytona helpers ────────────────────────────────────────────────── + + +def _make_mock_daytona_env(): + """Create a minimal DaytonaEnvironment without calling __init__.""" + env = object.__new__(daytona_env.DaytonaEnvironment) + env._sandbox = MagicMock() + env._remote_home = "/root" + env._sync_manager = None + env._lock = __import__("threading").Lock() + env._persistent = True + env._task_id = "test" + env._daytona = MagicMock() + return env + + +# ===================================================================== +# SSH bulk download +# ===================================================================== + + +class TestSSHBulkDownload: + """Unit tests for _ssh_bulk_download.""" + + def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path): + """subprocess.run command should include tar cf - over SSH.""" + dest = tmp_path / "backup.tar" + + with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: + # open() will be called to write stdout; mock it to avoid actual file I/O + ssh_mock_env._ssh_bulk_download(dest) + + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + cmd_str = " ".join(cmd) + assert "tar cf -" in cmd_str + assert "-C /" in cmd_str + assert "home/testuser/.hermes" in cmd_str + assert "ssh" in cmd_str + assert "testuser@example.com" in cmd_str + + def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path): + """subprocess.run should receive stdout=open(dest, 'wb').""" + dest = tmp_path / "backup.tar" + + with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: + ssh_mock_env._ssh_bulk_download(dest) + + # The stdout kwarg should be a file object opened for writing + call_kwargs = mock_run.call_args + # stdout is passed as a keyword arg + stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout") + # The file was opened via `with open(dest, "wb") as f` and passed as stdout=f. + # After the context manager exits, the file is closed, but we can verify + # the dest path was used by checking if the file was created. + assert dest.exists() + + def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path): + """Non-zero returncode should raise RuntimeError.""" + dest = tmp_path / "backup.tar" + + failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied") + with patch.object(subprocess, "run", return_value=failed): + with pytest.raises(RuntimeError, match="SSH bulk download failed"): + ssh_mock_env._ssh_bulk_download(dest) + + def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path): + """The subprocess.run call should use a 120s timeout.""" + dest = tmp_path / "backup.tar" + + with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: + ssh_mock_env._ssh_bulk_download(dest) + + call_kwargs = mock_run.call_args + assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120 + + +class TestSSHCleanup: + """Verify SSH cleanup() calls sync_back() before closing ControlMaster.""" + + def test_ssh_cleanup_calls_sync_back(self, monkeypatch): + """cleanup() should call sync_back() before SSH control socket teardown.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + + call_order = [] + + class TrackingSyncManager: + def __init__(self, **kwargs): + pass + + def sync(self, **kw): + pass + + def sync_back(self): + call_order.append("sync_back") + + monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager) + + env = SSHEnvironment(host="h", user="u") + # Ensure control_socket does not exist so cleanup skips the SSH exit call + env.control_socket = Path("/nonexistent/socket") + + env.cleanup() + + assert "sync_back" in call_order + + def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch): + """sync_back() must run before the ControlMaster exit command.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + + call_order = [] + + class TrackingSyncManager: + def __init__(self, **kwargs): + pass + + def sync(self, **kw): + pass + + def sync_back(self): + call_order.append("sync_back") + + monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager) + + env = SSHEnvironment(host="h", user="u") + + # Create a fake control socket so cleanup tries the SSH exit + import tempfile + with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp: + env.control_socket = Path(tmp.name) + + def mock_run(cmd, **kwargs): + cmd_str = " ".join(cmd) + if "-O" in cmd and "exit" in cmd_str: + call_order.append("control_exit") + return subprocess.CompletedProcess([], 0) + + with patch.object(subprocess, "run", side_effect=mock_run): + env.cleanup() + + assert call_order.index("sync_back") < call_order.index("control_exit") + + +# ===================================================================== +# Modal bulk download +# ===================================================================== + + +class TestModalBulkDownload: + """Unit tests for _modal_bulk_download.""" + + def test_modal_bulk_download_command(self, tmp_path): + """exec should be called with tar cf - -C /root/.hermes .""" + env = _make_mock_modal_env() + exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content") + dest = tmp_path / "backup.tar" + + env._modal_bulk_download(dest) + + assert len(exec_calls) == 1 + args = exec_calls[0] + assert args[0] == "bash" + assert args[1] == "-c" + assert "tar cf -" in args[2] + assert "-C / root/.hermes" in args[2] + + def test_modal_bulk_download_writes_to_dest(self, tmp_path): + """Downloaded tar bytes should be written to the dest path.""" + env = _make_mock_modal_env() + expected_data = b"some-tar-archive-bytes" + _wire_modal_download(env, tar_bytes=expected_data) + dest = tmp_path / "backup.tar" + + env._modal_bulk_download(dest) + + assert dest.exists() + assert dest.read_bytes() == expected_data + + def test_modal_bulk_download_handles_str_output(self, tmp_path): + """If stdout returns str instead of bytes, it should be encoded.""" + env = _make_mock_modal_env() + # Simulate Modal SDK returning str + _wire_modal_download(env, tar_bytes="string-tar-data") + dest = tmp_path / "backup.tar" + + env._modal_bulk_download(dest) + + assert dest.read_bytes() == b"string-tar-data" + + def test_modal_bulk_download_raises_on_failure(self, tmp_path): + """Non-zero exit code should raise RuntimeError.""" + env = _make_mock_modal_env() + _wire_modal_download(env, exit_code=1) + dest = tmp_path / "backup.tar" + + with pytest.raises(RuntimeError, match="Modal bulk download failed"): + env._modal_bulk_download(dest) + + def test_modal_bulk_download_uses_120s_timeout(self, tmp_path): + """run_coroutine should be called with timeout=120.""" + env = _make_mock_modal_env() + _wire_modal_download(env, tar_bytes=b"data") + + run_kwargs = {} + original_run = env._worker.run_coroutine + + def tracking_run(coro, **kwargs): + run_kwargs.update(kwargs) + return original_run(coro, **kwargs) + + env._worker.run_coroutine = tracking_run + dest = tmp_path / "backup.tar" + + env._modal_bulk_download(dest) + + assert run_kwargs.get("timeout") == 120 + + +class TestModalCleanup: + """Verify Modal cleanup() calls sync_back() before terminate.""" + + def test_modal_cleanup_calls_sync_back(self): + """cleanup() should call sync_back() before sandbox.terminate.""" + env = _make_mock_modal_env() + + call_order = [] + sync_mgr = MagicMock() + sync_mgr.sync_back = lambda: call_order.append("sync_back") + env._sync_manager = sync_mgr + + # Mock terminate to track call order + async def mock_terminate(): + pass + + env._sandbox.terminate = MagicMock() + env._sandbox.terminate.aio = mock_terminate + env._worker.run_coroutine = lambda coro, **kw: ( + call_order.append("terminate"), + asyncio.new_event_loop().run_until_complete(coro), + ) + env._worker.stop = lambda: None + + env.cleanup() + + assert "sync_back" in call_order + assert call_order.index("sync_back") < call_order.index("terminate") + + +# ===================================================================== +# Daytona bulk download +# ===================================================================== + + +class TestDaytonaBulkDownload: + """Unit tests for _daytona_bulk_download.""" + + def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path): + """exec and download_file should both be called.""" + env = _make_mock_daytona_env() + dest = tmp_path / "backup.tar" + + env._daytona_bulk_download(dest) + + # exec called twice: tar creation + rm cleanup + assert env._sandbox.process.exec.call_count == 2 + tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0] + assert "tar cf" in tar_cmd + assert "/tmp/.hermes_sync.tar" in tar_cmd + assert ".hermes" in tar_cmd + + cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0] + assert "rm -f /tmp/.hermes_sync.tar" in cleanup_cmd + + env._sandbox.fs.download_file.assert_called_once_with( + "/tmp/.hermes_sync.tar", str(dest) + ) + + def test_daytona_bulk_download_uses_remote_home(self, tmp_path): + """The tar command should use the env's _remote_home.""" + env = _make_mock_daytona_env() + env._remote_home = "/home/daytona" + dest = tmp_path / "backup.tar" + + env._daytona_bulk_download(dest) + + tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0] + assert "home/daytona/.hermes" in tar_cmd + + +class TestDaytonaCleanup: + """Verify Daytona cleanup() calls sync_back() before stop.""" + + def test_daytona_cleanup_calls_sync_back(self): + """cleanup() should call sync_back() before sandbox.stop().""" + env = _make_mock_daytona_env() + + call_order = [] + sync_mgr = MagicMock() + sync_mgr.sync_back = lambda: call_order.append("sync_back") + env._sync_manager = sync_mgr + env._sandbox.stop = lambda: call_order.append("stop") + + env.cleanup() + + assert "sync_back" in call_order + assert "stop" in call_order + assert call_order.index("sync_back") < call_order.index("stop") + + +# ===================================================================== +# FileSyncManager wiring: bulk_download_fn passed by each backend +# ===================================================================== + + +class TestBulkDownloadWiring: + """Verify each backend passes bulk_download_fn to FileSyncManager.""" + + def test_ssh_passes_bulk_download_fn(self, monkeypatch): + """SSHEnvironment should pass _ssh_bulk_download to FileSyncManager.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + + captured_kwargs = {} + + class CaptureSyncManager: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + def sync(self, **kw): + pass + + monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager) + + SSHEnvironment(host="h", user="u") + + assert "bulk_download_fn" in captured_kwargs + assert callable(captured_kwargs["bulk_download_fn"]) + + def test_modal_passes_bulk_download_fn(self, monkeypatch): + """ModalEnvironment should pass _modal_bulk_download to FileSyncManager.""" + captured_kwargs = {} + + def capture_fsm(**kwargs): + captured_kwargs.update(kwargs) + return type("M", (), {"sync": lambda self, **k: None})() + + monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm) + + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + + # Replicate the wiring done in __init__ + from tools.environments.file_sync import iter_sync_files + env._sync_manager = modal_env.FileSyncManager( + get_files_fn=lambda: iter_sync_files("/root/.hermes"), + upload_fn=env._modal_upload, + delete_fn=env._modal_delete, + bulk_upload_fn=env._modal_bulk_upload, + bulk_download_fn=env._modal_bulk_download, + ) + + assert "bulk_download_fn" in captured_kwargs + assert callable(captured_kwargs["bulk_download_fn"]) + + def test_daytona_passes_bulk_download_fn(self, monkeypatch): + """DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager.""" + captured_kwargs = {} + + def capture_fsm(**kwargs): + captured_kwargs.update(kwargs) + return type("M", (), {"sync": lambda self, **k: None})() + + monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm) + + env = object.__new__(daytona_env.DaytonaEnvironment) + env._sandbox = MagicMock() + env._remote_home = "/root" + env._lock = __import__("threading").Lock() + env._persistent = True + env._task_id = "test" + env._daytona = MagicMock() + + # Replicate the wiring done in __init__ + from tools.environments.file_sync import iter_sync_files + env._sync_manager = daytona_env.FileSyncManager( + get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"), + upload_fn=env._daytona_upload, + delete_fn=env._daytona_delete, + bulk_upload_fn=env._daytona_bulk_upload, + bulk_download_fn=env._daytona_bulk_download, + ) + + assert "bulk_download_fn" in captured_kwargs + assert callable(captured_kwargs["bulk_download_fn"]) diff --git a/tools/environments/daytona.py b/tools/environments/daytona.py index c2913e585e..1351444715 100644 --- a/tools/environments/daytona.py +++ b/tools/environments/daytona.py @@ -134,6 +134,7 @@ class DaytonaEnvironment(BaseEnvironment): upload_fn=self._daytona_upload, delete_fn=self._daytona_delete, bulk_upload_fn=self._daytona_bulk_upload, + bulk_download_fn=self._daytona_bulk_download, ) self._sync_manager.sync(force=True) self.init_session() @@ -166,6 +167,19 @@ class DaytonaEnvironment(BaseEnvironment): ] self._sandbox.fs.upload_files(uploads) + def _daytona_bulk_download(self, dest: Path) -> None: + """Download remote .hermes/ as a tar archive.""" + rel_base = f"{self._remote_home}/.hermes".lstrip("/") + self._sandbox.process.exec( + f"tar cf /tmp/.hermes_sync.tar -C / {shlex.quote(rel_base)}" + ) + self._sandbox.fs.download_file("/tmp/.hermes_sync.tar", str(dest)) + # Clean up remote temp file + try: + self._sandbox.process.exec("rm -f /tmp/.hermes_sync.tar") + except Exception: + pass # best-effort cleanup + def _daytona_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files via SDK exec.""" self._sandbox.process.exec(quoted_rm_command(remote_paths)) @@ -213,6 +227,10 @@ class DaytonaEnvironment(BaseEnvironment): return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) def cleanup(self): + if self._sync_manager: + logger.info("Daytona: syncing files from sandbox...") + self._sync_manager.sync_back() + with self._lock: if self._sandbox is None: return diff --git a/tools/environments/file_sync.py b/tools/environments/file_sync.py index 64a5b56dc4..ae61eb13b7 100644 --- a/tools/environments/file_sync.py +++ b/tools/environments/file_sync.py @@ -6,13 +6,25 @@ and Daytona. Docker and Singularity use bind mounts (live host FS view) and don't need this. """ +import hashlib import logging import os import shlex +import shutil +import signal +import tarfile +import tempfile +import threading import time + +try: + import fcntl +except ImportError: + fcntl = None # Windows — file locking skipped from pathlib import Path from typing import Callable +from hermes_constants import get_hermes_home from tools.environments.base import _file_mtime_key logger = logging.getLogger(__name__) @@ -23,6 +35,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC" # Transport callbacks provided by each backend UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure +BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...] @@ -71,6 +84,19 @@ def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]: return sorted({str(Path(remote).parent) for _, remote in files}) +def _sha256_file(path: str) -> str: + """Return hex SHA-256 digest of a file.""" + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +_SYNC_BACK_MAX_RETRIES = 3 +_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries + + class FileSyncManager: """Tracks local file changes and syncs to a remote environment. @@ -89,12 +115,15 @@ class FileSyncManager: delete_fn: DeleteFn, sync_interval: float = _SYNC_INTERVAL_SECONDS, bulk_upload_fn: BulkUploadFn | None = None, + bulk_download_fn: BulkDownloadFn | None = None, ): self._get_files_fn = get_files_fn self._upload_fn = upload_fn self._bulk_upload_fn = bulk_upload_fn + self._bulk_download_fn = bulk_download_fn self._delete_fn = delete_fn self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size) + self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs self._sync_interval = sync_interval @@ -136,6 +165,7 @@ class FileSyncManager: # Snapshot for rollback (only when there's work to do) prev_files = dict(self._synced_files) + prev_hashes = dict(self._pushed_hashes) if to_upload: logger.debug("file_sync: uploading %d file(s)", len(to_upload)) @@ -156,13 +186,187 @@ class FileSyncManager: logger.debug("file_sync: deleted %s", to_delete) # --- Commit (all succeeded) --- + for host_path, remote_path in to_upload: + self._pushed_hashes[remote_path] = _sha256_file(host_path) + for p in to_delete: new_files.pop(p, None) + self._pushed_hashes.pop(p, None) self._synced_files = new_files self._last_sync_time = time.monotonic() except Exception as exc: self._synced_files = prev_files + self._pushed_hashes = prev_hashes self._last_sync_time = time.monotonic() logger.warning("file_sync: sync failed, rolled back state: %s", exc) + + # ------------------------------------------------------------------ + # Sync-back: pull remote changes to host on teardown + # ------------------------------------------------------------------ + + def sync_back(self, hermes_home: Path | None = None) -> None: + """Pull remote changes back to the host filesystem. + + Downloads the remote ``.hermes/`` directory as a tar archive, + unpacks it, and applies only files that differ from what was + originally pushed (based on SHA-256 content hashes). + + Protected against SIGINT (defers the signal until complete) and + serialized across concurrent gateway sandboxes via file lock. + """ + if self._bulk_download_fn is None: + return + + lock_path = (hermes_home or get_hermes_home()) / ".sync.lock" + lock_path.parent.mkdir(parents=True, exist_ok=True) + + last_exc: Exception | None = None + for attempt in range(_SYNC_BACK_MAX_RETRIES): + try: + self._sync_back_once(lock_path) + return + except Exception as exc: + last_exc = exc + if attempt < _SYNC_BACK_MAX_RETRIES - 1: + delay = _SYNC_BACK_BACKOFF[attempt] + logger.warning( + "sync_back: attempt %d failed (%s), retrying in %ds", + attempt + 1, exc, delay, + ) + time.sleep(delay) + + logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc) + + def _sync_back_once(self, lock_path: Path) -> None: + """Single sync-back attempt with SIGINT protection and file lock.""" + # signal.signal() only works from the main thread. In gateway + # contexts cleanup() may run from a worker thread — skip SIGINT + # deferral there rather than crashing. + on_main_thread = threading.current_thread() is threading.main_thread() + + deferred_sigint: list[object] = [] + original_handler = None + if on_main_thread: + original_handler = signal.getsignal(signal.SIGINT) + + def _defer_sigint(signum, frame): + deferred_sigint.append((signum, frame)) + logger.debug("sync_back: SIGINT deferred until sync completes") + + signal.signal(signal.SIGINT, _defer_sigint) + try: + self._sync_back_locked(lock_path) + finally: + if on_main_thread and original_handler is not None: + signal.signal(signal.SIGINT, original_handler) + if deferred_sigint: + os.kill(os.getpid(), signal.SIGINT) + + def _sync_back_locked(self, lock_path: Path) -> None: + """Sync-back under file lock (serializes concurrent gateways).""" + if fcntl is None: + # Windows: no flock — run without serialization + self._sync_back_impl() + return + lock_fd = open(lock_path, "w") + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + self._sync_back_impl() + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + lock_fd.close() + + def _sync_back_impl(self) -> None: + """Download, diff, and apply remote changes to host.""" + if self._bulk_download_fn is None: + raise RuntimeError("_sync_back_impl called without bulk_download_fn") + + # Cache file mapping once to avoid O(n*m) from repeated iteration + try: + file_mapping = list(self._get_files_fn()) + except Exception: + file_mapping = [] + + with tempfile.NamedTemporaryFile(suffix=".tar") as tf: + self._bulk_download_fn(Path(tf.name)) + + with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging: + with tarfile.open(tf.name) as tar: + tar.extractall(staging, filter="data") + + applied = 0 + for dirpath, _dirnames, filenames in os.walk(staging): + for fname in filenames: + staged_file = os.path.join(dirpath, fname) + rel = os.path.relpath(staged_file, staging) + remote_path = "/" + rel + + pushed_hash = self._pushed_hashes.get(remote_path) + + # Skip hashing for files unchanged from push + if pushed_hash is not None: + remote_hash = _sha256_file(staged_file) + if remote_hash == pushed_hash: + continue + else: + remote_hash = None # new remote file + + # Resolve host path from cached mapping + host_path = self._resolve_host_path(remote_path, file_mapping) + if host_path is None: + host_path = self._infer_host_path(remote_path, file_mapping) + if host_path is None: + logger.debug( + "sync_back: skipping %s (no host mapping)", + remote_path, + ) + continue + + if os.path.exists(host_path) and pushed_hash is not None: + host_hash = _sha256_file(host_path) + if host_hash != pushed_hash: + logger.warning( + "sync_back: conflict on %s — host modified " + "since push, remote also changed. Applying " + "remote version (last-write-wins).", + remote_path, + ) + + os.makedirs(os.path.dirname(host_path), exist_ok=True) + shutil.copy2(staged_file, host_path) + applied += 1 + + if applied: + logger.info("sync_back: applied %d changed file(s)", applied) + else: + logger.debug("sync_back: no remote changes detected") + + def _resolve_host_path(self, remote_path: str, + file_mapping: list[tuple[str, str]] | None = None) -> str | None: + """Find the host path for a known remote path from the file mapping.""" + mapping = file_mapping if file_mapping is not None else [] + for host, remote in mapping: + if remote == remote_path: + return host + return None + + def _infer_host_path(self, remote_path: str, + file_mapping: list[tuple[str, str]] | None = None) -> str | None: + """Infer a host path for a new remote file by matching path prefixes. + + Uses the existing file mapping to find a remote->host directory + pair, then applies the same prefix substitution to the new file. + For example, if the mapping has ``/root/.hermes/skills/a.md`` → + ``~/.hermes/skills/a.md``, a new remote file at + ``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``. + """ + mapping = file_mapping if file_mapping is not None else [] + for host, remote in mapping: + remote_dir = str(Path(remote).parent) + if remote_path.startswith(remote_dir + "/"): + host_dir = str(Path(host).parent) + suffix = remote_path[len(remote_dir):] + return host_dir + suffix + return None diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 5c5c721c1e..4b7e9db0cd 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -269,6 +269,7 @@ class ModalEnvironment(BaseEnvironment): upload_fn=self._modal_upload, delete_fn=self._modal_delete, bulk_upload_fn=self._modal_bulk_upload, + bulk_download_fn=self._modal_bulk_download, ) self._sync_manager.sync(force=True) self.init_session() @@ -347,6 +348,27 @@ class ModalEnvironment(BaseEnvironment): self._worker.run_coroutine(_bulk(), timeout=120) + def _modal_bulk_download(self, dest: Path) -> None: + """Download remote .hermes/ as a tar archive. + + Modal sandboxes always run as root, so /root/.hermes is hardcoded + (consistent with iter_sync_files call on line 269). + """ + async def _download(): + proc = await self._sandbox.exec.aio( + "bash", "-c", "tar cf - -C / root/.hermes" + ) + data = await proc.stdout.read.aio() + exit_code = await proc.wait.aio() + if exit_code != 0: + raise RuntimeError(f"Modal bulk download failed (exit {exit_code})") + return data + + tar_bytes = self._worker.run_coroutine(_download(), timeout=120) + if isinstance(tar_bytes, str): + tar_bytes = tar_bytes.encode() + dest.write_bytes(tar_bytes) + def _modal_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files via exec.""" rm_cmd = quoted_rm_command(remote_paths) @@ -404,6 +426,10 @@ class ModalEnvironment(BaseEnvironment): if self._sandbox is None: return + if self._sync_manager: + logger.info("Modal: syncing files from sandbox...") + self._sync_manager.sync_back() + if self._persistent: try: async def _snapshot(): diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 0491764b2f..568112b2c8 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -58,6 +58,7 @@ class SSHEnvironment(BaseEnvironment): upload_fn=self._scp_upload, delete_fn=self._ssh_delete, bulk_upload_fn=self._ssh_bulk_upload, + bulk_download_fn=self._ssh_bulk_download, ) self._sync_manager.sync(force=True) @@ -216,6 +217,18 @@ class SSHEnvironment(BaseEnvironment): logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files)) + def _ssh_bulk_download(self, dest: Path) -> None: + """Download remote .hermes/ as a tar archive.""" + # Tar from / with the full path so archive entries preserve absolute + # paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys. + rel_base = f"{self._remote_home}/.hermes".lstrip("/") + ssh_cmd = self._build_ssh_command() + ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}") + with open(dest, "wb") as f: + result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120) + if result.returncode != 0: + raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}") + def _ssh_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files in one SSH call.""" cmd = self._build_ssh_command() @@ -245,6 +258,10 @@ class SSHEnvironment(BaseEnvironment): return _popen_bash(cmd, stdin_data) def cleanup(self): + if self._sync_manager: + logger.info("SSH: syncing files from sandbox...") + self._sync_manager.sync_back() + if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",