diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index f185c8788c2..87e7816169c 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -311,6 +311,16 @@ TOOL_CATEGORIES = { "image_gen": { "name": "Image Generation", "icon": "🎨", + # Per-provider rows for FAL.ai (`plugins/image_gen/fal`), OpenAI, + # OpenAI Codex, and xAI are injected at runtime from each + # ``plugins.image_gen.`` package via + # ``_plugin_image_gen_providers()`` in ``_visible_providers``. + # Only non-provider UX setup-flow rows remain here: + # - "Nous Subscription" — managed FAL billed via the Nous + # subscription (requires_nous_auth + override_env_vars). + # Uses the fal plugin as the underlying backend but has a + # distinct setup UX. + # Mirrors the shape browser/video_gen ship today. "providers": [ { "name": "Nous Subscription", @@ -322,15 +332,6 @@ TOOL_CATEGORIES = { "override_env_vars": ["FAL_KEY"], "imagegen_backend": "fal", }, - { - "name": "FAL.ai", - "badge": "paid", - "tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.", - "env_vars": [ - {"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"}, - ], - "imagegen_backend": "fal", - }, ], }, "video_gen": { @@ -1567,12 +1568,9 @@ def _plugin_image_gen_providers() -> list[dict]: Each returned dict looks like a regular ``TOOL_CATEGORIES`` provider row but carries an ``image_gen_plugin_name`` marker so downstream code (config writing, model picker) knows to route through the - plugin registry instead of the in-tree FAL backend. - - FAL is skipped — it's already exposed by the hardcoded - ``TOOL_CATEGORIES["image_gen"]`` entries. When FAL gets ported to - a plugin in a follow-up PR, the hardcoded entries go away and this - function surfaces it alongside OpenAI automatically. + plugin registry. Every image-gen backend is a plugin now — there + are no hardcoded rows left in ``TOOL_CATEGORIES["image_gen"]`` for + this function to dedupe against (see issue #26241). """ try: from agent.image_gen_registry import list_providers @@ -1585,9 +1583,6 @@ def _plugin_image_gen_providers() -> list[dict]: rows: list[dict] = [] for provider in providers: - if getattr(provider, "name", None) == "fal": - # FAL has its own hardcoded rows today. - continue try: schema = provider.get_setup_schema() except Exception: diff --git a/plugins/image_gen/fal/__init__.py b/plugins/image_gen/fal/__init__.py new file mode 100644 index 00000000000..21b88f37f34 --- /dev/null +++ b/plugins/image_gen/fal/__init__.py @@ -0,0 +1,182 @@ +"""FAL.ai image generation backend. + +Wraps the 18-model FAL catalog (FLUX 2, Z-Image, Nano Banana, GPT +Image 1.5, Recraft, Imagen 4, Qwen, Ideogram, …) as an +:class:`ImageGenProvider` implementation. + +The heavy lifting — model catalog, payload construction, request +submission, managed-Nous-gateway selection, Clarity Upscaler chaining +— lives in :mod:`tools.image_generation_tool`. This plugin reaches into +that module via call-time indirection (``import tools.image_generation_tool as _it``) +so: + +* the existing test suite (``tests/tools/test_image_generation.py``, + ``tests/tools/test_managed_media_gateways.py``) keeps patching + ``image_tool._submit_fal_request`` / ``image_tool.fal_client`` / + ``image_tool._managed_fal_client`` without modification, and +* there's exactly one canonical FAL code path on disk — the plugin is a + registration adapter, not a parallel implementation. + +See issue #26241 for the migration plan and the +``plugin-extraction-test-patch-compatibility.md`` rules this follows. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, Dict, List, Optional + +from agent.image_gen_provider import ( + DEFAULT_ASPECT_RATIO, + ImageGenProvider, + resolve_aspect_ratio, +) + +logger = logging.getLogger(__name__) + + +class FalImageGenProvider(ImageGenProvider): + """FAL.ai image generation backend. + + Delegates to ``tools.image_generation_tool.image_generate_tool`` so + the in-tree FAL implementation (model catalog, payload builder, + managed-gateway selection, Clarity Upscaler chaining) is the single + source of truth. Everything is resolved at call time via the + ``_it`` indirection so tests can monkey-patch the legacy module. + """ + + @property + def name(self) -> str: + return "fal" + + @property + def display_name(self) -> str: + return "FAL.ai" + + def is_available(self) -> bool: + # Available when direct FAL_KEY is set OR the managed Nous + # gateway resolves a fal-queue origin. Both checks come from the + # legacy module so this provider tracks whatever logic ships + # there. + import tools.image_generation_tool as _it + try: + return bool(_it.check_fal_api_key()) + except Exception: # noqa: BLE001 — defensive; never break the picker + return False + + def list_models(self) -> List[Dict[str, Any]]: + import tools.image_generation_tool as _it + return [ + { + "id": model_id, + "display": meta.get("display", model_id), + "speed": meta.get("speed", ""), + "strengths": meta.get("strengths", ""), + "price": meta.get("price", ""), + } + for model_id, meta in _it.FAL_MODELS.items() + ] + + def default_model(self) -> Optional[str]: + import tools.image_generation_tool as _it + return _it.DEFAULT_MODEL + + def get_setup_schema(self) -> Dict[str, Any]: + return { + "name": "FAL.ai", + "badge": "paid", + "tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.", + "env_vars": [ + { + "key": "FAL_KEY", + "prompt": "FAL API key", + "url": "https://fal.ai/dashboard/keys", + }, + ], + } + + def generate( + self, + prompt: str, + aspect_ratio: str = DEFAULT_ASPECT_RATIO, + **kwargs: Any, + ) -> Dict[str, Any]: + """Generate an image via the legacy FAL pipeline. + + Forwards prompt + aspect_ratio (and any forward-compat extras + the schema supports) into :func:`tools.image_generation_tool.image_generate_tool`, + then reshapes its JSON-string response into the provider-ABC + dict format consumed by ``_dispatch_to_plugin_provider``. + """ + import tools.image_generation_tool as _it + + aspect = resolve_aspect_ratio(aspect_ratio) + passthrough = { + key: kwargs[key] + for key in ( + "num_inference_steps", + "guidance_scale", + "num_images", + "output_format", + "seed", + ) + if key in kwargs and kwargs[key] is not None + } + + try: + raw = _it.image_generate_tool( + prompt=prompt, + aspect_ratio=aspect, + **passthrough, + ) + except Exception as exc: # noqa: BLE001 — never raise out of generate + logger.warning("FAL image_generate_tool raised: %s", exc, exc_info=True) + return { + "success": False, + "image": None, + "error": f"FAL image generation failed: {exc}", + "error_type": type(exc).__name__, + "provider": "fal", + "prompt": prompt, + "aspect_ratio": aspect, + } + + try: + response = json.loads(raw) if isinstance(raw, str) else raw + except Exception: # noqa: BLE001 + response = {"success": False, "image": None, "error": "Invalid JSON from FAL pipeline"} + + if not isinstance(response, dict): + response = { + "success": False, + "image": None, + "error": "FAL pipeline returned a non-dict response", + "error_type": "provider_contract", + } + + # Stamp provider/prompt/aspect_ratio so downstream consumers see + # the uniform shape declared in ``agent.image_gen_provider``. + response.setdefault("provider", "fal") + response.setdefault("prompt", prompt) + response.setdefault("aspect_ratio", aspect) + # Annotate model best-effort — the legacy pipeline resolves it + # internally, so query it after the fact for the response shape. + if "model" not in response: + try: + model_id, _meta = _it._resolve_fal_model() + response["model"] = model_id + except Exception: # noqa: BLE001 + pass + return response + + +# --------------------------------------------------------------------------- +# Plugin entry point +# --------------------------------------------------------------------------- + + +def register(ctx) -> None: + """Plugin entry point — wire ``FalImageGenProvider`` into the registry.""" + ctx.register_image_gen_provider(FalImageGenProvider()) diff --git a/plugins/image_gen/fal/plugin.yaml b/plugins/image_gen/fal/plugin.yaml new file mode 100644 index 00000000000..775b76c906d --- /dev/null +++ b/plugins/image_gen/fal/plugin.yaml @@ -0,0 +1,7 @@ +name: fal +version: 1.0.0 +description: "FAL.ai image generation backend (flux-2-klein, flux-2-pro, nano-banana, gpt-image-1.5, recraft-v3, etc.)." +author: NousResearch +kind: backend +requires_env: + - FAL_KEY diff --git a/plugins/video_gen/fal/__init__.py b/plugins/video_gen/fal/__init__.py index 0f46f62a7a0..61b36789855 100644 --- a/plugins/video_gen/fal/__init__.py +++ b/plugins/video_gen/fal/__init__.py @@ -282,20 +282,24 @@ def _build_payload( # --------------------------------------------------------------------------- -# fal_client lazy import (same pattern as image_generation_tool) +# fal_client lazy import (shared with image_generation_tool via fal_common) # --------------------------------------------------------------------------- _fal_client: Any = None def _load_fal_client() -> Any: + """Lazy-load the ``fal_client`` SDK and cache it on this module. + + Delegates the actual import to :func:`tools.fal_common.import_fal_client` + so the ``lazy_deps`` ensure-install handling stays in one place. + """ global _fal_client if _fal_client is not None: return _fal_client - import fal_client # type: ignore - - _fal_client = fal_client - return fal_client + from tools.fal_common import import_fal_client + _fal_client = import_fal_client() + return _fal_client # --------------------------------------------------------------------------- diff --git a/tests/hermes_cli/test_image_gen_picker.py b/tests/hermes_cli/test_image_gen_picker.py index 51eafd6da67..04d46bbbb86 100644 --- a/tests/hermes_cli/test_image_gen_picker.py +++ b/tests/hermes_cli/test_image_gen_picker.py @@ -69,18 +69,19 @@ class TestPluginPickerInjection: assert "Myimg" in names assert "myimg" in plugin_names - def test_fal_skipped_to_avoid_duplicate(self, monkeypatch): + def test_fal_surfaced_alongside_other_plugins(self, monkeypatch): from hermes_cli import tools_config - # Simulate a FAL plugin being registered — the picker already has - # hardcoded FAL rows in TOOL_CATEGORIES, so plugin-FAL must be - # skipped to avoid showing FAL twice. + # After #26241, FAL is itself a plugin (`plugins/image_gen/fal/`) + # and the hardcoded `TOOL_CATEGORIES["image_gen"]` FAL row is + # gone. The plugin-row builder therefore surfaces it like any + # other backend — no deduplication step needed. image_gen_registry.register_provider(_FakeProvider("fal")) image_gen_registry.register_provider(_FakeProvider("openai")) rows = tools_config._plugin_image_gen_providers() names = [r.get("image_gen_plugin_name") for r in rows] - assert "fal" not in names + assert "fal" in names assert "openai" in names def test_visible_providers_includes_plugins_for_image_gen(self, monkeypatch): diff --git a/tests/plugins/image_gen/check_parity_vs_main.py b/tests/plugins/image_gen/check_parity_vs_main.py new file mode 100644 index 00000000000..ca40cb5e13d --- /dev/null +++ b/tests/plugins/image_gen/check_parity_vs_main.py @@ -0,0 +1,300 @@ +"""Behavior-parity check for the image-gen FAL plugin migration (#26241). + +Spawns one subprocess per (version, scenario) cell — pinned to either +``origin/main`` (legacy in-tree FAL fall-through + ``configured == "fal"`` +skip in ``_dispatch_to_plugin_provider``) or this PR's worktree (FAL is +itself a plugin and the dispatcher routes every set provider through +the registry). Each subprocess clears all FAL-related env vars + writes +a ``config.yaml``, then asks the dispatcher how it would route an +``image_generate`` call. The emitted shape tuple is +``{dispatch_kind, provider_name, model}``: + +* ``dispatch_kind`` ∈ ``{"legacy_fal", "plugin", "error", None}`` — + whether the call would go straight to the in-tree pipeline, + through ``_dispatch_to_plugin_provider``, raise an explicit + provider-not-registered error, or fall through silently. +* ``provider_name`` — when ``dispatch_kind == "plugin"``, the + resolved provider name. ``None`` otherwise. +* ``model`` — the resolved FAL model id when applicable. + +The parent process diffs the shapes per scenario. A diff means the +migration introduced an observable behaviour change vs origin/main — +likely a real regression for users on the existing config keys. + +Run from the PR worktree: + + python tests/plugins/image_gen/check_parity_vs_main.py +""" +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[3] + + +# Pin one path to current main, one to the PR worktree. +# ``REPO_ROOT`` is ``.../.worktrees/``; the main checkout lives +# two levels up. When running directly from a regular clone (no +# worktree), ``MAIN_DIR`` falls back to a sibling ``hermes-agent-main`` +# checkout if one exists. +def _resolve_main_dir() -> Path: + candidate = REPO_ROOT.parent.parent + if (candidate / "tools" / "image_generation_tool.py").exists() and candidate != REPO_ROOT: + return candidate + sibling = REPO_ROOT.parent / "hermes-agent-main" + if (sibling / "tools" / "image_generation_tool.py").exists(): + return sibling + return REPO_ROOT + + +MAIN_DIR = _resolve_main_dir() +PR_DIR = REPO_ROOT +assert (PR_DIR / "tools" / "image_generation_tool.py").exists(), ( + f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout" +) + + +SUBPROCESS_SCRIPT = r""" +import json, os, sys, tempfile +sys.path.insert(0, sys.argv[1]) + +# Isolated HERMES_HOME so the config write is hermetic. +home = tempfile.mkdtemp() +os.environ["HERMES_HOME"] = home + +# Clear FAL-related env so dispatch decisions are config-driven. +for k in ( + "FAL_KEY", "FAL_QUEUE_GATEWAY_URL", + "TOOL_GATEWAY_DOMAIN", "TOOL_GATEWAY_USER_TOKEN", + "FAL_IMAGE_MODEL", +): + os.environ.pop(k, None) + +scenario_env = json.loads(sys.argv[2]) +os.environ.update(scenario_env) + +config_yaml = sys.argv[3] +config_path = os.path.join(home, "config.yaml") +with open(config_path, "w") as f: + f.write(config_yaml) + +# Fresh import — must not have anything cached. +for name in list(sys.modules): + if (name.startswith("tools.") + or name.startswith("agent.") + or name.startswith("plugins.") + or name.startswith("hermes_cli.")): + sys.modules.pop(name, None) + +import tools.image_generation_tool as image_tool + +dispatch_kind = None +provider_name = None +model = None +error_text = None + +try: + raw = image_tool._dispatch_to_plugin_provider("ping", "landscape") + if raw is None: + dispatch_kind = "legacy_fal" + else: + parsed = json.loads(raw) if isinstance(raw, str) else raw + if isinstance(parsed, dict): + if parsed.get("error_type") == "provider_not_registered": + dispatch_kind = "error" + error_text = parsed.get("error") + else: + dispatch_kind = "plugin" + provider_name = parsed.get("provider") + model = parsed.get("model") + else: + dispatch_kind = "unknown_payload" + + if model is None: + # _resolve_fal_model still returns the active FAL model id even + # when dispatch goes to a non-FAL plugin — used for the diff + # only when applicable. + try: + model_id, _meta = image_tool._resolve_fal_model() + if dispatch_kind == "legacy_fal": + model = model_id + except Exception: + pass +except Exception as exc: + dispatch_kind = "exception" + error_text = repr(exc) + +shape = { + "dispatch_kind": dispatch_kind, + "provider_name": provider_name, + "model": model, + "error_present": error_text is not None, +} +print(json.dumps(shape)) +""" + + +SCENARIOS: list[tuple[str, str, dict[str, str]]] = [ + # (label, config.yaml body, extra env vars) + ("no-config-no-env", "", {}), + ( + "explicit-fal-no-creds", + "image_gen:\n provider: fal\n", + {}, + ), + ( + "explicit-fal-with-creds", + "image_gen:\n provider: fal\n", + {"FAL_KEY": "test-key"}, + ), + ( + "explicit-fal-with-model", + "image_gen:\n provider: fal\n model: fal-ai/flux-2-pro\n", + {"FAL_KEY": "test-key"}, + ), + ( + "explicit-typo-provider", + "image_gen:\n provider: not-a-real-backend\n", + {"FAL_KEY": "test-key"}, + ), + ( + "managed-gateway-only", + "", + { + "TOOL_GATEWAY_DOMAIN": "nousresearch.com", + "TOOL_GATEWAY_USER_TOKEN": "nous-token", + }, + ), +] + + +def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict) -> dict: + venv_python = repo_path / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = MAIN_DIR / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = Path("python3") + + out = subprocess.run( + [ + str(venv_python), + "-c", + SUBPROCESS_SCRIPT, + str(repo_path), + json.dumps(env), + config_yaml, + ], + capture_output=True, + text=True, + timeout=60, + ) + if out.returncode != 0: + return { + "error": "subprocess failed", + "stdout": out.stdout[-500:], + "stderr": out.stderr[-500:], + } + try: + return json.loads(out.stdout.strip().splitlines()[-1]) + except Exception as exc: + return {"error": f"could not parse output: {exc}", "stdout": out.stdout} + + +def _reduce(shape: dict) -> dict: + """Reduce to the parts that matter for user-visible parity. + + On origin/main, ``explicit-fal-*`` scenarios short-circuit to + ``legacy_fal`` because of the ``configured == "fal"`` skip. On the + PR, those same scenarios route through the plugin and emit + ``dispatch_kind == "plugin"`` with ``provider_name == "fal"``. + + Both shapes are functionally equivalent — the plugin's ``generate()`` + re-enters the same in-tree pipeline via ``_it`` indirection — but + we want the diff to be visible so reviewers can sign off on the + intentional behaviour delta. + """ + return { + "dispatch_kind": shape.get("dispatch_kind"), + "provider_name": shape.get("provider_name"), + "model": shape.get("model"), + "error_present": shape.get("error_present"), + } + + +def main() -> int: + print(f"main: {MAIN_DIR}") + print(f"pr: {PR_DIR}") + print() + + if MAIN_DIR == PR_DIR: + print( + "WARN: MAIN_DIR == PR_DIR — diffs will be trivially identical.\n" + " Set up a sibling 'hermes-agent-main' checkout pinned to " + "origin/main to get real parity coverage." + ) + print() + + failures: list[str] = [] + errors: list[str] = [] + intentional_diffs: list[tuple[str, dict, dict]] = [] + for label, config_yaml, env in SCENARIOS: + main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env) + pr_shape = _run_scenario(PR_DIR, label, config_yaml, env) + + if "error" in main_shape or "error" in pr_shape: + print(f" [ERR ] {label}: subprocess failed") + print(f" main: {main_shape}") + print(f" pr: {pr_shape}") + errors.append(label) + continue + + main_reduced = _reduce(main_shape) + pr_reduced = _reduce(pr_shape) + + if main_reduced == pr_reduced: + print(f" [OK] {label}: {main_reduced}") + continue + + # On main, "explicit-fal-*" returns legacy_fal; on PR, plugin + # dispatch. That's the only acceptable diff — flag everything + # else as a regression. + legacy_to_plugin_fal = ( + main_reduced.get("dispatch_kind") == "legacy_fal" + and pr_reduced.get("dispatch_kind") == "plugin" + and pr_reduced.get("provider_name") == "fal" + ) + if legacy_to_plugin_fal: + print(f" [DIFF] {label}: legacy_fal → plugin (fal) — expected") + intentional_diffs.append((label, main_reduced, pr_reduced)) + else: + print(f" [FAIL] {label}") + print(f" main: {main_reduced}") + print(f" pr: {pr_reduced}") + failures.append(label) + + print() + if errors: + print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):") + for e in errors: + print(f" - {e}") + if failures: + print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):") + for f in failures: + print(f" - {f}") + if intentional_diffs: + print( + f"INTENTIONAL DIFFS ({len(intentional_diffs)}): " + f"legacy_fal → plugin dispatch for explicit FAL paths." + ) + if failures or errors: + return 1 + print(f"PARITY OK across {len(SCENARIOS)} scenarios.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/plugins/image_gen/test_fal_provider.py b/tests/plugins/image_gen/test_fal_provider.py new file mode 100644 index 00000000000..8b3e65e0bae --- /dev/null +++ b/tests/plugins/image_gen/test_fal_provider.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +"""Tests for the FAL.ai image generation plugin. + +The plugin is a thin registration adapter — actual FAL pipeline logic +lives in ``tools.image_generation_tool`` and is exercised by +``tests/tools/test_image_generation.py``. These tests focus on: + +* the ``ImageGenProvider`` ABC surface (name, models, schema) +* call-time indirection (``_it`` resolution at ``generate()`` time so + ``monkeypatch.setattr(image_tool, ...)`` keeps working) +* response shape stamping (provider/prompt/aspect_ratio/model) +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider surface +# --------------------------------------------------------------------------- + + +class TestFalImageGenProviderSurface: + def test_name(self): + from plugins.image_gen.fal import FalImageGenProvider + + assert FalImageGenProvider().name == "fal" + + def test_display_name(self): + from plugins.image_gen.fal import FalImageGenProvider + + assert FalImageGenProvider().display_name == "FAL.ai" + + def test_default_model_matches_legacy(self): + from plugins.image_gen.fal import FalImageGenProvider + from tools.image_generation_tool import DEFAULT_MODEL + + assert FalImageGenProvider().default_model() == DEFAULT_MODEL + + def test_list_models_uses_legacy_catalog(self): + from plugins.image_gen.fal import FalImageGenProvider + from tools.image_generation_tool import FAL_MODELS + + provider = FalImageGenProvider() + models = provider.list_models() + ids = {m["id"] for m in models} + # Whatever FAL_MODELS ships, the provider mirrors verbatim. + assert ids == set(FAL_MODELS.keys()) + # Spot-check the expected first-class fields are present. + for entry in models: + for field in ("id", "display", "speed", "strengths", "price"): + assert field in entry + + def test_setup_schema_advertises_fal_key(self): + from plugins.image_gen.fal import FalImageGenProvider + + schema = FalImageGenProvider().get_setup_schema() + assert schema["name"] == "FAL.ai" + assert schema["badge"] == "paid" + env_keys = {entry["key"] for entry in schema.get("env_vars", [])} + assert "FAL_KEY" in env_keys + + +class TestFalImageGenProviderAvailability: + def test_is_available_when_legacy_check_passes(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: True) + assert FalImageGenProvider().is_available() is True + + def test_is_available_false_when_legacy_check_fails(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: False) + assert FalImageGenProvider().is_available() is False + + def test_is_available_handles_legacy_exception(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + def _boom(): + raise RuntimeError("config broke") + + monkeypatch.setattr(image_tool, "check_fal_api_key", _boom) + # Picker must not propagate exceptions — show as "not available". + assert FalImageGenProvider().is_available() is False + + +# --------------------------------------------------------------------------- +# generate() — call-time indirection +# --------------------------------------------------------------------------- + + +class TestFalImageGenProviderGenerate: + def test_generate_delegates_to_legacy_image_generate_tool(self, monkeypatch): + """Plugin must look up ``image_generate_tool`` at call time so + ``monkeypatch.setattr(image_tool, "image_generate_tool", ...)`` + takes effect.""" + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + captured = {} + + def fake_image_generate_tool(prompt, aspect_ratio, **kwargs): + captured["prompt"] = prompt + captured["aspect_ratio"] = aspect_ratio + captured["kwargs"] = kwargs + return json.dumps({"success": True, "image": "https://fake/image.png"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake_image_generate_tool) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + result = FalImageGenProvider().generate( + "a serene mountain landscape", + aspect_ratio="square", + seed=42, + ) + + assert captured["prompt"] == "a serene mountain landscape" + assert captured["aspect_ratio"] == "square" + assert captured["kwargs"] == {"seed": 42} + assert result["success"] is True + assert result["image"] == "https://fake/image.png" + # Stamped fields for the unified response shape + assert result["provider"] == "fal" + assert result["prompt"] == "a serene mountain landscape" + assert result["aspect_ratio"] == "square" + assert result["model"] == "fal-ai/flux-2/klein/9b" + + def test_generate_invalid_aspect_ratio_is_coerced(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + seen_aspect = {} + + def fake(prompt, aspect_ratio, **kwargs): + seen_aspect["v"] = aspect_ratio + return json.dumps({"success": True, "image": "x"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + FalImageGenProvider().generate("p", aspect_ratio="not-a-real-ratio") + # ``resolve_aspect_ratio`` clamps to landscape. + assert seen_aspect["v"] == "landscape" + + def test_generate_passthrough_drops_none_kwargs(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + seen = {} + + def fake(prompt, aspect_ratio, **kwargs): + seen.update(kwargs) + return json.dumps({"success": True, "image": "x"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + FalImageGenProvider().generate( + "p", + aspect_ratio="landscape", + seed=None, + num_images=2, + guidance_scale=None, + ) + + # ``None`` values must not be forwarded — they'd override the + # model's defaults inside the legacy payload builder. + assert "seed" not in seen + assert "guidance_scale" not in seen + assert seen.get("num_images") == 2 + + def test_generate_catches_exception_from_legacy(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + def boom(*args, **kwargs): + raise RuntimeError("FAL endpoint exploded") + + monkeypatch.setattr(image_tool, "image_generate_tool", boom) + + result = FalImageGenProvider().generate("p") + assert result["success"] is False + assert "FAL image generation failed" in result["error"] + assert result["error_type"] == "RuntimeError" + assert result["provider"] == "fal" + + def test_generate_invalid_json_response(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "image_generate_tool", lambda **kw: "not-json") + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + result = FalImageGenProvider().generate("p") + assert result["success"] is False + assert "Invalid JSON" in result["error"] + assert result["provider"] == "fal" + + +# --------------------------------------------------------------------------- +# Registry wiring +# --------------------------------------------------------------------------- + + +class TestFalImageGenPluginRegistration: + def test_register_wires_provider_into_registry(self): + from plugins.image_gen.fal import FalImageGenProvider, register + + ctx = MagicMock() + register(ctx) + + ctx.register_image_gen_provider.assert_called_once() + (registered,), _ = ctx.register_image_gen_provider.call_args + assert isinstance(registered, FalImageGenProvider) diff --git a/tools/fal_common.py b/tools/fal_common.py new file mode 100644 index 00000000000..27636f90388 --- /dev/null +++ b/tools/fal_common.py @@ -0,0 +1,163 @@ +"""Shared FAL.ai SDK plumbing. + +Holds the stateless atoms that every FAL-backed tool needs: + +* :func:`import_fal_client` — lazy import + ``lazy_deps`` integration so + ``fal_client`` isn't pulled at cold start (it added ~64 ms per CLI + invocation when imported eagerly). +* :class:`_ManagedFalSyncClient` — wrapper that drives a Nous-managed + fal-queue gateway through the standard ``fal_client.SyncClient`` + primitives. +* :func:`_normalize_fal_queue_url_format`, :func:`_extract_http_status` + — small helpers used by both the managed client wrapper and + ``_submit_fal_request``. + +Stateful pieces (cache globals, ``_managed_fal_client*`` selectors, +``_submit_fal_request``) intentionally stay on +:mod:`tools.image_generation_tool`. That module is the patch target for +existing test suites (``tests/tools/test_image_generation.py``, +``tests/tools/test_managed_media_gateways.py``) and for the +``plugins/image_gen/fal/`` plugin's ``_it`` indirection — moving the +caches here would silently defeat ``monkeypatch.setattr(image_tool, +"_managed_fal_client", None)`` because the lookups would go against +``fal_common``'s namespace instead. See the per-rule walkthrough at +issue #26241 for details. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Union +from urllib.parse import urlencode + + +def import_fal_client() -> Any: + """Import ``fal_client`` (via ``lazy_deps`` when available) and return + the module reference. + + Callers are responsible for caching the result on their own module + global — keeping per-module globals lets tests monkey-patch the + target module's ``fal_client`` attribute and have the patched value + stick for that module's call sites. + + Raises :class:`ImportError` if the package is genuinely unavailable. + """ + try: + from tools.lazy_deps import ensure as _lazy_ensure + _lazy_ensure("image.fal", prompt=False) + except ImportError: + pass + except Exception as exc: # noqa: BLE001 — lazy_deps surfaces install hints + raise ImportError(str(exc)) + import fal_client # type: ignore # noqa: WPS433 — intentionally lazy + return fal_client + + +def _normalize_fal_queue_url_format(queue_run_origin: str) -> str: + normalized_origin = str(queue_run_origin or "").strip().rstrip("/") + if not normalized_origin: + raise ValueError("Managed FAL queue origin is required") + return f"{normalized_origin}/" + + +def _extract_http_status(exc: BaseException) -> Optional[int]: + """Return an HTTP status code from httpx/fal exceptions, else None. + + Defensive across exception shapes — httpx.HTTPStatusError exposes + ``.response.status_code`` while fal_client wrappers may expose + ``.status_code`` directly. + """ + response = getattr(exc, "response", None) + if response is not None: + status = getattr(response, "status_code", None) + if isinstance(status, int): + return status + status = getattr(exc, "status_code", None) + if isinstance(status, int): + return status + return None + + +class _ManagedFalSyncClient: + """Small per-instance wrapper around ``fal_client.SyncClient`` for + managed queue hosts. + + The wrapper carries its own ``fal_client`` module reference instead + of reaching into a module global, so callers stay in control of + which module's ``fal_client`` is in scope (matters for the test + patches that swap the legacy module's ``fal_client`` attribute). + """ + + def __init__(self, fal_client: Any, *, key: str, queue_run_origin: str): + sync_client_class = getattr(fal_client, "SyncClient", None) + if sync_client_class is None: + raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode") + + client_module = getattr(fal_client, "client", None) + if client_module is None: + raise RuntimeError("fal_client.client is required for managed FAL gateway mode") + + self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin) + self._sync_client = sync_client_class(key=key) + self._http_client = getattr(self._sync_client, "_client", None) + self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None) + self._raise_for_status = getattr(client_module, "_raise_for_status", None) + self._request_handle_class = getattr(client_module, "SyncRequestHandle", None) + self._add_hint_header = getattr(client_module, "add_hint_header", None) + self._add_priority_header = getattr(client_module, "add_priority_header", None) + self._add_timeout_header = getattr(client_module, "add_timeout_header", None) + + if self._http_client is None: + raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode") + if self._maybe_retry_request is None or self._raise_for_status is None: + raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode") + if self._request_handle_class is None: + raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode") + + def submit( + self, + application: str, + arguments: Dict[str, Any], + *, + path: str = "", + hint: Optional[str] = None, + webhook_url: Optional[str] = None, + priority: Any = None, + headers: Optional[Dict[str, str]] = None, + start_timeout: Optional[Union[int, float]] = None, + ): + url = self._queue_url_format + application + if path: + url += "/" + path.lstrip("/") + if webhook_url is not None: + url += "?" + urlencode({"fal_webhook": webhook_url}) + + request_headers = dict(headers or {}) + if hint is not None and self._add_hint_header is not None: + self._add_hint_header(hint, request_headers) + if priority is not None: + if self._add_priority_header is None: + raise RuntimeError("fal_client.client.add_priority_header is required for priority requests") + self._add_priority_header(priority, request_headers) + if start_timeout is not None: + if self._add_timeout_header is None: + raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests") + self._add_timeout_header(start_timeout, request_headers) + + response = self._maybe_retry_request( + self._http_client, + "POST", + url, + json=arguments, + timeout=getattr(self._sync_client, "default_timeout", 120.0), + headers=request_headers, + ) + self._raise_for_status(response) + + data = response.json() + return self._request_handle_class( + request_id=data["request_id"], + response_url=data["response_url"], + status_url=data["status_url"], + cancel_url=data["cancel_url"], + client=self._http_client, + ) diff --git a/tools/image_generation_tool.py b/tools/image_generation_tool.py index 3d171f093c9..584f5e9fa1c 100644 --- a/tools/image_generation_tool.py +++ b/tools/image_generation_tool.py @@ -26,8 +26,7 @@ import os import datetime import threading import uuid -from typing import Any, Dict, Optional, Union -from urllib.parse import urlencode +from typing import Any, Dict, Optional # fal_client is imported lazily — see _load_fal_client(). Pulling it # eagerly added ~64 ms to every CLI cold start because @@ -52,19 +51,17 @@ def _load_fal_client() -> Any: global fal_client if fal_client is not None: return fal_client - try: - from tools.lazy_deps import ensure as _lazy_ensure - _lazy_ensure("image.fal", prompt=False) - except ImportError: - pass - except Exception as e: - raise ImportError(str(e)) - import fal_client as _fal_client # noqa: F811 — module-global rebind - fal_client = _fal_client + from tools.fal_common import import_fal_client + fal_client = import_fal_client() return fal_client from tools.debug_helpers import DebugSession +from tools.fal_common import ( + _ManagedFalSyncClient, + _extract_http_status, + _normalize_fal_queue_url_format, # noqa: F401 — re-exported for tests +) from tools.managed_tool_gateway import resolve_managed_tool_gateway from tools.tool_backend_helpers import ( fal_key_is_configured, @@ -360,95 +357,6 @@ def _resolve_managed_fal_gateway(): return resolve_managed_tool_gateway("fal-queue") -def _normalize_fal_queue_url_format(queue_run_origin: str) -> str: - normalized_origin = str(queue_run_origin or "").strip().rstrip("/") - if not normalized_origin: - raise ValueError("Managed FAL queue origin is required") - return f"{normalized_origin}/" - - -class _ManagedFalSyncClient: - """Small per-instance wrapper around fal_client.SyncClient for managed queue hosts.""" - - def __init__(self, *, key: str, queue_run_origin: str): - # Trigger the lazy import on first construction. Idempotent — the - # placeholder is overwritten with the real module on first call. - _load_fal_client() - sync_client_class = getattr(fal_client, "SyncClient", None) - if sync_client_class is None: - raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode") - - client_module = getattr(fal_client, "client", None) - if client_module is None: - raise RuntimeError("fal_client.client is required for managed FAL gateway mode") - - self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin) - self._sync_client = sync_client_class(key=key) - self._http_client = getattr(self._sync_client, "_client", None) - self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None) - self._raise_for_status = getattr(client_module, "_raise_for_status", None) - self._request_handle_class = getattr(client_module, "SyncRequestHandle", None) - self._add_hint_header = getattr(client_module, "add_hint_header", None) - self._add_priority_header = getattr(client_module, "add_priority_header", None) - self._add_timeout_header = getattr(client_module, "add_timeout_header", None) - - if self._http_client is None: - raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode") - if self._maybe_retry_request is None or self._raise_for_status is None: - raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode") - if self._request_handle_class is None: - raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode") - - def submit( - self, - application: str, - arguments: Dict[str, Any], - *, - path: str = "", - hint: Optional[str] = None, - webhook_url: Optional[str] = None, - priority: Any = None, - headers: Optional[Dict[str, str]] = None, - start_timeout: Optional[Union[int, float]] = None, - ): - url = self._queue_url_format + application - if path: - url += "/" + path.lstrip("/") - if webhook_url is not None: - url += "?" + urlencode({"fal_webhook": webhook_url}) - - request_headers = dict(headers or {}) - if hint is not None and self._add_hint_header is not None: - self._add_hint_header(hint, request_headers) - if priority is not None: - if self._add_priority_header is None: - raise RuntimeError("fal_client.client.add_priority_header is required for priority requests") - self._add_priority_header(priority, request_headers) - if start_timeout is not None: - if self._add_timeout_header is None: - raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests") - self._add_timeout_header(start_timeout, request_headers) - - response = self._maybe_retry_request( - self._http_client, - "POST", - url, - json=arguments, - timeout=getattr(self._sync_client, "default_timeout", 120.0), - headers=request_headers, - ) - self._raise_for_status(response) - - data = response.json() - return self._request_handle_class( - request_id=data["request_id"], - response_url=data["response_url"], - status_url=data["status_url"], - cancel_url=data["cancel_url"], - client=self._http_client, - ) - - def _get_managed_fal_client(managed_gateway): """Reuse the managed FAL client so its internal httpx.Client is not leaked per call.""" global _managed_fal_client, _managed_fal_client_config @@ -461,7 +369,11 @@ def _get_managed_fal_client(managed_gateway): if _managed_fal_client is not None and _managed_fal_client_config == client_config: return _managed_fal_client + # Resolve fal_client on the legacy module — preserves the test + # pattern of monkey-patching ``image_generation_tool.fal_client``. + _load_fal_client() _managed_fal_client = _ManagedFalSyncClient( + fal_client, key=managed_gateway.nous_user_token, queue_run_origin=managed_gateway.gateway_origin, ) @@ -502,24 +414,6 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]): raise -def _extract_http_status(exc: BaseException) -> Optional[int]: - """Return an HTTP status code from httpx/fal exceptions, else None. - - Defensive across exception shapes — httpx.HTTPStatusError exposes - ``.response.status_code`` while fal_client wrappers may expose - ``.status_code`` directly. - """ - response = getattr(exc, "response", None) - if response is not None: - status = getattr(response, "status_code", None) - if isinstance(status, int): - return status - status = getattr(exc, "status_code", None) - if isinstance(status, int): - return status - return None - - # --------------------------------------------------------------------------- # Model resolution + payload construction # --------------------------------------------------------------------------- @@ -973,9 +867,12 @@ def _read_configured_image_provider(): """Return the value of ``image_gen.provider`` from config.yaml, or None. We only consult the plugin registry when this is explicitly set — an - unset value keeps users on the legacy in-tree FAL path even when other + unset value keeps users on the in-tree FAL fallback even when other providers happen to be registered (e.g. a user has OPENAI_API_KEY set - for other features but never asked for OpenAI image gen). + for other features but never asked for OpenAI image gen). ``"fal"`` + explicitly routes through ``plugins/image_gen/fal/`` (which delegates + back into this module's pipeline via call-time indirection — see + issue #26241). """ try: from hermes_cli.config import load_config @@ -994,15 +891,16 @@ def _dispatch_to_plugin_provider(prompt: str, aspect_ratio: str): """Route the call to a plugin-registered provider when one is selected. Returns a JSON string on dispatch, or ``None`` to fall through to the - built-in FAL path. + in-tree FAL fallback in ``image_generate_tool``. - Dispatch only fires when ``image_gen.provider`` is explicitly set AND - it does not point to ``fal`` (FAL still lives in-tree in this PR; - a later PR ports it into ``plugins/image_gen/fal/``). Any other value - that matches a registered plugin provider wins. + Dispatch fires when ``image_gen.provider`` is explicitly set — including + ``"fal"`` itself, which now resolves to the + ``plugins/image_gen/fal/`` plugin (the plugin re-enters this module's + pipeline via ``_it`` indirection so behavior is identical to the + direct call, just routed through the registry). """ configured = _read_configured_image_provider() - if not configured or configured == "fal": + if not configured: return None # Also read configured model so we can pass it to the plugin