Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 22 additions & 33 deletions library/core/src/array/drain.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::marker::{Destruct, PhantomData};
use crate::mem::{ManuallyDrop, SizedTypeProperties, conjure_zst};
use crate::ptr::{NonNull, drop_in_place, from_raw_parts_mut, null_mut};
use crate::ptr::{drop_in_place, slice_from_raw_parts_mut};

impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> {
/// This function returns a function that lets you index the given array in const.
Expand All @@ -18,15 +18,8 @@ impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> {
#[rustc_const_unstable(feature = "array_try_map", issue = "79711")]
pub(super) const unsafe fn new(array: &'l mut ManuallyDrop<[T; N]>, f: &'f mut F) -> Self {
// dont drop the array, transfers "ownership" to Self
let ptr: NonNull<T> = NonNull::from_mut(array).cast();
// SAFETY:
// Adding `slice.len()` to the starting pointer gives a pointer
// at the end of `slice`. `end` will never be dereferenced, only checked
// for direct pointer equality with `ptr` to check if the drainer is done.
unsafe {
let end = if T::IS_ZST { null_mut() } else { ptr.as_ptr().add(N) };
Self { ptr, end, f, l: PhantomData }
}
let end = array.as_mut_ptr_range().end;
Self { end, remaining: N, f, l: PhantomData }
}
}

Expand All @@ -35,20 +28,26 @@ impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> {
#[unstable(feature = "array_try_map", issue = "79711")]
pub(super) struct Drain<'l, 'f, T, const N: usize, F> {
// FIXME(const-hack): This is essentially a slice::IterMut<'static>, replace when possible.
/// The pointer to the next element to return, or the past-the-end location
/// if the drainer is empty.
///
/// This address will be used for all ZST elements, never changed.
/// Pointer to the past-the-end element.
/// As we "own" this array, we dont need to store any lifetime.
ptr: NonNull<T>,
/// For non-ZSTs, the non-null pointer to the past-the-end element.
/// For ZSTs, this is null.
end: *mut T,
/// The number of elements still to be drained.
remaining: usize,

f: &'f mut F,
l: PhantomData<&'l mut [T; N]>,
}

impl<T, const N: usize, F> Drain<'_, '_, T, N, F> {
/// Returns a pointer to the next element to be drained, or the past-the-end element if there
/// are no remaining elements to be drained.
const fn ptr(&mut self) -> *mut T {
// SAFETY: By the type invariants, self.remaining is always the number of elements prior to
// self.end that are still to be drained.
unsafe { self.end.sub(self.remaining) }
}
}

#[rustc_const_unstable(feature = "array_try_map", issue = "79711")]
#[unstable(feature = "array_try_map", issue = "79711")]
impl<T, U, const N: usize, F> const FnOnce<(usize,)> for &mut Drain<'_, '_, T, N, F>
Expand All @@ -73,15 +72,14 @@ where
&mut self,
(_ /* ignore argument */,): (usize,),
) -> Self::Output {
let p = self.ptr();
// decrement before moving; if `f` panics, we drop the rest.
self.remaining -= 1;
if T::IS_ZST {
// its UB to call this more than N times, so returning more ZSTs is valid.
// SAFETY: its a ZST? we conjur.
(self.f)(unsafe { conjure_zst::<T>() })
} else {
// increment before moving; if `f` panics, we drop the rest.
let p = self.ptr;
// SAFETY: caller guarantees never called more than N times (see `Drain::new`)
self.ptr = unsafe { self.ptr.add(1) };
// SAFETY: we are allowed to move this.
(self.f)(unsafe { p.read() })
}
Expand All @@ -91,18 +89,9 @@ where
#[unstable(feature = "array_try_map", issue = "79711")]
impl<T: [const] Destruct, const N: usize, F> const Drop for Drain<'_, '_, T, N, F> {
fn drop(&mut self) {
if !T::IS_ZST {
// SAFETY: we cant read more than N elements
let slice = unsafe {
from_raw_parts_mut::<[T]>(
self.ptr.as_ptr(),
// SAFETY: `start <= end`
self.end.offset_from_unsigned(self.ptr.as_ptr()),
)
};
let slice = slice_from_raw_parts_mut(self.ptr(), self.remaining);

// SAFETY: By the type invariant, we're allowed to drop all these. (we own it, after all)
unsafe { drop_in_place(slice) }
}
// SAFETY: By the type invariant, we're allowed to drop all these. (we own it, after all)
unsafe { drop_in_place(slice) }
}
}
22 changes: 15 additions & 7 deletions library/coretests/tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,26 +313,34 @@ fn array_map() {
#[test]
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
fn array_map_drop_safety() {
static DROPPED: AtomicUsize = AtomicUsize::new(0);
struct DropCounter;
impl Drop for DropCounter {
static OLD_DROPPED: AtomicUsize = AtomicUsize::new(0);
static NEW_DROPPED: AtomicUsize = AtomicUsize::new(0);
struct OldDropCounter;
struct NewDropCounter;
impl Drop for OldDropCounter {
fn drop(&mut self) {
DROPPED.fetch_add(1, Ordering::SeqCst);
OLD_DROPPED.fetch_add(1, Ordering::SeqCst);
}
}
impl Drop for NewDropCounter {
fn drop(&mut self) {
NEW_DROPPED.fetch_add(1, Ordering::SeqCst);
}
}

let num_to_create = 5;
let success = std::panic::catch_unwind(|| {
let items = [0; 10];
let items = [const { OldDropCounter }; 8];
let mut nth = 0;
let _ = items.map(|_| {
assert!(nth < num_to_create);
nth += 1;
DropCounter
NewDropCounter
});
});
assert!(success.is_err());
assert_eq!(DROPPED.load(Ordering::SeqCst), num_to_create);
assert_eq!(OLD_DROPPED.load(Ordering::SeqCst), 8);
assert_eq!(NEW_DROPPED.load(Ordering::SeqCst), num_to_create);
}

#[test]
Expand Down
Loading