mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
refactor(image_gen): port FAL backend to plugins/image_gen/fal
Mirrors the architecture established by the web (#25182), browser (#25214), and video_gen (#25126) plugin migrations: * `tools/fal_common.py` — stateless atoms shared by both FAL-backed plugins (image_gen + video_gen). Holds the lazy `fal_client` import helper, `_ManagedFalSyncClient`, `_normalize_fal_queue_url_format`, `_extract_http_status`. Stateful pieces (`fal_client` module global, `_managed_fal_client*` cache, `_submit_fal_request`, `_resolve_managed_fal_gateway`, `_get_managed_fal_client`) intentionally stay on `tools.image_generation_tool` so the existing `monkeypatch.setattr(image_tool, ...)` patch sites keep working unchanged. * `plugins/video_gen/fal/__init__.py` — drops its inline `_load_fal_client` duplicate; consumes `tools.fal_common.import_fal_client`. * `plugins/image_gen/fal/{plugin.yaml,__init__.py}` — new plugin. `FalImageGenProvider` is a thin registration adapter that resolves the legacy module via `import tools.image_generation_tool as _it` and calls `_it.image_generate_tool` + `_it._resolve_fal_model` at call time. The 18-model catalog, `_build_fal_payload`, managed- gateway selection, and Clarity Upscaler chaining all remain in `tools.image_generation_tool` as the single source of truth — the plugin is a registration adapter, not a parallel implementation. * `tools/image_generation_tool.py::_dispatch_to_plugin_provider` — drops the `configured == "fal"` skip. Setting `image_gen.provider: fal` now routes through the registry like any other provider; the plugin re-enters this module's pipeline so behavior is identical. Unset `image_gen.provider` still falls through to the in-tree pipeline (preserves no-config-with-FAL_KEY UX from #15696). * `hermes_cli/tools_config.py` — drops the hardcoded "FAL.ai" row from `TOOL_CATEGORIES["image_gen"]["providers"]` (now injected by `_plugin_image_gen_providers` like every other backend) and the `getattr(provider, "name") == "fal"` skip that protected against duplication with the hardcoded row. The "Nous Subscription" row stays as a setup-flow entry — same shape browser kept "Nous Subscription (Browser Use cloud)" after #25214. * `tests/plugins/image_gen/test_fal_provider.py` — 14 cases covering the ABC surface, call-time indirection (verifying `monkeypatch.setattr(image_tool, "image_generate_tool", ...)` takes effect through the plugin), response-shape stamping, exception handling, and registry wiring. * `tests/plugins/image_gen/check_parity_vs_main.py` — subprocess harness mirroring `tests/plugins/browser/check_parity_vs_main.py`. Pins one path to origin/main, one to the worktree; runs six scenarios (unset, explicit-fal-no-creds, explicit-fal-with-creds, explicit-fal-with-model, typo provider, managed-gateway-only) and diffs the reduced shape `{dispatch_kind, provider_name, model}` per scenario. The only acceptable diff is "legacy_fal → plugin (fal)" for explicit-FAL paths — every other delta is flagged as a regression. * `tests/hermes_cli/test_image_gen_picker.py::test_fal_surfaced_alongside_other_plugins` — flips the previous `test_fal_skipped_to_avoid_duplicate` to match the new shape (FAL is a plugin now, no dedup needed). Verified: 195/195 tests across `tests/{tools/test_image_generation*,tools/test_managed_media_gateways,plugins/image_gen,plugins/video_gen,hermes_cli/test_image_gen_picker}.py` pass on this branch with no test patches modified outside the picker test that asserted the old skip behaviour. Fixes #26241
This commit is contained in:
parent
7dea33303a
commit
3ac2125140
9 changed files with 930 additions and 154 deletions
|
|
@ -311,6 +311,16 @@ TOOL_CATEGORIES = {
|
|||
"image_gen": {
|
||||
"name": "Image Generation",
|
||||
"icon": "🎨",
|
||||
# Per-provider rows for FAL.ai (`plugins/image_gen/fal`), OpenAI,
|
||||
# OpenAI Codex, and xAI are injected at runtime from each
|
||||
# ``plugins.image_gen.<vendor>`` package via
|
||||
# ``_plugin_image_gen_providers()`` in ``_visible_providers``.
|
||||
# Only non-provider UX setup-flow rows remain here:
|
||||
# - "Nous Subscription" — managed FAL billed via the Nous
|
||||
# subscription (requires_nous_auth + override_env_vars).
|
||||
# Uses the fal plugin as the underlying backend but has a
|
||||
# distinct setup UX.
|
||||
# Mirrors the shape browser/video_gen ship today.
|
||||
"providers": [
|
||||
{
|
||||
"name": "Nous Subscription",
|
||||
|
|
@ -322,15 +332,6 @@ TOOL_CATEGORIES = {
|
|||
"override_env_vars": ["FAL_KEY"],
|
||||
"imagegen_backend": "fal",
|
||||
},
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"badge": "paid",
|
||||
"tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
],
|
||||
"imagegen_backend": "fal",
|
||||
},
|
||||
],
|
||||
},
|
||||
"video_gen": {
|
||||
|
|
@ -1567,12 +1568,9 @@ def _plugin_image_gen_providers() -> list[dict]:
|
|||
Each returned dict looks like a regular ``TOOL_CATEGORIES`` provider
|
||||
row but carries an ``image_gen_plugin_name`` marker so downstream
|
||||
code (config writing, model picker) knows to route through the
|
||||
plugin registry instead of the in-tree FAL backend.
|
||||
|
||||
FAL is skipped — it's already exposed by the hardcoded
|
||||
``TOOL_CATEGORIES["image_gen"]`` entries. When FAL gets ported to
|
||||
a plugin in a follow-up PR, the hardcoded entries go away and this
|
||||
function surfaces it alongside OpenAI automatically.
|
||||
plugin registry. Every image-gen backend is a plugin now — there
|
||||
are no hardcoded rows left in ``TOOL_CATEGORIES["image_gen"]`` for
|
||||
this function to dedupe against (see issue #26241).
|
||||
"""
|
||||
try:
|
||||
from agent.image_gen_registry import list_providers
|
||||
|
|
@ -1585,9 +1583,6 @@ def _plugin_image_gen_providers() -> list[dict]:
|
|||
|
||||
rows: list[dict] = []
|
||||
for provider in providers:
|
||||
if getattr(provider, "name", None) == "fal":
|
||||
# FAL has its own hardcoded rows today.
|
||||
continue
|
||||
try:
|
||||
schema = provider.get_setup_schema()
|
||||
except Exception:
|
||||
|
|
|
|||
182
plugins/image_gen/fal/__init__.py
Normal file
182
plugins/image_gen/fal/__init__.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""FAL.ai image generation backend.
|
||||
|
||||
Wraps the 18-model FAL catalog (FLUX 2, Z-Image, Nano Banana, GPT
|
||||
Image 1.5, Recraft, Imagen 4, Qwen, Ideogram, …) as an
|
||||
:class:`ImageGenProvider` implementation.
|
||||
|
||||
The heavy lifting — model catalog, payload construction, request
|
||||
submission, managed-Nous-gateway selection, Clarity Upscaler chaining
|
||||
— lives in :mod:`tools.image_generation_tool`. This plugin reaches into
|
||||
that module via call-time indirection (``import tools.image_generation_tool as _it``)
|
||||
so:
|
||||
|
||||
* the existing test suite (``tests/tools/test_image_generation.py``,
|
||||
``tests/tools/test_managed_media_gateways.py``) keeps patching
|
||||
``image_tool._submit_fal_request`` / ``image_tool.fal_client`` /
|
||||
``image_tool._managed_fal_client`` without modification, and
|
||||
* there's exactly one canonical FAL code path on disk — the plugin is a
|
||||
registration adapter, not a parallel implementation.
|
||||
|
||||
See issue #26241 for the migration plan and the
|
||||
``plugin-extraction-test-patch-compatibility.md`` rules this follows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.image_gen_provider import (
|
||||
DEFAULT_ASPECT_RATIO,
|
||||
ImageGenProvider,
|
||||
resolve_aspect_ratio,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FalImageGenProvider(ImageGenProvider):
|
||||
"""FAL.ai image generation backend.
|
||||
|
||||
Delegates to ``tools.image_generation_tool.image_generate_tool`` so
|
||||
the in-tree FAL implementation (model catalog, payload builder,
|
||||
managed-gateway selection, Clarity Upscaler chaining) is the single
|
||||
source of truth. Everything is resolved at call time via the
|
||||
``_it`` indirection so tests can monkey-patch the legacy module.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "fal"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return "FAL.ai"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# Available when direct FAL_KEY is set OR the managed Nous
|
||||
# gateway resolves a fal-queue origin. Both checks come from the
|
||||
# legacy module so this provider tracks whatever logic ships
|
||||
# there.
|
||||
import tools.image_generation_tool as _it
|
||||
try:
|
||||
return bool(_it.check_fal_api_key())
|
||||
except Exception: # noqa: BLE001 — defensive; never break the picker
|
||||
return False
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
import tools.image_generation_tool as _it
|
||||
return [
|
||||
{
|
||||
"id": model_id,
|
||||
"display": meta.get("display", model_id),
|
||||
"speed": meta.get("speed", ""),
|
||||
"strengths": meta.get("strengths", ""),
|
||||
"price": meta.get("price", ""),
|
||||
}
|
||||
for model_id, meta in _it.FAL_MODELS.items()
|
||||
]
|
||||
|
||||
def default_model(self) -> Optional[str]:
|
||||
import tools.image_generation_tool as _it
|
||||
return _it.DEFAULT_MODEL
|
||||
|
||||
def get_setup_schema(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": "FAL.ai",
|
||||
"badge": "paid",
|
||||
"tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.",
|
||||
"env_vars": [
|
||||
{
|
||||
"key": "FAL_KEY",
|
||||
"prompt": "FAL API key",
|
||||
"url": "https://fal.ai/dashboard/keys",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate an image via the legacy FAL pipeline.
|
||||
|
||||
Forwards prompt + aspect_ratio (and any forward-compat extras
|
||||
the schema supports) into :func:`tools.image_generation_tool.image_generate_tool`,
|
||||
then reshapes its JSON-string response into the provider-ABC
|
||||
dict format consumed by ``_dispatch_to_plugin_provider``.
|
||||
"""
|
||||
import tools.image_generation_tool as _it
|
||||
|
||||
aspect = resolve_aspect_ratio(aspect_ratio)
|
||||
passthrough = {
|
||||
key: kwargs[key]
|
||||
for key in (
|
||||
"num_inference_steps",
|
||||
"guidance_scale",
|
||||
"num_images",
|
||||
"output_format",
|
||||
"seed",
|
||||
)
|
||||
if key in kwargs and kwargs[key] is not None
|
||||
}
|
||||
|
||||
try:
|
||||
raw = _it.image_generate_tool(
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect,
|
||||
**passthrough,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — never raise out of generate
|
||||
logger.warning("FAL image_generate_tool raised: %s", exc, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"image": None,
|
||||
"error": f"FAL image generation failed: {exc}",
|
||||
"error_type": type(exc).__name__,
|
||||
"provider": "fal",
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect,
|
||||
}
|
||||
|
||||
try:
|
||||
response = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception: # noqa: BLE001
|
||||
response = {"success": False, "image": None, "error": "Invalid JSON from FAL pipeline"}
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = {
|
||||
"success": False,
|
||||
"image": None,
|
||||
"error": "FAL pipeline returned a non-dict response",
|
||||
"error_type": "provider_contract",
|
||||
}
|
||||
|
||||
# Stamp provider/prompt/aspect_ratio so downstream consumers see
|
||||
# the uniform shape declared in ``agent.image_gen_provider``.
|
||||
response.setdefault("provider", "fal")
|
||||
response.setdefault("prompt", prompt)
|
||||
response.setdefault("aspect_ratio", aspect)
|
||||
# Annotate model best-effort — the legacy pipeline resolves it
|
||||
# internally, so query it after the fact for the response shape.
|
||||
if "model" not in response:
|
||||
try:
|
||||
model_id, _meta = _it._resolve_fal_model()
|
||||
response["model"] = model_id
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Plugin entry point — wire ``FalImageGenProvider`` into the registry."""
|
||||
ctx.register_image_gen_provider(FalImageGenProvider())
|
||||
7
plugins/image_gen/fal/plugin.yaml
Normal file
7
plugins/image_gen/fal/plugin.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
name: fal
|
||||
version: 1.0.0
|
||||
description: "FAL.ai image generation backend (flux-2-klein, flux-2-pro, nano-banana, gpt-image-1.5, recraft-v3, etc.)."
|
||||
author: NousResearch
|
||||
kind: backend
|
||||
requires_env:
|
||||
- FAL_KEY
|
||||
|
|
@ -282,20 +282,24 @@ def _build_payload(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fal_client lazy import (same pattern as image_generation_tool)
|
||||
# fal_client lazy import (shared with image_generation_tool via fal_common)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_fal_client: Any = None
|
||||
|
||||
|
||||
def _load_fal_client() -> Any:
|
||||
"""Lazy-load the ``fal_client`` SDK and cache it on this module.
|
||||
|
||||
Delegates the actual import to :func:`tools.fal_common.import_fal_client`
|
||||
so the ``lazy_deps`` ensure-install handling stays in one place.
|
||||
"""
|
||||
global _fal_client
|
||||
if _fal_client is not None:
|
||||
return _fal_client
|
||||
import fal_client # type: ignore
|
||||
|
||||
_fal_client = fal_client
|
||||
return fal_client
|
||||
from tools.fal_common import import_fal_client
|
||||
_fal_client = import_fal_client()
|
||||
return _fal_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -69,18 +69,19 @@ class TestPluginPickerInjection:
|
|||
assert "Myimg" in names
|
||||
assert "myimg" in plugin_names
|
||||
|
||||
def test_fal_skipped_to_avoid_duplicate(self, monkeypatch):
|
||||
def test_fal_surfaced_alongside_other_plugins(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
# Simulate a FAL plugin being registered — the picker already has
|
||||
# hardcoded FAL rows in TOOL_CATEGORIES, so plugin-FAL must be
|
||||
# skipped to avoid showing FAL twice.
|
||||
# After #26241, FAL is itself a plugin (`plugins/image_gen/fal/`)
|
||||
# and the hardcoded `TOOL_CATEGORIES["image_gen"]` FAL row is
|
||||
# gone. The plugin-row builder therefore surfaces it like any
|
||||
# other backend — no deduplication step needed.
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
|
||||
rows = tools_config._plugin_image_gen_providers()
|
||||
names = [r.get("image_gen_plugin_name") for r in rows]
|
||||
assert "fal" not in names
|
||||
assert "fal" in names
|
||||
assert "openai" in names
|
||||
|
||||
def test_visible_providers_includes_plugins_for_image_gen(self, monkeypatch):
|
||||
|
|
|
|||
300
tests/plugins/image_gen/check_parity_vs_main.py
Normal file
300
tests/plugins/image_gen/check_parity_vs_main.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
"""Behavior-parity check for the image-gen FAL plugin migration (#26241).
|
||||
|
||||
Spawns one subprocess per (version, scenario) cell — pinned to either
|
||||
``origin/main`` (legacy in-tree FAL fall-through + ``configured == "fal"``
|
||||
skip in ``_dispatch_to_plugin_provider``) or this PR's worktree (FAL is
|
||||
itself a plugin and the dispatcher routes every set provider through
|
||||
the registry). Each subprocess clears all FAL-related env vars + writes
|
||||
a ``config.yaml``, then asks the dispatcher how it would route an
|
||||
``image_generate`` call. The emitted shape tuple is
|
||||
``{dispatch_kind, provider_name, model}``:
|
||||
|
||||
* ``dispatch_kind`` ∈ ``{"legacy_fal", "plugin", "error", None}`` —
|
||||
whether the call would go straight to the in-tree pipeline,
|
||||
through ``_dispatch_to_plugin_provider``, raise an explicit
|
||||
provider-not-registered error, or fall through silently.
|
||||
* ``provider_name`` — when ``dispatch_kind == "plugin"``, the
|
||||
resolved provider name. ``None`` otherwise.
|
||||
* ``model`` — the resolved FAL model id when applicable.
|
||||
|
||||
The parent process diffs the shapes per scenario. A diff means the
|
||||
migration introduced an observable behaviour change vs origin/main —
|
||||
likely a real regression for users on the existing config keys.
|
||||
|
||||
Run from the PR worktree:
|
||||
|
||||
python tests/plugins/image_gen/check_parity_vs_main.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
# Pin one path to current main, one to the PR worktree.
|
||||
# ``REPO_ROOT`` is ``.../.worktrees/<name>``; the main checkout lives
|
||||
# two levels up. When running directly from a regular clone (no
|
||||
# worktree), ``MAIN_DIR`` falls back to a sibling ``hermes-agent-main``
|
||||
# checkout if one exists.
|
||||
def _resolve_main_dir() -> Path:
|
||||
candidate = REPO_ROOT.parent.parent
|
||||
if (candidate / "tools" / "image_generation_tool.py").exists() and candidate != REPO_ROOT:
|
||||
return candidate
|
||||
sibling = REPO_ROOT.parent / "hermes-agent-main"
|
||||
if (sibling / "tools" / "image_generation_tool.py").exists():
|
||||
return sibling
|
||||
return REPO_ROOT
|
||||
|
||||
|
||||
MAIN_DIR = _resolve_main_dir()
|
||||
PR_DIR = REPO_ROOT
|
||||
assert (PR_DIR / "tools" / "image_generation_tool.py").exists(), (
|
||||
f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout"
|
||||
)
|
||||
|
||||
|
||||
SUBPROCESS_SCRIPT = r"""
|
||||
import json, os, sys, tempfile
|
||||
sys.path.insert(0, sys.argv[1])
|
||||
|
||||
# Isolated HERMES_HOME so the config write is hermetic.
|
||||
home = tempfile.mkdtemp()
|
||||
os.environ["HERMES_HOME"] = home
|
||||
|
||||
# Clear FAL-related env so dispatch decisions are config-driven.
|
||||
for k in (
|
||||
"FAL_KEY", "FAL_QUEUE_GATEWAY_URL",
|
||||
"TOOL_GATEWAY_DOMAIN", "TOOL_GATEWAY_USER_TOKEN",
|
||||
"FAL_IMAGE_MODEL",
|
||||
):
|
||||
os.environ.pop(k, None)
|
||||
|
||||
scenario_env = json.loads(sys.argv[2])
|
||||
os.environ.update(scenario_env)
|
||||
|
||||
config_yaml = sys.argv[3]
|
||||
config_path = os.path.join(home, "config.yaml")
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config_yaml)
|
||||
|
||||
# Fresh import — must not have anything cached.
|
||||
for name in list(sys.modules):
|
||||
if (name.startswith("tools.")
|
||||
or name.startswith("agent.")
|
||||
or name.startswith("plugins.")
|
||||
or name.startswith("hermes_cli.")):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
import tools.image_generation_tool as image_tool
|
||||
|
||||
dispatch_kind = None
|
||||
provider_name = None
|
||||
model = None
|
||||
error_text = None
|
||||
|
||||
try:
|
||||
raw = image_tool._dispatch_to_plugin_provider("ping", "landscape")
|
||||
if raw is None:
|
||||
dispatch_kind = "legacy_fal"
|
||||
else:
|
||||
parsed = json.loads(raw) if isinstance(raw, str) else raw
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get("error_type") == "provider_not_registered":
|
||||
dispatch_kind = "error"
|
||||
error_text = parsed.get("error")
|
||||
else:
|
||||
dispatch_kind = "plugin"
|
||||
provider_name = parsed.get("provider")
|
||||
model = parsed.get("model")
|
||||
else:
|
||||
dispatch_kind = "unknown_payload"
|
||||
|
||||
if model is None:
|
||||
# _resolve_fal_model still returns the active FAL model id even
|
||||
# when dispatch goes to a non-FAL plugin — used for the diff
|
||||
# only when applicable.
|
||||
try:
|
||||
model_id, _meta = image_tool._resolve_fal_model()
|
||||
if dispatch_kind == "legacy_fal":
|
||||
model = model_id
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
dispatch_kind = "exception"
|
||||
error_text = repr(exc)
|
||||
|
||||
shape = {
|
||||
"dispatch_kind": dispatch_kind,
|
||||
"provider_name": provider_name,
|
||||
"model": model,
|
||||
"error_present": error_text is not None,
|
||||
}
|
||||
print(json.dumps(shape))
|
||||
"""
|
||||
|
||||
|
||||
SCENARIOS: list[tuple[str, str, dict[str, str]]] = [
|
||||
# (label, config.yaml body, extra env vars)
|
||||
("no-config-no-env", "", {}),
|
||||
(
|
||||
"explicit-fal-no-creds",
|
||||
"image_gen:\n provider: fal\n",
|
||||
{},
|
||||
),
|
||||
(
|
||||
"explicit-fal-with-creds",
|
||||
"image_gen:\n provider: fal\n",
|
||||
{"FAL_KEY": "test-key"},
|
||||
),
|
||||
(
|
||||
"explicit-fal-with-model",
|
||||
"image_gen:\n provider: fal\n model: fal-ai/flux-2-pro\n",
|
||||
{"FAL_KEY": "test-key"},
|
||||
),
|
||||
(
|
||||
"explicit-typo-provider",
|
||||
"image_gen:\n provider: not-a-real-backend\n",
|
||||
{"FAL_KEY": "test-key"},
|
||||
),
|
||||
(
|
||||
"managed-gateway-only",
|
||||
"",
|
||||
{
|
||||
"TOOL_GATEWAY_DOMAIN": "nousresearch.com",
|
||||
"TOOL_GATEWAY_USER_TOKEN": "nous-token",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict) -> dict:
|
||||
venv_python = repo_path / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = MAIN_DIR / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = Path("python3")
|
||||
|
||||
out = subprocess.run(
|
||||
[
|
||||
str(venv_python),
|
||||
"-c",
|
||||
SUBPROCESS_SCRIPT,
|
||||
str(repo_path),
|
||||
json.dumps(env),
|
||||
config_yaml,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if out.returncode != 0:
|
||||
return {
|
||||
"error": "subprocess failed",
|
||||
"stdout": out.stdout[-500:],
|
||||
"stderr": out.stderr[-500:],
|
||||
}
|
||||
try:
|
||||
return json.loads(out.stdout.strip().splitlines()[-1])
|
||||
except Exception as exc:
|
||||
return {"error": f"could not parse output: {exc}", "stdout": out.stdout}
|
||||
|
||||
|
||||
def _reduce(shape: dict) -> dict:
|
||||
"""Reduce to the parts that matter for user-visible parity.
|
||||
|
||||
On origin/main, ``explicit-fal-*`` scenarios short-circuit to
|
||||
``legacy_fal`` because of the ``configured == "fal"`` skip. On the
|
||||
PR, those same scenarios route through the plugin and emit
|
||||
``dispatch_kind == "plugin"`` with ``provider_name == "fal"``.
|
||||
|
||||
Both shapes are functionally equivalent — the plugin's ``generate()``
|
||||
re-enters the same in-tree pipeline via ``_it`` indirection — but
|
||||
we want the diff to be visible so reviewers can sign off on the
|
||||
intentional behaviour delta.
|
||||
"""
|
||||
return {
|
||||
"dispatch_kind": shape.get("dispatch_kind"),
|
||||
"provider_name": shape.get("provider_name"),
|
||||
"model": shape.get("model"),
|
||||
"error_present": shape.get("error_present"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
print(f"main: {MAIN_DIR}")
|
||||
print(f"pr: {PR_DIR}")
|
||||
print()
|
||||
|
||||
if MAIN_DIR == PR_DIR:
|
||||
print(
|
||||
"WARN: MAIN_DIR == PR_DIR — diffs will be trivially identical.\n"
|
||||
" Set up a sibling 'hermes-agent-main' checkout pinned to "
|
||||
"origin/main to get real parity coverage."
|
||||
)
|
||||
print()
|
||||
|
||||
failures: list[str] = []
|
||||
errors: list[str] = []
|
||||
intentional_diffs: list[tuple[str, dict, dict]] = []
|
||||
for label, config_yaml, env in SCENARIOS:
|
||||
main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env)
|
||||
pr_shape = _run_scenario(PR_DIR, label, config_yaml, env)
|
||||
|
||||
if "error" in main_shape or "error" in pr_shape:
|
||||
print(f" [ERR ] {label}: subprocess failed")
|
||||
print(f" main: {main_shape}")
|
||||
print(f" pr: {pr_shape}")
|
||||
errors.append(label)
|
||||
continue
|
||||
|
||||
main_reduced = _reduce(main_shape)
|
||||
pr_reduced = _reduce(pr_shape)
|
||||
|
||||
if main_reduced == pr_reduced:
|
||||
print(f" [OK] {label}: {main_reduced}")
|
||||
continue
|
||||
|
||||
# On main, "explicit-fal-*" returns legacy_fal; on PR, plugin
|
||||
# dispatch. That's the only acceptable diff — flag everything
|
||||
# else as a regression.
|
||||
legacy_to_plugin_fal = (
|
||||
main_reduced.get("dispatch_kind") == "legacy_fal"
|
||||
and pr_reduced.get("dispatch_kind") == "plugin"
|
||||
and pr_reduced.get("provider_name") == "fal"
|
||||
)
|
||||
if legacy_to_plugin_fal:
|
||||
print(f" [DIFF] {label}: legacy_fal → plugin (fal) — expected")
|
||||
intentional_diffs.append((label, main_reduced, pr_reduced))
|
||||
else:
|
||||
print(f" [FAIL] {label}")
|
||||
print(f" main: {main_reduced}")
|
||||
print(f" pr: {pr_reduced}")
|
||||
failures.append(label)
|
||||
|
||||
print()
|
||||
if errors:
|
||||
print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
if failures:
|
||||
print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
if intentional_diffs:
|
||||
print(
|
||||
f"INTENTIONAL DIFFS ({len(intentional_diffs)}): "
|
||||
f"legacy_fal → plugin dispatch for explicit FAL paths."
|
||||
)
|
||||
if failures or errors:
|
||||
return 1
|
||||
print(f"PARITY OK across {len(SCENARIOS)} scenarios.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
226
tests/plugins/image_gen/test_fal_provider.py
Normal file
226
tests/plugins/image_gen/test_fal_provider.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for the FAL.ai image generation plugin.
|
||||
|
||||
The plugin is a thin registration adapter — actual FAL pipeline logic
|
||||
lives in ``tools.image_generation_tool`` and is exercised by
|
||||
``tests/tools/test_image_generation.py``. These tests focus on:
|
||||
|
||||
* the ``ImageGenProvider`` ABC surface (name, models, schema)
|
||||
* call-time indirection (``_it`` resolution at ``generate()`` time so
|
||||
``monkeypatch.setattr(image_tool, ...)`` keeps working)
|
||||
* response shape stamping (provider/prompt/aspect_ratio/model)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider surface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFalImageGenProviderSurface:
|
||||
def test_name(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
assert FalImageGenProvider().name == "fal"
|
||||
|
||||
def test_display_name(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
assert FalImageGenProvider().display_name == "FAL.ai"
|
||||
|
||||
def test_default_model_matches_legacy(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
from tools.image_generation_tool import DEFAULT_MODEL
|
||||
|
||||
assert FalImageGenProvider().default_model() == DEFAULT_MODEL
|
||||
|
||||
def test_list_models_uses_legacy_catalog(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
from tools.image_generation_tool import FAL_MODELS
|
||||
|
||||
provider = FalImageGenProvider()
|
||||
models = provider.list_models()
|
||||
ids = {m["id"] for m in models}
|
||||
# Whatever FAL_MODELS ships, the provider mirrors verbatim.
|
||||
assert ids == set(FAL_MODELS.keys())
|
||||
# Spot-check the expected first-class fields are present.
|
||||
for entry in models:
|
||||
for field in ("id", "display", "speed", "strengths", "price"):
|
||||
assert field in entry
|
||||
|
||||
def test_setup_schema_advertises_fal_key(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
schema = FalImageGenProvider().get_setup_schema()
|
||||
assert schema["name"] == "FAL.ai"
|
||||
assert schema["badge"] == "paid"
|
||||
env_keys = {entry["key"] for entry in schema.get("env_vars", [])}
|
||||
assert "FAL_KEY" in env_keys
|
||||
|
||||
|
||||
class TestFalImageGenProviderAvailability:
|
||||
def test_is_available_when_legacy_check_passes(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: True)
|
||||
assert FalImageGenProvider().is_available() is True
|
||||
|
||||
def test_is_available_false_when_legacy_check_fails(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: False)
|
||||
assert FalImageGenProvider().is_available() is False
|
||||
|
||||
def test_is_available_handles_legacy_exception(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("config broke")
|
||||
|
||||
monkeypatch.setattr(image_tool, "check_fal_api_key", _boom)
|
||||
# Picker must not propagate exceptions — show as "not available".
|
||||
assert FalImageGenProvider().is_available() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate() — call-time indirection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFalImageGenProviderGenerate:
|
||||
def test_generate_delegates_to_legacy_image_generate_tool(self, monkeypatch):
|
||||
"""Plugin must look up ``image_generate_tool`` at call time so
|
||||
``monkeypatch.setattr(image_tool, "image_generate_tool", ...)``
|
||||
takes effect."""
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_image_generate_tool(prompt, aspect_ratio, **kwargs):
|
||||
captured["prompt"] = prompt
|
||||
captured["aspect_ratio"] = aspect_ratio
|
||||
captured["kwargs"] = kwargs
|
||||
return json.dumps({"success": True, "image": "https://fake/image.png"})
|
||||
|
||||
monkeypatch.setattr(image_tool, "image_generate_tool", fake_image_generate_tool)
|
||||
monkeypatch.setattr(image_tool, "_resolve_fal_model",
|
||||
lambda: ("fal-ai/flux-2/klein/9b", {}))
|
||||
|
||||
result = FalImageGenProvider().generate(
|
||||
"a serene mountain landscape",
|
||||
aspect_ratio="square",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert captured["prompt"] == "a serene mountain landscape"
|
||||
assert captured["aspect_ratio"] == "square"
|
||||
assert captured["kwargs"] == {"seed": 42}
|
||||
assert result["success"] is True
|
||||
assert result["image"] == "https://fake/image.png"
|
||||
# Stamped fields for the unified response shape
|
||||
assert result["provider"] == "fal"
|
||||
assert result["prompt"] == "a serene mountain landscape"
|
||||
assert result["aspect_ratio"] == "square"
|
||||
assert result["model"] == "fal-ai/flux-2/klein/9b"
|
||||
|
||||
def test_generate_invalid_aspect_ratio_is_coerced(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
seen_aspect = {}
|
||||
|
||||
def fake(prompt, aspect_ratio, **kwargs):
|
||||
seen_aspect["v"] = aspect_ratio
|
||||
return json.dumps({"success": True, "image": "x"})
|
||||
|
||||
monkeypatch.setattr(image_tool, "image_generate_tool", fake)
|
||||
monkeypatch.setattr(image_tool, "_resolve_fal_model",
|
||||
lambda: ("fal-ai/flux-2/klein/9b", {}))
|
||||
|
||||
FalImageGenProvider().generate("p", aspect_ratio="not-a-real-ratio")
|
||||
# ``resolve_aspect_ratio`` clamps to landscape.
|
||||
assert seen_aspect["v"] == "landscape"
|
||||
|
||||
def test_generate_passthrough_drops_none_kwargs(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
seen = {}
|
||||
|
||||
def fake(prompt, aspect_ratio, **kwargs):
|
||||
seen.update(kwargs)
|
||||
return json.dumps({"success": True, "image": "x"})
|
||||
|
||||
monkeypatch.setattr(image_tool, "image_generate_tool", fake)
|
||||
monkeypatch.setattr(image_tool, "_resolve_fal_model",
|
||||
lambda: ("fal-ai/flux-2/klein/9b", {}))
|
||||
|
||||
FalImageGenProvider().generate(
|
||||
"p",
|
||||
aspect_ratio="landscape",
|
||||
seed=None,
|
||||
num_images=2,
|
||||
guidance_scale=None,
|
||||
)
|
||||
|
||||
# ``None`` values must not be forwarded — they'd override the
|
||||
# model's defaults inside the legacy payload builder.
|
||||
assert "seed" not in seen
|
||||
assert "guidance_scale" not in seen
|
||||
assert seen.get("num_images") == 2
|
||||
|
||||
def test_generate_catches_exception_from_legacy(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
def boom(*args, **kwargs):
|
||||
raise RuntimeError("FAL endpoint exploded")
|
||||
|
||||
monkeypatch.setattr(image_tool, "image_generate_tool", boom)
|
||||
|
||||
result = FalImageGenProvider().generate("p")
|
||||
assert result["success"] is False
|
||||
assert "FAL image generation failed" in result["error"]
|
||||
assert result["error_type"] == "RuntimeError"
|
||||
assert result["provider"] == "fal"
|
||||
|
||||
def test_generate_invalid_json_response(self, monkeypatch):
|
||||
import tools.image_generation_tool as image_tool
|
||||
from plugins.image_gen.fal import FalImageGenProvider
|
||||
|
||||
monkeypatch.setattr(image_tool, "image_generate_tool", lambda **kw: "not-json")
|
||||
monkeypatch.setattr(image_tool, "_resolve_fal_model",
|
||||
lambda: ("fal-ai/flux-2/klein/9b", {}))
|
||||
|
||||
result = FalImageGenProvider().generate("p")
|
||||
assert result["success"] is False
|
||||
assert "Invalid JSON" in result["error"]
|
||||
assert result["provider"] == "fal"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFalImageGenPluginRegistration:
|
||||
def test_register_wires_provider_into_registry(self):
|
||||
from plugins.image_gen.fal import FalImageGenProvider, register
|
||||
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
|
||||
ctx.register_image_gen_provider.assert_called_once()
|
||||
(registered,), _ = ctx.register_image_gen_provider.call_args
|
||||
assert isinstance(registered, FalImageGenProvider)
|
||||
163
tools/fal_common.py
Normal file
163
tools/fal_common.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Shared FAL.ai SDK plumbing.
|
||||
|
||||
Holds the stateless atoms that every FAL-backed tool needs:
|
||||
|
||||
* :func:`import_fal_client` — lazy import + ``lazy_deps`` integration so
|
||||
``fal_client`` isn't pulled at cold start (it added ~64 ms per CLI
|
||||
invocation when imported eagerly).
|
||||
* :class:`_ManagedFalSyncClient` — wrapper that drives a Nous-managed
|
||||
fal-queue gateway through the standard ``fal_client.SyncClient``
|
||||
primitives.
|
||||
* :func:`_normalize_fal_queue_url_format`, :func:`_extract_http_status`
|
||||
— small helpers used by both the managed client wrapper and
|
||||
``_submit_fal_request``.
|
||||
|
||||
Stateful pieces (cache globals, ``_managed_fal_client*`` selectors,
|
||||
``_submit_fal_request``) intentionally stay on
|
||||
:mod:`tools.image_generation_tool`. That module is the patch target for
|
||||
existing test suites (``tests/tools/test_image_generation.py``,
|
||||
``tests/tools/test_managed_media_gateways.py``) and for the
|
||||
``plugins/image_gen/fal/`` plugin's ``_it`` indirection — moving the
|
||||
caches here would silently defeat ``monkeypatch.setattr(image_tool,
|
||||
"_managed_fal_client", None)`` because the lookups would go against
|
||||
``fal_common``'s namespace instead. See the per-rule walkthrough at
|
||||
issue #26241 for details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
def import_fal_client() -> Any:
|
||||
"""Import ``fal_client`` (via ``lazy_deps`` when available) and return
|
||||
the module reference.
|
||||
|
||||
Callers are responsible for caching the result on their own module
|
||||
global — keeping per-module globals lets tests monkey-patch the
|
||||
target module's ``fal_client`` attribute and have the patched value
|
||||
stick for that module's call sites.
|
||||
|
||||
Raises :class:`ImportError` if the package is genuinely unavailable.
|
||||
"""
|
||||
try:
|
||||
from tools.lazy_deps import ensure as _lazy_ensure
|
||||
_lazy_ensure("image.fal", prompt=False)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as exc: # noqa: BLE001 — lazy_deps surfaces install hints
|
||||
raise ImportError(str(exc))
|
||||
import fal_client # type: ignore # noqa: WPS433 — intentionally lazy
|
||||
return fal_client
|
||||
|
||||
|
||||
def _normalize_fal_queue_url_format(queue_run_origin: str) -> str:
|
||||
normalized_origin = str(queue_run_origin or "").strip().rstrip("/")
|
||||
if not normalized_origin:
|
||||
raise ValueError("Managed FAL queue origin is required")
|
||||
return f"{normalized_origin}/"
|
||||
|
||||
|
||||
def _extract_http_status(exc: BaseException) -> Optional[int]:
|
||||
"""Return an HTTP status code from httpx/fal exceptions, else None.
|
||||
|
||||
Defensive across exception shapes — httpx.HTTPStatusError exposes
|
||||
``.response.status_code`` while fal_client wrappers may expose
|
||||
``.status_code`` directly.
|
||||
"""
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
status = getattr(response, "status_code", None)
|
||||
if isinstance(status, int):
|
||||
return status
|
||||
status = getattr(exc, "status_code", None)
|
||||
if isinstance(status, int):
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
class _ManagedFalSyncClient:
|
||||
"""Small per-instance wrapper around ``fal_client.SyncClient`` for
|
||||
managed queue hosts.
|
||||
|
||||
The wrapper carries its own ``fal_client`` module reference instead
|
||||
of reaching into a module global, so callers stay in control of
|
||||
which module's ``fal_client`` is in scope (matters for the test
|
||||
patches that swap the legacy module's ``fal_client`` attribute).
|
||||
"""
|
||||
|
||||
def __init__(self, fal_client: Any, *, key: str, queue_run_origin: str):
|
||||
sync_client_class = getattr(fal_client, "SyncClient", None)
|
||||
if sync_client_class is None:
|
||||
raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode")
|
||||
|
||||
client_module = getattr(fal_client, "client", None)
|
||||
if client_module is None:
|
||||
raise RuntimeError("fal_client.client is required for managed FAL gateway mode")
|
||||
|
||||
self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin)
|
||||
self._sync_client = sync_client_class(key=key)
|
||||
self._http_client = getattr(self._sync_client, "_client", None)
|
||||
self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None)
|
||||
self._raise_for_status = getattr(client_module, "_raise_for_status", None)
|
||||
self._request_handle_class = getattr(client_module, "SyncRequestHandle", None)
|
||||
self._add_hint_header = getattr(client_module, "add_hint_header", None)
|
||||
self._add_priority_header = getattr(client_module, "add_priority_header", None)
|
||||
self._add_timeout_header = getattr(client_module, "add_timeout_header", None)
|
||||
|
||||
if self._http_client is None:
|
||||
raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode")
|
||||
if self._maybe_retry_request is None or self._raise_for_status is None:
|
||||
raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode")
|
||||
if self._request_handle_class is None:
|
||||
raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
application: str,
|
||||
arguments: Dict[str, Any],
|
||||
*,
|
||||
path: str = "",
|
||||
hint: Optional[str] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
priority: Any = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
start_timeout: Optional[Union[int, float]] = None,
|
||||
):
|
||||
url = self._queue_url_format + application
|
||||
if path:
|
||||
url += "/" + path.lstrip("/")
|
||||
if webhook_url is not None:
|
||||
url += "?" + urlencode({"fal_webhook": webhook_url})
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
if hint is not None and self._add_hint_header is not None:
|
||||
self._add_hint_header(hint, request_headers)
|
||||
if priority is not None:
|
||||
if self._add_priority_header is None:
|
||||
raise RuntimeError("fal_client.client.add_priority_header is required for priority requests")
|
||||
self._add_priority_header(priority, request_headers)
|
||||
if start_timeout is not None:
|
||||
if self._add_timeout_header is None:
|
||||
raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests")
|
||||
self._add_timeout_header(start_timeout, request_headers)
|
||||
|
||||
response = self._maybe_retry_request(
|
||||
self._http_client,
|
||||
"POST",
|
||||
url,
|
||||
json=arguments,
|
||||
timeout=getattr(self._sync_client, "default_timeout", 120.0),
|
||||
headers=request_headers,
|
||||
)
|
||||
self._raise_for_status(response)
|
||||
|
||||
data = response.json()
|
||||
return self._request_handle_class(
|
||||
request_id=data["request_id"],
|
||||
response_url=data["response_url"],
|
||||
status_url=data["status_url"],
|
||||
cancel_url=data["cancel_url"],
|
||||
client=self._http_client,
|
||||
)
|
||||
|
|
@ -26,8 +26,7 @@ import os
|
|||
import datetime
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# fal_client is imported lazily — see _load_fal_client(). Pulling it
|
||||
# eagerly added ~64 ms to every CLI cold start because
|
||||
|
|
@ -52,19 +51,17 @@ def _load_fal_client() -> Any:
|
|||
global fal_client
|
||||
if fal_client is not None:
|
||||
return fal_client
|
||||
try:
|
||||
from tools.lazy_deps import ensure as _lazy_ensure
|
||||
_lazy_ensure("image.fal", prompt=False)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise ImportError(str(e))
|
||||
import fal_client as _fal_client # noqa: F811 — module-global rebind
|
||||
fal_client = _fal_client
|
||||
from tools.fal_common import import_fal_client
|
||||
fal_client = import_fal_client()
|
||||
return fal_client
|
||||
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.fal_common import (
|
||||
_ManagedFalSyncClient,
|
||||
_extract_http_status,
|
||||
_normalize_fal_queue_url_format, # noqa: F401 — re-exported for tests
|
||||
)
|
||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||
from tools.tool_backend_helpers import (
|
||||
fal_key_is_configured,
|
||||
|
|
@ -360,95 +357,6 @@ def _resolve_managed_fal_gateway():
|
|||
return resolve_managed_tool_gateway("fal-queue")
|
||||
|
||||
|
||||
def _normalize_fal_queue_url_format(queue_run_origin: str) -> str:
|
||||
normalized_origin = str(queue_run_origin or "").strip().rstrip("/")
|
||||
if not normalized_origin:
|
||||
raise ValueError("Managed FAL queue origin is required")
|
||||
return f"{normalized_origin}/"
|
||||
|
||||
|
||||
class _ManagedFalSyncClient:
|
||||
"""Small per-instance wrapper around fal_client.SyncClient for managed queue hosts."""
|
||||
|
||||
def __init__(self, *, key: str, queue_run_origin: str):
|
||||
# Trigger the lazy import on first construction. Idempotent — the
|
||||
# placeholder is overwritten with the real module on first call.
|
||||
_load_fal_client()
|
||||
sync_client_class = getattr(fal_client, "SyncClient", None)
|
||||
if sync_client_class is None:
|
||||
raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode")
|
||||
|
||||
client_module = getattr(fal_client, "client", None)
|
||||
if client_module is None:
|
||||
raise RuntimeError("fal_client.client is required for managed FAL gateway mode")
|
||||
|
||||
self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin)
|
||||
self._sync_client = sync_client_class(key=key)
|
||||
self._http_client = getattr(self._sync_client, "_client", None)
|
||||
self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None)
|
||||
self._raise_for_status = getattr(client_module, "_raise_for_status", None)
|
||||
self._request_handle_class = getattr(client_module, "SyncRequestHandle", None)
|
||||
self._add_hint_header = getattr(client_module, "add_hint_header", None)
|
||||
self._add_priority_header = getattr(client_module, "add_priority_header", None)
|
||||
self._add_timeout_header = getattr(client_module, "add_timeout_header", None)
|
||||
|
||||
if self._http_client is None:
|
||||
raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode")
|
||||
if self._maybe_retry_request is None or self._raise_for_status is None:
|
||||
raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode")
|
||||
if self._request_handle_class is None:
|
||||
raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
application: str,
|
||||
arguments: Dict[str, Any],
|
||||
*,
|
||||
path: str = "",
|
||||
hint: Optional[str] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
priority: Any = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
start_timeout: Optional[Union[int, float]] = None,
|
||||
):
|
||||
url = self._queue_url_format + application
|
||||
if path:
|
||||
url += "/" + path.lstrip("/")
|
||||
if webhook_url is not None:
|
||||
url += "?" + urlencode({"fal_webhook": webhook_url})
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
if hint is not None and self._add_hint_header is not None:
|
||||
self._add_hint_header(hint, request_headers)
|
||||
if priority is not None:
|
||||
if self._add_priority_header is None:
|
||||
raise RuntimeError("fal_client.client.add_priority_header is required for priority requests")
|
||||
self._add_priority_header(priority, request_headers)
|
||||
if start_timeout is not None:
|
||||
if self._add_timeout_header is None:
|
||||
raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests")
|
||||
self._add_timeout_header(start_timeout, request_headers)
|
||||
|
||||
response = self._maybe_retry_request(
|
||||
self._http_client,
|
||||
"POST",
|
||||
url,
|
||||
json=arguments,
|
||||
timeout=getattr(self._sync_client, "default_timeout", 120.0),
|
||||
headers=request_headers,
|
||||
)
|
||||
self._raise_for_status(response)
|
||||
|
||||
data = response.json()
|
||||
return self._request_handle_class(
|
||||
request_id=data["request_id"],
|
||||
response_url=data["response_url"],
|
||||
status_url=data["status_url"],
|
||||
cancel_url=data["cancel_url"],
|
||||
client=self._http_client,
|
||||
)
|
||||
|
||||
|
||||
def _get_managed_fal_client(managed_gateway):
|
||||
"""Reuse the managed FAL client so its internal httpx.Client is not leaked per call."""
|
||||
global _managed_fal_client, _managed_fal_client_config
|
||||
|
|
@ -461,7 +369,11 @@ def _get_managed_fal_client(managed_gateway):
|
|||
if _managed_fal_client is not None and _managed_fal_client_config == client_config:
|
||||
return _managed_fal_client
|
||||
|
||||
# Resolve fal_client on the legacy module — preserves the test
|
||||
# pattern of monkey-patching ``image_generation_tool.fal_client``.
|
||||
_load_fal_client()
|
||||
_managed_fal_client = _ManagedFalSyncClient(
|
||||
fal_client,
|
||||
key=managed_gateway.nous_user_token,
|
||||
queue_run_origin=managed_gateway.gateway_origin,
|
||||
)
|
||||
|
|
@ -502,24 +414,6 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]):
|
|||
raise
|
||||
|
||||
|
||||
def _extract_http_status(exc: BaseException) -> Optional[int]:
|
||||
"""Return an HTTP status code from httpx/fal exceptions, else None.
|
||||
|
||||
Defensive across exception shapes — httpx.HTTPStatusError exposes
|
||||
``.response.status_code`` while fal_client wrappers may expose
|
||||
``.status_code`` directly.
|
||||
"""
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
status = getattr(response, "status_code", None)
|
||||
if isinstance(status, int):
|
||||
return status
|
||||
status = getattr(exc, "status_code", None)
|
||||
if isinstance(status, int):
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model resolution + payload construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -973,9 +867,12 @@ def _read_configured_image_provider():
|
|||
"""Return the value of ``image_gen.provider`` from config.yaml, or None.
|
||||
|
||||
We only consult the plugin registry when this is explicitly set — an
|
||||
unset value keeps users on the legacy in-tree FAL path even when other
|
||||
unset value keeps users on the in-tree FAL fallback even when other
|
||||
providers happen to be registered (e.g. a user has OPENAI_API_KEY set
|
||||
for other features but never asked for OpenAI image gen).
|
||||
for other features but never asked for OpenAI image gen). ``"fal"``
|
||||
explicitly routes through ``plugins/image_gen/fal/`` (which delegates
|
||||
back into this module's pipeline via call-time indirection — see
|
||||
issue #26241).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
|
@ -994,15 +891,16 @@ def _dispatch_to_plugin_provider(prompt: str, aspect_ratio: str):
|
|||
"""Route the call to a plugin-registered provider when one is selected.
|
||||
|
||||
Returns a JSON string on dispatch, or ``None`` to fall through to the
|
||||
built-in FAL path.
|
||||
in-tree FAL fallback in ``image_generate_tool``.
|
||||
|
||||
Dispatch only fires when ``image_gen.provider`` is explicitly set AND
|
||||
it does not point to ``fal`` (FAL still lives in-tree in this PR;
|
||||
a later PR ports it into ``plugins/image_gen/fal/``). Any other value
|
||||
that matches a registered plugin provider wins.
|
||||
Dispatch fires when ``image_gen.provider`` is explicitly set — including
|
||||
``"fal"`` itself, which now resolves to the
|
||||
``plugins/image_gen/fal/`` plugin (the plugin re-enters this module's
|
||||
pipeline via ``_it`` indirection so behavior is identical to the
|
||||
direct call, just routed through the registry).
|
||||
"""
|
||||
configured = _read_configured_image_provider()
|
||||
if not configured or configured == "fal":
|
||||
if not configured:
|
||||
return None
|
||||
|
||||
# Also read configured model so we can pass it to the plugin
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue