Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions bitreq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::net::{TcpStream, ToSocketAddrs};
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
#[cfg(feature = "async")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "async")]
Expand All @@ -27,6 +27,9 @@ use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest};
use crate::Response;
use crate::{Error, Method, ResponseLazy};

#[cfg(feature = "async")]
const BACKING_READ_BUFFER_LENGTH: usize = 16 * 1024;

type UnsecuredStream = TcpStream;

#[cfg(feature = "rustls")]
Expand Down Expand Up @@ -214,13 +217,19 @@ impl AsyncWrite for AsyncHttpStream {
}
}

/// The starting value of [`AsyncConnectionState::permits`].
#[cfg(feature = "async")]
const MAX_SEND_PERMITS: isize = isize::MAX;

#[cfg(feature = "async")]
struct AsyncConnectionState {
write: AsyncMutex<WriteHalf<AsyncHttpStream>>,
read: AsyncMutex<ReadHalf<AsyncHttpStream>>,
/// The ID of the next request we'll send. If this reaches [`usize::MAX`] no further requests
/// can be sent on this socket and a new connection must be made. Thus, in order to limit the
/// connection to sending N new requests, this may be set to [`usize::MAX`] - N.
read: AsyncMutex<tokio::io::BufReader<ReadHalf<AsyncHttpStream>>>,
/// Remaining permits for sending new requests. Acquired via `fetch_sub(1)`; the counter is
/// alive while it remains positive. Poisoned by storing `0`; capped by keep-alive `max=N` via
/// `fetch_min(N)`.
permits: AtomicIsize,
/// Monotonic ID of the next request to send, used only for pipelining order.
next_request_id: AtomicUsize,
/// The ID of the next request which is readable from the socket. If we're pipelining this may
/// be a few behind [`Self::next_request_id`]. If this is [`usize::MAX`], the socket is in an
Expand Down Expand Up @@ -292,8 +301,12 @@ impl AsyncConnection {
let (read, write) = tokio::io::split(stream);

Ok(AsyncConnection(Mutex::new(Arc::new(AsyncConnectionState {
read: AsyncMutex::new(read),
read: AsyncMutex::new(tokio::io::BufReader::with_capacity(
BACKING_READ_BUFFER_LENGTH,
read,
)),
write: AsyncMutex::new(write),
permits: AtomicIsize::new(MAX_SEND_PERMITS),
next_request_id: AtomicUsize::new(0),
readable_request_id: AtomicUsize::new(0),
min_dropped_reader_id: AtomicUsize::new(usize::MAX),
Expand Down Expand Up @@ -399,7 +412,7 @@ impl AsyncConnection {
let next_read = conn.readable_request_id.load(Ordering::Acquire);
let next_request = conn.next_request_id.load(Ordering::Acquire);
debug_assert!(
next_request >= next_read,
next_read == usize::MAX || next_request >= next_read,
"At all times, the next_request_id should be higher than the readable id"
);
}
Expand All @@ -415,15 +428,17 @@ impl AsyncConnection {
if !request.config.pipelining {
// If we're not pipelining, wait for any existing pipelined requests to complete.
// Specifically, wait until we have both locks and either we're going to build a
// new connection (because `next_request_id` is `usize::MAX`) or there are no
// new connection (because `permits` is exhausted/poisoned) or there are no
// pending readers (because `next_request_id` and `readable_request_id` are the
// same).
read = Some(Self::timeout(request.timeout_at, conn.read.lock()).await?);
write = Some(Self::timeout(request.timeout_at, conn.write.lock()).await?);
while {
let next_read = conn.readable_request_id.load(Ordering::Relaxed);
let next_request = conn.next_request_id.load(Ordering::Relaxed);
next_request != usize::MAX && next_read < next_request
let has_open_request = next_read < next_request;
let permits_alive = conn.permits.load(Ordering::Relaxed) > 0;
permits_alive && has_open_request
} {
read.take();
write.take();
Expand All @@ -437,14 +452,14 @@ impl AsyncConnection {
(CONNECTION_STATE_UNDEFINED) => {
// The connection may next return bytes for a request which timed out, thus no
// more reads are allowed.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
conn.readable_request_id.store(usize::MAX, Ordering::Release);
retry_new_connection!(_internal);
};
(CONNECTION_STILL_READABLE, $write_lock: ident) => {
// Make sure new requests don't try to use the old connection (but allow
// requests that have already been sent to continue trying to read from it).
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
core::mem::drop($write_lock);
retry_new_connection!(_internal);
};
Expand Down Expand Up @@ -474,12 +489,13 @@ impl AsyncConnection {
let socket_timeout = *conn.socket_new_requests_timeout.lock().unwrap();
let socket_timed_out = Instant::now() > socket_timeout;

request_id = conn.next_request_id.fetch_add(1, Ordering::Relaxed);
if request_id == usize::MAX || socket_timed_out {
let permit_acquired = conn.permits.fetch_sub(1, Ordering::Relaxed) > 0;
if !permit_acquired || socket_timed_out {
// We can't send additional requests on the socket or the socket timed out and
// need to resend the request on a new connection.
retry_new_connection!(CONNECTION_STILL_READABLE, write);
}
request_id = conn.next_request_id.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "log")]
log::trace!(
"Writing HTTP request id {request_id} on connection to {:?}.",
Expand All @@ -492,11 +508,11 @@ impl AsyncConnection {
Err(e) => {
// If we failed to write the request, mark the socket as dead for future
// requests.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
return Err(e);
}
Ok(Err(ioe)) => {
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
return Err(Error::IoError(ioe));
}
Ok(Ok(())) => {}
Expand Down Expand Up @@ -566,7 +582,7 @@ impl AsyncConnection {
}
}
if !found_keep_alive {
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
conn.readable_request_id.store(usize::MAX, Ordering::Release);
} else {
conn.readable_request_id.fetch_add(1, Ordering::Release);
Expand All @@ -585,28 +601,27 @@ impl AsyncConnection {
.unwrap_or(Instant::now());
}
"max" => {
conn.next_request_id.fetch_max(
usize::MAX.saturating_sub(v),
Ordering::AcqRel,
);
// Cap remaining permits to at most `v`.
let cap = isize::try_from(v).unwrap_or(MAX_SEND_PERMITS);
conn.permits.fetch_min(cap, Ordering::AcqRel);
}
_ => {
// If we can't parse the keep-alive header, don't send any
// new requests over this socket, but don't give up on
// reading pending responses.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
}
}
} else {
// If we can't parse the keep-alive header, don't send any new
// requests over this socket, but don't give up on reading pending
// responses.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
}
} else {
// If we can't parse the keep-alive header, don't send any new requests
// over this socket, but don't give up on reading pending responses.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
}
}
}
Expand All @@ -627,7 +642,7 @@ impl AsyncConnection {
// If we failed to read the response after reading the request, the socket
// is in an indeterminate state. Thus, we have to force every other waiting
// request to retry on a new socket.
conn.next_request_id.store(usize::MAX, Ordering::Release);
conn.permits.store(0, Ordering::Release);
conn.readable_request_id.store(usize::MAX, Ordering::Relaxed);
return Err(e);
}
Expand Down
4 changes: 1 addition & 3 deletions bitreq/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ impl Response {
/// In order to avoid changing the API while fixing this, we read the full response but then
/// return a "lazy" response that has the full contents pre-read.
pub(crate) async fn create_async<R: AsyncRead + Unpin>(
stream: R,
mut stream: R,
is_head: bool,
max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
max_body_size: Option<usize>,
) -> Result<Response, Error> {
use HttpStreamState::*;

let mut stream = tokio::io::BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream);

let ResponseMetadata {
status_code,
reason_phrase,
Expand Down
100 changes: 100 additions & 0 deletions bitreq/tests/pipelining.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#![cfg(all(feature = "std", feature = "async"))]

extern crate bitreq;

use std::time::Duration;

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::task::JoinSet;

/// Spawns a TCP server that replies to every request with an empty `Content-Length: 0` response
/// carrying `Keep-Alive: max=1`, advertising that the connection must be closed after a single
/// request. Returns the bound address.
async fn spawn_keep_alive_max_one_server() -> std::net::SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (mut sock, _) = match listener.accept().await {
Ok(x) => x,
Err(_) => return,
};
tokio::spawn(async move {
let mut buf = [0u8; 4096];
let mut acc: Vec<u8> = Vec::new();
loop {
let n = match sock.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
acc.extend_from_slice(&buf[..n]);
while let Some(end) = find_double_crlf(&acc) {
acc.drain(..end);
let response = b"HTTP/1.1 200 OK\r\n\
Content-Length: 0\r\n\
Keep-Alive: max=1\r\n\
Connection: keep-alive\r\n\
\r\n";
if sock.write_all(response).await.is_err() {
return;
}
}
}
});
}
});
addr
}

fn find_double_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn pipelined_requests_on_keep_alive_max_one() {
// Number of pipelined requests to fire after the priming request.
const PIPELINED_REQUESTS: usize = 100;
// Maximum number of cached connections in the client's pool.
const POOL_SIZE: usize = 20;

let addr = spawn_keep_alive_max_one_server().await;
let url = format!("http://{}/", addr);

let client = bitreq::Client::new(POOL_SIZE);

// Prime the connection pool with one non-pipelined request. The server's `Keep-Alive: max=1`
// header retires the cached connection after this single use, so the pipelined batch below
// must complete on a fresh connection without hanging.
let _ = client
.send_async(bitreq::Request::new(bitreq::Method::Get, &url))
.await
.expect("priming request succeeds");

let mut set = JoinSet::new();
for i in 0..PIPELINED_REQUESTS {
let client = client.clone();
let url = url.clone();
set.spawn(async move {
println!("Launching request {}", i);
let req = bitreq::Request::new(bitreq::Method::Get, url).with_pipelining();
let res = client.send_async(req).await.expect("pipelined request succeeds");
println!("Got response {}", i);
res
});
}

let collect = async {
let mut results = Vec::with_capacity(PIPELINED_REQUESTS);
while let Some(res) = set.join_next().await {
results.push(res.expect("task panicked"));
}
results
};

let results = tokio::time::timeout(Duration::from_secs(10), collect)
.await
.unwrap_or_else(|_| panic!("{PIPELINED_REQUESTS} pipelined requests did not finish within 10s"));

assert_eq!(results.len(), PIPELINED_REQUESTS);
}
Loading