Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device<E: ExtensionFiel

let n = next_pow2_instance_padding(num_records);
let num_rows_padded = 2 * n;
let stream = hal.inner.stream_or_default(None);

let col_map = extract_shard_ram_column_map(config, num_witin);

Expand All @@ -863,7 +864,7 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device<E: ExtensionFiel
num_local_writes as u32,
num_witin as u32,
num_rows_padded as u32,
None,
Some(&stream),
)
.map_err(|e| {
ZKVMError::InvalidWitness(
Expand All @@ -875,20 +876,23 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device<E: ExtensionFiel
let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree_from_device", n).in_scope(
|| -> Result<_, ZKVMError> {
let col_offsets = col_map.to_flat();
let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| {
ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into())
})?;
let gpu_cols = hal
.alloc_u32_from_host(&col_offsets, Some(&stream))
.map_err(|e| {
ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into())
})?;

let (mut cur_x, mut cur_y) = hal
.witgen
.extract_ec_points_from_device(device_records, num_records, n, None)
.extract_ec_points_from_device(device_records, num_records, n, Some(&stream))
.map_err(|e| {
ZKVMError::InvalidWitness(format!("GPU extract_ec_points failed: {e}").into())
})?;

let mut witness_buf = gpu_witness.device_buffer;
let mut offset = num_rows_padded / 2;
let mut current_layer_len = n;
let mut retained_layers = Vec::new();

loop {
if current_layer_len <= 1 {
Expand All @@ -897,26 +901,31 @@ pub(crate) fn try_gpu_assign_shard_ram_witness_only_from_device<E: ExtensionFiel

let (next_x, next_y) = hal
.witgen
.shard_ram_ec_tree_layer(
.shard_ram_ec_tree_layer_async(
&gpu_cols,
&cur_x,
&cur_y,
&mut witness_buf,
current_layer_len,
offset,
num_rows_padded,
None,
Some(&stream),
)
.map_err(|e| {
ZKVMError::InvalidWitness(format!("GPU EC tree layer failed: {e}").into())
})?;

current_layer_len /= 2;
offset += current_layer_len;
retained_layers.push((cur_x, cur_y));
cur_x = next_x;
cur_y = next_y;
}

hal.inner.synchronize_stream(&stream).map_err(|e| {
ZKVMError::InvalidWitness(format!("GPU EC tree stream sync failed: {e}").into())
})?;

Ok(witness_buf)
},
)?;
Expand Down
190 changes: 177 additions & 13 deletions ceno_zkvm/src/scheme/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1120,17 +1120,25 @@ where
.expect("deferred commit source reused");
let witness_rmm: witness::RowMajorMatrix<E::BaseField> = match source {
DeferredGpuTrace::Eager(rmm) => rmm,
DeferredGpuTrace::Replay(plan) => plan
.replay_witness()
.map(|witness_rmm| {
assert_eq!(
witness_rmm.height(),
plan.trace_height,
"replayed trace height changed between plan build and deferred commit",
);
witness_rmm
})
.map_err(|e| ceno_gpu::HalError::InvalidInput(format!("{e:?}")))?,
DeferredGpuTrace::Replay(plan) => info_span!(
"[ceno] replay_witness_materialize",
phase = "commit_traces",
trace_idx,
kind = ?plan.kind,
rows = plan.trace_height,
num_witin = plan.num_witin,
steps = plan.step_indices.len(),
)
.in_scope(|| plan.replay_witness())
.map(|witness_rmm| {
assert_eq!(
witness_rmm.height(),
plan.trace_height,
"replayed trace height changed between plan build and deferred commit",
);
witness_rmm
})
.map_err(|e| ceno_gpu::HalError::InvalidInput(format!("{e:?}")))?,
};
Ok(unsafe { std::mem::transmute(witness_rmm) })
})
Expand Down Expand Up @@ -1501,7 +1509,16 @@ where
};

for (trace_idx, replay_plan) in replayable_traces {
let witness_rmm = replay_plan.replay_witness()?;
let witness_rmm = info_span!(
"[ceno] replay_witness_materialize",
phase = "restore_pcs_backing",
trace_idx = *trace_idx,
kind = ?replay_plan.kind,
rows = replay_plan.trace_height,
num_witin = replay_plan.num_witin,
steps = replay_plan.step_indices.len(),
)
.in_scope(|| replay_plan.replay_witness())?;
assert_eq!(
witness_rmm.height(),
replay_plan.trace_height,
Expand Down Expand Up @@ -2090,6 +2107,143 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> OpeningProver<GpuBac
}
}

pub fn open_with_incremental_replay<E, PCS>(
prover: &GpuProver<GpuBackend<E, PCS>>,
witness_data: <GpuBackend<E, PCS> as ProverBackend>::PcsData,
fixed_data: Option<Arc<<GpuBackend<E, PCS> as ProverBackend>::PcsData>>,
replayable_traces: &[(usize, crate::structs::GpuReplayPlan<E>)],
points: Vec<Point<E>>,
mut evals: Vec<Vec<Vec<E>>>,
transcript: &mut (impl Transcript<E> + 'static),
) -> PCS::Proof
where
E: ExtensionField,
PCS: PolynomialCommitmentScheme<E>,
{
if std::any::TypeId::of::<E::BaseField>() != std::any::TypeId::of::<BB31Base>() {
panic!("GPU backend only supports BabyBear base field");
}

let mut rounds = vec![];
rounds.push((&witness_data, {
evals
.iter_mut()
.zip(&points)
.filter_map(|(evals, point)| {
let witin_evals = evals.remove(0);
if !witin_evals.is_empty() {
Some((point.clone(), witin_evals))
} else {
None
}
})
.collect_vec()
}));
if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) {
rounds.push((fixed_data, {
evals
.iter_mut()
.zip(points.iter().cloned())
.filter_map(|(evals, point)| {
if !evals.is_empty() && !evals[0].is_empty() {
Some((point.clone(), evals.remove(0)))
} else {
None
}
})
.collect_vec()
}));
}

let prover_param = &prover.backend.pp;
let pp_gl64: &mpcs::basefold::structure::BasefoldProverParams<BB31Ext, mpcs::BasefoldRSParams> =
unsafe { std::mem::transmute(prover_param) };
let rounds_gl64: Vec<_> = rounds
.iter()
.map(|(commitment, point_eval_pairs)| {
let commitment_gl64: &BasefoldCommitmentWithWitnessGpu<
BB31Base,
BufferImpl<BB31Base>,
GpuDigestLayer,
GpuMatrix<'static>,
GpuPolynomial<'static>,
> = unsafe { std::mem::transmute(*commitment) };
let point_eval_pairs_gl64: Vec<_> = point_eval_pairs
.iter()
.map(|(point, evals)| {
let point_gl64: &Vec<BB31Ext> = unsafe { std::mem::transmute(point) };
let evals_gl64: &Vec<BB31Ext> = unsafe { std::mem::transmute(evals) };
(point_gl64.clone(), evals_gl64.clone())
})
.collect();
(commitment_gl64, point_eval_pairs_gl64)
})
.collect();

if std::any::TypeId::of::<E>() != std::any::TypeId::of::<BB31Ext>() {
panic!("GPU backend only supports BabyBear field extension");
}

let transcript_any = transcript as &mut dyn std::any::Any;
let basic_transcript = transcript_any
.downcast_mut::<BasicTranscript<BB31Ext>>()
.expect("Type should match");

let cuda_hal = get_cuda_hal().unwrap();
let gpu_proof_basefold = cuda_hal
.basefold
.batch_open_with_trace_materializer(
&cuda_hal,
pp_gl64,
rounds_gl64,
basic_transcript,
|round_idx, trace_idx| {
if round_idx != 0 {
return Ok(None);
}
let Some((_, replay_plan)) = replayable_traces
.iter()
.find(|(replay_trace_idx, _)| *replay_trace_idx == trace_idx)
else {
return Ok(None);
};
let witness_rmm = info_span!(
"[ceno] replay_witness_materialize",
phase = "pcs_opening",
round_idx,
trace_idx,
kind = ?replay_plan.kind,
rows = replay_plan.trace_height,
num_witin = replay_plan.num_witin,
steps = replay_plan.step_indices.len(),
)
.in_scope(|| replay_plan.replay_witness())
.map_err(|err| {
ceno_gpu::HalError::InvalidInput(format!(
"failed to replay trace {trace_idx} for PCS opening: {err:?}"
))
})?;
if witness_rmm.height() != replay_plan.trace_height {
return Err(ceno_gpu::HalError::InvalidInput(format!(
"replayed trace {trace_idx} height changed before PCS opening: expected {}, got {}",
replay_plan.trace_height,
witness_rmm.height(),
)));
}
let witness_rmm_bb31: witness::RowMajorMatrix<BB31Base> =
unsafe { std::mem::transmute(witness_rmm) };
Ok(Some(witness_rmm_bb31))
},
)
.unwrap();

let gpu_proof: PCS::Proof = unsafe { std::mem::transmute_copy(&gpu_proof_basefold) };
std::mem::forget(gpu_proof_basefold);
drop(rounds);
drop(witness_data);
gpu_proof
}

impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> DeviceTransporter<GpuBackend<E, PCS>>
for GpuProver<GpuBackend<E, PCS>>
{
Expand Down Expand Up @@ -2208,7 +2362,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>
)
})
} else {
let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed");
let witness_rmm = info_span!(
"[ceno] replay_witness_materialize",
phase = "gpu_task",
circuit = task.circuit_name.as_str(),
kind = ?replay_plan.kind,
rows = replay_plan.trace_height,
num_witin = replay_plan.num_witin,
steps = replay_plan.step_indices.len(),
)
.in_scope(|| replay_plan.replay_witness())
.expect("GPU raw replay failed");
check_gpu_mem_estimation_with_context(
gpu_mem_tracker,
estimated_replay_bytes,
Expand Down
50 changes: 36 additions & 14 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ impl<
&& !crate::instructions::gpu::config::should_retain_witness_device_backing_after_commit();
#[cfg(feature = "gpu")]
let trace_rows_for_estimate =
if !crate::instructions::gpu::config::is_gpu_witgen_enabled()
&& witness_rmm.num_instances() > 0
{
if witness_rmm.num_instances() > 0 && gpu_replay_plan.is_none() {
Some(witness_rmm.height())
} else {
None
Expand Down Expand Up @@ -557,24 +555,39 @@ impl<
if needs_replay_restore {
let replay_cache = current_replay_cache_stats();
tracing::info!(
"[gpu replay cache][before_restore_pcs] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB",
"[gpu replay cache][before_pcs_opening] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB",
replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0),
replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0),
replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0),
replay_cache.total_bytes() as f64 / (1024.0 * 1024.0),
);
crate::scheme::gpu::log_gpu_device_state("before_restore_pcs");
let gpu_witness_data: &mut <gkr_iop::gpu::GpuBackend<E, PCS> as ProverBackend>::PcsData =
unsafe { std::mem::transmute(&mut witness_data) };
crate::scheme::gpu::restore_replayable_trace_device_backing::<E, PCS>(
gpu_witness_data,
&replayable_traces,
)?;
crate::scheme::gpu::log_gpu_device_state("after_restore_pcs");
crate::scheme::gpu::log_gpu_device_state("before_pcs_opening");
}
let mpcs_opening_proof = info_span!("[ceno] pcs_opening").in_scope(|| {
#[cfg(feature = "gpu")]
{
if needs_replay_restore {
let gpu_device: &gkr_iop::gpu::GpuProver<gkr_iop::gpu::GpuBackend<E, PCS>> =
unsafe { std::mem::transmute(&self.device) };
let gpu_witness_data: <gkr_iop::gpu::GpuBackend<E, PCS> as ProverBackend>::PcsData =
unsafe { std::mem::transmute_copy(&witness_data) };
std::mem::forget(witness_data);
let fixed_data = self
.get_device_proving_key(shard_ctx)
.map(|dpk| dpk.pcs_data.clone());
let gpu_fixed_data: Option<
std::sync::Arc<
<gkr_iop::gpu::GpuBackend<E, PCS> as ProverBackend>::PcsData,
>,
> = unsafe { std::mem::transmute(fixed_data) };
return crate::scheme::gpu::open_with_incremental_replay::<E, PCS>(
gpu_device,
gpu_witness_data,
gpu_fixed_data,
&replayable_traces,
points,
evaluations,
&mut transcript,
);
}
self.device.open(
witness_data,
Expand Down Expand Up @@ -1208,7 +1221,16 @@ where
);
log_gpu_device_state(&format!("{name}:before_replay"));
log_gpu_pool_usage(&format!("{name}:before_replay"));
let witness_rmm = replay_plan.replay_witness()?;
let witness_rmm = info_span!(
"[ceno] replay_witness_materialize",
phase = "chip_proof",
circuit = name,
kind = ?replay_plan.kind,
rows = replay_plan.trace_height,
num_witin = replay_plan.num_witin,
steps = replay_plan.step_indices.len(),
)
.in_scope(|| replay_plan.replay_witness())?;
crate::scheme::gpu::check_gpu_mem_estimation_with_context(
gpu_mem_tracker,
estimated_replay_bytes,
Expand Down
Loading