fix(tui): preserve filtered model provider selection

This commit is contained in:
Donovan Yohan 2026-06-26 21:58:47 +00:00 committed by kshitij
parent b0f44d3fad
commit 386478211b
2 changed files with 51 additions and 0 deletions

View file

@ -0,0 +1,22 @@
import { describe, expect, it } from 'vitest'
import { providerIndexAfterClearingFilter } from '../components/modelPicker.js'
import type { ModelOptionProvider } from '../gatewayTypes.js'
const provider = (slug: string, name = slug): ModelOptionProvider => ({ name, slug })
describe('ModelPicker provider filtering', () => {
it('keeps the selected provider when clearing the provider filter', () => {
const nous = provider('nous', 'Nous Portal')
const ollama = provider('ollama-cloud', 'Ollama Cloud')
const rows = [
{ name: nous.name, provider: nous },
{ name: ollama.name, provider: ollama }
]
// With a provider-stage filter like "ollama", the selected row is index 0
// in the filtered list, but index 1 in the full list after setFilter('').
expect(providerIndexAfterClearingFilter(rows, ollama)).toBe(1)
})
})

View file

@ -17,6 +17,16 @@ const MAX_WIDTH = 90
type Stage = 'provider' | 'key' | 'model' | 'disconnect'
type ProviderRow = { name: string; provider: ModelOptionProvider }
export function providerIndexAfterClearingFilter(providerRows: ProviderRow[], provider: ModelOptionProvider | undefined) {
if (!provider) {
return -1
}
return providerRows.findIndex(row => row.provider.slug === provider.slug)
}
export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect, sessionId, t }: ModelPickerProps) {
const [providers, setProviders] = useState<ModelOptionProvider[]>([])
const [currentModel, setCurrentModel] = useState('')
@ -307,6 +317,12 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect,
if (provider.authenticated === false) {
// api_key providers: prompt for key inline
if (provider.auth_type === 'api_key' && provider.key_env) {
const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider)
if (fullProviderIdx >= 0) {
setProviderIdx(fullProviderIdx)
}
setStage('key')
setKeyInput('')
setKeyError('')
@ -317,6 +333,12 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect,
return
}
const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider)
if (fullProviderIdx >= 0) {
setProviderIdx(fullProviderIdx)
}
setStage('model')
setModelIdx(0)
setFilter('')
@ -365,7 +387,14 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect,
// Disconnect (Ctrl+D): only in provider stage, only for authenticated providers.
if (key.ctrl && ch === 'd' && stage === 'provider' && provider?.authenticated !== false) {
const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider)
if (fullProviderIdx >= 0) {
setProviderIdx(fullProviderIdx)
}
setStage('disconnect')
setFilter('')
return
}