mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(image-gen): persist plugin provider on reconfigure
This commit is contained in:
parent
d1ce358646
commit
bace220d29
2 changed files with 146 additions and 14 deletions
|
|
@ -1019,6 +1019,11 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
|||
|
||||
def _is_provider_active(provider: dict, config: dict) -> bool:
|
||||
"""Check if a provider entry matches the currently active config."""
|
||||
plugin_name = provider.get("image_gen_plugin_name")
|
||||
if plugin_name:
|
||||
image_cfg = config.get("image_gen", {})
|
||||
return isinstance(image_cfg, dict) and image_cfg.get("provider") == plugin_name
|
||||
|
||||
managed_feature = provider.get("managed_nous_feature")
|
||||
if managed_feature:
|
||||
features = get_nous_subscription_features(config)
|
||||
|
|
@ -1026,6 +1031,13 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
|
|||
if feature is None:
|
||||
return False
|
||||
if managed_feature == "image_gen":
|
||||
image_cfg = config.get("image_gen", {})
|
||||
if isinstance(image_cfg, dict):
|
||||
configured_provider = image_cfg.get("provider")
|
||||
if configured_provider not in (None, "", "fal"):
|
||||
return False
|
||||
if image_cfg.get("use_gateway") is False:
|
||||
return False
|
||||
return feature.managed_by_nous
|
||||
if provider.get("tts_provider"):
|
||||
return (
|
||||
|
|
@ -1048,6 +1060,16 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
|
|||
if provider.get("web_backend"):
|
||||
current = config.get("web", {}).get("backend")
|
||||
return current == provider["web_backend"]
|
||||
if provider.get("imagegen_backend"):
|
||||
image_cfg = config.get("image_gen", {})
|
||||
if not isinstance(image_cfg, dict):
|
||||
return False
|
||||
configured_provider = image_cfg.get("provider")
|
||||
return (
|
||||
provider["imagegen_backend"] == "fal"
|
||||
and configured_provider in (None, "", "fal")
|
||||
and not image_cfg.get("use_gateway")
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -1245,6 +1267,18 @@ def _configure_imagegen_model_for_plugin(plugin_name: str, config: dict) -> None
|
|||
_print_success(f" Model set to: {chosen}")
|
||||
|
||||
|
||||
def _select_plugin_image_gen_provider(plugin_name: str, config: dict) -> None:
|
||||
"""Persist a plugin-backed image generation provider selection."""
|
||||
img_cfg = config.setdefault("image_gen", {})
|
||||
if not isinstance(img_cfg, dict):
|
||||
img_cfg = {}
|
||||
config["image_gen"] = img_cfg
|
||||
img_cfg["provider"] = plugin_name
|
||||
img_cfg["use_gateway"] = False
|
||||
_print_success(f" image_gen.provider set to: {plugin_name}")
|
||||
_configure_imagegen_model_for_plugin(plugin_name, config)
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
|
@ -1305,13 +1339,7 @@ def _configure_provider(provider: dict, config: dict):
|
|||
# and route model selection to the plugin's own catalog.
|
||||
plugin_name = provider.get("image_gen_plugin_name")
|
||||
if plugin_name:
|
||||
img_cfg = config.setdefault("image_gen", {})
|
||||
if not isinstance(img_cfg, dict):
|
||||
img_cfg = {}
|
||||
config["image_gen"] = img_cfg
|
||||
img_cfg["provider"] = plugin_name
|
||||
_print_success(f" image_gen.provider set to: {plugin_name}")
|
||||
_configure_imagegen_model_for_plugin(plugin_name, config)
|
||||
_select_plugin_image_gen_provider(plugin_name, config)
|
||||
return
|
||||
# Imagegen backends prompt for model selection after backend pick.
|
||||
backend = provider.get("imagegen_backend")
|
||||
|
|
@ -1359,13 +1387,7 @@ def _configure_provider(provider: dict, config: dict):
|
|||
_print_success(f" {provider['name']} configured!")
|
||||
plugin_name = provider.get("image_gen_plugin_name")
|
||||
if plugin_name:
|
||||
img_cfg = config.setdefault("image_gen", {})
|
||||
if not isinstance(img_cfg, dict):
|
||||
img_cfg = {}
|
||||
config["image_gen"] = img_cfg
|
||||
img_cfg["provider"] = plugin_name
|
||||
_print_success(f" image_gen.provider set to: {plugin_name}")
|
||||
_configure_imagegen_model_for_plugin(plugin_name, config)
|
||||
_select_plugin_image_gen_provider(plugin_name, config)
|
||||
return
|
||||
# Imagegen backends prompt for model selection after env vars are in.
|
||||
backend = provider.get("imagegen_backend")
|
||||
|
|
@ -1539,16 +1561,39 @@ def _reconfigure_provider(provider: dict, config: dict):
|
|||
config.setdefault("web", {})["backend"] = provider["web_backend"]
|
||||
_print_success(f" Web backend set to: {provider['web_backend']}")
|
||||
|
||||
if managed_feature and managed_feature not in ("web", "tts", "browser"):
|
||||
section = config.setdefault(managed_feature, {})
|
||||
if not isinstance(section, dict):
|
||||
section = {}
|
||||
config[managed_feature] = section
|
||||
section["use_gateway"] = True
|
||||
elif not managed_feature:
|
||||
for cat_key, cat in TOOL_CATEGORIES.items():
|
||||
if provider in cat.get("providers", []):
|
||||
section = config.get(cat_key)
|
||||
if isinstance(section, dict) and section.get("use_gateway"):
|
||||
section["use_gateway"] = False
|
||||
break
|
||||
|
||||
if not env_vars:
|
||||
if provider.get("post_setup"):
|
||||
_run_post_setup(provider["post_setup"])
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
if managed_feature:
|
||||
_print_info(" Requests for this tool will be billed to your Nous subscription.")
|
||||
plugin_name = provider.get("image_gen_plugin_name")
|
||||
if plugin_name:
|
||||
_select_plugin_image_gen_provider(plugin_name, config)
|
||||
return
|
||||
# Imagegen backends prompt for model selection on reconfig too.
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
if backend == "fal":
|
||||
img_cfg = config.setdefault("image_gen", {})
|
||||
if isinstance(img_cfg, dict):
|
||||
img_cfg["provider"] = "fal"
|
||||
img_cfg["use_gateway"] = False
|
||||
return
|
||||
|
||||
for var in env_vars:
|
||||
|
|
@ -1567,9 +1612,19 @@ def _reconfigure_provider(provider: dict, config: dict):
|
|||
_print_info(" Kept current")
|
||||
|
||||
# Imagegen backends prompt for model selection on reconfig too.
|
||||
plugin_name = provider.get("image_gen_plugin_name")
|
||||
if plugin_name:
|
||||
_select_plugin_image_gen_provider(plugin_name, config)
|
||||
return
|
||||
|
||||
backend = provider.get("imagegen_backend")
|
||||
if backend:
|
||||
_configure_imagegen_model(backend, config)
|
||||
if backend == "fal":
|
||||
img_cfg = config.setdefault("image_gen", {})
|
||||
if isinstance(img_cfg, dict):
|
||||
img_cfg["provider"] = "fal"
|
||||
img_cfg["use_gateway"] = False
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue