mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 01:21:43 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
41d3d7afb7
18 changed files with 2877 additions and 484 deletions
|
|
@ -727,6 +727,22 @@ def switch_model(
|
||||||
if not api_mode:
|
if not api_mode:
|
||||||
api_mode = determine_api_mode(target_provider, base_url)
|
api_mode = determine_api_mode(target_provider, base_url)
|
||||||
|
|
||||||
|
# OpenCode base URLs end with /v1 for OpenAI-compatible models, but the
|
||||||
|
# Anthropic SDK prepends its own /v1/messages to the base_url. Strip the
|
||||||
|
# trailing /v1 so the SDK constructs the correct path (e.g.
|
||||||
|
# https://opencode.ai/zen/go/v1/messages instead of .../v1/v1/messages).
|
||||||
|
# Mirrors the same logic in hermes_cli.runtime_provider.resolve_runtime_provider;
|
||||||
|
# without it, /model switches into an anthropic_messages-routed OpenCode
|
||||||
|
# model (e.g. `/model minimax-m2.7` on opencode-go, `/model claude-sonnet-4-6`
|
||||||
|
# on opencode-zen) hit a double /v1 and returned OpenCode's website 404 page.
|
||||||
|
if (
|
||||||
|
api_mode == "anthropic_messages"
|
||||||
|
and target_provider in {"opencode-zen", "opencode-go"}
|
||||||
|
and isinstance(base_url, str)
|
||||||
|
and base_url
|
||||||
|
):
|
||||||
|
base_url = re.sub(r"/v1/?$", "", base_url)
|
||||||
|
|
||||||
# --- Get capabilities (legacy) ---
|
# --- Get capabilities (legacy) ---
|
||||||
capabilities = get_model_capabilities(target_provider, new_model)
|
capabilities = get_model_capabilities(target_provider, new_model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -258,14 +258,16 @@ TOOL_CATEGORIES = {
|
||||||
"requires_nous_auth": True,
|
"requires_nous_auth": True,
|
||||||
"managed_nous_feature": "image_gen",
|
"managed_nous_feature": "image_gen",
|
||||||
"override_env_vars": ["FAL_KEY"],
|
"override_env_vars": ["FAL_KEY"],
|
||||||
|
"imagegen_backend": "fal",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "FAL.ai",
|
"name": "FAL.ai",
|
||||||
"badge": "paid",
|
"badge": "paid",
|
||||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
"tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.",
|
||||||
"env_vars": [
|
"env_vars": [
|
||||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||||
],
|
],
|
||||||
|
"imagegen_backend": "fal",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
@ -950,6 +952,106 @@ def _detect_active_provider_index(providers: list, config: dict) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Image Generation Model Pickers ───────────────────────────────────────────
|
||||||
|
#
|
||||||
|
# IMAGEGEN_BACKENDS is a per-backend catalog. Each entry exposes:
|
||||||
|
# - config_key: top-level config.yaml key for this backend's settings
|
||||||
|
# - model_catalog_fn: returns an OrderedDict-like {model_id: metadata}
|
||||||
|
# - default_model: fallback when nothing is configured
|
||||||
|
#
|
||||||
|
# This prepares for future imagegen backends (Replicate, Stability, etc.):
|
||||||
|
# each new backend registers its own entry; the FAL provider entry in
|
||||||
|
# TOOL_CATEGORIES tags itself with `imagegen_backend: "fal"` to select the
|
||||||
|
# right catalog at picker time.
|
||||||
|
|
||||||
|
|
||||||
|
def _fal_model_catalog():
|
||||||
|
"""Lazy-load the FAL model catalog from the tool module."""
|
||||||
|
from tools.image_generation_tool import FAL_MODELS, DEFAULT_MODEL
|
||||||
|
return FAL_MODELS, DEFAULT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
IMAGEGEN_BACKENDS = {
|
||||||
|
"fal": {
|
||||||
|
"display": "FAL.ai",
|
||||||
|
"config_key": "image_gen",
|
||||||
|
"catalog_fn": _fal_model_catalog,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_imagegen_model_row(model_id: str, meta: dict, widths: dict) -> str:
|
||||||
|
"""Format a single picker row with column-aligned speed / strengths / price."""
|
||||||
|
return (
|
||||||
|
f"{model_id:<{widths['model']}} "
|
||||||
|
f"{meta.get('speed', ''):<{widths['speed']}} "
|
||||||
|
f"{meta.get('strengths', ''):<{widths['strengths']}} "
|
||||||
|
f"{meta.get('price', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_imagegen_model(backend_name: str, config: dict) -> None:
|
||||||
|
"""Prompt the user to pick a model for the given imagegen backend.
|
||||||
|
|
||||||
|
Writes selection to ``config[backend_config_key]["model"]``. Safe to
|
||||||
|
call even when stdin is not a TTY — curses_radiolist falls back to
|
||||||
|
keeping the current selection.
|
||||||
|
"""
|
||||||
|
backend = IMAGEGEN_BACKENDS.get(backend_name)
|
||||||
|
if not backend:
|
||||||
|
return
|
||||||
|
|
||||||
|
catalog, default_model = backend["catalog_fn"]()
|
||||||
|
if not catalog:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg_key = backend["config_key"]
|
||||||
|
cur_cfg = config.setdefault(cfg_key, {})
|
||||||
|
if not isinstance(cur_cfg, dict):
|
||||||
|
cur_cfg = {}
|
||||||
|
config[cfg_key] = cur_cfg
|
||||||
|
current_model = cur_cfg.get("model") or default_model
|
||||||
|
if current_model not in catalog:
|
||||||
|
current_model = default_model
|
||||||
|
|
||||||
|
model_ids = list(catalog.keys())
|
||||||
|
# Put current model at the top so the cursor lands on it by default.
|
||||||
|
ordered = [current_model] + [m for m in model_ids if m != current_model]
|
||||||
|
|
||||||
|
# Column widths
|
||||||
|
widths = {
|
||||||
|
"model": max(len(m) for m in model_ids),
|
||||||
|
"speed": max((len(catalog[m].get("speed", "")) for m in model_ids), default=6),
|
||||||
|
"strengths": max((len(catalog[m].get("strengths", "")) for m in model_ids), default=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
print()
|
||||||
|
header = (
|
||||||
|
f" {'Model':<{widths['model']}} "
|
||||||
|
f"{'Speed':<{widths['speed']}} "
|
||||||
|
f"{'Strengths':<{widths['strengths']}} "
|
||||||
|
f"Price"
|
||||||
|
)
|
||||||
|
print(color(header, Colors.CYAN))
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for mid in ordered:
|
||||||
|
row = _format_imagegen_model_row(mid, catalog[mid], widths)
|
||||||
|
if mid == current_model:
|
||||||
|
row += " ← currently in use"
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
idx = _prompt_choice(
|
||||||
|
f" Choose {backend['display']} model:",
|
||||||
|
rows,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen = ordered[idx]
|
||||||
|
cur_cfg["model"] = chosen
|
||||||
|
_print_success(f" Model set to: {chosen}")
|
||||||
|
|
||||||
|
|
||||||
def _configure_provider(provider: dict, config: dict):
|
def _configure_provider(provider: dict, config: dict):
|
||||||
"""Configure a single provider - prompt for API keys and set config."""
|
"""Configure a single provider - prompt for API keys and set config."""
|
||||||
env_vars = provider.get("env_vars", [])
|
env_vars = provider.get("env_vars", [])
|
||||||
|
|
@ -1006,6 +1108,10 @@ def _configure_provider(provider: dict, config: dict):
|
||||||
_print_success(f" {provider['name']} - no configuration needed!")
|
_print_success(f" {provider['name']} - no configuration needed!")
|
||||||
if managed_feature:
|
if managed_feature:
|
||||||
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
||||||
|
# Imagegen backends prompt for model selection after backend pick.
|
||||||
|
backend = provider.get("imagegen_backend")
|
||||||
|
if backend:
|
||||||
|
_configure_imagegen_model(backend, config)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Prompt for each required env var
|
# Prompt for each required env var
|
||||||
|
|
@ -1040,6 +1146,10 @@ def _configure_provider(provider: dict, config: dict):
|
||||||
|
|
||||||
if all_configured:
|
if all_configured:
|
||||||
_print_success(f" {provider['name']} configured!")
|
_print_success(f" {provider['name']} configured!")
|
||||||
|
# Imagegen backends prompt for model selection after env vars are in.
|
||||||
|
backend = provider.get("imagegen_backend")
|
||||||
|
if backend:
|
||||||
|
_configure_imagegen_model(backend, config)
|
||||||
|
|
||||||
|
|
||||||
def _configure_simple_requirements(ts_key: str):
|
def _configure_simple_requirements(ts_key: str):
|
||||||
|
|
@ -1211,6 +1321,10 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||||
_print_success(f" {provider['name']} - no configuration needed!")
|
_print_success(f" {provider['name']} - no configuration needed!")
|
||||||
if managed_feature:
|
if managed_feature:
|
||||||
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
||||||
|
# Imagegen backends prompt for model selection on reconfig too.
|
||||||
|
backend = provider.get("imagegen_backend")
|
||||||
|
if backend:
|
||||||
|
_configure_imagegen_model(backend, config)
|
||||||
return
|
return
|
||||||
|
|
||||||
for var in env_vars:
|
for var in env_vars:
|
||||||
|
|
@ -1228,6 +1342,11 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||||
else:
|
else:
|
||||||
_print_info(" Kept current")
|
_print_info(" Kept current")
|
||||||
|
|
||||||
|
# Imagegen backends prompt for model selection on reconfig too.
|
||||||
|
backend = provider.get("imagegen_backend")
|
||||||
|
if backend:
|
||||||
|
_configure_imagegen_model(backend, config)
|
||||||
|
|
||||||
|
|
||||||
def _reconfigure_simple_requirements(ts_key: str):
|
def _reconfigure_simple_requirements(ts_key: str):
|
||||||
"""Reconfigure simple env var requirements."""
|
"""Reconfigure simple env var requirements."""
|
||||||
|
|
|
||||||
49
run_agent.py
49
run_agent.py
|
|
@ -1674,12 +1674,26 @@ class AIAgent:
|
||||||
turn-scoped).
|
turn-scoped).
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import re as _re
|
||||||
from hermes_cli.providers import determine_api_mode
|
from hermes_cli.providers import determine_api_mode
|
||||||
|
|
||||||
# ── Determine api_mode if not provided ──
|
# ── Determine api_mode if not provided ──
|
||||||
if not api_mode:
|
if not api_mode:
|
||||||
api_mode = determine_api_mode(new_provider, base_url)
|
api_mode = determine_api_mode(new_provider, base_url)
|
||||||
|
|
||||||
|
# Defense-in-depth: ensure OpenCode base_url doesn't carry a trailing
|
||||||
|
# /v1 into the anthropic_messages client, which would cause the SDK to
|
||||||
|
# hit /v1/v1/messages. `model_switch.switch_model()` already strips
|
||||||
|
# this, but we guard here so any direct callers (future code paths,
|
||||||
|
# tests) can't reintroduce the double-/v1 404 bug.
|
||||||
|
if (
|
||||||
|
api_mode == "anthropic_messages"
|
||||||
|
and new_provider in ("opencode-zen", "opencode-go")
|
||||||
|
and isinstance(base_url, str)
|
||||||
|
and base_url
|
||||||
|
):
|
||||||
|
base_url = _re.sub(r"/v1/?$", "", base_url)
|
||||||
|
|
||||||
old_model = self.model
|
old_model = self.model
|
||||||
old_provider = self.provider
|
old_provider = self.provider
|
||||||
|
|
||||||
|
|
@ -4381,6 +4395,41 @@ class AIAgent:
|
||||||
self._client_log_context(),
|
self._client_log_context(),
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
# Inject TCP keepalives so the kernel detects dead provider connections
|
||||||
|
# instead of letting them sit silently in CLOSE-WAIT (#10324). Without
|
||||||
|
# this, a peer that drops mid-stream leaves the socket in a state where
|
||||||
|
# epoll_wait never fires, ``httpx`` read timeout may not trigger, and
|
||||||
|
# the agent hangs until manually killed. Probes after 30s idle, retry
|
||||||
|
# every 10s, give up after 3 → dead peer detected within ~60s.
|
||||||
|
#
|
||||||
|
# Safety against #10933: the ``client_kwargs = dict(client_kwargs)``
|
||||||
|
# above means this injection only lands in the local per-call copy,
|
||||||
|
# never back into ``self._client_kwargs``. Each ``_create_openai_client``
|
||||||
|
# invocation therefore gets its OWN fresh ``httpx.Client`` whose
|
||||||
|
# lifetime is tied to the OpenAI client it is passed to. When the
|
||||||
|
# OpenAI client is closed (rebuild, teardown, credential rotation),
|
||||||
|
# the paired ``httpx.Client`` closes with it, and the next call
|
||||||
|
# constructs a fresh one — no stale closed transport can be reused.
|
||||||
|
# Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and
|
||||||
|
# ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant.
|
||||||
|
if "http_client" not in client_kwargs:
|
||||||
|
try:
|
||||||
|
import httpx as _httpx
|
||||||
|
import socket as _socket
|
||||||
|
_sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)]
|
||||||
|
if hasattr(_socket, "TCP_KEEPIDLE"):
|
||||||
|
# Linux
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30))
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10))
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3))
|
||||||
|
elif hasattr(_socket, "TCP_KEEPALIVE"):
|
||||||
|
# macOS (uses TCP_KEEPALIVE instead of TCP_KEEPIDLE)
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30))
|
||||||
|
client_kwargs["http_client"] = _httpx.Client(
|
||||||
|
transport=_httpx.HTTPTransport(socket_options=_sock_opts),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Fall through to default transport if socket opts fail
|
||||||
client = OpenAI(**client_kwargs)
|
client = OpenAI(**client_kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"OpenAI client created (%s, shared=%s) %s",
|
"OpenAI client created (%s, shared=%s) %s",
|
||||||
|
|
|
||||||
|
|
@ -122,6 +122,43 @@ log_error() {
|
||||||
echo -e "${RED}✗${NC} $1"
|
echo -e "${RED}✗${NC} $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompt_yes_no() {
|
||||||
|
local question="$1"
|
||||||
|
local default="${2:-yes}"
|
||||||
|
local prompt_suffix
|
||||||
|
local answer=""
|
||||||
|
|
||||||
|
# Use case patterns (not ${var,,}) so this works on bash 3.2 (macOS /bin/bash).
|
||||||
|
case "$default" in
|
||||||
|
[yY]|[yY][eE][sS]|[tT][rR][uU][eE]|1) prompt_suffix="[Y/n]" ;;
|
||||||
|
*) prompt_suffix="[y/N]" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [ "$IS_INTERACTIVE" = true ]; then
|
||||||
|
read -r -p "$question $prompt_suffix " answer || answer=""
|
||||||
|
elif [ -r /dev/tty ] && [ -w /dev/tty ]; then
|
||||||
|
printf "%s %s " "$question" "$prompt_suffix" > /dev/tty
|
||||||
|
IFS= read -r answer < /dev/tty || answer=""
|
||||||
|
else
|
||||||
|
answer=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
answer="${answer#"${answer%%[![:space:]]*}"}"
|
||||||
|
answer="${answer%"${answer##*[![:space:]]}"}"
|
||||||
|
|
||||||
|
if [ -z "$answer" ]; then
|
||||||
|
case "$default" in
|
||||||
|
[yY]|[yY][eE][sS]|[tT][rR][uU][eE]|1) return 0 ;;
|
||||||
|
*) return 1 ;;
|
||||||
|
esac
|
||||||
|
fi
|
||||||
|
|
||||||
|
case "$answer" in
|
||||||
|
[yY]|[yY][eE][sS]) return 0 ;;
|
||||||
|
*) return 1 ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
is_termux() {
|
is_termux() {
|
||||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||||
}
|
}
|
||||||
|
|
@ -606,9 +643,7 @@ install_system_packages() {
|
||||||
echo ""
|
echo ""
|
||||||
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
||||||
log_info "Hermes Agent itself does not require or retain root access."
|
log_info "Hermes Agent itself does not require or retain root access."
|
||||||
read -p "Install ${description}? (requires sudo) [y/N] " -n 1 -r
|
if prompt_yes_no "Install ${description}? (requires sudo)" "no"; then
|
||||||
echo
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
|
||||||
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd; then
|
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd; then
|
||||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||||
|
|
@ -621,9 +656,7 @@ install_system_packages() {
|
||||||
echo ""
|
echo ""
|
||||||
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
log_info "sudo is needed ONLY to install optional system packages (${pkgs[*]}) via your package manager."
|
||||||
log_info "Hermes Agent itself does not require or retain root access."
|
log_info "Hermes Agent itself does not require or retain root access."
|
||||||
read -p "Install ${description}? [Y/n] " -n 1 -r < /dev/tty
|
if prompt_yes_no "Install ${description}?" "yes"; then
|
||||||
echo
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
|
||||||
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd < /dev/tty; then
|
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd < /dev/tty; then
|
||||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||||
|
|
@ -863,9 +896,7 @@ install_deps() {
|
||||||
else
|
else
|
||||||
log_info "sudo is needed ONLY to install build tools (build-essential, python3-dev, libffi-dev) via apt."
|
log_info "sudo is needed ONLY to install build tools (build-essential, python3-dev, libffi-dev) via apt."
|
||||||
log_info "Hermes Agent itself does not require or retain root access."
|
log_info "Hermes Agent itself does not require or retain root access."
|
||||||
read -p "Install build tools? [Y/n] " -n 1 -r < /dev/tty
|
if prompt_yes_no "Install build tools?" "yes"; then
|
||||||
echo
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
|
||||||
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get update -qq && sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get install -y -qq build-essential python3-dev libffi-dev >/dev/null 2>&1 || true
|
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get update -qq && sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a apt-get install -y -qq build-essential python3-dev libffi-dev >/dev/null 2>&1 || true
|
||||||
log_success "Build tools installed"
|
log_success "Build tools installed"
|
||||||
fi
|
fi
|
||||||
|
|
@ -1246,9 +1277,7 @@ maybe_start_gateway() {
|
||||||
log_info "WhatsApp is enabled but not yet paired."
|
log_info "WhatsApp is enabled but not yet paired."
|
||||||
log_info "Running 'hermes whatsapp' to pair via QR code..."
|
log_info "Running 'hermes whatsapp' to pair via QR code..."
|
||||||
echo ""
|
echo ""
|
||||||
read -p "Pair WhatsApp now? [Y/n] " -n 1 -r
|
if prompt_yes_no "Pair WhatsApp now?" "yes"; then
|
||||||
echo
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
|
||||||
HERMES_CMD="$(get_hermes_command_path)"
|
HERMES_CMD="$(get_hermes_command_path)"
|
||||||
$HERMES_CMD whatsapp || true
|
$HERMES_CMD whatsapp || true
|
||||||
fi
|
fi
|
||||||
|
|
@ -1263,14 +1292,18 @@ maybe_start_gateway() {
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
|
local should_install_gateway=false
|
||||||
if [ "$DISTRO" = "termux" ]; then
|
if [ "$DISTRO" = "termux" ]; then
|
||||||
read -p "Would you like to start the gateway in the background? [Y/n] " -n 1 -r < /dev/tty
|
if prompt_yes_no "Would you like to start the gateway in the background?" "yes"; then
|
||||||
else
|
should_install_gateway=true
|
||||||
read -p "Would you like to install the gateway as a background service? [Y/n] " -n 1 -r < /dev/tty
|
fi
|
||||||
|
else
|
||||||
|
if prompt_yes_no "Would you like to install the gateway as a background service?" "yes"; then
|
||||||
|
should_install_gateway=true
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
echo
|
|
||||||
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
if [ "$should_install_gateway" = true ]; then
|
||||||
HERMES_CMD="$(get_hermes_command_path)"
|
HERMES_CMD="$(get_hermes_command_path)"
|
||||||
|
|
||||||
if [ "$DISTRO" != "termux" ] && command -v systemctl &> /dev/null; then
|
if [ "$DISTRO" != "termux" ] && command -v systemctl &> /dev/null; then
|
||||||
|
|
|
||||||
252
tests/hermes_cli/test_model_switch_opencode_anthropic.py
Normal file
252
tests/hermes_cli/test_model_switch_opencode_anthropic.py
Normal file
|
|
@ -0,0 +1,252 @@
|
||||||
|
"""Regression tests for OpenCode /v1 stripping during /model switch.
|
||||||
|
|
||||||
|
When switching to an Anthropic-routed OpenCode model mid-session (e.g.
|
||||||
|
``/model minimax-m2.7`` on opencode-go, or ``/model claude-sonnet-4-6``
|
||||||
|
on opencode-zen), the resolved base_url must have its trailing ``/v1``
|
||||||
|
stripped before being handed to the Anthropic SDK.
|
||||||
|
|
||||||
|
Without the strip, the SDK prepends its own ``/v1/messages`` path and
|
||||||
|
requests hit ``https://opencode.ai/zen/go/v1/v1/messages`` — a double
|
||||||
|
``/v1`` that returns OpenCode's website 404 page with HTML body.
|
||||||
|
|
||||||
|
``hermes_cli.runtime_provider.resolve_runtime_provider`` already strips
|
||||||
|
``/v1`` at fresh agent init (PR #4918), but the ``/model`` mid-session
|
||||||
|
switch path in ``hermes_cli.model_switch.switch_model`` was missing the
|
||||||
|
same logic — these tests guard against that regression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from hermes_cli.model_switch import switch_model
|
||||||
|
|
||||||
|
|
||||||
|
_MOCK_VALIDATION = {
|
||||||
|
"accepted": True,
|
||||||
|
"persist": True,
|
||||||
|
"recognized": True,
|
||||||
|
"message": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _run_opencode_switch(
|
||||||
|
raw_input: str,
|
||||||
|
current_provider: str,
|
||||||
|
current_model: str,
|
||||||
|
current_base_url: str,
|
||||||
|
explicit_provider: str = "",
|
||||||
|
runtime_base_url: str = "",
|
||||||
|
):
|
||||||
|
"""Run switch_model with OpenCode mocks and return the result.
|
||||||
|
|
||||||
|
runtime_base_url defaults to current_base_url; tests can override it
|
||||||
|
to simulate the credential resolver returning a base_url different
|
||||||
|
from the session's current one.
|
||||||
|
"""
|
||||||
|
effective_runtime_base = runtime_base_url or current_base_url
|
||||||
|
with (
|
||||||
|
patch("hermes_cli.model_switch.resolve_alias", return_value=None),
|
||||||
|
patch("hermes_cli.model_switch.list_provider_models", return_value=[]),
|
||||||
|
patch(
|
||||||
|
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||||
|
return_value={
|
||||||
|
"api_key": "sk-opencode-fake",
|
||||||
|
"base_url": effective_runtime_base,
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"hermes_cli.models.validate_requested_model",
|
||||||
|
return_value=_MOCK_VALIDATION,
|
||||||
|
),
|
||||||
|
patch("hermes_cli.model_switch.get_model_info", return_value=None),
|
||||||
|
patch("hermes_cli.model_switch.get_model_capabilities", return_value=None),
|
||||||
|
patch("hermes_cli.models.detect_provider_for_model", return_value=None),
|
||||||
|
):
|
||||||
|
return switch_model(
|
||||||
|
raw_input=raw_input,
|
||||||
|
current_provider=current_provider,
|
||||||
|
current_model=current_model,
|
||||||
|
current_base_url=current_base_url,
|
||||||
|
current_api_key="sk-opencode-fake",
|
||||||
|
explicit_provider=explicit_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenCodeGoV1Strip:
|
||||||
|
"""OpenCode Go: ``/model minimax-*`` must strip /v1."""
|
||||||
|
|
||||||
|
def test_switch_to_minimax_m27_strips_v1(self):
|
||||||
|
"""GLM-5 → MiniMax-M2.7: base_url loses trailing /v1."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="minimax-m2.7",
|
||||||
|
current_provider="opencode-go",
|
||||||
|
current_model="glm-5",
|
||||||
|
current_base_url="https://opencode.ai/zen/go/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success, f"switch_model failed: {result.error_message}"
|
||||||
|
assert result.api_mode == "anthropic_messages"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/go", (
|
||||||
|
f"Expected /v1 stripped for anthropic_messages; got {result.base_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_switch_to_minimax_m25_strips_v1(self):
|
||||||
|
"""Same behavior for M2.5."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="minimax-m2.5",
|
||||||
|
current_provider="opencode-go",
|
||||||
|
current_model="kimi-k2.5",
|
||||||
|
current_base_url="https://opencode.ai/zen/go/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "anthropic_messages"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/go"
|
||||||
|
|
||||||
|
def test_switch_to_glm_leaves_v1_intact(self):
|
||||||
|
"""OpenAI-compatible models (GLM, Kimi, MiMo) keep /v1."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="glm-5.1",
|
||||||
|
current_provider="opencode-go",
|
||||||
|
current_model="minimax-m2.7",
|
||||||
|
current_base_url="https://opencode.ai/zen/go", # stripped from previous Anthropic model
|
||||||
|
runtime_base_url="https://opencode.ai/zen/go/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "chat_completions"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/go/v1", (
|
||||||
|
f"chat_completions must keep /v1; got {result.base_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_switch_to_kimi_leaves_v1_intact(self):
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="kimi-k2.5",
|
||||||
|
current_provider="opencode-go",
|
||||||
|
current_model="glm-5",
|
||||||
|
current_base_url="https://opencode.ai/zen/go/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "chat_completions"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/go/v1"
|
||||||
|
|
||||||
|
def test_trailing_slash_also_stripped(self):
|
||||||
|
"""``/v1/`` with trailing slash is also stripped cleanly."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="minimax-m2.7",
|
||||||
|
current_provider="opencode-go",
|
||||||
|
current_model="glm-5",
|
||||||
|
current_base_url="https://opencode.ai/zen/go/v1/",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "anthropic_messages"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/go"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenCodeZenV1Strip:
|
||||||
|
"""OpenCode Zen: ``/model claude-*`` must strip /v1."""
|
||||||
|
|
||||||
|
def test_switch_to_claude_sonnet_strips_v1(self):
|
||||||
|
"""Gemini → Claude on opencode-zen: /v1 stripped."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="claude-sonnet-4-6",
|
||||||
|
current_provider="opencode-zen",
|
||||||
|
current_model="gemini-3-flash",
|
||||||
|
current_base_url="https://opencode.ai/zen/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "anthropic_messages"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen"
|
||||||
|
|
||||||
|
def test_switch_to_gemini_leaves_v1_intact(self):
|
||||||
|
"""Gemini on opencode-zen stays on chat_completions with /v1."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="gemini-3-flash",
|
||||||
|
current_provider="opencode-zen",
|
||||||
|
current_model="claude-sonnet-4-6",
|
||||||
|
current_base_url="https://opencode.ai/zen", # stripped from previous Claude
|
||||||
|
runtime_base_url="https://opencode.ai/zen/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "chat_completions"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/v1"
|
||||||
|
|
||||||
|
def test_switch_to_gpt_uses_codex_responses_keeps_v1(self):
|
||||||
|
"""GPT on opencode-zen uses codex_responses api_mode — /v1 kept."""
|
||||||
|
result = _run_opencode_switch(
|
||||||
|
raw_input="gpt-5.4",
|
||||||
|
current_provider="opencode-zen",
|
||||||
|
current_model="claude-sonnet-4-6",
|
||||||
|
current_base_url="https://opencode.ai/zen",
|
||||||
|
runtime_base_url="https://opencode.ai/zen/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success
|
||||||
|
assert result.api_mode == "codex_responses"
|
||||||
|
assert result.base_url == "https://opencode.ai/zen/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentSwitchModelDefenseInDepth:
|
||||||
|
"""run_agent.AIAgent.switch_model() also strips /v1 as defense-in-depth."""
|
||||||
|
|
||||||
|
def test_agent_switch_model_strips_v1_for_anthropic_messages(self):
|
||||||
|
"""Even if a caller hands in a /v1 URL, the agent strips it."""
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
# Build a bare agent instance without running __init__; we only want
|
||||||
|
# to exercise switch_model's base_url normalization logic.
|
||||||
|
agent = AIAgent.__new__(AIAgent)
|
||||||
|
agent.model = "glm-5"
|
||||||
|
agent.provider = "opencode-go"
|
||||||
|
agent.base_url = "https://opencode.ai/zen/go/v1"
|
||||||
|
agent.api_key = "sk-opencode-fake"
|
||||||
|
agent.api_mode = "chat_completions"
|
||||||
|
agent._client_kwargs = {}
|
||||||
|
|
||||||
|
# Intercept the expensive client rebuild — we only need to verify
|
||||||
|
# that base_url was normalized before it reached the Anthropic
|
||||||
|
# client factory.
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def _fake_build_anthropic_client(api_key, base_url):
|
||||||
|
captured["api_key"] = api_key
|
||||||
|
captured["base_url"] = base_url
|
||||||
|
return object() # placeholder client — no real calls expected
|
||||||
|
|
||||||
|
# The downstream cache/plumbing touches a bunch of private state
|
||||||
|
# that wasn't initialized above; we don't want to rebuild the full
|
||||||
|
# runtime for this single assertion, so short-circuit after the
|
||||||
|
# strip by raising inside the stubbed factory.
|
||||||
|
class _Sentinel(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _raise_after_capture(api_key, base_url):
|
||||||
|
captured["api_key"] = api_key
|
||||||
|
captured["base_url"] = base_url
|
||||||
|
raise _Sentinel("strip verified")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"agent.anthropic_adapter.build_anthropic_client",
|
||||||
|
side_effect=_raise_after_capture,
|
||||||
|
), patch("agent.anthropic_adapter.resolve_anthropic_token", return_value=""), patch(
|
||||||
|
"agent.anthropic_adapter._is_oauth_token", return_value=False
|
||||||
|
):
|
||||||
|
with pytest.raises(_Sentinel):
|
||||||
|
agent.switch_model(
|
||||||
|
new_model="minimax-m2.7",
|
||||||
|
new_provider="opencode-go",
|
||||||
|
api_key="sk-opencode-fake",
|
||||||
|
base_url="https://opencode.ai/zen/go/v1",
|
||||||
|
api_mode="anthropic_messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured.get("base_url") == "https://opencode.ai/zen/go", (
|
||||||
|
f"agent.switch_model did not strip /v1; passed {captured.get('base_url')} "
|
||||||
|
"to build_anthropic_client"
|
||||||
|
)
|
||||||
|
|
@ -466,3 +466,90 @@ def test_numeric_mcp_server_name_does_not_crash_sorted():
|
||||||
|
|
||||||
# sorted() must not raise TypeError
|
# sorted() must not raise TypeError
|
||||||
sorted(enabled)
|
sorted(enabled)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Imagegen Backend Picker Wiring ────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestImagegenBackendRegistry:
|
||||||
|
"""IMAGEGEN_BACKENDS tags drive the model picker flow in tools_config."""
|
||||||
|
|
||||||
|
def test_fal_backend_registered(self):
|
||||||
|
from hermes_cli.tools_config import IMAGEGEN_BACKENDS
|
||||||
|
assert "fal" in IMAGEGEN_BACKENDS
|
||||||
|
|
||||||
|
def test_fal_catalog_loads_lazily(self):
|
||||||
|
"""catalog_fn should defer import to avoid import cycles."""
|
||||||
|
from hermes_cli.tools_config import IMAGEGEN_BACKENDS
|
||||||
|
catalog, default = IMAGEGEN_BACKENDS["fal"]["catalog_fn"]()
|
||||||
|
assert default == "fal-ai/flux-2/klein/9b"
|
||||||
|
assert "fal-ai/flux-2/klein/9b" in catalog
|
||||||
|
assert "fal-ai/flux-2-pro" in catalog
|
||||||
|
|
||||||
|
def test_image_gen_providers_tagged_with_fal_backend(self):
|
||||||
|
"""Both Nous Subscription and FAL.ai providers must carry the
|
||||||
|
imagegen_backend tag so _configure_provider fires the picker."""
|
||||||
|
from hermes_cli.tools_config import TOOL_CATEGORIES
|
||||||
|
providers = TOOL_CATEGORIES["image_gen"]["providers"]
|
||||||
|
for p in providers:
|
||||||
|
assert p.get("imagegen_backend") == "fal", (
|
||||||
|
f"{p['name']} missing imagegen_backend tag"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestImagegenModelPicker:
|
||||||
|
"""_configure_imagegen_model writes selection to config and respects
|
||||||
|
curses fallback semantics (returns default when stdin isn't a TTY)."""
|
||||||
|
|
||||||
|
def test_picker_writes_chosen_model_to_config(self):
|
||||||
|
from hermes_cli.tools_config import _configure_imagegen_model
|
||||||
|
config = {}
|
||||||
|
# Force _prompt_choice to pick index 1 (second-in-ordered-list).
|
||||||
|
with patch("hermes_cli.tools_config._prompt_choice", return_value=1):
|
||||||
|
_configure_imagegen_model("fal", config)
|
||||||
|
# ordered[0] == current (default klein), ordered[1] == first non-default
|
||||||
|
assert config["image_gen"]["model"] != "fal-ai/flux-2/klein/9b"
|
||||||
|
assert config["image_gen"]["model"].startswith("fal-ai/")
|
||||||
|
|
||||||
|
def test_picker_with_gpt_image_does_not_prompt_quality(self):
|
||||||
|
"""GPT-Image quality is pinned to medium in the tool's defaults —
|
||||||
|
no follow-up prompt, no config write for quality_setting."""
|
||||||
|
from hermes_cli.tools_config import (
|
||||||
|
_configure_imagegen_model,
|
||||||
|
IMAGEGEN_BACKENDS,
|
||||||
|
)
|
||||||
|
catalog, default_model = IMAGEGEN_BACKENDS["fal"]["catalog_fn"]()
|
||||||
|
model_ids = list(catalog.keys())
|
||||||
|
ordered = [default_model] + [m for m in model_ids if m != default_model]
|
||||||
|
gpt_idx = ordered.index("fal-ai/gpt-image-1.5")
|
||||||
|
|
||||||
|
# Only ONE picker call is expected (for model) — not two (model + quality).
|
||||||
|
call_count = {"n": 0}
|
||||||
|
def fake_prompt(*a, **kw):
|
||||||
|
call_count["n"] += 1
|
||||||
|
return gpt_idx
|
||||||
|
|
||||||
|
config = {}
|
||||||
|
with patch("hermes_cli.tools_config._prompt_choice", side_effect=fake_prompt):
|
||||||
|
_configure_imagegen_model("fal", config)
|
||||||
|
|
||||||
|
assert call_count["n"] == 1, (
|
||||||
|
f"Expected 1 picker call (model only), got {call_count['n']}"
|
||||||
|
)
|
||||||
|
assert config["image_gen"]["model"] == "fal-ai/gpt-image-1.5"
|
||||||
|
assert "quality_setting" not in config["image_gen"]
|
||||||
|
|
||||||
|
def test_picker_no_op_for_unknown_backend(self):
|
||||||
|
from hermes_cli.tools_config import _configure_imagegen_model
|
||||||
|
config = {}
|
||||||
|
_configure_imagegen_model("nonexistent-backend", config)
|
||||||
|
assert config == {} # untouched
|
||||||
|
|
||||||
|
def test_picker_repairs_corrupt_config_section(self):
|
||||||
|
"""When image_gen is a non-dict (user-edit YAML), the picker should
|
||||||
|
replace it with a fresh dict rather than crash."""
|
||||||
|
from hermes_cli.tools_config import _configure_imagegen_model
|
||||||
|
config = {"image_gen": "some-garbage-string"}
|
||||||
|
with patch("hermes_cli.tools_config._prompt_choice", return_value=0):
|
||||||
|
_configure_imagegen_model("fal", config)
|
||||||
|
assert isinstance(config["image_gen"], dict)
|
||||||
|
assert config["image_gen"]["model"] == "fal-ai/flux-2/klein/9b"
|
||||||
|
|
|
||||||
473
tests/tools/test_file_sync_back.py
Normal file
473
tests/tools/test_file_sync_back.py
Normal file
|
|
@ -0,0 +1,473 @@
|
||||||
|
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||||
|
|
||||||
|
import fcntl
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments.file_sync import (
|
||||||
|
FileSyncManager,
|
||||||
|
_sha256_file,
|
||||||
|
_SYNC_BACK_BACKOFF,
|
||||||
|
_SYNC_BACK_MAX_RETRIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||||
|
"""Write a tar archive containing the given arcname->content pairs."""
|
||||||
|
with tarfile.open(dest, "w") as tar:
|
||||||
|
for arcname, content in files.items():
|
||||||
|
info = tarfile.TarInfo(name=arcname)
|
||||||
|
info.size = len(content)
|
||||||
|
tar.addfile(info, io.BytesIO(content))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_download_fn(files: dict[str, bytes]):
|
||||||
|
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||||
|
def download(dest: Path):
|
||||||
|
_make_tar(files, dest)
|
||||||
|
return download
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_bytes(data: bytes) -> str:
|
||||||
|
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||||
|
import hashlib
|
||||||
|
return hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _write_file(path: Path, content: bytes) -> str:
|
||||||
|
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(content)
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(
|
||||||
|
tmp_path: Path,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None,
|
||||||
|
bulk_download_fn=None,
|
||||||
|
seed_pushed_state: bool = True,
|
||||||
|
) -> FileSyncManager:
|
||||||
|
"""Create a FileSyncManager wired for testing.
|
||||||
|
|
||||||
|
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||||
|
``get_files_fn`` returns. If *None* an empty list is used.
|
||||||
|
|
||||||
|
When *seed_pushed_state* is True (default), populate ``_pushed_hashes``
|
||||||
|
from the mapping so sync_back doesn't early-return on the "nothing
|
||||||
|
previously pushed" guard. Set False to test the noop path.
|
||||||
|
"""
|
||||||
|
mapping = file_mapping or []
|
||||||
|
mgr = FileSyncManager(
|
||||||
|
get_files_fn=lambda: mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
bulk_download_fn=bulk_download_fn,
|
||||||
|
)
|
||||||
|
if seed_pushed_state:
|
||||||
|
# Seed _pushed_hashes so sync_back's "nothing previously pushed"
|
||||||
|
# guard does not early-return. Populate from the mapping when we
|
||||||
|
# can; otherwise drop a sentinel entry.
|
||||||
|
for host_path, remote_path in mapping:
|
||||||
|
if os.path.exists(host_path):
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||||
|
else:
|
||||||
|
mgr._pushed_hashes[remote_path] = "0" * 64
|
||||||
|
if not mgr._pushed_hashes:
|
||||||
|
mgr._pushed_hashes["/_sentinel"] = "0" * 64
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNoop:
|
||||||
|
"""sync_back() is a no-op when there is no download function."""
|
||||||
|
|
||||||
|
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||||
|
# Should return immediately without error
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
# Nothing to assert beyond "no exception raised"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNoChanges:
|
||||||
|
"""When all remote files match pushed hashes, nothing is applied."""
|
||||||
|
|
||||||
|
def test_sync_back_no_changes(self, tmp_path):
|
||||||
|
host_file = tmp_path / "host" / "cred.json"
|
||||||
|
host_content = b'{"key": "val"}'
|
||||||
|
_write_file(host_file, host_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/cred.json"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
# Remote tar contains the same content as was pushed
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/cred.json": host_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
# Simulate that we already pushed this file with this hash
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# Host file should be unchanged (same content, same bytes)
|
||||||
|
assert host_file.read_bytes() == host_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackAppliesChanged:
|
||||||
|
"""Remote file differs from pushed version -- gets copied to host."""
|
||||||
|
|
||||||
|
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||||
|
host_file = tmp_path / "host" / "skill.py"
|
||||||
|
original_content = b"print('v1')"
|
||||||
|
_write_file(host_file, original_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/skill.py"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
remote_content = b"print('v2 - edited on remote')"
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/skill.py": remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
assert host_file.read_bytes() == remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackNewRemoteFile:
|
||||||
|
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||||
|
|
||||||
|
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||||
|
# Existing mapping gives _infer_host_path a prefix to work with
|
||||||
|
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||||
|
_write_file(existing_host, b"existing")
|
||||||
|
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||||
|
|
||||||
|
# Remote has a NEW file in the same directory that was never pushed
|
||||||
|
new_remote_content = b"# brand new skill created on remote"
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
# No entry in _pushed_hashes for the new file
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# The new file should have been inferred and written to the host
|
||||||
|
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||||
|
assert expected_host_path.exists()
|
||||||
|
assert expected_host_path.read_bytes() == new_remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackConflict:
|
||||||
|
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||||
|
|
||||||
|
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||||
|
host_file = tmp_path / "host" / "config.json"
|
||||||
|
original_content = b'{"v": 1}'
|
||||||
|
_write_file(host_file, original_content)
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/config.json"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
# Host was modified after push
|
||||||
|
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||||
|
|
||||||
|
# Remote was also modified
|
||||||
|
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||||
|
download_fn = _make_download_fn({
|
||||||
|
"root/.hermes/config.json": remote_content,
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||||
|
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# Conflict warning was logged
|
||||||
|
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
# Remote version wins (last-write-wins)
|
||||||
|
assert host_file.read_bytes() == remote_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackRetries:
|
||||||
|
"""Retry behaviour with exponential backoff."""
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.time.sleep")
|
||||||
|
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def flaky_download(dest: Path):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise RuntimeError(f"network error #{call_count}")
|
||||||
|
# Third attempt succeeds -- write a valid (empty) tar
|
||||||
|
_make_tar({}, dest)
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
assert call_count == 3
|
||||||
|
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||||
|
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.time.sleep")
|
||||||
|
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||||
|
def always_fail(dest: Path):
|
||||||
|
raise RuntimeError("persistent failure")
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||||
|
# Should NOT raise -- failures are logged, not propagated
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# All retries were attempted
|
||||||
|
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||||
|
|
||||||
|
# Final "all attempts failed" warning was logged
|
||||||
|
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPushedHashesPopulated:
|
||||||
|
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||||
|
|
||||||
|
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||||
|
host_file = tmp_path / "data.txt"
|
||||||
|
host_file.write_bytes(b"hello world")
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/data.txt"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
|
||||||
|
mgr = FileSyncManager(
|
||||||
|
get_files_fn=lambda: mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr.sync(force=True)
|
||||||
|
|
||||||
|
assert remote_path in mgr._pushed_hashes
|
||||||
|
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||||
|
|
||||||
|
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||||
|
host_file = tmp_path / "deleteme.txt"
|
||||||
|
host_file.write_bytes(b"to be deleted")
|
||||||
|
|
||||||
|
remote_path = "/root/.hermes/deleteme.txt"
|
||||||
|
mapping = [(str(host_file), remote_path)]
|
||||||
|
current_mapping = list(mapping)
|
||||||
|
|
||||||
|
mgr = FileSyncManager(
|
||||||
|
get_files_fn=lambda: current_mapping,
|
||||||
|
upload_fn=MagicMock(),
|
||||||
|
delete_fn=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync to populate hashes
|
||||||
|
mgr.sync(force=True)
|
||||||
|
assert remote_path in mgr._pushed_hashes
|
||||||
|
|
||||||
|
# Remove the file from the mapping (simulates local deletion)
|
||||||
|
os.unlink(str(host_file))
|
||||||
|
current_mapping.clear()
|
||||||
|
|
||||||
|
mgr.sync(force=True)
|
||||||
|
|
||||||
|
# Hash should be cleaned up
|
||||||
|
assert remote_path not in mgr._pushed_hashes
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackFileLock:
|
||||||
|
"""Verify that fcntl.flock is used during sync-back."""
|
||||||
|
|
||||||
|
@patch("tools.environments.file_sync.fcntl.flock")
|
||||||
|
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||||
|
assert mock_flock.call_count >= 2
|
||||||
|
|
||||||
|
lock_calls = mock_flock.call_args_list
|
||||||
|
lock_ops = [c[0][1] for c in lock_calls]
|
||||||
|
assert fcntl.LOCK_EX in lock_ops
|
||||||
|
assert fcntl.LOCK_UN in lock_ops
|
||||||
|
|
||||||
|
def test_sync_back_skips_flock_when_fcntl_none(self, tmp_path):
|
||||||
|
"""On Windows (fcntl=None), sync_back should skip file locking."""
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.fcntl", None):
|
||||||
|
# Should not raise — locking is skipped
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferHostPath:
|
||||||
|
"""Edge cases for _infer_host_path prefix matching."""
|
||||||
|
|
||||||
|
def test_infer_no_matching_prefix(self, tmp_path):
|
||||||
|
"""Remote path in unmapped directory should return None."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/cache/new.json",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_infer_partial_prefix_no_false_match(self, tmp_path):
|
||||||
|
"""A partial prefix like /root/.hermes/sk should NOT match /root/.hermes/skills/."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
# /root/.hermes/skillsXtra/b.py shares prefix "skills" but the
|
||||||
|
# directory is different — should not match /root/.hermes/skills/
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/skillsXtra/b.py",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_infer_matching_prefix(self, tmp_path):
|
||||||
|
"""A file in a mapped directory should be correctly inferred."""
|
||||||
|
host_file = tmp_path / "host" / "skills" / "a.py"
|
||||||
|
_write_file(host_file, b"content")
|
||||||
|
mapping = [(str(host_file), "/root/.hermes/skills/a.py")]
|
||||||
|
|
||||||
|
mgr = _make_manager(tmp_path, file_mapping=mapping)
|
||||||
|
result = mgr._infer_host_path(
|
||||||
|
"/root/.hermes/skills/b.py",
|
||||||
|
file_mapping=mapping,
|
||||||
|
)
|
||||||
|
expected = str(tmp_path / "host" / "skills" / "b.py")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackSIGINT:
|
||||||
|
"""SIGINT deferral during sync-back."""
|
||||||
|
|
||||||
|
def test_sync_back_defers_sigint_on_main_thread(self, tmp_path):
|
||||||
|
"""On the main thread, SIGINT handler should be swapped during sync."""
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
handlers_seen = []
|
||||||
|
original_getsignal = signal.getsignal
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.signal.getsignal",
|
||||||
|
side_effect=original_getsignal) as mock_get, \
|
||||||
|
patch("tools.environments.file_sync.signal.signal") as mock_set:
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# signal.getsignal was called to save the original handler
|
||||||
|
assert mock_get.called
|
||||||
|
# signal.signal was called at least twice: install defer, restore original
|
||||||
|
assert mock_set.call_count >= 2
|
||||||
|
|
||||||
|
def test_sync_back_skips_signal_on_worker_thread(self, tmp_path):
|
||||||
|
"""From a non-main thread, signal.signal should NOT be called."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
download_fn = _make_download_fn({})
|
||||||
|
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||||
|
|
||||||
|
signal_called = []
|
||||||
|
|
||||||
|
def tracking_signal(*args):
|
||||||
|
signal_called.append(args)
|
||||||
|
|
||||||
|
with patch("tools.environments.file_sync.signal.signal", side_effect=tracking_signal):
|
||||||
|
# Run from a worker thread
|
||||||
|
exc = []
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
except Exception as e:
|
||||||
|
exc.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=run)
|
||||||
|
t.start()
|
||||||
|
t.join(timeout=10)
|
||||||
|
|
||||||
|
assert not exc, f"sync_back raised: {exc}"
|
||||||
|
# signal.signal should NOT have been called from the worker thread
|
||||||
|
assert len(signal_called) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBackSizeCap:
|
||||||
|
"""The size cap refuses to extract tars above the configured limit."""
|
||||||
|
|
||||||
|
def test_sync_back_refuses_oversized_tar(self, tmp_path, caplog):
|
||||||
|
"""A tar larger than _SYNC_BACK_MAX_BYTES should be skipped with a warning."""
|
||||||
|
# Build a download_fn that writes a small tar, but patch the cap
|
||||||
|
# so the test doesn't need to produce a 2 GiB file.
|
||||||
|
skill_host = _write_file(tmp_path / "host_skill.md", b"original")
|
||||||
|
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||||
|
download_fn = _make_download_fn(files)
|
||||||
|
|
||||||
|
mgr = _make_manager(
|
||||||
|
tmp_path,
|
||||||
|
file_mapping=[(skill_host, "/root/.hermes/skill.md")],
|
||||||
|
bulk_download_fn=download_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cap at 1 byte so any non-empty tar exceeds it
|
||||||
|
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||||
|
with patch("tools.environments.file_sync._SYNC_BACK_MAX_BYTES", 1):
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
|
||||||
|
# Host file should be untouched because extraction was skipped
|
||||||
|
assert Path(skill_host).read_bytes() == b"original"
|
||||||
|
# Warning should mention the cap
|
||||||
|
assert any("cap" in r.message for r in caplog.records)
|
||||||
|
|
||||||
|
def test_sync_back_applies_when_under_cap(self, tmp_path):
|
||||||
|
"""A tar under the cap should extract normally (sanity check)."""
|
||||||
|
host_file = _write_file(tmp_path / "host_skill.md", b"original")
|
||||||
|
files = {"root/.hermes/skill.md": b"remote_version"}
|
||||||
|
download_fn = _make_download_fn(files)
|
||||||
|
|
||||||
|
mgr = _make_manager(
|
||||||
|
tmp_path,
|
||||||
|
file_mapping=[(host_file, "/root/.hermes/skill.md")],
|
||||||
|
bulk_download_fn=download_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default cap (2 GiB) is far above our tiny tar; extraction should proceed
|
||||||
|
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||||
|
assert Path(host_file).read_bytes() == b"remote_version"
|
||||||
450
tests/tools/test_image_generation.py
Normal file
450
tests/tools/test_image_generation.py
Normal file
|
|
@ -0,0 +1,450 @@
|
||||||
|
"""Tests for tools/image_generation_tool.py — FAL multi-model support.
|
||||||
|
|
||||||
|
Covers the pure logic of the new wrapper: catalog integrity, the three size
|
||||||
|
families (image_size_preset / aspect_ratio / gpt_literal), the supports
|
||||||
|
whitelist, default merging, GPT quality override, and model resolution
|
||||||
|
fallback. Does NOT exercise fal_client submission — that's covered by
|
||||||
|
tests/tools/test_managed_media_gateways.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_tool():
|
||||||
|
"""Fresh import of tools.image_generation_tool per test."""
|
||||||
|
import importlib
|
||||||
|
import tools.image_generation_tool as mod
|
||||||
|
return importlib.reload(mod)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Catalog integrity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFalCatalog:
|
||||||
|
"""Every FAL_MODELS entry must have a consistent shape."""
|
||||||
|
|
||||||
|
def test_default_model_is_klein(self, image_tool):
|
||||||
|
assert image_tool.DEFAULT_MODEL == "fal-ai/flux-2/klein/9b"
|
||||||
|
|
||||||
|
def test_default_model_in_catalog(self, image_tool):
|
||||||
|
assert image_tool.DEFAULT_MODEL in image_tool.FAL_MODELS
|
||||||
|
|
||||||
|
def test_all_entries_have_required_keys(self, image_tool):
|
||||||
|
required = {
|
||||||
|
"display", "speed", "strengths", "price",
|
||||||
|
"size_style", "sizes", "defaults", "supports", "upscale",
|
||||||
|
}
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
missing = required - set(meta.keys())
|
||||||
|
assert not missing, f"{mid} missing required keys: {missing}"
|
||||||
|
|
||||||
|
def test_size_style_is_valid(self, image_tool):
|
||||||
|
valid = {"image_size_preset", "aspect_ratio", "gpt_literal"}
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
assert meta["size_style"] in valid, \
|
||||||
|
f"{mid} has invalid size_style: {meta['size_style']}"
|
||||||
|
|
||||||
|
def test_sizes_cover_all_aspect_ratios(self, image_tool):
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
assert set(meta["sizes"].keys()) >= {"landscape", "square", "portrait"}, \
|
||||||
|
f"{mid} missing a required aspect_ratio key"
|
||||||
|
|
||||||
|
def test_supports_is_a_set(self, image_tool):
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
assert isinstance(meta["supports"], set), \
|
||||||
|
f"{mid}.supports must be a set, got {type(meta['supports'])}"
|
||||||
|
|
||||||
|
def test_prompt_is_always_supported(self, image_tool):
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
assert "prompt" in meta["supports"], \
|
||||||
|
f"{mid} must support 'prompt'"
|
||||||
|
|
||||||
|
def test_only_flux2_pro_upscales_by_default(self, image_tool):
|
||||||
|
"""Upscaling should default to False for all new models to preserve
|
||||||
|
the <1s / fast-render value prop. Only flux-2-pro stays True for
|
||||||
|
backward-compat with the previous default."""
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
if mid == "fal-ai/flux-2-pro":
|
||||||
|
assert meta["upscale"] is True, \
|
||||||
|
"flux-2-pro should keep upscale=True for backward-compat"
|
||||||
|
else:
|
||||||
|
assert meta["upscale"] is False, \
|
||||||
|
f"{mid} should default to upscale=False"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Payload building — three size families
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestImageSizePresetFamily:
|
||||||
|
"""Flux, z-image, qwen, recraft, ideogram all use preset enum sizes."""
|
||||||
|
|
||||||
|
def test_klein_landscape_uses_preset(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "landscape")
|
||||||
|
assert p["image_size"] == "landscape_16_9"
|
||||||
|
assert "aspect_ratio" not in p
|
||||||
|
|
||||||
|
def test_klein_square_uses_preset(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "square")
|
||||||
|
assert p["image_size"] == "square_hd"
|
||||||
|
|
||||||
|
def test_klein_portrait_uses_preset(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hello", "portrait")
|
||||||
|
assert p["image_size"] == "portrait_16_9"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAspectRatioFamily:
|
||||||
|
"""Nano-banana uses aspect_ratio enum, NOT image_size."""
|
||||||
|
|
||||||
|
def test_nano_banana_landscape_uses_aspect_ratio(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "landscape")
|
||||||
|
assert p["aspect_ratio"] == "16:9"
|
||||||
|
assert "image_size" not in p
|
||||||
|
|
||||||
|
def test_nano_banana_square_uses_aspect_ratio(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "square")
|
||||||
|
assert p["aspect_ratio"] == "1:1"
|
||||||
|
|
||||||
|
def test_nano_banana_portrait_uses_aspect_ratio(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hello", "portrait")
|
||||||
|
assert p["aspect_ratio"] == "9:16"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGptLiteralFamily:
|
||||||
|
"""GPT-Image 1.5 uses literal size strings."""
|
||||||
|
|
||||||
|
def test_gpt_landscape_is_literal(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "landscape")
|
||||||
|
assert p["image_size"] == "1536x1024"
|
||||||
|
|
||||||
|
def test_gpt_square_is_literal(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "square")
|
||||||
|
assert p["image_size"] == "1024x1024"
|
||||||
|
|
||||||
|
def test_gpt_portrait_is_literal(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hello", "portrait")
|
||||||
|
assert p["image_size"] == "1024x1536"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Supports whitelist — the main safety property
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSupportsFilter:
|
||||||
|
"""No model should receive keys outside its `supports` set."""
|
||||||
|
|
||||||
|
def test_payload_keys_are_subset_of_supports_for_all_models(self, image_tool):
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
payload = image_tool._build_fal_payload(mid, "test", "landscape", seed=42)
|
||||||
|
unsupported = set(payload.keys()) - meta["supports"]
|
||||||
|
assert not unsupported, \
|
||||||
|
f"{mid} payload has unsupported keys: {unsupported}"
|
||||||
|
|
||||||
|
def test_gpt_image_has_no_seed_even_if_passed(self, image_tool):
|
||||||
|
# GPT-Image 1.5 does not support seed — the filter must strip it.
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square", seed=42)
|
||||||
|
assert "seed" not in p
|
||||||
|
|
||||||
|
def test_gpt_image_strips_unsupported_overrides(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload(
|
||||||
|
"fal-ai/gpt-image-1.5", "hi", "square",
|
||||||
|
overrides={"guidance_scale": 7.5, "num_inference_steps": 50},
|
||||||
|
)
|
||||||
|
assert "guidance_scale" not in p
|
||||||
|
assert "num_inference_steps" not in p
|
||||||
|
|
||||||
|
def test_recraft_has_minimal_payload(self, image_tool):
|
||||||
|
# Recraft supports prompt, image_size, style only.
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/recraft-v3", "hi", "landscape")
|
||||||
|
assert set(p.keys()) <= {"prompt", "image_size", "style"}
|
||||||
|
|
||||||
|
def test_nano_banana_never_gets_image_size(self, image_tool):
|
||||||
|
# Common bug: translator accidentally setting both image_size and aspect_ratio.
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/nano-banana", "hi", "landscape", seed=1)
|
||||||
|
assert "image_size" not in p
|
||||||
|
assert p["aspect_ratio"] == "16:9"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Default merging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDefaults:
|
||||||
|
"""Model-level defaults should carry through unless overridden."""
|
||||||
|
|
||||||
|
def test_klein_default_steps_is_4(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "square")
|
||||||
|
assert p["num_inference_steps"] == 4
|
||||||
|
|
||||||
|
def test_flux_2_pro_default_steps_is_50(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2-pro", "hi", "square")
|
||||||
|
assert p["num_inference_steps"] == 50
|
||||||
|
|
||||||
|
def test_override_replaces_default(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload(
|
||||||
|
"fal-ai/flux-2-pro", "hi", "square", overrides={"num_inference_steps": 25}
|
||||||
|
)
|
||||||
|
assert p["num_inference_steps"] == 25
|
||||||
|
|
||||||
|
def test_none_override_does_not_replace_default(self, image_tool):
|
||||||
|
"""None values from caller should be ignored (use default)."""
|
||||||
|
p = image_tool._build_fal_payload(
|
||||||
|
"fal-ai/flux-2-pro", "hi", "square",
|
||||||
|
overrides={"num_inference_steps": None},
|
||||||
|
)
|
||||||
|
assert p["num_inference_steps"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GPT-Image quality is pinned to medium (not user-configurable)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGptQualityPinnedToMedium:
|
||||||
|
"""GPT-Image quality is baked into the FAL_MODELS defaults at 'medium'
|
||||||
|
and cannot be overridden via config. Pinning keeps Nous Portal billing
|
||||||
|
predictable across all users."""
|
||||||
|
|
||||||
|
def test_gpt_payload_always_has_medium_quality(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||||
|
assert p["quality"] == "medium"
|
||||||
|
|
||||||
|
def test_config_quality_setting_is_ignored(self, image_tool):
|
||||||
|
"""Even if a user manually edits config.yaml and adds quality_setting,
|
||||||
|
the payload must still use medium. No code path reads that field."""
|
||||||
|
with patch("hermes_cli.config.load_config",
|
||||||
|
return_value={"image_gen": {"quality_setting": "high"}}):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/gpt-image-1.5", "hi", "square")
|
||||||
|
assert p["quality"] == "medium"
|
||||||
|
|
||||||
|
def test_non_gpt_model_never_gets_quality(self, image_tool):
|
||||||
|
"""quality is only meaningful for gpt-image-1.5 — other models should
|
||||||
|
never have it in their payload."""
|
||||||
|
for mid in image_tool.FAL_MODELS:
|
||||||
|
if mid == "fal-ai/gpt-image-1.5":
|
||||||
|
continue
|
||||||
|
p = image_tool._build_fal_payload(mid, "hi", "square")
|
||||||
|
assert "quality" not in p, f"{mid} unexpectedly has 'quality' in payload"
|
||||||
|
|
||||||
|
def test_honors_quality_setting_flag_is_removed(self, image_tool):
|
||||||
|
"""The honors_quality_setting flag was the old override trigger.
|
||||||
|
It must not be present on any model entry anymore."""
|
||||||
|
for mid, meta in image_tool.FAL_MODELS.items():
|
||||||
|
assert "honors_quality_setting" not in meta, (
|
||||||
|
f"{mid} still has honors_quality_setting; "
|
||||||
|
f"remove it — quality is pinned to medium"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resolve_gpt_quality_function_is_gone(self, image_tool):
|
||||||
|
"""The _resolve_gpt_quality() helper was removed — quality is now
|
||||||
|
a static default, not a runtime lookup."""
|
||||||
|
assert not hasattr(image_tool, "_resolve_gpt_quality"), (
|
||||||
|
"_resolve_gpt_quality should not exist — quality is pinned"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model resolution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestModelResolution:
|
||||||
|
|
||||||
|
def test_no_config_falls_back_to_default(self, image_tool):
|
||||||
|
with patch("hermes_cli.config.load_config", return_value={}):
|
||||||
|
mid, meta = image_tool._resolve_fal_model()
|
||||||
|
assert mid == "fal-ai/flux-2/klein/9b"
|
||||||
|
|
||||||
|
def test_valid_config_model_is_used(self, image_tool):
|
||||||
|
with patch("hermes_cli.config.load_config",
|
||||||
|
return_value={"image_gen": {"model": "fal-ai/flux-2-pro"}}):
|
||||||
|
mid, meta = image_tool._resolve_fal_model()
|
||||||
|
assert mid == "fal-ai/flux-2-pro"
|
||||||
|
assert meta["upscale"] is True # flux-2-pro keeps backward-compat upscaling
|
||||||
|
|
||||||
|
def test_unknown_model_falls_back_to_default_with_warning(self, image_tool, caplog):
|
||||||
|
with patch("hermes_cli.config.load_config",
|
||||||
|
return_value={"image_gen": {"model": "fal-ai/nonexistent-9000"}}):
|
||||||
|
mid, _ = image_tool._resolve_fal_model()
|
||||||
|
assert mid == "fal-ai/flux-2/klein/9b"
|
||||||
|
|
||||||
|
def test_env_var_fallback_when_no_config(self, image_tool, monkeypatch):
|
||||||
|
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||||
|
with patch("hermes_cli.config.load_config", return_value={}):
|
||||||
|
mid, _ = image_tool._resolve_fal_model()
|
||||||
|
assert mid == "fal-ai/z-image/turbo"
|
||||||
|
|
||||||
|
def test_config_wins_over_env_var(self, image_tool, monkeypatch):
|
||||||
|
monkeypatch.setenv("FAL_IMAGE_MODEL", "fal-ai/z-image/turbo")
|
||||||
|
with patch("hermes_cli.config.load_config",
|
||||||
|
return_value={"image_gen": {"model": "fal-ai/nano-banana"}}):
|
||||||
|
mid, _ = image_tool._resolve_fal_model()
|
||||||
|
assert mid == "fal-ai/nano-banana"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Aspect ratio handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAspectRatioNormalization:
|
||||||
|
|
||||||
|
def test_invalid_aspect_defaults_to_landscape(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "cinemascope")
|
||||||
|
assert p["image_size"] == "landscape_16_9"
|
||||||
|
|
||||||
|
def test_uppercase_aspect_is_normalized(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "PORTRAIT")
|
||||||
|
assert p["image_size"] == "portrait_16_9"
|
||||||
|
|
||||||
|
def test_empty_aspect_defaults_to_landscape(self, image_tool):
|
||||||
|
p = image_tool._build_fal_payload("fal-ai/flux-2/klein/9b", "hi", "")
|
||||||
|
assert p["image_size"] == "landscape_16_9"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema + registry integrity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRegistryIntegration:
|
||||||
|
|
||||||
|
def test_schema_exposes_only_prompt_and_aspect_ratio_to_agent(self, image_tool):
|
||||||
|
"""The agent-facing schema must stay tight — model selection is a
|
||||||
|
user-level config choice, not an agent-level arg."""
|
||||||
|
props = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]
|
||||||
|
assert set(props.keys()) == {"prompt", "aspect_ratio"}
|
||||||
|
|
||||||
|
def test_aspect_ratio_enum_is_three_values(self, image_tool):
|
||||||
|
enum = image_tool.IMAGE_GENERATE_SCHEMA["parameters"]["properties"]["aspect_ratio"]["enum"]
|
||||||
|
assert set(enum) == {"landscape", "square", "portrait"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Managed gateway 4xx translation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class _MockResponse:
|
||||||
|
def __init__(self, status_code: int):
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class _MockHttpxError(Exception):
|
||||||
|
"""Simulates httpx.HTTPStatusError which exposes .response.status_code."""
|
||||||
|
def __init__(self, status_code: int, message: str = "Bad Request"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.response = _MockResponse(status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractHttpStatus:
|
||||||
|
"""Status-code extraction should work across exception shapes."""
|
||||||
|
|
||||||
|
def test_extracts_from_response_attr(self, image_tool):
|
||||||
|
exc = _MockHttpxError(403)
|
||||||
|
assert image_tool._extract_http_status(exc) == 403
|
||||||
|
|
||||||
|
def test_extracts_from_status_code_attr(self, image_tool):
|
||||||
|
exc = Exception("fail")
|
||||||
|
exc.status_code = 404 # type: ignore[attr-defined]
|
||||||
|
assert image_tool._extract_http_status(exc) == 404
|
||||||
|
|
||||||
|
def test_returns_none_for_non_http_exception(self, image_tool):
|
||||||
|
assert image_tool._extract_http_status(ValueError("nope")) is None
|
||||||
|
assert image_tool._extract_http_status(RuntimeError("nope")) is None
|
||||||
|
|
||||||
|
def test_response_attr_without_status_code_returns_none(self, image_tool):
|
||||||
|
class OddResponse:
|
||||||
|
pass
|
||||||
|
exc = Exception("weird")
|
||||||
|
exc.response = OddResponse() # type: ignore[attr-defined]
|
||||||
|
assert image_tool._extract_http_status(exc) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestManagedGatewayErrorTranslation:
|
||||||
|
"""4xx from the Nous managed gateway should be translated to a user-actionable message."""
|
||||||
|
|
||||||
|
def test_4xx_translates_to_value_error_with_remediation(self, image_tool, monkeypatch):
|
||||||
|
"""403 from managed gateway → ValueError mentioning FAL_KEY + hermes tools."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Simulate: managed mode active, managed submit raises 4xx.
|
||||||
|
managed_gateway = MagicMock()
|
||||||
|
managed_gateway.gateway_origin = "https://fal-queue-gateway.example.com"
|
||||||
|
managed_gateway.nous_user_token = "test-token"
|
||||||
|
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||||
|
lambda: managed_gateway)
|
||||||
|
|
||||||
|
bad_request = _MockHttpxError(403, "Forbidden")
|
||||||
|
mock_managed_client = MagicMock()
|
||||||
|
mock_managed_client.submit.side_effect = bad_request
|
||||||
|
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||||
|
lambda gw: mock_managed_client)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
image_tool._submit_fal_request("fal-ai/nano-banana", {"prompt": "x"})
|
||||||
|
|
||||||
|
msg = str(exc_info.value)
|
||||||
|
assert "fal-ai/nano-banana" in msg
|
||||||
|
assert "403" in msg
|
||||||
|
assert "FAL_KEY" in msg
|
||||||
|
assert "hermes tools" in msg
|
||||||
|
# Original exception chained for debugging
|
||||||
|
assert exc_info.value.__cause__ is bad_request
|
||||||
|
|
||||||
|
def test_5xx_is_not_translated(self, image_tool, monkeypatch):
|
||||||
|
"""500s are real outages, not model-availability issues — don't rewrite them."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
managed_gateway = MagicMock()
|
||||||
|
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||||
|
lambda: managed_gateway)
|
||||||
|
|
||||||
|
server_error = _MockHttpxError(502, "Bad Gateway")
|
||||||
|
mock_managed_client = MagicMock()
|
||||||
|
mock_managed_client.submit.side_effect = server_error
|
||||||
|
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||||
|
lambda gw: mock_managed_client)
|
||||||
|
|
||||||
|
with pytest.raises(_MockHttpxError):
|
||||||
|
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||||
|
|
||||||
|
def test_direct_fal_errors_are_not_translated(self, image_tool, monkeypatch):
|
||||||
|
"""When user has direct FAL_KEY (managed gateway returns None), raw
|
||||||
|
errors from fal_client bubble up unchanged — fal_client already
|
||||||
|
provides reasonable error messages for direct usage."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||||
|
lambda: None)
|
||||||
|
|
||||||
|
direct_error = _MockHttpxError(403, "Forbidden")
|
||||||
|
fake_fal_client = MagicMock()
|
||||||
|
fake_fal_client.submit.side_effect = direct_error
|
||||||
|
monkeypatch.setattr(image_tool, "fal_client", fake_fal_client)
|
||||||
|
|
||||||
|
with pytest.raises(_MockHttpxError):
|
||||||
|
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||||
|
|
||||||
|
def test_non_http_exception_from_managed_bubbles_up(self, image_tool, monkeypatch):
|
||||||
|
"""Connection errors, timeouts, etc. from managed mode aren't 4xx —
|
||||||
|
they should bubble up unchanged so callers can retry or diagnose."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
managed_gateway = MagicMock()
|
||||||
|
monkeypatch.setattr(image_tool, "_resolve_managed_fal_gateway",
|
||||||
|
lambda: managed_gateway)
|
||||||
|
|
||||||
|
conn_error = ConnectionError("network down")
|
||||||
|
mock_managed_client = MagicMock()
|
||||||
|
mock_managed_client.submit.side_effect = conn_error
|
||||||
|
monkeypatch.setattr(image_tool, "_get_managed_fal_client",
|
||||||
|
lambda gw: mock_managed_client)
|
||||||
|
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
image_tool._submit_fal_request("fal-ai/flux-2-pro", {"prompt": "x"})
|
||||||
495
tests/tools/test_sync_back_backends.py
Normal file
495
tests/tools/test_sync_back_backends.py
Normal file
|
|
@ -0,0 +1,495 @@
|
||||||
|
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.environments import ssh as ssh_env
|
||||||
|
from tools.environments import modal as modal_env
|
||||||
|
from tools.environments import daytona as daytona_env
|
||||||
|
from tools.environments.ssh import SSHEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ssh_mock_env(monkeypatch):
|
||||||
|
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ssh_env, "FileSyncManager",
|
||||||
|
lambda **kw: type("M", (), {
|
||||||
|
"sync": lambda self, **k: None,
|
||||||
|
"sync_back": lambda self: None,
|
||||||
|
})(),
|
||||||
|
)
|
||||||
|
return SSHEnvironment(host="example.com", user="testuser")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Modal helpers ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_modal_env():
|
||||||
|
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
env._sync_manager = None
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||||
|
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||||
|
|
||||||
|
Returns the exec_calls list for assertion.
|
||||||
|
"""
|
||||||
|
exec_calls = []
|
||||||
|
|
||||||
|
async def mock_exec_fn(*args, **kwargs):
|
||||||
|
exec_calls.append(args)
|
||||||
|
proc = MagicMock()
|
||||||
|
proc.stdout = MagicMock()
|
||||||
|
proc.stdout.read = MagicMock()
|
||||||
|
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||||
|
proc.wait = MagicMock()
|
||||||
|
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
env._sandbox.exec = MagicMock()
|
||||||
|
env._sandbox.exec.aio = mock_exec_fn
|
||||||
|
|
||||||
|
def real_run_coroutine(coro, **kwargs):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
env._worker.run_coroutine = real_run_coroutine
|
||||||
|
return exec_calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_daytona_env():
|
||||||
|
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||||
|
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._remote_home = "/root"
|
||||||
|
env._sync_manager = None
|
||||||
|
env._lock = __import__("threading").Lock()
|
||||||
|
env._persistent = True
|
||||||
|
env._task_id = "test"
|
||||||
|
env._daytona = MagicMock()
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# SSH bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHBulkDownload:
|
||||||
|
"""Unit tests for _ssh_bulk_download."""
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||||
|
"""subprocess.run command should include tar cf - over SSH."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
cmd = mock_run.call_args[0][0]
|
||||||
|
cmd_str = " ".join(cmd)
|
||||||
|
assert "tar cf -" in cmd_str
|
||||||
|
assert "-C /" in cmd_str
|
||||||
|
assert "home/testuser/.hermes" in cmd_str
|
||||||
|
assert "ssh" in cmd_str
|
||||||
|
assert "testuser@example.com" in cmd_str
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||||
|
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
# The stdout kwarg should be a file object opened for writing
|
||||||
|
call_kwargs = mock_run.call_args
|
||||||
|
# stdout is passed as a keyword arg
|
||||||
|
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||||
|
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||||
|
# After the context manager exits, the file is closed, but we can verify
|
||||||
|
# the dest path was used by checking if the file was created.
|
||||||
|
assert dest.exists()
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||||
|
"""Non-zero returncode should raise RuntimeError."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||||
|
with patch.object(subprocess, "run", return_value=failed):
|
||||||
|
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||||
|
"""The subprocess.run call should use a 120s timeout."""
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||||
|
ssh_mock_env._ssh_bulk_download(dest)
|
||||||
|
|
||||||
|
call_kwargs = mock_run.call_args
|
||||||
|
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSHCleanup:
|
||||||
|
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||||
|
|
||||||
|
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||||
|
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
class TrackingSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_back(self):
|
||||||
|
call_order.append("sync_back")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||||
|
|
||||||
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||||
|
env.control_socket = Path("/nonexistent/socket")
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
|
||||||
|
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||||
|
"""sync_back() must run before the ControlMaster exit command."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
class TrackingSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_back(self):
|
||||||
|
call_order.append("sync_back")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||||
|
|
||||||
|
env = SSHEnvironment(host="h", user="u")
|
||||||
|
|
||||||
|
# Create a fake control socket so cleanup tries the SSH exit
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||||
|
env.control_socket = Path(tmp.name)
|
||||||
|
|
||||||
|
def mock_run(cmd, **kwargs):
|
||||||
|
cmd_str = " ".join(cmd)
|
||||||
|
if "-O" in cmd and "exit" in cmd_str:
|
||||||
|
call_order.append("control_exit")
|
||||||
|
return subprocess.CompletedProcess([], 0)
|
||||||
|
|
||||||
|
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Modal bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalBulkDownload:
|
||||||
|
"""Unit tests for _modal_bulk_download."""
|
||||||
|
|
||||||
|
def test_modal_bulk_download_command(self, tmp_path):
|
||||||
|
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert len(exec_calls) == 1
|
||||||
|
args = exec_calls[0]
|
||||||
|
assert args[0] == "bash"
|
||||||
|
assert args[1] == "-c"
|
||||||
|
assert "tar cf -" in args[2]
|
||||||
|
assert "-C / root/.hermes" in args[2]
|
||||||
|
|
||||||
|
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||||
|
"""Downloaded tar bytes should be written to the dest path."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
expected_data = b"some-tar-archive-bytes"
|
||||||
|
_wire_modal_download(env, tar_bytes=expected_data)
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert dest.exists()
|
||||||
|
assert dest.read_bytes() == expected_data
|
||||||
|
|
||||||
|
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||||
|
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
# Simulate Modal SDK returning str
|
||||||
|
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert dest.read_bytes() == b"string-tar-data"
|
||||||
|
|
||||||
|
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||||
|
"""Non-zero exit code should raise RuntimeError."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
_wire_modal_download(env, exit_code=1)
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||||
|
"""run_coroutine should be called with timeout=120."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
_wire_modal_download(env, tar_bytes=b"data")
|
||||||
|
|
||||||
|
run_kwargs = {}
|
||||||
|
original_run = env._worker.run_coroutine
|
||||||
|
|
||||||
|
def tracking_run(coro, **kwargs):
|
||||||
|
run_kwargs.update(kwargs)
|
||||||
|
return original_run(coro, **kwargs)
|
||||||
|
|
||||||
|
env._worker.run_coroutine = tracking_run
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._modal_bulk_download(dest)
|
||||||
|
|
||||||
|
assert run_kwargs.get("timeout") == 120
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalCleanup:
|
||||||
|
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||||
|
|
||||||
|
def test_modal_cleanup_calls_sync_back(self):
|
||||||
|
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||||
|
env = _make_mock_modal_env()
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
sync_mgr = MagicMock()
|
||||||
|
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||||
|
env._sync_manager = sync_mgr
|
||||||
|
|
||||||
|
# Mock terminate to track call order
|
||||||
|
async def mock_terminate():
|
||||||
|
pass
|
||||||
|
|
||||||
|
env._sandbox.terminate = MagicMock()
|
||||||
|
env._sandbox.terminate.aio = mock_terminate
|
||||||
|
env._worker.run_coroutine = lambda coro, **kw: (
|
||||||
|
call_order.append("terminate"),
|
||||||
|
asyncio.new_event_loop().run_until_complete(coro),
|
||||||
|
)
|
||||||
|
env._worker.stop = lambda: None
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Daytona bulk download
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDaytonaBulkDownload:
|
||||||
|
"""Unit tests for _daytona_bulk_download."""
|
||||||
|
|
||||||
|
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||||
|
"""exec and download_file should both be called."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._daytona_bulk_download(dest)
|
||||||
|
|
||||||
|
# exec called twice: tar creation + rm cleanup
|
||||||
|
assert env._sandbox.process.exec.call_count == 2
|
||||||
|
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||||
|
assert "tar cf" in tar_cmd
|
||||||
|
# PID-suffixed temp path avoids collisions on sync_back retry
|
||||||
|
assert "/tmp/.hermes_sync." in tar_cmd
|
||||||
|
assert ".tar" in tar_cmd
|
||||||
|
assert ".hermes" in tar_cmd
|
||||||
|
|
||||||
|
cleanup_cmd = env._sandbox.process.exec.call_args_list[1][0][0]
|
||||||
|
assert "rm -f" in cleanup_cmd
|
||||||
|
assert "/tmp/.hermes_sync." in cleanup_cmd
|
||||||
|
|
||||||
|
# download_file called once with the same PID-suffixed path
|
||||||
|
env._sandbox.fs.download_file.assert_called_once()
|
||||||
|
download_args = env._sandbox.fs.download_file.call_args[0]
|
||||||
|
assert download_args[0].startswith("/tmp/.hermes_sync.")
|
||||||
|
assert download_args[0].endswith(".tar")
|
||||||
|
assert download_args[1] == str(dest)
|
||||||
|
|
||||||
|
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||||
|
"""The tar command should use the env's _remote_home."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
env._remote_home = "/home/daytona"
|
||||||
|
dest = tmp_path / "backup.tar"
|
||||||
|
|
||||||
|
env._daytona_bulk_download(dest)
|
||||||
|
|
||||||
|
tar_cmd = env._sandbox.process.exec.call_args_list[0][0][0]
|
||||||
|
assert "home/daytona/.hermes" in tar_cmd
|
||||||
|
|
||||||
|
|
||||||
|
class TestDaytonaCleanup:
|
||||||
|
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||||
|
|
||||||
|
def test_daytona_cleanup_calls_sync_back(self):
|
||||||
|
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||||
|
env = _make_mock_daytona_env()
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
sync_mgr = MagicMock()
|
||||||
|
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||||
|
env._sync_manager = sync_mgr
|
||||||
|
env._sandbox.stop = lambda: call_order.append("stop")
|
||||||
|
|
||||||
|
env.cleanup()
|
||||||
|
|
||||||
|
assert "sync_back" in call_order
|
||||||
|
assert "stop" in call_order
|
||||||
|
assert call_order.index("sync_back") < call_order.index("stop")
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkDownloadWiring:
|
||||||
|
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||||
|
|
||||||
|
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||||
|
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||||
|
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||||
|
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
class CaptureSyncManager:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
def sync(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||||
|
|
||||||
|
SSHEnvironment(host="h", user="u")
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
||||||
|
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def capture_fsm(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return type("M", (), {"sync": lambda self, **k: None})()
|
||||||
|
|
||||||
|
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||||
|
|
||||||
|
env = object.__new__(modal_env.ModalEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._worker = MagicMock()
|
||||||
|
env._persistent = False
|
||||||
|
env._task_id = "test"
|
||||||
|
|
||||||
|
# Replicate the wiring done in __init__
|
||||||
|
from tools.environments.file_sync import iter_sync_files
|
||||||
|
env._sync_manager = modal_env.FileSyncManager(
|
||||||
|
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||||
|
upload_fn=env._modal_upload,
|
||||||
|
delete_fn=env._modal_delete,
|
||||||
|
bulk_upload_fn=env._modal_bulk_upload,
|
||||||
|
bulk_download_fn=env._modal_bulk_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
||||||
|
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||||
|
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def capture_fsm(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return type("M", (), {"sync": lambda self, **k: None})()
|
||||||
|
|
||||||
|
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||||
|
|
||||||
|
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||||
|
env._sandbox = MagicMock()
|
||||||
|
env._remote_home = "/root"
|
||||||
|
env._lock = __import__("threading").Lock()
|
||||||
|
env._persistent = True
|
||||||
|
env._task_id = "test"
|
||||||
|
env._daytona = MagicMock()
|
||||||
|
|
||||||
|
# Replicate the wiring done in __init__
|
||||||
|
from tools.environments.file_sync import iter_sync_files
|
||||||
|
env._sync_manager = daytona_env.FileSyncManager(
|
||||||
|
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||||
|
upload_fn=env._daytona_upload,
|
||||||
|
delete_fn=env._daytona_delete,
|
||||||
|
bulk_upload_fn=env._daytona_bulk_upload,
|
||||||
|
bulk_download_fn=env._daytona_bulk_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bulk_download_fn" in captured_kwargs
|
||||||
|
assert callable(captured_kwargs["bulk_download_fn"])
|
||||||
|
|
@ -7,6 +7,7 @@ and resumed on next creation, preserving the filesystem across sessions.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -134,6 +135,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._daytona_upload,
|
upload_fn=self._daytona_upload,
|
||||||
delete_fn=self._daytona_delete,
|
delete_fn=self._daytona_delete,
|
||||||
bulk_upload_fn=self._daytona_bulk_upload,
|
bulk_upload_fn=self._daytona_bulk_upload,
|
||||||
|
bulk_download_fn=self._daytona_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
self.init_session()
|
self.init_session()
|
||||||
|
|
@ -166,6 +168,22 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
]
|
]
|
||||||
self._sandbox.fs.upload_files(uploads)
|
self._sandbox.fs.upload_files(uploads)
|
||||||
|
|
||||||
|
def _daytona_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive."""
|
||||||
|
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||||
|
# PID-suffixed remote temp path avoids collisions if sync_back fires
|
||||||
|
# concurrently for the same sandbox (e.g. retry after partial failure).
|
||||||
|
remote_tar = f"/tmp/.hermes_sync.{os.getpid()}.tar"
|
||||||
|
self._sandbox.process.exec(
|
||||||
|
f"tar cf {shlex.quote(remote_tar)} -C / {shlex.quote(rel_base)}"
|
||||||
|
)
|
||||||
|
self._sandbox.fs.download_file(remote_tar, str(dest))
|
||||||
|
# Clean up remote temp file
|
||||||
|
try:
|
||||||
|
self._sandbox.process.exec(f"rm -f {shlex.quote(remote_tar)}")
|
||||||
|
except Exception:
|
||||||
|
pass # best-effort cleanup
|
||||||
|
|
||||||
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files via SDK exec."""
|
"""Batch-delete remote files via SDK exec."""
|
||||||
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
||||||
|
|
@ -216,6 +234,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Sync remote changes back to host before teardown. Running
|
||||||
|
# inside the lock (and after the _sandbox is None guard) avoids
|
||||||
|
# firing sync_back on an already-cleaned-up env, which would
|
||||||
|
# trigger a 3-attempt retry storm against a nil sandbox.
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("Daytona: syncing files from sandbox...")
|
||||||
|
try:
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Daytona: sync_back failed: %s", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._persistent:
|
if self._persistent:
|
||||||
self._sandbox.stop()
|
self._sandbox.stop()
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,25 @@ and Daytona. Docker and Singularity use bind mounts (live host FS
|
||||||
view) and don't need this.
|
view) and don't need this.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fcntl
|
||||||
|
except ImportError:
|
||||||
|
fcntl = None # Windows — file locking skipped
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from hermes_constants import get_hermes_home
|
||||||
from tools.environments.base import _file_mtime_key
|
from tools.environments.base import _file_mtime_key
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,6 +35,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
|
||||||
# Transport callbacks provided by each backend
|
# Transport callbacks provided by each backend
|
||||||
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
||||||
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
||||||
|
BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure
|
||||||
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
||||||
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
||||||
|
|
||||||
|
|
@ -71,6 +84,20 @@ def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
||||||
return sorted({str(Path(remote).parent) for _, remote in files})
|
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_file(path: str) -> str:
|
||||||
|
"""Return hex SHA-256 digest of a file."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(65536), b""):
|
||||||
|
h.update(chunk)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
_SYNC_BACK_MAX_RETRIES = 3
|
||||||
|
_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries
|
||||||
|
_SYNC_BACK_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB — refuse to extract larger tars
|
||||||
|
|
||||||
|
|
||||||
class FileSyncManager:
|
class FileSyncManager:
|
||||||
"""Tracks local file changes and syncs to a remote environment.
|
"""Tracks local file changes and syncs to a remote environment.
|
||||||
|
|
||||||
|
|
@ -89,12 +116,15 @@ class FileSyncManager:
|
||||||
delete_fn: DeleteFn,
|
delete_fn: DeleteFn,
|
||||||
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
||||||
bulk_upload_fn: BulkUploadFn | None = None,
|
bulk_upload_fn: BulkUploadFn | None = None,
|
||||||
|
bulk_download_fn: BulkDownloadFn | None = None,
|
||||||
):
|
):
|
||||||
self._get_files_fn = get_files_fn
|
self._get_files_fn = get_files_fn
|
||||||
self._upload_fn = upload_fn
|
self._upload_fn = upload_fn
|
||||||
self._bulk_upload_fn = bulk_upload_fn
|
self._bulk_upload_fn = bulk_upload_fn
|
||||||
|
self._bulk_download_fn = bulk_download_fn
|
||||||
self._delete_fn = delete_fn
|
self._delete_fn = delete_fn
|
||||||
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
||||||
|
self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest
|
||||||
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
||||||
self._sync_interval = sync_interval
|
self._sync_interval = sync_interval
|
||||||
|
|
||||||
|
|
@ -136,6 +166,7 @@ class FileSyncManager:
|
||||||
|
|
||||||
# Snapshot for rollback (only when there's work to do)
|
# Snapshot for rollback (only when there's work to do)
|
||||||
prev_files = dict(self._synced_files)
|
prev_files = dict(self._synced_files)
|
||||||
|
prev_hashes = dict(self._pushed_hashes)
|
||||||
|
|
||||||
if to_upload:
|
if to_upload:
|
||||||
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
||||||
|
|
@ -156,13 +187,207 @@ class FileSyncManager:
|
||||||
logger.debug("file_sync: deleted %s", to_delete)
|
logger.debug("file_sync: deleted %s", to_delete)
|
||||||
|
|
||||||
# --- Commit (all succeeded) ---
|
# --- Commit (all succeeded) ---
|
||||||
|
for host_path, remote_path in to_upload:
|
||||||
|
self._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||||
|
|
||||||
for p in to_delete:
|
for p in to_delete:
|
||||||
new_files.pop(p, None)
|
new_files.pop(p, None)
|
||||||
|
self._pushed_hashes.pop(p, None)
|
||||||
|
|
||||||
self._synced_files = new_files
|
self._synced_files = new_files
|
||||||
self._last_sync_time = time.monotonic()
|
self._last_sync_time = time.monotonic()
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._synced_files = prev_files
|
self._synced_files = prev_files
|
||||||
|
self._pushed_hashes = prev_hashes
|
||||||
self._last_sync_time = time.monotonic()
|
self._last_sync_time = time.monotonic()
|
||||||
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Sync-back: pull remote changes to host on teardown
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sync_back(self, hermes_home: Path | None = None) -> None:
|
||||||
|
"""Pull remote changes back to the host filesystem.
|
||||||
|
|
||||||
|
Downloads the remote ``.hermes/`` directory as a tar archive,
|
||||||
|
unpacks it, and applies only files that differ from what was
|
||||||
|
originally pushed (based on SHA-256 content hashes).
|
||||||
|
|
||||||
|
Protected against SIGINT (defers the signal until complete) and
|
||||||
|
serialized across concurrent gateway sandboxes via file lock.
|
||||||
|
"""
|
||||||
|
if self._bulk_download_fn is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Nothing was ever committed through this manager — the initial
|
||||||
|
# push failed or never ran. Skip sync_back to avoid retry storms
|
||||||
|
# against an uninitialized remote .hermes/ directory.
|
||||||
|
if not self._pushed_hashes and not self._synced_files:
|
||||||
|
logger.debug("sync_back: no prior push state — skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
|
||||||
|
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(_SYNC_BACK_MAX_RETRIES):
|
||||||
|
try:
|
||||||
|
self._sync_back_once(lock_path)
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < _SYNC_BACK_MAX_RETRIES - 1:
|
||||||
|
delay = _SYNC_BACK_BACKOFF[attempt]
|
||||||
|
logger.warning(
|
||||||
|
"sync_back: attempt %d failed (%s), retrying in %ds",
|
||||||
|
attempt + 1, exc, delay,
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
|
||||||
|
|
||||||
|
def _sync_back_once(self, lock_path: Path) -> None:
|
||||||
|
"""Single sync-back attempt with SIGINT protection and file lock."""
|
||||||
|
# signal.signal() only works from the main thread. In gateway
|
||||||
|
# contexts cleanup() may run from a worker thread — skip SIGINT
|
||||||
|
# deferral there rather than crashing.
|
||||||
|
on_main_thread = threading.current_thread() is threading.main_thread()
|
||||||
|
|
||||||
|
deferred_sigint: list[object] = []
|
||||||
|
original_handler = None
|
||||||
|
if on_main_thread:
|
||||||
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
|
|
||||||
|
def _defer_sigint(signum, frame):
|
||||||
|
deferred_sigint.append((signum, frame))
|
||||||
|
logger.debug("sync_back: SIGINT deferred until sync completes")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, _defer_sigint)
|
||||||
|
try:
|
||||||
|
self._sync_back_locked(lock_path)
|
||||||
|
finally:
|
||||||
|
if on_main_thread and original_handler is not None:
|
||||||
|
signal.signal(signal.SIGINT, original_handler)
|
||||||
|
if deferred_sigint:
|
||||||
|
os.kill(os.getpid(), signal.SIGINT)
|
||||||
|
|
||||||
|
def _sync_back_locked(self, lock_path: Path) -> None:
|
||||||
|
"""Sync-back under file lock (serializes concurrent gateways)."""
|
||||||
|
if fcntl is None:
|
||||||
|
# Windows: no flock — run without serialization
|
||||||
|
self._sync_back_impl()
|
||||||
|
return
|
||||||
|
lock_fd = open(lock_path, "w")
|
||||||
|
try:
|
||||||
|
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||||
|
self._sync_back_impl()
|
||||||
|
finally:
|
||||||
|
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||||
|
lock_fd.close()
|
||||||
|
|
||||||
|
def _sync_back_impl(self) -> None:
|
||||||
|
"""Download, diff, and apply remote changes to host."""
|
||||||
|
if self._bulk_download_fn is None:
|
||||||
|
raise RuntimeError("_sync_back_impl called without bulk_download_fn")
|
||||||
|
|
||||||
|
# Cache file mapping once to avoid O(n*m) from repeated iteration
|
||||||
|
try:
|
||||||
|
file_mapping = list(self._get_files_fn())
|
||||||
|
except Exception:
|
||||||
|
file_mapping = []
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
|
||||||
|
self._bulk_download_fn(Path(tf.name))
|
||||||
|
|
||||||
|
# Defensive size cap: a misbehaving sandbox could produce an
|
||||||
|
# arbitrarily large tar. Refuse to extract if it exceeds the cap.
|
||||||
|
try:
|
||||||
|
tar_size = os.path.getsize(tf.name)
|
||||||
|
except OSError:
|
||||||
|
tar_size = 0
|
||||||
|
if tar_size > _SYNC_BACK_MAX_BYTES:
|
||||||
|
logger.warning(
|
||||||
|
"sync_back: remote tar is %d bytes (cap %d) — skipping extraction",
|
||||||
|
tar_size, _SYNC_BACK_MAX_BYTES,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
|
||||||
|
with tarfile.open(tf.name) as tar:
|
||||||
|
tar.extractall(staging, filter="data")
|
||||||
|
|
||||||
|
applied = 0
|
||||||
|
for dirpath, _dirnames, filenames in os.walk(staging):
|
||||||
|
for fname in filenames:
|
||||||
|
staged_file = os.path.join(dirpath, fname)
|
||||||
|
rel = os.path.relpath(staged_file, staging)
|
||||||
|
remote_path = "/" + rel
|
||||||
|
|
||||||
|
pushed_hash = self._pushed_hashes.get(remote_path)
|
||||||
|
|
||||||
|
# Skip hashing for files unchanged from push
|
||||||
|
if pushed_hash is not None:
|
||||||
|
remote_hash = _sha256_file(staged_file)
|
||||||
|
if remote_hash == pushed_hash:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
remote_hash = None # new remote file
|
||||||
|
|
||||||
|
# Resolve host path from cached mapping
|
||||||
|
host_path = self._resolve_host_path(remote_path, file_mapping)
|
||||||
|
if host_path is None:
|
||||||
|
host_path = self._infer_host_path(remote_path, file_mapping)
|
||||||
|
if host_path is None:
|
||||||
|
logger.debug(
|
||||||
|
"sync_back: skipping %s (no host mapping)",
|
||||||
|
remote_path,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(host_path) and pushed_hash is not None:
|
||||||
|
host_hash = _sha256_file(host_path)
|
||||||
|
if host_hash != pushed_hash:
|
||||||
|
logger.warning(
|
||||||
|
"sync_back: conflict on %s — host modified "
|
||||||
|
"since push, remote also changed. Applying "
|
||||||
|
"remote version (last-write-wins).",
|
||||||
|
remote_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||||
|
shutil.copy2(staged_file, host_path)
|
||||||
|
applied += 1
|
||||||
|
|
||||||
|
if applied:
|
||||||
|
logger.info("sync_back: applied %d changed file(s)", applied)
|
||||||
|
else:
|
||||||
|
logger.debug("sync_back: no remote changes detected")
|
||||||
|
|
||||||
|
def _resolve_host_path(self, remote_path: str,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||||
|
"""Find the host path for a known remote path from the file mapping."""
|
||||||
|
mapping = file_mapping if file_mapping is not None else []
|
||||||
|
for host, remote in mapping:
|
||||||
|
if remote == remote_path:
|
||||||
|
return host
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _infer_host_path(self, remote_path: str,
|
||||||
|
file_mapping: list[tuple[str, str]] | None = None) -> str | None:
|
||||||
|
"""Infer a host path for a new remote file by matching path prefixes.
|
||||||
|
|
||||||
|
Uses the existing file mapping to find a remote->host directory
|
||||||
|
pair, then applies the same prefix substitution to the new file.
|
||||||
|
For example, if the mapping has ``/root/.hermes/skills/a.md`` →
|
||||||
|
``~/.hermes/skills/a.md``, a new remote file at
|
||||||
|
``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
|
||||||
|
"""
|
||||||
|
mapping = file_mapping if file_mapping is not None else []
|
||||||
|
for host, remote in mapping:
|
||||||
|
remote_dir = str(Path(remote).parent)
|
||||||
|
if remote_path.startswith(remote_dir + "/"):
|
||||||
|
host_dir = str(Path(host).parent)
|
||||||
|
suffix = remote_path[len(remote_dir):]
|
||||||
|
return host_dir + suffix
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -269,6 +269,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._modal_upload,
|
upload_fn=self._modal_upload,
|
||||||
delete_fn=self._modal_delete,
|
delete_fn=self._modal_delete,
|
||||||
bulk_upload_fn=self._modal_bulk_upload,
|
bulk_upload_fn=self._modal_bulk_upload,
|
||||||
|
bulk_download_fn=self._modal_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
self.init_session()
|
self.init_session()
|
||||||
|
|
@ -347,6 +348,27 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
|
|
||||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||||
|
|
||||||
|
def _modal_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive.
|
||||||
|
|
||||||
|
Modal sandboxes always run as root, so /root/.hermes is hardcoded
|
||||||
|
(consistent with iter_sync_files call on line 269).
|
||||||
|
"""
|
||||||
|
async def _download():
|
||||||
|
proc = await self._sandbox.exec.aio(
|
||||||
|
"bash", "-c", "tar cf - -C / root/.hermes"
|
||||||
|
)
|
||||||
|
data = await proc.stdout.read.aio()
|
||||||
|
exit_code = await proc.wait.aio()
|
||||||
|
if exit_code != 0:
|
||||||
|
raise RuntimeError(f"Modal bulk download failed (exit {exit_code})")
|
||||||
|
return data
|
||||||
|
|
||||||
|
tar_bytes = self._worker.run_coroutine(_download(), timeout=120)
|
||||||
|
if isinstance(tar_bytes, str):
|
||||||
|
tar_bytes = tar_bytes.encode()
|
||||||
|
dest.write_bytes(tar_bytes)
|
||||||
|
|
||||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files via exec."""
|
"""Batch-delete remote files via exec."""
|
||||||
rm_cmd = quoted_rm_command(remote_paths)
|
rm_cmd = quoted_rm_command(remote_paths)
|
||||||
|
|
@ -404,6 +426,10 @@ class ModalEnvironment(BaseEnvironment):
|
||||||
if self._sandbox is None:
|
if self._sandbox is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("Modal: syncing files from sandbox...")
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
|
||||||
if self._persistent:
|
if self._persistent:
|
||||||
try:
|
try:
|
||||||
async def _snapshot():
|
async def _snapshot():
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
upload_fn=self._scp_upload,
|
upload_fn=self._scp_upload,
|
||||||
delete_fn=self._ssh_delete,
|
delete_fn=self._ssh_delete,
|
||||||
bulk_upload_fn=self._ssh_bulk_upload,
|
bulk_upload_fn=self._ssh_bulk_upload,
|
||||||
|
bulk_download_fn=self._ssh_bulk_download,
|
||||||
)
|
)
|
||||||
self._sync_manager.sync(force=True)
|
self._sync_manager.sync(force=True)
|
||||||
|
|
||||||
|
|
@ -216,6 +217,18 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
|
|
||||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||||
|
|
||||||
|
def _ssh_bulk_download(self, dest: Path) -> None:
|
||||||
|
"""Download remote .hermes/ as a tar archive."""
|
||||||
|
# Tar from / with the full path so archive entries preserve absolute
|
||||||
|
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
||||||
|
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||||
|
ssh_cmd = self._build_ssh_command()
|
||||||
|
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
||||||
|
with open(dest, "wb") as f:
|
||||||
|
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
||||||
|
|
||||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||||
"""Batch-delete remote files in one SSH call."""
|
"""Batch-delete remote files in one SSH call."""
|
||||||
cmd = self._build_ssh_command()
|
cmd = self._build_ssh_command()
|
||||||
|
|
@ -245,6 +258,10 @@ class SSHEnvironment(BaseEnvironment):
|
||||||
return _popen_bash(cmd, stdin_data)
|
return _popen_bash(cmd, stdin_data)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
if self._sync_manager:
|
||||||
|
logger.info("SSH: syncing files from sandbox...")
|
||||||
|
self._sync_manager.sync_back()
|
||||||
|
|
||||||
if self.control_socket.exists():
|
if self.control_socket.exists():
|
||||||
try:
|
try:
|
||||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||||
|
|
|
||||||
|
|
@ -2,30 +2,22 @@
|
||||||
"""
|
"""
|
||||||
Image Generation Tools Module
|
Image Generation Tools Module
|
||||||
|
|
||||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
Provides image generation via FAL.ai. Multiple FAL models are supported and
|
||||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
selectable via ``hermes tools`` → Image Generation; the active model is
|
||||||
|
persisted to ``image_gen.model`` in ``config.yaml``.
|
||||||
|
|
||||||
Available tools:
|
Architecture:
|
||||||
- image_generate_tool: Generate images from text prompts with automatic upscaling
|
- ``FAL_MODELS`` is a catalog of supported models with per-model metadata
|
||||||
|
(size-style family, defaults, ``supports`` whitelist, upscaler flag).
|
||||||
|
- ``_build_fal_payload()`` translates the agent's unified inputs (prompt +
|
||||||
|
aspect_ratio) into the model-specific payload and filters to the
|
||||||
|
``supports`` whitelist so models never receive rejected keys.
|
||||||
|
- Upscaling via FAL's Clarity Upscaler is gated per-model via the ``upscale``
|
||||||
|
flag — on for FLUX 2 Pro (backward-compat), off for all faster/newer models
|
||||||
|
where upscaling would either hurt latency or add marginal quality.
|
||||||
|
|
||||||
Features:
|
Pricing shown in UI strings is as-of the initial commit; we accept drift and
|
||||||
- High-quality image generation using FLUX 2 Pro model
|
update when it's noticed.
|
||||||
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
|
|
||||||
- Comprehensive parameter control (size, steps, guidance, etc.)
|
|
||||||
- Proper error handling and validation with fallback to original images
|
|
||||||
- Debug logging support
|
|
||||||
- Sync mode for immediate results
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from image_generation_tool import image_generate_tool
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Generate and automatically upscale an image
|
|
||||||
result = await image_generate_tool(
|
|
||||||
prompt="A serene mountain landscape with cherry blossoms",
|
|
||||||
image_size="landscape_4_3",
|
|
||||||
num_images=1
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -34,35 +26,237 @@ import os
|
||||||
import datetime
|
import datetime
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import fal_client
|
import fal_client
|
||||||
|
|
||||||
from tools.debug_helpers import DebugSession
|
from tools.debug_helpers import DebugSession
|
||||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||||
from tools.tool_backend_helpers import managed_nous_tools_enabled, prefers_gateway
|
from tools.tool_backend_helpers import managed_nous_tools_enabled, prefers_gateway
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Configuration for image generation
|
|
||||||
DEFAULT_MODEL = "fal-ai/flux-2-pro"
|
|
||||||
DEFAULT_ASPECT_RATIO = "landscape"
|
|
||||||
DEFAULT_NUM_INFERENCE_STEPS = 50
|
|
||||||
DEFAULT_GUIDANCE_SCALE = 4.5
|
|
||||||
DEFAULT_NUM_IMAGES = 1
|
|
||||||
DEFAULT_OUTPUT_FORMAT = "png"
|
|
||||||
|
|
||||||
# Safety settings
|
# ---------------------------------------------------------------------------
|
||||||
ENABLE_SAFETY_CHECKER = False
|
# FAL model catalog
|
||||||
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Each entry declares how to translate our unified inputs into the model's
|
||||||
|
# native payload shape. Size specification falls into three families:
|
||||||
|
#
|
||||||
|
# "image_size_preset" — preset enum ("square_hd", "landscape_16_9", ...)
|
||||||
|
# used by the flux family, z-image, qwen, recraft,
|
||||||
|
# ideogram.
|
||||||
|
# "aspect_ratio" — aspect ratio enum ("16:9", "1:1", ...) used by
|
||||||
|
# nano-banana (Gemini).
|
||||||
|
# "gpt_literal" — literal dimension strings ("1024x1024", etc.)
|
||||||
|
# used by gpt-image-1.5.
|
||||||
|
#
|
||||||
|
# ``supports`` is a whitelist of keys allowed in the outgoing payload — any
|
||||||
|
# key outside this set is stripped before submission so models never receive
|
||||||
|
# rejected parameters (each FAL model rejects unknown keys differently).
|
||||||
|
#
|
||||||
|
# ``upscale`` controls whether to chain Clarity Upscaler after generation.
|
||||||
|
|
||||||
# Aspect ratio mapping - simplified choices for model to select
|
FAL_MODELS: Dict[str, Dict[str, Any]] = {
|
||||||
ASPECT_RATIO_MAP = {
|
"fal-ai/flux-2/klein/9b": {
|
||||||
|
"display": "FLUX 2 Klein 9B",
|
||||||
|
"speed": "<1s",
|
||||||
|
"strengths": "Fast, crisp text",
|
||||||
|
"price": "$0.006/MP",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
"landscape": "landscape_16_9",
|
"landscape": "landscape_16_9",
|
||||||
"square": "square_hd",
|
"square": "square_hd",
|
||||||
"portrait": "portrait_16_9"
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"num_inference_steps": 4,
|
||||||
|
"output_format": "png",
|
||||||
|
"enable_safety_checker": False,
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "num_inference_steps", "seed",
|
||||||
|
"output_format", "enable_safety_checker",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/flux-2-pro": {
|
||||||
|
"display": "FLUX 2 Pro",
|
||||||
|
"speed": "~6s",
|
||||||
|
"strengths": "Studio photorealism",
|
||||||
|
"price": "$0.03/MP",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "landscape_16_9",
|
||||||
|
"square": "square_hd",
|
||||||
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"num_inference_steps": 50,
|
||||||
|
"guidance_scale": 4.5,
|
||||||
|
"num_images": 1,
|
||||||
|
"output_format": "png",
|
||||||
|
"enable_safety_checker": False,
|
||||||
|
"safety_tolerance": "5",
|
||||||
|
"sync_mode": True,
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "num_inference_steps", "guidance_scale",
|
||||||
|
"num_images", "output_format", "enable_safety_checker",
|
||||||
|
"safety_tolerance", "sync_mode", "seed",
|
||||||
|
},
|
||||||
|
"upscale": True, # Backward-compat: current default behavior.
|
||||||
|
},
|
||||||
|
"fal-ai/z-image/turbo": {
|
||||||
|
"display": "Z-Image Turbo",
|
||||||
|
"speed": "~2s",
|
||||||
|
"strengths": "Bilingual EN/CN, 6B",
|
||||||
|
"price": "$0.005/MP",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "landscape_16_9",
|
||||||
|
"square": "square_hd",
|
||||||
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"num_inference_steps": 8,
|
||||||
|
"num_images": 1,
|
||||||
|
"output_format": "png",
|
||||||
|
"enable_safety_checker": False,
|
||||||
|
"enable_prompt_expansion": False, # avoid the extra per-request charge
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "num_inference_steps", "num_images",
|
||||||
|
"seed", "output_format", "enable_safety_checker",
|
||||||
|
"enable_prompt_expansion",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/nano-banana": {
|
||||||
|
"display": "Nano Banana (Gemini 2.5 Flash Image)",
|
||||||
|
"speed": "~6s",
|
||||||
|
"strengths": "Gemini 2.5, consistency",
|
||||||
|
"price": "$0.08/image",
|
||||||
|
"size_style": "aspect_ratio",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "16:9",
|
||||||
|
"square": "1:1",
|
||||||
|
"portrait": "9:16",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"num_images": 1,
|
||||||
|
"output_format": "png",
|
||||||
|
"safety_tolerance": "5",
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "aspect_ratio", "num_images", "output_format",
|
||||||
|
"safety_tolerance", "seed", "sync_mode",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/gpt-image-1.5": {
|
||||||
|
"display": "GPT Image 1.5",
|
||||||
|
"speed": "~15s",
|
||||||
|
"strengths": "Prompt adherence",
|
||||||
|
"price": "$0.034/image",
|
||||||
|
"size_style": "gpt_literal",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "1536x1024",
|
||||||
|
"square": "1024x1024",
|
||||||
|
"portrait": "1024x1536",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
# Quality is pinned to medium to keep portal billing predictable
|
||||||
|
# across all users (low is too rough, high is 4-6x more expensive).
|
||||||
|
"quality": "medium",
|
||||||
|
"num_images": 1,
|
||||||
|
"output_format": "png",
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "quality", "num_images", "output_format",
|
||||||
|
"background", "sync_mode",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/ideogram/v3": {
|
||||||
|
"display": "Ideogram V3",
|
||||||
|
"speed": "~5s",
|
||||||
|
"strengths": "Best typography",
|
||||||
|
"price": "$0.03-0.09/image",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "landscape_16_9",
|
||||||
|
"square": "square_hd",
|
||||||
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"rendering_speed": "BALANCED",
|
||||||
|
"expand_prompt": True,
|
||||||
|
"style": "AUTO",
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "rendering_speed", "expand_prompt",
|
||||||
|
"style", "seed",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/recraft-v3": {
|
||||||
|
"display": "Recraft V3",
|
||||||
|
"speed": "~8s",
|
||||||
|
"strengths": "Vector, brand styles",
|
||||||
|
"price": "$0.04/image",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "landscape_16_9",
|
||||||
|
"square": "square_hd",
|
||||||
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"style": "realistic_image",
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "style",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
|
"fal-ai/qwen-image": {
|
||||||
|
"display": "Qwen Image",
|
||||||
|
"speed": "~12s",
|
||||||
|
"strengths": "LLM-based, complex text",
|
||||||
|
"price": "$0.02/MP",
|
||||||
|
"size_style": "image_size_preset",
|
||||||
|
"sizes": {
|
||||||
|
"landscape": "landscape_16_9",
|
||||||
|
"square": "square_hd",
|
||||||
|
"portrait": "portrait_16_9",
|
||||||
|
},
|
||||||
|
"defaults": {
|
||||||
|
"num_inference_steps": 30,
|
||||||
|
"guidance_scale": 2.5,
|
||||||
|
"num_images": 1,
|
||||||
|
"output_format": "png",
|
||||||
|
"acceleration": "regular",
|
||||||
|
},
|
||||||
|
"supports": {
|
||||||
|
"prompt", "image_size", "num_inference_steps", "guidance_scale",
|
||||||
|
"num_images", "output_format", "acceleration", "seed", "sync_mode",
|
||||||
|
},
|
||||||
|
"upscale": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Configuration for automatic upscaling
|
# Default model is the fastest reasonable option. Kept cheap and sub-1s.
|
||||||
|
DEFAULT_MODEL = "fal-ai/flux-2/klein/9b"
|
||||||
|
|
||||||
|
DEFAULT_ASPECT_RATIO = "landscape"
|
||||||
|
VALID_ASPECT_RATIOS = ("landscape", "square", "portrait")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Upscaler (Clarity Upscaler — unchanged from previous implementation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
|
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
|
||||||
UPSCALER_FACTOR = 2
|
UPSCALER_FACTOR = 2
|
||||||
UPSCALER_SAFETY_CHECKER = False
|
UPSCALER_SAFETY_CHECKER = False
|
||||||
|
|
@ -73,12 +267,6 @@ UPSCALER_RESEMBLANCE = 0.6
|
||||||
UPSCALER_GUIDANCE_SCALE = 4
|
UPSCALER_GUIDANCE_SCALE = 4
|
||||||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||||
|
|
||||||
# Valid parameter values for validation based on FLUX 2 Pro documentation
|
|
||||||
VALID_IMAGE_SIZES = [
|
|
||||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
|
||||||
]
|
|
||||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
|
||||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
|
||||||
|
|
||||||
_debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
_debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
||||||
_managed_fal_client = None
|
_managed_fal_client = None
|
||||||
|
|
@ -86,6 +274,9 @@ _managed_fal_client_config = None
|
||||||
_managed_fal_client_lock = threading.Lock()
|
_managed_fal_client_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Managed FAL gateway (Nous Subscription)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
def _resolve_managed_fal_gateway():
|
def _resolve_managed_fal_gateway():
|
||||||
"""Return managed fal-queue gateway config when the user prefers the gateway
|
"""Return managed fal-queue gateway config when the user prefers the gateway
|
||||||
or direct FAL credentials are absent."""
|
or direct FAL credentials are absent."""
|
||||||
|
|
@ -208,104 +399,140 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]):
|
||||||
return fal_client.submit(model, arguments=arguments, headers=request_headers)
|
return fal_client.submit(model, arguments=arguments, headers=request_headers)
|
||||||
|
|
||||||
managed_client = _get_managed_fal_client(managed_gateway)
|
managed_client = _get_managed_fal_client(managed_gateway)
|
||||||
|
try:
|
||||||
return managed_client.submit(
|
return managed_client.submit(
|
||||||
model,
|
model,
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
headers=request_headers,
|
headers=request_headers,
|
||||||
)
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
# 4xx from the managed gateway typically means the portal doesn't
|
||||||
|
# currently proxy this model (allowlist miss, billing gate, etc.)
|
||||||
|
# — surface a clearer message with actionable remediation instead
|
||||||
|
# of a raw HTTP error from httpx.
|
||||||
|
status = _extract_http_status(exc)
|
||||||
|
if status is not None and 400 <= status < 500:
|
||||||
|
raise ValueError(
|
||||||
|
f"Nous Subscription gateway rejected model '{model}' "
|
||||||
|
f"(HTTP {status}). This model may not yet be enabled on "
|
||||||
|
f"the Nous Portal's FAL proxy. Either:\n"
|
||||||
|
f" • Set FAL_KEY in your environment to use FAL.ai directly, or\n"
|
||||||
|
f" • Pick a different model via `hermes tools` → Image Generation."
|
||||||
|
) from exc
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _validate_parameters(
|
def _extract_http_status(exc: BaseException) -> Optional[int]:
|
||||||
image_size: Union[str, Dict[str, int]],
|
"""Return an HTTP status code from httpx/fal exceptions, else None.
|
||||||
num_inference_steps: int,
|
|
||||||
guidance_scale: float,
|
Defensive across exception shapes — httpx.HTTPStatusError exposes
|
||||||
num_images: int,
|
``.response.status_code`` while fal_client wrappers may expose
|
||||||
output_format: str,
|
``.status_code`` directly.
|
||||||
acceleration: str = "none"
|
"""
|
||||||
|
response = getattr(exc, "response", None)
|
||||||
|
if response is not None:
|
||||||
|
status = getattr(response, "status_code", None)
|
||||||
|
if isinstance(status, int):
|
||||||
|
return status
|
||||||
|
status = getattr(exc, "status_code", None)
|
||||||
|
if isinstance(status, int):
|
||||||
|
return status
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model resolution + payload construction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def _resolve_fal_model() -> tuple:
|
||||||
|
"""Resolve the active FAL model from config.yaml (primary) or default.
|
||||||
|
|
||||||
|
Returns (model_id, metadata_dict). Falls back to DEFAULT_MODEL if the
|
||||||
|
configured model is unknown (logged as a warning).
|
||||||
|
"""
|
||||||
|
model_id = ""
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
cfg = load_config()
|
||||||
|
img_cfg = cfg.get("image_gen") if isinstance(cfg, dict) else None
|
||||||
|
if isinstance(img_cfg, dict):
|
||||||
|
raw = img_cfg.get("model")
|
||||||
|
if isinstance(raw, str):
|
||||||
|
model_id = raw.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Could not load image_gen.model from config: %s", exc)
|
||||||
|
|
||||||
|
# Env var escape hatch (undocumented; backward-compat for tests/scripts).
|
||||||
|
if not model_id:
|
||||||
|
model_id = os.getenv("FAL_IMAGE_MODEL", "").strip()
|
||||||
|
|
||||||
|
if not model_id:
|
||||||
|
return DEFAULT_MODEL, FAL_MODELS[DEFAULT_MODEL]
|
||||||
|
|
||||||
|
if model_id not in FAL_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown FAL model '%s' in config; falling back to %s",
|
||||||
|
model_id, DEFAULT_MODEL,
|
||||||
|
)
|
||||||
|
return DEFAULT_MODEL, FAL_MODELS[DEFAULT_MODEL]
|
||||||
|
|
||||||
|
return model_id, FAL_MODELS[model_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fal_payload(
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
overrides: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
"""Build a FAL request payload for `model_id` from unified inputs.
|
||||||
|
|
||||||
|
Translates aspect_ratio into the model's native size spec (preset enum,
|
||||||
|
aspect-ratio enum, or GPT literal string), merges model defaults, applies
|
||||||
|
caller overrides, then filters to the model's ``supports`` whitelist.
|
||||||
"""
|
"""
|
||||||
Validate and normalize image generation parameters for FLUX 2 Pro model.
|
meta = FAL_MODELS[model_id]
|
||||||
|
size_style = meta["size_style"]
|
||||||
|
sizes = meta["sizes"]
|
||||||
|
|
||||||
Args:
|
aspect = (aspect_ratio or DEFAULT_ASPECT_RATIO).lower().strip()
|
||||||
image_size: Either a preset string or custom size dict
|
if aspect not in sizes:
|
||||||
num_inference_steps: Number of inference steps
|
aspect = DEFAULT_ASPECT_RATIO
|
||||||
guidance_scale: Guidance scale value
|
|
||||||
num_images: Number of images to generate
|
|
||||||
output_format: Output format for images
|
|
||||||
acceleration: Acceleration mode for generation speed
|
|
||||||
|
|
||||||
Returns:
|
payload: Dict[str, Any] = dict(meta.get("defaults", {}))
|
||||||
Dict[str, Any]: Validated and normalized parameters
|
payload["prompt"] = (prompt or "").strip()
|
||||||
|
|
||||||
Raises:
|
if size_style in ("image_size_preset", "gpt_literal"):
|
||||||
ValueError: If any parameter is invalid
|
payload["image_size"] = sizes[aspect]
|
||||||
"""
|
elif size_style == "aspect_ratio":
|
||||||
validated = {}
|
payload["aspect_ratio"] = sizes[aspect]
|
||||||
|
|
||||||
# Validate image_size
|
|
||||||
if isinstance(image_size, str):
|
|
||||||
if image_size not in VALID_IMAGE_SIZES:
|
|
||||||
raise ValueError(f"Invalid image_size '{image_size}'. Must be one of: {VALID_IMAGE_SIZES}")
|
|
||||||
validated["image_size"] = image_size
|
|
||||||
elif isinstance(image_size, dict):
|
|
||||||
if "width" not in image_size or "height" not in image_size:
|
|
||||||
raise ValueError("Custom image_size must contain 'width' and 'height' keys")
|
|
||||||
if not isinstance(image_size["width"], int) or not isinstance(image_size["height"], int):
|
|
||||||
raise ValueError("Custom image_size width and height must be integers")
|
|
||||||
if image_size["width"] < 64 or image_size["height"] < 64:
|
|
||||||
raise ValueError("Custom image_size dimensions must be at least 64x64")
|
|
||||||
if image_size["width"] > 2048 or image_size["height"] > 2048:
|
|
||||||
raise ValueError("Custom image_size dimensions must not exceed 2048x2048")
|
|
||||||
validated["image_size"] = image_size
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
raise ValueError(f"Unknown size_style: {size_style!r}")
|
||||||
|
|
||||||
# Validate num_inference_steps
|
if seed is not None and isinstance(seed, int):
|
||||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
payload["seed"] = seed
|
||||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
|
||||||
validated["num_inference_steps"] = num_inference_steps
|
|
||||||
|
|
||||||
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
|
if overrides:
|
||||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
for k, v in overrides.items():
|
||||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
if v is not None:
|
||||||
validated["guidance_scale"] = float(guidance_scale)
|
payload[k] = v
|
||||||
|
|
||||||
# Validate num_images
|
supports = meta["supports"]
|
||||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
return {k: v for k, v in payload.items() if k in supports}
|
||||||
raise ValueError("num_images must be an integer between 1 and 4")
|
|
||||||
validated["num_images"] = num_images
|
|
||||||
|
|
||||||
# Validate output_format
|
|
||||||
if output_format not in VALID_OUTPUT_FORMATS:
|
|
||||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
|
||||||
validated["output_format"] = output_format
|
|
||||||
|
|
||||||
# Validate acceleration
|
|
||||||
if acceleration not in VALID_ACCELERATION_MODES:
|
|
||||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
|
||||||
validated["acceleration"] = acceleration
|
|
||||||
|
|
||||||
return validated
|
|
||||||
|
|
||||||
|
|
||||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
# ---------------------------------------------------------------------------
|
||||||
"""
|
# Upscaler
|
||||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
# ---------------------------------------------------------------------------
|
||||||
|
def _upscale_image(image_url: str, original_prompt: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Upscale an image using FAL.ai's Clarity Upscaler.
|
||||||
|
|
||||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
Returns upscaled image dict, or None on failure (caller falls back to
|
||||||
when called from threaded contexts (e.g. gateway thread pool).
|
the original image).
|
||||||
|
|
||||||
Args:
|
|
||||||
image_url (str): URL of the image to upscale
|
|
||||||
original_prompt (str): Original prompt used to generate the image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("Upscaling image with Clarity Upscaler...")
|
logger.info("Upscaling image with Clarity Upscaler...")
|
||||||
|
|
||||||
# Prepare arguments for upscaler
|
|
||||||
upscaler_arguments = {
|
upscaler_arguments = {
|
||||||
"image_url": image_url,
|
"image_url": image_url,
|
||||||
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
|
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
|
||||||
|
|
@ -315,32 +542,26 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||||
"resemblance": UPSCALER_RESEMBLANCE,
|
"resemblance": UPSCALER_RESEMBLANCE,
|
||||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
handler = _submit_fal_request(UPSCALER_MODEL, arguments=upscaler_arguments)
|
||||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
|
||||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
|
||||||
# between calls (gateway thread-pool pattern).
|
|
||||||
handler = _submit_fal_request(
|
|
||||||
UPSCALER_MODEL,
|
|
||||||
arguments=upscaler_arguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the upscaled result (sync — blocks until done)
|
|
||||||
result = handler.get()
|
result = handler.get()
|
||||||
|
|
||||||
if result and "image" in result:
|
if result and "image" in result:
|
||||||
upscaled_image = result["image"]
|
upscaled_image = result["image"]
|
||||||
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
|
logger.info(
|
||||||
|
"Image upscaled successfully to %sx%s",
|
||||||
|
upscaled_image.get("width", "unknown"),
|
||||||
|
upscaled_image.get("height", "unknown"),
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"url": upscaled_image["url"],
|
"url": upscaled_image["url"],
|
||||||
"width": upscaled_image.get("width", 0),
|
"width": upscaled_image.get("width", 0),
|
||||||
"height": upscaled_image.get("height", 0),
|
"height": upscaled_image.get("height", 0),
|
||||||
"upscaled": True,
|
"upscaled": True,
|
||||||
"upscale_factor": UPSCALER_FACTOR
|
"upscale_factor": UPSCALER_FACTOR,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
logger.error("Upscaler returned invalid response")
|
logger.error("Upscaler returned invalid response")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -349,165 +570,137 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
def image_generate_tool(
|
def image_generate_tool(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||||
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
|
num_inference_steps: Optional[int] = None,
|
||||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
guidance_scale: Optional[float] = None,
|
||||||
num_images: int = DEFAULT_NUM_IMAGES,
|
num_images: Optional[int] = None,
|
||||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
output_format: Optional[str] = None,
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Generate an image from a text prompt using the configured FAL model.
|
||||||
|
|
||||||
|
The agent-facing schema exposes only ``prompt`` and ``aspect_ratio``; the
|
||||||
|
remaining kwargs are overrides for direct Python callers and are filtered
|
||||||
|
per-model via the ``supports`` whitelist (unsupported overrides are
|
||||||
|
silently dropped so legacy callers don't break when switching models).
|
||||||
|
|
||||||
|
Returns a JSON string with ``{"success": bool, "image": url | None,
|
||||||
|
"error": str, "error_type": str}``.
|
||||||
"""
|
"""
|
||||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
model_id, meta = _resolve_fal_model()
|
||||||
|
|
||||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
|
||||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
|
||||||
when asyncio.run() destroys and recreates event loops between calls, which
|
|
||||||
happens in the gateway's thread-pool pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The text prompt describing the desired image
|
|
||||||
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
|
|
||||||
num_inference_steps (int): Number of denoising steps (1-50, default: 50)
|
|
||||||
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
|
|
||||||
num_images (int): Number of images to generate (1-4, default: 1)
|
|
||||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
|
||||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON string containing minimal generation results:
|
|
||||||
{
|
|
||||||
"success": bool,
|
|
||||||
"image": str or None # URL of the upscaled image, or None if failed
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Validate and map aspect_ratio to actual image_size
|
|
||||||
aspect_ratio_lower = aspect_ratio.lower().strip() if aspect_ratio else DEFAULT_ASPECT_RATIO
|
|
||||||
if aspect_ratio_lower not in ASPECT_RATIO_MAP:
|
|
||||||
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
|
|
||||||
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
|
|
||||||
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
|
|
||||||
|
|
||||||
debug_call_data = {
|
debug_call_data = {
|
||||||
|
"model": model_id,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"aspect_ratio": aspect_ratio,
|
"aspect_ratio": aspect_ratio,
|
||||||
"image_size": image_size,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"guidance_scale": guidance_scale,
|
"guidance_scale": guidance_scale,
|
||||||
"num_images": num_images,
|
"num_images": num_images,
|
||||||
"output_format": output_format,
|
"output_format": output_format,
|
||||||
"seed": seed
|
"seed": seed,
|
||||||
},
|
},
|
||||||
"error": None,
|
"error": None,
|
||||||
"success": False,
|
"success": False,
|
||||||
"images_generated": 0,
|
"images_generated": 0,
|
||||||
"generation_time": 0
|
"generation_time": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
|
|
||||||
|
|
||||||
# Validate prompt
|
|
||||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||||
raise ValueError("Prompt is required and must be a non-empty string")
|
raise ValueError("Prompt is required and must be a non-empty string")
|
||||||
|
|
||||||
# Check API key availability
|
|
||||||
if not (os.getenv("FAL_KEY") or _resolve_managed_fal_gateway()):
|
if not (os.getenv("FAL_KEY") or _resolve_managed_fal_gateway()):
|
||||||
message = "FAL_KEY environment variable not set"
|
message = "FAL_KEY environment variable not set"
|
||||||
if managed_nous_tools_enabled():
|
if managed_nous_tools_enabled():
|
||||||
message += " and managed FAL gateway is unavailable"
|
message += " and managed FAL gateway is unavailable"
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
|
|
||||||
# Validate other parameters
|
aspect_lc = (aspect_ratio or DEFAULT_ASPECT_RATIO).lower().strip()
|
||||||
validated_params = _validate_parameters(
|
if aspect_lc not in VALID_ASPECT_RATIOS:
|
||||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
|
logger.warning(
|
||||||
|
"Invalid aspect_ratio '%s', defaulting to '%s'",
|
||||||
|
aspect_ratio, DEFAULT_ASPECT_RATIO,
|
||||||
|
)
|
||||||
|
aspect_lc = DEFAULT_ASPECT_RATIO
|
||||||
|
|
||||||
|
overrides: Dict[str, Any] = {}
|
||||||
|
if num_inference_steps is not None:
|
||||||
|
overrides["num_inference_steps"] = num_inference_steps
|
||||||
|
if guidance_scale is not None:
|
||||||
|
overrides["guidance_scale"] = guidance_scale
|
||||||
|
if num_images is not None:
|
||||||
|
overrides["num_images"] = num_images
|
||||||
|
if output_format is not None:
|
||||||
|
overrides["output_format"] = output_format
|
||||||
|
|
||||||
|
arguments = _build_fal_payload(
|
||||||
|
model_id, prompt, aspect_lc, seed=seed, overrides=overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare arguments for FAL.ai FLUX 2 Pro API
|
logger.info(
|
||||||
arguments = {
|
"Generating image with %s (%s) — prompt: %s",
|
||||||
"prompt": prompt.strip(),
|
meta.get("display", model_id), model_id, prompt[:80],
|
||||||
"image_size": validated_params["image_size"],
|
|
||||||
"num_inference_steps": validated_params["num_inference_steps"],
|
|
||||||
"guidance_scale": validated_params["guidance_scale"],
|
|
||||||
"num_images": validated_params["num_images"],
|
|
||||||
"output_format": validated_params["output_format"],
|
|
||||||
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
|
|
||||||
"safety_tolerance": SAFETY_TOLERANCE,
|
|
||||||
"sync_mode": True # Use sync mode for immediate results
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add seed if provided
|
|
||||||
if seed is not None and isinstance(seed, int):
|
|
||||||
arguments["seed"] = seed
|
|
||||||
|
|
||||||
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
|
|
||||||
logger.info(" Model: %s", DEFAULT_MODEL)
|
|
||||||
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
|
|
||||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
|
||||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
|
||||||
|
|
||||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
|
||||||
handler = _submit_fal_request(
|
|
||||||
DEFAULT_MODEL,
|
|
||||||
arguments=arguments,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the result (sync — blocks until done)
|
handler = _submit_fal_request(model_id, arguments=arguments)
|
||||||
result = handler.get()
|
result = handler.get()
|
||||||
|
|
||||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||||
|
|
||||||
# Process the response
|
|
||||||
if not result or "images" not in result:
|
if not result or "images" not in result:
|
||||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
raise ValueError("Invalid response from FAL.ai API — no images returned")
|
||||||
|
|
||||||
images = result.get("images", [])
|
images = result.get("images", [])
|
||||||
if not images:
|
if not images:
|
||||||
raise ValueError("No images were generated")
|
raise ValueError("No images were generated")
|
||||||
|
|
||||||
# Format image data and upscale images
|
should_upscale = bool(meta.get("upscale", False))
|
||||||
|
|
||||||
formatted_images = []
|
formatted_images = []
|
||||||
for img in images:
|
for img in images:
|
||||||
if isinstance(img, dict) and "url" in img:
|
if not (isinstance(img, dict) and "url" in img):
|
||||||
|
continue
|
||||||
original_image = {
|
original_image = {
|
||||||
"url": img["url"],
|
"url": img["url"],
|
||||||
"width": img.get("width", 0),
|
"width": img.get("width", 0),
|
||||||
"height": img.get("height", 0)
|
"height": img.get("height", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Attempt to upscale the image
|
if should_upscale:
|
||||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||||
|
|
||||||
if upscaled_image:
|
if upscaled_image:
|
||||||
# Use upscaled image if successful
|
|
||||||
formatted_images.append(upscaled_image)
|
formatted_images.append(upscaled_image)
|
||||||
else:
|
continue
|
||||||
# Fall back to original image if upscaling fails
|
logger.warning("Using original image as fallback (upscale failed)")
|
||||||
logger.warning("Using original image as fallback")
|
|
||||||
original_image["upscaled"] = False
|
original_image["upscaled"] = False
|
||||||
formatted_images.append(original_image)
|
formatted_images.append(original_image)
|
||||||
|
|
||||||
if not formatted_images:
|
if not formatted_images:
|
||||||
raise ValueError("No valid image URLs returned from API")
|
raise ValueError("No valid image URLs returned from API")
|
||||||
|
|
||||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled"))
|
||||||
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
|
logger.info(
|
||||||
|
"Generated %s image(s) in %.1fs (%s upscaled) via %s",
|
||||||
|
len(formatted_images), generation_time, upscaled_count, model_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare successful response - minimal format
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"image": formatted_images[0]["url"] if formatted_images else None
|
"image": formatted_images[0]["url"] if formatted_images else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
debug_call_data["success"] = True
|
debug_call_data["success"] = True
|
||||||
debug_call_data["images_generated"] = len(formatted_images)
|
debug_call_data["images_generated"] = len(formatted_images)
|
||||||
debug_call_data["generation_time"] = generation_time
|
debug_call_data["generation_time"] = generation_time
|
||||||
|
|
||||||
# Log debug information
|
|
||||||
_debug.log_call("image_generate_tool", debug_call_data)
|
_debug.log_call("image_generate_tool", debug_call_data)
|
||||||
_debug.save()
|
_debug.save()
|
||||||
|
|
||||||
|
|
@ -518,7 +711,6 @@ def image_generate_tool(
|
||||||
error_msg = f"Error generating image: {str(e)}"
|
error_msg = f"Error generating image: {str(e)}"
|
||||||
logger.error("%s", error_msg, exc_info=True)
|
logger.error("%s", error_msg, exc_info=True)
|
||||||
|
|
||||||
# Include error details so callers can diagnose failures
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"success": False,
|
"success": False,
|
||||||
"image": None,
|
"image": None,
|
||||||
|
|
@ -535,109 +727,54 @@ def image_generate_tool(
|
||||||
|
|
||||||
|
|
||||||
def check_fal_api_key() -> bool:
|
def check_fal_api_key() -> bool:
|
||||||
"""
|
"""True if the FAL.ai API key (direct or managed gateway) is available."""
|
||||||
Check if the FAL.ai API key is available in environment variables.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if API key is set, False otherwise
|
|
||||||
"""
|
|
||||||
return bool(os.getenv("FAL_KEY") or _resolve_managed_fal_gateway())
|
return bool(os.getenv("FAL_KEY") or _resolve_managed_fal_gateway())
|
||||||
|
|
||||||
|
|
||||||
def check_image_generation_requirements() -> bool:
|
def check_image_generation_requirements() -> bool:
|
||||||
"""
|
"""True if FAL credentials and fal_client SDK are both available."""
|
||||||
Check if all requirements for image generation tools are met.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if requirements are met, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Check API key
|
|
||||||
if not check_fal_api_key():
|
if not check_fal_api_key():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if fal_client is available
|
|
||||||
import fal_client # noqa: F401 — SDK presence check
|
import fal_client # noqa: F401 — SDK presence check
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Demo / CLI entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
print("🎨 Image Generation Tools — FAL.ai multi-model support")
|
||||||
Simple test/demo when run directly
|
|
||||||
"""
|
|
||||||
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# Check if API key is available
|
if not check_fal_api_key():
|
||||||
api_available = check_fal_api_key()
|
|
||||||
|
|
||||||
if not api_available:
|
|
||||||
print("❌ FAL_KEY environment variable not set")
|
print("❌ FAL_KEY environment variable not set")
|
||||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
print(" Set it via: export FAL_KEY='your-key-here'")
|
||||||
print("Get API key at: https://fal.ai/")
|
print(" Get a key: https://fal.ai/")
|
||||||
exit(1)
|
raise SystemExit(1)
|
||||||
else:
|
|
||||||
print("✅ FAL.ai API key found")
|
print("✅ FAL.ai API key found")
|
||||||
|
|
||||||
# Check if fal_client is available
|
|
||||||
try:
|
try:
|
||||||
import fal_client
|
import fal_client # noqa: F401
|
||||||
print("✅ fal_client library available")
|
print("✅ fal_client library available")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("❌ fal_client library not found")
|
print("❌ fal_client library not found — pip install fal-client")
|
||||||
print("Please install: pip install fal-client")
|
raise SystemExit(1)
|
||||||
exit(1)
|
|
||||||
|
|
||||||
print("🛠️ Image generation tools ready for use!")
|
model_id, meta = _resolve_fal_model()
|
||||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
print(f"🤖 Active model: {meta.get('display', model_id)} ({model_id})")
|
||||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
print(f" Speed: {meta.get('speed', '?')} · Price: {meta.get('price', '?')}")
|
||||||
|
print(f" Upscaler: {'on' if meta.get('upscale') else 'off'}")
|
||||||
|
|
||||||
|
print("\nAvailable models:")
|
||||||
|
for mid, m in FAL_MODELS.items():
|
||||||
|
marker = " ← active" if mid == model_id else ""
|
||||||
|
print(f" {mid:<32} {m.get('speed', '?'):<6} {m.get('price', '?')}{marker}")
|
||||||
|
|
||||||
# Show debug mode status
|
|
||||||
if _debug.active:
|
if _debug.active:
|
||||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
print(f"\n🐛 Debug mode enabled — session {_debug.session_id}")
|
||||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
|
|
||||||
else:
|
|
||||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
|
||||||
|
|
||||||
print("\nBasic usage:")
|
|
||||||
print(" from image_generation_tool import image_generate_tool")
|
|
||||||
print(" import asyncio")
|
|
||||||
print("")
|
|
||||||
print(" async def main():")
|
|
||||||
print(" # Generate image with automatic 2x upscaling")
|
|
||||||
print(" result = await image_generate_tool(")
|
|
||||||
print(" prompt='A serene mountain landscape with cherry blossoms',")
|
|
||||||
print(" image_size='landscape_4_3',")
|
|
||||||
print(" num_images=1")
|
|
||||||
print(" )")
|
|
||||||
print(" print(result)")
|
|
||||||
print(" asyncio.run(main())")
|
|
||||||
|
|
||||||
print("\nSupported image sizes:")
|
|
||||||
for size in VALID_IMAGE_SIZES:
|
|
||||||
print(f" - {size}")
|
|
||||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
|
||||||
|
|
||||||
print("\nAcceleration modes:")
|
|
||||||
for mode in VALID_ACCELERATION_MODES:
|
|
||||||
print(f" - {mode}")
|
|
||||||
|
|
||||||
print("\nExample prompts:")
|
|
||||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
|
||||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
|
||||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
|
||||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
|
||||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
|
||||||
|
|
||||||
print("\nDebug mode:")
|
|
||||||
print(" # Enable debug logging")
|
|
||||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
|
||||||
print(" # Debug logs capture all image generation calls and results")
|
|
||||||
print(" # Logs saved to: ./logs/image_tools_debug_UUID.json")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -647,23 +784,28 @@ from tools.registry import registry, tool_error
|
||||||
|
|
||||||
IMAGE_GENERATE_SCHEMA = {
|
IMAGE_GENERATE_SCHEMA = {
|
||||||
"name": "image_generate",
|
"name": "image_generate",
|
||||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ",
|
"description": (
|
||||||
|
"Generate high-quality images from text prompts using FAL.ai. "
|
||||||
|
"The underlying model is user-configured (default: FLUX 2 Klein 9B, "
|
||||||
|
"sub-1s generation) and is not selectable by the agent. Returns a "
|
||||||
|
"single image URL. Display it using markdown: "
|
||||||
|
),
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
|
||||||
},
|
},
|
||||||
"aspect_ratio": {
|
"aspect_ratio": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["landscape", "square", "portrait"],
|
"enum": list(VALID_ASPECT_RATIOS),
|
||||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||||
"default": "landscape"
|
"default": DEFAULT_ASPECT_RATIO,
|
||||||
}
|
},
|
||||||
|
},
|
||||||
|
"required": ["prompt"],
|
||||||
},
|
},
|
||||||
"required": ["prompt"]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -673,12 +815,7 @@ def _handle_image_generate(args, **kw):
|
||||||
return tool_error("prompt is required for image generation")
|
return tool_error("prompt is required for image generation")
|
||||||
return image_generate_tool(
|
return image_generate_tool(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=args.get("aspect_ratio", "landscape"),
|
aspect_ratio=args.get("aspect_ratio", DEFAULT_ASPECT_RATIO),
|
||||||
num_inference_steps=50,
|
|
||||||
guidance_scale=4.5,
|
|
||||||
num_images=1,
|
|
||||||
output_format="png",
|
|
||||||
seed=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -689,6 +826,6 @@ registry.register(
|
||||||
handler=_handle_image_generate,
|
handler=_handle_image_generate,
|
||||||
check_fn=check_image_generation_requirements,
|
check_fn=check_image_generation_requirements,
|
||||||
requires_env=[],
|
requires_env=[],
|
||||||
is_async=False, # Switched to sync fal_client API to fix "Event loop is closed" in gateway
|
is_async=False, # sync fal_client API to avoid "Event loop is closed" in gateway
|
||||||
emoji="🎨",
|
emoji="🎨",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ In addition to built-in tools, Hermes can load tools dynamically from MCP server
|
||||||
|
|
||||||
| Tool | Description | Requires environment |
|
| Tool | Description | Requires environment |
|
||||||
|------|-------------|----------------------|
|
|------|-------------|----------------------|
|
||||||
| `image_generate` | Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using… | FAL_KEY |
|
| `image_generate` | Generate high-quality images from text prompts using FAL.ai. The underlying model is user-configured (default: FLUX 2 Klein 9B, sub-1s generation) and is not selectable by the agent. Returns a single image URL. Display it using… | FAL_KEY |
|
||||||
|
|
||||||
## `memory` toolset
|
## `memory` toolset
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,35 @@
|
||||||
---
|
---
|
||||||
title: Image Generation
|
title: Image Generation
|
||||||
description: Generate high-quality images using FLUX 2 Pro with automatic upscaling via FAL.ai.
|
description: Generate images via FAL.ai — 8 models including FLUX 2, GPT-Image, Nano Banana, Ideogram, and more, selectable via `hermes tools`.
|
||||||
sidebar_label: Image Generation
|
sidebar_label: Image Generation
|
||||||
sidebar_position: 6
|
sidebar_position: 6
|
||||||
---
|
---
|
||||||
|
|
||||||
# Image Generation
|
# Image Generation
|
||||||
|
|
||||||
Hermes Agent can generate images from text prompts using FAL.ai's **FLUX 2 Pro** model with automatic 2x upscaling via the **Clarity Upscaler** for enhanced quality.
|
Hermes Agent generates images from text prompts via FAL.ai. Eight models are supported out of the box, each with different speed, quality, and cost tradeoffs. The active model is user-configurable via `hermes tools` and persists in `config.yaml`.
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
| Model | Speed | Strengths | Price |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `fal-ai/flux-2/klein/9b` *(default)* | <1s | Fast, crisp text | $0.006/MP |
|
||||||
|
| `fal-ai/flux-2-pro` | ~6s | Studio photorealism | $0.03/MP |
|
||||||
|
| `fal-ai/z-image/turbo` | ~2s | Bilingual EN/CN, 6B params | $0.005/MP |
|
||||||
|
| `fal-ai/nano-banana` | ~6s | Gemini 2.5, character consistency | $0.08/image |
|
||||||
|
| `fal-ai/gpt-image-1.5` | ~15s | Prompt adherence | $0.034/image |
|
||||||
|
| `fal-ai/ideogram/v3` | ~5s | Best typography | $0.03–0.09/image |
|
||||||
|
| `fal-ai/recraft-v3` | ~8s | Vector art, brand styles | $0.04/image |
|
||||||
|
| `fal-ai/qwen-image` | ~12s | LLM-based, complex text | $0.02/MP |
|
||||||
|
|
||||||
|
Prices are FAL's pricing at time of writing; check [fal.ai](https://fal.ai/) for current numbers.
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
:::tip Nous Subscribers
|
:::tip Nous Subscribers
|
||||||
If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription, you can use image generation through the **[Tool Gateway](tool-gateway.md)** without a FAL API key. Run `hermes model` or `hermes tools` to enable it.
|
If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription, you can use image generation through the **[Tool Gateway](tool-gateway.md)** without a FAL API key. Your model selection persists across both paths.
|
||||||
|
|
||||||
|
If the managed gateway returns `HTTP 4xx` for a specific model, that model isn't yet proxied on the portal side — the agent will tell you so, with remediation steps (set `FAL_KEY` for direct access, or pick a different model).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Get a FAL API Key
|
### Get a FAL API Key
|
||||||
|
|
@ -20,150 +37,117 @@ If you have a paid [Nous Portal](https://portal.nousresearch.com) subscription,
|
||||||
1. Sign up at [fal.ai](https://fal.ai/)
|
1. Sign up at [fal.ai](https://fal.ai/)
|
||||||
2. Generate an API key from your dashboard
|
2. Generate an API key from your dashboard
|
||||||
|
|
||||||
### Configure the Key
|
### Configure and Pick a Model
|
||||||
|
|
||||||
|
Run the tools command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Add to ~/.hermes/.env
|
hermes tools
|
||||||
FAL_KEY=your-fal-api-key-here
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Install the Client Library
|
Navigate to **🎨 Image Generation**, pick your backend (Nous Subscription or FAL.ai), then the picker shows all supported models in a column-aligned table — arrow keys to navigate, Enter to select:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
pip install fal-client
|
Model Speed Strengths Price
|
||||||
|
fal-ai/flux-2/klein/9b <1s Fast, crisp text $0.006/MP ← currently in use
|
||||||
|
fal-ai/flux-2-pro ~6s Studio photorealism $0.03/MP
|
||||||
|
fal-ai/z-image/turbo ~2s Bilingual EN/CN, 6B $0.005/MP
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
:::info
|
Your selection is saved to `config.yaml`:
|
||||||
The image generation tool is automatically available when `FAL_KEY` is set. No additional toolset configuration is needed.
|
|
||||||
:::
|
|
||||||
|
|
||||||
## How It Works
|
```yaml
|
||||||
|
image_gen:
|
||||||
|
model: fal-ai/flux-2/klein/9b
|
||||||
|
use_gateway: false # true if using Nous Subscription
|
||||||
|
```
|
||||||
|
|
||||||
When you ask Hermes to generate an image:
|
### GPT-Image Quality
|
||||||
|
|
||||||
1. **Generation** — Your prompt is sent to the FLUX 2 Pro model (`fal-ai/flux-2-pro`)
|
The `fal-ai/gpt-image-1.5` request quality is pinned to `medium` (~$0.034/image at 1024×1024). We don't expose the `low` / `high` tiers as a user-facing option so that Nous Portal billing stays predictable across all users — the cost spread between tiers is ~22×. If you want a cheaper GPT-Image option, pick a different model; if you want higher quality, use Klein 9B or Imagen-class models.
|
||||||
2. **Upscaling** — The generated image is automatically upscaled 2x using the Clarity Upscaler (`fal-ai/clarity-upscaler`)
|
|
||||||
3. **Delivery** — The upscaled image URL is returned
|
|
||||||
|
|
||||||
If upscaling fails for any reason, the original image is returned as a fallback.
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Simply ask Hermes to create an image:
|
The agent-facing schema is intentionally minimal — the model picks up whatever you've configured:
|
||||||
|
|
||||||
```
|
```
|
||||||
Generate an image of a serene mountain landscape with cherry blossoms
|
Generate an image of a serene mountain landscape with cherry blossoms
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
Create a portrait of a wise old owl perched on an ancient tree branch
|
Create a square portrait of a wise old owl — use the typography model
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
Make me a futuristic cityscape with flying cars and neon lights
|
Make me a futuristic cityscape, landscape orientation
|
||||||
```
|
```
|
||||||
|
|
||||||
## Parameters
|
|
||||||
|
|
||||||
The `image_generate_tool` accepts these parameters:
|
|
||||||
|
|
||||||
| Parameter | Default | Range | Description |
|
|
||||||
|-----------|---------|-------|-------------|
|
|
||||||
| `prompt` | *(required)* | — | Text description of the desired image |
|
|
||||||
| `aspect_ratio` | `"landscape"` | `landscape`, `square`, `portrait` | Image aspect ratio |
|
|
||||||
| `num_inference_steps` | `50` | 1–100 | Number of denoising steps (more = higher quality, slower) |
|
|
||||||
| `guidance_scale` | `4.5` | 0.1–20.0 | How closely to follow the prompt |
|
|
||||||
| `num_images` | `1` | 1–4 | Number of images to generate |
|
|
||||||
| `output_format` | `"png"` | `png`, `jpeg` | Image file format |
|
|
||||||
| `seed` | *(random)* | any integer | Random seed for reproducible results |
|
|
||||||
|
|
||||||
## Aspect Ratios
|
## Aspect Ratios
|
||||||
|
|
||||||
The tool uses simplified aspect ratio names that map to FLUX 2 Pro image sizes:
|
Every model accepts the same three aspect ratios from the agent's perspective. Internally, each model's native size spec is filled in automatically:
|
||||||
|
|
||||||
| Aspect Ratio | Maps To | Best For |
|
| Agent input | image_size (flux/z-image/qwen/recraft/ideogram) | aspect_ratio (nano-banana) | image_size (gpt-image) |
|
||||||
|-------------|---------|----------|
|
|---|---|---|---|
|
||||||
| `landscape` | `landscape_16_9` | Wallpapers, banners, scenes |
|
| `landscape` | `landscape_16_9` | `16:9` | `1536x1024` |
|
||||||
| `square` | `square_hd` | Profile pictures, social media posts |
|
| `square` | `square_hd` | `1:1` | `1024x1024` |
|
||||||
| `portrait` | `portrait_16_9` | Character art, phone wallpapers |
|
| `portrait` | `portrait_16_9` | `9:16` | `1024x1536` |
|
||||||
|
|
||||||
:::tip
|
This translation happens in `_build_fal_payload()` — agent code never has to know about per-model schema differences.
|
||||||
You can also use the raw FLUX 2 Pro size presets directly: `square_hd`, `square`, `portrait_4_3`, `portrait_16_9`, `landscape_4_3`, `landscape_16_9`. Custom sizes up to 2048x2048 are also supported.
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Automatic Upscaling
|
## Automatic Upscaling
|
||||||
|
|
||||||
Every generated image is automatically upscaled 2x using FAL.ai's Clarity Upscaler with these settings:
|
Upscaling via FAL's **Clarity Upscaler** is gated per-model:
|
||||||
|
|
||||||
|
| Model | Upscale? | Why |
|
||||||
|
|---|---|---|
|
||||||
|
| `fal-ai/flux-2-pro` | ✓ | Backward-compat (was the pre-picker default) |
|
||||||
|
| All others | ✗ | Fast models would lose their sub-second value prop; hi-res models don't need it |
|
||||||
|
|
||||||
|
When upscaling runs, it uses these settings:
|
||||||
|
|
||||||
| Setting | Value |
|
| Setting | Value |
|
||||||
|---------|-------|
|
|---|---|
|
||||||
| Upscale Factor | 2x |
|
| Upscale factor | 2× |
|
||||||
| Creativity | 0.35 |
|
| Creativity | 0.35 |
|
||||||
| Resemblance | 0.6 |
|
| Resemblance | 0.6 |
|
||||||
| Guidance Scale | 4 |
|
| Guidance scale | 4 |
|
||||||
| Inference Steps | 18 |
|
| Inference steps | 18 |
|
||||||
| Positive Prompt | `"masterpiece, best quality, highres"` + your original prompt |
|
|
||||||
| Negative Prompt | `"(worst quality, low quality, normal quality:2)"` |
|
|
||||||
|
|
||||||
The upscaler enhances detail and resolution while preserving the original composition. If the upscaler fails (network issue, rate limit), the original resolution image is returned automatically.
|
If upscaling fails (network issue, rate limit), the original image is returned automatically.
|
||||||
|
|
||||||
## Example Prompts
|
## How It Works Internally
|
||||||
|
|
||||||
Here are some effective prompts to try:
|
1. **Model resolution** — `_resolve_fal_model()` reads `image_gen.model` from `config.yaml`, falls back to the `FAL_IMAGE_MODEL` env var, then to `fal-ai/flux-2/klein/9b`.
|
||||||
|
2. **Payload building** — `_build_fal_payload()` translates your `aspect_ratio` into the model's native format (preset enum, aspect-ratio enum, or GPT literal), merges the model's default params, applies any caller overrides, then filters to the model's `supports` whitelist so unsupported keys are never sent.
|
||||||
```
|
3. **Submission** — `_submit_fal_request()` routes via direct FAL credentials or the managed Nous gateway.
|
||||||
A candid street photo of a woman with a pink bob and bold eyeliner
|
4. **Upscaling** — runs only if the model's metadata has `upscale: True`.
|
||||||
```
|
5. **Delivery** — final image URL returned to the agent, which emits a `MEDIA:<url>` tag that platform adapters convert to native media.
|
||||||
|
|
||||||
```
|
|
||||||
Modern architecture building with glass facade, sunset lighting
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
Abstract art with vibrant colors and geometric patterns
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
Portrait of a wise old owl perched on ancient tree branch
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
Futuristic cityscape with flying cars and neon lights
|
|
||||||
```
|
|
||||||
|
|
||||||
## Debugging
|
## Debugging
|
||||||
|
|
||||||
Enable debug logging for image generation:
|
Enable debug logging:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export IMAGE_TOOLS_DEBUG=true
|
export IMAGE_TOOLS_DEBUG=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Debug logs are saved to `./logs/image_tools_debug_<session_id>.json` with details about each generation request, parameters, timing, and any errors.
|
Debug logs go to `./logs/image_tools_debug_<session_id>.json` with per-call details (model, parameters, timing, errors).
|
||||||
|
|
||||||
## Safety Settings
|
|
||||||
|
|
||||||
The image generation tool runs with safety checks disabled by default (`safety_tolerance: 5`, the most permissive setting). This is configured at the code level and is not user-adjustable.
|
|
||||||
|
|
||||||
## Platform Delivery
|
## Platform Delivery
|
||||||
|
|
||||||
Generated images are delivered differently depending on the platform:
|
| Platform | Delivery |
|
||||||
|
|---|---|
|
||||||
| Platform | Delivery method |
|
| **CLI** | Image URL printed as markdown `` — click to open |
|
||||||
|----------|----------------|
|
| **Telegram** | Photo message with the prompt as caption |
|
||||||
| **CLI** | Image URL printed as markdown `` — click to open in browser |
|
| **Discord** | Embedded in a message |
|
||||||
| **Telegram** | Image sent as a photo message with the prompt as caption |
|
| **Slack** | URL unfurled by Slack |
|
||||||
| **Discord** | Image embedded in a message |
|
| **WhatsApp** | Media message |
|
||||||
| **Slack** | Image URL in message (Slack unfurls it) |
|
| **Others** | URL in plain text |
|
||||||
| **WhatsApp** | Image sent as a media message |
|
|
||||||
| **Other platforms** | Image URL in plain text |
|
|
||||||
|
|
||||||
The agent uses `MEDIA:<url>` syntax in its response, which the platform adapter converts to the appropriate format.
|
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- **Requires FAL API key** — image generation incurs API costs on your FAL.ai account
|
- **Requires FAL credentials** (direct `FAL_KEY` or Nous Subscription)
|
||||||
- **No image editing** — this is text-to-image only, no inpainting or img2img
|
- **Text-to-image only** — no inpainting, img2img, or editing via this tool
|
||||||
- **URL-based delivery** — images are returned as temporary FAL.ai URLs, not saved locally. URLs expire after a period (typically hours)
|
- **Temporary URLs** — FAL returns hosted URLs that expire after hours/days; save locally if needed
|
||||||
- **Upscaling adds latency** — the automatic 2x upscale step adds processing time
|
- **Per-model constraints** — some models don't support `seed`, `num_inference_steps`, etc. The `supports` filter silently drops unsupported params; this is expected behavior
|
||||||
- **Max 4 images per request** — `num_images` is capped at 4
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ Hermes Agent includes a rich set of capabilities that extend far beyond basic ch
|
||||||
- **[Voice Mode](voice-mode.md)** — Full voice interaction across CLI and messaging platforms. Talk to the agent using your microphone, hear spoken replies, and have live voice conversations in Discord voice channels.
|
- **[Voice Mode](voice-mode.md)** — Full voice interaction across CLI and messaging platforms. Talk to the agent using your microphone, hear spoken replies, and have live voice conversations in Discord voice channels.
|
||||||
- **[Browser Automation](browser.md)** — Full browser automation with multiple backends: Browserbase cloud, Browser Use cloud, local Chrome via CDP, or local Chromium. Navigate websites, fill forms, and extract information.
|
- **[Browser Automation](browser.md)** — Full browser automation with multiple backends: Browserbase cloud, Browser Use cloud, local Chrome via CDP, or local Chromium. Navigate websites, fill forms, and extract information.
|
||||||
- **[Vision & Image Paste](vision.md)** — Multimodal vision support. Paste images from your clipboard into the CLI and ask the agent to analyze, describe, or work with them using any vision-capable model.
|
- **[Vision & Image Paste](vision.md)** — Multimodal vision support. Paste images from your clipboard into the CLI and ask the agent to analyze, describe, or work with them using any vision-capable model.
|
||||||
- **[Image Generation](image-generation.md)** — Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic 2x upscaling via the Clarity Upscaler.
|
- **[Image Generation](image-generation.md)** — Generate images from text prompts using FAL.ai. Eight models supported (FLUX 2 Klein/Pro, GPT-Image 1.5, Nano Banana, Ideogram V3, Recraft V3, Qwen, Z-Image Turbo); pick one via `hermes tools`.
|
||||||
- **[Voice & TTS](tts.md)** — Text-to-speech output and voice message transcription across all messaging platforms, with five provider options: Edge TTS (free), ElevenLabs, OpenAI TTS, MiniMax, and NeuTTS.
|
- **[Voice & TTS](tts.md)** — Text-to-speech output and voice message transcription across all messaging platforms, with five provider options: Edge TTS (free), ElevenLabs, OpenAI TTS, MiniMax, and NeuTTS.
|
||||||
|
|
||||||
## Integrations
|
## Integrations
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ The **Tool Gateway** lets paid [Nous Portal](https://portal.nousresearch.com) su
|
||||||
| Tool | What It Does | Direct Alternative |
|
| Tool | What It Does | Direct Alternative |
|
||||||
|------|--------------|--------------------|
|
|------|--------------|--------------------|
|
||||||
| **Web search & extract** | Search the web and extract page content via Firecrawl | `FIRECRAWL_API_KEY`, `EXA_API_KEY`, `PARALLEL_API_KEY`, `TAVILY_API_KEY` |
|
| **Web search & extract** | Search the web and extract page content via Firecrawl | `FIRECRAWL_API_KEY`, `EXA_API_KEY`, `PARALLEL_API_KEY`, `TAVILY_API_KEY` |
|
||||||
| **Image generation** | Generate images via FAL (FLUX 2 Pro + upscaling) | `FAL_KEY` |
|
| **Image generation** | Generate images via FAL (8 models: FLUX 2 Klein/Pro, GPT-Image, Nano Banana, Ideogram, Recraft, Qwen, Z-Image) | `FAL_KEY` |
|
||||||
| **Text-to-speech** | Convert text to speech via OpenAI TTS | `VOICE_TOOLS_OPENAI_KEY`, `ELEVENLABS_API_KEY` |
|
| **Text-to-speech** | Convert text to speech via OpenAI TTS | `VOICE_TOOLS_OPENAI_KEY`, `ELEVENLABS_API_KEY` |
|
||||||
| **Browser automation** | Control cloud browsers via Browser Use | `BROWSER_USE_API_KEY`, `BROWSERBASE_API_KEY` |
|
| **Browser automation** | Control cloud browsers via Browser Use | `BROWSER_USE_API_KEY`, `BROWSERBASE_API_KEY` |
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue