diff --git a/src/array.rs b/src/array.rs index cd0f898..0ef0b3a 100644 --- a/src/array.rs +++ b/src/array.rs @@ -8,7 +8,7 @@ use crate::serialization::Serializable; use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable}; /// A wrapper around an array of field elements that implements SSZ Encode/Decode. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct FieldArray(pub [F; N]); diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 3d6c795..72ab34f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -188,7 +188,7 @@ impl Decode for GeneralizedXMSSSign /// Public key for GeneralizedXMSSSignatureScheme /// It contains a Merkle root and a parameter for the tweakable hash -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct GeneralizedXMSSPublicKey { root: TH::Domain, parameter: TH::Parameter, diff --git a/src/signature/generalized_xmss/instantiations_aborting.rs b/src/signature/generalized_xmss/instantiations_aborting.rs index ddd5942..e386055 100644 --- a/src/signature/generalized_xmss/instantiations_aborting.rs +++ b/src/signature/generalized_xmss/instantiations_aborting.rs @@ -4,7 +4,8 @@ pub mod lifetime_2_to_the_32 { use crate::{ inc_encoding::target_sum::TargetSumEncoding, signature::generalized_xmss::{ - GeneralizedXMSSPublicKey, GeneralizedXMSSSignature, GeneralizedXMSSSignatureScheme, + GeneralizedXMSSPublicKey, GeneralizedXMSSSecretKey, GeneralizedXMSSSignature, + GeneralizedXMSSSignatureScheme, }, symmetric::{ message_hash::aborting::AbortingHypercubeMessageHash, prf::shake_to_field::ShakePRFtoF, @@ -43,16 +44,18 @@ pub mod lifetime_2_to_the_32 { type PRF = ShakePRFtoF; type IE = TargetSumEncoding; - pub type SIGAbortingTargetSumLifetime32Dim64Base8 = + pub type SchemeAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignatureScheme; pub type PubKeyAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSPublicKey; + pub type SecretKeyAbortingTargetSumLifetime32Dim64Base8 = + GeneralizedXMSSSecretKey; pub type SigAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignature; #[cfg(test)] mod test { #[cfg(feature = "slow-tests")] - use super::*; + use super::SchemeAbortingTargetSumLifetime32Dim64Base8; #[cfg(feature = "slow-tests")] use crate::signature::SignatureScheme; @@ -62,15 +65,15 @@ pub mod lifetime_2_to_the_32 { #[test] #[cfg(feature = "slow-tests")] pub fn test_correctness() { - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 213, 0, - SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, ); - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 4, 0, - SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, ); } } @@ -121,7 +124,7 @@ pub mod lifetime_2_to_the_6 { type PRF = ShakePRFtoF; type IE = TargetSumEncoding; - pub type SIGAbortingTargetSumLifetime6Dim46Base8 = + pub type SchemeAbortingTargetSumLifetime6Dim46Base8 = GeneralizedXMSSSignatureScheme; #[cfg(test)] @@ -130,19 +133,19 @@ pub mod lifetime_2_to_the_6 { SignatureScheme, test_templates::test_signature_scheme_correctness, }; - use super::SIGAbortingTargetSumLifetime6Dim46Base8; + use super::SchemeAbortingTargetSumLifetime6Dim46Base8; #[test] pub fn test_correctness() { - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 2, 0, - SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, ); - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 11, 0, - SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, ); } } diff --git a/src/symmetric/message_hash.rs b/src/symmetric/message_hash.rs index 95fc4e5..0ba2181 100644 --- a/src/symmetric/message_hash.rs +++ b/src/symmetric/message_hash.rs @@ -5,6 +5,8 @@ use rand::RngExt; use crate::MESSAGE_LENGTH; use crate::serialization::Serializable; +pub use poseidon::encode_message; + /// Trait to model a hash function used for message hashing. /// /// This is a variant of a tweakable hash function that we use for diff --git a/src/symmetric/message_hash/aborting.rs b/src/symmetric/message_hash/aborting.rs index 90818ff..7b8a971 100644 --- a/src/symmetric/message_hash/aborting.rs +++ b/src/symmetric/message_hash/aborting.rs @@ -14,6 +14,7 @@ use crate::array::FieldArray; /// Given p = Q * w^z + alpha, each Poseidon output field element A_i is: /// 1) checked to be less than Q * w^z, and if not the hash aborts /// 2) decomposed as d_i = floor(A_i / Q), then d_i is written in base w with z digits. +#[derive(Debug, Clone, Copy)] pub struct AbortingHypercubeMessageHash< const PARAMETER_LEN: usize, const RAND_LEN_FE: usize, diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index 5605cbf..f4fe81e 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -114,11 +114,11 @@ pub(crate) fn poseidon_message_hash_fe< let epoch_fe = encode_epoch::(epoch); // now, we hash randomness, parameters, epoch, message using PoseidonCompress - let combined_input_vec: Vec = randomness + let combined_input_vec: Vec = message_fe .iter() .chain(parameter.iter()) .chain(epoch_fe.iter()) - .chain(message_fe.iter()) + .chain(randomness.iter()) .copied() .collect(); diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 286d09e..9b8ebc1 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -161,7 +161,7 @@ fn poseidon_safe_domain_separator( poseidon_compress::(perm, &input) } -/// Poseidon Sponge Hash Function +/// Poseidon Sponge with "Replacement" Hash Function /// /// Absorbs an arbitrary-length input using the Poseidon sponge construction /// and outputs `OUT_LEN` field elements. Domain separation is achieved by @@ -179,13 +179,19 @@ fn poseidon_safe_domain_separator( /// - `input`: message to hash (any length). /// /// ### Sponge Construction -/// This follows the classic sponge structure: -/// - **Absorption**: inputs are added chunk-by-chunk into the first `rate` elements of the state. -/// - **Squeezing**: outputs are read from the first `rate` elements of the state, permuted as needed. +/// This follows the classic sponge structure with capacity-first layout: +/// - The state is `[capacity | rate]`, i.e., the first elements hold the capacity, +/// followed by the rate elements. +/// - **Absorption**: inputs are written into the rate part of the state (`state[cap_len..]`). +/// - **Squeezing**: outputs are read from the rate part of the state, permuted as needed. +/// +/// ### "Replacement" +/// This means we "replace" the rate elements of the state with the input chunk, instead +/// of adding (in the sense of finite field addition). /// /// ### Panics /// - If `capacity_value.len() >= WIDTH` -fn poseidon_sponge( +fn poseidon_replacement_sponge( perm: &P, capacity_value: &[A], input: &[A], @@ -200,11 +206,12 @@ where capacity_value.len() < WIDTH, "Capacity length must be smaller than the state width." ); - let rate = WIDTH - capacity_value.len(); + let cap_len = capacity_value.len(); + let rate = WIDTH - cap_len; // initialize let mut state = [A::ZERO; WIDTH]; - state[rate..].copy_from_slice(capacity_value); + state[..cap_len].copy_from_slice(capacity_value); // Instead of converting the input to a vector, resizing and feeding the data into the // sponge, we instead fill in the vector from all chunks until we are left with a non @@ -213,20 +220,22 @@ where // 1. fill in all full chunks and permute let mut it = input.chunks_exact(rate); for chunk in &mut it { - // add chunk elements into the first `rate` many elements of the `state` - for (s, &x) in state.iter_mut().take(rate).zip(chunk) { - *s += x; + // write chunk elements into the `rate` part of the state + for (s, &x) in state[cap_len..].iter_mut().zip(chunk) { + *s = x; // 'replacement' sponge } perm.permute_mut(&mut state); } // 2. Fill the remainder and pad with zeros. // NOTE: This zero-padding is secure for constant-size inputs but may be insecure elsewhere. if !it.remainder().is_empty() { + let num_remainder = it.remainder().len(); for (i, x) in it.remainder().iter().enumerate() { - state[i] += *x; + state[cap_len + i] = *x; + } + for s in &mut state[cap_len + num_remainder..] { + *s = A::ZERO; } - // Since we only *add* to the state, positions beyond the remainder remain zero - // (their initial value), so no explicit zero-padding is needed. perm.permute_mut(&mut state); } @@ -235,7 +244,7 @@ where let mut out_index = 0; while out_index < OUT_LEN { let chunk_size = (OUT_LEN - out_index).min(rate); - out[out_index..out_index + chunk_size].copy_from_slice(&state[..chunk_size]); + out[out_index..out_index + chunk_size].copy_from_slice(&state[cap_len..][..chunk_size]); out_index += chunk_size; if out_index < OUT_LEN { // no need to permute in last iteration, `state` is local variable @@ -249,7 +258,7 @@ where /// /// Note: HASH_LEN, TWEAK_LEN, CAPACITY, and PARAMETER_LEN must /// be given in the unit "number of field elements". -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct PoseidonTweakHash< const PARAMETER_LEN: usize, const HASH_LEN: usize, @@ -343,18 +352,17 @@ impl< match message { [single] => { - // we compress parameter, tweak, message + // we compress message, parameter, tweak let perm = poseidon1_16(); - // Build input on stack: [parameter | tweak | message] + // Build input on stack: [message | parameter | tweak] let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH]; - combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); - combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); - combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] - .copy_from_slice(&single.0); + combined_input[..HASH_LEN].copy_from_slice(&single.0); + combined_input[HASH_LEN..][..PARAMETER_LEN].copy_from_slice(¶meter.0); + combined_input[HASH_LEN + PARAMETER_LEN..][..TWEAK_LEN].copy_from_slice(&tweak_fe); FieldArray( - poseidon_compress::( + poseidon_compress::<_, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( &perm, &combined_input, ), @@ -376,7 +384,7 @@ impl< .copy_from_slice(&right.0); FieldArray( - poseidon_compress::( + poseidon_compress::<_, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( &perm, &combined_input, ), @@ -400,11 +408,12 @@ impl< HASH_LEN as u32, ]; let capacity_value = poseidon_safe_domain_separator::(&perm, &lengths); - FieldArray(poseidon_sponge::( - &perm, - &capacity_value, - &combined_input, - )) + FieldArray(poseidon_replacement_sponge::< + _, + _, + MERGE_COMPRESSION_WIDTH, + HASH_LEN, + >(&perm, &capacity_value, &combined_input)) } _ => FieldArray([F::ONE; HASH_LEN]), // Unreachable case, added for safety } @@ -593,9 +602,10 @@ impl< // Cache strategy: process one chain at a time to maximize locality. // All epochs for that chain stay in registers across iterations. - // Offsets for chain compression: [parameter | tweak | current_value] - let chain_tweak_offset = PARAMETER_LEN; - let chain_value_offset = PARAMETER_LEN + TWEAK_LEN; + // Offsets for chain compression: [current_value | parameter | tweak] + let chain_value_offset = 0; + let chain_parameter_offset = HASH_LEN; + let chain_tweak_offset = HASH_LEN + PARAMETER_LEN; for (chain_index, packed_chain) in packed_chains.iter_mut().enumerate().take(num_chains) @@ -607,11 +617,17 @@ impl< let pos = (step + 1) as u8; // Assemble the packed input for the hash function. - // Layout: [parameter | tweak | current_value] + // Layout: [current_value | parameter | tweak] let mut packed_input = [PackedF::ZERO; CHAIN_COMPRESSION_WIDTH]; + // Copy current chain value (already packed) + packed_input[chain_value_offset..chain_value_offset + HASH_LEN] + .copy_from_slice(packed_chain); + // Copy pre-packed parameter - packed_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); + packed_input + [chain_parameter_offset..chain_parameter_offset + PARAMETER_LEN] + .copy_from_slice(&packed_parameter); // Pack tweaks directly into destination pack_fn_into::( @@ -623,10 +639,6 @@ impl< }, ); - // Copy current chain value (already packed) - packed_input[chain_value_offset..chain_value_offset + HASH_LEN] - .copy_from_slice(packed_chain); - // Apply the hash function to advance the chain. // This single call processes all epochs in parallel. *packed_chain = @@ -678,7 +690,7 @@ impl< // Apply the sponge hash to produce the leaf. // This absorbs all chain ends and squeezes out the final hash. - poseidon_sponge::( + poseidon_replacement_sponge::( &sponge_perm, &capacity_val, packed_leaf_input,