diff --git a/Cargo.lock b/Cargo.lock index a83e37c45..9317978be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2237,7 +2237,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "once_cell", "p3", @@ -3243,7 +3243,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "bincode 1.3.3", "clap", @@ -3267,7 +3267,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "either", "ff_ext", @@ -4558,7 +4558,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "p3-air", "p3-baby-bear", @@ -5126,7 +5126,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "ff_ext", "p3", @@ -6083,7 +6083,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "cfg-if", "dashu", @@ -6208,7 +6208,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "either", "ff_ext", @@ -6226,7 +6226,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "itertools 0.13.0", "p3", @@ -6633,7 +6633,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6927,7 +6927,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "bincode 1.3.3", "clap", @@ -7214,7 +7214,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.24#a3538e3529a7eb87e8867f4a87b760d7ad9991f7" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 8cc5823a5..00e4b0641 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,16 +27,16 @@ version = "0.1.0" ceno_crypto_primitives = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_crypto_primitives", branch = "main" } ceno_syscall = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_syscall", branch = "main" } -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.24" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.24" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.24" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.24" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.24" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.24" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.24" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.24" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.24" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.24" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.25" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.25" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.25" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.25" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.25" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.25" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.25" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.25" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.25" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.25" } anyhow = { version = "1.0", default-features = false } bincode = "1" @@ -129,7 +129,7 @@ lto = "thin" #[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] #ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } - +# #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } #mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 95b15b581..c489567a7 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -5,8 +5,7 @@ use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ Checkpoint, FieldType, MultiProver, PcsKind, Preset, public_io_words_to_digest_words, - run_e2e_full_trace_verify, run_e2e_single_shard_debug_verify, run_e2e_with_checkpoint, - setup_platform, setup_platform_debug, + run_e2e_with_checkpoint, setup_platform, setup_platform_debug, }, scheme::{ ZKVMProof, constants::MAX_NUM_VARIABLES, create_backend, create_prover, hal::ProverDevice, @@ -352,17 +351,10 @@ fn run_inner< fs::write(&vk_file, vk_bytes).unwrap(); if checkpoint > Checkpoint::PrepVerify { + // `run_e2e_with_checkpoint` already performs the real verification for the + // complete flow. Re-running it here without the emulation exit code causes + // a false "Unfinished execution" error to be logged. let verifier = ZKVMVerifier::new(vk); - if target_shard_id.is_some() { - run_e2e_single_shard_debug_verify( - &verifier, - zkvm_proofs.first().cloned().expect("missing shard proof"), - None, - max_steps, - ); - } else { - run_e2e_full_trace_verify(&verifier, zkvm_proofs.clone(), None, max_steps); - } soundness_test(zkvm_proofs.first().cloned().unwrap(), &verifier); } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c887bb920..fcd4a3f90 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -46,6 +46,7 @@ use serde::Serialize; use std::collections::{HashMap, HashSet}; use std::{ collections::{BTreeMap, BTreeSet}, + io::Write, marker::PhantomData, ops::Range, sync::Arc, @@ -2193,9 +2194,18 @@ fn create_proofs_streaming< let transcript = Transcript::new(b"riscv"); let start = std::time::Instant::now(); - let zkvm_proof = prover - .create_proof(&shard_ctx, zkvm_witness, pi, transcript) - .expect("create_proof failed"); + let zkvm_proof = + match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + Ok(proof) => proof, + Err(err) => { + eprintln!( + "create_proof failed for shard {}: {err:?}", + shard_ctx.shard_id + ); + let _ = std::io::stderr().flush(); + std::process::exit(1); + } + }; tracing::debug!( "{}th shard proof created in {:?}", shard_ctx.shard_id, @@ -2254,9 +2264,18 @@ fn create_proofs_streaming< let transcript = Transcript::new(b"riscv"); let start = std::time::Instant::now(); - let zkvm_proof = prover - .create_proof(&shard_ctx, zkvm_witness, pi, transcript) - .expect("create_proof failed"); + let zkvm_proof = + match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + Ok(proof) => proof, + Err(err) => { + eprintln!( + "create_proof failed for shard {}: {err:?}", + shard_ctx.shard_id + ); + let _ = std::io::stderr().flush(); + std::process::exit(1); + } + }; tracing::debug!( "{}th shard proof created in {:?}", shard_ctx.shard_id, diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 565e0dffa..4dc1bf289 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -348,8 +348,7 @@ fn replay_keccak_witness_only_from_packed( ) -> Result, ZKVMError> { use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; - let num_padded_instances = num_instances.next_power_of_two().max(2); - let num_padded_rows = num_padded_instances * 32; + let num_rows = num_instances * 32; let rotation = KECCAK_ROUNDS_CEIL_LOG2; let col_map = info_span!("col_map").in_scope(|| extract_keccak_column_map(config, num_witin)); @@ -358,7 +357,7 @@ fn replay_keccak_witness_only_from_packed( .witgen_keccak( &col_map, packed_instances, - num_padded_rows, + num_rows, shard_offset, fetch_base_pc, fetch_num_slots, @@ -372,9 +371,10 @@ fn replay_keccak_witness_only_from_packed( let raw_witin = if crate::instructions::gpu::config::is_debug_compare_enabled() || !should_materialize_witness_on_gpu() { - info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { + let produced_rows = gpu_result.witness.num_rows; + info_span!("transpose_d2h", rows = produced_rows, cols = num_witin).in_scope(|| { let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .alloc_elems_on_device(produced_rows * num_witin, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) })?; @@ -382,7 +382,7 @@ fn replay_keccak_witness_only_from_packed( &hal.inner, &mut rmm_buffer, &gpu_result.witness.device_buffer, - num_padded_rows, + produced_rows, num_witin, ) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; @@ -445,8 +445,7 @@ fn gpu_assign_keccak_inner( use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; let num_instances = step_indices.len(); - let num_padded_instances = num_instances.next_power_of_two().max(2); - let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance + let num_rows = num_instances * 32; // 2^5 = 32 rows per instance let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 let materialize_initial_witness = crate::instructions::gpu::config::is_debug_compare_enabled() || should_materialize_witness_on_initial_assign(); @@ -479,7 +478,7 @@ fn gpu_assign_keccak_inner( .witgen_keccak( &col_map, &packed_instances, - num_padded_rows, + num_rows, shard_ctx.current_shard_offset_cycle(), fetch_base_pc, fetch_num_slots, @@ -565,9 +564,10 @@ fn gpu_assign_keccak_inner( } else if crate::instructions::gpu::config::is_debug_compare_enabled() || !should_materialize_witness_on_gpu() { - info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { + let produced_rows = gpu_result.witness.num_rows; + info_span!("transpose_d2h", rows = produced_rows, cols = num_witin).in_scope(|| { let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .alloc_elems_on_device(produced_rows * num_witin, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) })?; @@ -575,7 +575,7 @@ fn gpu_assign_keccak_inner( &hal.inner, &mut rmm_buffer, &gpu_result.witness.device_buffer, - num_padded_rows, + produced_rows, num_witin, ) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 21f1f89a0..0813449e1 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -439,11 +439,11 @@ pub(crate) fn try_gpu_assign_shard_ram( { let struct_data = tracing::info_span!( "gpu_shard_ram_structural_transpose_d2h", - num_rows_padded, + rows = gpu_structural.num_rows, num_structural_witin, ) .in_scope(|| -> Result<_, ZKVMError> { - let wit_num_rows = num_rows_padded; + let wit_num_rows = gpu_structural.num_rows; let struct_num_cols = num_structural_witin; let mut struct_rmm_buf = hal .witgen @@ -684,11 +684,11 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( { let struct_data = tracing::info_span!( "gpu_shard_ram_structural_transpose_d2h_from_device", - num_rows_padded, + rows = gpu_structural.num_rows, num_structural_witin, ) .in_scope(|| -> Result<_, ZKVMError> { - let wit_num_rows = num_rows_padded; + let wit_num_rows = gpu_structural.num_rows; let struct_num_cols = num_structural_witin; let mut struct_rmm_buf = hal .witgen diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index be51108c4..0d3e440cb 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -135,6 +135,7 @@ pub(crate) fn try_gpu_assign_instances>( let total_instances = step_indices.len(); if total_instances == 0 { // Empty: just return empty matrices + let num_witin = num_witin.max(1); let num_structural_witin = num_structural_witin.max(1); let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); let raw_structural = @@ -481,7 +482,7 @@ fn gpu_assign_instances_inner>( total_instances, num_witin, I::padding_strategy(), - ) + )? }; if materialize_initial_witness { raw_witin.padding_by_strategy(); @@ -1484,7 +1485,7 @@ fn replay_gpu_witness_from_resident_raw>( total_instances, replay.num_witin, I::padding_strategy(), - ); + )?; // Keep replayed witness immutable after attaching the col-major device backing. // Mutating/padding a RowMajorMatrix clears device metadata, but replay consumers diff --git a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs index fc558046d..5647cef12 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs @@ -303,9 +303,9 @@ pub(crate) fn gpu_witness_to_rmm( num_rows: usize, num_cols: usize, padding: InstancePaddingStrategy, -) -> RowMajorMatrix { +) -> Result, ZKVMError> { let mut rmm = RowMajorMatrix::::new(num_rows, num_cols, padding); // Keep the original col-major witness buffer as the source of truth for GPU commit. rmm.set_device_backing(gpu_result.device_buffer, DeviceMatrixLayout::ColMajor); - rmm + Ok(rmm) } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 0d6a32680..0680141fd 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -75,6 +75,15 @@ pub struct ZKVMChipProof { pub num_instances: [usize; 2], } +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct MainConstraintProof { + pub proof: SumcheckLayerProof, +} + /// each field will be interpret to (constant) polynomial #[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicValues { @@ -202,6 +211,7 @@ pub struct ZKVMProof> { pub public_values: PublicValues, // each circuit may have multiple proof instances pub chip_proofs: BTreeMap>>, + pub main_constraint_proof: MainConstraintProof, pub witin_commit: >::Commitment, pub opening_proof: PCS::Proof, } @@ -210,12 +220,14 @@ impl> ZKVMProof { pub fn new( public_values: PublicValues, chip_proofs: BTreeMap>>, + main_constraint_proof: MainConstraintProof, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { Self { public_values, chip_proofs, + main_constraint_proof, witin_commit, opening_proof, } diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b1640f7b4..5b384e790 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -1,10 +1,12 @@ use super::hal::{ - DeviceTransporter, MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, - RotationProver, RotationProverOutput, TowerProver, TraceCommitter, + BatchedMainConstraintProver, DeviceTransporter, MainConstraintJob, MainConstraintResult, + MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, RotationProver, + RotationProverOutput, TowerProver, TraceCommitter, }; use crate::{ error::ZKVMError, scheme::{ + MainConstraintProof, constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, @@ -29,7 +31,9 @@ use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ Expression, ToExpr, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, + monomial::Term, util::ceil_log2, + utils::eval_by_expr_with_instance, virtual_poly::{build_eq_x_r_vec, eq_eval}, virtual_polys::VirtualPolynomialsBuilder, }; @@ -44,7 +48,7 @@ use std::{ }; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProverMessage, IOPProverState}, + structs::{IOPProof, IOPProverMessage, IOPProverState}, util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; @@ -1098,6 +1102,329 @@ impl> MainSumcheckProver> + BatchedMainConstraintProver> for CpuProver> +{ + fn prove_batched_main_constraints<'a>( + &self, + jobs: Vec>>, + _pcs_data: & as ProverBackend>::PcsData, + transcript: &mut impl Transcript, + ) -> Result<(MainConstraintProof, Vec>), ZKVMError> { + struct ChipMainData<'a, E: ExtensionField> { + circuit_idx: usize, + layer: &'a gkr_iop::gkr::layer::Layer, + mle_start: usize, + num_mles: usize, + num_var_with_rotation: usize, + pi: Vec, + alpha_start: usize, + } + + if jobs.is_empty() { + return Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof: IOPProof { proofs: vec![] }, + evals: vec![], + }, + }, + vec![], + )); + } + + let mut owned_mles = Vec::>::new(); + let mut chip_data = Vec::with_capacity(jobs.len()); + let mut total_exprs = 0usize; + let mut max_num_variables = 0usize; + let mut max_degree = 0usize; + + for job in &jobs { + let ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + } = job.cs; + + let num_instances = job.input.num_instances(); + let log2_num_instances = job.input.log2_num_instances(); + let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); + max_num_variables = max_num_variables.max(num_var_with_rotation); + + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + max_degree = max_degree.max(first_layer.max_expr_degree + 1); + let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() + { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext { + offset: 0, + num_instances: job.input.num_instances[0], + num_vars: num_var_with_rotation, + } + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext { + offset: job.input.num_instances[0], + num_instances: job.input.num_instances[1], + num_vars: num_var_with_rotation, + } + } else { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } + }) + .collect_vec(); + + let mut out_evals = + vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + + if let Some(rotation) = job.rotation.as_ref() { + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation proof provided for non-rotation layer") + }; + let (left_evals, right_evals, point_evals) = + split_rotation_evals(&rotation.proof.evals); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation.left_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation.right_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation.point, + ); + } + + if let Some(ecc_proof) = job.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) + .expect("invalid internal ecc bridge claims"); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + let eval_and_dedup_points = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, &job.challenges).point) + }) + .collect_vec(); + let selector_eq_pairs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(selector_ctxs.iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + let eq = sel_type.compute(point.as_ref()?, selector_ctx)?; + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + let wit_id = *wit_id as usize; + assert!(wit_id < first_layer.n_structural_witin); + Some((wit_id, eq)) + }) + .collect_vec(); + let mut selector_eq_by_wit_id = vec![None; first_layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } + + let mle_start = owned_mles.len(); + owned_mles.extend(job.input.witness.iter().map(|mle| mle.as_ref().clone())); + owned_mles.extend(job.input.fixed.iter().map(|mle| mle.as_ref().clone())); + for (selector_eq, mle) in selector_eq_by_wit_id + .into_iter() + .zip(job.input.structural_witness.iter()) + { + owned_mles.push(selector_eq.unwrap_or_else(|| mle.as_ref().clone())); + } + let num_mles = + first_layer.n_witin + first_layer.n_fixed + first_layer.n_structural_witin; + assert_eq!(owned_mles.len() - mle_start, num_mles); + + chip_data.push(ChipMainData { + circuit_idx: job.circuit_idx, + layer: first_layer, + mle_start, + num_mles, + num_var_with_rotation, + pi: job + .input + .pi + .iter() + .map(|v| v.map_either(E::from, |v| v).into_inner()) + .collect_vec(), + alpha_start: total_exprs, + }); + total_exprs += first_layer.exprs.len(); + } + + let num_threads = optimal_sumcheck_threads(max_num_variables); + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let mut builder = VirtualPolynomialsBuilder::new(num_threads, max_num_variables); + let global_mle_exprs = owned_mles + .iter() + .map(|mle| builder.lift(Either::Left(mle))) + .collect_vec(); + let mut global_expr = Expression::ZERO; + + for chip in &chip_data { + let main_sumcheck_challenges = chain!( + jobs[0].challenges.iter().copied(), + alpha_pows[chip.alpha_start..chip.alpha_start + chip.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + for Term { + scalar: scalar_expr, + product, + } in chip + .layer + .main_sumcheck_expression_monomial_terms + .as_ref() + .unwrap() + { + let scalar = eval_by_expr_with_instance( + &[], + &[], + &[], + &chip.pi, + &main_sumcheck_challenges, + scalar_expr, + ); + let product_expr = product + .iter() + .map(|expr| { + let Expression::WitIn(wit_id) = expr else { + panic!("main monomial product must be converted to WitIn") + }; + global_mle_exprs[chip.mle_start + *wit_id as usize].clone() + }) + .fold(Expression::ONE, |acc, expr| acc * expr); + global_expr = global_expr + Expression::Constant(scalar) * product_expr; + } + } + + let span = entered_span!("IOPProverState::prove_batched_main", profiling_4 = true); + let (proof, prover_state) = + IOPProverState::prove(builder.to_virtual_polys(&[global_expr], &[]), transcript); + let global_evals = prover_state.get_mle_flatten_final_evaluations(); + transcript.append_field_element_exts(&global_evals); + let global_rt = prover_state.collect_raw_challenges(); + exit_span!(span); + + let mut results = Vec::with_capacity(jobs.len()); + for chip in &chip_data { + let input_opening_point = + global_rt[global_rt.len() - chip.num_var_with_rotation..].to_vec(); + let chip_evals = &global_evals[chip.mle_start..chip.mle_start + chip.num_mles]; + results.push(MainConstraintResult { + circuit_idx: chip.circuit_idx, + input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: chip_evals[..chip.layer.n_witin].to_vec(), + fixed_in_evals: chip_evals + [chip.layer.n_witin..chip.layer.n_witin + chip.layer.n_fixed] + .to_vec(), + }, + }); + } + + Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof, + evals: global_evals, + }, + }, + results, + )) + } +} + impl> OpeningProver> for CpuProver> { diff --git a/ceno_zkvm/src/scheme/gpu/memory.rs b/ceno_zkvm/src/scheme/gpu/memory.rs index 6421ad3c9..0991764f8 100644 --- a/ceno_zkvm/src/scheme/gpu/memory.rs +++ b/ceno_zkvm/src/scheme/gpu/memory.rs @@ -11,11 +11,16 @@ use ceno_gpu::{ estimate_build_tower_memory, estimate_prove_tower_memory, estimate_sumcheck_memory, }; use ff_ext::ExtensionField; -use gkr_iop::gpu::{ - BB31Base, GpuBackend, - gpu_prover::{ - BB31Ext, CacheLevel, CudaHalBB31, MemTracker, get_gpu_cache_level, get_mem_tracking_mode, +use gkr_iop::{ + evaluation::EvalExpression, + gpu::{ + BB31Base, GpuBackend, + gpu_prover::{ + BB31Ext, CacheLevel, CudaHalBB31, MemTracker, get_gpu_cache_level, + get_mem_tracking_mode, + }, }, + hal::MultilinearPolynomial, }; use mpcs::PolynomialCommitmentScheme; @@ -40,15 +45,115 @@ pub fn init_gpu_mem_tracker<'a>( const ESTIMATION_TOLERANCE_BYTES: usize = 2 * 1024 * 1024; // max under-estimation error: 2 MB const ESTIMATION_SAFETY_MARGIN_BYTES: usize = 10 * 1024 * 1024; // reserved headroom / allowed over-estimate margin: 10 MB +const SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES: usize = 512 * 1024 * 1024; +const SHARD_RAM_TOWER_PROVE_TOLERANCE_BYTES: usize = 16 * 1024 * 1024; /// Validate that the estimated GPU memory matches actual usage within tolerance. /// - Under-estimate (actual > estimated): diff must be <= `ESTIMATION_TOLERANCE_BYTES` /// - Over-estimate (estimated > actual): diff must be <= `ESTIMATION_SAFETY_MARGIN_BYTES` pub fn check_gpu_mem_estimation(mem_tracker: Option, estimated_bytes: usize) { + check_gpu_mem_estimation_with_context(mem_tracker, estimated_bytes, None); +} + +pub fn check_gpu_mem_estimation_with_context( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, +) { + check_gpu_mem_estimation_with_margins( + mem_tracker, + estimated_bytes, + context, + ESTIMATION_TOLERANCE_BYTES, + ESTIMATION_SAFETY_MARGIN_BYTES, + ); +} + +pub(crate) fn check_gpu_tower_prove_mem_estimation_with_context( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, +) { + let (under_tolerance_bytes, over_tolerance_bytes) = if context == Some("ShardRamCircuit") { + ( + SHARD_RAM_TOWER_PROVE_TOLERANCE_BYTES, + SHARD_RAM_TOWER_PROVE_TOLERANCE_BYTES, + ) + } else { + (ESTIMATION_TOLERANCE_BYTES, ESTIMATION_SAFETY_MARGIN_BYTES) + }; + check_gpu_mem_estimation_with_margins( + mem_tracker, + estimated_bytes, + context, + under_tolerance_bytes, + over_tolerance_bytes, + ); +} + +pub fn check_gpu_scheduler_mem_estimation_with_context( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, +) { + // Scheduler estimates are admission-control estimates, not exact stage-local allocation + // estimates. They intentionally include safety margins and conservative lifetime overlap, so + // large over-estimates should be surfaced as warnings rather than failing the proof. Under- + // estimates remain hard failures because they can admit unsafe concurrent work. + if let Some(mem_tracker) = mem_tracker { + const ONE_MB: usize = 1024 * 1024; + let label = mem_tracker.name(); + let label = context + .filter(|context| !context.is_empty()) + .map(|context| format!("{label}[{context}]")) + .unwrap_or_else(|| label.to_string()); + let mem_stats = mem_tracker.finish(); + let actual_bytes = mem_stats.mem_occupancy as usize; + let diff = estimated_bytes as isize - actual_bytes as isize; + let to_mb = |b: usize| b as f64 / ONE_MB as f64; + let diff_mb = diff as f64 / ONE_MB as f64; + tracing::info!( + "[memcheck] {label}: scheduler_estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb + ); + if diff < 0 { + assert!( + (-diff) as usize <= ESTIMATION_TOLERANCE_BYTES, + "[memcheck] {label}: scheduler under-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, tolerance={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb, + to_mb(ESTIMATION_TOLERANCE_BYTES), + ); + } else if diff as usize > SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES { + tracing::warn!( + "[memcheck] {label}: scheduler over-estimate warning: estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, warning_margin={:.2}MB", + to_mb(estimated_bytes), + to_mb(actual_bytes), + diff_mb, + to_mb(SCHEDULER_ESTIMATION_WARNING_MARGIN_BYTES), + ); + } + } +} + +fn check_gpu_mem_estimation_with_margins( + mem_tracker: Option, + estimated_bytes: usize, + context: Option<&str>, + under_tolerance_bytes: usize, + over_tolerance_bytes: usize, +) { // `mem_tracker will` be Some only in sequential mode with mem tracking enabled, so if it's None, do nothing if let Some(mem_tracker) = mem_tracker { const ONE_MB: usize = 1024 * 1024; let label = mem_tracker.name(); + let label = context + .filter(|context| !context.is_empty()) + .map(|context| format!("{label}[{context}]")) + .unwrap_or_else(|| label.to_string()); let mem_stats = mem_tracker.finish(); let actual_bytes = mem_stats.mem_occupancy as usize; let diff = estimated_bytes as isize - actual_bytes as isize; @@ -63,22 +168,22 @@ pub fn check_gpu_mem_estimation(mem_tracker: Option, estimated_bytes if diff < 0 { // Under-estimate: actual exceeds estimated assert!( - (-diff) as usize <= ESTIMATION_TOLERANCE_BYTES, + (-diff) as usize <= under_tolerance_bytes, "[memcheck] {label}: under-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, tolerance={:.2}MB", to_mb(estimated_bytes), to_mb(actual_bytes), diff_mb, - to_mb(ESTIMATION_TOLERANCE_BYTES), + to_mb(under_tolerance_bytes), ); } else { // Over-estimate: estimated exceeds actual assert!( - diff as usize <= ESTIMATION_SAFETY_MARGIN_BYTES, + diff as usize <= over_tolerance_bytes, "[memcheck] {label}: over-estimate! estimated={:.2}MB, actual={:.2}MB, diff={:.2}MB, margin={:.2}MB", to_mb(estimated_bytes), to_mb(actual_bytes), diff_mb, - to_mb(ESTIMATION_SAFETY_MARGIN_BYTES), + to_mb(over_tolerance_bytes), ); } } @@ -92,11 +197,14 @@ pub fn estimate_chip_proof_memory>, circuit_name: &str, replay_plan: Option<&GpuReplayPlan>, + witness_trace_rows: Option, structural_cached_on_device: bool, ) -> u64 { let num_var_with_rotation = input.log2_num_instances() + composed_cs.rotation_vars().unwrap_or(0); let witness_replayable = replay_plan.is_some(); + let occupied_rows = + estimate_witness_occupied_rows(composed_cs, input, replay_plan, witness_trace_rows); let structural_resident_bytes = if structural_cached_on_device { 0 } else { @@ -110,12 +218,15 @@ pub fn estimate_chip_proof_memory>( + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'_, GpuBackend>, + replay_plan: Option<&GpuReplayPlan>, + witness_trace_rows: Option, +) -> usize { + if let Some(replay_plan) = replay_plan { + return replay_plan_actual_rows(replay_plan); + } + input + .witness + .first() + .map(|mle| mle.evaluations_len()) + .or(witness_trace_rows) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)) +} + pub(crate) struct TraceEstimate { /// Persistent resident bytes (witness polys + structural MLEs) pub(crate) trace_resident_bytes: usize, @@ -258,6 +387,25 @@ pub(crate) fn estimate_structural_mle_bytes(num_structural_witin: usize, num_var num_structural_witin * mle_len * base_elem_size } +fn replay_plan_actual_rows(replay_plan: &GpuReplayPlan) -> usize { + match replay_plan.kind { + GpuWitgenKind::Keccak => replay_plan + .keccak_instances + .as_ref() + .map(|instances| instances.len() * 32) + .unwrap_or(replay_plan.trace_height), + GpuWitgenKind::ShardRam => replay_plan.trace_height, + _ => replay_plan.step_indices.len(), + } +} + +fn replay_plan_actual_structural_rows(replay_plan: &GpuReplayPlan) -> usize { + match replay_plan.kind { + GpuWitgenKind::ShardRam => replay_plan.shard_ram_num_records, + _ => replay_plan.trace_height, + } +} + pub fn estimate_replay_materialization_bytes( num_witin: usize, _num_structural_witin: usize, @@ -273,7 +421,7 @@ pub fn estimate_replay_materialization_bytes_for_plan( _num_vars: usize, ) -> usize { let elem_size = std::mem::size_of::(); - let witness_bytes = replay_plan.trace_height * replay_plan.num_witin * elem_size; + let witness_bytes = replay_plan_actual_rows(replay_plan) * replay_plan.num_witin * elem_size; let replay_temp_bytes = match replay_plan.kind { GpuWitgenKind::Keccak => replay_plan .keccak_instances @@ -299,8 +447,10 @@ pub fn estimate_replay_materialization_bytes_for_plan( pub(crate) fn estimate_trace_bytes>( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>, + replay_plan: Option<&GpuReplayPlan>, witness_replayable: bool, structural_cached_on_device: bool, + occupied_rows_override: Option, ) -> TraceEstimate { let cs = &composed_cs.zkvm_v1_css; let num_var_with_rotation = @@ -308,14 +458,39 @@ pub(crate) fn estimate_trace_bytes() + }) + .unwrap_or_else(|| { + estimate_structural_mle_bytes( + cs.num_structural_witin as usize, + num_var_with_rotation, + ) + }) } else { estimate_structural_mle_bytes(cs.num_structural_witin as usize, num_var_with_rotation) }; - let (witness_mle_bytes, trace_temporary_bytes) = estimate_trace_extraction_bytes( - cs.num_witin as usize, - num_var_with_rotation, - witness_replayable, - ); + let (witness_mle_bytes, trace_temporary_bytes) = + if should_materialize_witness_on_gpu() && witness_replayable { + let base_elem_size = std::mem::size_of::(); + let actual_rows = replay_plan + .map(replay_plan_actual_rows) + .unwrap_or(1usize << num_var_with_rotation); + (cs.num_witin as usize * actual_rows * base_elem_size, 0) + } else { + estimate_trace_extraction_bytes( + cs.num_witin as usize, + num_var_with_rotation, + occupied_rows_override.unwrap_or_else(|| { + input.num_instances() << composed_cs.rotation_vars().unwrap_or(0) + }), + witness_replayable, + ) + }; TraceEstimate { trace_resident_bytes: witness_mle_bytes + structural_mle_bytes, @@ -325,11 +500,74 @@ pub(crate) fn estimate_trace_bytes( composed_cs: &ComposedConstrainSystem, - num_var_with_rotation: usize, + output_rows: usize, ) -> usize { let elem_size = std::mem::size_of::(); - let record_len = 1usize << num_var_with_rotation; - tower_output_count(composed_cs) * record_len * elem_size + main_witness_materialized_output_count(composed_cs) * output_rows * elem_size +} + +fn main_witness_materialized_output_count( + composed_cs: &ComposedConstrainSystem, +) -> usize { + let Some(gkr_circuit) = composed_cs.gkr_circuit.as_ref() else { + return 0; + }; + let final_layer_output_count = tower_output_count(composed_cs); + + gkr_circuit + .layers + .iter() + .enumerate() + .map(|(layer_index, layer)| { + let final_layer = layer_index == 0; + let out_evals = layer + .out_sel_and_eval_exprs + .iter() + .flat_map(|(_, out_eval)| out_eval.iter()); + + if final_layer { + out_evals + .take(final_layer_output_count) + .filter(|out_eval| main_witness_materializes_output(out_eval)) + .count() + } else { + out_evals + .filter(|out_eval| main_witness_materializes_output(out_eval)) + .count() + } + }) + .sum() +} + +fn main_witness_materializes_output(out_eval: &EvalExpression) -> bool { + matches!( + out_eval, + EvalExpression::Single(_) | EvalExpression::Linear(_, _, _) + ) +} + +pub fn main_witness_output_rows>( + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'_, GpuBackend>, + occupied_rows_override: Option, +) -> usize { + if composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.last()) + .is_some_and(|input_layer| input_layer.in_eval_expr.is_empty()) + { + if let Some(structural_mle) = input.structural_witness.first() { + return structural_mle.evaluations_len(); + } + } + + input + .witness + .first() + .map(|mle| mle.evaluations_len()) + .or(occupied_rows_override) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)) } pub(crate) fn estimate_main_constraints_bytes< @@ -403,7 +641,8 @@ pub(crate) fn estimate_main_constraints_bytes< fn estimate_tower_stage_components>( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>, -) -> (usize, usize, usize) { + occupied_rows_override: Option, +) -> (usize, usize, usize, usize) { let cs = &composed_cs.zkvm_v1_css; let num_prod_towers = composed_cs.num_reads() + composed_cs.num_writes(); let num_logup_towers = if composed_cs.is_with_lk_table() { @@ -418,31 +657,50 @@ fn estimate_tower_stage_components(); let has_logup_numerator = composed_cs.is_with_lk_table(); + let occupied_rows = input + .witness + .first() + .map(|mle| mle.evaluations_len()) + .or(occupied_rows_override) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)); let build_est = estimate_build_tower_memory( num_prod_towers, num_logup_towers, num_vars, num_vars, + occupied_rows, elem_size, has_logup_numerator, ); + let is_shard_ram = composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .is_some_and(|layer| layer.name == "ShardRamCircuit_main"); + let shard_ram_tower_batch_overhead = is_shard_ram.then_some(10 * 1024 * 1024).unwrap_or(0); + let build_bytes = build_est.total_bytes + shard_ram_tower_batch_overhead; let prove_est = estimate_prove_tower_memory( num_prod_towers, num_logup_towers, num_vars, num_vars, + occupied_rows, NUM_FANIN, elem_size, + has_logup_numerator, ); let tower_input_live_bytes = prove_est.prod_tower_buffer_bytes + prove_est.logup_tower_buffer_bytes; + let borrowed_input_bytes = + prove_est.prod_borrowed_input_bytes + prove_est.logup_borrowed_input_bytes; let prove_local_bytes = prove_est.total_bytes.saturating_sub(tower_input_live_bytes); ( - build_est.total_bytes, + build_bytes, prove_local_bytes, tower_input_live_bytes, + borrowed_input_bytes, ) } @@ -452,7 +710,8 @@ pub(crate) fn estimate_tower_stage_bytes, input: &ProofInput<'_, GpuBackend>, ) -> (usize, usize) { - let (build_bytes, prove_local_bytes, _) = estimate_tower_stage_components(composed_cs, input); + let (build_bytes, prove_local_bytes, _, _) = + estimate_tower_stage_components(composed_cs, input, None); (build_bytes, prove_local_bytes) } @@ -460,8 +719,8 @@ pub(crate) fn estimate_tower_bytes, input: &ProofInput<'_, GpuBackend>, ) -> usize { - let (build_bytes, prove_local_bytes, tower_input_live_bytes) = - estimate_tower_stage_components(composed_cs, input); + let (build_bytes, prove_local_bytes, tower_input_live_bytes, _) = + estimate_tower_stage_components(composed_cs, input, None); build_bytes.max(tower_input_live_bytes + prove_local_bytes) } @@ -474,12 +733,13 @@ pub(crate) fn estimate_tower_bytes (usize, usize) { let base_elem_size = std::mem::size_of::(); - let mle_len = 1usize << num_vars; - let poly_bytes = num_witin * mle_len * base_elem_size; + let compact_poly_bytes = num_witin * occupied_rows * base_elem_size; + let transpose_temporary_bytes = 2 * compact_poly_bytes; if should_materialize_witness_on_gpu() { if should_retain_witness_device_backing_after_commit() { @@ -495,19 +755,21 @@ pub(crate) fn estimate_trace_extraction_bytes( // duration of the chip proof. There is no separate extraction temp // buffer, but the replayed witness itself must be accounted for as // resident task memory. - return (poly_bytes, 0); + return (compact_poly_bytes, 0); } // GPU witgen alone does not imply replayability. Non-replayable traces - // still go through basefold::get_trace in cache-none mode, which - // allocates the extracted witness plus a temporary 2x transpose buffer. - return (poly_bytes, 2 * poly_bytes); + // still go through basefold::get_trace in cache-none mode. The fallback + // transpose buffer is 2x the compact RMM backing, not 2x the logical + // domain length. + return (compact_poly_bytes, transpose_temporary_bytes); } if matches!(get_gpu_cache_level(), CacheLevel::None) { // Default cache level is None - // get_trace allocates poly copies (resident) + temp_buffer (2x, freed after) - (poly_bytes, 2 * poly_bytes) + // get_trace allocates poly copies (resident) + temp_buffer over the + // compact RMM backing (2x, freed after). + (compact_poly_bytes, transpose_temporary_bytes) } else { (0, 0) } diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 142f894fe..8f27afd44 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -1,11 +1,13 @@ use super::hal::{ - DeviceTransporter, EccQuarkProver, MainSumcheckProver, OpeningProver, ProverDevice, - RotationProver, TowerProver, TraceCommitter, + BatchedMainConstraintProver, DeviceTransporter, EccQuarkProver, MainConstraintJob, + MainConstraintResult, MainSumcheckProver, OpeningProver, ProverDevice, RotationProver, + TowerProver, TraceCommitter, }; use crate::{ error::ZKVMError, instructions::gpu::cache::current_replay_cache_stats, scheme::{ + MainConstraintProof, constants::SEPTIC_EXTENSION_DEGREE, cpu::TowerRelationOutput, hal::{ @@ -22,11 +24,13 @@ use crate::{ use ceno_gpu::{ Buffer, CudaHal, bb31::{CudaHalBB31, GpuPolynomial}, + common::sumcheck::CommonTermPlan, get_cuda_mem_info, }; use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ + error::BackendError, gkr::{ self, Evaluation, GKRProof, GKRProverOutput, layer::{LayerWitness, gpu::utils::extract_mle_relationships_from_monomial_terms}, @@ -40,25 +44,27 @@ use multilinear_extensions::{ Expression, ToExpr, mle::{FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, + utils::eval_by_expr_constant, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use p3::matrix::Matrix; +use p3::{field::FieldAlgebra, matrix::Matrix}; use rayon::iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; use std::{ collections::BTreeMap, + io::Write, iter::{once, repeat_n}, sync::Arc, }; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverMessage}, - util::optimal_sumcheck_threads, + util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::{BasicTranscript, Transcript}; -use witness::next_pow2_instance_padding; +use witness::{InstancePaddingStrategy, next_pow2_instance_padding}; use ceno_gpu::common::transpose::matrix_transpose; use tracing::info_span; @@ -68,11 +74,14 @@ use witness::DeviceMatrixLayout; use gkr_iop::gpu::gpu_prover::*; mod memory; + mod util; pub(crate) use memory::{ - check_gpu_mem_estimation, estimate_chip_proof_memory, estimate_main_witness_bytes, - estimate_replay_materialization_bytes_for_plan, estimate_tower_bytes, - estimate_tower_stage_bytes, init_gpu_mem_tracker, + check_gpu_mem_estimation, check_gpu_mem_estimation_with_context, + check_gpu_scheduler_mem_estimation_with_context, + check_gpu_tower_prove_mem_estimation_with_context, estimate_chip_proof_memory, + estimate_main_witness_bytes, estimate_replay_materialization_bytes_for_plan, + estimate_tower_bytes, estimate_tower_stage_bytes, init_gpu_mem_tracker, }; use memory::{ estimate_ecc_quark_bytes_from_num_vars, estimate_main_constraints_bytes, @@ -101,6 +110,30 @@ struct PcsResidentStats { total_rmms: usize, } +fn rmm_device_backing_bytes(rmm: &witness::RowMajorMatrix) -> usize +where + T: FieldAlgebra + Default + Sync + Clone + Send + Copy + 'static, +{ + rmm.device_backing_ref::>() + .map(|device_buffer| device_buffer.len() * std::mem::size_of::()) + .unwrap_or(0) +} + +fn rmm_col_major_device_rows(rmm: &witness::RowMajorMatrix) -> Option +where + T: FieldAlgebra + Default + Sync + Clone + Send + Copy + 'static, +{ + if rmm.device_backing_layout() != Some(DeviceMatrixLayout::ColMajor) { + return None; + } + let cols = rmm.width(); + if cols == 0 { + return Some(0); + } + let device_buffer = rmm.device_backing_ref::>()?; + Some(device_buffer.len() / cols) +} + fn pcs_resident_stats( pcs_data_basefold: &BasefoldCommitmentWithWitnessGpu< BB31Base, @@ -141,7 +174,7 @@ fn pcs_resident_stats( ( rmms.iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(), rmms.iter().filter(|rmm| rmm.has_device_backing()).count(), rmms.len(), @@ -275,11 +308,14 @@ pub fn log_gpu_pool_usage(label: &str) { let used_bytes = pool.get_used_size().unwrap_or(0); let reserved_bytes = pool.get_reserved_size().unwrap_or(0); let mb = |bytes: usize| bytes as f64 / (1024.0 * 1024.0); - tracing::info!( + let message = format!( "[gpu pool][{label}] used={:.2}MB reserved={:.2}MB", mb(used_bytes as usize), mb(reserved_bytes as usize), ); + eprintln!("{message}"); + let _ = std::io::stderr().flush(); + tracing::info!("{}", message); } pub fn log_gpu_device_state(label: &str) { @@ -292,7 +328,7 @@ pub fn log_gpu_device_state(label: &str) { let (cuda_free_bytes, cuda_total_bytes) = get_cuda_mem_info().unwrap_or((0usize, 0usize)); let cuda_used_bytes = cuda_total_bytes.saturating_sub(cuda_free_bytes); let mb = |bytes: usize| bytes as f64 / (1024.0 * 1024.0); - tracing::info!( + let message = format!( "[gpu device][{label}] cuda_used={:.2}MB cuda_free={:.2}MB cuda_total={:.2}MB | pool_used={:.2}MB pool_reserved={:.2}MB pool_booked={:.2}MB pool_max={:.2}MB", mb(cuda_used_bytes), mb(cuda_free_bytes), @@ -302,6 +338,9 @@ pub fn log_gpu_device_state(label: &str) { mb(booked_bytes as usize), mb(max_bytes as usize), ); + eprintln!("{message}"); + let _ = std::io::stderr().flush(); + tracing::info!("{}", message); } use crate::scheme::{constants::NUM_FANIN, septic_curve::SepticPoint}; use gkr_iop::{ @@ -318,7 +357,7 @@ pub fn prove_tower_relation_impl as ProverBackend>::E>, cuda_hal: &Arc, -) -> TowerRelationOutput { +) -> Result, ZKVMError> { let stream = gkr_iop::gpu::get_thread_stream(); if std::any::TypeId::of::() != std::any::TypeId::of::() { panic!("GPU backend only supports Goldilocks base field"); @@ -331,25 +370,14 @@ pub fn prove_tower_relation_impl> = Vec::new(); - let mut _ones_buffer: Vec> = Vec::new(); - let mut _view_last_layers: Vec>>> = Vec::new(); - let (prod_gpu, logup_gpu) = info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { - build_tower_witness_gpu( - composed_cs, - input, - records, - challenges, - cuda_hal, - &mut _big_buffers, - &mut _ones_buffer, - &mut _view_last_layers, - ) - .map_err(|e| format!("build_tower_witness_gpu failed: {}", e)) - .unwrap() - }); + let (prod_gpu, logup_gpu) = + info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { + build_tower_witness_gpu(composed_cs, input, records, challenges, cuda_hal) + .map_err(|e| format!("build_tower_witness_gpu failed: {}", e)) + .map_err(|e| ZKVMError::InvalidWitness(e.into())) + })?; exit_span!(span); // GPU optimization: Extract out_evals from GPU-built towers before consuming them @@ -377,12 +405,17 @@ pub fn prove_tower_relation_impl let wit = LayerWitness( chain!(&input.witness, &input.fixed, &input.structural_witness) .cloned() - .collect_vec(), + .map(|mle| unsafe { std::mem::transmute(mle) }) + .collect(), ); let (proof, points) = gkr_iop::gkr::layer::gpu::prove_rotation_gpu::( @@ -694,7 +728,8 @@ pub fn prove_main_constraints_impl< layers: vec![LayerWitness( chain!(&input.witness, &input.fixed, &input.structural_witness,) .cloned() - .collect_vec(), + .map(|mle| unsafe { std::mem::transmute(mle) }) + .collect(), )], }, &out_evals, @@ -1108,6 +1143,18 @@ where }) .map_err(|e| ceno_gpu::HalError::InvalidInput(format!("{e:?}")))?, }; + let witness_rmm = if witness_rmm.width() == 0 { + tracing::warn!( + "[gpu] replacing zero-width deferred witness trace at index {trace_idx} with a dummy column" + ); + witness::RowMajorMatrix::::new( + witness_rmm.num_instances(), + 1, + InstancePaddingStrategy::Default, + ) + } else { + witness_rmm + }; Ok(unsafe { std::mem::transmute(witness_rmm) }) }) .unwrap(); @@ -1154,6 +1201,18 @@ impl> TraceCommitter> = traces.into_values().collect(); + for (trace_idx, trace) in vec_traces.iter_mut().enumerate() { + if trace.width() == 0 { + tracing::warn!( + "[gpu] replacing zero-width witness trace at index {trace_idx} with a dummy column" + ); + *trace = witness::RowMajorMatrix::::new( + trace.num_instances(), + 1, + InstancePaddingStrategy::Default, + ); + } + } if crate::instructions::gpu::config::should_materialize_witness_on_gpu() { let span = entered_span!("[gpu] normalize_trace_backing", profiling_2 = true); @@ -1272,9 +1331,13 @@ impl> TraceCommitter>> = poly_group @@ -1344,6 +1412,144 @@ where mles } +fn shard_ram_compact_physical_rows(col_idx: usize, num_records: usize, full_rows: usize) -> usize { + // ShardRAM witness columns are laid out as: + // 0..7 x EC coordinates + // 7..14 y EC coordinates + // 14..21 EC addition slopes + // 21..30 scalar record fields + // 30.. Poseidon2 trace columns + // + // Only the scalar record fields and Poseidon2 trace are prefix-populated + // on real record rows. EC columns also carry internal tree rows in the + // upper half, so they must keep the full logical backing. + if col_idx < 21 { full_rows } else { num_records } +} + +pub fn extract_shard_ram_witness_mles_for_trace<'a, E, PCS>( + pcs_data: & as ProverBackend>::PcsData, + trace_idx: usize, + expected_num: usize, + num_vars: usize, + num_records: usize, +) -> Vec>> +where + E: ExtensionField, + PCS: PolynomialCommitmentScheme, +{ + assert_eq!( + std::any::TypeId::of::(), + std::any::TypeId::of::(), + "GPU ShardRAM compact extraction only supports BabyBear base field", + ); + + let pcs_data_basefold: &BasefoldCommitmentWithWitnessGpu< + BB31Base, + BufferImpl, + GpuDigestLayer, + GpuMatrix<'static>, + GpuPolynomial<'static>, + > = unsafe { std::mem::transmute(pcs_data) }; + + let Some(rmms) = pcs_data_basefold.rmms.as_ref() else { + return extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + expected_num, + num_vars, + ); + }; + let rmm = &rmms[trace_idx]; + assert_eq!( + rmm.width(), + expected_num, + "ShardRAM trace width mismatch: expected {}, got {}", + expected_num, + rmm.width(), + ); + + let cuda_hal = get_cuda_hal().unwrap(); + let full_rows = rmm.height(); + assert_eq!( + full_rows, + 1usize << num_vars, + "ShardRAM trace height must match logical num_vars", + ); + assert!( + num_records <= full_rows, + "ShardRAM compact rows exceed full rows: {} > {}", + num_records, + full_rows, + ); + + let mles = if rmm.device_backing_layout() == Some(DeviceMatrixLayout::ColMajor) { + let device_buffer = rmm + .device_backing_ref::>() + .unwrap_or_else(|| panic!("ShardRAM col-major device backing type mismatch")); + let elem_size = std::mem::size_of::(); + let col_stride_bytes = full_rows * elem_size; + (0..expected_num) + .map(|col_idx| { + let physical_rows = + shard_ram_compact_physical_rows(col_idx, num_records, full_rows); + let start = col_idx * col_stride_bytes; + let end = start + physical_rows * elem_size; + let view_buf = device_buffer.owned_subrange(start..end); + let view_poly = GpuPolynomial::new(view_buf, num_vars); + let poly_static: GpuPolynomial<'static> = unsafe { std::mem::transmute(view_poly) }; + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(poly_static); + Arc::new(unsafe { + std::mem::transmute::< + MultilinearExtensionGpu<'static, E>, + MultilinearExtensionGpu<'a, E>, + >(mle_static) + }) + }) + .collect::>() + } else { + let values = rmm.values(); + (0..expected_num) + .map(|col_idx| { + let physical_rows = + shard_ram_compact_physical_rows(col_idx, num_records, full_rows); + let mut column = Vec::with_capacity(physical_rows); + column.extend((0..physical_rows).map(|row| values[row * expected_num + col_idx])); + let column_bb31: Vec = unsafe { + let mut column = std::mem::ManuallyDrop::new(column); + Vec::from_raw_parts( + column.as_mut_ptr() as *mut BB31Base, + column.len(), + column.capacity(), + ) + }; + let gpu_poly = cuda_hal + .alloc_elems_from_host(&column_bb31, None) + .map(|buffer| GpuPolynomial::new(buffer, num_vars)) + .unwrap_or_else(|err| panic!("ShardRAM compact H2D failed: {err:?}")); + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(gpu_poly); + Arc::new(unsafe { + std::mem::transmute::< + MultilinearExtensionGpu<'static, E>, + MultilinearExtensionGpu<'a, E>, + >(mle_static) + }) + }) + .collect::>() + }; + + eprintln!( + "[ceno][shard-ram-compact-mle] trace_idx={} cols={} records={} full_rows={} compact_cols={}", + trace_idx, + expected_num, + num_records, + full_rows, + expected_num.saturating_sub(21), + ); + let _ = std::io::stderr().flush(); + + mles +} + pub fn extract_witness_mles_for_trace_rmm<'a, E>( witness_rmm: witness::RowMajorMatrix<::BaseField>, ) -> Vec>> @@ -1367,7 +1573,8 @@ where let device_buffer = witness_rmm .device_backing_ref::>() .unwrap_or_else(|| panic!("col-major replay witness device backing type mismatch")); - let rows = witness_rmm.height(); + let rows = rmm_col_major_device_rows(&witness_rmm) + .unwrap_or_else(|| panic!("col-major replay witness device backing row count mismatch")); let cols = witness_rmm.width(); let poly_len_bytes = rows * std::mem::size_of::(); @@ -1380,14 +1587,11 @@ where (0..cols) .map(|col_idx| { let src_byte_offset = col_idx * poly_len_bytes; - // Keep an owned handle to the parent GPU allocation instead of a - // borrowed CudaView. The resulting MLE outlives this helper. let view_buf = device_buffer.owned_subrange(src_byte_offset..src_byte_offset + poly_len_bytes); - let view_poly = GpuPolynomial::new(view_buf, rows.trailing_zeros() as usize); - let view_poly_static: GpuPolynomial<'static> = - unsafe { std::mem::transmute(view_poly) }; - let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(view_poly_static); + let view_poly = GpuPolynomial::new(view_buf, witness_rmm.num_vars()); + let poly_static: GpuPolynomial<'static> = unsafe { std::mem::transmute(view_poly) }; + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(poly_static); Arc::new(unsafe { std::mem::transmute::< MultilinearExtensionGpu<'static, E>, @@ -1421,7 +1625,7 @@ pub fn clear_replayable_trace_device_backing( let before_device_bytes = rmms .iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(); for (trace_idx, _) in replayable_traces { @@ -1432,7 +1636,7 @@ pub fn clear_replayable_trace_device_backing( let after_device_bytes = rmms .iter() .filter(|rmm| rmm.has_device_backing()) - .map(|rmm| rmm.height() * rmm.width() * std::mem::size_of::()) + .map(rmm_device_backing_bytes) .sum::(); tracing::info!( "[gpu] cleared replayable PCS RMM device backing: replayable_traces={}, rmms_device_before={:.2}MB ({}) -> after={:.2}MB ({})", @@ -1508,7 +1712,8 @@ where let device_buffer = structural_rmm .device_backing_ref::>() .unwrap_or_else(|| panic!("col-major structural device backing type mismatch")); - let rows = structural_rmm.height(); + let rows = rmm_col_major_device_rows(structural_rmm) + .unwrap_or_else(|| panic!("col-major structural device backing row count mismatch")); let cols = structural_rmm.width(); let poly_len_bytes = rows * std::mem::size_of::(); let total_bytes = cols * poly_len_bytes; @@ -1517,8 +1722,7 @@ where total_bytes, "structural col-major buffer size mismatch" ); - let num_vars_in_poly = rows.trailing_zeros() as usize; - assert_eq!(rows, 1usize << num_vars_in_poly); + let num_vars_in_poly = structural_rmm.num_vars(); (0..cols) .map(|col_idx| { @@ -1559,25 +1763,22 @@ where } #[allow(clippy::too_many_arguments)] -pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( +pub(crate) fn build_tower_witness_gpu( composed_cs: &ComposedConstrainSystem, input: &ProofInput<'_, GpuBackend>>, records: &[ArcMultilinearExtensionGpu<'_, E>], challenges: &[E; 2], cuda_hal: &CudaHalBB31, - big_buffers: &'buf mut Vec>, - ones_buffer: &mut Vec>, - view_last_layers: &mut Vec>>>, ) -> Result< ( - Vec>, - Vec>, + Vec>, + Vec>, ), String, > { let stream = gkr_iop::gpu::get_thread_stream(); use crate::scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}; - use ceno_gpu::{CudaHal as _, bb31::GpuPolynomialExt}; + use ceno_gpu::bb31::GpuPolynomialExt; use p3::field::FieldAlgebra; let ComposedConstrainSystem { @@ -1585,7 +1786,7 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( } = composed_cs; let _num_instances_with_rotation = input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); - let _chip_record_alpha = challenges[0]; + let chip_record_alpha: BB31Ext = unsafe { std::mem::transmute_copy(&challenges[0]) }; // SAFETY: The `records` slice is borrowed for the duration of this function call. // The lifetime is erased to 'static only to satisfy GPU API signatures that require @@ -1616,46 +1817,56 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( &records[offset..][..cs.lk_expressions.len()] }; - assert_eq!(big_buffers.len(), 0, "expect no big buffers"); - - // prod: last layes & buffer - let mut is_prod_buffer_exists = false; + // prod: split last layer once, then build compact tower layers. let prod_last_layers = r_set_wit .iter() .chain(w_set_wit.iter()) - .map(|wit| wit.as_view_chunks(NUM_FANIN)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks(&*cuda_hal, poly, NUM_FANIN, BB31Ext::ONE, stream.as_ref()) + .map_err(|e| format!("Failed to split compact prod tower input: {e}")), + _ => return Err("tower witness expects extension-field record MLEs".to_string()), + }) + .collect::, String>>()?; if !prod_last_layers.is_empty() { let first_layer = &prod_last_layers[0]; assert_eq!(first_layer.len(), 2, "prod last_layer must have 2 MLEs"); - let num_vars = first_layer[0].num_vars(); - let num_towers = prod_last_layers.len(); - view_last_layers.push(prod_last_layers); - - // Allocate one big buffer for all product towers and add it to big_buffers - let tower_size = 1 << (num_vars + 1); // 2 * mle_len elements per tower - let total_buffer_size = num_towers * tower_size; - tracing::debug!( - "prod tower request buffer size: {:.2} MB", - (total_buffer_size * std::mem::size_of::()) as f64 / (1024.0 * 1024.0) - ); - let big_buffer = cuda_hal - .alloc_ext_elems_on_device(total_buffer_size, false, stream.as_ref()) - .map_err(|e| format!("Failed to allocate prod GPU buffer: {:?}", e))?; - big_buffers.push(big_buffer); - is_prod_buffer_exists = true; } - // logup: last layes - let mut is_logup_buffer_exists = false; + // logup: split last layer once, then build compact tower layers. let lk_numerator_last_layer = lk_n_wit .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks( + &*cuda_hal, + poly, + NUM_FANIN_LOGUP, + chip_record_alpha, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to split compact logup numerator: {e}")), + _ => Err("tower witness expects extension-field logup numerator MLEs".to_string()), + }) + .collect::, String>>()?; let lk_denominator_last_layer = lk_d_wit .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); + .map(|wit| match wit.inner() { + gkr_iop::gpu::GpuFieldType::Ext(poly) => cuda_hal + .tower + .masked_mle_view_chunks( + &*cuda_hal, + poly, + NUM_FANIN_LOGUP, + chip_record_alpha, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to split compact logup denominator: {e}")), + _ => Err("tower witness expects extension-field logup denominator MLEs".to_string()), + }) + .collect::, String>>()?; let logup_last_layers = if !lk_numerator_last_layer.is_empty() { // Case when we have both numerator and denominator // Combine [p1, p2] from numerator and [q1, q2] from denominator @@ -1665,100 +1876,58 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( .map(|(lk_n_chunks, lk_d_chunks)| { let mut last_layer = lk_n_chunks; last_layer.extend(lk_d_chunks); - last_layer + Ok(last_layer) }) - .collect::>() + .collect::, String>>()? } else if lk_denominator_last_layer.is_empty() { vec![] } else { - // Case when numerator is empty - create shared ones_buffer and use views - // This saves memory by having all p1, p2 polynomials reference the same buffer + // Case when numerator is empty: share one scalar compact polynomial. + // Its tail default is also ONE, so all logical numerator entries read as ONE + // without materializing per-chunk denominator-sized buffers. let nv = lk_denominator_last_layer[0][0].num_vars(); + let ones_poly = GpuPolynomialExt::new_with_scalar_len( + &cuda_hal.inner, + nv, + 1, + BB31Ext::ONE, + stream.as_ref(), + ) + .map_err(|e| format!("Failed to create compact shared ones numerator: {e:?}"))?; + let ones_poly: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; + let one_len_bytes = ones_poly.buf.len() * std::mem::size_of::(); - // Create one shared ones_buffer as Owned (can be 'static) - let ones_poly = - GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE, stream.as_ref()) - .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e)) - .unwrap(); - // SAFETY: Owned buffer can be safely treated as 'static - let ones_poly_static: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; - ones_buffer.push(ones_poly_static); - - // Get reference from storage to ensure proper lifetime - let ones_poly_ref = ones_buffer.last().unwrap(); - let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); - - // Create views referencing the shared ones_buffer for each tower's p1, p2 lk_denominator_last_layer .into_iter() .map(|lk_d_chunks| { - // Create views of ones_buffer for p1 and p2 - let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); - let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); - // SAFETY: views from 'static buffer can be 'static - let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; - let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; - // Use [p1, p2, q1, q2] format for the last layer + let p1_gpu = GpuPolynomialExt::new_with_tail_default( + ones_poly.buf.owned_subrange(0..one_len_bytes), + nv, + BB31Ext::ONE, + ); + let p2_gpu = GpuPolynomialExt::new_with_tail_default( + ones_poly.buf.owned_subrange(0..one_len_bytes), + nv, + BB31Ext::ONE, + ); let mut last_layer = vec![p1_gpu, p2_gpu]; last_layer.extend(lk_d_chunks); - last_layer + Ok(last_layer) }) - .collect::>() + .collect::, String>>()? }; if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; assert_eq!(first_layer.len(), 4, "logup last_layer must have 4 MLEs"); - let num_vars = first_layer[0].num_vars(); - let num_towers = logup_last_layers.len(); - view_last_layers.push(logup_last_layers); - - // Allocate one big buffer for all towers and add it to big_buffers - let tower_size = 1 << (num_vars + 2); // 4 * mle_len elements per tower - let total_buffer_size = num_towers * tower_size; - tracing::debug!( - "logup tower request buffer size: {:.2} MB", - (total_buffer_size * std::mem::size_of::()) as f64 / (1024.0 * 1024.0) - ); - let big_buffer = cuda_hal - .alloc_ext_elems_on_device(total_buffer_size, false, stream.as_ref()) - .unwrap(); - big_buffers.push(big_buffer); - is_logup_buffer_exists = true; } - let (_, pushed_big_buffers) = big_buffers.split_at_mut(0); - let (prod_big_buffer, logup_big_buffer) = match ( - is_prod_buffer_exists, - is_logup_buffer_exists, - pushed_big_buffers, - ) { - (false, false, []) => (None, None), - (true, false, [prod]) => (Some(prod), None), - (false, true, [logup]) => (None, Some(logup)), - (true, true, [prod, logup]) => (Some(prod), Some(logup)), - (prod_flag, logup_flag, slice) => { - panic!( - "unexpected state: prod={}, logup={}, newly_pushed_len={}", - prod_flag, - logup_flag, - slice.len() - ); - } - }; - // Build product GpuProverSpecs let mut prod_gpu_specs = Vec::new(); - if is_prod_buffer_exists { - let prod_last_layers = &view_last_layers[0]; + if !prod_last_layers.is_empty() { let first_layer = &prod_last_layers[0]; assert_eq!(first_layer.len(), 2, "prod last_layer must have 2 MLEs"); let num_vars = first_layer[0].num_vars(); let num_towers = prod_last_layers.len(); - let Some(prod_big_buffer) = prod_big_buffer else { - panic!("prod big buffer not found"); - }; let span_prod = entered_span!( "build_prod_tower", @@ -1768,31 +1937,31 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( let last_layers_refs: Vec<&[GpuPolynomialExt<'_>]> = prod_last_layers.iter().map(|v| v.as_slice()).collect(); let gpu_specs = { - cuda_hal.tower.build_prod_tower_from_gpu_polys_batch( + cuda_hal.tower.build_prod_tower_dense_from_gpu_polys_batch( cuda_hal, - prod_big_buffer, &last_layers_refs, num_vars, num_towers, stream.as_ref(), ) } - .map_err(|e| format!("build_prod_tower_from_gpu_polys_batch failed: {:?}", e))?; + .map_err(|e| { + format!( + "build_prod_tower_dense_from_gpu_polys_batch failed: {:?}", + e + ) + })?; prod_gpu_specs.extend(gpu_specs); exit_span!(span_prod); } // Build logup GpuProverSpecs let mut logup_gpu_specs = Vec::new(); - if is_logup_buffer_exists { - let logup_last_layers = view_last_layers.last().unwrap(); + if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; assert_eq!(first_layer.len(), 4, "logup last_layer must have 4 MLEs"); let num_vars = first_layer[0].num_vars(); let num_towers = logup_last_layers.len(); - let Some(logup_big_buffer) = logup_big_buffer else { - panic!("logup big buffer not found"); - }; let span_logup = entered_span!( "build_logup_tower", @@ -1803,16 +1972,19 @@ pub(crate) fn build_tower_witness_gpu<'buf, E: ExtensionField>( logup_last_layers.iter().map(|v| v.as_slice()).collect(); let gpu_specs = cuda_hal .tower - .build_logup_tower_from_gpu_polys_batch( + .build_logup_tower_dense_from_gpu_polys_batch( cuda_hal, - logup_big_buffer, &last_layers_refs, num_vars, num_towers, stream.as_ref(), ) - .map_err(|e| format!("build_logup_tower_from_gpu_polys_batch failed: {:?}", e))?; - + .map_err(|e| { + format!( + "build_logup_tower_dense_from_gpu_polys_batch failed: {:?}", + e + ) + })?; logup_gpu_specs.extend(gpu_specs); exit_span!(span_logup); } @@ -1875,10 +2047,19 @@ impl> TowerProver(composed_cs, input); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .map(|layer| layer.name.as_str()), + ); res } @@ -1927,12 +2108,548 @@ impl> MainSumcheckProver(composed_cs, input); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.first()) + .map(|layer| layer.name.as_str()), + ); res } } +impl> + BatchedMainConstraintProver> for GpuProver> +{ + fn prove_batched_main_constraints<'a>( + &self, + mut jobs: Vec>>, + pcs_data: & as ProverBackend>::PcsData, + transcript: &mut impl Transcript, + ) -> Result<(MainConstraintProof, Vec>), ZKVMError> { + struct ChipMainData<'a, E: ExtensionField> { + circuit_idx: usize, + layer: &'a gkr_iop::gkr::layer::Layer, + mle_start: usize, + num_mles: usize, + num_var_with_rotation: usize, + pi: Vec>, + alpha_start: usize, + } + + struct HostCommonGroup { + num_vars: usize, + term_terms: Vec, + common_mle_indices: Vec, + } + + if jobs.is_empty() { + return Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof: IOPProof { proofs: vec![] }, + evals: vec![], + }, + }, + vec![], + )); + } + + let stream = gkr_iop::gpu::get_thread_stream(); + let cuda_hal = get_cuda_hal().map_err(hal_to_backend_error)?; + for job in jobs.iter_mut() { + let num_vars = job.input.log2_num_instances() + job.cs.rotation_vars().unwrap_or(0); + if job.input.witness.is_empty() { + if let Some(trace_idx) = job.witness_trace_idx { + job.input.witness = + info_span!("[ceno] extract_main_witness_mles").in_scope(|| { + if job.circuit_name == "ShardRamCircuit" { + extract_shard_ram_witness_mles_for_trace::( + pcs_data, + trace_idx, + job.num_witin, + num_vars, + job.input.num_instances(), + ) + } else { + extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + job.num_witin, + num_vars, + ) + } + }); + } + } + if job.input.structural_witness.is_empty() { + if let Some(rmm) = job.structural_rmm.as_ref() { + let num_structural_witin = job.cs.zkvm_v1_css.num_structural_witin as usize; + job.input.structural_witness = + info_span!("[ceno] transport_main_structural_witness").in_scope(|| { + transport_structural_witness_to_gpu::( + rmm, + num_structural_witin, + num_vars, + ) + }); + } + } + } + let mut selector_eqs_by_chip = Vec::with_capacity(jobs.len()); + let mut chip_data = Vec::with_capacity(jobs.len()); + let mut total_exprs = 0usize; + let mut total_mles = 0usize; + let mut max_num_variables = 0usize; + + for job in &jobs { + let ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + } = job.cs; + let num_instances = job.input.num_instances(); + let log2_num_instances = job.input.log2_num_instances(); + let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); + max_num_variables = max_num_variables.max(num_var_with_rotation); + + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() + { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext { + offset: 0, + num_instances: job.input.num_instances[0], + num_vars: num_var_with_rotation, + } + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext { + offset: job.input.num_instances[0], + num_instances: job.input.num_instances[1], + num_vars: num_var_with_rotation, + } + } else { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } + }) + .collect_vec(); + + let mut out_evals = + vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + + if let Some(rotation) = job.rotation.as_ref() { + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation proof provided for non-rotation layer") + }; + let (left_evals, right_evals, point_evals) = + split_rotation_evals(&rotation.proof.evals); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation.left_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation.right_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation.point, + ); + } + + if let Some(ecc_proof) = job.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) + .expect("invalid internal ecc bridge claims"); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + let eval_and_dedup_points = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, &job.challenges).point) + }) + .collect_vec(); + let selector_eq_pairs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(selector_ctxs.iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + let eq = gkr_iop::gkr::layer::gpu::utils::build_eq_x_r_with_sel_gpu( + &cuda_hal, + point.as_ref()?, + selector_ctx, + sel_type, + ); + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + Some((*wit_id as usize, eq)) + }) + .collect_vec(); + let mut selector_eq_by_wit_id = vec![None; first_layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } + selector_eqs_by_chip.push(selector_eq_by_wit_id); + + let num_mles = + first_layer.n_witin + first_layer.n_fixed + first_layer.n_structural_witin; + chip_data.push(ChipMainData { + circuit_idx: job.circuit_idx, + layer: first_layer, + mle_start: total_mles, + num_mles, + num_var_with_rotation, + pi: job.input.pi.clone(), + alpha_start: total_exprs, + }); + total_mles += num_mles; + total_exprs += first_layer.exprs.len(); + } + let mut all_witins_gpu = Vec::with_capacity(total_mles); + for ((job, chip), selector_eq_by_wit_id) in jobs + .iter() + .zip(chip_data.iter()) + .zip(selector_eqs_by_chip.iter()) + { + all_witins_gpu.extend(job.input.witness.iter().map(|mle| mle.as_ref())); + all_witins_gpu.extend(job.input.fixed.iter().map(|mle| mle.as_ref())); + for (selector_eq, mle) in selector_eq_by_wit_id + .iter() + .zip(job.input.structural_witness.iter()) + { + if let Some(eq) = selector_eq.as_ref() { + all_witins_gpu.push(eq); + } else { + all_witins_gpu.push(mle.as_ref()); + } + } + assert_eq!( + all_witins_gpu.len(), + chip.mle_start + chip.num_mles, + "invalid gpu main witness layout" + ); + } + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let mut term_coefficients = Vec::new(); + let mut mle_indices_per_term = Vec::new(); + let mut mle_size_info = Vec::new(); + let mut common_groups = Vec::new(); + for chip in &chip_data { + let main_sumcheck_challenges = chain!( + jobs[0].challenges.iter().copied(), + alpha_pows[chip.alpha_start..chip.alpha_start + chip.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + let common_plan = chip.layer.main_sumcheck_expression_common_factored.as_ref(); + let monomial_terms = match ( + common_plan, + chip.layer + .main_sumcheck_expression_monomial_terms_excluded_shared + .as_ref(), + ) { + (Some(_), Some(residual_terms)) => residual_terms, + (Some(_), None) => { + panic!("common factoring plan present without residual monomials") + } + (None, Some(terms)) => terms, + (None, None) => chip + .layer + .main_sumcheck_expression_monomial_terms + .as_ref() + .unwrap(), + }; + let term_start = term_coefficients.len(); + for term in monomial_terms { + let scalar = + eval_by_expr_constant(&chip.pi, &main_sumcheck_challenges, &term.scalar) + .map_either(E::from, |v| v) + .into_inner(); + term_coefficients.push(scalar); + let indices = term + .product + .iter() + .map(|expr| { + let Expression::WitIn(wit_id) = expr else { + panic!("main monomial product must be converted to WitIn") + }; + chip.mle_start + *wit_id as usize + }) + .collect_vec(); + let first_idx = indices.first().copied(); + mle_indices_per_term.push(indices); + if let Some(first_idx) = first_idx { + let num_vars = all_witins_gpu[first_idx].mle.num_vars(); + mle_size_info.push((num_vars, num_vars)); + } else { + mle_size_info.push((0, 0)); + } + } + let mut covered_terms = vec![false; monomial_terms.len()]; + if let Some(common_plan) = common_plan { + for group in &common_plan.groups { + assert!( + !group.term_indices.is_empty(), + "common term group must include at least one term" + ); + let mut group_term_terms = Vec::with_capacity(group.term_indices.len()); + for &term_idx in &group.term_indices { + assert!( + term_idx < monomial_terms.len(), + "common term index {} out of range (terms={})", + term_idx, + monomial_terms.len() + ); + covered_terms[term_idx] = true; + group_term_terms.push( + u32::try_from(term_start + term_idx) + .expect("term index exceeds supported range for GPU plan"), + ); + } + + let mut group_mle_indices = Vec::with_capacity(group.witness_indices.len()); + for &wit_idx in &group.witness_indices { + assert!( + wit_idx < chip.num_mles, + "common witness index {} out of range (mles={})", + wit_idx, + chip.num_mles + ); + group_mle_indices.push( + u32::try_from(chip.mle_start + wit_idx) + .expect("witness index exceeds supported range for GPU plan"), + ); + } + common_groups.push(HostCommonGroup { + num_vars: chip.num_var_with_rotation, + term_terms: group_term_terms, + common_mle_indices: group_mle_indices, + }); + } + } + let mut uncovered_terms = Vec::new(); + for (term_idx, covered) in covered_terms.iter().copied().enumerate() { + if !covered { + uncovered_terms.push( + u32::try_from(term_start + term_idx) + .expect("term index exceeds supported range for GPU plan"), + ); + } + } + if !uncovered_terms.is_empty() { + common_groups.push(HostCommonGroup { + num_vars: chip.num_var_with_rotation, + term_terms: uncovered_terms, + common_mle_indices: Vec::new(), + }); + } + } + + common_groups.sort_by(|lhs, rhs| rhs.num_vars.cmp(&lhs.num_vars)); + + let mut common_term_offsets = Vec::with_capacity(common_groups.len() + 1); + let mut common_term_terms = Vec::new(); + let mut common_mle_offsets = Vec::with_capacity(common_groups.len() + 1); + let mut common_mle_indices = Vec::new(); + common_term_offsets.push(0); + common_mle_offsets.push(0); + for group in &common_groups { + common_term_terms.extend(group.term_terms.iter().copied()); + common_term_offsets.push(common_term_terms.len() as u32); + common_mle_indices.extend(group.common_mle_indices.iter().copied()); + common_mle_offsets.push(common_mle_indices.len() as u32); + } + + let active_counts_by_num_vars = (0..=max_num_variables) + .map(|num_vars| { + common_groups + .iter() + .take_while(|group| group.num_vars >= num_vars) + .count() as u32 + }) + .collect_vec(); + + let max_degree = common_groups + .iter() + .map(|group| { + let common_len = group.common_mle_indices.len(); + let max_residual_len = group + .term_terms + .iter() + .map(|&term_idx| mle_indices_per_term[term_idx as usize].len()) + .max() + .unwrap_or(0); + common_len + max_residual_len + }) + .max() + .unwrap_or(0); + let basic_transcript = expect_basic_transcript(transcript); + let common_scalar_offsets = vec![0u32; common_mle_offsets.len()]; + let common_term_plan = CommonTermPlan { + term_offsets: common_term_offsets, + term_terms: common_term_terms, + common_mle_offsets, + common_mle_indices, + common_scalar_offsets, + common_scalar_indices: vec![], + active_counts_by_num_vars, + }; + let term_coefficients_gl64: Vec = + unsafe { std::mem::transmute(term_coefficients) }; + let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = + unsafe { std::mem::transmute(all_witins_gpu) }; + let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); + let (proof_gpu, evals_gpu, challenges_gpu) = cuda_hal + .sumcheck + .prove_batched_main_sumcheck_gpu_v2( + cuda_hal.as_ref(), + all_witins_gpu_type_gl64, + &mle_size_info, + &term_coefficients_gl64, + &mle_indices_per_term, + max_num_variables, + max_degree, + Some(&common_term_plan), + 1, + basic_transcript, + stream.as_ref(), + ) + .map_err(|e| hal_to_backend_error(format!("GPU main sumcheck failed: {e:?}")))?; + let proof: IOPProof = unsafe { std::mem::transmute(proof_gpu) }; + let evals_gpu_e: Vec> = unsafe { std::mem::transmute(evals_gpu) }; + let global_evals = evals_gpu_e.into_iter().flatten().collect_vec(); + let global_rt: Point = unsafe { + std::mem::transmute::, Vec>( + challenges_gpu.iter().map(|c| c.elements).collect(), + ) + }; + + transcript.append_field_element_exts(&global_evals); + + let mut results = Vec::with_capacity(chip_data.len()); + for chip in &chip_data { + let input_opening_point = + gpu_v2_input_opening_point(&global_rt, chip.num_var_with_rotation); + let chip_evals = &global_evals[chip.mle_start..chip.mle_start + chip.num_mles]; + results.push(MainConstraintResult { + circuit_idx: chip.circuit_idx, + input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: chip_evals[..chip.layer.n_witin].to_vec(), + fixed_in_evals: chip_evals + [chip.layer.n_witin..chip.layer.n_witin + chip.layer.n_fixed] + .to_vec(), + }, + }); + } + + Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof, + evals: global_evals, + }, + }, + results, + )) + } +} + +fn gpu_v2_input_opening_point( + global_rt: &[E], + num_var_with_rotation: usize, +) -> Point { + global_rt[global_rt.len() - num_var_with_rotation..].to_vec() +} + impl> RotationProver> for GpuProver> { @@ -1964,7 +2681,15 @@ impl> EccQuarkProver> OpeningProver> OpeningProver( + prover: &GpuProver>, + witness_data: as ProverBackend>::PcsData, + fixed_data: Option as ProverBackend>::PcsData>>, + replayable_traces: &[(usize, crate::structs::GpuReplayPlan)], + points: Vec>, + mut evals: Vec>>, + transcript: &mut (impl Transcript + 'static), +) -> PCS::Proof +where + E: ExtensionField, + PCS: PolynomialCommitmentScheme, +{ + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear base field"); + } + + let mut rounds = vec![]; + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } + }) + .collect_vec() + })); + if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points.iter().cloned()) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } + }) + .collect_vec() + })); + } + + let prover_param = &prover.backend.pp; + let pp_gl64: &mpcs::basefold::structure::BasefoldProverParams = + unsafe { std::mem::transmute(prover_param) }; + let rounds_gl64: Vec<_> = rounds + .iter() + .map(|(commitment, point_eval_pairs)| { + let commitment_gl64: &BasefoldCommitmentWithWitnessGpu< + BB31Base, + BufferImpl, + GpuDigestLayer, + GpuMatrix<'static>, + GpuPolynomial<'static>, + > = unsafe { std::mem::transmute(*commitment) }; + let point_eval_pairs_gl64: Vec<_> = point_eval_pairs + .iter() + .map(|(point, evals)| { + let point_gl64: &Vec = unsafe { std::mem::transmute(point) }; + let evals_gl64: &Vec = unsafe { std::mem::transmute(evals) }; + (point_gl64.clone(), evals_gl64.clone()) + }) + .collect(); + (commitment_gl64, point_eval_pairs_gl64) + }) + .collect(); + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear field extension"); + } + + let transcript_any = transcript as &mut dyn std::any::Any; + let basic_transcript = transcript_any + .downcast_mut::>() + .expect("Type should match"); + + let cuda_hal = get_cuda_hal().unwrap(); + let gpu_proof_basefold = cuda_hal + .basefold + .batch_open_with_trace_materializer( + &cuda_hal, + pp_gl64, + rounds_gl64, + basic_transcript, + |round_idx, trace_idx| { + if round_idx != 0 { + return Ok(None); + } + let Some((_, replay_plan)) = replayable_traces + .iter() + .find(|(replay_trace_idx, _)| *replay_trace_idx == trace_idx) + else { + return Ok(None); + }; + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "pcs_opening", + round_idx, + trace_idx, + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness()) + .map_err(|err| { + ceno_gpu::HalError::InvalidInput(format!( + "failed to replay trace {trace_idx} for PCS opening: {err:?}" + )) + })?; + if witness_rmm.height() != replay_plan.trace_height { + return Err(ceno_gpu::HalError::InvalidInput(format!( + "replayed trace {trace_idx} height changed before PCS opening: expected {}, got {}", + replay_plan.trace_height, + witness_rmm.height(), + ))); + } + let witness_rmm_bb31: witness::RowMajorMatrix = + unsafe { std::mem::transmute(witness_rmm) }; + Ok(Some(witness_rmm_bb31)) + }, + ) + .unwrap(); + + let gpu_proof: PCS::Proof = unsafe { std::mem::transmute_copy(&gpu_proof_basefold) }; + std::mem::forget(gpu_proof_basefold); + drop(rounds); + drop(witness_data); + gpu_proof +} + impl> DeviceTransporter> for GpuProver> { @@ -2158,21 +3020,28 @@ impl> pcs_data: & as gkr_iop::hal::ProverBackend>::PcsData, ) { if let Some(replay_plan) = task.gpu_replay_plan.as_ref() { - let cuda_hal = get_cuda_hal().unwrap(); - let gpu_mem_tracker = init_gpu_mem_tracker(&cuda_hal, "replay_gpu_witness_from_raw"); let num_vars = task.input.log2_num_instances() + task.pk.get_cs().rotation_vars().unwrap_or(0); - let estimated_replay_bytes = - estimate_replay_materialization_bytes_for_plan(replay_plan, num_vars); - tracing::info!( - "[gpu] replaying witness from raw: circuit={}, estimated={:.2}MB", - task.circuit_name, - estimated_replay_bytes as f64 / (1024.0 * 1024.0), - ); - let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_replay_bytes); - task.input.witness = info_span!("[ceno] replay_gpu_witness_from_raw") - .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)); + if task.num_witin > 0 { + let cuda_hal = get_cuda_hal().unwrap(); + let gpu_mem_tracker = + init_gpu_mem_tracker(&cuda_hal, "replay_gpu_witness_from_raw"); + let estimated_replay_bytes = + estimate_replay_materialization_bytes_for_plan(replay_plan, num_vars); + tracing::info!( + "[gpu] replaying witness from raw: circuit={}, estimated={:.2}MB", + task.circuit_name, + estimated_replay_bytes as f64 / (1024.0 * 1024.0), + ); + let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_replay_bytes, + Some(task.circuit_name.as_str()), + ); + task.input.witness = info_span!("[ceno] replay_gpu_witness_from_raw") + .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)); + }; if let Some(rmm) = task.structural_rmm.as_ref() { task.input.structural_witness = info_span!("[ceno] transport_structural_witness") .in_scope(|| { diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 65fe06f2d..5f1b076df 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -20,6 +20,7 @@ pub trait ProverDevice: TraceCommitter + TowerProver + MainSumcheckProver + + BatchedMainConstraintProver + OpeningProver + DeviceTransporter + ProtocolWitnessGeneratorProver @@ -33,6 +34,21 @@ where fn get_pb(&self) -> &PB; } +pub trait BatchedMainConstraintProver { + fn prove_batched_main_constraints<'a>( + &self, + jobs: Vec>, + pcs_data: &PB::PcsData, + transcript: &mut impl Transcript, + ) -> Result< + ( + crate::scheme::MainConstraintProof, + Vec>, + ), + ZKVMError, + >; +} + /// Prepare a chip task's input for proving. /// CPU: no-op (input already fully populated during task building). /// GPU: deferred witness extraction + structural witness transport. @@ -54,6 +70,19 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub has_ecc_ops: bool, } +impl<'a, PB: ProverBackend> Clone for ProofInput<'a, PB> { + fn clone(&self) -> Self { + Self { + witness: self.witness.clone(), + structural_witness: self.structural_witness.clone(), + fixed: self.fixed.clone(), + pi: self.pi.clone(), + num_instances: self.num_instances, + has_ecc_ops: self.has_ecc_ops, + } + } +} + impl<'a, PB: ProverBackend> ProofInput<'a, PB> { pub fn num_instances(&self) -> usize { self.num_instances.iter().sum() @@ -154,6 +183,26 @@ pub struct MainSumcheckEvals { pub fixed_in_evals: Vec, } +pub struct MainConstraintJob<'a, PB: ProverBackend> { + pub circuit_name: String, + pub circuit_idx: usize, + pub input: ProofInput<'static, PB>, + pub witness_trace_idx: Option, + pub num_witin: usize, + pub structural_rmm: Option::BaseField>>, + pub rt_tower: Point, + pub rotation: Option>, + pub ecc_proof: Option>, + pub challenges: [PB::E; 2], + pub cs: &'a ComposedConstrainSystem, +} + +pub struct MainConstraintResult { + pub circuit_idx: usize, + pub input_opening_point: Point, + pub opening_evals: MainSumcheckEvals, +} + #[derive(Clone)] pub struct RotationProverOutput { pub proof: SumcheckLayerProof, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 516dee741..f72956ed8 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,4 +1,6 @@ use ff_ext::ExtensionField; +#[cfg(feature = "gpu")] +use gkr_iop::error::BackendError; use gkr_iop::{ cpu::{CpuBackend, CpuProver}, hal::ProverBackend, @@ -12,7 +14,7 @@ use crate::scheme::gpu::estimate_chip_proof_memory; #[cfg(feature = "gpu")] use crate::scheme::scheduler::get_chip_proving_mode; use crate::scheme::{ - hal::MainSumcheckEvals, + hal::{MainConstraintJob, MainConstraintResult, MainSumcheckEvals}, scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; #[cfg(feature = "gpu")] @@ -45,7 +47,10 @@ use crate::{ structs::{TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, }; -type CreateTableProof = (ZKVMChipProof, MainSumcheckEvals, Point); +type CreateTableProof<'a, PB> = ( + ZKVMChipProof<::E>, + MainConstraintJob<'a, PB>, +); pub type ZkVMCpuProver = ZKVMProver, CpuProver>>; @@ -220,6 +225,8 @@ impl< let mut structural_rmms = Vec::with_capacity(name_and_instances.len()); #[cfg(feature = "gpu")] let mut gpu_replay_plans = Vec::with_capacity(name_and_instances.len()); + #[cfg(feature = "gpu")] + let mut witness_trace_rows = Vec::with_capacity(name_and_instances.len()); // commit to opcode circuits first and then commit to table circuits, sorted by name for (i, chip_input) in witnesses.into_iter_sorted().enumerate() { let crate::structs::ChipInput { @@ -233,26 +240,37 @@ impl< #[cfg(feature = "gpu")] let use_deferred_gpu_commit = crate::instructions::gpu::config::is_gpu_witgen_enabled() && !crate::instructions::gpu::config::should_retain_witness_device_backing_after_commit(); - #[cfg(feature = "gpu")] - if use_deferred_gpu_commit { - if let Some(plan) = gpu_replay_plan.clone() { - deferred_gpu_traces - .insert(i, crate::scheme::gpu::DeferredGpuTrace::Replay(plan)); - } else if witness_rmm.num_instances() > 0 { - deferred_gpu_traces - .insert(i, crate::scheme::gpu::DeferredGpuTrace::Eager(witness_rmm)); - } - } else if witness_rmm.num_instances() > 0 { - wits_rmms.insert(i, witness_rmm); - } - - #[cfg(not(feature = "gpu"))] - if witness_rmm.num_instances() > 0 { - wits_rmms.insert(i, witness_rmm); - } + let trace_rows_for_estimate = + if !crate::instructions::gpu::config::is_gpu_witgen_enabled() + && witness_rmm.num_instances() > 0 + { + Some(witness_rmm.height()) + } else { + None + }; + + #[cfg(feature = "gpu")] + if use_deferred_gpu_commit { + if let Some(plan) = gpu_replay_plan.clone().filter(|plan| plan.num_witin > 0) { + deferred_gpu_traces + .insert(i, crate::scheme::gpu::DeferredGpuTrace::Replay(plan)); + } else if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + deferred_gpu_traces + .insert(i, crate::scheme::gpu::DeferredGpuTrace::Eager(witness_rmm)); + } + } else if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + wits_rmms.insert(i, witness_rmm); + } + + #[cfg(not(feature = "gpu"))] + if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + wits_rmms.insert(i, witness_rmm); + } structural_rmms.push(structural_witness_rmm); #[cfg(feature = "gpu")] + witness_trace_rows.push(trace_rows_for_estimate); + #[cfg(feature = "gpu")] gpu_replay_plans.push(gpu_replay_plan); } @@ -364,6 +382,8 @@ impl< structural_rmms, #[cfg(feature = "gpu")] gpu_replay_plans, + #[cfg(feature = "gpu")] + witness_trace_rows, witness_mles, &witness_data, fixed_mles, @@ -524,7 +544,7 @@ impl< // Phase 3: Collect results let collect_results_span = entered_span!("collect_chip_results", profiling_1 = true); - let (chip_proofs, points, evaluations) = Self::collect_chip_results(results); + let (chip_proofs, main_constraint_jobs) = Self::collect_chip_results(results); exit_span!(collect_results_span); exit_span!(main_proofs_span); @@ -533,31 +553,96 @@ impl< transcript.append_field_element_ext(&sample); } - // batch opening pcs - // generate static info from prover key for expected num variable - let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); #[cfg(feature = "gpu")] if needs_replay_restore { let replay_cache = current_replay_cache_stats(); tracing::info!( - "[gpu replay cache][before_restore_pcs] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", + "[gpu replay cache][before_restore_main] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0), replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0), replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0), replay_cache.total_bytes() as f64 / (1024.0 * 1024.0), ); - crate::scheme::gpu::log_gpu_device_state("before_restore_pcs"); + crate::scheme::gpu::log_gpu_device_state("before_restore_main"); let gpu_witness_data: &mut as ProverBackend>::PcsData = unsafe { std::mem::transmute(&mut witness_data) }; crate::scheme::gpu::restore_replayable_trace_device_backing::( gpu_witness_data, &replayable_traces, )?; - crate::scheme::gpu::log_gpu_device_state("after_restore_pcs"); + crate::scheme::gpu::log_gpu_device_state("after_restore_main"); + } + + let main_constraints_span = + entered_span!("prove_batched_main_constraints", profiling_1 = true); + let (main_constraint_proof, main_constraint_results) = + info_span!("[ceno] prove_batched_main_constraints").in_scope(|| { + self.device + .prove_batched_main_constraints( + main_constraint_jobs, + &witness_data, + &mut transcript, + ) + })?; + let (points, evaluations) = + Self::collect_main_constraint_results(main_constraint_results); + exit_span!(main_constraints_span); + + #[cfg(feature = "gpu")] + if needs_replay_restore { + crate::scheme::gpu::log_gpu_device_state("before_clear_main_backing"); + let gpu_witness_data: &mut as ProverBackend>::PcsData = + unsafe { std::mem::transmute(&mut witness_data) }; + crate::scheme::gpu::clear_replayable_trace_device_backing::( + gpu_witness_data, + &replayable_traces, + ); + crate::scheme::gpu::log_gpu_device_state("after_clear_main_backing"); + } + + // batch opening pcs + // generate static info from prover key for expected num variable + let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); + #[cfg(feature = "gpu")] + if needs_replay_restore { + let replay_cache = current_replay_cache_stats(); + tracing::info!( + "[gpu replay cache][before_pcs_opening] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", + replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0), + replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0), + replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0), + replay_cache.total_bytes() as f64 / (1024.0 * 1024.0), + ); + crate::scheme::gpu::log_gpu_device_state("before_pcs_opening"); } let mpcs_opening_proof = info_span!("[ceno] pcs_opening").in_scope(|| { #[cfg(feature = "gpu")] { + if needs_replay_restore { + let gpu_device: &gkr_iop::gpu::GpuProver< + gkr_iop::gpu::GpuBackend, + > = unsafe { std::mem::transmute(&self.device) }; + let gpu_witness_data: as ProverBackend>::PcsData = + unsafe { std::mem::transmute_copy(&witness_data) }; + std::mem::forget(witness_data); + let fixed_data = self + .get_device_proving_key(shard_ctx) + .map(|dpk| dpk.pcs_data.clone()); + let gpu_fixed_data: Option< + std::sync::Arc< + as ProverBackend>::PcsData, + >, + > = unsafe { std::mem::transmute(fixed_data) }; + return crate::scheme::gpu::open_with_incremental_replay::( + gpu_device, + gpu_witness_data, + gpu_fixed_data, + &replayable_traces, + points, + evaluations, + &mut transcript, + ); + } } self.device.open( witness_data, @@ -570,7 +655,13 @@ impl< }); exit_span!(pcs_opening); - let vm_proof = ZKVMProof::new(pi, chip_proofs, witin_commit, mpcs_opening_proof); + let vm_proof = ZKVMProof::new( + pi, + chip_proofs, + main_constraint_proof, + witin_commit, + mpcs_opening_proof, + ); Ok(vm_proof) }) @@ -587,13 +678,14 @@ impl< tasks: Vec>, transcript: &T, witness_data: &PB::PcsData, - ) -> Result<(Vec>, Vec), ZKVMError> { + ) -> Result<(Vec>, Vec), ZKVMError> { let scheduler = ChipScheduler::new(); #[cfg(feature = "gpu")] { - if std::any::TypeId::of::() - == std::any::TypeId::of::>() + if false + && std::any::TypeId::of::() + == std::any::TypeId::of::>() { let gpu_witness_data: & as gkr_iop::hal::ProverBackend>::PcsData = unsafe { std::mem::transmute(witness_data) }; @@ -603,6 +695,12 @@ impl< task.circuit_idx as u64, )); + let task_name = task.circuit_name.clone(); + let estimated_memory_bytes = task.estimated_memory_bytes as usize; + let cuda_hal = gkr_iop::gpu::get_cuda_hal().expect("Failed to get CUDA HAL"); + let chip_mem_tracker = + crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "create_chip_proof"); + let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = unsafe { std::mem::transmute(task.input) }; @@ -619,6 +717,11 @@ impl< task.num_witin, task.structural_rmm, )?; + crate::scheme::gpu::check_gpu_scheduler_mem_estimation_with_context( + chip_mem_tracker, + estimated_memory_bytes, + Some(task_name.as_str()), + ); Ok(ChipTaskResult { task_id: task.task_id, @@ -626,11 +729,12 @@ impl< proof, opening_evals, input_opening_point, + main_constraint_job: None, has_witness_or_fixed: task.has_witness_or_fixed, }) }; - if ChipScheduler::is_concurrent_mode() { + if false && ChipScheduler::is_concurrent_mode() { // SAFETY: pcs_data is only read (via get_trace) during concurrent execution. use crate::scheme::utils::SyncRef; let gpu_wd = SyncRef(gpu_witness_data); @@ -676,15 +780,18 @@ impl< )); } - let (proof, opening_evals, input_opening_point) = - self.create_chip_proof(&task, transcript)?; + let (proof, main_constraint_job) = self.create_chip_proof(&mut task, transcript)?; Ok(ChipTaskResult { task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - opening_evals, - input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: vec![], + fixed_in_evals: vec![], + }, + input_opening_point: vec![], + main_constraint_job: Some(main_constraint_job), has_witness_or_fixed: task.has_witness_or_fixed, }) }) @@ -696,11 +803,11 @@ impl< /// into a single tower tree, and then feed these trees into tower prover. #[tracing::instrument(skip_all, name = "create_chip_proof", fields(table_name=%task.circuit_name, profiling_2 ), level = "trace")] - pub fn create_chip_proof( + pub fn create_chip_proof<'a>( &self, - task: &ChipTask<'_, PB>, + task: &mut ChipTask<'a, PB>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { + ) -> Result, ZKVMError> { let circuit_pk = task.pk; let input = &task.input; let challenges = &task.challenges; @@ -786,68 +893,49 @@ impl< })?; exit_span!(span); - // 1. prove the main constraints among witness polynomials - // 2. prove the relation between last layer in the tower and read/write/logup records - let span = entered_span!("prove_main_constraints", profiling_2 = true); #[cfg(feature = "gpu")] - if task.gpu_replay_plan.as_ref().is_some_and(|plan| { - matches!( - plan.kind, - crate::instructions::gpu::dispatch::GpuWitgenKind::Keccak - ) - }) { - crate::scheme::gpu::log_gpu_pool_usage(&format!( - "{}:before_prove_main", - task.circuit_name - )); - } - let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = - info_span!("[ceno] prove_main_constraints").in_scope(|| { - self.device.prove_main_constraints( - rt_tower, - rotation.clone(), - ecc_proof.as_ref(), - input, - cs, - challenges, - transcript, - ) - })?; + let main_input = { + let mut input = input.clone(); + if std::any::TypeId::of::() + == std::any::TypeId::of::>() + { + input.witness.clear(); + input.structural_witness.clear(); + } + input + }; + #[cfg(not(feature = "gpu"))] + let main_input = input.clone(); #[cfg(feature = "gpu")] - if task.gpu_replay_plan.as_ref().is_some_and(|plan| { - matches!( - plan.kind, - crate::instructions::gpu::dispatch::GpuWitgenKind::Keccak - ) - }) { - crate::scheme::gpu::log_gpu_pool_usage(&format!( - "{}:after_prove_main", - task.circuit_name - )); - } - let MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, - } = evals; - exit_span!(span); + let structural_rmm = task.structural_rmm.take(); + #[cfg(not(feature = "gpu"))] + let structural_rmm = None; Ok(( ZKVMChipProof { r_out_evals, w_out_evals, lk_out_evals, - main_sumcheck_proofs, - gkr_iop_proof, - rotation_proof: rotation.map(|r| r.proof), + main_sumcheck_proofs: None, + gkr_iop_proof: None, + rotation_proof: rotation.clone().map(|r| r.proof), tower_proof, - ecc_proof, + ecc_proof: ecc_proof.clone(), num_instances: input.num_instances, }, - MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, + MainConstraintJob { + circuit_name: task.circuit_name.clone(), + circuit_idx: task.circuit_idx, + input: main_input, + witness_trace_idx: task.witness_trace_idx, + num_witin: task.num_witin, + structural_rmm, + rt_tower, + rotation, + ecc_proof, + challenges: *challenges, + cs, }, - input_opening_point, )) } @@ -860,6 +948,7 @@ impl< name_and_instances: Vec<(String, [usize; 2])>, structural_rmms: Vec>, #[cfg(feature = "gpu")] gpu_replay_plans: Vec>>, + #[cfg(feature = "gpu")] witness_trace_rows: Vec>, #[allow(unused_mut)] mut witness_mles: Vec>, witness_data: &PB::PcsData, mut fixed_mles: Vec>>, @@ -988,6 +1077,7 @@ impl< gpu_input, &circuit_name, gpu_replay_plans[this_idx].as_ref(), + witness_trace_rows[this_idx], structural_cached_on_device, ) }; @@ -1041,6 +1131,8 @@ impl< witness_trace_idx, #[cfg(feature = "gpu")] gpu_replay_plan, + #[cfg(feature = "gpu")] + witness_trace_rows: witness_trace_rows[this_idx], num_witin: cs.num_witin(), structural_rmm: task_structural_rmm, }); @@ -1054,16 +1146,14 @@ impl< /// Phase 3: Collect chip proof results into proof components. #[allow(clippy::type_complexity)] - fn collect_chip_results( - results: Vec>, + fn collect_chip_results<'a>( + results: Vec>, ) -> ( BTreeMap>>, - Vec>, - Vec>>, + Vec>, ) { let mut chip_proofs = BTreeMap::new(); - let mut points = Vec::new(); - let mut evaluations = Vec::new(); + let mut main_constraint_jobs = Vec::new(); for result in results { tracing::trace!( @@ -1072,12 +1162,8 @@ impl< result.task_id ); - if result.has_witness_or_fixed { - points.push(result.input_opening_point); - evaluations.push(vec![ - result.opening_evals.wits_in_evals, - result.opening_evals.fixed_in_evals, - ]); + if let Some(job) = result.main_constraint_job { + main_constraint_jobs.push(job); } chip_proofs .entry(result.circuit_idx) @@ -1085,7 +1171,26 @@ impl< .push(result.proof); } - (chip_proofs, points, evaluations) + (chip_proofs, main_constraint_jobs) + } + + fn collect_main_constraint_results( + results: Vec>, + ) -> (Vec>, Vec>>) { + let mut points = Vec::new(); + let mut evaluations = Vec::new(); + for result in results { + if !result.opening_evals.wits_in_evals.is_empty() + || !result.opening_evals.fixed_in_evals.is_empty() + { + points.push(result.input_opening_point); + evaluations.push(vec![ + result.opening_evals.wits_in_evals, + result.opening_evals.fixed_in_evals, + ]); + } + } + (points, evaluations) } } @@ -1107,7 +1212,7 @@ pub fn create_chip_proof_gpu_impl<'a, E, PCS>( #[cfg(feature = "gpu")] gpu_replay_plan: Option>, num_witin: usize, structural_rmm: Option::BaseField>>, -) -> Result, ZKVMError> +) -> Result<(ZKVMChipProof, MainSumcheckEvals, Point), ZKVMError> where E: ExtensionField, PCS: PolynomialCommitmentScheme + 'static, @@ -1117,12 +1222,11 @@ where scheme::{ constants::NUM_FANIN, gpu::{ - build_tower_witness_gpu, check_gpu_mem_estimation, - estimate_replay_materialization_bytes_for_plan, estimate_tower_stage_bytes, - extract_out_evals_from_gpu_towers, extract_witness_mles_for_trace, - log_gpu_device_state, log_gpu_pool_usage, prove_ec_sum_quark_impl, - prove_main_constraints_impl, prove_rotation_impl, prove_tower_relation_impl, - transport_structural_witness_to_gpu, + build_tower_witness_gpu, estimate_replay_materialization_bytes_for_plan, + estimate_tower_stage_bytes, extract_out_evals_from_gpu_towers, + extract_witness_mles_for_trace, log_gpu_device_state, log_gpu_pool_usage, + prove_ec_sum_quark_impl, prove_main_constraints_impl, prove_rotation_impl, + prove_tower_relation_impl, transport_structural_witness_to_gpu, }, }, }; @@ -1134,6 +1238,20 @@ where .get_pool_stream() .expect("should acquire stream"); let _thread_stream_guard = gkr_iop::gpu::bind_thread_stream(_stream.clone()); + let sync_concurrent_chip_stream = || -> Result<(), ZKVMError> { + if ChipScheduler::is_concurrent_mode() { + cuda_hal + .inner + .synchronize_stream(_stream.stream()) + .map_err(|e| { + ZKVMError::BackendError(BackendError::CircuitError( + format!("failed to synchronize GPU chip proof stream for {name}: {e:?}") + .into_boxed_str(), + )) + })?; + } + Ok(()) + }; let replay_stage_split = gpu_replay_plan .as_ref() .is_some_and(|plan| matches!(plan.kind, GpuWitgenKind::Keccak | GpuWitgenKind::ShardRam)); @@ -1164,7 +1282,11 @@ where log_gpu_device_state(&format!("{name}:before_replay")); log_gpu_pool_usage(&format!("{name}:before_replay")); let witness_rmm = replay_plan.replay_witness()?; - check_gpu_mem_estimation(gpu_mem_tracker, estimated_replay_bytes); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_replay_bytes, + Some(name), + ); input.witness = info_span!("[ceno] replay_gpu_witness_from_raw") .in_scope(|| crate::scheme::gpu::extract_witness_mles_for_trace_rmm::(witness_rmm)); if let Some(structural_rmm_cached) = structural_rmm.as_ref() { @@ -1254,32 +1376,22 @@ where let span = entered_span!("prove_tower_relation", profiling_2 = true); let r_set_len = cs.zkvm_v1_css.r_expressions.len() + cs.zkvm_v1_css.r_table_expressions.len(); - let (tower_build_estimated_bytes, tower_prove_estimated_bytes) = + let (tower_build_estimated_bytes, tower_prove_prebuild_estimated_bytes) = estimate_tower_stage_bytes::(cs, &input); tracing::info!( "[gpu tower][{}] estimated: build_tower={:.2}MB, prove_tower={:.2}MB", name, tower_build_estimated_bytes as f64 / (1024.0 * 1024.0), - tower_prove_estimated_bytes as f64 / (1024.0 * 1024.0), + tower_prove_prebuild_estimated_bytes as f64 / (1024.0 * 1024.0), ); let tower_build_mem_tracker = crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "build_tower_witness_gpu"); - let mut big_buffers = Vec::new(); - let mut ones_buffer = Vec::new(); - let mut view_last_layers = Vec::new(); log_gpu_device_state(&format!("{name}:before_build_tower_witness")); log_gpu_pool_usage(&format!("{name}:before_build_tower_witness")); let (prod_gpu, logup_gpu, lk_out_evals, w_out_evals, r_out_evals) = info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { let (prod_gpu, logup_gpu) = build_tower_witness_gpu( - cs, - &input, - &records, - challenges, - &cuda_hal, - &mut big_buffers, - &mut ones_buffer, - &mut view_last_layers, + cs, &input, &records, challenges, &cuda_hal, ) .map_err(|e| { ZKVMError::InvalidWitness(format!("build_tower_witness_gpu failed: {e}").into()) @@ -1288,7 +1400,11 @@ where extract_out_evals_from_gpu_towers(&prod_gpu, &logup_gpu, r_set_len); Ok::<_, ZKVMError>((prod_gpu, logup_gpu, lk_out_evals, w_out_evals, r_out_evals)) })?; - check_gpu_mem_estimation(tower_build_mem_tracker, tower_build_estimated_bytes); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + tower_build_mem_tracker, + tower_build_estimated_bytes, + Some(name), + ); log_gpu_device_state(&format!("{name}:after_build_tower_witness")); log_gpu_pool_usage(&format!("{name}:after_build_tower_witness")); @@ -1308,6 +1424,27 @@ where prod_specs: prod_gpu, logup_specs: logup_gpu, }; + let tower_prove_estimate = cuda_hal + .tower + .estimate_memory_requirements(&tower_input, NUM_FANIN); + let tower_input_live_bytes = tower_prove_estimate.prod_tower_buffer_bytes + + tower_prove_estimate.logup_tower_buffer_bytes; + let runtime_layout_prove_bytes = tower_prove_estimate + .total_bytes + .saturating_sub(tower_input_live_bytes); + let release_adjusted_prebuild_bytes = + tower_prove_prebuild_estimated_bytes / NUM_FANIN + 4 * 1024 * 1024; + let tower_prove_estimated_bytes = + runtime_layout_prove_bytes.max(release_adjusted_prebuild_bytes); + tracing::info!( + "[gpu tower][{}] refined prove_tower estimate: prebuild={:.2}MB, runtime_layout={:.2}MB, release_adjusted={:.2}MB, local={:.2}MB, tower_live={:.2}MB", + name, + tower_prove_prebuild_estimated_bytes as f64 / (1024.0 * 1024.0), + runtime_layout_prove_bytes as f64 / (1024.0 * 1024.0), + release_adjusted_prebuild_bytes as f64 / (1024.0 * 1024.0), + tower_prove_estimated_bytes as f64 / (1024.0 * 1024.0), + tower_input_live_bytes as f64 / (1024.0 * 1024.0), + ); let tower_prove_mem_tracker = crate::scheme::gpu::init_gpu_mem_tracker(&cuda_hal, "prove_tower_relation_gpu"); log_gpu_device_state(&format!("{name}:before_prove_tower")); @@ -1318,23 +1455,28 @@ where .tower .create_proof( &cuda_hal, - &tower_input, + tower_input, NUM_FANIN, basic_tr, gkr_iop::gpu::get_thread_stream().as_ref(), ) - .expect("gpu tower create_proof failed") - }); + .map_err(|e| { + ZKVMError::BackendError(BackendError::CircuitError( + format!("gpu tower create_proof failed for {name}: {e:?}") + .into_boxed_str(), + )) + }) + })?; log_gpu_device_state(&format!("{name}:after_prove_tower")); log_gpu_pool_usage(&format!("{name}:after_prove_tower")); let rt_tower: Point = unsafe { std::mem::transmute(rt_tower_gl) }; let tower_proof: TowerProofs = unsafe { std::mem::transmute(tower_proof_gpu) }; - check_gpu_mem_estimation(tower_prove_mem_tracker, tower_prove_estimated_bytes); + crate::scheme::gpu::check_gpu_tower_prove_mem_estimation_with_context( + tower_prove_mem_tracker, + tower_prove_estimated_bytes, + Some(name), + ); drop(records); - drop(tower_input); - drop(big_buffers); - drop(ones_buffer); - drop(view_last_layers); log_gpu_device_state(&format!("{name}:after_drop_tower")); exit_span!(span); @@ -1370,6 +1512,7 @@ where wits_in_evals, fixed_in_evals, } = evals; + sync_concurrent_chip_stream()?; clear_materialized_input(&mut input); log_gpu_device_state(&format!("{name}:after_main_constraints")); exit_span!(span); @@ -1419,7 +1562,7 @@ where prove_tower_relation_impl::( cs, &input, &records, challenges, transcript, &cuda_hal, ) - }); + })?; exit_span!(span); drop(records); @@ -1454,6 +1597,7 @@ where wits_in_evals, fixed_in_evals, } = evals; + sync_concurrent_chip_stream()?; exit_span!(span); Ok(( diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index e792b6fd4..c6cd999cb 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -16,7 +16,7 @@ use crate::{ error::ZKVMError, scheme::{ ZKVMChipProof, - hal::{MainSumcheckEvals, ProofInput}, + hal::{MainConstraintJob, MainSumcheckEvals, ProofInput}, }, structs::ProvingKey, }; @@ -90,6 +90,9 @@ pub struct ChipTask<'a, PB: ProverBackend> { /// Replay witness directly from shard-resident raw GPU data when available. #[cfg(feature = "gpu")] pub gpu_replay_plan: Option>, + /// Actual witness trace rows used for cache-none extraction estimates. + #[cfg(feature = "gpu")] + pub witness_trace_rows: Option, /// Expected number of witness polynomials for this circuit pub num_witin: usize, /// CPU-side structural witness RowMajorMatrix, transported to GPU on-demand @@ -97,32 +100,34 @@ pub struct ChipTask<'a, PB: ProverBackend> { } /// Result from a completed chip proof task -pub struct ChipTaskResult { +pub struct ChipTaskResult<'a, PB: ProverBackend> { /// Task ID for ordering pub task_id: usize, /// Circuit index for proof collection pub circuit_idx: usize, /// The generated proof - pub proof: ZKVMChipProof, + pub proof: ZKVMChipProof, /// Prover-only opening evaluations split by witness/fixed/pi domains. - pub opening_evals: MainSumcheckEvals, + pub opening_evals: MainSumcheckEvals, /// Opening point for this proof - pub input_opening_point: Point, + pub input_opening_point: Point, + /// Deferred main-constraint proving job. + pub main_constraint_job: Option>, /// Whether this circuit has witness or fixed polynomials pub has_witness_or_fixed: bool, } /// Message sent from worker to scheduler on task completion #[cfg(feature = "gpu")] -struct CompletionMessage { +struct CompletionMessage<'a, PB: ProverBackend> { /// The result of the proof - result: Result, ZKVMError>, + result: Result, ZKVMError>, /// Memory that was reserved for this task (to release) memory_reserved: u64, /// Task ID for ordering task_id: usize, /// Sampled value from the forked transcript (for gather phase) - forked_sample: E, + forked_sample: PB::E, } /// Memory-aware parallel chip proof scheduler @@ -149,12 +154,12 @@ impl ChipScheduler { tasks: Vec>, transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, { #[cfg(feature = "gpu")] { @@ -185,12 +190,12 @@ impl ChipScheduler { tasks: Vec>, parent_transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError>, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError>, { if tasks.is_empty() { return Ok((vec![], vec![])); @@ -250,12 +255,12 @@ impl ChipScheduler { mut tasks: Vec>, transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, { if tasks.is_empty() { return Ok((vec![], vec![])); @@ -308,15 +313,15 @@ impl ChipScheduler { // Worker -> Scheduler: CompletionMessage (includes sampled value) let (task_tx, task_rx) = mpsc::channel::>(); let task_rx = Arc::new(Mutex::new(task_rx)); - let (done_tx, done_rx) = mpsc::channel::>(); + let (done_tx, done_rx) = mpsc::channel::>(); // 3. State tracking let mut tasks_inflight = 0usize; - let mut results: Vec> = Vec::with_capacity(total_tasks); + let mut results: Vec> = Vec::with_capacity(total_tasks); let mut samples: Vec<(usize, PB::E)> = Vec::with_capacity(total_tasks); // Helper to handle a completion message - let mut handle_completion = |msg: CompletionMessage, + let mut handle_completion = |msg: CompletionMessage<'a, PB>, mem_pool: &ceno_gpu::common::mem_pool::CudaMemPool, tasks_inflight: &mut usize, label: &str| diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 568d548b6..329cd9835 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -39,10 +39,7 @@ use super::{ utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::{ - e2e::ShardContext, scheme::constants::NUM_FANIN, structs::PointAndEval, - tables::DynamicRangeTableCircuit, -}; +use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -174,7 +171,6 @@ fn test_rw_lk_expression_combination() { zkvm_fixed_traces, ) .unwrap(); - let vk = pk.get_vk_slow(); // generate mock witness let num_instances = 1 << 8; @@ -248,7 +244,7 @@ fn test_rw_lk_expression_combination() { num_instances: [num_instances, 0], has_ecc_ops: false, }; - let task = crate::scheme::scheduler::ChipTask { + let mut task = crate::scheme::scheduler::ChipTask { task_id: 0, circuit_name: name.clone(), circuit_idx: 0, @@ -264,12 +260,10 @@ fn test_rw_lk_expression_combination() { num_witin: 0, structural_rmm: None, }; - let (proof, _, _) = prover - .create_chip_proof(&task, &mut transcript) + let (_proof, _main_job) = prover + .create_chip_proof(&mut task, &mut transcript) .expect("create_proof failed"); - // verify proof - let verifier = ZKVMVerifier::new(vk.clone()); let mut v_transcript = BasicTranscript::new(b"test"); // write commitment into transcript and derive challenges from it Pcs::write_commitment(&witin_commit, &mut v_transcript).unwrap(); @@ -283,18 +277,6 @@ fn test_rw_lk_expression_combination() { { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - let _ = verifier - .verify_chip_proof( - name.as_str(), - verifier.vk.circuit_vks.get(&name).unwrap(), - &proof, - &PublicValues::default(), - &mut v_transcript, - NUM_FANIN, - &PointAndEval::default(), - &verifier_challenges, - ) - .expect("verifier failed"); #[cfg(debug_assertions)] { println!( diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index ead260f7d..4921d7f8c 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -680,7 +680,10 @@ pub fn build_main_witness< .iter() .chain(&input.structural_witness) .chain(&input.fixed) - .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) + .all(|v| { + v.num_vars() == num_var_with_rotation + && v.evaluations_len() <= (1 << num_var_with_rotation) + }) ); // GPU memory estimation @@ -704,9 +707,28 @@ pub fn build_main_witness< // GPU memory check: validate estimation against actual usage #[cfg(feature = "gpu")] { + let input_layer_has_only_structural_inputs = composed_cs + .gkr_circuit + .as_ref() + .and_then(|circuit| circuit.layers.last()) + .is_some_and(|input_layer| input_layer.in_eval_expr.is_empty()); + let output_rows = if input_layer_has_only_structural_inputs { + input + .structural_witness + .first() + .map(|mle| mle.evaluations_len()) + } else { + None + } + .or_else(|| input.witness.first().map(|mle| mle.evaluations_len())) + .unwrap_or_else(|| input.num_instances() << composed_cs.rotation_vars().unwrap_or(0)); let estimated_bytes = - crate::scheme::gpu::estimate_main_witness_bytes(composed_cs, num_var_with_rotation); - crate::scheme::gpu::check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + crate::scheme::gpu::estimate_main_witness_bytes(composed_cs, output_rows); + crate::scheme::gpu::check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_bytes, + gkr_circuit.layers.first().map(|layer| layer.name.as_str()), + ); } gkr_circuit_out.0.0 diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 028c2d551..661ebc094 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -8,7 +8,7 @@ use std::{ #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{PublicValues, ZKVMChipProof, ZKVMProof}; +use super::{MainConstraintProof, PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ @@ -18,7 +18,10 @@ use crate::{ scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, septic_curve::{SepticExtension, SepticPoint}, - utils::{assign_group_evals, derive_ecc_bridge_claims}, + utils::{ + GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, + first_layer_output_group_stage_masks, + }, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -29,19 +32,28 @@ use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; use gkr_iop::{ self, selector::{SelectorContext, SelectorType}, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + }, }; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ Expression, StructuralWitIn, - StructuralWitInType::StackedConstantSequence, + StructuralWitInType::{ + Empty, EqualDistanceDynamicSequence, EqualDistanceSequence, + InnerRepeatingIncrementalSequence, OuterRepeatingIncrementalSequence, + StackedConstantSequence, StackedIncrementalSequence, + }, mle::IntoMLE, util::ceil_log2, + utils::eval_by_expr_with_instance, virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; -use p3::field::FieldAlgebra; +use p3::field::{FieldAlgebra, dot_product}; use sumcheck::{ - structs::{IOPProof, IOPVerifierState}, + structs::{IOPProof, IOPVerifierState, SumCheckSubClaim}, util::get_challenge_pows, }; use transcript::{ForkableTranscript, Transcript}; @@ -59,6 +71,83 @@ pub struct ZKVMVerifier< pub vk: ZKVMVerifyingKey, } +pub(crate) struct PendingMainConstraintVerification<'a, E: ExtensionField> { + circuit_name: &'a str, + circuit_vk: &'a VerifyingKey, + proof: &'a ZKVMChipProof, + num_var_with_rotation: usize, + out_evals: Vec>, + pi: Vec, + selector_ctxs: Vec, +} + +fn validate_batched_main_structural_evals( + circuit_name: &str, + layer: &gkr_iop::gkr::layer::Layer, + eval_and_dedup_points: &[(Vec, Option>)], + selector_ctxs: &[SelectorContext], + pi: &[E], + layer_evals: &[E], + structural_witin_offset: usize, + in_point: &Point, +) -> Result<(), String> { + for (((sel_type, _), (_, out_point)), selector_ctx) in layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(selector_ctxs.iter()) + { + if let Some((expected_eval, wit_id)) = + sel_type.evaluate(out_point.as_ref().unwrap(), in_point, selector_ctx) + { + let wit_id = wit_id as usize + structural_witin_offset; + if layer_evals[wit_id] != expected_eval { + return Err(format!("{circuit_name} selector structural witin mismatch")); + } + } + } + + for StructuralWitIn { id, witin_type } in &layer.structural_witins { + let wit_id = *id as usize + structural_witin_offset; + let expected_eval = match witin_type { + EqualDistanceSequence { + offset, + multi_factor, + descending, + .. + } => eval_wellform_address_vec( + *offset as u64, + *multi_factor as u64, + in_point, + *descending, + ), + EqualDistanceDynamicSequence { + offset_instance_id, + multi_factor, + descending, + .. + } => { + let offset = pi[*offset_instance_id as usize].to_canonical_u64(); + eval_wellform_address_vec(offset, *multi_factor as u64, in_point, *descending) + } + StackedIncrementalSequence { .. } => eval_stacked_wellform_address_vec(in_point), + StackedConstantSequence { .. } => eval_stacked_constant_vec(in_point), + InnerRepeatingIncrementalSequence { k, .. } => { + eval_inner_repeated_incremental_vec(*k as u64, in_point) + } + OuterRepeatingIncrementalSequence { k, .. } => { + eval_outer_repeated_incremental_vec(*k as u64, in_point) + } + Empty => continue, + }; + if expected_eval != layer_evals[wit_id] { + return Err(format!("{circuit_name} structural witin mismatch")); + } + } + + Ok(()) +} + fn bind_active_tower_eval_round( transcript: &mut impl Transcript, tower_proofs: &TowerProofs, @@ -160,39 +249,6 @@ impl> Ok((next_heap_addr_end, next_hint_addr_end)) } - #[allow(clippy::type_complexity)] - fn split_input_opening_evals( - circuit_vk: &VerifyingKey, - proof: &ZKVMChipProof, - ) -> Result<(Vec, Vec), ZKVMError> { - let cs = circuit_vk.get_cs(); - let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { - return Err(ZKVMError::InvalidProof("missing gkr proof".into())); - }; - let Some(last_layer) = gkr_proof.0.last() else { - return Err(ZKVMError::InvalidProof("empty gkr proof layers".into())); - }; - - let evals = &last_layer.main.evals; - let wit_len = cs.num_witin(); - let fixed_len = cs.num_fixed(); - let min_len = wit_len + fixed_len; - if evals.len() < min_len { - return Err(ZKVMError::InvalidProof( - format!( - "insufficient main evals: {} < required {}", - evals.len(), - min_len - ) - .into(), - )); - } - - let wits_in_evals = evals[..wit_len].to_vec(); - let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); - Ok((wits_in_evals, fixed_in_evals)) - } - /// Verify a full zkVM trace from program entry to halt. /// /// This is the production verifier API. It treats a single proof as a @@ -340,6 +396,13 @@ impl> vm_proof: ZKVMProof, mut transcript: impl ForkableTranscript, ) -> Result, ZKVMError> { + tracing::info!( + "verifying shard proof: expected_shard_id={}, proof_shard_id={}, chip_groups={}", + shard_id, + vm_proof.public_values.shard_id, + vm_proof.chip_proofs.len() + ); + // main invariant between opcode circuits and table circuits let mut prod_r = E::ONE; let mut prod_w = E::ONE; @@ -461,6 +524,7 @@ impl> } // fork transcript to support chip concurrently proved + let mut pending_main_constraints = Vec::with_capacity(num_proofs); let mut forked_transcripts = transcript.fork(num_proofs); for ((index, proof), transcript) in vm_proof .chip_proofs @@ -548,6 +612,14 @@ impl> .into(), )); }; + if q1 == E::ZERO || q2 == E::ZERO { + return Err(ZKVMError::InvalidProof( + format!( + "{shard_id}th shard {circuit_name} has zero logup denominator in lk_out_evals: {evals:?}" + ) + .into(), + )); + } Ok(p1 * q1.inverse() + p2 * q2.inverse()) }) .sum::>()?; @@ -570,29 +642,17 @@ impl> // accumulate logup_sum logup_sum += chip_logup_sum; - let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = self - .verify_chip_proof( - circuit_name, - circuit_vk, - proof, - &vm_proof.public_values, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; - if circuit_vk.get_cs().num_witin() > 0 { - witin_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), wits_in_evals), - )); - } - if circuit_vk.get_cs().num_fixed() > 0 { - fixed_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), fixed_in_evals), - )); - } + let (pending_main_constraint, chip_shard_ec_sum) = self.verify_chip_proof_pre_main( + circuit_name, + circuit_vk, + proof, + &vm_proof.public_values, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; + pending_main_constraints.push(pending_main_constraint); prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); tracing::debug!( @@ -622,6 +682,28 @@ impl> transcript.append_field_element_ext(&sample); } + for (input_opening_point, wits_in_evals, fixed_in_evals) in self + .verify_batched_main_constraints( + pending_main_constraints, + &vm_proof.main_constraint_proof, + &mut transcript, + &challenges, + )? + { + if !wits_in_evals.is_empty() { + witin_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), wits_in_evals), + )); + } + if !fixed_in_evals.is_empty() { + fixed_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), fixed_in_evals), + )); + } + } + // verify mpcs let mut rounds = vec![(vm_proof.witin_commit.clone(), witin_openings)]; @@ -661,17 +743,23 @@ impl> /// verify proof and return input opening point #[allow(clippy::too_many_arguments, clippy::type_complexity)] - pub fn verify_chip_proof( + pub(crate) fn verify_chip_proof_pre_main<'a>( &self, - _name: &str, - circuit_vk: &VerifyingKey, - proof: &ZKVMChipProof, + _name: &'a str, + circuit_vk: &'a VerifyingKey, + proof: &'a ZKVMChipProof, public_values: &PublicValues, transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result<(Point, Option>, Vec, Vec), ZKVMError> { + ) -> Result< + ( + PendingMainConstraintVerification<'a, E>, + Option>, + ), + ZKVMError, + > { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -826,11 +914,13 @@ impl> let first_layer = gkr_circuit.layers.first().ok_or_else(|| { ZKVMError::InvalidProof(format!("{_name} empty gkr circuit layers").into()) })?; + let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, gkr_circuit); let selector_ctxs = first_layer .out_sel_and_eval_exprs .iter() - .map(|(selector, _)| { - if cs.ec_final_sum.is_empty() { + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() { SelectorContext::new(0, num_instances, num_var_with_rotation) } else if cs.r_selector.as_ref() == Some(selector) { SelectorContext::new(0, proof.num_instances[0], num_var_with_rotation) @@ -906,76 +996,272 @@ impl> ); } - if let Some(ecc_proof) = proof.ecc_proof.as_ref() { - let Some( - [ - x_group_idx, - y_group_idx, - slope_group_idx, - x3_group_idx, - y3_group_idx, - ], - ) = first_layer.ecc_bridge_group_indices() - else { + let pi = cs + .instance + .iter() + .map(|instance| E::from(public_values.query_by_index::(instance.0))) + .collect_vec(); + Ok(( + PendingMainConstraintVerification { + circuit_name: _name, + circuit_vk, + proof, + num_var_with_rotation, + out_evals, + pi, + selector_ctxs, + }, + shard_ec_sum, + )) + } + + fn verify_batched_main_constraints( + &self, + pending_main_constraints: Vec>, + main_constraint_proof: &MainConstraintProof, + transcript: &mut impl Transcript, + challenges: &[E; 2], + ) -> Result, Vec, Vec)>, ZKVMError> { + if pending_main_constraints.is_empty() { + if !main_constraint_proof.proof.proof.proofs.is_empty() + || !main_constraint_proof.proof.evals.is_empty() + { return Err(ZKVMError::InvalidProof( - "ecc bridge claims expected but selectors are missing".into(), + "empty main constraints with non-empty proof".into(), )); - }; + } + return Ok(vec![]); + } - let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; - let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation)?; + struct PendingLayer<'a, E: ExtensionField> { + pending: PendingMainConstraintVerification<'a, E>, + layer: &'a gkr_iop::gkr::layer::Layer, + eval_and_dedup_points: Vec<(Vec, Option>)>, + eval_start: usize, + eval_len: usize, + alpha_start: usize, + } - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[x_group_idx].1, - &claims.x_evals, - &claims.xy_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[y_group_idx].1, - &claims.y_evals, - &claims.xy_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, - &claims.s_evals, - &claims.s_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, - &claims.x3_evals, - &claims.x3y3_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, - &claims.y3_evals, - &claims.x3y3_point, - ); + let mut layers = Vec::with_capacity(pending_main_constraints.len()); + let mut total_exprs = 0usize; + let mut total_evals = 0usize; + let mut max_num_variables = 0usize; + let mut max_degree = 0usize; + + for pending in pending_main_constraints { + let gkr_circuit = pending + .circuit_vk + .get_cs() + .gkr_circuit + .as_ref() + .ok_or_else(|| { + ZKVMError::InvalidProof( + format!("{} missing gkr circuit in vk", pending.circuit_name).into(), + ) + })?; + let layer = gkr_circuit.layers.first().ok_or_else(|| { + ZKVMError::InvalidProof( + format!("{} empty gkr circuit layers", pending.circuit_name).into(), + ) + })?; + max_num_variables = max_num_variables.max(pending.num_var_with_rotation); + max_degree = max_degree.max(layer.max_expr_degree + 1); + + let mut out_evals = pending.out_evals.clone(); + if let Some(ecc_proof) = pending.proof.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = layer.ecc_bridge_group_indices() + else { + return Err(ZKVMError::InvalidProof( + "ecc bridge claims expected but selectors are missing".into(), + )); + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = + derive_ecc_bridge_claims(ecc_proof, sample_r, pending.num_var_with_rotation)?; + + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + out_evals.resize(gkr_circuit.n_evaluations, PointAndEval::default()); + let eval_and_dedup_points = layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + let evals = out_eval_exprs + .iter() + .map(|out_eval| out_eval.evaluate(&out_evals, challenges).eval) + .collect_vec(); + let point = out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, challenges).point); + (evals, point) + }) + .collect_vec(); + + let eval_len = layer.n_witin + layer.n_fixed + layer.n_structural_witin; + layers.push(PendingLayer { + pending, + layer, + eval_and_dedup_points, + eval_start: total_evals, + eval_len, + alpha_start: total_exprs, + }); + total_evals += eval_len; + total_exprs += layer.exprs.len(); } - let pi = cs - .instance + let main_evals = &main_constraint_proof.proof.evals; + if main_evals.len() != total_evals { + return Err(ZKVMError::InvalidProof( + format!( + "main constraint eval length mismatch: {} != {}", + main_evals.len(), + total_evals + ) + .into(), + )); + } + + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let sigma = layers .iter() - .map(|instance| E::from(public_values.query_by_index::(instance.0))) - .collect_vec(); - let (wits_in_evals, fixed_in_evals) = Self::split_input_opening_evals(circuit_vk, proof)?; - let gkr_iop_proof = proof.gkr_iop_proof.clone().ok_or_else(|| { - ZKVMError::InvalidProof(format!("{_name} missing gkr iop proof").into()) - })?; - let (_, rt) = gkr_circuit.verify( - num_var_with_rotation, - gkr_iop_proof, - &out_evals, - &pi, - challenges, + .map(|pending_layer| { + let alpha = &alpha_pows[pending_layer.alpha_start + ..pending_layer.alpha_start + pending_layer.layer.exprs.len()]; + let local_sigma: E = dot_product( + alpha.iter().copied(), + pending_layer + .eval_and_dedup_points + .iter() + .flat_map(|(sigmas, _)| sigmas) + .copied(), + ); + E::from_canonical_u64( + 1u64 << (max_num_variables - pending_layer.pending.num_var_with_rotation), + ) * local_sigma + }) + .sum::(); + + let SumCheckSubClaim { + point: global_in_point, + expected_evaluation, + } = IOPVerifierState::verify( + sigma, + &main_constraint_proof.proof.proof, + &VPAuxInfo { + max_degree, + max_num_variables, + phantom: PhantomData, + }, transcript, - &selector_ctxs, - )?; - Ok((rt, shard_ec_sum, wits_in_evals, fixed_in_evals)) + ); + let global_in_point = global_in_point + .into_iter() + .map(|challenge| challenge.elements) + .collect_vec(); + transcript.append_field_element_exts(main_evals); + + let mut got_claim = E::ZERO; + let mut results = Vec::with_capacity(layers.len()); + for pending_layer in &layers { + let in_point = global_in_point + [global_in_point.len() - pending_layer.pending.num_var_with_rotation..] + .to_vec(); + let layer_evals = &main_evals + [pending_layer.eval_start..pending_layer.eval_start + pending_layer.eval_len]; + let structural_witin_offset = pending_layer.layer.n_witin + pending_layer.layer.n_fixed; + + validate_batched_main_structural_evals( + pending_layer.pending.circuit_name, + pending_layer.layer, + &pending_layer.eval_and_dedup_points, + &pending_layer.pending.selector_ctxs, + &pending_layer.pending.pi, + layer_evals, + structural_witin_offset, + &in_point, + ) + .map_err(|err| ZKVMError::InvalidProof(err.into()))?; + + let main_sumcheck_challenges = chain!( + challenges.iter().copied(), + alpha_pows[pending_layer.alpha_start + ..pending_layer.alpha_start + pending_layer.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + got_claim += eval_by_expr_with_instance( + &[], + layer_evals, + &[], + &pending_layer.pending.pi, + &main_sumcheck_challenges, + pending_layer + .layer + .main_sumcheck_expression + .as_ref() + .unwrap(), + ) + .map_either(E::from, |v| v) + .into_inner(); + + results.push(( + in_point, + layer_evals[..pending_layer.layer.n_witin].to_vec(), + layer_evals[pending_layer.layer.n_witin + ..pending_layer.layer.n_witin + pending_layer.layer.n_fixed] + .to_vec(), + )); + } + + if got_claim != expected_evaluation { + return Err(ZKVMError::InvalidProof( + format!("main constraint claim mismatch: {expected_evaluation} != {got_claim}") + .into(), + )); + } + + Ok(results) } } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 1521f1f83..f28a4a29b 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -739,9 +739,9 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, - hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, + hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, }, - structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, + structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, }; #[cfg(feature = "gpu")] @@ -873,7 +873,6 @@ mod tests { let pd = create_prover(backend); let zkvm_pk = ZKVMProvingKey::new(pp, vp); - let zkvm_vk = zkvm_pk.get_vk_slow(); let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); @@ -919,7 +918,7 @@ mod tests { }; let mut rng = thread_rng(); let challenges = [E::random(&mut rng), E::random(&mut rng)]; - let task = crate::scheme::scheduler::ChipTask { + let mut task = crate::scheme::scheduler::ChipTask { task_id: 0, circuit_name: ShardRamCircuit::::name(), circuit_idx: 0, @@ -935,24 +934,8 @@ mod tests { num_witin: 0, structural_rmm: None, }; - let (proof, _, point) = zkvm_prover - .create_chip_proof(&task, &mut transcript) + let (_proof, _main_job) = zkvm_prover + .create_chip_proof(&mut task, &mut transcript) .unwrap(); - - let mut transcript = BasicTranscript::new(b"global chip test"); - let verifier = ZKVMVerifier::new(zkvm_vk); - let (vrf_point, _, _, _) = verifier - .verify_chip_proof( - "global", - &pk.vk, - &proof, - &public_value, - &mut transcript, - 2, - &PointAndEval::default(), - &challenges, - ) - .expect("verify global chip proof"); - assert_eq!(vrf_point, point); } } diff --git a/gkr_iop/src/gkr/layer/gpu/utils.rs b/gkr_iop/src/gkr/layer/gpu/utils.rs index e67153cc2..68a283588 100644 --- a/gkr_iop/src/gkr/layer/gpu/utils.rs +++ b/gkr_iop/src/gkr/layer/gpu/utils.rs @@ -115,6 +115,7 @@ pub fn encode_common_term_plan(plan: &CommonFactoredTermPlan, total_mles: usize) common_mle_indices, common_scalar_offsets, common_scalar_indices, + active_counts_by_num_vars: vec![], } } @@ -251,8 +252,9 @@ pub fn build_rotation_mles_gpu panic!("should be base field"), _ => panic!("unimplemented input mle"), }; + let logical_len = 1usize << input_mle.mle.num_vars(); let mut output_buf = cuda_hal - .alloc_elems_on_device(input_buf.len(), false, stream.as_ref()) + .alloc_elems_on_device(logical_len, false, stream.as_ref()) .unwrap(); // Safety: GPU buffers are actually 'static lifetime. We only read from input_buf @@ -294,8 +296,8 @@ pub fn build_rotation_selector_gpu MultilinearExtensionGpu<'static, E> { let stream = crate::gpu::get_thread_stream(); - let total_len = wit[0].evaluations_len(); // Take first mle just to retrieve total length - assert!(total_len.is_power_of_two()); + let num_vars = wit[0].num_vars(); + let total_len = 1usize << num_vars; let mut output_buf = cuda_hal .alloc_ext_elems_on_device(total_len, false, stream.as_ref()) .unwrap(); @@ -322,10 +324,8 @@ pub fn build_rotation_selector_gpu, diff --git a/gkr_iop/src/gpu/mod.rs b/gkr_iop/src/gpu/mod.rs index 54ae3744a..62d19ca85 100644 --- a/gkr_iop/src/gpu/mod.rs +++ b/gkr_iop/src/gpu/mod.rs @@ -222,7 +222,7 @@ impl<'a, E: ExtensionField> MultilinearExtensionGpu<'a, E> { let cpu_evaluations = poly.to_cpu_vec(stream.as_ref()); let cpu_evaluations_base: Vec = unsafe { std::mem::transmute(cpu_evaluations) }; - MultilinearExtension::from_evaluations_vec( + MultilinearExtension::from_evaluations_vec_compact( self.mle.num_vars(), cpu_evaluations_base, ) @@ -230,7 +230,7 @@ impl<'a, E: ExtensionField> MultilinearExtensionGpu<'a, E> { GpuFieldType::Ext(poly) => { let cpu_evaluations = poly.to_cpu_vec(stream.as_ref()); let cpu_evaluations_ext: Vec = unsafe { std::mem::transmute(cpu_evaluations) }; - MultilinearExtension::from_evaluations_ext_vec( + MultilinearExtension::from_evaluations_ext_vec_compact( self.mle.num_vars(), cpu_evaluations_ext, ) @@ -506,13 +506,23 @@ impl> let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = unsafe { std::mem::transmute(all_witins_gpu) }; let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); + // Match the CPU witness inference path: layer outputs are materialized over + // the occupied prefix of the layer witness domain, not the maximum length of + // any referenced structural/fixed MLE. + let output_len = all_witins_gpu_gl64 + .first() + .map(|mle| mle.evaluations_len()) + .unwrap_or(0); + let output_lengths = + std::iter::repeat_n(output_len, mle_indices_per_term.len()).collect_vec(); // buffer for output witness from gpu let cuda_hal = get_cuda_hal().unwrap(); - let mut next_witness_buf = (0..num_non_zero_expr) - .map(|_| { + let mut next_witness_buf = output_lengths + .iter() + .map(|&output_len| { cuda_hal - .alloc_ext_elems_on_device(1 << num_vars, false, stream.as_ref()) + .alloc_ext_elems_on_device(output_len, false, stream.as_ref()) .map_err(|e| format!("Failed to allocate prod GPU buffer: {:?}", e)) }) .collect::, _>>() diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index e1c8d7453..f133ae126 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -5,7 +5,6 @@ use itertools::Itertools; use multilinear_extensions::{ Fixed, WitIn, WitnessId, mle::{ArcMultilinearExtension, MultilinearExtension}, - util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; use p3::field::FieldAlgebra; @@ -49,7 +48,7 @@ pub fn rotation_next_base_mle<'a, E: ExtensionField>( rotate_chunk[to] = original_chunk[from]; } }); - MultilinearExtension::from_evaluation_vec_smart(mle.num_vars(), rotated_mle_evals) + MultilinearExtension::from_evaluation_vec_smart_compact(mle.num_vars(), rotated_mle_evals) } pub fn rotation_selector<'a, E: ExtensionField>( @@ -59,7 +58,6 @@ pub fn rotation_selector<'a, E: ExtensionField>( cyclic_group_log2_size: usize, total_len: usize, ) -> MultilinearExtension<'a, E> { - assert!(total_len.is_power_of_two()); let cyclic_group_size = 1 << cyclic_group_log2_size; assert!(cyclic_subgroup_size <= cyclic_group_size); let rotation_index = bh.into_iter().take(cyclic_subgroup_size).collect_vec(); @@ -74,7 +72,10 @@ pub fn rotation_selector<'a, E: ExtensionField>( rotate_chunk[to] = eq_chunk[to]; } }); - MultilinearExtension::from_evaluation_vec_smart(ceil_log2(total_len), rotated_mle_evals) + MultilinearExtension::from_evaluation_vec_smart_compact( + eq.len().ilog2() as usize, + rotated_mle_evals, + ) } /// sel(rx) diff --git a/summary.md b/summary.md new file mode 100644 index 000000000..0dd66ccd6 --- /dev/null +++ b/summary.md @@ -0,0 +1,329 @@ +# GPU Batched Main Sumcheck Optimization Handoff + +## Current Goal + +Optimize the new `prove_batched_main_constraints` path against the previous per-chip main sumcheck baseline. The expected model is that batching should reduce sumcheck kernel invocations and improve GPU utilization, but the current batched path introduced extra overhead from heterogeneous MLE sizes and non-direct layout handling. + +## Root-Cause Findings + +- Batched main uses one global `max_num_variables`, while chips have heterogeneous `num_vars`. +- The V2 fold kernel previously launched as `num_unique_mles * stride` each round, even when many MLEs were already inactive for that `current_num_vars`. +- This caused substantial overlaunch in large remote runs. Earlier shard-0 diagnostics showed global batched fold work could be about `39x` the per-chip ideal thread work for the fold phase. +- The batched main path also used non-direct MLE layout (`original_layout_flag = 0`) while the older wrapper path uses direct/original layout semantics. +- The remaining larger bottleneck is likely expression-side work/common factoring: the batched main path still passes `None` for common-term factoring and uses full monomial terms. + +## Implemented Changes + +### ceno + +- File: `ceno_zkvm/src/scheme/gpu/mod.rs` +- Change: `prove_batched_main_constraints` now calls `prove_generic_sumcheck_gpu_v2` with direct/original MLE layout flag `1`. +- Reason: avoid unnecessary layout conversion/indirection for batched main MLE inputs. + +### ceno-gpu + +- File: `cuda_hal/src/common/sumcheck/generic_v2.rs` +- Change: added per-round active-MLE worklist generation based on `host_num_vars_curr == current_num_vars`. +- Change: V2 fold launch size is now `active_mle_indices.len() * stride`, not `num_unique_mles * stride`. +- Change: both host-challenge and GPU-transcript challenge fold paths use the active worklist. + +- File: `cpp/common/sumcheck/generic_v2.cuh` +- Change: common V2 fold device routine now receives `active_mle_indices` and `num_active_mles`, maps active index to actual MLE index, and bounds on active work only. + +- Files: + - `cpp/bb31/kernels/sumcheck_generic_v2.cu` + - `cpp/gl64/kernels/sumcheck_generic_v2.cu` +- Change: BB31 and GL64 kernel wrappers forward active-MLE worklist arguments. + +## Validation Completed + +Commands that passed: + +```bash +cargo fmt --manifest-path ../ceno-gpu/cuda_hal/Cargo.toml --all +cargo fmt +cargo check -p ceno_zkvm --features gpu --config 'patch."https://github.com/scroll-tech/ceno-gpu-mock.git".cuda_hal.path="../ceno-gpu/cuda_hal"' +``` + +Downstream local sanity passed through `../ceno-reth-benchmark` using local `ceno` and `ceno-gpu` path patches, WITGEN=1, shard 0, block `23587691`, `--chain-id 1`, and the temporary local max-cell validation knob. Benchmark validation edits were restored afterward. + +Latest sanity log: + +```text +../ceno-reth-benchmark/sanity_23587691_shard0_witgen1_direct_active_fold_20260501_140101.log +``` + +Latest metrics: + +```text +../ceno-reth-benchmark/metrics_23587691_shard0_witgen1_direct_active_fold_20260501_140101.json +``` + +Success markers: + +```text +verifying shard proof: expected_shard_id=0, proof_shard_id=0, chip_groups=61 +single shard segment verified without full-trace continuation checks +``` + +## Local Timing Comparison + +Compared against previous local sanity log: + +```text +../ceno-reth-benchmark/sanity_23587691_shard0_witgen1_batch_alloc_clamp_20260501_133037.log +``` + +| Span | Previous | Latest | Delta | +| --- | ---: | ---: | ---: | +| `reth-block` | `76.2s` | `75.0s` | `-1.2s` | +| `app.prove` | `75.7s` | `74.5s` | `-1.2s` | +| `app_prove.inner` | `74.8s` | `73.6s` | `-1.2s` | +| `create_proof_of_shard` | `71.5s` | `70.4s` | `-1.1s` | +| `commit_traces` | `12.2s` | `12.0s` | `-0.2s` | +| `prove_batched_main_constraints` | `14.7s` | `14.6s` | `-0.1s` | +| `pcs_opening` | `2.59s` | `2.51s` | `-0.08s` | +| `app.verify` | `312ms` | `292ms` | `-20ms` | + +The local payload is small, so fold overlaunch reduction is not expected to dominate locally. The remote payload should show a clearer benefit because heterogeneous large-MLE fold overlaunch scales with global `max_num_variables`. + +## Current Git State Notes + +- `ceno` intended commit includes: + - `ceno_zkvm/src/scheme/gpu/mod.rs` + - `summary.md` +- `ceno-gpu` intended commit includes: + - `cuda_hal/src/common/sumcheck/generic_v2.rs` + - `cpp/common/sumcheck/generic_v2.cuh` + - `cpp/bb31/kernels/sumcheck_generic_v2.cu` + - `cpp/gl64/kernels/sumcheck_generic_v2.cu` +- Pre-existing unrelated local change in `ceno_zkvm/src/structs.rs` was intentionally not included. +- `../ceno-reth-benchmark/Cargo.lock` and `crates/host-bench/src/lib.rs` were restored after sanity validation. + +## Recommended Next Steps + +1. Run or inspect the remote benchmark for this commit pair and compare `prove_batched_main_constraints` against the latest feature run. +2. If batched main is still slower than per-chip baseline, prioritize expression common-term factoring for the batched path. +3. Consider bucketing by `num_var_with_rotation` only if one global sumcheck cannot recover enough performance with active folding and common factoring. +4. Keep V2 active-worklist optimization unless another V2 caller shows regression; it should reduce wasted fold work for any heterogeneous MLE set. + +## Local Baseline Vs Current Harness + +Purpose: keep a smaller local comparison target for future optimization before waiting for remote CI. Both runs use block `23587691`, `--shard-id 0`, WITGEN=1, `--chain-id 1`, local max-cell validation knob `(1 << 30) * 6 / 4 / 2`, and shared target dir `/home/wusm/rust/ceno-reth-benchmark/target`. + +### Pinned Versions + +Baseline: + +- `ceno-reth-benchmark`: `65a757522a7e` +- `ceno`: `9936d96ed51f` +- `ceno-gpu`: `911741992d4a` +- Worktrees: + - `/home/wusm/rust/ceno-reth-benchmark-baseline-65a757` + - `/home/wusm/rust/ceno-baseline-9936d96` + - `/home/wusm/rust/ceno-gpu-baseline-911741` + +Current: + +- `ceno-reth-benchmark`: `5da5ed70f2dd` +- `ceno`: `be5a6f57cdc7` +- `ceno-gpu`: `f884a4728b43` +- Worktrees: + - `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed` + - `/home/wusm/rust/ceno` + - `/home/wusm/rust/ceno-gpu` + +The isolated benchmark worktrees link these existing inputs: + +- `block_data -> ../ceno-reth-benchmark/block_data` +- `app_proof.bitcode -> ../ceno-reth-benchmark/app_proof.bitcode` +- `bin/ceno-client-eth/target -> ../../../ceno-reth-benchmark/bin/ceno-client-eth/target` + +### Latest Local Logs + +Baseline: + +```text +/home/wusm/rust/ceno-reth-benchmark-baseline-65a757/sanity_23587691_shard0_witgen1_baseline_9936d96_911741_20260501_143419.log +``` + +Current: + +```text +/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_current_be5a6f57_f884a472_20260501_143703.log +``` + +Both passed single-shard verification: + +```text +single shard segment verified without full-trace continuation checks +``` + +### Local Comparison + +| Span | Baseline | Current | Delta | +| --- | ---: | ---: | ---: | +| `reth-block` | `70.7s` | `74.8s` | `+4.1s` | +| `app.prove` | `70.2s` | `74.3s` | `+4.1s` | +| `app_prove.inner` | `69.4s` | `73.4s` | `+4.0s` | +| `create_proof_of_shard` | `66.2s` | `70.2s` | `+4.0s` | +| `commit_traces` | `11.7s` | `11.7s` | `0.0s` | +| main sumcheck | `10.788s` total per-chip `prove_main_constraints` | `14.5s` `prove_batched_main_constraints` | `+3.712s` | +| `pcs_opening` | `2.40s` | `2.47s` | `+0.07s` | +| `app.verify` | `304ms` | `294ms` | `-10ms` | + +Interpretation: + +- The smaller local block still shows current batched main slower than the original per-chip baseline by about `4.1s` E2E. +- The gap is concentrated in main sumcheck: per-chip baseline sums to `10.788s`, current batched main is `14.5s`. +- `commit_traces` is unchanged, so this local harness is useful for validating main-sumcheck optimization direction. +- This is not contradictory with the latest remote improvement: the active-MLE fold optimization improved the feature branch versus the previous feature branch, but current batched main has not yet beaten the original per-chip baseline. +- Next optimization should target batched common-term factoring/expression evaluation overhead, not PCS or trace commit. + +### Local Harness Notes + +- The first baseline attempt failed due to missing ceno-gpu submodule; fixed with `git -C /home/wusm/rust/ceno-gpu-baseline-911741 submodule update --init --recursive`. +- A later attempt hit disk exhaustion; disposable `/home/wusm/rust/ceno/target` and partial baseline worktree `target` were removed to free space. +- Shared target reuse between benchmark SHAs can reuse incompatible path-crate metadata. If current compile fails with stale `openvm-client-executor` symbols, clean affected packages with: + +```bash +cd /home/wusm/rust/ceno-reth-benchmark-current-5da5ed +CARGO_TARGET_DIR=/home/wusm/rust/ceno-reth-benchmark/target cargo clean -p openvm-client-executor -p openvm-reth-benchmark -p ceno-reth-benchmark-bin +``` + +- After validation, `Cargo.lock` and `crates/host-bench/src/lib.rs` were restored in both isolated benchmark worktrees. + +### 2026-05-01 Batched Main Sumcheck Optimization Notes + +- `prove_batched_main_constraints` now builds one global `CommonTermPlan` and keeps the verifier-facing shape as one global sumcheck/proof/transcript. +- The best validated default local result remains the common-factored grouped path: + - log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_bucketed_default_off_20260501_152804.log` + - metrics: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/metrics_23587691_shard0_witgen1_bucketed_default_off_20260501_152804.json` + - `reth-block`: `76.0s` + - `prove_batched_main_constraints`: `14.4s` +- A prototype compact-domain eval-bucket path was tried and removed. + - It preserved one global sumcheck by accumulating bucket round-polynomial evals on device before transcript squeeze. + - Validation passed, but performance regressed: + - log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_bucketed_eval_20260501_152343.log` + - `reth-block`: `75.5s` + - `prove_batched_main_constraints`: `15.1s` + - The environment-gated implementation was deleted rather than kept as a disabled path. +- A group-domain gate in the common-factor CUDA evaluator was tried and removed. + - It precomputed each common group domain and skipped inactive duplicated lanes before loading common factors. + - Validation passed, but performance did not improve: + - log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_group_domain_gate_20260501_162629.log` + - `reth-block`: `76.1s` + - `prove_batched_main_constraints`: `14.6s` + - The CUDA changes were reverted rather than kept because this does not beat the current common-factored path or the original per-chip baseline. +- A host-side domain-bucketed global sumcheck prototype was tried and removed. + - It kept one global proof/transcript and skipped smaller-domain buckets until their active rounds, but it initialized/ran separate V2 prover states per bucket and accumulated round evals on CPU. + - Validation passed, but performance regressed: + - log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_bucketed_global_skip_20260501_170025.log` + - metrics: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/metrics_23587691_shard0_witgen1_bucketed_global_skip_20260501_170025.json` + - `reth-block`: `74.8s` + - `prove_batched_main_constraints`: `15.0s` + - The CUDA bucket API and Ceno bucket wiring were removed. This confirms the next direction should not be multiple host-managed prover states; any bucket/worklist optimization must be one fused CUDA shape with persistent metadata and one global eval buffer. +- Next viable optimization should not split proofs. It should fuse work inside the one global sumcheck: + - precompute/upload persistent bucket/worklist metadata instead of per-round small H2D allocations; + - move toward one CUDA worklist kernel mapping `global_tid -> (work_item, local_tid)`; + - accumulate each work item into the same global round eval vector before the single transcript squeeze; + - use thresholds from measured bucket work so late/small rounds stay on the existing grouped common-term kernel. + +### Proposed Next Direction: Staggered-Domain Global Sumcheck + +Protocol idea: + +- Treat each chip/domain bucket as its own lower-dimensional zero-claim sumcheck, but keep one global transcript and one aggregated round polynomial. +- For global rounds where `current_num_vars > bucket_num_vars`, the bucket is not yet active and contributes nothing to the running claim. +- When `current_num_vars == bucket_num_vars`, activate that bucket by adding its initial claim to the verifier/prover running claim. +- For main constraints, each activated bucket initial claim is zero, so activation is protocol-cheap and does not need extra proof data unless future nonzero buckets are introduced. +- After activation, the bucket participates in all subsequent rounds and is folded with the same global transcript challenge. + +Correctness distinction: + +- Skipping arbitrary individual lower-domain monomial terms is invalid because each term may have a nonzero constant contribution and cancellation only holds across the complete zero-claim expression. +- Skipping a complete lower-domain zero-claim bucket before activation is valid because that sub-sumcheck has not been introduced into the global running claim yet. + +Implementation target: + +- Do not revive the removed host-side bucket prototype. +- Build one fused CUDA prover state with persistent domain-bucket metadata: + - buckets sorted by `bucket_num_vars`; + - per-bucket MLE ranges, term ranges, common-term ranges; + - activation offsets by global round/domain; + - one global round eval buffer and one GPU transcript. +- Per round: + - activate newly eligible zero-claim buckets; + - eval kernel processes only active buckets/work items; + - inactive smaller buckets are not traversed; + - fold kernel folds active MLEs, reusing the existing active-MLE worklist idea. +- Verifier/proof shape should remain one global main proof if zero activations are implicit. If nonzero activations are later needed, the proof/verifier must include or derive those claims at activation rounds. + +Expected benefit: + +- Avoids global-round eval traversal over lower-domain chips before their active domain starts. +- Attacks the real bottleneck: heterogeneous-domain expression evaluation, not fold. +- Keeps the long-term CUDA shape: one transcript, one eval buffer, persistent worklist metadata, no per-round host orchestration. + +## Restore Prompt For Next Session + +Use this prompt in the next Codex session: + +```text +We are continuing GPU batched main sumcheck optimization. Please read `/home/wusm/rust/ceno/summary.md` first. Workspaces are `/home/wusm/rust/ceno`, `/home/wusm/rust/ceno-gpu`, and `/home/wusm/rust/ceno-reth-benchmark`. The latest committed state includes direct MLE layout for `prove_batched_main_constraints` in ceno (`be5a6f57`) and active-MLE V2 fold worklist optimization in ceno-gpu (`f884a472`). Do not touch the unrelated local `ceno_zkvm/src/structs.rs` change unless explicitly requested. Use the `reth-gpu-sanity` skill for downstream validation, never print `$CENO_RPC`, and restore benchmark validation-only edits after running sanity. For quick local performance checks, use the baseline/current harness documented in `summary.md`: baseline worktrees are `/home/wusm/rust/ceno-reth-benchmark-baseline-65a757`, `/home/wusm/rust/ceno-baseline-9936d96`, `/home/wusm/rust/ceno-gpu-baseline-911741`; current worktree is `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed` with local `/home/wusm/rust/ceno` and `/home/wusm/rust/ceno-gpu`. Latest local result: current is still about `+4.1s` slower E2E than original per-chip baseline on block `23587691` shard 0, with main sumcheck `14.5s` versus per-chip sum `10.788s`. Important failed paths already removed: `CENO_GPU_SUMCHECK_BUCKET_EVAL`, group-domain gate, and host-side bucketed global skip. Next focus: design and implement the staggered-domain global sumcheck plan from `summary.md`: activate complete zero-claim chip/domain buckets only when `current_num_vars == bucket_num_vars`, keep one global transcript/proof, and implement it as one fused CUDA prover state with persistent bucket/worklist metadata, not multiple host-managed prover states. +``` + +### 2026-05-03 GPU-Feature Correction And Kernel Results + +Important correction: several older local sanity numbers in this file were produced without enabling the benchmark `gpu` feature. Those runs can still pass verification, but they do not exercise `ceno_zkvm/src/scheme/gpu` or local `ceno-gpu` changes. Treat the older `70s/14s` local comparison as stale for GPU-prover performance. Current GPU sanity must use `--features 'jemalloc,gpu'` and confirm `CUDA Backend Enabled` appears in the log. + +Committed experiment state: + +- `ceno`: `27ed8650` (`Experiment staggered batched main sumcheck prover`) +- `ceno-gpu`: `14649f95` (`Add staggered-domain sumcheck V2 entry`) + +Current code shape: + +- `prove_batched_main_constraints` builds one global `CommonTermPlan` for the batched main path. +- Common groups are sorted by descending `num_var_with_rotation`. +- `active_counts_by_num_vars[current_num_vars]` limits the eval kernel to the active common-group prefix, so smaller-domain inactive groups are not traversed. +- CUDA V2 eval has a `staggered_domain` flag; the batch-main entry currently passes `true`. +- Kernel selection remains the default/auto policy: use `sumcheck_round_perterm` only when `active_num_terms >= 256` and `poly_len <= MIN_LEN_FOR_WARP_REDUCTION`; otherwise use `sumcheck_round`. +- The prover still emits one global sumcheck proof/transcript with `max_num_vars` rounds. +- Verifier is not yet updated for staggered activation, so the GPU-feature staggered run is expected to fail with a main-claim mismatch. + +Correct GPU-feature baseline/current comparison on block `23587691`, shard 0, WITGEN=1: + +- Baseline log: `/home/wusm/rust/ceno-reth-benchmark-baseline-65a757/sanity_23587691_shard0_witgen1_gpu_feature_baseline_20260503_144417.log` +- Current staggered log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_gpu_feature_stagger_20260503_143633.log` +- Baseline per-chip `prove_main_constraints` sum: `0.746s` across `61` chips. +- Current `prove_batched_main_constraints`: `1.14s`. +- Main-sumcheck gap: current is `+0.394s`, about `+53%` slower. +- Baseline `reth-block`: `8.96s`, verified passed. +- Current `reth-block`: `9.11s`, verifier failed as expected with `main constraint claim mismatch`. + +Common factoring validation: + +- Log: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_common_factored_no_active_prefix_20260503_151018.log` +- Mode: common-group factoring enabled, but staggered active-prefix disabled for this check. +- Result: verified passed. +- `prove_batched_main_constraints`: `5.71s`. +- Conclusion: common-group factoring shape is algebraically correct. The verifier failure comes from staggered active-prefix/protocol mismatch, not common factoring itself. + +Kernel selection experiment on current staggered prover shape: + +- Auto/default: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_gpu_feature_stagger_20260503_143633.log` + - `prove_batched_main_constraints`: `1.14s` +- Forced normal kernel: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_kernel_normal_20260503_151308.log` + - `prove_batched_main_constraints`: `1.35s` + - Slower; do not force normal. +- Forced per-term for all small rounds: `/home/wusm/rust/ceno-reth-benchmark-current-5da5ed/sanity_23587691_shard0_witgen1_kernel_perterm_20260503_151514.log` + - `prove_batched_main_constraints`: `1.13s` + - Slightly faster locally, but within noise. Keep auto/default unless a larger payload confirms this wins. + +Next direction: + +1. Implement verifier running-claim activation by `current_num_vars` for the staggered single-global-sumcheck protocol. +2. Preserve the current auto kernel selector by default. +3. After verifier support, remeasure on the larger payload and then consider a threshold tweak only if forced per-term consistently wins. +4. The remaining `+53%` gap is not from scanning inactive smaller-domain groups; active prefix already avoids that. Focus next on matching per-chip prover work for active groups: degree sharing overhead, grouped-kernel register/cache behavior, metadata indirection, and common-factor granularity.