diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 0607e94a..bf077581 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -51,7 +51,7 @@ from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof from lean_spec.subspecs.xmss.constants import TARGET_CONFIG -from lean_spec.subspecs.xmss.containers import PublicKey, Signature, ValidatorKeyPair +from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature, ValidatorKeyPair from lean_spec.subspecs.xmss.interface import ( PROD_SIGNATURE_SCHEME, TEST_SIGNATURE_SCHEME, @@ -65,6 +65,9 @@ ) from lean_spec.types import Bytes32, Uint64 +SecretField = Literal["attestation_secret", "proposal_secret"] +"""The two secret key field names on ValidatorKeyPair.""" + __all__ = [ "CLI_DEFAULT_MAX_SLOT", "KEY_DOWNLOAD_URLS", @@ -173,6 +176,8 @@ def __init__(self, scheme_name: str) -> None: self._scheme_name = scheme_name self._keys_dir = get_keys_dir(scheme_name) self._cache: dict[ValidatorIndex, ValidatorKeyPair] = {} + self._public_cache: dict[ValidatorIndex, tuple[PublicKey, PublicKey]] = {} + self._raw_cache: dict[ValidatorIndex, dict[str, str]] = {} self._available_indices: set[ValidatorIndex] | None = None def _ensure_dir_exists(self) -> None: @@ -197,18 +202,62 @@ def _get_available_indices(self) -> set[ValidatorIndex]: ) return self._available_indices + def _load_raw(self, idx: ValidatorIndex) -> dict[str, str]: + """Load raw JSON data from disk (cached).""" + if idx not in self._raw_cache: + key_file = self._keys_dir / f"{idx}.json" + try: + self._raw_cache[idx] = json.loads(key_file.read_text()) + except FileNotFoundError: + raise KeyError(f"Key file not found: {key_file}") from None + return self._raw_cache[idx] + def _load_key(self, idx: ValidatorIndex) -> ValidatorKeyPair: """Load a single key from disk.""" - key_file = self._keys_dir / f"{idx}.json" - if not key_file.exists(): - raise KeyError(f"Key file not found: {key_file}") - data = json.loads(key_file.read_text()) - return ValidatorKeyPair.from_dict(data) + return ValidatorKeyPair.from_dict(self._load_raw(idx)) + + def _load_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: + """Load only public keys from disk, skipping expensive SecretKey deserialization.""" + data = self._load_raw(idx) + return ( + PublicKey.decode_bytes(bytes.fromhex(data["attestation_public"])), + PublicKey.decode_bytes(bytes.fromhex(data["proposal_public"])), + ) + + def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: + """ + Get (attestation_public, proposal_public) without loading secret keys. + + Returns cached public keys if available, otherwise loads only the public + key portions from disk. Avoids deserializing the heavy SecretKey objects + (each ~2.7KB raw with 3 HashSubTree structures) until signing is needed. + """ + if idx in self._cache: + kp = self._cache[idx] + return (kp.attestation_public, kp.proposal_public) + if idx not in self._public_cache: + self._public_cache[idx] = self._load_public_keys(idx) + return self._public_cache[idx] + + def get_secret_key(self, idx: ValidatorIndex, field: SecretField) -> SecretKey: + """ + Load a specific secret key from disk without deserializing the other keys. + + Only the requested SecretKey is deserialized (~370 MB in Python objects). + The other three fields remain as lightweight hex strings (~2.7 KB each). + """ + if idx in self._cache: + return getattr(self._cache[idx], field) + data = self._load_raw(idx) + return SecretKey.decode_bytes(bytes.fromhex(data[field])) def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair: """Get key pair by validator index, loading from disk if needed.""" if idx not in self._cache: self._cache[idx] = self._load_key(idx) + # Full pair supersedes raw/public caches for this index. + self._raw_cache.pop(idx, None) + self._public_cache.pop(idx, None) return self._cache[idx] def __contains__(self, idx: object) -> bool: @@ -247,7 +296,13 @@ def __init__( """Initialize the manager with optional custom configuration.""" self.max_slot = max_slot self.scheme = scheme - self._state: dict[ValidatorIndex, ValidatorKeyPair] = {} + self._secret_state: dict[tuple[ValidatorIndex, SecretField], bytes] = {} + """ + Advanced secret key state cached as raw SSZ bytes. + + Raw bytes (~2.7 KB each) instead of deserialized SecretKey objects + (~370 MB each) to avoid holding massive Pydantic model trees in memory. + """ try: self.scheme_name = next( @@ -264,9 +319,7 @@ def keys(self) -> LazyKeyDict: return _LAZY_KEY_CACHE[self.scheme_name] def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair: - """Get key pair, returning advanced state if available.""" - if idx in self._state: - return self._state[idx] + """Get key pair. Prefer get_public_keys() or signing methods to avoid loading all keys.""" if idx not in self.keys: raise KeyError(f"Validator {idx} not found (max: {len(self.keys) - 1})") return self.keys[idx] @@ -285,18 +338,31 @@ def __iter__(self) -> Iterator[ValidatorIndex]: """Iterate over validator indices.""" return iter(self.keys) + def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]: + """ + Get (attestation_public, proposal_public) without loading secret keys. + + Delegates to lazy disk loading that skips SecretKey deserialization. + """ + return self.keys.get_public_keys(idx) + def _sign_with_secret( self, validator_id: ValidatorIndex, slot: Slot, message: Bytes32, - secret_field: Literal["attestation_secret", "proposal_secret"], + secret_field: SecretField, ) -> Signature: """ Shared signing logic for attestation/proposal paths. Handles XMSS state advancement until the requested slot is within the - prepared interval, caches the updated secret, and produces the signature. + prepared interval, caches the updated secret as raw bytes, and produces + the signature. + + Only the needed SecretKey is deserialized (~370 MB in Python objects). + After signing, the advanced state is re-serialized to compact bytes + (~2.7 KB) so only one SecretKey is in memory at a time. Args: validator_id: Validator index whose key should be used. @@ -304,8 +370,13 @@ def _sign_with_secret( message: The message bytes to sign. secret_field: Which secret on the key pair should advance. """ - kp = self[validator_id] - sk = getattr(kp, secret_field) + cache_key = (validator_id, secret_field) + + # Deserialize the secret key: from cached bytes or from disk. + if cache_key in self._secret_state: + sk = SecretKey.decode_bytes(self._secret_state[cache_key]) + else: + sk = self.keys.get_secret_key(validator_id, secret_field) # Advance key state until the slot is ready for signing. prepared = self.scheme.get_prepared_interval(sk) @@ -316,10 +387,12 @@ def _sign_with_secret( sk = self.scheme.advance_preparation(sk) prepared = self.scheme.get_prepared_interval(sk) - # Cache advanced state (only the selected secret changes). - self._state[validator_id] = kp._replace(**{secret_field: sk}) + signature = self.scheme.sign(sk, slot, message) + + # Cache advanced state as raw bytes to keep memory compact. + self._secret_state[cache_key] = sk.encode_bytes() - return self.scheme.sign(sk, slot, message) + return signature def sign_attestation_data( self, @@ -397,7 +470,7 @@ def build_attestation_signatures( # Look up pre-computed signatures by attestation data and validator ID. sigs_for_data = lookup.get(agg.data, {}) - public_keys: list[PublicKey] = [self[vid].attestation_public for vid in validator_ids] + public_keys: list[PublicKey] = [self.get_public_keys(vid)[0] for vid in validator_ids] signatures: list[Signature] = [ sigs_for_data.get(vid) or self.sign_attestation_data(vid, agg.data) for vid in validator_ids diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index 90629e51..19864b45 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -200,10 +200,11 @@ def make_fixture(self) -> Self: updated_validators = [] for i, validator in enumerate(self.anchor_state.validators): idx = ValidatorIndex(i) + attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx) validator = validator.model_copy( update={ - "attestation_pubkey": key_manager[idx].attestation_public.encode_bytes(), - "proposal_pubkey": key_manager[idx].proposal_public.encode_bytes(), + "attestation_pubkey": attestation_pubkey.encode_bytes(), + "proposal_pubkey": proposal_pubkey.encode_bytes(), } ) updated_validators.append(validator) diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index 313dab5f..a8823b99 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -230,7 +230,7 @@ def _build_block_from_spec( # Valid proof but from wrong validators # Sign with signer_ids but claim validator_ids as participants signer_public_keys = [ - key_manager[vid].attestation_public for vid in invalid_spec.signer_ids + key_manager.get_public_keys(vid)[0] for vid in invalid_spec.signer_ids ] signer_signatures = [ key_manager.sign_attestation_data(vid, attestation_data) diff --git a/packages/testing/src/consensus_testing/test_types/genesis.py b/packages/testing/src/consensus_testing/test_types/genesis.py index 2d7f0186..b2c7adf8 100644 --- a/packages/testing/src/consensus_testing/test_types/genesis.py +++ b/packages/testing/src/consensus_testing/test_types/genesis.py @@ -33,19 +33,17 @@ def generate_pre_state(**kwargs: Any) -> State: f"but the key manager has only {available_keys} keys", ) - validators = Validators( - data=[ + validator_list = [] + for i in range(num_validators): + idx = ValidatorIndex(i) + attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx) + validator_list.append( Validator( - attestation_pubkey=Bytes52( - key_manager[ValidatorIndex(i)].attestation_public.encode_bytes() - ), - proposal_pubkey=Bytes52( - key_manager[ValidatorIndex(i)].proposal_public.encode_bytes() - ), - index=ValidatorIndex(i), + attestation_pubkey=Bytes52(attestation_pubkey.encode_bytes()), + proposal_pubkey=Bytes52(proposal_pubkey.encode_bytes()), + index=idx, ) - for i in range(num_validators) - ] - ) + ) + validators = Validators(data=validator_list) return State.generate_genesis(genesis_time=genesis_time, validators=validators)