From 87b27c0df0e01f55421ce96658c7fd4d21b41a41 Mon Sep 17 00:00:00 2001 From: Gergely Radics Date: Mon, 18 May 2026 15:47:05 +0200 Subject: [PATCH] feat: add multi-provider support (Gemini + OpenAI) with per-provider embedder config Introduces a provider abstraction so the AI backend can be swapped via config.yaml without code changes. Gemini remains the default; OpenAI (and any OpenAI-compatible endpoint) is the new alternative. - New `internal/ai/providers` package with a `Provider` interface and factory; Gemini and OpenAI implementations backed by genkit's native plugins. - Config gains `provider`, `providers.gemini/openai` (api_key + base_url) blocks; legacy `api_keys.gemini` is migrated automatically on load. - RAG embedder is independently configurable: `embedder_provider`, `embedder_base_url`, and `embedding_model` under `capabilities.rag`. - `internal/ai/gemini/config.go` removed; all logic lives in the providers package. - Gemini-specific settings (google_search, thinking_level) are silently ignored for non-Gemini providers. Co-Authored-By: Claude Sonnet 4.6 --- cmd/server-bot/main.go | 84 ++++++++++++++++------ config.yaml.example | 23 ++++-- go.mod | 5 ++ go.sum | 12 ++++ internal/ai/gemini/config.go | 113 ------------------------------ internal/ai/providers/gemini.go | 89 +++++++++++++++++++++++ internal/ai/providers/openai.go | 65 +++++++++++++++++ internal/ai/providers/provider.go | 43 ++++++++++++ internal/config/config.go | 62 +++++++++++++--- 9 files changed, 345 insertions(+), 151 deletions(-) delete mode 100644 internal/ai/gemini/config.go create mode 100644 internal/ai/providers/gemini.go create mode 100644 internal/ai/providers/openai.go create mode 100644 internal/ai/providers/provider.go diff --git a/cmd/server-bot/main.go b/cmd/server-bot/main.go index d3ab052..c1c7581 100644 --- a/cmd/server-bot/main.go +++ b/cmd/server-bot/main.go @@ -14,13 +14,14 @@ import ( "hairy-botter/internal/ai/adapters" "hairy-botter/internal/ai/agent" - "hairy-botter/internal/ai/gemini" + "hairy-botter/internal/ai/providers" "hairy-botter/internal/config" "hairy-botter/internal/history" "hairy-botter/internal/mcpserver" "hairy-botter/internal/rag" "hairy-botter/internal/server" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/genkit" ) @@ -75,11 +76,37 @@ func main() { return } - if cfg.APIKeys.Gemini == "" { - logger.Error("GEMINI_API_KEY is not set in config or env") + // Build the main AI provider. + mainProvider, err := providers.New(cfg.Provider, providerCfgFor(cfg.Provider, cfg)) + if err != nil { + logger.Error("failed to create AI provider", slog.String("provider", cfg.Provider), slog.String("err", err.Error())) return } + // Build the embedder provider (may differ from the main provider). + embedderProviderName := cfg.Capabilities.Rag.EmbedderProvider + embedderCfg := providerCfgFor(embedderProviderName, cfg) + if cfg.Capabilities.Rag.EmbedderBaseURL != "" { + embedderCfg.BaseURL = cfg.Capabilities.Rag.EmbedderBaseURL + } + embedderProvider, err := providers.New(embedderProviderName, embedderCfg) + if err != nil { + logger.Error("failed to create embedder provider", slog.String("provider", embedderProviderName), slog.String("err", err.Error())) + return + } + + // Collect distinct plugins for genkit.Init. + g := genkit.Init(context.Background(), genkit.WithPlugins(distinctPlugins(mainProvider.Plugin(), embedderProvider.Plugin())...)) + + model, err := mainProvider.Model(g, cfg.Model) + if err != nil { + logger.Error("failed to define model", slog.String("err", err.Error())) + return + } + + searchEnable := !cfg.GeminiSearchDisabled + customModelConfig := mainProvider.GenerateOptions(cfg.Model, searchEnable, cfg.GeminiThinkingLevel) + historySummary := 0 if cfg.Capabilities.HistorySummary.Enabled { historySummary = 20 @@ -101,30 +128,14 @@ func main() { } logger.Info("MCP servers from config", slog.Int("count", len(mcpServers))) - searchEnable := !cfg.GeminiSearchDisabled - - // Initialize the Gemini AI logic - ga := gemini.ConfigPlugin(cfg.APIKeys.Gemini) - g := genkit.Init(context.Background(), genkit.WithPlugins(ga)) - - model, err := gemini.ConfigModel(g, ga, cfg.Model) - if err != nil { - logger.Error("failed to define model", slog.String("err", err.Error())) - return - } - customModelConfig := gemini.GenerateOptions(cfg.Model, searchEnable, cfg.GeminiThinkingLevel) - var ragL *rag.Logic if cfg.Capabilities.Rag.Enabled && cfg.Capabilities.Rag.Directory != "" { - // First load/created embedder config - embedder, err := gemini.ConfigEmbedder(g, ga, cfg.Capabilities.Rag.EmbeddingModel) + embedder, err := embedderProvider.Embedder(g, cfg.Capabilities.Rag.EmbeddingModel) if err != nil { logger.Error("failed to define embedder", slog.String("err", err.Error())) - return } - // Init the RAG ragL, err = rag.New(logger, cfg.Capabilities.Rag.Directory, adapters.NewEmbedder(g, embedder)) if err != nil { logger.Error("failed to create RAG logic", slog.String("err", err.Error())) @@ -141,7 +152,6 @@ func main() { aiLogic, err := agent.New(logger, g, model, hist, mcpServers, ragL, systemPrompt, customModelConfig, cfg.Context) if err != nil { logger.Error("failed to create AI logic", slog.String("err", err.Error())) - return } @@ -228,7 +238,7 @@ func main() { stopCh := make(chan os.Signal, 1) signal.Notify(stopCh, os.Interrupt, syscall.SIGTERM) - finishedCh := make(chan struct{}) // Signal the end of the graceful shutdown + finishedCh := make(chan struct{}) go func() { <-stopCh logger.Info("shutting down server") @@ -262,8 +272,36 @@ func main() { logger.Error("server failed", slog.String("err", err.Error())) } } else { - // Wait if no chat proxy to keep agent alive logger.Info("agent running without chat proxy") } <-finishedCh } + +// providerCfgFor returns the providers.Config for the named provider from cfg. +func providerCfgFor(name string, cfg *config.Config) providers.Config { + switch name { + case "openai": + return providers.Config{ + APIKey: cfg.Providers.OpenAI.APIKey, + BaseURL: cfg.Providers.OpenAI.BaseURL, + } + default: // gemini + return providers.Config{ + APIKey: cfg.Providers.Gemini.APIKey, + BaseURL: cfg.Providers.Gemini.BaseURL, + } + } +} + +// distinctPlugins returns a deduplicated slice of api.Plugin by name. +func distinctPlugins(ps ...api.Plugin) []api.Plugin { + seen := make(map[string]bool, len(ps)) + out := make([]api.Plugin, 0, len(ps)) + for _, p := range ps { + if !seen[p.Name()] { + seen[p.Name()] = true + out = append(out, p) + } + } + return out +} diff --git a/config.yaml.example b/config.yaml.example index 3a2b744..2be80c4 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -22,8 +22,12 @@ agent_config: # --------------------------------------------------------- # 3. Universal Settings (Evaluated by all modes) # --------------------------------------------------------- -model: "gemini-flash-latest" -gemini_search_disabled: true # Gemini specific config + +# Provider for the main AI model: "gemini" (default) or "openai" +provider: "gemini" + +model: "gemini-flash-latest" # gemini: e.g. "gemini-2.5-flash"; openai: e.g. "gpt-4o" +gemini_search_disabled: true # Gemini-specific; ignored for other providers log_level: warning @@ -37,7 +41,9 @@ capabilities: rag: enabled: true directory: "./knowledge_base" - embedding_model: "gemini-embedding-001" + # embedder_provider: "gemini" # defaults to top-level provider; can be different + # embedder_base_url: "" # overrides the provider's base_url for the embedder only + embedding_model: "gemini-embedding-001" # openai default: "text-embedding-3-small" history_summary: enabled: true message_count: 20 @@ -80,7 +86,16 @@ context: # args: ["--city", "New York"] # --------------------------------------------------------- -# API Keys (Fallback config, can be overridden by env vars) +# 5. Provider Credentials +# API keys can also be set via env vars: GEMINI_API_KEY, OPENAI_API_KEY # --------------------------------------------------------- +# providers: +# gemini: +# api_key: "your_gemini_api_key_here" +# openai: +# api_key: "your_openai_api_key_here" +# base_url: "" # optional; override for any OpenAI-compatible endpoint + +# Legacy (still supported, merged into providers.gemini.api_key automatically): # api_keys: # gemini: "your_gemini_api_key_here" diff --git a/go.mod b/go.mod index 052983c..d01b9c4 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,12 @@ require ( github.com/mattn/go-colorable v0.1.2 // indirect github.com/mattn/go-isatty v0.0.8 // indirect github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect + github.com/openai/openai-go v1.8.2 // indirect github.com/spf13/cast v1.7.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/go.sum b/go.sum index a4bd16a..143a1c6 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= +github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8= +github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -80,6 +82,16 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= diff --git a/internal/ai/gemini/config.go b/internal/ai/gemini/config.go deleted file mode 100644 index 0781f8e..0000000 --- a/internal/ai/gemini/config.go +++ /dev/null @@ -1,113 +0,0 @@ -package gemini - -import ( - "strings" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core/api" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -// AgentConfigurator . -type AgentConfigurator interface { - api.Plugin - modelDefiner - modelEmbedder -} -type modelDefiner interface { - DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOptions) (ai.Model, error) -} - -type modelEmbedder interface { - DefineEmbedder(g *genkit.Genkit, name string, embedOpts *ai.EmbedderOptions) (ai.Embedder, error) -} - -func ConfigPlugin(apiKey string) AgentConfigurator { - return &googlegenai.GoogleAI{APIKey: apiKey} -} - -// ConfigModel . -func ConfigModel(g *genkit.Genkit, ga modelDefiner, modelName string) (ai.Model, error) { - if modelName == "" { - modelName = "gemini-flash-latest" - } - - // Try the known-model path first (nil opts = look up from plugin's registry). - // -latest aliases and next-gen model names are not in the registry, so fall back - // to generic multimodal options so the caller still gets a usable model. - model, err := ga.DefineModel(g, modelName, nil) - if err != nil { - model, err = ga.DefineModel(g, modelName, &ai.ModelOptions{ - Supports: &googlegenai.Multimodal, - Stage: ai.ModelStageUnstable, - }) - } - - return model, err -} - -// ConfigEmbedder . -func ConfigEmbedder(g *genkit.Genkit, ga modelEmbedder, modelName string) (ai.Embedder, error) { - if modelName == "" { - modelName = "gemini-embedding-001" - } - - embedder, err := ga.DefineEmbedder(g, modelName, &ai.EmbedderOptions{}) - if err != nil { - return nil, err - } - - return embedder, nil -} - -// GenerateOptions returns Gemini-specific generate options (thinking config, Google Search). -// Returns nil when neither feature is requested so no provider-specific config is sent -// to models that don't support it (e.g. older flash models without thinking support). -// modelName is used to filter out thinking levels unsupported by the specific model variant: -// MINIMAL is only valid for Flash models (e.g. gemini-2.5-flash); Pro models only support LOW/MEDIUM/HIGH. -func GenerateOptions(modelName string, searchEnable bool, thinkingLevel string) []ai.GenerateOption { - cfg := &genai.GenerateContentConfig{} - hasConfig := false - - if thinkingLevel != "" { - isFlash := strings.Contains(strings.ToLower(modelName), "flash") - var level genai.ThinkingLevel - switch thinkingLevel { - case "NONE", "MINIMAL": - // Gemini 3 Flash cannot fully disable thinking; MINIMAL is as low as it goes. - // NONE is accepted as an alias so callers can express intent clearly. - // Pro models do not support MINIMAL — skip silently for them. - if isFlash { - level = genai.ThinkingLevelMinimal - } - case "LOW": - level = genai.ThinkingLevelLow - case "MEDIUM": - level = genai.ThinkingLevelMedium - case "HIGH": - level = genai.ThinkingLevelHigh - } - if level != "" { - cfg.ThinkingConfig = &genai.ThinkingConfig{ThinkingLevel: level} - hasConfig = true - } - } - - if searchEnable { - ist := true - cfg.Tools = []*genai.Tool{ - {GoogleSearch: &genai.GoogleSearch{}}, - } - cfg.ToolConfig = &genai.ToolConfig{ - IncludeServerSideToolInvocations: &ist, - } - hasConfig = true - } - - if !hasConfig { - return nil - } - return []ai.GenerateOption{ai.WithConfig(cfg)} -} diff --git a/internal/ai/providers/gemini.go b/internal/ai/providers/gemini.go new file mode 100644 index 0000000..5210d1e --- /dev/null +++ b/internal/ai/providers/gemini.go @@ -0,0 +1,89 @@ +package providers + +import ( + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +type geminiProvider struct { + plugin *googlegenai.GoogleAI +} + +func newGemini(cfg Config) (Provider, error) { + return &geminiProvider{ + plugin: &googlegenai.GoogleAI{APIKey: cfg.APIKey}, + }, nil +} + +func (p *geminiProvider) Plugin() api.Plugin { return p.plugin } + +func (p *geminiProvider) Model(g *genkit.Genkit, name string) (ai.Model, error) { + if name == "" { + name = "gemini-flash-latest" + } + model, err := p.plugin.DefineModel(g, name, nil) + if err != nil { + model, err = p.plugin.DefineModel(g, name, &ai.ModelOptions{ + Supports: &googlegenai.Multimodal, + Stage: ai.ModelStageUnstable, + }) + } + return model, err +} + +func (p *geminiProvider) Embedder(g *genkit.Genkit, name string) (ai.Embedder, error) { + if name == "" { + name = "gemini-embedding-001" + } + return p.plugin.DefineEmbedder(g, name, &ai.EmbedderOptions{}) +} + +// GenerateOptions returns Gemini-specific generate options (thinking config, Google Search). +// Returns nil when neither feature is requested. +// MINIMAL thinking is only valid for Flash models; silently skipped for Pro. +func (p *geminiProvider) GenerateOptions(modelName string, searchEnable bool, thinkingLevel string) []ai.GenerateOption { + cfg := &genai.GenerateContentConfig{} + hasConfig := false + + if thinkingLevel != "" { + isFlash := strings.Contains(strings.ToLower(modelName), "flash") + var level genai.ThinkingLevel + switch thinkingLevel { + case "NONE", "MINIMAL": + if isFlash { + level = genai.ThinkingLevelMinimal + } + case "LOW": + level = genai.ThinkingLevelLow + case "MEDIUM": + level = genai.ThinkingLevelMedium + case "HIGH": + level = genai.ThinkingLevelHigh + } + if level != "" { + cfg.ThinkingConfig = &genai.ThinkingConfig{ThinkingLevel: level} + hasConfig = true + } + } + + if searchEnable { + ist := true + cfg.Tools = []*genai.Tool{ + {GoogleSearch: &genai.GoogleSearch{}}, + } + cfg.ToolConfig = &genai.ToolConfig{ + IncludeServerSideToolInvocations: &ist, + } + hasConfig = true + } + + if !hasConfig { + return nil + } + return []ai.GenerateOption{ai.WithConfig(cfg)} +} diff --git a/internal/ai/providers/openai.go b/internal/ai/providers/openai.go new file mode 100644 index 0000000..906d371 --- /dev/null +++ b/internal/ai/providers/openai.go @@ -0,0 +1,65 @@ +package providers + +import ( + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai/openai" + openaiopt "github.com/openai/openai-go/option" +) + +type openaiProvider struct { + plugin *openai.OpenAI +} + +func newOpenAI(cfg Config) (Provider, error) { + var extraOpts []openaiopt.RequestOption + if cfg.BaseURL != "" { + extraOpts = append(extraOpts, openaiopt.WithBaseURL(cfg.BaseURL)) + } + return &openaiProvider{ + plugin: &openai.OpenAI{ + APIKey: cfg.APIKey, + Opts: extraOpts, + }, + }, nil +} + +func (p *openaiProvider) Plugin() api.Plugin { return p.plugin } + +func (p *openaiProvider) Model(g *genkit.Genkit, name string) (ai.Model, error) { + if name == "" { + name = "gpt-4o" + } + model := p.plugin.Model(g, name) + if model != nil { + return model, nil + } + // Unknown model name — define it with generic multimodal capabilities. + return p.plugin.DefineModel(name, ai.ModelOptions{ + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: true, + }, + Stage: ai.ModelStageUnstable, + }), nil +} + +func (p *openaiProvider) Embedder(g *genkit.Genkit, name string) (ai.Embedder, error) { + if name == "" { + name = "text-embedding-3-small" + } + embedder := p.plugin.Embedder(g, name) + if embedder != nil { + return embedder, nil + } + // Unknown embedder name — define it with generic options. + return p.plugin.DefineEmbedder(name, &ai.EmbedderOptions{}), nil +} + +// GenerateOptions returns nil — OpenAI has no Gemini-specific extras. +func (p *openaiProvider) GenerateOptions(_ string, _ bool, _ string) []ai.GenerateOption { + return nil +} diff --git a/internal/ai/providers/provider.go b/internal/ai/providers/provider.go new file mode 100644 index 0000000..83d2d8e --- /dev/null +++ b/internal/ai/providers/provider.go @@ -0,0 +1,43 @@ +// Package providers abstracts AI provider initialization behind a common interface. +// Supported providers: gemini, openai (and any OpenAI-compatible endpoint via base_url). +package providers + +import ( + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" +) + +// Provider is the common interface that every AI backend must satisfy. +type Provider interface { + // Plugin returns the genkit plugin to register with genkit.Init. + Plugin() api.Plugin + // Model looks up or defines the named model on the given genkit instance. + Model(g *genkit.Genkit, name string) (ai.Model, error) + // Embedder looks up or defines the named embedder on the given genkit instance. + Embedder(g *genkit.Genkit, name string) (ai.Embedder, error) + // GenerateOptions returns provider-specific generation options (e.g. Google Search, thinking). + // Returns nil when the provider has no extra options for the given parameters. + GenerateOptions(modelName string, searchEnable bool, thinkingLevel string) []ai.GenerateOption +} + +// Config holds credentials and endpoint for a single provider. +type Config struct { + APIKey string + BaseURL string // optional; empty means provider default +} + +// New returns the Provider implementation for the given name. +// name must be "gemini" or "openai". +func New(name string, cfg Config) (Provider, error) { + switch name { + case "gemini", "": + return newGemini(cfg) + case "openai": + return newOpenAI(cfg) + default: + return nil, fmt.Errorf("unknown provider %q; supported: gemini, openai", name) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 2ad74f6..88b58f6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( type Config struct { RunMode string `yaml:"run_mode"` AgentConfig AgentConfig `yaml:"agent_config"` + Provider string `yaml:"provider"` // gemini (default) | openai Model string `yaml:"model"` GeminiSearchDisabled bool `yaml:"gemini_search_disabled"` GeminiThinkingLevel string `yaml:"gemini_thinking_level"` @@ -18,7 +19,8 @@ type Config struct { Personality PersonalityConfig `yaml:"personality"` Capabilities CapabilitiesConfig `yaml:"capabilities"` Context ContextConfig `yaml:"context"` - APIKeys APIKeysConfig `yaml:"api_keys"` // Keep this distinct for overriding + APIKeys APIKeysConfig `yaml:"api_keys"` // legacy; merged into Providers on load + Providers ProvidersConfig `yaml:"providers"` } // AgentConfig holds settings specific to the 'agent' run mode. @@ -49,9 +51,11 @@ type CapabilitiesConfig struct { // RagConfig specifies Retrieval-Augmented Generation settings. type RagConfig struct { - Enabled bool `yaml:"enabled"` - Directory string `yaml:"directory"` - EmbeddingModel string `yaml:"embedding_model"` + Enabled bool `yaml:"enabled"` + Directory string `yaml:"directory"` + EmbedderProvider string `yaml:"embedder_provider"` // gemini | openai; defaults to top-level provider + EmbedderBaseURL string `yaml:"embedder_base_url"` // overrides provider base_url for the embedder + EmbeddingModel string `yaml:"embedding_model"` } // HistorySummaryConfig controls conversation history summarization. @@ -81,7 +85,19 @@ type DynamicDataConfig struct { Args []string `yaml:"args"` } -// APIKeysConfig holds API keys to prioritize yaml over env +// ProvidersConfig holds per-provider credentials and endpoints. +type ProvidersConfig struct { + Gemini ProviderConfig `yaml:"gemini"` + OpenAI ProviderConfig `yaml:"openai"` +} + +// ProviderConfig holds the api key and optional base URL for a single provider. +type ProviderConfig struct { + APIKey string `yaml:"api_key"` + BaseURL string `yaml:"base_url"` +} + +// APIKeysConfig holds API keys to prioritize yaml over env (legacy). type APIKeysConfig struct { Gemini string `yaml:"gemini"` } @@ -99,22 +115,46 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("failed to parse config file: %w", err) } - // Apply environment variable fallbacks for API keys - if cfg.APIKeys.Gemini == "" { - cfg.APIKeys.Gemini = os.Getenv("GEMINI_API_KEY") + // Migrate legacy api_keys.gemini into providers.gemini.api_key + if cfg.APIKeys.Gemini != "" && cfg.Providers.Gemini.APIKey == "" { + cfg.Providers.Gemini.APIKey = cfg.APIKeys.Gemini + } + + // Apply environment variable fallbacks + if cfg.Providers.Gemini.APIKey == "" { + cfg.Providers.Gemini.APIKey = os.Getenv("GEMINI_API_KEY") + } + if cfg.Providers.OpenAI.APIKey == "" { + cfg.Providers.OpenAI.APIKey = os.Getenv("OPENAI_API_KEY") + } + + // Set defaults + if cfg.Provider == "" { + cfg.Provider = "gemini" } - // Set defaults if some values are absent if cfg.Model == "" { - cfg.Model = "gemini-flash-latest" + if cfg.Provider == "openai" { + cfg.Model = "gpt-4o" + } else { + cfg.Model = "gemini-flash-latest" + } } if cfg.LogLevel == "" { cfg.LogLevel = "warning" } + if cfg.Capabilities.Rag.EmbedderProvider == "" { + cfg.Capabilities.Rag.EmbedderProvider = cfg.Provider + } + if cfg.Capabilities.Rag.EmbeddingModel == "" { - cfg.Capabilities.Rag.EmbeddingModel = "gemini-embedding-001" + if cfg.Capabilities.Rag.EmbedderProvider == "openai" { + cfg.Capabilities.Rag.EmbeddingModel = "text-embedding-3-small" + } else { + cfg.Capabilities.Rag.EmbeddingModel = "gemini-embedding-001" + } } return &cfg, nil