mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
fix(model): match custom provider by active base url
This commit is contained in:
parent
60bb98e003
commit
d9829ab45f
2 changed files with 220 additions and 54 deletions
|
|
@ -1832,52 +1832,10 @@ def select_provider_and_model(args=None):
|
|||
config_provider or os.getenv("HERMES_INFERENCE_PROVIDER") or "auto"
|
||||
)
|
||||
compatible_custom_providers = get_compatible_custom_providers(config)
|
||||
active = None
|
||||
if effective_provider != "auto":
|
||||
active_def = resolve_provider_full(
|
||||
effective_provider,
|
||||
config.get("providers"),
|
||||
compatible_custom_providers,
|
||||
)
|
||||
if active_def is not None:
|
||||
active = active_def.id
|
||||
else:
|
||||
warning = (
|
||||
f"Unknown provider '{effective_provider}'. Check 'hermes model' for "
|
||||
"available providers, or run 'hermes doctor' to diagnose config "
|
||||
"issues."
|
||||
)
|
||||
print(f"Warning: {warning} Falling back to auto provider detection.")
|
||||
if active is None:
|
||||
try:
|
||||
active = resolve_provider("auto")
|
||||
except AuthError as exc:
|
||||
if effective_provider == "auto":
|
||||
warning = format_auth_error(exc)
|
||||
print(f"Warning: {warning} Falling back to auto provider detection.")
|
||||
active = None # no provider yet; default to first in list
|
||||
|
||||
# Detect custom endpoint
|
||||
if active == "openrouter" and get_env_value("OPENAI_BASE_URL"):
|
||||
active = "custom"
|
||||
|
||||
from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS
|
||||
|
||||
provider_labels = dict(_PROVIDER_LABELS) # derive from canonical list
|
||||
active_label = provider_labels.get(active, active) if active else "none"
|
||||
|
||||
print()
|
||||
print(f" Current model: {current_model}")
|
||||
print(f" Active provider: {active_label}")
|
||||
print()
|
||||
|
||||
# Step 1: Provider selection — flat list from CANONICAL_PROVIDERS
|
||||
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
|
||||
|
||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
# Build a lookup of raw (un-expanded) api_key templates keyed by a
|
||||
# Build lookups of raw (un-expanded) templates keyed by a
|
||||
# stable identity. We intentionally bypass
|
||||
# ``get_compatible_custom_providers(read_raw_config())`` here because
|
||||
# its ``_normalize_custom_provider_entry`` step calls ``urlparse()``
|
||||
|
|
@ -1886,6 +1844,7 @@ def select_provider_and_model(args=None):
|
|||
# entries is exactly how env-ref preservation fails for the user
|
||||
# config that motivated this fix.
|
||||
raw_api_key_refs: dict[tuple, str] = {}
|
||||
raw_base_url_refs: dict[tuple, str] = {}
|
||||
raw_cfg = read_raw_config()
|
||||
|
||||
def _record_raw(
|
||||
|
|
@ -1893,10 +1852,10 @@ def select_provider_and_model(args=None):
|
|||
provider_key: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
) -> None:
|
||||
template = str(api_key or "").strip()
|
||||
if "${" not in template:
|
||||
return
|
||||
base_template = str(base_url or "").strip()
|
||||
name = str(name or "").strip()
|
||||
provider_key = str(provider_key or "").strip()
|
||||
model = str(model or "").strip()
|
||||
|
|
@ -1904,12 +1863,19 @@ def select_provider_and_model(args=None):
|
|||
# might present: (name), (name, model), (provider_key), and
|
||||
# (provider_key, model). Case-insensitive on name/provider_key so
|
||||
# the loaded entry matches regardless of display casing.
|
||||
identities = []
|
||||
if name:
|
||||
raw_api_key_refs.setdefault((name.lower(),), template)
|
||||
raw_api_key_refs.setdefault((name.lower(), model), template)
|
||||
identities.extend(((name.lower(),), (name.lower(), model)))
|
||||
if provider_key:
|
||||
raw_api_key_refs.setdefault((provider_key.lower(),), template)
|
||||
raw_api_key_refs.setdefault((provider_key.lower(), model), template)
|
||||
identities.extend(
|
||||
((provider_key.lower(),), (provider_key.lower(), model))
|
||||
)
|
||||
if "${" in template:
|
||||
for identity in identities:
|
||||
raw_api_key_refs.setdefault(identity, template)
|
||||
if "${" in base_template:
|
||||
for identity in identities:
|
||||
raw_base_url_refs.setdefault(identity, base_template)
|
||||
|
||||
raw_list = raw_cfg.get("custom_providers")
|
||||
if isinstance(raw_list, list):
|
||||
|
|
@ -1921,6 +1887,9 @@ def select_provider_and_model(args=None):
|
|||
"",
|
||||
raw_entry.get("model", "") or raw_entry.get("default_model", ""),
|
||||
raw_entry.get("api_key", ""),
|
||||
raw_entry.get("base_url", "")
|
||||
or raw_entry.get("url", "")
|
||||
or raw_entry.get("api", ""),
|
||||
)
|
||||
raw_providers = raw_cfg.get("providers")
|
||||
if isinstance(raw_providers, dict):
|
||||
|
|
@ -1932,9 +1901,17 @@ def select_provider_and_model(args=None):
|
|||
raw_key,
|
||||
raw_entry.get("model", "") or raw_entry.get("default_model", ""),
|
||||
raw_entry.get("api_key", ""),
|
||||
raw_entry.get("base_url", "")
|
||||
or raw_entry.get("url", "")
|
||||
or raw_entry.get("api", ""),
|
||||
)
|
||||
|
||||
def _lookup_ref(name: str, provider_key: str, model: str) -> str:
|
||||
def _lookup_ref(
|
||||
refs: dict[tuple, str],
|
||||
name: str,
|
||||
provider_key: str,
|
||||
model: str,
|
||||
) -> str:
|
||||
name_lc = str(name or "").strip().lower()
|
||||
pkey_lc = str(provider_key or "").strip().lower()
|
||||
model = str(model or "").strip()
|
||||
|
|
@ -1944,8 +1921,8 @@ def select_provider_and_model(args=None):
|
|||
(name_lc, model),
|
||||
(name_lc,),
|
||||
):
|
||||
if identity[0] and identity in raw_api_key_refs:
|
||||
return raw_api_key_refs[identity]
|
||||
if identity[0] and identity in refs:
|
||||
return refs[identity]
|
||||
return ""
|
||||
|
||||
custom_provider_map = {}
|
||||
|
|
@ -1971,14 +1948,81 @@ def select_provider_and_model(args=None):
|
|||
"model": entry.get("model", ""),
|
||||
"api_mode": entry.get("api_mode", ""),
|
||||
"provider_key": provider_key,
|
||||
"api_key_ref": _lookup_ref(name, provider_key, entry.get("model", "")),
|
||||
"api_key_ref": _lookup_ref(
|
||||
raw_api_key_refs, name, provider_key, entry.get("model", "")
|
||||
),
|
||||
"base_url_ref": _lookup_ref(
|
||||
raw_base_url_refs, name, provider_key, entry.get("model", "")
|
||||
),
|
||||
}
|
||||
return custom_provider_map
|
||||
|
||||
def _norm_base_url(url: str) -> str:
|
||||
return str(url or "").strip().rstrip("/").lower()
|
||||
|
||||
# Add user-defined custom providers from config.yaml
|
||||
_custom_provider_map = _named_custom_provider_map(
|
||||
config
|
||||
) # key → {name, base_url, api_key}
|
||||
|
||||
def _active_custom_key_from_base_url() -> str:
|
||||
if effective_provider != "custom" or not isinstance(model_cfg, dict):
|
||||
return ""
|
||||
current_base = _norm_base_url(model_cfg.get("base_url", ""))
|
||||
if not current_base:
|
||||
return ""
|
||||
for key, provider_info in _custom_provider_map.items():
|
||||
if _norm_base_url(provider_info.get("base_url", "")) == current_base:
|
||||
return key
|
||||
return ""
|
||||
|
||||
active = _active_custom_key_from_base_url()
|
||||
if active is None:
|
||||
active = ""
|
||||
if not active and effective_provider != "auto":
|
||||
active_def = resolve_provider_full(
|
||||
effective_provider,
|
||||
config.get("providers"),
|
||||
compatible_custom_providers,
|
||||
)
|
||||
if active_def is not None:
|
||||
active = active_def.id
|
||||
else:
|
||||
warning = (
|
||||
f"Unknown provider '{effective_provider}'. Check 'hermes model' for "
|
||||
"available providers, or run 'hermes doctor' to diagnose config "
|
||||
"issues."
|
||||
)
|
||||
print(f"Warning: {warning} Falling back to auto provider detection.")
|
||||
if not active:
|
||||
try:
|
||||
active = resolve_provider("auto")
|
||||
except AuthError as exc:
|
||||
if effective_provider == "auto":
|
||||
warning = format_auth_error(exc)
|
||||
print(f"Warning: {warning} Falling back to auto provider detection.")
|
||||
active = None # no provider yet; default to first in list
|
||||
|
||||
# Detect custom endpoint
|
||||
if active == "openrouter" and get_env_value("OPENAI_BASE_URL"):
|
||||
active = "custom"
|
||||
|
||||
from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS
|
||||
|
||||
provider_labels = dict(_PROVIDER_LABELS) # derive from canonical list
|
||||
if active and active in _custom_provider_map:
|
||||
active_label = _custom_provider_map[active]["name"]
|
||||
else:
|
||||
active_label = provider_labels.get(active, active) if active else "none"
|
||||
|
||||
print()
|
||||
print(f" Current model: {current_model}")
|
||||
print(f" Active provider: {active_label}")
|
||||
print()
|
||||
|
||||
# Step 1: Provider selection — flat list from CANONICAL_PROVIDERS
|
||||
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
|
||||
|
||||
for key, provider_info in _custom_provider_map.items():
|
||||
name = provider_info["name"]
|
||||
base_url = provider_info["base_url"]
|
||||
|
|
@ -3501,6 +3545,14 @@ def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""):
|
|||
return str(resolved_api_key or "").strip()
|
||||
|
||||
|
||||
def _custom_provider_base_url_config_value(provider_info, resolved_base_url=""):
|
||||
"""Return the value that should be persisted for a custom provider URL."""
|
||||
base_url_ref = str(provider_info.get("base_url_ref", "") or "").strip()
|
||||
if base_url_ref:
|
||||
return base_url_ref
|
||||
return str(resolved_base_url or "").strip()
|
||||
|
||||
|
||||
def _save_custom_provider(
|
||||
base_url, api_key="", model="", context_length=None, name=None, api_mode=None
|
||||
):
|
||||
|
|
@ -4114,7 +4166,9 @@ def _model_flow_named_custom(config, provider_info):
|
|||
model.pop("api_key", None)
|
||||
else:
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
model["base_url"] = _custom_provider_base_url_config_value(
|
||||
provider_info, base_url
|
||||
)
|
||||
if config_api_key:
|
||||
model["api_key"] = config_api_key
|
||||
# Apply api_mode from custom_providers entry, or clear stale value
|
||||
|
|
|
|||
|
|
@ -327,6 +327,118 @@ class TestCustomProviderModelSwitch:
|
|||
assert config["custom_providers"][0]["api_key"] == "${NEURALWATT_API_KEY}"
|
||||
assert "sk-live-neuralwatt-secret" not in saved
|
||||
|
||||
def test_bare_custom_current_provider_matches_env_base_url_before_first_fallback(
|
||||
self, config_home, monkeypatch
|
||||
):
|
||||
"""`hermes model` must mark the custom provider matching model.base_url
|
||||
as current instead of falling back to the first saved custom provider.
|
||||
|
||||
Regression: with ``model.provider: custom`` and multiple
|
||||
``custom_providers`` entries, the CLI resolved bare ``custom`` through
|
||||
``resolve_custom_provider()``, whose compatibility fallback returns the
|
||||
first entry. A config with Cerebras first and NeuralWatt active then
|
||||
showed Cerebras as current.
|
||||
"""
|
||||
from hermes_cli.main import select_provider_and_model
|
||||
|
||||
config_path = config_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"model:\n"
|
||||
" default: kimi-k2.6-fast\n"
|
||||
" provider: custom\n"
|
||||
" base_url: ${NEURALWATT_API_BASE}\n"
|
||||
" api_key: ${NEURALWATT_API_KEY}\n"
|
||||
"providers: {}\n"
|
||||
"custom_providers:\n"
|
||||
"- name: Cerebras.ai\n"
|
||||
" base_url: ${CEREBRAS_API_BASE}\n"
|
||||
" api_key: ${CEREBRAS_API_KEY}\n"
|
||||
" model: qwen-3-235b-a22b-instruct-2507\n"
|
||||
" models: []\n"
|
||||
"- name: NeuralWatt\n"
|
||||
" base_url: ${NEURALWATT_API_BASE}\n"
|
||||
" api_key: ${NEURALWATT_API_KEY}\n"
|
||||
" model: kimi-k2.6-fast\n"
|
||||
" models: []\n"
|
||||
)
|
||||
monkeypatch.setenv("CEREBRAS_API_BASE", "https://api.cerebras.ai/v1")
|
||||
monkeypatch.setenv("CEREBRAS_API_KEY", "sk-live-cerebras-secret")
|
||||
monkeypatch.setenv("NEURALWATT_API_BASE", "https://api.neuralwatt.com/v1")
|
||||
monkeypatch.setenv("NEURALWATT_API_KEY", "sk-live-neuralwatt-secret")
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def _capture_and_cancel(labels, default=0):
|
||||
captured["labels"] = labels
|
||||
captured["default"] = default
|
||||
return len(labels) - 1 # Leave unchanged
|
||||
|
||||
with patch("hermes_cli.main._prompt_provider_choice",
|
||||
side_effect=_capture_and_cancel), \
|
||||
patch("builtins.print"):
|
||||
select_provider_and_model()
|
||||
|
||||
labels = captured["labels"]
|
||||
default_label = labels[captured["default"]]
|
||||
assert "NeuralWatt" in default_label
|
||||
assert "currently active" in default_label
|
||||
assert "Cerebras.ai" not in default_label
|
||||
assert not any(
|
||||
"Cerebras.ai" in label and "currently active" in label
|
||||
for label in labels
|
||||
)
|
||||
|
||||
def test_named_custom_provider_selection_preserves_base_url_env_ref(
|
||||
self, config_home, monkeypatch
|
||||
):
|
||||
"""Selecting an env-backed custom provider should not expand its
|
||||
``base_url`` template into ``model.base_url`` on disk."""
|
||||
import yaml
|
||||
from hermes_cli.main import select_provider_and_model
|
||||
|
||||
config_path = config_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"model:\n"
|
||||
" default: old-model\n"
|
||||
" provider: openrouter\n"
|
||||
"custom_providers:\n"
|
||||
"- name: NeuralWatt\n"
|
||||
" base_url: ${NEURALWATT_API_BASE}\n"
|
||||
" api_key: ${NEURALWATT_API_KEY}\n"
|
||||
" model: qwen3.6-35b-fast\n"
|
||||
" models: []\n"
|
||||
)
|
||||
monkeypatch.setenv("NEURALWATT_API_BASE", "https://api.neuralwatt.com/v1")
|
||||
monkeypatch.setenv("NEURALWATT_API_KEY", "sk-live-neuralwatt-secret")
|
||||
|
||||
def _pick_neuralwatt(labels, default=0):
|
||||
for i, label in enumerate(labels):
|
||||
if "NeuralWatt" in label:
|
||||
return i
|
||||
raise AssertionError(
|
||||
f"NeuralWatt entry missing from provider menu: {labels}"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.main._prompt_provider_choice",
|
||||
side_effect=_pick_neuralwatt), \
|
||||
patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["qwen3.6-35b-fast"]) as mock_fetch, \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
select_provider_and_model()
|
||||
|
||||
mock_fetch.assert_called_once()
|
||||
probe_args, _ = mock_fetch.call_args
|
||||
assert probe_args[1] == "https://api.neuralwatt.com/v1"
|
||||
|
||||
saved = config_path.read_text()
|
||||
config = yaml.safe_load(saved) or {}
|
||||
assert config["model"]["base_url"] == "${NEURALWATT_API_BASE}"
|
||||
assert config["model"]["api_key"] == "${NEURALWATT_API_KEY}"
|
||||
assert "https://api.neuralwatt.com/v1" not in saved
|
||||
assert "sk-live-neuralwatt-secret" not in saved
|
||||
|
||||
def test_key_env_providers_dict_entry_does_not_add_api_key(
|
||||
self, config_home, monkeypatch
|
||||
):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue