diff --git a/apps/desktop/src/app/session/hooks/use-model-controls.test.tsx b/apps/desktop/src/app/session/hooks/use-model-controls.test.tsx index 8f52018982a..612290800e0 100644 --- a/apps/desktop/src/app/session/hooks/use-model-controls.test.tsx +++ b/apps/desktop/src/app/session/hooks/use-model-controls.test.tsx @@ -1,5 +1,5 @@ -import { renderHook } from '@testing-library/react' import { QueryClient } from '@tanstack/react-query' +import { cleanup, render, renderHook } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { getGlobalModelInfo } from '@/hermes' @@ -13,12 +13,51 @@ import { import { useModelControls } from './use-model-controls' +const setGlobalModel = vi.fn() +const notifyError = vi.fn() + vi.mock('@/hermes', () => ({ getGlobalModelInfo: vi.fn(), - setGlobalModel: vi.fn() + setGlobalModel: (...args: Parameters) => setGlobalModel(...args) })) -describe('useModelControls.refreshCurrentModel', () => { +vi.mock('@/i18n', () => ({ + useI18n: () => ({ + t: { + desktop: { + modelSwitchFailed: 'Model switch failed' + } + } + }) +})) + +vi.mock('@/store/notifications', () => ({ + notifyError: (...args: Parameters) => notifyError(...args) +})) + +type Controls = ReturnType + +function Harness({ + activeSessionId, + onReady, + requestGateway +}: { + activeSessionId: string | null + onReady: (controls: Controls) => void + requestGateway: (method: string, params?: Record) => Promise +}) { + const controls = useModelControls({ + activeSessionId, + queryClient: new QueryClient(), + requestGateway + }) + + onReady(controls) + + return null +} + +describe('useModelControls', () => { beforeEach(() => { $activeSessionId.set(null) setCurrentModel('') @@ -26,6 +65,7 @@ describe('useModelControls.refreshCurrentModel', () => { }) afterEach(() => { + cleanup() vi.restoreAllMocks() $activeSessionId.set(null) setCurrentModel('') @@ -74,4 +114,55 @@ describe('useModelControls.refreshCurrentModel', () => { expect($currentModel.get()).toBe('deepseek/deepseek-v4-pro') expect($currentProvider.get()).toBe('deepseek') }) + + it('routes active-session picker changes through config.set with an explicit provider', async () => { + const requestGateway = vi.fn(async () => ({ key: 'model', value: 'claude-sonnet-4.6' }) as never) + let controls!: Controls + + render( + (controls = value)} + requestGateway={requestGateway} + /> + ) + + await expect( + controls.selectModel({ + model: 'claude-sonnet-4.6', + persistGlobal: false, + provider: 'anthropic' + }) + ).resolves.toBe(true) + + expect(requestGateway).toHaveBeenCalledWith('config.set', { + session_id: 'session-1', + key: 'model', + value: 'claude-sonnet-4.6 --provider anthropic' + }) + expect(requestGateway).not.toHaveBeenCalledWith('slash.exec', expect.anything()) + }) + + it('keeps the global path on setGlobalModel when there is no active session', async () => { + setGlobalModel.mockResolvedValue(undefined) + let controls!: Controls + + render( + (controls = value)} + requestGateway={vi.fn()} + /> + ) + + await expect( + controls.selectModel({ + model: 'claude-sonnet-4.6', + persistGlobal: false, + provider: 'anthropic' + }) + ).resolves.toBe(true) + + expect(setGlobalModel).toHaveBeenCalledWith('anthropic', 'claude-sonnet-4.6') + }) }) diff --git a/apps/desktop/src/app/session/hooks/use-model-controls.ts b/apps/desktop/src/app/session/hooks/use-model-controls.ts index 525c8d8385b..681eac871a2 100644 --- a/apps/desktop/src/app/session/hooks/use-model-controls.ts +++ b/apps/desktop/src/app/session/hooks/use-model-controls.ts @@ -82,9 +82,10 @@ export function useModelControls({ activeSessionId, queryClient, requestGateway try { if (activeSessionId) { - await requestGateway('slash.exec', { + await requestGateway('config.set', { session_id: activeSessionId, - command: `/model ${selection.model} --provider ${selection.provider}${selection.persistGlobal ? ' --global' : ''}` + key: 'model', + value: `${selection.model} --provider ${selection.provider}${selection.persistGlobal ? ' --global' : ''}` }) if (selection.persistGlobal) { diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 90a7f200255..da85cc26ad6 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -3036,6 +3036,94 @@ def test_config_set_model_global_persists(monkeypatch): assert saved["model"]["base_url"] == "https://api.anthropic.com" +def test_config_set_model_explicit_provider_skips_broken_default_init(monkeypatch): + seen = {"build": 0, "wait": 0, "requested": []} + session = _session() + session["agent"] = None + server._sessions["sid"] = session + monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"default": "broken/model", "provider": "openrouter"}}) + monkeypatch.setattr(server, "_start_agent_build", lambda *_args: seen.__setitem__("build", seen["build"] + 1)) + monkeypatch.setattr(server, "_wait_agent", lambda *_args: seen.__setitem__("wait", seen["wait"] + 1)) + monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None) + monkeypatch.setattr(server, "_restart_slash_worker", lambda *args, **kwargs: None) + + def fake_runtime_provider(*, requested=None, target_model=None, **_kwargs): + seen["requested"].append((requested, target_model)) + if requested is None: + raise RuntimeError("broken default provider should not be initialized") + if requested == "anthropic": + return { + "api_key": "sk-anthropic", + "api_mode": "anthropic_messages", + "base_url": "https://api.anthropic.com", + } + raise RuntimeError(f"unexpected provider {requested}") + + monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", fake_runtime_provider) + + try: + resp = server.handle_request( + { + "id": "1", + "method": "config.set", + "params": { + "session_id": "sid", + "key": "model", + "value": "claude-sonnet-4.6 --provider anthropic", + }, + } + ) + + assert resp["result"]["value"] == "claude-sonnet-4-6" + assert seen["build"] == 0 + assert seen["wait"] == 0 + assert seen["requested"] == [("anthropic", "claude-sonnet-4.6")] + assert session["model_override"]["provider"] == "anthropic" + assert session["model_override"]["model"] == "claude-sonnet-4-6" + finally: + server._sessions.pop("sid", None) + + +def test_config_set_model_explicit_provider_surfaces_selected_provider_errors(monkeypatch): + seen = {"build": 0, "wait": 0} + session = _session() + session["agent"] = None + server._sessions["sid"] = session + monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"default": "broken/model", "provider": "openrouter"}}) + monkeypatch.setattr(server, "_start_agent_build", lambda *_args: seen.__setitem__("build", seen["build"] + 1)) + monkeypatch.setattr(server, "_wait_agent", lambda *_args: seen.__setitem__("wait", seen["wait"] + 1)) + + def fake_runtime_provider(*, requested=None, **_kwargs): + if requested is None: + raise RuntimeError("broken default provider should not be initialized") + if requested == "anthropic": + raise RuntimeError("missing anthropic API key") + raise RuntimeError(f"unexpected provider {requested}") + + monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", fake_runtime_provider) + + try: + resp = server.handle_request( + { + "id": "1", + "method": "config.set", + "params": { + "session_id": "sid", + "key": "model", + "value": "claude-sonnet-4.6 --provider anthropic", + }, + } + ) + + assert resp["error"]["code"] == 5001 + assert "anthropic" in resp["error"]["message"].lower() + assert "missing anthropic api key" in resp["error"]["message"].lower() + assert seen["build"] == 0 + assert seen["wait"] == 0 + finally: + server._sessions.pop("sid", None) + + def test_config_set_model_does_not_leak_inference_provider_env(monkeypatch): """A /model switch must NOT mutate process-global env vars. The desktop / dashboard tui_gateway backend hosts every same-profile session in one diff --git a/tui_gateway/server.py b/tui_gateway/server.py index d34f558f6cf..715ca8b48b6 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1961,11 +1961,14 @@ def _apply_model_switch( *, confirm_expensive_model: bool = False, pin_session_override: bool = True, + parsed_flags: tuple[str, str, bool, bool] | None = None, ) -> dict: from hermes_cli.model_switch import parse_model_flags, switch_model from hermes_cli.runtime_provider import resolve_runtime_provider - model_input, explicit_provider, persist_global, _force_refresh = parse_model_flags(raw_input) + if parsed_flags is None: + parsed_flags = parse_model_flags(raw_input) + model_input, explicit_provider, persist_global, _force_refresh = parsed_flags if not model_input: raise ValueError("model value required") @@ -1976,20 +1979,24 @@ def _apply_model_switch( current_base_url = getattr(agent, "base_url", "") or "" current_api_key = getattr(agent, "api_key", "") or "" else: - runtime = resolve_runtime_provider(requested=None) - current_provider = str(runtime.get("provider", "") or "") current_model = _resolve_model() - current_base_url = str(runtime.get("base_url", "") or "") - # Preserve a callable api_key (Azure Foundry Entra ID bearer - # provider) unchanged — ``str(...)`` would produce - # ``""`` and poison downstream switch_model - # validation. Match the agent-present branch's behavior at the - # top of this block. - _runtime_key = runtime.get("api_key", "") - if callable(_runtime_key) and not isinstance(_runtime_key, str): - current_api_key = _runtime_key - else: - current_api_key = str(_runtime_key or "") + current_provider = explicit_provider.strip() + current_base_url = "" + current_api_key = "" + if not explicit_provider: + runtime = resolve_runtime_provider(requested=None) + current_provider = str(runtime.get("provider", "") or "") + current_base_url = str(runtime.get("base_url", "") or "") + # Preserve a callable api_key (Azure Foundry Entra ID bearer + # provider) unchanged — ``str(...)`` would produce + # ``""`` and poison downstream switch_model + # validation. Match the agent-present branch's behavior at the + # top of this block. + _runtime_key = runtime.get("api_key", "") + if callable(_runtime_key) and not isinstance(_runtime_key, str): + current_api_key = _runtime_key + else: + current_api_key = str(_runtime_key or "") # Load user-defined providers so switch_model can resolve named custom # endpoints (e.g. "ollama-launch") and validate against saved model lists. @@ -6996,7 +7003,11 @@ def _(rid, params: dict) -> dict: 4009, "session busy — /interrupt the current turn before switching models", ) - if session.get("agent") is None: + from hermes_cli.model_switch import parse_model_flags + + parsed_flags = parse_model_flags(value) + _model_input, explicit_provider, _persist_global, _force_refresh = parsed_flags + if session.get("agent") is None and not explicit_provider.strip(): session_id = params.get("session_id", "") _start_agent_build(session_id, session) init_err = _wait_agent(session, rid) @@ -7011,6 +7022,7 @@ def _(rid, params: dict) -> dict: confirm_expensive_model=bool( params.get("confirm_expensive_model", False) ), + parsed_flags=parsed_flags, ) else: result = _apply_model_switch(