mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
41d3d7afb7
18 changed files with 2877 additions and 484 deletions
473
tests/tools/test_file_sync_back.py
Normal file
473
tests/tools/test_file_sync_back.py
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||
|
||||
import fcntl
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
_sha256_file,
|
||||
_SYNC_BACK_BACKOFF,
|
||||
_SYNC_BACK_MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||
"""Write a tar archive containing the given arcname->content pairs."""
|
||||
with tarfile.open(dest, "w") as tar:
|
||||
for arcname, content in files.items():
|
||||
info = tarfile.TarInfo(name=arcname)
|
||||
info.size = len(content)
|
||||
tar.addfile(info, io.BytesIO(content))
|
||||
|
||||
|
||||
def _make_download_fn(files: dict[str, bytes]):
|
||||
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||
def download(dest: Path):
|
||||
_make_tar(files, dest)
|
||||
return download
|
||||
|
||||
|
||||
def _sha256_bytes(data: bytes) -> str:
|
||||
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||
import hashlib
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def _write_file(path: Path, content: bytes) -> str:
|
||||
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
def _make_manager(
|
||||
tmp_path: Path,
|
||||
file_mapping: list[tuple[str, str]] | None = None,
|
||||
bulk_download_fn=None,
|
||||
seed_pushed_state: bool = True,
|
||||
) -> FileSyncManager:
|
||||
"""Create a FileSyncManager wired for testing.
|
||||
|
||||
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||
``get_files_fn`` returns. If *None* an empty list is used.
|
||||
|
||||
When *seed_pushed_state* is True (default), populate ``_pushed_hashes``
|
||||
from the mapping so sync_back doesn't early-return on the "nothing
|
||||
previously pushed" guard. Set False to test the noop path.
|
||||
"""
|
||||
mapping = file_mapping or []
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
bulk_download_fn=bulk_download_fn,
|
||||
)
|
||||
if seed_pushed_state:
|
||||
# Seed _pushed_hashes so sync_back's "nothing previously pushed"
|
||||
# guard does not early-return. Populate from the mapping when we
|
||||
# can; otherwise drop a sentinel entry.
|
||||
for host_path, remote_path in mapping:
|
||||
if os.path.exists(host_path):
|
||||
mgr._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||
else:
|
||||
mgr._pushed_hashes[remote_path] = "0" * 64
|
||||
if not mgr._pushed_hashes:
|
||||
mgr._pushed_hashes["/_sentinel"] = "0" * 64
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncBackNoop:
|
||||
"""sync_back() is a no-op when there is no download function."""
|
||||
|
||||
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||
# Should return immediately without error
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
# Nothing to assert beyond "no exception raised"
|
||||
|
||||
|
||||
class TestSyncBackNoChanges:
|
||||
"""When all remote files match pushed hashes, nothing is applied."""
|
||||
|
||||
def test_sync_back_no_changes(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "cred.json"
|
||||
host_content = b'{"key": "val"}'
|
||||
_write_file(host_file, host_content)
|
||||
|
||||
remote_path = "/root/.hermes/cred.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Remote tar contains the same content as was pushed
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/cred.json": host_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# Simulate that we already pushed this file with this hash
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be unchanged (same content, same bytes)
|
||||
assert host_file.read_bytes() == host_content
|
||||
|
||||
|
||||
class TestSyncBackAppliesChanged:
|
||||
"""Remote file differs from pushed version -- gets copied to host."""
|
||||
|
||||
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "skill.py"
|
||||
original_content = b"print('v1')"
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/skill.py"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
remote_content = b"print('v2 - edited on remote')"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skill.py": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackNewRemoteFile:
|
||||
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||
|
||||
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||
# Existing mapping gives _infer_host_path a prefix to work with
|
||||
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||
_write_file(existing_host, b"existing")
|
||||
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||
|
||||
# Remote has a NEW file in the same directory that was never pushed
|
||||
new_remote_content = b"# brand new skill created on remote"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# No entry in _pushed_hashes for the new file
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# The new file should have been inferred and written to the host
|
||||
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||
assert expected_host_path.exists()
|
||||
assert expected_host_path.read_bytes() == new_remote_content
|
||||
|
||||
|
||||
class TestSyncBackConflict:
|
||||
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||
|
||||
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||
host_file = tmp_path / "host" / "config.json"
|
||||
original_content = b'{"v": 1}'
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/config.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Host was modified after push
|
||||
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||
|
||||
# Remote was also modified
|
||||
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/config.json": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Conflict warning was logged
|
||||
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||
|
||||
# Remote version wins (last-write-wins)
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackRetries:
|
||||
"""Retry behaviour with exponential backoff."""
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||
call_count = 0
|
||||
|
||||
def flaky_download(dest: Path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError(f"network error #{call_count}")
|
||||
# Third attempt succeeds -- write a valid (empty) tar
|
||||
_make_tar({}, dest)
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert call_count == 3
|
||||
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||
def always_fail(dest: Path):
|
||||
raise RuntimeError("persistent failure")
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
# Should NOT raise -- failures are logged, not propagated
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# All retries were attempted
|
||||
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||
|
||||
# Final "all attempts failed" warning was logged
|
||||
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
class TestPushedHashesPopulated:
|
||||
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||
|
||||
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||
host_file = tmp_path / "data.txt"
|
||||
host_file.write_bytes(b"hello world")
|
||||
|
||||
remote_path = "/root/.hermes/data.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||
|
||||
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||
host_file = tmp_path / "deleteme.txt"
|
||||
host_file.write_bytes(b"to be deleted")
|
||||
|
||||
remote_path = "/root/.hermes/deleteme.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
current_mapping = list(mapping)
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: current_mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
# Sync to populate hashes
|
||||
mgr.sync(force=True)
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
|
||||
# Remove the file from the mapping (simulates local deletion)
|
||||
os.unlink(str(host_file))
|
||||
current_mapping.clear()
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
# Hash should be cleaned up
|
||||
assert remote_path not in mgr._pushed_hashes
|
||||
|
||||
|
||||
class TestSyncBackFileLock:
|
||||
"""Verify that fcntl.flock is used during sync-back."""
|
||||
|
||||
@patch("tools.environments.file_sync.fcntl.flock")
|
||||
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||
assert mock_flock.call_count >= 2
|
||||
|
||||
lock_calls = mock_flock.call_args_list
|
||||
lock_ops = [c[0][1] for c in lock_calls]
|
||||
assert fcntl.LOCK_EX in lock_ops
|
||||
assert fcntl.LOCK_UN in lock_ops
|
||||
|
||||
def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path):
|
||||
"""On Windows (fcntl=None), sync_back should skip file locking."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
with patch("tools.environments.file_sync.fcntl", None):
|
||||
# Should not raise — locking is skipped
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
|
||||
class TestInferHostPath:
|
||||
"""Edge cases for _infer_host_path prefix matching."""
|
||||
|
||||
def test_infer_no_matching_prefix(self, tmp_path):
|
||||
"""Remote path in unmapped directory should return None."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/cache/new.json",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_partial_prefix_no_false_match(self, tmp_path):
|
||||
"""A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
# /root/.hermes/skillsXtra/b.py shares prefix "skills" but the
|
||||
# directory is different — should not match /root/.hermes/skills/
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skillsXtra/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_matching_prefix(self, tmp_path):
|
||||
"""A file in a mapped directory should be correctly inferred."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skills/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
expected = str(tmp_path / "host" / "skills" / "b.py")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestSyncBackSIGINT:
|
||||
"""SIGINT deferral during sync-back."""
|
||||
|
||||
def test_sync_back_defers_sigint_on_main_thread(self, tmp_path):
|
||||
"""On the main thread, SIGINT handler should be swapped during sync."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
handlers_seen = []
|
||||
original_getsignal = signal.getsignal
|
||||
|
||||
with patch("tools.environments.file_sync.signal.getsignal",
|
||||
side_effect=original_getsignal) as mock_get, \
|
||||
patch("tools.environments.file_sync.signal.signal") as mock_set:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# signal.getsignal was called to save the original handler
|
||||
assert mock_get.called
|
||||
# signal.signal was called at least twice: install defer, restore original
|
||||
assert mock_set.call_count >= 2
|
||||
|
||||
def test_sync_back_skips_signal_on_worker_thread(self, tmp_path):
|
||||
"""From a non-main thread, signal.signal should NOT be called."""
|
||||
import threading
|
||||
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
signal_called = []
|
||||
|
||||
def tracking_signal(*args):
|
||||
signal_called.append(args)
|
||||
|
||||
with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal):
|
||||
# Run from a worker thread
|
||||
exc = []
|
||||
def run():
|
||||
try:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
except Exception as e:
|
||||
exc.append(e)
|
||||
|
||||
t = threading.Thread(target=run)
|
||||
t.start()
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not exc, f"sync_back raised: {exc}"
|
||||
# signal.signal should NOT have been called from the worker thread
|
||||
assert len(signal_called) == 0
|
||||
|
||||
|
||||
class TestSyncBackSizeCap:
|
||||
"""The size cap refuses to extract tars above the configured limit."""
|
||||
|
||||
def test_sync_back_refuses_oversized_tar(self, tmp_path, caplog):
|
||||
"""A tar larger than _SYNC_BACK_MAX_BYTES should be skipped with a warning."""
|
||||
# Build a download_fn that writes a small tar, but patch the cap
|
||||
# so the test doesn't need to produce a 2 GiB file.
|
||||
skill_host = _write_file(tmp_path / "host_skill.md", b"original")
|
||||
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||
download_fn = _make_download_fn(files)
|
||||
|
||||
mgr = _make_manager(
|
||||
tmp_path,
|
||||
file_mapping=[(skill_host, "/root/.hermes/skill.md")],
|
||||
bulk_download_fn=download_fn,
|
||||
)
|
||||
|
||||
# Cap at 1 byte so any non-empty tar exceeds it
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
with patch("tools.environments.file_sync._SYNC_BACK_MAX_BYTES", 1):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be untouched because extraction was skipped
|
||||
assert Path(skill_host).read_bytes() == b"original"
|
||||
# Warning should mention the cap
|
||||
assert any("cap" in r.message for r in caplog.records)
|
||||
|
||||
def test_sync_back_applies_when_under_cap(self, tmp_path):
|
||||
"""A tar under the cap should extract normally (sanity check)."""
|
||||
host_file = _write_file(tmp_path / "host_skill.md", b"original")
|
||||
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||
download_fn = _make_download_fn(files)
|
||||
|
||||
mgr = _make_manager(
|
||||
tmp_path,
|
||||
file_mapping=[(host_file, "/root/.hermes/skill.md")],
|
||||
bulk_download_fn=download_fn,
|
||||
)
|
||||
|
||||
# Default cap (2 GiB) is far above our tiny tar; extraction should proceed
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
assert Path(host_file).read_bytes() == b"remote_version"
|
||||
450
tests/tools/test_image_generation.py
Normal file
450
tests/tools/test_image_generation.py
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
"""Tests for tools/image_generation_tool.py — FAL multi-model support.
|
||||
|
||||
Covers the pure logic of the new wrapper: catalog integrity, the three size
|
||||
families (image_size_preset / aspect_ratio / gpt_literal), the supports
|
||||
whitelist, default merging, GPT quality override, and model resolution
|
||||
fallback. Does NOT exercise fal_client submission — that's covered by
|
||||
tests/tools/test_managed_media_gateways.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def image_tool():
|
||||
"""Fresh import of tools.image_generation_tool per test."""
|
||||
import importlib
|
||||
import tools.image_generation_tool as mod
|
||||
return importlib.reload(mod)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalog integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFalCatalog:
|
||||
"""Every FAL_MODELS entry must have a consistent shape."""
|
||||
|
||||
def test_default_model_is_klein(self, image_tool):
|
||||
assert image_tool.DEFAULT_MODEL == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_default_model_in_catalog(self, image_tool):
|
||||
assert image_tool.DEFAULT_MODEL in image_tool.FAL_MODELS
|
||||
|
||||
def test_all_entries_have_required_keys(self, image_tool):
|
||||
required = {
|
||||
"display", "speed", "strengths", "price",
|
||||
"size_style", "sizes", "defaults", "supports", "upscale",
|
||||
}
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
missing = required - set(meta.keys())
|
||||
assert not missing, f"{mid} missing required keys: {missing}"
|
||||
|
||||
def test_size_style_is_valid(self, image_tool):
|
||||
valid = {"image_size_preset", "aspect_ratio", "gpt_literal"}
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert meta["size_style"] in valid, \
|
||||
f"{mid} has invalid size_style: {meta['size_style']}"
|
||||
|
||||
def test_sizes_cover_all_aspect_ratios(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert set(meta["sizes"].keys()) >= {"landscape", "square", "portrait"}, \
|
||||
f"{mid} missing a required aspect_ratio key"
|
||||
|
||||
def test_supports_is_a_set(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert isinstance(meta["supports"], set), \
|
||||
f"{mid}.supports must be a set, got {type(meta['supports'])}"
|
||||
|
||||
def test_prompt_is_always_supported(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert "prompt" in meta["supports"], \
|
||||
f"{mid} must support 'prompt'"
|
||||
|
||||
def test_only_flux2_pro_upscales_by_default(self, image_tool):
|
||||
"""Upscaling should default to False for all new models to preserve
|
||||
the <1s / fast-render value prop. Only flux-2-pro stays True for
|
||||
backward-compat with the previous default."""
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
if mid == "fal-ai/flux-2-pro":
|
||||
assert meta["upscale"] is True, \
|
||||
"flux-2-pro should keep upscale=True for backward-compat"
|
||||
else:
|
||||
assert meta["upscale"] is False, \
|
||||
f"{mid} should default to upscale=False"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload building — three size families
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageSizePresetFamily:
|
||||
"""Flux, z-image, qwen, recraft, ideogram all use preset enum sizes."""
|
||||
|
||||
def test_klein_landscape_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "landscape")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
assert "aspect_ratio" not in p
|
||||
|
||||
def test_klein_square_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "square")
|
||||
assert p["image_size"] == "square_hd"
|
||||
|
||||
def test_klein_portrait_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "portrait")
|
||||
assert p["image_size"] == "portrait_16_9"
|
||||
|
||||
|
||||
class TestAspectRatioFamily:
|
||||
"""Nano-banana uses aspect_ratio enum, NOT image_size."""
|
||||
|
||||
def test_nano_banana_landscape_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "landscape")
|
||||
assert p["aspect_ratio"] == "16:9"
|
||||
assert "image_size" not in p
|
||||
|
||||
def test_nano_banana_square_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "square")
|
||||
assert p["aspect_ratio"] == "1:1"
|
||||
|
||||
def test_nano_banana_portrait_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "portrait")
|
||||
assert p["aspect_ratio"] == "9:16"
|
||||
|
||||
|
||||
class TestGptLiteralFamily:
|
||||
"""GPT-Image 1.5 uses literal size strings."""
|
||||
|
||||
def test_gpt_landscape_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "landscape")
|
||||
assert p["image_size"] == "1536x1024"
|
||||
|
||||
def test_gpt_square_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "square")
|
||||
assert p["image_size"] == "1024x1024"
|
||||
|
||||
def test_gpt_portrait_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "portrait")
|
||||
assert p["image_size"] == "1024x1536"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supports whitelist — the main safety property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupportsFilter:
|
||||
"""No model should receive keys outside its `supports` set."""
|
||||
|
||||
def test_payload_keys_are_subset_of_supports_for_all_models(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
payload = image_tool._build_fal_payload(mid, "test", "landscape", seed=42)
|
||||
unsupported = set(payload.keys()) - meta["supports"]
|
||||
assert not unsupported, \
|
||||
f"{mid} payload has unsupported keys: {unsupported}"
|
||||
|
||||
def test_gpt_image_has_no_seed_even_if_passed(self, image_tool):
|
||||
# GPT-Image 1.5 does not support seed — the filter must strip it.
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square", seed=42)
|
||||
assert "seed" not in p
|
||||
|
||||
def test_gpt_image_strips_unsupported_overrides(self, image_tool):
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/gpt-image-1.5", "hi", "square",
|
||||
overrides={"guidance_scale": 7.5, "num_inference_steps": 50},
|
||||
)
|
||||
assert "guidance_scale" not in p
|
||||
assert "num_inference_steps" not in p
|
||||
|
||||
def test_recraft_has_minimal_payload(self, image_tool):
|
||||
# Recraft supports prompt, image_size, style only.
|
||||
p = image_tool._build_fal_payload("fal-ai/recraft-v3", "hi", "landscape")
|
||||
assert set(p.keys()) <= {"prompt", "image_size", "style"}
|
||||
|
||||
def test_nano_banana_never_gets_image_size(self, image_tool):
|
||||
# Common bug: translator accidentally setting both image_size and aspect_ratio.
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hi", "landscape", seed=1)
|
||||
assert "image_size" not in p
|
||||
assert p["aspect_ratio"] == "16:9"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default merging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDefaults:
|
||||
"""Model-level defaults should carry through unless overridden."""
|
||||
|
||||
def test_klein_default_steps_is_4(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "square")
|
||||
assert p["num_inference_steps"] == 4
|
||||
|
||||
def test_flux_2_pro_default_steps_is_50(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2-pro", "hi", "square")
|
||||
assert p["num_inference_steps"] == 50
|
||||
|
||||
def test_override_replaces_default(self, image_tool):
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/flux-2-pro", "hi", "square", overrides={"num_inference_steps": 25}
|
||||
)
|
||||
assert p["num_inference_steps"] == 25
|
||||
|
||||
def test_none_override_does_not_replace_default(self, image_tool):
|
||||
"""None values from caller should be ignored (use default)."""
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/flux-2-pro", "hi", "square",
|
||||
overrides={"num_inference_steps": None},
|
||||
)
|
||||
assert p["num_inference_steps"] == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPT-Image quality is pinned to medium (not user-configurable)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGptQualityPinnedToMedium:
|
||||
"""GPT-Image quality is baked into the FAL_MODELS defaults at 'medium'
|
||||
and cannot be overridden via config. Pinning keeps Nous Portal billing
|
||||
predictable across all users."""
|
||||
|
||||
def test_gpt_payload_always_has_medium_quality(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_config_quality_setting_is_ignored(self, image_tool):
|
||||
"""Even if a user manually edits config.yaml and adds quality_setting,
|
||||
the payload must still use medium. No code path reads that field."""
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"quality_setting": "high"}}):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_non_gpt_model_never_gets_quality(self, image_tool):
|
||||
"""quality is only meaningful for gpt-image-1.5 — other models should
|
||||
never have it in their payload."""
|
||||
for mid in image_tool.FAL_MODELS:
|
||||
if mid == "fal-ai/gpt-image-1.5":
|
||||
continue
|
||||
p = image_tool._build_fal_payload(mid, "hi", "square")
|
||||
assert "quality" not in p, f"{mid} unexpectedly has 'quality' in payload"
|
||||
|
||||
def test_honors_quality_setting_flag_is_removed(self, image_tool):
|
||||
"""The honors_quality_setting flag was the old override trigger.
|
||||
It must not be present on any model entry anymore."""
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert "honors_quality_setting" not in meta, (
|
||||
f"{mid} still has honors_quality_setting; "
|
||||
f"remove it — quality is pinned to medium"
|
||||
)
|
||||
|
||||
def test_resolve_gpt_quality_function_is_gone(self, image_tool):
|
||||
"""The _resolve_gpt_quality() helper was removed — quality is now
|
||||
a static default, not a runtime lookup."""
|
||||
assert not hasattr(image_tool, "_resolve_gpt_quality"), (
|
||||
"_resolve_gpt_quality should not exist — quality is pinned"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelResolution:
|
||||
|
||||
def test_no_config_falls_back_to_default(self, image_tool):
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
mid, meta = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_valid_config_model_is_used(self, image_tool):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/flux-2-pro"}}):
|
||||
mid, meta = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2-pro"
|
||||
assert meta["upscale"] is True # flux-2-pro keeps backward-compat upscaling
|
||||
|
||||
def test_unknown_model_falls_back_to_default_with_warning(self, image_tool, caplog):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/nonexistent-9000"}}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_env_var_fallback_when_no_config(self, image_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/z-image/turbo"
|
||||
|
||||
def test_config_wins_over_env_var(self, image_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/nano-banana"}}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/nano-banana"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aspect ratio handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAspectRatioNormalization:
|
||||
|
||||
def test_invalid_aspect_defaults_to_landscape(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "cinemascope")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
|
||||
def test_uppercase_aspect_is_normalized(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "PORTRAIT")
|
||||
assert p["image_size"] == "portrait_16_9"
|
||||
|
||||
def test_empty_aspect_defaults_to_landscape(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema + registry integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRegistryIntegration:
|
||||
|
||||
def test_schema_exposes_only_prompt_and_aspect_ratio_to_agent(self, image_tool):
|
||||
"""The agent-facing schema must stay tight — model selection is a
|
||||
user-level config choice, not an agent-level arg."""
|
||||
props = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]
|
||||
assert set(props.keys()) == {"prompt", "aspect_ratio"}
|
||||
|
||||
def test_aspect_ratio_enum_is_three_values(self, image_tool):
|
||||
enum = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]["aspect_ratio"]["enum"]
|
||||
assert set(enum) == {"landscape", "square", "portrait"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed gateway 4xx translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MockResponse:
|
||||
def __init__(self, status_code: int):
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class _MockHttpxError(Exception):
|
||||
"""Simulates httpx.HTTPStatusError which exposes .response.status_code."""
|
||||
def __init__(self, status_code: int, message: str = "Bad Request"):
|
||||
super().__init__(message)
|
||||
self.response = _MockResponse(status_code)
|
||||
|
||||
|
||||
class TestExtractHttpStatus:
|
||||
"""Status-code extraction should work across exception shapes."""
|
||||
|
||||
def test_extracts_from_response_attr(self, image_tool):
|
||||
exc = _MockHttpxError(403)
|
||||
assert image_tool._extract_http_status(exc) == 403
|
||||
|
||||
def test_extracts_from_status_code_attr(self, image_tool):
|
||||
exc = Exception("fail")
|
||||
exc.status_code = 404 # type: ignore[attr-defined]
|
||||
assert image_tool._extract_http_status(exc) == 404
|
||||
|
||||
def test_returns_none_for_non_http_exception(self, image_tool):
|
||||
assert image_tool._extract_http_status(ValueError("nope")) is None
|
||||
assert image_tool._extract_http_status(RuntimeError("nope")) is None
|
||||
|
||||
def test_response_attr_without_status_code_returns_none(self, image_tool):
|
||||
class OddResponse:
|
||||
pass
|
||||
exc = Exception("weird")
|
||||
exc.response = OddResponse() # type: ignore[attr-defined]
|
||||
assert image_tool._extract_http_status(exc) is None
|
||||
|
||||
|
||||
class TestManagedGatewayErrorTranslation:
|
||||
"""4xx from the Nous managed gateway should be translated to a user-actionable message."""
|
||||
|
||||
def test_4xx_translates_to_value_error_with_remediation(self, image_tool, monkeypatch):
|
||||
"""403 from managed gateway → ValueError mentioning FAL_KEY + hermes tools."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Simulate: managed mode active, managed submit raises 4xx.
|
||||
managed_gateway = MagicMock()
|
||||
managed_gateway.gateway_origin = "https://fal-queue-gateway.example.com"
|
||||
managed_gateway.nous_user_token = "test-token"
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
bad_request = _MockHttpxError(403, "Forbidden")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = bad_request
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
image_tool._submit_fal_request("fal-ai/nano-banana", {"prompt": "x"})
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "fal-ai/nano-banana" in msg
|
||||
assert "403" in msg
|
||||
assert "FAL_KEY" in msg
|
||||
assert "hermes tools" in msg
|
||||
# Original exception chained for debugging
|
||||
assert exc_info.value.__cause__ is bad_request
|
||||
|
||||
def test_5xx_is_not_translated(self, image_tool, monkeypatch):
|
||||
"""500s are real outages, not model-availability issues — don't rewrite them."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
managed_gateway = MagicMock()
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
server_error = _MockHttpxError(502, "Bad Gateway")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = server_error
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(_MockHttpxError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
|
||||
def test_direct_fal_errors_are_not_translated(self, image_tool, monkeypatch):
|
||||
"""When user has direct FAL_KEY (managed gateway returns None), raw
|
||||
errors from fal_client bubble up unchanged — fal_client already
|
||||
provides reasonable error messages for direct usage."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: None)
|
||||
|
||||
direct_error = _MockHttpxError(403, "Forbidden")
|
||||
fake_fal_client = MagicMock()
|
||||
fake_fal_client.submit.side_effect = direct_error
|
||||
monkeypatch.setattr(image_tool, "fal_client", fake_fal_client)
|
||||
|
||||
with pytest.raises(_MockHttpxError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
|
||||
def test_non_http_exception_from_managed_bubbles_up(self, image_tool, monkeypatch):
|
||||
"""Connection errors, timeouts, etc. from managed mode aren't 4xx —
|
||||
they should bubble up unchanged so callers can retry or diagnose."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
managed_gateway = MagicMock()
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
conn_error = ConnectionError("network down")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = conn_error
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
495
tests/tools/test_sync_back_backends.py
Normal file
495
tests/tools/test_sync_back_backends.py
Normal file
|
|
@ -0,0 +1,495 @@
|
|||
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments import modal as modal_env
|
||||
from tools.environments import daytona as daytona_env
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssh_mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {
|
||||
"sync": lambda self, **k: None,
|
||||
"sync_back": lambda self: None,
|
||||
})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
# ── Modal helpers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_modal_env():
|
||||
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||
|
||||
Returns the exec_calls list for assertion.
|
||||
"""
|
||||
exec_calls = []
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.read = MagicMock()
|
||||
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls
|
||||
|
||||
|
||||
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_daytona_env():
|
||||
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._sync_manager = None
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
return env
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SSH bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestSSHBulkDownload:
|
||||
"""Unit tests for _ssh_bulk_download."""
|
||||
|
||||
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run command should include tar cf - over SSH."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
mock_run.assert_called_once()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
cmd_str = " ".join(cmd)
|
||||
assert "tar cf -" in cmd_str
|
||||
assert "-C /" in cmd_str
|
||||
assert "home/testuser/.hermes" in cmd_str
|
||||
assert "ssh" in cmd_str
|
||||
assert "testuser@example.com" in cmd_str
|
||||
|
||||
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
# The stdout kwarg should be a file object opened for writing
|
||||
call_kwargs = mock_run.call_args
|
||||
# stdout is passed as a keyword arg
|
||||
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||
# After the context manager exits, the file is closed, but we can verify
|
||||
# the dest path was used by checking if the file was created.
|
||||
assert dest.exists()
|
||||
|
||||
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||
"""Non-zero returncode should raise RuntimeError."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||
"""The subprocess.run call should use a 120s timeout."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||
|
||||
|
||||
class TestSSHCleanup:
|
||||
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||
env.control_socket = Path("/nonexistent/socket")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||
"""sync_back() must run before the ControlMaster exit command."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
# Create a fake control socket so cleanup tries the SSH exit
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||
env.control_socket = Path(tmp.name)
|
||||
|
||||
def mock_run(cmd, **kwargs):
|
||||
cmd_str = " ".join(cmd)
|
||||
if "-O" in cmd and "exit" in cmd_str:
|
||||
call_order.append("control_exit")
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||
env.cleanup()
|
||||
|
||||
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Modal bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestModalBulkDownload:
|
||||
"""Unit tests for _modal_bulk_download."""
|
||||
|
||||
def test_modal_bulk_download_command(self, tmp_path):
|
||||
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||
env = _make_mock_modal_env()
|
||||
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
assert "tar cf -" in args[2]
|
||||
assert "-C / root/.hermes" in args[2]
|
||||
|
||||
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||
"""Downloaded tar bytes should be written to the dest path."""
|
||||
env = _make_mock_modal_env()
|
||||
expected_data = b"some-tar-archive-bytes"
|
||||
_wire_modal_download(env, tar_bytes=expected_data)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.exists()
|
||||
assert dest.read_bytes() == expected_data
|
||||
|
||||
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||
env = _make_mock_modal_env()
|
||||
# Simulate Modal SDK returning str
|
||||
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.read_bytes() == b"string-tar-data"
|
||||
|
||||
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||
"""Non-zero exit code should raise RuntimeError."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, exit_code=1)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||
"""run_coroutine should be called with timeout=120."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, tar_bytes=b"data")
|
||||
|
||||
run_kwargs = {}
|
||||
original_run = env._worker.run_coroutine
|
||||
|
||||
def tracking_run(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
return original_run(coro, **kwargs)
|
||||
|
||||
env._worker.run_coroutine = tracking_run
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
|
||||
class TestModalCleanup:
|
||||
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||
|
||||
def test_modal_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||
env = _make_mock_modal_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
|
||||
# Mock terminate to track call order
|
||||
async def mock_terminate():
|
||||
pass
|
||||
|
||||
env._sandbox.terminate = MagicMock()
|
||||
env._sandbox.terminate.aio = mock_terminate
|
||||
env._worker.run_coroutine = lambda coro, **kw: (
|
||||
call_order.append("terminate"),
|
||||
asyncio.new_event_loop().run_until_complete(coro),
|
||||
)
|
||||
env._worker.stop = lambda: None
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Daytona bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestDaytonaBulkDownload:
|
||||
"""Unit tests for _daytona_bulk_download."""
|
||||
|
||||
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||
"""exec and download_file should both be called."""
|
||||
env = _make_mock_daytona_env()
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
# exec called twice: tar creation + rm cleanup
|
||||
assert env._sandbox.process.exec.call_count == 2
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "tar cf" in tar_cmd
|
||||
# PID-suffixed temp path avoids collisions on sync_back retry
|
||||
assert "/tmp/.hermes_sync." in tar_cmd
|
||||
assert ".tar" in tar_cmd
|
||||
assert ".hermes" in tar_cmd
|
||||
|
||||
cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0]
|
||||
assert "rm -f" in cleanup_cmd
|
||||
assert "/tmp/.hermes_sync." in cleanup_cmd
|
||||
|
||||
# download_file called once with the same PID-suffixed path
|
||||
env._sandbox.fs.download_file.assert_called_once()
|
||||
download_args = env._sandbox.fs.download_file.call_args[0]
|
||||
assert download_args[0].startswith("/tmp/.hermes_sync.")
|
||||
assert download_args[0].endswith(".tar")
|
||||
assert download_args[1] == str(dest)
|
||||
|
||||
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||
"""The tar command should use the env's _remote_home."""
|
||||
env = _make_mock_daytona_env()
|
||||
env._remote_home = "/home/daytona"
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "home/daytona/.hermes" in tar_cmd
|
||||
|
||||
|
||||
class TestDaytonaCleanup:
|
||||
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||
|
||||
def test_daytona_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||
env = _make_mock_daytona_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
env._sandbox.stop = lambda: call_order.append("stop")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert "stop" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("stop")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestBulkDownloadWiring:
|
||||
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||
|
||||
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class CaptureSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||
|
||||
SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
bulk_download_fn=env._modal_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = daytona_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||
upload_fn=env._daytona_upload,
|
||||
delete_fn=env._daytona_delete,
|
||||
bulk_upload_fn=env._daytona_bulk_upload,
|
||||
bulk_download_fn=env._daytona_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue