Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,24 @@ pub struct PgTempDB {
// See shutdown implementation for why these are options
temp_dir: Option<TempDir>,
postgres_process: Option<Child>,
/// just dropping this is enough to shutdown the proxy
_shutdown_proxy: tokio::sync::oneshot::Sender<()>,
}

impl PgTempDB {
/// Start a PgTempDB with the parameters configured from a PgTempDBBuilder
pub fn from_builder(mut builder: PgTempDBBuilder) -> PgTempDB {
let dbuser = builder.get_user();
let dbpass = builder.get_password();
let dbport = builder.get_port_or_set_random();
let (tcp_listener, dbport) = make_listener(builder.port.unwrap_or(0));
builder.port = Some(dbport);
let dbname = builder.get_dbname();
let persist = builder.persist_data_dir;
let dump_path = builder.dump_path.clone();
let load_path = builder.load_path.clone();

let temp_dir = run_db::init_db(&mut builder);
let postgres_process = Some(run_db::run_db(&temp_dir, builder));
let (postgres_process, _shutdown_proxy) = run_db::run_db(&temp_dir, tcp_listener, builder);
let temp_dir = Some(temp_dir);

let db = PgTempDB {
Expand All @@ -62,7 +65,8 @@ impl PgTempDB {
persist,
dump_path,
temp_dir,
postgres_process,
postgres_process: Some(postgres_process),
_shutdown_proxy,
};

if let Some(path) = load_path {
Expand Down Expand Up @@ -458,8 +462,17 @@ impl PgTempDBBuilder {

/// Unlike the other getters, this getter will try to open a new socket to find an unused port,
/// and then set it as the current port.
///
/// By using this function, you force pgtemp to choose a port non-atomically. This can
/// cause a race condition where another process might choose the same port. You can
/// avoid that problem by simply not calling this function, pgtemp will then atomically
/// reserve a port.
pub fn get_port_or_set_random(&mut self) -> u16 {
let port = self.port.as_ref().copied().unwrap_or_else(get_unused_port);
let port = self
.port
.as_ref()
.copied()
.unwrap_or_else(|| make_listener(0).1);

self.port = Some(port);
port
Expand All @@ -471,13 +484,14 @@ impl PgTempDBBuilder {
}
}

fn get_unused_port() -> u16 {
// TODO: relies on Rust's stdlib setting SO_REUSEPORT by default so that postgres can still
// bind to the port afterwards. Also there's a race condition/TOCTOU because there's lag
// between when the port is checked here and when postgres actually tries to bind to it.
let sock = std::net::TcpListener::bind("localhost:0")
.expect("failed to bind to local port when getting unused port");
sock.local_addr()
/// make a socket listen on the given port, which can be 0 to pick a random one
///
/// Return a listener and the actual port chosen
fn make_listener(port: u16) -> (std::net::TcpListener, u16) {
let l = std::net::TcpListener::bind(("127.0.0.1", port)).expect("binding port");
let port = l
.local_addr()
.expect("failed to get local addr from socket")
.port()
.port();
(l, port)
}
84 changes: 79 additions & 5 deletions src/run_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,38 @@ pub fn init_db(builder: &mut PgTempDBBuilder) -> TempDir {
temp_dir
}

pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child {
async fn make_proxy(listener: std::net::TcpListener, unix_path: std::path::PathBuf) {
listener
.set_nonblocking(true)
.expect("failed to set nonblocking");

let listener =
tokio::net::TcpListener::from_std(listener).expect("failed to upgrade to tokio listener");

loop {
if let Ok((mut client_conn, _client_addr)) = listener.accept().await {
client_conn
.set_nodelay(true)
.expect("failed to set nodelay on client connection");
let mut db_conn = tokio::net::UnixStream::connect(&unix_path)
.await
.expect("failed to connect to postgres server");
tokio::spawn(async move {
let _ = tokio::io::copy_bidirectional(&mut db_conn, &mut client_conn).await;
});
} else {
println!("idk when this errs");
}
}
}

/// start the postgres server. If `port_holder` is specified, drop it after
/// we are sure that postgres is running.
pub fn run_db(
temp_dir: &TempDir,
listener_port: std::net::TcpListener,
mut builder: PgTempDBBuilder,
) -> (Child, tokio::sync::oneshot::Sender<()>) {
let data_dir = temp_dir.path().join("pg_data_dir");
let data_dir_str = data_dir.to_str().unwrap();
let port = builder.get_port_or_set_random();
Expand All @@ -121,13 +152,13 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child {

pgcmd
.args(["-c", &format!("unix_socket_directories={}", data_dir_str)])
.args(["-c", &format!("port={port}")])
// https://www.postgresql.org/docs/current/non-durability.html
// https://wiki.postgresql.org/wiki/Tuning_Your_PostgreSQL_Server
.args(["-c", "fsync=off"])
.args(["-c", "synchronous_commit=off"])
.args(["-c", "full_page_writes=off"])
.args(["-c", "autovacuum=off"])
.args(["-c", "listen_addresses="])
.args(["-D", data_dir.to_str().unwrap()]);
for (key, val) in &builder.server_configs {
pgcmd.args(["-c", &format!("{}={}", key, val)]);
Expand All @@ -142,11 +173,36 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child {
.spawn()
.expect("Failed to start postgres. Is it installed and on your path?");

std::thread::sleep(CREATEDB_RETRY_DELAY);
let (stop_proxy_tx, mut stop_proxy_rx) = tokio::sync::oneshot::channel();

let socket_path = data_dir.join(".s.PGSQL.5432");
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.expect("start tokio runtime");
rt.block_on(async move {
tokio::select! {
_ = &mut stop_proxy_rx => { }
_ = make_proxy(listener_port, socket_path) => {}
}
});
});

let isready_path = builder
.bin_path
.as_ref()
.map_or("pg_isready".into(), |p| p.join("pg_isready"));

// pg_isready treats its "-h" parameter as a path if it starts with a slash,
// so ensure that it starts with a slash
let unix_path = std::path::absolute(&data_dir).expect("failed to resolve datadir path");
while !check_postgres_ready(&isready_path, &unix_path) {
std::thread::sleep(CREATEDB_RETRY_DELAY);
}

let user = builder.get_user();
//let password = builder.get_password();
let port = builder.get_port_or_set_random();
let dbname = builder.get_dbname();

if dbname != "postgres" {
Expand Down Expand Up @@ -195,5 +251,23 @@ pub fn run_db(temp_dir: &TempDir, mut builder: PgTempDBBuilder) -> Child {
}
}

postgres_server_process
(postgres_server_process, stop_proxy_tx)
}

/// connect to postgres via its unix domain socket; and send it an SSL request.
///
/// If the server is not up and running, it will reply with an error, in which case
/// the we need to wait longer for it to start up.
///
/// Connect via a unix domain socket, because while we hold the "port locking TcpListener",
/// we will preferentially connect to that one instead of the real postgres.
fn check_postgres_ready(
is_ready_path: &std::path::Path,
unix_socket_path: &std::path::Path,
) -> bool {
let mut cmd = Command::new(is_ready_path);
cmd.arg("-h");
cmd.arg(unix_socket_path);
let output = cmd.output().expect("child");
output.status.success()
}
2 changes: 1 addition & 1 deletion tests/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn test_slow_postgres_startup() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let dir_path = temp_dir.path().to_owned();

for cmd in ["postgres", "createdb", "psql", "initdb"] {
for cmd in ["postgres", "createdb", "psql", "initdb", "pg_isready"] {
let sleep_cmd = if cmd == "postgres" { "sleep 0.5" } else { "" };
let exec_cmd = format!("exec {cmd} \"$@\"");
let wrapper_binary = ["#!/bin/bash", sleep_cmd, exec_cmd.as_str()].join("\n");
Expand Down