diff --git a/crates/forge_domain/src/provider.rs b/crates/forge_domain/src/provider.rs index a65b43e416..fedb9a1605 100644 --- a/crates/forge_domain/src/provider.rs +++ b/crates/forge_domain/src/provider.rs @@ -73,6 +73,8 @@ impl ProviderId { pub const FIREWORKS_AI: ProviderId = ProviderId(Cow::Borrowed("fireworks-ai")); pub const NOVITA: ProviderId = ProviderId(Cow::Borrowed("novita")); pub const GOOGLE_AI_STUDIO: ProviderId = ProviderId(Cow::Borrowed("google_ai_studio")); + pub const ADAL: ProviderId = ProviderId(Cow::Borrowed("adal")); + pub const MODAL: ProviderId = ProviderId(Cow::Borrowed("modal")); /// Returns all built-in provider IDs /// @@ -106,6 +108,8 @@ impl ProviderId { ProviderId::FIREWORKS_AI, ProviderId::NOVITA, ProviderId::GOOGLE_AI_STUDIO, + ProviderId::ADAL, + ProviderId::MODAL, ] } @@ -132,6 +136,8 @@ impl ProviderId { "fireworks-ai" => "FireworksAI".to_string(), "novita" => "Novita".to_string(), "google_ai_studio" => "GoogleAIStudio".to_string(), + "adal" => "AdaL".to_string(), + "modal" => "Modal".to_string(), _ => { // For other providers, use UpperCamelCase conversion use convert_case::{Case, Casing}; @@ -176,7 +182,12 @@ impl std::str::FromStr for ProviderId { "codex" => ProviderId::CODEX, "fireworks-ai" => ProviderId::FIREWORKS_AI, "novita" => ProviderId::NOVITA, + "vertex_ai_anthropic" => ProviderId::VERTEX_AI_ANTHROPIC, + "bedrock" => ProviderId::BEDROCK, + "opencode_zen" => ProviderId::OPENCODE_ZEN, "google_ai_studio" => ProviderId::GOOGLE_AI_STUDIO, + "adal" => ProviderId::ADAL, + "modal" => ProviderId::MODAL, // For custom providers, use Cow::Owned to avoid memory leaks custom => ProviderId(Cow::Owned(custom.to_string())), }; @@ -581,6 +592,42 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_adal_from_str() { + let actual = ProviderId::from_str("adal").unwrap(); + let expected = ProviderId::ADAL; + assert_eq!(actual, expected); + } + + #[test] + fn test_adal_display_name() { + assert_eq!(ProviderId::ADAL.to_string(), "AdaL"); + } + + #[test] + fn test_adal_in_built_in_providers() { + let built_in = ProviderId::built_in_providers(); + assert!(built_in.contains(&ProviderId::ADAL)); + } + + #[test] + fn test_modal_from_str() { + let actual = ProviderId::from_str("modal").unwrap(); + let expected = ProviderId::MODAL; + assert_eq!(actual, expected); + } + + #[test] + fn test_modal_display_name() { + assert_eq!(ProviderId::MODAL.to_string(), "Modal"); + } + + #[test] + fn test_modal_in_built_in_providers() { + let built_in = ProviderId::built_in_providers(); + assert!(built_in.contains(&ProviderId::MODAL)); + } + #[test] fn test_io_intelligence() { let fixture = "test_key"; diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 223c892e1c..7c465f3c0a 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -2076,13 +2076,13 @@ impl A + Send + Sync> UI /// selected the model list is scoped to that provider only. /// /// # Returns - /// - `Ok(Some(ModelId))` if a model was selected + /// - `Ok(Some((ProviderId, ModelId)))` if a model was selected /// - `Ok(None)` if selection was canceled #[async_recursion::async_recursion] async fn select_model( &mut self, provider_filter: Option, - ) -> Result> { + ) -> Result> { // Check if provider is set otherwise first ask to select a provider if provider_filter.is_none() && self.api.get_default_provider().await.is_err() { if !self.on_provider_selection().await? { @@ -2182,21 +2182,21 @@ impl A + Send + Sync> UI return Ok(None); } - // Build a flat list of (ModelId, display_line) for the data rows. + // Build a flat list of (ProviderId, ModelId) for the data rows. // The first line is the header; data rows follow in the same order as // the Info entries (sorted by provider, then model within provider). - let mut model_ids: Vec = Vec::new(); + let mut selections: Vec<(ProviderId, ModelId)> = Vec::new(); for pm in &all_provider_models { for model in &pm.models { - model_ids.push(model.id.clone()); + selections.push((pm.provider_id.clone(), model.id.clone())); } } // Create display items: header line first, then data lines paired with - // model IDs. + // provider/model selections. #[derive(Clone)] struct ModelRow { - model_id: Option, + selection: Option<(ProviderId, ModelId)>, display: String, } impl std::fmt::Display for ModelRow { @@ -2207,32 +2207,40 @@ impl A + Send + Sync> UI let mut rows: Vec = Vec::with_capacity(all_lines.len()); // Header row (non-selectable via header_lines=1) - rows.push(ModelRow { model_id: None, display: all_lines[0].to_string() }); + rows.push(ModelRow { selection: None, display: all_lines[0].to_string() }); // Data rows for (i, line) in all_lines.iter().skip(1).enumerate() { rows.push(ModelRow { - model_id: model_ids.get(i).cloned(), + selection: selections.get(i).cloned(), display: line.to_string(), }); } - // Find starting cursor position for the current model. + // Find starting cursor position for the current provider/model pair. // The cursor position is relative to the data rows (header is excluded // by fzf's --header-lines), so index 0 = first data row. - let current_model = self - .get_agent_model(self.api.get_active_agent().await) - .await; - let starting_cursor = current_model - .as_ref() - .and_then(|current| model_ids.iter().position(|id| id == current)) - .unwrap_or(0); + let active_agent = self.api.get_active_agent().await; + let current_model = self.get_agent_model(active_agent.clone()).await; + let current_provider = self.get_provider(active_agent).await.ok().map(|p| p.id); + let starting_cursor = match (current_provider, current_model) { + (Some(provider), Some(model)) => selections + .iter() + .position(|(pid, mid)| pid == &provider && mid == &model) + .or_else(|| selections.iter().position(|(_, mid)| mid == &model)) + .unwrap_or(0), + (_, Some(model)) => selections + .iter() + .position(|(_, mid)| mid == &model) + .unwrap_or(0), + _ => 0, + }; match ForgeWidget::select("Model", rows) .with_starting_cursor(starting_cursor) .with_header_lines(1) .prompt()? { - Some(row) => Ok(row.model_id), + Some(row) => Ok(row.selection), None => Ok(None), } } @@ -2736,22 +2744,31 @@ impl A + Send + Sync> UI provider_filter: Option, provider_to_activate: Option, ) -> Result> { - // Select a model - let model_option = self.select_model(provider_filter).await?; + // Select a model/provider pair + let selection = self.select_model(provider_filter).await?; // If no model was selected (user canceled), return early - let model = match model_option { - Some(model) => model, + let (selected_provider, model) = match selection { + Some(selection) => selection, None => return Ok(None), }; - // If we have a provider to activate, write both atomically + // If we have a provider to activate, write both atomically. + // Otherwise, if the selected model belongs to a different provider, + // switch provider and model together. if let Some(provider_id) = provider_to_activate { self.api .set_default_provider_and_model(provider_id, model.clone()) .await?; } else { - self.api.set_default_model(model.clone()).await?; + let current_provider = self.api.get_default_provider().await.ok().map(|p| p.id); + if current_provider.as_ref() == Some(&selected_provider) { + self.api.set_default_model(model.clone()).await?; + } else { + self.api + .set_default_provider_and_model(selected_provider, model.clone()) + .await?; + } } // Update the UI state with the new model diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 16dc7899f6..8792a7a621 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -3099,5 +3099,34 @@ "input_modalities": ["text"] } ] + }, + { + "id": "adal", + "api_key_vars": "ADAL_API_KEY", + "url_param_vars": [], + "response_type": "OpenAI", + "url": "https://api.sylph.ai/v1/chat/completions", + "models": "https://api.sylph.ai/v1/models", + "auth_methods": ["api_key"] + }, + { + "id": "modal", + "api_key_vars": "MODAL_API_KEY", + "url_param_vars": [], + "response_type": "OpenAI", + "url": "https://api.us-west-2.modal.direct/v1/chat/completions", + "models": [ + { + "id": "zai-org/GLM-5-FP8", + "name": "GLM-5", + "description": "Z.ai's 745B parameter flagship open-source MoE model for long-horizon agents and systems engineering, hosted on Modal", + "context_length": 192000, + "tools_supported": true, + "supports_parallel_tool_calls": true, + "supports_reasoning": true, + "input_modalities": ["text"] + } + ], + "auth_methods": ["api_key"] } ] diff --git a/crates/forge_services/src/agent_registry.rs b/crates/forge_services/src/agent_registry.rs index dc718e3682..1085bc1ade 100644 --- a/crates/forge_services/src/agent_registry.rs +++ b/crates/forge_services/src/agent_registry.rs @@ -14,8 +14,9 @@ pub struct ForgeAgentRegistryService { // Infrastructure dependency for loading agents repository: Arc, - // Startup configuration snapshot used to resolve default provider/model - config: ForgeConfig, + // Runtime configuration used to resolve default provider/model. + // Refreshed from disk on reload so model/provider switches apply immediately. + config: RwLock, // In-memory storage for agents keyed by AgentId string // Lazily initialized on first access @@ -31,7 +32,7 @@ impl ForgeAgentRegistryService { pub fn new(repository: Arc, config: ForgeConfig) -> Self { Self { repository, - config, + config: RwLock::new(config), agents: RwLock::new(None), active_agent_id: RwLock::new(None), } @@ -39,9 +40,21 @@ impl ForgeAgentRegistryService { } impl ForgeAgentRegistryService { - /// Lazily initializes and returns the agents map + /// Reloads Forge config using the same precedence order as startup. + fn load_current_config(&self) -> anyhow::Result { + ForgeConfig::read().map_err(Into::into) + } + + /// Refreshes the in-memory config snapshot from disk/env sources. + async fn refresh_config(&self) -> anyhow::Result<()> { + let config = self.load_current_config()?; + *self.config.write().await = config; + Ok(()) + } + + /// Lazily initializes and returns the agents map. /// Loads agents from repository on first call, subsequent calls return - /// cached value + /// cached value. async fn ensure_agents_loaded(&self) -> anyhow::Result> { // Check if already loaded { @@ -74,11 +87,8 @@ impl ForgeAgentRegistryService { /// them to the repository so agents that do not specify their own /// provider/model receive the session-level defaults. async fn load_agents(&self) -> anyhow::Result> { - let session = self - .config - .session - .as_ref() - .ok_or(Error::NoDefaultProvider)?; + let config = self.config.read().await; + let session = config.session.as_ref().ok_or(Error::NoDefaultProvider)?; let provider_id = session .provider_id .as_ref() @@ -129,9 +139,161 @@ impl forge_app::AgentRegist } async fn reload_agents(&self) -> anyhow::Result<()> { + self.refresh_config().await?; *self.agents.write().await = None; self.ensure_agents_loaded().await?; Ok(()) } } + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::fs; + use std::path::Path; + use std::sync::{Mutex, MutexGuard}; + + use forge_app::AgentRegistry; + use forge_config::{ForgeConfig, ModelConfig}; + use forge_domain::{Agent, AgentId, ConfigOperation, Environment, ModelId, ProviderId}; + use pretty_assertions::assert_eq; + use tempfile::tempdir; + + use super::*; + + static HOME_ENV_MUTEX: Mutex<()> = Mutex::new(()); + + struct HomeEnvGuard { + original_home: Option, + _lock: MutexGuard<'static, ()>, + } + + impl HomeEnvGuard { + fn set(home: &Path) -> Self { + let lock = HOME_ENV_MUTEX.lock().unwrap(); + let original_home = std::env::var("HOME").ok(); + unsafe { + std::env::set_var("HOME", home); + } + Self { + original_home, + _lock: lock, + } + } + } + + impl Drop for HomeEnvGuard { + fn drop(&mut self) { + if let Some(home) = &self.original_home { + unsafe { + std::env::set_var("HOME", home); + } + } else { + unsafe { + std::env::remove_var("HOME"); + } + } + } + } + + #[derive(Default)] + struct MockRepository; + + #[async_trait::async_trait] + impl forge_app::AgentRepository for MockRepository { + async fn get_agents( + &self, + provider_id: ProviderId, + model_id: ModelId, + ) -> anyhow::Result> { + Ok(vec![Agent::new(AgentId::new("forge"), provider_id, model_id)]) + } + } + + impl forge_app::EnvironmentInfra for MockRepository { + type Config = ForgeConfig; + + fn get_env_var(&self, _key: &str) -> Option { + None + } + + fn get_env_vars(&self) -> BTreeMap { + BTreeMap::new() + } + + fn get_environment(&self) -> Environment { + Environment { + os: "test".to_string(), + pid: 0, + cwd: std::path::PathBuf::from("."), + home: None, + shell: "zsh".to_string(), + base_path: std::path::PathBuf::from("."), + } + } + + async fn update_environment(&self, _ops: Vec) -> anyhow::Result<()> { + Ok(()) + } + } + + #[tokio::test] + async fn test_reload_agents_refreshes_provider_model_from_config_file() { + let fixture_home = tempdir().unwrap(); + let _fixture_home_env_guard = HomeEnvGuard::set(fixture_home.path()); + let fixture_forge_dir = fixture_home.path().join("forge"); + fs::create_dir_all(&fixture_forge_dir).unwrap(); + + let fixture_first_config = r#" +[session] +provider_id = "anthropic" +model_id = "claude-3-5-sonnet-20241022" +"#; + fs::write(fixture_forge_dir.join(".forge.toml"), fixture_first_config).unwrap(); + + let fixture_startup_config = ForgeConfig { + session: Some(ModelConfig { + provider_id: Some("openai".to_string()), + model_id: Some("gpt-4".to_string()), + }), + ..Default::default() + }; + + let fixture_repository = Arc::new(MockRepository); + let fixture_service = ForgeAgentRegistryService::new(fixture_repository, fixture_startup_config); + + let fixture_agent_id = AgentId::new("forge"); + let actual_before = fixture_service + .get_agent(&fixture_agent_id) + .await + .unwrap() + .unwrap(); + + let expected_before_provider = ProviderId::OPENAI; + let expected_before_model = ModelId::new("gpt-4"); + assert_eq!(actual_before.provider, expected_before_provider); + assert_eq!(actual_before.model, expected_before_model); + + let fixture_second_config = r#" +[session] +provider_id = "modal" +model_id = "zai-org/GLM-5-FP8" +"#; + fs::write(fixture_forge_dir.join(".forge.toml"), fixture_second_config).unwrap(); + + fixture_service.reload_agents().await.unwrap(); + + let actual_after = fixture_service + .get_agent(&fixture_agent_id) + .await + .unwrap() + .unwrap(); + + let expected_after_provider = ProviderId::MODAL; + let expected_after_model = ModelId::new("zai-org/GLM-5-FP8"); + assert_eq!(actual_after.provider, expected_after_provider); + assert_eq!(actual_after.model, expected_after_model); + } +} + diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 8dbbe6ca85..5d5f8ccbfd 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -114,8 +114,18 @@ impl AppConfigService provider_id: ProviderId, model: ModelId, ) -> anyhow::Result<()> { - self.update(ConfigOperation::SetModel(provider_id, model)) - .await + self.update(ConfigOperation::SetModel( + provider_id.clone(), + model.clone(), + )) + .await?; + + let mut config = self.config.lock().unwrap(); + let session = config.session.get_or_insert_with(Default::default); + session.provider_id = Some(provider_id.as_ref().to_string()); + session.model_id = Some(model.to_string()); + + Ok(()) } async fn get_commit_config(&self) -> anyhow::Result> { @@ -552,4 +562,25 @@ mod tests { assert_eq!(actual_model, ModelId::new("claude-3")); Ok(()) } + + #[tokio::test] + async fn test_set_default_provider_and_model_updates_session_cache() -> anyhow::Result<()> { + let fixture = MockInfra::new(); + let service = ForgeAppConfigService::new(Arc::new(fixture), ForgeConfig::default()); + + service + .set_default_provider_and_model(ProviderId::OPENAI, ModelId::new("gpt-4")) + .await?; + + let actual_provider = service.get_default_provider().await?; + let actual_model = service + .get_provider_model(Some(&ProviderId::OPENAI)) + .await?; + let expected_provider = ProviderId::OPENAI; + let expected_model = ModelId::new("gpt-4"); + + assert_eq!(actual_provider, expected_provider); + assert_eq!(actual_model, expected_model); + Ok(()) + } }