diff --git a/src/lib.rs b/src/lib.rs index 454da15..0f8edc3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,8 @@ pub struct PgTempDB { // See shutdown implementation for why these are options temp_dir: Option, postgres_process: Option, + /// just dropping this is enough to shutdown the proxy + _shutdown_proxy: tokio::sync::oneshot::Sender<()>, } impl PgTempDB { @@ -44,14 +46,15 @@ impl PgTempDB { pub fn from_builder(mut builder: PgTempDBBuilder) -> PgTempDB { let dbuser = builder.get_user(); let dbpass = builder.get_password(); - let dbport = builder.get_port_or_set_random(); + let (tcp_listener, dbport) = make_listener(builder.port.unwrap_or(0)); + builder.port = Some(dbport); let dbname = builder.get_dbname(); let persist = builder.persist_data_dir; let dump_path = builder.dump_path.clone(); let load_path = builder.load_path.clone(); let temp_dir = run_db::init_db(&mut builder); - let postgres_process = Some(run_db::run_db(&temp_dir, builder)); + let (postgres_process, _shutdown_proxy) = run_db::run_db(&temp_dir, tcp_listener, builder); let temp_dir = Some(temp_dir); let db = PgTempDB { @@ -62,7 +65,8 @@ impl PgTempDB { persist, dump_path, temp_dir, - postgres_process, + postgres_process: Some(postgres_process), + _shutdown_proxy, }; if let Some(path) = load_path { @@ -458,8 +462,17 @@ impl PgTempDBBuilder { /// Unlike the other getters, this getter will try to open a new socket to find an unused port, /// and then set it as the current port. + /// + /// By using this function, you force pgtemp to choose a port non-atomically. This can + /// cause a race condition where another process might choose the same port. You can + /// avoid that problem by simply not calling this function, pgtemp will then atomically + /// reserve a port. pub fn get_port_or_set_random(&mut self) -> u16 { - let port = self.port.as_ref().copied().unwrap_or_else(get_unused_port); + let port = self + .port + .as_ref() + .copied() + .unwrap_or_else(|| make_listener(0).1); self.port = Some(port); port @@ -471,13 +484,14 @@ impl PgTempDBBuilder { } } -fn get_unused_port() -> u16 { - // TODO: relies on Rust's stdlib setting SO_REUSEPORT by default so that postgres can still - // bind to the port afterwards. Also there's a race condition/TOCTOU because there's lag - // between when the port is checked here and when postgres actually tries to bind to it. - let sock = std::net::TcpListener::bind("localhost:0") - .expect("failed to bind to local port when getting unused port"); - sock.local_addr() +/// make a socket listen on the given port, which can be 0 to pick a random one +/// +/// Return a listener and the actual port chosen +fn make_listener(port: u16) -> (std::net::TcpListener, u16) { + let l = std::net::TcpListener::bind(("127.0.0.1", port)).expect("binding port"); + let port = l + .local_addr() .expect("failed to get local addr from socket") - .port() + .port(); + (l, port) } diff --git a/src/run_db.rs b/src/run_db.rs index efe0dce..5cc76b1 100644 --- a/src/run_db.rs +++ b/src/run_db.rs @@ -101,7 +101,38 @@ pub fn init_db(builder: &mut PgTempDBBuilder) -> TempDir { temp_dir } -pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { +async fn make_proxy(listener: std::net::TcpListener, unix_path: std::path::PathBuf) { + listener + .set_nonblocking(true) + .expect("failed to set nonblocking"); + + let listener = + tokio::net::TcpListener::from_std(listener).expect("failed to upgrade to tokio listener"); + + loop { + if let Ok((mut client_conn, _client_addr)) = listener.accept().await { + client_conn + .set_nodelay(true) + .expect("failed to set nodelay on client connection"); + let mut db_conn = tokio::net::UnixStream::connect(&unix_path) + .await + .expect("failed to connect to postgres server"); + tokio::spawn(async move { + let _ = tokio::io::copy_bidirectional(&mut db_conn, &mut client_conn).await; + }); + } else { + println!("idk when this errs"); + } + } +} + +/// start the postgres server. If `port_holder` is specified, drop it after +/// we are sure that postgres is running. +pub fn run_db( + temp_dir: &TempDir, + listener_port: std::net::TcpListener, + mut builder: PgTempDBBuilder, +) -> (Child, tokio::sync::oneshot::Sender<()>) { let data_dir = temp_dir.path().join("pg_data_dir"); let data_dir_str = data_dir.to_str().unwrap(); let port = builder.get_port_or_set_random(); @@ -121,13 +152,13 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { pgcmd .args(["-c", &format!("unix_socket_directories={}", data_dir_str)]) - .args(["-c", &format!("port={port}")]) // https://www.postgresql.org/docs/current/non-durability.html // https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server .args(["-c", "fsync=off"]) .args(["-c", "synchronous_commit=off"]) .args(["-c", "full_page_writes=off"]) .args(["-c", "autovacuum=off"]) + .args(["-c", "listen_addresses="]) .args(["-D", data_dir.to_str().unwrap()]); for (key, val) in &builder.server_configs { pgcmd.args(["-c", &format!("{}={}", key, val)]); @@ -142,11 +173,36 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { .spawn() .expect("Failed to start postgres. Is it installed and on your path?"); - std::thread::sleep(CREATEDB_RETRY_DELAY); + let (stop_proxy_tx, mut stop_proxy_rx) = tokio::sync::oneshot::channel(); + + let socket_path = data_dir.join(".s.PGSQL.5432"); + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .expect("start tokio runtime"); + rt.block_on(async move { + tokio::select! { + _ = &mut stop_proxy_rx => { } + _ = make_proxy(listener_port, socket_path) => {} + } + }); + }); + + let isready_path = builder + .bin_path + .as_ref() + .map_or("pg_isready".into(), |p| p.join("pg_isready")); + + // pg_isready treats its "-h" parameter as a path if it starts with a slash, + // so ensure that it starts with a slash + let unix_path = std::path::absolute(&data_dir).expect("failed to resolve datadir path"); + while !check_postgres_ready(&isready_path, &unix_path) { + std::thread::sleep(CREATEDB_RETRY_DELAY); + } let user = builder.get_user(); //let password = builder.get_password(); - let port = builder.get_port_or_set_random(); let dbname = builder.get_dbname(); if dbname != "postgres" { @@ -195,5 +251,23 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { } } - postgres_server_process + (postgres_server_process, stop_proxy_tx) +} + +/// connect to postgres via its unix domain socket; and send it an SSL request. +/// +/// If the server is not up and running, it will reply with an error, in which case +/// the we need to wait longer for it to start up. +/// +/// Connect via a unix domain socket, because while we hold the "port locking TcpListener", +/// we will preferentially connect to that one instead of the real postgres. +fn check_postgres_ready( + is_ready_path: &std::path::Path, + unix_socket_path: &std::path::Path, +) -> bool { + let mut cmd = Command::new(is_ready_path); + cmd.arg("-h"); + cmd.arg(unix_socket_path); + let output = cmd.output().expect("child"); + output.status.success() } diff --git a/tests/startup.rs b/tests/startup.rs index 926a553..aa09b73 100644 --- a/tests/startup.rs +++ b/tests/startup.rs @@ -92,7 +92,7 @@ fn test_slow_postgres_startup() { let temp_dir = TempDir::new().expect("Failed to create temp dir"); let dir_path = temp_dir.path().to_owned(); - for cmd in ["postgres", "createdb", "psql", "initdb"] { + for cmd in ["postgres", "createdb", "psql", "initdb", "pg_isready"] { let sleep_cmd = if cmd == "postgres" { "sleep 0.5" } else { "" }; let exec_cmd = format!("exec {cmd} \"$@\""); let wrapper_binary = ["#!/bin/bash", sleep_cmd, exec_cmd.as_str()].join("\n");