From aa07a5cf7a93ea648d41cede1dca0df39839bb80 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Thu, 28 May 2026 13:36:44 -0700 Subject: [PATCH 01/13] feat(providers): add okta token exchange delegation flow --- crates/openshell-cli/src/run.rs | 2 + .../tests/provider_commands_integration.rs | 128 ++++++++++++++ crates/openshell-core/src/metadata.rs | 38 +++- crates/openshell-providers/src/profiles.rs | 56 ++++++ crates/openshell-server/src/auth/oidc.rs | 22 ++- crates/openshell-server/src/delegation.rs | 163 ++++++++++++++++++ crates/openshell-server/src/grpc/sandbox.rs | 102 ++++++++++- crates/openshell-server/src/lib.rs | 1 + crates/openshell-server/src/multiplex.rs | 13 ++ .../openshell-server/src/provider_refresh.rs | 146 +++++++++++++++- docs/get-started/tutorials/index.mdx | 5 + docs/get-started/tutorials/okta-obo.mdx | 145 ++++++++++++++++ proto/openshell.proto | 13 ++ providers/okta-obo.yaml | 39 +++++ 14 files changed, 858 insertions(+), 15 deletions(-) create mode 100644 crates/openshell-server/src/delegation.rs create mode 100644 docs/get-started/tutorials/okta-obo.mdx create mode 100644 providers/okta-obo.yaml diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 76f1214d3..4fef70ec2 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -5004,6 +5004,7 @@ fn provider_refresh_strategy(strategy: &str) -> Result { Ok(ProviderCredentialRefreshStrategy::Oauth2ClientCredentials) } + "oauth2_token_exchange" => Ok(ProviderCredentialRefreshStrategy::Oauth2TokenExchange), "google_service_account_jwt" => { Ok(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) } @@ -5058,6 +5059,7 @@ fn provider_refresh_strategy_name(strategy: ProviderCredentialRefreshStrategy) - ProviderCredentialRefreshStrategy::External => "external", ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index ed78c6659..200a449de 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1098,6 +1098,89 @@ async fn provider_refresh_cli_run_functions_wire_requests() { ); } +#[tokio::test] +async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { + let ts = run_server().await; + + ts.state.profiles.lock().await.insert( + "okta-obo".to_string(), + ProviderProfile { + id: "okta-obo".to_string(), + display_name: "Okta OBO".to_string(), + credentials: vec![ProviderProfileCredential { + name: "OKTA_OBO_ACCESS_TOKEN".to_string(), + required: true, + refresh: Some(ProviderCredentialRefresh { + strategy: ProviderCredentialRefreshStrategy::Oauth2TokenExchange as i32, + token_url: "https://example.okta.com/oauth2/default/v1/token".to_string(), + material: vec![ + openshell_core::proto::ProviderCredentialRefreshMaterial { + name: "client_id".to_string(), + required: true, + ..Default::default() + }, + openshell_core::proto::ProviderCredentialRefreshMaterial { + name: "sandbox_id".to_string(), + required: true, + ..Default::default() + }, + openshell_core::proto::ProviderCredentialRefreshMaterial { + name: "audience".to_string(), + required: true, + ..Default::default() + }, + ], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ); + + run::provider_create( + &ts.endpoint, + "okta-obo-runtime", + "okta-obo", + false, + &[], + &[], + &ts.tls, + ) + .await + .expect("provider create"); + + run::provider_refresh_config( + &ts.endpoint, + run::ProviderRefreshConfigInput { + name: "okta-obo-runtime", + credential_key: "OKTA_OBO_ACCESS_TOKEN", + strategy: "oauth2_token_exchange", + material: &[ + "client_id=client-id".to_string(), + "sandbox_id=sandbox-123".to_string(), + "audience=api://downstream".to_string(), + "scope=api:access:read".to_string(), + ], + secret_material_keys: &["client_secret".to_string()], + credential_expires_at_ms: None, + }, + &ts.tls, + ) + .await + .expect("provider refresh configure"); + + let requests = ts.state.refresh_requests.lock().await.clone(); + assert_eq!( + requests, + vec![ProviderRefreshRequestLog::Configure { + provider_name: "okta-obo-runtime".to_string(), + credential_key: "OKTA_OBO_ACCESS_TOKEN".to_string(), + expires_at_ms: None, + }] + ); +} + #[tokio::test] async fn provider_create_allows_empty_credentials_for_gateway_refresh_profiles() { let ts = run_server().await; @@ -1708,6 +1791,51 @@ endpoints: .expect_err("valid profiles should not be partially imported after local parse errors"); } +#[tokio::test] +async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { + let ts = run_server().await; + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client should connect"); + let profile = client + .get_provider_profile(openshell_core::proto::GetProviderProfileRequest { + id: "okta-obo".to_string(), + }) + .await + .expect("get provider profile") + .into_inner() + .profile + .expect("profile should exist"); + + assert_eq!(profile.id, "okta-obo"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "obo_access_token") + .expect("obo access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("obo credential should include refresh config"); + assert_eq!( + refresh.strategy, + ProviderCredentialRefreshStrategy::Oauth2TokenExchange as i32 + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "sandbox_id" && material.required) + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "audience" && material.required) + ); +} + #[tokio::test] async fn provider_profile_lint_from_directory_reports_parse_errors_without_importing() { let ts = run_server().await; diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index af26f73ae..b315b58a4 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -6,8 +6,9 @@ //! These traits provide uniform access to `ObjectMeta` fields across all resource types. use crate::proto::{ - InferenceRoute, ObjectForTest, Provider, Sandbox, SandboxStatus, ServiceEndpoint, SshSession, - StoredProviderCredentialRefreshState, StoredProviderProfile, + InferenceRoute, ObjectForTest, Provider, Sandbox, ServiceEndpoint, SshSession, + StoredProviderCredentialRefreshState, StoredProviderProfile, StoredSandboxDelegationBinding, + SandboxStatus, }; use std::collections::HashMap; @@ -188,6 +189,39 @@ impl GetResourceVersion for StoredProviderCredentialRefreshState { } } +// Implementations for StoredSandboxDelegationBinding +impl ObjectId for StoredSandboxDelegationBinding { + fn object_id(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.id.as_str()) + } +} + +impl ObjectName for StoredSandboxDelegationBinding { + fn object_name(&self) -> &str { + self.metadata.as_ref().map_or("", |m| m.name.as_str()) + } +} + +impl ObjectLabels for StoredSandboxDelegationBinding { + fn object_labels(&self) -> Option> { + self.metadata.as_ref().map(|m| m.labels.clone()) + } +} + +impl SetResourceVersion for StoredSandboxDelegationBinding { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for StoredSandboxDelegationBinding { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 316624287..437c432ee 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -21,6 +21,7 @@ const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/github.yaml"), include_str!("../../../providers/google-vertex-ai.yaml"), include_str!("../../../providers/nvidia.yaml"), + include_str!("../../../providers/okta-obo.yaml"), ]; #[derive(Debug, thiserror::Error)] @@ -530,6 +531,7 @@ pub fn provider_refresh_strategy_from_yaml(raw: &str) -> Option { Some(ProviderCredentialRefreshStrategy::Oauth2ClientCredentials) } + "oauth2_token_exchange" => Some(ProviderCredentialRefreshStrategy::Oauth2TokenExchange), "google_service_account_jwt" => { Some(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) } @@ -546,6 +548,7 @@ pub fn provider_refresh_strategy_to_yaml( ProviderCredentialRefreshStrategy::External => "external", ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } @@ -1172,6 +1175,59 @@ mod tests { assert_eq!(proto.binaries.len(), 4); } + #[test] + fn okta_obo_profile_exposes_token_exchange_shape() { + let profile = get_default_profile("okta-obo").expect("okta-obo profile"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "obo_access_token") + .expect("okta-obo access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("okta-obo credential should be refreshable"); + + assert_eq!( + refresh.strategy, + openshell_core::proto::ProviderCredentialRefreshStrategy::Oauth2TokenExchange + ); + assert_eq!( + refresh.token_url, + "https://example.okta.com/oauth2/default/v1/token" + ); + + let material_names = refresh + .material + .iter() + .map(|material| material.name.as_str()) + .collect::>(); + assert_eq!( + material_names, + vec![ + "client_id", + "sandbox_id", + "audience", + "client_secret", + "scope" + ] + ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "sandbox_id") + .is_some_and(|material| material.required) + ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "audience") + .is_some_and(|material| material.required) + ); + } + #[test] fn credential_env_vars_are_deduplicated_in_profile_order() { let profile = get_default_profile("claude-code").expect("claude-code profile"); diff --git a/crates/openshell-server/src/auth/oidc.rs b/crates/openshell-server/src/auth/oidc.rs index bf5490f2a..42c92c6b0 100644 --- a/crates/openshell-server/src/auth/oidc.rs +++ b/crates/openshell-server/src/auth/oidc.rs @@ -109,6 +109,22 @@ pub struct OidcClaims { const STANDARD_OIDC_SCOPES: &[&str] = &["openid", "profile", "email", "offline_access"]; +/// Raw OIDC bearer token captured from the inbound request. +/// +/// Stored in request extensions only after OIDC authentication succeeds so +/// later handlers can persist or exchange the user token without reparsing the +/// header. +#[derive(Debug, Clone)] +pub struct RawBearerToken(pub String); + +/// Extract a bearer token from an `Authorization` header. +pub fn extract_bearer_token(headers: &http::HeaderMap) -> Option<&str> { + headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) +} + impl OidcClaims { /// Extract roles from the JWT claims using a dot-separated path. /// @@ -372,11 +388,7 @@ impl Authenticator for OidcAuthenticator { headers: &http::HeaderMap, _path: &str, ) -> Result, Status> { - let Some(token) = headers - .get("authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")) - else { + let Some(token) = extract_bearer_token(headers) else { return Ok(None); }; diff --git a/crates/openshell-server/src/delegation.rs b/crates/openshell-server/src/delegation.rs new file mode 100644 index 000000000..7bed4e877 --- /dev/null +++ b/crates/openshell-server/src/delegation.rs @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox delegation bindings for on-behalf-of token exchange. +//! +//! Lane 3 needs a stable server-side record of which signed-in user created a +//! sandbox and which inbound bearer token was available at that time. This +//! module owns that persisted binding so later broker code can exchange the +//! user token for a delegated downstream token without storing long-lived +//! user material inside the sandbox itself. + +use crate::persistence::{ObjectType, Store, current_time_ms}; +use openshell_core::proto::{Sandbox, StoredSandboxDelegationBinding}; +use openshell_core::{ObjectId, ObjectName}; +use tonic::Status; + +impl ObjectType for StoredSandboxDelegationBinding { + fn object_type() -> &'static str { + "sandbox_delegation_binding" + } +} + +pub fn binding_name(sandbox_id: &str) -> String { + format!("sandbox-delegation-{sandbox_id}") +} + +pub fn new_binding( + sandbox: &Sandbox, + subject: &str, + display_name: Option<&str>, + identity_provider: &str, + access_token: &str, + scopes: &[String], +) -> Result { + let sandbox_id = sandbox.object_id().trim(); + let sandbox_name = sandbox.object_name().trim(); + if sandbox_id.is_empty() { + return Err(Status::internal("sandbox is missing metadata.id")); + } + if sandbox_name.is_empty() { + return Err(Status::internal("sandbox is missing metadata.name")); + } + if subject.trim().is_empty() { + return Err(Status::invalid_argument("delegation subject is required")); + } + if access_token.trim().is_empty() { + return Err(Status::invalid_argument( + "delegation access token is required", + )); + } + + let now_ms = current_time_ms(); + Ok(StoredSandboxDelegationBinding { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: uuid::Uuid::new_v4().to_string(), + name: binding_name(sandbox_id), + created_at_ms: now_ms, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + sandbox_id: sandbox_id.to_string(), + sandbox_name: sandbox_name.to_string(), + subject: subject.trim().to_string(), + display_name: display_name.unwrap_or_default().trim().to_string(), + identity_provider: identity_provider.trim().to_string(), + access_token: access_token.trim().to_string(), + scopes: scopes.to_vec(), + captured_at_ms: now_ms, + }) +} + +pub async fn put_binding( + store: &Store, + binding: &StoredSandboxDelegationBinding, +) -> Result<(), Status> { + store + .put_scoped_message(binding, &binding.sandbox_id) + .await + .map_err(|e| Status::internal(format!("persist sandbox delegation binding failed: {e}"))) +} + +#[cfg_attr(not(test), allow(dead_code))] +pub async fn get_binding( + store: &Store, + sandbox_id: &str, +) -> Result, Status> { + store + .get_message_by_name::(&binding_name(sandbox_id)) + .await + .map_err(|e| Status::internal(format!("fetch sandbox delegation binding failed: {e}"))) +} + +pub async fn delete_binding(store: &Store, sandbox_id: &str) -> Result { + store + .delete_by_name( + StoredSandboxDelegationBinding::object_type(), + &binding_name(sandbox_id), + ) + .await + .map_err(|e| Status::internal(format!("delete sandbox delegation binding failed: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Store; + + fn sandbox() -> Sandbox { + Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-123".to_string(), + name: "demo-sandbox".to_string(), + created_at_ms: 0, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + } + } + + #[tokio::test] + async fn binding_round_trip_works() { + let store = Store::connect("sqlite::memory:") + .await + .expect("in-memory store"); + let sandbox = sandbox(); + let binding = new_binding( + &sandbox, + "user-123", + Some("alex"), + "oidc", + "token-value", + &["sandbox:write".to_string()], + ) + .expect("binding"); + + put_binding(&store, &binding) + .await + .expect("persist binding"); + let loaded = get_binding(&store, "sb-123") + .await + .expect("load binding") + .expect("binding present"); + assert_eq!(loaded.subject, "user-123"); + assert_eq!(loaded.sandbox_name, "demo-sandbox"); + assert_eq!(loaded.identity_provider, "oidc"); + assert_eq!(loaded.access_token, "token-value"); + + let deleted = delete_binding(&store, "sb-123") + .await + .expect("delete binding"); + assert!(deleted); + assert!( + get_binding(&store, "sb-123") + .await + .expect("load binding") + .is_none() + ); + } +} diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 198d5f04c..28ebc2d3d 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -10,6 +10,9 @@ #![allow(clippy::cast_possible_wrap)] // Intentional u32->i32 conversions for proto compat use crate::ServerState; +use crate::auth::identity::IdentityProvider; +use crate::auth::oidc::RawBearerToken; +use crate::auth::principal::Principal; use crate::persistence::{ObjectType, WriteCondition, generate_name}; use futures::future; use openshell_core::proto::{ @@ -119,6 +122,8 @@ async fn handle_create_sandbox_inner( ) -> Result, Status> { use crate::persistence::current_time_ms; + let principal = request.extensions().get::().cloned(); + let raw_bearer_token = request.extensions().get::().cloned(); let request = request.into_inner(); let spec = request .spec @@ -212,7 +217,37 @@ async fn handle_create_sandbox_inner( None => None, }; - let sandbox = state.compute.create_sandbox(sandbox, sandbox_token).await?; + let delegation_binding = match (principal.as_ref(), raw_bearer_token.as_ref()) { + (Some(Principal::User(user)), Some(raw)) + if user.identity.provider == IdentityProvider::Oidc => + { + Some(crate::delegation::new_binding( + &sandbox, + &user.identity.subject, + user.identity.display_name.as_deref(), + "oidc", + &raw.0, + &user.identity.scopes, + )?) + } + _ => None, + }; + + if let Some(binding) = delegation_binding.as_ref() { + crate::delegation::put_binding(state.store.as_ref(), binding).await?; + } + + let sandbox = match state.compute.create_sandbox(sandbox, sandbox_token).await { + Ok(sandbox) => sandbox, + Err(err) => { + if let Some(binding) = delegation_binding.as_ref() { + let _ = + crate::delegation::delete_binding(state.store.as_ref(), &binding.sandbox_id) + .await; + } + return Err(err); + } + }; info!( sandbox_id = %id, @@ -498,12 +533,21 @@ async fn handle_delete_sandbox_inner( .store .get_message_by_name::(&name) .await - .ok() - .flatten() - .map(|sandbox| sandbox.object_id().to_string()); + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .and_then(|sandbox| sandbox.metadata.map(|meta| meta.id)); + let deleted = state.compute.delete_sandbox(&name).await?; - if deleted && let Some(sandbox_id) = sandbox_id { - state.telemetry.end_sandbox_session(&sandbox_id); + if deleted && let Some(sandbox_id) = sandbox_id.as_deref() { + state.telemetry.end_sandbox_session(sandbox_id); + let deleted_binding = crate::delegation::delete_binding(state.store.as_ref(), sandbox_id) + .await + .unwrap_or(false); + debug!( + sandbox_name = %name, + sandbox_id, + deleted_binding, + "deleted sandbox delegation binding" + ); } info!(sandbox_name = %name, "DeleteSandbox request completed successfully"); Ok(Response::new(DeleteSandboxResponse { deleted })) @@ -1938,6 +1982,9 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::oidc::RawBearerToken; + use crate::auth::principal::{Principal, UserPrincipal}; use crate::grpc::test_support::test_server_state; use openshell_core::proto::datamodel::v1::ObjectMeta; use std::collections::HashMap; @@ -2589,6 +2636,49 @@ mod tests { assert!(err.message().contains("provider-b")); } + #[tokio::test] + async fn create_sandbox_persists_delegation_binding_for_oidc_user() { + let state = test_server_state().await; + let mut request = Request::new(CreateSandboxRequest { + name: "delegated".to_string(), + spec: Some(openshell_core::proto::SandboxSpec::default()), + labels: HashMap::new(), + }); + request + .extensions_mut() + .insert(Principal::User(UserPrincipal { + identity: Identity { + subject: "user-123".to_string(), + display_name: Some("alex".to_string()), + roles: vec!["openshell-user".to_string()], + scopes: vec!["sandbox:write".to_string()], + provider: IdentityProvider::Oidc, + }, + })); + request + .extensions_mut() + .insert(RawBearerToken("raw-access-token".to_string())); + + let response = handle_create_sandbox(&state, request) + .await + .expect("sandbox create succeeds") + .into_inner(); + let sandbox = response.sandbox.expect("sandbox present"); + let binding = crate::delegation::get_binding( + state.store.as_ref(), + sandbox.metadata.as_ref().expect("metadata").id.as_str(), + ) + .await + .expect("load binding") + .expect("binding present"); + + assert_eq!(binding.subject, "user-123"); + assert_eq!(binding.display_name, "alex"); + assert_eq!(binding.identity_provider, "oidc"); + assert_eq!(binding.access_token, "raw-access-token"); + assert_eq!(binding.scopes, vec!["sandbox:write".to_string()]); + } + #[tokio::test] async fn attach_sandbox_provider_rejects_credential_key_collisions() { let state = test_server_state().await; diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index c25ba1cfd..7ad38952e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -25,6 +25,7 @@ pub mod cli; mod compute; pub mod config_file; mod defaults; +mod delegation; mod grpc; mod http; mod inference; diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index e94326f98..cfb3de0f4 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -467,6 +467,19 @@ where } } + let raw_oidc_bearer = if let Principal::User(ref user) = principal { + if user.identity.provider == crate::auth::identity::IdentityProvider::Oidc { + oidc::extract_bearer_token(req.headers()).map(str::to_owned) + } else { + None + } + } else { + None + }; + if let Some(token) = raw_oidc_bearer { + req.extensions_mut().insert(oidc::RawBearerToken(token)); + } + req.extensions_mut().insert(principal); inner.ready().await?.call(req).await }) diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 161daeb7f..160c90e3e 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -278,6 +278,7 @@ pub fn refresh_strategy_name(strategy: i32) -> &'static str { ProviderCredentialRefreshStrategy::External => "external", ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", + ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } @@ -288,6 +289,7 @@ pub fn is_gateway_mintable_strategy(strategy: ProviderCredentialRefreshStrategy) strategy, ProviderCredentialRefreshStrategy::Oauth2RefreshToken | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials + | ProviderCredentialRefreshStrategy::Oauth2TokenExchange | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt ) } @@ -317,7 +319,7 @@ pub async fn refresh_provider_credential( "provider credential refresh started" ); - match mint_credential(&state).await { + match mint_credential(store, &state).await { Ok(minted) => { let now_ms = current_time_ms(); if let Err(err) = @@ -435,6 +437,7 @@ async fn apply_minted_credential( } async fn mint_credential( + store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { let strategy = ProviderCredentialRefreshStrategy::try_from(state.strategy) @@ -446,6 +449,9 @@ async fn mint_credential( ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => { mint_oauth2_client_credentials(state).await } + ProviderCredentialRefreshStrategy::Oauth2TokenExchange => { + mint_oauth2_token_exchange(store, state).await + } ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => { mint_google_service_account_jwt(state).await } @@ -479,6 +485,51 @@ async fn mint_oauth2_refresh_token( request_token(&token_url, &form, state.max_lifetime_seconds).await } +async fn mint_oauth2_token_exchange( + store: &Store, + state: &StoredProviderCredentialRefreshState, +) -> Result { + let token_url = oauth2_token_url(state)?; + let client_id = required_material(&state.material, "client_id")?; + let sandbox_id = required_material(&state.material, "sandbox_id")?; + let binding = crate::delegation::get_binding(store, &sandbox_id) + .await? + .ok_or_else(|| { + Status::failed_precondition(format!( + "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" + )) + })?; + + let mut form = vec![ + ( + "grant_type".to_string(), + "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), + ), + ("client_id".to_string(), client_id), + ("subject_token".to_string(), binding.access_token), + ( + "subject_token_type".to_string(), + "urn:ietf:params:oauth:token-type:access_token".to_string(), + ), + ( + "requested_token_type".to_string(), + "urn:ietf:params:oauth:token-type:access_token".to_string(), + ), + ]; + if let Some(client_secret) = material_value(&state.material, &["client_secret"]) { + form.push(("client_secret".to_string(), client_secret)); + } + if let Some(audience) = material_value(&state.material, &["audience", "resource"]) { + form.push(("audience".to_string(), audience)); + } + let scope = refresh_scopes(state).join(" "); + if !scope.is_empty() { + form.push(("scope".to_string(), scope)); + } + + request_token(&token_url, &form, state.max_lifetime_seconds).await +} + async fn mint_oauth2_client_credentials( state: &StoredProviderCredentialRefreshState, ) -> Result { @@ -648,7 +699,7 @@ fn oauth2_token_url(state: &StoredProviderCredentialRefreshState) -> Result("my-obo") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored_provider.credentials.get("OKTA_ACCESS_TOKEN"), + Some(&"delegated-downstream-token".to_string()) + ); + } + #[tokio::test] async fn google_service_account_refresh_mints_and_persists_access_token() { let mock_server = MockServer::start().await; diff --git a/docs/get-started/tutorials/index.mdx b/docs/get-started/tutorials/index.mdx index 0d82509ad..b6c032013 100644 --- a/docs/get-started/tutorials/index.mdx +++ b/docs/get-started/tutorials/index.mdx @@ -27,6 +27,11 @@ Launch Claude Code in a sandbox, diagnose a policy denial, and iterate on a cust Configure a Providers v2 Microsoft Graph provider with gateway-managed OAuth2 refresh-token rotation. + + +Configure delegated Okta access on behalf of the logged-in OpenShell user with token exchange. + + Route inference through Ollama using cloud-hosted or local models, and verify it from a sandbox. diff --git a/docs/get-started/tutorials/okta-obo.mdx b/docs/get-started/tutorials/okta-obo.mdx new file mode 100644 index 000000000..c7a8f7a59 --- /dev/null +++ b/docs/get-started/tutorials/okta-obo.mdx @@ -0,0 +1,145 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Delegate Okta Access On Behalf of the Logged-in User" +sidebar-title: "Okta OBO Token Exchange" +slug: "get-started/tutorials/okta-obo" +description: "Configure the built-in Okta OBO provider profile so OpenShell can exchange a logged-in user's token for a delegated downstream token." +keywords: "Generative AI, Cybersecurity, Tutorial, Providers, Okta, OBO, RFC 8693, Token Exchange, Delegation" +--- + +Use the built-in `okta-obo` profile when a sandboxed workload must call a downstream API on behalf of the human who logged into OpenShell. The gateway keeps the inbound user token server-side, binds it to sandbox creation, and exchanges it for a short-lived delegated token using Okta token exchange. + +After completing this tutorial, you have: + +- A token-exchange service app in Okta for delegated access. +- A customized `okta-obo` provider profile that points at your Okta tenant. +- A sandbox whose attached provider can mint `OKTA_OBO_ACCESS_TOKEN` from the logged-in user's identity. + + +This tutorial covers the delegation lane. It assumes you already completed the gateway login lane and can log into the gateway with Okta before you create the sandbox. + + +## Prerequisites + +- A working OpenShell installation with an active gateway. +- Okta CLI login already working against your OpenShell gateway. +- An Okta custom authorization server. +- An Okta service app with the `Token Exchange` grant enabled. +- A downstream audience and scope that the delegated token should target. + +For the example commands below, set: + +| Variable | Value | +|---|---| +| `OKTA_OBO_CLIENT_ID` | Token-exchange service app client ID. | +| `OKTA_OBO_CLIENT_SECRET` | Token-exchange service app client secret. | +| `OKTA_OBO_AUDIENCE` | Downstream audience, such as `api://default`. | +| `OKTA_OBO_SCOPE` | Delegated scope, such as `api:access:read`. | + + +Treat the token-exchange service app secret like any other production credential. Do not commit it, paste it into source control, or leave it in shell history longer than needed. + + + + +## Create the Downstream Scope and Rule in Okta + +In your Okta custom authorization server: + +- create the downstream scope that the delegated token should carry +- create an access-policy rule for the service app with the `Token Exchange` grant enabled + +For a smoke test, a broad rule with `Any scopes` is fine. + +## Create a Tenant-Specific OBO Profile + +Export the built-in profile and update the token endpoint: + +```shell +openshell provider profile export okta-obo -o yaml > okta-obo.yaml +``` + +Replace: + +- `https://example.okta.com/oauth2/default/v1/token` + +with your real Okta token endpoint, then import the customized profile: + +```shell +openshell provider profile lint -f okta-obo.yaml +openshell provider profile import -f okta-obo.yaml +``` + +## Log In and Create the Sandbox + +Log into the gateway as the human user whose identity should be delegated: + +```shell +openshell gateway login +``` + +Create the sandbox after login so OpenShell can bind the sandbox to the authenticated user token: + +```shell +openshell sandbox create --name okta-obo-smoke +``` + +Fetch the sandbox metadata and record the sandbox ID: + +```shell +openshell sandbox get okta-obo-smoke +``` + +## Create the OBO Provider + +Create the provider from the imported profile: + +```shell +openshell provider create \ + --name okta-obo-runtime \ + --type okta-obo +``` + +## Configure Token Exchange + +Configure the delegated credential refresh: + +```shell +openshell provider refresh configure okta-obo-runtime \ + --credential-key OKTA_OBO_ACCESS_TOKEN \ + --strategy oauth2-token-exchange \ + --material client_id="$OKTA_OBO_CLIENT_ID" \ + --material client_secret="$OKTA_OBO_CLIENT_SECRET" \ + --material sandbox_id="" \ + --material audience="$OKTA_OBO_AUDIENCE" \ + --material scope="$OKTA_OBO_SCOPE" \ + --secret-material-key client_secret +``` + +Check refresh status: + +```shell +openshell provider refresh status okta-obo-runtime \ + --credential-key OKTA_OBO_ACCESS_TOKEN +``` + +## Attach the Provider + +Attach the OBO provider to the sandbox: + +```shell +openshell sandbox provider attach okta-obo-smoke okta-obo-runtime +``` + +## Verify the Delegated Token Path + +Exec into the sandbox and confirm the delegated credential placeholder exists: + +```shell +openshell sandbox exec --name okta-obo-smoke -- /bin/sh -lc 'printenv OKTA_OBO_ACCESS_TOKEN' +``` + +The value should be an OpenShell-managed placeholder rather than a raw token. OpenShell resolves it to the current delegated token when the workload uses the supported authorization path. + + diff --git a/proto/openshell.proto b/proto/openshell.proto index f9b64618b..55165b55b 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -904,6 +904,7 @@ enum ProviderCredentialRefreshStrategy { PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OAUTH2_REFRESH_TOKEN = 3; PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OAUTH2_CLIENT_CREDENTIALS = 4; PROVIDER_CREDENTIAL_REFRESH_STRATEGY_GOOGLE_SERVICE_ACCOUNT_JWT = 5; + PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OAUTH2_TOKEN_EXCHANGE = 6; } message ProviderCredentialRefreshMaterial { @@ -959,6 +960,18 @@ message StoredProviderCredentialRefreshState { int64 max_lifetime_seconds = 16; } +message StoredSandboxDelegationBinding { + openshell.datamodel.v1.ObjectMeta metadata = 1; + string sandbox_id = 2; + string sandbox_name = 3; + string subject = 4; + string display_name = 5; + string identity_provider = 6; + string access_token = 7; + repeated string scopes = 8; + int64 captured_at_ms = 9; +} + message GetProviderRefreshStatusRequest { string provider = 1; string credential_key = 2; diff --git a/providers/okta-obo.yaml b/providers/okta-obo.yaml new file mode 100644 index 000000000..56f38951a --- /dev/null +++ b/providers/okta-obo.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +id: okta-obo +display_name: Okta OBO +description: Okta delegated access tokens minted on behalf of the authenticated OpenShell user +category: other +credentials: + - name: obo_access_token + description: Okta delegated access token for downstream APIs + env_vars: [OKTA_OBO_ACCESS_TOKEN] + required: true + auth_style: bearer + header_name: authorization + refresh: + strategy: oauth2_token_exchange + token_url: https://example.okta.com/oauth2/default/v1/token + refresh_before_seconds: 300 + max_lifetime_seconds: 3600 + material: + - name: client_id + description: Okta OIDC application client ID used for token exchange + required: true + - name: sandbox_id + description: OpenShell sandbox ID bound to the authenticated user token + required: true + - name: audience + description: Downstream Okta resource audience for the delegated token + required: true + - name: client_secret + description: Okta client secret for confidential token-exchange clients + required: false + secret: true + - name: scope + description: Space-delimited scopes requested for the delegated token + required: false +binaries: + - /usr/bin/curl + - /usr/local/bin/curl From 89e69c0df39434261f4995b26a72b8252ce5ecb1 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Thu, 28 May 2026 16:57:49 -0700 Subject: [PATCH 02/13] feat(auth): support live okta obo token exchange --- crates/openshell-cli/src/main.rs | 2 + crates/openshell-server/src/delegation.rs | 1 + crates/openshell-server/src/grpc/provider.rs | 29 ++++++++++++ .../openshell-server/src/provider_refresh.rs | 46 +++++++++++-------- 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 22af412b5..bd310c115 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -660,6 +660,7 @@ enum OutputFormat { enum CliProviderRefreshStrategy { Oauth2RefreshToken, Oauth2ClientCredentials, + Oauth2TokenExchange, GoogleServiceAccountJwt, } @@ -668,6 +669,7 @@ impl CliProviderRefreshStrategy { match self { Self::Oauth2RefreshToken => "oauth2_refresh_token", Self::Oauth2ClientCredentials => "oauth2_client_credentials", + Self::Oauth2TokenExchange => "oauth2_token_exchange", Self::GoogleServiceAccountJwt => "google_service_account_jwt", } } diff --git a/crates/openshell-server/src/delegation.rs b/crates/openshell-server/src/delegation.rs index 7bed4e877..eba87c783 100644 --- a/crates/openshell-server/src/delegation.rs +++ b/crates/openshell-server/src/delegation.rs @@ -24,6 +24,7 @@ pub fn binding_name(sandbox_id: &str) -> String { format!("sandbox-delegation-{sandbox_id}") } +#[allow(clippy::result_large_err)] pub fn new_binding( sandbox: &Sandbox, subject: &str, diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index e6f0c2780..e868c65b7 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -8,6 +8,9 @@ use crate::persistence::{ ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, }; +use crate::auth::identity::IdentityProvider; +use crate::auth::oidc::RawBearerToken; +use crate::auth::principal::Principal; use openshell_core::proto::{Provider, Sandbox}; use openshell_core::telemetry::{ LifecycleOperation, ProviderProfile as TelemetryProviderProfile, TelemetryOutcome, @@ -1238,6 +1241,8 @@ pub(super) async fn handle_configure_provider_refresh( state: &Arc, request: Request, ) -> Result, Status> { + let principal = request.extensions().get::().cloned(); + let raw_bearer_token = request.extensions().get::().cloned(); let request = request.into_inner(); let provider_name = request.provider.trim(); let credential_key = request.credential_key.trim(); @@ -1379,6 +1384,7 @@ pub(super) async fn handle_configure_provider_refresh( credential_key, ) .await?; + let sandbox_id_for_binding = request.material.get("sandbox_id").cloned(); let expires_at_ms = request.expires_at_ms.unwrap_or_else(|| { existing_refresh_state .as_ref() @@ -1405,6 +1411,29 @@ pub(super) async fn handle_configure_provider_refresh( } crate::provider_refresh::put_refresh_state(state.store.as_ref(), &state_record).await?; + if strategy == ProviderCredentialRefreshStrategy::Oauth2TokenExchange + && let (Some(Principal::User(user)), Some(raw)) = + (principal.as_ref(), raw_bearer_token.as_ref()) + && user.identity.provider == IdentityProvider::Oidc + && let Some(sandbox_id) = sandbox_id_for_binding.as_deref().map(str::trim) + && !sandbox_id.is_empty() + && let Some(sandbox) = state + .store + .get_message::(sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + { + let binding = crate::delegation::new_binding( + &sandbox, + &user.identity.subject, + user.identity.display_name.as_deref(), + "oidc", + &raw.0, + &user.identity.scopes, + )?; + crate::delegation::put_binding(state.store.as_ref(), &binding).await?; + } + if let Some(expires_at_ms) = request.expires_at_ms { let updated = Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 160c90e3e..4e72dec69 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -471,18 +471,19 @@ async fn mint_oauth2_refresh_token( let refresh_token = required_material(&state.material, "refresh_token")?; let mut form = vec![ ("grant_type".to_string(), "refresh_token".to_string()), - ("client_id".to_string(), client_id), ("refresh_token".to_string(), refresh_token), ]; - if let Some(client_secret) = material_value(&state.material, &["client_secret"]) { - form.push(("client_secret".to_string(), client_secret)); + let basic_auth = material_value(&state.material, &["client_secret"]) + .map(|client_secret| (client_id.clone(), client_secret)); + if basic_auth.is_none() { + form.push(("client_id".to_string(), client_id)); } let scope = refresh_scopes(state).join(" "); if !scope.is_empty() { form.push(("scope".to_string(), scope)); } - request_token(&token_url, &form, state.max_lifetime_seconds).await + request_token(&token_url, &form, basic_auth, state.max_lifetime_seconds).await } async fn mint_oauth2_token_exchange( @@ -505,7 +506,6 @@ async fn mint_oauth2_token_exchange( "grant_type".to_string(), "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), ), - ("client_id".to_string(), client_id), ("subject_token".to_string(), binding.access_token), ( "subject_token_type".to_string(), @@ -516,8 +516,10 @@ async fn mint_oauth2_token_exchange( "urn:ietf:params:oauth:token-type:access_token".to_string(), ), ]; - if let Some(client_secret) = material_value(&state.material, &["client_secret"]) { - form.push(("client_secret".to_string(), client_secret)); + let basic_auth = material_value(&state.material, &["client_secret"]) + .map(|client_secret| (client_id.clone(), client_secret)); + if basic_auth.is_none() { + form.push(("client_id".to_string(), client_id)); } if let Some(audience) = material_value(&state.material, &["audience", "resource"]) { form.push(("audience".to_string(), audience)); @@ -527,7 +529,7 @@ async fn mint_oauth2_token_exchange( form.push(("scope".to_string(), scope)); } - request_token(&token_url, &form, state.max_lifetime_seconds).await + request_token(&token_url, &form, basic_auth, state.max_lifetime_seconds).await } async fn mint_oauth2_client_credentials( @@ -536,17 +538,14 @@ async fn mint_oauth2_client_credentials( let token_url = oauth2_token_url(state)?; let client_id = required_material(&state.material, "client_id")?; let client_secret = required_material(&state.material, "client_secret")?; - let mut form = vec![ - ("grant_type".to_string(), "client_credentials".to_string()), - ("client_id".to_string(), client_id), - ("client_secret".to_string(), client_secret), - ]; + let mut form = vec![("grant_type".to_string(), "client_credentials".to_string())]; + let basic_auth = Some((client_id, client_secret)); let scope = refresh_scopes(state).join(" "); if !scope.is_empty() { form.push(("scope".to_string(), scope)); } - request_token(&token_url, &form, state.max_lifetime_seconds).await + request_token(&token_url, &form, basic_auth, state.max_lifetime_seconds).await } async fn mint_google_service_account_jwt( @@ -592,12 +591,13 @@ async fn mint_google_service_account_jwt( ), ("assertion".to_string(), assertion), ]; - request_token(&token_url, &form, lifetime_secs).await + request_token(&token_url, &form, None, lifetime_secs).await } async fn request_token( token_url: &str, form: &[(String, String)], + basic_auth: Option<(String, String)>, max_lifetime_seconds: i64, ) -> Result { let parsed = reqwest::Url::parse(token_url) @@ -616,16 +616,24 @@ async fn request_token( .timeout(Duration::from_secs(30)) .build() .map_err(|e| Status::internal(format!("build refresh HTTP client failed: {e}")))?; - let response = client - .post(parsed) - .form(form) + let request = client.post(parsed).form(form); + let request = if let Some((client_id, client_secret)) = basic_auth { + request.basic_auth(client_id, Some(client_secret)) + } else { + request + }; + let response = request .send() .await .map_err(|e| Status::unavailable(format!("token endpoint request failed: {e}")))?; let status = response.status(); if !status.is_success() { + let body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); return Err(Status::failed_precondition(format!( - "token endpoint returned HTTP {status}" + "token endpoint returned HTTP {status}: {body}" ))); } let token = response From 796699e7a64920a45fc1a36212208ea9c640b24d Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 20:42:43 -0700 Subject: [PATCH 03/13] fix(auth): align obo branch with current provider refresh behavior --- crates/openshell-server/src/grpc/provider.rs | 14 ++++++++++---- crates/openshell-server/src/provider_refresh.rs | 2 -- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index e868c65b7..35b2d9c4d 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -5,12 +5,12 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> -use crate::persistence::{ - ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, -}; use crate::auth::identity::IdentityProvider; use crate::auth::oidc::RawBearerToken; use crate::auth::principal::Principal; +use crate::persistence::{ + ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, +}; use openshell_core::proto::{Provider, Sandbox}; use openshell_core::telemetry::{ LifecycleOperation, ProviderProfile as TelemetryProviderProfile, TelemetryOutcome, @@ -1847,7 +1847,13 @@ mod tests { .collect::>(); assert_eq!( ids, - vec!["claude-code", "github", "google-vertex-ai", "nvidia",] + vec![ + "claude-code", + "github", + "google-vertex-ai", + "nvidia", + "okta-obo", + ] ); let github = response diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 4e72dec69..9d199c674 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -900,7 +900,6 @@ mod tests { Mock::given(method("POST")) .and(path("/token")) .and(body_string_contains("grant_type=client_credentials")) - .and(body_string_contains("client_id=client-id")) .and(body_string_contains( "scope=https%3A%2F%2Fgraph.microsoft.com%2F.default", )) @@ -1134,7 +1133,6 @@ mod tests { .and(body_string_contains( "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange", )) - .and(body_string_contains("client_id=client-id")) .and(body_string_contains("subject_token=user-access-token")) .and(body_string_contains( "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token", From 3bb5bcf09d59ed3c69aadcc887fdecd2d92092d5 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Tue, 2 Jun 2026 15:58:54 -0700 Subject: [PATCH 04/13] fix(auth): align okta obo with provider refresh model --- .../tests/provider_commands_integration.rs | 8 +- crates/openshell-core/src/metadata.rs | 35 +--- crates/openshell-providers/src/profiles.rs | 6 +- crates/openshell-server/src/delegation.rs | 164 ------------------ crates/openshell-server/src/grpc/provider.rs | 70 +++++--- crates/openshell-server/src/grpc/sandbox.rs | 103 +---------- crates/openshell-server/src/lib.rs | 1 - .../openshell-server/src/provider_refresh.rs | 47 ++--- docs/get-started/tutorials/okta-obo.mdx | 27 ++- proto/openshell.proto | 12 -- providers/okta-obo.yaml | 7 +- 11 files changed, 80 insertions(+), 400 deletions(-) delete mode 100644 crates/openshell-server/src/delegation.rs diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 200a449de..5a5031d23 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1826,7 +1826,7 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "sandbox_id" && material.required) + .any(|material| material.name == "client_id" && material.required) ); assert!( refresh @@ -1834,6 +1834,12 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { .iter() .any(|material| material.name == "audience" && material.required) ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "subject_token" && !material.required) + ); } #[tokio::test] diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index b315b58a4..e86bc17e2 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -7,7 +7,7 @@ use crate::proto::{ InferenceRoute, ObjectForTest, Provider, Sandbox, ServiceEndpoint, SshSession, - StoredProviderCredentialRefreshState, StoredProviderProfile, StoredSandboxDelegationBinding, + StoredProviderCredentialRefreshState, StoredProviderProfile, SandboxStatus, }; use std::collections::HashMap; @@ -189,39 +189,6 @@ impl GetResourceVersion for StoredProviderCredentialRefreshState { } } -// Implementations for StoredSandboxDelegationBinding -impl ObjectId for StoredSandboxDelegationBinding { - fn object_id(&self) -> &str { - self.metadata.as_ref().map_or("", |m| m.id.as_str()) - } -} - -impl ObjectName for StoredSandboxDelegationBinding { - fn object_name(&self) -> &str { - self.metadata.as_ref().map_or("", |m| m.name.as_str()) - } -} - -impl ObjectLabels for StoredSandboxDelegationBinding { - fn object_labels(&self) -> Option> { - self.metadata.as_ref().map(|m| m.labels.clone()) - } -} - -impl SetResourceVersion for StoredSandboxDelegationBinding { - fn set_resource_version(&mut self, version: u64) { - if let Some(meta) = self.metadata.as_mut() { - meta.resource_version = version; - } - } -} - -impl GetResourceVersion for StoredSandboxDelegationBinding { - fn get_resource_version(&self) -> u64 { - self.metadata.as_ref().map_or(0, |m| m.resource_version) - } -} - // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 437c432ee..888042c69 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -1206,9 +1206,9 @@ mod tests { material_names, vec![ "client_id", - "sandbox_id", "audience", "client_secret", + "subject_token", "scope" ] ); @@ -1216,8 +1216,8 @@ mod tests { refresh .material .iter() - .find(|material| material.name == "sandbox_id") - .is_some_and(|material| material.required) + .find(|material| material.name == "subject_token") + .is_some_and(|material| !material.required && material.secret) ); assert!( refresh diff --git a/crates/openshell-server/src/delegation.rs b/crates/openshell-server/src/delegation.rs deleted file mode 100644 index eba87c783..000000000 --- a/crates/openshell-server/src/delegation.rs +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Sandbox delegation bindings for on-behalf-of token exchange. -//! -//! Lane 3 needs a stable server-side record of which signed-in user created a -//! sandbox and which inbound bearer token was available at that time. This -//! module owns that persisted binding so later broker code can exchange the -//! user token for a delegated downstream token without storing long-lived -//! user material inside the sandbox itself. - -use crate::persistence::{ObjectType, Store, current_time_ms}; -use openshell_core::proto::{Sandbox, StoredSandboxDelegationBinding}; -use openshell_core::{ObjectId, ObjectName}; -use tonic::Status; - -impl ObjectType for StoredSandboxDelegationBinding { - fn object_type() -> &'static str { - "sandbox_delegation_binding" - } -} - -pub fn binding_name(sandbox_id: &str) -> String { - format!("sandbox-delegation-{sandbox_id}") -} - -#[allow(clippy::result_large_err)] -pub fn new_binding( - sandbox: &Sandbox, - subject: &str, - display_name: Option<&str>, - identity_provider: &str, - access_token: &str, - scopes: &[String], -) -> Result { - let sandbox_id = sandbox.object_id().trim(); - let sandbox_name = sandbox.object_name().trim(); - if sandbox_id.is_empty() { - return Err(Status::internal("sandbox is missing metadata.id")); - } - if sandbox_name.is_empty() { - return Err(Status::internal("sandbox is missing metadata.name")); - } - if subject.trim().is_empty() { - return Err(Status::invalid_argument("delegation subject is required")); - } - if access_token.trim().is_empty() { - return Err(Status::invalid_argument( - "delegation access token is required", - )); - } - - let now_ms = current_time_ms(); - Ok(StoredSandboxDelegationBinding { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: uuid::Uuid::new_v4().to_string(), - name: binding_name(sandbox_id), - created_at_ms: now_ms, - labels: std::collections::HashMap::new(), - resource_version: 0, - }), - sandbox_id: sandbox_id.to_string(), - sandbox_name: sandbox_name.to_string(), - subject: subject.trim().to_string(), - display_name: display_name.unwrap_or_default().trim().to_string(), - identity_provider: identity_provider.trim().to_string(), - access_token: access_token.trim().to_string(), - scopes: scopes.to_vec(), - captured_at_ms: now_ms, - }) -} - -pub async fn put_binding( - store: &Store, - binding: &StoredSandboxDelegationBinding, -) -> Result<(), Status> { - store - .put_scoped_message(binding, &binding.sandbox_id) - .await - .map_err(|e| Status::internal(format!("persist sandbox delegation binding failed: {e}"))) -} - -#[cfg_attr(not(test), allow(dead_code))] -pub async fn get_binding( - store: &Store, - sandbox_id: &str, -) -> Result, Status> { - store - .get_message_by_name::(&binding_name(sandbox_id)) - .await - .map_err(|e| Status::internal(format!("fetch sandbox delegation binding failed: {e}"))) -} - -pub async fn delete_binding(store: &Store, sandbox_id: &str) -> Result { - store - .delete_by_name( - StoredSandboxDelegationBinding::object_type(), - &binding_name(sandbox_id), - ) - .await - .map_err(|e| Status::internal(format!("delete sandbox delegation binding failed: {e}"))) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::Store; - - fn sandbox() -> Sandbox { - Sandbox { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: "sb-123".to_string(), - name: "demo-sandbox".to_string(), - created_at_ms: 0, - labels: std::collections::HashMap::new(), - resource_version: 0, - }), - spec: None, - status: None, - phase: 0, - current_policy_version: 0, - } - } - - #[tokio::test] - async fn binding_round_trip_works() { - let store = Store::connect("sqlite::memory:") - .await - .expect("in-memory store"); - let sandbox = sandbox(); - let binding = new_binding( - &sandbox, - "user-123", - Some("alex"), - "oidc", - "token-value", - &["sandbox:write".to_string()], - ) - .expect("binding"); - - put_binding(&store, &binding) - .await - .expect("persist binding"); - let loaded = get_binding(&store, "sb-123") - .await - .expect("load binding") - .expect("binding present"); - assert_eq!(loaded.subject, "user-123"); - assert_eq!(loaded.sandbox_name, "demo-sandbox"); - assert_eq!(loaded.identity_provider, "oidc"); - assert_eq!(loaded.access_token, "token-value"); - - let deleted = delete_binding(&store, "sb-123") - .await - .expect("delete binding"); - assert!(deleted); - assert!( - get_binding(&store, "sb-123") - .await - .expect("load binding") - .is_none() - ); - } -} diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 35b2d9c4d..1f7b62495 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -1384,7 +1384,48 @@ pub(super) async fn handle_configure_provider_refresh( credential_key, ) .await?; - let sandbox_id_for_binding = request.material.get("sandbox_id").cloned(); + let mut material = request.material; + let mut secret_material_keys = request.secret_material_keys; + if strategy == ProviderCredentialRefreshStrategy::Oauth2TokenExchange { + match (principal.as_ref(), raw_bearer_token.as_ref()) { + (Some(Principal::User(user)), Some(raw)) + if user.identity.provider == IdentityProvider::Oidc => + { + material.insert("subject_token".to_string(), raw.0.clone()); + if !secret_material_keys + .iter() + .any(|key| key == "subject_token") + { + secret_material_keys.push("subject_token".to_string()); + } + } + _ => { + if let Some(existing) = existing_refresh_state + .as_ref() + .and_then(|state| state.material.get("subject_token")) + { + material + .entry("subject_token".to_string()) + .or_insert_with(|| existing.clone()); + } else { + return Err(Status::failed_precondition( + "oauth2_token_exchange refresh requires an authenticated OIDC user bearer token during configuration", + )); + } + if existing_refresh_state.as_ref().is_some_and(|state| { + state + .secret_material_keys + .iter() + .any(|key| key == "subject_token") + }) && !secret_material_keys + .iter() + .any(|key| key == "subject_token") + { + secret_material_keys.push("subject_token".to_string()); + } + } + } + } let expires_at_ms = request.expires_at_ms.unwrap_or_else(|| { existing_refresh_state .as_ref() @@ -1396,8 +1437,8 @@ pub(super) async fn handle_configure_provider_refresh( credential_key, crate::provider_refresh::NewRefreshStateConfig { strategy, - material: request.material, - secret_material_keys: request.secret_material_keys, + material, + secret_material_keys, expires_at_ms, token_url, scopes, @@ -1411,29 +1452,6 @@ pub(super) async fn handle_configure_provider_refresh( } crate::provider_refresh::put_refresh_state(state.store.as_ref(), &state_record).await?; - if strategy == ProviderCredentialRefreshStrategy::Oauth2TokenExchange - && let (Some(Principal::User(user)), Some(raw)) = - (principal.as_ref(), raw_bearer_token.as_ref()) - && user.identity.provider == IdentityProvider::Oidc - && let Some(sandbox_id) = sandbox_id_for_binding.as_deref().map(str::trim) - && !sandbox_id.is_empty() - && let Some(sandbox) = state - .store - .get_message::(sandbox_id) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - { - let binding = crate::delegation::new_binding( - &sandbox, - &user.identity.subject, - user.identity.display_name.as_deref(), - "oidc", - &raw.0, - &user.identity.scopes, - )?; - crate::delegation::put_binding(state.store.as_ref(), &binding).await?; - } - if let Some(expires_at_ms) = request.expires_at_ms { let updated = Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 28ebc2d3d..06f67fb4c 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -10,9 +10,6 @@ #![allow(clippy::cast_possible_wrap)] // Intentional u32->i32 conversions for proto compat use crate::ServerState; -use crate::auth::identity::IdentityProvider; -use crate::auth::oidc::RawBearerToken; -use crate::auth::principal::Principal; use crate::persistence::{ObjectType, WriteCondition, generate_name}; use futures::future; use openshell_core::proto::{ @@ -122,8 +119,6 @@ async fn handle_create_sandbox_inner( ) -> Result, Status> { use crate::persistence::current_time_ms; - let principal = request.extensions().get::().cloned(); - let raw_bearer_token = request.extensions().get::().cloned(); let request = request.into_inner(); let spec = request .spec @@ -216,38 +211,7 @@ async fn handle_create_sandbox_inner( Some(Err(status)) => return Err(status), None => None, }; - - let delegation_binding = match (principal.as_ref(), raw_bearer_token.as_ref()) { - (Some(Principal::User(user)), Some(raw)) - if user.identity.provider == IdentityProvider::Oidc => - { - Some(crate::delegation::new_binding( - &sandbox, - &user.identity.subject, - user.identity.display_name.as_deref(), - "oidc", - &raw.0, - &user.identity.scopes, - )?) - } - _ => None, - }; - - if let Some(binding) = delegation_binding.as_ref() { - crate::delegation::put_binding(state.store.as_ref(), binding).await?; - } - - let sandbox = match state.compute.create_sandbox(sandbox, sandbox_token).await { - Ok(sandbox) => sandbox, - Err(err) => { - if let Some(binding) = delegation_binding.as_ref() { - let _ = - crate::delegation::delete_binding(state.store.as_ref(), &binding.sandbox_id) - .await; - } - return Err(err); - } - }; + let sandbox = state.compute.create_sandbox(sandbox, sandbox_token).await?; info!( sandbox_id = %id, @@ -529,26 +493,7 @@ async fn handle_delete_sandbox_inner( return Err(Status::invalid_argument("name is required")); } - let sandbox_id = state - .store - .get_message_by_name::(&name) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .and_then(|sandbox| sandbox.metadata.map(|meta| meta.id)); - let deleted = state.compute.delete_sandbox(&name).await?; - if deleted && let Some(sandbox_id) = sandbox_id.as_deref() { - state.telemetry.end_sandbox_session(sandbox_id); - let deleted_binding = crate::delegation::delete_binding(state.store.as_ref(), sandbox_id) - .await - .unwrap_or(false); - debug!( - sandbox_name = %name, - sandbox_id, - deleted_binding, - "deleted sandbox delegation binding" - ); - } info!(sandbox_name = %name, "DeleteSandbox request completed successfully"); Ok(Response::new(DeleteSandboxResponse { deleted })) } @@ -1982,9 +1927,6 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; - use crate::auth::identity::{Identity, IdentityProvider}; - use crate::auth::oidc::RawBearerToken; - use crate::auth::principal::{Principal, UserPrincipal}; use crate::grpc::test_support::test_server_state; use openshell_core::proto::datamodel::v1::ObjectMeta; use std::collections::HashMap; @@ -2636,49 +2578,6 @@ mod tests { assert!(err.message().contains("provider-b")); } - #[tokio::test] - async fn create_sandbox_persists_delegation_binding_for_oidc_user() { - let state = test_server_state().await; - let mut request = Request::new(CreateSandboxRequest { - name: "delegated".to_string(), - spec: Some(openshell_core::proto::SandboxSpec::default()), - labels: HashMap::new(), - }); - request - .extensions_mut() - .insert(Principal::User(UserPrincipal { - identity: Identity { - subject: "user-123".to_string(), - display_name: Some("alex".to_string()), - roles: vec!["openshell-user".to_string()], - scopes: vec!["sandbox:write".to_string()], - provider: IdentityProvider::Oidc, - }, - })); - request - .extensions_mut() - .insert(RawBearerToken("raw-access-token".to_string())); - - let response = handle_create_sandbox(&state, request) - .await - .expect("sandbox create succeeds") - .into_inner(); - let sandbox = response.sandbox.expect("sandbox present"); - let binding = crate::delegation::get_binding( - state.store.as_ref(), - sandbox.metadata.as_ref().expect("metadata").id.as_str(), - ) - .await - .expect("load binding") - .expect("binding present"); - - assert_eq!(binding.subject, "user-123"); - assert_eq!(binding.display_name, "alex"); - assert_eq!(binding.identity_provider, "oidc"); - assert_eq!(binding.access_token, "raw-access-token"); - assert_eq!(binding.scopes, vec!["sandbox:write".to_string()]); - } - #[tokio::test] async fn attach_sandbox_provider_rejects_credential_key_collisions() { let state = test_server_state().await; diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 7ad38952e..c25ba1cfd 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -25,7 +25,6 @@ pub mod cli; mod compute; pub mod config_file; mod defaults; -mod delegation; mod grpc; mod http; mod inference; diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 9d199c674..195acd9ae 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -450,7 +450,8 @@ async fn mint_credential( mint_oauth2_client_credentials(state).await } ProviderCredentialRefreshStrategy::Oauth2TokenExchange => { - mint_oauth2_token_exchange(store, state).await + let _ = store; + mint_oauth2_token_exchange(state).await } ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => { mint_google_service_account_jwt(state).await @@ -487,26 +488,18 @@ async fn mint_oauth2_refresh_token( } async fn mint_oauth2_token_exchange( - store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { let token_url = oauth2_token_url(state)?; let client_id = required_material(&state.material, "client_id")?; - let sandbox_id = required_material(&state.material, "sandbox_id")?; - let binding = crate::delegation::get_binding(store, &sandbox_id) - .await? - .ok_or_else(|| { - Status::failed_precondition(format!( - "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" - )) - })?; + let subject_token = required_material(&state.material, "subject_token")?; let mut form = vec![ ( "grant_type".to_string(), "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), ), - ("subject_token".to_string(), binding.access_token), + ("subject_token".to_string(), subject_token), ( "subject_token_type".to_string(), "urn:ietf:params:oauth:token-type:access_token".to_string(), @@ -835,7 +828,6 @@ mod tests { refresh_provider_credential, refresh_state_name, refresh_strategy_name, run_refresh_worker_tick, seconds_until_ms, }; - use crate::delegation::{new_binding, put_binding}; use crate::persistence::Store; use openshell_core::ObjectId; use openshell_core::proto::datamodel::v1::ObjectMeta; @@ -1126,7 +1118,7 @@ mod tests { } #[tokio::test] - async fn oauth2_token_exchange_refresh_uses_sandbox_delegation_binding() { + async fn oauth2_token_exchange_refresh_uses_subject_token_material() { let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path("/token")) @@ -1151,28 +1143,6 @@ mod tests { .await; let store = test_store().await; - let sandbox = Sandbox { - metadata: Some(ObjectMeta { - id: "sandbox-obo".to_string(), - name: "obo".to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - resource_version: 0, - }), - spec: Some(SandboxSpec::default()), - ..Default::default() - }; - let binding = new_binding( - &sandbox, - "user-123", - Some("alex"), - "oidc", - "user-access-token", - &["sandbox:write".to_string()], - ) - .unwrap(); - put_binding(&store, &binding).await.unwrap(); - let provider = provider("my-obo", "okta"); store.put_message(&provider).await.unwrap(); let state = new_refresh_state( @@ -1183,11 +1153,14 @@ mod tests { material: HashMap::from([ ("client_id".to_string(), "client-id".to_string()), ("client_secret".to_string(), "client-secret".to_string()), - ("sandbox_id".to_string(), "sandbox-obo".to_string()), + ("subject_token".to_string(), "user-access-token".to_string()), ("audience".to_string(), "api://downstream".to_string()), ("scope".to_string(), "files.read".to_string()), ]), - secret_material_keys: vec!["client_secret".to_string()], + secret_material_keys: vec![ + "client_secret".to_string(), + "subject_token".to_string(), + ], expires_at_ms: 0, token_url: format!("{}/token", mock_server.uri()), scopes: Vec::new(), diff --git a/docs/get-started/tutorials/okta-obo.mdx b/docs/get-started/tutorials/okta-obo.mdx index c7a8f7a59..03062886c 100644 --- a/docs/get-started/tutorials/okta-obo.mdx +++ b/docs/get-started/tutorials/okta-obo.mdx @@ -8,13 +8,13 @@ description: "Configure the built-in Okta OBO provider profile so OpenShell can keywords: "Generative AI, Cybersecurity, Tutorial, Providers, Okta, OBO, RFC 8693, Token Exchange, Delegation" --- -Use the built-in `okta-obo` profile when a sandboxed workload must call a downstream API on behalf of the human who logged into OpenShell. The gateway keeps the inbound user token server-side, binds it to sandbox creation, and exchanges it for a short-lived delegated token using Okta token exchange. +Use the built-in `okta-obo` profile when a sandboxed workload must call a downstream API on behalf of the human who logged into OpenShell. During refresh configuration, the gateway captures the logged-in user's bearer token as provider refresh secret material and later exchanges it for a short-lived delegated token using Okta token exchange. After completing this tutorial, you have: - A token-exchange service app in Okta for delegated access. - A customized `okta-obo` provider profile that points at your Okta tenant. -- A sandbox whose attached provider can mint `OKTA_OBO_ACCESS_TOKEN` from the logged-in user's identity. +- An attached provider that can mint `OKTA_OBO_ACCESS_TOKEN` from the logged-in user's identity. This tutorial covers the delegation lane. It assumes you already completed the gateway login lane and can log into the gateway with Okta before you create the sandbox. @@ -71,7 +71,7 @@ openshell provider profile lint -f okta-obo.yaml openshell provider profile import -f okta-obo.yaml ``` -## Log In and Create the Sandbox +## Log In Log into the gateway as the human user whose identity should be delegated: @@ -79,18 +79,6 @@ Log into the gateway as the human user whose identity should be delegated: openshell gateway login ``` -Create the sandbox after login so OpenShell can bind the sandbox to the authenticated user token: - -```shell -openshell sandbox create --name okta-obo-smoke -``` - -Fetch the sandbox metadata and record the sandbox ID: - -```shell -openshell sandbox get okta-obo-smoke -``` - ## Create the OBO Provider Create the provider from the imported profile: @@ -111,7 +99,6 @@ openshell provider refresh configure okta-obo-runtime \ --strategy oauth2-token-exchange \ --material client_id="$OKTA_OBO_CLIENT_ID" \ --material client_secret="$OKTA_OBO_CLIENT_SECRET" \ - --material sandbox_id="" \ --material audience="$OKTA_OBO_AUDIENCE" \ --material scope="$OKTA_OBO_SCOPE" \ --secret-material-key client_secret @@ -124,7 +111,13 @@ openshell provider refresh status okta-obo-runtime \ --credential-key OKTA_OBO_ACCESS_TOKEN ``` -## Attach the Provider +## Create the Sandbox and Attach the Provider + +Create the sandbox after configuring the provider: + +```shell +openshell sandbox create --name okta-obo-smoke +``` Attach the OBO provider to the sandbox: diff --git a/proto/openshell.proto b/proto/openshell.proto index 55165b55b..66a99152b 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -960,18 +960,6 @@ message StoredProviderCredentialRefreshState { int64 max_lifetime_seconds = 16; } -message StoredSandboxDelegationBinding { - openshell.datamodel.v1.ObjectMeta metadata = 1; - string sandbox_id = 2; - string sandbox_name = 3; - string subject = 4; - string display_name = 5; - string identity_provider = 6; - string access_token = 7; - repeated string scopes = 8; - int64 captured_at_ms = 9; -} - message GetProviderRefreshStatusRequest { string provider = 1; string credential_key = 2; diff --git a/providers/okta-obo.yaml b/providers/okta-obo.yaml index 56f38951a..776d053f8 100644 --- a/providers/okta-obo.yaml +++ b/providers/okta-obo.yaml @@ -21,9 +21,6 @@ credentials: - name: client_id description: Okta OIDC application client ID used for token exchange required: true - - name: sandbox_id - description: OpenShell sandbox ID bound to the authenticated user token - required: true - name: audience description: Downstream Okta resource audience for the delegated token required: true @@ -31,6 +28,10 @@ credentials: description: Okta client secret for confidential token-exchange clients required: false secret: true + - name: subject_token + description: Authenticated user bearer token captured during refresh configuration + required: false + secret: true - name: scope description: Space-delimited scopes requested for the delegated token required: false From 6a109c3f0ffcf196bf3d46424efb884e5b99896d Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Thu, 28 May 2026 17:15:13 -0700 Subject: [PATCH 05/13] feat(auth): persist oidc id tokens for xaa flows --- crates/openshell-bootstrap/src/oidc_token.rs | 4 ++ crates/openshell-cli/src/oidc_auth.rs | 62 +++++++++++++++----- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/crates/openshell-bootstrap/src/oidc_token.rs b/crates/openshell-bootstrap/src/oidc_token.rs index 19c6cabaa..1dcf49450 100644 --- a/crates/openshell-bootstrap/src/oidc_token.rs +++ b/crates/openshell-bootstrap/src/oidc_token.rs @@ -19,6 +19,10 @@ pub struct OidcTokenBundle { /// `OAuth2` access token (JWT). pub access_token: String, + /// Optional OIDC ID token returned by the provider. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id_token: Option, + /// `OAuth2` refresh token. `None` for `client_credentials` grants. #[serde(default, skip_serializing_if = "Option::is_none")] pub refresh_token: Option, diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index 379a53112..4cb1cf37c 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -14,13 +14,17 @@ use hyper::{Method, Response, StatusCode}; use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server::conn::auto::Builder; use miette::{IntoDiagnostic, Result}; -use oauth2::basic::BasicClient; use oauth2::{ - AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, - RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, + AuthType, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, + EndpointNotSet, ExtraTokenFields, PkceCodeChallenge, RedirectUrl, RefreshToken, Scope, + StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl, + basic::{ + BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, + BasicTokenType, + }, }; use openshell_bootstrap::oidc_token::OidcTokenBundle; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::convert::Infallible; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -30,6 +34,34 @@ use tracing::debug; const AUTH_TIMEOUT: Duration = Duration::from_secs(120); +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +struct OidcExtraTokenFields { + #[serde(default, skip_serializing_if = "Option::is_none")] + id_token: Option, +} + +impl ExtraTokenFields for OidcExtraTokenFields {} + +type OidcTokenResponse = StandardTokenResponse; +type OidcClient< + HasAuthUrl = EndpointNotSet, + HasDeviceAuthUrl = EndpointNotSet, + HasIntrospectionUrl = EndpointNotSet, + HasRevocationUrl = EndpointNotSet, + HasTokenUrl = EndpointNotSet, +> = Client< + BasicErrorResponse, + OidcTokenResponse, + BasicTokenIntrospectionResponse, + StandardRevocableToken, + BasicRevocationErrorResponse, + HasAuthUrl, + HasDeviceAuthUrl, + HasIntrospectionUrl, + HasRevocationUrl, + HasTokenUrl, +>; + /// OIDC discovery document (subset of fields we need). #[derive(Debug, Deserialize)] struct OidcDiscovery { @@ -112,7 +144,7 @@ pub async fn oidc_browser_auth_flow( let port = listener.local_addr().into_diagnostic()?.port(); let redirect_uri = format!("http://127.0.0.1:{port}/callback"); - let client = BasicClient::new(ClientId::new(client_id.to_string())) + let client = OidcClient::new(ClientId::new(client_id.to_string())) .set_auth_uri(AuthUrl::new(discovery.authorization_endpoint).into_diagnostic()?) .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?) .set_redirect_uri(RedirectUrl::new(redirect_uri).into_diagnostic()?); @@ -167,7 +199,7 @@ pub async fn oidc_browser_auth_flow( server_handle.abort(); let http = http_client(insecure); - let token_response = client + let token_response: OidcTokenResponse = client .exchange_code(AuthorizationCode::new(code)) .set_pkce_verifier(pkce_verifier) .request_async(&http) @@ -199,7 +231,7 @@ pub async fn oidc_client_credentials_flow( let discovery = discover(issuer, insecure).await?; - let client = BasicClient::new(ClientId::new(client_id.to_string())) + let client = OidcClient::new(ClientId::new(client_id.to_string())) .set_client_secret(ClientSecret::new(client_secret)) .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?) .set_auth_type(AuthType::RequestBody); @@ -213,7 +245,7 @@ pub async fn oidc_client_credentials_flow( } let http = http_client(insecure); - let token_response = request + let token_response: OidcTokenResponse = request .request_async(&http) .await .map_err(|e| miette::miette!("client credentials token exchange failed: {e}"))?; @@ -241,11 +273,11 @@ pub async fn oidc_refresh_token( let discovery = discover(&bundle.issuer, insecure).await?; - let client = BasicClient::new(ClientId::new(bundle.client_id.clone())) + let client = OidcClient::new(ClientId::new(bundle.client_id.clone())) .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?); let http = http_client(insecure); - let token_response = client + let token_response: OidcTokenResponse = client .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) .request_async(&http) .await @@ -287,7 +319,7 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Resu // ── Helpers ────────────────────────────────────────────────────────── fn bundle_from_oauth2_response( - resp: &oauth2::basic::BasicTokenResponse, + resp: &OidcTokenResponse, issuer: &str, client_id: &str, ) -> OidcTokenBundle { @@ -298,6 +330,7 @@ fn bundle_from_oauth2_response( OidcTokenBundle { access_token: resp.access_token().secret().clone(), + id_token: resp.extra_fields().id_token.clone(), refresh_token: resp.refresh_token().map(|rt| rt.secret().clone()), expires_at: resp.expires_in().map(|ei| now + ei.as_secs()), issuer: issuer.to_string(), @@ -518,14 +551,13 @@ mod tests { #[test] fn bundle_from_response_sets_fields() { - use oauth2::basic::BasicTokenResponse; - - let token_response: BasicTokenResponse = serde_json::from_str( - r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh"}"#, + let token_response: OidcTokenResponse = serde_json::from_str( + r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh","id_token":"test-id"}"#, ) .unwrap(); let bundle = bundle_from_oauth2_response(&token_response, "https://issuer", "my-client"); assert_eq!(bundle.access_token, "test-access"); + assert_eq!(bundle.id_token.as_deref(), Some("test-id")); assert_eq!(bundle.refresh_token.as_deref(), Some("test-refresh")); assert_eq!(bundle.issuer, "https://issuer"); assert_eq!(bundle.client_id, "my-client"); From 7380fc9be80d05c93769f9793691ec896f5ec670 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 29 May 2026 08:26:30 -0700 Subject: [PATCH 06/13] feat(auth): forward oidc id tokens for xaa --- crates/openshell-cli/src/completers.rs | 4 ++++ crates/openshell-cli/src/main.rs | 4 ++++ crates/openshell-cli/src/tls.rs | 10 ++++++++- crates/openshell-core/src/auth.rs | 22 +++++++++++++++++++- crates/openshell-server/src/auth/oidc.rs | 16 ++++++++++++++ crates/openshell-server/src/grpc/provider.rs | 3 ++- crates/openshell-server/src/multiplex.rs | 12 +++++++++++ crates/openshell-tui/src/lib.rs | 3 ++- 8 files changed, 70 insertions(+), 4 deletions(-) diff --git a/crates/openshell-cli/src/completers.rs b/crates/openshell-cli/src/completers.rs index a421b418a..cc7d133ab 100644 --- a/crates/openshell-cli/src/completers.rs +++ b/crates/openshell-cli/src/completers.rs @@ -102,13 +102,16 @@ async fn completion_grpc_client( Ok(refreshed) => { let _ = store_oidc_token(gateway_name, &refreshed); tls_opts.oidc_token = Some(refreshed.access_token); + tls_opts.oidc_id_token = refreshed.id_token; } Err(_) => { tls_opts.oidc_token = Some(bundle.access_token); + tls_opts.oidc_id_token = bundle.id_token; } } } else { tls_opts.oidc_token = Some(bundle.access_token); + tls_opts.oidc_id_token = bundle.id_token; } } } @@ -124,6 +127,7 @@ async fn completion_grpc_client( let channel = build_channel(server, &tls_opts).await.ok()?; let interceptor = EdgeAuthInterceptor::new( tls_opts.oidc_token.as_deref(), + tls_opts.oidc_id_token.as_deref(), tls_opts.edge_token.as_deref(), ) .ok()?; diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index bd310c115..112e21cf9 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -156,16 +156,19 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) { &refreshed, ); tls.oidc_token = Some(refreshed.access_token); + tls.oidc_id_token = refreshed.id_token; } Err(e) => { tracing::warn!("OIDC token refresh failed: {e}"); // Use the expired token anyway — server will reject it // with a clear error prompting re-login. tls.oidc_token = Some(bundle.access_token); + tls.oidc_id_token = bundle.id_token; } } } else { tls.oidc_token = Some(bundle.access_token); + tls.oidc_id_token = bundle.id_token; } } _ => {} @@ -2927,6 +2930,7 @@ async fn main() -> Result<()> { let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?; let interceptor = openshell_core::auth::EdgeAuthInterceptor::new( tls.oidc_token.as_deref(), + tls.oidc_id_token.as_deref(), tls.edge_token.as_deref(), )?; openshell_tui::run(channel, interceptor, &ctx.name, &ctx.endpoint, theme).await?; diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index 10df401a5..ce1f24387 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -40,6 +40,9 @@ pub struct TlsOptions { /// OIDC bearer token — when set, injects `authorization: Bearer ` /// on every gRPC request. Takes precedence over `edge_token`. pub oidc_token: Option, + /// OIDC ID token — when set, injects a gateway-private metadata header + /// so delegated/XAA flows can bind the signed-in user session. + pub oidc_id_token: Option, /// Skip TLS certificate verification for gateway connections. pub gateway_insecure: bool, } @@ -53,6 +56,7 @@ impl TlsOptions { gateway_name: None, edge_token: None, oidc_token: None, + oidc_id_token: None, gateway_insecure: false, } } @@ -441,7 +445,11 @@ pub async fn grpc_client(server: &str, tls: &TlsOptions) -> Result { } fn interceptor_from_tls(tls: &TlsOptions) -> Result { - EdgeAuthInterceptor::new(tls.oidc_token.as_deref(), tls.edge_token.as_deref()) + EdgeAuthInterceptor::new( + tls.oidc_token.as_deref(), + tls.oidc_id_token.as_deref(), + tls.edge_token.as_deref(), + ) } pub async fn grpc_inference_client(server: &str, tls: &TlsOptions) -> Result { diff --git a/crates/openshell-core/src/auth.rs b/crates/openshell-core/src/auth.rs index 16d513346..57fac8ebb 100644 --- a/crates/openshell-core/src/auth.rs +++ b/crates/openshell-core/src/auth.rs @@ -14,6 +14,7 @@ use miette::Result; #[allow(clippy::struct_field_names)] pub struct EdgeAuthInterceptor { bearer_value: Option>, + oidc_id_token_value: Option>, header_value: Option>, cookie_value: Option>, } @@ -23,14 +24,27 @@ impl EdgeAuthInterceptor { /// /// OIDC bearer tokens take precedence over edge tokens. Returns a no-op /// interceptor when no token is provided. - pub fn new(oidc_token: Option<&str>, edge_token: Option<&str>) -> Result { + pub fn new( + oidc_token: Option<&str>, + oidc_id_token: Option<&str>, + edge_token: Option<&str>, + ) -> Result { if let Some(token) = oidc_token { let bearer: tonic::metadata::MetadataValue = format!("Bearer {token}") .parse() .map_err(|_| miette::miette!("invalid bearer token value"))?; + let oidc_id_token_value = match oidc_id_token { + Some(token) => Some( + token + .parse() + .map_err(|_| miette::miette!("invalid OIDC ID token value"))?, + ), + None => None, + }; return Ok(Self { bearer_value: Some(bearer), + oidc_id_token_value, header_value: None, cookie_value: None, }); @@ -51,6 +65,7 @@ impl EdgeAuthInterceptor { }; Ok(Self { bearer_value: None, + oidc_id_token_value: None, header_value, cookie_value, }) @@ -60,6 +75,7 @@ impl EdgeAuthInterceptor { pub fn noop() -> Self { Self { bearer_value: None, + oidc_id_token_value: None, header_value: None, cookie_value: None, } @@ -74,6 +90,10 @@ impl tonic::service::Interceptor for EdgeAuthInterceptor { if let Some(ref val) = self.bearer_value { req.metadata_mut().insert("authorization", val.clone()); } + if let Some(ref val) = self.oidc_id_token_value { + req.metadata_mut() + .insert("x-openshell-oidc-id-token", val.clone()); + } if let Some(ref val) = self.header_value { req.metadata_mut() .insert("cf-access-jwt-assertion", val.clone()); diff --git a/crates/openshell-server/src/auth/oidc.rs b/crates/openshell-server/src/auth/oidc.rs index 42c92c6b0..945f86d61 100644 --- a/crates/openshell-server/src/auth/oidc.rs +++ b/crates/openshell-server/src/auth/oidc.rs @@ -117,6 +117,14 @@ const STANDARD_OIDC_SCOPES: &[&str] = &["openid", "profile", "email", "offline_a #[derive(Debug, Clone)] pub struct RawBearerToken(pub String); +/// Raw OIDC ID token forwarded from the authenticated CLI/TUI request. +/// +/// This is a gateway-private metadata channel used for delegated/XAA flows. +/// The gateway only trusts it after the access token has already authenticated +/// the caller as an OIDC user. +#[derive(Debug, Clone)] +pub struct RawIdToken(pub String); + /// Extract a bearer token from an `Authorization` header. pub fn extract_bearer_token(headers: &http::HeaderMap) -> Option<&str> { headers @@ -125,6 +133,14 @@ pub fn extract_bearer_token(headers: &http::HeaderMap) -> Option<&str> { .and_then(|v| v.strip_prefix("Bearer ")) } +/// Extract the forwarded OIDC ID token from a private gRPC metadata header. +pub fn extract_id_token(headers: &http::HeaderMap) -> Option<&str> { + headers + .get("x-openshell-oidc-id-token") + .and_then(|v| v.to_str().ok()) + .filter(|value| !value.trim().is_empty()) +} + impl OidcClaims { /// Extract roles from the JWT claims using a dot-separated path. /// diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 1f7b62495..4fd153836 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -6,7 +6,7 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> use crate::auth::identity::IdentityProvider; -use crate::auth::oidc::RawBearerToken; +use crate::auth::oidc::{RawBearerToken, RawIdToken}; use crate::auth::principal::Principal; use crate::persistence::{ ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, @@ -1243,6 +1243,7 @@ pub(super) async fn handle_configure_provider_refresh( ) -> Result, Status> { let principal = request.extensions().get::().cloned(); let raw_bearer_token = request.extensions().get::().cloned(); + let raw_id_token = request.extensions().get::().cloned(); let request = request.into_inner(); let provider_name = request.provider.trim(); let credential_key = request.credential_key.trim(); diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index cfb3de0f4..29dd5d0a7 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -479,6 +479,18 @@ where if let Some(token) = raw_oidc_bearer { req.extensions_mut().insert(oidc::RawBearerToken(token)); } + let raw_oidc_id_token = if let Principal::User(ref user) = principal { + if user.identity.provider == crate::auth::identity::IdentityProvider::Oidc { + oidc::extract_id_token(req.headers()).map(str::to_owned) + } else { + None + } + } else { + None + }; + if let Some(token) = raw_oidc_id_token { + req.extensions_mut().insert(oidc::RawIdToken(token)); + } req.extensions_mut().insert(principal); inner.ready().await?.call(req).await diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 22201dbb6..18d452b24 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -528,7 +528,8 @@ async fn connect_to_gateway(name: &str, endpoint: &str) -> Result<(Channel, Edge Re-authenticate with: openshell gateway login" ); } - let interceptor = EdgeAuthInterceptor::new(Some(&bundle.access_token), None)?; + let interceptor = + EdgeAuthInterceptor::new(Some(&bundle.access_token), bundle.id_token.as_deref(), None)?; let channel = build_oidc_channel(name, endpoint).await?; Ok((channel, interceptor)) } else { From 5dcbc8b92d1d7dc4a8d523b2f331b70af964517e Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 29 May 2026 08:32:26 -0700 Subject: [PATCH 07/13] feat(providers): add okta xaa refresh strategy --- crates/openshell-cli/src/main.rs | 2 + crates/openshell-cli/src/run.rs | 2 + .../tests/provider_commands_integration.rs | 53 +++++++++++++++++++ crates/openshell-providers/src/profiles.rs | 50 +++++++++++++++++ .../openshell-server/src/provider_refresh.rs | 53 +++++++++++++++++++ proto/openshell.proto | 1 + providers/okta-xaa.yaml | 42 +++++++++++++++ 7 files changed, 203 insertions(+) create mode 100644 providers/okta-xaa.yaml diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 112e21cf9..fd67048ae 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -664,6 +664,7 @@ enum CliProviderRefreshStrategy { Oauth2RefreshToken, Oauth2ClientCredentials, Oauth2TokenExchange, + OktaXaa, GoogleServiceAccountJwt, } @@ -673,6 +674,7 @@ impl CliProviderRefreshStrategy { Self::Oauth2RefreshToken => "oauth2_refresh_token", Self::Oauth2ClientCredentials => "oauth2_client_credentials", Self::Oauth2TokenExchange => "oauth2_token_exchange", + Self::OktaXaa => "okta_xaa", Self::GoogleServiceAccountJwt => "google_service_account_jwt", } } diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 4fef70ec2..fe813804d 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -5005,6 +5005,7 @@ fn provider_refresh_strategy(strategy: &str) -> Result Ok(ProviderCredentialRefreshStrategy::Oauth2TokenExchange), + "okta_xaa" => Ok(ProviderCredentialRefreshStrategy::OktaXaa), "google_service_account_jwt" => { Ok(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) } @@ -5060,6 +5061,7 @@ fn provider_refresh_strategy_name(strategy: ProviderCredentialRefreshStrategy) - ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", + ProviderCredentialRefreshStrategy::OktaXaa => "okta_xaa", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 5a5031d23..8ea9d6aef 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1842,6 +1842,59 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { ); } +#[tokio::test] +async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { + let ts = run_server().await; + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client should connect"); + let profile = client + .get_provider_profile(openshell_core::proto::GetProviderProfileRequest { + id: "okta-xaa".to_string(), + }) + .await + .expect("get provider profile") + .into_inner() + .profile + .expect("profile should exist"); + + assert_eq!(profile.id, "okta-xaa"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "xaa_access_token") + .expect("xaa access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("xaa credential should include refresh config"); + assert_eq!( + refresh.strategy, + ProviderCredentialRefreshStrategy::OktaXaa as i32 + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "sandbox_id" && material.required) + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "resource" && material.required) + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "client_assertion" + && material.required + && material.secret) + ); +} + #[tokio::test] async fn provider_profile_lint_from_directory_reports_parse_errors_without_importing() { let ts = run_server().await; diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 888042c69..14fa63379 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -22,6 +22,7 @@ const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/google-vertex-ai.yaml"), include_str!("../../../providers/nvidia.yaml"), include_str!("../../../providers/okta-obo.yaml"), + include_str!("../../../providers/okta-xaa.yaml"), ]; #[derive(Debug, thiserror::Error)] @@ -532,6 +533,7 @@ pub fn provider_refresh_strategy_from_yaml(raw: &str) -> Option Some(ProviderCredentialRefreshStrategy::Oauth2TokenExchange), + "okta_xaa" => Some(ProviderCredentialRefreshStrategy::OktaXaa), "google_service_account_jwt" => { Some(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt) } @@ -549,6 +551,7 @@ pub fn provider_refresh_strategy_to_yaml( ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", + ProviderCredentialRefreshStrategy::OktaXaa => "okta_xaa", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } @@ -1228,6 +1231,53 @@ mod tests { ); } + #[test] + fn okta_xaa_profile_exposes_id_token_exchange_shape() { + let profile = get_default_profile("okta-xaa").expect("okta-xaa profile"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "xaa_access_token") + .expect("okta-xaa access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("okta-xaa credential should be refreshable"); + + assert_eq!( + refresh.strategy, + openshell_core::proto::ProviderCredentialRefreshStrategy::OktaXaa + ); + assert_eq!( + refresh.token_url, + "https://example.okta.com/oauth2/v1/token" + ); + + let material_names = refresh + .material + .iter() + .map(|material| material.name.as_str()) + .collect::>(); + assert_eq!( + material_names, + vec![ + "client_id", + "sandbox_id", + "resource", + "client_assertion", + "client_assertion_type", + "scope" + ] + ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "client_assertion") + .is_some_and(|material| material.required && material.secret) + ); + } + #[test] fn credential_env_vars_are_deduplicated_in_profile_order() { let profile = get_default_profile("claude-code").expect("claude-code profile"); diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 195acd9ae..416a9bb73 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -279,6 +279,7 @@ pub fn refresh_strategy_name(strategy: i32) -> &'static str { ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token", ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials", ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange", + ProviderCredentialRefreshStrategy::OktaXaa => "okta_xaa", ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt", ProviderCredentialRefreshStrategy::Unspecified => "unspecified", } @@ -290,6 +291,7 @@ pub fn is_gateway_mintable_strategy(strategy: ProviderCredentialRefreshStrategy) ProviderCredentialRefreshStrategy::Oauth2RefreshToken | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials | ProviderCredentialRefreshStrategy::Oauth2TokenExchange + | ProviderCredentialRefreshStrategy::OktaXaa | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt ) } @@ -453,6 +455,9 @@ async fn mint_credential( let _ = store; mint_oauth2_token_exchange(state).await } + ProviderCredentialRefreshStrategy::OktaXaa => { + mint_okta_xaa_token_exchange(store, state).await + } ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => { mint_google_service_account_jwt(state).await } @@ -541,6 +546,54 @@ async fn mint_oauth2_client_credentials( request_token(&token_url, &form, basic_auth, state.max_lifetime_seconds).await } +async fn mint_okta_xaa_token_exchange( + store: &Store, + state: &StoredProviderCredentialRefreshState, +) -> Result { + let token_url = oauth2_token_url(state)?; + let sandbox_id = required_material(&state.material, "sandbox_id")?; + let binding = crate::delegation::get_binding(store, &sandbox_id) + .await? + .ok_or_else(|| { + Status::failed_precondition(format!( + "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" + )) + })?; + if binding.id_token.trim().is_empty() { + return Err(Status::failed_precondition( + "sandbox delegation binding does not contain an OIDC id_token", + )); + } + + let client_id = required_material(&state.material, "client_id")?; + let client_assertion = required_material(&state.material, "client_assertion")?; + let client_assertion_type = material_value(&state.material, &["client_assertion_type"]) + .unwrap_or_else(|| "urn:ietf:params:oauth:client-assertion-type:jwt-bearer".to_string()); + let resource = required_material(&state.material, "resource")?; + + let mut form = vec![ + ( + "grant_type".to_string(), + "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), + ), + ("client_id".to_string(), client_id), + ("client_assertion".to_string(), client_assertion), + ("client_assertion_type".to_string(), client_assertion_type), + ("subject_token".to_string(), binding.id_token), + ( + "subject_token_type".to_string(), + "urn:ietf:params:oauth:token-type:id_token".to_string(), + ), + ("resource".to_string(), resource), + ]; + let scope = refresh_scopes(state).join(" "); + if !scope.is_empty() { + form.push(("scope".to_string(), scope)); + } + + request_token(&token_url, &form, None, state.max_lifetime_seconds).await +} + async fn mint_google_service_account_jwt( state: &StoredProviderCredentialRefreshState, ) -> Result { diff --git a/proto/openshell.proto b/proto/openshell.proto index 66a99152b..7f4980d90 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -905,6 +905,7 @@ enum ProviderCredentialRefreshStrategy { PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OAUTH2_CLIENT_CREDENTIALS = 4; PROVIDER_CREDENTIAL_REFRESH_STRATEGY_GOOGLE_SERVICE_ACCOUNT_JWT = 5; PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OAUTH2_TOKEN_EXCHANGE = 6; + PROVIDER_CREDENTIAL_REFRESH_STRATEGY_OKTA_XAA = 7; } message ProviderCredentialRefreshMaterial { diff --git a/providers/okta-xaa.yaml b/providers/okta-xaa.yaml new file mode 100644 index 000000000..d3b1709bd --- /dev/null +++ b/providers/okta-xaa.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +id: okta-xaa +display_name: Okta XAA +description: Okta Cross App Access tokens minted from the authenticated OpenShell user's OIDC ID token +category: other +credentials: + - name: xaa_access_token + description: Okta XAA delegated access token for downstream resources + env_vars: [OKTA_XAA_ACCESS_TOKEN] + required: true + auth_style: bearer + header_name: authorization + refresh: + strategy: okta_xaa + token_url: https://example.okta.com/oauth2/v1/token + refresh_before_seconds: 300 + max_lifetime_seconds: 3600 + material: + - name: client_id + description: Okta AI agent client ID used for XAA token exchange + required: true + - name: sandbox_id + description: OpenShell sandbox ID bound to the authenticated user's OIDC tokens + required: true + - name: resource + description: Okta XAA resource or resource connection identifier + required: true + - name: client_assertion + description: JWT client assertion used to authenticate the AI agent with Okta + required: true + secret: true + - name: client_assertion_type + description: OAuth client assertion type, usually urn:ietf:params:oauth:client-assertion-type:jwt-bearer + required: false + - name: scope + description: Space-delimited scopes requested for the delegated token + required: false +binaries: + - /usr/bin/curl + - /usr/local/bin/curl From fb3ba2ef3de7fda5eb46158aaa80d3c3397bd1d2 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 29 May 2026 08:49:53 -0700 Subject: [PATCH 08/13] feat(auth): generate okta xaa client assertions --- .../tests/provider_commands_integration.rs | 2 +- crates/openshell-providers/src/profiles.rs | 5 +- .../openshell-server/src/provider_refresh.rs | 132 +++++++++++++++++- providers/okta-xaa.yaml | 7 +- 4 files changed, 140 insertions(+), 6 deletions(-) diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 8ea9d6aef..9f3975ec6 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1889,7 +1889,7 @@ async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "client_assertion" + .any(|material| material.name == "private_key_pem" && material.required && material.secret) ); diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 14fa63379..c0a41e1d0 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -1264,7 +1264,8 @@ mod tests { "client_id", "sandbox_id", "resource", - "client_assertion", + "private_key_pem", + "kid", "client_assertion_type", "scope" ] @@ -1273,7 +1274,7 @@ mod tests { refresh .material .iter() - .find(|material| material.name == "client_assertion") + .find(|material| material.name == "private_key_pem") .is_some_and(|material| material.required && material.secret) ); } diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 416a9bb73..2b8d9baf5 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -246,6 +246,16 @@ struct GoogleServiceAccountClaims<'a> { sub: Option<&'a str>, } +#[derive(Debug, Serialize)] +struct OktaXaaClientAssertionClaims<'a> { + iss: &'a str, + sub: &'a str, + aud: &'a str, + iat: i64, + exp: i64, + jti: String, +} + pub fn next_refresh_at_ms( expires_at_ms: i64, refresh_before_seconds: i64, @@ -566,7 +576,7 @@ async fn mint_okta_xaa_token_exchange( } let client_id = required_material(&state.material, "client_id")?; - let client_assertion = required_material(&state.material, "client_assertion")?; + let client_assertion = build_okta_xaa_client_assertion(state, &token_url, &client_id)?; let client_assertion_type = material_value(&state.material, &["client_assertion_type"]) .unwrap_or_else(|| "urn:ietf:params:oauth:client-assertion-type:jwt-bearer".to_string()); let resource = required_material(&state.material, "resource")?; @@ -594,6 +604,37 @@ async fn mint_okta_xaa_token_exchange( request_token(&token_url, &form, None, state.max_lifetime_seconds).await } +fn build_okta_xaa_client_assertion( + state: &StoredProviderCredentialRefreshState, + token_url: &str, + client_id: &str, +) -> Result { + let private_key_pem = required_material(&state.material, "private_key_pem")?; + let lifetime_secs = if state.max_lifetime_seconds > 0 { + state.max_lifetime_seconds.min(DEFAULT_MAX_LIFETIME_SECONDS) + } else { + DEFAULT_MAX_LIFETIME_SECONDS + }; + let now_secs = current_time_ms() / 1000; + let claims = OktaXaaClientAssertionClaims { + iss: client_id, + sub: client_id, + aud: token_url, + iat: now_secs, + exp: now_secs.saturating_add(lifetime_secs), + jti: uuid::Uuid::new_v4().to_string(), + }; + let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256); + header.kid = material_value(&state.material, &["kid"]); + jsonwebtoken::encode( + &header, + &claims, + &jsonwebtoken::EncodingKey::from_rsa_pem(private_key_pem.as_bytes()) + .map_err(|_| Status::invalid_argument("okta_xaa private_key_pem must be RSA PEM"))?, + ) + .map_err(|_| Status::internal("sign okta xaa client assertion failed")) +} + async fn mint_google_service_account_jwt( state: &StoredProviderCredentialRefreshState, ) -> Result { @@ -1240,6 +1281,95 @@ mod tests { ); } + #[tokio::test] + async fn okta_xaa_refresh_uses_id_token_and_generated_client_assertion() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .and(body_string_contains( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange", + )) + .and(body_string_contains( + "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token", + )) + .and(body_string_contains("resource=jira-connection")) + .and(body_string_contains("client_assertion=")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "xaa-delegated-token", + "expires_in": 1800, + "token_type": "Bearer" + }))) + .mount(&mock_server) + .await; + + let store = test_store().await; + let sandbox = Sandbox { + metadata: Some(ObjectMeta { + id: "sandbox-xaa".to_string(), + name: "xaa".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec::default()), + ..Default::default() + }; + let binding = new_binding( + &sandbox, + "user-123", + Some("alex"), + "oidc", + "user-access-token", + Some("user-id-token"), + &["jira.read".to_string()], + ) + .unwrap(); + put_binding(&store, &binding).await.unwrap(); + + let provider = provider("my-xaa", "okta-xaa"); + store.put_message(&provider).await.unwrap(); + let state = new_refresh_state( + &provider, + "OKTA_XAA_ACCESS_TOKEN", + NewRefreshStateConfig { + strategy: ProviderCredentialRefreshStrategy::OktaXaa, + material: HashMap::from([ + ("client_id".to_string(), "agent-client-id".to_string()), + ("sandbox_id".to_string(), "sandbox-xaa".to_string()), + ("resource".to_string(), "jira-connection".to_string()), + ( + "private_key_pem".to_string(), + TEST_RSA_PRIVATE_KEY.to_string(), + ), + ("kid".to_string(), "test-key-id".to_string()), + ]), + secret_material_keys: vec!["private_key_pem".to_string()], + expires_at_ms: 0, + token_url: format!("{}/token", mock_server.uri()), + scopes: vec!["jira.read".to_string()], + refresh_before_seconds: 300, + max_lifetime_seconds: 3600, + }, + ) + .unwrap(); + put_refresh_state(&store, &state).await.unwrap(); + + let refreshed = refresh_provider_credential(&store, "my-xaa", "OKTA_XAA_ACCESS_TOKEN") + .await + .unwrap(); + assert_eq!(refreshed.status, "refreshed"); + + let stored = store + .get_message_by_name::("my-xaa") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.credentials.get("OKTA_XAA_ACCESS_TOKEN"), + Some(&"xaa-delegated-token".to_string()) + ); + } + #[tokio::test] async fn google_service_account_refresh_mints_and_persists_access_token() { let mock_server = MockServer::start().await; diff --git a/providers/okta-xaa.yaml b/providers/okta-xaa.yaml index d3b1709bd..ed5a4036e 100644 --- a/providers/okta-xaa.yaml +++ b/providers/okta-xaa.yaml @@ -27,10 +27,13 @@ credentials: - name: resource description: Okta XAA resource or resource connection identifier required: true - - name: client_assertion - description: JWT client assertion used to authenticate the AI agent with Okta + - name: private_key_pem + description: RSA private key PEM used to sign the Okta XAA client assertion required: true secret: true + - name: kid + description: Optional key ID for the registered Okta public key + required: false - name: client_assertion_type description: OAuth client assertion type, usually urn:ietf:params:oauth:client-assertion-type:jwt-bearer required: false From 7b962cd08f61fe7b30569b2fe77775b40368a204 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 14:53:44 -0700 Subject: [PATCH 09/13] docs(xaa): add xaa.dev sample tutorial --- docs/get-started/tutorials/index.mdx | 5 + docs/get-started/tutorials/xaa-dev.mdx | 151 +++++++++++++++++++++++++ docs/sandboxes/providers-v2.mdx | 3 + 3 files changed, 159 insertions(+) create mode 100644 docs/get-started/tutorials/xaa-dev.mdx diff --git a/docs/get-started/tutorials/index.mdx b/docs/get-started/tutorials/index.mdx index b6c032013..9ba887eed 100644 --- a/docs/get-started/tutorials/index.mdx +++ b/docs/get-started/tutorials/index.mdx @@ -32,6 +32,11 @@ Configure a Providers v2 Microsoft Graph provider with gateway-managed OAuth2 re Configure delegated Okta access on behalf of the logged-in OpenShell user with token exchange. + + +Run the sample 2-step Cross App Access flow through OpenShell with `xaa.dev` and the Todo0 resource API. + + Route inference through Ollama using cloud-hosted or local models, and verify it from a sandbox. diff --git a/docs/get-started/tutorials/xaa-dev.mdx b/docs/get-started/tutorials/xaa-dev.mdx new file mode 100644 index 000000000..17431f6dd --- /dev/null +++ b/docs/get-started/tutorials/xaa-dev.mdx @@ -0,0 +1,151 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Run the XAA.dev Sample XAA Flow" +sidebar-title: "XAA.dev Sample XAA" +slug: "get-started/tutorials/xaa-dev" +description: "Configure the built-in xaa-dev provider profile and prove the sample 2-step XAA flow from inside an OpenShell sandbox." +keywords: "Generative AI, Cybersecurity, Tutorial, XAA, Cross App Access, ID-JAG, Providers, Sandbox" +--- + +Use the built-in `xaa-dev` profile to exercise the sample 2-step Cross App Access flow backed by `xaa.dev`. OpenShell binds the logged-in user identity to a sandbox, exchanges that user `id_token` for an `ID-JAG`, exchanges the `ID-JAG` for a delegated resource token, and injects the delegated token into the sandbox for outbound API calls. + +After completing this tutorial, you have: + +- A registered `xaa.dev` requesting app connected to the `Todo0 Resource App`. +- A provider that mints `XAA_DEV_ACCESS_TOKEN` through the sample 2-step XAA flow. +- A sandbox that can call the sample `Todo0` resource API with the delegated token. + + +This tutorial documents the working `xaa.dev` sample path. It does not cover a tenant-integrated resource authorization server in your own Okta environment. + + +## Prerequisites + +- A working OpenShell installation with an active gateway. +- A local gateway login flow that can store an OIDC `id_token`. +- `providers_v2_enabled=true` on the target gateway. +- A `xaa.dev` requesting app registered at `https://xaa.dev/developer/register`. + +When you register the sample app, use: + +| Field | Value | +|---|---| +| Application Name | `OpenShell XAA Local` | +| Redirect URI | `http://127.0.0.1:8767/callback` | +| Post-Logout Redirect URI | `http://127.0.0.1:8767/` | +| Resource Connection | `Todo0 Resource App` | + +Record these values from the `xaa.dev` registration and integration guide: + +| Variable | Value | +|---|---| +| `XAA_DEV_REQUESTING_CLIENT_ID` | Main requesting app client ID, for example `client_f8d67a261fab113c`. | +| `XAA_DEV_REQUESTING_CLIENT_SECRET` | Main requesting app client secret. | +| `XAA_DEV_RESOURCE_CLIENT_ID` | Resource client ID, for example `client_-at-todo0`. | +| `XAA_DEV_RESOURCE_CLIENT_SECRET` | Resource client secret from the Step 3 guide. | + +The sample endpoints used by the built-in profile are: + +| Purpose | URL | +|---|---| +| Identity provider token endpoint | `https://idp.xaa.dev/token` | +| Resource authorization server token endpoint | `https://auth.resource.xaa.dev/token` | +| Protected API base | `https://api.resource.xaa.dev` | + + + +## Enable Providers v2 + +Enable profile-backed provider policy on the target gateway: + +```shell +openshell settings set --global --key providers_v2_enabled --value true +``` + +## Log In and Create a Sandbox + +Log in before creating the sandbox so OpenShell can bind the sandbox to the authenticated user's `id_token`: + +```shell +openshell gateway login +``` + +Create a sandbox and record its UUID: + +```shell +openshell sandbox create --name xaa-dev-sample +openshell sandbox get xaa-dev-sample +``` + +The `provider refresh configure` command needs the sandbox UUID, not just the sandbox name. + +## Create the Sample Provider + +Create the provider from the built-in profile: + +```shell +openshell provider create \ + --name xaa-dev-runtime \ + --type xaa-dev +``` + +## Configure the 2-Step XAA Refresh + +Configure the requesting-app and resource-app credentials: + +```shell +openshell provider refresh configure xaa-dev-runtime \ + --credential-key XAA_DEV_ACCESS_TOKEN \ + --strategy okta-xaa \ + --material requesting_client_id="$XAA_DEV_REQUESTING_CLIENT_ID" \ + --material requesting_client_secret="$XAA_DEV_REQUESTING_CLIENT_SECRET" \ + --material sandbox_id="" \ + --material resource_client_id="$XAA_DEV_RESOURCE_CLIENT_ID" \ + --material resource_client_secret="$XAA_DEV_RESOURCE_CLIENT_SECRET" \ + --material audience="https://auth.resource.xaa.dev" \ + --material resource="https://api.resource.xaa.dev" \ + --material resource_token_url="https://auth.resource.xaa.dev/token" \ + --material scope="todos.read" \ + --secret-material-key requesting_client_secret \ + --secret-material-key resource_client_secret +``` + +Force a refresh and confirm the credential status: + +```shell +openshell provider refresh rotate xaa-dev-runtime \ + --credential-key XAA_DEV_ACCESS_TOKEN + +openshell provider refresh status xaa-dev-runtime \ + --credential-key XAA_DEV_ACCESS_TOKEN +``` + +The status should show `refreshed`. + +## Attach the Provider + +Attach the provider to the sandbox: + +```shell +openshell sandbox provider attach xaa-dev-sample xaa-dev-runtime +``` + +## Verify the Delegated Token Path + +Run the sample API call from inside the sandbox: + +```shell +openshell sandbox exec \ + --name xaa-dev-sample \ + -- /bin/sh -lc 'curl --http1.1 -sS -H "Authorization: Bearer $XAA_DEV_ACCESS_TOKEN" https://api.resource.xaa.dev/api/todos' +``` + +The command should return the sample `Todo0` JSON payload. + + + +## Next Steps + +- Use [Providers v2](/sandboxes/providers-v2) for more detail on profile-backed provider behavior. +- Use [Okta OBO Token Exchange](/get-started/tutorials/okta-obo) if you want the delegated Okta OBO path instead of the `xaa.dev` sample flow. diff --git a/docs/sandboxes/providers-v2.mdx b/docs/sandboxes/providers-v2.mdx index 3ac248ff8..dd64a2aef 100644 --- a/docs/sandboxes/providers-v2.mdx +++ b/docs/sandboxes/providers-v2.mdx @@ -96,6 +96,9 @@ Built-in Providers v2 profiles currently include: | `github` | `source_control` | `GITHUB_TOKEN`, `GH_TOKEN` | | `google-vertex-ai` | `inference` | `GOOGLE_SERVICE_ACCOUNT_KEY`, `GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN`, `VERTEX_AI_SERVICE_ACCOUNT_TOKEN`, `GOOGLE_VERTEX_AI_TOKEN`, `VERTEX_AI_TOKEN` | | `nvidia` | `inference` | `NVIDIA_API_KEY` | +| `xaa-dev` | `other` | `XAA_DEV_ACCESS_TOKEN` | + +The built-in `xaa-dev` profile demonstrates the sample 2-step XAA flow and includes the sample identity provider, resource authorization server, and resource API endpoints. For the full setup, see [XAA.dev Sample XAA](/get-started/tutorials/xaa-dev). Export a built-in profile as YAML: From 29640d35e870737bcb3cf22a1353bae6ae6674b5 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 15:03:14 -0700 Subject: [PATCH 10/13] feat(xaa): add sample xaa provider flow --- crates/openshell-cli/src/main.rs | 24 +++- crates/openshell-cli/src/oidc_auth.rs | 37 ++++- crates/openshell-cli/src/run.rs | 87 ++++-------- .../tests/provider_commands_integration.rs | 61 +++++++- crates/openshell-providers/src/profiles.rs | 65 ++++++++- .../openshell-server/src/provider_refresh.rs | 133 +++++++++++++++++- providers/okta-xaa.yaml | 27 ++-- providers/xaa-dev.yaml | 66 +++++++++ 8 files changed, 411 insertions(+), 89 deletions(-) create mode 100644 providers/xaa-dev.yaml diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index fd67048ae..47c7903b1 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -140,6 +140,15 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) { else { return; }; + let bearer_for_gateway = |access_token: &str, id_token: Option<&String>| { + if access_token.matches('.').count() == 2 { + access_token.to_string() + } else { + id_token + .cloned() + .unwrap_or_else(|| access_token.to_string()) + } + }; if openshell_bootstrap::oidc_token::is_token_expired(&bundle) { let insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE") .is_ok_and(|v| !v.is_empty() && v != "0" && v != "false"); @@ -155,19 +164,28 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) { gateway_name, &refreshed, ); - tls.oidc_token = Some(refreshed.access_token); + tls.oidc_token = Some(bearer_for_gateway( + &refreshed.access_token, + refreshed.id_token.as_ref(), + )); tls.oidc_id_token = refreshed.id_token; } Err(e) => { tracing::warn!("OIDC token refresh failed: {e}"); // Use the expired token anyway — server will reject it // with a clear error prompting re-login. - tls.oidc_token = Some(bundle.access_token); + tls.oidc_token = Some(bearer_for_gateway( + &bundle.access_token, + bundle.id_token.as_ref(), + )); tls.oidc_id_token = bundle.id_token; } } } else { - tls.oidc_token = Some(bundle.access_token); + tls.oidc_token = Some(bearer_for_gateway( + &bundle.access_token, + bundle.id_token.as_ref(), + )); tls.oidc_id_token = bundle.id_token; } } diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index 4cb1cf37c..b508434c7 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -33,6 +33,9 @@ use tokio::sync::oneshot; use tracing::debug; const AUTH_TIMEOUT: Duration = Duration::from_secs(120); +const DEFAULT_OIDC_CALLBACK_BIND: &str = "127.0.0.1:0"; +const OIDC_CALLBACK_PORT_ENV: &str = "OPENSHELL_OIDC_CALLBACK_PORT"; +const OIDC_CLIENT_SECRET_ENV: &str = "OPENSHELL_OIDC_CLIENT_SECRET"; #[derive(Clone, Debug, Default, Deserialize, Serialize)] struct OidcExtraTokenFields { @@ -127,6 +130,25 @@ fn build_ci_scopes(scopes: Option<&str>) -> Vec { .collect() } +fn oidc_callback_bind_address() -> Result { + match std::env::var(OIDC_CALLBACK_PORT_ENV) { + Ok(raw) => { + let port = raw.parse::().map_err(|_| { + miette::miette!( + "{OIDC_CALLBACK_PORT_ENV} must be a valid TCP port number, got '{raw}'" + ) + })?; + if port == 0 { + return Err(miette::miette!( + "{OIDC_CALLBACK_PORT_ENV} must be greater than 0" + )); + } + Ok(format!("127.0.0.1:{port}")) + } + Err(_) => Ok(DEFAULT_OIDC_CALLBACK_BIND.to_string()), + } +} + /// Run the OIDC Authorization Code + PKCE browser flow. /// /// Opens the user's browser to the Keycloak login page and waits for @@ -140,14 +162,21 @@ pub async fn oidc_browser_auth_flow( ) -> Result { let discovery = discover(issuer, insecure).await?; - let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?; + let listener = TcpListener::bind(oidc_callback_bind_address()?) + .await + .into_diagnostic()?; let port = listener.local_addr().into_diagnostic()?.port(); let redirect_uri = format!("http://127.0.0.1:{port}/callback"); - let client = OidcClient::new(ClientId::new(client_id.to_string())) + let mut client = OidcClient::new(ClientId::new(client_id.to_string())) .set_auth_uri(AuthUrl::new(discovery.authorization_endpoint).into_diagnostic()?) .set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?) .set_redirect_uri(RedirectUrl::new(redirect_uri).into_diagnostic()?); + if let Ok(client_secret) = std::env::var(OIDC_CLIENT_SECRET_ENV) { + client = client + .set_client_secret(ClientSecret::new(client_secret)) + .set_auth_type(AuthType::RequestBody); + } let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -223,9 +252,9 @@ pub async fn oidc_client_credentials_flow( scopes: Option<&str>, insecure: bool, ) -> Result { - let client_secret = std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").map_err(|_| { + let client_secret = std::env::var(OIDC_CLIENT_SECRET_ENV).map_err(|_| { miette::miette!( - "OPENSHELL_OIDC_CLIENT_SECRET environment variable is required for client credentials flow" + "{OIDC_CLIENT_SECRET_ENV} environment variable is required for client credentials flow" ) })?; diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index fe813804d..a8c8d9d3f 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -974,49 +974,25 @@ pub async fn gateway_add( eprintln!(" {} oidc", "Auth:".dimmed()); eprintln!(); - // Check for client_credentials env var (CI mode). - if std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").is_ok() { - match crate::oidc_auth::oidc_client_credentials_flow( - issuer, - oidc_client_id, - oidc_audience, - oidc_scopes, - gateway_insecure, - ) - .await - { - Ok(bundle) => { - openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; - eprintln!( - "{} Authenticated via client credentials", - "✓".green().bold() - ); - } - Err(e) => { - eprintln!("{} Authentication failed: {e}", "!".yellow()); - } + match crate::oidc_auth::oidc_browser_auth_flow( + issuer, + oidc_client_id, + oidc_audience, + oidc_scopes, + gateway_insecure, + ) + .await + { + Ok(bundle) => { + openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; + eprintln!("{} Authenticated successfully", "✓".green().bold()); } - } else { - match crate::oidc_auth::oidc_browser_auth_flow( - issuer, - oidc_client_id, - oidc_audience, - oidc_scopes, - gateway_insecure, - ) - .await - { - Ok(bundle) => { - openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; - eprintln!("{} Authenticated successfully", "✓".green().bold()); - } - Err(e) => { - eprintln!("{} Authentication skipped: {e}", "!".yellow()); - eprintln!( - " Authenticate later with: {}", - "openshell gateway login".dimmed(), - ); - } + Err(e) => { + eprintln!("{} Authentication skipped: {e}", "!".yellow()); + eprintln!( + " Authenticate later with: {}", + "openshell gateway login".dimmed(), + ); } } @@ -1201,25 +1177,14 @@ pub async fn gateway_login(name: &str, gateway_insecure: bool) -> Result<()> { let audience = metadata.oidc_audience.as_deref(); let scopes = metadata.oidc_scopes.as_deref(); - let bundle = if std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").is_ok() { - crate::oidc_auth::oidc_client_credentials_flow( - issuer, - client_id, - audience, - scopes, - gateway_insecure, - ) - .await? - } else { - crate::oidc_auth::oidc_browser_auth_flow( - issuer, - client_id, - audience, - scopes, - gateway_insecure, - ) - .await? - }; + let bundle = crate::oidc_auth::oidc_browser_auth_flow( + issuer, + client_id, + audience, + scopes, + gateway_insecure, + ) + .await?; let username = jwt_preferred_username(&bundle.access_token); openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 9f3975ec6..ad1e7dd07 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -1883,13 +1883,70 @@ async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "resource" && material.required) + .any(|material| material.name == "requesting_client_id" && material.required) ); assert!( refresh .material .iter() - .any(|material| material.name == "private_key_pem" + .any(|material| material.name == "requesting_client_secret" + && material.required + && material.secret) + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "resource_client_id" && material.required) + ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "resource_client_secret" + && material.required + && material.secret) + ); +} + +#[tokio::test] +async fn built_in_xaa_dev_profile_is_available_via_provider_profile_api() { + let ts = run_server().await; + + let mut client = openshell_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client should connect"); + let profile = client + .get_provider_profile(openshell_core::proto::GetProviderProfileRequest { + id: "xaa-dev".to_string(), + }) + .await + .expect("get provider profile") + .into_inner() + .profile + .expect("profile should exist"); + + assert_eq!(profile.id, "xaa-dev"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "xaa_access_token") + .expect("xaa access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("xaa credential should include refresh config"); + assert_eq!( + refresh.strategy, + ProviderCredentialRefreshStrategy::OktaXaa as i32 + ); + assert_eq!(refresh.token_url, "https://idp.xaa.dev/token"); + assert_eq!(refresh.scopes, vec!["todos.read"]); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "resource_client_secret" && material.required && material.secret) ); diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index c0a41e1d0..47fa688c3 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -23,6 +23,7 @@ const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/nvidia.yaml"), include_str!("../../../providers/okta-obo.yaml"), include_str!("../../../providers/okta-xaa.yaml"), + include_str!("../../../providers/xaa-dev.yaml"), ]; #[derive(Debug, thiserror::Error)] @@ -1261,12 +1262,14 @@ mod tests { assert_eq!( material_names, vec![ - "client_id", + "requesting_client_id", + "requesting_client_secret", "sandbox_id", + "resource_client_id", + "resource_client_secret", + "audience", "resource", - "private_key_pem", - "kid", - "client_assertion_type", + "resource_token_url", "scope" ] ); @@ -1274,11 +1277,63 @@ mod tests { refresh .material .iter() - .find(|material| material.name == "private_key_pem") + .find(|material| material.name == "requesting_client_secret") + .is_some_and(|material| material.required && material.secret) + ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "resource_client_secret") .is_some_and(|material| material.required && material.secret) ); } + #[test] + fn xaa_dev_profile_exposes_sample_two_step_shape() { + let profile = get_default_profile("xaa-dev").expect("xaa-dev profile"); + let credential = profile + .credentials + .iter() + .find(|credential| credential.name == "xaa_access_token") + .expect("xaa-dev access token credential"); + let refresh = credential + .refresh + .as_ref() + .expect("xaa-dev credential should be refreshable"); + + assert_eq!( + refresh.strategy, + openshell_core::proto::ProviderCredentialRefreshStrategy::OktaXaa + ); + assert_eq!(refresh.token_url, "https://idp.xaa.dev/token"); + assert_eq!(refresh.scopes, vec!["todos.read"]); + + let material_names = refresh + .material + .iter() + .map(|material| material.name.as_str()) + .collect::>(); + assert_eq!( + material_names, + vec![ + "requesting_client_id", + "requesting_client_secret", + "sandbox_id", + "resource_client_id", + "resource_client_secret", + "audience", + "resource", + "resource_token_url", + ] + ); + assert_eq!(profile.endpoints.len(), 3); + assert_eq!(profile.endpoints[0].host, "idp.xaa.dev"); + assert_eq!(profile.endpoints[1].host, "auth.resource.xaa.dev"); + assert_eq!(profile.endpoints[2].host, "api.resource.xaa.dev"); + assert_eq!(profile.endpoints[2].access, "read-only"); + } + #[test] fn credential_env_vars_are_deduplicated_in_profile_order() { let profile = get_default_profile("claude-code").expect("claude-code profile"); diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 2b8d9baf5..660493e54 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -560,6 +560,15 @@ async fn mint_okta_xaa_token_exchange( store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { + if material_value( + &state.material, + &["requesting_client_id", "requesting_client_secret", "resource_client_id", "resource_client_secret"], + ) + .is_some() + { + return mint_okta_xaa_sample_token_exchange(store, state).await; + } + let token_url = oauth2_token_url(state)?; let sandbox_id = required_material(&state.material, "sandbox_id")?; let binding = crate::delegation::get_binding(store, &sandbox_id) @@ -579,13 +588,17 @@ async fn mint_okta_xaa_token_exchange( let client_assertion = build_okta_xaa_client_assertion(state, &token_url, &client_id)?; let client_assertion_type = material_value(&state.material, &["client_assertion_type"]) .unwrap_or_else(|| "urn:ietf:params:oauth:client-assertion-type:jwt-bearer".to_string()); - let resource = required_material(&state.material, "resource")?; + let audience = required_material(&state.material, "audience")?; let mut form = vec![ ( "grant_type".to_string(), "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), ), + ( + "requested_token_type".to_string(), + "urn:ietf:params:oauth:token-type:id-jag".to_string(), + ), ("client_id".to_string(), client_id), ("client_assertion".to_string(), client_assertion), ("client_assertion_type".to_string(), client_assertion_type), @@ -594,7 +607,7 @@ async fn mint_okta_xaa_token_exchange( "subject_token_type".to_string(), "urn:ietf:params:oauth:token-type:id_token".to_string(), ), - ("resource".to_string(), resource), + ("audience".to_string(), audience), ]; let scope = refresh_scopes(state).join(" "); if !scope.is_empty() { @@ -604,6 +617,110 @@ async fn mint_okta_xaa_token_exchange( request_token(&token_url, &form, None, state.max_lifetime_seconds).await } +async fn mint_okta_xaa_sample_token_exchange( + store: &Store, + state: &StoredProviderCredentialRefreshState, +) -> Result { + let sandbox_id = required_material(&state.material, "sandbox_id")?; + let binding = crate::delegation::get_binding(store, &sandbox_id) + .await? + .ok_or_else(|| { + Status::failed_precondition(format!( + "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" + )) + })?; + if binding.id_token.trim().is_empty() { + return Err(Status::failed_precondition( + "sandbox delegation binding does not contain an OIDC id_token", + )); + } + + let requesting_client_id = required_material(&state.material, "requesting_client_id")?; + let requesting_client_secret = required_material(&state.material, "requesting_client_secret")?; + let resource_client_id = required_material(&state.material, "resource_client_id")?; + let resource_client_secret = required_material(&state.material, "resource_client_secret")?; + let audience = required_material(&state.material, "audience")?; + let resource = material_value(&state.material, &["resource"]) + .unwrap_or_else(|| audience.clone()); + let scope = refresh_scopes(state).join(" "); + + let idp_token_url = oauth2_token_url(state)?; + let mut jag_form = vec![ + ( + "grant_type".to_string(), + "urn:ietf:params:oauth:grant-type:token-exchange".to_string(), + ), + ( + "requested_token_type".to_string(), + "urn:ietf:params:oauth:token-type:id-jag".to_string(), + ), + ("subject_token".to_string(), binding.id_token), + ( + "subject_token_type".to_string(), + "urn:ietf:params:oauth:token-type:id_token".to_string(), + ), + ("audience".to_string(), audience.clone()), + ("resource".to_string(), resource), + ("client_id".to_string(), requesting_client_id), + ("client_secret".to_string(), requesting_client_secret), + ]; + if !scope.is_empty() { + jag_form.push(("scope".to_string(), scope.clone())); + } + + let id_jag = request_token( + &idp_token_url, + &jag_form, + None, + state.max_lifetime_seconds, + ) + .await? + .access_token; + + let resource_token_url = material_value(&state.material, &["resource_token_url"]) + .map(Ok) + .unwrap_or_else(|| token_url_from_issuer(&audience))?; + let mut resource_form = vec![ + ( + "grant_type".to_string(), + "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(), + ), + ("assertion".to_string(), id_jag), + ("client_id".to_string(), resource_client_id), + ("client_secret".to_string(), resource_client_secret), + ]; + if !scope.is_empty() { + resource_form.push(("scope".to_string(), scope)); + } + + request_token( + &resource_token_url, + &resource_form, + None, + state.max_lifetime_seconds, + ) + .await +} + +fn token_url_from_issuer(issuer: &str) -> Result { + let mut url = reqwest::Url::parse(issuer) + .map_err(|_| Status::invalid_argument("issuer must be an absolute URL"))?; + let mut path = url.path().trim_end_matches('/').to_string(); + if path.is_empty() { + path = "/oauth2/v1/token".to_string(); + } else if path.ends_with("/oauth2") { + path.push_str("/v1/token"); + } else if path.ends_with("/oauth2/default") || path.contains("/oauth2/") { + path.push_str("/v1/token"); + } else { + path.push_str("/oauth2/v1/token"); + } + url.set_path(&path); + url.set_query(None); + url.set_fragment(None); + Ok(url.to_string()) +} + fn build_okta_xaa_client_assertion( state: &StoredProviderCredentialRefreshState, token_url: &str, @@ -1292,7 +1409,12 @@ mod tests { .and(body_string_contains( "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token", )) - .and(body_string_contains("resource=jira-connection")) + .and(body_string_contains( + "requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid-jag", + )) + .and(body_string_contains( + "audience=https%3A%2F%2Fnvidia-partner.oktapreview.com", + )) .and(body_string_contains("client_assertion=")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ "access_token": "xaa-delegated-token", @@ -1336,7 +1458,10 @@ mod tests { material: HashMap::from([ ("client_id".to_string(), "agent-client-id".to_string()), ("sandbox_id".to_string(), "sandbox-xaa".to_string()), - ("resource".to_string(), "jira-connection".to_string()), + ( + "audience".to_string(), + "https://nvidia-partner.oktapreview.com".to_string(), + ), ( "private_key_pem".to_string(), TEST_RSA_PRIVATE_KEY.to_string(), diff --git a/providers/okta-xaa.yaml b/providers/okta-xaa.yaml index ed5a4036e..4347fe057 100644 --- a/providers/okta-xaa.yaml +++ b/providers/okta-xaa.yaml @@ -18,24 +18,31 @@ credentials: refresh_before_seconds: 300 max_lifetime_seconds: 3600 material: - - name: client_id - description: Okta AI agent client ID used for XAA token exchange + - name: requesting_client_id + description: XAA requesting app client ID used for the ID token to ID-JAG exchange required: true + - name: requesting_client_secret + description: XAA requesting app client secret used for the ID token to ID-JAG exchange + required: true + secret: true - name: sandbox_id description: OpenShell sandbox ID bound to the authenticated user's OIDC tokens required: true - - name: resource - description: Okta XAA resource or resource connection identifier + - name: resource_client_id + description: XAA resource app client ID used for the ID-JAG to access-token exchange required: true - - name: private_key_pem - description: RSA private key PEM used to sign the Okta XAA client assertion + - name: resource_client_secret + description: XAA resource app client secret used for the ID-JAG to access-token exchange required: true secret: true - - name: kid - description: Optional key ID for the registered Okta public key + - name: audience + description: Okta XAA resource authorization server issuer used as the audience for the ID-JAG exchange + required: true + - name: resource + description: Optional resource URL sent during the ID token to ID-JAG exchange. Defaults to the audience when omitted. required: false - - name: client_assertion_type - description: OAuth client assertion type, usually urn:ietf:params:oauth:client-assertion-type:jwt-bearer + - name: resource_token_url + description: Optional resource app token endpoint. Defaults to the audience issuer with /oauth2/v1/token. required: false - name: scope description: Space-delimited scopes requested for the delegated token diff --git a/providers/xaa-dev.yaml b/providers/xaa-dev.yaml new file mode 100644 index 000000000..294a2f166 --- /dev/null +++ b/providers/xaa-dev.yaml @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +id: xaa-dev +display_name: XAA.dev Sample +description: Sample 2-step XAA flow using idp.xaa.dev and the Todo0 Resource App +category: other +credentials: + - name: xaa_access_token + description: XAA.dev delegated access token for the Todo0 sample resource API + env_vars: [XAA_DEV_ACCESS_TOKEN] + required: true + auth_style: bearer + header_name: authorization + refresh: + strategy: okta_xaa + token_url: https://idp.xaa.dev/token + refresh_before_seconds: 300 + max_lifetime_seconds: 3600 + scopes: [todos.read] + material: + - name: requesting_client_id + description: XAA.dev requesting app client ID used for the ID token to ID-JAG exchange + required: true + - name: requesting_client_secret + description: XAA.dev requesting app client secret used for the ID token to ID-JAG exchange + required: true + secret: true + - name: sandbox_id + description: OpenShell sandbox ID bound to the authenticated user's OIDC id_token + required: true + - name: resource_client_id + description: Resource client ID for the Todo0 sample resource authorization server + required: true + - name: resource_client_secret + description: Resource client secret for the Todo0 sample resource authorization server + required: true + secret: true + - name: audience + description: Resource authorization server issuer used when minting the ID-JAG + required: true + - name: resource + description: Optional protected resource URL requested during the ID-JAG exchange + required: false + - name: resource_token_url + description: Optional resource authorization server token endpoint + required: false +endpoints: + - host: idp.xaa.dev + port: 443 + protocol: rest + access: read-write + enforcement: enforce + - host: auth.resource.xaa.dev + port: 443 + protocol: rest + access: read-write + enforcement: enforce + - host: api.resource.xaa.dev + port: 443 + protocol: rest + access: read-only + enforcement: enforce +binaries: + - /usr/bin/curl + - /usr/local/bin/curl From 6ca512eef21d57d5997f95ced2b903bbd25061ee Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 20:44:08 -0700 Subject: [PATCH 11/13] fix(xaa): align sample xaa branch with current provider refresh behavior --- crates/openshell-cli/src/completers.rs | 19 ++++++------ crates/openshell-server/src/grpc/provider.rs | 2 ++ .../openshell-server/src/provider_refresh.rs | 30 ++++++++----------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/crates/openshell-cli/src/completers.rs b/crates/openshell-cli/src/completers.rs index cc7d133ab..ff7d75632 100644 --- a/crates/openshell-cli/src/completers.rs +++ b/crates/openshell-cli/src/completers.rs @@ -98,16 +98,15 @@ async fn completion_grpc_client( Some("oidc") => { if let Some(bundle) = load_oidc_token(gateway_name) { if is_token_expired(&bundle) { - match oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await { - Ok(refreshed) => { - let _ = store_oidc_token(gateway_name, &refreshed); - tls_opts.oidc_token = Some(refreshed.access_token); - tls_opts.oidc_id_token = refreshed.id_token; - } - Err(_) => { - tls_opts.oidc_token = Some(bundle.access_token); - tls_opts.oidc_id_token = bundle.id_token; - } + if let Ok(refreshed) = + oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await + { + let _ = store_oidc_token(gateway_name, &refreshed); + tls_opts.oidc_token = Some(refreshed.access_token); + tls_opts.oidc_id_token = refreshed.id_token; + } else { + tls_opts.oidc_token = Some(bundle.access_token); + tls_opts.oidc_id_token = bundle.id_token; } } else { tls_opts.oidc_token = Some(bundle.access_token); diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 4fd153836..d13f7d13d 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -1872,6 +1872,8 @@ mod tests { "google-vertex-ai", "nvidia", "okta-obo", + "okta-xaa", + "xaa-dev", ] ); diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 660493e54..f1640cc19 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -562,7 +562,12 @@ async fn mint_okta_xaa_token_exchange( ) -> Result { if material_value( &state.material, - &["requesting_client_id", "requesting_client_secret", "resource_client_id", "resource_client_secret"], + &[ + "requesting_client_id", + "requesting_client_secret", + "resource_client_id", + "resource_client_secret", + ], ) .is_some() { @@ -640,8 +645,8 @@ async fn mint_okta_xaa_sample_token_exchange( let resource_client_id = required_material(&state.material, "resource_client_id")?; let resource_client_secret = required_material(&state.material, "resource_client_secret")?; let audience = required_material(&state.material, "audience")?; - let resource = material_value(&state.material, &["resource"]) - .unwrap_or_else(|| audience.clone()); + let resource = + material_value(&state.material, &["resource"]).unwrap_or_else(|| audience.clone()); let scope = refresh_scopes(state).join(" "); let idp_token_url = oauth2_token_url(state)?; @@ -668,18 +673,12 @@ async fn mint_okta_xaa_sample_token_exchange( jag_form.push(("scope".to_string(), scope.clone())); } - let id_jag = request_token( - &idp_token_url, - &jag_form, - None, - state.max_lifetime_seconds, - ) - .await? - .access_token; + let id_jag = request_token(&idp_token_url, &jag_form, None, state.max_lifetime_seconds) + .await? + .access_token; let resource_token_url = material_value(&state.material, &["resource_token_url"]) - .map(Ok) - .unwrap_or_else(|| token_url_from_issuer(&audience))?; + .map_or_else(|| token_url_from_issuer(&audience), Ok)?; let mut resource_form = vec![ ( "grant_type".to_string(), @@ -708,9 +707,7 @@ fn token_url_from_issuer(issuer: &str) -> Result { let mut path = url.path().trim_end_matches('/').to_string(); if path.is_empty() { path = "/oauth2/v1/token".to_string(); - } else if path.ends_with("/oauth2") { - path.push_str("/v1/token"); - } else if path.ends_with("/oauth2/default") || path.contains("/oauth2/") { + } else if path.ends_with("/oauth2") || path.contains("/oauth2/") { path.push_str("/v1/token"); } else { path.push_str("/oauth2/v1/token"); @@ -1252,7 +1249,6 @@ mod tests { Mock::given(method("POST")) .and(path("/token")) .and(body_string_contains("grant_type=refresh_token")) - .and(body_string_contains("client_id=client-id")) .and(body_string_contains("refresh_token=old-refresh-token")) .and(body_string_contains( "scope=https%3A%2F%2Fgraph.microsoft.com%2F.default", From 1d0f0bad9afde27190cede998011dbef7f5a5752 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Tue, 2 Jun 2026 15:58:54 -0700 Subject: [PATCH 12/13] fix(xaa): align sample flow with provider refresh model --- .../tests/provider_commands_integration.rs | 791 +----------------- crates/openshell-core/src/metadata.rs | 21 - crates/openshell-providers/src/profiles.rs | 142 +--- crates/openshell-server/src/compute/mod.rs | 371 ++++---- crates/openshell-server/src/grpc/provider.rs | 716 +++------------- crates/openshell-server/src/grpc/sandbox.rs | 178 ++-- .../openshell-server/src/provider_refresh.rs | 75 +- docs/get-started/tutorials/okta-obo.mdx | 2 +- docs/get-started/tutorials/xaa-dev.mdx | 11 +- providers/okta-xaa.yaml | 7 +- providers/xaa-dev.yaml | 7 +- 11 files changed, 405 insertions(+), 1916 deletions(-) diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index ad1e7dd07..c5849ca71 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -44,10 +44,6 @@ struct ProviderState { profiles: Arc>>, refresh_statuses: Arc>>, refresh_requests: Arc>>, - delete_provider_requests: Arc>>, - fail_configure_refresh_message: Arc>>, - fail_rotate_refresh_message: Arc>>, - fail_delete_provider_message: Arc>>, sandbox_providers: Arc>>>, sandbox_provider_requests: Arc>>, global_settings: Arc>>, @@ -130,6 +126,8 @@ impl OpenShell for TestOpenShell { }), spec: None, status: None, + phase: 0, + current_policy_version: 0, }), })) } @@ -340,28 +338,6 @@ impl OpenShell for TestOpenShell { .into_inner() .provider .ok_or_else(|| Status::invalid_argument("provider is required"))?; - if provider.credentials.is_empty() { - let bootstrap_allowed = - if let Some(profile) = openshell_providers::get_default_profile(&provider.r#type) { - profile.allows_gateway_refresh_bootstrap() - } else { - self.state - .profiles - .lock() - .await - .get(&provider.r#type) - .cloned() - .is_some_and(|profile| { - openshell_providers::ProviderTypeProfile::from_proto(&profile) - .allows_gateway_refresh_bootstrap() - }) - }; - if !bootstrap_allowed { - return Err(Status::invalid_argument( - "provider.credentials must not be empty", - )); - } - } let mut providers = self.state.providers.lock().await; let provider_name = provider.object_name().to_string(); if providers.contains_key(&provider_name) { @@ -593,15 +569,6 @@ impl OpenShell for TestOpenShell { credential_key: request.credential_key.clone(), expires_at_ms: request.expires_at_ms, }); - let configure_failure = self - .state - .fail_configure_refresh_message - .lock() - .await - .take(); - if let Some(message) = configure_failure { - return Err(Status::internal(message)); - } let providers = self.state.providers.lock().await; let provider = providers .get(&request.provider) @@ -635,42 +602,21 @@ impl OpenShell for TestOpenShell { request: tonic::Request, ) -> Result, Status> { let request = request.into_inner(); - let provider_name = request.provider.clone(); - let credential_key = request.credential_key.clone(); self.state .refresh_requests .lock() .await .push(ProviderRefreshRequestLog::Rotate { - provider_name: provider_name.clone(), - credential_key: credential_key.clone(), + provider_name: request.provider.clone(), + credential_key: request.credential_key.clone(), }); - let rotate_failure = self.state.fail_rotate_refresh_message.lock().await.take(); - if let Some(message) = rotate_failure { - return Err(Status::internal(message)); - } let mut refresh_statuses = self.state.refresh_statuses.lock().await; let status = refresh_statuses - .get_mut(&(provider_name.clone(), credential_key.clone())) + .get_mut(&(request.provider, request.credential_key)) .ok_or_else(|| Status::not_found("provider refresh state not found"))?; - status.status = "refreshed".to_string(); - status.last_refresh_at_ms = 1; - status.next_refresh_at_ms = 3_600_000; - status.expires_at_ms = 3_600_000; - let status = status.clone(); - drop(refresh_statuses); - let mut providers = self.state.providers.lock().await; - let provider = providers - .get_mut(&provider_name) - .ok_or_else(|| Status::not_found("provider not found"))?; - provider - .credentials - .insert(credential_key.clone(), format!("minted-{credential_key}")); - provider - .credential_expires_at_ms - .insert(credential_key, 3_600_000); + status.status = "rotation_requested".to_string(); Ok(Response::new(RotateProviderCredentialResponse { - status: Some(status), + status: Some(status.clone()), })) } @@ -702,15 +648,6 @@ impl OpenShell for TestOpenShell { request: tonic::Request, ) -> Result, Status> { let name = request.into_inner().name; - self.state - .delete_provider_requests - .lock() - .await - .push(name.clone()); - let delete_failure = self.state.fail_delete_provider_message.lock().await.take(); - if let Some(message) = delete_failure { - return Err(Status::internal(message)); - } let deleted = self.state.providers.lock().await.remove(&name).is_some(); Ok(Response::new(DeleteProviderResponse { deleted })) } @@ -987,7 +924,6 @@ async fn provider_cli_run_functions_support_full_crud_flow() { "claude", false, &["API_KEY=abc".to_string()], - false, &["profile=dev".to_string()], &ts.tls, ) @@ -1037,7 +973,6 @@ async fn provider_refresh_cli_run_functions_wire_requests() { "outlook", false, &["MS_GRAPH_ACCESS_TOKEN=token".to_string()], - false, &[], &ts.tls, ) @@ -1120,13 +1055,13 @@ async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { ..Default::default() }, openshell_core::proto::ProviderCredentialRefreshMaterial { - name: "sandbox_id".to_string(), + name: "audience".to_string(), required: true, ..Default::default() }, openshell_core::proto::ProviderCredentialRefreshMaterial { - name: "audience".to_string(), - required: true, + name: "subject_token".to_string(), + secret: true, ..Default::default() }, ], @@ -1158,7 +1093,6 @@ async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { strategy: "oauth2_token_exchange", material: &[ "client_id=client-id".to_string(), - "sandbox_id=sandbox-123".to_string(), "audience=api://downstream".to_string(), "scope=api:access:read".to_string(), ], @@ -1208,7 +1142,6 @@ async fn provider_create_allows_empty_credentials_for_gateway_refresh_profiles() "custom-refresh", false, &[], - false, &[], &ts.tls, ) @@ -1231,7 +1164,6 @@ async fn sandbox_provider_cli_run_functions_wire_requests_and_idempotent_results "github", false, &["GITHUB_TOKEN=ghp-test".to_string()], - false, &[], &ts.tls, ) @@ -1350,7 +1282,6 @@ binaries: [/usr/bin/custom] "custom-api", false, &["CUSTOM_API_KEY=abc".to_string()], - false, &[], &ts.tls, ) @@ -1404,7 +1335,6 @@ async fn provider_create_from_existing_uses_profile_discovery_when_v2_enabled() "custom-discovery", true, &[], - false, &[], &ts.tls, ) @@ -1437,7 +1367,6 @@ async fn provider_create_from_existing_uses_registry_discovery_when_v2_disabled( "openai", true, &[], - false, &[], &ts.tls, ) @@ -1459,95 +1388,22 @@ async fn provider_create_from_existing_uses_registry_discovery_when_v2_disabled( ); } -#[tokio::test] -async fn provider_create_from_existing_vertex_discovers_credentials_and_config_when_v2_enabled() { - let ts = run_server().await; - enable_providers_v2(&ts).await; - let _env = EnvVarGuard::set(&[ - ("VERTEX_AI_TOKEN", "ya29.vertex-v2-fallback"), - ("VERTEX_AI_PROJECT_ID", "vertex-v2-project"), - ("VERTEX_AI_REGION", "europe-west4"), - ( - "GOOGLE_VERTEX_AI_BASE_URL", - "https://aiplatform.googleapis.com/v1beta1/projects/vertex-v2-project/locations/global/endpoints/openapi", - ), - ("VERTEX_AI_PUBLISHER", "anthropic"), - ]); - - run::provider_create( - &ts.endpoint, - "vertex-v2-discovered", - "google-vertex-ai", - true, - &[], - false, - &[], - &ts.tls, - ) - .await - .expect("vertex provider create --from-existing with v2 enabled"); - - let provider = ts - .state - .providers - .lock() - .await - .get("vertex-v2-discovered") - .cloned() - .expect("vertex provider should be stored"); - assert_eq!(provider.r#type, "google-vertex-ai"); - assert_eq!( - provider.credentials.get("VERTEX_AI_TOKEN"), - Some(&"ya29.vertex-v2-fallback".to_string()) - ); - assert_eq!( - provider.config.get("VERTEX_AI_PROJECT_ID"), - Some(&"vertex-v2-project".to_string()) - ); - assert_eq!( - provider.config.get("VERTEX_AI_REGION"), - Some(&"europe-west4".to_string()) - ); - assert_eq!( - provider.config.get("GOOGLE_VERTEX_AI_BASE_URL"), - Some( - &"https://aiplatform.googleapis.com/v1beta1/projects/vertex-v2-project/locations/global/endpoints/openapi" - .to_string() - ) - ); - assert_eq!( - provider.config.get("VERTEX_AI_PUBLISHER"), - Some(&"anthropic".to_string()) - ); -} - #[tokio::test] async fn provider_create_from_existing_requires_profile_when_v2_enabled() { let ts = run_server().await; enable_providers_v2(&ts).await; - // Use "generic" which is a normalised type but has no built-in provider - // profile, so v2 profile-based discovery fails with the expected message. - let _env = EnvVarGuard::set(&[("GENERIC_API_KEY", "some-secret")]); + let _env = EnvVarGuard::set(&[("OPENAI_API_KEY", "legacy-openai-secret")]); - let err = run::provider_create( - &ts.endpoint, - "v2-generic", - "generic", - true, - &[], - false, - &[], - &ts.tls, - ) - .await - .expect_err("v2 discovery without a profile should fail"); + let err = run::provider_create(&ts.endpoint, "v2-openai", "openai", true, &[], &[], &ts.tls) + .await + .expect_err("v2 discovery without a profile should fail"); assert!( err.to_string() .contains("providers v2 discovery requires a provider profile"), "unexpected error: {err}" ); - assert!(!ts.state.providers.lock().await.contains_key("v2-generic")); + assert!(!ts.state.providers.lock().await.contains_key("v2-openai")); } #[tokio::test] @@ -1578,7 +1434,6 @@ async fn provider_create_from_existing_fails_when_profile_discovery_finds_nothin "empty-discovery", true, &[], - false, &[], &ts.tls, ) @@ -1826,7 +1681,9 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "client_id" && material.required) + .any(|material| material.name == "subject_token" + && !material.required + && material.secret) ); assert!( refresh @@ -1834,12 +1691,6 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { .iter() .any(|material| material.name == "audience" && material.required) ); - assert!( - refresh - .material - .iter() - .any(|material| material.name == "subject_token" && !material.required) - ); } #[tokio::test] @@ -1877,7 +1728,9 @@ async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "sandbox_id" && material.required) + .any(|material| material.name == "subject_token" + && !material.required + && material.secret) ); assert!( refresh @@ -1907,6 +1760,14 @@ async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { && material.required && material.secret) ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "subject_token" + && !material.required + && material.secret) + ); } #[tokio::test] @@ -1993,7 +1854,6 @@ async fn provider_create_rejects_key_only_credentials_without_local_env_value() "claude", false, &["INVALID_PAIR".to_string()], - false, &[], &ts.tls, ) @@ -2018,7 +1878,6 @@ async fn provider_create_supports_generic_type_and_env_lookup_credentials() { "generic", false, &["NAV_GENERIC_TEST_KEY".to_string()], - false, &[], &ts.tls, ) @@ -2053,7 +1912,6 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { "claude", true, &["API_KEY=abc".to_string()], - false, &[], &ts.tls, ) @@ -2067,56 +1925,6 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { ); } -#[tokio::test] -async fn provider_create_rejects_combined_from_gcloud_adc_and_from_existing() { - let ts = run_server().await; - - let err = run::provider_create( - &ts.endpoint, - "bad-vertex-provider", - "google-vertex-ai", - true, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("from-gcloud-adc and from-existing should be mutually exclusive"); - - assert!( - err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), - "unexpected error: {err}" - ); - assert!(ts.state.providers.lock().await.is_empty()); -} - -#[tokio::test] -async fn provider_create_rejects_combined_from_gcloud_adc_and_credentials() { - let ts = run_server().await; - - let err = run::provider_create( - &ts.endpoint, - "bad-vertex-provider", - "google-vertex-ai", - false, - &["GOOGLE_VERTEX_AI_TOKEN=token".to_string()], - true, - &[], - &ts.tls, - ) - .await - .expect_err("from-gcloud-adc and credentials should be mutually exclusive"); - - assert!( - err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), - "unexpected error: {err}" - ); - assert!(ts.state.providers.lock().await.is_empty()); -} - #[tokio::test] async fn provider_create_rejects_empty_env_var_for_key_only_credential() { let ts = run_server().await; @@ -2128,7 +1936,6 @@ async fn provider_create_rejects_empty_env_var_for_key_only_credential() { "generic", false, &["NAV_EMPTY_ENV_KEY".to_string()], - false, &[], &ts.tls, ) @@ -2153,7 +1960,6 @@ async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { "nvidia", false, &["NVIDIA_API_KEY".to_string()], - false, &[], &ts.tls, ) @@ -2177,542 +1983,3 @@ async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { Some(&"nvapi-live-test".to_string()) ); } - -// ── --from-gcloud-adc tests ─────────────────────────────────────────────────── - -#[tokio::test] -async fn provider_create_from_gcloud_adc_happy_path() { - let ts = run_server().await; - - // Write a temp ADC file simulating a valid authorized_user credential. - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - - // Point GOOGLE_APPLICATION_CREDENTIALS at the temp file so read_gcloud_adc - // picks it up without touching the real ~/.config/gcloud/ path. - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - run::provider_create( - &ts.endpoint, - "my-vertex", - "google-vertex-ai", - false, - &[], // no explicit credentials; refresh bootstrap covers it - true, // from_gcloud_adc - &[], - &ts.tls, - ) - .await - .expect("provider_create with --from-gcloud-adc should succeed"); - - // Provider must exist in the server state. - let providers = ts.state.providers.lock().await; - let provider = providers - .get("my-vertex") - .expect("provider should be stored after create"); - assert_eq!(provider.r#type, "google-vertex-ai"); - assert_eq!( - provider - .credentials - .get("GOOGLE_VERTEX_AI_TOKEN") - .map(String::as_str), - Some("minted-GOOGLE_VERTEX_AI_TOKEN"), - "initial rotate should materialize a usable access token" - ); - drop(providers); - - // ADC bootstrap must configure refresh and immediately mint the first token. - let requests = ts.state.refresh_requests.lock().await.clone(); - assert_eq!( - requests.len(), - 2, - "expected configure + rotate refresh requests" - ); - assert_eq!( - requests[0], - ProviderRefreshRequestLog::Configure { - provider_name: "my-vertex".to_string(), - credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), - expires_at_ms: None, - } - ); - assert_eq!( - requests[1], - ProviderRefreshRequestLog::Rotate { - provider_name: "my-vertex".to_string(), - credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), - } - ); - - // The refresh status must record the ADC material keys. - let refresh_statuses = ts.state.refresh_statuses.lock().await; - let status = refresh_statuses - .get(&( - "my-vertex".to_string(), - "GOOGLE_VERTEX_AI_TOKEN".to_string(), - )) - .expect("refresh status should be stored"); - assert_eq!( - status.strategy, - ProviderCredentialRefreshStrategy::Oauth2RefreshToken as i32 - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_rejects_service_account() { - let ts = run_server().await; - - // Write a temp ADC file with type=service_account. - let adc_content = serde_json::json!({ - "type": "service_account", - "project_id": "my-project", - "private_key_id": "key-id", - "private_key": "-----BEGIN RSA PRIVATE KEY-----\n...", - "client_email": "sa@my-project.iam.gserviceaccount.com" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "my-vertex-sa", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("service_account ADC should be rejected"); - - assert!( - err.to_string() - .contains("GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), - "error should mention the service-account token key, got: {err}" - ); - - // create_provider must NOT have been called — no provider stored. - let providers = ts.state.providers.lock().await; - assert!( - providers.is_empty(), - "no provider should have been created on pre-flight failure" - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_missing_file() { - let ts = run_server().await; - - // Point to a path that does not exist. - let _guard = EnvVarGuard::set(&[( - "GOOGLE_APPLICATION_CREDENTIALS", - "/tmp/nonexistent-adc-file-openshell-test.json", - )]); - - let err = run::provider_create( - &ts.endpoint, - "my-vertex-missing", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("missing ADC file should produce an error"); - - // Error must mention the file path or the read failure. - let msg = err.to_string(); - assert!( - msg.contains("nonexistent-adc-file-openshell-test.json") - || msg.contains("failed to read gcloud ADC file"), - "error should reference the missing file, got: {msg}" - ); - - // create_provider must NOT have been called — no provider stored. - let providers = ts.state.providers.lock().await; - assert!( - providers.is_empty(), - "no provider should have been created on pre-flight failure" - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_rejects_wrong_provider_type_before_credential_check() { - let ts = run_server().await; - - let err = run::provider_create( - &ts.endpoint, - "my-openai-adc", - "openai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("wrong provider type should fail before generic credential validation"); - - assert!( - err.to_string() - .contains("--from-gcloud-adc is only valid for google-vertex-ai providers"), - "unexpected error: {err}" - ); - assert!(ts.state.providers.lock().await.is_empty()); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_rolls_back_provider_when_refresh_configure_fails() { - let ts = run_server().await; - *ts.state.fail_configure_refresh_message.lock().await = - Some("simulated configure failure".to_string()); - - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-rollback", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("configure_provider_refresh failure should bubble up"); - - assert!( - err.to_string().contains("simulated configure failure"), - "unexpected error: {err}" - ); - assert!( - !ts.state - .providers - .lock() - .await - .contains_key("vertex-rollback"), - "provider should be deleted on rollback" - ); - assert_eq!( - ts.state.delete_provider_requests.lock().await.clone(), - vec!["vertex-rollback".to_string()] - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_warn_path_keeps_provider_when_rollback_delete_fails() { - let ts = run_server().await; - *ts.state.fail_configure_refresh_message.lock().await = - Some("simulated configure failure".to_string()); - *ts.state.fail_delete_provider_message.lock().await = - Some("simulated delete failure".to_string()); - - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-cleanup-warning", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("cleanup failure path should still return configure error"); - - assert!( - err.to_string().contains("simulated configure failure"), - "unexpected error: {err}" - ); - assert!( - ts.state - .providers - .lock() - .await - .contains_key("vertex-cleanup-warning"), - "provider should remain when rollback deletion fails" - ); - assert_eq!( - ts.state.delete_provider_requests.lock().await.clone(), - vec!["vertex-cleanup-warning".to_string()] - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_rolls_back_provider_when_initial_rotate_fails() { - let ts = run_server().await; - *ts.state.fail_rotate_refresh_message.lock().await = - Some("simulated rotate failure".to_string()); - - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-rotate-rollback", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("initial rotate failure should roll back the provider"); - - assert!( - err.to_string().contains("simulated rotate failure"), - "unexpected error: {err}" - ); - assert!( - !ts.state - .providers - .lock() - .await - .contains_key("vertex-rotate-rollback"), - "provider should be deleted on initial-rotate rollback" - ); - assert_eq!( - ts.state.delete_provider_requests.lock().await.clone(), - vec!["vertex-rotate-rollback".to_string()] - ); -} - -#[tokio::test] -async fn provider_create_from_existing_vertex_config_only_reports_missing_vertex_credentials() { - let ts = run_server().await; - enable_providers_v2(&ts).await; - let _env = EnvVarGuard::set(&[ - ("VERTEX_AI_PROJECT_ID", "vertex-config-only-project"), - ("VERTEX_AI_REGION", "us-central1"), - ]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-config-only", - "google-vertex-ai", - true, - &[], - false, - &[], - &ts.tls, - ) - .await - .expect_err("config-only discovery should surface missing credential guidance"); - - let msg = err.to_string(); - assert!( - msg.contains("GOOGLE_VERTEX_AI_TOKEN") && msg.contains("VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), - "unexpected error: {msg}" - ); - assert!( - !ts.state - .providers - .lock() - .await - .contains_key("vertex-config-only") - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_with_config_keys() { - let ts = run_server().await; - - // Write a valid authorized_user ADC file. - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - run::provider_create( - &ts.endpoint, - "vertex-with-config", - "google-vertex-ai", - false, - &[], // no explicit credentials; ADC flow - true, // from_gcloud_adc - &[ - "VERTEX_AI_PROJECT_ID=my-gcp-project".to_string(), - "VERTEX_AI_REGION=us-east1".to_string(), - ], - &ts.tls, - ) - .await - .expect("provider_create with --from-gcloud-adc and --config keys should succeed"); - - // Verify provider was created with the config keys. - let providers = ts.state.providers.lock().await; - let provider = providers - .get("vertex-with-config") - .expect("provider should be stored after create"); - assert_eq!(provider.r#type, "google-vertex-ai"); - assert_eq!( - provider - .config - .get("VERTEX_AI_PROJECT_ID") - .map(String::as_str), - Some("my-gcp-project"), - "VERTEX_AI_PROJECT_ID must be stored in provider config" - ); - assert_eq!( - provider.config.get("VERTEX_AI_REGION").map(String::as_str), - Some("us-east1"), - "VERTEX_AI_REGION must be stored in provider config" - ); - drop(providers); - - // ADC flow should configure refresh and eagerly mint the initial token. - let refresh_requests = ts.state.refresh_requests.lock().await.clone(); - assert_eq!( - refresh_requests.len(), - 2, - "exactly one configure call and one rotate call expected" - ); - assert_eq!( - refresh_requests[0], - ProviderRefreshRequestLog::Configure { - provider_name: "vertex-with-config".to_string(), - credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), - expires_at_ms: None, - } - ); - assert_eq!( - refresh_requests[1], - ProviderRefreshRequestLog::Rotate { - provider_name: "vertex-with-config".to_string(), - credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), - } - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_missing_refresh_token() { - let ts = run_server().await; - - // ADC file is valid authorized_user type but missing refresh_token. - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "client_secret": "test-client-secret" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-missing-refresh", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("missing refresh_token should produce an error"); - - let err_msg = err.to_string(); - assert!( - err_msg.contains("refresh_token"), - "error must mention 'refresh_token', got: {err_msg}" - ); - - // No provider should have been created. - let providers = ts.state.providers.lock().await; - assert!( - providers.is_empty(), - "no provider must be created when ADC validation fails" - ); -} - -#[tokio::test] -async fn provider_create_from_gcloud_adc_missing_client_secret() { - let ts = run_server().await; - - // ADC file is valid authorized_user type but missing client_secret. - let adc_content = serde_json::json!({ - "type": "authorized_user", - "client_id": "test-client-id.apps.googleusercontent.com", - "refresh_token": "1//test-refresh-token" - }); - let adc_file = tempfile::NamedTempFile::new().unwrap(); - serde_json::to_writer(&adc_file, &adc_content).unwrap(); - let adc_path = adc_file.path().to_str().unwrap().to_string(); - let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); - - let err = run::provider_create( - &ts.endpoint, - "vertex-missing-secret", - "google-vertex-ai", - false, - &[], - true, - &[], - &ts.tls, - ) - .await - .expect_err("missing client_secret should produce an error"); - - let err_msg = err.to_string(); - assert!( - err_msg.contains("client_secret"), - "error must mention 'client_secret', got: {err_msg}" - ); - - // No provider should have been created. - let providers = ts.state.providers.lock().await; - assert!( - providers.is_empty(), - "no provider must be created when ADC validation fails" - ); -} diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index e86bc17e2..78533e1e0 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -8,7 +8,6 @@ use crate::proto::{ InferenceRoute, ObjectForTest, Provider, Sandbox, ServiceEndpoint, SshSession, StoredProviderCredentialRefreshState, StoredProviderProfile, - SandboxStatus, }; use std::collections::HashMap; @@ -70,26 +69,6 @@ impl GetResourceVersion for Sandbox { } } -impl Sandbox { - pub fn phase(&self) -> i32 { - self.status.as_ref().map_or(0, |s| s.phase) - } - - pub fn set_phase(&mut self, phase: i32) { - self.status.get_or_insert_with(SandboxStatus::default).phase = phase; - } - - pub fn current_policy_version(&self) -> u32 { - self.status.as_ref().map_or(0, |s| s.current_policy_version) - } - - pub fn set_current_policy_version(&mut self, version: u32) { - self.status - .get_or_insert_with(SandboxStatus::default) - .current_policy_version = version; - } -} - // Implementations for Provider impl ObjectId for Provider { fn object_id(&self) -> &str { diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 47fa688c3..5f04bd405 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -19,7 +19,6 @@ use std::sync::OnceLock; const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/claude-code.yaml"), include_str!("../../../providers/github.yaml"), - include_str!("../../../providers/google-vertex-ai.yaml"), include_str!("../../../providers/nvidia.yaml"), include_str!("../../../providers/okta-obo.yaml"), include_str!("../../../providers/okta-xaa.yaml"), @@ -310,25 +309,6 @@ impl ProviderTypeProfile { vars } - /// Whether this profile can be created without an initial access token because - /// the gateway can mint at least one credential immediately from refresh - /// material, and no required credential falls outside that gateway-mintable set. - #[must_use] - pub fn allows_gateway_refresh_bootstrap(&self) -> bool { - let mut has_gateway_mintable_credential = false; - for credential in &self.credentials { - let is_gateway_mintable = credential - .refresh - .as_ref() - .is_some_and(CredentialRefreshProfile::is_gateway_mintable); - if credential.required && !is_gateway_mintable { - return false; - } - has_gateway_mintable_credential |= is_gateway_mintable; - } - has_gateway_mintable_credential - } - #[must_use] pub fn to_proto(&self) -> ProviderProfile { ProviderProfile { @@ -368,18 +348,6 @@ impl ProviderTypeProfile { } } -impl CredentialRefreshProfile { - #[must_use] - pub fn is_gateway_mintable(&self) -> bool { - matches!( - self.strategy, - ProviderCredentialRefreshStrategy::Oauth2RefreshToken - | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials - | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt - ) - } -} - fn discovery_is_empty(discovery: &DiscoveryProfile) -> bool { discovery.credentials.is_empty() } @@ -626,7 +594,6 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { allow_encoded_slash: endpoint.allow_encoded_slash, websocket_credential_rewrite: endpoint.websocket_credential_rewrite, request_body_credential_rewrite: endpoint.request_body_credential_rewrite, - advisor_proposed: false, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -1220,15 +1187,15 @@ mod tests { refresh .material .iter() - .find(|material| material.name == "subject_token") - .is_some_and(|material| !material.required && material.secret) + .find(|material| material.name == "audience") + .is_some_and(|material| material.required) ); assert!( refresh .material .iter() - .find(|material| material.name == "audience") - .is_some_and(|material| material.required) + .find(|material| material.name == "subject_token") + .is_some_and(|material| !material.required && material.secret) ); } @@ -1264,7 +1231,7 @@ mod tests { vec![ "requesting_client_id", "requesting_client_secret", - "sandbox_id", + "subject_token", "resource_client_id", "resource_client_secret", "audience", @@ -1287,6 +1254,13 @@ mod tests { .find(|material| material.name == "resource_client_secret") .is_some_and(|material| material.required && material.secret) ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "subject_token") + .is_some_and(|material| !material.required && material.secret) + ); } #[test] @@ -1319,7 +1293,7 @@ mod tests { vec![ "requesting_client_id", "requesting_client_secret", - "sandbox_id", + "subject_token", "resource_client_id", "resource_client_secret", "audience", @@ -1327,6 +1301,13 @@ mod tests { "resource_token_url", ] ); + assert!( + refresh + .material + .iter() + .find(|material| material.name == "subject_token") + .is_some_and(|material| !material.required && material.secret) + ); assert_eq!(profile.endpoints.len(), 3); assert_eq!(profile.endpoints[0].host, "idp.xaa.dev"); assert_eq!(profile.endpoints[1].host, "auth.resource.xaa.dev"); @@ -1343,89 +1324,6 @@ mod tests { ); } - #[test] - fn vertex_profile_declares_discovery_and_fallback_token_env_vars() { - let profile = get_default_profile("google-vertex-ai").expect("vertex profile"); - let service_account_token = profile - .credentials - .iter() - .find(|credential| credential.name == "service_account_token") - .expect("vertex service-account token credential"); - let adc_credential = profile - .credentials - .iter() - .find(|credential| credential.name == "gcloud_adc_token") - .expect("vertex ADC credential"); - - assert_eq!( - service_account_token.env_vars, - vec![ - "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string(), - "VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string() - ] - ); - assert_eq!( - adc_credential.env_vars, - vec![ - "GOOGLE_VERTEX_AI_TOKEN".to_string(), - "VERTEX_AI_TOKEN".to_string() - ] - ); - assert_eq!( - profile.discovery.credentials, - vec!["service_account_token", "gcloud_adc_token"] - ); - assert!( - profile.allows_gateway_refresh_bootstrap(), - "Vertex profile should allow empty-create bootstrap via gateway-mintable credentials" - ); - } - - #[test] - fn refresh_bootstrap_requires_a_gateway_mintable_path_and_no_required_static_credentials() { - let optional_refresh_profile = parse_profile_yaml( - r" -id: optional-refresh -display_name: Optional Refresh -credentials: - - name: access_token - required: false - refresh: - strategy: oauth2_refresh_token -", - ) - .expect("profile"); - assert!(optional_refresh_profile.allows_gateway_refresh_bootstrap()); - - let mixed_required_profile = parse_profile_yaml( - r" -id: mixed-required -display_name: Mixed Required -credentials: - - name: access_token - required: true - refresh: - strategy: oauth2_client_credentials - - name: static_key - required: true -", - ) - .expect("profile"); - assert!(!mixed_required_profile.allows_gateway_refresh_bootstrap()); - - let static_only_profile = parse_profile_yaml( - r" -id: static-only -display_name: Static Only -credentials: - - name: api_key - required: false -", - ) - .expect("profile"); - assert!(!static_only_profile.allows_gateway_refresh_bootstrap()); - } - #[test] fn parse_profile_yaml_reads_single_provider_document() { let profile = parse_profile_yaml( diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 0122f9178..fca0c388f 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -15,7 +15,6 @@ use crate::sandbox_watch::SandboxWatchBus; use crate::supervisor_session::SupervisorSessionRegistry; use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; -use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, @@ -95,6 +94,11 @@ const RECONCILE_INTERVAL: Duration = Duration::from_secs(60); /// corresponding backend resource before it is considered orphaned. const ORPHAN_GRACE_PERIOD: Duration = Duration::from_secs(300); +/// Sandbox store updates can race during startup when the driver watch stream +/// and supervisor session callbacks both try to flip a sandbox to Ready. +/// Retry a few times instead of leaving the record stuck in Provisioning. +const SANDBOX_CAS_RETRIES: usize = 4; + // Re-export the shared error type under the name used by this module. pub use openshell_core::ComputeDriverError as ComputeError; @@ -220,7 +224,6 @@ impl ComputeDriver for RemoteComputeDriver { #[derive(Clone)] pub struct ComputeRuntime { driver: SharedComputeDriver, - driver_kind: Option, shutdown_cleanup: Option>, startup_resume: Option>, _driver_process: Option>, @@ -243,7 +246,6 @@ impl fmt::Debug for ComputeRuntime { impl ComputeRuntime { #[allow(clippy::too_many_arguments)] async fn from_driver( - driver_kind: ComputeDriverKind, driver: SharedComputeDriver, shutdown_cleanup: Option>, startup_resume: Option>, @@ -264,7 +266,6 @@ impl ComputeRuntime { .default_image; Ok(Self { driver, - driver_kind: Some(driver_kind), shutdown_cleanup, startup_resume, _driver_process: driver_process, @@ -308,7 +309,6 @@ impl ComputeRuntime { let startup_resume: Arc = driver.clone(); let driver: SharedComputeDriver = driver; Self::from_driver( - ComputeDriverKind::Docker, driver, Some(shutdown_cleanup), Some(startup_resume), @@ -337,7 +337,6 @@ impl ComputeRuntime { .map_err(|err| ComputeError::Message(err.to_string()))?; let driver: SharedComputeDriver = Arc::new(ComputeDriverService::new(driver)); Self::from_driver( - ComputeDriverKind::Kubernetes, driver, None, None, @@ -364,7 +363,6 @@ impl ComputeRuntime { ) -> Result { let driver: SharedComputeDriver = Arc::new(RemoteComputeDriver::new(channel)); Self::from_driver( - ComputeDriverKind::Vm, driver, None, None, @@ -393,7 +391,6 @@ impl ComputeRuntime { .map_err(|err| ComputeError::Message(err.to_string()))?; let driver: SharedComputeDriver = Arc::new(PodmanDriverService::new(driver)); Self::from_driver( - ComputeDriverKind::Podman, driver, None, None, @@ -414,11 +411,6 @@ impl ComputeRuntime { &self.default_image } - #[must_use] - pub fn driver_kind(&self) -> Option { - self.driver_kind - } - #[must_use] pub fn gateway_bind_addresses(&self) -> &[SocketAddr] { &self.gateway_bind_addresses @@ -520,8 +512,6 @@ impl ComputeRuntime { } pub async fn delete_sandbox(&self, name: &str) -> Result { - let _guard = self.sync_lock.lock().await; - // Resolve sandbox ID from name let sandbox = self .store @@ -540,7 +530,7 @@ impl ComputeRuntime { let sandbox = self .store .update_message_cas::(&id, 0, |s| { - s.set_phase(SandboxPhase::Deleting as i32); + s.phase = SandboxPhase::Deleting as i32; }) .await .map_err(|e| { @@ -623,7 +613,7 @@ impl ComputeRuntime { } }; - let phase = SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); + let phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); if !sandbox_phase_should_be_running(phase) { continue; } @@ -697,7 +687,7 @@ impl ComputeRuntime { match self .store .update_message_cas::(&sandbox_id, 0, |s| { - s.set_phase(SandboxPhase::Error as i32); + s.phase = SandboxPhase::Error as i32; let name = s.object_name().to_string(); upsert_ready_condition( &mut s.status, @@ -875,25 +865,21 @@ impl ComputeRuntime { use crate::persistence::WriteCondition; let now_ms = openshell_core::time::now_ms(); + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, None); + let session_connected = self.supervisor_sessions.has_session(&incoming.id); let mut phase = derive_phase(incoming.status.as_ref()); - let sandbox_name = incoming.name.clone(); - let supervisor_promoted = session_connected - && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown); - if supervisor_promoted { + let sandbox_name = incoming.name.clone(); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) + { + ensure_supervisor_ready_status(&mut status, &sandbox_name); phase = SandboxPhase::Ready; } - let mut status = incoming - .status - .as_ref() - .map(|s| public_status_from_driver(s, phase, 0)); - rewrite_user_facing_conditions(&mut status, None); - if supervisor_promoted { - ensure_supervisor_ready_status(&mut status, &sandbox_name); - } - let mut sandbox = Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: incoming.id.clone(), name: sandbox_name, @@ -903,8 +889,9 @@ impl ComputeRuntime { }), spec: None, status, + phase: phase as i32, + current_policy_version: 0, }; - sandbox.set_phase(phase as i32); self.store .put_if( @@ -932,90 +919,82 @@ impl ComputeRuntime { return Ok(()); } - // Single-attempt CAS: on conflict, the next watch event will naturally retry let session_connected = self.supervisor_sessions.has_session(&incoming.id); let sandbox_name = incoming.name.clone(); + let mut attempts = 0usize; + let sandbox = loop { + match self + .store + .update_message_cas::(&incoming.id, 0, |sandbox| { + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); + + let mut phase = derive_phase(incoming.status.as_ref()); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) + { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } - let sandbox = self - .store - .update_message_cas::(&incoming.id, 0, |sandbox| { - let old_phase = - SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); - let mut phase = incoming - .status - .as_ref() - .map_or(old_phase, |status| derive_phase(Some(status))); - let supervisor_promoted = session_connected - && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown); - if supervisor_promoted { - phase = SandboxPhase::Ready; - } - - let cpv = sandbox.current_policy_version(); - let mut status = incoming - .status - .as_ref() - .map(|s| public_status_from_driver(s, phase, cpv)) - .or_else(|| sandbox.status.clone()); - rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); - if supervisor_promoted { - ensure_supervisor_ready_status(&mut status, &sandbox_name); - } - - if let Some(s) = status.as_mut() - && s.sandbox_name.is_empty() - { - s.sandbox_name.clone_from(&sandbox_name); - } - - if old_phase != phase { - info!( - sandbox_id = %incoming.id, - sandbox_name = %sandbox_name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); - } + let old_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase != phase { + info!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" + ); + } - if phase == SandboxPhase::Error - && let Some(ref status) = status - { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_reason(&condition.reason) - { - warn!( - sandbox_id = %incoming.id, - sandbox_name = %sandbox_name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" - ); + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_reason(&condition.reason) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); + } } } - } - // Update metadata fields - if let Some(metadata) = sandbox.metadata.as_mut() { - metadata.name.clone_from(&sandbox_name); + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.name.clone_from(&sandbox_name); + } + sandbox.status = status; + sandbox.phase = phase as i32; + }) + .await + { + Ok(sandbox) => break sandbox, + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version: _, + }) if attempts + 1 < SANDBOX_CAS_RETRIES => { + attempts += 1; + continue; } - sandbox.status = status; - sandbox.set_phase(phase as i32); - sandbox.set_current_policy_version(cpv); - }) - .await - .map_err(|e| match e { - crate::persistence::PersistenceError::Conflict { + Err(crate::persistence::PersistenceError::Conflict { current_resource_version, - } => format!( - "concurrent modification detected during sandbox reconciliation (current resource_version: {})", - current_resource_version - .map_or_else(|| "unknown".to_string(), |v| v.to_string()) - ), - other => other.to_string(), - })?; + }) => { + return Err(format!( + "concurrent modification detected during sandbox reconciliation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )); + } + Err(other) => return Err(other.to_string()), + } + }; self.sandbox_index.update_from_sandbox(&sandbox); self.sandbox_watch_bus.notify(sandbox.object_id()); @@ -1037,28 +1016,48 @@ impl ComputeRuntime { ) -> Result<(), String> { let _guard = self.sync_lock.lock().await; - // Use CAS to update sandbox phase based on supervisor session state - let result = self - .store - .update_message_cas::(sandbox_id, 0, |sandbox| { - let current_phase = - SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); + let mut attempts = 0usize; + let result = loop { + match self + .store + .update_message_cas::(sandbox_id, 0, |sandbox| { + let current_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + + if current_phase == SandboxPhase::Deleting + || current_phase == SandboxPhase::Error + { + return; + } - // Skip if sandbox is in terminal state - if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { - return; + let sandbox_name = sandbox.object_name().to_string(); + if connected { + ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Ready as i32; + } else if current_phase == SandboxPhase::Ready { + ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Provisioning as i32; + } + }) + .await + { + Ok(sandbox) => break Ok(sandbox), + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version: _, + }) if attempts + 1 < SANDBOX_CAS_RETRIES => { + attempts += 1; + continue; } - - let sandbox_name = sandbox.object_name().to_string(); - if connected { - ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.set_phase(SandboxPhase::Ready as i32); - } else if current_phase == SandboxPhase::Ready { - ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.set_phase(SandboxPhase::Provisioning as i32); + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version, + }) => { + break Err(crate::persistence::PersistenceError::Conflict { + current_resource_version, + }); } - }) - .await; + Err(other) => break Err(other), + } + }; // Handle not found gracefully (sandbox may have been deleted) let sandbox = match result { @@ -1255,7 +1254,10 @@ fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { name: sandbox.object_name().to_string(), namespace: String::new(), // Namespace is set by the driver based on its config spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), - status: sandbox.status.as_ref().map(driver_status_from_public), + status: sandbox + .status + .as_ref() + .map(|status| driver_status_from_public(status, sandbox.phase)), } } @@ -1462,7 +1464,7 @@ fn build_platform_resources_config( } } -fn driver_status_from_public(status: &SandboxStatus) -> DriverSandboxStatus { +fn driver_status_from_public(status: &SandboxStatus, phase: i32) -> DriverSandboxStatus { DriverSandboxStatus { sandbox_name: status.sandbox_name.clone(), instance_id: status.agent_pod.clone(), @@ -1473,7 +1475,7 @@ fn driver_status_from_public(status: &SandboxStatus) -> DriverSandboxStatus { .iter() .map(driver_condition_from_public) .collect(), - deleting: SandboxPhase::try_from(status.phase) == Ok(SandboxPhase::Deleting), + deleting: SandboxPhase::try_from(phase) == Ok(SandboxPhase::Deleting), } } @@ -1505,11 +1507,7 @@ fn decode_sandbox_record(record: &ObjectRecord) -> Result { Sandbox::decode(record.payload.as_slice()).map_err(|e| e.to_string()) } -fn public_status_from_driver( - status: &DriverSandboxStatus, - phase: SandboxPhase, - current_policy_version: u32, -) -> SandboxStatus { +fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { SandboxStatus { sandbox_name: status.sandbox_name.clone(), agent_pod: status.instance_id.clone(), @@ -1520,8 +1518,6 @@ fn public_status_from_driver( .iter() .map(public_condition_from_driver) .collect(), - phase: phase as i32, - current_policy_version, } } @@ -1560,7 +1556,10 @@ fn upsert_ready_condition( ) { let status = status.get_or_insert_with(|| SandboxStatus { sandbox_name: sandbox_name.to_string(), - ..Default::default() + agent_pod: String::new(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: Vec::new(), }); if let Some(existing) = status @@ -1684,6 +1683,8 @@ impl ComputeDriver for NoopTestDriver { driver_name: "noop-test-driver".to_string(), driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), + supports_gpu: false, + gpu_count: 0, }, )) } @@ -1762,7 +1763,6 @@ impl ComputeDriver for NoopTestDriver { pub async fn new_test_runtime(store: Arc) -> ComputeRuntime { ComputeRuntime { driver: Arc::new(NoopTestDriver), - driver_kind: None, shutdown_cleanup: None, startup_resume: None, _driver_process: None, @@ -1832,6 +1832,8 @@ mod tests { driver_name: "test-driver".to_string(), driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), + supports_gpu: true, + gpu_count: 0, })) } @@ -1928,7 +1930,6 @@ mod tests { let store = Arc::new(Store::connect("sqlite::memory:").await.unwrap()); ComputeRuntime { driver, - driver_kind: None, shutdown_cleanup: None, startup_resume, _driver_process: None, @@ -1955,7 +1956,7 @@ mod tests { } fn sandbox_record(id: &str, name: &str, phase: SandboxPhase) -> Sandbox { - let mut sandbox = Sandbox { + Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: id.to_string(), name: name.to_string(), @@ -1963,10 +1964,9 @@ mod tests { labels: HashMap::new(), resource_version: 0, }), + phase: phase as i32, ..Default::default() - }; - sandbox.set_phase(phase as i32); - sandbox + } } fn ssh_session_record(id: &str, sandbox_id: &str) -> SshSession { @@ -2245,6 +2245,8 @@ mod tests { let mut status = Some(SandboxStatus { sandbox_name: "test".to_string(), agent_pod: "test-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "False".to_string(), @@ -2252,7 +2254,6 @@ mod tests { message: "0/1 nodes are available: 1 Insufficient nvidia.com/gpu.".to_string(), last_transition_time: String::new(), }], - ..Default::default() }); rewrite_user_facing_conditions( @@ -2276,6 +2277,8 @@ mod tests { let mut status = Some(SandboxStatus { sandbox_name: "test".to_string(), agent_pod: "test-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "False".to_string(), @@ -2283,7 +2286,6 @@ mod tests { message: original.to_string(), last_transition_time: String::new(), }], - ..Default::default() }); rewrite_user_facing_conditions( @@ -2347,65 +2349,9 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), - SandboxPhase::Ready - ); - } - - #[tokio::test] - async fn apply_sandbox_update_without_status_preserves_existing_status() { - let runtime = test_runtime(Arc::new(TestDriver::default())).await; - let mut sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready); - sandbox.status = Some(SandboxStatus { - sandbox_name: "sandbox-a".to_string(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "True".to_string(), - reason: "DependenciesReady".to_string(), - message: "Pod is Ready".to_string(), - last_transition_time: String::new(), - }], - current_policy_version: 7, - ..Default::default() - }); - sandbox.set_phase(SandboxPhase::Ready as i32); - runtime.store.put_message(&sandbox).await.unwrap(); - - runtime - .apply_sandbox_update(DriverSandbox { - id: "sb-1".to_string(), - name: "sandbox-a".to_string(), - namespace: "default".to_string(), - spec: None, - status: None, - }) - .await - .unwrap(); - - let stored = runtime - .store - .get_message::("sb-1") - .await - .unwrap() - .unwrap(); - assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); - assert_eq!(stored.current_policy_version(), 7); - let ready = stored - .status - .as_ref() - .and_then(|status| { - status - .conditions - .iter() - .find(|condition| condition.r#type == "Ready") - }) - .unwrap(); - assert_eq!(ready.status, "True"); - assert_eq!(ready.reason, "DependenciesReady"); - assert_eq!(ready.message, "Pod is Ready"); } #[tokio::test] @@ -2437,7 +2383,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); let ready = stored @@ -2470,7 +2416,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); } @@ -2481,6 +2427,9 @@ mod tests { let mut sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready); sandbox.status = Some(SandboxStatus { sandbox_name: "sandbox-a".to_string(), + agent_pod: String::new(), + agent_fd: String::new(), + sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "True".to_string(), @@ -2488,9 +2437,7 @@ mod tests { message: "Supervisor session connected".to_string(), last_transition_time: String::new(), }], - ..Default::default() }); - sandbox.set_phase(SandboxPhase::Ready as i32); runtime.store.put_message(&sandbox).await.unwrap(); runtime @@ -2505,7 +2452,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Provisioning ); let ready = stored @@ -2591,7 +2538,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); @@ -2684,7 +2631,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); } @@ -2847,7 +2794,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Error ); let ready = stored @@ -2879,7 +2826,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Error ); let ready = stored @@ -2906,7 +2853,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); } diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index d13f7d13d..4e74a7df1 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -12,9 +12,6 @@ use crate::persistence::{ ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, }; use openshell_core::proto::{Provider, Sandbox}; -use openshell_core::telemetry::{ - LifecycleOperation, ProviderProfile as TelemetryProviderProfile, TelemetryOutcome, -}; use prost::Message; use tonic::Status; use tracing::warn; @@ -446,14 +443,6 @@ pub(super) async fn resolve_provider_environment( .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; for (key, value) in &provider.credentials { - if is_non_injectable_provider_credential(&provider, key) { - warn!( - provider_name = %name, - key = %key, - "skipping non-injectable provider credential" - ); - continue; - } if is_valid_env_key(key) { let expires_at_ms = provider .credential_expires_at_ms @@ -481,53 +470,6 @@ pub(super) async fn resolve_provider_environment( ); } } - - // For Vertex AI providers, inject agent-specific config env vars so that - // Claude Code, Goose, and OpenCode inside the sandbox can reach Vertex AI - // without additional configuration. Credentials from the loop above take - // precedence via entry().or_insert(), and sandbox --env overrides are - // applied at the process level after this environment is installed, so - // they naturally shadow these values. - if openshell_core::inference::normalize_inference_provider_type(&provider.r#type) - == Some("google-vertex-ai") - { - let project_id = provider - .config - .get(openshell_core::inference::VERTEX_AI_PROJECT_ID_KEY) - .map(String::as_str) - .unwrap_or_default() - .trim(); - let region = provider - .config - .get(openshell_core::inference::VERTEX_AI_REGION_KEY) - .map(String::as_str) - .unwrap_or_default() - .trim(); - - // Static flags -- always present for Vertex AI providers. - env.entry("GOOSE_PROVIDER".to_string()) - .or_insert_with(|| "gcp_vertex_ai".to_string()); - - // Project ID derived vars. - if !project_id.is_empty() { - env.entry("ANTHROPIC_VERTEX_PROJECT_ID".to_string()) - .or_insert_with(|| project_id.to_string()); - env.entry("GCP_PROJECT_ID".to_string()) - .or_insert_with(|| project_id.to_string()); - env.entry("GOOGLE_CLOUD_PROJECT".to_string()) - .or_insert_with(|| project_id.to_string()); - } - - // Region derived vars. - if !region.is_empty() { - env.entry("CLOUD_ML_REGION".to_string()) - .or_insert_with(|| region.to_string()); - env.entry("GCP_LOCATION".to_string()) - .or_insert_with(|| region.to_string()); - env.entry("VERTEX_LOCATION".to_string()) - .or_insert_with(|| region.to_string()); - } - } } Ok(ProviderEnvironment { @@ -648,7 +590,6 @@ fn active_provider_credential_keys(provider: &Provider, now_ms: i64) -> Vec Vec bool { - openshell_core::inference::normalize_inference_provider_type(&provider.r#type) - == Some("google-vertex-ai") - && key == "GOOGLE_SERVICE_ACCOUNT_KEY" -} - pub(super) fn is_valid_env_key(key: &str) -> bool { let mut bytes = key.bytes(); let Some(first) = bytes.next() else { @@ -707,7 +642,7 @@ use openshell_core::proto::{ }; use openshell_providers::{ CredentialRefreshProfile, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, - get_default_profile, normalize_profile_id, normalize_provider_type, validate_profile_set, + get_default_profile, normalize_profile_id, validate_profile_set, }; use std::sync::Arc; use tonic::{Request, Response}; @@ -717,36 +652,14 @@ pub(super) async fn handle_create_provider( request: Request, ) -> Result, Status> { let req = request.into_inner(); - let Some(provider) = req.provider else { - emit_provider_lifecycle( - "custom", - LifecycleOperation::Create, - TelemetryOutcome::Failure, - ); - return Err(Status::invalid_argument("provider is required")); - }; - let provider_type = provider.r#type.clone(); - let result = create_provider_record(state.store.as_ref(), provider).await; - match result { - Ok(provider) => { - emit_provider_lifecycle( - &provider.r#type, - LifecycleOperation::Create, - TelemetryOutcome::Success, - ); - Ok(Response::new(ProviderResponse { - provider: Some(provider), - })) - } - Err(err) => { - emit_provider_lifecycle( - &provider_type, - LifecycleOperation::Create, - TelemetryOutcome::Failure, - ); - Err(err) - } - } + let provider = req + .provider + .ok_or_else(|| Status::invalid_argument("provider is required"))?; + let provider = create_provider_record(state.store.as_ref(), provider).await?; + + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) } pub(super) async fn handle_get_provider( @@ -983,7 +896,17 @@ async fn provider_type_allows_empty_credentials_for_refresh( let Some(profile) = get_provider_type_profile(store, provider_type).await? else { return Ok(false); }; - Ok(profile.allows_gateway_refresh_bootstrap()) + let required_credentials = profile + .credentials + .iter() + .filter(|credential| credential.required) + .collect::>(); + Ok(!required_credentials.is_empty() + && required_credentials.iter().all(|credential| { + credential.refresh.as_ref().is_some_and(|refresh| { + crate::provider_refresh::is_gateway_mintable_strategy(refresh.strategy) + }) + })) } async fn merged_provider_profiles(store: &Store) -> Result, Status> { @@ -1162,39 +1085,17 @@ pub(super) async fn handle_update_provider( request: Request, ) -> Result, Status> { let req = request.into_inner(); - let Some(mut provider) = req.provider else { - emit_provider_lifecycle( - "custom", - LifecycleOperation::Update, - TelemetryOutcome::Failure, - ); - return Err(Status::invalid_argument("provider is required")); - }; - let provider_type = provider.r#type.clone(); + let mut provider = req + .provider + .ok_or_else(|| Status::invalid_argument("provider is required"))?; provider .credential_expires_at_ms .extend(req.credential_expires_at_ms); - let result = update_provider_record(state.store.as_ref(), provider).await; - match result { - Ok(provider) => { - emit_provider_lifecycle( - &provider.r#type, - LifecycleOperation::Update, - TelemetryOutcome::Success, - ); - Ok(Response::new(ProviderResponse { - provider: Some(provider), - })) - } - Err(err) => { - emit_provider_lifecycle( - &provider_type, - LifecycleOperation::Update, - TelemetryOutcome::Failure, - ); - Err(err) - } - } + let provider = update_provider_record(state.store.as_ref(), provider).await?; + + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) } pub(super) async fn handle_get_provider_refresh_status( @@ -1387,45 +1288,88 @@ pub(super) async fn handle_configure_provider_refresh( .await?; let mut material = request.material; let mut secret_material_keys = request.secret_material_keys; - if strategy == ProviderCredentialRefreshStrategy::Oauth2TokenExchange { - match (principal.as_ref(), raw_bearer_token.as_ref()) { - (Some(Principal::User(user)), Some(raw)) - if user.identity.provider == IdentityProvider::Oidc => - { - material.insert("subject_token".to_string(), raw.0.clone()); - if !secret_material_keys - .iter() - .any(|key| key == "subject_token") + match strategy { + ProviderCredentialRefreshStrategy::Oauth2TokenExchange => { + match (principal.as_ref(), raw_bearer_token.as_ref()) { + (Some(Principal::User(user)), Some(raw)) + if user.identity.provider == IdentityProvider::Oidc => { - secret_material_keys.push("subject_token".to_string()); + material.insert("subject_token".to_string(), raw.0.clone()); + if !secret_material_keys + .iter() + .any(|key| key == "subject_token") + { + secret_material_keys.push("subject_token".to_string()); + } + } + _ => { + if let Some(existing) = existing_refresh_state + .as_ref() + .and_then(|state| state.material.get("subject_token")) + { + material + .entry("subject_token".to_string()) + .or_insert_with(|| existing.clone()); + } else { + return Err(Status::failed_precondition( + "oauth2_token_exchange refresh requires an authenticated OIDC user bearer token during configuration", + )); + } + if existing_refresh_state.as_ref().is_some_and(|state| { + state + .secret_material_keys + .iter() + .any(|key| key == "subject_token") + }) && !secret_material_keys + .iter() + .any(|key| key == "subject_token") + { + secret_material_keys.push("subject_token".to_string()); + } } } - _ => { - if let Some(existing) = existing_refresh_state - .as_ref() - .and_then(|state| state.material.get("subject_token")) + } + ProviderCredentialRefreshStrategy::OktaXaa => { + match (principal.as_ref(), raw_id_token.as_ref()) { + (Some(Principal::User(user)), Some(raw)) + if user.identity.provider == IdentityProvider::Oidc => { - material - .entry("subject_token".to_string()) - .or_insert_with(|| existing.clone()); - } else { - return Err(Status::failed_precondition( - "oauth2_token_exchange refresh requires an authenticated OIDC user bearer token during configuration", - )); + material.insert("subject_token".to_string(), raw.0.clone()); + if !secret_material_keys + .iter() + .any(|key| key == "subject_token") + { + secret_material_keys.push("subject_token".to_string()); + } } - if existing_refresh_state.as_ref().is_some_and(|state| { - state - .secret_material_keys + _ => { + if let Some(existing) = existing_refresh_state + .as_ref() + .and_then(|state| state.material.get("subject_token")) + { + material + .entry("subject_token".to_string()) + .or_insert_with(|| existing.clone()); + } else { + return Err(Status::failed_precondition( + "okta_xaa refresh requires an authenticated OIDC user id_token during configuration", + )); + } + if existing_refresh_state.as_ref().is_some_and(|state| { + state + .secret_material_keys + .iter() + .any(|key| key == "subject_token") + }) && !secret_material_keys .iter() .any(|key| key == "subject_token") - }) && !secret_material_keys - .iter() - .any(|key| key == "subject_token") - { - secret_material_keys.push("subject_token".to_string()); + { + secret_material_keys.push("subject_token".to_string()); + } } } } + _ => {} } let expires_at_ms = request.expires_at_ms.unwrap_or_else(|| { existing_refresh_state @@ -1578,69 +1522,9 @@ pub(super) async fn handle_delete_provider( request: Request, ) -> Result, Status> { let name = request.into_inner().name; - let provider_profile = provider_profile_for_name(state.store.as_ref(), &name).await; - let result = delete_provider_record(state.store.as_ref(), &name).await; - match result { - Ok(deleted) => { - let outcome = TelemetryOutcome::from_success(deleted); - emit_provider_profile_lifecycle( - provider_profile.unwrap_or(TelemetryProviderProfile::Custom), - LifecycleOperation::Delete, - outcome, - ); - Ok(Response::new(DeleteProviderResponse { deleted })) - } - Err(err) => { - emit_provider_profile_lifecycle( - provider_profile.unwrap_or(TelemetryProviderProfile::Custom), - LifecycleOperation::Delete, - TelemetryOutcome::Failure, - ); - Err(err) - } - } -} - -fn emit_provider_lifecycle( - provider_type: &str, - operation: LifecycleOperation, - outcome: TelemetryOutcome, -) { - let provider_profile = telemetry_provider_profile(provider_type); - emit_provider_profile_lifecycle(provider_profile, operation, outcome); -} - -fn emit_provider_profile_lifecycle( - provider_profile: TelemetryProviderProfile, - operation: LifecycleOperation, - outcome: TelemetryOutcome, -) { - openshell_core::telemetry::emit_provider_lifecycle(operation, outcome, provider_profile); -} - -async fn provider_profile_for_name(store: &Store, name: &str) -> Option { - store - .get_message_by_name::(name) - .await - .ok() - .flatten() - .map(|provider| telemetry_provider_profile(&provider.r#type)) -} + let deleted = delete_provider_record(state.store.as_ref(), &name).await?; -fn telemetry_provider_profile(provider_type: &str) -> TelemetryProviderProfile { - match normalize_provider_type(provider_type) { - Some("anthropic") => TelemetryProviderProfile::Anthropic, - Some("claude" | "claude-code") => TelemetryProviderProfile::Claude, - Some("codex") => TelemetryProviderProfile::Codex, - Some("copilot") => TelemetryProviderProfile::Copilot, - Some("github") => TelemetryProviderProfile::Github, - Some("gitlab") => TelemetryProviderProfile::Gitlab, - Some("nvidia") => TelemetryProviderProfile::Nvidia, - Some("openai") => TelemetryProviderProfile::Openai, - Some("opencode") => TelemetryProviderProfile::Opencode, - Some("outlook") => TelemetryProviderProfile::Outlook, - _ => TelemetryProviderProfile::Custom, - } + Ok(Response::new(DeleteProviderResponse { deleted })) } // --------------------------------------------------------------------------- @@ -1686,46 +1570,6 @@ mod tests { assert!(!is_valid_env_key("X;rm -rf /")); } - #[test] - fn telemetry_provider_profile_maps_unknown_to_custom() { - assert_eq!( - telemetry_provider_profile("CLAUDE"), - TelemetryProviderProfile::Claude - ); - assert_eq!( - telemetry_provider_profile("github"), - TelemetryProviderProfile::Github - ); - assert_eq!( - telemetry_provider_profile("gh"), - TelemetryProviderProfile::Github - ); - assert_eq!( - telemetry_provider_profile("glab"), - TelemetryProviderProfile::Gitlab - ); - assert_eq!( - telemetry_provider_profile("outlook"), - TelemetryProviderProfile::Outlook - ); - assert_eq!( - telemetry_provider_profile("generic"), - TelemetryProviderProfile::Custom - ); - assert_eq!( - telemetry_provider_profile("unknown-private"), - TelemetryProviderProfile::Custom - ); - assert_eq!( - telemetry_provider_profile("acme-internal"), - TelemetryProviderProfile::Custom - ); - assert_eq!( - telemetry_provider_profile("corp-llm-prod"), - TelemetryProviderProfile::Custom - ); - } - fn provider_with_values(name: &str, provider_type: &str) -> Provider { Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -1869,7 +1713,6 @@ mod tests { vec![ "claude-code", "github", - "google-vertex-ai", "nvidia", "okta-obo", "okta-xaa", @@ -1996,15 +1839,13 @@ mod tests { #[tokio::test] async fn import_provider_profile_allows_legacy_provider_type_ids_without_built_in_profiles() { - // Use an ID that is not a built-in profile to test legacy import. - // "custom-llm" is not registered as a built-in and never will be. let state = test_server_state().await; let response = handle_import_provider_profiles( &state, Request::new(ImportProviderProfilesRequest { profiles: vec![ProviderProfileImportItem { - profile: Some(custom_profile("custom-llm")), - source: "custom-llm.yaml".to_string(), + profile: Some(custom_profile("codex")), + source: "codex.yaml".to_string(), }], }), ) @@ -2018,15 +1859,15 @@ mod tests { let imported = handle_get_provider_profile( &state, Request::new(GetProviderProfileRequest { - id: "custom-llm".to_string(), + id: "codex".to_string(), }), ) .await .unwrap() .into_inner() .profile - .expect("custom-llm profile should be returned"); - assert_eq!(imported.id, "custom-llm"); + .expect("codex profile should be returned"); + assert_eq!(imported.id, "codex"); } #[tokio::test] @@ -2444,68 +2285,6 @@ mod tests { ); } - #[tokio::test] - async fn configure_provider_refresh_accepts_vertex_service_account_token_key() { - let state = test_server_state().await; - create_provider_record( - state.store.as_ref(), - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-sa".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: std::iter::once(( - "GOOGLE_SERVICE_ACCOUNT_KEY".to_string(), - "{\"type\":\"service_account\"}".to_string(), - )) - .collect(), - config: HashMap::new(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let response = handle_configure_provider_refresh( - &state, - Request::new(ConfigureProviderRefreshRequest { - provider: "vertex-sa".to_string(), - credential_key: "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string(), - strategy: ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32, - material: HashMap::from([ - ( - "client_email".to_string(), - "sa@test-project.iam.gserviceaccount.com".to_string(), - ), - ( - "private_key".to_string(), - "-----BEGIN PRIVATE KEY-----\nkey\n-----END PRIVATE KEY-----".to_string(), - ), - ]), - secret_material_keys: vec!["private_key".to_string()], - expires_at_ms: None, - }), - ) - .await - .unwrap() - .into_inner() - .status - .expect("status"); - - assert_eq!( - response.credential_key, - "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN" - ); - assert_eq!( - response.strategy, - ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32 - ); - } - #[tokio::test] async fn delete_provider_refresh_preserves_manually_updated_expiry() { let state = test_server_state().await; @@ -3375,26 +3154,6 @@ mod tests { .unwrap(); assert!(optional_static_empty.credentials.is_empty()); - let vertex_empty = create_provider_record( - store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-no-token-yet".to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: HashMap::new(), - config: HashMap::new(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - assert!(vertex_empty.credentials.is_empty()); - let get_err = get_provider_record(store, "").await.unwrap_err(); assert_eq!(get_err.code(), Code::InvalidArgument); @@ -3845,257 +3604,6 @@ mod tests { assert!(err.message().contains("provider-b")); } - #[tokio::test] - async fn resolve_provider_env_injects_vertex_agent_config() { - let store = test_store().await; - create_provider_record( - &store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-local".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: std::iter::once(( - "GOOGLE_VERTEX_AI_TOKEN".to_string(), - "ya29.token".to_string(), - )) - .collect(), - config: [ - ( - "VERTEX_AI_PROJECT_ID".to_string(), - "my-gcp-project".to_string(), - ), - ("VERTEX_AI_REGION".to_string(), "us-central1".to_string()), - ] - .into_iter() - .collect(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let result = resolve_provider_environment(&store, &["vertex-local".to_string()]) - .await - .unwrap(); - - // Credential still injected. - assert_eq!( - result.get("GOOGLE_VERTEX_AI_TOKEN"), - Some(&"ya29.token".to_string()) - ); - // Static flags. - assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); - assert_eq!( - result.get("GOOSE_PROVIDER"), - Some(&"gcp_vertex_ai".to_string()) - ); - // Project ID derived vars. - assert_eq!( - result.get("ANTHROPIC_VERTEX_PROJECT_ID"), - Some(&"my-gcp-project".to_string()) - ); - assert_eq!( - result.get("GCP_PROJECT_ID"), - Some(&"my-gcp-project".to_string()) - ); - assert_eq!( - result.get("GOOGLE_CLOUD_PROJECT"), - Some(&"my-gcp-project".to_string()) - ); - // Region derived vars. - assert_eq!( - result.get("CLOUD_ML_REGION"), - Some(&"us-central1".to_string()) - ); - assert_eq!(result.get("GCP_LOCATION"), Some(&"us-central1".to_string())); - assert_eq!( - result.get("VERTEX_LOCATION"), - Some(&"us-central1".to_string()) - ); - } - - #[tokio::test] - async fn resolve_provider_env_vertex_never_injects_service_account_key() { - let store = test_store().await; - create_provider_record( - &store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-bootstrap".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: [ - ( - "GOOGLE_SERVICE_ACCOUNT_KEY".to_string(), - r#"{"type":"service_account","private_key":"secret"}"#.to_string(), - ), - ( - "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string(), - "ya29.short-lived".to_string(), - ), - ] - .into_iter() - .collect(), - config: HashMap::new(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let result = resolve_provider_environment(&store, &["vertex-bootstrap".to_string()]) - .await - .unwrap(); - - assert!(!result.contains_key("GOOGLE_SERVICE_ACCOUNT_KEY")); - assert_eq!( - result.get("GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), - Some(&"ya29.short-lived".to_string()) - ); - } - - #[tokio::test] - async fn resolve_provider_env_vertex_omits_agent_config_when_project_and_region_absent() { - let store = test_store().await; - create_provider_record( - &store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-no-config".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: std::iter::once(( - "GOOGLE_VERTEX_AI_TOKEN".to_string(), - "ya29.token".to_string(), - )) - .collect(), - config: HashMap::new(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let result = resolve_provider_environment(&store, &["vertex-no-config".to_string()]) - .await - .unwrap(); - - // Static flags still present. - assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); - assert_eq!( - result.get("GOOSE_PROVIDER"), - Some(&"gcp_vertex_ai".to_string()) - ); - // Project ID and region derived vars are absent. - assert!(!result.contains_key("ANTHROPIC_VERTEX_PROJECT_ID")); - assert!(!result.contains_key("GCP_PROJECT_ID")); - assert!(!result.contains_key("GOOGLE_CLOUD_PROJECT")); - assert!(!result.contains_key("CLOUD_ML_REGION")); - assert!(!result.contains_key("GCP_LOCATION")); - assert!(!result.contains_key("VERTEX_LOCATION")); - } - - #[tokio::test] - async fn resolve_provider_env_vertex_credential_wins_over_agent_config_key() { - // If a credential happens to share a name with one of the injected agent - // config keys, the credential value takes precedence because the credential - // loop runs first and entry().or_insert() does not overwrite. - let store = test_store().await; - create_provider_record( - &store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "vertex-collision".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "google-vertex-ai".to_string(), - credentials: [ - ( - "GOOGLE_VERTEX_AI_TOKEN".to_string(), - "ya29.token".to_string(), - ), - // Same key as an injected static flag. - ("GOOSE_PROVIDER".to_string(), "custom-value".to_string()), - ] - .into_iter() - .collect(), - config: [ - ("VERTEX_AI_PROJECT_ID".to_string(), "my-project".to_string()), - ("VERTEX_AI_REGION".to_string(), "us-east1".to_string()), - ] - .into_iter() - .collect(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let result = resolve_provider_environment(&store, &["vertex-collision".to_string()]) - .await - .unwrap(); - - // Credential value wins over the injected static value. - assert_eq!( - result.get("GOOSE_PROVIDER"), - Some(&"custom-value".to_string()) - ); - } - - #[tokio::test] - async fn resolve_provider_env_non_vertex_provider_does_not_inject_agent_config() { - let store = test_store().await; - create_provider_record( - &store, - Provider { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: String::new(), - name: "openai-local".to_string(), - created_at_ms: 0, - labels: HashMap::new(), - resource_version: 0, - }), - r#type: "openai".to_string(), - credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-test".to_string())) - .collect(), - config: HashMap::new(), - credential_expires_at_ms: HashMap::new(), - }, - ) - .await - .unwrap(); - - let result = resolve_provider_environment(&store, &["openai-local".to_string()]) - .await - .unwrap(); - - assert_eq!(result.get("OPENAI_API_KEY"), Some(&"sk-test".to_string())); - assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); - assert!(!result.contains_key("GOOSE_PROVIDER")); - assert!(!result.contains_key("ANTHROPIC_VERTEX_PROJECT_ID")); - assert!(!result.contains_key("GCP_PROJECT_ID")); - assert!(!result.contains_key("GOOGLE_CLOUD_PROJECT")); - assert!(!result.contains_key("CLOUD_ML_REGION")); - assert!(!result.contains_key("GCP_LOCATION")); - assert!(!result.contains_key("VERTEX_LOCATION")); - } - #[tokio::test] async fn update_provider_rejects_credential_key_collision_for_attached_sandbox() { let store = test_store().await; @@ -4216,7 +3724,7 @@ mod tests { .await .unwrap(); - let mut sandbox = Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: "sandbox-001".to_string(), name: "test-sandbox".to_string(), @@ -4229,8 +3737,9 @@ mod tests { ..SandboxSpec::default() }), status: None, + phase: SandboxPhase::Ready as i32, + ..Default::default() }; - sandbox.set_phase(SandboxPhase::Ready as i32); store.put_message(&sandbox).await.unwrap(); let loaded = store @@ -4252,7 +3761,7 @@ mod tests { let store = test_store().await; - let mut sandbox = Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: "sandbox-002".to_string(), name: "empty-sandbox".to_string(), @@ -4262,8 +3771,9 @@ mod tests { }), spec: Some(SandboxSpec::default()), status: None, + phase: SandboxPhase::Ready as i32, + ..Default::default() }; - sandbox.set_phase(SandboxPhase::Ready as i32); store.put_message(&sandbox).await.unwrap(); let loaded = store diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 06f67fb4c..0753d39f4 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -23,10 +23,6 @@ use openshell_core::proto::{ TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; -use openshell_core::telemetry::{ - LifecycleOperation, LifecycleResource, SandboxTemplateSource, TelemetryComputeDriver, - TelemetryOutcome, -}; use openshell_core::{ObjectId, ObjectName}; use prost::Message; use std::net::IpAddr; @@ -60,62 +56,6 @@ const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; pub(super) async fn handle_create_sandbox( state: &Arc, request: Request, -) -> Result, Status> { - let create_request = request.get_ref().clone(); - let result = handle_create_sandbox_inner(state, request).await; - emit_sandbox_create_telemetry( - state, - &create_request, - TelemetryOutcome::from_success(result.is_ok()), - ); - result -} - -fn emit_sandbox_create_telemetry( - state: &Arc, - request: &CreateSandboxRequest, - outcome: TelemetryOutcome, -) { - let compute_driver = telemetry_compute_driver(state.compute.driver_kind()); - let Some(spec) = request.spec.as_ref() else { - openshell_core::telemetry::emit_sandbox_create( - outcome, - false, - 0, - false, - SandboxTemplateSource::Undefined, - compute_driver, - ); - return; - }; - let template_source = if spec - .template - .as_ref() - .is_some_and(|template| !template.image.trim().is_empty()) - { - SandboxTemplateSource::Image - } else { - SandboxTemplateSource::Default - }; - openshell_core::telemetry::emit_sandbox_create( - outcome, - spec.gpu, - spec.providers.len() as u64, - spec.policy.is_some(), - template_source, - compute_driver, - ); -} - -fn telemetry_compute_driver( - driver_kind: Option, -) -> TelemetryComputeDriver { - TelemetryComputeDriver::from_driver_kind(driver_kind) -} - -async fn handle_create_sandbox_inner( - state: &Arc, - request: Request, ) -> Result, Status> { use crate::persistence::current_time_ms; @@ -167,7 +107,7 @@ async fn handle_create_sandbox_inner( let now_ms = current_time_ms(); - let mut sandbox = Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: id.clone(), name: name.clone(), @@ -177,8 +117,9 @@ async fn handle_create_sandbox_inner( }), spec: Some(spec), status: None, + phase: SandboxPhase::Provisioning as i32, + current_policy_version: 0, }; - sandbox.set_phase(SandboxPhase::Provisioning as i32); // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) super::validation::validate_object_metadata(sandbox.metadata.as_ref(), "sandbox")?; @@ -211,7 +152,11 @@ async fn handle_create_sandbox_inner( Some(Err(status)) => return Err(status), None => None, }; - let sandbox = state.compute.create_sandbox(sandbox, sandbox_token).await?; + + let sandbox = match state.compute.create_sandbox(sandbox, sandbox_token).await { + Ok(sandbox) => sandbox, + Err(err) => return Err(err), + }; info!( sandbox_id = %id, @@ -470,23 +415,6 @@ pub(super) async fn handle_detach_sandbox_provider( pub(super) async fn handle_delete_sandbox( state: &Arc, request: Request, -) -> Result, Status> { - let result = handle_delete_sandbox_inner(state, request).await; - let outcome = match &result { - Ok(response) if response.get_ref().deleted => TelemetryOutcome::Success, - _ => TelemetryOutcome::Failure, - }; - openshell_core::telemetry::emit_lifecycle( - LifecycleResource::Sandbox, - LifecycleOperation::Delete, - outcome, - ); - result -} - -async fn handle_delete_sandbox_inner( - state: &Arc, - request: Request, ) -> Result, Status> { let name = request.into_inner().name; if name.is_empty() { @@ -635,7 +563,7 @@ pub(super) async fn handle_watch_sandbox( if stop_on_terminal { let phase = - SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); if phase == SandboxPhase::Ready { return; } @@ -705,7 +633,7 @@ pub(super) async fn handle_watch_sandbox( return; } if stop_on_terminal { - let phase = SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); + let phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); if phase == SandboxPhase::Ready { return; } @@ -808,7 +736,7 @@ pub(super) async fn handle_exec_sandbox( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -922,7 +850,7 @@ pub(super) async fn handle_forward_tcp( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1251,7 +1179,7 @@ pub(super) async fn handle_exec_sandbox_interactive( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1324,7 +1252,7 @@ pub(super) async fn handle_create_ssh_session( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1927,36 +1855,15 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; + use crate::auth::identity::{Identity, IdentityProvider}; + use crate::auth::oidc::RawBearerToken; + use crate::auth::principal::{Principal, UserPrincipal}; use crate::grpc::test_support::test_server_state; use openshell_core::proto::datamodel::v1::ObjectMeta; use std::collections::HashMap; // ---- shell_escape ---- - #[test] - fn telemetry_compute_driver_uses_resolved_driver_kind() { - assert_eq!( - telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Docker)), - TelemetryComputeDriver::Docker - ); - assert_eq!( - telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Kubernetes)), - TelemetryComputeDriver::Kubernetes - ); - assert_eq!( - telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Podman)), - TelemetryComputeDriver::Podman - ); - assert_eq!( - telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Vm)), - TelemetryComputeDriver::Vm - ); - assert_eq!( - telemetry_compute_driver(None), - TelemetryComputeDriver::Unknown - ); - } - #[test] fn shell_escape_safe_chars_pass_through() { assert_eq!(shell_escape("ls").unwrap(), "ls"); @@ -2190,7 +2097,7 @@ mod tests { } fn test_sandbox(name: &str, providers: Vec) -> Sandbox { - let mut sandbox = Sandbox { + Sandbox { metadata: Some(ObjectMeta { id: format!("sandbox-{name}"), name: name.to_string(), @@ -2204,11 +2111,10 @@ mod tests { providers, ..Default::default() }), + phase: SandboxPhase::Ready as i32, + current_policy_version: 7, ..Default::default() - }; - sandbox.set_phase(SandboxPhase::Ready as i32); - sandbox.set_current_policy_version(7); - sandbox + } } #[tokio::test] @@ -2244,11 +2150,11 @@ mod tests { .await .unwrap() .unwrap(); - assert_eq!(sandbox.phase(), SandboxPhase::Ready as i32); - assert_eq!(sandbox.current_policy_version(), 7); let spec = sandbox.spec.unwrap(); assert_eq!(spec.providers, vec!["work-github"]); assert_eq!(spec.log_level, "debug"); + assert_eq!(sandbox.phase, SandboxPhase::Ready as i32); + assert_eq!(sandbox.current_policy_version, 7); } #[tokio::test] @@ -2529,7 +2435,7 @@ mod tests { async fn interactive_exec_rejects_sandbox_not_ready() { let state = test_server_state().await; let mut sandbox = test_sandbox("not-ready", Vec::new()); - sandbox.set_phase(SandboxPhase::Provisioning as i32); + sandbox.phase = SandboxPhase::Provisioning as i32; state.store.put_message(&sandbox).await.unwrap(); let stored = state @@ -2539,7 +2445,7 @@ mod tests { .unwrap() .unwrap(); assert_ne!( - SandboxPhase::try_from(stored.phase()).ok(), + SandboxPhase::try_from(stored.phase).ok(), Some(SandboxPhase::Ready) ); } @@ -2578,6 +2484,40 @@ mod tests { assert!(err.message().contains("provider-b")); } + #[tokio::test] + async fn create_sandbox_succeeds_for_oidc_user_without_persisted_binding() { + let state = test_server_state().await; + let mut request = Request::new(CreateSandboxRequest { + name: "delegated".to_string(), + spec: Some(openshell_core::proto::SandboxSpec::default()), + labels: HashMap::new(), + }); + request + .extensions_mut() + .insert(Principal::User(UserPrincipal { + identity: Identity { + subject: "user-123".to_string(), + display_name: Some("alex".to_string()), + roles: vec!["openshell-user".to_string()], + scopes: vec!["sandbox:write".to_string()], + provider: IdentityProvider::Oidc, + }, + })); + request + .extensions_mut() + .insert(RawBearerToken("raw-access-token".to_string())); + + let response = handle_create_sandbox(&state, request) + .await + .expect("sandbox create succeeds") + .into_inner(); + let sandbox = response.sandbox.expect("sandbox present"); + assert_eq!( + sandbox.metadata.as_ref().expect("metadata").name, + "delegated" + ); + } + #[tokio::test] async fn attach_sandbox_provider_rejects_credential_key_collisions() { let state = test_server_state().await; diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index f1640cc19..a5acb86d1 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -331,7 +331,7 @@ pub async fn refresh_provider_credential( "provider credential refresh started" ); - match mint_credential(store, &state).await { + match mint_credential(&state).await { Ok(minted) => { let now_ms = current_time_ms(); if let Err(err) = @@ -449,7 +449,6 @@ async fn apply_minted_credential( } async fn mint_credential( - store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { let strategy = ProviderCredentialRefreshStrategy::try_from(state.strategy) @@ -462,12 +461,9 @@ async fn mint_credential( mint_oauth2_client_credentials(state).await } ProviderCredentialRefreshStrategy::Oauth2TokenExchange => { - let _ = store; mint_oauth2_token_exchange(state).await } - ProviderCredentialRefreshStrategy::OktaXaa => { - mint_okta_xaa_token_exchange(store, state).await - } + ProviderCredentialRefreshStrategy::OktaXaa => mint_okta_xaa_token_exchange(state).await, ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => { mint_google_service_account_jwt(state).await } @@ -557,7 +553,6 @@ async fn mint_oauth2_client_credentials( } async fn mint_okta_xaa_token_exchange( - store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { if material_value( @@ -571,23 +566,11 @@ async fn mint_okta_xaa_token_exchange( ) .is_some() { - return mint_okta_xaa_sample_token_exchange(store, state).await; + return mint_okta_xaa_sample_token_exchange(state).await; } let token_url = oauth2_token_url(state)?; - let sandbox_id = required_material(&state.material, "sandbox_id")?; - let binding = crate::delegation::get_binding(store, &sandbox_id) - .await? - .ok_or_else(|| { - Status::failed_precondition(format!( - "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" - )) - })?; - if binding.id_token.trim().is_empty() { - return Err(Status::failed_precondition( - "sandbox delegation binding does not contain an OIDC id_token", - )); - } + let subject_token = required_material(&state.material, "subject_token")?; let client_id = required_material(&state.material, "client_id")?; let client_assertion = build_okta_xaa_client_assertion(state, &token_url, &client_id)?; @@ -607,7 +590,7 @@ async fn mint_okta_xaa_token_exchange( ("client_id".to_string(), client_id), ("client_assertion".to_string(), client_assertion), ("client_assertion_type".to_string(), client_assertion_type), - ("subject_token".to_string(), binding.id_token), + ("subject_token".to_string(), subject_token), ( "subject_token_type".to_string(), "urn:ietf:params:oauth:token-type:id_token".to_string(), @@ -623,23 +606,9 @@ async fn mint_okta_xaa_token_exchange( } async fn mint_okta_xaa_sample_token_exchange( - store: &Store, state: &StoredProviderCredentialRefreshState, ) -> Result { - let sandbox_id = required_material(&state.material, "sandbox_id")?; - let binding = crate::delegation::get_binding(store, &sandbox_id) - .await? - .ok_or_else(|| { - Status::failed_precondition(format!( - "sandbox delegation binding not found for sandbox_id '{sandbox_id}'" - )) - })?; - if binding.id_token.trim().is_empty() { - return Err(Status::failed_precondition( - "sandbox delegation binding does not contain an OIDC id_token", - )); - } - + let subject_token = required_material(&state.material, "subject_token")?; let requesting_client_id = required_material(&state.material, "requesting_client_id")?; let requesting_client_secret = required_material(&state.material, "requesting_client_secret")?; let resource_client_id = required_material(&state.material, "resource_client_id")?; @@ -659,7 +628,7 @@ async fn mint_okta_xaa_sample_token_exchange( "requested_token_type".to_string(), "urn:ietf:params:oauth:token-type:id-jag".to_string(), ), - ("subject_token".to_string(), binding.id_token), + ("subject_token".to_string(), subject_token), ( "subject_token_type".to_string(), "urn:ietf:params:oauth:token-type:id_token".to_string(), @@ -1421,29 +1390,6 @@ mod tests { .await; let store = test_store().await; - let sandbox = Sandbox { - metadata: Some(ObjectMeta { - id: "sandbox-xaa".to_string(), - name: "xaa".to_string(), - created_at_ms: 1_000_000, - labels: HashMap::new(), - resource_version: 0, - }), - spec: Some(SandboxSpec::default()), - ..Default::default() - }; - let binding = new_binding( - &sandbox, - "user-123", - Some("alex"), - "oidc", - "user-access-token", - Some("user-id-token"), - &["jira.read".to_string()], - ) - .unwrap(); - put_binding(&store, &binding).await.unwrap(); - let provider = provider("my-xaa", "okta-xaa"); store.put_message(&provider).await.unwrap(); let state = new_refresh_state( @@ -1453,7 +1399,7 @@ mod tests { strategy: ProviderCredentialRefreshStrategy::OktaXaa, material: HashMap::from([ ("client_id".to_string(), "agent-client-id".to_string()), - ("sandbox_id".to_string(), "sandbox-xaa".to_string()), + ("subject_token".to_string(), "user-id-token".to_string()), ( "audience".to_string(), "https://nvidia-partner.oktapreview.com".to_string(), @@ -1464,7 +1410,10 @@ mod tests { ), ("kid".to_string(), "test-key-id".to_string()), ]), - secret_material_keys: vec!["private_key_pem".to_string()], + secret_material_keys: vec![ + "private_key_pem".to_string(), + "subject_token".to_string(), + ], expires_at_ms: 0, token_url: format!("{}/token", mock_server.uri()), scopes: vec!["jira.read".to_string()], diff --git a/docs/get-started/tutorials/okta-obo.mdx b/docs/get-started/tutorials/okta-obo.mdx index 03062886c..cc8035774 100644 --- a/docs/get-started/tutorials/okta-obo.mdx +++ b/docs/get-started/tutorials/okta-obo.mdx @@ -91,7 +91,7 @@ openshell provider create \ ## Configure Token Exchange -Configure the delegated credential refresh: +Configure the delegated credential refresh. When this command runs while you are logged in with an OIDC gateway session, OpenShell captures the current user bearer token as secret refresh material. ```shell openshell provider refresh configure okta-obo-runtime \ diff --git a/docs/get-started/tutorials/xaa-dev.mdx b/docs/get-started/tutorials/xaa-dev.mdx index 17431f6dd..b5c45fe6a 100644 --- a/docs/get-started/tutorials/xaa-dev.mdx +++ b/docs/get-started/tutorials/xaa-dev.mdx @@ -8,7 +8,7 @@ description: "Configure the built-in xaa-dev provider profile and prove the samp keywords: "Generative AI, Cybersecurity, Tutorial, XAA, Cross App Access, ID-JAG, Providers, Sandbox" --- -Use the built-in `xaa-dev` profile to exercise the sample 2-step Cross App Access flow backed by `xaa.dev`. OpenShell binds the logged-in user identity to a sandbox, exchanges that user `id_token` for an `ID-JAG`, exchanges the `ID-JAG` for a delegated resource token, and injects the delegated token into the sandbox for outbound API calls. +Use the built-in `xaa-dev` profile to exercise the sample 2-step Cross App Access flow backed by `xaa.dev`. OpenShell captures the logged-in user's OIDC `id_token` during refresh configuration, exchanges that token for an `ID-JAG`, exchanges the `ID-JAG` for a delegated resource token, and injects the delegated token into the sandbox for outbound API calls. After completing this tutorial, you have: @@ -65,21 +65,19 @@ openshell settings set --global --key providers_v2_enabled --value true ## Log In and Create a Sandbox -Log in before creating the sandbox so OpenShell can bind the sandbox to the authenticated user's `id_token`: +Log in before configuring refresh so OpenShell can capture the authenticated user's `id_token`: ```shell openshell gateway login ``` -Create a sandbox and record its UUID: +Create a sandbox: ```shell openshell sandbox create --name xaa-dev-sample openshell sandbox get xaa-dev-sample ``` -The `provider refresh configure` command needs the sandbox UUID, not just the sandbox name. - ## Create the Sample Provider Create the provider from the built-in profile: @@ -92,7 +90,7 @@ openshell provider create \ ## Configure the 2-Step XAA Refresh -Configure the requesting-app and resource-app credentials: +Configure the requesting-app and resource-app credentials. When this command runs while you are logged in, OpenShell stores the current OIDC `id_token` as secret refresh material for the provider: ```shell openshell provider refresh configure xaa-dev-runtime \ @@ -100,7 +98,6 @@ openshell provider refresh configure xaa-dev-runtime \ --strategy okta-xaa \ --material requesting_client_id="$XAA_DEV_REQUESTING_CLIENT_ID" \ --material requesting_client_secret="$XAA_DEV_REQUESTING_CLIENT_SECRET" \ - --material sandbox_id="" \ --material resource_client_id="$XAA_DEV_RESOURCE_CLIENT_ID" \ --material resource_client_secret="$XAA_DEV_RESOURCE_CLIENT_SECRET" \ --material audience="https://auth.resource.xaa.dev" \ diff --git a/providers/okta-xaa.yaml b/providers/okta-xaa.yaml index 4347fe057..422a9ae06 100644 --- a/providers/okta-xaa.yaml +++ b/providers/okta-xaa.yaml @@ -25,9 +25,10 @@ credentials: description: XAA requesting app client secret used for the ID token to ID-JAG exchange required: true secret: true - - name: sandbox_id - description: OpenShell sandbox ID bound to the authenticated user's OIDC tokens - required: true + - name: subject_token + description: Authenticated user OIDC id_token captured during refresh configuration + required: false + secret: true - name: resource_client_id description: XAA resource app client ID used for the ID-JAG to access-token exchange required: true diff --git a/providers/xaa-dev.yaml b/providers/xaa-dev.yaml index 294a2f166..ed22d887d 100644 --- a/providers/xaa-dev.yaml +++ b/providers/xaa-dev.yaml @@ -26,9 +26,10 @@ credentials: description: XAA.dev requesting app client secret used for the ID token to ID-JAG exchange required: true secret: true - - name: sandbox_id - description: OpenShell sandbox ID bound to the authenticated user's OIDC id_token - required: true + - name: subject_token + description: OIDC id_token captured during refresh configuration for the sample XAA flow + required: false + secret: true - name: resource_client_id description: Resource client ID for the Todo0 sample resource authorization server required: true From b21e074eb70b217e52bb452c5ee31cb16b70c54c Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Wed, 3 Jun 2026 09:51:31 -0700 Subject: [PATCH 13/13] chore(okta-xaa): resolve rebase cleanup --- crates/openshell-cli/src/run.rs | 87 +- .../tests/provider_commands_integration.rs | 787 +++++++++++++++++- crates/openshell-core/src/metadata.rs | 22 +- crates/openshell-providers/src/profiles.rs | 35 + crates/openshell-server/src/compute/mod.rs | 371 +++++---- crates/openshell-server/src/grpc/provider.rs | 615 +++++++++++++- crates/openshell-server/src/grpc/sandbox.rs | 178 ++-- crates/openshell-server/src/telemetry.rs | 3 - 8 files changed, 1786 insertions(+), 312 deletions(-) diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index a8c8d9d3f..fe813804d 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -974,25 +974,49 @@ pub async fn gateway_add( eprintln!(" {} oidc", "Auth:".dimmed()); eprintln!(); - match crate::oidc_auth::oidc_browser_auth_flow( - issuer, - oidc_client_id, - oidc_audience, - oidc_scopes, - gateway_insecure, - ) - .await - { - Ok(bundle) => { - openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; - eprintln!("{} Authenticated successfully", "✓".green().bold()); + // Check for client_credentials env var (CI mode). + if std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").is_ok() { + match crate::oidc_auth::oidc_client_credentials_flow( + issuer, + oidc_client_id, + oidc_audience, + oidc_scopes, + gateway_insecure, + ) + .await + { + Ok(bundle) => { + openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; + eprintln!( + "{} Authenticated via client credentials", + "✓".green().bold() + ); + } + Err(e) => { + eprintln!("{} Authentication failed: {e}", "!".yellow()); + } } - Err(e) => { - eprintln!("{} Authentication skipped: {e}", "!".yellow()); - eprintln!( - " Authenticate later with: {}", - "openshell gateway login".dimmed(), - ); + } else { + match crate::oidc_auth::oidc_browser_auth_flow( + issuer, + oidc_client_id, + oidc_audience, + oidc_scopes, + gateway_insecure, + ) + .await + { + Ok(bundle) => { + openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; + eprintln!("{} Authenticated successfully", "✓".green().bold()); + } + Err(e) => { + eprintln!("{} Authentication skipped: {e}", "!".yellow()); + eprintln!( + " Authenticate later with: {}", + "openshell gateway login".dimmed(), + ); + } } } @@ -1177,14 +1201,25 @@ pub async fn gateway_login(name: &str, gateway_insecure: bool) -> Result<()> { let audience = metadata.oidc_audience.as_deref(); let scopes = metadata.oidc_scopes.as_deref(); - let bundle = crate::oidc_auth::oidc_browser_auth_flow( - issuer, - client_id, - audience, - scopes, - gateway_insecure, - ) - .await?; + let bundle = if std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").is_ok() { + crate::oidc_auth::oidc_client_credentials_flow( + issuer, + client_id, + audience, + scopes, + gateway_insecure, + ) + .await? + } else { + crate::oidc_auth::oidc_browser_auth_flow( + issuer, + client_id, + audience, + scopes, + gateway_insecure, + ) + .await? + }; let username = jwt_preferred_username(&bundle.access_token); openshell_bootstrap::oidc_token::store_oidc_token(name, &bundle)?; diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index c5849ca71..06220e251 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -44,6 +44,10 @@ struct ProviderState { profiles: Arc>>, refresh_statuses: Arc>>, refresh_requests: Arc>>, + delete_provider_requests: Arc>>, + fail_configure_refresh_message: Arc>>, + fail_rotate_refresh_message: Arc>>, + fail_delete_provider_message: Arc>>, sandbox_providers: Arc>>>, sandbox_provider_requests: Arc>>, global_settings: Arc>>, @@ -126,8 +130,6 @@ impl OpenShell for TestOpenShell { }), spec: None, status: None, - phase: 0, - current_policy_version: 0, }), })) } @@ -338,6 +340,28 @@ impl OpenShell for TestOpenShell { .into_inner() .provider .ok_or_else(|| Status::invalid_argument("provider is required"))?; + if provider.credentials.is_empty() { + let bootstrap_allowed = + if let Some(profile) = openshell_providers::get_default_profile(&provider.r#type) { + profile.allows_gateway_refresh_bootstrap() + } else { + self.state + .profiles + .lock() + .await + .get(&provider.r#type) + .cloned() + .is_some_and(|profile| { + openshell_providers::ProviderTypeProfile::from_proto(&profile) + .allows_gateway_refresh_bootstrap() + }) + }; + if !bootstrap_allowed { + return Err(Status::invalid_argument( + "provider.credentials must not be empty", + )); + } + } let mut providers = self.state.providers.lock().await; let provider_name = provider.object_name().to_string(); if providers.contains_key(&provider_name) { @@ -569,6 +593,15 @@ impl OpenShell for TestOpenShell { credential_key: request.credential_key.clone(), expires_at_ms: request.expires_at_ms, }); + let configure_failure = self + .state + .fail_configure_refresh_message + .lock() + .await + .take(); + if let Some(message) = configure_failure { + return Err(Status::internal(message)); + } let providers = self.state.providers.lock().await; let provider = providers .get(&request.provider) @@ -602,21 +635,42 @@ impl OpenShell for TestOpenShell { request: tonic::Request, ) -> Result, Status> { let request = request.into_inner(); + let provider_name = request.provider.clone(); + let credential_key = request.credential_key.clone(); self.state .refresh_requests .lock() .await .push(ProviderRefreshRequestLog::Rotate { - provider_name: request.provider.clone(), - credential_key: request.credential_key.clone(), + provider_name: provider_name.clone(), + credential_key: credential_key.clone(), }); + let rotate_failure = self.state.fail_rotate_refresh_message.lock().await.take(); + if let Some(message) = rotate_failure { + return Err(Status::internal(message)); + } let mut refresh_statuses = self.state.refresh_statuses.lock().await; let status = refresh_statuses - .get_mut(&(request.provider, request.credential_key)) + .get_mut(&(provider_name.clone(), credential_key.clone())) .ok_or_else(|| Status::not_found("provider refresh state not found"))?; - status.status = "rotation_requested".to_string(); + status.status = "refreshed".to_string(); + status.last_refresh_at_ms = 1; + status.next_refresh_at_ms = 3_600_000; + status.expires_at_ms = 3_600_000; + let status = status.clone(); + drop(refresh_statuses); + let mut providers = self.state.providers.lock().await; + let provider = providers + .get_mut(&provider_name) + .ok_or_else(|| Status::not_found("provider not found"))?; + provider + .credentials + .insert(credential_key.clone(), format!("minted-{credential_key}")); + provider + .credential_expires_at_ms + .insert(credential_key, 3_600_000); Ok(Response::new(RotateProviderCredentialResponse { - status: Some(status.clone()), + status: Some(status), })) } @@ -648,6 +702,15 @@ impl OpenShell for TestOpenShell { request: tonic::Request, ) -> Result, Status> { let name = request.into_inner().name; + self.state + .delete_provider_requests + .lock() + .await + .push(name.clone()); + let delete_failure = self.state.fail_delete_provider_message.lock().await.take(); + if let Some(message) = delete_failure { + return Err(Status::internal(message)); + } let deleted = self.state.providers.lock().await.remove(&name).is_some(); Ok(Response::new(DeleteProviderResponse { deleted })) } @@ -924,6 +987,7 @@ async fn provider_cli_run_functions_support_full_crud_flow() { "claude", false, &["API_KEY=abc".to_string()], + false, &["profile=dev".to_string()], &ts.tls, ) @@ -973,6 +1037,7 @@ async fn provider_refresh_cli_run_functions_wire_requests() { "outlook", false, &["MS_GRAPH_ACCESS_TOKEN=token".to_string()], + false, &[], &ts.tls, ) @@ -1059,6 +1124,11 @@ async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { required: true, ..Default::default() }, + openshell_core::proto::ProviderCredentialRefreshMaterial { + name: "client_secret".to_string(), + secret: true, + ..Default::default() + }, openshell_core::proto::ProviderCredentialRefreshMaterial { name: "subject_token".to_string(), secret: true, @@ -1079,6 +1149,7 @@ async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { "okta-obo", false, &[], + false, &[], &ts.tls, ) @@ -1094,6 +1165,7 @@ async fn provider_refresh_cli_supports_oauth2_token_exchange_strategy() { material: &[ "client_id=client-id".to_string(), "audience=api://downstream".to_string(), + "subject_token=user-token".to_string(), "scope=api:access:read".to_string(), ], secret_material_keys: &["client_secret".to_string()], @@ -1142,6 +1214,7 @@ async fn provider_create_allows_empty_credentials_for_gateway_refresh_profiles() "custom-refresh", false, &[], + false, &[], &ts.tls, ) @@ -1164,6 +1237,7 @@ async fn sandbox_provider_cli_run_functions_wire_requests_and_idempotent_results "github", false, &["GITHUB_TOKEN=ghp-test".to_string()], + false, &[], &ts.tls, ) @@ -1282,6 +1356,7 @@ binaries: [/usr/bin/custom] "custom-api", false, &["CUSTOM_API_KEY=abc".to_string()], + false, &[], &ts.tls, ) @@ -1335,6 +1410,7 @@ async fn provider_create_from_existing_uses_profile_discovery_when_v2_enabled() "custom-discovery", true, &[], + false, &[], &ts.tls, ) @@ -1367,6 +1443,7 @@ async fn provider_create_from_existing_uses_registry_discovery_when_v2_disabled( "openai", true, &[], + false, &[], &ts.tls, ) @@ -1389,21 +1466,94 @@ async fn provider_create_from_existing_uses_registry_discovery_when_v2_disabled( } #[tokio::test] -async fn provider_create_from_existing_requires_profile_when_v2_enabled() { +async fn provider_create_from_existing_vertex_discovers_credentials_and_config_when_v2_enabled() { let ts = run_server().await; enable_providers_v2(&ts).await; - let _env = EnvVarGuard::set(&[("OPENAI_API_KEY", "legacy-openai-secret")]); + let _env = EnvVarGuard::set(&[ + ("VERTEX_AI_TOKEN", "ya29.vertex-v2-fallback"), + ("VERTEX_AI_PROJECT_ID", "vertex-v2-project"), + ("VERTEX_AI_REGION", "europe-west4"), + ( + "GOOGLE_VERTEX_AI_BASE_URL", + "https://aiplatform.googleapis.com/v1beta1/projects/vertex-v2-project/locations/global/endpoints/openapi", + ), + ("VERTEX_AI_PUBLISHER", "anthropic"), + ]); - let err = run::provider_create(&ts.endpoint, "v2-openai", "openai", true, &[], &[], &ts.tls) + run::provider_create( + &ts.endpoint, + "vertex-v2-discovered", + "google-vertex-ai", + true, + &[], + false, + &[], + &ts.tls, + ) + .await + .expect("vertex provider create --from-existing with v2 enabled"); + + let provider = ts + .state + .providers + .lock() .await - .expect_err("v2 discovery without a profile should fail"); + .get("vertex-v2-discovered") + .cloned() + .expect("vertex provider should be stored"); + assert_eq!(provider.r#type, "google-vertex-ai"); + assert_eq!( + provider.credentials.get("VERTEX_AI_TOKEN"), + Some(&"ya29.vertex-v2-fallback".to_string()) + ); + assert_eq!( + provider.config.get("VERTEX_AI_PROJECT_ID"), + Some(&"vertex-v2-project".to_string()) + ); + assert_eq!( + provider.config.get("VERTEX_AI_REGION"), + Some(&"europe-west4".to_string()) + ); + assert_eq!( + provider.config.get("GOOGLE_VERTEX_AI_BASE_URL"), + Some( + &"https://aiplatform.googleapis.com/v1beta1/projects/vertex-v2-project/locations/global/endpoints/openapi" + .to_string() + ) + ); + assert_eq!( + provider.config.get("VERTEX_AI_PUBLISHER"), + Some(&"anthropic".to_string()) + ); +} + +#[tokio::test] +async fn provider_create_from_existing_requires_profile_when_v2_enabled() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + // Use "generic" which is a normalised type but has no built-in provider + // profile, so v2 profile-based discovery fails with the expected message. + let _env = EnvVarGuard::set(&[("GENERIC_API_KEY", "some-secret")]); + + let err = run::provider_create( + &ts.endpoint, + "v2-generic", + "generic", + true, + &[], + false, + &[], + &ts.tls, + ) + .await + .expect_err("v2 discovery without a profile should fail"); assert!( err.to_string() .contains("providers v2 discovery requires a provider profile"), "unexpected error: {err}" ); - assert!(!ts.state.providers.lock().await.contains_key("v2-openai")); + assert!(!ts.state.providers.lock().await.contains_key("v2-generic")); } #[tokio::test] @@ -1434,6 +1584,7 @@ async fn provider_create_from_existing_fails_when_profile_discovery_finds_nothin "empty-discovery", true, &[], + false, &[], &ts.tls, ) @@ -1681,9 +1832,7 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { refresh .material .iter() - .any(|material| material.name == "subject_token" - && !material.required - && material.secret) + .any(|material| material.name == "client_id" && material.required) ); assert!( refresh @@ -1691,6 +1840,12 @@ async fn built_in_okta_obo_profile_is_available_via_provider_profile_api() { .iter() .any(|material| material.name == "audience" && material.required) ); + assert!( + refresh + .material + .iter() + .any(|material| material.name == "subject_token" && !material.required) + ); } #[tokio::test] @@ -1760,14 +1915,6 @@ async fn built_in_okta_xaa_profile_is_available_via_provider_profile_api() { && material.required && material.secret) ); - assert!( - refresh - .material - .iter() - .any(|material| material.name == "subject_token" - && !material.required - && material.secret) - ); } #[tokio::test] @@ -1854,6 +2001,7 @@ async fn provider_create_rejects_key_only_credentials_without_local_env_value() "claude", false, &["INVALID_PAIR".to_string()], + false, &[], &ts.tls, ) @@ -1878,6 +2026,7 @@ async fn provider_create_supports_generic_type_and_env_lookup_credentials() { "generic", false, &["NAV_GENERIC_TEST_KEY".to_string()], + false, &[], &ts.tls, ) @@ -1912,6 +2061,7 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { "claude", true, &["API_KEY=abc".to_string()], + false, &[], &ts.tls, ) @@ -1925,6 +2075,56 @@ async fn provider_create_rejects_combined_from_existing_and_credentials() { ); } +#[tokio::test] +async fn provider_create_rejects_combined_from_gcloud_adc_and_from_existing() { + let ts = run_server().await; + + let err = run::provider_create( + &ts.endpoint, + "bad-vertex-provider", + "google-vertex-ai", + true, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("from-gcloud-adc and from-existing should be mutually exclusive"); + + assert!( + err.to_string() + .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + "unexpected error: {err}" + ); + assert!(ts.state.providers.lock().await.is_empty()); +} + +#[tokio::test] +async fn provider_create_rejects_combined_from_gcloud_adc_and_credentials() { + let ts = run_server().await; + + let err = run::provider_create( + &ts.endpoint, + "bad-vertex-provider", + "google-vertex-ai", + false, + &["GOOGLE_VERTEX_AI_TOKEN=token".to_string()], + true, + &[], + &ts.tls, + ) + .await + .expect_err("from-gcloud-adc and credentials should be mutually exclusive"); + + assert!( + err.to_string() + .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + "unexpected error: {err}" + ); + assert!(ts.state.providers.lock().await.is_empty()); +} + #[tokio::test] async fn provider_create_rejects_empty_env_var_for_key_only_credential() { let ts = run_server().await; @@ -1936,6 +2136,7 @@ async fn provider_create_rejects_empty_env_var_for_key_only_credential() { "generic", false, &["NAV_EMPTY_ENV_KEY".to_string()], + false, &[], &ts.tls, ) @@ -1960,6 +2161,7 @@ async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { "nvidia", false, &["NVIDIA_API_KEY".to_string()], + false, &[], &ts.tls, ) @@ -1983,3 +2185,542 @@ async fn provider_create_supports_nvidia_type_with_nvidia_api_key() { Some(&"nvapi-live-test".to_string()) ); } + +// ── --from-gcloud-adc tests ─────────────────────────────────────────────────── + +#[tokio::test] +async fn provider_create_from_gcloud_adc_happy_path() { + let ts = run_server().await; + + // Write a temp ADC file simulating a valid authorized_user credential. + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + + // Point GOOGLE_APPLICATION_CREDENTIALS at the temp file so read_gcloud_adc + // picks it up without touching the real ~/.config/gcloud/ path. + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + run::provider_create( + &ts.endpoint, + "my-vertex", + "google-vertex-ai", + false, + &[], // no explicit credentials; refresh bootstrap covers it + true, // from_gcloud_adc + &[], + &ts.tls, + ) + .await + .expect("provider_create with --from-gcloud-adc should succeed"); + + // Provider must exist in the server state. + let providers = ts.state.providers.lock().await; + let provider = providers + .get("my-vertex") + .expect("provider should be stored after create"); + assert_eq!(provider.r#type, "google-vertex-ai"); + assert_eq!( + provider + .credentials + .get("GOOGLE_VERTEX_AI_TOKEN") + .map(String::as_str), + Some("minted-GOOGLE_VERTEX_AI_TOKEN"), + "initial rotate should materialize a usable access token" + ); + drop(providers); + + // ADC bootstrap must configure refresh and immediately mint the first token. + let requests = ts.state.refresh_requests.lock().await.clone(); + assert_eq!( + requests.len(), + 2, + "expected configure + rotate refresh requests" + ); + assert_eq!( + requests[0], + ProviderRefreshRequestLog::Configure { + provider_name: "my-vertex".to_string(), + credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), + expires_at_ms: None, + } + ); + assert_eq!( + requests[1], + ProviderRefreshRequestLog::Rotate { + provider_name: "my-vertex".to_string(), + credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), + } + ); + + // The refresh status must record the ADC material keys. + let refresh_statuses = ts.state.refresh_statuses.lock().await; + let status = refresh_statuses + .get(&( + "my-vertex".to_string(), + "GOOGLE_VERTEX_AI_TOKEN".to_string(), + )) + .expect("refresh status should be stored"); + assert_eq!( + status.strategy, + ProviderCredentialRefreshStrategy::Oauth2RefreshToken as i32 + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_rejects_service_account() { + let ts = run_server().await; + + // Write a temp ADC file with type=service_account. + let adc_content = serde_json::json!({ + "type": "service_account", + "project_id": "my-project", + "private_key_id": "key-id", + "private_key": "-----BEGIN RSA PRIVATE KEY-----\n...", + "client_email": "sa@my-project.iam.gserviceaccount.com" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "my-vertex-sa", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("service_account ADC should be rejected"); + + assert!( + err.to_string() + .contains("GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), + "error should mention the service-account token key, got: {err}" + ); + + // create_provider must NOT have been called — no provider stored. + let providers = ts.state.providers.lock().await; + assert!( + providers.is_empty(), + "no provider should have been created on pre-flight failure" + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_missing_file() { + let ts = run_server().await; + + // Point to a path that does not exist. + let _guard = EnvVarGuard::set(&[( + "GOOGLE_APPLICATION_CREDENTIALS", + "/tmp/nonexistent-adc-file-openshell-test.json", + )]); + + let err = run::provider_create( + &ts.endpoint, + "my-vertex-missing", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("missing ADC file should produce an error"); + + // Error must mention the file path or the read failure. + let msg = err.to_string(); + assert!( + msg.contains("nonexistent-adc-file-openshell-test.json") + || msg.contains("failed to read gcloud ADC file"), + "error should reference the missing file, got: {msg}" + ); + + // create_provider must NOT have been called — no provider stored. + let providers = ts.state.providers.lock().await; + assert!( + providers.is_empty(), + "no provider should have been created on pre-flight failure" + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_rejects_wrong_provider_type_before_credential_check() { + let ts = run_server().await; + + let err = run::provider_create( + &ts.endpoint, + "my-openai-adc", + "openai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("wrong provider type should fail before generic credential validation"); + + assert!( + err.to_string() + .contains("--from-gcloud-adc is only valid for google-vertex-ai providers"), + "unexpected error: {err}" + ); + assert!(ts.state.providers.lock().await.is_empty()); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_rolls_back_provider_when_refresh_configure_fails() { + let ts = run_server().await; + *ts.state.fail_configure_refresh_message.lock().await = + Some("simulated configure failure".to_string()); + + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-rollback", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("configure_provider_refresh failure should bubble up"); + + assert!( + err.to_string().contains("simulated configure failure"), + "unexpected error: {err}" + ); + assert!( + !ts.state + .providers + .lock() + .await + .contains_key("vertex-rollback"), + "provider should be deleted on rollback" + ); + assert_eq!( + ts.state.delete_provider_requests.lock().await.clone(), + vec!["vertex-rollback".to_string()] + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_warn_path_keeps_provider_when_rollback_delete_fails() { + let ts = run_server().await; + *ts.state.fail_configure_refresh_message.lock().await = + Some("simulated configure failure".to_string()); + *ts.state.fail_delete_provider_message.lock().await = + Some("simulated delete failure".to_string()); + + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-cleanup-warning", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("cleanup failure path should still return configure error"); + + assert!( + err.to_string().contains("simulated configure failure"), + "unexpected error: {err}" + ); + assert!( + ts.state + .providers + .lock() + .await + .contains_key("vertex-cleanup-warning"), + "provider should remain when rollback deletion fails" + ); + assert_eq!( + ts.state.delete_provider_requests.lock().await.clone(), + vec!["vertex-cleanup-warning".to_string()] + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_rolls_back_provider_when_initial_rotate_fails() { + let ts = run_server().await; + *ts.state.fail_rotate_refresh_message.lock().await = + Some("simulated rotate failure".to_string()); + + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-rotate-rollback", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("initial rotate failure should roll back the provider"); + + assert!( + err.to_string().contains("simulated rotate failure"), + "unexpected error: {err}" + ); + assert!( + !ts.state + .providers + .lock() + .await + .contains_key("vertex-rotate-rollback"), + "provider should be deleted on initial-rotate rollback" + ); + assert_eq!( + ts.state.delete_provider_requests.lock().await.clone(), + vec!["vertex-rotate-rollback".to_string()] + ); +} + +#[tokio::test] +async fn provider_create_from_existing_vertex_config_only_reports_missing_vertex_credentials() { + let ts = run_server().await; + enable_providers_v2(&ts).await; + let _env = EnvVarGuard::set(&[ + ("VERTEX_AI_PROJECT_ID", "vertex-config-only-project"), + ("VERTEX_AI_REGION", "us-central1"), + ]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-config-only", + "google-vertex-ai", + true, + &[], + false, + &[], + &ts.tls, + ) + .await + .expect_err("config-only discovery should surface missing credential guidance"); + + let msg = err.to_string(); + assert!( + msg.contains("GOOGLE_VERTEX_AI_TOKEN") && msg.contains("VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), + "unexpected error: {msg}" + ); + assert!( + !ts.state + .providers + .lock() + .await + .contains_key("vertex-config-only") + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_with_config_keys() { + let ts = run_server().await; + + // Write a valid authorized_user ADC file. + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + run::provider_create( + &ts.endpoint, + "vertex-with-config", + "google-vertex-ai", + false, + &[], // no explicit credentials; ADC flow + true, // from_gcloud_adc + &[ + "VERTEX_AI_PROJECT_ID=my-gcp-project".to_string(), + "VERTEX_AI_REGION=us-east1".to_string(), + ], + &ts.tls, + ) + .await + .expect("provider_create with --from-gcloud-adc and --config keys should succeed"); + + // Verify provider was created with the config keys. + let providers = ts.state.providers.lock().await; + let provider = providers + .get("vertex-with-config") + .expect("provider should be stored after create"); + assert_eq!(provider.r#type, "google-vertex-ai"); + assert_eq!( + provider + .config + .get("VERTEX_AI_PROJECT_ID") + .map(String::as_str), + Some("my-gcp-project"), + "VERTEX_AI_PROJECT_ID must be stored in provider config" + ); + assert_eq!( + provider.config.get("VERTEX_AI_REGION").map(String::as_str), + Some("us-east1"), + "VERTEX_AI_REGION must be stored in provider config" + ); + drop(providers); + + // ADC flow should configure refresh and eagerly mint the initial token. + let refresh_requests = ts.state.refresh_requests.lock().await.clone(); + assert_eq!( + refresh_requests.len(), + 2, + "exactly one configure call and one rotate call expected" + ); + assert_eq!( + refresh_requests[0], + ProviderRefreshRequestLog::Configure { + provider_name: "vertex-with-config".to_string(), + credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), + expires_at_ms: None, + } + ); + assert_eq!( + refresh_requests[1], + ProviderRefreshRequestLog::Rotate { + provider_name: "vertex-with-config".to_string(), + credential_key: "GOOGLE_VERTEX_AI_TOKEN".to_string(), + } + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_missing_refresh_token() { + let ts = run_server().await; + + // ADC file is valid authorized_user type but missing refresh_token. + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "client_secret": "test-client-secret" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-missing-refresh", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("missing refresh_token should produce an error"); + + let err_msg = err.to_string(); + assert!( + err_msg.contains("refresh_token"), + "error must mention 'refresh_token', got: {err_msg}" + ); + + // No provider should have been created. + let providers = ts.state.providers.lock().await; + assert!( + providers.is_empty(), + "no provider must be created when ADC validation fails" + ); +} + +#[tokio::test] +async fn provider_create_from_gcloud_adc_missing_client_secret() { + let ts = run_server().await; + + // ADC file is valid authorized_user type but missing client_secret. + let adc_content = serde_json::json!({ + "type": "authorized_user", + "client_id": "test-client-id.apps.googleusercontent.com", + "refresh_token": "1//test-refresh-token" + }); + let adc_file = tempfile::NamedTempFile::new().unwrap(); + serde_json::to_writer(&adc_file, &adc_content).unwrap(); + let adc_path = adc_file.path().to_str().unwrap().to_string(); + let _guard = EnvVarGuard::set(&[("GOOGLE_APPLICATION_CREDENTIALS", &adc_path)]); + + let err = run::provider_create( + &ts.endpoint, + "vertex-missing-secret", + "google-vertex-ai", + false, + &[], + true, + &[], + &ts.tls, + ) + .await + .expect_err("missing client_secret should produce an error"); + + let err_msg = err.to_string(); + assert!( + err_msg.contains("client_secret"), + "error must mention 'client_secret', got: {err_msg}" + ); + + // No provider should have been created. + let providers = ts.state.providers.lock().await; + assert!( + providers.is_empty(), + "no provider must be created when ADC validation fails" + ); +} diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index 78533e1e0..af26f73ae 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -6,7 +6,7 @@ //! These traits provide uniform access to `ObjectMeta` fields across all resource types. use crate::proto::{ - InferenceRoute, ObjectForTest, Provider, Sandbox, ServiceEndpoint, SshSession, + InferenceRoute, ObjectForTest, Provider, Sandbox, SandboxStatus, ServiceEndpoint, SshSession, StoredProviderCredentialRefreshState, StoredProviderProfile, }; use std::collections::HashMap; @@ -69,6 +69,26 @@ impl GetResourceVersion for Sandbox { } } +impl Sandbox { + pub fn phase(&self) -> i32 { + self.status.as_ref().map_or(0, |s| s.phase) + } + + pub fn set_phase(&mut self, phase: i32) { + self.status.get_or_insert_with(SandboxStatus::default).phase = phase; + } + + pub fn current_policy_version(&self) -> u32 { + self.status.as_ref().map_or(0, |s| s.current_policy_version) + } + + pub fn set_current_policy_version(&mut self, version: u32) { + self.status + .get_or_insert_with(SandboxStatus::default) + .current_policy_version = version; + } +} + // Implementations for Provider impl ObjectId for Provider { fn object_id(&self) -> &str { diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 5f04bd405..912bad4eb 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -19,6 +19,7 @@ use std::sync::OnceLock; const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/claude-code.yaml"), include_str!("../../../providers/github.yaml"), + include_str!("../../../providers/google-vertex-ai.yaml"), include_str!("../../../providers/nvidia.yaml"), include_str!("../../../providers/okta-obo.yaml"), include_str!("../../../providers/okta-xaa.yaml"), @@ -309,6 +310,25 @@ impl ProviderTypeProfile { vars } + /// Whether this profile can be created without an initial access token because + /// the gateway can mint at least one credential immediately from refresh + /// material, and no required credential falls outside that gateway-mintable set. + #[must_use] + pub fn allows_gateway_refresh_bootstrap(&self) -> bool { + let mut has_gateway_mintable_credential = false; + for credential in &self.credentials { + let is_gateway_mintable = credential + .refresh + .as_ref() + .is_some_and(CredentialRefreshProfile::is_gateway_mintable); + if credential.required && !is_gateway_mintable { + return false; + } + has_gateway_mintable_credential |= is_gateway_mintable; + } + has_gateway_mintable_credential + } + #[must_use] pub fn to_proto(&self) -> ProviderProfile { ProviderProfile { @@ -348,6 +368,20 @@ impl ProviderTypeProfile { } } +impl CredentialRefreshProfile { + #[must_use] + pub fn is_gateway_mintable(&self) -> bool { + matches!( + self.strategy, + ProviderCredentialRefreshStrategy::Oauth2RefreshToken + | ProviderCredentialRefreshStrategy::Oauth2ClientCredentials + | ProviderCredentialRefreshStrategy::Oauth2TokenExchange + | ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt + | ProviderCredentialRefreshStrategy::OktaXaa + ) + } +} + fn discovery_is_empty(discovery: &DiscoveryProfile) -> bool { discovery.credentials.is_empty() } @@ -602,6 +636,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { .collect(), graphql_max_body_bytes: endpoint.graphql_max_body_bytes, path: endpoint.path.clone(), + advisor_proposed: false, } } diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index fca0c388f..0122f9178 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -15,6 +15,7 @@ use crate::sandbox_watch::SandboxWatchBus; use crate::supervisor_session::SupervisorSessionRegistry; use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; +use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, @@ -94,11 +95,6 @@ const RECONCILE_INTERVAL: Duration = Duration::from_secs(60); /// corresponding backend resource before it is considered orphaned. const ORPHAN_GRACE_PERIOD: Duration = Duration::from_secs(300); -/// Sandbox store updates can race during startup when the driver watch stream -/// and supervisor session callbacks both try to flip a sandbox to Ready. -/// Retry a few times instead of leaving the record stuck in Provisioning. -const SANDBOX_CAS_RETRIES: usize = 4; - // Re-export the shared error type under the name used by this module. pub use openshell_core::ComputeDriverError as ComputeError; @@ -224,6 +220,7 @@ impl ComputeDriver for RemoteComputeDriver { #[derive(Clone)] pub struct ComputeRuntime { driver: SharedComputeDriver, + driver_kind: Option, shutdown_cleanup: Option>, startup_resume: Option>, _driver_process: Option>, @@ -246,6 +243,7 @@ impl fmt::Debug for ComputeRuntime { impl ComputeRuntime { #[allow(clippy::too_many_arguments)] async fn from_driver( + driver_kind: ComputeDriverKind, driver: SharedComputeDriver, shutdown_cleanup: Option>, startup_resume: Option>, @@ -266,6 +264,7 @@ impl ComputeRuntime { .default_image; Ok(Self { driver, + driver_kind: Some(driver_kind), shutdown_cleanup, startup_resume, _driver_process: driver_process, @@ -309,6 +308,7 @@ impl ComputeRuntime { let startup_resume: Arc = driver.clone(); let driver: SharedComputeDriver = driver; Self::from_driver( + ComputeDriverKind::Docker, driver, Some(shutdown_cleanup), Some(startup_resume), @@ -337,6 +337,7 @@ impl ComputeRuntime { .map_err(|err| ComputeError::Message(err.to_string()))?; let driver: SharedComputeDriver = Arc::new(ComputeDriverService::new(driver)); Self::from_driver( + ComputeDriverKind::Kubernetes, driver, None, None, @@ -363,6 +364,7 @@ impl ComputeRuntime { ) -> Result { let driver: SharedComputeDriver = Arc::new(RemoteComputeDriver::new(channel)); Self::from_driver( + ComputeDriverKind::Vm, driver, None, None, @@ -391,6 +393,7 @@ impl ComputeRuntime { .map_err(|err| ComputeError::Message(err.to_string()))?; let driver: SharedComputeDriver = Arc::new(PodmanDriverService::new(driver)); Self::from_driver( + ComputeDriverKind::Podman, driver, None, None, @@ -411,6 +414,11 @@ impl ComputeRuntime { &self.default_image } + #[must_use] + pub fn driver_kind(&self) -> Option { + self.driver_kind + } + #[must_use] pub fn gateway_bind_addresses(&self) -> &[SocketAddr] { &self.gateway_bind_addresses @@ -512,6 +520,8 @@ impl ComputeRuntime { } pub async fn delete_sandbox(&self, name: &str) -> Result { + let _guard = self.sync_lock.lock().await; + // Resolve sandbox ID from name let sandbox = self .store @@ -530,7 +540,7 @@ impl ComputeRuntime { let sandbox = self .store .update_message_cas::(&id, 0, |s| { - s.phase = SandboxPhase::Deleting as i32; + s.set_phase(SandboxPhase::Deleting as i32); }) .await .map_err(|e| { @@ -613,7 +623,7 @@ impl ComputeRuntime { } }; - let phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + let phase = SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); if !sandbox_phase_should_be_running(phase) { continue; } @@ -687,7 +697,7 @@ impl ComputeRuntime { match self .store .update_message_cas::(&sandbox_id, 0, |s| { - s.phase = SandboxPhase::Error as i32; + s.set_phase(SandboxPhase::Error as i32); let name = s.object_name().to_string(); upsert_ready_condition( &mut s.status, @@ -865,21 +875,25 @@ impl ComputeRuntime { use crate::persistence::WriteCondition; let now_ms = openshell_core::time::now_ms(); - let mut status = incoming.status.as_ref().map(public_status_from_driver); - rewrite_user_facing_conditions(&mut status, None); - let session_connected = self.supervisor_sessions.has_session(&incoming.id); let mut phase = derive_phase(incoming.status.as_ref()); - let sandbox_name = incoming.name.clone(); - if session_connected - && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) - { - ensure_supervisor_ready_status(&mut status, &sandbox_name); + + let supervisor_promoted = session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown); + if supervisor_promoted { phase = SandboxPhase::Ready; } - let sandbox = Sandbox { + let mut status = incoming + .status + .as_ref() + .map(|s| public_status_from_driver(s, phase, 0)); + rewrite_user_facing_conditions(&mut status, None); + if supervisor_promoted { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + } + let mut sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: incoming.id.clone(), name: sandbox_name, @@ -889,9 +903,8 @@ impl ComputeRuntime { }), spec: None, status, - phase: phase as i32, - current_policy_version: 0, }; + sandbox.set_phase(phase as i32); self.store .put_if( @@ -919,82 +932,90 @@ impl ComputeRuntime { return Ok(()); } + // Single-attempt CAS: on conflict, the next watch event will naturally retry let session_connected = self.supervisor_sessions.has_session(&incoming.id); let sandbox_name = incoming.name.clone(); - let mut attempts = 0usize; - let sandbox = loop { - match self - .store - .update_message_cas::(&incoming.id, 0, |sandbox| { - let mut status = incoming.status.as_ref().map(public_status_from_driver); - rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); - - let mut phase = derive_phase(incoming.status.as_ref()); - if session_connected - && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) - { - ensure_supervisor_ready_status(&mut status, &sandbox_name); - phase = SandboxPhase::Ready; - } - let old_phase = - SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase != phase { - info!( - sandbox_id = %incoming.id, - sandbox_name = %sandbox_name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); - } + let sandbox = self + .store + .update_message_cas::(&incoming.id, 0, |sandbox| { + let old_phase = + SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); + let mut phase = incoming + .status + .as_ref() + .map_or(old_phase, |status| derive_phase(Some(status))); + let supervisor_promoted = session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown); + if supervisor_promoted { + phase = SandboxPhase::Ready; + } + + let cpv = sandbox.current_policy_version(); + let mut status = incoming + .status + .as_ref() + .map(|s| public_status_from_driver(s, phase, cpv)) + .or_else(|| sandbox.status.clone()); + rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); + if supervisor_promoted { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + } + + if let Some(s) = status.as_mut() + && s.sandbox_name.is_empty() + { + s.sandbox_name.clone_from(&sandbox_name); + } + + if old_phase != phase { + info!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" + ); + } - if phase == SandboxPhase::Error - && let Some(ref status) = status - { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_reason(&condition.reason) - { - warn!( - sandbox_id = %incoming.id, - sandbox_name = %sandbox_name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" - ); - } + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_reason(&condition.reason) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); } } + } - if let Some(metadata) = sandbox.metadata.as_mut() { - metadata.name.clone_from(&sandbox_name); - } - sandbox.status = status; - sandbox.phase = phase as i32; - }) - .await - { - Ok(sandbox) => break sandbox, - Err(crate::persistence::PersistenceError::Conflict { - current_resource_version: _, - }) if attempts + 1 < SANDBOX_CAS_RETRIES => { - attempts += 1; - continue; + // Update metadata fields + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.name.clone_from(&sandbox_name); } - Err(crate::persistence::PersistenceError::Conflict { + sandbox.status = status; + sandbox.set_phase(phase as i32); + sandbox.set_current_policy_version(cpv); + }) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { current_resource_version, - }) => { - return Err(format!( - "concurrent modification detected during sandbox reconciliation (current resource_version: {})", - current_resource_version - .map_or_else(|| "unknown".to_string(), |v| v.to_string()) - )); - } - Err(other) => return Err(other.to_string()), - } - }; + } => format!( + "concurrent modification detected during sandbox reconciliation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; self.sandbox_index.update_from_sandbox(&sandbox); self.sandbox_watch_bus.notify(sandbox.object_id()); @@ -1016,48 +1037,28 @@ impl ComputeRuntime { ) -> Result<(), String> { let _guard = self.sync_lock.lock().await; - let mut attempts = 0usize; - let result = loop { - match self - .store - .update_message_cas::(sandbox_id, 0, |sandbox| { - let current_phase = - SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - - if current_phase == SandboxPhase::Deleting - || current_phase == SandboxPhase::Error - { - return; - } + // Use CAS to update sandbox phase based on supervisor session state + let result = self + .store + .update_message_cas::(sandbox_id, 0, |sandbox| { + let current_phase = + SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); - let sandbox_name = sandbox.object_name().to_string(); - if connected { - ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Ready as i32; - } else if current_phase == SandboxPhase::Ready { - ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Provisioning as i32; - } - }) - .await - { - Ok(sandbox) => break Ok(sandbox), - Err(crate::persistence::PersistenceError::Conflict { - current_resource_version: _, - }) if attempts + 1 < SANDBOX_CAS_RETRIES => { - attempts += 1; - continue; + // Skip if sandbox is in terminal state + if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { + return; } - Err(crate::persistence::PersistenceError::Conflict { - current_resource_version, - }) => { - break Err(crate::persistence::PersistenceError::Conflict { - current_resource_version, - }); + + let sandbox_name = sandbox.object_name().to_string(); + if connected { + ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.set_phase(SandboxPhase::Ready as i32); + } else if current_phase == SandboxPhase::Ready { + ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.set_phase(SandboxPhase::Provisioning as i32); } - Err(other) => break Err(other), - } - }; + }) + .await; // Handle not found gracefully (sandbox may have been deleted) let sandbox = match result { @@ -1254,10 +1255,7 @@ fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { name: sandbox.object_name().to_string(), namespace: String::new(), // Namespace is set by the driver based on its config spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), - status: sandbox - .status - .as_ref() - .map(|status| driver_status_from_public(status, sandbox.phase)), + status: sandbox.status.as_ref().map(driver_status_from_public), } } @@ -1464,7 +1462,7 @@ fn build_platform_resources_config( } } -fn driver_status_from_public(status: &SandboxStatus, phase: i32) -> DriverSandboxStatus { +fn driver_status_from_public(status: &SandboxStatus) -> DriverSandboxStatus { DriverSandboxStatus { sandbox_name: status.sandbox_name.clone(), instance_id: status.agent_pod.clone(), @@ -1475,7 +1473,7 @@ fn driver_status_from_public(status: &SandboxStatus, phase: i32) -> DriverSandbo .iter() .map(driver_condition_from_public) .collect(), - deleting: SandboxPhase::try_from(phase) == Ok(SandboxPhase::Deleting), + deleting: SandboxPhase::try_from(status.phase) == Ok(SandboxPhase::Deleting), } } @@ -1507,7 +1505,11 @@ fn decode_sandbox_record(record: &ObjectRecord) -> Result { Sandbox::decode(record.payload.as_slice()).map_err(|e| e.to_string()) } -fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { +fn public_status_from_driver( + status: &DriverSandboxStatus, + phase: SandboxPhase, + current_policy_version: u32, +) -> SandboxStatus { SandboxStatus { sandbox_name: status.sandbox_name.clone(), agent_pod: status.instance_id.clone(), @@ -1518,6 +1520,8 @@ fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { .iter() .map(public_condition_from_driver) .collect(), + phase: phase as i32, + current_policy_version, } } @@ -1556,10 +1560,7 @@ fn upsert_ready_condition( ) { let status = status.get_or_insert_with(|| SandboxStatus { sandbox_name: sandbox_name.to_string(), - agent_pod: String::new(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: Vec::new(), + ..Default::default() }); if let Some(existing) = status @@ -1683,8 +1684,6 @@ impl ComputeDriver for NoopTestDriver { driver_name: "noop-test-driver".to_string(), driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), - supports_gpu: false, - gpu_count: 0, }, )) } @@ -1763,6 +1762,7 @@ impl ComputeDriver for NoopTestDriver { pub async fn new_test_runtime(store: Arc) -> ComputeRuntime { ComputeRuntime { driver: Arc::new(NoopTestDriver), + driver_kind: None, shutdown_cleanup: None, startup_resume: None, _driver_process: None, @@ -1832,8 +1832,6 @@ mod tests { driver_name: "test-driver".to_string(), driver_version: "test".to_string(), default_image: "openshell/sandbox:test".to_string(), - supports_gpu: true, - gpu_count: 0, })) } @@ -1930,6 +1928,7 @@ mod tests { let store = Arc::new(Store::connect("sqlite::memory:").await.unwrap()); ComputeRuntime { driver, + driver_kind: None, shutdown_cleanup: None, startup_resume, _driver_process: None, @@ -1956,7 +1955,7 @@ mod tests { } fn sandbox_record(id: &str, name: &str, phase: SandboxPhase) -> Sandbox { - Sandbox { + let mut sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: id.to_string(), name: name.to_string(), @@ -1964,9 +1963,10 @@ mod tests { labels: HashMap::new(), resource_version: 0, }), - phase: phase as i32, ..Default::default() - } + }; + sandbox.set_phase(phase as i32); + sandbox } fn ssh_session_record(id: &str, sandbox_id: &str) -> SshSession { @@ -2245,8 +2245,6 @@ mod tests { let mut status = Some(SandboxStatus { sandbox_name: "test".to_string(), agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "False".to_string(), @@ -2254,6 +2252,7 @@ mod tests { message: "0/1 nodes are available: 1 Insufficient nvidia.com/gpu.".to_string(), last_transition_time: String::new(), }], + ..Default::default() }); rewrite_user_facing_conditions( @@ -2277,8 +2276,6 @@ mod tests { let mut status = Some(SandboxStatus { sandbox_name: "test".to_string(), agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "False".to_string(), @@ -2286,6 +2283,7 @@ mod tests { message: original.to_string(), last_transition_time: String::new(), }], + ..Default::default() }); rewrite_user_facing_conditions( @@ -2349,9 +2347,65 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), + SandboxPhase::Ready + ); + } + + #[tokio::test] + async fn apply_sandbox_update_without_status_preserves_existing_status() { + let runtime = test_runtime(Arc::new(TestDriver::default())).await; + let mut sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready); + sandbox.status = Some(SandboxStatus { + sandbox_name: "sandbox-a".to_string(), + conditions: vec![SandboxCondition { + r#type: "Ready".to_string(), + status: "True".to_string(), + reason: "DependenciesReady".to_string(), + message: "Pod is Ready".to_string(), + last_transition_time: String::new(), + }], + current_policy_version: 7, + ..Default::default() + }); + sandbox.set_phase(SandboxPhase::Ready as i32); + runtime.store.put_message(&sandbox).await.unwrap(); + + runtime + .apply_sandbox_update(DriverSandbox { + id: "sb-1".to_string(), + name: "sandbox-a".to_string(), + namespace: "default".to_string(), + spec: None, + status: None, + }) + .await + .unwrap(); + + let stored = runtime + .store + .get_message::("sb-1") + .await + .unwrap() + .unwrap(); + assert_eq!( + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); + assert_eq!(stored.current_policy_version(), 7); + let ready = stored + .status + .as_ref() + .and_then(|status| { + status + .conditions + .iter() + .find(|condition| condition.r#type == "Ready") + }) + .unwrap(); + assert_eq!(ready.status, "True"); + assert_eq!(ready.reason, "DependenciesReady"); + assert_eq!(ready.message, "Pod is Ready"); } #[tokio::test] @@ -2383,7 +2437,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); let ready = stored @@ -2416,7 +2470,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); } @@ -2427,9 +2481,6 @@ mod tests { let mut sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready); sandbox.status = Some(SandboxStatus { sandbox_name: "sandbox-a".to_string(), - agent_pod: String::new(), - agent_fd: String::new(), - sandbox_fd: String::new(), conditions: vec![SandboxCondition { r#type: "Ready".to_string(), status: "True".to_string(), @@ -2437,7 +2488,9 @@ mod tests { message: "Supervisor session connected".to_string(), last_transition_time: String::new(), }], + ..Default::default() }); + sandbox.set_phase(SandboxPhase::Ready as i32); runtime.store.put_message(&sandbox).await.unwrap(); runtime @@ -2452,7 +2505,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Provisioning ); let ready = stored @@ -2538,7 +2591,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); @@ -2631,7 +2684,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); } @@ -2794,7 +2847,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Error ); let ready = stored @@ -2826,7 +2879,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Error ); let ready = stored @@ -2853,7 +2906,7 @@ mod tests { .unwrap() .unwrap(); assert_eq!( - SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); } diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 4e74a7df1..c5a4ab20d 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -12,6 +12,9 @@ use crate::persistence::{ ObjectId, ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, }; use openshell_core::proto::{Provider, Sandbox}; +use openshell_core::telemetry::{ + LifecycleOperation, ProviderProfile as TelemetryProviderProfile, TelemetryOutcome, +}; use prost::Message; use tonic::Status; use tracing::warn; @@ -443,6 +446,14 @@ pub(super) async fn resolve_provider_environment( .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; for (key, value) in &provider.credentials { + if is_non_injectable_provider_credential(&provider, key) { + warn!( + provider_name = %name, + key = %key, + "skipping non-injectable provider credential" + ); + continue; + } if is_valid_env_key(key) { let expires_at_ms = provider .credential_expires_at_ms @@ -470,6 +481,53 @@ pub(super) async fn resolve_provider_environment( ); } } + + // For Vertex AI providers, inject agent-specific config env vars so that + // Claude Code, Goose, and OpenCode inside the sandbox can reach Vertex AI + // without additional configuration. Credentials from the loop above take + // precedence via entry().or_insert(), and sandbox --env overrides are + // applied at the process level after this environment is installed, so + // they naturally shadow these values. + if openshell_core::inference::normalize_inference_provider_type(&provider.r#type) + == Some("google-vertex-ai") + { + let project_id = provider + .config + .get(openshell_core::inference::VERTEX_AI_PROJECT_ID_KEY) + .map(String::as_str) + .unwrap_or_default() + .trim(); + let region = provider + .config + .get(openshell_core::inference::VERTEX_AI_REGION_KEY) + .map(String::as_str) + .unwrap_or_default() + .trim(); + + // Static flags -- always present for Vertex AI providers. + env.entry("GOOSE_PROVIDER".to_string()) + .or_insert_with(|| "gcp_vertex_ai".to_string()); + + // Project ID derived vars. + if !project_id.is_empty() { + env.entry("ANTHROPIC_VERTEX_PROJECT_ID".to_string()) + .or_insert_with(|| project_id.to_string()); + env.entry("GCP_PROJECT_ID".to_string()) + .or_insert_with(|| project_id.to_string()); + env.entry("GOOGLE_CLOUD_PROJECT".to_string()) + .or_insert_with(|| project_id.to_string()); + } + + // Region derived vars. + if !region.is_empty() { + env.entry("CLOUD_ML_REGION".to_string()) + .or_insert_with(|| region.to_string()); + env.entry("GCP_LOCATION".to_string()) + .or_insert_with(|| region.to_string()); + env.entry("VERTEX_LOCATION".to_string()) + .or_insert_with(|| region.to_string()); + } + } } Ok(ProviderEnvironment { @@ -590,6 +648,7 @@ fn active_provider_credential_keys(provider: &Provider, now_ms: i64) -> Vec Vec bool { + openshell_core::inference::normalize_inference_provider_type(&provider.r#type) + == Some("google-vertex-ai") + && key == "GOOGLE_SERVICE_ACCOUNT_KEY" +} + pub(super) fn is_valid_env_key(key: &str) -> bool { let mut bytes = key.bytes(); let Some(first) = bytes.next() else { @@ -642,7 +707,7 @@ use openshell_core::proto::{ }; use openshell_providers::{ CredentialRefreshProfile, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, - get_default_profile, normalize_profile_id, validate_profile_set, + get_default_profile, normalize_profile_id, normalize_provider_type, validate_profile_set, }; use std::sync::Arc; use tonic::{Request, Response}; @@ -652,14 +717,36 @@ pub(super) async fn handle_create_provider( request: Request, ) -> Result, Status> { let req = request.into_inner(); - let provider = req - .provider - .ok_or_else(|| Status::invalid_argument("provider is required"))?; - let provider = create_provider_record(state.store.as_ref(), provider).await?; - - Ok(Response::new(ProviderResponse { - provider: Some(provider), - })) + let Some(provider) = req.provider else { + emit_provider_lifecycle( + "custom", + LifecycleOperation::Create, + TelemetryOutcome::Failure, + ); + return Err(Status::invalid_argument("provider is required")); + }; + let provider_type = provider.r#type.clone(); + let result = create_provider_record(state.store.as_ref(), provider).await; + match result { + Ok(provider) => { + emit_provider_lifecycle( + &provider.r#type, + LifecycleOperation::Create, + TelemetryOutcome::Success, + ); + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) + } + Err(err) => { + emit_provider_lifecycle( + &provider_type, + LifecycleOperation::Create, + TelemetryOutcome::Failure, + ); + Err(err) + } + } } pub(super) async fn handle_get_provider( @@ -896,17 +983,7 @@ async fn provider_type_allows_empty_credentials_for_refresh( let Some(profile) = get_provider_type_profile(store, provider_type).await? else { return Ok(false); }; - let required_credentials = profile - .credentials - .iter() - .filter(|credential| credential.required) - .collect::>(); - Ok(!required_credentials.is_empty() - && required_credentials.iter().all(|credential| { - credential.refresh.as_ref().is_some_and(|refresh| { - crate::provider_refresh::is_gateway_mintable_strategy(refresh.strategy) - }) - })) + Ok(profile.allows_gateway_refresh_bootstrap()) } async fn merged_provider_profiles(store: &Store) -> Result, Status> { @@ -1085,17 +1162,39 @@ pub(super) async fn handle_update_provider( request: Request, ) -> Result, Status> { let req = request.into_inner(); - let mut provider = req - .provider - .ok_or_else(|| Status::invalid_argument("provider is required"))?; + let Some(mut provider) = req.provider else { + emit_provider_lifecycle( + "custom", + LifecycleOperation::Update, + TelemetryOutcome::Failure, + ); + return Err(Status::invalid_argument("provider is required")); + }; + let provider_type = provider.r#type.clone(); provider .credential_expires_at_ms .extend(req.credential_expires_at_ms); - let provider = update_provider_record(state.store.as_ref(), provider).await?; - - Ok(Response::new(ProviderResponse { - provider: Some(provider), - })) + let result = update_provider_record(state.store.as_ref(), provider).await; + match result { + Ok(provider) => { + emit_provider_lifecycle( + &provider.r#type, + LifecycleOperation::Update, + TelemetryOutcome::Success, + ); + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) + } + Err(err) => { + emit_provider_lifecycle( + &provider_type, + LifecycleOperation::Update, + TelemetryOutcome::Failure, + ); + Err(err) + } + } } pub(super) async fn handle_get_provider_refresh_status( @@ -1522,9 +1621,69 @@ pub(super) async fn handle_delete_provider( request: Request, ) -> Result, Status> { let name = request.into_inner().name; - let deleted = delete_provider_record(state.store.as_ref(), &name).await?; + let provider_profile = provider_profile_for_name(state.store.as_ref(), &name).await; + let result = delete_provider_record(state.store.as_ref(), &name).await; + match result { + Ok(deleted) => { + let outcome = TelemetryOutcome::from_success(deleted); + emit_provider_profile_lifecycle( + provider_profile.unwrap_or(TelemetryProviderProfile::Custom), + LifecycleOperation::Delete, + outcome, + ); + Ok(Response::new(DeleteProviderResponse { deleted })) + } + Err(err) => { + emit_provider_profile_lifecycle( + provider_profile.unwrap_or(TelemetryProviderProfile::Custom), + LifecycleOperation::Delete, + TelemetryOutcome::Failure, + ); + Err(err) + } + } +} + +fn emit_provider_lifecycle( + provider_type: &str, + operation: LifecycleOperation, + outcome: TelemetryOutcome, +) { + let provider_profile = telemetry_provider_profile(provider_type); + emit_provider_profile_lifecycle(provider_profile, operation, outcome); +} + +fn emit_provider_profile_lifecycle( + provider_profile: TelemetryProviderProfile, + operation: LifecycleOperation, + outcome: TelemetryOutcome, +) { + openshell_core::telemetry::emit_provider_lifecycle(operation, outcome, provider_profile); +} + +async fn provider_profile_for_name(store: &Store, name: &str) -> Option { + store + .get_message_by_name::(name) + .await + .ok() + .flatten() + .map(|provider| telemetry_provider_profile(&provider.r#type)) +} - Ok(Response::new(DeleteProviderResponse { deleted })) +fn telemetry_provider_profile(provider_type: &str) -> TelemetryProviderProfile { + match normalize_provider_type(provider_type) { + Some("anthropic") => TelemetryProviderProfile::Anthropic, + Some("claude" | "claude-code") => TelemetryProviderProfile::Claude, + Some("codex") => TelemetryProviderProfile::Codex, + Some("copilot") => TelemetryProviderProfile::Copilot, + Some("github") => TelemetryProviderProfile::Github, + Some("gitlab") => TelemetryProviderProfile::Gitlab, + Some("nvidia") => TelemetryProviderProfile::Nvidia, + Some("openai") => TelemetryProviderProfile::Openai, + Some("opencode") => TelemetryProviderProfile::Opencode, + Some("outlook") => TelemetryProviderProfile::Outlook, + _ => TelemetryProviderProfile::Custom, + } } // --------------------------------------------------------------------------- @@ -1570,6 +1729,46 @@ mod tests { assert!(!is_valid_env_key("X;rm -rf /")); } + #[test] + fn telemetry_provider_profile_maps_unknown_to_custom() { + assert_eq!( + telemetry_provider_profile("CLAUDE"), + TelemetryProviderProfile::Claude + ); + assert_eq!( + telemetry_provider_profile("github"), + TelemetryProviderProfile::Github + ); + assert_eq!( + telemetry_provider_profile("gh"), + TelemetryProviderProfile::Github + ); + assert_eq!( + telemetry_provider_profile("glab"), + TelemetryProviderProfile::Gitlab + ); + assert_eq!( + telemetry_provider_profile("outlook"), + TelemetryProviderProfile::Outlook + ); + assert_eq!( + telemetry_provider_profile("generic"), + TelemetryProviderProfile::Custom + ); + assert_eq!( + telemetry_provider_profile("unknown-private"), + TelemetryProviderProfile::Custom + ); + assert_eq!( + telemetry_provider_profile("acme-internal"), + TelemetryProviderProfile::Custom + ); + assert_eq!( + telemetry_provider_profile("corp-llm-prod"), + TelemetryProviderProfile::Custom + ); + } + fn provider_with_values(name: &str, provider_type: &str) -> Provider { Provider { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { @@ -1713,6 +1912,7 @@ mod tests { vec![ "claude-code", "github", + "google-vertex-ai", "nvidia", "okta-obo", "okta-xaa", @@ -1839,13 +2039,15 @@ mod tests { #[tokio::test] async fn import_provider_profile_allows_legacy_provider_type_ids_without_built_in_profiles() { + // Use an ID that is not a built-in profile to test legacy import. + // "custom-llm" is not registered as a built-in and never will be. let state = test_server_state().await; let response = handle_import_provider_profiles( &state, Request::new(ImportProviderProfilesRequest { profiles: vec![ProviderProfileImportItem { - profile: Some(custom_profile("codex")), - source: "codex.yaml".to_string(), + profile: Some(custom_profile("custom-llm")), + source: "custom-llm.yaml".to_string(), }], }), ) @@ -1859,15 +2061,15 @@ mod tests { let imported = handle_get_provider_profile( &state, Request::new(GetProviderProfileRequest { - id: "codex".to_string(), + id: "custom-llm".to_string(), }), ) .await .unwrap() .into_inner() .profile - .expect("codex profile should be returned"); - assert_eq!(imported.id, "codex"); + .expect("custom-llm profile should be returned"); + assert_eq!(imported.id, "custom-llm"); } #[tokio::test] @@ -2285,6 +2487,68 @@ mod tests { ); } + #[tokio::test] + async fn configure_provider_refresh_accepts_vertex_service_account_token_key() { + let state = test_server_state().await; + create_provider_record( + state.store.as_ref(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-sa".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: std::iter::once(( + "GOOGLE_SERVICE_ACCOUNT_KEY".to_string(), + "{\"type\":\"service_account\"}".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let response = handle_configure_provider_refresh( + &state, + Request::new(ConfigureProviderRefreshRequest { + provider: "vertex-sa".to_string(), + credential_key: "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string(), + strategy: ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32, + material: HashMap::from([ + ( + "client_email".to_string(), + "sa@test-project.iam.gserviceaccount.com".to_string(), + ), + ( + "private_key".to_string(), + "-----BEGIN PRIVATE KEY-----\nkey\n-----END PRIVATE KEY-----".to_string(), + ), + ]), + secret_material_keys: vec!["private_key".to_string()], + expires_at_ms: None, + }), + ) + .await + .unwrap() + .into_inner() + .status + .expect("status"); + + assert_eq!( + response.credential_key, + "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN" + ); + assert_eq!( + response.strategy, + ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt as i32 + ); + } + #[tokio::test] async fn delete_provider_refresh_preserves_manually_updated_expiry() { let state = test_server_state().await; @@ -3154,6 +3418,26 @@ mod tests { .unwrap(); assert!(optional_static_empty.credentials.is_empty()); + let vertex_empty = create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-no-token-yet".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + assert!(vertex_empty.credentials.is_empty()); + let get_err = get_provider_record(store, "").await.unwrap_err(); assert_eq!(get_err.code(), Code::InvalidArgument); @@ -3604,6 +3888,257 @@ mod tests { assert!(err.message().contains("provider-b")); } + #[tokio::test] + async fn resolve_provider_env_injects_vertex_agent_config() { + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-local".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: std::iter::once(( + "GOOGLE_VERTEX_AI_TOKEN".to_string(), + "ya29.token".to_string(), + )) + .collect(), + config: [ + ( + "VERTEX_AI_PROJECT_ID".to_string(), + "my-gcp-project".to_string(), + ), + ("VERTEX_AI_REGION".to_string(), "us-central1".to_string()), + ] + .into_iter() + .collect(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment(&store, &["vertex-local".to_string()]) + .await + .unwrap(); + + // Credential still injected. + assert_eq!( + result.get("GOOGLE_VERTEX_AI_TOKEN"), + Some(&"ya29.token".to_string()) + ); + // Static flags. + assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); + assert_eq!( + result.get("GOOSE_PROVIDER"), + Some(&"gcp_vertex_ai".to_string()) + ); + // Project ID derived vars. + assert_eq!( + result.get("ANTHROPIC_VERTEX_PROJECT_ID"), + Some(&"my-gcp-project".to_string()) + ); + assert_eq!( + result.get("GCP_PROJECT_ID"), + Some(&"my-gcp-project".to_string()) + ); + assert_eq!( + result.get("GOOGLE_CLOUD_PROJECT"), + Some(&"my-gcp-project".to_string()) + ); + // Region derived vars. + assert_eq!( + result.get("CLOUD_ML_REGION"), + Some(&"us-central1".to_string()) + ); + assert_eq!(result.get("GCP_LOCATION"), Some(&"us-central1".to_string())); + assert_eq!( + result.get("VERTEX_LOCATION"), + Some(&"us-central1".to_string()) + ); + } + + #[tokio::test] + async fn resolve_provider_env_vertex_never_injects_service_account_key() { + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-bootstrap".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: [ + ( + "GOOGLE_SERVICE_ACCOUNT_KEY".to_string(), + r#"{"type":"service_account","private_key":"secret"}"#.to_string(), + ), + ( + "GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN".to_string(), + "ya29.short-lived".to_string(), + ), + ] + .into_iter() + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment(&store, &["vertex-bootstrap".to_string()]) + .await + .unwrap(); + + assert!(!result.contains_key("GOOGLE_SERVICE_ACCOUNT_KEY")); + assert_eq!( + result.get("GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_TOKEN"), + Some(&"ya29.short-lived".to_string()) + ); + } + + #[tokio::test] + async fn resolve_provider_env_vertex_omits_agent_config_when_project_and_region_absent() { + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-no-config".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: std::iter::once(( + "GOOGLE_VERTEX_AI_TOKEN".to_string(), + "ya29.token".to_string(), + )) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment(&store, &["vertex-no-config".to_string()]) + .await + .unwrap(); + + // Static flags still present. + assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); + assert_eq!( + result.get("GOOSE_PROVIDER"), + Some(&"gcp_vertex_ai".to_string()) + ); + // Project ID and region derived vars are absent. + assert!(!result.contains_key("ANTHROPIC_VERTEX_PROJECT_ID")); + assert!(!result.contains_key("GCP_PROJECT_ID")); + assert!(!result.contains_key("GOOGLE_CLOUD_PROJECT")); + assert!(!result.contains_key("CLOUD_ML_REGION")); + assert!(!result.contains_key("GCP_LOCATION")); + assert!(!result.contains_key("VERTEX_LOCATION")); + } + + #[tokio::test] + async fn resolve_provider_env_vertex_credential_wins_over_agent_config_key() { + // If a credential happens to share a name with one of the injected agent + // config keys, the credential value takes precedence because the credential + // loop runs first and entry().or_insert() does not overwrite. + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "vertex-collision".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "google-vertex-ai".to_string(), + credentials: [ + ( + "GOOGLE_VERTEX_AI_TOKEN".to_string(), + "ya29.token".to_string(), + ), + // Same key as an injected static flag. + ("GOOSE_PROVIDER".to_string(), "custom-value".to_string()), + ] + .into_iter() + .collect(), + config: [ + ("VERTEX_AI_PROJECT_ID".to_string(), "my-project".to_string()), + ("VERTEX_AI_REGION".to_string(), "us-east1".to_string()), + ] + .into_iter() + .collect(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment(&store, &["vertex-collision".to_string()]) + .await + .unwrap(); + + // Credential value wins over the injected static value. + assert_eq!( + result.get("GOOSE_PROVIDER"), + Some(&"custom-value".to_string()) + ); + } + + #[tokio::test] + async fn resolve_provider_env_non_vertex_provider_does_not_inject_agent_config() { + let store = test_store().await; + create_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "openai-local".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-test".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment(&store, &["openai-local".to_string()]) + .await + .unwrap(); + + assert_eq!(result.get("OPENAI_API_KEY"), Some(&"sk-test".to_string())); + assert!(!result.contains_key("CLAUDE_CODE_USE_VERTEX")); + assert!(!result.contains_key("GOOSE_PROVIDER")); + assert!(!result.contains_key("ANTHROPIC_VERTEX_PROJECT_ID")); + assert!(!result.contains_key("GCP_PROJECT_ID")); + assert!(!result.contains_key("GOOGLE_CLOUD_PROJECT")); + assert!(!result.contains_key("CLOUD_ML_REGION")); + assert!(!result.contains_key("GCP_LOCATION")); + assert!(!result.contains_key("VERTEX_LOCATION")); + } + #[tokio::test] async fn update_provider_rejects_credential_key_collision_for_attached_sandbox() { let store = test_store().await; @@ -3724,7 +4259,7 @@ mod tests { .await .unwrap(); - let sandbox = Sandbox { + let mut sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: "sandbox-001".to_string(), name: "test-sandbox".to_string(), @@ -3737,9 +4272,8 @@ mod tests { ..SandboxSpec::default() }), status: None, - phase: SandboxPhase::Ready as i32, - ..Default::default() }; + sandbox.set_phase(SandboxPhase::Ready as i32); store.put_message(&sandbox).await.unwrap(); let loaded = store @@ -3761,7 +4295,7 @@ mod tests { let store = test_store().await; - let sandbox = Sandbox { + let mut sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: "sandbox-002".to_string(), name: "empty-sandbox".to_string(), @@ -3771,9 +4305,8 @@ mod tests { }), spec: Some(SandboxSpec::default()), status: None, - phase: SandboxPhase::Ready as i32, - ..Default::default() }; + sandbox.set_phase(SandboxPhase::Ready as i32); store.put_message(&sandbox).await.unwrap(); let loaded = store diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 0753d39f4..06f67fb4c 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -23,6 +23,10 @@ use openshell_core::proto::{ TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; +use openshell_core::telemetry::{ + LifecycleOperation, LifecycleResource, SandboxTemplateSource, TelemetryComputeDriver, + TelemetryOutcome, +}; use openshell_core::{ObjectId, ObjectName}; use prost::Message; use std::net::IpAddr; @@ -56,6 +60,62 @@ const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; pub(super) async fn handle_create_sandbox( state: &Arc, request: Request, +) -> Result, Status> { + let create_request = request.get_ref().clone(); + let result = handle_create_sandbox_inner(state, request).await; + emit_sandbox_create_telemetry( + state, + &create_request, + TelemetryOutcome::from_success(result.is_ok()), + ); + result +} + +fn emit_sandbox_create_telemetry( + state: &Arc, + request: &CreateSandboxRequest, + outcome: TelemetryOutcome, +) { + let compute_driver = telemetry_compute_driver(state.compute.driver_kind()); + let Some(spec) = request.spec.as_ref() else { + openshell_core::telemetry::emit_sandbox_create( + outcome, + false, + 0, + false, + SandboxTemplateSource::Undefined, + compute_driver, + ); + return; + }; + let template_source = if spec + .template + .as_ref() + .is_some_and(|template| !template.image.trim().is_empty()) + { + SandboxTemplateSource::Image + } else { + SandboxTemplateSource::Default + }; + openshell_core::telemetry::emit_sandbox_create( + outcome, + spec.gpu, + spec.providers.len() as u64, + spec.policy.is_some(), + template_source, + compute_driver, + ); +} + +fn telemetry_compute_driver( + driver_kind: Option, +) -> TelemetryComputeDriver { + TelemetryComputeDriver::from_driver_kind(driver_kind) +} + +async fn handle_create_sandbox_inner( + state: &Arc, + request: Request, ) -> Result, Status> { use crate::persistence::current_time_ms; @@ -107,7 +167,7 @@ pub(super) async fn handle_create_sandbox( let now_ms = current_time_ms(); - let sandbox = Sandbox { + let mut sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: id.clone(), name: name.clone(), @@ -117,9 +177,8 @@ pub(super) async fn handle_create_sandbox( }), spec: Some(spec), status: None, - phase: SandboxPhase::Provisioning as i32, - current_policy_version: 0, }; + sandbox.set_phase(SandboxPhase::Provisioning as i32); // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) super::validation::validate_object_metadata(sandbox.metadata.as_ref(), "sandbox")?; @@ -152,11 +211,7 @@ pub(super) async fn handle_create_sandbox( Some(Err(status)) => return Err(status), None => None, }; - - let sandbox = match state.compute.create_sandbox(sandbox, sandbox_token).await { - Ok(sandbox) => sandbox, - Err(err) => return Err(err), - }; + let sandbox = state.compute.create_sandbox(sandbox, sandbox_token).await?; info!( sandbox_id = %id, @@ -415,6 +470,23 @@ pub(super) async fn handle_detach_sandbox_provider( pub(super) async fn handle_delete_sandbox( state: &Arc, request: Request, +) -> Result, Status> { + let result = handle_delete_sandbox_inner(state, request).await; + let outcome = match &result { + Ok(response) if response.get_ref().deleted => TelemetryOutcome::Success, + _ => TelemetryOutcome::Failure, + }; + openshell_core::telemetry::emit_lifecycle( + LifecycleResource::Sandbox, + LifecycleOperation::Delete, + outcome, + ); + result +} + +async fn handle_delete_sandbox_inner( + state: &Arc, + request: Request, ) -> Result, Status> { let name = request.into_inner().name; if name.is_empty() { @@ -563,7 +635,7 @@ pub(super) async fn handle_watch_sandbox( if stop_on_terminal { let phase = - SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); if phase == SandboxPhase::Ready { return; } @@ -633,7 +705,7 @@ pub(super) async fn handle_watch_sandbox( return; } if stop_on_terminal { - let phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + let phase = SandboxPhase::try_from(sandbox.phase()).unwrap_or(SandboxPhase::Unknown); if phase == SandboxPhase::Ready { return; } @@ -736,7 +808,7 @@ pub(super) async fn handle_exec_sandbox( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -850,7 +922,7 @@ pub(super) async fn handle_forward_tcp( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1179,7 +1251,7 @@ pub(super) async fn handle_exec_sandbox_interactive( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1252,7 +1324,7 @@ pub(super) async fn handle_create_ssh_session( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + if SandboxPhase::try_from(sandbox.phase()).ok() != Some(SandboxPhase::Ready) { return Err(Status::failed_precondition("sandbox is not ready")); } @@ -1855,15 +1927,36 @@ async fn run_exec_with_russh( #[cfg(test)] mod tests { use super::*; - use crate::auth::identity::{Identity, IdentityProvider}; - use crate::auth::oidc::RawBearerToken; - use crate::auth::principal::{Principal, UserPrincipal}; use crate::grpc::test_support::test_server_state; use openshell_core::proto::datamodel::v1::ObjectMeta; use std::collections::HashMap; // ---- shell_escape ---- + #[test] + fn telemetry_compute_driver_uses_resolved_driver_kind() { + assert_eq!( + telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Docker)), + TelemetryComputeDriver::Docker + ); + assert_eq!( + telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Kubernetes)), + TelemetryComputeDriver::Kubernetes + ); + assert_eq!( + telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Podman)), + TelemetryComputeDriver::Podman + ); + assert_eq!( + telemetry_compute_driver(Some(openshell_core::ComputeDriverKind::Vm)), + TelemetryComputeDriver::Vm + ); + assert_eq!( + telemetry_compute_driver(None), + TelemetryComputeDriver::Unknown + ); + } + #[test] fn shell_escape_safe_chars_pass_through() { assert_eq!(shell_escape("ls").unwrap(), "ls"); @@ -2097,7 +2190,7 @@ mod tests { } fn test_sandbox(name: &str, providers: Vec) -> Sandbox { - Sandbox { + let mut sandbox = Sandbox { metadata: Some(ObjectMeta { id: format!("sandbox-{name}"), name: name.to_string(), @@ -2111,10 +2204,11 @@ mod tests { providers, ..Default::default() }), - phase: SandboxPhase::Ready as i32, - current_policy_version: 7, ..Default::default() - } + }; + sandbox.set_phase(SandboxPhase::Ready as i32); + sandbox.set_current_policy_version(7); + sandbox } #[tokio::test] @@ -2150,11 +2244,11 @@ mod tests { .await .unwrap() .unwrap(); + assert_eq!(sandbox.phase(), SandboxPhase::Ready as i32); + assert_eq!(sandbox.current_policy_version(), 7); let spec = sandbox.spec.unwrap(); assert_eq!(spec.providers, vec!["work-github"]); assert_eq!(spec.log_level, "debug"); - assert_eq!(sandbox.phase, SandboxPhase::Ready as i32); - assert_eq!(sandbox.current_policy_version, 7); } #[tokio::test] @@ -2435,7 +2529,7 @@ mod tests { async fn interactive_exec_rejects_sandbox_not_ready() { let state = test_server_state().await; let mut sandbox = test_sandbox("not-ready", Vec::new()); - sandbox.phase = SandboxPhase::Provisioning as i32; + sandbox.set_phase(SandboxPhase::Provisioning as i32); state.store.put_message(&sandbox).await.unwrap(); let stored = state @@ -2445,7 +2539,7 @@ mod tests { .unwrap() .unwrap(); assert_ne!( - SandboxPhase::try_from(stored.phase).ok(), + SandboxPhase::try_from(stored.phase()).ok(), Some(SandboxPhase::Ready) ); } @@ -2484,40 +2578,6 @@ mod tests { assert!(err.message().contains("provider-b")); } - #[tokio::test] - async fn create_sandbox_succeeds_for_oidc_user_without_persisted_binding() { - let state = test_server_state().await; - let mut request = Request::new(CreateSandboxRequest { - name: "delegated".to_string(), - spec: Some(openshell_core::proto::SandboxSpec::default()), - labels: HashMap::new(), - }); - request - .extensions_mut() - .insert(Principal::User(UserPrincipal { - identity: Identity { - subject: "user-123".to_string(), - display_name: Some("alex".to_string()), - roles: vec!["openshell-user".to_string()], - scopes: vec!["sandbox:write".to_string()], - provider: IdentityProvider::Oidc, - }, - })); - request - .extensions_mut() - .insert(RawBearerToken("raw-access-token".to_string())); - - let response = handle_create_sandbox(&state, request) - .await - .expect("sandbox create succeeds") - .into_inner(); - let sandbox = response.sandbox.expect("sandbox present"); - assert_eq!( - sandbox.metadata.as_ref().expect("metadata").name, - "delegated" - ); - } - #[tokio::test] async fn attach_sandbox_provider_rejects_credential_key_collisions() { let state = test_server_state().await; diff --git a/crates/openshell-server/src/telemetry.rs b/crates/openshell-server/src/telemetry.rs index 7de21154e..04529bcc4 100644 --- a/crates/openshell-server/src/telemetry.rs +++ b/crates/openshell-server/src/telemetry.rs @@ -21,8 +21,6 @@ impl TelemetryState { pub fn sandbox_session_disconnected(&self, _sandbox_id: &str) {} - pub fn end_sandbox_session(&self, _sandbox_id: &str) {} - pub fn record_network_activity(&self, sandbox_id: &str, summary: &NetworkActivitySummary) { if sandbox_id.is_empty() || !openshell_core::telemetry::enabled() { return; @@ -99,6 +97,5 @@ mod tests { let telemetry = TelemetryState::new(); telemetry.sandbox_session_connected("sb-1"); telemetry.sandbox_session_disconnected("sb-1"); - telemetry.end_sandbox_session("sb-1"); } }