diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 7a9a598f9..e89f96178 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -1019,6 +1019,11 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): def _is_provider_active(provider: dict, config: dict) -> bool: """Check if a provider entry matches the currently active config.""" + plugin_name = provider.get("image_gen_plugin_name") + if plugin_name: + image_cfg = config.get("image_gen", {}) + return isinstance(image_cfg, dict) and image_cfg.get("provider") == plugin_name + managed_feature = provider.get("managed_nous_feature") if managed_feature: features = get_nous_subscription_features(config) @@ -1026,6 +1031,13 @@ def _is_provider_active(provider: dict, config: dict) -> bool: if feature is None: return False if managed_feature == "image_gen": + image_cfg = config.get("image_gen", {}) + if isinstance(image_cfg, dict): + configured_provider = image_cfg.get("provider") + if configured_provider not in (None, "", "fal"): + return False + if image_cfg.get("use_gateway") is False: + return False return feature.managed_by_nous if provider.get("tts_provider"): return ( @@ -1048,6 +1060,16 @@ def _is_provider_active(provider: dict, config: dict) -> bool: if provider.get("web_backend"): current = config.get("web", {}).get("backend") return current == provider["web_backend"] + if provider.get("imagegen_backend"): + image_cfg = config.get("image_gen", {}) + if not isinstance(image_cfg, dict): + return False + configured_provider = image_cfg.get("provider") + return ( + provider["imagegen_backend"] == "fal" + and configured_provider in (None, "", "fal") + and not image_cfg.get("use_gateway") + ) return False @@ -1245,6 +1267,18 @@ def _configure_imagegen_model_for_plugin(plugin_name: str, config: dict) -> None _print_success(f" Model set to: {chosen}") +def _select_plugin_image_gen_provider(plugin_name: str, config: dict) -> None: + """Persist a plugin-backed image generation provider selection.""" + img_cfg = config.setdefault("image_gen", {}) + if not isinstance(img_cfg, dict): + img_cfg = {} + config["image_gen"] = img_cfg + img_cfg["provider"] = plugin_name + img_cfg["use_gateway"] = False + _print_success(f" image_gen.provider set to: {plugin_name}") + _configure_imagegen_model_for_plugin(plugin_name, config) + + def _configure_provider(provider: dict, config: dict): """Configure a single provider - prompt for API keys and set config.""" env_vars = provider.get("env_vars", []) @@ -1305,13 +1339,7 @@ def _configure_provider(provider: dict, config: dict): # and route model selection to the plugin's own catalog. plugin_name = provider.get("image_gen_plugin_name") if plugin_name: - img_cfg = config.setdefault("image_gen", {}) - if not isinstance(img_cfg, dict): - img_cfg = {} - config["image_gen"] = img_cfg - img_cfg["provider"] = plugin_name - _print_success(f" image_gen.provider set to: {plugin_name}") - _configure_imagegen_model_for_plugin(plugin_name, config) + _select_plugin_image_gen_provider(plugin_name, config) return # Imagegen backends prompt for model selection after backend pick. backend = provider.get("imagegen_backend") @@ -1359,13 +1387,7 @@ def _configure_provider(provider: dict, config: dict): _print_success(f" {provider['name']} configured!") plugin_name = provider.get("image_gen_plugin_name") if plugin_name: - img_cfg = config.setdefault("image_gen", {}) - if not isinstance(img_cfg, dict): - img_cfg = {} - config["image_gen"] = img_cfg - img_cfg["provider"] = plugin_name - _print_success(f" image_gen.provider set to: {plugin_name}") - _configure_imagegen_model_for_plugin(plugin_name, config) + _select_plugin_image_gen_provider(plugin_name, config) return # Imagegen backends prompt for model selection after env vars are in. backend = provider.get("imagegen_backend") @@ -1539,16 +1561,39 @@ def _reconfigure_provider(provider: dict, config: dict): config.setdefault("web", {})["backend"] = provider["web_backend"] _print_success(f" Web backend set to: {provider['web_backend']}") + if managed_feature and managed_feature not in ("web", "tts", "browser"): + section = config.setdefault(managed_feature, {}) + if not isinstance(section, dict): + section = {} + config[managed_feature] = section + section["use_gateway"] = True + elif not managed_feature: + for cat_key, cat in TOOL_CATEGORIES.items(): + if provider in cat.get("providers", []): + section = config.get(cat_key) + if isinstance(section, dict) and section.get("use_gateway"): + section["use_gateway"] = False + break + if not env_vars: if provider.get("post_setup"): _run_post_setup(provider["post_setup"]) _print_success(f" {provider['name']} - no configuration needed!") if managed_feature: _print_info(" Requests for this tool will be billed to your Nous subscription.") + plugin_name = provider.get("image_gen_plugin_name") + if plugin_name: + _select_plugin_image_gen_provider(plugin_name, config) + return # Imagegen backends prompt for model selection on reconfig too. backend = provider.get("imagegen_backend") if backend: _configure_imagegen_model(backend, config) + if backend == "fal": + img_cfg = config.setdefault("image_gen", {}) + if isinstance(img_cfg, dict): + img_cfg["provider"] = "fal" + img_cfg["use_gateway"] = False return for var in env_vars: @@ -1567,9 +1612,19 @@ def _reconfigure_provider(provider: dict, config: dict): _print_info(" Kept current") # Imagegen backends prompt for model selection on reconfig too. + plugin_name = provider.get("image_gen_plugin_name") + if plugin_name: + _select_plugin_image_gen_provider(plugin_name, config) + return + backend = provider.get("imagegen_backend") if backend: _configure_imagegen_model(backend, config) + if backend == "fal": + img_cfg = config.setdefault("image_gen", {}) + if isinstance(img_cfg, dict): + img_cfg["provider"] = "fal" + img_cfg["use_gateway"] = False def _reconfigure_simple_requirements(ts_key: str): diff --git a/tests/hermes_cli/test_image_gen_picker.py b/tests/hermes_cli/test_image_gen_picker.py index 27c502def..6da847691 100644 --- a/tests/hermes_cli/test_image_gen_picker.py +++ b/tests/hermes_cli/test_image_gen_picker.py @@ -6,6 +6,8 @@ Covers `_plugin_image_gen_providers`, `_visible_providers`, and from __future__ import annotations +from types import SimpleNamespace + import pytest from agent import image_gen_registry @@ -172,3 +174,78 @@ class TestConfigWriting: assert config["image_gen"]["provider"] == "noenv" assert config["image_gen"]["model"] == "noenv-model-v1" + + def test_reconfiguring_plugin_provider_writes_provider_and_model(self, monkeypatch, tmp_path): + """The reconfigure path should switch image_gen away from managed FAL + and onto the selected plugin provider.""" + from hermes_cli import tools_config + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + image_gen_registry.register_provider(_FakeProvider("testopenai")) + monkeypatch.setattr(tools_config, "_prompt_choice", lambda *a, **kw: 0) + monkeypatch.setattr(tools_config, "_prompt", lambda *a, **kw: "") + monkeypatch.setattr( + tools_config, + "get_env_value", + lambda key: "sk-test" if key == "OPENAI_API_KEY" else "", + ) + + config = {"image_gen": {"use_gateway": True}} + provider_row = { + "name": "OpenAI", + "env_vars": [{"key": "OPENAI_API_KEY", "prompt": "OpenAI API key"}], + "image_gen_plugin_name": "testopenai", + } + + tools_config._reconfigure_provider(provider_row, config) + + assert config["image_gen"]["provider"] == "testopenai" + assert config["image_gen"]["model"] == "testopenai-model-v1" + assert config["image_gen"]["use_gateway"] is False + + def test_plugin_provider_active_overrides_managed_nous_active_label(self, monkeypatch): + from hermes_cli import tools_config + + monkeypatch.setattr( + tools_config, + "get_nous_subscription_features", + lambda config: SimpleNamespace( + features={"image_gen": SimpleNamespace(managed_by_nous=True)} + ), + ) + + config = {"image_gen": {"provider": "openai", "use_gateway": False}} + nous_row = { + "name": "Nous Subscription", + "managed_nous_feature": "image_gen", + } + openai_row = { + "name": "OpenAI", + "image_gen_plugin_name": "openai", + } + + assert tools_config._is_provider_active(openai_row, config) is True + assert tools_config._is_provider_active(nous_row, config) is False + + def test_reconfiguring_fal_clears_plugin_provider(self, monkeypatch): + from hermes_cli import tools_config + + monkeypatch.setattr(tools_config, "_prompt_choice", lambda *a, **kw: 0) + monkeypatch.setattr(tools_config, "_prompt", lambda *a, **kw: "") + monkeypatch.setattr( + tools_config, + "get_env_value", + lambda key: "fal-key" if key == "FAL_KEY" else "", + ) + + config = {"image_gen": {"provider": "openai", "use_gateway": False}} + provider_row = { + "name": "FAL.ai", + "env_vars": [{"key": "FAL_KEY", "prompt": "FAL API key"}], + "imagegen_backend": "fal", + } + + tools_config._reconfigure_provider(provider_row, config) + + assert config["image_gen"]["provider"] == "fal" + assert config["image_gen"]["use_gateway"] is False