Merge pull request #18117 from NousResearch/austin/fix/model-selector

feat(tui): overhaul /model picker to match hermes model with inline auth
This commit is contained in:
Austin Pickett 2026-05-01 05:30:05 -07:00 committed by GitHub
commit 20132435c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 421 additions and 37 deletions

View file

@ -1085,9 +1085,7 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
from hermes_cli.config import get_compatible_custom_providers, load_config from hermes_cli.config import get_compatible_custom_providers, load_config
cfg = load_config() cfg = load_config()
user_provs = [ user_provs = cfg.get("providers")
{"provider": k, **v} for k, v in (cfg.get("providers") or {}).items()
]
custom_provs = get_compatible_custom_providers(cfg) custom_provs = get_compatible_custom_providers(cfg)
except Exception: except Exception:
pass pass
@ -4737,6 +4735,7 @@ def _(rid, params: dict) -> dict:
def _(rid, params: dict) -> dict: def _(rid, params: dict) -> dict:
try: try:
from hermes_cli.model_switch import list_authenticated_providers from hermes_cli.model_switch import list_authenticated_providers
from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS
session = _sessions.get(params.get("session_id", "")) session = _sessions.get(params.get("session_id", ""))
agent = session.get("agent") if session else None agent = session.get("agent") if session else None
@ -4750,6 +4749,127 @@ def _(rid, params: dict) -> dict:
# provider_model_ids() — that bypasses curation and pulls in # provider_model_ids() — that bypasses curation and pulls in
# non-agentic models (e.g. Nous /models returns ~400 IDs including # non-agentic models (e.g. Nous /models returns ~400 IDs including
# TTS, embeddings, rerankers, image/video generators). # TTS, embeddings, rerankers, image/video generators).
user_provs = (
cfg.get("providers") if isinstance(cfg.get("providers"), dict) else {}
)
custom_provs = (
cfg.get("custom_providers")
if isinstance(cfg.get("custom_providers"), list)
else []
)
authenticated = list_authenticated_providers(
current_provider=current_provider,
current_base_url=current_base_url,
current_model=current_model,
user_providers=user_provs,
custom_providers=custom_provs,
max_models=50,
)
# Mark authenticated providers and build lookup by slug
authed_map: dict = {}
authed_extra: list = [] # user-defined/custom not in CANONICAL_PROVIDERS
canonical_slugs = {e.slug for e in CANONICAL_PROVIDERS}
for p in authenticated:
p["authenticated"] = True
authed_map[p["slug"]] = p
if p["slug"] not in canonical_slugs:
authed_extra.append(p)
# Build final list in CANONICAL_PROVIDERS order, merging auth data
from hermes_cli.auth import PROVIDER_REGISTRY as _auth_reg
ordered: list = []
for entry in CANONICAL_PROVIDERS:
if entry.slug in authed_map:
ordered.append(authed_map[entry.slug])
else:
pconfig = _auth_reg.get(entry.slug)
auth_type = pconfig.auth_type if pconfig else "api_key"
key_env = pconfig.api_key_env_vars[0] if (pconfig and pconfig.api_key_env_vars) else ""
if auth_type == "api_key" and key_env:
warning = f"paste {key_env} to activate"
else:
warning = f"run `hermes model` to configure ({auth_type})"
ordered.append({
"slug": entry.slug,
"name": _PROVIDER_LABELS.get(entry.slug, entry.label),
"is_current": entry.slug == current_provider,
"is_user_defined": False,
"models": [],
"total_models": 0,
"source": "built-in",
"authenticated": False,
"auth_type": auth_type,
"key_env": key_env,
"warning": warning,
})
# Append user-defined/custom providers not in canonical list
ordered.extend(authed_extra)
return _ok(
rid,
{
"providers": ordered,
"model": current_model,
"provider": current_provider,
},
)
except Exception as e:
return _err(rid, 5033, str(e))
@method("model.save_key")
def _(rid, params: dict) -> dict:
"""Save an API key for a provider, then return its refreshed model list.
Params:
slug: provider slug (e.g. "deepseek", "xai")
api_key: the key value to save
Returns the provider dict with models populated (same shape as
model.options entries) on success.
"""
try:
from hermes_cli.auth import PROVIDER_REGISTRY
from hermes_cli.config import is_managed, save_env_value
from hermes_cli.model_switch import list_authenticated_providers
slug = (params.get("slug") or "").strip()
api_key = (params.get("api_key") or "").strip()
if not slug or not api_key:
return _err(rid, 4001, "slug and api_key are required")
if is_managed():
return _err(rid, 4006, "managed install — credentials are read-only")
pconfig = PROVIDER_REGISTRY.get(slug)
if not pconfig:
return _err(rid, 4002, f"unknown provider: {slug}")
if pconfig.auth_type != "api_key":
return _err(
rid, 4003,
f"{pconfig.name} uses {pconfig.auth_type} auth — "
f"run `hermes model` to configure"
)
if not pconfig.api_key_env_vars:
return _err(rid, 4004, f"no env var defined for {pconfig.name}")
# Save the key to ~/.hermes/.env
env_var = pconfig.api_key_env_vars[0]
save_env_value(env_var, api_key)
# Also set in current process so list_authenticated_providers sees it
import os
os.environ[env_var] = api_key
# Refresh provider data
cfg = _load_cfg()
session = _sessions.get(params.get("session_id", ""))
agent = session.get("agent") if session else None
current_provider = getattr(agent, "provider", "") or ""
current_model = getattr(agent, "model", "") or _resolve_model()
current_base_url = getattr(agent, "base_url", "") or ""
providers = list_authenticated_providers( providers = list_authenticated_providers(
current_provider=current_provider, current_provider=current_provider,
current_base_url=current_base_url, current_base_url=current_base_url,
@ -4764,16 +4884,72 @@ def _(rid, params: dict) -> dict:
), ),
max_models=50, max_models=50,
) )
return _ok(
rid, # Find the newly-authenticated provider
{ provider_data = None
"providers": providers, for p in providers:
"model": current_model, if p["slug"] == slug:
"provider": current_provider, provider_data = p
}, break
)
if not provider_data:
# Key was saved but provider didn't appear — still return success
provider_data = {
"slug": slug,
"name": pconfig.name,
"is_current": False,
"models": [],
"total_models": 0,
"authenticated": True,
}
provider_data["authenticated"] = True
return _ok(rid, {"provider": provider_data})
except Exception as e: except Exception as e:
return _err(rid, 5033, str(e)) return _err(rid, 5034, str(e))
@method("model.disconnect")
def _(rid, params: dict) -> dict:
"""Remove credentials for a provider.
Params:
slug: provider slug (e.g. "deepseek", "xai")
Returns success status and the provider's slug.
"""
try:
from hermes_cli.auth import PROVIDER_REGISTRY, clear_provider_auth
from hermes_cli.config import remove_env_value
slug = (params.get("slug") or "").strip()
if not slug:
return _err(rid, 4001, "slug is required")
pconfig = PROVIDER_REGISTRY.get(slug)
cleared_env = False
cleared_auth = False
# Remove API key env vars from .env and process
if pconfig and pconfig.api_key_env_vars:
for ev in pconfig.api_key_env_vars:
if remove_env_value(ev):
cleared_env = True
# Clear OAuth / credential pool state
cleared_auth = clear_provider_auth(slug)
if not cleared_env and not cleared_auth:
return _err(rid, 4005, f"no credentials found for {slug}")
provider_name = pconfig.name if pconfig else slug
return _ok(rid, {
"slug": slug,
"name": provider_name,
"disconnected": True,
})
except Exception as e:
return _err(rid, 5035, str(e))
# ── Methods: slash.exec ────────────────────────────────────────────── # ── Methods: slash.exec ──────────────────────────────────────────────

View file

@ -8,12 +8,14 @@ import type { ModelOptionProvider, ModelOptionsResponse } from '../gatewayTypes.
import { asRpcResult, rpcErrorMessage } from '../lib/rpc.js' import { asRpcResult, rpcErrorMessage } from '../lib/rpc.js'
import type { Theme } from '../theme.js' import type { Theme } from '../theme.js'
import { OverlayHint, useOverlayKeys, windowItems, windowOffset } from './overlayControls.js' import { OverlayHint, useOverlayKeys, windowItems } from './overlayControls.js'
const VISIBLE = 12 const VISIBLE = 12
const MIN_WIDTH = 40 const MIN_WIDTH = 40
const MAX_WIDTH = 90 const MAX_WIDTH = 90
type Stage = 'provider' | 'key' | 'model' | 'disconnect'
export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPickerProps) { export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPickerProps) {
const [providers, setProviders] = useState<ModelOptionProvider[]>([]) const [providers, setProviders] = useState<ModelOptionProvider[]>([])
const [currentModel, setCurrentModel] = useState('') const [currentModel, setCurrentModel] = useState('')
@ -22,7 +24,10 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
const [persistGlobal, setPersistGlobal] = useState(false) const [persistGlobal, setPersistGlobal] = useState(false)
const [providerIdx, setProviderIdx] = useState(0) const [providerIdx, setProviderIdx] = useState(0)
const [modelIdx, setModelIdx] = useState(0) const [modelIdx, setModelIdx] = useState(0)
const [stage, setStage] = useState<'model' | 'provider'>('provider') const [stage, setStage] = useState<Stage>('provider')
const [keyInput, setKeyInput] = useState('')
const [keySaving, setKeySaving] = useState(false)
const [keyError, setKeyError] = useState('')
const { stdout } = useStdout() const { stdout } = useStdout()
// Pin the picker to a stable width so the FloatBox parent (which shrinks- // Pin the picker to a stable width so the FloatBox parent (which shrinks-
@ -68,9 +73,12 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
const names = useMemo(() => providerDisplayNames(providers), [providers]) const names = useMemo(() => providerDisplayNames(providers), [providers])
const back = () => { const back = () => {
if (stage === 'model') { if (stage === 'model' || stage === 'key' || stage === 'disconnect') {
setStage('provider') setStage('provider')
setModelIdx(0) setModelIdx(0)
setKeyInput('')
setKeyError('')
setKeySaving(false)
return return
} }
@ -81,6 +89,118 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
useOverlayKeys({ onBack: back, onClose: onCancel }) useOverlayKeys({ onBack: back, onClose: onCancel })
useInput((ch, key) => { useInput((ch, key) => {
// Key entry stage handles its own input
if (stage === 'key') {
if (keySaving) {
return
}
if (key.return) {
if (!keyInput.trim()) {
return
}
setKeySaving(true)
setKeyError('')
gw.request<{ provider?: ModelOptionProvider }>('model.save_key', {
slug: provider?.slug,
api_key: keyInput.trim(),
...(sessionId ? { session_id: sessionId } : {}),
})
.then(raw => {
const r = asRpcResult<{ provider?: ModelOptionProvider }>(raw)
if (!r?.provider) {
setKeyError('failed to save key')
setKeySaving(false)
return
}
// Update the provider in our list with fresh data
setProviders(prev =>
prev.map(p => p.slug === r.provider!.slug ? r.provider! : p)
)
setKeyInput('')
setKeySaving(false)
setStage('model')
setModelIdx(0)
})
.catch((e: unknown) => {
setKeyError(rpcErrorMessage(e))
setKeySaving(false)
})
return
}
if (key.backspace || key.delete) {
setKeyInput(v => v.slice(0, -1))
return
}
// ctrl+u clears input
if (ch === '\u0015') {
setKeyInput('')
return
}
if (ch && !key.ctrl && !key.meta) {
setKeyInput(v => v + ch)
}
return
}
// Disconnect confirmation stage
if (stage === 'disconnect') {
if (ch.toLowerCase() === 'y' || key.return) {
if (!provider) {
setStage('provider')
return
}
setKeySaving(true)
gw.request<{ disconnected?: boolean }>('model.disconnect', {
slug: provider.slug,
...(sessionId ? { session_id: sessionId } : {}),
})
.then(raw => {
const r = asRpcResult<{ disconnected?: boolean }>(raw)
if (r?.disconnected) {
// Mark provider as unauthenticated in local state
setProviders(prev =>
prev.map(p => p.slug === provider.slug
? { ...p, authenticated: false, models: [], total_models: 0, warning: p.key_env ? `paste ${p.key_env} to activate` : 'run `hermes model` to configure' }
: p
)
)
}
setKeySaving(false)
setStage('provider')
})
.catch(() => {
setKeySaving(false)
setStage('provider')
})
return
}
if (ch.toLowerCase() === 'n' || key.escape) {
setStage('provider')
return
}
return
}
const count = stage === 'provider' ? providers.length : models.length const count = stage === 'provider' ? providers.length : models.length
const sel = stage === 'provider' ? providerIdx : modelIdx const sel = stage === 'provider' ? providerIdx : modelIdx
const setSel = stage === 'provider' ? setProviderIdx : setModelIdx const setSel = stage === 'provider' ? setProviderIdx : setModelIdx
@ -103,6 +223,18 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
return return
} }
if (provider.authenticated === false) {
// api_key providers: prompt for key inline
if (provider.auth_type === 'api_key' && provider.key_env) {
setStage('key')
setKeyInput('')
setKeyError('')
}
// Other auth types: no-op (warning shown tells them to run hermes model)
return
}
setStage('model') setStage('model')
setModelIdx(0) setModelIdx(0)
@ -126,22 +258,11 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
return return
} }
const n = ch === '0' ? 10 : parseInt(ch, 10) // Disconnect: only in provider stage, only for authenticated providers
if (ch.toLowerCase() === 'd' && stage === 'provider' && provider?.authenticated !== false) {
setStage('disconnect')
if (!Number.isNaN(n) && n >= 1 && n <= Math.min(10, count)) { return
const offset = windowOffset(count, sel, VISIBLE)
if (stage === 'provider') {
const next = offset + n - 1
if (providers[next]) {
setProviderIdx(next)
}
} else if (provider && models[offset + n - 1]) {
onSelect(
`${models[offset + n - 1]} --provider ${provider.slug}${persistGlobal ? ' --global' : ` ${TUI_SESSION_MODEL_FLAG}`}`
)
}
} }
}) })
@ -161,15 +282,96 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
if (!providers.length) { if (!providers.length) {
return ( return (
<Box flexDirection="column"> <Box flexDirection="column">
<Text color={t.color.muted}>no authenticated providers</Text> <Text color={t.color.muted}>no providers available</Text>
<OverlayHint t={t}>Esc/q cancel</OverlayHint> <OverlayHint t={t}>Esc/q cancel</OverlayHint>
</Box> </Box>
) )
} }
// ── Key entry stage ──────────────────────────────────────────────────
if (stage === 'key' && provider) {
const masked = keyInput ? '•'.repeat(Math.min(keyInput.length, 40)) : ''
return (
<Box flexDirection="column" width={width}>
<Text bold color={t.color.accent} wrap="truncate-end">
Configure {provider.name}
</Text>
<Text color={t.color.muted} wrap="truncate-end">
Paste your API key below (saved to ~/.hermes/.env)
</Text>
<Text color={t.color.muted} wrap="truncate-end"> </Text>
<Text color={t.color.muted} wrap="truncate-end">
{provider.key_env}:
</Text>
<Text color={t.color.accent} wrap="truncate-end">
{' '}{masked || '(empty)'}{keySaving ? '' : '▎'}
</Text>
<Text color={t.color.muted} wrap="truncate-end"> </Text>
{keyError ? (
<Text color={t.color.label} wrap="truncate-end">
error: {keyError}
</Text>
) : keySaving ? (
<Text color={t.color.muted} wrap="truncate-end">
saving
</Text>
) : (
<Text color={t.color.muted} wrap="truncate-end"> </Text>
)}
<OverlayHint t={t}>Enter save · Ctrl+U clear · Esc back</OverlayHint>
</Box>
)
}
// ── Disconnect confirmation stage ─────────────────────────────────────
if (stage === 'disconnect' && provider) {
return (
<Box flexDirection="column" width={width}>
<Text bold color={t.color.accent} wrap="truncate-end">
Disconnect {provider.name}?
</Text>
<Text color={t.color.muted} wrap="truncate-end"> </Text>
<Text color={t.color.muted} wrap="truncate-end">
This removes saved credentials for {provider.name}.
</Text>
<Text color={t.color.muted} wrap="truncate-end">
You can re-authenticate later by selecting it again.
</Text>
<Text color={t.color.muted} wrap="truncate-end"> </Text>
{keySaving ? (
<Text color={t.color.muted} wrap="truncate-end">disconnecting</Text>
) : (
<OverlayHint t={t}>y/Enter confirm · n/Esc cancel</OverlayHint>
)}
</Box>
)
}
// ── Provider selection stage ─────────────────────────────────────────
if (stage === 'provider') { if (stage === 'provider') {
const rows = providers.map( const rows = providers.map(
(p, i) => `${p.is_current ? '*' : ' '} ${names[i]} · ${p.total_models ?? p.models?.length ?? 0} models` (p, i) => {
const authMark = p.authenticated === false ? '○' : p.is_current ? '*' : '●'
const modelCount = p.total_models ?? p.models?.length ?? 0
const suffix = p.authenticated === false
? (p.auth_type === 'api_key' ? '(no key)' : '(needs setup)')
: `${modelCount} models`
return `${authMark} ${names[i]} · ${suffix}`
}
) )
const { items, offset } = windowItems(rows, providerIdx, VISIBLE) const { items, offset } = windowItems(rows, providerIdx, VISIBLE)
@ -197,17 +399,19 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
{Array.from({ length: VISIBLE }, (_, i) => { {Array.from({ length: VISIBLE }, (_, i) => {
const row = items[i] const row = items[i]
const idx = offset + i const idx = offset + i
const p = providers[idx]
const dimmed = p?.authenticated === false
return row ? ( return row ? (
<Text <Text
bold={providerIdx === idx} bold={providerIdx === idx}
color={providerIdx === idx ? t.color.accent : t.color.muted} color={providerIdx === idx ? t.color.accent : dimmed ? t.color.label : t.color.muted}
inverse={providerIdx === idx} inverse={providerIdx === idx}
key={providers[idx]?.slug ?? `row-${idx}`} key={providers[idx]?.slug ?? `row-${idx}`}
wrap="truncate-end" wrap="truncate-end"
> >
{providerIdx === idx ? '▸ ' : ' '} {providerIdx === idx ? '▸ ' : ' '}
{i + 1}. {row} {idx + 1}. {row}
</Text> </Text>
) : ( ) : (
<Text color={t.color.muted} key={`pad-${i}`} wrap="truncate-end"> <Text color={t.color.muted} key={`pad-${i}`} wrap="truncate-end">
@ -223,11 +427,12 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
<Text color={t.color.muted} wrap="truncate-end"> <Text color={t.color.muted} wrap="truncate-end">
persist: {persistGlobal ? 'global' : 'session'} · g toggle persist: {persistGlobal ? 'global' : 'session'} · g toggle
</Text> </Text>
<OverlayHint t={t}>/ select · Enter choose · 1-9,0 quick · Esc/q cancel</OverlayHint> <OverlayHint t={t}>/ select · Enter choose · d disconnect · Esc/q cancel</OverlayHint>
</Box> </Box>
) )
} }
// ── Model selection stage ────────────────────────────────────────────
const { items, offset } = windowItems(models, modelIdx, VISIBLE) const { items, offset } = windowItems(models, modelIdx, VISIBLE)
return ( return (
@ -273,7 +478,7 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
wrap="truncate-end" wrap="truncate-end"
> >
{prefix} {prefix}
{i + 1}. {row} {idx + 1}. {row}
</Text> </Text>
) )
})} })}
@ -286,7 +491,7 @@ export function ModelPicker({ gw, onCancel, onSelect, sessionId, t }: ModelPicke
persist: {persistGlobal ? 'global' : 'session'} · g toggle persist: {persistGlobal ? 'global' : 'session'} · g toggle
</Text> </Text>
<OverlayHint t={t}> <OverlayHint t={t}>
{models.length ? '↑/↓ select · Enter switch · 1-9,0 quick · Esc back · q close' : 'Enter/Esc back · q close'} {models.length ? '↑/↓ select · Enter switch · Esc back · q close' : 'Enter/Esc back · q close'}
</OverlayHint> </OverlayHint>
</Box> </Box>
) )

View file

@ -302,7 +302,10 @@ export interface ToolsConfigureResponse {
// ── Model picker ───────────────────────────────────────────────────── // ── Model picker ─────────────────────────────────────────────────────
export interface ModelOptionProvider { export interface ModelOptionProvider {
auth_type?: string
authenticated?: boolean
is_current?: boolean is_current?: boolean
key_env?: string
models?: string[] models?: string[]
name: string name: string
slug: string slug: string