diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2f71fd7..66e9424 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,9 +8,9 @@ on: env: JS_PACKAGES: "['type-length-value-js']" - SBPF_PROGRAM_PACKAGES: "['discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" - RUST_PACKAGES: "['discriminator', 'discriminator-derive', 'discriminator-syn', 'generic-token', 'generic-token-tests', 'list-view', 'pod', 'program-error', 'program-error-derive', 'tlv-account-resolution', 'type-length-value', 'type-length-value-derive', 'type-length-value-derive-test']" - WASM_PACKAGES: "['discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" + SBPF_PROGRAM_PACKAGES: "['collections', 'discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" + RUST_PACKAGES: "['collections', 'discriminator', 'discriminator-derive', 'discriminator-syn', 'generic-token', 'generic-token-tests', 'list-view', 'pod', 'program-error', 'program-error-derive', 'tlv-account-resolution', 'type-length-value', 'type-length-value-derive', 'type-length-value-derive-test']" + WASM_PACKAGES: "['collections', 'discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" jobs: set_env: diff --git a/.github/workflows/publish-rust.yml b/.github/workflows/publish-rust.yml index 5469c66..e69ef74 100644 --- a/.github/workflows/publish-rust.yml +++ b/.github/workflows/publish-rust.yml @@ -9,6 +9,7 @@ on: default: 'discriminator' type: choice options: + - collections - discriminator - discriminator-derive - discriminator-syn diff --git a/Cargo.lock b/Cargo.lock index 95c6adb..da843c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1532,6 +1532,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "spl-collections" +version = "0.0.0" +dependencies = [ + "borsh", + "spl-collections", + "wincode", +] + [[package]] name = "spl-discriminator" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index a18b107..521897b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "collections", "discriminator", "discriminator-derive", "discriminator-syn", diff --git a/collections/Cargo.toml b/collections/Cargo.toml new file mode 100644 index 0000000..953e5a2 --- /dev/null +++ b/collections/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "spl-collections" +version = "0.0.0" +description = "Serialization-aware collection wrappers for Solana account data" +authors = ["Anza Maintainers "] +repository = "https://github.com/solana-program/libraries" +license = "Apache-2.0" +edition = "2021" + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] +all-features = true +rustdoc-args = ["--cfg=docsrs"] + +[features] +borsh = ["dep:borsh"] +wincode = ["dep:wincode"] + +[dependencies] +borsh = { version = "1.0", features = ["derive"], default-features = false, optional = true } +wincode = { version = "0.4.4", features = ["alloc", "derive"], default-features = false, optional = true } + +[dev-dependencies] +spl-collections = { path = ".", features = ["borsh", "wincode"] } + +[lib] +crate-type = ["lib"] diff --git a/collections/src/lib.rs b/collections/src/lib.rs new file mode 100644 index 0000000..fb1bc2c --- /dev/null +++ b/collections/src/lib.rs @@ -0,0 +1,16 @@ +//! Serialization-aware collection wrappers for Solana account data. +//! +//! This crate provides wrappers around collection types to support custom serialization +//! logic. This is useful for programs that have specific requirements for how data is +//! stored. + +#![no_std] +#![cfg_attr(docsrs, feature(doc_cfg))] + +extern crate alloc; + +mod str; +mod vec; + +pub use str::*; +pub use vec::*; diff --git a/collections/src/str.rs b/collections/src/str.rs new file mode 100644 index 0000000..9d97a4a --- /dev/null +++ b/collections/src/str.rs @@ -0,0 +1,497 @@ +//! Types for serializing strings. +//! +//! This module provides two types for serializing strings: `TrailingStr` and a +//! set of `PrefixedStr`. +//! +//! `TrailingStr` is serialized without a length prefix, while the `PrefixedStr`s +//! are serialized with a length prefix determined by a type. The length prefix is useful +//! for deserializing strings that are not the last field of a struct, as it allows the +//! deserializer to know how many bytes to read for the string, while allowing for more +//! efficient storage depending on the expected length of the string. +//! +//! The types in this module also implement the `Deref` trait, allowing them to be used +//! as `&str` in most contexts. + +#[cfg(feature = "borsh")] +use borsh::{ + io::{ErrorKind, Read}, + BorshDeserialize, BorshSerialize, +}; +use { + crate::{TrailingVec, U16PrefixedVec, U32PrefixedVec, U64PrefixedVec, U8PrefixedVec}, + core::{ + fmt::{Debug, Formatter}, + ops::Deref, + str::from_utf8_unchecked, + }, +}; +#[cfg(feature = "wincode")] +use { + core::{mem::MaybeUninit, str::from_utf8}, + wincode::{ + config::{Config, ConfigCore}, + io::Reader, + ReadError, ReadResult, SchemaRead, SchemaWrite, UninitBuilder, + }, +}; + +/// A `str` serialized without a length prefix. +/// +/// This is useful for serializing strings that are the last field +/// of a struct, where the length can be inferred from the remaining +/// bytes. +/// +/// Note that this type is not suitable for serializing strings that +/// are not the last field of a struct, as it will consume all +/// remaining bytes. +/// +/// # Examples +/// +/// Using `TrailingStr` in a struct results in the string being +/// serialized without a length prefix. +/// +/// ``` +/// use spl_collections::TrailingStr; +/// use wincode::{SchemaRead, SchemaWrite}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// pub struct MyStruct { +/// pub state: u8, +/// pub amount: u64, +/// pub description: TrailingStr, +/// } +/// +/// let my_struct = MyStruct { +/// state: 1, +/// amount: 1_000_000_000, +/// description: TrailingStr::from( +/// "The quick brown fox jumps over the lazy dog" +/// ), +/// }; +/// +/// let bytes = wincode::serialize(&my_struct).unwrap(); +/// // Expected size: +/// // - state (1 byte) +/// // - amount (8 bytes) +/// // - description (remaining bytes without a length prefix) +/// assert_eq!(bytes.len(), 1 + 8 + my_struct.description.len()); +/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); +/// +/// # assert_eq!(deserialized.state, my_struct.state); +/// # assert_eq!(deserialized.amount, my_struct.amount); +/// # assert_eq!(deserialized.description, my_struct.description); +/// ``` +#[cfg_attr(feature = "borsh", derive(BorshSerialize))] +#[cfg_attr(feature = "wincode", derive(SchemaWrite, UninitBuilder))] +#[derive(Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct TrailingStr(TrailingVec); + +impl> From for TrailingStr { + fn from(value: T) -> Self { + Self(TrailingVec::from(value.as_ref().as_bytes())) + } +} + +impl Deref for TrailingStr { + type Target = str; + + fn deref(&self) -> &Self::Target { + // SAFETY: The `TrailingStr` type is only constructed + // from valid UTF-8 strings. + unsafe { from_utf8_unchecked(&self.0) } + } +} + +impl Debug for TrailingStr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.deref())) + } +} + +#[cfg(feature = "borsh")] +impl BorshDeserialize for TrailingStr { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let container = TrailingVec::::deserialize_reader(reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ErrorKind::InvalidData.into()); + } + + Ok(Self(container)) + } +} + +#[cfg(feature = "wincode")] +unsafe impl<'de, C: Config> SchemaRead<'de, C> for TrailingStr { + type Dst = Self; + + fn read(reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let mut builder = TrailingStrUninitBuilder::::from_maybe_uninit_mut(dst); + builder.read_0(reader)?; + + let container = unsafe { builder.uninit_0_mut().assume_init_ref() }; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(container).is_err() { + return Err(ReadError::Custom("invalid UTF-8 bytes")); + } + + builder.finish(); + + Ok(()) + } +} + +/// Macro defining a `PrefixedStr` type with a specified length prefix type. +macro_rules! prefixed_str_type { + ( $name:tt, $container_type:tt, $prefix_type:tt ) => { + #[doc = concat!("A `str` that is serialized with an `", stringify!($prefix_type), "` length prefix.")] + #[cfg_attr(feature = "borsh", derive(BorshSerialize))] + #[cfg_attr(feature = "wincode", derive(SchemaWrite))] + #[derive(Clone, Eq, PartialEq)] + #[repr(transparent)] + pub struct $name($container_type); + + impl> From for $name { + fn from(value: T) -> Self { + Self($container_type::from(value.as_ref().as_bytes())) + } + } + + impl Deref for $name { + type Target = str; + + fn deref(&self) -> &Self::Target { + // SAFETY: `*PrefixedStr` types are only constructed + // from valid UTF-8 strings. + unsafe { from_utf8_unchecked(&self.0) } + } + } + + impl Debug for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.deref())) + } + } + + #[cfg(feature = "borsh")] + impl BorshDeserialize for $name { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let container = $container_type::::deserialize_reader(reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ErrorKind::InvalidData.into()); + } + + Ok(Self(container)) + } + } + + #[cfg(feature = "wincode")] + unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for $name { + type Dst = Self; + + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let container = <$container_type:: as SchemaRead>::get(&mut reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ReadError::Custom("invalid UTF-8 bytes")); + } + + dst.write(Self(container)); + + Ok(()) + } + } + }; +} + +// A `PrefixedStr` with a `u8` length prefix. +prefixed_str_type!(U8PrefixedStr, U8PrefixedVec, u8); + +// A `PrefixedStr` with a `u16` length prefix. +prefixed_str_type!(U16PrefixedStr, U16PrefixedVec, u16); + +// A `PrefixedStr` with a `u32` length prefix. +prefixed_str_type!(U32PrefixedStr, U32PrefixedVec, u32); + +// A `PrefixedStr` with a `u64` length prefix. +prefixed_str_type!(U64PrefixedStr, U64PrefixedVec, u64); + +#[cfg(test)] +mod tests { + use { + alloc::{string::String, vec::Vec}, + borsh::{io::ErrorKind, BorshDeserialize, BorshSerialize}, + core::mem::size_of, + wincode::WriteError, + }; + + use super::*; + + #[test] + fn trailing_str_borsh_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingStr = TrailingStr::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + original.serialize(&mut bytes.as_mut_slice()).unwrap(); + + let serialized = TrailingStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + } + + #[test] + fn trailing_str_wincode_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingStr = TrailingStr::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.deref(), DATA); + assert_eq!(serialized, original); + } + + #[test] + fn prefixed_str_borsh_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let string = U8PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u16 length prefix + string bytes + let original = U16PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let string = U16PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u32 length prefix + string bytes + let original = U32PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let string = U32PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u64 length prefix + string bytes + let original = U64PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let string = U64PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + } + + #[test] + fn prefixed_str_wincode_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u16 length prefix + string bytes + let original = U16PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u32 length prefix + string bytes + let original = U32PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u64 length prefix + string bytes + let original = U64PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + } + + #[test] + fn invalid_prefixed_value() { + let large_text = "a".repeat(256); + + let original = U8PrefixedStr::from(large_text); + + // borsh + let result = borsh::to_vec(&original); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); + + // wincode + let result = wincode::serialize(&original); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + WriteError::LengthEncodingOverflow(_) + )); + } + + #[test] + fn prefixed_str_borsh_with_remaining_bytes() { + let value = "⚙️ serialized data with extra bytes"; + let mut bytes = Vec::::new(); + + bytes.push(value.len() as u8); + bytes.extend_from_slice(value.as_bytes()); + // Extra bytes that should be ignored. + bytes.extend_from_slice(&[255u8; 16]); + + let mut reader = bytes.as_slice(); + let serialized = U8PrefixedStr::deserialize(&mut reader).unwrap(); + + assert_eq!(serialized.len(), value.len()); + assert_eq!(serialized.deref(), value); + } + + #[test] + fn prefixed_str_wincode_with_remaining_bytes() { + let value = "⚙️ serialized data with extra bytes"; + + let mut bytes = Vec::::new(); + bytes.push(value.len() as u8); + bytes.extend_from_slice(value.as_bytes()); + // Extra bytes that should be ignored. + bytes.extend_from_slice(&[255u8; 16]); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), value.len()); + assert_eq!(serialized.deref(), value); + } + + #[test] + fn invalid_utf8_borsh() { + // prefix + 2 invalid UTF-8 bytes + let bytes = [2u8, 255, 255]; + + // For `TrailingStr`, skip the prefix byte and attempt to deserialize the remaining + // bytes as UTF-8. Expect an error due to the invalid UTF-8 bytes. + let mut reader = bytes[1..].as_ref(); + let maybe_deserialized = TrailingStr::deserialize(&mut reader); + + assert!(maybe_deserialized.is_err()); + + // For `PrefixedStr`, read the length prefix and then read the specified number of + // bytes as URF-8. Expect an error due to the invalid UTF-8 bytes. + let mut reader = bytes.as_slice(); + let maybe_deserialized = U8PrefixedStr::deserialize(&mut reader); + + assert!(maybe_deserialized.is_err()); + } + + #[test] + fn invalid_utf8_wincode() { + // prefix + 2 invalid UTF-8 bytes + let bytes = [2u8, 255, 255]; + + // For `TrailingStr`, skip the prefix byte and attempt to deserialize the remaining + // bytes as UTF-8. Expect an error due to the invalid UTF-8 bytes. + let maybe_deserialized = wincode::deserialize::(&bytes[1..]); + + assert!(maybe_deserialized.is_err()); + + // For `PrefixedStr`, read the length prefix and then read the specified number of + // bytes as URF-8. Expect an error due to the invalid UTF-8 bytes. + let maybe_deserialized = wincode::deserialize::(&bytes); + + assert!(maybe_deserialized.is_err()); + } +} diff --git a/collections/src/vec.rs b/collections/src/vec.rs new file mode 100644 index 0000000..6bacbc7 --- /dev/null +++ b/collections/src/vec.rs @@ -0,0 +1,518 @@ +//! Types for serializing `Vec` types. +//! +//! This module provides two types for serializing a `Vec`: `TrailingVec` and a +//! set of `PrefixedVec`s with different length prefix types. +//! +//! `TrailingVec` is serialized without a length prefix, while the `PrefixedVec`s +//! are serialized with a length prefix determined by a type. The length prefix is useful +//! for deserializing vectors that are not the last field of a struct, as it allows the +//! deserializer to know how many bytes to read for the vector, while allowing for more +//! efficient storage depending on the expected length of the vector. +//! +//! The types in this module also implement the `Deref` trait, allowing them to be used +//! as regular `Vec` in most contexts. + +#[cfg(feature = "borsh")] +use borsh::{ + io::{ErrorKind, Read, Write}, + BorshDeserialize, BorshSerialize, +}; +use { + alloc::vec::Vec, + core::{ + fmt::{Debug, Formatter}, + ops::Deref, + }, +}; +#[cfg(feature = "wincode")] +use { + core::mem::MaybeUninit, + wincode::{ + config::ConfigCore, + error::{write_length_encoding_overflow, ReadError}, + io::{Reader, Writer}, + ReadResult, SchemaRead, SchemaWrite, WriteResult, + }, +}; + +/// A `Vec` serialized without a length prefix. +/// +/// This is useful for serializing a `Vec` that is the last field +/// of a struct, where the length can be inferred from the remaining +/// bytes. +/// +/// Note that this type is not suitable for serializing `Vec`s that +/// are not the last field of a struct, as it will consume all +/// remaining bytes. +/// +/// # Examples +/// +/// Using `TrailingVec` in a struct results in the vector being +/// serialized without a length prefix. +/// +/// ``` +/// use spl_collections::TrailingVec; +/// use wincode::{SchemaRead, SchemaWrite}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// pub struct MyStruct { +/// pub amount: u64, +/// pub items: TrailingVec, +/// } +/// +/// let my_struct = MyStruct { +/// amount: 1_000_000_000, +/// items: TrailingVec::from(vec![1, 2, 3, 4, 5]), +/// }; +/// +/// let bytes = wincode::serialize(&my_struct).unwrap(); +/// // Expected size: +/// // - amount (8 bytes) +/// // - items (remaining `Vec` without a length prefix) +/// assert_eq!(bytes.len(), 8 + my_struct.items.len() * size_of::()); +/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); +/// +/// # assert_eq!(deserialized.amount, my_struct.amount); +/// # assert_eq!(deserialized.items, my_struct.items); +/// ``` +#[derive(Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct TrailingVec(Vec); + +impl From> for TrailingVec { + fn from(value: Vec) -> Self { + Self(value) + } +} + +impl From<&[T]> for TrailingVec { + fn from(value: &[T]) -> Self { + Self(Vec::from(value)) + } +} + +impl From<&[T; N]> for TrailingVec { + fn from(value: &[T; N]) -> Self { + Self(Vec::from(value)) + } +} + +impl Deref for TrailingVec { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Debug for TrailingVec { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } +} + +#[cfg(feature = "borsh")] +impl BorshSerialize for TrailingVec { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + // Serialized items without a length prefix. + self.0.iter().try_for_each(|item| item.serialize(writer)) + } +} + +#[cfg(feature = "borsh")] +impl BorshDeserialize for TrailingVec { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let mut items: Vec = Vec::new(); + + while let Ok(item) = T::deserialize_reader(reader) { + items.push(item); + } + + Ok(Self(items)) + } +} + +#[cfg(feature = "wincode")] +unsafe impl SchemaWrite for TrailingVec +where + C: ConfigCore, + T: SchemaWrite, +{ + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + let expected_size = src.0.len().saturating_mul(core::mem::size_of::()); + + // `Vec` capacity is limited to `isize::MAX`. + if expected_size > isize::MAX as usize { + return Err(write_length_encoding_overflow( + "size of items in TrailingVec", + )); + } + + Ok(expected_size) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + // SAFETY: Serializing a slice `[T]` without a length prefix. + unsafe { + writer + .write_slice_t(src.0.as_slice()) + .map_err(wincode::WriteError::Io) + } + } +} + +#[cfg(feature = "wincode")] +unsafe impl<'de, T, C> SchemaRead<'de, C> for TrailingVec +where + C: ConfigCore, + T: SchemaRead<'de, C, Dst = T>, +{ + type Dst = Self; + + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let mut items = Vec::new(); + + while let Ok(item) = T::get(&mut reader) { + items.push(item); + } + + dst.write(Self(items)); + + Ok(()) + } +} + +/// Macro defining a `PrefixedVec` type with a specified length prefix type. +macro_rules! prefixed_vec_type { + ( $name:tt, $prefix_type:tt ) => { + #[doc = concat!("A `Vec` serialized with an `", stringify!($prefix_type), "` length prefix.")] + #[derive(Clone, Eq, PartialEq)] + #[repr(transparent)] + pub struct $name(Vec); + + impl From> for $name { + fn from(value: Vec) -> Self { + Self(value) + } + } + + impl From<&[T]> for $name { + fn from(value: &[T]) -> Self { + Self(Vec::from(value)) + } + } + + impl From<&[T; N]> for $name { + fn from(value: &[T; N]) -> Self { + Self(Vec::from(value)) + } + } + + impl Deref for $name { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl Debug for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } + } + + #[cfg(feature = "borsh")] + impl BorshSerialize for $name { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + BorshSerialize::serialize( + &$prefix_type::try_from(self.0.len()).map_err(|_| ErrorKind::InvalidData)?, + writer, + )?; + self.0.iter().try_for_each(|item| item.serialize(writer)) + } + } + + #[cfg(feature = "borsh")] + impl BorshDeserialize for $name { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let prefix = $prefix_type::deserialize_reader(reader)? as usize; + let mut items: Vec = Vec::with_capacity(prefix); + + while items.len() < prefix { + let Ok(item) = T::deserialize_reader(reader) else { + return Err(ErrorKind::InvalidData.into()); + }; + + items.push(item); + } + + Ok(Self(items)) + } + } + + #[cfg(feature = "wincode")] + unsafe impl SchemaWrite for $name + where + C: ConfigCore, + T: SchemaWrite, + { + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + let expected_size = core::mem::size_of::<$prefix_type>().saturating_add( + src.0.len().saturating_mul(size_of::())); + + // `Vec` capacity is limited to `isize::MAX`. + if expected_size > isize::MAX as usize { + return Err(write_length_encoding_overflow( + "size of items in TrailingVec", + )); + } + + Ok(expected_size) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + <$prefix_type as SchemaWrite>::write( + &mut writer, + &$prefix_type::try_from(src.0.len()) + .map_err(|_| write_length_encoding_overflow(stringify!($prefix_type::MAX)))?, + )?; + // SAFETY: Serializing a slice `[T]`. + unsafe { + writer + .write_slice_t(src.0.as_slice()) + .map_err(wincode::WriteError::Io) + } + } + } + + #[cfg(feature = "wincode")] + unsafe impl<'de, T, C> SchemaRead<'de, C> for $name + where + C: ConfigCore, + T: SchemaRead<'de, C, Dst = T>, + { + type Dst = Self; + + fn read( + mut reader: impl Reader<'de>, + dst: &mut MaybeUninit, + ) -> ReadResult<()> { + let mut prefix = MaybeUninit::<$prefix_type>::uninit(); + <$prefix_type as SchemaRead<'de, C>>::read(&mut reader, &mut prefix)?; + // SAFETY: We have just read the prefix from the reader, so it is initialized. + let prefix = unsafe { prefix.assume_init() } as usize; + + let mut items = Vec::with_capacity(prefix); + + while items.len() < prefix { + let Ok(item) = T::get(&mut reader) else { + return Err(ReadError::Custom("failed to deserialize")); + }; + + items.push(item); + } + + dst.write(Self(items)); + + Ok(()) + } + } + }; +} + +// A `PrefixedVec` with a `u8` length prefix. +prefixed_vec_type!(U8PrefixedVec, u8); + +// A `PrefixedVec` with a `u16` length prefix. +prefixed_vec_type!(U16PrefixedVec, u16); + +// A `PrefixedVec` with a `u32` length prefix. +prefixed_vec_type!(U32PrefixedVec, u32); + +// A `PrefixedVec` with a `u64` length prefix. +prefixed_vec_type!(U64PrefixedVec, u64); + +#[cfg(test)] +mod tests { + use borsh::{BorshDeserialize, BorshSerialize}; + use core::mem::size_of; + use wincode::WriteError; + + use super::*; + + #[test] + fn trailing_vec_borsh_round_trip() { + const VALUES: [u64; 5] = [255u64; 5]; + + let original: TrailingVec = TrailingVec::from(&VALUES); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; size_of::() * VALUES.len()]; + + original.serialize(&mut bytes.as_mut_slice()).unwrap(); + + let serialized = TrailingVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.as_slice(), VALUES); + assert_eq!(serialized, original); + } + + #[test] + fn trailing_vec_wincode_round_trip() { + const VALUES: [u64; 5] = [255u64; 5]; + + let original: TrailingVec = TrailingVec::from(&VALUES); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; size_of::() * VALUES.len()]; + + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.as_slice(), VALUES); + assert_eq!(serialized, original); + } + + #[test] + fn prefixed_vec_borsh_round_trip() { + const VALUES: [u64; 10] = [255u64; 10]; + + // u8 length prefix + let original = U8PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U8PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u16 length prefix + let original = U16PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U16PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u64 length prefix + let original = U64PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U64PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + } + + #[test] + fn prefixed_vec_wincode_round_trip() { + const VALUES: [u64; 10] = [255u64; 10]; + + // u8 length prefix + let original = U8PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u16 length prefix + let original = U16PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u32 length prefix + let original = U32PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u64 length prefix + let original = U64PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + } + + #[test] + fn invalid_prefixed_value() { + const VALUES: [u8; 256] = [255u8; 256]; + + let original = U8PrefixedVec::from(&VALUES); + + // borsh + let result = borsh::to_vec(&original); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); + + // wincode + let result = wincode::serialize(&original); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + WriteError::LengthEncodingOverflow(_) + )); + } + + #[test] + fn prefixed_vec_borsh_with_remaining_bytes() { + // Bytes representation for a `U8PrefixedVec` with 8 `u64` values + // followed by 16 additional bytes. + let mut bytes = [255u8; 81]; + bytes[0] = 8; + + let mut reader = bytes.as_slice(); + let serialized = U8PrefixedVec::::deserialize(&mut reader).unwrap(); + + assert_eq!(serialized.len(), 8); + assert_eq!(serialized.as_slice(), &[!(0u64); 8]); + } + + #[test] + fn prefixed_vec_wincode_with_remaining_bytes() { + // Bytes representation for a `U8PrefixedVec` with 8 `u64` values + // followed by 16 additional bytes. + let mut bytes = [255u8; 81]; + bytes[0] = 8; + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), 8); + assert_eq!(serialized.as_slice(), &[!(0u64); 8]); + } +}