diff --git a/README.md b/README.md index 309cc13..a2047f6 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,10 @@ Confidence intervals can also be shown via python3 benchmark-mean.py target --intervals ``` +## Deviations from the [original paper](https://eprint.iacr.org/2025/055.pdf) + +- use of 'overwrite' sponge, instead of 'addition' / 'xor' sponge. + ## License Apache Version 2.0. diff --git a/src/array.rs b/src/array.rs index cd0f898..0ef0b3a 100644 --- a/src/array.rs +++ b/src/array.rs @@ -8,7 +8,7 @@ use crate::serialization::Serializable; use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable}; /// A wrapper around an array of field elements that implements SSZ Encode/Decode. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct FieldArray(pub [F; N]); diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 3d6c795..72ab34f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -188,7 +188,7 @@ impl Decode for GeneralizedXMSSSign /// Public key for GeneralizedXMSSSignatureScheme /// It contains a Merkle root and a parameter for the tweakable hash -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct GeneralizedXMSSPublicKey { root: TH::Domain, parameter: TH::Parameter, diff --git a/src/signature/generalized_xmss/instantiations_aborting.rs b/src/signature/generalized_xmss/instantiations_aborting.rs index ddd5942..09348bb 100644 --- a/src/signature/generalized_xmss/instantiations_aborting.rs +++ b/src/signature/generalized_xmss/instantiations_aborting.rs @@ -4,7 +4,8 @@ pub mod lifetime_2_to_the_32 { use crate::{ inc_encoding::target_sum::TargetSumEncoding, signature::generalized_xmss::{ - GeneralizedXMSSPublicKey, GeneralizedXMSSSignature, GeneralizedXMSSSignatureScheme, + GeneralizedXMSSPublicKey, GeneralizedXMSSSecretKey, GeneralizedXMSSSignature, + GeneralizedXMSSSignatureScheme, }, symmetric::{ message_hash::aborting::AbortingHypercubeMessageHash, prf::shake_to_field::ShakePRFtoF, @@ -43,16 +44,18 @@ pub mod lifetime_2_to_the_32 { type PRF = ShakePRFtoF; type IE = TargetSumEncoding; - pub type SIGAbortingTargetSumLifetime32Dim64Base8 = + pub type SchemeAbortingTargetSumLifetime32Dim46Base8 = GeneralizedXMSSSignatureScheme; - pub type PubKeyAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSPublicKey; - pub type SigAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignature; + pub type PubKeyAbortingTargetSumLifetime32Dim46Base8 = GeneralizedXMSSPublicKey; + pub type SecretKeyAbortingTargetSumLifetime32Dim46Base8 = + GeneralizedXMSSSecretKey; + pub type SigAbortingTargetSumLifetime32Dim46Base8 = GeneralizedXMSSSignature; #[cfg(test)] mod test { #[cfg(feature = "slow-tests")] - use super::*; + use super::SchemeAbortingTargetSumLifetime32Dim46Base8; #[cfg(feature = "slow-tests")] use crate::signature::SignatureScheme; @@ -62,38 +65,41 @@ pub mod lifetime_2_to_the_32 { #[test] #[cfg(feature = "slow-tests")] pub fn test_correctness() { - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 213, 0, - SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime32Dim46Base8::LIFETIME as usize, ); - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 4, 0, - SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime32Dim46Base8::LIFETIME as usize, ); } } } -/// Instantiations with Lifetime 2^6. This is for testing purposes only. +/// Instantiations with Lifetime 2^8. This is for testing purposes only. /// /// Warning: Should not be used in production environments. -pub mod lifetime_2_to_the_6 { +pub mod lifetime_2_to_the_8 { use crate::{ inc_encoding::target_sum::TargetSumEncoding, - signature::generalized_xmss::GeneralizedXMSSSignatureScheme, + signature::generalized_xmss::{ + GeneralizedXMSSPublicKey, GeneralizedXMSSSecretKey, GeneralizedXMSSSignature, + GeneralizedXMSSSignatureScheme, + }, symmetric::{ message_hash::aborting::AbortingHypercubeMessageHash, prf::shake_to_field::ShakePRFtoF, tweak_hash::poseidon::PoseidonTweakHash, }, }; - const LOG_LIFETIME: usize = 6; + const LOG_LIFETIME: usize = 8; - const DIMENSION: usize = 46; + const DIMENSION: usize = 4; const BASE: usize = 8; - const TARGET_SUM: usize = 200; + const TARGET_SUM: usize = 6; const Z: usize = 8; const Q: usize = 127; @@ -121,28 +127,32 @@ pub mod lifetime_2_to_the_6 { type PRF = ShakePRFtoF; type IE = TargetSumEncoding; - pub type SIGAbortingTargetSumLifetime6Dim46Base8 = + pub type SchemeAbortingTargetSumLifetime8Dim46Base8 = GeneralizedXMSSSignatureScheme; + pub type PubKeyAbortingTargetSumLifetime8Dim46Base8 = GeneralizedXMSSPublicKey; + pub type SecretKeyAbortingTargetSumLifetime8Dim46Base8 = + GeneralizedXMSSSecretKey; + pub type SigAbortingTargetSumLifetime8Dim46Base8 = GeneralizedXMSSSignature; #[cfg(test)] mod test { use crate::signature::{ - SignatureScheme, test_templates::test_signature_scheme_correctness, + SignatureScheme, + generalized_xmss::instantiations_aborting::lifetime_2_to_the_8::SchemeAbortingTargetSumLifetime8Dim46Base8, + test_templates::test_signature_scheme_correctness, }; - use super::SIGAbortingTargetSumLifetime6Dim46Base8; - #[test] pub fn test_correctness() { - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 2, 0, - SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime8Dim46Base8::LIFETIME as usize, ); - test_signature_scheme_correctness::( + test_signature_scheme_correctness::( 11, 0, - SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize, + SchemeAbortingTargetSumLifetime8Dim46Base8::LIFETIME as usize, ); } } diff --git a/src/symmetric/message_hash.rs b/src/symmetric/message_hash.rs index 95fc4e5..0ba2181 100644 --- a/src/symmetric/message_hash.rs +++ b/src/symmetric/message_hash.rs @@ -5,6 +5,8 @@ use rand::RngExt; use crate::MESSAGE_LENGTH; use crate::serialization::Serializable; +pub use poseidon::encode_message; + /// Trait to model a hash function used for message hashing. /// /// This is a variant of a tweakable hash function that we use for diff --git a/src/symmetric/message_hash/aborting.rs b/src/symmetric/message_hash/aborting.rs index 90818ff..7b8a971 100644 --- a/src/symmetric/message_hash/aborting.rs +++ b/src/symmetric/message_hash/aborting.rs @@ -14,6 +14,7 @@ use crate::array::FieldArray; /// Given p = Q * w^z + alpha, each Poseidon output field element A_i is: /// 1) checked to be less than Q * w^z, and if not the hash aborts /// 2) decomposed as d_i = floor(A_i / Q), then d_i is written in base w with z digits. +#[derive(Debug, Clone, Copy)] pub struct AbortingHypercubeMessageHash< const PARAMETER_LEN: usize, const RAND_LEN_FE: usize, diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index 5605cbf..f4fe81e 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -114,11 +114,11 @@ pub(crate) fn poseidon_message_hash_fe< let epoch_fe = encode_epoch::(epoch); // now, we hash randomness, parameters, epoch, message using PoseidonCompress - let combined_input_vec: Vec = randomness + let combined_input_vec: Vec = message_fe .iter() .chain(parameter.iter()) .chain(epoch_fe.iter()) - .chain(message_fe.iter()) + .chain(randomness.iter()) .copied() .collect(); diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 286d09e..9b8ebc1 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -161,7 +161,7 @@ fn poseidon_safe_domain_separator( poseidon_compress::(perm, &input) } -/// Poseidon Sponge Hash Function +/// Poseidon Sponge with "Replacement" Hash Function /// /// Absorbs an arbitrary-length input using the Poseidon sponge construction /// and outputs `OUT_LEN` field elements. Domain separation is achieved by @@ -179,13 +179,19 @@ fn poseidon_safe_domain_separator( /// - `input`: message to hash (any length). /// /// ### Sponge Construction -/// This follows the classic sponge structure: -/// - **Absorption**: inputs are added chunk-by-chunk into the first `rate` elements of the state. -/// - **Squeezing**: outputs are read from the first `rate` elements of the state, permuted as needed. +/// This follows the classic sponge structure with capacity-first layout: +/// - The state is `[capacity | rate]`, i.e., the first elements hold the capacity, +/// followed by the rate elements. +/// - **Absorption**: inputs are written into the rate part of the state (`state[cap_len..]`). +/// - **Squeezing**: outputs are read from the rate part of the state, permuted as needed. +/// +/// ### "Replacement" +/// This means we "replace" the rate elements of the state with the input chunk, instead +/// of adding (in the sense of finite field addition). /// /// ### Panics /// - If `capacity_value.len() >= WIDTH` -fn poseidon_sponge( +fn poseidon_replacement_sponge( perm: &P, capacity_value: &[A], input: &[A], @@ -200,11 +206,12 @@ where capacity_value.len() < WIDTH, "Capacity length must be smaller than the state width." ); - let rate = WIDTH - capacity_value.len(); + let cap_len = capacity_value.len(); + let rate = WIDTH - cap_len; // initialize let mut state = [A::ZERO; WIDTH]; - state[rate..].copy_from_slice(capacity_value); + state[..cap_len].copy_from_slice(capacity_value); // Instead of converting the input to a vector, resizing and feeding the data into the // sponge, we instead fill in the vector from all chunks until we are left with a non @@ -213,20 +220,22 @@ where // 1. fill in all full chunks and permute let mut it = input.chunks_exact(rate); for chunk in &mut it { - // add chunk elements into the first `rate` many elements of the `state` - for (s, &x) in state.iter_mut().take(rate).zip(chunk) { - *s += x; + // write chunk elements into the `rate` part of the state + for (s, &x) in state[cap_len..].iter_mut().zip(chunk) { + *s = x; // 'replacement' sponge } perm.permute_mut(&mut state); } // 2. Fill the remainder and pad with zeros. // NOTE: This zero-padding is secure for constant-size inputs but may be insecure elsewhere. if !it.remainder().is_empty() { + let num_remainder = it.remainder().len(); for (i, x) in it.remainder().iter().enumerate() { - state[i] += *x; + state[cap_len + i] = *x; + } + for s in &mut state[cap_len + num_remainder..] { + *s = A::ZERO; } - // Since we only *add* to the state, positions beyond the remainder remain zero - // (their initial value), so no explicit zero-padding is needed. perm.permute_mut(&mut state); } @@ -235,7 +244,7 @@ where let mut out_index = 0; while out_index < OUT_LEN { let chunk_size = (OUT_LEN - out_index).min(rate); - out[out_index..out_index + chunk_size].copy_from_slice(&state[..chunk_size]); + out[out_index..out_index + chunk_size].copy_from_slice(&state[cap_len..][..chunk_size]); out_index += chunk_size; if out_index < OUT_LEN { // no need to permute in last iteration, `state` is local variable @@ -249,7 +258,7 @@ where /// /// Note: HASH_LEN, TWEAK_LEN, CAPACITY, and PARAMETER_LEN must /// be given in the unit "number of field elements". -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct PoseidonTweakHash< const PARAMETER_LEN: usize, const HASH_LEN: usize, @@ -343,18 +352,17 @@ impl< match message { [single] => { - // we compress parameter, tweak, message + // we compress message, parameter, tweak let perm = poseidon1_16(); - // Build input on stack: [parameter | tweak | message] + // Build input on stack: [message | parameter | tweak] let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH]; - combined_input[..PARAMETER_LEN].copy_from_slice(¶meter.0); - combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe); - combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN] - .copy_from_slice(&single.0); + combined_input[..HASH_LEN].copy_from_slice(&single.0); + combined_input[HASH_LEN..][..PARAMETER_LEN].copy_from_slice(¶meter.0); + combined_input[HASH_LEN + PARAMETER_LEN..][..TWEAK_LEN].copy_from_slice(&tweak_fe); FieldArray( - poseidon_compress::( + poseidon_compress::<_, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>( &perm, &combined_input, ), @@ -376,7 +384,7 @@ impl< .copy_from_slice(&right.0); FieldArray( - poseidon_compress::( + poseidon_compress::<_, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>( &perm, &combined_input, ), @@ -400,11 +408,12 @@ impl< HASH_LEN as u32, ]; let capacity_value = poseidon_safe_domain_separator::(&perm, &lengths); - FieldArray(poseidon_sponge::( - &perm, - &capacity_value, - &combined_input, - )) + FieldArray(poseidon_replacement_sponge::< + _, + _, + MERGE_COMPRESSION_WIDTH, + HASH_LEN, + >(&perm, &capacity_value, &combined_input)) } _ => FieldArray([F::ONE; HASH_LEN]), // Unreachable case, added for safety } @@ -593,9 +602,10 @@ impl< // Cache strategy: process one chain at a time to maximize locality. // All epochs for that chain stay in registers across iterations. - // Offsets for chain compression: [parameter | tweak | current_value] - let chain_tweak_offset = PARAMETER_LEN; - let chain_value_offset = PARAMETER_LEN + TWEAK_LEN; + // Offsets for chain compression: [current_value | parameter | tweak] + let chain_value_offset = 0; + let chain_parameter_offset = HASH_LEN; + let chain_tweak_offset = HASH_LEN + PARAMETER_LEN; for (chain_index, packed_chain) in packed_chains.iter_mut().enumerate().take(num_chains) @@ -607,11 +617,17 @@ impl< let pos = (step + 1) as u8; // Assemble the packed input for the hash function. - // Layout: [parameter | tweak | current_value] + // Layout: [current_value | parameter | tweak] let mut packed_input = [PackedF::ZERO; CHAIN_COMPRESSION_WIDTH]; + // Copy current chain value (already packed) + packed_input[chain_value_offset..chain_value_offset + HASH_LEN] + .copy_from_slice(packed_chain); + // Copy pre-packed parameter - packed_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter); + packed_input + [chain_parameter_offset..chain_parameter_offset + PARAMETER_LEN] + .copy_from_slice(&packed_parameter); // Pack tweaks directly into destination pack_fn_into::( @@ -623,10 +639,6 @@ impl< }, ); - // Copy current chain value (already packed) - packed_input[chain_value_offset..chain_value_offset + HASH_LEN] - .copy_from_slice(packed_chain); - // Apply the hash function to advance the chain. // This single call processes all epochs in parallel. *packed_chain = @@ -678,7 +690,7 @@ impl< // Apply the sponge hash to produce the leaf. // This absorbs all chain ends and squeezes out the final hash. - poseidon_sponge::( + poseidon_replacement_sponge::( &sponge_perm, &capacity_val, packed_leaf_input,