mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
41d3d7afb7
18 changed files with 2877 additions and 484 deletions
|
|
@ -727,6 +727,22 @@ def switch_model(
|
|||
if not api_mode:
|
||||
api_mode = determine_api_mode(target_provider, base_url)
|
||||
|
||||
# OpenCode base URLs end with /v1 for OpenAI-compatible models, but the
|
||||
# Anthropic SDK prepends its own /v1/messages to the base_url. Strip the
|
||||
# trailing /v1 so the SDK constructs the correct path (e.g.
|
||||
# https://opencode.ai/zen/go/v1/messages instead of .../v1/v1/messages).
|
||||
# Mirrors the same logic in hermes_cli.runtime_provider.resolve_runtime_provider;
|
||||
# without it, /model switches into an anthropic_messages-routed OpenCode
|
||||
# model (e.g. `/model minimax-m2.7` on opencode-go, `/model claude-sonnet-4-6`
|
||||
# on opencode-zen) hit a double /v1 and returned OpenCode's website 404 page.
|
||||
if (
|
||||
api_mode == "anthropic_messages"
|
||||
and target_provider in {"opencode-zen", "opencode-go"}
|
||||
and isinstance(base_url, str)
|
||||
and base_url
|
||||
):
|
||||
base_url = re.sub(r"/v1/?$", "", base_url)
|
||||
|
||||
# --- Get capabilities (legacy) ---
|
||||
capabilities = get_model_capabilities(target_provider, new_model)
|
||||
|
||||
|
|
|
|||
|
|
@ -258,14 +258,16 @@ TOOL_CATEGORIES = {
|
|||
"requires_nous_auth": True,
|
||||
"managed_nous_feature": "image_gen",
|
||||
"override_env_vars": ["FAL_KEY"],
|
||||
"imagegen_backend": "fal",
|
||||
},
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"badge": "paid",
|
||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
||||
"tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
],
|
||||
"imagegen_backend": "fal",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -950,6 +952,106 @@ def _detect_active_provider_index(providers: list, config: dict) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
# ─── Image Generation Model Pickers ───────────────────────────────────────────
|
||||
#
|
||||
# IMAGEGEN_BACKENDS is a per-backend catalog. Each entry exposes:
|
||||
# - config_key: top-level config.yaml key for this backend's settings
|
||||
# - model_catalog_fn: returns an OrderedDict-like {model_id: metadata}
|
||||
# - default_model: fallback when nothing is configured
|
||||
#
|
||||
# This prepares for future imagegen backends (Replicate, Stability, etc.):
|
||||
# each new backend registers its own entry; the FAL provider entry in
|
||||
# TOOL_CATEGORIES tags itself with `imagegen_backend: "fal"` to select the
|
||||
# right catalog at picker time.
|
||||
|
||||
|
||||
def _fal_model_catalog():
|
||||
"""Lazy-load the FAL model catalog from the tool module."""
|
||||
from tools.image_generation_tool import FAL_MODELS, DEFAULT_MODEL
|
||||
return FAL_MODELS, DEFAULT_MODEL
|
||||
|
||||
|
||||
IMAGEGEN_BACKENDS = {
|
||||
"fal": {
|
||||
"display": "FAL.ai",
|
||||
"config_key": "image_gen",
|
||||
"catalog_fn": _fal_model_catalog,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _format_imagegen_model_row(model_id: str, meta: dict, widths: dict) -> str:
|
||||
"""Format a single picker row with column-aligned speed / strengths / price."""
|
||||
return (
|
||||
f"{model_id:<{widths['model']}} "
|
||||
f"{meta.get('speed', ''):<{widths['speed']}} "
|
||||
f"{meta.get('strengths', ''):<{widths['strengths']}} "
|
||||
f"{meta.get('price', '')}"
|
||||
)
|
||||
|
||||
|
||||
def _configure_imagegen_model(backend_name: str, config: dict) -> None:
|
||||
"""Prompt the user to pick a model for the given imagegen backend.
|
||||
|
||||
Writes selection to ``config[backend_config_key]["model"]``. Safe to
|
||||
call even when stdin is not a TTY — curses_radiolist falls back to
|
||||
keeping the current selection.
|
||||
"""
|
||||
backend = IMAGEGEN_BACKENDS.get(backend_name)
|
||||
if not backend:
|
||||
return
|
||||
|
||||
catalog, default_model = backend["catalog_fn"]()
|
||||
if not catalog:
|
||||
return
|
||||
|
||||
cfg_key = backend["config_key"]
|
||||
cur_cfg = config.setdefault(cfg_key, {})
|
||||
if not isinstance(cur_cfg, dict):
|
||||
cur_cfg = {}
|
||||
config[cfg_key] = cur_cfg
|
||||
current_model = cur_cfg.get("model") or default_model
|
||||
if current_model not in catalog:
|
||||
current_model = default_model
|
||||
|
||||
model_ids = list(catalog.keys())
|
||||
# Put current model at the top so the cursor lands on it by default.
|
||||
ordered = [current_model] + [m for m in model_ids if m != current_model]
|
||||
|
||||
# Column widths
|
||||
widths = {
|
||||
"model": max(len(m) for m in model_ids),
|
||||
"speed": max((len(catalog[m].get("speed", "")) for m in model_ids), default=6),
|
||||
"strengths": max((len(catalog[m].get("strengths", "")) for m in model_ids), default=0),
|
||||
}
|
||||
|
||||
print()
|
||||
header = (
|
||||
f" {'Model':<{widths['model']}} "
|
||||
f"{'Speed':<{widths['speed']}} "
|
||||
f"{'Strengths':<{widths['strengths']}} "
|
||||
f"Price"
|
||||
)
|
||||
print(color(header, Colors.CYAN))
|
||||
|
||||
rows = []
|
||||
for mid in ordered:
|
||||
row = _format_imagegen_model_row(mid, catalog[mid], widths)
|
||||
if mid == current_model:
|
||||
row += " ← currently in use"
|
||||
rows.append(row)
|
||||
|
||||
idx = _prompt_choice(
|
||||
f" Choose {backend['display']} model:",
|
||||
rows,
|
||||
default=0,
|
||||
)
|
||||
|
||||
chosen = ordered[idx]
|
||||
cur_cfg["model"] = chosen
|
||||
_print_success(f" Model set to: {chosen}")
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
|
@ -1006,6 +1108,10 @@ def _configure_provider(provider: dict, config: dict):
|
|||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
if managed_feature:
|
||||
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
||||
# Imagegen backends prompt for model selection after backend pick.
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
return
|
||||
|
||||
# Prompt for each required env var
|
||||
|
|
@ -1040,6 +1146,10 @@ def _configure_provider(provider: dict, config: dict):
|
|||
|
||||
if all_configured:
|
||||
_print_success(f" {provider['name']} configured!")
|
||||
# Imagegen backends prompt for model selection after env vars are in.
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
|
||||
|
||||
def _configure_simple_requirements(ts_key: str):
|
||||
|
|
@ -1211,6 +1321,10 @@ def _reconfigure_provider(provider: dict, config: dict):
|
|||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
if managed_feature:
|
||||
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
||||
# Imagegen backends prompt for model selection on reconfig too.
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
return
|
||||
|
||||
for var in env_vars:
|
||||
|
|
@ -1228,6 +1342,11 @@ def _reconfigure_provider(provider: dict, config: dict):
|
|||
else:
|
||||
_print_info(" Kept current")
|
||||
|
||||
# Imagegen backends prompt for model selection on reconfig too.
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
"""Reconfigure simple env var requirements."""
|
||||
|
|
|
|||
49
run_agent.py
49
run_agent.py
|
|
@ -1674,12 +1674,26 @@ class AIAgent:
|
|||
turn-scoped).
|
||||
"""
|
||||
import logging
|
||||
import re as _re
|
||||
from hermes_cli.providers import determine_api_mode
|
||||
|
||||
# ── Determine api_mode if not provided ──
|
||||
if not api_mode:
|
||||
api_mode = determine_api_mode(new_provider, base_url)
|
||||
|
||||
# Defense-in-depth: ensure OpenCode base_url doesn't carry a trailing
|
||||
# /v1 into the anthropic_messages client, which would cause the SDK to
|
||||
# hit /v1/v1/messages. `model_switch.switch_model()` already strips
|
||||
# this, but we guard here so any direct callers (future code paths,
|
||||
# tests) can't reintroduce the double-/v1 404 bug.
|
||||
if (
|
||||
api_mode == "anthropic_messages"
|
||||
and new_provider in ("opencode-zen", "opencode-go")
|
||||
and isinstance(base_url, str)
|
||||
and base_url
|
||||
):
|
||||
base_url = _re.sub(r"/v1/?$", "", base_url)
|
||||
|
||||
old_model = self.model
|
||||
old_provider = self.provider
|
||||
|
||||
|
|
@ -4381,6 +4395,41 @@ class AIAgent:
|
|||
self._client_log_context(),
|
||||
)
|
||||
return client
|
||||
# Inject TCP keepalives so the kernel detects dead provider connections
|
||||
# instead of letting them sit silently in CLOSE-WAIT (#10324). Without
|
||||
# this, a peer that drops mid-stream leaves the socket in a state where
|
||||
# epoll_wait never fires, ``httpx`` read timeout may not trigger, and
|
||||
# the agent hangs until manually killed. Probes after 30s idle, retry
|
||||
# every 10s, give up after 3 → dead peer detected within ~60s.
|
||||
#
|
||||
# Safety against #10933: the ``client_kwargs = dict(client_kwargs)``
|
||||
# above means this injection only lands in the local per-call copy,
|
||||
# never back into ``self._client_kwargs``. Each ``_create_openai_client``
|
||||
# invocation therefore gets its OWN fresh ``httpx.Client`` whose
|
||||
# lifetime is tied to the OpenAI client it is passed to. When the
|
||||
# OpenAI client is closed (rebuild, teardown, credential rotation),
|
||||
# the paired ``httpx.Client`` closes with it, and the next call
|
||||
# constructs a fresh one — no stale closed transport can be reused.
|
||||
# Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and
|
||||
# ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant.
|
||||
if "http_client" not in client_kwargs:
|
||||
try:
|
||||
import httpx as _httpx
|
||||
import socket as _socket
|
||||
_sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)]
|
||||
if hasattr(_socket, "TCP_KEEPIDLE"):
|
||||
# Linux
|
||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30))
|
||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10))
|
||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3))
|
||||
elif hasattr(_socket, "TCP_KEEPALIVE"):
|
||||
# macOS (uses TCP_KEEPALIVE instead of TCP_KEEPIDLE)
|
||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30))
|
||||
client_kwargs["http_client"] = _httpx.Client(
|
||||
transport=_httpx.HTTPTransport(socket_options=_sock_opts),
|
||||
)
|
||||
except Exception:
|
||||
pass # Fall through to default transport if socket opts fail
|
||||
client = OpenAI(**client_kwargs)
|
||||
logger.info(
|
||||
"OpenAI client created (%s, shared=%s) %s",
|
||||
|
|
|
|||
|
|
@ -122,6 +122,43 @@ log_error() {
|
|||
echo -e "${RED}✗${NC} $1"
|
||||
}
|
||||
|
||||
prompt_yes_no() {
|
||||
local question="$1"
|
||||
local default="${2:-yes}"
|
||||
local prompt_suffix
|
||||
local answer=""
|
||||
|
||||
# Use case patterns (not ${var,,}) so this works on bash 3.2 (macOS /bin/bash).
|
||||
case "$default" in
|
||||
[yY]|[yY][eE][sS]|[tT][rR][uU][eE]|1) prompt_suffix="[Y/n]" ;;
|
||||
*) prompt_suffix="[y/N]" ;;
|
||||
esac
|
||||
|
||||
if [ "$IS_INTERACTIVE" = true ]; then
|
||||
read -r -p "$question $prompt_suffix " answer || answer=""
|
||||
elif [ -r /dev/tty ] && [ -w /dev/tty ]; then
|
||||
printf "%s %s " "$question" "$prompt_suffix" > /dev/tty
|
||||
IFS= read -r answer < /dev/tty || answer=""
|
||||
else
|
||||
answer=""
|
||||
fi
|
||||
|
||||
answer="${answer#"${answer%%[![:space:]]*}"}"
|
||||
answer="${answer%"${answer##*[![:space:]]}"}"
|
||||
|
||||
if [ -z "$answer" ]; then
|
||||
case "$default" in
|
||||
[yY]|[yY][eE][sS]|[tT][rR][uU][eE]|1) return 0 ;;
|
||||
*) return 1 ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
case "$answer" in
|
||||
[yY]|[yY][eE][sS]) return 0 ;;
|
||||
*) return 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
|
@ -606,9 +643,7 @@ install_system_packages() {
|
|||
echo ""
|
||||
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
||||
log_info "Hermes Agent itself does not require or retain root access."
|
||||
read -p "Install ${description}? (requires sudo) [y/N] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
if prompt_yes_no "Install ${description}? (requires sudo)" "no"; then
|
||||
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd; then
|
||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||
|
|
@ -621,9 +656,7 @@ install_system_packages() {
|
|||
echo ""
|
||||
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
||||
log_info "Hermes Agent itself does not require or retain root access."
|
||||
read -p "Install ${description}? [Y/n] " -n 1 -r < /dev/tty
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
if prompt_yes_no "Install ${description}?" "yes"; then
|
||||
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd < /dev/tty; then
|
||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||
|
|
@ -863,9 +896,7 @@ install_deps() {
|
|||
else
|
||||
log_info "sudo is needed ONLY to install build tools (build-essential, python3-dev, libffi-dev) via apt."
|
||||
log_info "Hermes Agent itself does not require or retain root access."
|
||||
read -p "Install build tools? [Y/n] " -n 1 -r < /dev/tty
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
if prompt_yes_no "Install build tools?" "yes"; then
|
||||
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get update -qq && sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get install -y -qq build-essential python3-dev libffi-dev >/dev/null 2>&1 || true
|
||||
log_success "Build tools installed"
|
||||
fi
|
||||
|
|
@ -1246,9 +1277,7 @@ maybe_start_gateway() {
|
|||
log_info "WhatsApp is enabled but not yet paired."
|
||||
log_info "Running 'hermes whatsapp' to pair via QR code..."
|
||||
echo ""
|
||||
read -p "Pair WhatsApp now? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
if prompt_yes_no "Pair WhatsApp now?" "yes"; then
|
||||
HERMES_CMD="$(get_hermes_command_path)"
|
||||
$HERMES_CMD whatsapp || true
|
||||
fi
|
||||
|
|
@ -1263,14 +1292,18 @@ maybe_start_gateway() {
|
|||
fi
|
||||
|
||||
echo ""
|
||||
local should_install_gateway=false
|
||||
if [ "$DISTRO" = "termux" ]; then
|
||||
read -p "Would you like to start the gateway in the background? [Y/n] " -n 1 -r < /dev/tty
|
||||
if prompt_yes_no "Would you like to start the gateway in the background?" "yes"; then
|
||||
should_install_gateway=true
|
||||
fi
|
||||
else
|
||||
read -p "Would you like to install the gateway as a background service? [Y/n] " -n 1 -r < /dev/tty
|
||||
if prompt_yes_no "Would you like to install the gateway as a background service?" "yes"; then
|
||||
should_install_gateway=true
|
||||
fi
|
||||
fi
|
||||
echo
|
||||
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
if [ "$should_install_gateway" = true ]; then
|
||||
HERMES_CMD="$(get_hermes_command_path)"
|
||||
|
||||
if [ "$DISTRO" != "termux" ] && command -v systemctl &> /dev/null; then
|
||||
|
|
|
|||
252
tests/hermes_cli/test_model_switch_opencode_anthropic.py
Normal file
252
tests/hermes_cli/test_model_switch_opencode_anthropic.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""Regression tests for OpenCode /v1 stripping during /model switch.
|
||||
|
||||
When switching to an Anthropic-routed OpenCode model mid-session (e.g.
|
||||
``/model minimax-m2.7`` on opencode-go, or ``/model claude-sonnet-4-6``
|
||||
on opencode-zen), the resolved base_url must have its trailing ``/v1``
|
||||
stripped before being handed to the Anthropic SDK.
|
||||
|
||||
Without the strip, the SDK prepends its own ``/v1/messages`` path and
|
||||
requests hit ``https://opencode.ai/zen/go/v1/v1/messages`` — a double
|
||||
``/v1`` that returns OpenCode's website 404 page with HTML body.
|
||||
|
||||
``hermes_cli.runtime_provider.resolve_runtime_provider`` already strips
|
||||
``/v1`` at fresh agent init (PR #4918), but the ``/model`` mid-session
|
||||
switch path in ``hermes_cli.model_switch.switch_model`` was missing the
|
||||
same logic — these tests guard against that regression.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.model_switch import switch_model
|
||||
|
||||
|
||||
_MOCK_VALIDATION = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
|
||||
def _run_opencode_switch(
|
||||
raw_input: str,
|
||||
current_provider: str,
|
||||
current_model: str,
|
||||
current_base_url: str,
|
||||
explicit_provider: str = "",
|
||||
runtime_base_url: str = "",
|
||||
):
|
||||
"""Run switch_model with OpenCode mocks and return the result.
|
||||
|
||||
runtime_base_url defaults to current_base_url; tests can override it
|
||||
to simulate the credential resolver returning a base_url different
|
||||
from the session's current one.
|
||||
"""
|
||||
effective_runtime_base = runtime_base_url or current_base_url
|
||||
with (
|
||||
patch("hermes_cli.model_switch.resolve_alias", return_value=None),
|
||||
patch("hermes_cli.model_switch.list_provider_models", return_value=[]),
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "sk-opencode-fake",
|
||||
"base_url": effective_runtime_base,
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"hermes_cli.models.validate_requested_model",
|
||||
return_value=_MOCK_VALIDATION,
|
||||
),
|
||||
patch("hermes_cli.model_switch.get_model_info", return_value=None),
|
||||
patch("hermes_cli.model_switch.get_model_capabilities", return_value=None),
|
||||
patch("hermes_cli.models.detect_provider_for_model", return_value=None),
|
||||
):
|
||||
return switch_model(
|
||||
raw_input=raw_input,
|
||||
current_provider=current_provider,
|
||||
current_model=current_model,
|
||||
current_base_url=current_base_url,
|
||||
current_api_key="sk-opencode-fake",
|
||||
explicit_provider=explicit_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenCodeGoV1Strip:
|
||||
"""OpenCode Go: ``/model minimax-*`` must strip /v1."""
|
||||
|
||||
def test_switch_to_minimax_m27_strips_v1(self):
|
||||
"""GLM-5 → MiniMax-M2.7: base_url loses trailing /v1."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="minimax-m2.7",
|
||||
current_provider="opencode-go",
|
||||
current_model="glm-5",
|
||||
current_base_url="https://opencode.ai/zen/go/v1",
|
||||
)
|
||||
|
||||
assert result.success, f"switch_model failed: {result.error_message}"
|
||||
assert result.api_mode == "anthropic_messages"
|
||||
assert result.base_url == "https://opencode.ai/zen/go", (
|
||||
f"Expected /v1 stripped for anthropic_messages; got {result.base_url}"
|
||||
)
|
||||
|
||||
def test_switch_to_minimax_m25_strips_v1(self):
|
||||
"""Same behavior for M2.5."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="minimax-m2.5",
|
||||
current_provider="opencode-go",
|
||||
current_model="kimi-k2.5",
|
||||
current_base_url="https://opencode.ai/zen/go/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "anthropic_messages"
|
||||
assert result.base_url == "https://opencode.ai/zen/go"
|
||||
|
||||
def test_switch_to_glm_leaves_v1_intact(self):
|
||||
"""OpenAI-compatible models (GLM, Kimi, MiMo) keep /v1."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="glm-5.1",
|
||||
current_provider="opencode-go",
|
||||
current_model="minimax-m2.7",
|
||||
current_base_url="https://opencode.ai/zen/go", # stripped from previous Anthropic model
|
||||
runtime_base_url="https://opencode.ai/zen/go/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "chat_completions"
|
||||
assert result.base_url == "https://opencode.ai/zen/go/v1", (
|
||||
f"chat_completions must keep /v1; got {result.base_url}"
|
||||
)
|
||||
|
||||
def test_switch_to_kimi_leaves_v1_intact(self):
|
||||
result = _run_opencode_switch(
|
||||
raw_input="kimi-k2.5",
|
||||
current_provider="opencode-go",
|
||||
current_model="glm-5",
|
||||
current_base_url="https://opencode.ai/zen/go/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "chat_completions"
|
||||
assert result.base_url == "https://opencode.ai/zen/go/v1"
|
||||
|
||||
def test_trailing_slash_also_stripped(self):
|
||||
"""``/v1/`` with trailing slash is also stripped cleanly."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="minimax-m2.7",
|
||||
current_provider="opencode-go",
|
||||
current_model="glm-5",
|
||||
current_base_url="https://opencode.ai/zen/go/v1/",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "anthropic_messages"
|
||||
assert result.base_url == "https://opencode.ai/zen/go"
|
||||
|
||||
|
||||
class TestOpenCodeZenV1Strip:
|
||||
"""OpenCode Zen: ``/model claude-*`` must strip /v1."""
|
||||
|
||||
def test_switch_to_claude_sonnet_strips_v1(self):
|
||||
"""Gemini → Claude on opencode-zen: /v1 stripped."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="claude-sonnet-4-6",
|
||||
current_provider="opencode-zen",
|
||||
current_model="gemini-3-flash",
|
||||
current_base_url="https://opencode.ai/zen/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "anthropic_messages"
|
||||
assert result.base_url == "https://opencode.ai/zen"
|
||||
|
||||
def test_switch_to_gemini_leaves_v1_intact(self):
|
||||
"""Gemini on opencode-zen stays on chat_completions with /v1."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="gemini-3-flash",
|
||||
current_provider="opencode-zen",
|
||||
current_model="claude-sonnet-4-6",
|
||||
current_base_url="https://opencode.ai/zen", # stripped from previous Claude
|
||||
runtime_base_url="https://opencode.ai/zen/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "chat_completions"
|
||||
assert result.base_url == "https://opencode.ai/zen/v1"
|
||||
|
||||
def test_switch_to_gpt_uses_codex_responses_keeps_v1(self):
|
||||
"""GPT on opencode-zen uses codex_responses api_mode — /v1 kept."""
|
||||
result = _run_opencode_switch(
|
||||
raw_input="gpt-5.4",
|
||||
current_provider="opencode-zen",
|
||||
current_model="claude-sonnet-4-6",
|
||||
current_base_url="https://opencode.ai/zen",
|
||||
runtime_base_url="https://opencode.ai/zen/v1",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.api_mode == "codex_responses"
|
||||
assert result.base_url == "https://opencode.ai/zen/v1"
|
||||
|
||||
|
||||
class TestAgentSwitchModelDefenseInDepth:
|
||||
"""run_agent.AIAgent.switch_model() also strips /v1 as defense-in-depth."""
|
||||
|
||||
def test_agent_switch_model_strips_v1_for_anthropic_messages(self):
|
||||
"""Even if a caller hands in a /v1 URL, the agent strips it."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Build a bare agent instance without running __init__; we only want
|
||||
# to exercise switch_model's base_url normalization logic.
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.model = "glm-5"
|
||||
agent.provider = "opencode-go"
|
||||
agent.base_url = "https://opencode.ai/zen/go/v1"
|
||||
agent.api_key = "sk-opencode-fake"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._client_kwargs = {}
|
||||
|
||||
# Intercept the expensive client rebuild — we only need to verify
|
||||
# that base_url was normalized before it reached the Anthropic
|
||||
# client factory.
|
||||
captured = {}
|
||||
|
||||
def _fake_build_anthropic_client(api_key, base_url):
|
||||
captured["api_key"] = api_key
|
||||
captured["base_url"] = base_url
|
||||
return object() # placeholder client — no real calls expected
|
||||
|
||||
# The downstream cache/plumbing touches a bunch of private state
|
||||
# that wasn't initialized above; we don't want to rebuild the full
|
||||
# runtime for this single assertion, so short-circuit after the
|
||||
# strip by raising inside the stubbed factory.
|
||||
class _Sentinel(Exception):
|
||||
pass
|
||||
|
||||
def _raise_after_capture(api_key, base_url):
|
||||
captured["api_key"] = api_key
|
||||
captured["base_url"] = base_url
|
||||
raise _Sentinel("strip verified")
|
||||
|
||||
with patch(
|
||||
"agent.anthropic_adapter.build_anthropic_client",
|
||||
side_effect=_raise_after_capture,
|
||||
), patch("agent.anthropic_adapter.resolve_anthropic_token", return_value=""), patch(
|
||||
"agent.anthropic_adapter._is_oauth_token", return_value=False
|
||||
):
|
||||
with pytest.raises(_Sentinel):
|
||||
agent.switch_model(
|
||||
new_model="minimax-m2.7",
|
||||
new_provider="opencode-go",
|
||||
api_key="sk-opencode-fake",
|
||||
base_url="https://opencode.ai/zen/go/v1",
|
||||
api_mode="anthropic_messages",
|
||||
)
|
||||
|
||||
assert captured.get("base_url") == "https://opencode.ai/zen/go", (
|
||||
f"agent.switch_model did not strip /v1; passed {captured.get('base_url')} "
|
||||
"to build_anthropic_client"
|
||||
)
|
||||
|
|
@ -466,3 +466,90 @@ def test_numeric_mcp_server_name_does_not_crash_sorted():
|
|||
|
||||
# sorted() must not raise TypeError
|
||||
sorted(enabled)
|
||||
|
||||
|
||||
# ─── Imagegen Backend Picker Wiring ────────────────────────────────────────
|
||||
|
||||
class TestImagegenBackendRegistry:
|
||||
"""IMAGEGEN_BACKENDS tags drive the model picker flow in tools_config."""
|
||||
|
||||
def test_fal_backend_registered(self):
|
||||
from hermes_cli.tools_config import IMAGEGEN_BACKENDS
|
||||
assert "fal" in IMAGEGEN_BACKENDS
|
||||
|
||||
def test_fal_catalog_loads_lazily(self):
|
||||
"""catalog_fn should defer import to avoid import cycles."""
|
||||
from hermes_cli.tools_config import IMAGEGEN_BACKENDS
|
||||
catalog, default = IMAGEGEN_BACKENDS["fal"]["catalog_fn"]()
|
||||
assert default == "fal-ai/flux-2/klein/9b"
|
||||
assert "fal-ai/flux-2/klein/9b" in catalog
|
||||
assert "fal-ai/flux-2-pro" in catalog
|
||||
|
||||
def test_image_gen_providers_tagged_with_fal_backend(self):
|
||||
"""Both Nous Subscription and FAL.ai providers must carry the
|
||||
imagegen_backend tag so _configure_provider fires the picker."""
|
||||
from hermes_cli.tools_config import TOOL_CATEGORIES
|
||||
providers = TOOL_CATEGORIES["image_gen"]["providers"]
|
||||
for p in providers:
|
||||
assert p.get("imagegen_backend") == "fal", (
|
||||
f"{p['name']} missing imagegen_backend tag"
|
||||
)
|
||||
|
||||
|
||||
class TestImagegenModelPicker:
|
||||
"""_configure_imagegen_model writes selection to config and respects
|
||||
curses fallback semantics (returns default when stdin isn't a TTY)."""
|
||||
|
||||
def test_picker_writes_chosen_model_to_config(self):
|
||||
from hermes_cli.tools_config import _configure_imagegen_model
|
||||
config = {}
|
||||
# Force _prompt_choice to pick index 1 (second-in-ordered-list).
|
||||
with patch("hermes_cli.tools_config._prompt_choice", return_value=1):
|
||||
_configure_imagegen_model("fal", config)
|
||||
# ordered[0] == current (default klein), ordered[1] == first non-default
|
||||
assert config["image_gen"]["model"] != "fal-ai/flux-2/klein/9b"
|
||||
assert config["image_gen"]["model"].startswith("fal-ai/")
|
||||
|
||||
def test_picker_with_gpt_image_does_not_prompt_quality(self):
|
||||
"""GPT-Image quality is pinned to medium in the tool's defaults —
|
||||
no follow-up prompt, no config write for quality_setting."""
|
||||
from hermes_cli.tools_config import (
|
||||
_configure_imagegen_model,
|
||||
IMAGEGEN_BACKENDS,
|
||||
)
|
||||
catalog, default_model = IMAGEGEN_BACKENDS["fal"]["catalog_fn"]()
|
||||
model_ids = list(catalog.keys())
|
||||
ordered = [default_model] + [m for m in model_ids if m != default_model]
|
||||
gpt_idx = ordered.index("fal-ai/gpt-image-1.5")
|
||||
|
||||
# Only ONE picker call is expected (for model) — not two (model + quality).
|
||||
call_count = {"n": 0}
|
||||
def fake_prompt(*a, **kw):
|
||||
call_count["n"] += 1
|
||||
return gpt_idx
|
||||
|
||||
config = {}
|
||||
with patch("hermes_cli.tools_config._prompt_choice", side_effect=fake_prompt):
|
||||
_configure_imagegen_model("fal", config)
|
||||
|
||||
assert call_count["n"] == 1, (
|
||||
f"Expected 1 picker call (model only), got {call_count['n']}"
|
||||
)
|
||||
assert config["image_gen"]["model"] == "fal-ai/gpt-image-1.5"
|
||||
assert "quality_setting" not in config["image_gen"]
|
||||
|
||||
def test_picker_no_op_for_unknown_backend(self):
|
||||
from hermes_cli.tools_config import _configure_imagegen_model
|
||||
config = {}
|
||||
_configure_imagegen_model("nonexistent-backend", config)
|
||||
assert config == {} # untouched
|
||||
|
||||
def test_picker_repairs_corrupt_config_section(self):
|
||||
"""When image_gen is a non-dict (user-edit YAML), the picker should
|
||||
replace it with a fresh dict rather than crash."""
|
||||
from hermes_cli.tools_config import _configure_imagegen_model
|
||||
config = {"image_gen": "some-garbage-string"}
|
||||
with patch("hermes_cli.tools_config._prompt_choice", return_value=0):
|
||||
_configure_imagegen_model("fal", config)
|
||||
assert isinstance(config["image_gen"], dict)
|
||||
assert config["image_gen"]["model"] == "fal-ai/flux-2/klein/9b"
|
||||
|
|
|
|||
473
tests/tools/test_file_sync_back.py
Normal file
473
tests/tools/test_file_sync_back.py
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||
|
||||
import fcntl
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
_sha256_file,
|
||||
_SYNC_BACK_BACKOFF,
|
||||
_SYNC_BACK_MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||
"""Write a tar archive containing the given arcname->content pairs."""
|
||||
with tarfile.open(dest, "w") as tar:
|
||||
for arcname, content in files.items():
|
||||
info = tarfile.TarInfo(name=arcname)
|
||||
info.size = len(content)
|
||||
tar.addfile(info, io.BytesIO(content))
|
||||
|
||||
|
||||
def _make_download_fn(files: dict[str, bytes]):
|
||||
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||
def download(dest: Path):
|
||||
_make_tar(files, dest)
|
||||
return download
|
||||
|
||||
|
||||
def _sha256_bytes(data: bytes) -> str:
|
||||
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||
import hashlib
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def _write_file(path: Path, content: bytes) -> str:
|
||||
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
def _make_manager(
|
||||
tmp_path: Path,
|
||||
file_mapping: list[tuple[str, str]] | None = None,
|
||||
bulk_download_fn=None,
|
||||
seed_pushed_state: bool = True,
|
||||
) -> FileSyncManager:
|
||||
"""Create a FileSyncManager wired for testing.
|
||||
|
||||
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||
``get_files_fn`` returns. If *None* an empty list is used.
|
||||
|
||||
When *seed_pushed_state* is True (default), populate ``_pushed_hashes``
|
||||
from the mapping so sync_back doesn't early-return on the "nothing
|
||||
previously pushed" guard. Set False to test the noop path.
|
||||
"""
|
||||
mapping = file_mapping or []
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
bulk_download_fn=bulk_download_fn,
|
||||
)
|
||||
if seed_pushed_state:
|
||||
# Seed _pushed_hashes so sync_back's "nothing previously pushed"
|
||||
# guard does not early-return. Populate from the mapping when we
|
||||
# can; otherwise drop a sentinel entry.
|
||||
for host_path, remote_path in mapping:
|
||||
if os.path.exists(host_path):
|
||||
mgr._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||
else:
|
||||
mgr._pushed_hashes[remote_path] = "0" * 64
|
||||
if not mgr._pushed_hashes:
|
||||
mgr._pushed_hashes["/_sentinel"] = "0" * 64
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncBackNoop:
|
||||
"""sync_back() is a no-op when there is no download function."""
|
||||
|
||||
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||
# Should return immediately without error
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
# Nothing to assert beyond "no exception raised"
|
||||
|
||||
|
||||
class TestSyncBackNoChanges:
|
||||
"""When all remote files match pushed hashes, nothing is applied."""
|
||||
|
||||
def test_sync_back_no_changes(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "cred.json"
|
||||
host_content = b'{"key": "val"}'
|
||||
_write_file(host_file, host_content)
|
||||
|
||||
remote_path = "/root/.hermes/cred.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Remote tar contains the same content as was pushed
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/cred.json": host_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# Simulate that we already pushed this file with this hash
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be unchanged (same content, same bytes)
|
||||
assert host_file.read_bytes() == host_content
|
||||
|
||||
|
||||
class TestSyncBackAppliesChanged:
|
||||
"""Remote file differs from pushed version -- gets copied to host."""
|
||||
|
||||
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "skill.py"
|
||||
original_content = b"print('v1')"
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/skill.py"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
remote_content = b"print('v2 - edited on remote')"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skill.py": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackNewRemoteFile:
|
||||
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||
|
||||
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||
# Existing mapping gives _infer_host_path a prefix to work with
|
||||
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||
_write_file(existing_host, b"existing")
|
||||
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||
|
||||
# Remote has a NEW file in the same directory that was never pushed
|
||||
new_remote_content = b"# brand new skill created on remote"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# No entry in _pushed_hashes for the new file
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# The new file should have been inferred and written to the host
|
||||
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||
assert expected_host_path.exists()
|
||||
assert expected_host_path.read_bytes() == new_remote_content
|
||||
|
||||
|
||||
class TestSyncBackConflict:
|
||||
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||
|
||||
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||
host_file = tmp_path / "host" / "config.json"
|
||||
original_content = b'{"v": 1}'
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/config.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Host was modified after push
|
||||
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||
|
||||
# Remote was also modified
|
||||
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/config.json": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Conflict warning was logged
|
||||
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||
|
||||
# Remote version wins (last-write-wins)
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackRetries:
|
||||
"""Retry behaviour with exponential backoff."""
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||
call_count = 0
|
||||
|
||||
def flaky_download(dest: Path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError(f"network error #{call_count}")
|
||||
# Third attempt succeeds -- write a valid (empty) tar
|
||||
_make_tar({}, dest)
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert call_count == 3
|
||||
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||
def always_fail(dest: Path):
|
||||
raise RuntimeError("persistent failure")
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
# Should NOT raise -- failures are logged, not propagated
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# All retries were attempted
|
||||
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||
|
||||
# Final "all attempts failed" warning was logged
|
||||
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
class TestPushedHashesPopulated:
|
||||
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||
|
||||
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||
host_file = tmp_path / "data.txt"
|
||||
host_file.write_bytes(b"hello world")
|
||||
|
||||
remote_path = "/root/.hermes/data.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||
|
||||
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||
host_file = tmp_path / "deleteme.txt"
|
||||
host_file.write_bytes(b"to be deleted")
|
||||
|
||||
remote_path = "/root/.hermes/deleteme.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
current_mapping = list(mapping)
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: current_mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
# Sync to populate hashes
|
||||
mgr.sync(force=True)
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
|
||||
# Remove the file from the mapping (simulates local deletion)
|
||||
os.unlink(str(host_file))
|
||||
current_mapping.clear()
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
# Hash should be cleaned up
|
||||
assert remote_path not in mgr._pushed_hashes
|
||||
|
||||
|
||||
class TestSyncBackFileLock:
|
||||
"""Verify that fcntl.flock is used during sync-back."""
|
||||
|
||||
@patch("tools.environments.file_sync.fcntl.flock")
|
||||
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||
assert mock_flock.call_count >= 2
|
||||
|
||||
lock_calls = mock_flock.call_args_list
|
||||
lock_ops = [c[0][1] for c in lock_calls]
|
||||
assert fcntl.LOCK_EX in lock_ops
|
||||
assert fcntl.LOCK_UN in lock_ops
|
||||
|
||||
def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path):
|
||||
"""On Windows (fcntl=None), sync_back should skip file locking."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
with patch("tools.environments.file_sync.fcntl", None):
|
||||
# Should not raise — locking is skipped
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
|
||||
class TestInferHostPath:
|
||||
"""Edge cases for _infer_host_path prefix matching."""
|
||||
|
||||
def test_infer_no_matching_prefix(self, tmp_path):
|
||||
"""Remote path in unmapped directory should return None."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/cache/new.json",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_partial_prefix_no_false_match(self, tmp_path):
|
||||
"""A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
# /root/.hermes/skillsXtra/b.py shares prefix "skills" but the
|
||||
# directory is different — should not match /root/.hermes/skills/
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skillsXtra/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_infer_matching_prefix(self, tmp_path):
|
||||
"""A file in a mapped directory should be correctly inferred."""
|
||||
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||
_write_file(host_file, b"content")
|
||||
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||
result = mgr._infer_host_path(
|
||||
"/root/.hermes/skills/b.py",
|
||||
file_mapping=mapping,
|
||||
)
|
||||
expected = str(tmp_path / "host" / "skills" / "b.py")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestSyncBackSIGINT:
|
||||
"""SIGINT deferral during sync-back."""
|
||||
|
||||
def test_sync_back_defers_sigint_on_main_thread(self, tmp_path):
|
||||
"""On the main thread, SIGINT handler should be swapped during sync."""
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
handlers_seen = []
|
||||
original_getsignal = signal.getsignal
|
||||
|
||||
with patch("tools.environments.file_sync.signal.getsignal",
|
||||
side_effect=original_getsignal) as mock_get, \
|
||||
patch("tools.environments.file_sync.signal.signal") as mock_set:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# signal.getsignal was called to save the original handler
|
||||
assert mock_get.called
|
||||
# signal.signal was called at least twice: install defer, restore original
|
||||
assert mock_set.call_count >= 2
|
||||
|
||||
def test_sync_back_skips_signal_on_worker_thread(self, tmp_path):
|
||||
"""From a non-main thread, signal.signal should NOT be called."""
|
||||
import threading
|
||||
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
signal_called = []
|
||||
|
||||
def tracking_signal(*args):
|
||||
signal_called.append(args)
|
||||
|
||||
with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal):
|
||||
# Run from a worker thread
|
||||
exc = []
|
||||
def run():
|
||||
try:
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
except Exception as e:
|
||||
exc.append(e)
|
||||
|
||||
t = threading.Thread(target=run)
|
||||
t.start()
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not exc, f"sync_back raised: {exc}"
|
||||
# signal.signal should NOT have been called from the worker thread
|
||||
assert len(signal_called) == 0
|
||||
|
||||
|
||||
class TestSyncBackSizeCap:
|
||||
"""The size cap refuses to extract tars above the configured limit."""
|
||||
|
||||
def test_sync_back_refuses_oversized_tar(self, tmp_path, caplog):
|
||||
"""A tar larger than _SYNC_BACK_MAX_BYTES should be skipped with a warning."""
|
||||
# Build a download_fn that writes a small tar, but patch the cap
|
||||
# so the test doesn't need to produce a 2 GiB file.
|
||||
skill_host = _write_file(tmp_path / "host_skill.md", b"original")
|
||||
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||
download_fn = _make_download_fn(files)
|
||||
|
||||
mgr = _make_manager(
|
||||
tmp_path,
|
||||
file_mapping=[(skill_host, "/root/.hermes/skill.md")],
|
||||
bulk_download_fn=download_fn,
|
||||
)
|
||||
|
||||
# Cap at 1 byte so any non-empty tar exceeds it
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
with patch("tools.environments.file_sync._SYNC_BACK_MAX_BYTES", 1):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be untouched because extraction was skipped
|
||||
assert Path(skill_host).read_bytes() == b"original"
|
||||
# Warning should mention the cap
|
||||
assert any("cap" in r.message for r in caplog.records)
|
||||
|
||||
def test_sync_back_applies_when_under_cap(self, tmp_path):
|
||||
"""A tar under the cap should extract normally (sanity check)."""
|
||||
host_file = _write_file(tmp_path / "host_skill.md", b"original")
|
||||
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||
download_fn = _make_download_fn(files)
|
||||
|
||||
mgr = _make_manager(
|
||||
tmp_path,
|
||||
file_mapping=[(host_file, "/root/.hermes/skill.md")],
|
||||
bulk_download_fn=download_fn,
|
||||
)
|
||||
|
||||
# Default cap (2 GiB) is far above our tiny tar; extraction should proceed
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
assert Path(host_file).read_bytes() == b"remote_version"
|
||||
450
tests/tools/test_image_generation.py
Normal file
450
tests/tools/test_image_generation.py
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
"""Tests for tools/image_generation_tool.py — FAL multi-model support.
|
||||
|
||||
Covers the pure logic of the new wrapper: catalog integrity, the three size
|
||||
families (image_size_preset / aspect_ratio / gpt_literal), the supports
|
||||
whitelist, default merging, GPT quality override, and model resolution
|
||||
fallback. Does NOT exercise fal_client submission — that's covered by
|
||||
tests/tools/test_managed_media_gateways.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def image_tool():
|
||||
"""Fresh import of tools.image_generation_tool per test."""
|
||||
import importlib
|
||||
import tools.image_generation_tool as mod
|
||||
return importlib.reload(mod)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalog integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFalCatalog:
|
||||
"""Every FAL_MODELS entry must have a consistent shape."""
|
||||
|
||||
def test_default_model_is_klein(self, image_tool):
|
||||
assert image_tool.DEFAULT_MODEL == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_default_model_in_catalog(self, image_tool):
|
||||
assert image_tool.DEFAULT_MODEL in image_tool.FAL_MODELS
|
||||
|
||||
def test_all_entries_have_required_keys(self, image_tool):
|
||||
required = {
|
||||
"display", "speed", "strengths", "price",
|
||||
"size_style", "sizes", "defaults", "supports", "upscale",
|
||||
}
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
missing = required - set(meta.keys())
|
||||
assert not missing, f"{mid} missing required keys: {missing}"
|
||||
|
||||
def test_size_style_is_valid(self, image_tool):
|
||||
valid = {"image_size_preset", "aspect_ratio", "gpt_literal"}
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert meta["size_style"] in valid, \
|
||||
f"{mid} has invalid size_style: {meta['size_style']}"
|
||||
|
||||
def test_sizes_cover_all_aspect_ratios(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert set(meta["sizes"].keys()) >= {"landscape", "square", "portrait"}, \
|
||||
f"{mid} missing a required aspect_ratio key"
|
||||
|
||||
def test_supports_is_a_set(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert isinstance(meta["supports"], set), \
|
||||
f"{mid}.supports must be a set, got {type(meta['supports'])}"
|
||||
|
||||
def test_prompt_is_always_supported(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert "prompt" in meta["supports"], \
|
||||
f"{mid} must support 'prompt'"
|
||||
|
||||
def test_only_flux2_pro_upscales_by_default(self, image_tool):
|
||||
"""Upscaling should default to False for all new models to preserve
|
||||
the <1s / fast-render value prop. Only flux-2-pro stays True for
|
||||
backward-compat with the previous default."""
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
if mid == "fal-ai/flux-2-pro":
|
||||
assert meta["upscale"] is True, \
|
||||
"flux-2-pro should keep upscale=True for backward-compat"
|
||||
else:
|
||||
assert meta["upscale"] is False, \
|
||||
f"{mid} should default to upscale=False"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload building — three size families
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageSizePresetFamily:
|
||||
"""Flux, z-image, qwen, recraft, ideogram all use preset enum sizes."""
|
||||
|
||||
def test_klein_landscape_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "landscape")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
assert "aspect_ratio" not in p
|
||||
|
||||
def test_klein_square_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "square")
|
||||
assert p["image_size"] == "square_hd"
|
||||
|
||||
def test_klein_portrait_uses_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "portrait")
|
||||
assert p["image_size"] == "portrait_16_9"
|
||||
|
||||
|
||||
class TestAspectRatioFamily:
|
||||
"""Nano-banana uses aspect_ratio enum, NOT image_size."""
|
||||
|
||||
def test_nano_banana_landscape_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "landscape")
|
||||
assert p["aspect_ratio"] == "16:9"
|
||||
assert "image_size" not in p
|
||||
|
||||
def test_nano_banana_square_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "square")
|
||||
assert p["aspect_ratio"] == "1:1"
|
||||
|
||||
def test_nano_banana_portrait_uses_aspect_ratio(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "portrait")
|
||||
assert p["aspect_ratio"] == "9:16"
|
||||
|
||||
|
||||
class TestGptLiteralFamily:
|
||||
"""GPT-Image 1.5 uses literal size strings."""
|
||||
|
||||
def test_gpt_landscape_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "landscape")
|
||||
assert p["image_size"] == "1536x1024"
|
||||
|
||||
def test_gpt_square_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "square")
|
||||
assert p["image_size"] == "1024x1024"
|
||||
|
||||
def test_gpt_portrait_is_literal(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "portrait")
|
||||
assert p["image_size"] == "1024x1536"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supports whitelist — the main safety property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupportsFilter:
|
||||
"""No model should receive keys outside its `supports` set."""
|
||||
|
||||
def test_payload_keys_are_subset_of_supports_for_all_models(self, image_tool):
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
payload = image_tool._build_fal_payload(mid, "test", "landscape", seed=42)
|
||||
unsupported = set(payload.keys()) - meta["supports"]
|
||||
assert not unsupported, \
|
||||
f"{mid} payload has unsupported keys: {unsupported}"
|
||||
|
||||
def test_gpt_image_has_no_seed_even_if_passed(self, image_tool):
|
||||
# GPT-Image 1.5 does not support seed — the filter must strip it.
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square", seed=42)
|
||||
assert "seed" not in p
|
||||
|
||||
def test_gpt_image_strips_unsupported_overrides(self, image_tool):
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/gpt-image-1.5", "hi", "square",
|
||||
overrides={"guidance_scale": 7.5, "num_inference_steps": 50},
|
||||
)
|
||||
assert "guidance_scale" not in p
|
||||
assert "num_inference_steps" not in p
|
||||
|
||||
def test_recraft_has_minimal_payload(self, image_tool):
|
||||
# Recraft supports prompt, image_size, style only.
|
||||
p = image_tool._build_fal_payload("fal-ai/recraft-v3", "hi", "landscape")
|
||||
assert set(p.keys()) <= {"prompt", "image_size", "style"}
|
||||
|
||||
def test_nano_banana_never_gets_image_size(self, image_tool):
|
||||
# Common bug: translator accidentally setting both image_size and aspect_ratio.
|
||||
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hi", "landscape", seed=1)
|
||||
assert "image_size" not in p
|
||||
assert p["aspect_ratio"] == "16:9"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default merging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDefaults:
|
||||
"""Model-level defaults should carry through unless overridden."""
|
||||
|
||||
def test_klein_default_steps_is_4(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "square")
|
||||
assert p["num_inference_steps"] == 4
|
||||
|
||||
def test_flux_2_pro_default_steps_is_50(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2-pro", "hi", "square")
|
||||
assert p["num_inference_steps"] == 50
|
||||
|
||||
def test_override_replaces_default(self, image_tool):
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/flux-2-pro", "hi", "square", overrides={"num_inference_steps": 25}
|
||||
)
|
||||
assert p["num_inference_steps"] == 25
|
||||
|
||||
def test_none_override_does_not_replace_default(self, image_tool):
|
||||
"""None values from caller should be ignored (use default)."""
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/flux-2-pro", "hi", "square",
|
||||
overrides={"num_inference_steps": None},
|
||||
)
|
||||
assert p["num_inference_steps"] == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPT-Image quality is pinned to medium (not user-configurable)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGptQualityPinnedToMedium:
|
||||
"""GPT-Image quality is baked into the FAL_MODELS defaults at 'medium'
|
||||
and cannot be overridden via config. Pinning keeps Nous Portal billing
|
||||
predictable across all users."""
|
||||
|
||||
def test_gpt_payload_always_has_medium_quality(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_config_quality_setting_is_ignored(self, image_tool):
|
||||
"""Even if a user manually edits config.yaml and adds quality_setting,
|
||||
the payload must still use medium. No code path reads that field."""
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"quality_setting": "high"}}):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_non_gpt_model_never_gets_quality(self, image_tool):
|
||||
"""quality is only meaningful for gpt-image-1.5 — other models should
|
||||
never have it in their payload."""
|
||||
for mid in image_tool.FAL_MODELS:
|
||||
if mid == "fal-ai/gpt-image-1.5":
|
||||
continue
|
||||
p = image_tool._build_fal_payload(mid, "hi", "square")
|
||||
assert "quality" not in p, f"{mid} unexpectedly has 'quality' in payload"
|
||||
|
||||
def test_honors_quality_setting_flag_is_removed(self, image_tool):
|
||||
"""The honors_quality_setting flag was the old override trigger.
|
||||
It must not be present on any model entry anymore."""
|
||||
for mid, meta in image_tool.FAL_MODELS.items():
|
||||
assert "honors_quality_setting" not in meta, (
|
||||
f"{mid} still has honors_quality_setting; "
|
||||
f"remove it — quality is pinned to medium"
|
||||
)
|
||||
|
||||
def test_resolve_gpt_quality_function_is_gone(self, image_tool):
|
||||
"""The _resolve_gpt_quality() helper was removed — quality is now
|
||||
a static default, not a runtime lookup."""
|
||||
assert not hasattr(image_tool, "_resolve_gpt_quality"), (
|
||||
"_resolve_gpt_quality should not exist — quality is pinned"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelResolution:
|
||||
|
||||
def test_no_config_falls_back_to_default(self, image_tool):
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
mid, meta = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_valid_config_model_is_used(self, image_tool):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/flux-2-pro"}}):
|
||||
mid, meta = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2-pro"
|
||||
assert meta["upscale"] is True # flux-2-pro keeps backward-compat upscaling
|
||||
|
||||
def test_unknown_model_falls_back_to_default_with_warning(self, image_tool, caplog):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/nonexistent-9000"}}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_env_var_fallback_when_no_config(self, image_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/z-image/turbo"
|
||||
|
||||
def test_config_wins_over_env_var(self, image_tool, monkeypatch):
|
||||
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"image_gen": {"model": "fal-ai/nano-banana"}}):
|
||||
mid, _ = image_tool._resolve_fal_model()
|
||||
assert mid == "fal-ai/nano-banana"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aspect ratio handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAspectRatioNormalization:
|
||||
|
||||
def test_invalid_aspect_defaults_to_landscape(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "cinemascope")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
|
||||
def test_uppercase_aspect_is_normalized(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "PORTRAIT")
|
||||
assert p["image_size"] == "portrait_16_9"
|
||||
|
||||
def test_empty_aspect_defaults_to_landscape(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "")
|
||||
assert p["image_size"] == "landscape_16_9"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema + registry integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRegistryIntegration:
|
||||
|
||||
def test_schema_exposes_only_prompt_and_aspect_ratio_to_agent(self, image_tool):
|
||||
"""The agent-facing schema must stay tight — model selection is a
|
||||
user-level config choice, not an agent-level arg."""
|
||||
props = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]
|
||||
assert set(props.keys()) == {"prompt", "aspect_ratio"}
|
||||
|
||||
def test_aspect_ratio_enum_is_three_values(self, image_tool):
|
||||
enum = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]["aspect_ratio"]["enum"]
|
||||
assert set(enum) == {"landscape", "square", "portrait"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed gateway 4xx translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MockResponse:
|
||||
def __init__(self, status_code: int):
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class _MockHttpxError(Exception):
|
||||
"""Simulates httpx.HTTPStatusError which exposes .response.status_code."""
|
||||
def __init__(self, status_code: int, message: str = "Bad Request"):
|
||||
super().__init__(message)
|
||||
self.response = _MockResponse(status_code)
|
||||
|
||||
|
||||
class TestExtractHttpStatus:
|
||||
"""Status-code extraction should work across exception shapes."""
|
||||
|
||||
def test_extracts_from_response_attr(self, image_tool):
|
||||
exc = _MockHttpxError(403)
|
||||
assert image_tool._extract_http_status(exc) == 403
|
||||
|
||||
def test_extracts_from_status_code_attr(self, image_tool):
|
||||
exc = Exception("fail")
|
||||
exc.status_code = 404 # type: ignore[attr-defined]
|
||||
assert image_tool._extract_http_status(exc) == 404
|
||||
|
||||
def test_returns_none_for_non_http_exception(self, image_tool):
|
||||
assert image_tool._extract_http_status(ValueError("nope")) is None
|
||||
assert image_tool._extract_http_status(RuntimeError("nope")) is None
|
||||
|
||||
def test_response_attr_without_status_code_returns_none(self, image_tool):
|
||||
class OddResponse:
|
||||
pass
|
||||
exc = Exception("weird")
|
||||
exc.response = OddResponse() # type: ignore[attr-defined]
|
||||
assert image_tool._extract_http_status(exc) is None
|
||||
|
||||
|
||||
class TestManagedGatewayErrorTranslation:
|
||||
"""4xx from the Nous managed gateway should be translated to a user-actionable message."""
|
||||
|
||||
def test_4xx_translates_to_value_error_with_remediation(self, image_tool, monkeypatch):
|
||||
"""403 from managed gateway → ValueError mentioning FAL_KEY + hermes tools."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Simulate: managed mode active, managed submit raises 4xx.
|
||||
managed_gateway = MagicMock()
|
||||
managed_gateway.gateway_origin = "https://fal-queue-gateway.example.com"
|
||||
managed_gateway.nous_user_token = "test-token"
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
bad_request = _MockHttpxError(403, "Forbidden")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = bad_request
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
image_tool._submit_fal_request("fal-ai/nano-banana", {"prompt": "x"})
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "fal-ai/nano-banana" in msg
|
||||
assert "403" in msg
|
||||
assert "FAL_KEY" in msg
|
||||
assert "hermes tools" in msg
|
||||
# Original exception chained for debugging
|
||||
assert exc_info.value.__cause__ is bad_request
|
||||
|
||||
def test_5xx_is_not_translated(self, image_tool, monkeypatch):
|
||||
"""500s are real outages, not model-availability issues — don't rewrite them."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
managed_gateway = MagicMock()
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
server_error = _MockHttpxError(502, "Bad Gateway")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = server_error
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(_MockHttpxError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
|
||||
def test_direct_fal_errors_are_not_translated(self, image_tool, monkeypatch):
|
||||
"""When user has direct FAL_KEY (managed gateway returns None), raw
|
||||
errors from fal_client bubble up unchanged — fal_client already
|
||||
provides reasonable error messages for direct usage."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: None)
|
||||
|
||||
direct_error = _MockHttpxError(403, "Forbidden")
|
||||
fake_fal_client = MagicMock()
|
||||
fake_fal_client.submit.side_effect = direct_error
|
||||
monkeypatch.setattr(image_tool, "fal_client", fake_fal_client)
|
||||
|
||||
with pytest.raises(_MockHttpxError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
|
||||
def test_non_http_exception_from_managed_bubbles_up(self, image_tool, monkeypatch):
|
||||
"""Connection errors, timeouts, etc. from managed mode aren't 4xx —
|
||||
they should bubble up unchanged so callers can retry or diagnose."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
managed_gateway = MagicMock()
|
||||
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||
lambda: managed_gateway)
|
||||
|
||||
conn_error = ConnectionError("network down")
|
||||
mock_managed_client = MagicMock()
|
||||
mock_managed_client.submit.side_effect = conn_error
|
||||
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||
lambda gw: mock_managed_client)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||
495
tests/tools/test_sync_back_backends.py
Normal file
495
tests/tools/test_sync_back_backends.py
Normal file
|
|
@ -0,0 +1,495 @@
|
|||
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments import modal as modal_env
|
||||
from tools.environments import daytona as daytona_env
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssh_mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {
|
||||
"sync": lambda self, **k: None,
|
||||
"sync_back": lambda self: None,
|
||||
})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
# ── Modal helpers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_modal_env():
|
||||
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||
|
||||
Returns the exec_calls list for assertion.
|
||||
"""
|
||||
exec_calls = []
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.read = MagicMock()
|
||||
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls
|
||||
|
||||
|
||||
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_daytona_env():
|
||||
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._sync_manager = None
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
return env
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SSH bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestSSHBulkDownload:
|
||||
"""Unit tests for _ssh_bulk_download."""
|
||||
|
||||
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run command should include tar cf - over SSH."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
mock_run.assert_called_once()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
cmd_str = " ".join(cmd)
|
||||
assert "tar cf -" in cmd_str
|
||||
assert "-C /" in cmd_str
|
||||
assert "home/testuser/.hermes" in cmd_str
|
||||
assert "ssh" in cmd_str
|
||||
assert "testuser@example.com" in cmd_str
|
||||
|
||||
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
# The stdout kwarg should be a file object opened for writing
|
||||
call_kwargs = mock_run.call_args
|
||||
# stdout is passed as a keyword arg
|
||||
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||
# After the context manager exits, the file is closed, but we can verify
|
||||
# the dest path was used by checking if the file was created.
|
||||
assert dest.exists()
|
||||
|
||||
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||
"""Non-zero returncode should raise RuntimeError."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||
"""The subprocess.run call should use a 120s timeout."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||
|
||||
|
||||
class TestSSHCleanup:
|
||||
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||
env.control_socket = Path("/nonexistent/socket")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||
"""sync_back() must run before the ControlMaster exit command."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
# Create a fake control socket so cleanup tries the SSH exit
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||
env.control_socket = Path(tmp.name)
|
||||
|
||||
def mock_run(cmd, **kwargs):
|
||||
cmd_str = " ".join(cmd)
|
||||
if "-O" in cmd and "exit" in cmd_str:
|
||||
call_order.append("control_exit")
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||
env.cleanup()
|
||||
|
||||
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Modal bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestModalBulkDownload:
|
||||
"""Unit tests for _modal_bulk_download."""
|
||||
|
||||
def test_modal_bulk_download_command(self, tmp_path):
|
||||
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||
env = _make_mock_modal_env()
|
||||
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
assert "tar cf -" in args[2]
|
||||
assert "-C / root/.hermes" in args[2]
|
||||
|
||||
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||
"""Downloaded tar bytes should be written to the dest path."""
|
||||
env = _make_mock_modal_env()
|
||||
expected_data = b"some-tar-archive-bytes"
|
||||
_wire_modal_download(env, tar_bytes=expected_data)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.exists()
|
||||
assert dest.read_bytes() == expected_data
|
||||
|
||||
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||
env = _make_mock_modal_env()
|
||||
# Simulate Modal SDK returning str
|
||||
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.read_bytes() == b"string-tar-data"
|
||||
|
||||
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||
"""Non-zero exit code should raise RuntimeError."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, exit_code=1)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||
"""run_coroutine should be called with timeout=120."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, tar_bytes=b"data")
|
||||
|
||||
run_kwargs = {}
|
||||
original_run = env._worker.run_coroutine
|
||||
|
||||
def tracking_run(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
return original_run(coro, **kwargs)
|
||||
|
||||
env._worker.run_coroutine = tracking_run
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
|
||||
class TestModalCleanup:
|
||||
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||
|
||||
def test_modal_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||
env = _make_mock_modal_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
|
||||
# Mock terminate to track call order
|
||||
async def mock_terminate():
|
||||
pass
|
||||
|
||||
env._sandbox.terminate = MagicMock()
|
||||
env._sandbox.terminate.aio = mock_terminate
|
||||
env._worker.run_coroutine = lambda coro, **kw: (
|
||||
call_order.append("terminate"),
|
||||
asyncio.new_event_loop().run_until_complete(coro),
|
||||
)
|
||||
env._worker.stop = lambda: None
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Daytona bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestDaytonaBulkDownload:
|
||||
"""Unit tests for _daytona_bulk_download."""
|
||||
|
||||
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||
"""exec and download_file should both be called."""
|
||||
env = _make_mock_daytona_env()
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
# exec called twice: tar creation + rm cleanup
|
||||
assert env._sandbox.process.exec.call_count == 2
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "tar cf" in tar_cmd
|
||||
# PID-suffixed temp path avoids collisions on sync_back retry
|
||||
assert "/tmp/.hermes_sync." in tar_cmd
|
||||
assert ".tar" in tar_cmd
|
||||
assert ".hermes" in tar_cmd
|
||||
|
||||
cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0]
|
||||
assert "rm -f" in cleanup_cmd
|
||||
assert "/tmp/.hermes_sync." in cleanup_cmd
|
||||
|
||||
# download_file called once with the same PID-suffixed path
|
||||
env._sandbox.fs.download_file.assert_called_once()
|
||||
download_args = env._sandbox.fs.download_file.call_args[0]
|
||||
assert download_args[0].startswith("/tmp/.hermes_sync.")
|
||||
assert download_args[0].endswith(".tar")
|
||||
assert download_args[1] == str(dest)
|
||||
|
||||
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||
"""The tar command should use the env's _remote_home."""
|
||||
env = _make_mock_daytona_env()
|
||||
env._remote_home = "/home/daytona"
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||
assert "home/daytona/.hermes" in tar_cmd
|
||||
|
||||
|
||||
class TestDaytonaCleanup:
|
||||
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||
|
||||
def test_daytona_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||
env = _make_mock_daytona_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
env._sandbox.stop = lambda: call_order.append("stop")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert "stop" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("stop")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestBulkDownloadWiring:
|
||||
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||
|
||||
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class CaptureSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||
|
||||
SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
bulk_download_fn=env._modal_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = daytona_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||
upload_fn=env._daytona_upload,
|
||||
delete_fn=env._daytona_delete,
|
||||
bulk_upload_fn=env._daytona_bulk_upload,
|
||||
bulk_download_fn=env._daytona_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
|
@ -7,6 +7,7 @@ and resumed on next creation, preserving the filesystem across sessions.
|
|||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shlex
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
|
@ -134,6 +135,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
upload_fn=self._daytona_upload,
|
||||
delete_fn=self._daytona_delete,
|
||||
bulk_upload_fn=self._daytona_bulk_upload,
|
||||
bulk_download_fn=self._daytona_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
|
@ -166,6 +168,22 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
]
|
||||
self._sandbox.fs.upload_files(uploads)
|
||||
|
||||
def _daytona_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
# PID-suffixed remote temp path avoids collisions if sync_back fires
|
||||
# concurrently for the same sandbox (e.g. retry after partial failure).
|
||||
remote_tar = f"/tmp/.hermes_sync.{os.getpid()}.tar"
|
||||
self._sandbox.process.exec(
|
||||
f"tar cf {shlex.quote(remote_tar)} -C / {shlex.quote(rel_base)}"
|
||||
)
|
||||
self._sandbox.fs.download_file(remote_tar, str(dest))
|
||||
# Clean up remote temp file
|
||||
try:
|
||||
self._sandbox.process.exec(f"rm -f {shlex.quote(remote_tar)}")
|
||||
except Exception:
|
||||
pass # best-effort cleanup
|
||||
|
||||
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via SDK exec."""
|
||||
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
||||
|
|
@ -216,6 +234,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
with self._lock:
|
||||
if self._sandbox is None:
|
||||
return
|
||||
|
||||
# Sync remote changes back to host before teardown. Running
|
||||
# inside the lock (and after the _sandbox is None guard) avoids
|
||||
# firing sync_back on an already-cleaned-up env, which would
|
||||
# trigger a 3-attempt retry storm against a nil sandbox.
|
||||
if self._sync_manager:
|
||||
logger.info("Daytona: syncing files from sandbox...")
|
||||
try:
|
||||
self._sync_manager.sync_back()
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: sync_back failed: %s", e)
|
||||
|
||||
try:
|
||||
if self._persistent:
|
||||
self._sandbox.stop()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,25 @@ and Daytona. Docker and Singularity use bind mounts (live host FS
|
|||
view) and don't need this.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
fcntl = None # Windows — file locking skipped
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import _file_mtime_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,6 +35,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
|
|||
# Transport callbacks provided by each backend
|
||||
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
||||
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
||||
BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure
|
||||
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
||||
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
||||
|
||||
|
|
@ -71,6 +84,20 @@ def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
|||
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||
|
||||
|
||||
def _sha256_file(path: str) -> str:
|
||||
"""Return hex SHA-256 digest of a file."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
_SYNC_BACK_MAX_RETRIES = 3
|
||||
_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries
|
||||
_SYNC_BACK_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB — refuse to extract larger tars
|
||||
|
||||
|
||||
class FileSyncManager:
|
||||
"""Tracks local file changes and syncs to a remote environment.
|
||||
|
||||
|
|
@ -89,12 +116,15 @@ class FileSyncManager:
|
|||
delete_fn: DeleteFn,
|
||||
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
||||
bulk_upload_fn: BulkUploadFn | None = None,
|
||||
bulk_download_fn: BulkDownloadFn | None = None,
|
||||
):
|
||||
self._get_files_fn = get_files_fn
|
||||
self._upload_fn = upload_fn
|
||||
self._bulk_upload_fn = bulk_upload_fn
|
||||
self._bulk_download_fn = bulk_download_fn
|
||||
self._delete_fn = delete_fn
|
||||
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
||||
self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest
|
||||
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
||||
self._sync_interval = sync_interval
|
||||
|
||||
|
|
@ -136,6 +166,7 @@ class FileSyncManager:
|
|||
|
||||
# Snapshot for rollback (only when there's work to do)
|
||||
prev_files = dict(self._synced_files)
|
||||
prev_hashes = dict(self._pushed_hashes)
|
||||
|
||||
if to_upload:
|
||||
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
||||
|
|
@ -156,13 +187,207 @@ class FileSyncManager:
|
|||
logger.debug("file_sync: deleted %s", to_delete)
|
||||
|
||||
# --- Commit (all succeeded) ---
|
||||
for host_path, remote_path in to_upload:
|
||||
self._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||
|
||||
for p in to_delete:
|
||||
new_files.pop(p, None)
|
||||
self._pushed_hashes.pop(p, None)
|
||||
|
||||
self._synced_files = new_files
|
||||
self._last_sync_time = time.monotonic()
|
||||
|
||||
except Exception as exc:
|
||||
self._synced_files = prev_files
|
||||
self._pushed_hashes = prev_hashes
|
||||
self._last_sync_time = time.monotonic()
|
||||
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync-back: pull remote changes to host on teardown
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def sync_back(self, hermes_home: Path | None = None) -> None:
|
||||
"""Pull remote changes back to the host filesystem.
|
||||
|
||||
Downloads the remote ``.hermes/`` directory as a tar archive,
|
||||
unpacks it, and applies only files that differ from what was
|
||||
originally pushed (based on SHA-256 content hashes).
|
||||
|
||||
Protected against SIGINT (defers the signal until complete) and
|
||||
serialized across concurrent gateway sandboxes via file lock.
|
||||
"""
|
||||
if self._bulk_download_fn is None:
|
||||
return
|
||||
|
||||
# Nothing was ever committed through this manager — the initial
|
||||
# push failed or never ran. Skip sync_back to avoid retry storms
|
||||
# against an uninitialized remote .hermes/ directory.
|
||||
if not self._pushed_hashes and not self._synced_files:
|
||||
logger.debug("sync_back: no prior push state — skipping")
|
||||
return
|
||||
|
||||
lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_SYNC_BACK_MAX_RETRIES):
|
||||
try:
|
||||
self._sync_back_once(lock_path)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _SYNC_BACK_MAX_RETRIES - 1:
|
||||
delay = _SYNC_BACK_BACKOFF[attempt]
|
||||
logger.warning(
|
||||
"sync_back: attempt %d failed (%s), retrying in %ds",
|
||||
attempt + 1, exc, delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
|
||||
|
||||
def _sync_back_once(self, lock_path: Path) -> None:
|
||||
"""Single sync-back attempt with SIGINT protection and file lock."""
|
||||
# signal.signal() only works from the main thread. In gateway
|
||||
# contexts cleanup() may run from a worker thread — skip SIGINT
|
||||
# deferral there rather than crashing.
|
||||
on_main_thread = threading.current_thread() is threading.main_thread()
|
||||
|
||||
deferred_sigint: list[object] = []
|
||||
original_handler = None
|
||||
if on_main_thread:
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def _defer_sigint(signum, frame):
|
||||
deferred_sigint.append((signum, frame))
|
||||
logger.debug("sync_back: SIGINT deferred until sync completes")
|
||||
|
||||
signal.signal(signal.SIGINT, _defer_sigint)
|
||||
try:
|
||||
self._sync_back_locked(lock_path)
|
||||
finally:
|
||||
if on_main_thread and original_handler is not None:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if deferred_sigint:
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
def _sync_back_locked(self, lock_path: Path) -> None:
|
||||
"""Sync-back under file lock (serializes concurrent gateways)."""
|
||||
if fcntl is None:
|
||||
# Windows: no flock — run without serialization
|
||||
self._sync_back_impl()
|
||||
return
|
||||
lock_fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||
self._sync_back_impl()
|
||||
finally:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
lock_fd.close()
|
||||
|
||||
def _sync_back_impl(self) -> None:
|
||||
"""Download, diff, and apply remote changes to host."""
|
||||
if self._bulk_download_fn is None:
|
||||
raise RuntimeError("_sync_back_impl called without bulk_download_fn")
|
||||
|
||||
# Cache file mapping once to avoid O(n*m) from repeated iteration
|
||||
try:
|
||||
file_mapping = list(self._get_files_fn())
|
||||
except Exception:
|
||||
file_mapping = []
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
|
||||
self._bulk_download_fn(Path(tf.name))
|
||||
|
||||
# Defensive size cap: a misbehaving sandbox could produce an
|
||||
# arbitrarily large tar. Refuse to extract if it exceeds the cap.
|
||||
try:
|
||||
tar_size = os.path.getsize(tf.name)
|
||||
except OSError:
|
||||
tar_size = 0
|
||||
if tar_size > _SYNC_BACK_MAX_BYTES:
|
||||
logger.warning(
|
||||
"sync_back: remote tar is %d bytes (cap %d) — skipping extraction",
|
||||
tar_size, _SYNC_BACK_MAX_BYTES,
|
||||
)
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
|
||||
with tarfile.open(tf.name) as tar:
|
||||
tar.extractall(staging, filter="data")
|
||||
|
||||
applied = 0
|
||||
for dirpath, _dirnames, filenames in os.walk(staging):
|
||||
for fname in filenames:
|
||||
staged_file = os.path.join(dirpath, fname)
|
||||
rel = os.path.relpath(staged_file, staging)
|
||||
remote_path = "/" + rel
|
||||
|
||||
pushed_hash = self._pushed_hashes.get(remote_path)
|
||||
|
||||
# Skip hashing for files unchanged from push
|
||||
if pushed_hash is not None:
|
||||
remote_hash = _sha256_file(staged_file)
|
||||
if remote_hash == pushed_hash:
|
||||
continue
|
||||
else:
|
||||
remote_hash = None # new remote file
|
||||
|
||||
# Resolve host path from cached mapping
|
||||
host_path = self._resolve_host_path(remote_path, file_mapping)
|
||||
if host_path is None:
|
||||
host_path = self._infer_host_path(remote_path, file_mapping)
|
||||
if host_path is None:
|
||||
logger.debug(
|
||||
"sync_back: skipping %s (no host mapping)",
|
||||
remote_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if os.path.exists(host_path) and pushed_hash is not None:
|
||||
host_hash = _sha256_file(host_path)
|
||||
if host_hash != pushed_hash:
|
||||
logger.warning(
|
||||
"sync_back: conflict on %s — host modified "
|
||||
"since push, remote also changed. Applying "
|
||||
"remote version (last-write-wins).",
|
||||
remote_path,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||
shutil.copy2(staged_file, host_path)
|
||||
applied += 1
|
||||
|
||||
if applied:
|
||||
logger.info("sync_back: applied %d changed file(s)", applied)
|
||||
else:
|
||||
logger.debug("sync_back: no remote changes detected")
|
||||
|
||||
def _resolve_host_path(self, remote_path: str,
|
||||
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||
"""Find the host path for a known remote path from the file mapping."""
|
||||
mapping = file_mapping if file_mapping is not None else []
|
||||
for host, remote in mapping:
|
||||
if remote == remote_path:
|
||||
return host
|
||||
return None
|
||||
|
||||
def _infer_host_path(self, remote_path: str,
|
||||
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||
"""Infer a host path for a new remote file by matching path prefixes.
|
||||
|
||||
Uses the existing file mapping to find a remote->host directory
|
||||
pair, then applies the same prefix substitution to the new file.
|
||||
For example, if the mapping has ``/root/.hermes/skills/a.md`` →
|
||||
``~/.hermes/skills/a.md``, a new remote file at
|
||||
``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
|
||||
"""
|
||||
mapping = file_mapping if file_mapping is not None else []
|
||||
for host, remote in mapping:
|
||||
remote_dir = str(Path(remote).parent)
|
||||
if remote_path.startswith(remote_dir + "/"):
|
||||
host_dir = str(Path(host).parent)
|
||||
suffix = remote_path[len(remote_dir):]
|
||||
return host_dir + suffix
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
upload_fn=self._modal_upload,
|
||||
delete_fn=self._modal_delete,
|
||||
bulk_upload_fn=self._modal_bulk_upload,
|
||||
bulk_download_fn=self._modal_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
|
@ -347,6 +348,27 @@ class ModalEnvironment(BaseEnvironment):
|
|||
|
||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||
|
||||
def _modal_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive.
|
||||
|
||||
Modal sandboxes always run as root, so /root/.hermes is hardcoded
|
||||
(consistent with iter_sync_files call on line 269).
|
||||
"""
|
||||
async def _download():
|
||||
proc = await self._sandbox.exec.aio(
|
||||
"bash", "-c", "tar cf - -C / root/.hermes"
|
||||
)
|
||||
data = await proc.stdout.read.aio()
|
||||
exit_code = await proc.wait.aio()
|
||||
if exit_code != 0:
|
||||
raise RuntimeError(f"Modal bulk download failed (exit {exit_code})")
|
||||
return data
|
||||
|
||||
tar_bytes = self._worker.run_coroutine(_download(), timeout=120)
|
||||
if isinstance(tar_bytes, str):
|
||||
tar_bytes = tar_bytes.encode()
|
||||
dest.write_bytes(tar_bytes)
|
||||
|
||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via exec."""
|
||||
rm_cmd = quoted_rm_command(remote_paths)
|
||||
|
|
@ -404,6 +426,10 @@ class ModalEnvironment(BaseEnvironment):
|
|||
if self._sandbox is None:
|
||||
return
|
||||
|
||||
if self._sync_manager:
|
||||
logger.info("Modal: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self._persistent:
|
||||
try:
|
||||
async def _snapshot():
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||
upload_fn=self._scp_upload,
|
||||
delete_fn=self._ssh_delete,
|
||||
bulk_upload_fn=self._ssh_bulk_upload,
|
||||
bulk_download_fn=self._ssh_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
|
||||
|
|
@ -216,6 +217,18 @@ class SSHEnvironment(BaseEnvironment):
|
|||
|
||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||
|
||||
def _ssh_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
# Tar from / with the full path so archive entries preserve absolute
|
||||
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
ssh_cmd = self._build_ssh_command()
|
||||
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
||||
with open(dest, "wb") as f:
|
||||
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
||||
|
||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files in one SSH call."""
|
||||
cmd = self._build_ssh_command()
|
||||
|
|
@ -245,6 +258,10 @@ class SSHEnvironment(BaseEnvironment):
|
|||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
if self._sync_manager:
|
||||
logger.info("SSH: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -79,7 +79,7 @@ In addition to built-in tools, Hermes can load tools dynamically from MCP server
|
|||
|
||||
| Tool | Description | Requires environment |
|
||||
|------|-------------|----------------------|
|
||||
| `image_generate` | Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using… | FAL_KEY |
|
||||
| `image_generate` | Generate high-quality images from text prompts using FAL.ai. The underlying model is user-configured (default: FLUX 2 Klein 9B, sub-1s generation) and is not selectable by the agent. Returns a single image URL. Display it using… | FAL_KEY |
|
||||
|
||||
## `memory` toolset
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,35 @@
|
|||
---
|
||||
title: Image Generation
|
||||
description: Generate high-quality images using FLUX 2 Pro with automatic upscaling via FAL.ai.
|
||||
description: Generate images via FAL.ai — 8 models including FLUX 2, GPT-Image, Nano Banana, Ideogram, and more, selectable via `hermes tools`.
|
||||
sidebar_label: Image Generation
|
||||
sidebar_position: 6
|
||||
---
|
||||
|
||||
# Image Generation
|
||||
|
||||
Hermes Agent can generate images from text prompts using FAL.ai's **FLUX 2 Pro** model with automatic 2x upscaling via the **Clarity Upscaler** for enhanced quality.
|
||||
Hermes Agent generates images from text prompts via FAL.ai. Eight models are supported out of the box, each with different speed, quality, and cost tradeoffs. The active model is user-configurable via `hermes tools` and persists in `config.yaml`.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Speed | Strengths | Price |
|
||||
|---|---|---|---|
|
||||
| `fal-ai/flux-2/klein/9b` *(default)* | <1s | Fast, crisp text | $0.006/MP |
|
||||
| `fal-ai/flux-2-pro` | ~6s | Studio photorealism | $0.03/MP |
|
||||
| `fal-ai/z-image/turbo` | ~2s | Bilingual EN/CN, 6B params | $0.005/MP |
|
||||
| `fal-ai/nano-banana` | ~6s | Gemini 2.5, character consistency | $0.08/image |
|
||||
| `fal-ai/gpt-image-1.5` | ~15s | Prompt adherence | $0.034/image |
|
||||
| `fal-ai/ideogram/v3` | ~5s | Best typography | $0.03–0.09/image |
|
||||
| `fal-ai/recraft-v3` | ~8s | Vector art, brand styles | $0.04/image |
|
||||
| `fal-ai/qwen-image` | ~12s | LLM-based, complex text | $0.02/MP |
|
||||
|
||||
Prices are FAL's pricing at time of writing; check [fal.ai](https://fal.ai/) for current numbers.
|
||||
|
||||
## Setup
|
||||
|
||||
:::tip Nous Subscribers
|
||||
If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription, you can use image generation through the **[Tool Gateway](tool-gateway.md)** without a FAL API key. Run `hermes model` or `hermes tools` to enable it.
|
||||
If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription, you can use image generation through the **[Tool Gateway](tool-gateway.md)** without a FAL API key. Your model selection persists across both paths.
|
||||
|
||||
If the managed gateway returns `HTTP 4xx` for a specific model, that model isn't yet proxied on the portal side — the agent will tell you so, with remediation steps (set `FAL_KEY` for direct access, or pick a different model).
|
||||
:::
|
||||
|
||||
### Get a FAL API Key
|
||||
|
|
@ -20,150 +37,117 @@ If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription,
|
|||
1. Sign up at [fal.ai](https://fal.ai/)
|
||||
2. Generate an API key from your dashboard
|
||||
|
||||
### Configure the Key
|
||||
### Configure and Pick a Model
|
||||
|
||||
Run the tools command:
|
||||
|
||||
```bash
|
||||
# Add to ~/.hermes/.env
|
||||
FAL_KEY=your-fal-api-key-here
|
||||
hermes tools
|
||||
```
|
||||
|
||||
### Install the Client Library
|
||||
Navigate to **🎨 Image Generation**, pick your backend (Nous Subscription or FAL.ai), then the picker shows all supported models in a column-aligned table — arrow keys to navigate, Enter to select:
|
||||
|
||||
```bash
|
||||
pip install fal-client
|
||||
```
|
||||
Model Speed Strengths Price
|
||||
fal-ai/flux-2/klein/9b <1s Fast, crisp text $0.006/MP ← currently in use
|
||||
fal-ai/flux-2-pro ~6s Studio photorealism $0.03/MP
|
||||
fal-ai/z-image/turbo ~2s Bilingual EN/CN, 6B $0.005/MP
|
||||
...
|
||||
```
|
||||
|
||||
:::info
|
||||
The image generation tool is automatically available when `FAL_KEY` is set. No additional toolset configuration is needed.
|
||||
:::
|
||||
Your selection is saved to `config.yaml`:
|
||||
|
||||
## How It Works
|
||||
```yaml
|
||||
image_gen:
|
||||
model: fal-ai/flux-2/klein/9b
|
||||
use_gateway: false # true if using Nous Subscription
|
||||
```
|
||||
|
||||
When you ask Hermes to generate an image:
|
||||
### GPT-Image Quality
|
||||
|
||||
1. **Generation** — Your prompt is sent to the FLUX 2 Pro model (`fal-ai/flux-2-pro`)
|
||||
2. **Upscaling** — The generated image is automatically upscaled 2x using the Clarity Upscaler (`fal-ai/clarity-upscaler`)
|
||||
3. **Delivery** — The upscaled image URL is returned
|
||||
|
||||
If upscaling fails for any reason, the original image is returned as a fallback.
|
||||
The `fal-ai/gpt-image-1.5` request quality is pinned to `medium` (~$0.034/image at 1024×1024). We don't expose the `low` / `high` tiers as a user-facing option so that Nous Portal billing stays predictable across all users — the cost spread between tiers is ~22×. If you want a cheaper GPT-Image option, pick a different model; if you want higher quality, use Klein 9B or Imagen-class models.
|
||||
|
||||
## Usage
|
||||
|
||||
Simply ask Hermes to create an image:
|
||||
The agent-facing schema is intentionally minimal — the model picks up whatever you've configured:
|
||||
|
||||
```
|
||||
Generate an image of a serene mountain landscape with cherry blossoms
|
||||
```
|
||||
|
||||
```
|
||||
Create a portrait of a wise old owl perched on an ancient tree branch
|
||||
Create a square portrait of a wise old owl — use the typography model
|
||||
```
|
||||
|
||||
```
|
||||
Make me a futuristic cityscape with flying cars and neon lights
|
||||
Make me a futuristic cityscape, landscape orientation
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
The `image_generate_tool` accepts these parameters:
|
||||
|
||||
| Parameter | Default | Range | Description |
|
||||
|-----------|---------|-------|-------------|
|
||||
| `prompt` | *(required)* | — | Text description of the desired image |
|
||||
| `aspect_ratio` | `"landscape"` | `landscape`, `square`, `portrait` | Image aspect ratio |
|
||||
| `num_inference_steps` | `50` | 1–100 | Number of denoising steps (more = higher quality, slower) |
|
||||
| `guidance_scale` | `4.5` | 0.1–20.0 | How closely to follow the prompt |
|
||||
| `num_images` | `1` | 1–4 | Number of images to generate |
|
||||
| `output_format` | `"png"` | `png`, `jpeg` | Image file format |
|
||||
| `seed` | *(random)* | any integer | Random seed for reproducible results |
|
||||
|
||||
## Aspect Ratios
|
||||
|
||||
The tool uses simplified aspect ratio names that map to FLUX 2 Pro image sizes:
|
||||
Every model accepts the same three aspect ratios from the agent's perspective. Internally, each model's native size spec is filled in automatically:
|
||||
|
||||
| Aspect Ratio | Maps To | Best For |
|
||||
|-------------|---------|----------|
|
||||
| `landscape` | `landscape_16_9` | Wallpapers, banners, scenes |
|
||||
| `square` | `square_hd` | Profile pictures, social media posts |
|
||||
| `portrait` | `portrait_16_9` | Character art, phone wallpapers |
|
||||
| Agent input | image_size (flux/z-image/qwen/recraft/ideogram) | aspect_ratio (nano-banana) | image_size (gpt-image) |
|
||||
|---|---|---|---|
|
||||
| `landscape` | `landscape_16_9` | `16:9` | `1536x1024` |
|
||||
| `square` | `square_hd` | `1:1` | `1024x1024` |
|
||||
| `portrait` | `portrait_16_9` | `9:16` | `1024x1536` |
|
||||
|
||||
:::tip
|
||||
You can also use the raw FLUX 2 Pro size presets directly: `square_hd`, `square`, `portrait_4_3`, `portrait_16_9`, `landscape_4_3`, `landscape_16_9`. Custom sizes up to 2048x2048 are also supported.
|
||||
:::
|
||||
This translation happens in `_build_fal_payload()` — agent code never has to know about per-model schema differences.
|
||||
|
||||
## Automatic Upscaling
|
||||
|
||||
Every generated image is automatically upscaled 2x using FAL.ai's Clarity Upscaler with these settings:
|
||||
Upscaling via FAL's **Clarity Upscaler** is gated per-model:
|
||||
|
||||
| Model | Upscale? | Why |
|
||||
|---|---|---|
|
||||
| `fal-ai/flux-2-pro` | ✓ | Backward-compat (was the pre-picker default) |
|
||||
| All others | ✗ | Fast models would lose their sub-second value prop; hi-res models don't need it |
|
||||
|
||||
When upscaling runs, it uses these settings:
|
||||
|
||||
| Setting | Value |
|
||||
|---------|-------|
|
||||
| Upscale Factor | 2x |
|
||||
|---|---|
|
||||
| Upscale factor | 2× |
|
||||
| Creativity | 0.35 |
|
||||
| Resemblance | 0.6 |
|
||||
| Guidance Scale | 4 |
|
||||
| Inference Steps | 18 |
|
||||
| Positive Prompt | `"masterpiece, best quality, highres"` + your original prompt |
|
||||
| Negative Prompt | `"(worst quality, low quality, normal quality:2)"` |
|
||||
| Guidance scale | 4 |
|
||||
| Inference steps | 18 |
|
||||
|
||||
The upscaler enhances detail and resolution while preserving the original composition. If the upscaler fails (network issue, rate limit), the original resolution image is returned automatically.
|
||||
If upscaling fails (network issue, rate limit), the original image is returned automatically.
|
||||
|
||||
## Example Prompts
|
||||
## How It Works Internally
|
||||
|
||||
Here are some effective prompts to try:
|
||||
|
||||
```
|
||||
A candid street photo of a woman with a pink bob and bold eyeliner
|
||||
```
|
||||
|
||||
```
|
||||
Modern architecture building with glass facade, sunset lighting
|
||||
```
|
||||
|
||||
```
|
||||
Abstract art with vibrant colors and geometric patterns
|
||||
```
|
||||
|
||||
```
|
||||
Portrait of a wise old owl perched on ancient tree branch
|
||||
```
|
||||
|
||||
```
|
||||
Futuristic cityscape with flying cars and neon lights
|
||||
```
|
||||
1. **Model resolution** — `_resolve_fal_model()` reads `image_gen.model` from `config.yaml`, falls back to the `FAL_IMAGE_MODEL` env var, then to `fal-ai/flux-2/klein/9b`.
|
||||
2. **Payload building** — `_build_fal_payload()` translates your `aspect_ratio` into the model's native format (preset enum, aspect-ratio enum, or GPT literal), merges the model's default params, applies any caller overrides, then filters to the model's `supports` whitelist so unsupported keys are never sent.
|
||||
3. **Submission** — `_submit_fal_request()` routes via direct FAL credentials or the managed Nous gateway.
|
||||
4. **Upscaling** — runs only if the model's metadata has `upscale: True`.
|
||||
5. **Delivery** — final image URL returned to the agent, which emits a `MEDIA:<url>` tag that platform adapters convert to native media.
|
||||
|
||||
## Debugging
|
||||
|
||||
Enable debug logging for image generation:
|
||||
Enable debug logging:
|
||||
|
||||
```bash
|
||||
export IMAGE_TOOLS_DEBUG=true
|
||||
```
|
||||
|
||||
Debug logs are saved to `./logs/image_tools_debug_<session_id>.json` with details about each generation request, parameters, timing, and any errors.
|
||||
|
||||
## Safety Settings
|
||||
|
||||
The image generation tool runs with safety checks disabled by default (`safety_tolerance: 5`, the most permissive setting). This is configured at the code level and is not user-adjustable.
|
||||
Debug logs go to `./logs/image_tools_debug_<session_id>.json` with per-call details (model, parameters, timing, errors).
|
||||
|
||||
## Platform Delivery
|
||||
|
||||
Generated images are delivered differently depending on the platform:
|
||||
|
||||
| Platform | Delivery method |
|
||||
|----------|----------------|
|
||||
| **CLI** | Image URL printed as markdown `` — click to open in browser |
|
||||
| **Telegram** | Image sent as a photo message with the prompt as caption |
|
||||
| **Discord** | Image embedded in a message |
|
||||
| **Slack** | Image URL in message (Slack unfurls it) |
|
||||
| **WhatsApp** | Image sent as a media message |
|
||||
| **Other platforms** | Image URL in plain text |
|
||||
|
||||
The agent uses `MEDIA:<url>` syntax in its response, which the platform adapter converts to the appropriate format.
|
||||
| Platform | Delivery |
|
||||
|---|---|
|
||||
| **CLI** | Image URL printed as markdown `` — click to open |
|
||||
| **Telegram** | Photo message with the prompt as caption |
|
||||
| **Discord** | Embedded in a message |
|
||||
| **Slack** | URL unfurled by Slack |
|
||||
| **WhatsApp** | Media message |
|
||||
| **Others** | URL in plain text |
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Requires FAL API key** — image generation incurs API costs on your FAL.ai account
|
||||
- **No image editing** — this is text-to-image only, no inpainting or img2img
|
||||
- **URL-based delivery** — images are returned as temporary FAL.ai URLs, not saved locally. URLs expire after a period (typically hours)
|
||||
- **Upscaling adds latency** — the automatic 2x upscale step adds processing time
|
||||
- **Max 4 images per request** — `num_images` is capped at 4
|
||||
- **Requires FAL credentials** (direct `FAL_KEY` or Nous Subscription)
|
||||
- **Text-to-image only** — no inpainting, img2img, or editing via this tool
|
||||
- **Temporary URLs** — FAL returns hosted URLs that expire after hours/days; save locally if needed
|
||||
- **Per-model constraints** — some models don't support `seed`, `num_inference_steps`, etc. The `supports` filter silently drops unsupported params; this is expected behavior
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ Hermes Agent includes a rich set of capabilities that extend far beyond basic ch
|
|||
- **[Voice Mode](voice-mode.md)** — Full voice interaction across CLI and messaging platforms. Talk to the agent using your microphone, hear spoken replies, and have live voice conversations in Discord voice channels.
|
||||
- **[Browser Automation](browser.md)** — Full browser automation with multiple backends: Browserbase cloud, Browser Use cloud, local Chrome via CDP, or local Chromium. Navigate websites, fill forms, and extract information.
|
||||
- **[Vision & Image Paste](vision.md)** — Multimodal vision support. Paste images from your clipboard into the CLI and ask the agent to analyze, describe, or work with them using any vision-capable model.
|
||||
- **[Image Generation](image-generation.md)** — Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic 2x upscaling via the Clarity Upscaler.
|
||||
- **[Image Generation](image-generation.md)** — Generate images from text prompts using FAL.ai. Eight models supported (FLUX 2 Klein/Pro, GPT-Image 1.5, Nano Banana, Ideogram V3, Recraft V3, Qwen, Z-Image Turbo); pick one via `hermes tools`.
|
||||
- **[Voice & TTS](tts.md)** — Text-to-speech output and voice message transcription across all messaging platforms, with five provider options: Edge TTS (free), ElevenLabs, OpenAI TTS, MiniMax, and NeuTTS.
|
||||
|
||||
## Integrations
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ The **Tool Gateway** lets paid [Nous Portal](https://portal.nousresearch.com) su
|
|||
| Tool | What It Does | Direct Alternative |
|
||||
|------|--------------|--------------------|
|
||||
| **Web search & extract** | Search the web and extract page content via Firecrawl | `FIRECRAWL_API_KEY`, `EXA_API_KEY`, `PARALLEL_API_KEY`, `TAVILY_API_KEY` |
|
||||
| **Image generation** | Generate images via FAL (FLUX 2 Pro + upscaling) | `FAL_KEY` |
|
||||
| **Image generation** | Generate images via FAL (8 models: FLUX 2 Klein/Pro, GPT-Image, Nano Banana, Ideogram, Recraft, Qwen, Z-Image) | `FAL_KEY` |
|
||||
| **Text-to-speech** | Convert text to speech via OpenAI TTS | `VOICE_TOOLS_OPENAI_KEY`, `ELEVENLABS_API_KEY` |
|
||||
| **Browser automation** | Control cloud browsers via Browser Use | `BROWSER_USE_API_KEY`, `BROWSERBASE_API_KEY` |
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue