Skip to content
4 changes: 4 additions & 0 deletions crates/openshell-bootstrap/src/oidc_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub struct OidcTokenBundle {
/// `OAuth2` access token (JWT).
pub access_token: String,

/// Optional OIDC ID token returned by the provider.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,

/// `OAuth2` refresh token. `None` for `client_credentials` grants.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
Expand Down
19 changes: 11 additions & 8 deletions crates/openshell-cli/src/completers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,19 @@ async fn completion_grpc_client(
Some("oidc") => {
if let Some(bundle) = load_oidc_token(gateway_name) {
if is_token_expired(&bundle) {
match oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await {
Ok(refreshed) => {
let _ = store_oidc_token(gateway_name, &refreshed);
tls_opts.oidc_token = Some(refreshed.access_token);
}
Err(_) => {
tls_opts.oidc_token = Some(bundle.access_token);
}
if let Ok(refreshed) =
oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await
{
let _ = store_oidc_token(gateway_name, &refreshed);
tls_opts.oidc_token = Some(refreshed.access_token);
tls_opts.oidc_id_token = refreshed.id_token;
} else {
tls_opts.oidc_token = Some(bundle.access_token);
tls_opts.oidc_id_token = bundle.id_token;
}
} else {
tls_opts.oidc_token = Some(bundle.access_token);
tls_opts.oidc_id_token = bundle.id_token;
}
}
}
Expand All @@ -124,6 +126,7 @@ async fn completion_grpc_client(
let channel = build_channel(server, &tls_opts).await.ok()?;
let interceptor = EdgeAuthInterceptor::new(
tls_opts.oidc_token.as_deref(),
tls_opts.oidc_id_token.as_deref(),
tls_opts.edge_token.as_deref(),
)
.ok()?;
Expand Down
32 changes: 29 additions & 3 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) {
else {
return;
};
let bearer_for_gateway = |access_token: &str, id_token: Option<&String>| {
if access_token.matches('.').count() == 2 {
access_token.to_string()
} else {
id_token
.cloned()
.unwrap_or_else(|| access_token.to_string())
}
};
if openshell_bootstrap::oidc_token::is_token_expired(&bundle) {
let insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE")
.is_ok_and(|v| !v.is_empty() && v != "0" && v != "false");
Expand All @@ -155,17 +164,29 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) {
gateway_name,
&refreshed,
);
tls.oidc_token = Some(refreshed.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&refreshed.access_token,
refreshed.id_token.as_ref(),
));
tls.oidc_id_token = refreshed.id_token;
}
Err(e) => {
tracing::warn!("OIDC token refresh failed: {e}");
// Use the expired token anyway — server will reject it
// with a clear error prompting re-login.
tls.oidc_token = Some(bundle.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&bundle.access_token,
bundle.id_token.as_ref(),
));
tls.oidc_id_token = bundle.id_token;
}
}
} else {
tls.oidc_token = Some(bundle.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&bundle.access_token,
bundle.id_token.as_ref(),
));
tls.oidc_id_token = bundle.id_token;
}
}
_ => {}
Expand Down Expand Up @@ -660,6 +681,8 @@ enum OutputFormat {
enum CliProviderRefreshStrategy {
Oauth2RefreshToken,
Oauth2ClientCredentials,
Oauth2TokenExchange,
OktaXaa,
GoogleServiceAccountJwt,
}

Expand All @@ -668,6 +691,8 @@ impl CliProviderRefreshStrategy {
match self {
Self::Oauth2RefreshToken => "oauth2_refresh_token",
Self::Oauth2ClientCredentials => "oauth2_client_credentials",
Self::Oauth2TokenExchange => "oauth2_token_exchange",
Self::OktaXaa => "okta_xaa",
Self::GoogleServiceAccountJwt => "google_service_account_jwt",
}
}
Expand Down Expand Up @@ -2898,6 +2923,7 @@ async fn main() -> Result<()> {
let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?;
let interceptor = openshell_core::auth::EdgeAuthInterceptor::new(
tls.oidc_token.as_deref(),
tls.oidc_id_token.as_deref(),
tls.edge_token.as_deref(),
)?;
openshell_tui::run(channel, interceptor, &ctx.name, &ctx.endpoint, theme).await?;
Expand Down
97 changes: 79 additions & 18 deletions crates/openshell-cli/src/oidc_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ use hyper::{Method, Response, StatusCode};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use miette::{IntoDiagnostic, Result};
use oauth2::basic::BasicClient;
use oauth2::{
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
AuthType, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EndpointNotSet, ExtraTokenFields, PkceCodeChallenge, RedirectUrl, RefreshToken, Scope,
StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl,
basic::{
BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
BasicTokenType,
},
};
use openshell_bootstrap::oidc_token::OidcTokenBundle;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::{Arc, Mutex};
use std::time::Duration;
Expand All @@ -29,6 +33,37 @@ use tokio::sync::oneshot;
use tracing::debug;

const AUTH_TIMEOUT: Duration = Duration::from_secs(120);
const DEFAULT_OIDC_CALLBACK_BIND: &str = "127.0.0.1:0";
const OIDC_CALLBACK_PORT_ENV: &str = "OPENSHELL_OIDC_CALLBACK_PORT";
const OIDC_CLIENT_SECRET_ENV: &str = "OPENSHELL_OIDC_CLIENT_SECRET";

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
struct OidcExtraTokenFields {
#[serde(default, skip_serializing_if = "Option::is_none")]
id_token: Option<String>,
}

impl ExtraTokenFields for OidcExtraTokenFields {}

type OidcTokenResponse = StandardTokenResponse<OidcExtraTokenFields, BasicTokenType>;
type OidcClient<
HasAuthUrl = EndpointNotSet,
HasDeviceAuthUrl = EndpointNotSet,
HasIntrospectionUrl = EndpointNotSet,
HasRevocationUrl = EndpointNotSet,
HasTokenUrl = EndpointNotSet,
> = Client<
BasicErrorResponse,
OidcTokenResponse,
BasicTokenIntrospectionResponse,
StandardRevocableToken,
BasicRevocationErrorResponse,
HasAuthUrl,
HasDeviceAuthUrl,
HasIntrospectionUrl,
HasRevocationUrl,
HasTokenUrl,
>;

/// OIDC discovery document (subset of fields we need).
#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -95,6 +130,25 @@ fn build_ci_scopes(scopes: Option<&str>) -> Vec<Scope> {
.collect()
}

fn oidc_callback_bind_address() -> Result<String> {
match std::env::var(OIDC_CALLBACK_PORT_ENV) {
Ok(raw) => {
let port = raw.parse::<u16>().map_err(|_| {
miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be a valid TCP port number, got '{raw}'"
)
})?;
if port == 0 {
return Err(miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be greater than 0"
));
}
Ok(format!("127.0.0.1:{port}"))
}
Err(_) => Ok(DEFAULT_OIDC_CALLBACK_BIND.to_string()),
}
}

/// Run the OIDC Authorization Code + PKCE browser flow.
///
/// Opens the user's browser to the Keycloak login page and waits for
Expand All @@ -108,14 +162,21 @@ pub async fn oidc_browser_auth_flow(
) -> Result<OidcTokenBundle> {
let discovery = discover(issuer, insecure).await?;

let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?;
let listener = TcpListener::bind(oidc_callback_bind_address()?)
.await
.into_diagnostic()?;
let port = listener.local_addr().into_diagnostic()?.port();
let redirect_uri = format!("http://127.0.0.1:{port}/callback");

let client = BasicClient::new(ClientId::new(client_id.to_string()))
let mut client = OidcClient::new(ClientId::new(client_id.to_string()))
.set_auth_uri(AuthUrl::new(discovery.authorization_endpoint).into_diagnostic()?)
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?)
.set_redirect_uri(RedirectUrl::new(redirect_uri).into_diagnostic()?);
if let Ok(client_secret) = std::env::var(OIDC_CLIENT_SECRET_ENV) {
client = client
.set_client_secret(ClientSecret::new(client_secret))
.set_auth_type(AuthType::RequestBody);
}

let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();

Expand Down Expand Up @@ -167,7 +228,7 @@ pub async fn oidc_browser_auth_flow(
server_handle.abort();

let http = http_client(insecure);
let token_response = client
let token_response: OidcTokenResponse = client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(pkce_verifier)
.request_async(&http)
Expand All @@ -191,15 +252,15 @@ pub async fn oidc_client_credentials_flow(
scopes: Option<&str>,
insecure: bool,
) -> Result<OidcTokenBundle> {
let client_secret = std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").map_err(|_| {
let client_secret = std::env::var(OIDC_CLIENT_SECRET_ENV).map_err(|_| {
miette::miette!(
"OPENSHELL_OIDC_CLIENT_SECRET environment variable is required for client credentials flow"
"{OIDC_CLIENT_SECRET_ENV} environment variable is required for client credentials flow"
)
})?;

let discovery = discover(issuer, insecure).await?;

let client = BasicClient::new(ClientId::new(client_id.to_string()))
let client = OidcClient::new(ClientId::new(client_id.to_string()))
.set_client_secret(ClientSecret::new(client_secret))
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?)
.set_auth_type(AuthType::RequestBody);
Expand All @@ -213,7 +274,7 @@ pub async fn oidc_client_credentials_flow(
}

let http = http_client(insecure);
let token_response = request
let token_response: OidcTokenResponse = request
.request_async(&http)
.await
.map_err(|e| miette::miette!("client credentials token exchange failed: {e}"))?;
Expand Down Expand Up @@ -241,11 +302,11 @@ pub async fn oidc_refresh_token(

let discovery = discover(&bundle.issuer, insecure).await?;

let client = BasicClient::new(ClientId::new(bundle.client_id.clone()))
let client = OidcClient::new(ClientId::new(bundle.client_id.clone()))
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?);

let http = http_client(insecure);
let token_response = client
let token_response: OidcTokenResponse = client
.exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
.request_async(&http)
.await
Expand Down Expand Up @@ -287,7 +348,7 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Resu
// ── Helpers ──────────────────────────────────────────────────────────

fn bundle_from_oauth2_response(
resp: &oauth2::basic::BasicTokenResponse,
resp: &OidcTokenResponse,
issuer: &str,
client_id: &str,
) -> OidcTokenBundle {
Expand All @@ -298,6 +359,7 @@ fn bundle_from_oauth2_response(

OidcTokenBundle {
access_token: resp.access_token().secret().clone(),
id_token: resp.extra_fields().id_token.clone(),
refresh_token: resp.refresh_token().map(|rt| rt.secret().clone()),
expires_at: resp.expires_in().map(|ei| now + ei.as_secs()),
issuer: issuer.to_string(),
Expand Down Expand Up @@ -518,14 +580,13 @@ mod tests {

#[test]
fn bundle_from_response_sets_fields() {
use oauth2::basic::BasicTokenResponse;

let token_response: BasicTokenResponse = serde_json::from_str(
r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh"}"#,
let token_response: OidcTokenResponse = serde_json::from_str(
r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh","id_token":"test-id"}"#,
)
.unwrap();
let bundle = bundle_from_oauth2_response(&token_response, "https://issuer", "my-client");
assert_eq!(bundle.access_token, "test-access");
assert_eq!(bundle.id_token.as_deref(), Some("test-id"));
assert_eq!(bundle.refresh_token.as_deref(), Some("test-refresh"));
assert_eq!(bundle.issuer, "https://issuer");
assert_eq!(bundle.client_id, "my-client");
Expand Down
Loading
Loading