diff --git a/cli.py b/cli.py index 739a1b91e..559224b5e 100644 --- a/cli.py +++ b/cli.py @@ -4130,6 +4130,16 @@ class HermesCLI: # Parse --provider and --global flags model_input, explicit_provider, persist_global = parse_model_flags(raw_args) + user_provs = None + custom_provs = None + try: + from hermes_cli.config import load_config + cfg = load_config() + user_provs = cfg.get("providers") + custom_provs = cfg.get("custom_providers") + except Exception: + pass + # No args at all: show available providers + models if not model_input and not explicit_provider: model_display = self.model or "unknown" @@ -4139,18 +4149,10 @@ class HermesCLI: # Show authenticated providers with top models try: - # Load user providers from config - user_provs = None - try: - from hermes_cli.config import load_config - cfg = load_config() - user_provs = cfg.get("providers") - except Exception: - pass - providers = list_authenticated_providers( current_provider=self.provider or "", user_providers=user_provs, + custom_providers=custom_provs, max_models=6, ) if providers: @@ -4191,6 +4193,8 @@ class HermesCLI: current_api_key=self.api_key or "", is_global=persist_global, explicit_provider=explicit_provider, + user_providers=user_provs, + custom_providers=custom_provs, ) if not result.success: diff --git a/gateway/run.py b/gateway/run.py index 5aa42cf53..9aae8217d 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3546,6 +3546,7 @@ class GatewayRunner: current_base_url = "" current_api_key = "" user_provs = None + custom_provs = None config_path = _hermes_home / "config.yaml" try: if config_path.exists(): @@ -3557,6 +3558,7 @@ class GatewayRunner: current_provider = model_cfg.get("provider", current_provider) current_base_url = model_cfg.get("base_url", "") user_provs = cfg.get("providers") + custom_provs = cfg.get("custom_providers") except Exception: pass @@ -3584,6 +3586,7 @@ class GatewayRunner: providers = list_authenticated_providers( current_provider=current_provider, user_providers=user_provs, + custom_providers=custom_provs, max_models=50, ) except Exception: @@ -3611,6 +3614,8 @@ class GatewayRunner: current_api_key=_cur_api_key, is_global=False, explicit_provider=provider_slug, + user_providers=user_provs, + custom_providers=custom_provs, ) if not result.success: return f"Error: {result.error_message}" @@ -3689,6 +3694,7 @@ class GatewayRunner: providers = list_authenticated_providers( current_provider=current_provider, user_providers=user_provs, + custom_providers=custom_provs, max_models=5, ) for p in providers: @@ -3718,6 +3724,8 @@ class GatewayRunner: current_api_key=current_api_key, is_global=persist_global, explicit_provider=explicit_provider, + user_providers=user_provs, + custom_providers=custom_provs, ) if not result.success: diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index ef35108df..d2cdcc908 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -336,6 +336,7 @@ def resolve_alias( def get_authenticated_provider_slugs( current_provider: str = "", user_providers: dict = None, + custom_providers: list | None = None, ) -> list[str]: """Return slugs of providers that have credentials. @@ -346,6 +347,7 @@ def get_authenticated_provider_slugs( providers = list_authenticated_providers( current_provider=current_provider, user_providers=user_providers, + custom_providers=custom_providers, max_models=0, ) return [p["slug"] for p in providers] @@ -383,6 +385,7 @@ def switch_model( is_global: bool = False, explicit_provider: str = "", user_providers: dict = None, + custom_providers: list | None = None, ) -> ModelSwitchResult: """Core model-switching pipeline shared between CLI and gateway. @@ -416,6 +419,7 @@ def switch_model( is_global: Whether to persist the switch. explicit_provider: From --provider flag (empty = no explicit provider). user_providers: The ``providers:`` dict from config.yaml (for user endpoints). + custom_providers: The ``custom_providers:`` list from config.yaml. Returns: ModelSwitchResult with all information the caller needs. @@ -436,7 +440,11 @@ def switch_model( # ================================================================= if explicit_provider: # Resolve the provider - pdef = resolve_provider_full(explicit_provider, user_providers) + pdef = resolve_provider_full( + explicit_provider, + user_providers, + custom_providers, + ) if pdef is None: _switch_err = ( f"Unknown provider '{explicit_provider}'. " @@ -516,6 +524,7 @@ def switch_model( authed = get_authenticated_provider_slugs( current_provider=current_provider, user_providers=user_providers, + custom_providers=custom_providers, ) fallback_result = _resolve_alias_fallback(raw_input, authed) if fallback_result is not None: @@ -590,6 +599,14 @@ def switch_model( provider_changed = target_provider != current_provider provider_label = get_label(target_provider) + if target_provider.startswith("custom:"): + custom_pdef = resolve_provider_full( + target_provider, + user_providers, + custom_providers, + ) + if custom_pdef is not None: + provider_label = custom_pdef.name # --- Resolve credentials --- api_key = current_api_key @@ -708,6 +725,7 @@ def switch_model( def list_authenticated_providers( current_provider: str = "", user_providers: dict = None, + custom_providers: list | None = None, max_models: int = 8, ) -> List[dict]: """Detect which providers have credentials and list their curated models. @@ -853,6 +871,43 @@ def list_authenticated_providers( "api_url": api_url, }) + # --- 4. Saved custom providers from config --- + if custom_providers and isinstance(custom_providers, list): + for entry in custom_providers: + if not isinstance(entry, dict): + continue + + display_name = (entry.get("name") or "").strip() + api_url = ( + entry.get("base_url", "") + or entry.get("url", "") + or entry.get("api", "") + or "" + ).strip() + if not display_name or not api_url: + continue + + slug = "custom:" + display_name.lower().replace(" ", "-") + if slug in seen_slugs: + continue + + models_list = [] + default_model = (entry.get("model") or "").strip() + if default_model: + models_list.append(default_model) + + results.append({ + "slug": slug, + "name": display_name, + "is_current": slug == current_provider, + "is_user_defined": True, + "models": models_list, + "total_models": len(models_list), + "source": "user-config", + "api_url": api_url, + }) + seen_slugs.add(slug) + # Sort: current provider first, then by model count descending results.sort(key=lambda r: (not r["is_current"], -r["total_models"])) diff --git a/hermes_cli/providers.py b/hermes_cli/providers.py index 18109e6ea..13081fddb 100644 --- a/hermes_cli/providers.py +++ b/hermes_cli/providers.py @@ -452,9 +452,55 @@ def resolve_user_provider(name: str, user_config: Dict[str, Any]) -> Optional[Pr ) +def resolve_custom_provider( + name: str, + custom_providers: Optional[List[Dict[str, Any]]], +) -> Optional[ProviderDef]: + """Resolve a provider from the user's config.yaml ``custom_providers`` list.""" + if not custom_providers or not isinstance(custom_providers, list): + return None + + requested = (name or "").strip().lower() + canonical = normalize_provider(name) + if not requested: + return None + + for entry in custom_providers: + if not isinstance(entry, dict): + continue + + display_name = (entry.get("name") or "").strip() + api_url = ( + entry.get("base_url", "") + or entry.get("url", "") + or entry.get("api", "") + or "" + ).strip() + if not display_name or not api_url: + continue + + slug = "custom:" + display_name.lower().replace(" ", "-") + if requested not in {display_name.lower(), slug, canonical}: + continue + + return ProviderDef( + id=slug, + name=display_name, + transport="openai_chat", + api_key_env_vars=(), + base_url=api_url, + is_aggregator=False, + auth_type="api_key", + source="user-config", + ) + + return None + + def resolve_provider_full( name: str, user_providers: Optional[Dict[str, Any]] = None, + custom_providers: Optional[List[Dict[str, Any]]] = None, ) -> Optional[ProviderDef]: """Full resolution chain: built-in → models.dev → user config. @@ -463,6 +509,7 @@ def resolve_provider_full( Args: name: Provider name or alias. user_providers: The ``providers:`` dict from config.yaml (optional). + custom_providers: The ``custom_providers:`` list from config.yaml (optional). Returns: ProviderDef if found, else None. @@ -485,6 +532,11 @@ def resolve_provider_full( if user_pdef is not None: return user_pdef + # 2b. Saved custom providers from config + custom_pdef = resolve_custom_provider(name, custom_providers) + if custom_pdef is not None: + return custom_pdef + # 3. Try models.dev directly (for providers not in our ALIASES) try: from agent.models_dev import get_provider_info as _mdev_provider diff --git a/tests/gateway/test_model_command_custom_providers.py b/tests/gateway/test_model_command_custom_providers.py new file mode 100644 index 000000000..f64ce85c2 --- /dev/null +++ b/tests/gateway/test_model_command_custom_providers.py @@ -0,0 +1,61 @@ +"""Regression tests for gateway /model support of config.yaml custom_providers.""" + +import yaml +import pytest + +from gateway.config import Platform +from gateway.platforms.base import MessageEvent, MessageType +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +def _make_runner(): + runner = object.__new__(GatewayRunner) + runner.adapters = {} + return runner + + +def _make_event(text="/model"): + return MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"), + ) + + +@pytest.mark.asyncio +async def test_handle_model_command_lists_saved_custom_provider(tmp_path, monkeypatch): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + yaml.safe_dump( + { + "model": { + "default": "gpt-5.4", + "provider": "openai-codex", + "base_url": "https://chatgpt.com/backend-api/codex", + }, + "providers": {}, + "custom_providers": [ + { + "name": "Local (127.0.0.1:4141)", + "base_url": "http://127.0.0.1:4141/v1", + "model": "rotator-openrouter-coding", + } + ], + } + ), + encoding="utf-8", + ) + + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home) + monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {}) + + result = await _make_runner()._handle_model_command(_make_event()) + + assert result is not None + assert "Local (127.0.0.1:4141)" in result + assert "custom:local-(127.0.0.1:4141)" in result + assert "rotator-openrouter-coding" in result diff --git a/tests/hermes_cli/test_model_switch_custom_providers.py b/tests/hermes_cli/test_model_switch_custom_providers.py new file mode 100644 index 000000000..9b81e5641 --- /dev/null +++ b/tests/hermes_cli/test_model_switch_custom_providers.py @@ -0,0 +1,104 @@ +"""Regression tests for /model support of config.yaml custom_providers. + +The terminal `hermes model` flow already exposes `custom_providers`, but the +shared slash-command pipeline (`/model` in CLI/gateway/Telegram) historically +only looked at `providers:`. +""" + +import hermes_cli.providers as providers_mod +from hermes_cli.model_switch import list_authenticated_providers, switch_model +from hermes_cli.providers import resolve_provider_full + + +_MOCK_VALIDATION = { + "accepted": True, + "persist": True, + "recognized": True, + "message": None, +} + + +def test_list_authenticated_providers_includes_custom_providers(monkeypatch): + """No-args /model menus should include saved custom_providers entries.""" + monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {}) + monkeypatch.setattr(providers_mod, "HERMES_OVERLAYS", {}) + + providers = list_authenticated_providers( + current_provider="openai-codex", + user_providers={}, + custom_providers=[ + { + "name": "Local (127.0.0.1:4141)", + "base_url": "http://127.0.0.1:4141/v1", + "model": "rotator-openrouter-coding", + } + ], + max_models=50, + ) + + assert any( + p["slug"] == "custom:local-(127.0.0.1:4141)" + and p["name"] == "Local (127.0.0.1:4141)" + and p["models"] == ["rotator-openrouter-coding"] + and p["api_url"] == "http://127.0.0.1:4141/v1" + for p in providers + ) + + +def test_resolve_provider_full_finds_named_custom_provider(): + """Explicit /model --provider should resolve saved custom_providers entries.""" + resolved = resolve_provider_full( + "custom:local-(127.0.0.1:4141)", + user_providers={}, + custom_providers=[ + { + "name": "Local (127.0.0.1:4141)", + "base_url": "http://127.0.0.1:4141/v1", + } + ], + ) + + assert resolved is not None + assert resolved.id == "custom:local-(127.0.0.1:4141)" + assert resolved.name == "Local (127.0.0.1:4141)" + assert resolved.base_url == "http://127.0.0.1:4141/v1" + assert resolved.source == "user-config" + + +def test_switch_model_accepts_explicit_named_custom_provider(monkeypatch): + """Shared /model switch pipeline should accept --provider for custom_providers.""" + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + lambda requested: { + "api_key": "no-key-required", + "base_url": "http://127.0.0.1:4141/v1", + "api_mode": "chat_completions", + }, + ) + monkeypatch.setattr("hermes_cli.models.validate_requested_model", lambda *a, **k: _MOCK_VALIDATION) + monkeypatch.setattr("hermes_cli.model_switch.get_model_info", lambda *a, **k: None) + monkeypatch.setattr("hermes_cli.model_switch.get_model_capabilities", lambda *a, **k: None) + + result = switch_model( + raw_input="rotator-openrouter-coding", + current_provider="openai-codex", + current_model="gpt-5.4", + current_base_url="https://chatgpt.com/backend-api/codex", + current_api_key="", + explicit_provider="custom:local-(127.0.0.1:4141)", + user_providers={}, + custom_providers=[ + { + "name": "Local (127.0.0.1:4141)", + "base_url": "http://127.0.0.1:4141/v1", + "model": "rotator-openrouter-coding", + } + ], + ) + + assert result.success is True + assert result.target_provider == "custom:local-(127.0.0.1:4141)" + assert result.provider_label == "Local (127.0.0.1:4141)" + assert result.new_model == "rotator-openrouter-coding" + assert result.base_url == "http://127.0.0.1:4141/v1" + assert result.api_key == "no-key-required"