diff --git a/mea/src/mpsc/bounded.rs b/mea/src/mpsc/bounded.rs index b6ffcb3..1e89588 100644 --- a/mea/src/mpsc/bounded.rs +++ b/mea/src/mpsc/bounded.rs @@ -85,8 +85,8 @@ impl Clone for BoundedSender { } impl fmt::Debug for BoundedSender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("BoundedSender").finish_non_exhaustive() + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoundedSender").finish_non_exhaustive() } } @@ -227,8 +227,8 @@ pub struct BoundedReceiver { unsafe impl Sync for BoundedReceiver {} impl fmt::Debug for BoundedReceiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("BoundedReceiver").finish_non_exhaustive() + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoundedReceiver").finish_non_exhaustive() } } diff --git a/mea/src/mpsc/error.rs b/mea/src/mpsc/error.rs index f3fb9c9..ee104b8 100644 --- a/mea/src/mpsc/error.rs +++ b/mea/src/mpsc/error.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::type_name; use std::fmt; /// An error returned when trying to send on a closed channel. @@ -49,13 +50,13 @@ impl SendError { impl fmt::Display for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - "sending on a closed channel".fmt(f) + f.write_str("sending on a closed channel") } } impl fmt::Debug for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SendError<{}>(..)", stringify!(T)) + write!(f, "SendError<{}>(..)", type_name::()) } } @@ -88,21 +89,20 @@ impl TrySendError { } impl fmt::Display for TrySendError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TrySendError::Full(_) => "sending on a full channel".fmt(fmt), - TrySendError::Disconnected(_) => "sending on a closed channel".fmt(fmt), - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + TrySendError::Full(_) => "sending on a full channel", + TrySendError::Disconnected(_) => "sending on a closed channel", + }) } } impl fmt::Debug for TrySendError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ty = type_name::(); match self { - TrySendError::Full(_) => write!(fmt, "TrySendError<{}>::Full(..)", stringify!(T)), - TrySendError::Disconnected(_) => { - write!(fmt, "TrySendError<{}>::Disconnected(..)", stringify!(T)) - } + TrySendError::Full(_) => write!(f, "TrySendError<{ty}>::Full(..)"), + TrySendError::Disconnected(_) => write!(f, "TrySendError<{ty}>::Disconnected(..)"), } } } @@ -117,8 +117,8 @@ pub enum RecvError { } impl fmt::Display for RecvError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - "receiving on a closed channel".fmt(fmt) + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("receiving on a closed channel") } } @@ -135,11 +135,11 @@ pub enum TryRecvError { } impl fmt::Display for TryRecvError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), - TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + TryRecvError::Empty => "receiving on an empty channel", + TryRecvError::Disconnected => "receiving on a closed channel", + }) } } diff --git a/mea/src/mpsc/unbounded.rs b/mea/src/mpsc/unbounded.rs index 9f7d944..df23b87 100644 --- a/mea/src/mpsc/unbounded.rs +++ b/mea/src/mpsc/unbounded.rs @@ -79,8 +79,8 @@ impl Clone for UnboundedSender { } impl fmt::Debug for UnboundedSender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("UnboundedSender").finish_non_exhaustive() + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UnboundedSender").finish_non_exhaustive() } } @@ -139,9 +139,8 @@ pub struct UnboundedReceiver { unsafe impl Sync for UnboundedReceiver {} impl fmt::Debug for UnboundedReceiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("UnboundedReceiver") - .finish_non_exhaustive() + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UnboundedReceiver").finish_non_exhaustive() } } diff --git a/mea/src/oneshot/mod.rs b/mea/src/oneshot/mod.rs index d6d2908..5e43107 100644 --- a/mea/src/oneshot/mod.rs +++ b/mea/src/oneshot/mod.rs @@ -72,6 +72,7 @@ //! # } //! ``` +use std::any::type_name; use std::cell::UnsafeCell; use std::fmt; use std::future::Future; @@ -114,10 +115,6 @@ unsafe impl Sync for Sender {} #[inline(always)] fn sender_wake_up_receiver(channel: &Channel, state: u8) { - // ORDERING: Synchronizes with writing waker to memory, and prevents the - // taking of the waker from being ordered before this operation. - fence(Ordering::Acquire); - // Take the waker, but critically do not awake it. If we awake it now, the // receiving thread could still observe the AWAKING state and re-await, meaning // that after we change to the MESSAGE state, it would remain waiting indefinitely @@ -125,14 +122,15 @@ fn sender_wake_up_receiver(channel: &Channel, state: u8) { // // SAFETY: at this point we are in the AWAKING state, and the receiving thread // does not access the waker while in this state, nor does it free the channel - // allocation in this state. + // allocation in this state. The caller's acquire ordering establishes a happens-before + // relationship with the writing of the waker. let waker = unsafe { channel.take_waker() }; - // ORDERING: this ordering serves two-fold: it synchronizes with the acquire load - // in the receiving thread, ensuring that both our read of the waker and write of - // the message happen-before the taking of the message and freeing of the channel. - // Furthermore, we need acquire ordering to ensure awaking the receiver - // happens after the channel state is updated. + // ORDERING: this ordering serves two-fold: it synchronizes with the receiver's + // acquire fence after it observes this state, ensuring that both our read of the + // waker and write of the message happen-before the taking of the message and + // freeing of the channel. Furthermore, we need acquire ordering to ensure awaking + // the receiver happens after the channel state is updated. channel.state.swap(state, Ordering::AcqRel); // Note: it is possible that between the store above and this statement that @@ -171,11 +169,11 @@ impl Sender { // * RECEIVING + 1 = AWAKING // * DISCONNECTED + 1 = EMPTY (invalid), however this state is never observed // - // ORDERING: we use release ordering to ensure writing the message is visible to the - // receiving thread. The EMPTY and DISCONNECTED branches do not observe any shared state, - // and thus we do not need an acquire ordering. The RECEIVING branch manages synchronization - // independent of this operation. - match channel.state.fetch_add(1, Ordering::Release) { + // ORDERING: we need release ordering to allow the receiver to synchronize with our write + // of the message and with our final write of the state, in the case where the receiver + // becomes responsible for freeing the channel. We need acquire ordering in the RECEIVING + // and DISCONNECTED branches, as explained further down. + match channel.state.fetch_add(1, Ordering::AcqRel) { // The receiver is alive and has not started waiting. Send done. EMPTY => Ok(()), // The receiver is waiting. Wake it up so it can return the message. @@ -185,8 +183,10 @@ impl Sender { } // The receiver was already dropped. The error is responsible for freeing the channel. // - // SAFETY: since the receiver disconnected it will no longer access `channel_ptr`, so - // we can transfer exclusive ownership of the channel's resources to the error. + // SAFETY: The acquire ordering above synchronizes with the receiver's write of the + // DISCONNECTED state. Since the receiver disconnected it will no longer access + // `channel_ptr`, so we can transfer exclusive ownership of the channel's resources to + // the error. // Moreover, since we just placed the message in the channel, the channel contains a // valid message. DISCONNECTED => Err(SendError { channel_ptr }), @@ -230,11 +230,13 @@ impl Drop for Sender { // * RECEIVING ^ 001 = AWAKING // * DISCONNECTED ^ 001 = EMPTY (invalid), but this state is never observed // - // ORDERING: we need not release ordering here since there are no modifications we - // need to make visible to other thread, and the Err(RECEIVING) branch handles - // synchronization independent of this fetch_xor - match channel.state.fetch_xor(0b001, Ordering::Relaxed) { - // The receiver has not started waiting, nor is it dropped. + // ORDERING: Release is required so that in the states where the receiver becomes + // responsible for deallocating the channel, they can synchronize with this final state + // write from us. Acquire is required by the branches below to synchronize with writes from + // the receiver. + match channel.state.fetch_xor(0b001, Ordering::AcqRel) { + // The receiver is not waiting, nor is it dropped. The receiver is responsible for + // deallocating the channel. EMPTY => {} // The receiver is waiting. Wake it up so it can detect that the channel disconnected. RECEIVING => sender_wake_up_receiver(channel, DISCONNECTED), @@ -243,7 +245,8 @@ impl Drop for Sender { // SAFETY: when the receiver switches the state to DISCONNECTED they have received // the message or will no longer be trying to receive the message, and have // observed that the sender is still alive, meaning that we are responsible for - // freeing the channel allocation. + // freeing the channel allocation. The acquire ordering above synchronizes with + // the receiver's final write of the state. unsafe { dealloc(self.channel_ptr) }; } state => unreachable!("unexpected channel state: {}", state), @@ -335,10 +338,9 @@ impl Receiver { // SAFETY: The channel will not be freed while this method is still running. let channel = unsafe { self.channel_ptr.as_ref() }; - // ORDERING: we use acquire ordering to synchronize with the store of the message. - match channel.state.load(Ordering::Acquire) { - EMPTY => Err(TryRecvError::Empty), - DISCONNECTED => Err(TryRecvError::Disconnected), + // ORDERING: Relaxed is fine since the only branch that needs synchronization is MESSAGE, + // and that branch has its own synchronization. + match channel.state.load(Ordering::Relaxed) { MESSAGE => { // It is okay to break up the load and store since once we are in the MESSAGE state, // the sender no longer modifies the state @@ -347,9 +349,14 @@ impl Receiver { // we need not make any side effects visible to it. channel.state.store(DISCONNECTED, Ordering::Relaxed); - // SAFETY: we are in the MESSAGE state so the message is present + // ORDERING: Synchronize with the sender's write of the message. + fence(Ordering::Acquire); + + // SAFETY: we are in the MESSAGE state so the message is present and synchronized. Ok(unsafe { channel.take_message() }) } + EMPTY => Err(TryRecvError::Empty), + DISCONNECTED => Err(TryRecvError::Disconnected), state => unreachable!("unexpected channel state: {}", state), } } @@ -361,17 +368,29 @@ impl Drop for Receiver { // left deallocating the channel allocation to us. let channel = unsafe { self.channel_ptr.as_ref() }; - // Set the channel state to disconnected and read what state the receiver was in. - match channel.state.swap(DISCONNECTED, Ordering::Acquire) { - // The sender has not sent anything, nor is it dropped. + // Set the channel state to disconnected and read what state the channel was in. + // + // ORDERING: Release is required so that in the states where the sender becomes responsible + // for deallocating the channel, they can synchronize with this final state write from us. + // Acquire is required by the branches below to synchronize with writes from the sender. + match channel.state.swap(DISCONNECTED, Ordering::AcqRel) { + // The sender has not sent anything, nor is it dropped. The sender is responsible for + // deallocating the channel. EMPTY => {} // The sender already sent something. We must drop it, and free the channel. MESSAGE => { + // SAFETY: The MESSAGE state plus acquire ordering guarantees the sender has + // written a message and that it has a happens-before relationship with this drop. unsafe { channel.drop_message() }; + + // SAFETY: The acquire ordering above synchronizes with the sender's final write + // of the state, so we can safely deallocate the channel. unsafe { dealloc(self.channel_ptr) }; } // The sender was already dropped. We are responsible for freeing the channel. DISCONNECTED => { + // SAFETY: The acquire ordering above synchronizes with the sender's final write + // of the state, so we can safely deallocate the channel. unsafe { dealloc(self.channel_ptr) }; } // NOTE: the receiver, unless transformed into a future, will never see the @@ -399,15 +418,20 @@ fn recv_awaken(channel: &Channel) -> Poll> { loop { hint::spin_loop(); - // ORDERING: The load above has already synchronized with writing message. + // ORDERING: The MESSAGE branch below uses a dedicated fence to synchronize with the + // sender. Until then, we only need to observe the state change. match channel.state.load(Ordering::Relaxed) { AWAKING => {} DISCONNECTED => break Poll::Ready(Err(RecvError::Disconnected)), MESSAGE => { - // ORDERING: the sender has been dropped, so this update only - // needs to be visible to us. + // ORDERING: after publishing MESSAGE, the sender no longer uses the channel, so + // this state update only needs to be visible to this receiver. channel.state.store(DISCONNECTED, Ordering::Relaxed); - // SAFETY: We observed the MESSAGE state. + + // ORDERING: Synchronize with the sender's write of the message and final state. + fence(Ordering::Acquire); + + // SAFETY: We observed the MESSAGE state and synchronized with the sender. break Poll::Ready(Ok(unsafe { channel.take_message() })); } state => unreachable!("unexpected channel state: {}", state), @@ -425,8 +449,9 @@ impl Future for Recv { // channel to us, so `self.channel` is valid let channel = unsafe { self.channel_ptr.as_ref() }; - // ORDERING: we use acquire ordering to synchronize with the store of the message. - match channel.state.load(Ordering::Acquire) { + // ORDERING: Relaxed is fine since the branches that need synchronization use dedicated + // fences. + match channel.state.load(Ordering::Relaxed) { // The sender is alive but has not sent anything yet. EMPTY => { let waker = cx.waker().clone(); @@ -435,32 +460,34 @@ impl Future for Recv { } // The sender sent the message. MESSAGE => { - // ORDERING: the sender has been dropped so this update only needs to be - // visible to us. + // ORDERING: after publishing MESSAGE, the sender no longer uses the channel, so + // this state update only needs to be visible to this receiver. channel.state.store(DISCONNECTED, Ordering::Relaxed); + + // ORDERING: Synchronize with the sender's write of the message and final state. + fence(Ordering::Acquire); + + // SAFETY: we are in the MESSAGE state and have synchronized with the sender. Poll::Ready(Ok(unsafe { channel.take_message() })) } // We were polled again while waiting for the sender. Replace the waker with the new // one. RECEIVING => { - // ORDERING: We use relaxed ordering on both success and failure since we have not - // written anything above that must be released, and the individual match arms - // handle any additional synchronization. + // ORDERING: Success synchronizes with the previous write_waker call before we + // drop the stored waker. Failure does not access the stored waker. match channel.state.compare_exchange( RECEIVING, EMPTY, - Ordering::Relaxed, + Ordering::Acquire, Ordering::Relaxed, ) { - // We successfully changed the state back to EMPTY. - // - // This is the most likely branch to be taken, which is why we do not use any - // memory barriers in the compare_exchange above. + // The state is EMPTY again. Ok(_) => { let waker = cx.waker().clone(); - // SAFETY: We wrote the waker in a previous call to poll. We do not need - // a memory barrier since the previous write here was by ourselves. + // SAFETY: The successful exchange makes the state EMPTY, so the sender + // cannot take the stored waker. The acquire ordering synchronizes with the + // waker write. unsafe { channel.drop_waker() }; // SAFETY: We can not be in the forbidden states, and no waker in the @@ -471,11 +498,15 @@ impl Future for Recv { // We take the message and mark the channel disconnected. // The sender has already taken the waker. Err(MESSAGE) => { - // ORDERING: Synchronize with writing message. This branch is - // unlikely to be taken. - channel.state.swap(DISCONNECTED, Ordering::Acquire); + // ORDERING: after publishing MESSAGE, the sender no longer uses the + // channel, so this state update only needs to be visible to this receiver. + channel.state.store(DISCONNECTED, Ordering::Relaxed); + + // ORDERING: Synchronize with the sender's write of the message. + fence(Ordering::Acquire); - // SAFETY: The state tells us the sender has initialized the message. + // SAFETY: The state tells us the sender has initialized the message, and + // the fence above synchronizes with that write. Poll::Ready(Ok(unsafe { channel.take_message() })) } // The sender is currently waking us up. @@ -503,47 +534,70 @@ impl Drop for Recv { // left deallocating the channel allocation to us. let channel = unsafe { self.channel_ptr.as_ref() }; - // Set the channel state to disconnected and read what state the receiver was in. - match channel.state.swap(DISCONNECTED, Ordering::Acquire) { - // The sender has not sent anything, nor is it dropped. - EMPTY => {} - // The sender already sent something. We must drop it, and free the channel. - MESSAGE => { - unsafe { channel.drop_message() }; - unsafe { dealloc(self.channel_ptr) }; - } - // The receiver has been polled. We must drop the waker. - RECEIVING => { - unsafe { channel.drop_waker() }; - } - // The sender was already dropped. We are responsible for freeing the channel. - DISCONNECTED => { - // SAFETY: see safety comment at top of function. - unsafe { dealloc(self.channel_ptr) }; - } - // This receiver was previously polled, so the channel was in the RECEIVING state. - // But the sender has observed the RECEIVING state and is currently reading the waker - // to wake us up. We need to loop here until we observe the MESSAGE or DISCONNECTED - // state. We busy loop here since we know the sender is done very soon. - AWAKING => { - loop { - hint::spin_loop(); - - // ORDERING: The swap above has already synchronized with writing message. - match channel.state.load(Ordering::Relaxed) { - AWAKING => {} - DISCONNECTED => break, - MESSAGE => { - // SAFETY: we are in the message state so the message is initialized. - unsafe { channel.drop_message() }; - break; - } - state => unreachable!("unexpected channel state: {}", state), + loop { + // ORDERING: MESSAGE and DISCONNECTED synchronize with the sender's state writes. + match channel.state.load(Ordering::Acquire) { + // The sender has not sent anything, nor is it dropped. Mark the receiver as + // dropped; the sender is responsible for deallocating the channel. + EMPTY => { + if channel + .state + .compare_exchange(EMPTY, DISCONNECTED, Ordering::Release, Ordering::Relaxed) + .is_ok() + { + break; } } - unsafe { dealloc(self.channel_ptr) }; + // The sender already sent something. We must drop it, and free the channel. + MESSAGE => { + // SAFETY: The MESSAGE state plus acquire ordering guarantees the sender has + // written a message and that it has a happens-before relationship with this + // drop. + unsafe { channel.drop_message() }; + + // SAFETY: The acquire load above synchronizes with the sender's final write of + // the state, so we can safely deallocate the channel. + unsafe { dealloc(self.channel_ptr) }; + break; + } + // This receiver was previously polled, but was not polled to completion. Move away + // from RECEIVING before dropping the waker so the sender cannot take the same + // waker. + // + // A successful exchange creates a short EMPTY window before the next iteration can + // mark DISCONNECTED. This branch owns and drops the stored waker first. A sender + // that observes EMPTY does not touch the waker. It either stores MESSAGE and + // leaves the message and allocation to this loop, or stores DISCONNECTED and + // leaves the allocation to this loop. If this loop marks DISCONNECTED first, the + // sender observes DISCONNECTED and owns any send error cleanup. + RECEIVING => { + if channel + .state + .compare_exchange(RECEIVING, EMPTY, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + // SAFETY: The successful exchange makes the state EMPTY, so the sender + // cannot take the stored waker. The acquire ordering synchronizes with the + // waker write. + unsafe { channel.drop_waker() }; + } + } + // The sender has observed RECEIVING and is taking the waker. Wait until it stores + // MESSAGE or DISCONNECTED. + AWAKING => { + hint::spin_loop(); + } + // The sender was already dropped, or this future was previously polled to + // completion. We are responsible for freeing the channel. + DISCONNECTED => { + // SAFETY: When DISCONNECTED comes from the sender, the acquire load + // synchronizes with the sender's state write. When it comes from our own + // completed poll, the message has already been taken. + unsafe { dealloc(self.channel_ptr) }; + break; + } + state => unreachable!("unexpected channel state: {}", state), } - state => unreachable!("unexpected channel state: {}", state), } } } @@ -605,7 +659,7 @@ impl Channel { /// # Safety /// /// * The `waker` field must not have a waker stored when calling this method. - /// * The `state` must not be in the RECEIVING state when calling this method. + /// * The `state` must not be in the RECEIVING or AWAKING state when calling this method. unsafe fn write_waker(&self, waker: Waker) -> Poll> { // Write the waker instance to the channel. // @@ -629,29 +683,27 @@ impl Channel { // The sender sent the message while we prepared to await. // We take the message and mark the channel disconnected. Err(MESSAGE) => { - // ORDERING: Synchronize with writing message. This branch is unlikely to be - // taken, so it is likely more efficient to use a fence here - // instead of AcqRel ordering on the compare_exchange - // operation. - fence(Ordering::Acquire); - - // SAFETY: we started in the EMPTY state and the sender switched us to the - // MESSAGE state. This means that it did not take the waker, so we're - // responsible for dropping it. + // SAFETY: We wrote a waker above. The sender cannot have observed the RECEIVING + // state, so it has not accessed the waker. We must drop it. unsafe { self.drop_waker() }; // ORDERING: sender does not exist, so this update only needs to be visible to // us. self.state.store(DISCONNECTED, Ordering::Relaxed); - // SAFETY: The MESSAGE state tells us there is a correctly initialized message. + // ORDERING: Synchronize with writing message. This branch is unlikely to be + // taken, so it is likely more efficient to use a fence here instead of AcqRel + // ordering on the compare_exchange operation. + fence(Ordering::Acquire); + + // SAFETY: The MESSAGE state tells us there is a correctly initialized message, + // and the fence above synchronizes with that write. Poll::Ready(Ok(unsafe { self.take_message() })) } // The sender was dropped before sending anything while we prepared to await. Err(DISCONNECTED) => { - // SAFETY: we started in the EMPTY state and the sender switched us to the - // DISCONNECTED state. This means that it did not take the waker, so we are - // responsible for dropping it. + // SAFETY: We wrote a waker above. The sender cannot have observed the RECEIVING + // state, so it has not accessed the waker. We must drop it. unsafe { self.drop_waker() }; Poll::Ready(Err(RecvError::Disconnected)) } @@ -737,13 +789,13 @@ impl Drop for SendError { impl fmt::Display for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - "sending on a closed channel".fmt(f) + f.write_str("sending on a closed channel") } } impl fmt::Debug for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SendError<{}>(..)", core::any::type_name::()) + write!(f, "SendError<{}>(..)", type_name::()) } } @@ -761,10 +813,10 @@ pub enum TryRecvError { impl fmt::Display for TryRecvError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryRecvError::Empty => write!(f, "receiving on an empty channel"), - TryRecvError::Disconnected => write!(f, "receiving on a closed channel"), - } + f.write_str(match self { + TryRecvError::Empty => "receiving on an empty channel", + TryRecvError::Disconnected => "receiving on a closed channel", + }) } } @@ -783,7 +835,7 @@ pub enum RecvError { impl fmt::Display for RecvError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "receiving on a closed channel") + f.write_str("receiving on a closed channel") } } diff --git a/mea/src/oneshot/tests.rs b/mea/src/oneshot/tests.rs index af394c9..bb4c980 100644 --- a/mea/src/oneshot/tests.rs +++ b/mea/src/oneshot/tests.rs @@ -14,6 +14,7 @@ use std::future::Future; use std::future::IntoFuture; +use std::hint::spin_loop; use std::mem; use std::pin::Pin; use std::sync::Arc; @@ -26,6 +27,7 @@ use std::task::RawWaker; use std::task::RawWakerVTable; use std::task::Waker; use std::time::Duration; +use std::time::Instant; use crate::oneshot; use crate::oneshot::TryRecvError; @@ -285,6 +287,59 @@ fn poll_with_different_wakers() { assert_eq!(waker_handle2.wake_count(), 1); } +#[test] +fn poll_with_different_wakers_across_threads() { + let (sender, receiver) = oneshot::channel::(); + let mut receiver = receiver.into_future(); + + let (waker1, waker_handle1) = waker(); + let mut context1 = Context::from_waker(&waker1); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context1), Poll::Pending); + assert_eq!(waker_handle1.clone_count(), 1); + assert_eq!(waker_handle1.drop_count(), 0); + assert_eq!(waker_handle1.wake_count(), 0); + + let receiver_thread = spawn_named("receiver", move || { + let (waker2, waker_handle2) = waker(); + let mut context2 = Context::from_waker(&waker2); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context2), Poll::Pending); + assert_eq!(waker_handle2.clone_count(), 1); + assert_eq!(waker_handle2.drop_count(), 0); + assert_eq!(waker_handle2.wake_count(), 0); + + drop(receiver); + assert_eq!(waker_handle2.drop_count(), 1); + }); + + receiver_thread.join().unwrap(); + assert_eq!(waker_handle1.drop_count(), 1); + assert!(sender.is_closed()); +} + +#[test] +fn drop_pending_receiver_closes_channel_and_drops_waker() { + let (sender, receiver) = oneshot::channel::(); + let mut receiver = receiver.into_future(); + + let (waker, waker_handle) = waker(); + let mut context = Context::from_waker(&waker); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending); + assert_eq!(waker_handle.clone_count(), 1); + assert_eq!(waker_handle.drop_count(), 0); + assert_eq!(waker_handle.wake_count(), 0); + + drop(receiver); + assert_eq!(waker_handle.drop_count(), 1); + assert_eq!(waker_handle.wake_count(), 0); + assert!(sender.is_closed()); + + let error = sender.send(1234).unwrap_err(); + assert_eq!(*error.as_inner(), 1234); +} + #[test] fn poll_then_drop_receiver_during_send() { let (sender, receiver) = oneshot::channel::(); @@ -324,3 +379,130 @@ fn async_receiver_has_message() { assert!(sender.send(19i128).is_ok()); assert!(receiver.has_message()); } + +#[test] +fn concurrent_send_and_try_recv_to_completion() { + let (sender, receiver) = oneshot::channel::(); + + let receiver_thread = spawn_named("receiver", move || { + spin_until("message from sender", || match receiver.try_recv() { + Ok(999) => Some(()), + Ok(value) => panic!("unexpected value: {value}"), + Err(TryRecvError::Empty) => None, + Err(TryRecvError::Disconnected) => panic!("unexpected disconnect"), + }); + }); + + let sender_thread = spawn_named("sender", move || { + sender.send(999).unwrap(); + }); + + receiver_thread.join().unwrap(); + sender_thread.join().unwrap(); +} + +#[test] +fn concurrent_drop_sender_and_try_recv_to_completion() { + let (sender, receiver) = oneshot::channel::(); + + let receiver_thread = spawn_named("receiver", move || { + spin_until("sender disconnect", || match receiver.try_recv() { + Ok(value) => panic!("unexpected value: {value}"), + Err(TryRecvError::Empty) => None, + Err(TryRecvError::Disconnected) => Some(()), + }); + }); + + let sender_thread = spawn_named("sender", move || { + drop(sender); + }); + + receiver_thread.join().unwrap(); + sender_thread.join().unwrap(); +} + +#[test] +fn concurrent_send_and_poll_to_completion() { + let (sender, receiver) = oneshot::channel::(); + + let receiver_thread = spawn_named("receiver", move || { + let mut receiver = receiver.into_future(); + let (waker, _waker_handle) = waker(); + let mut context = Context::from_waker(&waker); + + spin_until("poll ready with message", || { + match Pin::new(&mut receiver).poll(&mut context) { + Poll::Ready(Ok(999)) => Some(()), + Poll::Ready(result) => panic!("unexpected result: {result:?}"), + Poll::Pending => None, + } + }); + }); + + let sender_thread = spawn_named("sender", move || { + sender.send(999).unwrap(); + }); + + receiver_thread.join().unwrap(); + sender_thread.join().unwrap(); +} + +#[test] +fn concurrent_drop_sender_and_poll_to_completion() { + let (sender, receiver) = oneshot::channel::(); + + let receiver_thread = spawn_named("receiver", move || { + let mut receiver = receiver.into_future(); + let (waker, _waker_handle) = waker(); + let mut context = Context::from_waker(&waker); + + spin_until("poll ready with disconnect", || { + match Pin::new(&mut receiver).poll(&mut context) { + Poll::Ready(Err(oneshot::RecvError::Disconnected)) => Some(()), + Poll::Ready(result) => panic!("unexpected result: {result:?}"), + Poll::Pending => None, + } + }); + }); + + let sender_thread = spawn_named("sender", move || { + drop(sender); + }); + + receiver_thread.join().unwrap(); + sender_thread.join().unwrap(); +} + +fn spawn_named(name: &str, f: F) -> std::thread::JoinHandle<()> +where + F: FnOnce() + Send + 'static, +{ + std::thread::Builder::new() + .name(name.to_string()) + .spawn(f) + .unwrap() +} + +fn spin_until(label: &str, mut f: F) +where + F: FnMut() -> Option<()>, +{ + let deadline = Instant::now() + Duration::from_secs(5); + let mut spins = 0usize; + + loop { + if f().is_some() { + break; + } + + assert!(Instant::now() < deadline, "timed out waiting for {label}"); + + if spins % 64 == 0 { + std::thread::yield_now(); + } else { + spin_loop(); + } + + spins += 1; + } +}