diff --git a/cmd/opencodereview/config_cmd.go b/cmd/opencodereview/config_cmd.go index ccd3afe..8d95295 100644 --- a/cmd/opencodereview/config_cmd.go +++ b/cmd/opencodereview/config_cmd.go @@ -370,6 +370,16 @@ func parseModelListValue(value string) ([]string, error) { return normalizeModelList(strings.Split(value, ",")), nil } +func activeModelForProvider(cfg *Config, providerName string, entry ProviderEntry) string { + if entry.Model != "" { + return entry.Model + } + if cfg != nil && cfg.Provider == providerName && cfg.Model != "" { + return cfg.Model + } + return "" +} + func normalizeModelList(models []string) []string { out := make([]string, 0, len(models)) seen := make(map[string]struct{}, len(models)) @@ -395,6 +405,19 @@ func mergeModelLists(lists ...[]string) []string { return normalizeModelList(merged) } +// ensureModelInList appends model to the end when missing; never reorders existing entries. +func ensureModelInList(models []string, model string) []string { + model = strings.TrimSpace(model) + if model == "" { + return models + } + if modelListContains(models, model) { + return models + } + out := append([]string(nil), models...) + return append(out, model) +} + func modelListContains(models []string, target string) bool { target = strings.TrimSpace(target) for _, model := range models { diff --git a/cmd/opencodereview/config_cmd_test.go b/cmd/opencodereview/config_cmd_test.go index 875b2f6..741794c 100644 --- a/cmd/opencodereview/config_cmd_test.go +++ b/cmd/opencodereview/config_cmd_test.go @@ -339,3 +339,23 @@ func TestUnsetInvalidKey(t *testing.T) { }) } } + +func TestEnsureModelInList(t *testing.T) { + models := []string{"test-model", "test-model-2", "bbb", "aaa", "test-model-3"} + + got := ensureModelInList(models, "test-model-3") + if len(got) != len(models) { + t.Fatalf("existing model should not reorder: got %v", got) + } + for i := range models { + if got[i] != models[i] { + t.Errorf("models[%d] = %q, want %q", i, got[i], models[i]) + } + } + + got = ensureModelInList(models, "new-model") + want := append(append([]string(nil), models...), "new-model") + if len(got) != len(want) || got[len(got)-1] != "new-model" { + t.Errorf("new model should append: got %v, want %v", got, want) + } +} diff --git a/cmd/opencodereview/provider_cmd.go b/cmd/opencodereview/provider_cmd.go index 9e29394..995efb7 100644 --- a/cmd/opencodereview/provider_cmd.go +++ b/cmd/opencodereview/provider_cmd.go @@ -22,7 +22,7 @@ func runConfigProvider() error { return fmt.Errorf("load config: %w", err) } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, configPath) p := tea.NewProgram(m) finalModel, err := p.Run() if err != nil { @@ -31,19 +31,11 @@ func runConfigProvider() error { final := finalModel.(providerTUIModel) - if len(final.deletedProviders) > 0 { - clearedActive, err := applyProviderDeletions(configPath, cfg, final.deletedProviders) - if err != nil { - return err - } - if clearedActive && !final.confirmed { - fmt.Fprintf(os.Stderr, "[ocr] WARNING: active provider was deleted; 'provider' and 'model' have been cleared.\n") - fmt.Fprintf(os.Stderr, "[ocr] Run 'ocr config provider' to select a new provider.\n") - } - } - if !final.confirmed { - if len(final.deletedProviders) > 0 { + // TUI persists changes (create/edit/model/add/delete) directly to disk + // during the session, so the on-disk file is already up to date for any + // savedInSession operation. No additional post-TUI apply step is needed. + if final.savedInSession { return nil } fmt.Println("Cancelled.") @@ -82,6 +74,21 @@ func applyProviderDeletions(configPath string, cfg *Config, names []string) (boo return clearedActive, nil } +func removeModels(existing, toRemove []string) []string { + removeSet := make(map[string]struct{}, len(toRemove)) + for _, m := range toRemove { + removeSet[m] = struct{}{} + } + result := make([]string, 0, len(existing)) + for _, m := range existing { + if _, found := removeSet[m]; found { + continue + } + result = append(result, m) + } + return result +} + func applyManualConfig(configPath string, cfg *Config, result providerTUIResult) error { if result.url == "" { return fmt.Errorf("URL is required for manual configuration") @@ -95,6 +102,13 @@ func applyManualConfig(configPath string, cfg *Config, result providerTUIResult) cfg.Llm.URL = result.url cfg.Llm.Model = result.model cfg.Llm.AuthToken = result.apiKey + authHeader, err := llm.NormalizeAuthHeader(result.authHeader) + if err != nil { + return fmt.Errorf("invalid auth_header: %w", err) + } + cfg.Llm.AuthHeader = authHeader + useAnthropic := result.protocol == "anthropic" + cfg.Llm.UseAnthropic = &useAnthropic if err := saveConfig(configPath, cfg); err != nil { return err @@ -102,6 +116,7 @@ func applyManualConfig(configPath string, cfg *Config, result providerTUIResult) fmt.Println("\nManual configuration saved.") fmt.Printf("URL: %s\n", result.url) + fmt.Printf("Protocol: %s\n", result.protocol) fmt.Printf("Model: %s\n", result.model) fmt.Println("\nTesting connection...") @@ -129,8 +144,9 @@ func applyCustomProviderConfig(configPath string, cfg *Config, result providerTU entry := cfg.CustomProviders[result.provider] entry.Model = result.model if len(result.models) > 0 { - entry.Models = mergeModelLists([]string{result.model}, result.models) + entry.Models = append([]string(nil), result.models...) } + entry.Models = ensureModelInList(entry.Models, result.model) if result.url != "" { entry.URL = result.url } @@ -138,22 +154,39 @@ func applyCustomProviderConfig(configPath string, cfg *Config, result providerTU entry.Protocol = result.protocol } if result.authHeader != "" { - entry.AuthHeader = result.authHeader + authHeader, err := llm.NormalizeAuthHeader(result.authHeader) + if err != nil { + return fmt.Errorf("invalid auth_header: %w", err) + } + entry.AuthHeader = authHeader } if result.apiKey != "" { entry.APIKey = result.apiKey } cfg.CustomProviders[result.provider] = entry - if cfg.Provider != result.provider { - cfg.Model = "" + if !result.isEdit { + cfg.Provider = result.provider + cfg.Model = result.model + } else if cfg.Provider == result.provider { + cfg.Model = result.model } - cfg.Provider = result.provider if err := saveConfig(configPath, cfg); err != nil { return err } + if result.isEdit { + if cfg.Provider == result.provider { + fmt.Printf("\nActive provider %q updated.\n", result.provider) + } else { + fmt.Printf("\nCustom provider %q updated (not currently active).\n", result.provider) + } + fmt.Printf("Model: %s\n", result.model) + fmt.Println("\nTip: run 'ocr config model' to switch model later.") + return nil + } + fmt.Printf("\nProvider set to: %s (custom)\n", result.provider) fmt.Printf("Model: %s\n", result.model) @@ -203,6 +236,7 @@ func applyOfficialProviderConfig(configPath string, cfg *Config, result provider cfg.Model = "" } cfg.Provider = result.provider + cfg.Model = result.model if err := saveConfig(configPath, cfg); err != nil { return err @@ -243,7 +277,7 @@ func runConfigModel() error { if preset, isPreset := llm.LookupProvider(cfg.Provider); isPreset { provider = preset if entry, ok := cfg.Providers[cfg.Provider]; ok { - currentModel = entry.Model + currentModel = activeModelForProvider(cfg, cfg.Provider, entry) provider.Models = mergeModelLists(provider.Models, entry.Models) } } else { @@ -252,15 +286,12 @@ func runConfigModel() error { if !ok { return fmt.Errorf("provider %q is not configured in custom_providers", cfg.Provider) } - currentModel = entry.Model + currentModel = activeModelForProvider(cfg, cfg.Provider, entry) provider.DisplayName = cfg.Provider + " (custom)" provider.Protocol = entry.Protocol provider.BaseURL = entry.URL provider.Models = mergeModelLists(entry.Models) } - if currentModel == "" { - currentModel = cfg.Model - } m := newModelTUI(provider, currentModel) p := tea.NewProgram(m) @@ -286,7 +317,7 @@ func runConfigModel() error { } entry := cfg.CustomProviders[cfg.Provider] entry.Model = selectedModel - entry.Models = mergeModelLists([]string{selectedModel}, entry.Models) + entry.Models = ensureModelInList(entry.Models, selectedModel) cfg.CustomProviders[cfg.Provider] = entry } else { if cfg.Providers == nil { @@ -295,10 +326,11 @@ func runConfigModel() error { entry := cfg.Providers[cfg.Provider] entry.Model = selectedModel if !modelListContains(provider.Models, selectedModel) { - entry.Models = mergeModelLists([]string{selectedModel}, entry.Models) + entry.Models = ensureModelInList(entry.Models, selectedModel) } cfg.Providers[cfg.Provider] = entry } + cfg.Model = selectedModel if err := saveConfig(configPath, cfg); err != nil { return err diff --git a/cmd/opencodereview/provider_tui.go b/cmd/opencodereview/provider_tui.go index 52e9493..2e0b795 100644 --- a/cmd/opencodereview/provider_tui.go +++ b/cmd/opencodereview/provider_tui.go @@ -36,8 +36,6 @@ const ( cpStepName customProviderStep = iota cpStepProtocol cpStepBaseURL - cpStepModel - cpStepModels cpStepAPIKey cpStepAuthHeader ) @@ -46,8 +44,10 @@ type manualStep int const ( manualStepURL manualStep = iota + manualStepProtocol manualStepModel manualStepAuthToken + manualStepAuthHeader ) var cpProtocols = []string{"anthropic", "openai"} @@ -58,15 +58,17 @@ type customProviderListItem struct { } type providerTUIResult struct { - provider string - model string - models []string - apiKey string - isCustom bool - isManual bool - url string - protocol string - authHeader string + provider string + model string + models []string + apiKey string + isCustom bool + isEdit bool + editTargetName string + isManual bool + url string + protocol string + authHeader string } type providerTUIModel struct { @@ -84,20 +86,24 @@ type providerTUIModel struct { customProviders []customProviderListItem customIdx int creatingCustom bool + editingCustom bool + editTargetName string cpStep customProviderStep cpProtocolIdx int cpNameInput textinput.Model cpURLInput textinput.Model - cpModelInput textinput.Model - cpModelsInput textinput.Model cpAuthInput textinput.Model // --- tab: manual --- - inManualForm bool - manualStep manualStep - manualURLInput textinput.Model - manualModelInput textinput.Model - manualTokenInput textinput.Model + inManualForm bool + manualStep manualStep + manualProtocolIdx int + manualURLInput textinput.Model + manualModelInput textinput.Model + manualAuthHeaderInput textinput.Model + manualTokenInput textinput.Model + manualTokenMasked bool + manualTokenOriginal string // --- shared model/api-key steps (official + existing custom) --- modelIdx int @@ -108,15 +114,36 @@ type providerTUIModel struct { apiKeyMasked bool apiKeyOriginal string - existingCfg *Config - confirmed bool - cancelled bool + existingCfg *Config + configPath string + confirmed bool + cancelled bool + formError string + savedInSession bool // --- delete confirmation --- - confirmingDelete bool - deleteTargetIdx int - deleteTargetName string - deletedProviders []string + confirmingDelete bool + deleteTargetIdx int + deleteTargetName string + deletedProviders []string + confirmingDeleteModel bool + deleteModelName string +} + +func (m providerTUIModel) customProviderNameTaken(name string) bool { + if m.existingCfg == nil || m.existingCfg.CustomProviders == nil { + return false + } + _, exists := m.existingCfg.CustomProviders[name] + return exists +} + +func (m providerTUIModel) customProviderActiveModel(cp customProviderListItem) string { + if m.existingCfg == nil || m.existingCfg.Provider != cp.name { + return "" + } + entry := m.customProviderEntry(cp.name, cp.entry) + return activeModelForProvider(m.existingCfg, cp.name, entry) } func collectCustomProviders(cfg *Config) []customProviderListItem { @@ -131,7 +158,7 @@ func collectCustomProviders(cfg *Config) []customProviderListItem { return out } -func newProviderTUI(cfg *Config) providerTUIModel { +func newProviderTUI(cfg *Config, configPath string) providerTUIModel { providers := llm.ListProviders() sort.SliceStable(providers, func(i, j int) bool { left := strings.ToLower(providers[i].DisplayName) @@ -143,8 +170,8 @@ func newProviderTUI(cfg *Config) providerTUIModel { }) mi := textinput.New() - mi.Placeholder = "model name" - mi.SetWidth(40) + mi.Placeholder = "enter model name" + mi.SetWidth(50) ai := textinput.New() ai.Placeholder = "paste your API key here" @@ -160,14 +187,6 @@ func newProviderTUI(cfg *Config) providerTUIModel { cpURL.Placeholder = "enter your API base URL" cpURL.SetWidth(50) - cpModel := textinput.New() - cpModel.Placeholder = "model name" - cpModel.SetWidth(40) - - cpModels := textinput.New() - cpModels.Placeholder = "optional comma-separated models" - cpModels.SetWidth(50) - cpAuth := textinput.New() cpAuth.Placeholder = "optional, leave empty for default (Authorization)" cpAuth.SetWidth(55) @@ -180,6 +199,10 @@ func newProviderTUI(cfg *Config) providerTUIModel { manualModel.Placeholder = "enter model name" manualModel.SetWidth(40) + manualAuthHeader := textinput.New() + manualAuthHeader.Placeholder = "optional, leave empty for default (Authorization)" + manualAuthHeader.SetWidth(55) + manualToken := textinput.New() manualToken.Placeholder = "enter your auth token" manualToken.SetWidth(50) @@ -187,22 +210,22 @@ func newProviderTUI(cfg *Config) providerTUIModel { manualToken.EchoCharacter = '*' m := providerTUIModel{ - providers: providers, - existingCfg: cfg, - modelInput: mi, - apiKeyInput: ai, - cpNameInput: cpName, - cpURLInput: cpURL, - cpModelInput: cpModel, - cpModelsInput: cpModels, - cpAuthInput: cpAuth, - manualURLInput: manualURL, - manualModelInput: manualModel, - manualTokenInput: manualToken, - width: 80, - height: 24, - activeTab: tabOfficial, - customProviders: collectCustomProviders(cfg), + providers: providers, + existingCfg: cfg, + modelInput: mi, + apiKeyInput: ai, + cpNameInput: cpName, + cpURLInput: cpURL, + cpAuthInput: cpAuth, + manualURLInput: manualURL, + manualModelInput: manualModel, + manualAuthHeaderInput: manualAuthHeader, + manualTokenInput: manualToken, + width: 80, + height: 24, + activeTab: tabOfficial, + customProviders: collectCustomProviders(cfg), + configPath: configPath, } providerFound := false @@ -253,12 +276,23 @@ func newProviderTUI(cfg *Config) providerTUIModel { if cfg.Provider == "" && cfg.Llm.URL != "" { m.activeTab = tabManual } + // Intentionally do not auto-switch activeTab to tabCustom when only custom + // providers exist — leave the cursor on Official so users navigate + // explicitly via Tab/Right. if cfg.Llm.URL != "" { m.manualURLInput.SetValue(cfg.Llm.URL) m.manualModelInput.SetValue(cfg.Llm.Model) + m.manualAuthHeaderInput.SetValue(cfg.Llm.AuthHeader) if cfg.Llm.AuthToken != "" { - m.manualTokenInput.SetValue(cfg.Llm.AuthToken) + m.manualTokenOriginal = cfg.Llm.AuthToken + m.manualTokenMasked = true + m.manualTokenInput.SetValue(strings.Repeat("*", 20)) + } + if cfg.Llm.UseAnthropic == nil || *cfg.Llm.UseAnthropic { + m.manualProtocolIdx = 0 // anthropic + } else { + m.manualProtocolIdx = 1 // openai } } @@ -336,6 +370,63 @@ func (m *providerTUIModel) prepareModelSelection(currentModel string) { m.modelInput.SetValue(currentModel) } +func (m *providerTUIModel) customProviderEntry(name string, fallback ProviderEntry) ProviderEntry { + if m.existingCfg != nil { + if entry, ok := m.existingCfg.CustomProviders[name]; ok { + return entry + } + } + return fallback +} + +func (m *providerTUIModel) syncSessionModelSelection() error { + if m.existingCfg == nil { + return nil + } + model := m.selectedModelFromState() + if model == "" { + return nil + } + + switch m.activeTab { + case tabCustom: + cp, ok := m.selectedCustomProvider() + if !ok { + return nil + } + entry := m.customProviderEntry(cp.name, cp.entry) + entry.Model = model + if m.existingCfg.CustomProviders == nil { + m.existingCfg.CustomProviders = make(map[string]ProviderEntry) + } + m.existingCfg.CustomProviders[cp.name] = entry + cp.entry = entry + m.customProviders[m.customIdx] = cp + if m.existingCfg.Provider == cp.name { + m.existingCfg.Model = model + } + case tabOfficial: + provider := m.currentProvider() + if m.existingCfg.Providers == nil { + m.existingCfg.Providers = make(map[string]ProviderEntry) + } + entry := m.existingCfg.Providers[provider.Name] + entry.Model = model + m.existingCfg.Providers[provider.Name] = entry + if m.existingCfg.Provider == provider.Name { + m.existingCfg.Model = model + } + } + + if m.configPath != "" { + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + } + m.savedInSession = true + return nil +} + func (m providerTUIModel) isCustomModelItem(idx int) bool { return idx == len(m.models()) } @@ -368,7 +459,7 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m.updateAPIKeyInput(key, msg) } - if m.step == stepProvider && m.creatingCustom { + if m.step == stepProvider && (m.creatingCustom || m.editingCustom) { return m.updateCustomProviderForm(key, msg) } @@ -380,6 +471,10 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m.updateDeleteConfirm(key) } + if m.step == stepModel && m.confirmingDeleteModel { + return m.updateDeleteModelConfirm(key) + } + switch key { case "ctrl+c": m.cancelled = true @@ -391,6 +486,7 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Quit } m.step-- + m.formError = "" return m, nil case "enter": @@ -406,6 +502,7 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.step == stepProvider { if m.activeTab > 0 { m.activeTab-- + m.formError = "" } } return m, nil @@ -414,6 +511,7 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.step == stepProvider { if m.activeTab < tabCount-1 { m.activeTab++ + m.formError = "" } } return m, nil @@ -421,6 +519,7 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case "tab": if m.step == stepProvider { m.activeTab = (m.activeTab + 1) % tabCount + m.formError = "" } return m, nil @@ -429,12 +528,27 @@ func (m providerTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.confirmingDelete = true m.deleteTargetIdx = m.customIdx m.deleteTargetName = m.customProviders[m.customIdx].name + return m, nil + } + if m.step == stepModel && m.activeTab == tabCustom && m.customIdx < len(m.customProviders) { + models := m.models() + if m.modelIdx < len(models) { + m.confirmingDeleteModel = true + m.deleteModelName = models[m.modelIdx] + } + } + return m, nil + + case "e": + if m.step == stepProvider && m.activeTab == tabCustom && !m.creatingCustom && m.customIdx < len(m.customProviders) { + m.enterEditCustomProvider() + return m, m.cpNameInput.Focus() } return m, nil } default: - if m.step == stepProvider && m.creatingCustom { + if m.step == stepProvider && (m.creatingCustom || m.editingCustom) { return m.passThroughCPInput(msg) } if m.step == stepProvider && m.inManualForm { @@ -460,28 +574,89 @@ func (m providerTUIModel) updateCustomModelInput(key string, msg tea.KeyPressMsg m.customModel = false m.modelInput.Blur() m.modelInput.SetValue("") + m.formError = "" return m, nil case "enter": - if m.modelInput.Value() != "" { - m.customModel = false - m.modelInput.Blur() - m.step = stepAPIKey - m.loadExistingAPIKey() - return m, m.apiKeyInput.Focus() + name := strings.TrimSpace(m.modelInput.Value()) + if name == "" { + return m, nil + } + for _, existing := range m.models() { + if existing == name { + m.formError = fmt.Sprintf("Already in list: %s", name) + return m, nil + } } + m.formError = "" + if err := m.addCustomModelToSession(name); err != nil { + m.formError = err.Error() + return m, nil + } + m.customModel = false + m.modelInput.Blur() + m.modelInput.SetValue("") + // Reposition the cursor on the first newly-added model so the user + // can see what just landed. + m.refreshModelSelectionForCustom() return m, nil default: var cmd tea.Cmd m.modelInput, cmd = m.modelInput.Update(msg) + m.formError = "" return m, cmd } } +// addCustomModelToSession appends a single model name to the current custom +// provider's Models list and persists in-memory state to disk. It does not +// change the active model — the user picks that explicitly from the list +// afterwards. +func (m *providerTUIModel) addCustomModelToSession(name string) error { + if m.existingCfg == nil { + return nil + } + cp, ok := m.selectedCustomProvider() + if !ok { + return nil + } + entry := m.customProviderEntry(cp.name, cp.entry) + prevEntry := cloneProviderEntry(entry) + entry.Models = append(entry.Models, name) + if m.existingCfg.CustomProviders == nil { + m.existingCfg.CustomProviders = make(map[string]ProviderEntry) + } + m.existingCfg.CustomProviders[cp.name] = entry + cp.entry = entry + m.customProviders[m.customIdx] = cp + if m.configPath != "" { + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + m.existingCfg.CustomProviders[cp.name] = prevEntry + cp.entry = prevEntry + m.customProviders[m.customIdx] = cp + return fmt.Errorf("failed to save models: %w", err) + } + } + m.savedInSession = true + return nil +} + +// refreshModelSelectionForCustom moves the cursor to "Enter custom model name..." +// after the user adds models via the input field. +func (m *providerTUIModel) refreshModelSelectionForCustom() { + models := m.models() + m.modelIdx = 0 + if len(models) == 0 { + return + } + m.modelIdx = len(models) // land on "Enter custom model name..." +} + func (m providerTUIModel) updateAPIKeyInput(key string, msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { switch key { case "esc": m.apiKeyInput.Blur() m.step = stepModel + m.formError = "" return m, nil case "enter": m.confirmed = true @@ -500,6 +675,7 @@ func (m providerTUIModel) updateAPIKeyInput(key string, msg tea.KeyPressMsg) (te } var cmd tea.Cmd m.apiKeyInput, cmd = m.apiKeyInput.Update(msg) + m.formError = "" return m, cmd } } @@ -512,35 +688,109 @@ func (m providerTUIModel) updateCustomProviderForm(key string, msg tea.KeyPressM case "esc": if m.cpStep == cpStepName { m.creatingCustom = false + m.editingCustom = false + m.editTargetName = "" m.cpNameInput.Blur() + m.cpNameInput.SetValue("") + m.cpURLInput.SetValue("") + m.cpAuthInput.SetValue("") + m.apiKeyInput.SetValue("") + m.apiKeyMasked = false + m.apiKeyOriginal = "" + m.formError = "" return m, nil } m.blurCPStep() - m.cpStep-- + if m.editingCustom && m.cpStep == cpStepAPIKey { + m.cpStep = cpStepBaseURL + } else { + m.cpStep-- + } + m.formError = "" return m, m.focusCPStep() case "enter": return m.handleCustomFormEnter() - case "up", "k": - if m.cpStep == cpStepProtocol && m.cpProtocolIdx > 0 { - m.cpProtocolIdx-- + default: + if m.cpStep == cpStepProtocol { + switch key { + case "up", "k": + if m.cpProtocolIdx > 0 { + m.cpProtocolIdx-- + } + return m, nil + case "down", "j": + if m.cpProtocolIdx < len(cpProtocols)-1 { + m.cpProtocolIdx++ + } + return m, nil + } } - return m, nil - case "down", "j": - if m.cpStep == cpStepProtocol && m.cpProtocolIdx < len(cpProtocols)-1 { - m.cpProtocolIdx++ + if m.cpStep == cpStepAPIKey { + if m.apiKeyMasked { + if len(key) == 1 { + m.apiKeyMasked = false + m.apiKeyInput.SetValue("") + } else { + return m, nil + } + } + var cmd tea.Cmd + m.apiKeyInput, cmd = m.apiKeyInput.Update(msg) + return m, cmd } - return m, nil - default: return m.passThroughCPInput(msg) } } +func (m *providerTUIModel) enterEditCustomProvider() { + cp := m.customProviders[m.customIdx] + entry := m.customProviderEntry(cp.name, cp.entry) + m.editingCustom = true + m.editTargetName = cp.name + m.cpStep = cpStepName + m.formError = "" + protoIdx := 1 + if entry.Protocol == "anthropic" { + protoIdx = 0 + } + m.cpProtocolIdx = protoIdx + m.cpNameInput.SetValue(cp.name) + m.cpURLInput.SetValue(entry.URL) + m.cpAuthInput.SetValue(entry.AuthHeader) + if entry.APIKey != "" { + m.apiKeyOriginal = entry.APIKey + m.apiKeyMasked = true + m.apiKeyInput.SetValue(strings.Repeat("*", 20)) + } else { + m.apiKeyInput.SetValue("") + m.apiKeyMasked = false + m.apiKeyOriginal = "" + } +} + +func authHeaderFormError(raw string) string { + return fmt.Sprintf( + "Unsupported Auth Header %q. Use 'authorization' (default), 'x-api-key', or leave empty.", + strings.TrimSpace(raw), + ) +} + func (m providerTUIModel) handleCustomFormEnter() (tea.Model, tea.Cmd) { switch m.cpStep { case cpStepName: - if m.cpNameInput.Value() == "" { + name := m.cpNameInput.Value() + if name == "" { + return m, nil + } + if m.creatingCustom && m.customProviderNameTaken(name) { + m.formError = fmt.Sprintf(`Provider "%s" already exists`, name) + return m, nil + } + if m.editingCustom && name != m.editTargetName && m.customProviderNameTaken(name) { + m.formError = fmt.Sprintf(`Provider "%s" already exists`, name) return m, nil } + m.formError = "" m.cpNameInput.Blur() m.cpStep = cpStepProtocol return m, nil @@ -552,43 +802,234 @@ func (m providerTUIModel) handleCustomFormEnter() (tea.Model, tea.Cmd) { return m, nil } m.cpURLInput.Blur() - m.cpStep = cpStepModel - return m, m.cpModelInput.Focus() - case cpStepModel: - if m.cpModelInput.Value() == "" { - return m, nil - } - m.cpModelInput.Blur() - m.cpStep = cpStepModels - return m, m.cpModelsInput.Focus() - case cpStepModels: - m.cpModelsInput.Blur() m.cpStep = cpStepAPIKey - m.apiKeyInput.SetValue("") - m.apiKeyMasked = false - return m, m.apiKeyInput.Focus() + if m.creatingCustom { + m.apiKeyInput.SetValue("") + m.apiKeyMasked = false + } + return m, m.focusCPStep() case cpStepAPIKey: m.apiKeyInput.Blur() m.cpStep = cpStepAuthHeader return m, m.cpAuthInput.Focus() case cpStepAuthHeader: + raw := m.cpAuthInput.Value() + if _, err := llm.NormalizeAuthHeader(raw); err != nil { + m.formError = authHeaderFormError(raw) + return m, nil + } m.cpAuthInput.Blur() + if m.editingCustom { + r := m.result() + if err := m.applyEditCustomProviderSave(); err != nil { + return m, nil + } + // Edit succeeded — drop the user into the model list for this provider. + m.editingCustom = false + m.editTargetName = "" + if idx := m.findCustomIdx(r.provider); idx >= 0 { + m.customIdx = idx + } + m.step = stepModel + m.prepareModelSelection(m.customProviderEntry(r.provider, ProviderEntry{}).Model) + return m, nil + } + if m.creatingCustom { + return m.applyCreateCustomProvider() + } m.confirmed = true return m, tea.Quit } return m, nil } +func (m providerTUIModel) applyCreateCustomProvider() (tea.Model, tea.Cmd) { + if m.existingCfg == nil { + m.formError = "failed to save: config not loaded" + return m, nil + } + if m.configPath == "" { + m.formError = "failed to save: config path not available" + return m, nil + } + r := m.result() + if r.provider == "" { + m.formError = "Provider name is required" + m.cpStep = cpStepName + return m, m.cpNameInput.Focus() + } + if m.customProviderNameTaken(r.provider) { + m.formError = fmt.Sprintf(`Provider "%s" already exists`, r.provider) + m.cpStep = cpStepName + return m, m.cpNameInput.Focus() + } + + if m.existingCfg.CustomProviders == nil { + m.existingCfg.CustomProviders = make(map[string]ProviderEntry) + } + + entry := ProviderEntry{ + URL: r.url, + Protocol: r.protocol, + AuthHeader: r.authHeader, + } + if r.apiKey != "" { + entry.APIKey = r.apiKey + } + m.existingCfg.CustomProviders[r.provider] = entry + + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + m.formError = fmt.Sprintf("failed to save: %v", err) + return m, nil + } + + m.customProviders = collectCustomProviders(m.existingCfg) + if idx := m.findCustomIdx(r.provider); idx >= 0 { + m.customIdx = idx + } + m.creatingCustom = false + m.cpNameInput.SetValue("") + m.cpURLInput.SetValue("") + m.cpAuthInput.SetValue("") + m.apiKeyInput.SetValue("") + m.apiKeyMasked = false + m.apiKeyOriginal = "" + m.formError = "" + m.cpStep = cpStepName + m.savedInSession = true + // Drop into the model selection step so the user picks/adds a model for + // the newly created provider right away. + m.step = stepModel + m.prepareModelSelection("") + return m, nil +} + +// cloneProviderEntry deep-copies a ProviderEntry so callers (rollback paths, +// map cloning) can safely mutate the returned value without aliasing the +// original's slice or map fields. +func cloneProviderEntry(v ProviderEntry) ProviderEntry { + out := ProviderEntry{ + APIKey: v.APIKey, + URL: v.URL, + Protocol: v.Protocol, + Model: v.Model, + Models: append([]string(nil), v.Models...), + AuthHeader: v.AuthHeader, + } + if v.ExtraBody != nil { + out.ExtraBody = make(map[string]any, len(v.ExtraBody)) + for k, val := range v.ExtraBody { + // Shallow copy only: nested maps/slices inside val are not cloned. + out.ExtraBody[k] = val + } + } + return out +} + +func cloneCustomProvidersMap(src map[string]ProviderEntry) map[string]ProviderEntry { + if src == nil { + return nil + } + out := make(map[string]ProviderEntry, len(src)) + for k, v := range src { + out[k] = cloneProviderEntry(v) + } + return out +} + +func cloneCustomProviderList(src []customProviderListItem) []customProviderListItem { + out := make([]customProviderListItem, len(src)) + for i, cp := range src { + out[i] = customProviderListItem{name: cp.name, entry: cloneProviderEntry(cp.entry)} + } + return out +} + +func (m *providerTUIModel) applyEditCustomProviderSave() error { + if m.existingCfg == nil { + m.formError = "failed to save: config not loaded" + return fmt.Errorf("config not loaded") + } + if m.configPath == "" { + m.formError = "failed to save: config path not available" + return fmt.Errorf("config path not available") + } + r := m.result() + backupProviders := cloneCustomProvidersMap(m.existingCfg.CustomProviders) + backupActiveProvider := m.existingCfg.Provider + backupActiveModel := m.existingCfg.Model + backupCustomList := cloneCustomProviderList(m.customProviders) + + if m.existingCfg.CustomProviders == nil { + m.existingCfg.CustomProviders = make(map[string]ProviderEntry) + } + entry := m.existingCfg.CustomProviders[r.editTargetName] + if r.model != "" { + entry.Model = r.model + } + if len(r.models) > 0 { + entry.Models = append([]string(nil), r.models...) + } + entry.Models = ensureModelInList(entry.Models, r.model) + // Optional fields are always applied so users can intentionally clear them. + // To detect "user cleared the API key" vs "user left it masked/untouched", + // apiKey is only overwritten when the user actively typed something. + entry.URL = r.url + entry.Protocol = r.protocol + entry.AuthHeader = r.authHeader + if r.apiKey != "" { + entry.APIKey = r.apiKey + } + // If name changed, delete old key + if r.editTargetName != "" && r.editTargetName != r.provider { + if _, exists := m.existingCfg.CustomProviders[r.provider]; exists { + m.formError = fmt.Sprintf(`Provider "%s" already exists`, r.provider) + return fmt.Errorf("provider %q already exists", r.provider) + } + delete(m.existingCfg.CustomProviders, r.editTargetName) + if m.existingCfg.Provider == r.editTargetName { + m.existingCfg.Provider = r.provider + m.existingCfg.Model = "" + } + } + m.existingCfg.CustomProviders[r.provider] = entry + + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + m.formError = fmt.Sprintf("failed to save: %v", err) + if reloaded, reloadErr := loadOrCreateConfig(m.configPath); reloadErr == nil { + m.existingCfg = reloaded + m.customProviders = collectCustomProviders(reloaded) + } else { + m.existingCfg.CustomProviders = backupProviders + m.existingCfg.Provider = backupActiveProvider + m.existingCfg.Model = backupActiveModel + m.customProviders = backupCustomList + } + return fmt.Errorf("save config: %w", err) + } + m.customProviders = collectCustomProviders(m.existingCfg) + if idx := m.findCustomIdx(r.provider); idx >= 0 { + m.customIdx = idx + } + m.savedInSession = true + return nil +} + +func (m providerTUIModel) findCustomIdx(name string) int { + for i, cp := range m.customProviders { + if cp.name == name { + return i + } + } + return -1 +} + func (m *providerTUIModel) blurCPStep() { switch m.cpStep { case cpStepName: m.cpNameInput.Blur() case cpStepBaseURL: m.cpURLInput.Blur() - case cpStepModel: - m.cpModelInput.Blur() - case cpStepModels: - m.cpModelsInput.Blur() case cpStepAPIKey: m.apiKeyInput.Blur() case cpStepAuthHeader: @@ -596,16 +1037,12 @@ func (m *providerTUIModel) blurCPStep() { } } -func (m providerTUIModel) focusCPStep() tea.Cmd { +func (m *providerTUIModel) focusCPStep() tea.Cmd { switch m.cpStep { case cpStepName: return m.cpNameInput.Focus() case cpStepBaseURL: return m.cpURLInput.Focus() - case cpStepModel: - return m.cpModelInput.Focus() - case cpStepModels: - return m.cpModelsInput.Focus() case cpStepAPIKey: return m.apiKeyInput.Focus() case cpStepAuthHeader: @@ -621,18 +1058,15 @@ func (m providerTUIModel) passThroughCPInput(msg tea.Msg) (tea.Model, tea.Cmd) { m.cpNameInput, cmd = m.cpNameInput.Update(msg) case cpStepBaseURL: m.cpURLInput, cmd = m.cpURLInput.Update(msg) - case cpStepModel: - m.cpModelInput, cmd = m.cpModelInput.Update(msg) - case cpStepModels: - m.cpModelsInput, cmd = m.cpModelsInput.Update(msg) case cpStepAPIKey: - if m.apiKeyMasked { - return m, nil - } + // masked unlock is handled in updateCustomProviderForm default branch m.apiKeyInput, cmd = m.apiKeyInput.Update(msg) case cpStepAuthHeader: m.cpAuthInput, cmd = m.cpAuthInput.Update(msg) } + if _, ok := msg.(tea.KeyPressMsg); ok { + m.formError = "" + } return m, cmd } @@ -648,20 +1082,56 @@ func (m providerTUIModel) updateManualForm(key string, msg tea.KeyPressMsg) (tea if m.existingCfg != nil { m.manualURLInput.SetValue(m.existingCfg.Llm.URL) m.manualModelInput.SetValue(m.existingCfg.Llm.Model) - m.manualTokenInput.SetValue(m.existingCfg.Llm.AuthToken) + m.manualAuthHeaderInput.SetValue(m.existingCfg.Llm.AuthHeader) + if m.existingCfg.Llm.AuthToken != "" { + m.manualTokenOriginal = m.existingCfg.Llm.AuthToken + m.manualTokenMasked = true + m.manualTokenInput.SetValue(strings.Repeat("*", 20)) + } else { + m.manualTokenInput.SetValue("") + m.manualTokenMasked = false + m.manualTokenOriginal = "" + } } else { m.manualURLInput.SetValue("") m.manualModelInput.SetValue("") + m.manualAuthHeaderInput.SetValue("") m.manualTokenInput.SetValue("") + m.manualTokenMasked = false + m.manualTokenOriginal = "" } + m.formError = "" return m, nil } m.blurManualStep() m.manualStep-- + m.formError = "" return m, m.focusManualStep() case "enter": return m.handleManualFormEnter() default: + if m.manualStep == manualStepProtocol { + switch key { + case "up", "k": + if m.manualProtocolIdx > 0 { + m.manualProtocolIdx-- + } + return m, nil + case "down", "j": + if m.manualProtocolIdx < len(cpProtocols)-1 { + m.manualProtocolIdx++ + } + return m, nil + } + } + if m.manualStep == manualStepAuthToken && m.manualTokenMasked { + if len(key) == 1 { + m.manualTokenMasked = false + m.manualTokenInput.SetValue("") + } else { + return m, nil + } + } return m.passThroughManualInput(msg) } } @@ -681,6 +1151,27 @@ func (m providerTUIModel) updateDeleteConfirm(key string) (tea.Model, tea.Cmd) { if m.customIdx >= len(m.customProviders) && m.customIdx > 0 { m.customIdx = len(m.customProviders) - 1 } + if m.existingCfg != nil { + if m.existingCfg.CustomProviders != nil { + delete(m.existingCfg.CustomProviders, m.deleteTargetName) + } + if m.existingCfg.Provider == m.deleteTargetName { + m.existingCfg.Provider = "" + m.existingCfg.Model = "" + } + if m.configPath != "" { + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + if reloaded, reloadErr := loadOrCreateConfig(m.configPath); reloadErr == nil { + m.existingCfg = reloaded + m.customProviders = collectCustomProviders(reloaded) + } + m.formError = fmt.Sprintf("failed to save: %v", err) + m.confirmingDelete = false + return m, nil + } + } + } + m.savedInSession = true m.confirmingDelete = false return m, nil case "n", "N", "esc": @@ -693,6 +1184,64 @@ func (m providerTUIModel) updateDeleteConfirm(key string) (tea.Model, tea.Cmd) { return m, nil } +func (m providerTUIModel) updateDeleteModelConfirm(key string) (tea.Model, tea.Cmd) { + switch key { + case "y", "Y": + if m.customIdx >= len(m.customProviders) { + m.confirmingDeleteModel = false + return m, nil + } + models := m.models() + if m.modelIdx < len(models) { + cp := m.customProviders[m.customIdx] + cp.entry.Models = removeModels(cp.entry.Models, []string{m.deleteModelName}) + if cp.entry.Model == m.deleteModelName { + cp.entry.Model = "" + } + if m.existingCfg != nil && m.existingCfg.Provider == cp.name && + m.existingCfg.Model == m.deleteModelName { + m.existingCfg.Model = "" + } + m.customProviders[m.customIdx] = cp + if m.existingCfg != nil { + if m.existingCfg.CustomProviders == nil { + m.existingCfg.CustomProviders = make(map[string]ProviderEntry) + } + m.existingCfg.CustomProviders[cp.name] = cp.entry + } + if m.configPath != "" { + if err := saveConfig(m.configPath, m.existingCfg); err != nil { + if reloaded, reloadErr := loadOrCreateConfig(m.configPath); reloadErr == nil { + m.existingCfg = reloaded + m.customProviders = collectCustomProviders(reloaded) + } + m.formError = fmt.Sprintf("failed to save: %v", err) + m.confirmingDeleteModel = false + return m, nil + } + } + updated := m.models() + if m.modelIdx >= len(updated) { + if len(updated) > 0 { + m.modelIdx = len(updated) - 1 + } else { + m.modelIdx = 0 + } + } + } + m.savedInSession = true + m.confirmingDeleteModel = false + return m, nil + case "n", "N", "esc": + m.confirmingDeleteModel = false + return m, nil + case "ctrl+c": + m.cancelled = true + return m, tea.Quit + } + return m, nil +} + func (m providerTUIModel) handleManualFormEnter() (tea.Model, tea.Cmd) { switch m.manualStep { case manualStepURL: @@ -700,6 +1249,9 @@ func (m providerTUIModel) handleManualFormEnter() (tea.Model, tea.Cmd) { return m, nil } m.manualURLInput.Blur() + m.manualStep = manualStepProtocol + return m, nil + case manualStepProtocol: m.manualStep = manualStepModel return m, m.manualModelInput.Focus() case manualStepModel: @@ -710,7 +1262,19 @@ func (m providerTUIModel) handleManualFormEnter() (tea.Model, tea.Cmd) { m.manualStep = manualStepAuthToken return m, m.manualTokenInput.Focus() case manualStepAuthToken: + if m.manualTokenInput.Value() == "" && m.manualTokenOriginal == "" { + return m, nil + } m.manualTokenInput.Blur() + m.manualStep = manualStepAuthHeader + return m, m.manualAuthHeaderInput.Focus() + case manualStepAuthHeader: + raw := m.manualAuthHeaderInput.Value() + if _, err := llm.NormalizeAuthHeader(raw); err != nil { + m.formError = authHeaderFormError(raw) + return m, nil + } + m.manualAuthHeaderInput.Blur() m.confirmed = true return m, tea.Quit } @@ -721,10 +1285,14 @@ func (m *providerTUIModel) blurManualStep() { switch m.manualStep { case manualStepURL: m.manualURLInput.Blur() + case manualStepProtocol: + // no input to blur case manualStepModel: m.manualModelInput.Blur() case manualStepAuthToken: m.manualTokenInput.Blur() + case manualStepAuthHeader: + m.manualAuthHeaderInput.Blur() } } @@ -732,10 +1300,14 @@ func (m providerTUIModel) focusManualStep() tea.Cmd { switch m.manualStep { case manualStepURL: return m.manualURLInput.Focus() + case manualStepProtocol: + return nil case manualStepModel: return m.manualModelInput.Focus() case manualStepAuthToken: return m.manualTokenInput.Focus() + case manualStepAuthHeader: + return m.manualAuthHeaderInput.Focus() } return nil } @@ -745,10 +1317,17 @@ func (m providerTUIModel) passThroughManualInput(msg tea.Msg) (tea.Model, tea.Cm switch m.manualStep { case manualStepURL: m.manualURLInput, cmd = m.manualURLInput.Update(msg) + case manualStepProtocol: + return m, nil case manualStepModel: m.manualModelInput, cmd = m.manualModelInput.Update(msg) case manualStepAuthToken: m.manualTokenInput, cmd = m.manualTokenInput.Update(msg) + case manualStepAuthHeader: + m.manualAuthHeaderInput, cmd = m.manualAuthHeaderInput.Update(msg) + } + if _, ok := msg.(tea.KeyPressMsg); ok { + m.formError = "" } return m, cmd } @@ -761,8 +1340,8 @@ func (m providerTUIModel) handleEnter() (tea.Model, tea.Cmd) { m.step = stepModel currentModel := "" if m.existingCfg != nil { - if entry, ok := m.existingCfg.Providers[m.currentProvider().Name]; ok && entry.Model != "" { - currentModel = entry.Model + if entry, ok := m.existingCfg.Providers[m.currentProvider().Name]; ok { + currentModel = activeModelForProvider(m.existingCfg, m.currentProvider().Name, entry) } } m.prepareModelSelection(currentModel) @@ -773,11 +1352,10 @@ func (m providerTUIModel) handleEnter() (tea.Model, tea.Cmd) { if m.customIdx == addIdx { m.creatingCustom = true m.cpStep = cpStepName - m.cpProtocolIdx = 1 // default openai + m.cpProtocolIdx = 0 // default anthropic + m.formError = "" m.cpNameInput.SetValue("") m.cpURLInput.SetValue("") - m.cpModelInput.SetValue("") - m.cpModelsInput.SetValue("") m.cpAuthInput.SetValue("") m.apiKeyInput.SetValue("") m.apiKeyMasked = false @@ -785,7 +1363,8 @@ func (m providerTUIModel) handleEnter() (tea.Model, tea.Cmd) { } cp := m.customProviders[m.customIdx] m.step = stepModel - m.prepareModelSelection(cp.entry.Model) + entry := m.customProviderEntry(cp.name, cp.entry) + m.prepareModelSelection(activeModelForProvider(m.existingCfg, cp.name, entry)) return m, nil case tabManual: @@ -799,7 +1378,12 @@ func (m providerTUIModel) handleEnter() (tea.Model, tea.Cmd) { m.customModel = true return m, m.modelInput.Focus() } + if err := m.syncSessionModelSelection(); err != nil { + m.formError = err.Error() + return m, nil + } m.step = stepAPIKey + m.formError = "" m.loadExistingAPIKey() return m, m.apiKeyInput.Focus() } @@ -813,15 +1397,21 @@ func (m providerTUIModel) handleUp() (tea.Model, tea.Cmd) { case tabOfficial: if m.officialIdx > 0 { m.officialIdx-- + } else if len(m.providers) > 0 { + m.officialIdx = len(m.providers) - 1 } case tabCustom: if m.customIdx > 0 { m.customIdx-- + } else { + m.customIdx = m.customListCount() - 1 } } case stepModel: if m.modelIdx > 0 { m.modelIdx-- + } else { + m.modelIdx = m.modelCount() - 1 } } return m, nil @@ -834,15 +1424,21 @@ func (m providerTUIModel) handleDown() (tea.Model, tea.Cmd) { case tabOfficial: if m.officialIdx < len(m.providers)-1 { m.officialIdx++ + } else if len(m.providers) > 0 { + m.officialIdx = 0 } case tabCustom: if m.customIdx < m.customListCount()-1 { m.customIdx++ + } else { + m.customIdx = 0 } } case stepModel: if m.modelIdx < m.modelCount()-1 { m.modelIdx++ + } else { + m.modelIdx = 0 } } return m, nil @@ -902,22 +1498,32 @@ func (m providerTUIModel) result() providerTUIResult { } case tabCustom: - if m.creatingCustom { + if m.creatingCustom || m.editingCustom { protocol := cpProtocols[m.cpProtocolIdx] - models := mergeModelLists( - []string{m.cpModelInput.Value()}, - strings.Split(m.cpModelsInput.Value(), ","), - ) - return providerTUIResult{ - provider: m.cpNameInput.Value(), - model: m.cpModelInput.Value(), - models: models, - apiKey: m.apiKeyInput.Value(), - isCustom: true, - url: m.cpURLInput.Value(), - protocol: protocol, - authHeader: m.cpAuthInput.Value(), + apiKey := m.apiKeyInput.Value() + if m.apiKeyMasked { + apiKey = m.apiKeyOriginal + } + authHeader, _ := llm.NormalizeAuthHeader(m.cpAuthInput.Value()) + r := providerTUIResult{ + provider: m.cpNameInput.Value(), + apiKey: apiKey, + isCustom: true, + isEdit: m.editingCustom, + editTargetName: m.editTargetName, + url: m.cpURLInput.Value(), + protocol: protocol, + authHeader: authHeader, } + // Models are managed in the model selection step, not in the + // create/edit form. Preserve existing model/models when editing. + if m.editingCustom { + if idx := m.findCustomIdx(m.editTargetName); idx >= 0 { + r.model = m.customProviders[idx].entry.Model + r.models = m.customProviders[idx].entry.Models + } + } + return r } if m.customIdx < len(m.customProviders) { cp := m.customProviders[m.customIdx] @@ -934,7 +1540,7 @@ func (m providerTUIModel) result() providerTUIResult { return providerTUIResult{ provider: cp.name, model: model, - models: mergeModelLists([]string{model}, cp.entry.Models), + models: append([]string(nil), cp.entry.Models...), apiKey: apiKey, isCustom: true, url: cp.entry.URL, @@ -945,17 +1551,38 @@ func (m providerTUIModel) result() providerTUIResult { return providerTUIResult{} case tabManual: + apiKey := m.manualTokenInput.Value() + if m.manualTokenMasked || (apiKey == "" && m.manualTokenOriginal != "") { + apiKey = m.manualTokenOriginal + } + authHeader, _ := llm.NormalizeAuthHeader(m.manualAuthHeaderInput.Value()) return providerTUIResult{ - isManual: true, - url: m.manualURLInput.Value(), - model: m.manualModelInput.Value(), - apiKey: m.manualTokenInput.Value(), + isManual: true, + url: m.manualURLInput.Value(), + model: m.manualModelInput.Value(), + apiKey: apiKey, + protocol: cpProtocols[m.manualProtocolIdx], + authHeader: authHeader, } } return providerTUIResult{} } +func listCursorPrefix(isCursor bool) string { + if isCursor { + return " " + tuiCursorStyle.Render(tuiCursor) + " " + } + return " " +} + +func renderListName(name string, isCursor bool) string { + if isCursor { + return tuiSelectedItemStyle.Render(name) + } + return tuiItemStyle.Render(name) +} + // --- View --- func (m providerTUIModel) View() tea.View { @@ -1011,12 +1638,12 @@ func (m providerTUIModel) viewProvider(s *strings.Builder) { } s.WriteString("\n") - if m.creatingCustom || m.inManualForm { + if m.creatingCustom || m.editingCustom || m.inManualForm { s.WriteString(tuiHelpStyle.Render(" Enter Confirm · Esc Back")) } else if m.confirmingDelete { s.WriteString(tuiHelpStyle.Render(" y Confirm · n/Esc Cancel")) } else if m.activeTab == tabCustom && m.customIdx < len(m.customProviders) { - s.WriteString(tuiHelpStyle.Render(" Enter Select · d Delete · Tab/Arrow Navigate · Esc Cancel")) + s.WriteString(tuiHelpStyle.Render(" Enter Select · e Edit · d Delete · Tab/Arrow Navigate · Esc Cancel")) } else { s.WriteString(tuiHelpStyle.Render(" Enter to select · Tab/Arrow keys to navigate · Esc to cancel")) } @@ -1028,22 +1655,14 @@ func (m providerTUIModel) viewOfficialTab(s *strings.Builder) { s.WriteString("\n\n") for i, p := range m.providers { - cursor := " " - if i == m.officialIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } - name := p.DisplayName - if i == m.officialIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(name)) - } else { - s.WriteString(cursor + tuiItemStyle.Render(name)) - } + isCursor := i == m.officialIdx + s.WriteString(listCursorPrefix(isCursor) + renderListName(p.DisplayName, isCursor)) s.WriteString("\n") } } func (m providerTUIModel) viewCustomTab(s *strings.Builder) { - if m.creatingCustom { + if m.creatingCustom || m.editingCustom { m.viewCustomProviderForm(s) return } @@ -1052,21 +1671,12 @@ func (m providerTUIModel) viewCustomTab(s *strings.Builder) { s.WriteString("\n\n") for i, cp := range m.customProviders { - cursor := " " - if i == m.customIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } - label := cp.name - if cp.entry.Model != "" { - label += " " + tuiDimStyle.Render("("+cp.entry.Model+")") - } - if i == m.customIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(cp.name)) - if cp.entry.Model != "" { - s.WriteString(" " + tuiDimStyle.Render("("+cp.entry.Model+")")) - } - } else { - s.WriteString(cursor + label) + isCursor := i == m.customIdx + activeModel := m.customProviderActiveModel(cp) + + s.WriteString(listCursorPrefix(isCursor) + renderListName(cp.name, isCursor)) + if activeModel != "" { + s.WriteString(" " + tuiDimStyle.Render("("+activeModel+")")) } s.WriteString("\n") } @@ -1099,7 +1709,11 @@ func (m providerTUIModel) viewCustomTab(s *strings.Builder) { } func (m providerTUIModel) viewCustomProviderForm(s *strings.Builder) { - s.WriteString(tuiTitleStyle.Render(" Add Custom Provider")) + title := " Add Custom Provider" + if m.editingCustom { + title = fmt.Sprintf(" Edit Custom Provider (%s)", m.editTargetName) + } + s.WriteString(tuiTitleStyle.Render(title)) s.WriteString("\n\n") type field struct { @@ -1112,8 +1726,6 @@ func (m providerTUIModel) viewCustomProviderForm(s *strings.Builder) { {"Provider name", m.cpNameInput.Value(), m.cpStep == cpStepName}, {"Protocol", cpProtocols[m.cpProtocolIdx], m.cpStep == cpStepProtocol}, {"Base URL", m.cpURLInput.Value(), m.cpStep == cpStepBaseURL}, - {"Model", m.cpModelInput.Value(), m.cpStep == cpStepModel}, - {"Models", m.cpModelsInput.Value(), m.cpStep == cpStepModels}, {"API Key", strings.Repeat("*", len(m.apiKeyInput.Value())), m.cpStep == cpStepAPIKey}, {"Auth Header", m.cpAuthInput.Value(), m.cpStep == cpStepAuthHeader}, } @@ -1126,31 +1738,39 @@ func (m providerTUIModel) viewCustomProviderForm(s *strings.Builder) { s.WriteString(" " + m.cpNameInput.View() + "\n") case cpStepProtocol: for i, proto := range cpProtocols { - cur := " " - if i == m.cpProtocolIdx { - cur = " " + tuiCursorStyle.Render(tuiCursor) + " " - } if i == m.cpProtocolIdx { + cur := " " + tuiCursorStyle.Render(tuiCursor) + " " s.WriteString(cur + tuiSelectedItemStyle.Render(proto) + "\n") } else { + cur := " " s.WriteString(cur + tuiItemStyle.Render(proto) + "\n") } } case cpStepBaseURL: s.WriteString(" " + m.cpURLInput.View() + "\n") - case cpStepModel: - s.WriteString(" " + m.cpModelInput.View() + "\n") - case cpStepModels: - s.WriteString(" " + m.cpModelsInput.View() + "\n") case cpStepAPIKey: s.WriteString(" " + m.apiKeyInput.View() + "\n") case cpStepAuthHeader: s.WriteString(" " + m.cpAuthInput.View() + "\n") } - } else if f.value != "" { - s.WriteString(" " + tuiDimStyle.Render(f.label+": "+f.value) + "\n") + } else { + display := f.value + if display == "" && f.label == "Auth Header" { + display = "(Authorization)" + } + if display == "" { + s.WriteString(" " + tuiDimStyle.Render(f.label+":") + "\n") + } else { + s.WriteString(" " + tuiDimStyle.Render(f.label+": "+display) + "\n") + } } } + + if m.formError != "" { + s.WriteString("\n") + s.WriteString(tuiErrorStyle.Render(" " + m.formError)) + s.WriteString("\n") + } } func (m providerTUIModel) viewManualTab(s *strings.Builder) { @@ -1181,8 +1801,10 @@ func (m providerTUIModel) viewManualTab(s *strings.Builder) { fields := []field{ {"URL", m.manualURLInput.Value(), m.manualStep == manualStepURL}, + {"Protocol", cpProtocols[m.manualProtocolIdx], m.manualStep == manualStepProtocol}, {"Model", m.manualModelInput.Value(), m.manualStep == manualStepModel}, {"Auth Token", strings.Repeat("*", len(m.manualTokenInput.Value())), m.manualStep == manualStepAuthToken}, + {"Auth Header", m.manualAuthHeaderInput.Value(), m.manualStep == manualStepAuthHeader}, } for _, f := range fields { @@ -1191,15 +1813,41 @@ func (m providerTUIModel) viewManualTab(s *strings.Builder) { switch m.manualStep { case manualStepURL: s.WriteString(" " + m.manualURLInput.View() + "\n") + case manualStepProtocol: + for i, proto := range cpProtocols { + if i == m.manualProtocolIdx { + cur := " " + tuiCursorStyle.Render(tuiCursor) + " " + s.WriteString(cur + tuiSelectedItemStyle.Render(proto) + "\n") + } else { + cur := " " + s.WriteString(cur + tuiItemStyle.Render(proto) + "\n") + } + } case manualStepModel: s.WriteString(" " + m.manualModelInput.View() + "\n") case manualStepAuthToken: s.WriteString(" " + m.manualTokenInput.View() + "\n") + case manualStepAuthHeader: + s.WriteString(" " + m.manualAuthHeaderInput.View() + "\n") + } + } else { + display := f.value + if display == "" && f.label == "Auth Header" { + display = "(Authorization)" + } + if display == "" { + s.WriteString(" " + tuiDimStyle.Render(f.label+":") + "\n") + } else { + s.WriteString(" " + tuiDimStyle.Render(f.label+": "+display) + "\n") } - } else if f.value != "" { - s.WriteString(" " + tuiDimStyle.Render(f.label+": "+f.value) + "\n") } } + + if m.formError != "" { + s.WriteString("\n") + s.WriteString(tuiErrorStyle.Render(" " + m.formError)) + s.WriteString("\n") + } } func (m providerTUIModel) viewModel(s *strings.Builder) { @@ -1207,40 +1855,44 @@ func (m providerTUIModel) viewModel(s *strings.Builder) { s.WriteString("\n\n") models := m.models() + for i, model := range models { - cursor := " " - if i == m.modelIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } - if i == m.modelIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(model)) - } else { - s.WriteString(cursor + tuiItemStyle.Render(model)) - } + isCursor := i == m.modelIdx + s.WriteString(listCursorPrefix(isCursor) + renderListName(model, isCursor)) s.WriteString("\n") } customIdx := len(models) - cursor := " " - if m.modelIdx == customIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } + isCursor := m.modelIdx == customIdx customLabel := "Enter custom model name..." - if m.modelIdx == customIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(customLabel)) + if isCursor { + s.WriteString(listCursorPrefix(isCursor) + tuiSelectedItemStyle.Render(customLabel)) } else { - s.WriteString(cursor + tuiDimStyle.Render(customLabel)) + s.WriteString(listCursorPrefix(isCursor) + tuiDimStyle.Render(customLabel)) } s.WriteString("\n") if m.customModel { s.WriteString("\n") s.WriteString(" " + m.modelInput.View()) + if m.formError != "" { + s.WriteString("\n") + s.WriteString(" " + tuiErrorStyle.Render(m.formError)) + } s.WriteString("\n") } s.WriteString("\n") - s.WriteString(tuiHelpStyle.Render(" ↑/↓ Select Enter Confirm Esc Back")) + + if m.confirmingDeleteModel { + s.WriteString(" " + tuiSelectedItemStyle.Render(fmt.Sprintf("Delete %q? (y/n)", m.deleteModelName))) + s.WriteString("\n") + s.WriteString(tuiHelpStyle.Render(" y Confirm · n/Esc Cancel")) + } else if m.activeTab == tabCustom && m.customIdx < len(m.customProviders) { + s.WriteString(tuiHelpStyle.Render(" ↑/↓ Select Enter Confirm d Delete Esc Back")) + } else { + s.WriteString(tuiHelpStyle.Render(" ↑/↓ Select Enter Confirm Esc Back")) + } s.WriteString("\n") } @@ -1306,6 +1958,9 @@ var ( tuiInactiveTabStyle = lipgloss.NewStyle(). Foreground(lipgloss.Color("8")) + + tuiErrorStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("9")) ) // --- Model-only TUI (for `ocr config model`) --- @@ -1319,6 +1974,7 @@ type modelTUIModel struct { modelIdx int customModel bool modelInput textinput.Model + activeModel string confirmed bool cancelled bool @@ -1326,20 +1982,21 @@ type modelTUIModel struct { func newModelTUI(provider llm.Provider, currentModel string) modelTUIModel { mi := textinput.New() - mi.Placeholder = "model name" - mi.SetWidth(40) + mi.Placeholder = "model name(s), comma-separated" + mi.SetWidth(50) m := modelTUIModel{ - provider: provider, - models: provider.Models, - width: 80, - height: 24, - modelInput: mi, + provider: provider, + models: provider.Models, + width: 80, + height: 24, + modelInput: mi, + activeModel: currentModel, } if currentModel != "" { found := false - for i, model := range provider.Models { + for i, model := range m.models { if model == currentModel { m.modelIdx = i found = true @@ -1347,7 +2004,7 @@ func newModelTUI(provider llm.Provider, currentModel string) modelTUIModel { } } if !found { - m.modelIdx = len(provider.Models) + m.modelIdx = len(m.models) m.modelInput.SetValue(currentModel) } } @@ -1411,11 +2068,15 @@ func (m modelTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case "up", "k": if m.modelIdx > 0 { m.modelIdx-- + } else { + m.modelIdx = m.itemCount() - 1 } return m, nil case "down", "j": if m.modelIdx < m.itemCount()-1 { m.modelIdx++ + } else { + m.modelIdx = 0 } return m, nil } @@ -1447,28 +2108,18 @@ func (m modelTUIModel) View() tea.View { s.WriteString("\n\n") for i, model := range m.models { - cursor := " " - if i == m.modelIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } - if i == m.modelIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(model)) - } else { - s.WriteString(cursor + tuiItemStyle.Render(model)) - } + isCursor := i == m.modelIdx + s.WriteString(listCursorPrefix(isCursor) + renderListName(model, isCursor)) s.WriteString("\n") } customIdx := len(m.models) - cursor := " " - if m.modelIdx == customIdx { - cursor = " " + tuiCursorStyle.Render(tuiCursor) + " " - } + isCursor := m.modelIdx == customIdx customLabel := "Enter custom model name..." - if m.modelIdx == customIdx { - s.WriteString(cursor + tuiSelectedItemStyle.Render(customLabel)) + if isCursor { + s.WriteString(listCursorPrefix(isCursor) + tuiSelectedItemStyle.Render(customLabel)) } else { - s.WriteString(cursor + tuiDimStyle.Render(customLabel)) + s.WriteString(listCursorPrefix(isCursor) + tuiDimStyle.Render(customLabel)) } s.WriteString("\n") diff --git a/cmd/opencodereview/provider_tui_test.go b/cmd/opencodereview/provider_tui_test.go index cb54b83..42e9d41 100644 --- a/cmd/opencodereview/provider_tui_test.go +++ b/cmd/opencodereview/provider_tui_test.go @@ -1,6 +1,8 @@ package main import ( + "os" + "path/filepath" "sort" "strings" "testing" @@ -33,13 +35,13 @@ func tabKeyMsg() tea.KeyPressMsg { } func charKey(c rune) tea.KeyPressMsg { - return tea.KeyPressMsg{Code: c} + return tea.KeyPressMsg{Code: c, Text: string(c)} } // --- Tab switching tests --- func TestProviderTUI_TabSwitchRight(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") if m.activeTab != tabOfficial { t.Fatalf("initial tab = %d, want %d", m.activeTab, tabOfficial) } @@ -65,7 +67,7 @@ func TestProviderTUI_TabSwitchRight(t *testing.T) { } func TestProviderTUI_TabSwitchLeft(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Go to manual tab first result, _ := m.Update(rightKey()) @@ -97,7 +99,7 @@ func TestProviderTUI_TabSwitchLeft(t *testing.T) { } func TestProviderTUI_TabKeyCycles(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") result, _ := m.Update(tabKeyMsg()) m2 := result.(providerTUIModel) @@ -119,7 +121,7 @@ func TestProviderTUI_TabKeyCycles(t *testing.T) { } func TestProviderTUI_TabSwitchOnlyOnStepProvider(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Advance to stepModel result, _ := m.Update(enterKey()) @@ -139,7 +141,7 @@ func TestProviderTUI_TabSwitchOnlyOnStepProvider(t *testing.T) { // --- Official tab tests (updated from original) --- func TestProviderTUI_OfficialProvidersSortedByDisplayName(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") displayNames := make([]string, len(m.providers)) normalized := make([]string, len(m.providers)) @@ -154,7 +156,7 @@ func TestProviderTUI_OfficialProvidersSortedByDisplayName(t *testing.T) { } func TestProviderTUI_EscFromModelGoesBackToProvider(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") result, _ := m.Update(enterKey()) m2 := result.(providerTUIModel) @@ -173,7 +175,7 @@ func TestProviderTUI_EscFromModelGoesBackToProvider(t *testing.T) { } func TestProviderTUI_EscFromAPIKeyGoesBackToModel(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") result, _ := m.Update(enterKey()) m2 := result.(providerTUIModel) @@ -192,7 +194,7 @@ func TestProviderTUI_EscFromAPIKeyGoesBackToModel(t *testing.T) { } func TestProviderTUI_EscFromProviderCancels(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") result, cmd := m.Update(escKey()) m2 := result.(providerTUIModel) @@ -214,7 +216,7 @@ func TestProviderTUI_EscKeyString(t *testing.T) { // --- Manual tab tests --- func TestProviderTUI_ManualTabEnterStartsForm(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Switch to manual tab result, _ := m.Update(rightKey()) @@ -237,7 +239,7 @@ func TestProviderTUI_ManualTabEnterStartsForm(t *testing.T) { } func TestProviderTUI_ManualFormEscFromURLExitsForm(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Switch to manual tab and enter form result, _ := m.Update(rightKey()) @@ -269,7 +271,7 @@ func TestProviderTUI_ManualFormEscRestoresOriginalValues(t *testing.T) { AuthToken: "token-123", }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Enter the form result, _ := m.Update(enterKey()) @@ -293,8 +295,11 @@ func TestProviderTUI_ManualFormEscRestoresOriginalValues(t *testing.T) { if m3.manualModelInput.Value() != "test-model" { t.Errorf("Model not restored: got %q, want %q", m3.manualModelInput.Value(), "test-model") } - if m3.manualTokenInput.Value() != "token-123" { - t.Errorf("Token not restored: got %q, want %q", m3.manualTokenInput.Value(), "token-123") + if !m3.manualTokenMasked { + t.Error("Token should be masked after Esc restore") + } + if m3.manualTokenOriginal != "token-123" { + t.Errorf("Token original not restored: got %q, want %q", m3.manualTokenOriginal, "token-123") } } @@ -306,7 +311,7 @@ func TestProviderTUI_ManualFormPrefilledValues(t *testing.T) { AuthToken: "token-123", }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") if m.activeTab != tabManual { t.Fatalf("should auto-select manual tab when Llm.URL is set, got %d", m.activeTab) @@ -317,8 +322,14 @@ func TestProviderTUI_ManualFormPrefilledValues(t *testing.T) { if m.manualModelInput.Value() != "test-model" { t.Errorf("Model not prefilled: got %q", m.manualModelInput.Value()) } - if m.manualTokenInput.Value() != "token-123" { - t.Errorf("Token not prefilled: got %q", m.manualTokenInput.Value()) + if !m.manualTokenMasked { + t.Error("Token should be masked when prefilled") + } + if m.manualTokenOriginal != "token-123" { + t.Errorf("Token original not prefilled: got %q, want %q", m.manualTokenOriginal, "token-123") + } + if m.manualTokenInput.Value() != strings.Repeat("*", 20) { + t.Errorf("Token input not masked display: got %q", m.manualTokenInput.Value()) } } @@ -330,7 +341,7 @@ func TestProviderTUI_ManualResult(t *testing.T) { AuthToken: "token-123", }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Enter the form result, _ := m.Update(enterKey()) @@ -361,7 +372,7 @@ func TestProviderTUI_ManualFormPrefilledWhenProviderSet(t *testing.T) { AuthToken: "manual-token", }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") if m.activeTab != tabCustom { t.Fatalf("should auto-select custom tab, got %d", m.activeTab) @@ -372,15 +383,77 @@ func TestProviderTUI_ManualFormPrefilledWhenProviderSet(t *testing.T) { if m.manualModelInput.Value() != "manual-model" { t.Errorf("Model not prefilled: got %q", m.manualModelInput.Value()) } - if m.manualTokenInput.Value() != "manual-token" { - t.Errorf("Token not prefilled: got %q", m.manualTokenInput.Value()) + if !m.manualTokenMasked { + t.Error("Token should be masked when prefilled") + } + if m.manualTokenOriginal != "manual-token" { + t.Errorf("Token original not prefilled: got %q, want %q", m.manualTokenOriginal, "manual-token") + } +} + +func TestProviderTUI_ManualFormPrefillsAuthHeader(t *testing.T) { + cfg := &Config{ + Llm: LlmConfig{ + URL: "https://manual.example.com/v1", + Model: "manual-model", + AuthToken: "manual-token", + AuthHeader: "X-Custom-Auth", + }, + } + m := newProviderTUI(cfg, "") + + if got := m.manualAuthHeaderInput.Value(); got != "X-Custom-Auth" { + t.Errorf("manualAuthHeaderInput not prefilled: got %q, want %q", got, "X-Custom-Auth") + } +} + +func TestProviderTUI_ManualFormSkipsEmptyTokenWhenOriginalExists(t *testing.T) { + cfg := &Config{ + Llm: LlmConfig{ + URL: "https://example.com/v1", + Model: "test-model", + AuthToken: "token-123", + }, + } + m := newProviderTUI(cfg, "") + m.inManualForm = true + m.manualStep = manualStepAuthToken + m.manualTokenOriginal = "token-123" + m.manualTokenMasked = false + m.manualTokenInput.SetValue("") + m.manualTokenInput.Focus() + + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + if m2.manualStep != manualStepAuthHeader { + t.Errorf("manualStep = %d, want %d", m2.manualStep, manualStepAuthHeader) + } + + m2.confirmed = true + r := m2.result() + if r.apiKey != "token-123" { + t.Errorf("result apiKey = %q, want %q", r.apiKey, "token-123") + } +} + +func TestProviderTUI_ManualFormRequiresTokenOnFirstSetup(t *testing.T) { + m := newProviderTUI(&Config{}, "") + m.inManualForm = true + m.manualStep = manualStepAuthToken + m.manualTokenInput.SetValue("") + m.manualTokenInput.Focus() + + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + if m2.manualStep != manualStepAuthToken { + t.Errorf("should stay on auth token step, got %d", m2.manualStep) } } // --- Custom tab tests --- func TestProviderTUI_CustomTabShowsAddOption(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Switch to custom tab result, _ := m.Update(rightKey()) @@ -396,7 +469,7 @@ func TestProviderTUI_CustomTabShowsAddOption(t *testing.T) { } func TestProviderTUI_CustomTabSelectAddStartsForm(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Switch to custom tab result, _ := m.Update(rightKey()) @@ -414,7 +487,7 @@ func TestProviderTUI_CustomTabSelectAddStartsForm(t *testing.T) { } func TestProviderTUI_CustomFormEscFromNameExitsForm(t *testing.T) { - m := newProviderTUI(&Config{}) + m := newProviderTUI(&Config{}, "") // Switch to custom tab and start form result, _ := m.Update(rightKey()) @@ -436,6 +509,231 @@ func TestProviderTUI_CustomFormEscFromNameExitsForm(t *testing.T) { } } +func TestProviderTUI_CustomFormRejectsDuplicateName(t *testing.T) { + cfg := &Config{ + Provider: "stepfun", + CustomProviders: map[string]ProviderEntry{ + "stepfun": {Model: "xxx"}, + }, + } + m := newProviderTUI(cfg, "") + + result, _ := m.Update(downKey()) + m2 := result.(providerTUIModel) + + result, _ = m2.Update(enterKey()) + m3 := result.(providerTUIModel) + if !m3.creatingCustom { + t.Fatal("should be creating custom") + } + + m3.cpNameInput.SetValue("stepfun") + result, _ = m3.Update(enterKey()) + m4 := result.(providerTUIModel) + if m4.cpStep != cpStepName { + t.Errorf("cpStep = %d, want %d", m4.cpStep, cpStepName) + } + if m4.formError == "" { + t.Error("expected formError for duplicate name") + } + if !strings.Contains(m4.formError, "stepfun") { + t.Errorf("formError = %q, want to mention stepfun", m4.formError) + } + + result, _ = m4.Update(charKey('x')) + m4b := result.(providerTUIModel) + if m4b.formError != "" { + t.Errorf("formError should clear on keystroke, got %q", m4b.formError) + } + + m4b.cpNameInput.SetValue("stepfun2") + result, _ = m4b.Update(enterKey()) + m5 := result.(providerTUIModel) + if m5.cpStep != cpStepProtocol { + t.Errorf("cpStep = %d, want %d", m5.cpStep, cpStepProtocol) + } + if m5.formError != "" { + t.Errorf("formError = %q, want empty after valid name", m5.formError) + } +} + +func TestProviderTUI_CustomFormRejectsInvalidAuthHeader(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{} + m := newProviderTUI(cfg, configPath) + + result, _ := m.Update(rightKey()) + m2 := result.(providerTUIModel) + result, _ = m2.Update(enterKey()) + m3 := result.(providerTUIModel) + + m3.cpNameInput.SetValue("my-new") + result, _ = m3.Update(enterKey()) + m4 := result.(providerTUIModel) + result, _ = m4.Update(enterKey()) + m5 := result.(providerTUIModel) + m5.cpURLInput.SetValue("https://api.example.com") + result, _ = m5.Update(enterKey()) + m6 := result.(providerTUIModel) + result, _ = m6.Update(enterKey()) + m7 := result.(providerTUIModel) + if m7.cpStep != cpStepAuthHeader { + t.Fatalf("cpStep = %d, want %d", m7.cpStep, cpStepAuthHeader) + } + + for _, c := range "bad-header" { + result, _ = m7.Update(charKey(c)) + m7 = result.(providerTUIModel) + } + result, _ = m7.Update(enterKey()) + m8 := result.(providerTUIModel) + + if m8.cpStep != cpStepAuthHeader { + t.Errorf("cpStep = %d, want %d", m8.cpStep, cpStepAuthHeader) + } + if m8.formError == "" { + t.Error("expected formError for invalid auth header") + } + if !strings.Contains(m8.formError, "Unsupported Auth Header") { + t.Errorf("formError = %q, want unsupported auth header message", m8.formError) + } + if !m8.creatingCustom { + t.Error("creatingCustom should remain true when validation fails") + } + if _, err := os.Stat(configPath); err == nil { + t.Error("config should not be saved for invalid auth header") + } +} + +func TestProviderTUI_CustomFormEditRejectsInvalidAuthHeader(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{ + CustomProviders: map[string]ProviderEntry{ + "stepfun": { + URL: "https://api.example.com", + Protocol: "anthropic", + AuthHeader: "authorization", + }, + }, + } + m := newProviderTUI(cfg, configPath) + m.activeTab = tabCustom + m.customIdx = 0 + m.enterEditCustomProvider() + m.cpStep = cpStepAuthHeader + m.cpAuthInput.SetValue("bad-header") + m.cpAuthInput.Focus() + + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + + if m2.cpStep != cpStepAuthHeader { + t.Errorf("cpStep = %d, want %d", m2.cpStep, cpStepAuthHeader) + } + if m2.formError == "" { + t.Error("expected formError for invalid auth header") + } + if !m2.editingCustom { + t.Error("editingCustom should remain true when validation fails") + } + if got := cfg.CustomProviders["stepfun"].AuthHeader; got != "authorization" { + t.Errorf("AuthHeader = %q, want unchanged %q", got, "authorization") + } +} + +func TestProviderTUI_EditCustomProviderSaveRejectsDuplicateRename(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{ + CustomProviders: map[string]ProviderEntry{ + "stepfun": { + URL: "https://stepfun.example.com", + Protocol: "anthropic", + }, + "other": { + URL: "https://other.example.com", + Protocol: "openai", + }, + }, + } + m := newProviderTUI(cfg, configPath) + m.activeTab = tabCustom + m.editingCustom = true + m.editTargetName = "other" + m.cpProtocolIdx = 1 // openai + m.cpNameInput.SetValue("stepfun") + m.cpURLInput.SetValue("https://other.example.com") + + err := m.applyEditCustomProviderSave() + if err == nil { + t.Fatal("expected error when renaming to existing provider name") + } + if !strings.Contains(m.formError, "stepfun") { + t.Errorf("formError = %q, want to mention stepfun", m.formError) + } + if _, ok := cfg.CustomProviders["other"]; !ok { + t.Error("original provider 'other' should still exist") + } + if cfg.CustomProviders["other"].URL != "https://other.example.com" { + t.Errorf("provider 'other' URL = %q, want unchanged", cfg.CustomProviders["other"].URL) + } +} + +func TestProviderTUI_CustomFormCreateReturnsToModelList(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{} + m := newProviderTUI(cfg, configPath) + + result, _ := m.Update(rightKey()) + m2 := result.(providerTUIModel) + result, _ = m2.Update(enterKey()) + m3 := result.(providerTUIModel) + + m3.cpNameInput.SetValue("my-new") + result, _ = m3.Update(enterKey()) // name -> protocol + m4 := result.(providerTUIModel) + result, _ = m4.Update(enterKey()) // protocol -> URL + m5 := result.(providerTUIModel) + m5.cpURLInput.SetValue("https://api.example.com") + result, _ = m5.Update(enterKey()) // URL -> API key + m6 := result.(providerTUIModel) + m6.apiKeyInput.SetValue("key-123") + result, _ = m6.Update(enterKey()) // API key -> auth header + m7 := result.(providerTUIModel) + result, cmd := m7.Update(enterKey()) // auth header -> save + m8 := result.(providerTUIModel) + + if cmd != nil { + t.Error("create should not quit TUI") + } + if m8.creatingCustom { + t.Error("creatingCustom should be false after create") + } + // Create should drop the user into the model selection step for the new + // provider so they can pick/add a model right away. + if m8.step != stepModel { + t.Errorf("step = %d, want stepModel", m8.step) + } + if len(m8.customProviders) != 1 { + t.Fatalf("expected 1 custom provider, got %d", len(m8.customProviders)) + } + if m8.customProviders[0].name != "my-new" { + t.Errorf("provider name = %q, want %q", m8.customProviders[0].name, "my-new") + } + if cfg.Provider != "" { + t.Error("active provider should not be set when only creating") + } + if !m8.savedInSession { + t.Error("savedInSession should be true after create") + } + if _, err := os.Stat(configPath); err != nil { + t.Fatalf("config should be saved: %v", err) + } +} + func TestProviderTUI_CustomProviderExistsInList(t *testing.T) { cfg := &Config{ Provider: "my-llm", @@ -448,7 +746,7 @@ func TestProviderTUI_CustomProviderExistsInList(t *testing.T) { }, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") if m.activeTab != tabCustom { t.Fatalf("should auto-select custom tab, got %d", m.activeTab) @@ -474,7 +772,7 @@ func TestProviderTUI_SelectExistingCustomGoesToModel(t *testing.T) { }, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Enter on existing custom provider should go to model selection first. result, _ := m.Update(enterKey()) @@ -482,8 +780,9 @@ func TestProviderTUI_SelectExistingCustomGoesToModel(t *testing.T) { if m2.step != stepModel { t.Errorf("step = %d, want %d (stepModel)", m2.step, stepModel) } - if m2.models()[0] != "custom-model" { - t.Errorf("first model = %q, want %q", m2.models()[0], "custom-model") + gotModels := m2.models() + if len(gotModels) != 2 || gotModels[0] != "custom-model" || gotModels[1] != "custom-fast" { + t.Errorf("models = %v, want [custom-model custom-fast] (config order)", gotModels) } } @@ -555,7 +854,7 @@ func TestProviderTUI_DeleteCustomProvider(t *testing.T) { "my-llm": {URL: "https://custom.api/v1", Protocol: "openai", Model: "custom-model"}, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Switch to custom tab result, _ := m.Update(rightKey()) @@ -595,8 +894,9 @@ func TestProviderTUI_DeleteCustomProviderCancel(t *testing.T) { "my-llm": {URL: "https://custom.api/v1", Protocol: "openai", Model: "custom-model"}, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") + // Force custom tab so this test is independent of init-time tab routing. // Switch to custom tab, select provider, press d result, _ := m.Update(rightKey()) m2 := result.(providerTUIModel) @@ -627,7 +927,7 @@ func TestProviderTUI_DeleteOnAddOptionIgnored(t *testing.T) { "my-llm": {URL: "https://custom.api/v1", Protocol: "openai"}, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Switch to custom tab result, _ := m.Update(rightKey()) @@ -649,7 +949,7 @@ func TestProviderTUI_DeleteActiveCustomProvider(t *testing.T) { "my-llm": {URL: "https://custom.api/v1", Protocol: "openai", Model: "custom-model"}, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") // Should auto-select custom tab with active provider if m.activeTab != tabCustom { @@ -678,7 +978,7 @@ func TestProviderTUI_DeleteEscCancels(t *testing.T) { "my-llm": {URL: "https://custom.api/v1", Protocol: "openai"}, }, } - m := newProviderTUI(cfg) + m := newProviderTUI(cfg, "") result, _ := m.Update(rightKey()) m2 := result.(providerTUIModel) @@ -696,3 +996,300 @@ func TestProviderTUI_DeleteEscCancels(t *testing.T) { t.Error("no providers should be deleted after Esc") } } + +func TestActiveModelForProvider_PrefersEntryModel(t *testing.T) { + cfg := &Config{Provider: "stepfun", Model: "step-3.7-flash"} + entry := ProviderEntry{Model: "step-3.5-flash"} + got := activeModelForProvider(cfg, "stepfun", entry) + if got != "step-3.5-flash" { + t.Errorf("got %q, want step-3.5-flash", got) + } +} + +func TestActiveModelForProvider_FallsBackToCfgModel(t *testing.T) { + cfg := &Config{Provider: "stepfun", Model: "step-3.5-flash"} + entry := ProviderEntry{} + got := activeModelForProvider(cfg, "stepfun", entry) + if got != "step-3.5-flash" { + t.Errorf("got %q, want step-3.5-flash", got) + } +} + +func TestProviderTUI_CustomModelInput_AddsSingleName(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{ + Provider: "stepfun", + Model: "step-3.5-flash", + CustomProviders: map[string]ProviderEntry{ + "stepfun": { + URL: "https://api.stepfun.com/v1", + Model: "step-3.5-flash", + Models: []string{"step-3.5-flash"}, + }, + }, + } + m := newProviderTUI(cfg, configPath) + m.activeTab = tabCustom + m.customIdx = 0 + m.step = stepModel + m.modelIdx = len(m.models()) // land on "Enter custom model name..." + m.customModel = true + m.modelInput.SetValue("newmodel") + m.modelInput.Focus() + + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + + if m2.customModel { + t.Error("customModel should be cleared after Enter") + } + if m2.formError != "" { + t.Errorf("formError = %q, want empty", m2.formError) + } + got := m2.existingCfg.CustomProviders["stepfun"].Models + want := []string{"step-3.5-flash", "newmodel"} + if len(got) != len(want) || got[0] != want[0] || got[1] != want[1] { + t.Errorf("Models = %v, want %v", got, want) + } + if !m2.savedInSession { + t.Error("savedInSession should be true after add") + } + + diskCfg, err := loadOrCreateConfig(configPath) + if err != nil { + t.Fatalf("load disk config: %v", err) + } + diskModels := diskCfg.CustomProviders["stepfun"].Models + if len(diskModels) != 2 || diskModels[1] != "newmodel" { + t.Errorf("disk Models = %v, want last=step-3.5-flash,newmodel", diskModels) + } +} + +func TestProviderTUI_CustomModelInput_RejectsDuplicate(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{ + Provider: "stepfun", + Model: "step-3.5-flash", + CustomProviders: map[string]ProviderEntry{ + "stepfun": { + URL: "https://api.stepfun.com/v1", + Model: "step-3.5-flash", + Models: []string{"step-3.5-flash"}, + }, + }, + } + m := newProviderTUI(cfg, configPath) + m.activeTab = tabCustom + m.customIdx = 0 + m.step = stepModel + m.modelIdx = len(m.models()) + m.customModel = true + m.modelInput.SetValue("step-3.5-flash") + m.modelInput.Focus() + + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + + if !m2.customModel { + t.Error("customModel should stay true after duplicate reject") + } + if m2.formError != "Already in list: step-3.5-flash" { + t.Errorf("formError = %q, want %q", m2.formError, "Already in list: step-3.5-flash") + } + if m2.modelInput.Value() != "step-3.5-flash" { + t.Errorf("input should be preserved on dup; got %q", m2.modelInput.Value()) + } + if len(m2.existingCfg.CustomProviders["stepfun"].Models) != 1 { + t.Errorf("Models mutated: %v", m2.existingCfg.CustomProviders["stepfun"].Models) + } + if _, err := os.Stat(configPath); err == nil { + t.Errorf("disk file should not exist; duplicate did not persist") + } + if m2.savedInSession { + t.Error("savedInSession should be false after rejected duplicate") + } +} + +func TestProviderTUI_ManualFormPassesKToAuthHeaderInput(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{Llm: LlmConfig{URL: "https://example.com/v1", Model: "m", AuthToken: "k"}} + m := newProviderTUI(cfg, configPath) + m.activeTab = tabManual + m.inManualForm = true + m.manualStep = manualStepAuthHeader + m.manualAuthHeaderInput.Focus() + + result, _ := m.Update(charKey('x')) + m2 := result.(providerTUIModel) + result, _ = m2.Update(charKey('-')) + m3 := result.(providerTUIModel) + result, _ = m3.Update(charKey('a')) + m4 := result.(providerTUIModel) + result, _ = m4.Update(charKey('p')) + m5 := result.(providerTUIModel) + result, _ = m5.Update(charKey('i')) + m6 := result.(providerTUIModel) + result, _ = m6.Update(charKey('-')) + m7 := result.(providerTUIModel) + result, _ = m7.Update(charKey('k')) + m8 := result.(providerTUIModel) + result, _ = m8.Update(charKey('e')) + m9 := result.(providerTUIModel) + result, _ = m9.Update(charKey('y')) + m10 := result.(providerTUIModel) + + if got := m10.manualAuthHeaderInput.Value(); got != "x-api-key" { + t.Errorf("manualAuthHeaderInput.Value() = %q, want %q", got, "x-api-key") + } +} + +func TestProviderTUI_CustomFormPassesKToAuthHeaderInput(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{} + m := newProviderTUI(cfg, configPath) + m.creatingCustom = true + m.cpStep = cpStepAuthHeader + m.cpAuthInput.Focus() + + result, _ := m.Update(charKey('k')) + m2 := result.(providerTUIModel) + result, _ = m2.Update(charKey('e')) + m3 := result.(providerTUIModel) + result, _ = m3.Update(charKey('y')) + m4 := result.(providerTUIModel) + + if got := m4.cpAuthInput.Value(); got != "key" { + t.Errorf("cpAuthInput.Value() = %q, want %q", got, "key") + } +} + +func TestProviderTUI_DeleteModelPreservesActiveModel(t *testing.T) { + cfg := &Config{ + Provider: "stepfun", + Model: "step-3.5-flash", + CustomProviders: map[string]ProviderEntry{ + "stepfun": { + Model: "step-3.5-flash", + Models: []string{"step-3.5-flash", "aaa"}, + }, + }, + } + m := newProviderTUI(cfg, "") + m.activeTab = tabCustom + m.customIdx = 0 + m.step = stepModel + m.modelIdx = 1 // aaa + + m.confirmingDeleteModel = true + m.deleteModelName = "aaa" + result, _ := m.Update(yKey()) + m2 := result.(providerTUIModel) + + if m2.existingCfg.CustomProviders["stepfun"].Model != "step-3.5-flash" { + t.Errorf("entry.Model = %q, want step-3.5-flash", m2.existingCfg.CustomProviders["stepfun"].Model) + } + if m2.existingCfg.Model != "step-3.5-flash" { + t.Errorf("cfg.Model = %q, want step-3.5-flash", m2.existingCfg.Model) + } + if !m2.savedInSession { + t.Error("savedInSession should be true after deleting a model") + } +} + +func TestApplyCustomProviderConfigPreservesModelOrder(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + models := []string{"test-model", "test-model-2", "bbb", "aaa", "test-model-3"} + cfg := &Config{ + Provider: "test-provider", + Model: "test-model-2", + CustomProviders: map[string]ProviderEntry{ + "test-provider": { + Model: "test-model-2", + Models: append([]string(nil), models...), + }, + }, + } + if err := saveConfig(configPath, cfg); err != nil { + t.Fatalf("saveConfig: %v", err) + } + + result := providerTUIResult{ + provider: "test-provider", + model: "test-model-3", + models: append([]string(nil), models...), + isCustom: true, + isEdit: true, + } + if err := applyCustomProviderConfig(configPath, cfg, result); err != nil { + t.Fatalf("applyCustomProviderConfig: %v", err) + } + + got := cfg.CustomProviders["test-provider"].Models + if len(got) != len(models) { + t.Fatalf("Models length = %d, want %d: %v", len(got), len(models), got) + } + for i := range models { + if got[i] != models[i] { + t.Errorf("Models[%d] = %q, want %q", i, got[i], models[i]) + } + } + if cfg.CustomProviders["test-provider"].Model != "test-model-3" { + t.Errorf("entry.Model = %q, want test-model-3", cfg.CustomProviders["test-provider"].Model) + } +} + +func TestApplyManualConfigNormalizesAuthHeader(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{} + + result := providerTUIResult{ + isManual: true, + url: "https://example.com/v1", + model: "test-model", + apiKey: "token", + protocol: "anthropic", + authHeader: "X-Api-Key", + } + if err := applyManualConfig(configPath, cfg, result); err != nil { + t.Fatalf("applyManualConfig: %v", err) + } + if got := cfg.Llm.AuthHeader; got != "x-api-key" { + t.Errorf("Llm.AuthHeader = %q, want %q", got, "x-api-key") + } + useAnthropic := true + if cfg.Llm.UseAnthropic == nil || *cfg.Llm.UseAnthropic != useAnthropic { + t.Errorf("UseAnthropic = %v, want %v", cfg.Llm.UseAnthropic, useAnthropic) + } +} + +func TestApplyCustomProviderConfigNormalizesAuthHeader(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + cfg := &Config{ + CustomProviders: map[string]ProviderEntry{ + "test-provider": {URL: "https://example.com", Model: "m"}, + }, + } + + result := providerTUIResult{ + provider: "test-provider", + model: "m", + url: "https://example.com", + protocol: "anthropic", + authHeader: "Authorization", + isCustom: true, + isEdit: true, + } + if err := applyCustomProviderConfig(configPath, cfg, result); err != nil { + t.Fatalf("applyCustomProviderConfig: %v", err) + } + if got := cfg.CustomProviders["test-provider"].AuthHeader; got != "authorization" { + t.Errorf("AuthHeader = %q, want %q", got, "authorization") + } +}