diff --git a/Cargo.lock b/Cargo.lock index 1e5941c..38b29ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -968,6 +968,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "filetime" version = "0.2.26" @@ -1283,12 +1289,14 @@ dependencies = [ "futures", "hotfix-message", "mongodb", + "rcgen", "redb", "rustls", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types", "serde", + "tempfile", "testcontainers", "thiserror", "tokio", @@ -2260,6 +2268,16 @@ dependencies = [ "digest", ] +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2619,6 +2637,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redb" version = "3.1.0" @@ -3341,6 +3372,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "termtree" version = "0.5.1" @@ -4450,6 +4494,15 @@ dependencies = [ "rustix", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/crates/hotfix-web/CHANGELOG.md b/crates/hotfix-web/CHANGELOG.md index f66ed26..f4ccbe2 100644 --- a/crates/hotfix-web/CHANGELOG.md +++ b/crates/hotfix-web/CHANGELOG.md @@ -11,12 +11,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Other -- release ([#268](https://github.com/Validus-Risk-Management/hotfix/pull/268)) - -## [0.1.5](https://github.com/Validus-Risk-Management/hotfix/compare/hotfix-web-v0.1.4...hotfix-web-v0.1.5) - 2026-01-20 - -### Other - - updated the following local packages: hotfix, hotfix-web-ui ## [0.1.4](https://github.com/Validus-Risk-Management/hotfix/compare/hotfix-web-v0.1.3...hotfix-web-v0.1.4) - 2025-12-09 @@ -41,7 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- support reconnects in shutdowns initiated via CLI tool ([#250](https://github.com/Validus-Risk-Management/hotfix/pull/250)) +- support reconnects in shutdowns initiated via CLI + tool ([#250](https://github.com/Validus-Risk-Management/hotfix/pull/250)) ## [0.1.0](https://github.com/Validus-Risk-Management/hotfix/releases/tag/hotfix-web-v0.1.0) - 2025-11-26 diff --git a/crates/hotfix/Cargo.toml b/crates/hotfix/Cargo.toml index 58a2441..847be51 100644 --- a/crates/hotfix/Cargo.toml +++ b/crates/hotfix/Cargo.toml @@ -47,5 +47,8 @@ uuid = { workspace = true, features = ["v4"] } [dev-dependencies] hotfix-message = { version = "0.2.8", path = "../hotfix-message", features = ["fix44", "utils-chrono"] } +rcgen = "0.13" +rustls = { workspace = true, features = ["ring"] } +tempfile = "3" testcontainers = { workspace = true } tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/hotfix/clippy.toml b/crates/hotfix/clippy.toml new file mode 100644 index 0000000..c64d65f --- /dev/null +++ b/crates/hotfix/clippy.toml @@ -0,0 +1,3 @@ +allow-expect-in-tests = true +allow-panic-in-tests = true +allow-unwrap-in-tests = true diff --git a/crates/hotfix/src/config.rs b/crates/hotfix/src/config.rs index 8eeb820..858516b 100644 --- a/crates/hotfix/src/config.rs +++ b/crates/hotfix/src/config.rs @@ -19,9 +19,10 @@ pub struct Config { impl Config { /// Load a [Config] from a `toml` file. - pub fn load_from_path>(path: P) -> Self { - let config_str = fs::read_to_string(path).expect("to be able to load config"); - toml::from_str::(&config_str).expect("to be able to parse config") + pub fn load_from_path>(path: P) -> Result { + let config_str = fs::read_to_string(path)?; + let config = toml::from_str::(&config_str)?; + Ok(config) } } @@ -115,10 +116,22 @@ pub struct SessionConfig { pub schedule: Option, } +/// Errors that may occur when loading configuration. +#[derive(Debug, thiserror::Error)] +pub enum ConfigError { + #[error("failed to read config file")] + Io(#[from] std::io::Error), + + #[error("failed to parse config")] + Parse(#[from] toml::de::Error), +} + #[cfg(test)] mod tests { - use crate::config::{Config, TlsConfig}; + use crate::config::{Config, ConfigError, TlsConfig}; use chrono::{NaiveTime, Weekday}; + use std::io::Write; + use tempfile::NamedTempFile; #[test] fn test_simple_config() { @@ -425,4 +438,50 @@ end_day = "Friday" let session_config = config.sessions.first().unwrap(); assert_eq!(session_config.reconnect_interval, 15); } + + #[test] + fn test_load_from_path_success() { + let config_contents = r#" +[[sessions]] +begin_string = "FIX.4.4" +sender_comp_id = "sender" +target_comp_id = "target" +connection_host = "127.0.0.1" +connection_port = 9876 +heartbeat_interval = 30 +"#; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(config_contents.as_bytes()).unwrap(); + + let config = Config::load_from_path(temp_file.path()).unwrap(); + assert_eq!(config.sessions.len(), 1); + + let session = config.sessions.first().unwrap(); + assert_eq!(session.begin_string, "FIX.4.4"); + assert_eq!(session.sender_comp_id, "sender"); + assert_eq!(session.target_comp_id, "target"); + assert_eq!(session.connection_host, "127.0.0.1"); + assert_eq!(session.connection_port, 9876); + assert_eq!(session.heartbeat_interval, 30); + } + + #[test] + fn test_load_from_path_missing_file() { + let result = Config::load_from_path("/nonexistent/path/to/config.toml"); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::Io(_))); + } + + #[test] + fn test_load_from_path_invalid_toml() { + let invalid_toml = "this is not valid toml {{{{"; + + let mut temp_file = NamedTempFile::new().unwrap(); + temp_file.write_all(invalid_toml.as_bytes()).unwrap(); + + let result = Config::load_from_path(temp_file.path()); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::Parse(_))); + } } diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index 2f852b5..8f0e1c2 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -94,11 +94,21 @@ async fn establish_connection( completion_tx: watch::Sender, ) { loop { - session_ref.await_active_session_time().await; + if session_ref.await_active_session_time().await.is_err() { + warn!("session task terminated when checking active session time"); + break; + } match connect(&config, session_ref.clone()).await { Ok(conn) => { - session_ref.register_writer(conn.get_writer()).await; + if session_ref + .register_writer(conn.get_writer()) + .await + .is_err() + { + warn!("session task terminated when trying to register writer"); + break; + }; conn.run_until_disconnect().await; warn!("session connection dropped, attempting to reconnect"); } @@ -108,9 +118,18 @@ async fn establish_connection( } }; - if !session_ref.should_reconnect().await { - warn!("session indicated we shouldn't reconnect"); - break; + match session_ref.should_reconnect().await { + Ok(false) => { + warn!("session indicated we shouldn't reconnect"); + break; + } + Ok(true) => { + debug!("session indicated we should reconnect"); + } + Err(_) => { + warn!("session task terminated when making decision to reconnect"); + break; + } } let reconnect_interval = config.reconnect_interval; debug!("waiting for {reconnect_interval} seconds before attempting to reconnect"); diff --git a/crates/hotfix/src/lib.rs b/crates/hotfix/src/lib.rs index 169ab4d..296dace 100644 --- a/crates/hotfix/src/lib.rs +++ b/crates/hotfix/src/lib.rs @@ -20,6 +20,11 @@ //! //! Check out the [examples](https://github.com/Validus-Risk-Management/hotfix/tree/main/examples) //! to get started. + +#![deny(clippy::expect_used)] +#![deny(clippy::panic)] +#![deny(clippy::unwrap_used)] + pub mod application; pub mod config; pub(crate) mod error; diff --git a/crates/hotfix/src/message/reject.rs b/crates/hotfix/src/message/reject.rs index ddf99df..bf81ec0 100644 --- a/crates/hotfix/src/message/reject.rs +++ b/crates/hotfix/src/message/reject.rs @@ -78,7 +78,11 @@ impl OutboundMessage for Reject { impl InboundMessage for Reject { fn parse(message: &Message) -> Self { Self { - ref_seq_num: message.get(REF_SEQ_NUM).unwrap(), + // TODO: how do we handle errors in parsing messages? + #[allow(clippy::expect_used)] + ref_seq_num: message + .get(REF_SEQ_NUM) + .expect("ref_seq_num should be present"), ref_tag_id: message.get(REF_TAG_ID).ok(), ref_msg_type: message.get(REF_MSG_TYPE).ok(), session_reject_reason: message.get(SESSION_REJECT_REASON).ok(), diff --git a/crates/hotfix/src/message/sequence_reset.rs b/crates/hotfix/src/message/sequence_reset.rs index 152d9ed..2167fc0 100644 --- a/crates/hotfix/src/message/sequence_reset.rs +++ b/crates/hotfix/src/message/sequence_reset.rs @@ -16,7 +16,10 @@ impl OutboundMessage for SequenceReset { fn write(&self, msg: &mut Message) { msg.set(GAP_FILL_FLAG, self.gap_fill); msg.set(NEW_SEQ_NO, self.new_seq_no); - let sending_time: Timestamp = msg.header().get(SENDING_TIME).unwrap(); + #[allow(clippy::expect_used)] + let sending_time: Timestamp = msg.header().get(SENDING_TIME).expect( + "sending time should always be present due to previously having validated message", + ); msg.header_mut().set(ORIG_SENDING_TIME, sending_time); msg.header_mut().set(POSS_DUP_FLAG, true); } diff --git a/crates/hotfix/src/session/session_ref.rs b/crates/hotfix/src/session/session_ref.rs index 4c22b19..4bff916 100644 --- a/crates/hotfix/src/session/session_ref.rs +++ b/crates/hotfix/src/session/session_ref.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use thiserror::Error; use tokio::sync::{mpsc, oneshot}; use tracing::debug; @@ -42,44 +43,63 @@ impl InternalSessionRef { }) } - pub async fn register_writer(&self, writer: WriterRef) { + pub async fn register_writer(&self, writer: WriterRef) -> Result<(), SessionGone> { self.event_sender .send(SessionEvent::Connected(writer)) - .await - .expect("be able to register writer"); + .await?; + + Ok(()) } - pub async fn new_fix_message_received(&self, msg: RawFixMessage) { + pub async fn new_fix_message_received(&self, msg: RawFixMessage) -> Result<(), SessionGone> { self.event_sender .send(SessionEvent::FixMessageReceived(msg)) - .await - .expect("be able to receive message"); + .await?; + + Ok(()) } - pub async fn disconnect(&self, reason: String) { + pub async fn disconnect(&self, reason: String) -> Result<(), SessionGone> { self.event_sender .send(SessionEvent::Disconnected(reason)) - .await - .expect("be able to send disconnect"); + .await?; + + Ok(()) } - pub async fn should_reconnect(&self) -> bool { + pub async fn should_reconnect(&self) -> Result { let (sender, receiver) = oneshot::channel(); self.event_sender .send(SessionEvent::ShouldReconnect(sender)) - .await - .unwrap(); - receiver.await.expect("to receive a response") + .await?; + Ok(receiver.await?) } - pub async fn await_active_session_time(&self) { + pub async fn await_active_session_time(&self) -> Result<(), SessionGone> { debug!("awaiting active session time"); let (sender, receiver) = oneshot::channel::(); self.event_sender .send(SessionEvent::AwaitingActiveSession(sender)) - .await - .unwrap(); - receiver.await.expect("to receive a response"); + .await?; + receiver.await?; + debug!("resuming connection as session is active"); + Ok(()) + } +} + +#[derive(Debug, Error)] +#[error("session task terminated")] +pub struct SessionGone(String); + +impl From> for SessionGone { + fn from(err: mpsc::error::SendError) -> Self { + Self(err.to_string()) + } +} + +impl From for SessionGone { + fn from(err: oneshot::error::RecvError) -> Self { + Self(err.to_string()) } } diff --git a/crates/hotfix/src/transport.rs b/crates/hotfix/src/transport.rs index d8e0288..20e1a52 100644 --- a/crates/hotfix/src/transport.rs +++ b/crates/hotfix/src/transport.rs @@ -1,7 +1,9 @@ mod connection; +pub mod error; pub mod reader; pub mod socket; pub mod writer; pub use connection::FixConnection; +pub use error::{ConnectionError, ConnectionResult}; pub use socket::*; diff --git a/crates/hotfix/src/transport/error.rs b/crates/hotfix/src/transport/error.rs new file mode 100644 index 0000000..8ae32a9 --- /dev/null +++ b/crates/hotfix/src/transport/error.rs @@ -0,0 +1,11 @@ +/// Errors that may occur when establishing a connection. +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("IO error")] + IOError(#[from] std::io::Error), + + #[error("Invalid DNS name")] + InvalidDnsName(#[from] rustls_pki_types::InvalidDnsNameError), +} + +pub type ConnectionResult = Result; diff --git a/crates/hotfix/src/transport/reader.rs b/crates/hotfix/src/transport/reader.rs index 5d869c2..46ee36d 100644 --- a/crates/hotfix/src/transport/reader.rs +++ b/crates/hotfix/src/transport/reader.rs @@ -1,4 +1,5 @@ use tokio::sync::oneshot; +use tracing::warn; #[derive(Clone, Debug)] #[allow(dead_code)] @@ -14,8 +15,8 @@ impl ReaderRef { } pub async fn wait_for_disconnect(self) { - self.disconnect_signal - .await - .expect("not to drop signal prematurely"); + if self.disconnect_signal.await.is_err() { + warn!("reader dropped without issuing disconnect notification"); + } } } diff --git a/crates/hotfix/src/transport/socket.rs b/crates/hotfix/src/transport/socket.rs index 4bcf4b8..99481ab 100644 --- a/crates/hotfix/src/transport/socket.rs +++ b/crates/hotfix/src/transport/socket.rs @@ -3,11 +3,11 @@ pub mod socket_writer; pub mod tcp; pub mod tls; -use std::io; use tokio::io::{AsyncRead, AsyncWrite}; use crate::message::OutboundMessage; use crate::session::InternalSessionRef; +use crate::transport::error::ConnectionResult; use crate::{ config::SessionConfig, transport::{ @@ -20,14 +20,17 @@ use crate::{ pub async fn connect( config: &SessionConfig, session_ref: InternalSessionRef, -) -> io::Result { - let use_tls = config.tls_config.is_some(); - - let conn = if use_tls { - let stream = create_tcp_over_tls_connection(config).await?; +) -> ConnectionResult { + let conn = if let Some(tls_config) = config.tls_config.as_ref() { + let stream = create_tcp_over_tls_connection( + config.connection_host.to_owned(), + config.connection_port, + tls_config, + ) + .await?; _create_io_refs(session_ref.clone(), stream).await } else { - let stream = create_tcp_connection(config).await?; + let stream = create_tcp_connection(&config.connection_host, config.connection_port).await?; _create_io_refs(session_ref.clone(), stream).await }; diff --git a/crates/hotfix/src/transport/socket/socket_reader.rs b/crates/hotfix/src/transport/socket/socket_reader.rs index cc28b10..10c497d 100644 --- a/crates/hotfix/src/transport/socket/socket_reader.rs +++ b/crates/hotfix/src/transport/socket/socket_reader.rs @@ -49,30 +49,36 @@ where match actor.reader.read_buf(&mut buf).await { Ok(0) => { - actor + let _ = actor .session_ref .disconnect("received EOF".to_string()) .await; break; } Err(err) => { - actor.session_ref.disconnect(err.to_string()).await; + let _ = actor.session_ref.disconnect(err.to_string()).await; break; } Ok(_) => { let messages = parser.parse(&buf); for msg in messages { - actor.session_ref.new_fix_message_received(msg).await; + if actor + .session_ref + .new_fix_message_received(msg) + .await + .is_err() + { + debug!("reader received message but session has been terminated"); + } } } } } debug!("reader loop is shutting down"); - actor - .dc_sender - .send(()) - .expect("be able to signal disconnect"); + if actor.dc_sender.send(()).is_err() { + debug!("receiver dropped before we could notify them of reader disconnecting"); + } } #[cfg(test)] diff --git a/crates/hotfix/src/transport/socket/tcp.rs b/crates/hotfix/src/transport/socket/tcp.rs index e94d13c..2279690 100644 --- a/crates/hotfix/src/transport/socket/tcp.rs +++ b/crates/hotfix/src/transport/socket/tcp.rs @@ -1,12 +1,7 @@ use std::io; use tokio::net::TcpStream; -use crate::config::SessionConfig; - -pub async fn create_tcp_connection(session_config: &SessionConfig) -> io::Result { - let address = format!( - "{}:{}", - &session_config.connection_host, &session_config.connection_port - ); +pub async fn create_tcp_connection(host: &str, port: u16) -> io::Result { + let address = format!("{}:{}", host, port); TcpStream::connect(address).await } diff --git a/crates/hotfix/src/transport/socket/tls.rs b/crates/hotfix/src/transport/socket/tls.rs index 82cd842..3376689 100644 --- a/crates/hotfix/src/transport/socket/tls.rs +++ b/crates/hotfix/src/transport/socket/tls.rs @@ -1,6 +1,6 @@ +use std::fs; use std::io::BufReader; use std::sync::Arc; -use std::{fs, io}; use rustls::ClientConfig; use rustls::RootCertStore; @@ -9,40 +9,37 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio_rustls::{TlsConnector, client::TlsStream}; -use crate::config::{SessionConfig, TlsConfig}; +use crate::config::TlsConfig; +use crate::transport::error::ConnectionResult; use crate::transport::tcp::create_tcp_connection; pub async fn create_tcp_over_tls_connection( - session_config: &SessionConfig, -) -> io::Result> { - let tls_config = session_config - .tls_config - .as_ref() - .expect("TLS config must be present when creating TLS connection"); - let client_config = get_client_config(tls_config); - let socket = create_tcp_connection(session_config).await?; - wrap_stream( - socket, - session_config.connection_host.clone(), - Arc::new(client_config), - ) - .await + host: String, + port: u16, + tls_config: &TlsConfig, +) -> ConnectionResult> { + let client_config = get_client_config(tls_config)?; + let socket = create_tcp_connection(&host, port).await?; + wrap_stream(socket, host, Arc::new(client_config)).await } -fn get_client_config(tls_config: &TlsConfig) -> ClientConfig { - let root_store = get_root_store(tls_config); - ClientConfig::builder() +/// Create a TLS client configuration from the given TLS config. +pub fn get_client_config(tls_config: &TlsConfig) -> ConnectionResult { + let root_store = get_root_store(tls_config)?; + let client_config = ClientConfig::builder() .with_root_certificates(root_store) - .with_no_client_auth() + .with_no_client_auth(); + + Ok(client_config) } -fn get_root_store(tls_config: &TlsConfig) -> RootCertStore { - match tls_config { +fn get_root_store(tls_config: &TlsConfig) -> ConnectionResult { + let store = match tls_config { TlsConfig::File { ca_certificate_path, } => { let mut root_store = RootCertStore::empty(); - let certs = load_certs_from_file(ca_certificate_path); + let certs = load_certs_from_file(ca_certificate_path)?; root_store.add_parsable_certificates(certs); root_store } @@ -55,26 +52,28 @@ fn get_root_store(tls_config: &TlsConfig) -> RootCertStore { TlsConfig::Webpki => { RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()) } - } + }; + + Ok(store) } -fn load_certs_from_file(filename: &str) -> Vec> { - let certfile = fs::File::open(filename).expect("certificate file to be open"); +fn load_certs_from_file(filename: &str) -> ConnectionResult>> { + let certfile = fs::File::open(filename)?; let mut reader = BufReader::new(certfile); - rustls_pemfile::certs(&mut reader) - .map(|result| result.unwrap()) - .collect() + let certs = rustls_pemfile::certs(&mut reader).collect::, _>>()?; + + Ok(certs) } pub async fn wrap_stream( socket: S, domain: String, config: Arc, -) -> io::Result> +) -> ConnectionResult> where S: 'static + AsyncRead + AsyncWrite + Send + Unpin, { - let domain = ServerName::try_from(domain).unwrap(); + let domain = ServerName::try_from(domain)?; let stream = TlsConnector::from(config); - stream.connect(domain, socket).await + Ok(stream.connect(domain, socket).await?) } diff --git a/crates/hotfix/tests/connection_test_cases/connect_tests.rs b/crates/hotfix/tests/connection_test_cases/connect_tests.rs new file mode 100644 index 0000000..4e67710 --- /dev/null +++ b/crates/hotfix/tests/connection_test_cases/connect_tests.rs @@ -0,0 +1,138 @@ +//! Integration tests for the transport layer's `connect` function. +//! +//! These tests verify that `hotfix::transport::socket::connect` correctly +//! establishes connections over TCP and TLS. + +use crate::helpers::{ + MinimalApplication, MinimalMessage, TestCertificates, TestTcpServer, TestTlsServer, +}; +use hotfix::config::{SessionConfig, TlsConfig}; +use hotfix::session::InternalSessionRef; +use hotfix::store::in_memory::InMemoryMessageStore; +use hotfix::transport::socket::connect; + +fn create_session_config(host: &str, port: u16, tls_config: Option) -> SessionConfig { + SessionConfig { + begin_string: "FIX.4.4".to_string(), + sender_comp_id: "TEST_SENDER".to_string(), + target_comp_id: "TEST_TARGET".to_string(), + data_dictionary_path: None, + connection_host: host.to_string(), + connection_port: port, + tls_config, + heartbeat_interval: 30, + logon_timeout: 10, + logout_timeout: 2, + reconnect_interval: 30, + reset_on_logon: false, + schedule: None, + } +} + +fn create_session_ref() -> InternalSessionRef { + let store = InMemoryMessageStore::default(); + let app = MinimalApplication; + InternalSessionRef::new(create_session_config("", 0, None), app, store) + .expect("Failed to create session ref") +} + +#[tokio::test] +async fn test_connect_with_tls_config() { + // Generate certificates and start TLS server + let certs = TestCertificates::generate(&["localhost"]); + let ca_file = certs.write_ca_to_temp_file(); + let server = TestTlsServer::start(&certs).await; + + // Create session config with TLS + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + let config = create_session_config("localhost", server.port(), Some(tls_config)); + + // Create a session ref + let session_ref = create_session_ref(); + + // Call connect - this should establish a TLS connection and return a FixConnection + let result = connect(&config, session_ref).await; + + assert!( + result.is_ok(), + "connect() with TLS config should succeed: {:?}", + result.err() + ); + + let connection = result.unwrap(); + + // Verify we got a valid connection by checking we can get a writer + let _writer = connection.get_writer(); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_connect_without_tls_config() { + // Start a plain TCP server + let server = TestTcpServer::start().await; + + // Create session config without TLS + let config = create_session_config("127.0.0.1", server.port(), None); + + // Create a session ref + let session_ref = create_session_ref(); + + // Call connect - this should establish a plain TCP connection + let result = connect(&config, session_ref).await; + + assert!( + result.is_ok(), + "connect() without TLS config should succeed: {:?}", + result.err() + ); + + let connection = result.unwrap(); + + // Verify we got a valid connection + let _writer = connection.get_writer(); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_connect_with_tls_fails_on_bad_certificate() { + // Generate two different certificate sets + let server_certs = TestCertificates::generate(&["localhost"]); + let client_certs = TestCertificates::generate(&["localhost"]); // Different CA + let ca_file = client_certs.write_ca_to_temp_file(); + + // Start server with its own certificates + let server = TestTlsServer::start(&server_certs).await; + + // Client trusts a different CA - connection should fail + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + let config = create_session_config("localhost", server.port(), Some(tls_config)); + + let session_ref = create_session_ref(); + + let result = connect(&config, session_ref).await; + + assert!(result.is_err(), "connect() with untrusted CA should fail"); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_connect_fails_when_server_not_running() { + // Try to connect to a port where nothing is listening + let config = create_session_config("127.0.0.1", 59998, None); + + let session_ref = create_session_ref(); + + let result = connect(&config, session_ref).await; + + assert!( + result.is_err(), + "connect() to non-existent server should fail" + ); +} diff --git a/crates/hotfix/tests/connection_test_cases/helpers.rs b/crates/hotfix/tests/connection_test_cases/helpers.rs new file mode 100644 index 0000000..02e6fb5 --- /dev/null +++ b/crates/hotfix/tests/connection_test_cases/helpers.rs @@ -0,0 +1,362 @@ +//! Test helpers for connection-related integration tests. +//! +//! Provides utilities for generating test certificates, spinning up local servers, +//! and minimal Application implementations for testing. + +use std::io::Write; +use std::net::SocketAddr; +use std::sync::{Arc, Once}; + +use hotfix::Application; +use hotfix::application::{InboundDecision, OutboundDecision}; +use hotfix::message::{InboundMessage, OutboundMessage}; +use hotfix_message::message::Message; +use rcgen::{CertificateParams, DnType, IsCa, KeyPair, KeyUsagePurpose, SanType}; +use rustls::ServerConfig; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tempfile::NamedTempFile; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tokio_rustls::TlsAcceptor; + +static CRYPTO_PROVIDER_INIT: Once = Once::new(); + +/// Initialize the rustls crypto provider for tests. +/// This must be called before any TLS operations. +pub fn init_crypto_provider() { + CRYPTO_PROVIDER_INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install crypto provider"); + }); +} + +/// A set of test certificates including a CA and server certificate. +pub struct TestCertificates { + /// The CA certificate in PEM format. + pub ca_cert_pem: String, + /// The server certificate in DER format. + pub server_cert_der: CertificateDer<'static>, + /// The server private key in DER format. + pub server_key_der: PrivateKeyDer<'static>, +} + +impl TestCertificates { + /// Generate a new set of test certificates. + /// + /// Creates a self-signed CA certificate and a server certificate signed by that CA. + /// The server certificate will be valid for the specified domain names. + pub fn generate(domains: &[&str]) -> Self { + // Ensure crypto provider is initialized + init_crypto_provider(); + + // Generate CA key pair + let ca_key_pair = KeyPair::generate().expect("Failed to generate CA key pair"); + + // Create CA certificate parameters + let mut ca_params = CertificateParams::default(); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(DnType::CommonName, "Test CA"); + ca_params + .distinguished_name + .push(DnType::OrganizationName, "Test Organization"); + ca_params.key_usages = vec![ + KeyUsagePurpose::KeyCertSign, + KeyUsagePurpose::CrlSign, + KeyUsagePurpose::DigitalSignature, + ]; + + // Generate the CA certificate + let ca_cert = ca_params + .self_signed(&ca_key_pair) + .expect("Failed to generate CA certificate"); + + // Generate server key pair + let server_key_pair = KeyPair::generate().expect("Failed to generate server key pair"); + + // Create server certificate parameters + let mut server_params = CertificateParams::default(); + server_params + .distinguished_name + .push(DnType::CommonName, *domains.first().unwrap_or(&"localhost")); + server_params + .distinguished_name + .push(DnType::OrganizationName, "Test Organization"); + + // Add Subject Alternative Names for all domains + server_params.subject_alt_names = domains + .iter() + .map(|d| { + // Try to parse as IP address first + if let Ok(ip) = d.parse::() { + SanType::IpAddress(ip) + } else { + SanType::DnsName((*d).try_into().expect("Invalid DNS name")) + } + }) + .collect(); + + server_params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + + // Sign the server certificate with the CA + let server_cert = server_params + .signed_by(&server_key_pair, &ca_cert, &ca_key_pair) + .expect("Failed to sign server certificate"); + + TestCertificates { + ca_cert_pem: ca_cert.pem(), + server_cert_der: CertificateDer::from(server_cert.der().to_vec()), + server_key_der: PrivateKeyDer::try_from(server_key_pair.serialize_der()) + .expect("Failed to convert server key"), + } + } + + /// Write the CA certificate to a temporary file and return the file. + /// + /// The returned `NamedTempFile` will keep the file alive as long as it exists. + pub fn write_ca_to_temp_file(&self) -> NamedTempFile { + let mut temp_file = + NamedTempFile::new().expect("Failed to create temporary file for CA cert"); + temp_file + .write_all(self.ca_cert_pem.as_bytes()) + .expect("Failed to write CA cert to temp file"); + temp_file.flush().expect("Failed to flush temp file"); + temp_file + } + + /// Create a rustls ServerConfig from this certificate set. + pub fn server_config(&self) -> ServerConfig { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + vec![self.server_cert_der.clone()], + self.server_key_der.clone_key(), + ) + .expect("Failed to create server config") + } +} + +/// A test TLS server that can be used for integration testing. +pub struct TestTlsServer { + /// The address the server is listening on. + pub addr: SocketAddr, + /// Channel to signal the server to shut down. + shutdown_tx: Option>, + /// Handle to the server task. + task_handle: Option>, +} + +impl TestTlsServer { + /// Start a new TLS server with the given certificates. + /// + /// The server will listen on a random available port on localhost. + /// It echoes back any data it receives. + pub async fn start(certs: &TestCertificates) -> Self { + Self::start_with_behavior(certs, ServerBehavior::Echo).await + } + + /// Start a new TLS server with specified behavior. + pub async fn start_with_behavior(certs: &TestCertificates, behavior: ServerBehavior) -> Self { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind TCP listener"); + let addr = listener.local_addr().expect("Failed to get local address"); + + let server_config = certs.server_config(); + let acceptor = TlsAcceptor::from(Arc::new(server_config)); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + + let task_handle = tokio::spawn(async move { + loop { + tokio::select! { + result = listener.accept() => { + match result { + Ok((tcp_stream, _peer_addr)) => { + let acceptor = acceptor.clone(); + tokio::spawn(async move { + match behavior { + ServerBehavior::Echo => { + if let Ok(mut tls_stream) = acceptor.accept(tcp_stream).await { + let mut buf = [0u8; 1024]; + while let Ok(n) = tls_stream.read(&mut buf).await { + if n == 0 { + break; + } + let _ = tls_stream.write_all(&buf[..n]).await; + } + } + } + ServerBehavior::CloseImmediately => { + // Just drop the connection without completing TLS handshake + drop(tcp_stream); + } + } + }); + } + Err(_) => break, + } + } + _ = &mut shutdown_rx => { + break; + } + } + } + }); + + TestTlsServer { + addr, + shutdown_tx: Some(shutdown_tx), + task_handle: Some(task_handle), + } + } + + /// Get the port the server is listening on. + pub fn port(&self) -> u16 { + self.addr.port() + } + + /// Shutdown the server gracefully. + pub async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Some(handle) = self.task_handle.take() { + let _ = handle.await; + } + } +} + +impl Drop for TestTlsServer { + fn drop(&mut self) { + // Signal shutdown if not already done + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +/// Describes how the test server should behave. +#[derive(Clone, Copy, Debug)] +pub enum ServerBehavior { + /// Echo back any received data (normal operation). + Echo, + /// Close the connection immediately after TCP accept, before TLS handshake. + CloseImmediately, +} + +/// A test TCP server (without TLS) for integration testing. +pub struct TestTcpServer { + pub addr: SocketAddr, + shutdown_tx: Option>, + task_handle: Option>, +} + +impl TestTcpServer { + /// Start a new TCP echo server. + pub async fn start() -> Self { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind TCP listener"); + let addr = listener.local_addr().expect("Failed to get local address"); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + + let task_handle = tokio::spawn(async move { + loop { + tokio::select! { + result = listener.accept() => { + match result { + Ok((mut tcp_stream, _)) => { + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + while let Ok(n) = tcp_stream.read(&mut buf).await { + if n == 0 { + break; + } + let _ = tcp_stream.write_all(&buf[..n]).await; + } + }); + } + Err(_) => break, + } + } + _ = &mut shutdown_rx => { + break; + } + } + } + }); + + TestTcpServer { + addr, + shutdown_tx: Some(shutdown_tx), + task_handle: Some(task_handle), + } + } + + pub fn port(&self) -> u16 { + self.addr.port() + } + + pub async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Some(handle) = self.task_handle.take() { + let _ = handle.await; + } + } +} + +impl Drop for TestTcpServer { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +/// A minimal message type for testing that doesn't require fix44 types. +#[derive(Debug, Clone)] +pub struct MinimalMessage; + +impl OutboundMessage for MinimalMessage { + fn write(&self, _msg: &mut Message) { + // No-op for minimal test message + } + + fn message_type(&self) -> &str { + "0" // Heartbeat type, simplest message + } +} + +impl InboundMessage for MinimalMessage { + fn parse(_message: &Message) -> Self { + MinimalMessage + } +} + +/// A minimal Application implementation for testing transport connectivity. +pub struct MinimalApplication; + +#[async_trait::async_trait] +impl Application for MinimalApplication { + async fn on_outbound_message(&self, _msg: &MinimalMessage) -> OutboundDecision { + OutboundDecision::Send + } + + async fn on_inbound_message(&self, _msg: MinimalMessage) -> InboundDecision { + InboundDecision::Accept + } + + async fn on_logout(&mut self, _reason: &str) {} + + async fn on_logon(&mut self) {} +} diff --git a/crates/hotfix/tests/connection_test_cases/mod.rs b/crates/hotfix/tests/connection_test_cases/mod.rs new file mode 100644 index 0000000..af78ab0 --- /dev/null +++ b/crates/hotfix/tests/connection_test_cases/mod.rs @@ -0,0 +1,3 @@ +mod connect_tests; +pub(crate) mod helpers; +mod tls_tests; diff --git a/crates/hotfix/tests/connection_test_cases/tls_tests.rs b/crates/hotfix/tests/connection_test_cases/tls_tests.rs new file mode 100644 index 0000000..a29f09d --- /dev/null +++ b/crates/hotfix/tests/connection_test_cases/tls_tests.rs @@ -0,0 +1,386 @@ +//! Integration tests for TLS transport functionality. +//! +//! These tests verify the TLS connection logic in `crates/hotfix/src/transport/socket/tls.rs`. + +use std::sync::Arc; + +use hotfix::config::TlsConfig; +use hotfix::transport::error::ConnectionError; +use hotfix::transport::socket::tls::{create_tcp_over_tls_connection, wrap_stream}; +use rustls::ClientConfig; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use crate::helpers::{ServerBehavior, TestCertificates, TestTlsServer, init_crypto_provider}; + +#[tokio::test] +async fn test_tls_connection_with_file_config_succeeds() { + // Generate test certificates valid for localhost + let certs = TestCertificates::generate(&["localhost"]); + let ca_file = certs.write_ca_to_temp_file(); + + // Start a TLS server + let server = TestTlsServer::start(&certs).await; + + // Create TLS config using the CA file + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + // Connect to the server + let mut stream = + create_tcp_over_tls_connection("localhost".to_string(), server.port(), &tls_config) + .await + .expect("TLS connection should succeed"); + + // Verify the connection works by sending and receiving data + let test_data = b"Hello, TLS!"; + stream + .write_all(test_data) + .await + .expect("Write should succeed"); + + let mut buf = vec![0u8; test_data.len()]; + stream + .read_exact(&mut buf) + .await + .expect("Read should succeed"); + + assert_eq!(&buf, test_data); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_tls_connection_with_ip_address() { + // Generate test certificates valid for 127.0.0.1 + let certs = TestCertificates::generate(&["127.0.0.1"]); + let ca_file = certs.write_ca_to_temp_file(); + + // Start a TLS server + let server = TestTlsServer::start(&certs).await; + + // Create TLS config + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + // Connect using IP address + let mut stream = + create_tcp_over_tls_connection("127.0.0.1".to_string(), server.port(), &tls_config) + .await + .expect("TLS connection with IP should succeed"); + + // Verify connection works + let test_data = b"IP address test"; + stream + .write_all(test_data) + .await + .expect("Write should succeed"); + + let mut buf = vec![0u8; test_data.len()]; + stream + .read_exact(&mut buf) + .await + .expect("Read should succeed"); + + assert_eq!(&buf, test_data); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_wrap_stream_with_valid_config() { + // Generate test certificates + let certs = TestCertificates::generate(&["localhost"]); + let ca_file = certs.write_ca_to_temp_file(); + + // Start a TLS server + let server = TestTlsServer::start(&certs).await; + + // Establish raw TCP connection + let tcp_stream = TcpStream::connect(server.addr) + .await + .expect("TCP connection should succeed"); + + // Build client config manually using the CA + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + let client_config = hotfix::transport::socket::tls::get_client_config(&tls_config) + .expect("Client config should be created"); + + // Wrap the stream + let mut tls_stream = wrap_stream(tcp_stream, "localhost".to_string(), Arc::new(client_config)) + .await + .expect("wrap_stream should succeed"); + + // Verify the wrapped stream works + let test_data = b"Wrapped stream test"; + tls_stream + .write_all(test_data) + .await + .expect("Write should succeed"); + + let mut buf = vec![0u8; test_data.len()]; + tls_stream + .read_exact(&mut buf) + .await + .expect("Read should succeed"); + + assert_eq!(&buf, test_data); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_tls_connection_fails_with_untrusted_ca() { + // Generate two separate certificate sets - server will use one, client trusts another + let server_certs = TestCertificates::generate(&["localhost"]); + let untrusted_certs = TestCertificates::generate(&["localhost"]); + let untrusted_ca_file = untrusted_certs.write_ca_to_temp_file(); + + // Start server with its own certificates + let server = TestTlsServer::start(&server_certs).await; + + // Client trusts a different CA + let tls_config = TlsConfig::File { + ca_certificate_path: untrusted_ca_file.path().to_string_lossy().to_string(), + }; + + // Connection should fail due to untrusted certificate + let result = + create_tcp_over_tls_connection("localhost".to_string(), server.port(), &tls_config).await; + + assert!(result.is_err(), "Connection should fail with untrusted CA"); + match result.unwrap_err() { + ConnectionError::IOError(e) => { + let error_string = e.to_string(); + assert!( + error_string.contains("certificate") || error_string.contains("invalid"), + "Error should mention certificate issue: {error_string}" + ); + } + other => panic!("Expected IOError, got: {other:?}"), + } + + server.shutdown().await; +} + +#[tokio::test] +async fn test_tls_connection_fails_with_hostname_mismatch() { + // Generate certificate only valid for "other-host.example.com" + let certs = TestCertificates::generate(&["other-host.example.com"]); + let ca_file = certs.write_ca_to_temp_file(); + + // Start server + let server = TestTlsServer::start(&certs).await; + + // Try to connect using "localhost" - hostname won't match certificate + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + let result = + create_tcp_over_tls_connection("localhost".to_string(), server.port(), &tls_config).await; + + assert!( + result.is_err(), + "Connection should fail with hostname mismatch" + ); + + server.shutdown().await; +} + +#[tokio::test] +async fn test_wrap_stream_invalid_dns_name_empty_string() { + init_crypto_provider(); + + // Create a mock TCP stream (we won't actually connect, just test DNS name validation) + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("Bind should succeed"); + let addr = listener.local_addr().expect("Should have local addr"); + + // Connect to our own listener + let tcp_stream = TcpStream::connect(addr) + .await + .expect("TCP connect should succeed"); + + // Create a minimal client config + let client_config = ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + + // Try to wrap with empty domain name + let result = wrap_stream(tcp_stream, "".to_string(), Arc::new(client_config)).await; + + assert!(result.is_err(), "Empty domain should fail"); + match result.unwrap_err() { + ConnectionError::InvalidDnsName(_) => {} + other => panic!("Expected InvalidDnsName error, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_wrap_stream_invalid_dns_name_formats() { + init_crypto_provider(); + + let invalid_domains = vec![ + "", // Empty + " ", // Whitespace only + "host name", // Contains space + "-invalid.com", // Starts with hyphen + "invalid-.com", // Ends with hyphen + "a]bad[name.com", // Invalid characters + ]; + + for domain in invalid_domains { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("Bind should succeed"); + let addr = listener.local_addr().expect("Should have local addr"); + + let tcp_stream = TcpStream::connect(addr) + .await + .expect("TCP connect should succeed"); + + let client_config = ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + + let result = wrap_stream(tcp_stream, domain.to_string(), Arc::new(client_config)).await; + + assert!( + result.is_err(), + "Domain '{domain}' should fail DNS validation" + ); + match result.unwrap_err() { + ConnectionError::InvalidDnsName(_) => {} + other => panic!("Expected InvalidDnsName error for '{domain}', got: {other:?}"), + } + } +} + +#[tokio::test] +async fn test_tls_connection_to_nonexistent_server() { + let certs = TestCertificates::generate(&["localhost"]); + let ca_file = certs.write_ca_to_temp_file(); + + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + // Try to connect to a port where nothing is listening + // Use a high port number that's unlikely to be in use + let result = create_tcp_over_tls_connection("localhost".to_string(), 59999, &tls_config).await; + + assert!( + result.is_err(), + "Connection to nonexistent server should fail" + ); + match result.unwrap_err() { + ConnectionError::IOError(_) => {} + other => panic!("Expected IOError, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_tls_connection_refused() { + // Bind a port but don't accept connections, then close it + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("Bind should succeed"); + let port = listener + .local_addr() + .expect("Should have local addr") + .port(); + drop(listener); // Close the listener immediately + + let certs = TestCertificates::generate(&["127.0.0.1"]); + let ca_file = certs.write_ca_to_temp_file(); + + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + // Try to connect to the closed port + let result = create_tcp_over_tls_connection("127.0.0.1".to_string(), port, &tls_config).await; + + assert!(result.is_err(), "Connection to closed port should fail"); + match result.unwrap_err() { + ConnectionError::IOError(_) => {} + other => panic!("Expected IOError, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_tls_config_native_root_store() { + init_crypto_provider(); + + // Test that Native config successfully creates a client config + let tls_config = TlsConfig::Native; + let result = hotfix::transport::socket::tls::get_client_config(&tls_config); + + // Should succeed in creating the config (even if the root store may be empty on some systems) + assert!( + result.is_ok(), + "Native root store config should be created successfully" + ); +} + +#[tokio::test] +async fn test_tls_config_webpki_root_store() { + init_crypto_provider(); + + // Test that Webpki config successfully creates a client config + let tls_config = TlsConfig::Webpki; + let result = hotfix::transport::socket::tls::get_client_config(&tls_config); + + // Should succeed - webpki-roots provides bundled certificates + assert!( + result.is_ok(), + "Webpki root store config should be created successfully" + ); +} + +#[tokio::test] +async fn test_tls_config_file_with_nonexistent_path() { + init_crypto_provider(); + + let tls_config = TlsConfig::File { + ca_certificate_path: "/nonexistent/path/to/ca.pem".to_string(), + }; + + let result = hotfix::transport::socket::tls::get_client_config(&tls_config); + + assert!(result.is_err(), "Nonexistent CA file should fail"); + match result.unwrap_err() { + ConnectionError::IOError(_) => {} + other => panic!("Expected IOError for nonexistent file, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_server_closes_after_tcp_accept() { + let certs = TestCertificates::generate(&["localhost"]); + let ca_file = certs.write_ca_to_temp_file(); + + // Start server that closes connections immediately + let server = TestTlsServer::start_with_behavior(&certs, ServerBehavior::CloseImmediately).await; + + let tls_config = TlsConfig::File { + ca_certificate_path: ca_file.path().to_string_lossy().to_string(), + }; + + // Connection should fail when server closes during handshake + let result = + create_tcp_over_tls_connection("localhost".to_string(), server.port(), &tls_config).await; + + assert!( + result.is_err(), + "Connection should fail when server closes during handshake" + ); + + server.shutdown().await; +} diff --git a/crates/hotfix/tests/connection_tests.rs b/crates/hotfix/tests/connection_tests.rs new file mode 100644 index 0000000..a7812c6 --- /dev/null +++ b/crates/hotfix/tests/connection_tests.rs @@ -0,0 +1,3 @@ +mod connection_test_cases; + +use connection_test_cases::helpers; diff --git a/crates/hotfix/tests/common/actions.rs b/crates/hotfix/tests/session_test_cases/common/actions.rs similarity index 93% rename from crates/hotfix/tests/common/actions.rs rename to crates/hotfix/tests/session_test_cases/common/actions.rs index a862f0b..5b164cc 100644 --- a/crates/hotfix/tests/common/actions.rs +++ b/crates/hotfix/tests/session_test_cases/common/actions.rs @@ -63,7 +63,10 @@ impl When<&mut FakeCounterparty> { } pub async fn gets_reconnected(&mut self, reset_store: bool) { - self.target.reconnect(reset_store).await; + self.target + .reconnect(reset_store) + .await + .expect("failed to check whether session reconnected"); } } diff --git a/crates/hotfix/tests/common/assertions.rs b/crates/hotfix/tests/session_test_cases/common/assertions.rs similarity index 100% rename from crates/hotfix/tests/common/assertions.rs rename to crates/hotfix/tests/session_test_cases/common/assertions.rs diff --git a/crates/hotfix/tests/common/cleanup.rs b/crates/hotfix/tests/session_test_cases/common/cleanup.rs similarity index 100% rename from crates/hotfix/tests/common/cleanup.rs rename to crates/hotfix/tests/session_test_cases/common/cleanup.rs diff --git a/crates/hotfix/tests/common/fakes/fake_application.rs b/crates/hotfix/tests/session_test_cases/common/fakes/fake_application.rs similarity index 100% rename from crates/hotfix/tests/common/fakes/fake_application.rs rename to crates/hotfix/tests/session_test_cases/common/fakes/fake_application.rs diff --git a/crates/hotfix/tests/common/fakes/fake_counterparty.rs b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs similarity index 93% rename from crates/hotfix/tests/common/fakes/fake_counterparty.rs rename to crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs index 33cb643..d4738c8 100644 --- a/crates/hotfix/tests/common/fakes/fake_counterparty.rs +++ b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use hotfix::config::SessionConfig; use hotfix::message::logon::{Logon, ResetSeqNumConfig}; use hotfix::message::sequence_reset::SequenceReset; @@ -33,16 +34,16 @@ where pub async fn start( session_ref: InternalSessionRef, session_config: SessionConfig, - ) -> Self { + ) -> Result { let (writer_ref, receiver) = Self::create_writer(); let (reader_ref, dc_sender) = Self::create_reader(); let connection = FixConnection::new(writer_ref, reader_ref); let message_config = MessageConfig::default(); - let message_builder = MessageBuilder::new(Dictionary::fix44(), message_config).unwrap(); + let message_builder = MessageBuilder::new(Dictionary::fix44(), message_config)?; - session_ref.register_writer(connection.get_writer()).await; + session_ref.register_writer(connection.get_writer()).await?; - Self { + let fake_counterparty = Self { receiver, received_messages: vec![], sent_messages: vec![], @@ -52,10 +53,12 @@ where message_config, _connection: connection, _dc_sender: dc_sender, - } + }; + + Ok(fake_counterparty) } - pub async fn reconnect(&mut self, reset_store: bool) { + pub async fn reconnect(&mut self, reset_store: bool) -> Result<()> { let (writer_ref, receiver) = Self::create_writer(); let (reader_ref, dc_sender) = Self::create_reader(); let connection = FixConnection::new(writer_ref, reader_ref); @@ -64,11 +67,13 @@ where self._dc_sender = dc_sender; self.session_ref .register_writer(connection.get_writer()) - .await; + .await?; if reset_store { self.sent_messages.clear(); } + + Ok(()) } pub async fn push_previously_sent_message(&mut self, message: impl OutboundMessage) { @@ -94,7 +99,8 @@ where if skip_updates { self.session_ref .new_fix_message_received(RawFixMessage::new(original_raw)) - .await; + .await + .expect("failed to resend message"); return; } @@ -112,7 +118,8 @@ where Ok(resent_raw) => { self.session_ref .new_fix_message_received(RawFixMessage::new(resent_raw)) - .await; + .await + .expect("failed to resend message"); } Err(err) => { panic!("failed to encode message for resend: {err:?}"); @@ -140,7 +147,8 @@ where .expect("failed to generate message"); self.session_ref .new_fix_message_received(RawFixMessage::new(raw_message)) - .await; + .await + .expect("failed to send sequence reset"); } pub async fn send_logon(&mut self) { @@ -167,7 +175,8 @@ where self.sent_messages.push(raw_message.clone()); self.session_ref .new_fix_message_received(RawFixMessage::new(raw_message)) - .await; + .await + .expect("failed to send message"); } pub fn delete_last_message_from_store(&mut self) -> bool { diff --git a/crates/hotfix/tests/common/fakes/mod.rs b/crates/hotfix/tests/session_test_cases/common/fakes/mod.rs similarity index 100% rename from crates/hotfix/tests/common/fakes/mod.rs rename to crates/hotfix/tests/session_test_cases/common/fakes/mod.rs diff --git a/crates/hotfix/tests/common/fakes/session_spy.rs b/crates/hotfix/tests/session_test_cases/common/fakes/session_spy.rs similarity index 100% rename from crates/hotfix/tests/common/fakes/session_spy.rs rename to crates/hotfix/tests/session_test_cases/common/fakes/session_spy.rs diff --git a/crates/hotfix/tests/common/mod.rs b/crates/hotfix/tests/session_test_cases/common/mod.rs similarity index 100% rename from crates/hotfix/tests/common/mod.rs rename to crates/hotfix/tests/session_test_cases/common/mod.rs diff --git a/crates/hotfix/tests/common/setup.rs b/crates/hotfix/tests/session_test_cases/common/setup.rs similarity index 96% rename from crates/hotfix/tests/common/setup.rs rename to crates/hotfix/tests/session_test_cases/common/setup.rs index 9480fba..5746f1b 100644 --- a/crates/hotfix/tests/common/setup.rs +++ b/crates/hotfix/tests/session_test_cases/common/setup.rs @@ -32,7 +32,9 @@ pub async fn given_a_connected_session_with_store( .expect("session to be created successfully"); let session_spy = SessionSpy::new(session.clone().into(), message_rx); - let mock_counterparty = FakeCounterparty::start(session.clone(), counterparty_config).await; + let mock_counterparty = FakeCounterparty::start(session.clone(), counterparty_config) + .await + .expect("failed to start FakeCounterparty"); (session_spy, mock_counterparty) } diff --git a/crates/hotfix/tests/common/test_messages.rs b/crates/hotfix/tests/session_test_cases/common/test_messages.rs similarity index 100% rename from crates/hotfix/tests/common/test_messages.rs rename to crates/hotfix/tests/session_test_cases/common/test_messages.rs diff --git a/crates/hotfix/tests/session_test_cases/mod.rs b/crates/hotfix/tests/session_test_cases/mod.rs index 4fbc7c3..43a4be7 100644 --- a/crates/hotfix/tests/session_test_cases/mod.rs +++ b/crates/hotfix/tests/session_test_cases/mod.rs @@ -1,5 +1,6 @@ mod admin_request_tests; mod business_tests; +pub(crate) mod common; mod heartbeat_tests; mod invalid_message_tests; mod logon_tests; diff --git a/crates/hotfix/tests/session_tests.rs b/crates/hotfix/tests/session_tests.rs index 3a37486..e1a5b89 100644 --- a/crates/hotfix/tests/session_tests.rs +++ b/crates/hotfix/tests/session_tests.rs @@ -1,2 +1,3 @@ -mod common; mod session_test_cases; + +use session_test_cases::common; diff --git a/examples/simple-new-order/src/main.rs b/examples/simple-new-order/src/main.rs index 5ef58a8..e41eb66 100644 --- a/examples/simple-new-order/src/main.rs +++ b/examples/simple-new-order/src/main.rs @@ -147,7 +147,7 @@ async fn start_session( db_config: &Database, app: TestApplication, ) -> Result> { - let mut config = Config::load_from_path(config_path); + let mut config = Config::load_from_path(config_path)?; let session_config = config .sessions .pop()