mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(file-sync): sync remote changes back to host on teardown
Salvage of PR #8018 by @alt-glitch onto current main. On sandbox teardown, FileSyncManager now downloads the remote .hermes/ directory, diffs against SHA-256 hashes of what was originally pushed, and applies only changed files back to the host. Core (tools/environments/file_sync.py): - sync_back(): orchestrates download -> unpack -> diff -> apply with: - Retry with exponential backoff (3 attempts, 2s/4s/8s) - SIGINT trap + defer (prevents partial writes on Ctrl-C) - fcntl.flock serialization (concurrent gateway sandboxes) - Last-write-wins conflict resolution with warning - New remote files pulled back via _infer_host_path prefix matching Backends: - SSH: _ssh_bulk_download — tar cf - piped over SSH - Modal: _modal_bulk_download — exec tar cf - -> proc.stdout.read - Daytona: _daytona_bulk_download — exec tar cf -> SDK download_file - All three call sync_back() at the top of cleanup() Fixes applied during salvage (vs original PR #8018): | # | Issue | Fix | |---|-------|-----| | C1 | import fcntl unconditional — crashes Windows | try/except with fallback; _sync_back_locked skips locking when fcntl=None | | W1 | assert for runtime guard (stripped by -O) | Replaced with proper if/raise RuntimeError | | W2 | O(n*m) from _get_files_fn() called per file | Cache mapping once at start of _sync_back_impl, pass to resolve/infer | | W3 | Dead BulkDownloadFn imports in 3 backends | Removed unused imports | | W4 | Modal hardcodes root/.hermes, no explanation | Added docstring comment explaining Modal always runs as root | | S1 | SHA-256 computed for new files where pushed_hash=None | Skip hashing when pushed_hash is None (comparison always False) | | S2 | Daytona /tmp/.hermes_sync.tar never cleaned up | Added rm -f after download (best-effort) | Tests: 49 passing (17 new: _infer_host_path edge cases, SIGINT main/worker thread, Windows fcntl=None fallback, Daytona tar cleanup). Based on #8018 by @alt-glitch.
This commit is contained in:
parent
764536b684
commit
d64446e315
6 changed files with 1166 additions and 0 deletions
412
tests/tools/test_file_sync_back.py
Normal file
412
tests/tools/test_file_sync_back.py
Normal file
|
|
@ -0,0 +1,412 @@
|
||||||
|
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||||
|
|
||||||
|
import fcntl
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments.file_sync import (
|
||||||
|
FileSyncManager,
|
||||||
|
_sha256_file,
|
||||||
|
_SYNC_BACK_BACKOFF,
|
||||||
|
_SYNC_BACK_MAX_RETRIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||||
|
"""Write a tar archive containing the given arcname->content pairs."""
|
||||||
|
with tarfile.open(dest, "w") as tar:
|
||||||
|
for arcname, content in files.items():
|
||||||
|
info = tarfile.TarInfo(name=arcname)
|
||||||
|
info.size = len(content)
|
||||||
|
tar.addfile(info, io.BytesIO(content))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_download_fn(files: dict[str, bytes]):
|
||||||
|
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||||
|
def download(dest: Path):
|
||||||
|
_make_tar(files, dest)
|
||||||
|
return download
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_bytes(data: bytes) -> str:
|
||||||
|
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||||
|
import hashlib
|
||||||
|
return hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _write_file(path: Path, content: bytes) -> str:
|
||||||
|
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(content)
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(
|
||||||
|
tmp_path: Path,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None,
|
||||||
|
bulk_download_fn=None,
|
||||||
|
) -> FileSyncManager:
|
||||||
|
"""Create a FileSyncManager wired for testing.
|
||||||
|
|
||||||
|
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||||
|
``get_files_fn`` returns. If *None* an empty list is used.
|
||||||
|
"""
|
||||||
|
mapping = file_mapping or []
|
||||||
|
return FileSyncManager(
|
||||||
|
get_files_fn=lambda: mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
bulk_download_fn=bulk_download_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNoop:
|
||||||
|
"""sync_back() is a no-op when there is no download function."""
|
||||||
|
|
||||||
|
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||||
|
# Should return immediately without error
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
# Nothing to assert beyond "no exception raised"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNoChanges:
|
||||||
|
"""When all remote files match pushed hashes, nothing is applied."""
|
||||||
|
|
||||||
|
def test_sync_back_no_changes(self, tmp_path):
|
||||||
|
host_file = tmp_path / "host" / "cred.json"
|
||||||
|
host_content = b'{"key": "val"}'
|
||||||
|
_write_file(host_file, host_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/cred.json"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
# Remote tar contains the same content as was pushed
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/cred.json": host_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
# Simulate that we already pushed this file with this hash
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# Host file should be unchanged (same content, same bytes)
|
||||||
|
assert host_file.read_bytes() == host_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackAppliesChanged:
|
||||||
|
"""Remote file differs from pushed version -- gets copied to host."""
|
||||||
|
|
||||||
|
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||||
|
host_file = tmp_path / "host" / "skill.py"
|
||||||
|
original_content = b"print('v1')"
|
||||||
|
_write_file(host_file, original_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/skill.py"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
remote_content = b"print('v2 - edited on remote')"
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/skill.py": remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
assert host_file.read_bytes() == remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNewRemoteFile:
|
||||||
|
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||||
|
|
||||||
|
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||||
|
# Existing mapping gives _infer_host_path a prefix to work with
|
||||||
|
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||||
|
_write_file(existing_host, b"existing")
|
||||||
|
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||||
|
|
||||||
|
# Remote has a NEW file in the same directory that was never pushed
|
||||||
|
new_remote_content = b"# brand new skill created on remote"
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
# No entry in _pushed_hashes for the new file
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# The new file should have been inferred and written to the host
|
||||||
|
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||||
|
assert expected_host_path.exists()
|
||||||
|
assert expected_host_path.read_bytes() == new_remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackConflict:
|
||||||
|
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||||
|
|
||||||
|
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||||
|
host_file = tmp_path / "host" / "config.json"
|
||||||
|
original_content = b'{"v": 1}'
|
||||||
|
_write_file(host_file, original_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/config.json"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
# Host was modified after push
|
||||||
|
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||||
|
|
||||||
|
# Remote was also modified
|
||||||
|
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/config.json": remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# Conflict warning was logged
|
||||||
|
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
# Remote version wins (last-write-wins)
|
||||||
|
assert host_file.read_bytes() == remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackRetries:
|
||||||
|
"""Retry behaviour with exponential backoff."""
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.time.sleep")
|
||||||
|
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def flaky_download(dest: Path):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise RuntimeError(f"network error #{call_count}")
|
||||||
|
# Third attempt succeeds -- write a valid (empty) tar
|
||||||
|
_make_tar({}, dest)
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
assert call_count == 3
|
||||||
|
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||||
|
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.time.sleep")
|
||||||
|
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||||
|
def always_fail(dest: Path):
|
||||||
|
raise RuntimeError("persistent failure")
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||||
|
# Should NOT raise -- failures are logged, not propagated
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# All retries were attempted
|
||||||
|
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||||
|
|
||||||
|
# Final "all attempts failed" warning was logged
|
||||||
|
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPushedHashesPopulated:
|
||||||
|
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||||
|
|
||||||
|
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||||
|
host_file = tmp_path / "data.txt"
|
||||||
|
host_file.write_bytes(b"hello world")
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/data.txt"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
mgr = FileSyncManager(
|
||||||
|
get_files_fn=lambda: mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr.sync(force=True)
|
||||||
|
|
||||||
|
assert remote_path in mgr._pushed_hashes
|
||||||
|
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||||
|
|
||||||
|
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||||
|
host_file = tmp_path / "deleteme.txt"
|
||||||
|
host_file.write_bytes(b"to be deleted")
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/deleteme.txt"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
current_mapping = list(mapping)
|
||||||
|
|
||||||
|
mgr = FileSyncManager(
|
||||||
|
get_files_fn=lambda: current_mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync to populate hashes
|
||||||
|
mgr.sync(force=True)
|
||||||
|
assert remote_path in mgr._pushed_hashes
|
||||||
|
|
||||||
|
# Remove the file from the mapping (simulates local deletion)
|
||||||
|
os.unlink(str(host_file))
|
||||||
|
current_mapping.clear()
|
||||||
|
|
||||||
|
mgr.sync(force=True)
|
||||||
|
|
||||||
|
# Hash should be cleaned up
|
||||||
|
assert remote_path not in mgr._pushed_hashes
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackFileLock:
|
||||||
|
"""Verify that fcntl.flock is used during sync-back."""
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.fcntl.flock")
|
||||||
|
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||||
|
assert mock_flock.call_count >= 2
|
||||||
|
|
||||||
|
lock_calls = mock_flock.call_args_list
|
||||||
|
lock_ops = [c[0][1] for c in lock_calls]
|
||||||
|
assert fcntl.LOCK_EX in lock_ops
|
||||||
|
assert fcntl.LOCK_UN in lock_ops
|
||||||
|
|
||||||
|
def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path):
|
||||||
|
"""On Windows (fcntl=None), sync_back should skip file locking."""
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.fcntl", None):
|
||||||
|
# Should not raise — locking is skipped
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferHostPath:
|
||||||
|
"""Edge cases for _infer_host_path prefix matching."""
|
||||||
|
|
||||||
|
def test_infer_no_matching_prefix(self, tmp_path):
|
||||||
|
"""Remote path in unmapped directory should return None."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/cache/new.json",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_infer_partial_prefix_no_false_match(self, tmp_path):
|
||||||
|
"""A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
# /root/.hermes/skillsXtra/b.py shares prefix "skills" but the
|
||||||
|
# directory is different — should not match /root/.hermes/skills/
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/skillsXtra/b.py",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_infer_matching_prefix(self, tmp_path):
|
||||||
|
"""A file in a mapped directory should be correctly inferred."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/skills/b.py",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
expected = str(tmp_path / "host" / "skills" / "b.py")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackSIGINT:
|
||||||
|
"""SIGINT deferral during sync-back."""
|
||||||
|
|
||||||
|
def test_sync_back_defers_sigint_on_main_thread(self, tmp_path):
|
||||||
|
"""On the main thread, SIGINT handler should be swapped during sync."""
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
handlers_seen = []
|
||||||
|
original_getsignal = signal.getsignal
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.signal.getsignal",
|
||||||
|
side_effect=original_getsignal) as mock_get, \
|
||||||
|
patch("tools.environments.file_sync.signal.signal") as mock_set:
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# signal.getsignal was called to save the original handler
|
||||||
|
assert mock_get.called
|
||||||
|
# signal.signal was called at least twice: install defer, restore original
|
||||||
|
assert mock_set.call_count >= 2
|
||||||
|
|
||||||
|
def test_sync_back_skips_signal_on_worker_thread(self, tmp_path):
|
||||||
|
"""From a non-main thread, signal.signal should NOT be called."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
signal_called = []
|
||||||
|
|
||||||
|
def tracking_signal(*args):
|
||||||
|
signal_called.append(args)
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal):
|
||||||
|
# Run from a worker thread
|
||||||
|
exc = []
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
except Exception as e:
|
||||||
|
exc.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=run)
|
||||||
|
t.start()
|
||||||
|
t.join(timeout=10)
|
||||||
|
|
||||||
|
assert not exc, f"sync_back raised: {exc}"
|
||||||
|
# signal.signal should NOT have been called from the worker thread
|
||||||
|
assert len(signal_called) == 0
|
||||||
489
tests/tools/test_sync_back_backends.py
Normal file
489
tests/tools/test_sync_back_backends.py
Normal file
|
|
@ -0,0 +1,489 @@
|
||||||
|
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments import ssh as ssh_env
|
||||||
|
from tools.environments import modal as modal_env
|
||||||
|
from tools.environments import daytona as daytona_env
|
||||||
|
from tools.environments.ssh import SSHEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ssh_mock_env(monkeypatch):
|
||||||
|
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ssh_env, "FileSyncManager",
|
||||||
|
lambda **kw: type("M", (), {
|
||||||
|
"sync": lambda self, **k: None,
|
||||||
|
"sync_back": lambda self: None,
|
||||||
|
})(),
|
||||||
|
)
|
||||||
|
return SSHEnvironment(host="example.com", user="testuser")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Modal helpers ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_modal_env():
|
||||||
|
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
env._sync_manager = None
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||||
|
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||||
|
|
||||||
|
Returns the exec_calls list for assertion.
|
||||||
|
"""
|
||||||
|
exec_calls = []
|
||||||
|
|
||||||
|
async def mock_exec_fn(*args, **kwargs):
|
||||||
|
exec_calls.append(args)
|
||||||
|
proc = MagicMock()
|
||||||
|
proc.stdout = MagicMock()
|
||||||
|
proc.stdout.read = MagicMock()
|
||||||
|
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||||
|
proc.wait = MagicMock()
|
||||||
|
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
env._sandbox.exec = MagicMock()
|
||||||
|
env._sandbox.exec.aio = mock_exec_fn
|
||||||
|
|
||||||
|
def real_run_coroutine(coro, **kwargs):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
env._worker.run_coroutine = real_run_coroutine
|
||||||
|
return exec_calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_daytona_env():
|
||||||
|
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||||
|
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._remote_home = "/root"
|
||||||
|
env._sync_manager = None
|
||||||
|
env._lock = __import__("threading").Lock()
|
||||||
|
env._persistent = True
|
||||||
|
env._task_id = "test"
|
||||||
|
env._daytona = MagicMock()
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# SSH bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHBulkDownload:
|
||||||
|
"""Unit tests for _ssh_bulk_download."""
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||||
|
"""subprocess.run command should include tar cf - over SSH."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
cmd = mock_run.call_args[0][0]
|
||||||
|
cmd_str = " ".join(cmd)
|
||||||
|
assert "tar cf -" in cmd_str
|
||||||
|
assert "-C /" in cmd_str
|
||||||
|
assert "home/testuser/.hermes" in cmd_str
|
||||||
|
assert "ssh" in cmd_str
|
||||||
|
assert "testuser@example.com" in cmd_str
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||||
|
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
# The stdout kwarg should be a file object opened for writing
|
||||||
|
call_kwargs = mock_run.call_args
|
||||||
|
# stdout is passed as a keyword arg
|
||||||
|
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||||
|
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||||
|
# After the context manager exits, the file is closed, but we can verify
|
||||||
|
# the dest path was used by checking if the file was created.
|
||||||
|
assert dest.exists()
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||||
|
"""Non-zero returncode should raise RuntimeError."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||||
|
with patch.object(subprocess, "run", return_value=failed):
|
||||||
|
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||||
|
"""The subprocess.run call should use a 120s timeout."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
call_kwargs = mock_run.call_args
|
||||||
|
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHCleanup:
|
||||||
|
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||||
|
|
||||||
|
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||||
|
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
class TrackingSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_back(self):
|
||||||
|
call_order.append("sync_back")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||||
|
|
||||||
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||||
|
env.control_socket = Path("/nonexistent/socket")
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
|
||||||
|
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||||
|
"""sync_back() must run before the ControlMaster exit command."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
class TrackingSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_back(self):
|
||||||
|
call_order.append("sync_back")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||||
|
|
||||||
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
|
||||||
|
# Create a fake control socket so cleanup tries the SSH exit
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||||
|
env.control_socket = Path(tmp.name)
|
||||||
|
|
||||||
|
def mock_run(cmd, **kwargs):
|
||||||
|
cmd_str = " ".join(cmd)
|
||||||
|
if "-O" in cmd and "exit" in cmd_str:
|
||||||
|
call_order.append("control_exit")
|
||||||
|
return subprocess.CompletedProcess([], 0)
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Modal bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalBulkDownload:
|
||||||
|
"""Unit tests for _modal_bulk_download."""
|
||||||
|
|
||||||
|
def test_modal_bulk_download_command(self, tmp_path):
|
||||||
|
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert len(exec_calls) == 1
|
||||||
|
args = exec_calls[0]
|
||||||
|
assert args[0] == "bash"
|
||||||
|
assert args[1] == "-c"
|
||||||
|
assert "tar cf -" in args[2]
|
||||||
|
assert "-C / root/.hermes" in args[2]
|
||||||
|
|
||||||
|
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||||
|
"""Downloaded tar bytes should be written to the dest path."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
expected_data = b"some-tar-archive-bytes"
|
||||||
|
_wire_modal_download(env, tar_bytes=expected_data)
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert dest.exists()
|
||||||
|
assert dest.read_bytes() == expected_data
|
||||||
|
|
||||||
|
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||||
|
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
# Simulate Modal SDK returning str
|
||||||
|
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert dest.read_bytes() == b"string-tar-data"
|
||||||
|
|
||||||
|
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||||
|
"""Non-zero exit code should raise RuntimeError."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
_wire_modal_download(env, exit_code=1)
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||||
|
"""run_coroutine should be called with timeout=120."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
_wire_modal_download(env, tar_bytes=b"data")
|
||||||
|
|
||||||
|
run_kwargs = {}
|
||||||
|
original_run = env._worker.run_coroutine
|
||||||
|
|
||||||
|
def tracking_run(coro, **kwargs):
|
||||||
|
run_kwargs.update(kwargs)
|
||||||
|
return original_run(coro, **kwargs)
|
||||||
|
|
||||||
|
env._worker.run_coroutine = tracking_run
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert run_kwargs.get("timeout") == 120
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalCleanup:
|
||||||
|
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||||
|
|
||||||
|
def test_modal_cleanup_calls_sync_back(self):
|
||||||
|
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
sync_mgr = MagicMock()
|
||||||
|
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||||
|
env._sync_manager = sync_mgr
|
||||||
|
|
||||||
|
# Mock terminate to track call order
|
||||||
|
async def mock_terminate():
|
||||||
|
pass
|
||||||
|
|
||||||
|
env._sandbox.terminate = MagicMock()
|
||||||
|
env._sandbox.terminate.aio = mock_terminate
|
||||||
|
env._worker.run_coroutine = lambda coro, **kw: (
|
||||||
|
call_order.append("terminate"),
|
||||||
|
asyncio.new_event_loop().run_until_complete(coro),
|
||||||
|
)
|
||||||
|
env._worker.stop = lambda: None
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Daytona bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDaytonaBulkDownload:
|
||||||
|
"""Unit tests for _daytona_bulk_download."""
|
||||||
|
|
||||||
|
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||||
|
"""exec and download_file should both be called."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._daytona_bulk_download(dest)
|
||||||
|
|
||||||
|
# exec called twice: tar creation + rm cleanup
|
||||||
|
assert env._sandbox.process.exec.call_count == 2
|
||||||
|
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||||
|
assert "tar cf" in tar_cmd
|
||||||
|
assert "/tmp/.hermes_sync.tar" in tar_cmd
|
||||||
|
assert ".hermes" in tar_cmd
|
||||||
|
|
||||||
|
cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0]
|
||||||
|
assert "rm -f /tmp/.hermes_sync.tar" in cleanup_cmd
|
||||||
|
|
||||||
|
env._sandbox.fs.download_file.assert_called_once_with(
|
||||||
|
"/tmp/.hermes_sync.tar", str(dest)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||||
|
"""The tar command should use the env's _remote_home."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
env._remote_home = "/home/daytona"
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._daytona_bulk_download(dest)
|
||||||
|
|
||||||
|
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||||
|
assert "home/daytona/.hermes" in tar_cmd
|
||||||
|
|
||||||
|
|
||||||
|
class TestDaytonaCleanup:
|
||||||
|
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||||
|
|
||||||
|
def test_daytona_cleanup_calls_sync_back(self):
|
||||||
|
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
sync_mgr = MagicMock()
|
||||||
|
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||||
|
env._sync_manager = sync_mgr
|
||||||
|
env._sandbox.stop = lambda: call_order.append("stop")
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
assert "stop" in call_order
|
||||||
|
assert call_order.index("sync_back") < call_order.index("stop")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkDownloadWiring:
|
||||||
|
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||||
|
|
||||||
|
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
class CaptureSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||||
|
|
||||||
|
SSHEnvironment(host="h", user="u")
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
||||||
|
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def capture_fsm(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return type("M", (), {"sync": lambda self, **k: None})()
|
||||||
|
|
||||||
|
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||||
|
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
|
||||||
|
# Replicate the wiring done in __init__
|
||||||
|
from tools.environments.file_sync import iter_sync_files
|
||||||
|
env._sync_manager = modal_env.FileSyncManager(
|
||||||
|
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||||
|
upload_fn=env._modal_upload,
|
||||||
|
delete_fn=env._modal_delete,
|
||||||
|
bulk_upload_fn=env._modal_bulk_upload,
|
||||||
|
bulk_download_fn=env._modal_bulk_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
||||||
|
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def capture_fsm(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return type("M", (), {"sync": lambda self, **k: None})()
|
||||||
|
|
||||||
|
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||||
|
|
||||||
|
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._remote_home = "/root"
|
||||||
|
env._lock = __import__("threading").Lock()
|
||||||
|
env._persistent = True
|
||||||
|
env._task_id = "test"
|
||||||
|
env._daytona = MagicMock()
|
||||||
|
|
||||||
|
# Replicate the wiring done in __init__
|
||||||
|
from tools.environments.file_sync import iter_sync_files
|
||||||
|
env._sync_manager = daytona_env.FileSyncManager(
|
||||||
|
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||||
|
upload_fn=env._daytona_upload,
|
||||||
|
delete_fn=env._daytona_delete,
|
||||||
|
bulk_upload_fn=env._daytona_bulk_upload,
|
||||||
|
bulk_download_fn=env._daytona_bulk_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
@ -134,6 +134,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._daytona_upload,
|
upload_fn=self._daytona_upload,
|
||||||
delete_fn=self._daytona_delete,
|
delete_fn=self._daytona_delete,
|
||||||
bulk_upload_fn=self._daytona_bulk_upload,
|
bulk_upload_fn=self._daytona_bulk_upload,
|
||||||
|
bulk_download_fn=self._daytona_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
self.init_session()
|
self.init_session()
|
||||||
|
|
@ -166,6 +167,19 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
]
|
]
|
||||||
self._sandbox.fs.upload_files(uploads)
|
self._sandbox.fs.upload_files(uploads)
|
||||||
|
|
||||||
|
def _daytona_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive."""
|
||||||
|
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||||
|
self._sandbox.process.exec(
|
||||||
|
f"tar cf /tmp/.hermes_sync.tar -C / {shlex.quote(rel_base)}"
|
||||||
|
)
|
||||||
|
self._sandbox.fs.download_file("/tmp/.hermes_sync.tar", str(dest))
|
||||||
|
# Clean up remote temp file
|
||||||
|
try:
|
||||||
|
self._sandbox.process.exec("rm -f /tmp/.hermes_sync.tar")
|
||||||
|
except Exception:
|
||||||
|
pass # best-effort cleanup
|
||||||
|
|
||||||
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files via SDK exec."""
|
"""Batch-delete remote files via SDK exec."""
|
||||||
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
||||||
|
|
@ -213,6 +227,10 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("Daytona: syncing files from sandbox...")
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,25 @@ and Daytona. Docker and Singularity use bind mounts (live host FS
|
||||||
view) and don't need this.
|
view) and don't need this.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fcntl
|
||||||
|
except ImportError:
|
||||||
|
fcntl = None # Windows — file locking skipped
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from hermes_constants import get_hermes_home
|
||||||
from tools.environments.base import _file_mtime_key
|
from tools.environments.base import _file_mtime_key
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +35,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
|
||||||
# Transport callbacks provided by each backend
|
# Transport callbacks provided by each backend
|
||||||
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
||||||
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
||||||
|
BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure
|
||||||
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
||||||
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
||||||
|
|
||||||
|
|
@ -71,6 +84,19 @@ def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
||||||
return sorted({str(Path(remote).parent) for _, remote in files})
|
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_file(path: str) -> str:
|
||||||
|
"""Return hex SHA-256 digest of a file."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(65536), b""):
|
||||||
|
h.update(chunk)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
_SYNC_BACK_MAX_RETRIES = 3
|
||||||
|
_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries
|
||||||
|
|
||||||
|
|
||||||
class FileSyncManager:
|
class FileSyncManager:
|
||||||
"""Tracks local file changes and syncs to a remote environment.
|
"""Tracks local file changes and syncs to a remote environment.
|
||||||
|
|
||||||
|
|
@ -89,12 +115,15 @@ class FileSyncManager:
|
||||||
delete_fn: DeleteFn,
|
delete_fn: DeleteFn,
|
||||||
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
||||||
bulk_upload_fn: BulkUploadFn | None = None,
|
bulk_upload_fn: BulkUploadFn | None = None,
|
||||||
|
bulk_download_fn: BulkDownloadFn | None = None,
|
||||||
):
|
):
|
||||||
self._get_files_fn = get_files_fn
|
self._get_files_fn = get_files_fn
|
||||||
self._upload_fn = upload_fn
|
self._upload_fn = upload_fn
|
||||||
self._bulk_upload_fn = bulk_upload_fn
|
self._bulk_upload_fn = bulk_upload_fn
|
||||||
|
self._bulk_download_fn = bulk_download_fn
|
||||||
self._delete_fn = delete_fn
|
self._delete_fn = delete_fn
|
||||||
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
||||||
|
self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest
|
||||||
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
||||||
self._sync_interval = sync_interval
|
self._sync_interval = sync_interval
|
||||||
|
|
||||||
|
|
@ -136,6 +165,7 @@ class FileSyncManager:
|
||||||
|
|
||||||
# Snapshot for rollback (only when there's work to do)
|
# Snapshot for rollback (only when there's work to do)
|
||||||
prev_files = dict(self._synced_files)
|
prev_files = dict(self._synced_files)
|
||||||
|
prev_hashes = dict(self._pushed_hashes)
|
||||||
|
|
||||||
if to_upload:
|
if to_upload:
|
||||||
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
||||||
|
|
@ -156,13 +186,187 @@ class FileSyncManager:
|
||||||
logger.debug("file_sync: deleted %s", to_delete)
|
logger.debug("file_sync: deleted %s", to_delete)
|
||||||
|
|
||||||
# --- Commit (all succeeded) ---
|
# --- Commit (all succeeded) ---
|
||||||
|
for host_path, remote_path in to_upload:
|
||||||
|
self._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||||
|
|
||||||
for p in to_delete:
|
for p in to_delete:
|
||||||
new_files.pop(p, None)
|
new_files.pop(p, None)
|
||||||
|
self._pushed_hashes.pop(p, None)
|
||||||
|
|
||||||
self._synced_files = new_files
|
self._synced_files = new_files
|
||||||
self._last_sync_time = time.monotonic()
|
self._last_sync_time = time.monotonic()
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._synced_files = prev_files
|
self._synced_files = prev_files
|
||||||
|
self._pushed_hashes = prev_hashes
|
||||||
self._last_sync_time = time.monotonic()
|
self._last_sync_time = time.monotonic()
|
||||||
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Sync-back: pull remote changes to host on teardown
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sync_back(self, hermes_home: Path | None = None) -> None:
|
||||||
|
"""Pull remote changes back to the host filesystem.
|
||||||
|
|
||||||
|
Downloads the remote ``.hermes/`` directory as a tar archive,
|
||||||
|
unpacks it, and applies only files that differ from what was
|
||||||
|
originally pushed (based on SHA-256 content hashes).
|
||||||
|
|
||||||
|
Protected against SIGINT (defers the signal until complete) and
|
||||||
|
serialized across concurrent gateway sandboxes via file lock.
|
||||||
|
"""
|
||||||
|
if self._bulk_download_fn is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
|
||||||
|
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(_SYNC_BACK_MAX_RETRIES):
|
||||||
|
try:
|
||||||
|
self._sync_back_once(lock_path)
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < _SYNC_BACK_MAX_RETRIES - 1:
|
||||||
|
delay = _SYNC_BACK_BACKOFF[attempt]
|
||||||
|
logger.warning(
|
||||||
|
"sync_back: attempt %d failed (%s), retrying in %ds",
|
||||||
|
attempt + 1, exc, delay,
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
|
||||||
|
|
||||||
|
def _sync_back_once(self, lock_path: Path) -> None:
|
||||||
|
"""Single sync-back attempt with SIGINT protection and file lock."""
|
||||||
|
# signal.signal() only works from the main thread. In gateway
|
||||||
|
# contexts cleanup() may run from a worker thread — skip SIGINT
|
||||||
|
# deferral there rather than crashing.
|
||||||
|
on_main_thread = threading.current_thread() is threading.main_thread()
|
||||||
|
|
||||||
|
deferred_sigint: list[object] = []
|
||||||
|
original_handler = None
|
||||||
|
if on_main_thread:
|
||||||
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
|
|
||||||
|
def _defer_sigint(signum, frame):
|
||||||
|
deferred_sigint.append((signum, frame))
|
||||||
|
logger.debug("sync_back: SIGINT deferred until sync completes")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, _defer_sigint)
|
||||||
|
try:
|
||||||
|
self._sync_back_locked(lock_path)
|
||||||
|
finally:
|
||||||
|
if on_main_thread and original_handler is not None:
|
||||||
|
signal.signal(signal.SIGINT, original_handler)
|
||||||
|
if deferred_sigint:
|
||||||
|
os.kill(os.getpid(), signal.SIGINT)
|
||||||
|
|
||||||
|
def _sync_back_locked(self, lock_path: Path) -> None:
|
||||||
|
"""Sync-back under file lock (serializes concurrent gateways)."""
|
||||||
|
if fcntl is None:
|
||||||
|
# Windows: no flock — run without serialization
|
||||||
|
self._sync_back_impl()
|
||||||
|
return
|
||||||
|
lock_fd = open(lock_path, "w")
|
||||||
|
try:
|
||||||
|
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||||
|
self._sync_back_impl()
|
||||||
|
finally:
|
||||||
|
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||||
|
lock_fd.close()
|
||||||
|
|
||||||
|
def _sync_back_impl(self) -> None:
|
||||||
|
"""Download, diff, and apply remote changes to host."""
|
||||||
|
if self._bulk_download_fn is None:
|
||||||
|
raise RuntimeError("_sync_back_impl called without bulk_download_fn")
|
||||||
|
|
||||||
|
# Cache file mapping once to avoid O(n*m) from repeated iteration
|
||||||
|
try:
|
||||||
|
file_mapping = list(self._get_files_fn())
|
||||||
|
except Exception:
|
||||||
|
file_mapping = []
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
|
||||||
|
self._bulk_download_fn(Path(tf.name))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
|
||||||
|
with tarfile.open(tf.name) as tar:
|
||||||
|
tar.extractall(staging, filter="data")
|
||||||
|
|
||||||
|
applied = 0
|
||||||
|
for dirpath, _dirnames, filenames in os.walk(staging):
|
||||||
|
for fname in filenames:
|
||||||
|
staged_file = os.path.join(dirpath, fname)
|
||||||
|
rel = os.path.relpath(staged_file, staging)
|
||||||
|
remote_path = "/" + rel
|
||||||
|
|
||||||
|
pushed_hash = self._pushed_hashes.get(remote_path)
|
||||||
|
|
||||||
|
# Skip hashing for files unchanged from push
|
||||||
|
if pushed_hash is not None:
|
||||||
|
remote_hash = _sha256_file(staged_file)
|
||||||
|
if remote_hash == pushed_hash:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
remote_hash = None # new remote file
|
||||||
|
|
||||||
|
# Resolve host path from cached mapping
|
||||||
|
host_path = self._resolve_host_path(remote_path, file_mapping)
|
||||||
|
if host_path is None:
|
||||||
|
host_path = self._infer_host_path(remote_path, file_mapping)
|
||||||
|
if host_path is None:
|
||||||
|
logger.debug(
|
||||||
|
"sync_back: skipping %s (no host mapping)",
|
||||||
|
remote_path,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(host_path) and pushed_hash is not None:
|
||||||
|
host_hash = _sha256_file(host_path)
|
||||||
|
if host_hash != pushed_hash:
|
||||||
|
logger.warning(
|
||||||
|
"sync_back: conflict on %s — host modified "
|
||||||
|
"since push, remote also changed. Applying "
|
||||||
|
"remote version (last-write-wins).",
|
||||||
|
remote_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||||
|
shutil.copy2(staged_file, host_path)
|
||||||
|
applied += 1
|
||||||
|
|
||||||
|
if applied:
|
||||||
|
logger.info("sync_back: applied %d changed file(s)", applied)
|
||||||
|
else:
|
||||||
|
logger.debug("sync_back: no remote changes detected")
|
||||||
|
|
||||||
|
def _resolve_host_path(self, remote_path: str,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||||
|
"""Find the host path for a known remote path from the file mapping."""
|
||||||
|
mapping = file_mapping if file_mapping is not None else []
|
||||||
|
for host, remote in mapping:
|
||||||
|
if remote == remote_path:
|
||||||
|
return host
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _infer_host_path(self, remote_path: str,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||||
|
"""Infer a host path for a new remote file by matching path prefixes.
|
||||||
|
|
||||||
|
Uses the existing file mapping to find a remote->host directory
|
||||||
|
pair, then applies the same prefix substitution to the new file.
|
||||||
|
For example, if the mapping has ``/root/.hermes/skills/a.md`` →
|
||||||
|
``~/.hermes/skills/a.md``, a new remote file at
|
||||||
|
``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
|
||||||
|
"""
|
||||||
|
mapping = file_mapping if file_mapping is not None else []
|
||||||
|
for host, remote in mapping:
|
||||||
|
remote_dir = str(Path(remote).parent)
|
||||||
|
if remote_path.startswith(remote_dir + "/"):
|
||||||
|
host_dir = str(Path(host).parent)
|
||||||
|
suffix = remote_path[len(remote_dir):]
|
||||||
|
return host_dir + suffix
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -269,6 +269,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._modal_upload,
|
upload_fn=self._modal_upload,
|
||||||
delete_fn=self._modal_delete,
|
delete_fn=self._modal_delete,
|
||||||
bulk_upload_fn=self._modal_bulk_upload,
|
bulk_upload_fn=self._modal_bulk_upload,
|
||||||
|
bulk_download_fn=self._modal_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
self.init_session()
|
self.init_session()
|
||||||
|
|
@ -347,6 +348,27 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
|
|
||||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||||
|
|
||||||
|
def _modal_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive.
|
||||||
|
|
||||||
|
Modal sandboxes always run as root, so /root/.hermes is hardcoded
|
||||||
|
(consistent with iter_sync_files call on line 269).
|
||||||
|
"""
|
||||||
|
async def _download():
|
||||||
|
proc = await self._sandbox.exec.aio(
|
||||||
|
"bash", "-c", "tar cf - -C / root/.hermes"
|
||||||
|
)
|
||||||
|
data = await proc.stdout.read.aio()
|
||||||
|
exit_code = await proc.wait.aio()
|
||||||
|
if exit_code != 0:
|
||||||
|
raise RuntimeError(f"Modal bulk download failed (exit {exit_code})")
|
||||||
|
return data
|
||||||
|
|
||||||
|
tar_bytes = self._worker.run_coroutine(_download(), timeout=120)
|
||||||
|
if isinstance(tar_bytes, str):
|
||||||
|
tar_bytes = tar_bytes.encode()
|
||||||
|
dest.write_bytes(tar_bytes)
|
||||||
|
|
||||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files via exec."""
|
"""Batch-delete remote files via exec."""
|
||||||
rm_cmd = quoted_rm_command(remote_paths)
|
rm_cmd = quoted_rm_command(remote_paths)
|
||||||
|
|
@ -404,6 +426,10 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("Modal: syncing files from sandbox...")
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
|
||||||
if self._persistent:
|
if self._persistent:
|
||||||
try:
|
try:
|
||||||
async def _snapshot():
|
async def _snapshot():
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._scp_upload,
|
upload_fn=self._scp_upload,
|
||||||
delete_fn=self._ssh_delete,
|
delete_fn=self._ssh_delete,
|
||||||
bulk_upload_fn=self._ssh_bulk_upload,
|
bulk_upload_fn=self._ssh_bulk_upload,
|
||||||
|
bulk_download_fn=self._ssh_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
|
|
||||||
|
|
@ -216,6 +217,18 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
|
|
||||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||||
|
|
||||||
|
def _ssh_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive."""
|
||||||
|
# Tar from / with the full path so archive entries preserve absolute
|
||||||
|
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
||||||
|
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||||
|
ssh_cmd = self._build_ssh_command()
|
||||||
|
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
||||||
|
with open(dest, "wb") as f:
|
||||||
|
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
||||||
|
|
||||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files in one SSH call."""
|
"""Batch-delete remote files in one SSH call."""
|
||||||
cmd = self._build_ssh_command()
|
cmd = self._build_ssh_command()
|
||||||
|
|
@ -245,6 +258,10 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
return _popen_bash(cmd, stdin_data)
|
return _popen_bash(cmd, stdin_data)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("SSH: syncing files from sandbox...")
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
|
||||||
if self.control_socket.exists():
|
if self.control_socket.exists():
|
||||||
try:
|
try:
|
||||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue