From 386478211b1f1709f1b508b41b43bcaee7478b0c Mon Sep 17 00:00:00 2001 From: Donovan Yohan Date: Fri, 26 Jun 2026 21:58:47 +0000 Subject: [PATCH] fix(tui): preserve filtered model provider selection --- ui-tui/src/__tests__/modelPicker.test.ts | 22 ++++++++++++++++++ ui-tui/src/components/modelPicker.tsx | 29 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 ui-tui/src/__tests__/modelPicker.test.ts diff --git a/ui-tui/src/__tests__/modelPicker.test.ts b/ui-tui/src/__tests__/modelPicker.test.ts new file mode 100644 index 00000000000..df573f3dd00 --- /dev/null +++ b/ui-tui/src/__tests__/modelPicker.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, it } from 'vitest' + +import { providerIndexAfterClearingFilter } from '../components/modelPicker.js' +import type { ModelOptionProvider } from '../gatewayTypes.js' + +const provider = (slug: string, name = slug): ModelOptionProvider => ({ name, slug }) + +describe('ModelPicker provider filtering', () => { + it('keeps the selected provider when clearing the provider filter', () => { + const nous = provider('nous', 'Nous Portal') + const ollama = provider('ollama-cloud', 'Ollama Cloud') + + const rows = [ + { name: nous.name, provider: nous }, + { name: ollama.name, provider: ollama } + ] + + // With a provider-stage filter like "ollama", the selected row is index 0 + // in the filtered list, but index 1 in the full list after setFilter(''). + expect(providerIndexAfterClearingFilter(rows, ollama)).toBe(1) + }) +}) diff --git a/ui-tui/src/components/modelPicker.tsx b/ui-tui/src/components/modelPicker.tsx index c18fbe9f058..ea6c1b70645 100644 --- a/ui-tui/src/components/modelPicker.tsx +++ b/ui-tui/src/components/modelPicker.tsx @@ -17,6 +17,16 @@ const MAX_WIDTH = 90 type Stage = 'provider' | 'key' | 'model' | 'disconnect' +type ProviderRow = { name: string; provider: ModelOptionProvider } + +export function providerIndexAfterClearingFilter(providerRows: ProviderRow[], provider: ModelOptionProvider | undefined) { + if (!provider) { + return -1 + } + + return providerRows.findIndex(row => row.provider.slug === provider.slug) +} + export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect, sessionId, t }: ModelPickerProps) { const [providers, setProviders] = useState([]) const [currentModel, setCurrentModel] = useState('') @@ -307,6 +317,12 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect, if (provider.authenticated === false) { // api_key providers: prompt for key inline if (provider.auth_type === 'api_key' && provider.key_env) { + const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider) + + if (fullProviderIdx >= 0) { + setProviderIdx(fullProviderIdx) + } + setStage('key') setKeyInput('') setKeyError('') @@ -317,6 +333,12 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect, return } + const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider) + + if (fullProviderIdx >= 0) { + setProviderIdx(fullProviderIdx) + } + setStage('model') setModelIdx(0) setFilter('') @@ -365,7 +387,14 @@ export function ModelPicker({ allowPersistGlobal = true, gw, onCancel, onSelect, // Disconnect (Ctrl+D): only in provider stage, only for authenticated providers. if (key.ctrl && ch === 'd' && stage === 'provider' && provider?.authenticated !== false) { + const fullProviderIdx = providerIndexAfterClearingFilter(providerRows, provider) + + if (fullProviderIdx >= 0) { + setProviderIdx(fullProviderIdx) + } + setStage('disconnect') + setFilter('') return }