From da7b9e30bb46128599097b2a0c5ec409a7d45d92 Mon Sep 17 00:00:00 2001 From: febo Date: Wed, 25 Feb 2026 18:42:39 +0000 Subject: [PATCH 1/6] Add collections crate --- Cargo.lock | 9 + Cargo.toml | 1 + collections/Cargo.toml | 29 +++ collections/src/lib.rs | 21 ++ collections/src/string.rs | 516 ++++++++++++++++++++++++++++++++++++++ collections/src/vec.rs | 465 ++++++++++++++++++++++++++++++++++ 6 files changed, 1041 insertions(+) create mode 100644 collections/Cargo.toml create mode 100644 collections/src/lib.rs create mode 100644 collections/src/string.rs create mode 100644 collections/src/vec.rs diff --git a/Cargo.lock b/Cargo.lock index 95c6adb..da843c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1532,6 +1532,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "spl-collections" +version = "0.0.0" +dependencies = [ + "borsh", + "spl-collections", + "wincode", +] + [[package]] name = "spl-discriminator" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index a18b107..521897b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "collections", "discriminator", "discriminator-derive", "discriminator-syn", diff --git a/collections/Cargo.toml b/collections/Cargo.toml new file mode 100644 index 0000000..1efe025 --- /dev/null +++ b/collections/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "spl-collections" +version = "0.0.0" +description = "Serialization-aware collection wrappers for Solana account data" +authors = ["Anza Maintainers "] +repository = "https://github.com/solana-program/libraries" +license = "Apache-2.0" +edition = "2021" + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] +all-features = true +rustdoc-args = ["--cfg=docsrs"] + +[features] +alloc = [] +borsh = ["dep:borsh", "alloc"] +default = ["alloc"] +wincode = ["dep:wincode", "alloc"] + +[dependencies] +borsh = { version = "1.0", features = ["derive"], default-features = false, optional = true } +wincode = { version = "0.4.4", features = ["alloc", "derive"], default-features = false, optional = true } + +[dev-dependencies] +spl-collections = { path = ".", features = ["borsh", "wincode"] } + +[lib] +crate-type = ["lib"] diff --git a/collections/src/lib.rs b/collections/src/lib.rs new file mode 100644 index 0000000..af96d28 --- /dev/null +++ b/collections/src/lib.rs @@ -0,0 +1,21 @@ +//! Serialization-aware collection wrappers for Solana account data. +//! +//! This crate provides wrappers around collection types to support custom serialization +//! logic. This is useful for programs that have specific requirements for how data is +//! stored. + +#![no_std] +#![cfg_attr(docsrs, feature(doc_cfg))] + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(feature = "alloc")] +mod string; +#[cfg(feature = "alloc")] +mod vec; + +#[cfg(feature = "alloc")] +pub use string::*; +#[cfg(feature = "alloc")] +pub use vec::*; diff --git a/collections/src/string.rs b/collections/src/string.rs new file mode 100644 index 0000000..e5c2eb2 --- /dev/null +++ b/collections/src/string.rs @@ -0,0 +1,516 @@ +//! Types for serializing strings types. +//! +//! This module provides two types for serializing strings: `TrailingString` and a +//! set of `PrefixedString`. +//! +//! `TrailingString` is serialized without a length prefix, while the `PrefixedString`s +//! are serialized with a length prefix determined by a type. The length prefix is useful +//! for deserializing strings that are not the last field of a struct, as it allows the +//! deserializer to know how many bytes to read for the string, while allowing for more +//! efficient storage depending on the expected length of the string. +//! +//! The types in this module also implement the `Deref` trait, allowing them to be used +//! as regular `String` in most contexts. + +use { + alloc::string::{String, ToString}, + core::{ + fmt::{Debug, Formatter}, + ops::Deref, + }, +}; +#[cfg(feature = "borsh")] +use { + alloc::vec, + borsh::{ + io::{ErrorKind, Read, Write}, + BorshDeserialize, BorshSerialize, + }, +}; +#[cfg(feature = "wincode")] +use { + core::mem::MaybeUninit, + wincode::{ + config::ConfigCore, + error::{invalid_utf8_encoding, write_length_encoding_overflow}, + io::{Reader, Writer}, + ReadResult, SchemaRead, SchemaWrite, WriteResult, + }, +}; + +#[cfg(feature = "borsh")] +/// Size of the buffer used to read the string. +const BUFFER_SIZE: usize = 1024; + +/// A `String` serialized without a length prefix. +/// +/// This is useful for serializing strings that are the last field +/// of a struct, where the length can be inferred from the remaining +/// bytes. +/// +/// Note that this type is not suitable for serializing strings that +/// are not the last field of a struct, as it will consume all +/// remaining bytes. +/// +/// # Examples +/// +/// Using `TrailingString` in a struct results in the string being +/// serialized without a length prefix. +/// +/// ``` +/// use spl_collections::TrailingString; +/// use wincode::{SchemaRead, SchemaWrite}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// pub struct MyStruct { +/// pub state: u8, +/// pub amount: u64, +/// pub description: TrailingString, +/// } +/// +/// let my_struct = MyStruct { +/// state: 1, +/// amount: 1_000_000_000, +/// description: TrailingString::from( +/// "The quick brown fox jumps over the lazy dog" +/// ), +/// }; +/// +/// let bytes = wincode::serialize(&my_struct).unwrap(); +/// // Expected size: +/// // - state (1 byte) +/// // - amount (8 bytes) +/// // - description (remaining bytes without a length prefix) +/// assert_eq!(bytes.len(), 1 + 8 + my_struct.description.len()); +/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); +/// +/// # assert_eq!(deserialized.state, my_struct.state); +/// # assert_eq!(deserialized.amount, my_struct.amount); +/// # assert_eq!(deserialized.description, my_struct.description); +/// ``` +#[derive(Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct TrailingString(String); + +impl From for TrailingString { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for TrailingString { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl Deref for TrailingString { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Debug for TrailingString { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } +} + +#[cfg(feature = "borsh")] +impl BorshSerialize for TrailingString { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + // Serialize the string bytes without a length prefix. + writer.write_all(self.0.as_bytes()) + } +} + +#[cfg(feature = "borsh")] +impl BorshDeserialize for TrailingString { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + // Read the string in chunks until we reach the end of the reader. + let mut buffer = [0u8; BUFFER_SIZE]; + let mut s = String::new(); + + loop { + let bytes_read = reader.read(&mut buffer)?; + + if bytes_read == 0 { + break; + } + + s.push_str( + core::str::from_utf8(&buffer[..bytes_read]).map_err(|_| ErrorKind::InvalidData)?, + ); + } + + Ok(Self(s)) + } +} + +#[cfg(feature = "wincode")] +unsafe impl SchemaWrite for TrailingString { + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + Ok(src.0.len()) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + // Serialize the string bytes without a length prefix. + unsafe { + writer + .write_slice_t(src.0.as_bytes()) + .map_err(wincode::WriteError::Io) + } + } +} + +#[cfg(feature = "wincode")] +unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for TrailingString { + type Dst = Self; + + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let mut s = String::new(); + let mut bytes_read = 0; + + loop { + // SAFETY: Move the reader by `bytes_read` from the previous iteration. + unsafe { reader.consume_unchecked(bytes_read) }; + + // Read the string in chunks until we reach the end of the reader. + let bytes = reader.fill_buf(BUFFER_SIZE)?; + + if bytes.is_empty() { + break; + } + + s.push_str(core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?); + bytes_read = bytes.len(); + } + + dst.write(Self(s)); + + Ok(()) + } +} + +/// Macro defining a `PrefixedStr` type with a specified length prefix type. +macro_rules! prefixed_str_type { + ( $name:tt, $prefix_type:tt ) => { + #[doc = concat!("A `String` that is serialized with an `", stringify!($prefix_type), "` length prefix.")] + #[derive(Clone, Eq, PartialEq)] + #[repr(transparent)] + pub struct $name(String); + + impl From for $name { + fn from(value: String) -> Self { + Self(value) + } + } + + impl From<&str> for $name { + fn from(value: &str) -> Self { + Self(value.to_string()) + } + } + + impl Deref for $name { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl Debug for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } + } + + #[cfg(feature = "borsh")] + impl BorshSerialize for $name { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + BorshSerialize::serialize( + &$prefix_type::try_from(self.0.len()).map_err(|_| ErrorKind::InvalidData)?, + writer, + )?; + writer.write_all(self.0.as_bytes()) + } + } + + #[cfg(feature = "borsh")] + impl BorshDeserialize for $name { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let prefix = $prefix_type::deserialize_reader(reader)?; + + let mut buffer = vec![0u8; prefix as usize]; + reader.read_exact(&mut buffer)?; + + Ok(Self::from( + String::from_utf8(buffer).map_err(|_| ErrorKind::InvalidData)?, + )) + } + } + + #[cfg(feature = "wincode")] + unsafe impl SchemaWrite for $name { + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + Ok(core::mem::size_of::<$prefix_type>() + src.0.len()) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + <$prefix_type as SchemaWrite>::write( + &mut writer, + &$prefix_type::try_from(src.0.len()) + .map_err(|_| write_length_encoding_overflow(stringify!($prefix_type::MAX)))?, + )?; + // SAFETY: Serializing a slice of `[u8]`. + unsafe { + writer + .write_slice_t(src.0.as_bytes()) + .map_err(wincode::WriteError::Io) + } + } + } + + #[cfg(feature = "wincode")] + unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for $name { + type Dst = Self; + + fn read( + mut reader: impl Reader<'de>, + dst: &mut MaybeUninit, + ) -> ReadResult<()> { + // Read the length prefix first to determine how many bytes to read for the string. + let mut prefix = MaybeUninit::<$prefix_type>::uninit(); + <$prefix_type as SchemaRead<'de, C>>::read(&mut reader, &mut prefix)?; + // SAFETY: We have just read the prefix from the reader, so it is initialized. + let prefix = unsafe { prefix.assume_init() } as usize; + + let bytes = reader.fill_exact(prefix)?; + dst.write($name::from( + core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?, + )); + + Ok(()) + } + } + }; +} + +// A `PrefixedString` with a `u8` length prefix. +prefixed_str_type!(U8PrefixedString, u8); + +// A `PrefixedString` with a `u16` length prefix. +prefixed_str_type!(U16PrefixedString, u16); + +// A `PrefixedString` with a `u32` length prefix. +prefixed_str_type!(U32PrefixedString, u32); + +// A `PrefixedString` with a `u64` length prefix. +prefixed_str_type!(U64PrefixedString, u64); + +#[cfg(test)] +mod tests { + use borsh::{BorshDeserialize, BorshSerialize}; + use core::mem::size_of; + use wincode::WriteError; + + use super::*; + + #[test] + fn trailing_str_borsh_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingString = TrailingString::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + original.serialize(&mut bytes.as_mut_slice()).unwrap(); + + let serialized = TrailingString::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + } + + #[test] + fn trailing_str_wincode_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingString = TrailingString::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.as_str(), DATA); + assert_eq!(serialized, original); + } + + #[test] + fn prefixed_str_borsh_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedString::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let string = U8PrefixedString::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.as_str(), TEXT); + + // u16 length prefix + string bytes + let original = U16PrefixedString::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let string = U16PrefixedString::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.as_str(), TEXT); + + // u32 length prefix + string bytes + let original = U32PrefixedString::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let string = U32PrefixedString::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.as_str(), TEXT); + + // u64 length prefix + string bytes + let original = U64PrefixedString::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let string = U64PrefixedString::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.as_str(), TEXT); + } + + #[test] + fn prefixed_str_wincode_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedString::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.as_str(), TEXT); + assert_eq!(serialized, original); + + // u16 length prefix + string bytes + let original = U16PrefixedString::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.as_str(), TEXT); + assert_eq!(serialized, original); + + // u32 length prefix + string bytes + let original = U32PrefixedString::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.as_str(), TEXT); + assert_eq!(serialized, original); + + // u64 length prefix + string bytes + let original = U64PrefixedString::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.as_str(), TEXT); + assert_eq!(serialized, original); + } + + #[test] + fn invalid_prefixed_value() { + let large_text = "a".repeat(256); + + let original = U8PrefixedString::from(large_text); + + // borsh + let result = borsh::to_vec(&original); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); + + // wincode + let result = wincode::serialize(&original); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + WriteError::LengthEncodingOverflow(_) + )); + } +} diff --git a/collections/src/vec.rs b/collections/src/vec.rs new file mode 100644 index 0000000..3e2c5d3 --- /dev/null +++ b/collections/src/vec.rs @@ -0,0 +1,465 @@ +//! Types for serializing `Vec` types. +//! +//! This module provides two types for serializing a `Vec`: `TrailingVec` and a +//! set of `PrefixedVec`s with different length prefix types. +//! +//! `TrailingVec` is serialized without a length prefix, while the `PrefixedVec`s +//! are serialized with a length prefix determined by a type. The length prefix is useful +//! for deserializing vectors that are not the last field of a struct, as it allows the +//! deserializer to know how many bytes to read for the vector, while allowing for more +//! efficient storage depending on the expected length of the vector. +//! +//! The types in this module also implement the `Deref` trait, allowing them to be used +//! as regular `Vec` in most contexts. + +#[cfg(feature = "borsh")] +use borsh::{ + io::{ErrorKind, Read, Write}, + BorshDeserialize, BorshSerialize, +}; +use { + alloc::vec::Vec, + core::{ + fmt::{Debug, Formatter}, + ops::Deref, + }, +}; +#[cfg(feature = "wincode")] +use { + core::mem::MaybeUninit, + wincode::{ + config::ConfigCore, + error::write_length_encoding_overflow, + io::{Reader, Writer}, + ReadResult, SchemaRead, SchemaWrite, WriteResult, + }, +}; + +/// A `Vec` serialized without a length prefix. +/// +/// This is useful for serializing a `Vec` that is the last field +/// of a struct, where the length can be inferred from the remaining +/// bytes. +/// +/// Note that this type is not suitable for serializing `Vec`s that +/// are not the last field of a struct, as it will consume all +/// remaining bytes. +/// +/// # Examples +/// +/// Using `TrailingVec` in a struct results in the vector being +/// serialized without a length prefix. +/// +/// ``` +/// use spl_collections::TrailingVec; +/// use wincode::{SchemaRead, SchemaWrite}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// pub struct MyStruct { +/// pub amount: u64, +/// pub items: TrailingVec, +/// } +/// +/// let my_struct = MyStruct { +/// amount: 1_000_000_000, +/// items: TrailingVec::from(vec![1, 2, 3, 4, 5]), +/// }; +/// +/// let bytes = wincode::serialize(&my_struct).unwrap(); +/// // Expected size: +/// // - amount (8 bytes) +/// // - items (remaining `Vec` without a length prefix) +/// assert_eq!(bytes.len(), 8 + my_struct.items.len() * size_of::()); +/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); +/// +/// # assert_eq!(deserialized.amount, my_struct.amount); +/// # assert_eq!(deserialized.items, my_struct.items); +/// ``` +#[derive(Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct TrailingVec(Vec); + +impl From> for TrailingVec { + fn from(value: Vec) -> Self { + Self(value) + } +} + +impl From<&[T]> for TrailingVec { + fn from(value: &[T]) -> Self { + Self(Vec::from(value)) + } +} + +impl From<&[T; N]> for TrailingVec { + fn from(value: &[T; N]) -> Self { + Self(Vec::from(value)) + } +} + +impl Deref for TrailingVec { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Debug for TrailingVec { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } +} + +#[cfg(feature = "borsh")] +impl BorshSerialize for TrailingVec { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + // Serialized items without a length prefix. + self.0.iter().try_for_each(|item| item.serialize(writer)) + } +} + +#[cfg(feature = "borsh")] +impl BorshDeserialize for TrailingVec { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let mut items: Vec = Vec::new(); + + while let Ok(item) = T::deserialize_reader(reader) { + items.push(item); + } + + Ok(Self(items)) + } +} + +#[cfg(feature = "wincode")] +unsafe impl SchemaWrite for TrailingVec +where + C: ConfigCore, + T: SchemaWrite, +{ + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + Ok(src.0.len() * size_of::()) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + // SAFETY: Serializing a slice `[T]` without a length prefix. + unsafe { + writer + .write_slice_t(src.0.as_slice()) + .map_err(wincode::WriteError::Io) + } + } +} + +#[cfg(feature = "wincode")] +unsafe impl<'de, T, C> SchemaRead<'de, C> for TrailingVec +where + C: ConfigCore, + T: SchemaRead<'de, C, Dst = T>, +{ + type Dst = Self; + + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let mut items = Vec::new(); + + while let Ok(item) = T::get(&mut reader) { + items.push(item); + } + + dst.write(Self(items)); + + Ok(()) + } +} + +/// Macro defining a `PrefixedVec` type with a specified length prefix type. +macro_rules! prefixed_vec_type { + ( $name:tt, $prefix_type:tt ) => { + #[doc = concat!("A `Vec` serialized with an `", stringify!($prefix_type), "` length prefix.")] + #[derive(Clone, Eq, PartialEq)] + #[repr(transparent)] + pub struct $name(Vec); + + impl From> for $name { + fn from(value: Vec) -> Self { + Self(value) + } + } + + impl From<&[T]> for $name { + fn from(value: &[T]) -> Self { + Self(Vec::from(value)) + } + } + + impl From<&[T; N]> for $name { + fn from(value: &[T; N]) -> Self { + Self(Vec::from(value)) + } + } + + impl Deref for $name { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl Debug for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } + } + + #[cfg(feature = "borsh")] + impl BorshSerialize for $name { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + BorshSerialize::serialize( + &$prefix_type::try_from(self.0.len()).map_err(|_| ErrorKind::InvalidData)?, + writer, + )?; + self.0.iter().try_for_each(|item| item.serialize(writer)) + } + } + + #[cfg(feature = "borsh")] + impl BorshDeserialize for $name { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let prefix = $prefix_type::deserialize_reader(reader)?; + let mut items: Vec = Vec::with_capacity(prefix as usize); + + while let Ok(item) = T::deserialize_reader(reader) { + items.push(item); + } + + Ok(Self(items)) + } + } + + #[cfg(feature = "wincode")] + unsafe impl SchemaWrite for $name + where + C: ConfigCore, + T: SchemaWrite, + { + type Src = Self; + + #[inline(always)] + fn size_of(src: &Self::Src) -> WriteResult { + Ok(core::mem::size_of::<$prefix_type>() + src.0.len()) + } + + #[inline(always)] + fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + <$prefix_type as SchemaWrite>::write( + &mut writer, + &$prefix_type::try_from(src.0.len()) + .map_err(|_| write_length_encoding_overflow(stringify!($prefix_type::MAX)))?, + )?; + // SAFETY: Serializing a slice `[T]`. + unsafe { + writer + .write_slice_t(src.0.as_slice()) + .map_err(wincode::WriteError::Io) + } + } + } + + #[cfg(feature = "wincode")] + unsafe impl<'de, T, C> SchemaRead<'de, C> for $name + where + C: ConfigCore, + T: SchemaRead<'de, C, Dst = T>, + { + type Dst = Self; + + fn read( + mut reader: impl Reader<'de>, + dst: &mut MaybeUninit, + ) -> ReadResult<()> { + // Read the length prefix first to allocate space for `T`s. + let mut prefix = MaybeUninit::<$prefix_type>::uninit(); + <$prefix_type as SchemaRead<'de, C>>::read(&mut reader, &mut prefix)?; + // SAFETY: We have just read the prefix from the reader, so it is initialized. + let prefix = unsafe { prefix.assume_init() } as usize; + + let mut items = Vec::with_capacity(prefix); + + while let Ok(item) = T::get(&mut reader) { + items.push(item); + } + + dst.write(Self(items)); + + Ok(()) + } + } + }; +} + +// A `PrefixedVec` with a `u8` length prefix. +prefixed_vec_type!(U8PrefixedVec, u8); + +// A `PrefixedVec` with a `u16` length prefix. +prefixed_vec_type!(U16PrefixedVec, u16); + +// A `PrefixedVec` with a `u32` length prefix. +prefixed_vec_type!(U32PrefixedVec, u32); + +// A `PrefixedVec` with a `u64` length prefix. +prefixed_vec_type!(U64PrefixedVec, u64); + +#[cfg(test)] +mod tests { + use borsh::{BorshDeserialize, BorshSerialize}; + use core::mem::size_of; + use wincode::WriteError; + + use super::*; + + #[test] + fn trailing_vec_borsh_round_trip() { + const VALUES: [u64; 5] = [255u64; 5]; + + let original: TrailingVec = TrailingVec::from(&VALUES); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; size_of::() * VALUES.len()]; + + original.serialize(&mut bytes.as_mut_slice()).unwrap(); + + let serialized = TrailingVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.as_slice(), VALUES); + assert_eq!(serialized, original); + } + + #[test] + fn trailing_vec_wincode_round_trip() { + const VALUES: [u64; 5] = [255u64; 5]; + + let original: TrailingVec = TrailingVec::from(&VALUES); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; size_of::() * VALUES.len()]; + + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.as_slice(), VALUES); + assert_eq!(serialized, original); + } + + #[test] + fn prefixed_vec_borsh_round_trip() { + const VALUES: [u64; 10] = [255u64; 10]; + + // u8 length prefix + let original = U8PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U8PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u16 length prefix + let original = U16PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U16PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u64 length prefix + let original = U64PrefixedVec::from(&VALUES); + let bytes = borsh::to_vec(&original).unwrap(); + + let serialized = U64PrefixedVec::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + } + + #[test] + fn prefixed_vec_wincode_round_trip() { + const VALUES: [u64; 10] = [255u64; 10]; + + // u8 length prefix + let original = U8PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u16 length prefix + let original = U16PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u32 length prefix + let original = U32PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + + // u64 length prefix + let original = U64PrefixedVec::from(&VALUES); + let mut bytes = [0u8; size_of::() + size_of::() * VALUES.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + assert_eq!(serialized.as_slice(), VALUES); + } + + #[test] + fn invalid_prefixed_value() { + const VALUES: [u8; 256] = [255u8; 256]; + + let original = U8PrefixedVec::from(&VALUES); + + // borsh + let result = borsh::to_vec(&original); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); + + // wincode + let result = wincode::serialize(&original); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + WriteError::LengthEncodingOverflow(_) + )); + } +} From 3590d2f7feb5fa80ed00976afc4248513b0be158 Mon Sep 17 00:00:00 2001 From: febo Date: Sun, 1 Mar 2026 13:28:10 -0700 Subject: [PATCH 2/6] Address review comments --- collections/Cargo.toml | 6 +- collections/src/lib.rs | 9 +- collections/src/str.rs | 510 +++++++++++++++++++++++++++++++++++++ collections/src/string.rs | 516 -------------------------------------- collections/src/vec.rs | 48 +++- 5 files changed, 555 insertions(+), 534 deletions(-) create mode 100644 collections/src/str.rs delete mode 100644 collections/src/string.rs diff --git a/collections/Cargo.toml b/collections/Cargo.toml index 1efe025..953e5a2 100644 --- a/collections/Cargo.toml +++ b/collections/Cargo.toml @@ -13,10 +13,8 @@ all-features = true rustdoc-args = ["--cfg=docsrs"] [features] -alloc = [] -borsh = ["dep:borsh", "alloc"] -default = ["alloc"] -wincode = ["dep:wincode", "alloc"] +borsh = ["dep:borsh"] +wincode = ["dep:wincode"] [dependencies] borsh = { version = "1.0", features = ["derive"], default-features = false, optional = true } diff --git a/collections/src/lib.rs b/collections/src/lib.rs index af96d28..fb1bc2c 100644 --- a/collections/src/lib.rs +++ b/collections/src/lib.rs @@ -7,15 +7,10 @@ #![no_std] #![cfg_attr(docsrs, feature(doc_cfg))] -#[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "alloc")] -mod string; -#[cfg(feature = "alloc")] +mod str; mod vec; -#[cfg(feature = "alloc")] -pub use string::*; -#[cfg(feature = "alloc")] +pub use str::*; pub use vec::*; diff --git a/collections/src/str.rs b/collections/src/str.rs new file mode 100644 index 0000000..72dffed --- /dev/null +++ b/collections/src/str.rs @@ -0,0 +1,510 @@ +//! Types for serializing strings. +//! +//! This module provides two types for serializing strings: `TrailingStr` and a +//! set of `PrefixedStr`. +//! +//! `TrailingStr` is serialized without a length prefix, while the `PrefixedStr`s +//! are serialized with a length prefix determined by a type. The length prefix is useful +//! for deserializing strings that are not the last field of a struct, as it allows the +//! deserializer to know how many bytes to read for the string, while allowing for more +//! efficient storage depending on the expected length of the string. +//! +//! The types in this module also implement the `Deref` trait, allowing them to be used +//! as `&str` in most contexts. + +#[cfg(feature = "borsh")] +use borsh::{ + io::{ErrorKind, Read}, + BorshDeserialize, BorshSerialize, +}; +use { + crate::{TrailingVec, U16PrefixedVec, U32PrefixedVec, U64PrefixedVec, U8PrefixedVec}, + alloc::string::String, + core::{ + fmt::{Debug, Formatter}, + ops::Deref, + str::from_utf8_unchecked, + }, +}; +#[cfg(feature = "wincode")] +use { + core::{mem::MaybeUninit, str::from_utf8}, + wincode::{ + config::{Config, ConfigCore}, + io::Reader, + ReadError, ReadResult, SchemaRead, SchemaWrite, UninitBuilder, + }, +}; + +/// A `str` serialized without a length prefix. +/// +/// This is useful for serializing strings that are the last field +/// of a struct, where the length can be inferred from the remaining +/// bytes. +/// +/// Note that this type is not suitable for serializing strings that +/// are not the last field of a struct, as it will consume all +/// remaining bytes. +/// +/// # Examples +/// +/// Using `TrailingStr` in a struct results in the string being +/// serialized without a length prefix. +/// +/// ``` +/// use spl_collections::TrailingStr; +/// use wincode::{SchemaRead, SchemaWrite}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// pub struct MyStruct { +/// pub state: u8, +/// pub amount: u64, +/// pub description: TrailingStr, +/// } +/// +/// let my_struct = MyStruct { +/// state: 1, +/// amount: 1_000_000_000, +/// description: TrailingStr::from( +/// "The quick brown fox jumps over the lazy dog" +/// ), +/// }; +/// +/// let bytes = wincode::serialize(&my_struct).unwrap(); +/// // Expected size: +/// // - state (1 byte) +/// // - amount (8 bytes) +/// // - description (remaining bytes without a length prefix) +/// assert_eq!(bytes.len(), 1 + 8 + my_struct.description.len()); +/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); +/// +/// # assert_eq!(deserialized.state, my_struct.state); +/// # assert_eq!(deserialized.amount, my_struct.amount); +/// # assert_eq!(deserialized.description, my_struct.description); +/// ``` +#[cfg_attr(feature = "borsh", derive(BorshSerialize))] +#[cfg_attr(feature = "wincode", derive(SchemaWrite, UninitBuilder))] +#[derive(Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct TrailingStr(TrailingVec); + +impl From for TrailingStr { + fn from(value: String) -> Self { + Self(TrailingVec::from(value.as_bytes())) + } +} + +impl From<&str> for TrailingStr { + fn from(value: &str) -> Self { + Self(TrailingVec::from(value.as_bytes())) + } +} + +impl Deref for TrailingStr { + type Target = str; + + fn deref(&self) -> &Self::Target { + // SAFETY: The `TrailingStr` type is only constructed + // from valid UTF-8 strings. + unsafe { from_utf8_unchecked(&self.0) } + } +} + +impl Debug for TrailingStr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.deref())) + } +} + +#[cfg(feature = "borsh")] +impl BorshDeserialize for TrailingStr { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let container = TrailingVec::::deserialize_reader(reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ErrorKind::InvalidData.into()); + } + + Ok(Self(container)) + } +} + +#[cfg(feature = "wincode")] +unsafe impl<'de, C: Config> SchemaRead<'de, C> for TrailingStr { + type Dst = Self; + + fn read(reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let mut builder = TrailingStrUninitBuilder::::from_maybe_uninit_mut(dst); + builder.read_0(reader)?; + + let container = unsafe { builder.uninit_0_mut().assume_init_ref() }; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(container).is_err() { + return Err(ReadError::Custom("invalid UTF-8 bytes")); + } + + builder.finish(); + + Ok(()) + } +} + +/// Macro defining a `PrefixedStr` type with a specified length prefix type. +macro_rules! prefixed_str_type { + ( $name:tt, $container_type:tt, $prefix_type:tt ) => { + #[doc = concat!("A `str` that is serialized with an `", stringify!($prefix_type), "` length prefix.")] + #[cfg_attr(feature = "borsh", derive(BorshSerialize))] + #[cfg_attr(feature = "wincode", derive(SchemaWrite))] + #[derive(Clone, Eq, PartialEq)] + #[repr(transparent)] + pub struct $name($container_type); + + impl From for $name { + fn from(value: String) -> Self { + Self($container_type::from(value.as_bytes())) + } + } + + impl From<&str> for $name { + fn from(value: &str) -> Self { + Self($container_type::from(value.as_bytes())) + } + } + + impl Deref for $name { + type Target = str; + + fn deref(&self) -> &Self::Target { + // SAFETY: `*PrefixedStr` types are only constructed + // from valid UTF-8 strings. + unsafe { from_utf8_unchecked(&self.0) } + } + } + + impl Debug for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self.deref())) + } + } + + #[cfg(feature = "borsh")] + impl BorshDeserialize for $name { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let container = $container_type::::deserialize_reader(reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ErrorKind::InvalidData.into()); + } + + Ok(Self(container)) + } + } + + #[cfg(feature = "wincode")] + unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for $name { + type Dst = Self; + + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let container = <$container_type:: as SchemaRead>::get(&mut reader)?; + + // Validate that we got valid UTF-8 bytes, as `TrailingStr` must + // always be valid UTF-8. + if from_utf8(&container).is_err() { + return Err(ReadError::Custom("invalid UTF-8 bytes")); + } + + dst.write(Self(container)); + + Ok(()) + } + } + }; +} + +// A `PrefixedStr` with a `u8` length prefix. +prefixed_str_type!(U8PrefixedStr, U8PrefixedVec, u8); + +// A `PrefixedStr` with a `u16` length prefix. +prefixed_str_type!(U16PrefixedStr, U16PrefixedVec, u16); + +// A `PrefixedStr` with a `u32` length prefix. +prefixed_str_type!(U32PrefixedStr, U32PrefixedVec, u32); + +// A `PrefixedStr` with a `u64` length prefix. +prefixed_str_type!(U64PrefixedStr, U64PrefixedVec, u64); + +#[cfg(test)] +mod tests { + use { + alloc::vec::Vec, + borsh::{io::ErrorKind, BorshDeserialize, BorshSerialize}, + core::mem::size_of, + wincode::WriteError, + }; + + use super::*; + + #[test] + fn trailing_str_borsh_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingStr = TrailingStr::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + original.serialize(&mut bytes.as_mut_slice()).unwrap(); + + let serialized = TrailingStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized, original); + } + + #[test] + fn trailing_str_wincode_round_trip() { + const DATA: &str = "Trailing strings have many characters"; + + let original: TrailingStr = TrailingStr::from(String::from(DATA)); + // No need to reserve space for a length prefix. + let mut bytes = [0u8; DATA.len()]; + + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), original.len()); + assert_eq!(serialized.deref(), DATA); + assert_eq!(serialized, original); + } + + #[test] + fn prefixed_str_borsh_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let string = U8PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u16 length prefix + string bytes + let original = U16PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let string = U16PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u32 length prefix + string bytes + let original = U32PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let string = U32PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + + // u64 length prefix + string bytes + let original = U64PrefixedStr::from(String::from(TEXT)); + let bytes = borsh::to_vec(&original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let string = U64PrefixedStr::try_from_slice(&bytes).unwrap(); + + assert_eq!(string.len(), TEXT.len()); + assert_eq!(string.deref(), TEXT); + } + + #[test] + fn prefixed_str_wincode_round_trip() { + const TEXT: &str = "Prefixed strings have many characters"; + + // u8 length prefix + string bytes + let original = U8PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!(bytes[0], TEXT.len() as u8); + assert_eq!(&bytes[1..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u16 length prefix + string bytes + let original = U16PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), + TEXT.len() as u16 + ); + assert_eq!(&bytes[2..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u32 length prefix + string bytes + let original = U32PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), + TEXT.len() as u32 + ); + assert_eq!(&bytes[4..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + + // u64 length prefix + string bytes + let original = U64PrefixedStr::from(String::from(TEXT)); + let mut bytes = [0u8; size_of::() + TEXT.len()]; + wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); + + assert_eq!( + u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), + TEXT.len() as u64 + ); + assert_eq!(&bytes[8..], TEXT.as_bytes()); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), TEXT.len()); + assert_eq!(serialized.deref(), TEXT); + assert_eq!(serialized, original); + } + + #[test] + fn invalid_prefixed_value() { + let large_text = "a".repeat(256); + + let original = U8PrefixedStr::from(large_text); + + // borsh + let result = borsh::to_vec(&original); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); + + // wincode + let result = wincode::serialize(&original); + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + WriteError::LengthEncodingOverflow(_) + )); + } + + #[test] + fn prefixed_str_borsh_with_remaining_bytes() { + let value = "⚙️ serialized data with extra bytes"; + let mut bytes = Vec::::new(); + + bytes.push(value.len() as u8); + bytes.extend_from_slice(value.as_bytes()); + // Extra bytes that should be ignored. + bytes.extend_from_slice(&[255u8; 16]); + + let mut reader = bytes.as_slice(); + let serialized = U8PrefixedStr::deserialize(&mut reader).unwrap(); + + assert_eq!(serialized.len(), value.len()); + assert_eq!(serialized.deref(), value); + } + + #[test] + fn prefixed_str_wincode_with_remaining_bytes() { + let value = "⚙️ serialized data with extra bytes"; + + let mut bytes = Vec::::new(); + bytes.push(value.len() as u8); + bytes.extend_from_slice(value.as_bytes()); + // Extra bytes that should be ignored. + bytes.extend_from_slice(&[255u8; 16]); + + let serialized = wincode::deserialize::(&bytes).unwrap(); + + assert_eq!(serialized.len(), value.len()); + assert_eq!(serialized.deref(), value); + } + + #[test] + fn invalid_utf8_borsh() { + // prefix + 2 invalid UTF-8 bytes + let bytes = [2u8, 255, 255]; + + // For `TrailingStr`, skip the prefix byte and attempt to deserialize the remaining + // bytes as UTF-8. Expect an error due to the invalid UTF-8 bytes. + let mut reader = bytes[1..].as_ref(); + let maybe_deserialized = TrailingStr::deserialize(&mut reader); + + assert!(maybe_deserialized.is_err()); + + // For `PrefixedStr`, read the length prefix and then read the specified number of + // bytes as URF-8. Expect an error due to the invalid UTF-8 bytes. + let mut reader = bytes.as_slice(); + let maybe_deserialized = U8PrefixedStr::deserialize(&mut reader); + + assert!(maybe_deserialized.is_err()); + } + + #[test] + fn invalid_utf8_wincode() { + // prefix + 2 invalid UTF-8 bytes + let bytes = [2u8, 255, 255]; + + // For `TrailingStr`, skip the prefix byte and attempt to deserialize the remaining + // bytes as UTF-8. Expect an error due to the invalid UTF-8 bytes. + let maybe_deserialized = wincode::deserialize::(&bytes[1..]); + + assert!(maybe_deserialized.is_err()); + + // For `PrefixedStr`, read the length prefix and then read the specified number of + // bytes as URF-8. Expect an error due to the invalid UTF-8 bytes. + let maybe_deserialized = wincode::deserialize::(&bytes); + + assert!(maybe_deserialized.is_err()); + } +} diff --git a/collections/src/string.rs b/collections/src/string.rs deleted file mode 100644 index e5c2eb2..0000000 --- a/collections/src/string.rs +++ /dev/null @@ -1,516 +0,0 @@ -//! Types for serializing strings types. -//! -//! This module provides two types for serializing strings: `TrailingString` and a -//! set of `PrefixedString`. -//! -//! `TrailingString` is serialized without a length prefix, while the `PrefixedString`s -//! are serialized with a length prefix determined by a type. The length prefix is useful -//! for deserializing strings that are not the last field of a struct, as it allows the -//! deserializer to know how many bytes to read for the string, while allowing for more -//! efficient storage depending on the expected length of the string. -//! -//! The types in this module also implement the `Deref` trait, allowing them to be used -//! as regular `String` in most contexts. - -use { - alloc::string::{String, ToString}, - core::{ - fmt::{Debug, Formatter}, - ops::Deref, - }, -}; -#[cfg(feature = "borsh")] -use { - alloc::vec, - borsh::{ - io::{ErrorKind, Read, Write}, - BorshDeserialize, BorshSerialize, - }, -}; -#[cfg(feature = "wincode")] -use { - core::mem::MaybeUninit, - wincode::{ - config::ConfigCore, - error::{invalid_utf8_encoding, write_length_encoding_overflow}, - io::{Reader, Writer}, - ReadResult, SchemaRead, SchemaWrite, WriteResult, - }, -}; - -#[cfg(feature = "borsh")] -/// Size of the buffer used to read the string. -const BUFFER_SIZE: usize = 1024; - -/// A `String` serialized without a length prefix. -/// -/// This is useful for serializing strings that are the last field -/// of a struct, where the length can be inferred from the remaining -/// bytes. -/// -/// Note that this type is not suitable for serializing strings that -/// are not the last field of a struct, as it will consume all -/// remaining bytes. -/// -/// # Examples -/// -/// Using `TrailingString` in a struct results in the string being -/// serialized without a length prefix. -/// -/// ``` -/// use spl_collections::TrailingString; -/// use wincode::{SchemaRead, SchemaWrite}; -/// -/// #[derive(SchemaRead, SchemaWrite)] -/// pub struct MyStruct { -/// pub state: u8, -/// pub amount: u64, -/// pub description: TrailingString, -/// } -/// -/// let my_struct = MyStruct { -/// state: 1, -/// amount: 1_000_000_000, -/// description: TrailingString::from( -/// "The quick brown fox jumps over the lazy dog" -/// ), -/// }; -/// -/// let bytes = wincode::serialize(&my_struct).unwrap(); -/// // Expected size: -/// // - state (1 byte) -/// // - amount (8 bytes) -/// // - description (remaining bytes without a length prefix) -/// assert_eq!(bytes.len(), 1 + 8 + my_struct.description.len()); -/// # let deserialized = wincode::deserialize::(&bytes).unwrap(); -/// -/// # assert_eq!(deserialized.state, my_struct.state); -/// # assert_eq!(deserialized.amount, my_struct.amount); -/// # assert_eq!(deserialized.description, my_struct.description); -/// ``` -#[derive(Clone, Eq, PartialEq)] -#[repr(transparent)] -pub struct TrailingString(String); - -impl From for TrailingString { - fn from(value: String) -> Self { - Self(value) - } -} - -impl From<&str> for TrailingString { - fn from(value: &str) -> Self { - Self(value.to_string()) - } -} - -impl Deref for TrailingString { - type Target = String; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Debug for TrailingString { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - f.write_fmt(format_args!("{:?}", self.0)) - } -} - -#[cfg(feature = "borsh")] -impl BorshSerialize for TrailingString { - fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { - // Serialize the string bytes without a length prefix. - writer.write_all(self.0.as_bytes()) - } -} - -#[cfg(feature = "borsh")] -impl BorshDeserialize for TrailingString { - fn deserialize_reader(reader: &mut R) -> borsh::io::Result { - // Read the string in chunks until we reach the end of the reader. - let mut buffer = [0u8; BUFFER_SIZE]; - let mut s = String::new(); - - loop { - let bytes_read = reader.read(&mut buffer)?; - - if bytes_read == 0 { - break; - } - - s.push_str( - core::str::from_utf8(&buffer[..bytes_read]).map_err(|_| ErrorKind::InvalidData)?, - ); - } - - Ok(Self(s)) - } -} - -#[cfg(feature = "wincode")] -unsafe impl SchemaWrite for TrailingString { - type Src = Self; - - #[inline(always)] - fn size_of(src: &Self::Src) -> WriteResult { - Ok(src.0.len()) - } - - #[inline(always)] - fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { - // Serialize the string bytes without a length prefix. - unsafe { - writer - .write_slice_t(src.0.as_bytes()) - .map_err(wincode::WriteError::Io) - } - } -} - -#[cfg(feature = "wincode")] -unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for TrailingString { - type Dst = Self; - - fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { - let mut s = String::new(); - let mut bytes_read = 0; - - loop { - // SAFETY: Move the reader by `bytes_read` from the previous iteration. - unsafe { reader.consume_unchecked(bytes_read) }; - - // Read the string in chunks until we reach the end of the reader. - let bytes = reader.fill_buf(BUFFER_SIZE)?; - - if bytes.is_empty() { - break; - } - - s.push_str(core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?); - bytes_read = bytes.len(); - } - - dst.write(Self(s)); - - Ok(()) - } -} - -/// Macro defining a `PrefixedStr` type with a specified length prefix type. -macro_rules! prefixed_str_type { - ( $name:tt, $prefix_type:tt ) => { - #[doc = concat!("A `String` that is serialized with an `", stringify!($prefix_type), "` length prefix.")] - #[derive(Clone, Eq, PartialEq)] - #[repr(transparent)] - pub struct $name(String); - - impl From for $name { - fn from(value: String) -> Self { - Self(value) - } - } - - impl From<&str> for $name { - fn from(value: &str) -> Self { - Self(value.to_string()) - } - } - - impl Deref for $name { - type Target = String; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl Debug for $name { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - f.write_fmt(format_args!("{:?}", self.0)) - } - } - - #[cfg(feature = "borsh")] - impl BorshSerialize for $name { - fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { - BorshSerialize::serialize( - &$prefix_type::try_from(self.0.len()).map_err(|_| ErrorKind::InvalidData)?, - writer, - )?; - writer.write_all(self.0.as_bytes()) - } - } - - #[cfg(feature = "borsh")] - impl BorshDeserialize for $name { - fn deserialize_reader(reader: &mut R) -> borsh::io::Result { - let prefix = $prefix_type::deserialize_reader(reader)?; - - let mut buffer = vec![0u8; prefix as usize]; - reader.read_exact(&mut buffer)?; - - Ok(Self::from( - String::from_utf8(buffer).map_err(|_| ErrorKind::InvalidData)?, - )) - } - } - - #[cfg(feature = "wincode")] - unsafe impl SchemaWrite for $name { - type Src = Self; - - #[inline(always)] - fn size_of(src: &Self::Src) -> WriteResult { - Ok(core::mem::size_of::<$prefix_type>() + src.0.len()) - } - - #[inline(always)] - fn write(mut writer: impl Writer, src: &Self::Src) -> WriteResult<()> { - <$prefix_type as SchemaWrite>::write( - &mut writer, - &$prefix_type::try_from(src.0.len()) - .map_err(|_| write_length_encoding_overflow(stringify!($prefix_type::MAX)))?, - )?; - // SAFETY: Serializing a slice of `[u8]`. - unsafe { - writer - .write_slice_t(src.0.as_bytes()) - .map_err(wincode::WriteError::Io) - } - } - } - - #[cfg(feature = "wincode")] - unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for $name { - type Dst = Self; - - fn read( - mut reader: impl Reader<'de>, - dst: &mut MaybeUninit, - ) -> ReadResult<()> { - // Read the length prefix first to determine how many bytes to read for the string. - let mut prefix = MaybeUninit::<$prefix_type>::uninit(); - <$prefix_type as SchemaRead<'de, C>>::read(&mut reader, &mut prefix)?; - // SAFETY: We have just read the prefix from the reader, so it is initialized. - let prefix = unsafe { prefix.assume_init() } as usize; - - let bytes = reader.fill_exact(prefix)?; - dst.write($name::from( - core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?, - )); - - Ok(()) - } - } - }; -} - -// A `PrefixedString` with a `u8` length prefix. -prefixed_str_type!(U8PrefixedString, u8); - -// A `PrefixedString` with a `u16` length prefix. -prefixed_str_type!(U16PrefixedString, u16); - -// A `PrefixedString` with a `u32` length prefix. -prefixed_str_type!(U32PrefixedString, u32); - -// A `PrefixedString` with a `u64` length prefix. -prefixed_str_type!(U64PrefixedString, u64); - -#[cfg(test)] -mod tests { - use borsh::{BorshDeserialize, BorshSerialize}; - use core::mem::size_of; - use wincode::WriteError; - - use super::*; - - #[test] - fn trailing_str_borsh_round_trip() { - const DATA: &str = "Trailing strings have many characters"; - - let original: TrailingString = TrailingString::from(String::from(DATA)); - // No need to reserve space for a length prefix. - let mut bytes = [0u8; DATA.len()]; - - original.serialize(&mut bytes.as_mut_slice()).unwrap(); - - let serialized = TrailingString::try_from_slice(&bytes).unwrap(); - - assert_eq!(serialized.len(), original.len()); - assert_eq!(serialized, original); - } - - #[test] - fn trailing_str_wincode_round_trip() { - const DATA: &str = "Trailing strings have many characters"; - - let original: TrailingString = TrailingString::from(String::from(DATA)); - // No need to reserve space for a length prefix. - let mut bytes = [0u8; DATA.len()]; - - wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); - - let serialized = wincode::deserialize::(&bytes).unwrap(); - - assert_eq!(serialized.len(), original.len()); - assert_eq!(serialized.as_str(), DATA); - assert_eq!(serialized, original); - } - - #[test] - fn prefixed_str_borsh_round_trip() { - const TEXT: &str = "Prefixed strings have many characters"; - - // u8 length prefix + string bytes - let original = U8PrefixedString::from(String::from(TEXT)); - let bytes = borsh::to_vec(&original).unwrap(); - - assert_eq!(bytes[0], TEXT.len() as u8); - assert_eq!(&bytes[1..], TEXT.as_bytes()); - - let string = U8PrefixedString::try_from_slice(&bytes).unwrap(); - - assert_eq!(string.len(), TEXT.len()); - assert_eq!(string.as_str(), TEXT); - - // u16 length prefix + string bytes - let original = U16PrefixedString::from(String::from(TEXT)); - let bytes = borsh::to_vec(&original).unwrap(); - - assert_eq!( - u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), - TEXT.len() as u16 - ); - assert_eq!(&bytes[2..], TEXT.as_bytes()); - - let string = U16PrefixedString::try_from_slice(&bytes).unwrap(); - - assert_eq!(string.len(), TEXT.len()); - assert_eq!(string.as_str(), TEXT); - - // u32 length prefix + string bytes - let original = U32PrefixedString::from(String::from(TEXT)); - let bytes = borsh::to_vec(&original).unwrap(); - - assert_eq!( - u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), - TEXT.len() as u32 - ); - assert_eq!(&bytes[4..], TEXT.as_bytes()); - - let string = U32PrefixedString::try_from_slice(&bytes).unwrap(); - - assert_eq!(string.len(), TEXT.len()); - assert_eq!(string.as_str(), TEXT); - - // u64 length prefix + string bytes - let original = U64PrefixedString::from(String::from(TEXT)); - let bytes = borsh::to_vec(&original).unwrap(); - - assert_eq!( - u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), - TEXT.len() as u64 - ); - assert_eq!(&bytes[8..], TEXT.as_bytes()); - - let string = U64PrefixedString::try_from_slice(&bytes).unwrap(); - - assert_eq!(string.len(), TEXT.len()); - assert_eq!(string.as_str(), TEXT); - } - - #[test] - fn prefixed_str_wincode_round_trip() { - const TEXT: &str = "Prefixed strings have many characters"; - - // u8 length prefix + string bytes - let original = U8PrefixedString::from(String::from(TEXT)); - let mut bytes = [0u8; size_of::() + TEXT.len()]; - wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); - - assert_eq!(bytes[0], TEXT.len() as u8); - assert_eq!(&bytes[1..], TEXT.as_bytes()); - - let serialized = wincode::deserialize::(&bytes).unwrap(); - - assert_eq!(serialized.len(), TEXT.len()); - assert_eq!(serialized.as_str(), TEXT); - assert_eq!(serialized, original); - - // u16 length prefix + string bytes - let original = U16PrefixedString::from(String::from(TEXT)); - let mut bytes = [0u8; size_of::() + TEXT.len()]; - wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); - - assert_eq!( - u16::from_le_bytes(unsafe { *(bytes[0..2].as_ptr() as *const [u8; 2]) }), - TEXT.len() as u16 - ); - assert_eq!(&bytes[2..], TEXT.as_bytes()); - - let serialized = wincode::deserialize::(&bytes).unwrap(); - - assert_eq!(serialized.len(), TEXT.len()); - assert_eq!(serialized.as_str(), TEXT); - assert_eq!(serialized, original); - - // u32 length prefix + string bytes - let original = U32PrefixedString::from(String::from(TEXT)); - let mut bytes = [0u8; size_of::() + TEXT.len()]; - wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); - - assert_eq!( - u32::from_le_bytes(unsafe { *(bytes[0..4].as_ptr() as *const [u8; 4]) }), - TEXT.len() as u32 - ); - assert_eq!(&bytes[4..], TEXT.as_bytes()); - - let serialized = wincode::deserialize::(&bytes).unwrap(); - - assert_eq!(serialized.len(), TEXT.len()); - assert_eq!(serialized.as_str(), TEXT); - assert_eq!(serialized, original); - - // u64 length prefix + string bytes - let original = U64PrefixedString::from(String::from(TEXT)); - let mut bytes = [0u8; size_of::() + TEXT.len()]; - wincode::serialize_into(bytes.as_mut_slice(), &original).unwrap(); - - assert_eq!( - u64::from_le_bytes(unsafe { *(bytes[0..8].as_ptr() as *const [u8; 8]) }), - TEXT.len() as u64 - ); - assert_eq!(&bytes[8..], TEXT.as_bytes()); - - let serialized = wincode::deserialize::(&bytes).unwrap(); - - assert_eq!(serialized.len(), TEXT.len()); - assert_eq!(serialized.as_str(), TEXT); - assert_eq!(serialized, original); - } - - #[test] - fn invalid_prefixed_value() { - let large_text = "a".repeat(256); - - let original = U8PrefixedString::from(large_text); - - // borsh - let result = borsh::to_vec(&original); - - assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData); - - // wincode - let result = wincode::serialize(&original); - - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - WriteError::LengthEncodingOverflow(_) - )); - } -} diff --git a/collections/src/vec.rs b/collections/src/vec.rs index 3e2c5d3..c4161fa 100644 --- a/collections/src/vec.rs +++ b/collections/src/vec.rs @@ -29,7 +29,7 @@ use { core::mem::MaybeUninit, wincode::{ config::ConfigCore, - error::write_length_encoding_overflow, + error::{write_length_encoding_overflow, ReadError}, io::{Reader, Writer}, ReadResult, SchemaRead, SchemaWrite, WriteResult, }, @@ -231,10 +231,14 @@ macro_rules! prefixed_vec_type { #[cfg(feature = "borsh")] impl BorshDeserialize for $name { fn deserialize_reader(reader: &mut R) -> borsh::io::Result { - let prefix = $prefix_type::deserialize_reader(reader)?; - let mut items: Vec = Vec::with_capacity(prefix as usize); + let prefix = $prefix_type::deserialize_reader(reader)? as usize; + let mut items: Vec = Vec::with_capacity(prefix); + + while items.len() < prefix { + let Ok(item) = T::deserialize_reader(reader) else { + return Err(ErrorKind::InvalidData.into()); + }; - while let Ok(item) = T::deserialize_reader(reader) { items.push(item); } @@ -252,7 +256,7 @@ macro_rules! prefixed_vec_type { #[inline(always)] fn size_of(src: &Self::Src) -> WriteResult { - Ok(core::mem::size_of::<$prefix_type>() + src.0.len()) + Ok(core::mem::size_of::<$prefix_type>() + size_of::() * src.0.len()) } #[inline(always)] @@ -283,7 +287,6 @@ macro_rules! prefixed_vec_type { mut reader: impl Reader<'de>, dst: &mut MaybeUninit, ) -> ReadResult<()> { - // Read the length prefix first to allocate space for `T`s. let mut prefix = MaybeUninit::<$prefix_type>::uninit(); <$prefix_type as SchemaRead<'de, C>>::read(&mut reader, &mut prefix)?; // SAFETY: We have just read the prefix from the reader, so it is initialized. @@ -291,7 +294,11 @@ macro_rules! prefixed_vec_type { let mut items = Vec::with_capacity(prefix); - while let Ok(item) = T::get(&mut reader) { + while items.len() < prefix { + let Ok(item) = T::get(&mut reader) else { + return Err(ReadError::Custom("failed to deserialize")); + }; + items.push(item); } @@ -462,4 +469,31 @@ mod tests { WriteError::LengthEncodingOverflow(_) )); } + + #[test] + fn prefixed_vec_borsh_with_remaining_bytes() { + // Bytes representation for a `U8PrefixedVec` with 8 `u64` values + // followed by 16 additional bytes. + let mut bytes = [255u8; 81]; + bytes[0] = 8; + + let mut reader = bytes.as_slice(); + let serialized = U8PrefixedVec::::deserialize(&mut reader).unwrap(); + + assert_eq!(serialized.len(), 8); + assert_eq!(serialized.as_slice(), &[!(0u64); 8]); + } + + #[test] + fn prefixed_vec_wincode_with_remaining_bytes() { + // Bytes representation for a `U8PrefixedVec` with 8 `u64` values + // followed by 16 additional bytes. + let mut bytes = [255u8; 81]; + bytes[0] = 8; + + let serialized = wincode::deserialize::>(&bytes).unwrap(); + + assert_eq!(serialized.len(), 8); + assert_eq!(serialized.as_slice(), &[!(0u64); 8]); + } } From dbc4e46ed3ca0f6e91be231d8667259d1499a546 Mon Sep 17 00:00:00 2001 From: febo Date: Sun, 1 Mar 2026 15:09:08 -0700 Subject: [PATCH 3/6] Add collections package to env --- .github/workflows/main.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2f71fd7..66e9424 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,9 +8,9 @@ on: env: JS_PACKAGES: "['type-length-value-js']" - SBPF_PROGRAM_PACKAGES: "['discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" - RUST_PACKAGES: "['discriminator', 'discriminator-derive', 'discriminator-syn', 'generic-token', 'generic-token-tests', 'list-view', 'pod', 'program-error', 'program-error-derive', 'tlv-account-resolution', 'type-length-value', 'type-length-value-derive', 'type-length-value-derive-test']" - WASM_PACKAGES: "['discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" + SBPF_PROGRAM_PACKAGES: "['collections', 'discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" + RUST_PACKAGES: "['collections', 'discriminator', 'discriminator-derive', 'discriminator-syn', 'generic-token', 'generic-token-tests', 'list-view', 'pod', 'program-error', 'program-error-derive', 'tlv-account-resolution', 'type-length-value', 'type-length-value-derive', 'type-length-value-derive-test']" + WASM_PACKAGES: "['collections', 'discriminator', 'generic-token', 'list-view', 'pod', 'program-error', 'tlv-account-resolution', 'type-length-value']" jobs: set_env: From 224bbeabef903c907d93b3395975220ff04da5e9 Mon Sep 17 00:00:00 2001 From: febo Date: Mon, 2 Mar 2026 14:19:19 +0000 Subject: [PATCH 4/6] Add size check --- collections/src/vec.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/collections/src/vec.rs b/collections/src/vec.rs index c4161fa..6bacbc7 100644 --- a/collections/src/vec.rs +++ b/collections/src/vec.rs @@ -142,7 +142,16 @@ where #[inline(always)] fn size_of(src: &Self::Src) -> WriteResult { - Ok(src.0.len() * size_of::()) + let expected_size = src.0.len().saturating_mul(core::mem::size_of::()); + + // `Vec` capacity is limited to `isize::MAX`. + if expected_size > isize::MAX as usize { + return Err(write_length_encoding_overflow( + "size of items in TrailingVec", + )); + } + + Ok(expected_size) } #[inline(always)] @@ -256,7 +265,17 @@ macro_rules! prefixed_vec_type { #[inline(always)] fn size_of(src: &Self::Src) -> WriteResult { - Ok(core::mem::size_of::<$prefix_type>() + size_of::() * src.0.len()) + let expected_size = core::mem::size_of::<$prefix_type>().saturating_add( + src.0.len().saturating_mul(size_of::())); + + // `Vec` capacity is limited to `isize::MAX`. + if expected_size > isize::MAX as usize { + return Err(write_length_encoding_overflow( + "size of items in TrailingVec", + )); + } + + Ok(expected_size) } #[inline(always)] From 6eeb6267832d4674bbc382b1c8bffd1f5c35a0a6 Mon Sep 17 00:00:00 2001 From: febo Date: Tue, 3 Mar 2026 01:19:20 +0000 Subject: [PATCH 5/6] Add collection to publish --- .github/workflows/publish-rust.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/publish-rust.yml b/.github/workflows/publish-rust.yml index 5469c66..e69ef74 100644 --- a/.github/workflows/publish-rust.yml +++ b/.github/workflows/publish-rust.yml @@ -9,6 +9,7 @@ on: default: 'discriminator' type: choice options: + - collections - discriminator - discriminator-derive - discriminator-syn From 023ee0c2cb669740ddd0b9535e0e97b7b9455c9f Mon Sep 17 00:00:00 2001 From: febo Date: Tue, 3 Mar 2026 13:34:23 +0000 Subject: [PATCH 6/6] Add impl AsRef --- collections/src/str.rs | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/collections/src/str.rs b/collections/src/str.rs index 72dffed..9d97a4a 100644 --- a/collections/src/str.rs +++ b/collections/src/str.rs @@ -19,7 +19,6 @@ use borsh::{ }; use { crate::{TrailingVec, U16PrefixedVec, U32PrefixedVec, U64PrefixedVec, U8PrefixedVec}, - alloc::string::String, core::{ fmt::{Debug, Formatter}, ops::Deref, @@ -88,15 +87,9 @@ use { #[repr(transparent)] pub struct TrailingStr(TrailingVec); -impl From for TrailingStr { - fn from(value: String) -> Self { - Self(TrailingVec::from(value.as_bytes())) - } -} - -impl From<&str> for TrailingStr { - fn from(value: &str) -> Self { - Self(TrailingVec::from(value.as_bytes())) +impl> From for TrailingStr { + fn from(value: T) -> Self { + Self(TrailingVec::from(value.as_ref().as_bytes())) } } @@ -163,15 +156,9 @@ macro_rules! prefixed_str_type { #[repr(transparent)] pub struct $name($container_type); - impl From for $name { - fn from(value: String) -> Self { - Self($container_type::from(value.as_bytes())) - } - } - - impl From<&str> for $name { - fn from(value: &str) -> Self { - Self($container_type::from(value.as_bytes())) + impl> From for $name { + fn from(value: T) -> Self { + Self($container_type::from(value.as_ref().as_bytes())) } } @@ -242,7 +229,7 @@ prefixed_str_type!(U64PrefixedStr, U64PrefixedVec, u64); #[cfg(test)] mod tests { use { - alloc::vec::Vec, + alloc::{string::String, vec::Vec}, borsh::{io::ErrorKind, BorshDeserialize, BorshSerialize}, core::mem::size_of, wincode::WriteError,