Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ pub trait API: Sync + Send {
/// Provides a list of models available in the current environment
async fn get_models(&self) -> Result<Vec<Model>>;

/// Provides models from all configured providers. Providers that
/// successfully return models are included in the result. If every
/// configured provider fails (e.g. due to an invalid API key), the
/// first error is returned so the caller sees the real underlying cause
/// rather than an empty list.
async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>>;
/// Provides models from all configured providers. Each element is
/// either a successful `ProviderModels` or an error for a provider
/// that failed (e.g. due to stale credentials), so callers can show
/// partial results alongside per-provider errors.
async fn get_all_provider_models(&self) -> Result<Vec<Result<ProviderModels>>>;

/// Provides a list of agents available in the current environment
async fn get_agents(&self) -> Result<Vec<Agent>>;
Expand Down
2 changes: 1 addition & 1 deletion crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<
self.app().get_models().await
}

async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>> {
async fn get_all_provider_models(&self) -> Result<Vec<Result<ProviderModels>>> {
self.app().get_all_provider_models().await
}

Expand Down
19 changes: 8 additions & 11 deletions crates/forge_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ impl<S: Services + EnvironmentInfra<Config = forge_config::ForgeConfig>> ForgeAp

/// Gets available models from all configured providers concurrently.
///
/// Returns a list of `ProviderModels` for each configured provider that
/// successfully returned models. If every configured provider fails (e.g.
/// due to an invalid API key), the first error encountered is returned so
/// the caller receives the real underlying cause rather than an empty list.
pub async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>> {
/// Returns one `Result<ProviderModels>` per configured provider so the
/// caller can display partial results alongside per-provider errors
/// (e.g. stale credentials on one provider should not hide models from
/// others).
pub async fn get_all_provider_models(&self) -> Result<Vec<Result<ProviderModels>>> {
let all_providers = self.services.get_all_providers().await?;

// Build one future per configured provider, preserving the error on failure.
Expand All @@ -312,6 +312,7 @@ impl<S: Services + EnvironmentInfra<Config = forge_config::ForgeConfig>> ForgeAp
let provider_id = provider.id.clone();
let services = self.services.clone();
async move {
let pid = provider_id.clone();
let result: Result<ProviderModels> = async {
let refreshed = services
.provider_auth_service()
Expand All @@ -321,15 +322,11 @@ impl<S: Services + EnvironmentInfra<Config = forge_config::ForgeConfig>> ForgeAp
Ok(ProviderModels { provider_id, models })
}
.await;
result
result.map_err(|e| e.context(format!("provider '{pid}'")))
}
})
.collect();

// Execute all provider fetches concurrently.
futures::future::join_all(futures)
.await
.into_iter()
.collect::<anyhow::Result<Vec<_>>>()
Ok(futures::future::join_all(futures).await)
}
}
47 changes: 42 additions & 5 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use forge_app::{CommitResult, ToolResolver};
use forge_config::ForgeConfig;
use forge_display::MarkdownFormat;
use forge_domain::{
AuthMethod, ChatResponseContent, ConsoleWriter, ContextMessage, Role, TitleFormat, UserCommand,
AuthMethod, ChatResponseContent, ConsoleWriter, ContextMessage, ProviderModels, Role,
TitleFormat, UserCommand,
};
use forge_fs::ForgeFS;
use forge_select::ForgeWidget;
Expand Down Expand Up @@ -129,6 +130,37 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
self.spinner.ewrite_ln(title)
}

/// Partitions provider model results into successes, writing
/// per-provider errors to stderr. If every provider failed, the
/// first error is returned so callers surface a real failure rather
/// than silently treating it as "no models configured".
fn collect_provider_models(
&mut self,
results: Vec<Result<ProviderModels>>,
) -> anyhow::Result<Vec<ProviderModels>> {
let mut models = Vec::new();
let mut first_error: Option<anyhow::Error> = None;
for result in results {
match result {
Ok(pm) => models.push(pm),
Err(err) => {
self.writeln_to_stderr(
TitleFormat::error(format!("{err:#}")).display().to_string(),
)?;
if first_error.is_none() {
first_error = Some(err);
}
}
}
}
if models.is_empty()
&& let Some(err) = first_error
{
return Err(err);
}
Ok(models)
}

/// Helper to get provider for an optional agent, defaulting to the current
/// active agent's provider
async fn get_provider(&self, agent_id: Option<AgentId>) -> Result<Provider<Url>> {
Expand Down Expand Up @@ -1217,13 +1249,15 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
async fn on_show_models(&mut self, porcelain: bool) -> anyhow::Result<()> {
self.spinner.start(Some("Fetching Models"))?;

let mut all_provider_models = match self.api.get_all_provider_models().await {
Ok(provider_models) => provider_models,
let results = match self.api.get_all_provider_models().await {
Ok(results) => results,
Err(err) => {
self.spinner.stop(None)?;
return Err(err);
}
};
self.spinner.stop(None)?;
let mut all_provider_models = self.collect_provider_models(results)?;

if all_provider_models.is_empty() {
return Ok(());
Expand Down Expand Up @@ -2158,8 +2192,9 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
// Fetch models from ALL configured providers (matches shell plugin's
// `forge list models --porcelain`), then optionally filter by provider.
self.spinner.start(Some("Loading"))?;
let mut all_provider_models = self.api.get_all_provider_models().await?;
let results = self.api.get_all_provider_models().await?;
self.spinner.stop(None)?;
let mut all_provider_models = self.collect_provider_models(results)?;

// When a provider filter is specified (e.g. during onboarding after a
// provider was just selected), restrict the list to that provider's
Expand Down Expand Up @@ -2928,7 +2963,8 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
let (needs_model_selection, compatible_model) = match current_model {
None => (true, None),
Some(current_model) => {
let provider_models = self.api.get_all_provider_models().await?;
let results = self.api.get_all_provider_models().await?;
let provider_models = self.collect_provider_models(results)?;
let model_available = provider_models
.iter()
.find(|pm| pm.provider_id == provider.id)
Expand Down Expand Up @@ -3821,6 +3857,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
.get_all_provider_models()
.await?
.into_iter()
.filter_map(Result::ok)
.find(|pm| &pm.provider_id == provider_id)
.with_context(|| {
format!("Provider '{provider_id}' not found or returned no models")
Expand Down
Loading