mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-17 09:41:58 +00:00
When a desktop/dashboard session had no agent built yet and the user explicitly
picked a provider in the model picker, config.set('model', ...) would first try
to initialize the agent from the (possibly broken) config default provider —
failing before the user's explicit switch could take effect, trapping them on a
misconfigured default.
config.set now pre-parses the model flags: if an explicit --provider is present
and no agent exists yet, it skips the default-provider agent build and routes
straight through _apply_model_switch with the explicit provider. _apply_model_switch
gained a parsed_flags passthrough (avoids double-parsing) and only falls back to
resolve_runtime_provider(requested=None) when no explicit provider was given.
The desktop hook now sends config.set instead of slash.exec for active-session
model changes, so errors from the selected provider surface to the user instead
of being swallowed.
Co-authored-by: rodboev <rod.boev@gmail.com>
This commit is contained in:
parent
2a08b8c86f
commit
ed20f5ed06
4 changed files with 212 additions and 20 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import { renderHook } from '@testing-library/react'
|
||||
import { QueryClient } from '@tanstack/react-query'
|
||||
import { cleanup, render, renderHook } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getGlobalModelInfo } from '@/hermes'
|
||||
|
|
@ -13,12 +13,51 @@ import {
|
|||
|
||||
import { useModelControls } from './use-model-controls'
|
||||
|
||||
const setGlobalModel = vi.fn()
|
||||
const notifyError = vi.fn()
|
||||
|
||||
vi.mock('@/hermes', () => ({
|
||||
getGlobalModelInfo: vi.fn(),
|
||||
setGlobalModel: vi.fn()
|
||||
setGlobalModel: (...args: Parameters<typeof setGlobalModel>) => setGlobalModel(...args)
|
||||
}))
|
||||
|
||||
describe('useModelControls.refreshCurrentModel', () => {
|
||||
vi.mock('@/i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: {
|
||||
desktop: {
|
||||
modelSwitchFailed: 'Model switch failed'
|
||||
}
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/store/notifications', () => ({
|
||||
notifyError: (...args: Parameters<typeof notifyError>) => notifyError(...args)
|
||||
}))
|
||||
|
||||
type Controls = ReturnType<typeof useModelControls>
|
||||
|
||||
function Harness({
|
||||
activeSessionId,
|
||||
onReady,
|
||||
requestGateway
|
||||
}: {
|
||||
activeSessionId: string | null
|
||||
onReady: (controls: Controls) => void
|
||||
requestGateway: <T = unknown>(method: string, params?: Record<string, unknown>) => Promise<T>
|
||||
}) {
|
||||
const controls = useModelControls({
|
||||
activeSessionId,
|
||||
queryClient: new QueryClient(),
|
||||
requestGateway
|
||||
})
|
||||
|
||||
onReady(controls)
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
describe('useModelControls', () => {
|
||||
beforeEach(() => {
|
||||
$activeSessionId.set(null)
|
||||
setCurrentModel('')
|
||||
|
|
@ -26,6 +65,7 @@ describe('useModelControls.refreshCurrentModel', () => {
|
|||
})
|
||||
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
vi.restoreAllMocks()
|
||||
$activeSessionId.set(null)
|
||||
setCurrentModel('')
|
||||
|
|
@ -74,4 +114,55 @@ describe('useModelControls.refreshCurrentModel', () => {
|
|||
expect($currentModel.get()).toBe('deepseek/deepseek-v4-pro')
|
||||
expect($currentProvider.get()).toBe('deepseek')
|
||||
})
|
||||
|
||||
it('routes active-session picker changes through config.set with an explicit provider', async () => {
|
||||
const requestGateway = vi.fn(async () => ({ key: 'model', value: 'claude-sonnet-4.6' }) as never)
|
||||
let controls!: Controls
|
||||
|
||||
render(
|
||||
<Harness
|
||||
activeSessionId="session-1"
|
||||
onReady={value => (controls = value)}
|
||||
requestGateway={requestGateway}
|
||||
/>
|
||||
)
|
||||
|
||||
await expect(
|
||||
controls.selectModel({
|
||||
model: 'claude-sonnet-4.6',
|
||||
persistGlobal: false,
|
||||
provider: 'anthropic'
|
||||
})
|
||||
).resolves.toBe(true)
|
||||
|
||||
expect(requestGateway).toHaveBeenCalledWith('config.set', {
|
||||
session_id: 'session-1',
|
||||
key: 'model',
|
||||
value: 'claude-sonnet-4.6 --provider anthropic'
|
||||
})
|
||||
expect(requestGateway).not.toHaveBeenCalledWith('slash.exec', expect.anything())
|
||||
})
|
||||
|
||||
it('keeps the global path on setGlobalModel when there is no active session', async () => {
|
||||
setGlobalModel.mockResolvedValue(undefined)
|
||||
let controls!: Controls
|
||||
|
||||
render(
|
||||
<Harness
|
||||
activeSessionId={null}
|
||||
onReady={value => (controls = value)}
|
||||
requestGateway={vi.fn()}
|
||||
/>
|
||||
)
|
||||
|
||||
await expect(
|
||||
controls.selectModel({
|
||||
model: 'claude-sonnet-4.6',
|
||||
persistGlobal: false,
|
||||
provider: 'anthropic'
|
||||
})
|
||||
).resolves.toBe(true)
|
||||
|
||||
expect(setGlobalModel).toHaveBeenCalledWith('anthropic', 'claude-sonnet-4.6')
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -82,9 +82,10 @@ export function useModelControls({ activeSessionId, queryClient, requestGateway
|
|||
|
||||
try {
|
||||
if (activeSessionId) {
|
||||
await requestGateway('slash.exec', {
|
||||
await requestGateway('config.set', {
|
||||
session_id: activeSessionId,
|
||||
command: `/model ${selection.model} --provider ${selection.provider}${selection.persistGlobal ? ' --global' : ''}`
|
||||
key: 'model',
|
||||
value: `${selection.model} --provider ${selection.provider}${selection.persistGlobal ? ' --global' : ''}`
|
||||
})
|
||||
|
||||
if (selection.persistGlobal) {
|
||||
|
|
|
|||
|
|
@ -3036,6 +3036,94 @@ def test_config_set_model_global_persists(monkeypatch):
|
|||
assert saved["model"]["base_url"] == "https://api.anthropic.com"
|
||||
|
||||
|
||||
def test_config_set_model_explicit_provider_skips_broken_default_init(monkeypatch):
|
||||
seen = {"build": 0, "wait": 0, "requested": []}
|
||||
session = _session()
|
||||
session["agent"] = None
|
||||
server._sessions["sid"] = session
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"default": "broken/model", "provider": "openrouter"}})
|
||||
monkeypatch.setattr(server, "_start_agent_build", lambda *_args: seen.__setitem__("build", seen["build"] + 1))
|
||||
monkeypatch.setattr(server, "_wait_agent", lambda *_args: seen.__setitem__("wait", seen["wait"] + 1))
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda *args, **kwargs: None)
|
||||
|
||||
def fake_runtime_provider(*, requested=None, target_model=None, **_kwargs):
|
||||
seen["requested"].append((requested, target_model))
|
||||
if requested is None:
|
||||
raise RuntimeError("broken default provider should not be initialized")
|
||||
if requested == "anthropic":
|
||||
return {
|
||||
"api_key": "sk-anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
}
|
||||
raise RuntimeError(f"unexpected provider {requested}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", fake_runtime_provider)
|
||||
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "config.set",
|
||||
"params": {
|
||||
"session_id": "sid",
|
||||
"key": "model",
|
||||
"value": "claude-sonnet-4.6 --provider anthropic",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert resp["result"]["value"] == "claude-sonnet-4-6"
|
||||
assert seen["build"] == 0
|
||||
assert seen["wait"] == 0
|
||||
assert seen["requested"] == [("anthropic", "claude-sonnet-4.6")]
|
||||
assert session["model_override"]["provider"] == "anthropic"
|
||||
assert session["model_override"]["model"] == "claude-sonnet-4-6"
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_config_set_model_explicit_provider_surfaces_selected_provider_errors(monkeypatch):
|
||||
seen = {"build": 0, "wait": 0}
|
||||
session = _session()
|
||||
session["agent"] = None
|
||||
server._sessions["sid"] = session
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"default": "broken/model", "provider": "openrouter"}})
|
||||
monkeypatch.setattr(server, "_start_agent_build", lambda *_args: seen.__setitem__("build", seen["build"] + 1))
|
||||
monkeypatch.setattr(server, "_wait_agent", lambda *_args: seen.__setitem__("wait", seen["wait"] + 1))
|
||||
|
||||
def fake_runtime_provider(*, requested=None, **_kwargs):
|
||||
if requested is None:
|
||||
raise RuntimeError("broken default provider should not be initialized")
|
||||
if requested == "anthropic":
|
||||
raise RuntimeError("missing anthropic API key")
|
||||
raise RuntimeError(f"unexpected provider {requested}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", fake_runtime_provider)
|
||||
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "config.set",
|
||||
"params": {
|
||||
"session_id": "sid",
|
||||
"key": "model",
|
||||
"value": "claude-sonnet-4.6 --provider anthropic",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert resp["error"]["code"] == 5001
|
||||
assert "anthropic" in resp["error"]["message"].lower()
|
||||
assert "missing anthropic api key" in resp["error"]["message"].lower()
|
||||
assert seen["build"] == 0
|
||||
assert seen["wait"] == 0
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
|
||||
def test_config_set_model_does_not_leak_inference_provider_env(monkeypatch):
|
||||
"""A /model switch must NOT mutate process-global env vars. The desktop /
|
||||
dashboard tui_gateway backend hosts every same-profile session in one
|
||||
|
|
|
|||
|
|
@ -1961,11 +1961,14 @@ def _apply_model_switch(
|
|||
*,
|
||||
confirm_expensive_model: bool = False,
|
||||
pin_session_override: bool = True,
|
||||
parsed_flags: tuple[str, str, bool, bool] | None = None,
|
||||
) -> dict:
|
||||
from hermes_cli.model_switch import parse_model_flags, switch_model
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
model_input, explicit_provider, persist_global, _force_refresh = parse_model_flags(raw_input)
|
||||
if parsed_flags is None:
|
||||
parsed_flags = parse_model_flags(raw_input)
|
||||
model_input, explicit_provider, persist_global, _force_refresh = parsed_flags
|
||||
if not model_input:
|
||||
raise ValueError("model value required")
|
||||
|
||||
|
|
@ -1976,20 +1979,24 @@ def _apply_model_switch(
|
|||
current_base_url = getattr(agent, "base_url", "") or ""
|
||||
current_api_key = getattr(agent, "api_key", "") or ""
|
||||
else:
|
||||
runtime = resolve_runtime_provider(requested=None)
|
||||
current_provider = str(runtime.get("provider", "") or "")
|
||||
current_model = _resolve_model()
|
||||
current_base_url = str(runtime.get("base_url", "") or "")
|
||||
# Preserve a callable api_key (Azure Foundry Entra ID bearer
|
||||
# provider) unchanged — ``str(...)`` would produce
|
||||
# ``"<function ...>"`` and poison downstream switch_model
|
||||
# validation. Match the agent-present branch's behavior at the
|
||||
# top of this block.
|
||||
_runtime_key = runtime.get("api_key", "")
|
||||
if callable(_runtime_key) and not isinstance(_runtime_key, str):
|
||||
current_api_key = _runtime_key
|
||||
else:
|
||||
current_api_key = str(_runtime_key or "")
|
||||
current_provider = explicit_provider.strip()
|
||||
current_base_url = ""
|
||||
current_api_key = ""
|
||||
if not explicit_provider:
|
||||
runtime = resolve_runtime_provider(requested=None)
|
||||
current_provider = str(runtime.get("provider", "") or "")
|
||||
current_base_url = str(runtime.get("base_url", "") or "")
|
||||
# Preserve a callable api_key (Azure Foundry Entra ID bearer
|
||||
# provider) unchanged — ``str(...)`` would produce
|
||||
# ``"<function ...>"`` and poison downstream switch_model
|
||||
# validation. Match the agent-present branch's behavior at the
|
||||
# top of this block.
|
||||
_runtime_key = runtime.get("api_key", "")
|
||||
if callable(_runtime_key) and not isinstance(_runtime_key, str):
|
||||
current_api_key = _runtime_key
|
||||
else:
|
||||
current_api_key = str(_runtime_key or "")
|
||||
|
||||
# Load user-defined providers so switch_model can resolve named custom
|
||||
# endpoints (e.g. "ollama-launch") and validate against saved model lists.
|
||||
|
|
@ -6996,7 +7003,11 @@ def _(rid, params: dict) -> dict:
|
|||
4009,
|
||||
"session busy — /interrupt the current turn before switching models",
|
||||
)
|
||||
if session.get("agent") is None:
|
||||
from hermes_cli.model_switch import parse_model_flags
|
||||
|
||||
parsed_flags = parse_model_flags(value)
|
||||
_model_input, explicit_provider, _persist_global, _force_refresh = parsed_flags
|
||||
if session.get("agent") is None and not explicit_provider.strip():
|
||||
session_id = params.get("session_id", "")
|
||||
_start_agent_build(session_id, session)
|
||||
init_err = _wait_agent(session, rid)
|
||||
|
|
@ -7011,6 +7022,7 @@ def _(rid, params: dict) -> dict:
|
|||
confirm_expensive_model=bool(
|
||||
params.get("confirm_expensive_model", False)
|
||||
),
|
||||
parsed_flags=parsed_flags,
|
||||
)
|
||||
else:
|
||||
result = _apply_model_switch(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue