"""Tests for backend-specific bulk download implementations and cleanup() wiring.""" import asyncio import subprocess from pathlib import Path from unittest.mock import AsyncMock, MagicMock, call, patch import pytest from tools.environments import ssh as ssh_env from tools.environments import modal as modal_env from tools.environments import daytona as daytona_env from tools.environments.ssh import SSHEnvironment # ── SSH helpers ────────────────────────────────────────────────────── @pytest.fixture def ssh_mock_env(monkeypatch): """Create an SSHEnvironment with mocked connection/sync.""" monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser") monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) monkeypatch.setattr( ssh_env, "FileSyncManager", lambda **kw: type("M", (), { "sync": lambda self, **k: None, "sync_back": lambda self: None, })(), ) return SSHEnvironment(host="example.com", user="testuser") # ── Modal helpers ──────────────────────────────────────────────────── def _make_mock_modal_env(): """Create a minimal ModalEnvironment without calling __init__.""" env = object.__new__(modal_env.ModalEnvironment) env._sandbox = MagicMock() env._worker = MagicMock() env._persistent = False env._task_id = "test" env._sync_manager = None return env def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0): """Wire sandbox.exec.aio to return mock tar output for download tests. Returns the exec_calls list for assertion. """ exec_calls = [] async def mock_exec_fn(*args, **kwargs): exec_calls.append(args) proc = MagicMock() proc.stdout = MagicMock() proc.stdout.read = MagicMock() proc.stdout.read.aio = AsyncMock(return_value=tar_bytes) proc.wait = MagicMock() proc.wait.aio = AsyncMock(return_value=exit_code) return proc env._sandbox.exec = MagicMock() env._sandbox.exec.aio = mock_exec_fn def real_run_coroutine(coro, **kwargs): loop = asyncio.new_event_loop() try: return loop.run_until_complete(coro) finally: loop.close() env._worker.run_coroutine = real_run_coroutine return exec_calls # ── Daytona helpers ────────────────────────────────────────────────── def _make_mock_daytona_env(): """Create a minimal DaytonaEnvironment without calling __init__.""" env = object.__new__(daytona_env.DaytonaEnvironment) env._sandbox = MagicMock() env._remote_home = "/root" env._sync_manager = None env._lock = __import__("threading").Lock() env._persistent = True env._task_id = "test" env._daytona = MagicMock() return env # ===================================================================== # SSH bulk download # ===================================================================== class TestSSHBulkDownload: """Unit tests for _ssh_bulk_download.""" def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path): """subprocess.run command should include tar cf - over SSH.""" dest = tmp_path / "backup.tar" with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: # open() will be called to write stdout; mock it to avoid actual file I/O ssh_mock_env._ssh_bulk_download(dest) mock_run.assert_called_once() cmd = mock_run.call_args[0][0] cmd_str = " ".join(cmd) assert "tar cf -" in cmd_str assert "-C /" in cmd_str assert "home/testuser/.hermes" in cmd_str assert "ssh" in cmd_str assert "testuser@example.com" in cmd_str def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path): """subprocess.run should receive stdout=open(dest, 'wb').""" dest = tmp_path / "backup.tar" with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: ssh_mock_env._ssh_bulk_download(dest) # The stdout kwarg should be a file object opened for writing call_kwargs = mock_run.call_args # stdout is passed as a keyword arg stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout") # The file was opened via `with open(dest, "wb") as f` and passed as stdout=f. # After the context manager exits, the file is closed, but we can verify # the dest path was used by checking if the file was created. assert dest.exists() def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path): """Non-zero returncode should raise RuntimeError.""" dest = tmp_path / "backup.tar" failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied") with patch.object(subprocess, "run", return_value=failed): with pytest.raises(RuntimeError, match="SSH bulk download failed"): ssh_mock_env._ssh_bulk_download(dest) def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path): """The subprocess.run call should use a 120s timeout.""" dest = tmp_path / "backup.tar" with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run: ssh_mock_env._ssh_bulk_download(dest) call_kwargs = mock_run.call_args assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120 class TestSSHCleanup: """Verify SSH cleanup() calls sync_back() before closing ControlMaster.""" def test_ssh_cleanup_calls_sync_back(self, monkeypatch): """cleanup() should call sync_back() before SSH control socket teardown.""" monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) call_order = [] class TrackingSyncManager: def __init__(self, **kwargs): pass def sync(self, **kw): pass def sync_back(self): call_order.append("sync_back") monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager) env = SSHEnvironment(host="h", user="u") # Ensure control_socket does not exist so cleanup skips the SSH exit call env.control_socket = Path("/nonexistent/socket") env.cleanup() assert "sync_back" in call_order def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch): """sync_back() must run before the ControlMaster exit command.""" monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) call_order = [] class TrackingSyncManager: def __init__(self, **kwargs): pass def sync(self, **kw): pass def sync_back(self): call_order.append("sync_back") monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager) env = SSHEnvironment(host="h", user="u") # Create a fake control socket so cleanup tries the SSH exit import tempfile with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp: env.control_socket = Path(tmp.name) def mock_run(cmd, **kwargs): cmd_str = " ".join(cmd) if "-O" in cmd and "exit" in cmd_str: call_order.append("control_exit") return subprocess.CompletedProcess([], 0) with patch.object(subprocess, "run", side_effect=mock_run): env.cleanup() assert call_order.index("sync_back") < call_order.index("control_exit") # ===================================================================== # Modal bulk download # ===================================================================== class TestModalBulkDownload: """Unit tests for _modal_bulk_download.""" def test_modal_bulk_download_command(self, tmp_path): """exec should be called with tar cf - -C /root/.hermes .""" env = _make_mock_modal_env() exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content") dest = tmp_path / "backup.tar" env._modal_bulk_download(dest) assert len(exec_calls) == 1 args = exec_calls[0] assert args[0] == "bash" assert args[1] == "-c" assert "tar cf -" in args[2] assert "-C / root/.hermes" in args[2] def test_modal_bulk_download_writes_to_dest(self, tmp_path): """Downloaded tar bytes should be written to the dest path.""" env = _make_mock_modal_env() expected_data = b"some-tar-archive-bytes" _wire_modal_download(env, tar_bytes=expected_data) dest = tmp_path / "backup.tar" env._modal_bulk_download(dest) assert dest.exists() assert dest.read_bytes() == expected_data def test_modal_bulk_download_handles_str_output(self, tmp_path): """If stdout returns str instead of bytes, it should be encoded.""" env = _make_mock_modal_env() # Simulate Modal SDK returning str _wire_modal_download(env, tar_bytes="string-tar-data") dest = tmp_path / "backup.tar" env._modal_bulk_download(dest) assert dest.read_bytes() == b"string-tar-data" def test_modal_bulk_download_raises_on_failure(self, tmp_path): """Non-zero exit code should raise RuntimeError.""" env = _make_mock_modal_env() _wire_modal_download(env, exit_code=1) dest = tmp_path / "backup.tar" with pytest.raises(RuntimeError, match="Modal bulk download failed"): env._modal_bulk_download(dest) def test_modal_bulk_download_uses_120s_timeout(self, tmp_path): """run_coroutine should be called with timeout=120.""" env = _make_mock_modal_env() _wire_modal_download(env, tar_bytes=b"data") run_kwargs = {} original_run = env._worker.run_coroutine def tracking_run(coro, **kwargs): run_kwargs.update(kwargs) return original_run(coro, **kwargs) env._worker.run_coroutine = tracking_run dest = tmp_path / "backup.tar" env._modal_bulk_download(dest) assert run_kwargs.get("timeout") == 120 class TestModalCleanup: """Verify Modal cleanup() calls sync_back() before terminate.""" def test_modal_cleanup_calls_sync_back(self): """cleanup() should call sync_back() before sandbox.terminate.""" env = _make_mock_modal_env() call_order = [] sync_mgr = MagicMock() sync_mgr.sync_back = lambda: call_order.append("sync_back") env._sync_manager = sync_mgr # Mock terminate to track call order async def mock_terminate(): pass env._sandbox.terminate = MagicMock() env._sandbox.terminate.aio = mock_terminate env._worker.run_coroutine = lambda coro, **kw: ( call_order.append("terminate"), asyncio.new_event_loop().run_until_complete(coro), ) env._worker.stop = lambda: None env.cleanup() assert "sync_back" in call_order assert call_order.index("sync_back") < call_order.index("terminate") # ===================================================================== # Daytona bulk download # ===================================================================== class TestDaytonaBulkDownload: """Unit tests for _daytona_bulk_download.""" def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path): """exec and download_file should both be called.""" env = _make_mock_daytona_env() dest = tmp_path / "backup.tar" env._daytona_bulk_download(dest) # exec called twice: tar creation + rm cleanup assert env._sandbox.process.exec.call_count == 2 tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0] assert "tar cf" in tar_cmd # PID-suffixed temp path avoids collisions on sync_back retry assert "/tmp/.hermes_sync." in tar_cmd assert ".tar" in tar_cmd assert ".hermes" in tar_cmd cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0] assert "rm -f" in cleanup_cmd assert "/tmp/.hermes_sync." in cleanup_cmd # download_file called once with the same PID-suffixed path env._sandbox.fs.download_file.assert_called_once() download_args = env._sandbox.fs.download_file.call_args[0] assert download_args[0].startswith("/tmp/.hermes_sync.") assert download_args[0].endswith(".tar") assert download_args[1] == str(dest) def test_daytona_bulk_download_uses_remote_home(self, tmp_path): """The tar command should use the env's _remote_home.""" env = _make_mock_daytona_env() env._remote_home = "/home/daytona" dest = tmp_path / "backup.tar" env._daytona_bulk_download(dest) tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0] assert "home/daytona/.hermes" in tar_cmd class TestDaytonaCleanup: """Verify Daytona cleanup() calls sync_back() before stop.""" def test_daytona_cleanup_calls_sync_back(self): """cleanup() should call sync_back() before sandbox.stop().""" env = _make_mock_daytona_env() call_order = [] sync_mgr = MagicMock() sync_mgr.sync_back = lambda: call_order.append("sync_back") env._sync_manager = sync_mgr env._sandbox.stop = lambda: call_order.append("stop") env.cleanup() assert "sync_back" in call_order assert "stop" in call_order assert call_order.index("sync_back") < call_order.index("stop") # ===================================================================== # FileSyncManager wiring: bulk_download_fn passed by each backend # ===================================================================== class TestBulkDownloadWiring: """Verify each backend passes bulk_download_fn to FileSyncManager.""" def test_ssh_passes_bulk_download_fn(self, monkeypatch): """SSHEnvironment should pass _ssh_bulk_download to FileSyncManager.""" monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root") monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) captured_kwargs = {} class CaptureSyncManager: def __init__(self, **kwargs): captured_kwargs.update(kwargs) def sync(self, **kw): pass monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager) SSHEnvironment(host="h", user="u") assert "bulk_download_fn" in captured_kwargs assert callable(captured_kwargs["bulk_download_fn"]) def test_modal_passes_bulk_download_fn(self, monkeypatch): """ModalEnvironment should pass _modal_bulk_download to FileSyncManager.""" captured_kwargs = {} def capture_fsm(**kwargs): captured_kwargs.update(kwargs) return type("M", (), {"sync": lambda self, **k: None})() monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm) env = object.__new__(modal_env.ModalEnvironment) env._sandbox = MagicMock() env._worker = MagicMock() env._persistent = False env._task_id = "test" # Replicate the wiring done in __init__ from tools.environments.file_sync import iter_sync_files env._sync_manager = modal_env.FileSyncManager( get_files_fn=lambda: iter_sync_files("/root/.hermes"), upload_fn=env._modal_upload, delete_fn=env._modal_delete, bulk_upload_fn=env._modal_bulk_upload, bulk_download_fn=env._modal_bulk_download, ) assert "bulk_download_fn" in captured_kwargs assert callable(captured_kwargs["bulk_download_fn"]) def test_daytona_passes_bulk_download_fn(self, monkeypatch): """DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager.""" captured_kwargs = {} def capture_fsm(**kwargs): captured_kwargs.update(kwargs) return type("M", (), {"sync": lambda self, **k: None})() monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm) env = object.__new__(daytona_env.DaytonaEnvironment) env._sandbox = MagicMock() env._remote_home = "/root" env._lock = __import__("threading").Lock() env._persistent = True env._task_id = "test" env._daytona = MagicMock() # Replicate the wiring done in __init__ from tools.environments.file_sync import iter_sync_files env._sync_manager = daytona_env.FileSyncManager( get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"), upload_fn=env._daytona_upload, delete_fn=env._daytona_delete, bulk_upload_fn=env._daytona_bulk_upload, bulk_download_fn=env._daytona_bulk_download, ) assert "bulk_download_fn" in captured_kwargs assert callable(captured_kwargs["bulk_download_fn"])