From c548dc88fccac94fcb529105169ac44b32c66c0c Mon Sep 17 00:00:00 2001 From: naure Date: Thu, 21 Nov 2024 13:48:30 +0100 Subject: [PATCH] fix/program-size2: refactor padding_zero (#615) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to #611. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/instructions.rs | 2 +- ceno_zkvm/src/tables/mod.rs | 49 +++++++++++++++++++-------------- ceno_zkvm/src/tables/program.rs | 22 +++------------ ceno_zkvm/src/witness.rs | 9 +++--- 4 files changed, 39 insertions(+), 43 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 63314cbee..e87675dd6 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -94,7 +94,7 @@ pub trait Instruction { num_padding_instances }; raw_witin - .par_batch_iter_padding_mut(num_padding_instance_per_batch) + .par_batch_iter_padding_mut(None, num_padding_instance_per_batch) .with_min_len(MIN_PAR_SIZE) .for_each(|row| { row.chunks_mut(num_witin) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 2ef7e293a..f498b868e 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -2,8 +2,8 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, scheme::constants::MIN_PAR_SIZE, witness::RowMajorMatrix, }; -use ff::Field; use ff_ext::ExtensionField; +use goldilocks::SmallField; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, mem::MaybeUninit}; mod range; @@ -46,25 +46,34 @@ pub trait TableCircuit { table: &mut RowMajorMatrix, num_witin: usize, ) -> Result<(), ZKVMError> { - // Fill the padding with zeros, if any. - let num_padding_instances = table.num_padding_instances(); - if num_padding_instances > 0 { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); - let padding_instance = vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]; - let num_padding_instance_per_batch = if num_padding_instances > 256 { - num_padding_instances.div_ceil(nthreads) - } else { - num_padding_instances - }; - table - .par_batch_iter_padding_mut(num_padding_instance_per_batch) - .with_min_len(MIN_PAR_SIZE) - .for_each(|row| { - row.chunks_mut(num_witin) - .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); - }); - } + padding_zero(table, num_witin, None); Ok(()) } } + +/// Fill the padding with zeros. Start after the given `num_instances`, or detect it from the table. +pub fn padding_zero( + table: &mut RowMajorMatrix, + num_cols: usize, + num_instances: Option, +) { + // Fill the padding with zeros, if any. + let num_padding_instances = table.num_padding_instances(); + if num_padding_instances > 0 { + let nthreads = + std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let padding_instance = vec![MaybeUninit::new(F::ZERO); num_cols]; + let num_padding_instance_per_batch = if num_padding_instances > 256 { + num_padding_instances.div_ceil(nthreads) + } else { + num_padding_instances + }; + table + .par_batch_iter_padding_mut(num_instances, num_padding_instance_per_batch) + .with_min_len(MIN_PAR_SIZE) + .for_each(|row| { + row.chunks_mut(num_cols) + .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); + }); + } +} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 9572a52cf..6340f94a7 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -7,7 +7,7 @@ use crate::{ scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, - tables::TableCircuit, + tables::{TableCircuit, padding_zero}, utils::i64_to_base, witness::RowMajorMatrix, }; @@ -175,16 +175,8 @@ impl TableCircuit for ProgramTableCircuit { } }); - assert_eq!(INVALID as u64, 0, "cannot use 0 as program padding"); - fixed - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .skip(num_instructions) - .for_each(|row| { - for col in config.record.as_slice() { - set_fixed_val!(row, *col, 0_u64.into()); - } - }); + assert_eq!(INVALID as u64, 0, "0 padding must be invalid instructions"); + padding_zero(&mut fixed, num_fixed, Some(num_instructions)); fixed } @@ -212,13 +204,7 @@ impl TableCircuit for ProgramTableCircuit { set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); }); - witness - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .skip(program.instructions.len()) - .for_each(|row| { - set_val!(row, config.mlt, 0_u64); - }); + padding_zero(&mut witness, num_witin, Some(program.instructions.len())); Ok(witness) } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 7acc9ad50..e85360aee 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -87,12 +87,13 @@ impl RowMajorMatrix { pub fn par_batch_iter_padding_mut( &mut self, - num_rows: usize, + num_instances: Option, + batch_size: usize, ) -> rayon::slice::ChunksMut<'_, MaybeUninit> { - let valid_instance = self.num_instances(); - self.values[valid_instance * self.num_col..] + let num_instances = num_instances.unwrap_or(self.num_instances()); + self.values[num_instances * self.num_col..] .as_mut() - .par_chunks_mut(num_rows * self.num_col) + .par_chunks_mut(batch_size * self.num_col) } pub fn de_interleaving(mut self) -> Vec> {