From 26ca721853f8ee19d34ce8103317d16a28ceef91 Mon Sep 17 00:00:00 2001 From: airton Date: Fri, 6 Mar 2026 22:25:51 +0100 Subject: [PATCH 1/3] native-tls https fix --- bitreq/src/connection.rs | 56 ++++++++++++++++---------- bitreq/src/connection/rustls_stream.rs | 37 +++++++++++------ bitreq/src/error.rs | 7 +++- bitreq/tests/main.rs | 27 +++++++++++++ 4 files changed, 92 insertions(+), 35 deletions(-) diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index f8b98c133..2d9610d1d 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -29,14 +29,14 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(feature = "rustls")] +#[cfg(any(feature = "rustls", feature = "native-tls"))] mod rustls_stream; -#[cfg(feature = "rustls")] +#[cfg(any(feature = "rustls", feature = "native-tls"))] type SecuredStream = rustls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] Secured(Box, Option), #[cfg(feature = "async")] Buffer(std::io::Cursor>), @@ -81,7 +81,7 @@ impl Read for HttpStream { timeout(inner, *timeout_at)?; inner.read(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) @@ -111,7 +111,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.write(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.write(buf) @@ -137,7 +137,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.flush() } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.flush() @@ -158,13 +158,13 @@ impl Write for HttpStream { } } -#[cfg(feature = "tokio-rustls")] +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { Unsecured(AsyncTcpStream), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] Secured(Box), } @@ -177,7 +177,7 @@ impl AsyncRead for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf), } } @@ -192,7 +192,7 @@ impl AsyncWrite for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf), } } @@ -200,7 +200,7 @@ impl AsyncWrite for AsyncHttpStream { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx), } } @@ -208,7 +208,7 @@ impl AsyncWrite for AsyncHttpStream { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx), - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx), } } @@ -271,10 +271,8 @@ impl AsyncConnection { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(feature = "tokio-rustls"))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "tokio-rustls")] - rustls_stream::wrap_async_stream(socket, params.host).await + // temp call + Self::wrap_async_stream(socket, params.host).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -298,6 +296,23 @@ impl AsyncConnection { })))) } + /// Temp Method implementation + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + ) -> Result { + rustls_stream::wrap_async_stream(socket, host).await + } + + /// Temp Method implementation + #[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -653,13 +668,10 @@ impl Connection { let socket = Self::connect(params, timeout_at)?; let stream = if params.https { - #[cfg(not(feature = "rustls"))] + #[cfg(not(any(feature = "rustls", feature = "native-tls")))] return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "rustls")] - { - let tls = rustls_stream::wrap_stream(socket, params.host)?; - HttpStream::Secured(Box::new(tls), timeout_at) - } + #[cfg(any(feature = "rustls", feature = "native-tls"))] + rustls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) }; diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index c21db715b..baaadccab 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -3,6 +3,7 @@ #[cfg(feature = "rustls")] use alloc::sync::Arc; +#[cfg(any(feature = "rustls", feature = "native-tls"))] use std::io; use std::net::TcpStream; use std::sync::OnceLock; @@ -20,9 +21,9 @@ use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; -#[cfg(feature = "tokio-rustls")] -use super::{AsyncHttpStream, AsyncTcpStream}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(any(feature = "rustls", feature = "native-tls"))] +use super::HttpStream; +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] use super::{AsyncHttpStream, AsyncTcpStream}; use crate::Error; @@ -50,7 +51,7 @@ fn build_client_config() -> Arc { } #[cfg(feature = "rustls")] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); let dns_name = ServerName::try_from(host) @@ -58,10 +59,12 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result> = OnceLock::new(); #[cfg(all(feature = "native-tls", not(feature = "rustls")))] fn native_tls_err(e: HandshakeError) -> Error { match e { - HandshakeError::Failure(e) => Error::NativeTlsError(e), + HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), HandshakeError::WouldBlock(_) => { debug_assert!(false, "We shouldn't hit a blocking error"); Error::Other("Got a WouldBlock error from native-tls") @@ -109,22 +112,27 @@ fn native_tls_err(e: HandshakeError) -> Error { #[cfg(all(feature = "native-tls", not(feature = "rustls")))] fn build_tls_connector() -> Result { - TlsConnector::builder().build().map_err(Error::NativeTlsError) + TlsConnector::builder().build().map_err(Error::from) } #[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = CONNECTOR.get_or_init(build_tls_connector)?; + let connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - connector.connect(host, tcp).map_err(native_tls_err) + let tls = connector.connect(host, tcp).map_err(native_tls_err)?; + + Ok(HttpStream::Secured(Box::new(tls), None)) } #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] @@ -140,12 +148,17 @@ pub(super) async fn wrap_async_stream( // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = AsyncTlsConnector::from(CONNECTOR.get_or_init(build_tls_connector)?.clone()); + let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let async_connector = AsyncTlsConnector::from(sync_connector); #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - let tls = connector.connect(host, tcp).await.map_err(native_tls_err)?; + let tls = async_connector.connect(host, tcp).await?; Ok(AsyncHttpStream::Secured(Box::new(tls))) } diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index ca9d1421d..d923c90e0 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -105,7 +105,7 @@ impl fmt::Display for Error { #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), #[cfg(feature = "native-tls")] - NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), + NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"), MalformedContentLength => write!(f, "non-usize content length"), @@ -160,3 +160,8 @@ impl From for Error { impl From for Error { fn from(other: UrlParseError) -> Error { Error::InvalidUrl(other) } } + +#[cfg(feature = "native-tls")] +impl From for Error { + fn from(err: native_tls::Error) -> Error { Error::NativeTlsCreateConnection(err) } +} diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..058e34c46 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -16,6 +16,33 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https() { + // TODO: Implement this locally. + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); + // Test reusing the existing connection in client: + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() { From 00470699509188d330b0b4df193c6d42ee86ebee Mon Sep 17 00:00:00 2001 From: airton Date: Thu, 19 Mar 2026 20:17:09 +0100 Subject: [PATCH 2/3] implement client builder for rustls --- bitreq/src/client.rs | 143 ++++++++++++++++++++++++- bitreq/src/connection.rs | 51 +++++++-- bitreq/src/connection/rustls_stream.rs | 50 ++++++++- bitreq/src/connection/tls_config.rs | 68 ++++++++++++ bitreq/src/error.rs | 12 +++ bitreq/src/request.rs | 2 +- bitreq/tests/ca_cert.der | Bin 0 -> 1413 bytes bitreq/tests/main.rs | 44 +++++++- bitreq/tests/test_cert.der | Bin 0 -> 771 bytes 9 files changed, 354 insertions(+), 16 deletions(-) create mode 100644 bitreq/src/connection/tls_config.rs create mode 100644 bitreq/tests/ca_cert.der create mode 100644 bitreq/tests/test_cert.der diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index b5de6f2fb..0fd793ddf 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -9,10 +9,138 @@ use std::collections::{hash_map, HashMap, VecDeque}; use std::sync::{Arc, Mutex}; +#[cfg(feature = "tokio-rustls")] +use crate::connection::tls_config::{TlsConfig, TlsConfigBuilder}; use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; use crate::{Error, Request, Response}; +#[derive(Clone)] +pub(crate) struct ClientConfig { + #[cfg(feature = "tokio-rustls")] + pub(crate) tls: Option, +} + +pub struct ClientBuilder { + capacity: usize, + #[cfg(feature = "tokio-rustls")] + tls_config: Option, +} + +/// Builder for configuring a `Client` with custom settings. +/// +/// # Example +/// +/// ```no_run +/// # async fn example() -> Result<(), bitreq::Error> { +/// use bitreq::{Client, RequestExt}; +/// +/// let client = Client::builder().with_capacity(20).build()?; +/// +/// let response = bitreq::get("https://example.com") +/// .send_async_with_client(&client) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +impl ClientBuilder { + /// Creates a new `ClientBuilder` with a default pool capacity of 10. + pub fn new() -> Self { + Self { + capacity: 10, + #[cfg(feature = "tokio-rustls")] + tls_config: None, + } + } + + /// Sets the maximum number of connections to keep in the pool. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + #[cfg(feature = "tokio-rustls")] + /// Builds the `Client` with the configured settings. + pub fn build(self) -> Result { + let build_config = if let Some(builder) = self.tls_config { + let tls_config = builder.build()?; + Some(ClientConfig { tls: Some(tls_config) }) + } else { + None + }; + let client_config = build_config.map(Arc::new); + + Ok(Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config, + })), + }) + } + + /// Builds the `Client` with the configured settings. + #[cfg(not(feature = "tokio-rustls"))] + pub fn build(self) -> Result { + Ok(Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config: None, + })), + }) + } + + /// Adds a custom DER-encoded root certificate for TLS verification. + /// The certificate must be provided in DER format. This method accepts any type + /// that can be converted into a `Vec`. + /// The certificate is appended to the default trust store rather than replacing it. + /// The trust store used depends on the TLS backend: system certificates for native-tls, + /// Mozilla's root certificates(rustls-webpki) and/or system certificates(rustls-native-certs) for rustls. + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// # async fn example() -> Result<(), bitreq::Error> { + /// let client = Client::builder() + /// .with_root_certificate(include_bytes!("../tests/test_cert.der"))? + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "tokio-rustls")] + pub fn with_root_certificate>>(mut self, cert_der: T) -> Result { + let cert_der = cert_der.into(); + if let Some(ref mut tls_config) = self.tls_config { + tls_config.append_certificate(cert_der)?; + + return Ok(self); + } + + self.tls_config = Some(TlsConfigBuilder::new(Some(cert_der))?); + Ok(self) + } + + /// Disables default root certificates for TLS connections. + /// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured. + #[cfg(feature = "tokio-rustls")] + pub fn disable_default_certificates(mut self) -> Result { + match self.tls_config { + Some(ref mut tls_config) => tls_config.disable_default_certificates()?, + None => return Err(Error::InvalidTlsConfig), + }; + + Ok(self) + } +} + +impl Default for ClientBuilder { + fn default() -> Self { Self::new() } +} + /// A client that caches connections for reuse. /// /// The client maintains a pool of up to `capacity` connections, evicting @@ -39,10 +167,11 @@ struct ClientImpl { connections: HashMap>, lru_order: VecDeque, capacity: usize, + client_config: Option>, } impl Client { - /// Creates a new `Client` with the specified connection cache capacity. + /// Creates a new `Client` with the specified connection pool capacity. /// /// # Arguments /// @@ -54,10 +183,14 @@ impl Client { connections: HashMap::new(), lru_order: VecDeque::new(), capacity, + client_config: None, })), } } + /// Create a builder for a client + pub fn builder() -> ClientBuilder { ClientBuilder::new() } + /// Sends a request asynchronously using a cached connection if available. pub async fn send_async(&self, request: Request) -> Result { let parsed_request = ParsedRequest::new(request)?; @@ -77,7 +210,13 @@ impl Client { let conn = if let Some(conn) = conn_opt { conn } else { - let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?; + let client_config = { + let state = self.r#async.lock().unwrap(); + state.client_config.as_ref().map(Arc::clone) + }; + + let connection = + AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?; let connection = Arc::new(connection); let mut state = self.r#async.lock().unwrap(); diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index 2d9610d1d..5d8e98ee4 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -22,6 +22,8 @@ use tokio::net::TcpStream as AsyncTcpStream; #[cfg(feature = "async")] use tokio::sync::Mutex as AsyncMutex; +#[cfg(feature = "async")] +use crate::client::ClientConfig; use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; #[cfg(feature = "async")] use crate::Response; @@ -31,6 +33,8 @@ type UnsecuredStream = TcpStream; #[cfg(any(feature = "rustls", feature = "native-tls"))] mod rustls_stream; +#[cfg(feature = "tokio-rustls")] +pub(crate) mod tls_config; #[cfg(any(feature = "rustls", feature = "native-tls"))] type SecuredStream = rustls_stream::SecuredStream; @@ -238,6 +242,7 @@ struct AsyncConnectionState { /// Defaults to 60 seconds after open to align with nginx's default timeout of 75 seconds, but /// can be overridden by the `Keep-Alive` header. socket_new_requests_timeout: Mutex, + client_config: Option>, } #[cfg(feature = "async")] @@ -266,13 +271,15 @@ impl AsyncConnection { pub(crate) async fn new( params: ConnectionParams<'_>, timeout_at: Option, + client_config: Option>, ) -> Result { + let client_config_ref = &client_config; + let future = async move { let socket = Self::connect(params).await?; if params.https { - // temp call - Self::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host, client_config_ref).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -293,26 +300,47 @@ impl AsyncConnection { readable_request_id: AtomicUsize::new(0), min_dropped_reader_id: AtomicUsize::new(usize::MAX), socket_new_requests_timeout: Mutex::new(Instant::now() + Duration::from_secs(60)), + client_config, })))) } - /// Temp Method implementation - #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] + // ======= + /// Temp method. Required to compile + #[cfg(all(feature = "tokio-native-tls", not(feature = "tokio-rustls")))] async fn wrap_async_stream( socket: AsyncTcpStream, host: &str, + _client_config: &Option>, ) -> Result { rustls_stream::wrap_async_stream(socket, host).await } + // ======= + + /// Call the correct wrapper function depending on whether client_configs are present + #[cfg(feature = "tokio-rustls")] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + client_config: &Option>, + ) -> Result { + if let Some(client_config) = client_config { + let tls_config = client_config.tls.as_ref().unwrap().clone(); + rustls_stream::wrap_async_stream_with_configs(socket, host, tls_config).await + } else { + rustls_stream::wrap_async_stream(socket, host).await + } + } - /// Temp Method implementation + /// Error treatment function, should not be called under normal circustances #[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))] async fn wrap_async_stream( _socket: AsyncTcpStream, _host: &str, + _client_config: &Option>, ) -> Result { Err(Error::HttpsFeatureNotEnabled) } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -461,9 +489,13 @@ impl AsyncConnection { retry_new_connection!(_internal); }; (_internal) => { - let new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at) - .await?; + let config = conn.client_config.as_ref().map(Arc::clone); + let new_connection = AsyncConnection::new( + request.connection_params(), + request.timeout_at, + config, + ) + .await?; *self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap()); core::mem::drop(read); // Note that this cannot recurse infinitely as we'll always be able to send at @@ -818,7 +850,8 @@ async fn async_handle_redirects( let new_connection; if needs_new_connection { new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at).await?; + AsyncConnection::new(request.connection_params(), request.timeout_at, None) + .await?; connection = &new_connection; } connection.send(request).await diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index baaadccab..b9bcb9643 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -14,7 +14,7 @@ use native_tls::{HandshakeError, TlsConnector, TlsStream}; use rustls::pki_types::ServerName; #[cfg(feature = "rustls")] use rustls::{self, ClientConfig, ClientConnection, RootCertStore, StreamOwned}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] use tokio_native_tls::TlsConnector as AsyncTlsConnector; #[cfg(feature = "tokio-rustls")] use tokio_rustls::{client::TlsStream, TlsConnector}; @@ -25,6 +25,8 @@ use webpki_roots::TLS_SERVER_ROOTS; use super::HttpStream; #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] use super::{AsyncHttpStream, AsyncTcpStream}; +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +use crate::connection::tls_config::TlsConfig; use crate::Error; #[cfg(feature = "rustls")] @@ -50,6 +52,13 @@ fn build_client_config() -> Arc { Arc::new(config) } +#[cfg(feature = "tokio-rustls")] +fn build_rustls_client_config(certificates: Arc) -> Arc { + let config = ClientConfig::builder().with_root_certificates(certificates).with_no_client_auth(); + + Arc::new(config) +} + #[cfg(feature = "rustls")] pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] @@ -69,10 +78,10 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; -#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +#[cfg(feature = "tokio-rustls")] pub(super) async fn wrap_async_stream( tcp: AsyncTcpStream, host: &str, @@ -93,6 +102,29 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } +#[cfg(feature = "tokio-rustls")] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + tls_config: TlsConfig, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + let dns_name = ServerName::try_from(host) + .map(|name| name.to_owned()) + .map_err(|err| Error::IoError(io::Error::new(io::ErrorKind::Other, err)))?; + let certificates = tls_config.certificates; + let client_config = build_rustls_client_config(certificates); + let connector = TlsConnector::from(client_config); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(dns_name, tcp).await.map_err(Error::IoError)?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} + #[cfg(all(feature = "native-tls", not(feature = "rustls")))] pub type SecuredStream = TlsStream; @@ -138,6 +170,18 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; +// ======= +// Temp method, required for compilation +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + _client_configs: Option<()>, +) -> Result { + wrap_async_stream(tcp, host).await +} +// ======= + #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] pub(super) async fn wrap_async_stream( tcp: AsyncTcpStream, diff --git a/bitreq/src/connection/tls_config.rs b/bitreq/src/connection/tls_config.rs new file mode 100644 index 000000000..5a0827387 --- /dev/null +++ b/bitreq/src/connection/tls_config.rs @@ -0,0 +1,68 @@ +use std::sync::Arc; + +#[cfg(feature = "rustls")] +use rustls::pki_types::CertificateDer; +#[cfg(feature = "rustls")] +use rustls::RootCertStore; +#[cfg(feature = "rustls-webpki")] +use webpki_roots::TLS_SERVER_ROOTS; + +use crate::Error; + +#[cfg(feature = "rustls")] +pub(crate) struct TlsConfigBuilder { + pub(crate) inner: RootCertStore, + pub(crate) disable_default: bool, +} + +#[cfg(feature = "tokio-rustls")] +impl TlsConfigBuilder { + pub(crate) fn new(cert_der: Option>) -> Result { + let mut tls_config = Self { inner: RootCertStore::empty(), disable_default: false }; + + if let Some(cert_der) = cert_der { + tls_config.append_certificate(cert_der)?; + } + + Ok(tls_config) + } + + pub(crate) fn append_certificate(&mut self, cert_der: Vec) -> Result<&mut Self, Error> { + self.inner.add(CertificateDer::from(cert_der)).map_err(Error::RustlsAppendCert)?; + + Ok(self) + } + + fn with_root_certificates(&mut self) -> &mut Self { + // Try to load native certs + #[cfg(feature = "https-rustls-probe")] + for cert in rustls_native_certs::load_native_certs().certs { + let _ = self.inner.add(cert); + } + + #[cfg(feature = "rustls-webpki")] + { + self.inner.extend(TLS_SERVER_ROOTS.iter().cloned()); + } + self + } + + pub(crate) fn disable_default_certificates(&mut self) -> Result<&mut Self, Error> { + self.disable_default = true; + Ok(self) + } + + pub(crate) fn build(mut self) -> Result { + if !self.disable_default { + self.with_root_certificates(); + } + + Ok(TlsConfig { certificates: Arc::new(self.inner) }) + } +} + +#[derive(Clone)] +#[cfg(feature = "rustls")] +pub(crate) struct TlsConfig { + pub(crate) certificates: Arc, +} diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index d923c90e0..0e28d1075 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -22,9 +22,15 @@ pub enum Error { #[cfg(feature = "rustls")] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), + #[cfg(feature = "rustls")] + /// Ran into a rustls error while appending a certificate. + RustlsAppendCert(rustls::Error), #[cfg(feature = "native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), + #[cfg(any(feature = "rustls", feature = "native-tls"))] + /// The current TLS configuration is invalid. + InvalidTlsConfig, /// Ran into an IO problem while loading the response. #[cfg(feature = "std")] IoError(io::Error), @@ -104,8 +110,12 @@ impl fmt::Display for Error { InvalidUtf8InBody(err) => write!(f, "{}", err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err), #[cfg(feature = "native-tls")] NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), + #[cfg(any(feature = "rustls", feature = "native-tls"))] + InvalidTlsConfig => write!(f, "error disabling default certificates. Must have custom cert."), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"), MalformedContentLength => write!(f, "non-usize content length"), @@ -147,6 +157,8 @@ impl error::Error for Error { InvalidUtf8InBody(err) => Some(err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => Some(err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => Some(err), _ => None, } } diff --git a/bitreq/src/request.rs b/bitreq/src/request.rs index d39d6d89a..24411bc0b 100644 --- a/bitreq/src/request.rs +++ b/bitreq/src/request.rs @@ -327,7 +327,7 @@ impl Request { #[cfg(feature = "async")] pub async fn send_async(self) -> Result { let parsed_request = ParsedRequest::new(self)?; - AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at) + AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at, None) .await? .send(parsed_request) .await diff --git a/bitreq/tests/ca_cert.der b/bitreq/tests/ca_cert.der new file mode 100644 index 0000000000000000000000000000000000000000..994da6aa5cb97333c4a5616c1e3238f9730efb5d GIT binary patch literal 1413 zcmXqLVr?{NV$NK^%*4pVB%+TM5?S>}Vh-&TC|5U}R)qXk=h%Y8EBVYYgNXT0prpu(63r3E5kWtPIReO#BQ6 zO-x)&O-zgo%Zlb$Wci35y>TM6NNd5WtGbRnsr}q)iSa9a*)ltNYk8xTFP|0ae5A0X zSY>_qHEx@0)yezsDqWIWv~QQsrkwS^wo4}n+{(Lr?Oz$Eu$SjT zqgd`++HSvkckN4o_?JtT*+)c`Y*f(s{&e^G7yi50OXjOfZo6Q0=cVP%|L=ZHWy_1+ zUHh>r^xhNwsZS2IYD~HB^|mma{rkszx8DDl#xOxBr?Rx|wky-RI3||$`3JvfH6}#w zRyg|chF>}7HsQ!H)s3Qm1HXMpSyrelD8E$XXxXhL&kLTn6s|h^;V|2=Shb+S%$S-Z z@`um22i6#_+VY#{tKT-JHy3p81RRk3?B=^YqLShAQ_++ft~D-S@8}#k7d!u;mBL!J zMT#GI-FB*dXx@6a@PVD*@h4A87rj_Ix9j<)XK!x_yqY_|+cJB*?SWPaak>4q70YKG z)>UFS@KxT*Zo`x^f360*2Z_2i%?+NsQ5kFGHnqF^IUKqdTdc#s>ypdi=XU3VJ$L+2 zw2+-)l3A3ny=u1k!}(`!?O(m^V)Me>)jiv)XLNoPZ(F(j^Xm6P-io~WhZBD7OwCtw z7VcT|>+^@*^VQ`-!?O>Z`?TazazKbpZS)!6*-Xrg42+9|4FU~hfvHuNk420{B)87~ z#D<+;WQ@ElCV#j6AE2x8ah-uYNLrai!a%G6y8?cY0%1nR|17Kq%s>h`*#UDCFxfFO zM6+yJ)%JZyu-B)}{TqI7)w0?0b^46hjAFz1tTuO>0|MbC#q+(F&Ym9G@KI()RqdLK zZ7{bPPj2=g52gkNg8Xa9W~kC2Q1!F zT6$i^J7TudZk?EZ*(-?wrB2AK*Lo(l WQ25g=wRK;m-z`%2PCl@u${YYQZGm5>c(WMyD( zV&rE4x`T_UiII_^cfy=T-EWmm)_y7}5!Sn!%T_GdxqTyF$>i1jkK}ecUYHX&ou8|t z{^^$w$sDpU>)WyfE*9*2Rr6qDMW~mo*nvBSEj?Tf&DA2A%f9HEeTq-Nar(r)m6g*TE;`0! zwsw#8)jN{mSIistYzbcX@Ob5!-JT1px84swUs3D#C+TV9t(g{&9;=ik?Nt$4{!sU! z?W?&j3pVb4&iP2RyU#yhO5D5DxtxU{}42+9|4FU~hf#EF6$0Eie!kaW}cH!6G!FTrT zGZ#D6EEb%w>ArzHNLrai!a%G6y8?cY0%1nR|17Kq%s>h`*nx2g40c8axz<|7NxZjaW=#Q|F67L5C@f415>;Jsbk-aPmpT7TnAaXRsA$MPGSGdX5FHbkS zFt$vHl(^Fxtgg0AdedN0son>iLjxFi9r?o@Fc>1v~5v=WI>sc+<&99q&?5=YF zi&N@T)~|glFY|qxvO9N4Kg+npA)(pB?>MI?0o}1f!rXacDTFlYgj1Pa= F0|4+yHh2I4 literal 0 HcmV?d00001 From 7d50fd37fd9c1b44e7e129110c43c3f58f74b96f Mon Sep 17 00:00:00 2001 From: airton Date: Fri, 20 Mar 2026 01:43:20 +0100 Subject: [PATCH 3/3] implement client builder for native-tls --- bitreq/src/client.rs | 16 ++++----- bitreq/src/connection.rs | 16 ++------- bitreq/src/connection/rustls_stream.rs | 35 ++++++++++-------- bitreq/src/connection/tls_config.rs | 50 +++++++++++++++++++++++++- bitreq/src/error.rs | 7 ++++ bitreq/tests/main.rs | 42 ++++++++++++++++++++++ 6 files changed, 129 insertions(+), 37 deletions(-) diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index 0fd793ddf..3be40a1cd 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -9,7 +9,7 @@ use std::collections::{hash_map, HashMap, VecDeque}; use std::sync::{Arc, Mutex}; -#[cfg(feature = "tokio-rustls")] +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] use crate::connection::tls_config::{TlsConfig, TlsConfigBuilder}; use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; @@ -17,13 +17,13 @@ use crate::{Error, Request, Response}; #[derive(Clone)] pub(crate) struct ClientConfig { - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] pub(crate) tls: Option, } pub struct ClientBuilder { capacity: usize, - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] tls_config: Option, } @@ -48,7 +48,7 @@ impl ClientBuilder { pub fn new() -> Self { Self { capacity: 10, - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] tls_config: None, } } @@ -59,8 +59,8 @@ impl ClientBuilder { self } - #[cfg(feature = "tokio-rustls")] /// Builds the `Client` with the configured settings. + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] pub fn build(self) -> Result { let build_config = if let Some(builder) = self.tls_config { let tls_config = builder.build()?; @@ -81,7 +81,7 @@ impl ClientBuilder { } /// Builds the `Client` with the configured settings. - #[cfg(not(feature = "tokio-rustls"))] + #[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))] pub fn build(self) -> Result { Ok(Client { r#async: Arc::new(Mutex::new(ClientImpl { @@ -111,7 +111,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] pub fn with_root_certificate>>(mut self, cert_der: T) -> Result { let cert_der = cert_der.into(); if let Some(ref mut tls_config) = self.tls_config { @@ -126,7 +126,7 @@ impl ClientBuilder { /// Disables default root certificates for TLS connections. /// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured. - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] pub fn disable_default_certificates(mut self) -> Result { match self.tls_config { Some(ref mut tls_config) => tls_config.disable_default_certificates()?, diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index 5d8e98ee4..0ccb1a4ac 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -33,7 +33,7 @@ type UnsecuredStream = TcpStream; #[cfg(any(feature = "rustls", feature = "native-tls"))] mod rustls_stream; -#[cfg(feature = "tokio-rustls")] +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] pub(crate) mod tls_config; #[cfg(any(feature = "rustls", feature = "native-tls"))] type SecuredStream = rustls_stream::SecuredStream; @@ -304,20 +304,8 @@ impl AsyncConnection { })))) } - // ======= - /// Temp method. Required to compile - #[cfg(all(feature = "tokio-native-tls", not(feature = "tokio-rustls")))] - async fn wrap_async_stream( - socket: AsyncTcpStream, - host: &str, - _client_config: &Option>, - ) -> Result { - rustls_stream::wrap_async_stream(socket, host).await - } - // ======= - /// Call the correct wrapper function depending on whether client_configs are present - #[cfg(feature = "tokio-rustls")] + #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] async fn wrap_async_stream( socket: AsyncTcpStream, host: &str, diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index b9bcb9643..cd4f0730a 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -25,7 +25,7 @@ use webpki_roots::TLS_SERVER_ROOTS; use super::HttpStream; #[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] use super::{AsyncHttpStream, AsyncTcpStream}; -#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))] use crate::connection::tls_config::TlsConfig; use crate::Error; @@ -167,22 +167,10 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; -// ======= -// Temp method, required for compilation #[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] -pub(super) async fn wrap_async_stream_with_configs( - tcp: AsyncTcpStream, - host: &str, - _client_configs: Option<()>, -) -> Result { - wrap_async_stream(tcp, host).await -} -// ======= - -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] pub(super) async fn wrap_async_stream( tcp: AsyncTcpStream, host: &str, @@ -206,3 +194,22 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } + +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + tls_config: TlsConfig, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + let async_connector = tls_config.connector; + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = async_connector.connect(host, tcp).await?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} diff --git a/bitreq/src/connection/tls_config.rs b/bitreq/src/connection/tls_config.rs index 5a0827387..dd20e2818 100644 --- a/bitreq/src/connection/tls_config.rs +++ b/bitreq/src/connection/tls_config.rs @@ -1,9 +1,13 @@ use std::sync::Arc; +#[cfg(not(feature = "rustls"))] +use native_tls::{Certificate, TlsConnector, TlsConnectorBuilder}; #[cfg(feature = "rustls")] use rustls::pki_types::CertificateDer; #[cfg(feature = "rustls")] use rustls::RootCertStore; +#[cfg(not(feature = "rustls"))] +use tokio_native_tls::TlsConnector as AsyncTlsConnector; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; @@ -15,7 +19,12 @@ pub(crate) struct TlsConfigBuilder { pub(crate) disable_default: bool, } -#[cfg(feature = "tokio-rustls")] +#[cfg(not(feature = "rustls"))] +pub(crate) struct TlsConfigBuilder { + pub(crate) inner: TlsConnectorBuilder, +} + +#[cfg(feature = "rustls")] impl TlsConfigBuilder { pub(crate) fn new(cert_der: Option>) -> Result { let mut tls_config = Self { inner: RootCertStore::empty(), disable_default: false }; @@ -61,8 +70,47 @@ impl TlsConfigBuilder { } } +#[cfg(not(feature = "rustls"))] +impl TlsConfigBuilder { + pub(crate) fn new(cert_der: Option>) -> Result { + let builder = TlsConnector::builder(); + let mut tls_config = Self { inner: builder }; + + if let Some(cert_der) = cert_der { + tls_config.append_certificate(cert_der)?; + } + + Ok(tls_config) + } + + pub(crate) fn append_certificate(&mut self, cert_der: Vec) -> Result<&mut Self, Error> { + let certificate = Certificate::from_der(&cert_der)?; + self.inner.add_root_certificate(certificate); + + Ok(self) + } + + pub(crate) fn disable_default_certificates(&mut self) -> Result<&mut Self, Error> { + self.inner.disable_built_in_roots(true); + Ok(self) + } + + pub(crate) fn build(self) -> Result { + let connector = self.inner.build()?; + let async_connector = AsyncTlsConnector::from(connector); + + Ok(TlsConfig { connector: Arc::new(async_connector) }) + } +} + #[derive(Clone)] #[cfg(feature = "rustls")] pub(crate) struct TlsConfig { pub(crate) certificates: Arc, } + +#[derive(Clone)] +#[cfg(not(feature = "rustls"))] +pub(crate) struct TlsConfig { + pub(crate) connector: Arc, +} diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index 0e28d1075..878609c77 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[cfg(feature = "native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), + #[cfg(feature = "native-tls")] + /// Ran into a native-tls error while appending a certificate. + NativeTlsAppendCert, #[cfg(any(feature = "rustls", feature = "native-tls"))] /// The current TLS configuration is invalid. InvalidTlsConfig, @@ -114,6 +117,8 @@ impl fmt::Display for Error { RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err), #[cfg(feature = "native-tls")] NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), + #[cfg(feature = "native-tls")] + NativeTlsAppendCert => write!(f, "error appending certificate"), #[cfg(any(feature = "rustls", feature = "native-tls"))] InvalidTlsConfig => write!(f, "error disabling default certificates. Must have custom cert."), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), @@ -159,6 +164,8 @@ impl error::Error for Error { RustlsCreateConnection(err) => Some(err), #[cfg(feature = "rustls")] RustlsAppendCert(err) => Some(err), + #[cfg(feature = "native-tls")] + NativeTlsCreateConnection(err) => Some(err), _ => None, } } diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 3f652abad..959ce2b3d 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -52,6 +52,15 @@ async fn test_https_with_client_builder() { assert_eq!(response.status_code, 200); } +#[tokio::test] +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] +async fn test_https_with_client_builder() { + setup(); + let client = bitreq::Client::builder().build().unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "tokio-rustls")] async fn test_https_with_client_builder_and_cert() { @@ -66,6 +75,39 @@ async fn test_https_with_client_builder_and_cert() { assert_eq!(response.status_code, 200); } +#[tokio::test] +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] +async fn test_https_with_client_builder_and_cert() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .build() + .unwrap(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "tokio-native-tls", not(feature = "rustls")))] +async fn test_https_with_multiple_certs() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let ca_der = include_bytes!("ca_cert.der"); + + let client = bitreq::Client::builder() + .with_root_certificate(cert_der.as_slice()) + .unwrap() + .with_root_certificate(ca_der.as_slice()) + .unwrap() + .build() + .unwrap(); + + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "tokio-rustls")] async fn test_https_with_multiple_certs() {