diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index bf077581..64c8db01 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -19,10 +19,9 @@ File format: - Each key pair is stored in a separate JSON file with hex-encoded SSZ. -- Directory structure: ``test_keys/{scheme}_scheme/{index}.json`` +- Directory structure: `test_keys/{scheme}_scheme/{index}.json` - Each file has four hex-encoded SSZ fields: - ``attestation_public``, ``attestation_secret``, - ``proposal_public``, ``proposal_secret`` + `attestation_public`, `attestation_secret`, `proposal_public`, `proposal_secret` """ from __future__ import annotations @@ -38,11 +37,10 @@ from concurrent.futures import ProcessPoolExecutor from functools import partial from pathlib import Path -from typing import Literal +from typing import ClassVar, Literal from lean_spec.config import LEAN_ENV from lean_spec.subspecs.containers import AttestationData, ValidatorIndex -from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.block.types import ( AggregatedAttestations, AttestationSignatures, @@ -66,285 +64,310 @@ from lean_spec.types import Bytes32, Uint64 SecretField = Literal["attestation_secret", "proposal_secret"] -"""The two secret key field names on ValidatorKeyPair.""" - -__all__ = [ - "CLI_DEFAULT_MAX_SLOT", - "KEY_DOWNLOAD_URLS", - "LEAN_ENV_TO_SCHEMES", - "LazyKeyDict", - "NUM_VALIDATORS", - "XmssKeyManager", - "create_dummy_signature", - "download_keys", - "get_keys_dir", - "get_shared_key_manager", -] +"""Discriminator for which secret key to load from a validator key pair.""" KEY_DOWNLOAD_URLS = { "test": "https://github.com/leanEthereum/leansig-test-keys/releases/download/leanSpec-ad9a3226/test_scheme.tar.gz", "prod": "https://github.com/leanEthereum/leansig-test-keys/releases/download/leanSpec-ad9a3226/prod_scheme.tar.gz", } -"""URLs for downloading pre-generated keys.""" +""" +GitHub release URLs for pre-generated key archives. + +Keyed by scheme name ("test" or "prod"). +Each URL points to a tar.gz containing per-validator JSON files. +""" -LEAN_ENV_TO_SCHEMES = { +LEAN_ENV_TO_SCHEMES: dict[str, GeneralizedXmssScheme] = { "test": TEST_SIGNATURE_SCHEME, "prod": PROD_SIGNATURE_SCHEME, } """ -Mapping from short name to scheme objects. This mapping is useful for: - -- The CLI argument for choosing the signature scheme to generate -- Deriving the file name for the cached keys -- Caching key managers in test fixtures -""" +Maps short scheme names to their XMSS scheme instances. -_KEY_MANAGER_CACHE: dict[tuple[str, Slot], XmssKeyManager] = {} -"""Cache for key managers: {(scheme_name, max_slot): XmssKeyManager}""" +Used for: -_SHARED_MANAGER_MAX_SLOT: Slot = Slot(10) -"""Default max slot for the shared key manager.""" +- CLI argument validation +- Deriving on-disk directory names for cached keys +- Keying the per-scheme manager cache in test fixtures +""" def create_dummy_signature() -> Signature: """ - Create a structurally valid but cryptographically invalid individual signature. + Create a structurally valid but cryptographically meaningless signature. + + All fields are zero-filled. + The result has correct dimensions so it passes structural checks, + but it will fail any cryptographic verification. - The signature has proper structure (correct number of siblings, hashes, etc.) - but all values are zeros, so it will fail cryptographic verification. + Returns: + A zero-valued signature with correct field sizes. """ - # Create zero-filled hash digests with correct dimensions - zero_digest = HashDigestVector(data=[Fp(0) for _ in range(TARGET_CONFIG.HASH_LEN_FE)]) + # Build a single zero-filled hash digest with the scheme's hash length. + zero_digest = HashDigestVector(data=[Fp(0)] * TARGET_CONFIG.HASH_LEN_FE) - # Path needs LOG_LIFETIME siblings for the Merkle authentication path - siblings = HashDigestList(data=[zero_digest for _ in range(TARGET_CONFIG.LOG_LIFETIME)]) + # The Merkle authentication path needs one sibling per tree level. + # + # The tree height equals the log of the key lifetime. + siblings = HashDigestList(data=[zero_digest] * TARGET_CONFIG.LOG_LIFETIME) - # Hashes need DIMENSION vectors for the Winternitz chain hashes - hashes = HashDigestList(data=[zero_digest for _ in range(TARGET_CONFIG.DIMENSION)]) + # Winternitz one-time signatures use one hash chain per dimension. + hashes = HashDigestList(data=[zero_digest] * TARGET_CONFIG.DIMENSION) + # Assemble a complete signature with all components zeroed out. return Signature( path=HashTreeOpening(siblings=siblings), - rho=Randomness(data=[Fp(0) for _ in range(TARGET_CONFIG.RAND_LEN_FE)]), + rho=Randomness(data=[Fp(0)] * TARGET_CONFIG.RAND_LEN_FE), hashes=hashes, ) -def get_shared_key_manager(max_slot: Slot = _SHARED_MANAGER_MAX_SLOT) -> XmssKeyManager: - """ - Get a shared XMSS key manager for reusing keys across tests. +DEFAULT_MAX_SLOT = Slot(10) +""" +Default max slot for the shared key manager. - Implements caching that reuses key managers with sufficient capacity. - If a cached key manager exists with max slot >= the requested max slot, it will - be reused instead of creating a new one. +Slot 10 is high enough for most unit tests while keeping key generation fast. +""" + +NUM_VALIDATORS: int = 12 +""" +Default number of validator key pairs. + +Twelve validators is enough to exercise committee logic while keeping +key generation and test execution fast. +""" + +CLI_DEFAULT_MAX_SLOT = Slot(100) +""" +Maximum slot when generating keys via CLI (inclusive). + +One hundred slots provides ample signing headroom for typical test scenarios. +""" + + +def get_keys_dir(scheme_name: str) -> Path: + """ + Resolve the on-disk directory that holds key files for a scheme. Args: - max_slot: Maximum slot for which XMSS keys should be valid. Defaults to 10 slots. + scheme_name: Short scheme identifier (e.g. "test" or "prod"). Returns: - Shared XmssKeyManager instance for the target scheme that supports at least max slot. + Absolute path to the scheme's key directory. + """ + return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme" + + +class XmssKeyManager: """ - scheme = LEAN_ENV_TO_SCHEMES[LEAN_ENV] + Stateful manager for XMSS signing in tests. - # Check if we have a cached key manager with sufficient capacity - for (cached_lean_env, cached_max_slot), manager in _KEY_MANAGER_CACHE.items(): - if cached_lean_env == LEAN_ENV and cached_max_slot >= max_slot: - return manager + XMSS is a stateful signature scheme. - # No suitable cached manager found, create a new one - manager = XmssKeyManager(max_slot=max_slot, scheme=scheme) - _KEY_MANAGER_CACHE[(LEAN_ENV, max_slot)] = manager - return manager + Each signing operation consumes a one-time leaf and advances the key state forward. + This manager tracks that state across slots and validators. + Keys are lazily loaded from disk on first access, with a three-tier cache: -NUM_VALIDATORS: int = 12 -"""Default number of validator key pairs.""" + - Raw JSON (lightweight hex strings, ~2.7 KB per validator) + - Deserialized public keys only (avoids the heavy secret key objects) + - Advanced secret key state as compact SSZ bytes + """ -CLI_DEFAULT_MAX_SLOT = Slot(100) -"""Maximum slot for CLI-generated test signatures (inclusive).""" + __slots__ = ( + "max_slot", + "scheme_name", + "scheme", + "_keys_dir", + "_json_cache", + "_public_cache", + "_available_indices", + "_secret_state", + ) + _cache: ClassVar[dict[str, XmssKeyManager]] = {} + """ + Per-scheme singleton cache for shared managers. -def get_keys_dir(scheme_name: str) -> Path: - """Get the keys directory path for the given scheme.""" - return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme" + Replaced when a caller requests a larger max slot than what is cached. + """ + + @classmethod + def shared(cls, max_slot: Slot = DEFAULT_MAX_SLOT) -> XmssKeyManager: + """ + Return a shared manager, creating or replacing it as needed. + + The cache holds one manager per scheme. + If the cached manager already covers the requested slot range, reuse it. + Otherwise, create a fresh one with the wider range. + + Args: + max_slot: Highest slot the manager must support. Defaults to 10. + Returns: + A manager valid for at least the requested slot range. + """ + # A cached manager is usable if its range covers the requested max slot. + cached = cls._cache.get(LEAN_ENV) + if cached is not None and cached.max_slot >= max_slot: + return cached -class LazyKeyDict(Mapping[ValidatorIndex, ValidatorKeyPair]): - """Load pre-generated keys from disk (cached after first call).""" + # No suitable cached manager exists. Build a new one and cache it. + manager = cls(max_slot=max_slot, scheme_name=LEAN_ENV) + cls._cache[LEAN_ENV] = manager + return manager - def __init__(self, scheme_name: str) -> None: - """Initialize with scheme name for locating key files.""" - self._scheme_name = scheme_name + def __init__( + self, + max_slot: Slot = DEFAULT_MAX_SLOT, + scheme_name: str = "test", + ) -> None: + """Initialize with a scheme name and maximum slot for key validity.""" + if scheme_name not in LEAN_ENV_TO_SCHEMES: + raise ValueError(f"Unknown scheme: {scheme_name!r}") + self.max_slot = max_slot + self.scheme_name = scheme_name + self.scheme = LEAN_ENV_TO_SCHEMES[scheme_name] self._keys_dir = get_keys_dir(scheme_name) - self._cache: dict[ValidatorIndex, ValidatorKeyPair] = {} + + # Raw JSON cache: hex-encoded SSZ strings, very lightweight. + self._json_cache: dict[ValidatorIndex, dict[str, str]] = {} + + # Deserialized public key pairs, still avoids secret key overhead. self._public_cache: dict[ValidatorIndex, tuple[PublicKey, PublicKey]] = {} - self._raw_cache: dict[ValidatorIndex, dict[str, str]] = {} + + # Populated lazily on first directory scan. self._available_indices: set[ValidatorIndex] | None = None - def _ensure_dir_exists(self) -> None: - """Raise FileNotFoundError if the keys directory does not exist.""" - if not self._keys_dir.exists(): - raise FileNotFoundError( - f"Keys directory not found: {self._keys_dir} - " - f"Run: python -m consensus_testing.keys --scheme {self._scheme_name}" - ) + # Advanced secret key state cached as raw SSZ bytes. + # Raw bytes (~2.7 KB each) instead of deserialized objects (~370 MB each) + # to avoid holding massive Pydantic model trees in memory. + self._secret_state: dict[tuple[ValidatorIndex, SecretField], bytes] = {} + + def _scan_indices(self) -> set[ValidatorIndex]: + """ + Discover which validator indices have key files on disk. - def _get_available_indices(self) -> set[ValidatorIndex]: - """Scan directory for available key indices (cached).""" + The result is cached after the first call. + + Returns: + Set of validator indices with available key files. + + Raises: + FileNotFoundError: If the directory is missing or empty. + """ if self._available_indices is None: - self._ensure_dir_exists() + # Verify the key directory exists before scanning. + if not self._keys_dir.exists(): + raise FileNotFoundError( + f"Keys directory not found: {self._keys_dir} - " + f"Run: python -m consensus_testing.keys --scheme {self.scheme_name}" + ) + + # Each JSON file is named by its validator index (e.g. "0.json"). self._available_indices = { ValidatorIndex(int(f.stem)) for f in self._keys_dir.glob("*.json") } + + # An empty directory is as bad as a missing one. if not self._available_indices: raise FileNotFoundError( f"No key files found in: {self._keys_dir} - " - f"Run: python -m consensus_testing.keys --scheme {self._scheme_name}" + f"Run: python -m consensus_testing.keys --scheme {self.scheme_name}" ) return self._available_indices - def _load_raw(self, idx: ValidatorIndex) -> dict[str, str]: - """Load raw JSON data from disk (cached).""" - if idx not in self._raw_cache: + def _load_json(self, idx: ValidatorIndex) -> dict[str, str]: + """ + Load raw JSON for a single validator, caching the result. + + The JSON contains four hex-encoded SSZ fields. + Keeping them as strings avoids the cost of deserializing secret keys. + + Args: + idx: Validator index to load. + + Returns: + Dictionary of hex-encoded SSZ field strings. + + Raises: + KeyError: If no key file exists for the index. + """ + if idx not in self._json_cache: + # Resolve the per-validator JSON file path. key_file = self._keys_dir / f"{idx}.json" try: - self._raw_cache[idx] = json.loads(key_file.read_text()) + with key_file.open() as f: + self._json_cache[idx] = json.load(f) except FileNotFoundError: raise KeyError(f"Key file not found: {key_file}") from None - return self._raw_cache[idx] - - def _load_key(self, idx: ValidatorIndex) -> ValidatorKeyPair: - """Load a single key from disk.""" - return ValidatorKeyPair.from_dict(self._load_raw(idx)) - - def _load_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: - """Load only public keys from disk, skipping expensive SecretKey deserialization.""" - data = self._load_raw(idx) - return ( - PublicKey.decode_bytes(bytes.fromhex(data["attestation_public"])), - PublicKey.decode_bytes(bytes.fromhex(data["proposal_public"])), - ) + return self._json_cache[idx] - def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: + def _get_secret_key(self, idx: ValidatorIndex, field: SecretField) -> SecretKey: """ - Get (attestation_public, proposal_public) without loading secret keys. + Deserialize a single secret key from disk. - Returns cached public keys if available, otherwise loads only the public - key portions from disk. Avoids deserializing the heavy SecretKey objects - (each ~2.7KB raw with 3 HashSubTree structures) until signing is needed. - """ - if idx in self._cache: - kp = self._cache[idx] - return (kp.attestation_public, kp.proposal_public) - if idx not in self._public_cache: - self._public_cache[idx] = self._load_public_keys(idx) - return self._public_cache[idx] + Only the requested field is decoded into a full Python object. + The other three fields remain as lightweight hex strings in the cache. - def get_secret_key(self, idx: ValidatorIndex, field: SecretField) -> SecretKey: - """ - Load a specific secret key from disk without deserializing the other keys. + Args: + idx: Validator index to look up. + field: Which secret key to decode (attestation or proposal). - Only the requested SecretKey is deserialized (~370 MB in Python objects). - The other three fields remain as lightweight hex strings (~2.7 KB each). + Returns: + The deserialized secret key. """ - if idx in self._cache: - return getattr(self._cache[idx], field) - data = self._load_raw(idx) + # Load the raw JSON (cached), then decode only the requested field. + data = self._load_json(idx) return SecretKey.decode_bytes(bytes.fromhex(data[field])) def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair: - """Get key pair by validator index, loading from disk if needed.""" - if idx not in self._cache: - self._cache[idx] = self._load_key(idx) - # Full pair supersedes raw/public caches for this index. - self._raw_cache.pop(idx, None) - self._public_cache.pop(idx, None) - return self._cache[idx] - - def __contains__(self, idx: object) -> bool: - """Check if a key exists for the given validator index.""" - if not isinstance(idx, ValidatorIndex): - return False - return idx in self._get_available_indices() - - def __len__(self) -> int: - """Return the number of available keys.""" - return len(self._get_available_indices()) - - def __iter__(self) -> Iterator[ValidatorIndex]: - """Iterate over available validator indices in sorted order.""" - return iter(sorted(self._get_available_indices())) - - -_LAZY_KEY_CACHE: dict[str, LazyKeyDict] = {} -"""Cache for lazy key dictionaries by scheme name.""" - - -class XmssKeyManager: - """ - Stateful manager for XMSS signing operations. - - Handles automatic key state advancement for the stateful XMSS scheme. - - Keys are lazily loaded from disk on first access. - """ - - def __init__( - self, - max_slot: Slot, - scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME, - ) -> None: - """Initialize the manager with optional custom configuration.""" - self.max_slot = max_slot - self.scheme = scheme - self._secret_state: dict[tuple[ValidatorIndex, SecretField], bytes] = {} """ - Advanced secret key state cached as raw SSZ bytes. + Fully deserialize a key pair including secrets. - Raw bytes (~2.7 KB each) instead of deserialized SecretKey objects - (~370 MB each) to avoid holding massive Pydantic model trees in memory. + Prefer using the public-key or signing accessors to avoid loading + heavy secret key objects unnecessarily. """ - try: - self.scheme_name = next( - name for name, obj in LEAN_ENV_TO_SCHEMES.items() if obj is scheme - ) - except StopIteration: - raise ValueError(f"Unknown scheme: {scheme}") from None - - @property - def keys(self) -> LazyKeyDict: - """Lazy access to immutable base keys.""" - if self.scheme_name not in _LAZY_KEY_CACHE: - _LAZY_KEY_CACHE[self.scheme_name] = LazyKeyDict(self.scheme_name) - return _LAZY_KEY_CACHE[self.scheme_name] - - def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair: - """Get key pair. Prefer get_public_keys() or signing methods to avoid loading all keys.""" - if idx not in self.keys: - raise KeyError(f"Validator {idx} not found (max: {len(self.keys) - 1})") - return self.keys[idx] + return ValidatorKeyPair.from_dict(self._load_json(idx)) + except KeyError: + raise KeyError(f"Validator {idx} not found (available: {len(self)})") from None def __contains__(self, idx: object) -> bool: - """Check if validator index exists.""" + """Check whether a validator index has keys on disk.""" if not isinstance(idx, ValidatorIndex): return False - return idx in self.keys + return idx in self._scan_indices() def __len__(self) -> int: - """Number of available validators.""" - return len(self.keys) + """Return the number of available validator key pairs.""" + return len(self._scan_indices()) def __iter__(self) -> Iterator[ValidatorIndex]: - """Iterate over validator indices.""" - return iter(self.keys) + """Iterate over validator indices in ascending order.""" + return iter(sorted(self._scan_indices())) def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: """ - Get (attestation_public, proposal_public) without loading secret keys. + Return attestation and proposal public keys without touching secrets. + + Only the public key portions are deserialized from the hex JSON. + Secret keys (~2.7 KB raw, ~370 MB as Python objects) are not touched. - Delegates to lazy disk loading that skips SecretKey deserialization. + Args: + idx: Validator index to look up. + + Returns: + Tuple of (attestation public key, proposal public key). """ - return self.keys.get_public_keys(idx) + if idx not in self._public_cache: + # Decode only the two public key fields from the raw JSON. + data = self._load_json(idx) + self._public_cache[idx] = ( + PublicKey.decode_bytes(bytes.fromhex(data["attestation_public"])), + PublicKey.decode_bytes(bytes.fromhex(data["proposal_public"])), + ) + return self._public_cache[idx] def _sign_with_secret( self, @@ -354,42 +377,55 @@ def _sign_with_secret( secret_field: SecretField, ) -> Signature: """ - Shared signing logic for attestation/proposal paths. + Core signing logic shared by attestation and proposal paths. + + XMSS keys have a "prepared interval" -- the range of slots the key + can currently sign for. If the target slot falls outside that range, + the key state must be advanced forward until the slot is covered. - Handles XMSS state advancement until the requested slot is within the - prepared interval, caches the updated secret as raw bytes, and produces - the signature. + Memory strategy: - Only the needed SecretKey is deserialized (~370 MB in Python objects). - After signing, the advanced state is re-serialized to compact bytes - (~2.7 KB) so only one SecretKey is in memory at a time. + 1. Deserialize the secret key from cached bytes or disk + 2. Advance and sign (only one full key object in memory) + 3. Re-serialize to compact bytes (~2.7 KB) for caching Args: - validator_id: Validator index whose key should be used. - slot: The slot to sign for. - message: The message bytes to sign. - secret_field: Which secret on the key pair should advance. + validator_id: Which validator's key to use. + slot: Target slot to sign for. + message: The 32-byte message digest to sign. + secret_field: Which secret key (attestation or proposal) to advance. + + Raises: + ValueError: If the slot exceeds the key's total lifetime. """ cache_key = (validator_id, secret_field) - # Deserialize the secret key: from cached bytes or from disk. + # Deserialize the secret key from either the byte cache or disk. if cache_key in self._secret_state: sk = SecretKey.decode_bytes(self._secret_state[cache_key]) else: - sk = self.keys.get_secret_key(validator_id, secret_field) + sk = self._get_secret_key(validator_id, secret_field) - # Advance key state until the slot is ready for signing. + # Advance the key state until the target slot falls within the prepared interval. + # + # Each advancement step extends the interval by consuming the next one-time signing leaf. prepared = self.scheme.get_prepared_interval(sk) while int(slot) not in prepared: activation = self.scheme.get_activation_interval(sk) + + # If the prepared interval already reaches the activation boundary, + # no further advancement is possible, the key is exhausted. if prepared.stop >= activation.stop: raise ValueError(f"Slot {slot} exceeds key lifetime {activation.stop}") + sk = self.scheme.advance_preparation(sk) prepared = self.scheme.get_prepared_interval(sk) + # Produce the signature for the target slot. signature = self.scheme.sign(sk, slot, message) - # Cache advanced state as raw bytes to keep memory compact. + # Re-serialize the advanced key state to compact bytes for caching. + # This drops the full Python object tree from memory immediately. self._secret_state[cache_key] = sk.encode_bytes() return signature @@ -400,21 +436,23 @@ def sign_attestation_data( attestation_data: AttestationData, ) -> Signature: """ - Sign attestation data with the attestation key. + Sign attestation data using the validator's attestation key. - XMSS is stateful: this delegates to the shared helper which advances the - attestation key state as needed while leaving the proposal key untouched. + Advances only the attestation key state. + The proposal key remains untouched. Args: - validator_id: The validator index to sign the attestation data for. - attestation_data: The attestation data to sign. + validator_id: Which validator signs. + attestation_data: The attestation to sign. Returns: - XMSS signature. + XMSS signature over the attestation data root. Raises: - ValueError: If slot exceeds key lifetime. + ValueError: If the attestation slot exceeds key lifetime. """ + # Derive the message digest from the attestation data and delegate + # to the shared signing logic with the attestation secret. return self._sign_with_secret( validator_id, attestation_data.slot, @@ -429,21 +467,21 @@ def sign_block_root( block_root: Bytes32, ) -> Signature: """ - Sign a block root with the proposal key. + Sign a block root using the validator's proposal key. - Advances the proposal key state until the requested slot is within the - prepared interval, then signs the block root. + Advances only the proposal key state. + The attestation key remains untouched. Args: - validator_id: The validator index to sign the block for. - slot: The slot of the block being signed. - block_root: The hash_tree_root(block) to sign. + validator_id: Which validator signs. + slot: Slot of the block being proposed. + block_root: The hash tree root of the block. Returns: - XMSS signature. + XMSS signature over the block root. Raises: - ValueError: If slot exceeds key lifetime. + ValueError: If the slot exceeds key lifetime. """ return self._sign_with_secret(validator_id, slot, block_root, "proposal_secret") @@ -454,38 +492,53 @@ def build_attestation_signatures( | None = None, ) -> AttestationSignatures: """ - Build attestation signatures for already-aggregated attestations. + Produce aggregated signature proofs for a list of attestations. + + For each aggregated attestation: + + 1. Identify participating validators from the aggregation bitfield + 2. Collect each participant's public key and individual signature + 3. Combine them into a single aggregated proof for the leanVM verifier + + Pre-computed signatures can be supplied via the lookup to avoid + redundant signing. Missing signatures are computed on the fly. - For each aggregated attestation, collect the participating validators' public keys and - signatures, then produce a single leanVM aggregated signature proof. + Args: + aggregated_attestations: Attestations with aggregation bitfields set. + signature_lookup: Optional pre-computed signatures keyed by + attestation data then validator index. + + Returns: + One aggregated signature proof per attestation. """ lookup = signature_lookup or {} proofs: list[AggregatedSignatureProof] = [] for agg in aggregated_attestations: + # Decode which validators participated from the bitfield. validator_ids = agg.aggregation_bits.to_validator_indices() - message = agg.data.data_root_bytes() - slot = agg.data.slot - # Look up pre-computed signatures by attestation data and validator ID. + # Try the lookup first for pre-computed signatures. + # Fall back to signing on the fly for any missing entries. sigs_for_data = lookup.get(agg.data, {}) - public_keys: list[PublicKey] = [self.get_public_keys(vid)[0] for vid in validator_ids] - signatures: list[Signature] = [ + # Collect the attestation public key for each participant. + public_keys = [self.get_public_keys(vid)[0] for vid in validator_ids] + + # Gather individual signatures, computing any that are missing. + signatures = [ sigs_for_data.get(vid) or self.sign_attestation_data(vid, agg.data) for vid in validator_ids ] - # If the caller supplied raw signatures and any are invalid, - # aggregation should fail with exception. - participants = AggregationBits.from_validator_indices(validator_ids) - raw_xmss = list(zip(public_keys, signatures, strict=True)) + # Produce a single aggregated proof that the leanVM can verify + # in one pass over all participants. proof = AggregatedSignatureProof.aggregate( - xmss_participants=participants, + xmss_participants=agg.aggregation_bits, children=[], - raw_xmss=raw_xmss, - message=message, - slot=slot, + raw_xmss=list(zip(public_keys, signatures, strict=True)), + message=agg.data.data_root_bytes(), + slot=agg.data.slot, ) proofs.append(proof) @@ -494,31 +547,48 @@ def build_attestation_signatures( def _generate_single_keypair( scheme: GeneralizedXmssScheme, num_slots: int, index: int -) -> dict[str, str]: - """Generate dual key pairs for one validator (module-level for pickling).""" +) -> ValidatorKeyPair: + """ + Generate attestation and proposal key pairs for one validator. + + Defined at module level so it can be pickled for multiprocessing. + + Args: + scheme: XMSS scheme instance to use for key generation. + num_slots: Total number of slots the keys must cover. + index: Validator index (used only for progress logging). + + Returns: + Complete key pair with both attestation and proposal keys. + """ print(f"Starting key #{index} generation...") + + # Generate two independent key pairs: one for attestations, one for proposals. + # + # Separate keys allow signing both roles within the same slot + # without exhausting a one-time leaf. att_pk, att_sk = scheme.key_gen(Slot(0), Uint64(num_slots)) prop_pk, prop_sk = scheme.key_gen(Slot(0), Uint64(num_slots)) + return ValidatorKeyPair( attestation_public=att_pk, attestation_secret=att_sk, proposal_public=prop_pk, proposal_secret=prop_sk, - ).to_dict() + ) def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: """ - Generate XMSS key pairs in parallel and save to individual files. + Generate XMSS key pairs in parallel and write each to a separate file. - Uses ProcessPoolExecutor to saturate CPU cores for faster generation. - Each keypair is saved to a separate file to avoid the keyfile being - very large for production keys. + Each validator gets its own JSON file to keep individual files small, + which matters especially for production-scheme keys. Args: - lean_env: Name of the XMSS signature scheme to use (e.g. "test" or "prod"). - count: Number of validators. - max_slot: Maximum slot (key lifetime = max_slot + 1 slots). + lean_env: Scheme name (e.g. "test" or "prod"). + count: Number of validator key pairs to generate. + max_slot: Maximum signable slot (key lifetime = max_slot + 1 slots). """ scheme = LEAN_ENV_TO_SCHEMES[lean_env] keys_dir = get_keys_dir(lean_env) @@ -530,81 +600,70 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: f"({num_slots} slots) using {num_workers} cores..." ) - with ProcessPoolExecutor(max_workers=num_workers) as executor: - worker_func = partial(_generate_single_keypair, scheme, num_slots) - key_pairs = list(executor.map(worker_func, range(count))) - - # Create keys directory (remove old one if it exists) - if keys_dir.exists(): - shutil.rmtree(keys_dir) + # Ensure the output directory exists. keys_dir.mkdir(parents=True, exist_ok=True) - # Save each keypair to a separate file - for idx, key_pair in enumerate(key_pairs): - key_file = keys_dir / f"{idx}.json" - key_file.write_text(json.dumps(key_pair, indent=2)) + # Remove stale key files from previous runs that may have generated + # a different number of keys. + for old_file in keys_dir.glob("*.json"): + old_file.unlink() - print(f"Saved {len(key_pairs)} key pairs to {keys_dir}/") + # Generate key pairs in parallel across all CPU cores. + # Results arrive in index order thanks to executor.map. + with ProcessPoolExecutor(max_workers=num_workers) as executor: + worker_func = partial(_generate_single_keypair, scheme, num_slots) + for idx, key_pair in enumerate(executor.map(worker_func, range(count))): + # Serialize and write each key pair as a separate JSON file named by index. + key_file = keys_dir / f"{idx}.json" + key_file.write_text(json.dumps(key_pair.to_dict(), indent=2)) - # Clear cache so new keys are loaded - _LAZY_KEY_CACHE.clear() + print(f"Saved {count} key pairs to {keys_dir}/") def download_keys(scheme: str) -> None: """ - Download pre-generated XMSS key pairs from GitHub releases. + Download pre-generated key pairs from a GitHub release. - Downloads and extracts tar.gz archive for the specified scheme - into its respective directory. + Downloads a tar.gz archive for the specified scheme, removes any + existing keys for that scheme, and extracts the archive in place. Args: - scheme: Scheme name to download (e.g., 'test' or 'prod'). + scheme: Scheme name ("test" or "prod"). """ base_dir = Path(__file__).parent / "test_keys" url = KEY_DOWNLOAD_URLS[scheme] print(f"Downloading {scheme} keys from {url}...") - # Download to a temporary file + # Download to a temporary file to avoid partial-write corruption. with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) try: + # Stream the response directly into the temp file. with urllib.request.urlopen(url) as response: - tmp_file.write(response.read()) - tmp_path = tmp_file.name - except Exception as e: - print(f"Failed to download {scheme} keys: {e}") - return - - # Extract the archive - try: - target_dir = base_dir / f"{scheme}_scheme" - - # Remove existing directory if present - if target_dir.exists(): - shutil.rmtree(target_dir) - - # Create parent directory - base_dir.mkdir(parents=True, exist_ok=True) + shutil.copyfileobj(response, tmp_file) - # Extract tar.gz - with tarfile.open(tmp_path, "r:gz") as tar: - tar.extractall(path=base_dir, filter="data") + # Remove any existing keys for this scheme before extracting. + target_dir = base_dir / f"{scheme}_scheme" + if target_dir.exists(): + shutil.rmtree(target_dir) + base_dir.mkdir(parents=True, exist_ok=True) - print(f"Extracted {scheme} keys to {target_dir}/") + # Extract the archive into the base directory. + # The archive root is the scheme directory itself. + with tarfile.open(tmp_path, "r:gz") as tar: + tar.extractall(path=base_dir, filter="data") - except Exception as e: - print(f"Failed to extract {scheme} keys: {e}") - finally: - # Clean up temporary file - os.unlink(tmp_path) + print(f"Extracted {scheme} keys to {target_dir}/") + finally: + # Always clean up the temporary download file. + tmp_path.unlink(missing_ok=True) - # Clear cache so new keys are loaded - _LAZY_KEY_CACHE.clear() print("Download complete!") def main() -> None: - """CLI entry point for key generation.""" + """CLI entry point for generating or downloading test keys.""" parser = argparse.ArgumentParser( description="Generate XMSS key pairs for consensus testing", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -634,11 +693,12 @@ def main() -> None: ) args = parser.parse_args() - # Download keys instead of generating if specified + # Download pre-generated keys instead of generating locally. if args.download: download_keys(scheme=args.scheme) return + # Generate fresh keys with the specified parameters. _generate_keys(lean_env=args.scheme, count=args.count, max_slot=args.max_slot) diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index ba303411..2ca130f4 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -44,7 +44,6 @@ LEAN_ENV_TO_SCHEMES, XmssKeyManager, create_dummy_signature, - get_shared_key_manager, ) from ..test_types import ( AggregatedAttestationSpec, @@ -191,7 +190,7 @@ def make_fixture(self) -> Self: # XMSS keys are expensive to generate. # The shared key manager caches keys across tests. # Tests requiring higher max slot trigger key expansion. - key_manager = get_shared_key_manager(max_slot=self.max_slot) + key_manager = XmssKeyManager.shared(max_slot=self.max_slot) # Validator pubkey synchronization # diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index f91c1dbd..752a2c5f 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -27,7 +27,7 @@ from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof from lean_spec.types import ByteListMiB, Bytes32 -from ..keys import XmssKeyManager, create_dummy_signature, get_shared_key_manager +from ..keys import XmssKeyManager, create_dummy_signature from ..test_types import AggregatedAttestationSpec, BlockSpec from .base import BaseConsensusFixture @@ -122,7 +122,7 @@ def make_fixture(self) -> VerifySignaturesTest: assert self.anchor_state is not None, "anchor state must be set before making the fixture" # Use shared key manager - key_manager = get_shared_key_manager() + key_manager = XmssKeyManager.shared() # Build the signed block signed_block = self._build_block_from_spec(self.block, self.anchor_state, key_manager) diff --git a/packages/testing/src/consensus_testing/test_types/genesis.py b/packages/testing/src/consensus_testing/test_types/genesis.py index b2c7adf8..bbd37cb4 100644 --- a/packages/testing/src/consensus_testing/test_types/genesis.py +++ b/packages/testing/src/consensus_testing/test_types/genesis.py @@ -6,7 +6,7 @@ from lean_spec.subspecs.containers.validator import Validator, ValidatorIndex from lean_spec.types import Bytes52, Uint64 -from ..keys import get_shared_key_manager +from ..keys import XmssKeyManager def generate_pre_state(**kwargs: Any) -> State: @@ -24,7 +24,7 @@ def generate_pre_state(**kwargs: Any) -> State: genesis_time = kwargs.get("genesis_time", Uint64(0)) num_validators = kwargs.get("num_validators", 4) - key_manager = get_shared_key_manager() + key_manager = XmssKeyManager.shared() available_keys = len(key_manager) assert num_validators <= available_keys, ( diff --git a/tests/consensus/devnet/ssz/test_xmss_containers.py b/tests/consensus/devnet/ssz/test_xmss_containers.py index a2c85959..186f529d 100644 --- a/tests/consensus/devnet/ssz/test_xmss_containers.py +++ b/tests/consensus/devnet/ssz/test_xmss_containers.py @@ -2,7 +2,7 @@ import pytest from consensus_testing import SSZTestFiller -from consensus_testing.keys import create_dummy_signature, get_shared_key_manager +from consensus_testing.keys import XmssKeyManager, create_dummy_signature from lean_spec.subspecs.containers import ValidatorIndex from lean_spec.subspecs.containers.attestation import AggregationBits @@ -52,9 +52,9 @@ def test_signature_zero(ssz: SSZTestFiller) -> None: def test_signature_actual(ssz: SSZTestFiller) -> None: """SSZ roundtrip for a cryptographically valid Signature produced by signing.""" - key_manager = get_shared_key_manager() + key_manager = XmssKeyManager.shared() scheme = key_manager.scheme - sk = key_manager.keys[ValidatorIndex(0)].attestation_secret + sk = key_manager[ValidatorIndex(0)].attestation_secret signature = scheme.sign(sk, Slot(0), Bytes32(b"\x42" * 32)) ssz(type_name="Signature", value=signature) diff --git a/tests/lean_spec/conftest.py b/tests/lean_spec/conftest.py index e44b6c54..24457825 100644 --- a/tests/lean_spec/conftest.py +++ b/tests/lean_spec/conftest.py @@ -10,7 +10,7 @@ from collections.abc import Callable import pytest -from consensus_testing.keys import XmssKeyManager, get_shared_key_manager +from consensus_testing.keys import XmssKeyManager from lean_spec.subspecs.containers import Block, State from lean_spec.subspecs.containers.slot import Slot @@ -28,7 +28,7 @@ @pytest.fixture def key_manager() -> XmssKeyManager: """XMSS key manager for signing attestations.""" - return get_shared_key_manager(max_slot=Slot(20)) + return XmssKeyManager.shared(max_slot=Slot(20)) _DEFAULT_VALIDATOR_ID = ValidatorIndex(0) diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index 726bdb2a..028abe7c 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -9,7 +9,7 @@ from typing import NamedTuple, cast -from consensus_testing.keys import XmssKeyManager, get_shared_key_manager +from consensus_testing.keys import XmssKeyManager from lean_spec.subspecs.chain.clock import Interval, SlotClock from lean_spec.subspecs.chain.config import INTERVALS_PER_SLOT @@ -101,20 +101,18 @@ def make_validators(count: int) -> Validators: def make_validators_from_key_manager(key_manager: XmssKeyManager, count: int) -> Validators: """Build a validator registry with real XMSS keys from a key manager.""" - return Validators( - data=[ + validators = [] + for i in range(count): + idx = ValidatorIndex(i) + att_pk, prop_pk = key_manager.get_public_keys(idx) + validators.append( Validator( - attestation_pubkey=Bytes52( - key_manager[ValidatorIndex(i)].attestation_public.encode_bytes() - ), - proposal_pubkey=Bytes52( - key_manager[ValidatorIndex(i)].proposal_public.encode_bytes() - ), - index=ValidatorIndex(i), + attestation_pubkey=Bytes52(att_pk.encode_bytes()), + proposal_pubkey=Bytes52(prop_pk.encode_bytes()), + index=idx, ) - for i in range(count) - ] - ) + ) + return Validators(data=validators) def make_genesis_state( @@ -413,7 +411,7 @@ def make_keyed_genesis_state( ) -> State: """Create a genesis state with real XMSS keys from the shared key manager.""" if key_manager is None: - key_manager = get_shared_key_manager() + key_manager = XmssKeyManager.shared() validators = make_validators_from_key_manager(key_manager, num_validators) return make_genesis_state(validators=validators) @@ -428,7 +426,7 @@ def make_aggregated_proof( xmss_participants = AggregationBits.from_validator_indices(ValidatorIndices(data=participants)) raw_xmss = list( zip( - [key_manager[vid].attestation_public for vid in participants], + [key_manager.get_public_keys(vid)[0] for vid in participants], [key_manager.sign_attestation_data(vid, attestation_data) for vid in participants], strict=True, ) diff --git a/tests/lean_spec/subspecs/containers/conftest.py b/tests/lean_spec/subspecs/containers/conftest.py index b2f717ae..9a5a64f4 100644 --- a/tests/lean_spec/subspecs/containers/conftest.py +++ b/tests/lean_spec/subspecs/containers/conftest.py @@ -3,7 +3,7 @@ from __future__ import annotations import pytest -from consensus_testing.keys import XmssKeyManager, get_shared_key_manager +from consensus_testing.keys import XmssKeyManager from lean_spec.subspecs.containers.slot import Slot @@ -11,4 +11,4 @@ @pytest.fixture def container_key_manager() -> XmssKeyManager: """Key manager for container tests.""" - return get_shared_key_manager(max_slot=Slot(20)) + return XmssKeyManager.shared(max_slot=Slot(20)) diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index 3301bb36..56565c64 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from consensus_testing.keys import XmssKeyManager, get_shared_key_manager +from consensus_testing.keys import XmssKeyManager from lean_spec.subspecs.chain.clock import SlotClock from lean_spec.subspecs.chain.config import MILLISECONDS_PER_INTERVAL @@ -445,7 +445,7 @@ class TestValidatorServiceIntegration: @pytest.fixture def key_manager(self) -> XmssKeyManager: """Key manager with pre-generated test keys.""" - return get_shared_key_manager(max_slot=Slot(10)) + return XmssKeyManager.shared(max_slot=Slot(10)) @pytest.fixture def real_store(self, key_manager: XmssKeyManager) -> Store: