mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(file-sync): sync remote changes back to host on teardown
Salvage of PR #8018 by @alt-glitch onto current main. On sandbox teardown, FileSyncManager now downloads the remote .hermes/ directory, diffs against SHA-256 hashes of what was originally pushed, and applies only changed files back to the host. Core (tools/environments/file_sync.py): - sync_back(): orchestrates download -> unpack -> diff -> apply with: - Retry with exponential backoff (3 attempts, 2s/4s/8s) - SIGINT trap + defer (prevents partial writes on Ctrl-C) - fcntl.flock serialization (concurrent gateway sandboxes) - Last-write-wins conflict resolution with warning - New remote files pulled back via _infer_host_path prefix matching Backends: - SSH: _ssh_bulk_download — tar cf - piped over SSH - Modal: _modal_bulk_download — exec tar cf - -> proc.stdout.read - Daytona: _daytona_bulk_download — exec tar cf -> SDK download_file - All three call sync_back() at the top of cleanup() Fixes applied during salvage (vs original PR #8018): | # | Issue | Fix | |---|-------|-----| | C1 | import fcntl unconditional — crashes Windows | try/except with fallback; _sync_back_locked skips locking when fcntl=None | | W1 | assert for runtime guard (stripped by -O) | Replaced with proper if/raise RuntimeError | | W2 | O(n*m) from _get_files_fn() called per file | Cache mapping once at start of _sync_back_impl, pass to resolve/infer | | W3 | Dead BulkDownloadFn imports in 3 backends | Removed unused imports | | W4 | Modal hardcodes root/.hermes, no explanation | Added docstring comment explaining Modal always runs as root | | S1 | SHA-256 computed for new files where pushed_hash=None | Skip hashing when pushed_hash is None (comparison always False) | | S2 | Daytona /tmp/.hermes_sync.tar never cleaned up | Added rm -f after download (best-effort) | Tests: 49 passing (17 new: _infer_host_path edge cases, SIGINT main/worker thread, Windows fcntl=None fallback, Daytona tar cleanup). Based on #8018 by @alt-glitch.
This commit is contained in:
parent
764536b684
commit
d64446e315
6 changed files with 1166 additions and 0 deletions
412
tests/tools/test_file_sync_back.py
Normal file
412
tests/tools/test_file_sync_back.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||
|
||||
import fcntl
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
_sha256_file,
|
||||
_SYNC_BACK_BACKOFF,
|
||||
_SYNC_BACK_MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||
"""Write a tar archive containing the given arcname->content pairs."""
|
||||
with tarfile.open(dest, "w") as tar:
|
||||
for arcname, content in files.items():
|
||||
info = tarfile.TarInfo(name=arcname)
|
||||
info.size = len(content)
|
||||
tar.addfile(info, io.BytesIO(content))
|
||||
|
||||
|
||||
def _make_download_fn(files: dict[str, bytes]):
|
||||
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||
def download(dest: Path):
|
||||
_make_tar(files, dest)
|
||||
return download
|
||||
|
||||
|
||||
def _sha256_bytes(data: bytes) -> str:
|
||||
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||
import hashlib
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def _write_file(path: Path, content: bytes) -> str:
|
||||
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
def _make_manager(
|
||||
tmp_path: Path,
|
||||
file_mapping: list[tuple[str, str]] | None = None,
|
||||
bulk_download_fn=None,
|
||||
) -> FileSyncManager:
|
||||
"""Create a FileSyncManager wired for testing.
|
||||
|
||||
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||
``get_files_fn`` returns. If *None* an empty list is used.
|
||||
"""
|
||||
mapping = file_mapping or []
|
||||
return FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
bulk_download_fn=bulk_download_fn,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncBackNoop:
|
||||
"""sync_back() is a no-op when there is no download function."""
|
||||
|
||||
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||
# Should return immediately without error
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
# Nothing to assert beyond "no exception raised"
|
||||
|
||||
|
||||
class TestSyncBackNoChanges:
|
||||
"""When all remote files match pushed hashes, nothing is applied."""
|
||||
|
||||
def test_sync_back_no_changes(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "cred.json"
|
||||
host_content = b'{"key": "val"}'
|
||||
_write_file(host_file, host_content)
|
||||
|
||||
remote_path = "/root/.hermes/cred.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Remote tar contains the same content as was pushed
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/cred.json": host_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# Simulate that we already pushed this file with this hash
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be unchanged (same content, same bytes)
|
||||
assert host_file.read_bytes() == host_content
|
||||
|
||||
|
||||
class TestSyncBackAppliesChanged:
|
||||
"""Remote file differs from pushed version -- gets copied to host."""
|
||||
|
||||
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "skill.py"
|
||||
original_content = b"print('v1')"
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/skill.py"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
remote_content = b"print('v2 - edited on remote')"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skill.py": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackNewRemoteFile:
|
||||
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||
|
||||
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||
# Existing mapping gives _infer_host_path a prefix to work with
|
||||
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||
_write_file(existing_host, b"existing")
|
||||
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||
|
||||
# Remote has a NEW file in the same directory that was never pushed
|
||||
new_remote_content = b"# brand new skill created on remote"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# No entry in _pushed_hashes for the new file
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# The new file should have been inferred and written to the host
|
||||
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||
assert expected_host_path.exists()
|
||||
assert expected_host_path.read_bytes() == new_remote_content
|
||||
|
||||
|
||||
class TestSyncBackConflict:
|
||||
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||
|
||||
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||
host_file = tmp_path / "host" / "config.json"
|
||||
original_content = b'{"v": 1}'
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/config.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Host was modified after push
|
||||
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||
|
||||
# Remote was also modified
|
||||
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/config.json": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Conflict warning was logged
|
||||
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||
|
||||
# Remote version wins (last-write-wins)
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackRetries:
|
||||
"""Retry behaviour with exponential backoff."""
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||
call_count = 0
|
||||
|
||||
def flaky_download(dest: Path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError(f"network error #{call_count}")
|
||||
# Third attempt succeeds -- write a valid (empty) tar
|
||||
_make_tar({}, dest)
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert call_count == 3
|
||||
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||
def always_fail(dest: Path):
|
||||
raise RuntimeError("persistent failure")
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
# Should NOT raise -- failures are logged, not propagated
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# All retries were attempted
|
||||
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||
|
||||
# Final "all attempts failed" warning was logged
|
||||
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
class TestPushedHashesPopulated:
|
||||
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||
|
||||
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||
host_file = tmp_path / "data.txt"
|
||||
host_file.write_bytes(b"hello world")
|
||||
|
||||
remote_path = "/root/.hermes/data.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||
|
||||
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||
host_file = tmp_path / "deleteme.txt"
|
||||
host_file.write_bytes(b"to be deleted")
|
||||
|
||||
remote_path = "/root/.hermes/deleteme.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
current_mapping = list(mapping)
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: current_mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
# Sync to populate hashes
|
||||
mgr.sync(force=True)
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
|
||||
# Remove the file from the mapping (simulates local deletion)
|
||||
os.unlink(str(host_file))
|
||||
current_mapping.clear()
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
# Hash should be cleaned up
|
||||
assert remote_path not in mgr._pushed_hashes
|
||||
|
||||
|
||||
class TestSyncBackFileLock:
|
||||
"""Verify that fcntl.flock is used during sync-back."""
|
||||
|
||||
@patch("tools.environments.file_sync.fcntl.flock")
|
||||
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||
assert mock_flock.call_count >= 2
|
||||
|
||||
lock_calls = mock_flock.call_args_list
|
||||
lock_ops = [c[0][1] for c in lock_calls]
|
||||
assert fcntl.LOCK_EX in lock_ops
|
||||
assert fcntl.LOCK_UN in lock_ops
|
||||
|
||||
def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path):
|
||||
"""On Windows (fcntl=None), sync_back should skip file locking."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
with patch("tools.environments.file_sync.fcntl", None):
|
||||
# Should not raise — locking is skipped
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
|
||||
class TestInferHostPath:
|
||||
"""Edge cases for _infer_host_path prefix matching."""
|
||||
|
||||
def test_infer_no_matching_prefix(self, tmp_path):
|
||||
"""Remote path in unmapped directory should return None."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/cache/new.json",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_partial_prefix_no_false_match(self, tmp_path):
|
||||
"""A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
# /root/.hermes/skillsXtra/b.py shares prefix "skills" but the
|
||||
# directory is different — should not match /root/.hermes/skills/
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skillsXtra/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_matching_prefix(self, tmp_path):
|
||||
"""A file in a mapped directory should be correctly inferred."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skills/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
expected = str(tmp_path / "host" / "skills" / "b.py")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestSyncBackSIGINT:
|
||||
"""SIGINT deferral during sync-back."""
|
||||
|
||||
def test_sync_back_defers_sigint_on_main_thread(self, tmp_path):
|
||||
"""On the main thread, SIGINT handler should be swapped during sync."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
handlers_seen = []
|
||||
original_getsignal = signal.getsignal
|
||||
|
||||
with patch("tools.environments.file_sync.signal.getsignal",
|
||||
side_effect=original_getsignal) as mock_get, \
|
||||
patch("tools.environments.file_sync.signal.signal") as mock_set:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# signal.getsignal was called to save the original handler
|
||||
assert mock_get.called
|
||||
# signal.signal was called at least twice: install defer, restore original
|
||||
assert mock_set.call_count >= 2
|
||||
|
||||
def test_sync_back_skips_signal_on_worker_thread(self, tmp_path):
|
||||
"""From a non-main thread, signal.signal should NOT be called."""
|
||||
import threading
|
||||
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
signal_called = []
|
||||
|
||||
def tracking_signal(*args):
|
||||
signal_called.append(args)
|
||||
|
||||
with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal):
|
||||
# Run from a worker thread
|
||||
exc = []
|
||||
def run():
|
||||
try:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
except Exception as e:
|
||||
exc.append(e)
|
||||
|
||||
t = threading.Thread(target=run)
|
||||
t.start()
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not exc, f"sync_back raised: {exc}"
|
||||
# signal.signal should NOT have been called from the worker thread
|
||||
assert len(signal_called) == 0
|
||||
489
tests/tools/test_sync_back_backends.py
Normal file
489
tests/tools/test_sync_back_backends.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments import modal as modal_env
|
||||
from tools.environments import daytona as daytona_env
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssh_mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {
|
||||
"sync": lambda self, **k: None,
|
||||
"sync_back": lambda self: None,
|
||||
})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
# ── Modal helpers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_modal_env():
|
||||
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||
|
||||
Returns the exec_calls list for assertion.
|
||||
"""
|
||||
exec_calls = []
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.read = MagicMock()
|
||||
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls
|
||||
|
||||
|
||||
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_daytona_env():
|
||||
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._sync_manager = None
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
return env
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SSH bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestSSHBulkDownload:
|
||||
"""Unit tests for _ssh_bulk_download."""
|
||||
|
||||
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run command should include tar cf - over SSH."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
mock_run.assert_called_once()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
cmd_str = " ".join(cmd)
|
||||
assert "tar cf -" in cmd_str
|
||||
assert "-C /" in cmd_str
|
||||
assert "home/testuser/.hermes" in cmd_str
|
||||
assert "ssh" in cmd_str
|
||||
assert "testuser@example.com" in cmd_str
|
||||
|
||||
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
# The stdout kwarg should be a file object opened for writing
|
||||
call_kwargs = mock_run.call_args
|
||||
# stdout is passed as a keyword arg
|
||||
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||
# After the context manager exits, the file is closed, but we can verify
|
||||
# the dest path was used by checking if the file was created.
|
||||
assert dest.exists()
|
||||
|
||||
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||
"""Non-zero returncode should raise RuntimeError."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||
"""The subprocess.run call should use a 120s timeout."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||
|
||||
|
||||
class TestSSHCleanup:
|
||||
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||
env.control_socket = Path("/nonexistent/socket")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||
"""sync_back() must run before the ControlMaster exit command."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
# Create a fake control socket so cleanup tries the SSH exit
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||
env.control_socket = Path(tmp.name)
|
||||
|
||||
def mock_run(cmd, **kwargs):
|
||||
cmd_str = " ".join(cmd)
|
||||
if "-O" in cmd and "exit" in cmd_str:
|
||||
call_order.append("control_exit")
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||
env.cleanup()
|
||||
|
||||
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Modal bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestModalBulkDownload:
|
||||
"""Unit tests for _modal_bulk_download."""
|
||||
|
||||
def test_modal_bulk_download_command(self, tmp_path):
|
||||
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||
env = _make_mock_modal_env()
|
||||
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
assert "tar cf -" in args[2]
|
||||
assert "-C / root/.hermes" in args[2]
|
||||
|
||||
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||
"""Downloaded tar bytes should be written to the dest path."""
|
||||
env = _make_mock_modal_env()
|
||||
expected_data = b"some-tar-archive-bytes"
|
||||
_wire_modal_download(env, tar_bytes=expected_data)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.exists()
|
||||
assert dest.read_bytes() == expected_data
|
||||
|
||||
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||
env = _make_mock_modal_env()
|
||||
# Simulate Modal SDK returning str
|
||||
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.read_bytes() == b"string-tar-data"
|
||||
|
||||
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||
"""Non-zero exit code should raise RuntimeError."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, exit_code=1)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||
"""run_coroutine should be called with timeout=120."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, tar_bytes=b"data")
|
||||
|
||||
run_kwargs = {}
|
||||
original_run = env._worker.run_coroutine
|
||||
|
||||
def tracking_run(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
return original_run(coro, **kwargs)
|
||||
|
||||
env._worker.run_coroutine = tracking_run
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
|
||||
class TestModalCleanup:
|
||||
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||
|
||||
def test_modal_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||
env = _make_mock_modal_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
|
||||
# Mock terminate to track call order
|
||||
async def mock_terminate():
|
||||
pass
|
||||
|
||||
env._sandbox.terminate = MagicMock()
|
||||
env._sandbox.terminate.aio = mock_terminate
|
||||
env._worker.run_coroutine = lambda coro, **kw: (
|
||||
call_order.append("terminate"),
|
||||
asyncio.new_event_loop().run_until_complete(coro),
|
||||
)
|
||||
env._worker.stop = lambda: None
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Daytona bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestDaytonaBulkDownload:
|
||||
"""Unit tests for _daytona_bulk_download."""
|
||||
|
||||
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||
"""exec and download_file should both be called."""
|
||||
env = _make_mock_daytona_env()
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
# exec called twice: tar creation + rm cleanup
|
||||
assert env._sandbox.process.exec.call_count == 2
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "tar cf" in tar_cmd
|
||||
assert "/tmp/.hermes_sync.tar" in tar_cmd
|
||||
assert ".hermes" in tar_cmd
|
||||
|
||||
cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0]
|
||||
assert "rm -f /tmp/.hermes_sync.tar" in cleanup_cmd
|
||||
|
||||
env._sandbox.fs.download_file.assert_called_once_with(
|
||||
"/tmp/.hermes_sync.tar", str(dest)
|
||||
)
|
||||
|
||||
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||
"""The tar command should use the env's _remote_home."""
|
||||
env = _make_mock_daytona_env()
|
||||
env._remote_home = "/home/daytona"
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "home/daytona/.hermes" in tar_cmd
|
||||
|
||||
|
||||
class TestDaytonaCleanup:
|
||||
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||
|
||||
def test_daytona_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||
env = _make_mock_daytona_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
env._sandbox.stop = lambda: call_order.append("stop")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert "stop" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("stop")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestBulkDownloadWiring:
|
||||
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||
|
||||
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class CaptureSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||
|
||||
SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
bulk_download_fn=env._modal_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = daytona_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||
upload_fn=env._daytona_upload,
|
||||
delete_fn=env._daytona_delete,
|
||||
bulk_upload_fn=env._daytona_bulk_upload,
|
||||
bulk_download_fn=env._daytona_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
|
@ -134,6 +134,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
upload_fn=self._daytona_upload,
|
||||
delete_fn=self._daytona_delete,
|
||||
bulk_upload_fn=self._daytona_bulk_upload,
|
||||
bulk_download_fn=self._daytona_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
|
@ -166,6 +167,19 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
]
|
||||
self._sandbox.fs.upload_files(uploads)
|
||||
|
||||
def _daytona_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
self._sandbox.process.exec(
|
||||
f"tar cf /tmp/.hermes_sync.tar -C / {shlex.quote(rel_base)}"
|
||||
)
|
||||
self._sandbox.fs.download_file("/tmp/.hermes_sync.tar", str(dest))
|
||||
# Clean up remote temp file
|
||||
try:
|
||||
self._sandbox.process.exec("rm -f /tmp/.hermes_sync.tar")
|
||||
except Exception:
|
||||
pass # best-effort cleanup
|
||||
|
||||
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via SDK exec."""
|
||||
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
||||
|
|
@ -213,6 +227,10 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
if self._sync_manager:
|
||||
logger.info("Daytona: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
with self._lock:
|
||||
if self._sandbox is None:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -6,13 +6,25 @@ and Daytona. Docker and Singularity use bind mounts (live host FS
|
|||
view) and don't need this.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
fcntl = None # Windows — file locking skipped
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import _file_mtime_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +35,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
|
|||
# Transport callbacks provided by each backend
|
||||
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
||||
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
||||
BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure
|
||||
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
||||
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
||||
|
||||
|
|
@ -71,6 +84,19 @@ def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
|||
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||
|
||||
|
||||
def _sha256_file(path: str) -> str:
|
||||
"""Return hex SHA-256 digest of a file."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
_SYNC_BACK_MAX_RETRIES = 3
|
||||
_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries
|
||||
|
||||
|
||||
class FileSyncManager:
|
||||
"""Tracks local file changes and syncs to a remote environment.
|
||||
|
||||
|
|
@ -89,12 +115,15 @@ class FileSyncManager:
|
|||
delete_fn: DeleteFn,
|
||||
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
||||
bulk_upload_fn: BulkUploadFn | None = None,
|
||||
bulk_download_fn: BulkDownloadFn | None = None,
|
||||
):
|
||||
self._get_files_fn = get_files_fn
|
||||
self._upload_fn = upload_fn
|
||||
self._bulk_upload_fn = bulk_upload_fn
|
||||
self._bulk_download_fn = bulk_download_fn
|
||||
self._delete_fn = delete_fn
|
||||
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
||||
self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest
|
||||
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
||||
self._sync_interval = sync_interval
|
||||
|
||||
|
|
@ -136,6 +165,7 @@ class FileSyncManager:
|
|||
|
||||
# Snapshot for rollback (only when there's work to do)
|
||||
prev_files = dict(self._synced_files)
|
||||
prev_hashes = dict(self._pushed_hashes)
|
||||
|
||||
if to_upload:
|
||||
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
||||
|
|
@ -156,13 +186,187 @@ class FileSyncManager:
|
|||
logger.debug("file_sync: deleted %s", to_delete)
|
||||
|
||||
# --- Commit (all succeeded) ---
|
||||
for host_path, remote_path in to_upload:
|
||||
self._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||
|
||||
for p in to_delete:
|
||||
new_files.pop(p, None)
|
||||
self._pushed_hashes.pop(p, None)
|
||||
|
||||
self._synced_files = new_files
|
||||
self._last_sync_time = time.monotonic()
|
||||
|
||||
except Exception as exc:
|
||||
self._synced_files = prev_files
|
||||
self._pushed_hashes = prev_hashes
|
||||
self._last_sync_time = time.monotonic()
|
||||
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync-back: pull remote changes to host on teardown
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def sync_back(self, hermes_home: Path | None = None) -> None:
|
||||
"""Pull remote changes back to the host filesystem.
|
||||
|
||||
Downloads the remote ``.hermes/`` directory as a tar archive,
|
||||
unpacks it, and applies only files that differ from what was
|
||||
originally pushed (based on SHA-256 content hashes).
|
||||
|
||||
Protected against SIGINT (defers the signal until complete) and
|
||||
serialized across concurrent gateway sandboxes via file lock.
|
||||
"""
|
||||
if self._bulk_download_fn is None:
|
||||
return
|
||||
|
||||
lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_SYNC_BACK_MAX_RETRIES):
|
||||
try:
|
||||
self._sync_back_once(lock_path)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _SYNC_BACK_MAX_RETRIES - 1:
|
||||
delay = _SYNC_BACK_BACKOFF[attempt]
|
||||
logger.warning(
|
||||
"sync_back: attempt %d failed (%s), retrying in %ds",
|
||||
attempt + 1, exc, delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
|
||||
|
||||
def _sync_back_once(self, lock_path: Path) -> None:
|
||||
"""Single sync-back attempt with SIGINT protection and file lock."""
|
||||
# signal.signal() only works from the main thread. In gateway
|
||||
# contexts cleanup() may run from a worker thread — skip SIGINT
|
||||
# deferral there rather than crashing.
|
||||
on_main_thread = threading.current_thread() is threading.main_thread()
|
||||
|
||||
deferred_sigint: list[object] = []
|
||||
original_handler = None
|
||||
if on_main_thread:
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def _defer_sigint(signum, frame):
|
||||
deferred_sigint.append((signum, frame))
|
||||
logger.debug("sync_back: SIGINT deferred until sync completes")
|
||||
|
||||
signal.signal(signal.SIGINT, _defer_sigint)
|
||||
try:
|
||||
self._sync_back_locked(lock_path)
|
||||
finally:
|
||||
if on_main_thread and original_handler is not None:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if deferred_sigint:
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
def _sync_back_locked(self, lock_path: Path) -> None:
|
||||
"""Sync-back under file lock (serializes concurrent gateways)."""
|
||||
if fcntl is None:
|
||||
# Windows: no flock — run without serialization
|
||||
self._sync_back_impl()
|
||||
return
|
||||
lock_fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||
self._sync_back_impl()
|
||||
finally:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
lock_fd.close()
|
||||
|
||||
def _sync_back_impl(self) -> None:
|
||||
"""Download, diff, and apply remote changes to host."""
|
||||
if self._bulk_download_fn is None:
|
||||
raise RuntimeError("_sync_back_impl called without bulk_download_fn")
|
||||
|
||||
# Cache file mapping once to avoid O(n*m) from repeated iteration
|
||||
try:
|
||||
file_mapping = list(self._get_files_fn())
|
||||
except Exception:
|
||||
file_mapping = []
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
|
||||
self._bulk_download_fn(Path(tf.name))
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
|
||||
with tarfile.open(tf.name) as tar:
|
||||
tar.extractall(staging, filter="data")
|
||||
|
||||
applied = 0
|
||||
for dirpath, _dirnames, filenames in os.walk(staging):
|
||||
for fname in filenames:
|
||||
staged_file = os.path.join(dirpath, fname)
|
||||
rel = os.path.relpath(staged_file, staging)
|
||||
remote_path = "/" + rel
|
||||
|
||||
pushed_hash = self._pushed_hashes.get(remote_path)
|
||||
|
||||
# Skip hashing for files unchanged from push
|
||||
if pushed_hash is not None:
|
||||
remote_hash = _sha256_file(staged_file)
|
||||
if remote_hash == pushed_hash:
|
||||
continue
|
||||
else:
|
||||
remote_hash = None # new remote file
|
||||
|
||||
# Resolve host path from cached mapping
|
||||
host_path = self._resolve_host_path(remote_path, file_mapping)
|
||||
if host_path is None:
|
||||
host_path = self._infer_host_path(remote_path, file_mapping)
|
||||
if host_path is None:
|
||||
logger.debug(
|
||||
"sync_back: skipping %s (no host mapping)",
|
||||
remote_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if os.path.exists(host_path) and pushed_hash is not None:
|
||||
host_hash = _sha256_file(host_path)
|
||||
if host_hash != pushed_hash:
|
||||
logger.warning(
|
||||
"sync_back: conflict on %s — host modified "
|
||||
"since push, remote also changed. Applying "
|
||||
"remote version (last-write-wins).",
|
||||
remote_path,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||
shutil.copy2(staged_file, host_path)
|
||||
applied += 1
|
||||
|
||||
if applied:
|
||||
logger.info("sync_back: applied %d changed file(s)", applied)
|
||||
else:
|
||||
logger.debug("sync_back: no remote changes detected")
|
||||
|
||||
def _resolve_host_path(self, remote_path: str,
|
||||
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||
"""Find the host path for a known remote path from the file mapping."""
|
||||
mapping = file_mapping if file_mapping is not None else []
|
||||
for host, remote in mapping:
|
||||
if remote == remote_path:
|
||||
return host
|
||||
return None
|
||||
|
||||
def _infer_host_path(self, remote_path: str,
|
||||
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||
"""Infer a host path for a new remote file by matching path prefixes.
|
||||
|
||||
Uses the existing file mapping to find a remote->host directory
|
||||
pair, then applies the same prefix substitution to the new file.
|
||||
For example, if the mapping has ``/root/.hermes/skills/a.md`` →
|
||||
``~/.hermes/skills/a.md``, a new remote file at
|
||||
``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
|
||||
"""
|
||||
mapping = file_mapping if file_mapping is not None else []
|
||||
for host, remote in mapping:
|
||||
remote_dir = str(Path(remote).parent)
|
||||
if remote_path.startswith(remote_dir + "/"):
|
||||
host_dir = str(Path(host).parent)
|
||||
suffix = remote_path[len(remote_dir):]
|
||||
return host_dir + suffix
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
upload_fn=self._modal_upload,
|
||||
delete_fn=self._modal_delete,
|
||||
bulk_upload_fn=self._modal_bulk_upload,
|
||||
bulk_download_fn=self._modal_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
|
@ -347,6 +348,27 @@ class ModalEnvironment(BaseEnvironment):
|
|||
|
||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||
|
||||
def _modal_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive.
|
||||
|
||||
Modal sandboxes always run as root, so /root/.hermes is hardcoded
|
||||
(consistent with iter_sync_files call on line 269).
|
||||
"""
|
||||
async def _download():
|
||||
proc = await self._sandbox.exec.aio(
|
||||
"bash", "-c", "tar cf - -C / root/.hermes"
|
||||
)
|
||||
data = await proc.stdout.read.aio()
|
||||
exit_code = await proc.wait.aio()
|
||||
if exit_code != 0:
|
||||
raise RuntimeError(f"Modal bulk download failed (exit {exit_code})")
|
||||
return data
|
||||
|
||||
tar_bytes = self._worker.run_coroutine(_download(), timeout=120)
|
||||
if isinstance(tar_bytes, str):
|
||||
tar_bytes = tar_bytes.encode()
|
||||
dest.write_bytes(tar_bytes)
|
||||
|
||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via exec."""
|
||||
rm_cmd = quoted_rm_command(remote_paths)
|
||||
|
|
@ -404,6 +426,10 @@ class ModalEnvironment(BaseEnvironment):
|
|||
if self._sandbox is None:
|
||||
return
|
||||
|
||||
if self._sync_manager:
|
||||
logger.info("Modal: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self._persistent:
|
||||
try:
|
||||
async def _snapshot():
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||
upload_fn=self._scp_upload,
|
||||
delete_fn=self._ssh_delete,
|
||||
bulk_upload_fn=self._ssh_bulk_upload,
|
||||
bulk_download_fn=self._ssh_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
|
||||
|
|
@ -216,6 +217,18 @@ class SSHEnvironment(BaseEnvironment):
|
|||
|
||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||
|
||||
def _ssh_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
# Tar from / with the full path so archive entries preserve absolute
|
||||
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
ssh_cmd = self._build_ssh_command()
|
||||
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
||||
with open(dest, "wb") as f:
|
||||
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
||||
|
||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files in one SSH call."""
|
||||
cmd = self._build_ssh_command()
|
||||
|
|
@ -245,6 +258,10 @@ class SSHEnvironment(BaseEnvironment):
|
|||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
if self._sync_manager:
|
||||
logger.info("SSH: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue