From 5aaf392727538950b8fed5dfe9bc21ba7382c93e Mon Sep 17 00:00:00 2001 From: Charles Samuels Date: Thu, 22 Jan 2026 10:07:45 -0800 Subject: [PATCH] verify that postgres is up-and-running on startup Atomically acquire the port by building the TcpListener before postgres even starts. Tell postgres to not even bind to a TCP port; it only creates a Unix Domain Socket. It still names that filename with "5432", but it doesn't bind to a port. Then, proxy the Unix Domain Socket over the TcpListener we previously acquired. We can now be sure which postgres we are talking to and on what port. Finally, as our last trick, we can verify postgres is up and running by running the `pg_isready` CLI tool. We tell pg_isready to check the unix domain socket path (though we could just as well have it go over our proxied TCP listener) --- src/lib.rs | 38 +++++++++++++++------- src/run_db.rs | 84 +++++++++++++++++++++++++++++++++++++++++++++--- tests/startup.rs | 2 +- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 454da15..0f8edc3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,8 @@ pub struct PgTempDB { // See shutdown implementation for why these are options temp_dir: Option, postgres_process: Option, + /// just dropping this is enough to shutdown the proxy + _shutdown_proxy: tokio::sync::oneshot::Sender<()>, } impl PgTempDB { @@ -44,14 +46,15 @@ impl PgTempDB { pub fn from_builder(mut builder: PgTempDBBuilder) -> PgTempDB { let dbuser = builder.get_user(); let dbpass = builder.get_password(); - let dbport = builder.get_port_or_set_random(); + let (tcp_listener, dbport) = make_listener(builder.port.unwrap_or(0)); + builder.port = Some(dbport); let dbname = builder.get_dbname(); let persist = builder.persist_data_dir; let dump_path = builder.dump_path.clone(); let load_path = builder.load_path.clone(); let temp_dir = run_db::init_db(&mut builder); - let postgres_process = Some(run_db::run_db(&temp_dir, builder)); + let (postgres_process, _shutdown_proxy) = run_db::run_db(&temp_dir, tcp_listener, builder); let temp_dir = Some(temp_dir); let db = PgTempDB { @@ -62,7 +65,8 @@ impl PgTempDB { persist, dump_path, temp_dir, - postgres_process, + postgres_process: Some(postgres_process), + _shutdown_proxy, }; if let Some(path) = load_path { @@ -458,8 +462,17 @@ impl PgTempDBBuilder { /// Unlike the other getters, this getter will try to open a new socket to find an unused port, /// and then set it as the current port. + /// + /// By using this function, you force pgtemp to choose a port non-atomically. This can + /// cause a race condition where another process might choose the same port. You can + /// avoid that problem by simply not calling this function, pgtemp will then atomically + /// reserve a port. pub fn get_port_or_set_random(&mut self) -> u16 { - let port = self.port.as_ref().copied().unwrap_or_else(get_unused_port); + let port = self + .port + .as_ref() + .copied() + .unwrap_or_else(|| make_listener(0).1); self.port = Some(port); port @@ -471,13 +484,14 @@ impl PgTempDBBuilder { } } -fn get_unused_port() -> u16 { - // TODO: relies on Rust's stdlib setting SO_REUSEPORT by default so that postgres can still - // bind to the port afterwards. Also there's a race condition/TOCTOU because there's lag - // between when the port is checked here and when postgres actually tries to bind to it. - let sock = std::net::TcpListener::bind("localhost:0") - .expect("failed to bind to local port when getting unused port"); - sock.local_addr() +/// make a socket listen on the given port, which can be 0 to pick a random one +/// +/// Return a listener and the actual port chosen +fn make_listener(port: u16) -> (std::net::TcpListener, u16) { + let l = std::net::TcpListener::bind(("127.0.0.1", port)).expect("binding port"); + let port = l + .local_addr() .expect("failed to get local addr from socket") - .port() + .port(); + (l, port) } diff --git a/src/run_db.rs b/src/run_db.rs index efe0dce..5cc76b1 100644 --- a/src/run_db.rs +++ b/src/run_db.rs @@ -101,7 +101,38 @@ pub fn init_db(builder: &mut PgTempDBBuilder) -> TempDir { temp_dir } -pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { +async fn make_proxy(listener: std::net::TcpListener, unix_path: std::path::PathBuf) { + listener + .set_nonblocking(true) + .expect("failed to set nonblocking"); + + let listener = + tokio::net::TcpListener::from_std(listener).expect("failed to upgrade to tokio listener"); + + loop { + if let Ok((mut client_conn, _client_addr)) = listener.accept().await { + client_conn + .set_nodelay(true) + .expect("failed to set nodelay on client connection"); + let mut db_conn = tokio::net::UnixStream::connect(&unix_path) + .await + .expect("failed to connect to postgres server"); + tokio::spawn(async move { + let _ = tokio::io::copy_bidirectional(&mut db_conn, &mut client_conn).await; + }); + } else { + println!("idk when this errs"); + } + } +} + +/// start the postgres server. If `port_holder` is specified, drop it after +/// we are sure that postgres is running. +pub fn run_db( + temp_dir: &TempDir, + listener_port: std::net::TcpListener, + mut builder: PgTempDBBuilder, +) -> (Child, tokio::sync::oneshot::Sender<()>) { let data_dir = temp_dir.path().join("pg_data_dir"); let data_dir_str = data_dir.to_str().unwrap(); let port = builder.get_port_or_set_random(); @@ -121,13 +152,13 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { pgcmd .args(["-c", &format!("unix_socket_directories={}", data_dir_str)]) - .args(["-c", &format!("port={port}")]) // https://www.postgresql.org/docs/current/non-durability.html // https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server .args(["-c", "fsync=off"]) .args(["-c", "synchronous_commit=off"]) .args(["-c", "full_page_writes=off"]) .args(["-c", "autovacuum=off"]) + .args(["-c", "listen_addresses="]) .args(["-D", data_dir.to_str().unwrap()]); for (key, val) in &builder.server_configs { pgcmd.args(["-c", &format!("{}={}", key, val)]); @@ -142,11 +173,36 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { .spawn() .expect("Failed to start postgres. Is it installed and on your path?"); - std::thread::sleep(CREATEDB_RETRY_DELAY); + let (stop_proxy_tx, mut stop_proxy_rx) = tokio::sync::oneshot::channel(); + + let socket_path = data_dir.join(".s.PGSQL.5432"); + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .expect("start tokio runtime"); + rt.block_on(async move { + tokio::select! { + _ = &mut stop_proxy_rx => { } + _ = make_proxy(listener_port, socket_path) => {} + } + }); + }); + + let isready_path = builder + .bin_path + .as_ref() + .map_or("pg_isready".into(), |p| p.join("pg_isready")); + + // pg_isready treats its "-h" parameter as a path if it starts with a slash, + // so ensure that it starts with a slash + let unix_path = std::path::absolute(&data_dir).expect("failed to resolve datadir path"); + while !check_postgres_ready(&isready_path, &unix_path) { + std::thread::sleep(CREATEDB_RETRY_DELAY); + } let user = builder.get_user(); //let password = builder.get_password(); - let port = builder.get_port_or_set_random(); let dbname = builder.get_dbname(); if dbname != "postgres" { @@ -195,5 +251,23 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child { } } - postgres_server_process + (postgres_server_process, stop_proxy_tx) +} + +/// connect to postgres via its unix domain socket; and send it an SSL request. +/// +/// If the server is not up and running, it will reply with an error, in which case +/// the we need to wait longer for it to start up. +/// +/// Connect via a unix domain socket, because while we hold the "port locking TcpListener", +/// we will preferentially connect to that one instead of the real postgres. +fn check_postgres_ready( + is_ready_path: &std::path::Path, + unix_socket_path: &std::path::Path, +) -> bool { + let mut cmd = Command::new(is_ready_path); + cmd.arg("-h"); + cmd.arg(unix_socket_path); + let output = cmd.output().expect("child"); + output.status.success() } diff --git a/tests/startup.rs b/tests/startup.rs index 926a553..aa09b73 100644 --- a/tests/startup.rs +++ b/tests/startup.rs @@ -92,7 +92,7 @@ fn test_slow_postgres_startup() { let temp_dir = TempDir::new().expect("Failed to create temp dir"); let dir_path = temp_dir.path().to_owned(); - for cmd in ["postgres", "createdb", "psql", "initdb"] { + for cmd in ["postgres", "createdb", "psql", "initdb", "pg_isready"] { let sleep_cmd = if cmd == "postgres" { "sleep 0.5" } else { "" }; let exec_cmd = format!("exec {cmd} \"$@\""); let wrapper_binary = ["#!/bin/bash", sleep_cmd, exec_cmd.as_str()].join("\n");