Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions crates/forge_domain/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ impl ProviderId {
pub const FIREWORKS_AI: ProviderId = ProviderId(Cow::Borrowed("fireworks-ai"));
pub const NOVITA: ProviderId = ProviderId(Cow::Borrowed("novita"));
pub const GOOGLE_AI_STUDIO: ProviderId = ProviderId(Cow::Borrowed("google_ai_studio"));
pub const ADAL: ProviderId = ProviderId(Cow::Borrowed("adal"));
pub const MODAL: ProviderId = ProviderId(Cow::Borrowed("modal"));

/// Returns all built-in provider IDs
///
Expand Down Expand Up @@ -106,6 +108,8 @@ impl ProviderId {
ProviderId::FIREWORKS_AI,
ProviderId::NOVITA,
ProviderId::GOOGLE_AI_STUDIO,
ProviderId::ADAL,
ProviderId::MODAL,
]
}

Expand All @@ -132,6 +136,8 @@ impl ProviderId {
"fireworks-ai" => "FireworksAI".to_string(),
"novita" => "Novita".to_string(),
"google_ai_studio" => "GoogleAIStudio".to_string(),
"adal" => "AdaL".to_string(),
"modal" => "Modal".to_string(),
_ => {
// For other providers, use UpperCamelCase conversion
use convert_case::{Case, Casing};
Expand Down Expand Up @@ -176,7 +182,12 @@ impl std::str::FromStr for ProviderId {
"codex" => ProviderId::CODEX,
"fireworks-ai" => ProviderId::FIREWORKS_AI,
"novita" => ProviderId::NOVITA,
"vertex_ai_anthropic" => ProviderId::VERTEX_AI_ANTHROPIC,
"bedrock" => ProviderId::BEDROCK,
"opencode_zen" => ProviderId::OPENCODE_ZEN,
"google_ai_studio" => ProviderId::GOOGLE_AI_STUDIO,
"adal" => ProviderId::ADAL,
"modal" => ProviderId::MODAL,
// For custom providers, use Cow::Owned to avoid memory leaks
custom => ProviderId(Cow::Owned(custom.to_string())),
};
Expand Down Expand Up @@ -581,6 +592,42 @@ mod tests {
assert_eq!(actual, expected);
}

#[test]
fn test_adal_from_str() {
let actual = ProviderId::from_str("adal").unwrap();
let expected = ProviderId::ADAL;
assert_eq!(actual, expected);
}

#[test]
fn test_adal_display_name() {
assert_eq!(ProviderId::ADAL.to_string(), "AdaL");
}

#[test]
fn test_adal_in_built_in_providers() {
let built_in = ProviderId::built_in_providers();
assert!(built_in.contains(&ProviderId::ADAL));
}

#[test]
fn test_modal_from_str() {
let actual = ProviderId::from_str("modal").unwrap();
let expected = ProviderId::MODAL;
assert_eq!(actual, expected);
}

#[test]
fn test_modal_display_name() {
assert_eq!(ProviderId::MODAL.to_string(), "Modal");
}

#[test]
fn test_modal_in_built_in_providers() {
let built_in = ProviderId::built_in_providers();
assert!(built_in.contains(&ProviderId::MODAL));
}

#[test]
fn test_io_intelligence() {
let fixture = "test_key";
Expand Down
65 changes: 41 additions & 24 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2076,13 +2076,13 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
/// selected the model list is scoped to that provider only.
///
/// # Returns
/// - `Ok(Some(ModelId))` if a model was selected
/// - `Ok(Some((ProviderId, ModelId)))` if a model was selected
/// - `Ok(None)` if selection was canceled
#[async_recursion::async_recursion]
async fn select_model(
&mut self,
provider_filter: Option<ProviderId>,
) -> Result<Option<ModelId>> {
) -> Result<Option<(ProviderId, ModelId)>> {
// Check if provider is set otherwise first ask to select a provider
if provider_filter.is_none() && self.api.get_default_provider().await.is_err() {
if !self.on_provider_selection().await? {
Expand Down Expand Up @@ -2182,21 +2182,21 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
return Ok(None);
}

// Build a flat list of (ModelId, display_line) for the data rows.
// Build a flat list of (ProviderId, ModelId) for the data rows.
// The first line is the header; data rows follow in the same order as
// the Info entries (sorted by provider, then model within provider).
let mut model_ids: Vec<ModelId> = Vec::new();
let mut selections: Vec<(ProviderId, ModelId)> = Vec::new();
for pm in &all_provider_models {
for model in &pm.models {
model_ids.push(model.id.clone());
selections.push((pm.provider_id.clone(), model.id.clone()));
}
}

// Create display items: header line first, then data lines paired with
// model IDs.
// provider/model selections.
#[derive(Clone)]
struct ModelRow {
model_id: Option<ModelId>,
selection: Option<(ProviderId, ModelId)>,
display: String,
}
impl std::fmt::Display for ModelRow {
Expand All @@ -2207,32 +2207,40 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI

let mut rows: Vec<ModelRow> = Vec::with_capacity(all_lines.len());
// Header row (non-selectable via header_lines=1)
rows.push(ModelRow { model_id: None, display: all_lines[0].to_string() });
rows.push(ModelRow { selection: None, display: all_lines[0].to_string() });
// Data rows
for (i, line) in all_lines.iter().skip(1).enumerate() {
rows.push(ModelRow {
model_id: model_ids.get(i).cloned(),
selection: selections.get(i).cloned(),
display: line.to_string(),
});
}

// Find starting cursor position for the current model.
// Find starting cursor position for the current provider/model pair.
// The cursor position is relative to the data rows (header is excluded
// by fzf's --header-lines), so index 0 = first data row.
let current_model = self
.get_agent_model(self.api.get_active_agent().await)
.await;
let starting_cursor = current_model
.as_ref()
.and_then(|current| model_ids.iter().position(|id| id == current))
.unwrap_or(0);
let active_agent = self.api.get_active_agent().await;
let current_model = self.get_agent_model(active_agent.clone()).await;
let current_provider = self.get_provider(active_agent).await.ok().map(|p| p.id);
let starting_cursor = match (current_provider, current_model) {
(Some(provider), Some(model)) => selections
.iter()
.position(|(pid, mid)| pid == &provider && mid == &model)
.or_else(|| selections.iter().position(|(_, mid)| mid == &model))
.unwrap_or(0),
(_, Some(model)) => selections
.iter()
.position(|(_, mid)| mid == &model)
.unwrap_or(0),
_ => 0,
};

match ForgeWidget::select("Model", rows)
.with_starting_cursor(starting_cursor)
.with_header_lines(1)
.prompt()?
{
Some(row) => Ok(row.model_id),
Some(row) => Ok(row.selection),
None => Ok(None),
}
}
Expand Down Expand Up @@ -2736,22 +2744,31 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
provider_filter: Option<ProviderId>,
provider_to_activate: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Select a model
let model_option = self.select_model(provider_filter).await?;
// Select a model/provider pair
let selection = self.select_model(provider_filter).await?;

// If no model was selected (user canceled), return early
let model = match model_option {
Some(model) => model,
let (selected_provider, model) = match selection {
Some(selection) => selection,
None => return Ok(None),
};

// If we have a provider to activate, write both atomically
// If we have a provider to activate, write both atomically.
// Otherwise, if the selected model belongs to a different provider,
// switch provider and model together.
if let Some(provider_id) = provider_to_activate {
self.api
.set_default_provider_and_model(provider_id, model.clone())
.await?;
} else {
self.api.set_default_model(model.clone()).await?;
let current_provider = self.api.get_default_provider().await.ok().map(|p| p.id);
if current_provider.as_ref() == Some(&selected_provider) {
self.api.set_default_model(model.clone()).await?;
} else {
self.api
.set_default_provider_and_model(selected_provider, model.clone())
.await?;
}
}

// Update the UI state with the new model
Expand Down
29 changes: 29 additions & 0 deletions crates/forge_repo/src/provider/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -3099,5 +3099,34 @@
"input_modalities": ["text"]
}
]
},
{
"id": "adal",
"api_key_vars": "ADAL_API_KEY",
"url_param_vars": [],
"response_type": "OpenAI",
"url": "https://api.sylph.ai/v1/chat/completions",
"models": "https://api.sylph.ai/v1/models",
"auth_methods": ["api_key"]
},
{
"id": "modal",
"api_key_vars": "MODAL_API_KEY",
"url_param_vars": [],
"response_type": "OpenAI",
"url": "https://api.us-west-2.modal.direct/v1/chat/completions",
"models": [
{
"id": "zai-org/GLM-5-FP8",
"name": "GLM-5",
"description": "Z.ai's 745B parameter flagship open-source MoE model for long-horizon agents and systems engineering, hosted on Modal",
"context_length": 192000,
"tools_supported": true,
"supports_parallel_tool_calls": true,
"supports_reasoning": true,
"input_modalities": ["text"]
}
],
"auth_methods": ["api_key"]
}
]
Loading
Loading