diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 4a7f4785e6..523655d4b9 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1085,9 +1085,7 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict: from hermes_cli.config import get_compatible_custom_providers, load_config cfg = load_config() - user_provs = [ - {"provider": k, **v} for k, v in (cfg.get("providers") or {}).items() - ] + user_provs = cfg.get("providers") custom_provs = get_compatible_custom_providers(cfg) except Exception: pass @@ -4737,6 +4735,7 @@ def _(rid, params: dict) -> dict: def _(rid, params: dict) -> dict: try: from hermes_cli.model_switch import list_authenticated_providers + from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS session = _sessions.get(params.get("session_id", "")) agent = session.get("agent") if session else None @@ -4750,6 +4749,127 @@ def _(rid, params: dict) -> dict: # provider_model_ids() — that bypasses curation and pulls in # non-agentic models (e.g. Nous /models returns ~400 IDs including # TTS, embeddings, rerankers, image/video generators). + user_provs = ( + cfg.get("providers") if isinstance(cfg.get("providers"), dict) else {} + ) + custom_provs = ( + cfg.get("custom_providers") + if isinstance(cfg.get("custom_providers"), list) + else [] + ) + authenticated = list_authenticated_providers( + current_provider=current_provider, + current_base_url=current_base_url, + current_model=current_model, + user_providers=user_provs, + custom_providers=custom_provs, + max_models=50, + ) + + # Mark authenticated providers and build lookup by slug + authed_map: dict = {} + authed_extra: list = [] # user-defined/custom not in CANONICAL_PROVIDERS + canonical_slugs = {e.slug for e in CANONICAL_PROVIDERS} + for p in authenticated: + p["authenticated"] = True + authed_map[p["slug"]] = p + if p["slug"] not in canonical_slugs: + authed_extra.append(p) + + # Build final list in CANONICAL_PROVIDERS order, merging auth data + from hermes_cli.auth import PROVIDER_REGISTRY as _auth_reg + ordered: list = [] + for entry in CANONICAL_PROVIDERS: + if entry.slug in authed_map: + ordered.append(authed_map[entry.slug]) + else: + pconfig = _auth_reg.get(entry.slug) + auth_type = pconfig.auth_type if pconfig else "api_key" + key_env = pconfig.api_key_env_vars[0] if (pconfig and pconfig.api_key_env_vars) else "" + if auth_type == "api_key" and key_env: + warning = f"paste {key_env} to activate" + else: + warning = f"run `hermes model` to configure ({auth_type})" + ordered.append({ + "slug": entry.slug, + "name": _PROVIDER_LABELS.get(entry.slug, entry.label), + "is_current": entry.slug == current_provider, + "is_user_defined": False, + "models": [], + "total_models": 0, + "source": "built-in", + "authenticated": False, + "auth_type": auth_type, + "key_env": key_env, + "warning": warning, + }) + + # Append user-defined/custom providers not in canonical list + ordered.extend(authed_extra) + + return _ok( + rid, + { + "providers": ordered, + "model": current_model, + "provider": current_provider, + }, + ) + except Exception as e: + return _err(rid, 5033, str(e)) + + +@method("model.save_key") +def _(rid, params: dict) -> dict: + """Save an API key for a provider, then return its refreshed model list. + + Params: + slug: provider slug (e.g. "deepseek", "xai") + api_key: the key value to save + + Returns the provider dict with models populated (same shape as + model.options entries) on success. + """ + try: + from hermes_cli.auth import PROVIDER_REGISTRY + from hermes_cli.config import is_managed, save_env_value + from hermes_cli.model_switch import list_authenticated_providers + + slug = (params.get("slug") or "").strip() + api_key = (params.get("api_key") or "").strip() + if not slug or not api_key: + return _err(rid, 4001, "slug and api_key are required") + + if is_managed(): + return _err(rid, 4006, "managed install — credentials are read-only") + + pconfig = PROVIDER_REGISTRY.get(slug) + if not pconfig: + return _err(rid, 4002, f"unknown provider: {slug}") + if pconfig.auth_type != "api_key": + return _err( + rid, 4003, + f"{pconfig.name} uses {pconfig.auth_type} auth — " + f"run `hermes model` to configure" + ) + if not pconfig.api_key_env_vars: + return _err(rid, 4004, f"no env var defined for {pconfig.name}") + + # Save the key to ~/.hermes/.env + env_var = pconfig.api_key_env_vars[0] + save_env_value(env_var, api_key) + # Also set in current process so list_authenticated_providers sees it + import os + os.environ[env_var] = api_key + + # Refresh provider data + cfg = _load_cfg() + session = _sessions.get(params.get("session_id", "")) + agent = session.get("agent") if session else None + current_provider = getattr(agent, "provider", "") or "" + current_model = getattr(agent, "model", "") or _resolve_model() + current_base_url = getattr(agent, "base_url", "") or "" + providers = list_authenticated_providers( current_provider=current_provider, current_base_url=current_base_url, @@ -4764,16 +4884,72 @@ def _(rid, params: dict) -> dict: ), max_models=50, ) - return _ok( - rid, - { - "providers": providers, - "model": current_model, - "provider": current_provider, - }, - ) + + # Find the newly-authenticated provider + provider_data = None + for p in providers: + if p["slug"] == slug: + provider_data = p + break + + if not provider_data: + # Key was saved but provider didn't appear — still return success + provider_data = { + "slug": slug, + "name": pconfig.name, + "is_current": False, + "models": [], + "total_models": 0, + "authenticated": True, + } + + provider_data["authenticated"] = True + return _ok(rid, {"provider": provider_data}) except Exception as e: - return _err(rid, 5033, str(e)) + return _err(rid, 5034, str(e)) + + +@method("model.disconnect") +def _(rid, params: dict) -> dict: + """Remove credentials for a provider. + + Params: + slug: provider slug (e.g. "deepseek", "xai") + + Returns success status and the provider's slug. + """ + try: + from hermes_cli.auth import PROVIDER_REGISTRY, clear_provider_auth + from hermes_cli.config import remove_env_value + + slug = (params.get("slug") or "").strip() + if not slug: + return _err(rid, 4001, "slug is required") + + pconfig = PROVIDER_REGISTRY.get(slug) + cleared_env = False + cleared_auth = False + + # Remove API key env vars from .env and process + if pconfig and pconfig.api_key_env_vars: + for ev in pconfig.api_key_env_vars: + if remove_env_value(ev): + cleared_env = True + + # Clear OAuth / credential pool state + cleared_auth = clear_provider_auth(slug) + + if not cleared_env and not cleared_auth: + return _err(rid, 4005, f"no credentials found for {slug}") + + provider_name = pconfig.name if pconfig else slug + return _ok(rid, { + "slug": slug, + "name": provider_name, + "disconnected": True, + }) + except Exception as e: + return _err(rid, 5035, str(e)) # ── Methods: slash.exec ────────────────────────────────────────────── diff --git a/ui-tui/src/components/modelPicker.tsx b/ui-tui/src/components/modelPicker.tsx index 833496e4ff..45c9bc4cda 100644 --- a/ui-tui/src/components/modelPicker.tsx +++ b/ui-tui/src/components/modelPicker.tsx @@ -8,12 +8,14 @@ import type { ModelOptionProvider, ModelOptionsResponse } from '../gatewayTypes. import { asRpcResult, rpcErrorMessage } from '../lib/rpc.js' import type { Theme } from '../theme.js' -import { OverlayHint, useOverlayKeys, windowItems, windowOffset } from './overlayControls.js' +import { OverlayHint, useOverlayKeys, windowItems } from './overlayControls.js' const VISIBLE = 12 const MIN_WIDTH = 40 const MAX_WIDTH = 90 +type Stage = 'provider' | 'key' | 'model' | 'disconnect' + export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPickerProps) { const [providers, setProviders] = useState([]) const [currentModel, setCurrentModel] = useState('') @@ -22,7 +24,10 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke const [persistGlobal, setPersistGlobal] = useState(false) const [providerIdx, setProviderIdx] = useState(0) const [modelIdx, setModelIdx] = useState(0) - const [stage, setStage] = useState<'model' | 'provider'>('provider') + const [stage, setStage] = useState('provider') + const [keyInput, setKeyInput] = useState('') + const [keySaving, setKeySaving] = useState(false) + const [keyError, setKeyError] = useState('') const { stdout } = useStdout() // Pin the picker to a stable width so the FloatBox parent (which shrinks- @@ -68,9 +73,12 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke const names = useMemo(() => providerDisplayNames(providers), [providers]) const back = () => { - if (stage === 'model') { + if (stage === 'model' || stage === 'key' || stage === 'disconnect') { setStage('provider') setModelIdx(0) + setKeyInput('') + setKeyError('') + setKeySaving(false) return } @@ -81,6 +89,118 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke useOverlayKeys({ onBack: back, onClose: onCancel }) useInput((ch, key) => { + // Key entry stage handles its own input + if (stage === 'key') { + if (keySaving) { + return + } + + if (key.return) { + if (!keyInput.trim()) { + return + } + + setKeySaving(true) + setKeyError('') + gw.request<{ provider?: ModelOptionProvider }>('model.save_key', { + slug: provider?.slug, + api_key: keyInput.trim(), + ...(sessionId ? { session_id: sessionId } : {}), + }) + .then(raw => { + const r = asRpcResult<{ provider?: ModelOptionProvider }>(raw) + + if (!r?.provider) { + setKeyError('failed to save key') + setKeySaving(false) + + return + } + + // Update the provider in our list with fresh data + setProviders(prev => + prev.map(p => p.slug === r.provider!.slug ? r.provider! : p) + ) + setKeyInput('') + setKeySaving(false) + setStage('model') + setModelIdx(0) + }) + .catch((e: unknown) => { + setKeyError(rpcErrorMessage(e)) + setKeySaving(false) + }) + + return + } + + if (key.backspace || key.delete) { + setKeyInput(v => v.slice(0, -1)) + + return + } + + // ctrl+u clears input + if (ch === '\u0015') { + setKeyInput('') + + return + } + + if (ch && !key.ctrl && !key.meta) { + setKeyInput(v => v + ch) + } + + return + } + + // Disconnect confirmation stage + if (stage === 'disconnect') { + if (ch.toLowerCase() === 'y' || key.return) { + if (!provider) { + setStage('provider') + + return + } + + setKeySaving(true) + gw.request<{ disconnected?: boolean }>('model.disconnect', { + slug: provider.slug, + ...(sessionId ? { session_id: sessionId } : {}), + }) + .then(raw => { + const r = asRpcResult<{ disconnected?: boolean }>(raw) + + if (r?.disconnected) { + // Mark provider as unauthenticated in local state + setProviders(prev => + prev.map(p => p.slug === provider.slug + ? { ...p, authenticated: false, models: [], total_models: 0, warning: p.key_env ? `paste ${p.key_env} to activate` : 'run `hermes model` to configure' } + : p + ) + ) + } + + setKeySaving(false) + setStage('provider') + }) + .catch(() => { + setKeySaving(false) + setStage('provider') + }) + + return + } + + if (ch.toLowerCase() === 'n' || key.escape) { + setStage('provider') + + return + } + + return + } + const count = stage === 'provider' ? providers.length : models.length const sel = stage === 'provider' ? providerIdx : modelIdx const setSel = stage === 'provider' ? setProviderIdx : setModelIdx @@ -103,6 +223,18 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke return } + if (provider.authenticated === false) { + // api_key providers: prompt for key inline + if (provider.auth_type === 'api_key' && provider.key_env) { + setStage('key') + setKeyInput('') + setKeyError('') + } + + // Other auth types: no-op (warning shown tells them to run hermes model) + return + } + setStage('model') setModelIdx(0) @@ -126,22 +258,11 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke return } - const n = ch === '0' ? 10 : parseInt(ch, 10) + // Disconnect: only in provider stage, only for authenticated providers + if (ch.toLowerCase() === 'd' && stage === 'provider' && provider?.authenticated !== false) { + setStage('disconnect') - if (!Number.isNaN(n) && n >= 1 && n <= Math.min(10, count)) { - const offset = windowOffset(count, sel, VISIBLE) - - if (stage === 'provider') { - const next = offset + n - 1 - - if (providers[next]) { - setProviderIdx(next) - } - } else if (provider && models[offset + n - 1]) { - onSelect( - `${models[offset + n - 1]} --provider ${provider.slug}${persistGlobal ? ' --global' : ` ${TUI_SESSION_MODEL_FLAG}`}` - ) - } + return } }) @@ -161,15 +282,96 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke if (!providers.length) { return ( - no authenticated providers + no providers available Esc/q cancel ) } + // ── Key entry stage ────────────────────────────────────────────────── + if (stage === 'key' && provider) { + const masked = keyInput ? '•'.repeat(Math.min(keyInput.length, 40)) : '' + + return ( + + + Configure {provider.name} + + + + Paste your API key below (saved to ~/.hermes/.env) + + + + + + {provider.key_env}: + + + + {' '}{masked || '(empty)'}{keySaving ? '' : '▎'} + + + + + {keyError ? ( + + error: {keyError} + + ) : keySaving ? ( + + saving… + + ) : ( + + )} + + Enter save · Ctrl+U clear · Esc back + + ) + } + + // ── Disconnect confirmation stage ───────────────────────────────────── + if (stage === 'disconnect' && provider) { + return ( + + + Disconnect {provider.name}? + + + + + + This removes saved credentials for {provider.name}. + + + + You can re-authenticate later by selecting it again. + + + + + {keySaving ? ( + disconnecting… + ) : ( + y/Enter confirm · n/Esc cancel + )} + + ) + } + + // ── Provider selection stage ───────────────────────────────────────── if (stage === 'provider') { const rows = providers.map( - (p, i) => `${p.is_current ? '*' : ' '} ${names[i]} · ${p.total_models ?? p.models?.length ?? 0} models` + (p, i) => { + const authMark = p.authenticated === false ? '○' : p.is_current ? '*' : '●' + const modelCount = p.total_models ?? p.models?.length ?? 0 + const suffix = p.authenticated === false + ? (p.auth_type === 'api_key' ? '(no key)' : '(needs setup)') + : `${modelCount} models` + + return `${authMark} ${names[i]} · ${suffix}` + } ) const { items, offset } = windowItems(rows, providerIdx, VISIBLE) @@ -197,17 +399,19 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke {Array.from({ length: VISIBLE }, (_, i) => { const row = items[i] const idx = offset + i + const p = providers[idx] + const dimmed = p?.authenticated === false return row ? ( {providerIdx === idx ? '▸ ' : ' '} - {i + 1}. {row} + {idx + 1}. {row} ) : ( @@ -223,11 +427,12 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke persist: {persistGlobal ? 'global' : 'session'} · g toggle - ↑/↓ select · Enter choose · 1-9,0 quick · Esc/q cancel + ↑/↓ select · Enter choose · d disconnect · Esc/q cancel ) } + // ── Model selection stage ──────────────────────────────────────────── const { items, offset } = windowItems(models, modelIdx, VISIBLE) return ( @@ -273,7 +478,7 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke wrap="truncate-end" > {prefix} - {i + 1}. {row} + {idx + 1}. {row} ) })} @@ -286,7 +491,7 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke persist: {persistGlobal ? 'global' : 'session'} · g toggle - {models.length ? '↑/↓ select · Enter switch · 1-9,0 quick · Esc back · q close' : 'Enter/Esc back · q close'} + {models.length ? '↑/↓ select · Enter switch · Esc back · q close' : 'Enter/Esc back · q close'} ) diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index 02d878cb21..390e7af3e0 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -302,7 +302,10 @@ export interface ToolsConfigureResponse { // ── Model picker ───────────────────────────────────────────────────── export interface ModelOptionProvider { + auth_type?: string + authenticated?: boolean is_current?: boolean + key_env?: string models?: string[] name: string slug: string