diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 0813449e1..a4328ada1 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -845,6 +845,7 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device Result<_, ZKVMError> { let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; + let gpu_cols = hal + .alloc_u32_from_host(&col_offsets, Some(&stream)) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; let (mut cur_x, mut cur_y) = hal .witgen - .extract_ec_points_from_device(device_records, num_records, n, None) + .extract_ec_points_from_device(device_records, num_records, n, Some(&stream)) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU extract_ec_points failed: {e}").into()) })?; @@ -889,6 +892,7 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device = match source { DeferredGpuTrace::Eager(rmm) => rmm, - DeferredGpuTrace::Replay(plan) => plan - .replay_witness() - .map(|witness_rmm| { - assert_eq!( - witness_rmm.height(), - plan.trace_height, - "replayed trace height changed between plan build and deferred commit", - ); - witness_rmm - }) - .map_err(|e| ceno_gpu::HalError::InvalidInput(format!("{e:?}")))?, + DeferredGpuTrace::Replay(plan) => info_span!( + "[ceno] replay_witness_materialize", + phase = "commit_traces", + trace_idx, + kind = ?plan.kind, + rows = plan.trace_height, + num_witin = plan.num_witin, + steps = plan.step_indices.len(), + ) + .in_scope(|| plan.replay_witness()) + .map(|witness_rmm| { + assert_eq!( + witness_rmm.height(), + plan.trace_height, + "replayed trace height changed between plan build and deferred commit", + ); + witness_rmm + }) + .map_err(|e| ceno_gpu::HalError::InvalidInput(format!("{e:?}")))?, }; Ok(unsafe { std::mem::transmute(witness_rmm) }) }) @@ -1501,7 +1509,16 @@ where }; for (trace_idx, replay_plan) in replayable_traces { - let witness_rmm = replay_plan.replay_witness()?; + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "restore_pcs_backing", + trace_idx = *trace_idx, + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness())?; assert_eq!( witness_rmm.height(), replay_plan.trace_height, @@ -2090,6 +2107,143 @@ impl> OpeningProver( + prover: &GpuProver>, + witness_data: as ProverBackend>::PcsData, + fixed_data: Option as ProverBackend>::PcsData>>, + replayable_traces: &[(usize, crate::structs::GpuReplayPlan)], + points: Vec>, + mut evals: Vec>>, + transcript: &mut (impl Transcript + 'static), +) -> PCS::Proof +where + E: ExtensionField, + PCS: PolynomialCommitmentScheme, +{ + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear base field"); + } + + let mut rounds = vec![]; + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } + }) + .collect_vec() + })); + if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points.iter().cloned()) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } + }) + .collect_vec() + })); + } + + let prover_param = &prover.backend.pp; + let pp_gl64: &mpcs::basefold::structure::BasefoldProverParams = + unsafe { std::mem::transmute(prover_param) }; + let rounds_gl64: Vec<_> = rounds + .iter() + .map(|(commitment, point_eval_pairs)| { + let commitment_gl64: &BasefoldCommitmentWithWitnessGpu< + BB31Base, + BufferImpl, + GpuDigestLayer, + GpuMatrix<'static>, + GpuPolynomial<'static>, + > = unsafe { std::mem::transmute(*commitment) }; + let point_eval_pairs_gl64: Vec<_> = point_eval_pairs + .iter() + .map(|(point, evals)| { + let point_gl64: &Vec = unsafe { std::mem::transmute(point) }; + let evals_gl64: &Vec = unsafe { std::mem::transmute(evals) }; + (point_gl64.clone(), evals_gl64.clone()) + }) + .collect(); + (commitment_gl64, point_eval_pairs_gl64) + }) + .collect(); + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear field extension"); + } + + let transcript_any = transcript as &mut dyn std::any::Any; + let basic_transcript = transcript_any + .downcast_mut::>() + .expect("Type should match"); + + let cuda_hal = get_cuda_hal().unwrap(); + let gpu_proof_basefold = cuda_hal + .basefold + .batch_open_with_trace_materializer( + &cuda_hal, + pp_gl64, + rounds_gl64, + basic_transcript, + |round_idx, trace_idx| { + if round_idx != 0 { + return Ok(None); + } + let Some((_, replay_plan)) = replayable_traces + .iter() + .find(|(replay_trace_idx, _)| *replay_trace_idx == trace_idx) + else { + return Ok(None); + }; + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "pcs_opening", + round_idx, + trace_idx, + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness()) + .map_err(|err| { + ceno_gpu::HalError::InvalidInput(format!( + "failed to replay trace {trace_idx} for PCS opening: {err:?}" + )) + })?; + if witness_rmm.height() != replay_plan.trace_height { + return Err(ceno_gpu::HalError::InvalidInput(format!( + "replayed trace {trace_idx} height changed before PCS opening: expected {}, got {}", + replay_plan.trace_height, + witness_rmm.height(), + ))); + } + let witness_rmm_bb31: witness::RowMajorMatrix = + unsafe { std::mem::transmute(witness_rmm) }; + Ok(Some(witness_rmm_bb31)) + }, + ) + .unwrap(); + + let gpu_proof: PCS::Proof = unsafe { std::mem::transmute_copy(&gpu_proof_basefold) }; + std::mem::forget(gpu_proof_basefold); + drop(rounds); + drop(witness_data); + gpu_proof +} + impl> DeviceTransporter> for GpuProver> { @@ -2208,7 +2362,17 @@ impl> ) }) } else { - let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "gpu_task", + circuit = task.circuit_name.as_str(), + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness()) + .expect("GPU raw replay failed"); check_gpu_mem_estimation_with_context( gpu_mem_tracker, estimated_replay_bytes, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c0e4e42e6..267bd4920 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -239,9 +239,7 @@ impl< && !crate::instructions::gpu::config::should_retain_witness_device_backing_after_commit(); #[cfg(feature = "gpu")] let trace_rows_for_estimate = - if !crate::instructions::gpu::config::is_gpu_witgen_enabled() - && witness_rmm.num_instances() > 0 - { + if witness_rmm.num_instances() > 0 && gpu_replay_plan.is_none() { Some(witness_rmm.height()) } else { None @@ -557,24 +555,39 @@ impl< if needs_replay_restore { let replay_cache = current_replay_cache_stats(); tracing::info!( - "[gpu replay cache][before_restore_pcs] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", + "[gpu replay cache][before_pcs_opening] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0), replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0), replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0), replay_cache.total_bytes() as f64 / (1024.0 * 1024.0), ); - crate::scheme::gpu::log_gpu_device_state("before_restore_pcs"); - let gpu_witness_data: &mut as ProverBackend>::PcsData = - unsafe { std::mem::transmute(&mut witness_data) }; - crate::scheme::gpu::restore_replayable_trace_device_backing::( - gpu_witness_data, - &replayable_traces, - )?; - crate::scheme::gpu::log_gpu_device_state("after_restore_pcs"); + crate::scheme::gpu::log_gpu_device_state("before_pcs_opening"); } let mpcs_opening_proof = info_span!("[ceno] pcs_opening").in_scope(|| { #[cfg(feature = "gpu")] - { + if needs_replay_restore { + let gpu_device: &gkr_iop::gpu::GpuProver> = + unsafe { std::mem::transmute(&self.device) }; + let gpu_witness_data: as ProverBackend>::PcsData = + unsafe { std::mem::transmute_copy(&witness_data) }; + std::mem::forget(witness_data); + let fixed_data = self + .get_device_proving_key(shard_ctx) + .map(|dpk| dpk.pcs_data.clone()); + let gpu_fixed_data: Option< + std::sync::Arc< + as ProverBackend>::PcsData, + >, + > = unsafe { std::mem::transmute(fixed_data) }; + return crate::scheme::gpu::open_with_incremental_replay::( + gpu_device, + gpu_witness_data, + gpu_fixed_data, + &replayable_traces, + points, + evaluations, + &mut transcript, + ); } self.device.open( witness_data, @@ -1208,7 +1221,16 @@ where ); log_gpu_device_state(&format!("{name}:before_replay")); log_gpu_pool_usage(&format!("{name}:before_replay")); - let witness_rmm = replay_plan.replay_witness()?; + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "chip_proof", + circuit = name, + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness())?; crate::scheme::gpu::check_gpu_mem_estimation_with_context( gpu_mem_tracker, estimated_replay_bytes,