diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index b323c6af3..fd704e09a 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -6,7 +6,7 @@ use std::net::{TcpStream, ToSocketAddrs}; #[cfg(feature = "async")] use std::pin::Pin; #[cfg(feature = "async")] -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering}; #[cfg(feature = "async")] use std::sync::{Arc, Mutex}; #[cfg(feature = "async")] @@ -27,6 +27,9 @@ use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; use crate::Response; use crate::{Error, Method, ResponseLazy}; +#[cfg(feature = "async")] +const BACKING_READ_BUFFER_LENGTH: usize = 16 * 1024; + type UnsecuredStream = TcpStream; #[cfg(feature = "rustls")] @@ -214,13 +217,19 @@ impl AsyncWrite for AsyncHttpStream { } } +/// The starting value of [`AsyncConnectionState::permits`]. +#[cfg(feature = "async")] +const MAX_SEND_PERMITS: isize = isize::MAX; + #[cfg(feature = "async")] struct AsyncConnectionState { write: AsyncMutex>, - read: AsyncMutex>, - /// The ID of the next request we'll send. If this reaches [`usize::MAX`] no further requests - /// can be sent on this socket and a new connection must be made. Thus, in order to limit the - /// connection to sending N new requests, this may be set to [`usize::MAX`] - N. + read: AsyncMutex>>, + /// Remaining permits for sending new requests. Acquired via `fetch_sub(1)`; the counter is + /// alive while it remains positive. Poisoned by storing `0`; capped by keep-alive `max=N` via + /// `fetch_min(N)`. + permits: AtomicIsize, + /// Monotonic ID of the next request to send, used only for pipelining order. next_request_id: AtomicUsize, /// The ID of the next request which is readable from the socket. If we're pipelining this may /// be a few behind [`Self::next_request_id`]. If this is [`usize::MAX`], the socket is in an @@ -292,8 +301,12 @@ impl AsyncConnection { let (read, write) = tokio::io::split(stream); Ok(AsyncConnection(Mutex::new(Arc::new(AsyncConnectionState { - read: AsyncMutex::new(read), + read: AsyncMutex::new(tokio::io::BufReader::with_capacity( + BACKING_READ_BUFFER_LENGTH, + read, + )), write: AsyncMutex::new(write), + permits: AtomicIsize::new(MAX_SEND_PERMITS), next_request_id: AtomicUsize::new(0), readable_request_id: AtomicUsize::new(0), min_dropped_reader_id: AtomicUsize::new(usize::MAX), @@ -399,7 +412,7 @@ impl AsyncConnection { let next_read = conn.readable_request_id.load(Ordering::Acquire); let next_request = conn.next_request_id.load(Ordering::Acquire); debug_assert!( - next_request >= next_read, + next_read == usize::MAX || next_request >= next_read, "At all times, the next_request_id should be higher than the readable id" ); } @@ -415,7 +428,7 @@ impl AsyncConnection { if !request.config.pipelining { // If we're not pipelining, wait for any existing pipelined requests to complete. // Specifically, wait until we have both locks and either we're going to build a - // new connection (because `next_request_id` is `usize::MAX`) or there are no + // new connection (because `permits` is exhausted/poisoned) or there are no // pending readers (because `next_request_id` and `readable_request_id` are the // same). read = Some(Self::timeout(request.timeout_at, conn.read.lock()).await?); @@ -423,7 +436,9 @@ impl AsyncConnection { while { let next_read = conn.readable_request_id.load(Ordering::Relaxed); let next_request = conn.next_request_id.load(Ordering::Relaxed); - next_request != usize::MAX && next_read < next_request + let has_open_request = next_read < next_request; + let permits_alive = conn.permits.load(Ordering::Relaxed) > 0; + permits_alive && has_open_request } { read.take(); write.take(); @@ -437,14 +452,14 @@ impl AsyncConnection { (CONNECTION_STATE_UNDEFINED) => { // The connection may next return bytes for a request which timed out, thus no // more reads are allowed. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); conn.readable_request_id.store(usize::MAX, Ordering::Release); retry_new_connection!(_internal); }; (CONNECTION_STILL_READABLE, $write_lock: ident) => { // Make sure new requests don't try to use the old connection (but allow // requests that have already been sent to continue trying to read from it). - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); core::mem::drop($write_lock); retry_new_connection!(_internal); }; @@ -474,12 +489,13 @@ impl AsyncConnection { let socket_timeout = *conn.socket_new_requests_timeout.lock().unwrap(); let socket_timed_out = Instant::now() > socket_timeout; - request_id = conn.next_request_id.fetch_add(1, Ordering::Relaxed); - if request_id == usize::MAX || socket_timed_out { + let permit_acquired = conn.permits.fetch_sub(1, Ordering::Relaxed) > 0; + if !permit_acquired || socket_timed_out { // We can't send additional requests on the socket or the socket timed out and // need to resend the request on a new connection. retry_new_connection!(CONNECTION_STILL_READABLE, write); } + request_id = conn.next_request_id.fetch_add(1, Ordering::Relaxed); #[cfg(feature = "log")] log::trace!( "Writing HTTP request id {request_id} on connection to {:?}.", @@ -492,11 +508,11 @@ impl AsyncConnection { Err(e) => { // If we failed to write the request, mark the socket as dead for future // requests. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); return Err(e); } Ok(Err(ioe)) => { - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); return Err(Error::IoError(ioe)); } Ok(Ok(())) => {} @@ -566,7 +582,7 @@ impl AsyncConnection { } } if !found_keep_alive { - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); conn.readable_request_id.store(usize::MAX, Ordering::Release); } else { conn.readable_request_id.fetch_add(1, Ordering::Release); @@ -585,28 +601,27 @@ impl AsyncConnection { .unwrap_or(Instant::now()); } "max" => { - conn.next_request_id.fetch_max( - usize::MAX.saturating_sub(v), - Ordering::AcqRel, - ); + // Cap remaining permits to at most `v`. + let cap = isize::try_from(v).unwrap_or(MAX_SEND_PERMITS); + conn.permits.fetch_min(cap, Ordering::AcqRel); } _ => { // If we can't parse the keep-alive header, don't send any // new requests over this socket, but don't give up on // reading pending responses. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); } } } else { // If we can't parse the keep-alive header, don't send any new // requests over this socket, but don't give up on reading pending // responses. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); } } else { // If we can't parse the keep-alive header, don't send any new requests // over this socket, but don't give up on reading pending responses. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); } } } @@ -627,7 +642,7 @@ impl AsyncConnection { // If we failed to read the response after reading the request, the socket // is in an indeterminate state. Thus, we have to force every other waiting // request to retry on a new socket. - conn.next_request_id.store(usize::MAX, Ordering::Release); + conn.permits.store(0, Ordering::Release); conn.readable_request_id.store(usize::MAX, Ordering::Relaxed); return Err(e); } diff --git a/bitreq/src/response.rs b/bitreq/src/response.rs index b234de107..eef2c2f5e 100644 --- a/bitreq/src/response.rs +++ b/bitreq/src/response.rs @@ -82,7 +82,7 @@ impl Response { /// In order to avoid changing the API while fixing this, we read the full response but then /// return a "lazy" response that has the full contents pre-read. pub(crate) async fn create_async( - stream: R, + mut stream: R, is_head: bool, max_headers_size: Option, max_status_line_len: Option, @@ -90,8 +90,6 @@ impl Response { ) -> Result { use HttpStreamState::*; - let mut stream = tokio::io::BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream); - let ResponseMetadata { status_code, reason_phrase, diff --git a/bitreq/tests/pipelining.rs b/bitreq/tests/pipelining.rs new file mode 100644 index 000000000..ec8cc04e3 --- /dev/null +++ b/bitreq/tests/pipelining.rs @@ -0,0 +1,100 @@ +#![cfg(all(feature = "std", feature = "async"))] + +extern crate bitreq; + +use std::time::Duration; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::task::JoinSet; + +/// Spawns a TCP server that replies to every request with an empty `Content-Length: 0` response +/// carrying `Keep-Alive: max=1`, advertising that the connection must be closed after a single +/// request. Returns the bound address. +async fn spawn_keep_alive_max_one_server() -> std::net::SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + loop { + let (mut sock, _) = match listener.accept().await { + Ok(x) => x, + Err(_) => return, + }; + tokio::spawn(async move { + let mut buf = [0u8; 4096]; + let mut acc: Vec = Vec::new(); + loop { + let n = match sock.read(&mut buf).await { + Ok(0) | Err(_) => return, + Ok(n) => n, + }; + acc.extend_from_slice(&buf[..n]); + while let Some(end) = find_double_crlf(&acc) { + acc.drain(..end); + let response = b"HTTP/1.1 200 OK\r\n\ + Content-Length: 0\r\n\ + Keep-Alive: max=1\r\n\ + Connection: keep-alive\r\n\ + \r\n"; + if sock.write_all(response).await.is_err() { + return; + } + } + } + }); + } + }); + addr +} + +fn find_double_crlf(buf: &[u8]) -> Option { + buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn pipelined_requests_on_keep_alive_max_one() { + // Number of pipelined requests to fire after the priming request. + const PIPELINED_REQUESTS: usize = 100; + // Maximum number of cached connections in the client's pool. + const POOL_SIZE: usize = 20; + + let addr = spawn_keep_alive_max_one_server().await; + let url = format!("http://{}/", addr); + + let client = bitreq::Client::new(POOL_SIZE); + + // Prime the connection pool with one non-pipelined request. The server's `Keep-Alive: max=1` + // header retires the cached connection after this single use, so the pipelined batch below + // must complete on a fresh connection without hanging. + let _ = client + .send_async(bitreq::Request::new(bitreq::Method::Get, &url)) + .await + .expect("priming request succeeds"); + + let mut set = JoinSet::new(); + for i in 0..PIPELINED_REQUESTS { + let client = client.clone(); + let url = url.clone(); + set.spawn(async move { + println!("Launching request {}", i); + let req = bitreq::Request::new(bitreq::Method::Get, url).with_pipelining(); + let res = client.send_async(req).await.expect("pipelined request succeeds"); + println!("Got response {}", i); + res + }); + } + + let collect = async { + let mut results = Vec::with_capacity(PIPELINED_REQUESTS); + while let Some(res) = set.join_next().await { + results.push(res.expect("task panicked")); + } + results + }; + + let results = tokio::time::timeout(Duration::from_secs(10), collect) + .await + .unwrap_or_else(|_| panic!("{PIPELINED_REQUESTS} pipelined requests did not finish within 10s")); + + assert_eq!(results.len(), PIPELINED_REQUESTS); +}