Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions codex-rs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ codex-app-server-protocol = { workspace = true }
codex-apply-patch = { workspace = true }
codex-async-utils = { workspace = true }
codex-code-mode = { workspace = true }
codex-client = { workspace = true }
codex-connectors = { workspace = true }
codex-config = { workspace = true }
codex-core-plugins = { workspace = true }
Expand Down
206 changes: 160 additions & 46 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use codex_api::build_session_headers;
use codex_api::create_text_param_for_request;
use codex_api::response_create_client_metadata;
use codex_app_server_protocol::AuthMode;
use codex_client::HttpTransport;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::RefreshTokenError;
Expand Down Expand Up @@ -184,6 +185,99 @@ struct CurrentClientSetup {
api_auth: SharedAuthProvider,
}

/// Session-scoped factory for constructing codex-api REST clients.
///
/// The factory keeps the runtime model provider handle and resolves API provider/auth state when a
/// client is requested, so clients created after auth refresh use current credentials.
#[derive(Clone, Debug)]
pub struct ApiClientFactory {
provider: SharedModelProvider,
}

impl ApiClientFactory {
pub(crate) fn new(provider: SharedModelProvider) -> Self {
Self { provider }
}

pub async fn create<C: ApiClient<ReqwestTransport>>(&self) -> Result<C> {
Ok(self.current_setup().await?.create())
}

async fn current_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.provider.auth().await;
let api_provider = self.provider.api_provider().await?;
let api_auth = self.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}
}

impl CurrentClientSetup {
fn create<C: ApiClient<ReqwestTransport>>(&self) -> C {
C::from_api_parts(
ReqwestTransport::new(build_reqwest_client()),
self.api_provider.clone(),
self.api_auth.clone(),
)
}
}

/// Constructs a codex-api REST client from a transport and resolved API provider/auth parts.
///
/// Implementations should be thin adapters over the concrete client's existing constructor. This
/// lets [`ApiClientFactory`] instantiate client types generically without hardcoding them in the
/// factory itself.
pub trait ApiClient<T: HttpTransport>: Sized {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self;
}

impl<T: HttpTransport> ApiClient<T> for ApiCompactClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}

impl<T: HttpTransport> ApiClient<T> for ApiMemoriesClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}

impl<T: HttpTransport> ApiClient<T> for ApiRealtimeCallClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}

impl<T: HttpTransport> ApiClient<T> for ApiResponsesClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}

#[derive(Clone, Copy)]
struct RequestRouteTelemetry {
endpoint: &'static str,
Expand Down Expand Up @@ -226,6 +320,7 @@ pub struct ModelClient {
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
api_client_factory: ApiClientFactory,
websocket_session: WebsocketSession,
/// Turn state for sticky routing.
///
Expand Down Expand Up @@ -316,6 +411,31 @@ impl ModelClient {
beta_features_header: Option<String>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
Self::from_model_provider(
model_provider,
session_id,
thread_id,
installation_id,
session_source,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
)
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn from_model_provider(
model_provider: SharedModelProvider,
session_id: SessionId,
thread_id: ThreadId,
installation_id: String,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let codex_api_key_env_enabled = model_provider
.auth_manager()
.as_ref()
Expand Down Expand Up @@ -346,8 +466,17 @@ impl ModelClient {
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
/// when the first stream request is issued.
pub fn new_session(&self) -> ModelClientSession {
let api_client_factory = ApiClientFactory::new(Arc::clone(&self.state.provider));
self.new_session_with_client_factory(api_client_factory)
}

pub(crate) fn new_session_with_client_factory(
&self,
api_client_factory: ApiClientFactory,
) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
api_client_factory,
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
Expand Down Expand Up @@ -422,6 +551,7 @@ impl ModelClient {
/// session-scoped.
pub(crate) async fn compact_conversation_history(
&self,
api_client_factory: &ApiClientFactory,
prompt: &Prompt,
model_info: &ModelInfo,
settings: CompactConversationRequestSettings,
Expand All @@ -431,8 +561,7 @@ impl ModelClient {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = api_client_factory.current_setup().await?;
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
Expand Down Expand Up @@ -463,9 +592,9 @@ impl ModelClient {
text,
..
} = request;
let client =
ApiCompactClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let client = client_setup
.create::<ApiCompactClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry));
let payload = ApiCompactionInput {
model: &model,
input: &input,
Expand Down Expand Up @@ -503,23 +632,23 @@ impl ModelClient {

pub(crate) async fn create_realtime_call_with_headers(
&self,
api_client_factory: &ApiClientFactory,
sdp: String,
session_config: ApiRealtimeSessionConfig,
extra_headers: ApiHeaderMap,
) -> Result<RealtimeWebrtcCallStart> {
// Create the media call over HTTP first, then retain matching auth so realtime can attach
// the server-side control WebSocket to the call id from that HTTP response.
let client_setup = self.current_client_setup().await?;
let client_setup = api_client_factory.current_setup().await?;
let mut sideband_headers = extra_headers.clone();
sideband_headers.extend(sideband_websocket_auth_headers(
client_setup.api_auth.as_ref(),
));
let transport = ReqwestTransport::new(build_reqwest_client());
let response =
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
let response = client_setup
.create::<ApiRealtimeCallClient<ReqwestTransport>>()
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
Ok(RealtimeWebrtcCallStart {
sdp: response.sdp,
call_id: response.call_id,
Expand All @@ -535,6 +664,7 @@ impl ModelClient {
/// `ModelClient` session-scoped.
pub async fn summarize_memories(
&self,
api_client_factory: &ApiClientFactory,
raw_memories: Vec<ApiRawMemory>,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
Expand All @@ -544,8 +674,7 @@ impl ModelClient {
return Ok(Vec::new());
}

let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = api_client_factory.current_setup().await?;
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
Expand All @@ -556,9 +685,9 @@ impl ModelClient {
RequestRouteTelemetry::for_endpoint(MEMORIES_SUMMARIZE_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let client =
ApiMemoriesClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let client = client_setup
.create::<ApiMemoriesClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry));

let payload = ApiMemorySummarizeInput {
model: model_info.slug.clone(),
Expand Down Expand Up @@ -747,21 +876,6 @@ impl ModelClient {
true
}

/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in
/// lockstep when auth/provider resolution changes.
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}

/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
Expand Down Expand Up @@ -1038,11 +1152,15 @@ impl ModelClientSession {
return Ok(());
}

let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let client_setup = self
.api_client_factory
.current_setup()
.await
.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
Expand Down Expand Up @@ -1201,8 +1319,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = self.api_client_factory.current_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
Expand All @@ -1227,12 +1344,9 @@ impl ModelClientSession {
)?;
let inference_trace_attempt = inference_trace.start_attempt();
inference_trace_attempt.record_started(&request);
let client = ApiResponsesClient::new(
transport,
client_setup.api_provider,
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let client = client_setup
.create::<ApiResponsesClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;

match stream_result {
Expand Down Expand Up @@ -1314,7 +1428,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.api_client_factory.current_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
Expand Down
6 changes: 6 additions & 0 deletions codex-rs/core/src/client_tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::ApiClientFactory;
use super::AuthRequestTelemetryContext;
use super::ModelClient;
use super::PendingUnauthorizedRetry;
Expand All @@ -11,6 +12,7 @@ use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider::create_model_provider;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
Expand Down Expand Up @@ -308,6 +310,10 @@ async fn summarize_memories_returns_empty_for_empty_input() {

let output = client
.summarize_memories(
&ApiClientFactory::new(create_model_provider(
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses),
/*auth_manager*/ None,
)),
Vec::new(),
&model_info,
/*effort*/ None,
Expand Down
5 changes: 4 additions & 1 deletion codex-rs/core/src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ async fn run_compact_task_inner_impl(

let max_retries = turn_context.provider.info().stream_max_retries();
let mut retries = 0;
let mut client_session = sess.services.model_client.new_session();
let mut client_session = sess
.services
.model_client
.new_session_with_client_factory(sess.services.api_client_factory.clone());
// Reuse one client session so turn-scoped state (sticky routing, websocket incremental
// request tracking)
// survives retries within this compact turn.
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/compact_remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async fn run_remote_compact_task_inner_impl(
.services
.model_client
.compact_conversation_history(
&sess.services.api_client_factory,
&prompt,
&turn_context.model_info,
CompactConversationRequestSettings {
Expand Down
Loading
Loading