diff --git a/src/client/config.rs b/src/client/config.rs index fff68bc1..5640b4d2 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -38,6 +38,8 @@ pub struct Config { pub(crate) enum TrustConfig { #[allow(dead_code)] CaCertificateLocation(PathBuf), + #[allow(dead_code)] + CaCertificatePem(String), TrustAll, Default, } @@ -130,14 +132,19 @@ impl Config { /// storage (or use `trust_cert_ca` instead), using this setting is potentially dangerous. /// /// # Panics - /// Will panic in case `trust_cert_ca` was called before. + /// Will panic in case `trust_cert_ca` or `trust_cert_ca_pem` was called before. /// /// - Defaults to `default`, meaning server certificate is validated against system-truststore. pub fn trust_cert(&mut self) { - if let TrustConfig::CaCertificateLocation(_) = &self.trust { - panic!("'trust_cert' and 'trust_cert_ca' are mutual exclusive! Only use one.") - } - self.trust = TrustConfig::TrustAll; + self.trust = match &self.trust { + TrustConfig::Default | TrustConfig::TrustAll => TrustConfig::TrustAll, + TrustConfig::CaCertificatePem(_) => { + panic!("'trust_cert_ca_pem' and 'trust_cert' are mutual exclusive! Only use one.") + } + TrustConfig::CaCertificateLocation(_) => { + panic!("'trust_cert_ca' and 'trust_cert' are mutual exclusive! Only use one.") + } + }; } /// If set, the server certificate will be validated against the given CA certificate in @@ -146,15 +153,44 @@ impl Config { /// trust-chain. /// /// # Panics - /// Will panic in case `trust_cert` was called before. + /// Will panic in case `trust_cert` or `trust_cert_ca_pem` was called before. /// /// - Defaults to validating the server certificate is validated against system's certificate storage. pub fn trust_cert_ca(&mut self, path: impl ToString) { - if let TrustConfig::TrustAll = &self.trust { - panic!("'trust_cert' and 'trust_cert_ca' are mutual exclusive! Only use one.") - } else { - self.trust = TrustConfig::CaCertificateLocation(PathBuf::from(path.to_string())) - } + self.trust = match &self.trust { + TrustConfig::Default | TrustConfig::CaCertificateLocation(_) => { + TrustConfig::CaCertificateLocation(PathBuf::from(path.to_string())) + } + TrustConfig::CaCertificatePem(_) => { + panic!("'trust_cert_ca_pem' and 'trust_cert_ca' are mutual exclusive! Only use one.") + } + TrustConfig::TrustAll => { + panic!("'trust_cert' and 'trust_cert_ca' are mutual exclusive! Only use one.") + } + }; + } + + /// If set, the server certificate will be validated against the given CA certificate in + /// in addition to the system-truststore. + /// Useful when using self-signed certificates on the server without having to disable the + /// trust-chain. + /// + /// # Panics + /// Will panic in case `trust_cert` or `trust_cert_ca` was called before. + /// + /// - Defaults to validating the server certificate is validated against system's certificate storage. + pub fn trust_cert_ca_pem(&mut self, cert: impl ToString) { + self.trust = match &self.trust { + TrustConfig::Default | TrustConfig::CaCertificatePem(_) => { + TrustConfig::CaCertificatePem(cert.to_string()) + } + TrustConfig::CaCertificateLocation(_) => { + panic!("'trust_cert_ca' and 'trust_cert_ca_pem' are mutual exclusive! Only use one.") + } + TrustConfig::TrustAll => { + panic!("'trust_cert' and 'trust_cert_ca_pem' are mutual exclusive! Only use one.") + } + }; } /// Sets the authentication method. diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index cf5591d8..2e9fd918 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -15,6 +15,10 @@ pub(crate) async fn create_tls_stream( let mut builder = TlsConnector::new(); match &config.trust { + TrustConfig::CaCertificatePem(cert) => { + let cert = Certificate::from_pem(cert.as_bytes())?; + builder = builder.add_root_certificate(cert); + } TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { let cert = match path.extension() { diff --git a/src/client/tls_stream/opentls_tls_stream.rs b/src/client/tls_stream/opentls_tls_stream.rs index 1f028669..cf3889c9 100644 --- a/src/client/tls_stream/opentls_tls_stream.rs +++ b/src/client/tls_stream/opentls_tls_stream.rs @@ -15,6 +15,10 @@ pub(crate) async fn create_tls_stream( let mut builder = TlsConnector::new(); match &config.trust { + TrustConfig::CaCertificatePem(cert) => { + let cert = Certificate::from_pem(cert.as_bytes())?; + builder = builder.add_root_certificate(cert); + } TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { let cert = match path.extension() { diff --git a/src/client/tls_stream/rustls_tls_stream.rs b/src/client/tls_stream/rustls_tls_stream.rs index e417583a..1a8c8ffa 100644 --- a/src/client/tls_stream/rustls_tls_stream.rs +++ b/src/client/tls_stream/rustls_tls_stream.rs @@ -77,6 +77,23 @@ impl TlsStream { let builder = ClientConfig::builder().with_safe_defaults(); let client_config = match &config.trust { + TrustConfig::CaCertificatePem(cert) => { + let pem_cert = rustls_pemfile::certs(&mut cert.as_bytes())?; + if pem_cert.len() != 1 { + return Err(crate::Error::Io { + kind: IoErrorKind::InvalidInput, + message: format!("Proided PEM cannot contain more than 1 cert"), + }); + } + + let cert = Certificate(pem_cert.into_iter().next().unwrap()); + + let mut cert_store = RootCertStore::empty(); + cert_store.add(&cert)?; + builder + .with_root_certificates(cert_store) + .with_no_client_auth() + } TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { let cert = match path.extension() { diff --git a/tests/custom-cert.rs b/tests/custom-cert.rs index 8aeed94e..2bf78a13 100644 --- a/tests/custom-cert.rs +++ b/tests/custom-cert.rs @@ -1,5 +1,5 @@ -#![cfg(unix)] -use std::sync::Once; +#[allow(unused_imports)] +use std::{fs::File, io::Read, sync::Once}; use tiberius::{AuthMethod, Client, Config, EncryptionLevel, Result}; use tokio::{net::TcpStream, runtime::Runtime}; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -102,3 +102,43 @@ fn connect_to_custom_cert_instance_without_ca() -> Result<()> { Ok(()) }) } + +#[test] +#[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" +))] +fn connect_to_customer_cert_instance_with_pem() -> Result<()> { + LOGGER_SETUP.call_once(|| { + env_logger::init(); + }); + let mut cert_file = File::open("docker/certs/customCA.crt")?; + let mut cert = String::new(); + cert_file.read_to_string(&mut cert)?; + + let rt = Runtime::new()?; + + rt.block_on(async { + let mut config = Config::new(); + config.host("localhost"); + config.port(1433); + config.trust_cert_ca_pem(cert); + config.encryption(EncryptionLevel::Required); + config.authentication(AuthMethod::sql_server("sa", "")); + + let tcp = TcpStream::connect(config.get_addr()).await?; + + let mut client = Client::connect(config, tcp.compat_write()).await?; + + let row = client + .query("SELECT @P1", &[&-4i32]) + .await? + .into_row() + .await? + .unwrap(); + + assert_eq!(Some(-4i32), row.get(0)); + Ok(()) + }) +}