Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from lean_spec.subspecs.koalabear import Fp
from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof
from lean_spec.subspecs.xmss.constants import TARGET_CONFIG
from lean_spec.subspecs.xmss.containers import PublicKey, Signature, ValidatorKeyPair
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature, ValidatorKeyPair
from lean_spec.subspecs.xmss.interface import (
PROD_SIGNATURE_SCHEME,
TEST_SIGNATURE_SCHEME,
Expand All @@ -65,6 +65,9 @@
)
from lean_spec.types import Bytes32, Uint64

SecretField = Literal["attestation_secret", "proposal_secret"]
"""The two secret key field names on ValidatorKeyPair."""

__all__ = [
"CLI_DEFAULT_MAX_SLOT",
"KEY_DOWNLOAD_URLS",
Expand Down Expand Up @@ -173,6 +176,8 @@ def __init__(self, scheme_name: str) -> None:
self._scheme_name = scheme_name
self._keys_dir = get_keys_dir(scheme_name)
self._cache: dict[ValidatorIndex, ValidatorKeyPair] = {}
self._public_cache: dict[ValidatorIndex, tuple[PublicKey, PublicKey]] = {}
self._raw_cache: dict[ValidatorIndex, dict[str, str]] = {}
self._available_indices: set[ValidatorIndex] | None = None

def _ensure_dir_exists(self) -> None:
Expand All @@ -197,18 +202,62 @@ def _get_available_indices(self) -> set[ValidatorIndex]:
)
return self._available_indices

def _load_raw(self, idx: ValidatorIndex) -> dict[str, str]:
"""Load raw JSON data from disk (cached)."""
if idx not in self._raw_cache:
key_file = self._keys_dir / f"{idx}.json"
try:
self._raw_cache[idx] = json.loads(key_file.read_text())
except FileNotFoundError:
raise KeyError(f"Key file not found: {key_file}") from None
return self._raw_cache[idx]

def _load_key(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Load a single key from disk."""
key_file = self._keys_dir / f"{idx}.json"
if not key_file.exists():
raise KeyError(f"Key file not found: {key_file}")
data = json.loads(key_file.read_text())
return ValidatorKeyPair.from_dict(data)
return ValidatorKeyPair.from_dict(self._load_raw(idx))

def _load_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
"""Load only public keys from disk, skipping expensive SecretKey deserialization."""
data = self._load_raw(idx)
return (
PublicKey.decode_bytes(bytes.fromhex(data["attestation_public"])),
PublicKey.decode_bytes(bytes.fromhex(data["proposal_public"])),
)

def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
"""
Get (attestation_public, proposal_public) without loading secret keys.

Returns cached public keys if available, otherwise loads only the public
key portions from disk. Avoids deserializing the heavy SecretKey objects
(each ~2.7KB raw with 3 HashSubTree structures) until signing is needed.
"""
if idx in self._cache:
kp = self._cache[idx]
return (kp.attestation_public, kp.proposal_public)
if idx not in self._public_cache:
self._public_cache[idx] = self._load_public_keys(idx)
return self._public_cache[idx]

def get_secret_key(self, idx: ValidatorIndex, field: SecretField) -> SecretKey:
"""
Load a specific secret key from disk without deserializing the other keys.

Only the requested SecretKey is deserialized (~370 MB in Python objects).
The other three fields remain as lightweight hex strings (~2.7 KB each).
"""
if idx in self._cache:
return getattr(self._cache[idx], field)
data = self._load_raw(idx)
return SecretKey.decode_bytes(bytes.fromhex(data[field]))

def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Get key pair by validator index, loading from disk if needed."""
if idx not in self._cache:
self._cache[idx] = self._load_key(idx)
# Full pair supersedes raw/public caches for this index.
self._raw_cache.pop(idx, None)
self._public_cache.pop(idx, None)
return self._cache[idx]

def __contains__(self, idx: object) -> bool:
Expand Down Expand Up @@ -247,7 +296,13 @@ def __init__(
"""Initialize the manager with optional custom configuration."""
self.max_slot = max_slot
self.scheme = scheme
self._state: dict[ValidatorIndex, ValidatorKeyPair] = {}
self._secret_state: dict[tuple[ValidatorIndex, SecretField], bytes] = {}
"""
Advanced secret key state cached as raw SSZ bytes.

Raw bytes (~2.7 KB each) instead of deserialized SecretKey objects
(~370 MB each) to avoid holding massive Pydantic model trees in memory.
"""

try:
self.scheme_name = next(
Expand All @@ -264,9 +319,7 @@ def keys(self) -> LazyKeyDict:
return _LAZY_KEY_CACHE[self.scheme_name]

def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Get key pair, returning advanced state if available."""
if idx in self._state:
return self._state[idx]
"""Get key pair. Prefer get_public_keys() or signing methods to avoid loading all keys."""
if idx not in self.keys:
raise KeyError(f"Validator {idx} not found (max: {len(self.keys) - 1})")
return self.keys[idx]
Expand All @@ -285,27 +338,45 @@ def __iter__(self) -> Iterator[ValidatorIndex]:
"""Iterate over validator indices."""
return iter(self.keys)

def get_public_keys(self, idx: ValidatorIndex) -> tuple[PublicKey, PublicKey]:
"""
Get (attestation_public, proposal_public) without loading secret keys.

Delegates to lazy disk loading that skips SecretKey deserialization.
"""
return self.keys.get_public_keys(idx)

def _sign_with_secret(
self,
validator_id: ValidatorIndex,
slot: Slot,
message: Bytes32,
secret_field: Literal["attestation_secret", "proposal_secret"],
secret_field: SecretField,
) -> Signature:
"""
Shared signing logic for attestation/proposal paths.

Handles XMSS state advancement until the requested slot is within the
prepared interval, caches the updated secret, and produces the signature.
prepared interval, caches the updated secret as raw bytes, and produces
the signature.

Only the needed SecretKey is deserialized (~370 MB in Python objects).
After signing, the advanced state is re-serialized to compact bytes
(~2.7 KB) so only one SecretKey is in memory at a time.

Args:
validator_id: Validator index whose key should be used.
slot: The slot to sign for.
message: The message bytes to sign.
secret_field: Which secret on the key pair should advance.
"""
kp = self[validator_id]
sk = getattr(kp, secret_field)
cache_key = (validator_id, secret_field)

# Deserialize the secret key: from cached bytes or from disk.
if cache_key in self._secret_state:
sk = SecretKey.decode_bytes(self._secret_state[cache_key])
else:
sk = self.keys.get_secret_key(validator_id, secret_field)

# Advance key state until the slot is ready for signing.
prepared = self.scheme.get_prepared_interval(sk)
Expand All @@ -316,10 +387,12 @@ def _sign_with_secret(
sk = self.scheme.advance_preparation(sk)
prepared = self.scheme.get_prepared_interval(sk)

# Cache advanced state (only the selected secret changes).
self._state[validator_id] = kp._replace(**{secret_field: sk})
signature = self.scheme.sign(sk, slot, message)

# Cache advanced state as raw bytes to keep memory compact.
self._secret_state[cache_key] = sk.encode_bytes()

return self.scheme.sign(sk, slot, message)
return signature

def sign_attestation_data(
self,
Expand Down Expand Up @@ -397,7 +470,7 @@ def build_attestation_signatures(
# Look up pre-computed signatures by attestation data and validator ID.
sigs_for_data = lookup.get(agg.data, {})

public_keys: list[PublicKey] = [self[vid].attestation_public for vid in validator_ids]
public_keys: list[PublicKey] = [self.get_public_keys(vid)[0] for vid in validator_ids]
signatures: list[Signature] = [
sigs_for_data.get(vid) or self.sign_attestation_data(vid, agg.data)
for vid in validator_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ def make_fixture(self) -> Self:
updated_validators = []
for i, validator in enumerate(self.anchor_state.validators):
idx = ValidatorIndex(i)
attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx)
validator = validator.model_copy(
update={
"attestation_pubkey": key_manager[idx].attestation_public.encode_bytes(),
"proposal_pubkey": key_manager[idx].proposal_public.encode_bytes(),
"attestation_pubkey": attestation_pubkey.encode_bytes(),
"proposal_pubkey": proposal_pubkey.encode_bytes(),
}
)
updated_validators.append(validator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _build_block_from_spec(
# Valid proof but from wrong validators
# Sign with signer_ids but claim validator_ids as participants
signer_public_keys = [
key_manager[vid].attestation_public for vid in invalid_spec.signer_ids
key_manager.get_public_keys(vid)[0] for vid in invalid_spec.signer_ids
]
signer_signatures = [
key_manager.sign_attestation_data(vid, attestation_data)
Expand Down
22 changes: 10 additions & 12 deletions packages/testing/src/consensus_testing/test_types/genesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,17 @@ def generate_pre_state(**kwargs: Any) -> State:
f"but the key manager has only {available_keys} keys",
)

validators = Validators(
data=[
validator_list = []
for i in range(num_validators):
idx = ValidatorIndex(i)
attestation_pubkey, proposal_pubkey = key_manager.get_public_keys(idx)
validator_list.append(
Validator(
attestation_pubkey=Bytes52(
key_manager[ValidatorIndex(i)].attestation_public.encode_bytes()
),
proposal_pubkey=Bytes52(
key_manager[ValidatorIndex(i)].proposal_public.encode_bytes()
),
index=ValidatorIndex(i),
attestation_pubkey=Bytes52(attestation_pubkey.encode_bytes()),
proposal_pubkey=Bytes52(proposal_pubkey.encode_bytes()),
index=idx,
)
for i in range(num_validators)
]
)
)
validators = Validators(data=validator_list)

return State.generate_genesis(genesis_time=genesis_time, validators=validators)
Loading