Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::serialization::Serializable;
use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable};

/// A wrapper around an array of field elements that implements SSZ Encode/Decode.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct FieldArray<const N: usize>(pub [F; N]);

Expand Down
2 changes: 1 addition & 1 deletion src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl<IE: IncomparableEncoding, TH: TweakableHash> Decode for GeneralizedXMSSSign

/// Public key for GeneralizedXMSSSignatureScheme
/// It contains a Merkle root and a parameter for the tweakable hash
#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct GeneralizedXMSSPublicKey<TH: TweakableHash> {
root: TH::Domain,
parameter: TH::Parameter,
Expand Down
29 changes: 16 additions & 13 deletions src/signature/generalized_xmss/instantiations_aborting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ pub mod lifetime_2_to_the_32 {
use crate::{
inc_encoding::target_sum::TargetSumEncoding,
signature::generalized_xmss::{
GeneralizedXMSSPublicKey, GeneralizedXMSSSignature, GeneralizedXMSSSignatureScheme,
GeneralizedXMSSPublicKey, GeneralizedXMSSSecretKey, GeneralizedXMSSSignature,
GeneralizedXMSSSignatureScheme,
},
symmetric::{
message_hash::aborting::AbortingHypercubeMessageHash, prf::shake_to_field::ShakePRFtoF,
Expand Down Expand Up @@ -43,16 +44,18 @@ pub mod lifetime_2_to_the_32 {
type PRF = ShakePRFtoF<HASH_LEN_FE, RAND_LEN_FE>;
type IE = TargetSumEncoding<MH, TARGET_SUM>;

pub type SIGAbortingTargetSumLifetime32Dim64Base8 =
pub type SchemeAbortingTargetSumLifetime32Dim64Base8 =
GeneralizedXMSSSignatureScheme<PRF, IE, TH, LOG_LIFETIME>;
pub type PubKeyAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSPublicKey<TH>;
pub type SecretKeyAbortingTargetSumLifetime32Dim64Base8 =
GeneralizedXMSSSecretKey<PRF, IE, TH, LOG_LIFETIME>;
pub type SigAbortingTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignature<IE, TH>;

#[cfg(test)]
mod test {

#[cfg(feature = "slow-tests")]
use super::*;
use super::SchemeAbortingTargetSumLifetime32Dim64Base8;
#[cfg(feature = "slow-tests")]
use crate::signature::SignatureScheme;

Expand All @@ -62,15 +65,15 @@ pub mod lifetime_2_to_the_32 {
#[test]
#[cfg(feature = "slow-tests")]
pub fn test_correctness() {
test_signature_scheme_correctness::<SIGAbortingTargetSumLifetime32Dim64Base8>(
test_signature_scheme_correctness::<SchemeAbortingTargetSumLifetime32Dim64Base8>(
213,
0,
SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize,
SchemeAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize,
);
test_signature_scheme_correctness::<SIGAbortingTargetSumLifetime32Dim64Base8>(
test_signature_scheme_correctness::<SchemeAbortingTargetSumLifetime32Dim64Base8>(
4,
0,
SIGAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize,
SchemeAbortingTargetSumLifetime32Dim64Base8::LIFETIME as usize,
);
}
}
Expand Down Expand Up @@ -121,7 +124,7 @@ pub mod lifetime_2_to_the_6 {
type PRF = ShakePRFtoF<HASH_LEN_FE, RAND_LEN_FE>;
type IE = TargetSumEncoding<MH, TARGET_SUM>;

pub type SIGAbortingTargetSumLifetime6Dim46Base8 =
pub type SchemeAbortingTargetSumLifetime6Dim46Base8 =
GeneralizedXMSSSignatureScheme<PRF, IE, TH, LOG_LIFETIME>;

#[cfg(test)]
Expand All @@ -130,19 +133,19 @@ pub mod lifetime_2_to_the_6 {
SignatureScheme, test_templates::test_signature_scheme_correctness,
};

use super::SIGAbortingTargetSumLifetime6Dim46Base8;
use super::SchemeAbortingTargetSumLifetime6Dim46Base8;

#[test]
pub fn test_correctness() {
test_signature_scheme_correctness::<SIGAbortingTargetSumLifetime6Dim46Base8>(
test_signature_scheme_correctness::<SchemeAbortingTargetSumLifetime6Dim46Base8>(
2,
0,
SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize,
SchemeAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize,
);
test_signature_scheme_correctness::<SIGAbortingTargetSumLifetime6Dim46Base8>(
test_signature_scheme_correctness::<SchemeAbortingTargetSumLifetime6Dim46Base8>(
11,
0,
SIGAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize,
SchemeAbortingTargetSumLifetime6Dim46Base8::LIFETIME as usize,
);
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/symmetric/message_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rand::RngExt;
use crate::MESSAGE_LENGTH;
use crate::serialization::Serializable;

pub use poseidon::encode_message;

/// Trait to model a hash function used for message hashing.
///
/// This is a variant of a tweakable hash function that we use for
Expand Down
1 change: 1 addition & 0 deletions src/symmetric/message_hash/aborting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::array::FieldArray;
/// Given p = Q * w^z + alpha, each Poseidon output field element A_i is:
/// 1) checked to be less than Q * w^z, and if not the hash aborts
/// 2) decomposed as d_i = floor(A_i / Q), then d_i is written in base w with z digits.
#[derive(Debug, Clone, Copy)]
pub struct AbortingHypercubeMessageHash<
const PARAMETER_LEN: usize,
const RAND_LEN_FE: usize,
Expand Down
4 changes: 2 additions & 2 deletions src/symmetric/message_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ pub(crate) fn poseidon_message_hash_fe<
let epoch_fe = encode_epoch::<TWEAK_LEN_FE>(epoch);

// now, we hash randomness, parameters, epoch, message using PoseidonCompress
let combined_input_vec: Vec<F> = randomness
let combined_input_vec: Vec<F> = message_fe
.iter()
.chain(parameter.iter())
.chain(epoch_fe.iter())
.chain(message_fe.iter())
.chain(randomness.iter())
.copied()
.collect();

Expand Down
88 changes: 50 additions & 38 deletions src/symmetric/tweak_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ fn poseidon_safe_domain_separator<const OUT_LEN: usize>(
poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, OUT_LEN>(perm, &input)
}

/// Poseidon Sponge Hash Function
/// Poseidon Sponge with "Replacement" Hash Function
///
/// Absorbs an arbitrary-length input using the Poseidon sponge construction
/// and outputs `OUT_LEN` field elements. Domain separation is achieved by
Expand All @@ -179,13 +179,19 @@ fn poseidon_safe_domain_separator<const OUT_LEN: usize>(
/// - `input`: message to hash (any length).
///
/// ### Sponge Construction
/// This follows the classic sponge structure:
/// - **Absorption**: inputs are added chunk-by-chunk into the first `rate` elements of the state.
/// - **Squeezing**: outputs are read from the first `rate` elements of the state, permuted as needed.
/// This follows the classic sponge structure with capacity-first layout:
/// - The state is `[capacity | rate]`, i.e., the first elements hold the capacity,
/// followed by the rate elements.
/// - **Absorption**: inputs are written into the rate part of the state (`state[cap_len..]`).
/// - **Squeezing**: outputs are read from the rate part of the state, permuted as needed.
///
/// ### "Replacement"
/// This means we "replace" the rate elements of the state with the input chunk, instead
/// of adding (in the sense of finite field addition).
///
/// ### Panics
/// - If `capacity_value.len() >= WIDTH`
fn poseidon_sponge<A, P, const WIDTH: usize, const OUT_LEN: usize>(
fn poseidon_replacement_sponge<A, P, const WIDTH: usize, const OUT_LEN: usize>(
perm: &P,
capacity_value: &[A],
input: &[A],
Expand All @@ -200,11 +206,12 @@ where
capacity_value.len() < WIDTH,
"Capacity length must be smaller than the state width."
);
let rate = WIDTH - capacity_value.len();
let cap_len = capacity_value.len();
let rate = WIDTH - cap_len;

// initialize
let mut state = [A::ZERO; WIDTH];
state[rate..].copy_from_slice(capacity_value);
state[..cap_len].copy_from_slice(capacity_value);

// Instead of converting the input to a vector, resizing and feeding the data into the
// sponge, we instead fill in the vector from all chunks until we are left with a non
Expand All @@ -213,20 +220,22 @@ where
// 1. fill in all full chunks and permute
let mut it = input.chunks_exact(rate);
for chunk in &mut it {
// add chunk elements into the first `rate` many elements of the `state`
for (s, &x) in state.iter_mut().take(rate).zip(chunk) {
*s += x;
// write chunk elements into the `rate` part of the state
for (s, &x) in state[cap_len..].iter_mut().zip(chunk) {
*s = x; // 'replacement' sponge
}
perm.permute_mut(&mut state);
}
// 2. Fill the remainder and pad with zeros.
// NOTE: This zero-padding is secure for constant-size inputs but may be insecure elsewhere.
if !it.remainder().is_empty() {
let num_remainder = it.remainder().len();
for (i, x) in it.remainder().iter().enumerate() {
state[i] += *x;
state[cap_len + i] = *x;
}
for s in &mut state[cap_len + num_remainder..] {
*s = A::ZERO;
}
// Since we only *add* to the state, positions beyond the remainder remain zero
// (their initial value), so no explicit zero-padding is needed.
perm.permute_mut(&mut state);
}

Expand All @@ -235,7 +244,7 @@ where
let mut out_index = 0;
while out_index < OUT_LEN {
let chunk_size = (OUT_LEN - out_index).min(rate);
out[out_index..out_index + chunk_size].copy_from_slice(&state[..chunk_size]);
out[out_index..out_index + chunk_size].copy_from_slice(&state[cap_len..][..chunk_size]);
out_index += chunk_size;
if out_index < OUT_LEN {
// no need to permute in last iteration, `state` is local variable
Expand All @@ -249,7 +258,7 @@ where
///
/// Note: HASH_LEN, TWEAK_LEN, CAPACITY, and PARAMETER_LEN must
/// be given in the unit "number of field elements".
#[derive(Clone)]
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct PoseidonTweakHash<
const PARAMETER_LEN: usize,
const HASH_LEN: usize,
Expand Down Expand Up @@ -343,18 +352,17 @@ impl<

match message {
[single] => {
// we compress parameter, tweak, message
// we compress message, parameter, tweak
let perm = poseidon1_16();

// Build input on stack: [parameter | tweak | message]
// Build input on stack: [message | parameter | tweak]
let mut combined_input = [F::ZERO; CHAIN_COMPRESSION_WIDTH];
combined_input[..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[PARAMETER_LEN..PARAMETER_LEN + TWEAK_LEN].copy_from_slice(&tweak_fe);
combined_input[PARAMETER_LEN + TWEAK_LEN..PARAMETER_LEN + TWEAK_LEN + HASH_LEN]
.copy_from_slice(&single.0);
combined_input[..HASH_LEN].copy_from_slice(&single.0);
combined_input[HASH_LEN..][..PARAMETER_LEN].copy_from_slice(&parameter.0);
combined_input[HASH_LEN + PARAMETER_LEN..][..TWEAK_LEN].copy_from_slice(&tweak_fe);

FieldArray(
poseidon_compress::<F, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
poseidon_compress::<_, _, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
&combined_input,
),
Expand All @@ -376,7 +384,7 @@ impl<
.copy_from_slice(&right.0);

FieldArray(
poseidon_compress::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
poseidon_compress::<_, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
&combined_input,
),
Expand All @@ -400,11 +408,12 @@ impl<
HASH_LEN as u32,
];
let capacity_value = poseidon_safe_domain_separator::<CAPACITY>(&perm, &lengths);
FieldArray(poseidon_sponge::<F, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&perm,
&capacity_value,
&combined_input,
))
FieldArray(poseidon_replacement_sponge::<
_,
_,
MERGE_COMPRESSION_WIDTH,
HASH_LEN,
>(&perm, &capacity_value, &combined_input))
}
_ => FieldArray([F::ONE; HASH_LEN]), // Unreachable case, added for safety
}
Expand Down Expand Up @@ -593,9 +602,10 @@ impl<
// Cache strategy: process one chain at a time to maximize locality.
// All epochs for that chain stay in registers across iterations.

// Offsets for chain compression: [parameter | tweak | current_value]
let chain_tweak_offset = PARAMETER_LEN;
let chain_value_offset = PARAMETER_LEN + TWEAK_LEN;
// Offsets for chain compression: [current_value | parameter | tweak]
let chain_value_offset = 0;
let chain_parameter_offset = HASH_LEN;
let chain_tweak_offset = HASH_LEN + PARAMETER_LEN;

for (chain_index, packed_chain) in
packed_chains.iter_mut().enumerate().take(num_chains)
Expand All @@ -607,11 +617,17 @@ impl<
let pos = (step + 1) as u8;

// Assemble the packed input for the hash function.
// Layout: [parameter | tweak | current_value]
// Layout: [current_value | parameter | tweak]
let mut packed_input = [PackedF::ZERO; CHAIN_COMPRESSION_WIDTH];

// Copy current chain value (already packed)
packed_input[chain_value_offset..chain_value_offset + HASH_LEN]
.copy_from_slice(packed_chain);

// Copy pre-packed parameter
packed_input[..PARAMETER_LEN].copy_from_slice(&packed_parameter);
packed_input
[chain_parameter_offset..chain_parameter_offset + PARAMETER_LEN]
.copy_from_slice(&packed_parameter);

// Pack tweaks directly into destination
pack_fn_into::<TWEAK_LEN>(
Expand All @@ -623,10 +639,6 @@ impl<
},
);

// Copy current chain value (already packed)
packed_input[chain_value_offset..chain_value_offset + HASH_LEN]
.copy_from_slice(packed_chain);

// Apply the hash function to advance the chain.
// This single call processes all epochs in parallel.
*packed_chain =
Expand Down Expand Up @@ -678,7 +690,7 @@ impl<

// Apply the sponge hash to produce the leaf.
// This absorbs all chain ends and squeezes out the final hash.
poseidon_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
poseidon_replacement_sponge::<PackedF, _, MERGE_COMPRESSION_WIDTH, HASH_LEN>(
&sponge_perm,
&capacity_val,
packed_leaf_input,
Expand Down
Loading